fix: 修复 StartToGoal 命令无法通过 CTRL+C 中断的问题

- 为 AI 相关方法添加 context.Context 参数支持中断

- 在重试循环中添加上下文取消检查

- 创建可取消的上下文并监听中断信号

- 更新 MCP 工具调用使用带上下文的方法

现在用户可以通过 CTRL+C 正常中断长时间运行的 AI 自动化任务
This commit is contained in:
lilong.129
2025-06-05 19:57:31 +08:00
parent d883aa6a21
commit 5f400735fc
18 changed files with 89 additions and 162 deletions

View File

@@ -1 +1 @@
v5.0.0-beta-2506051809
v5.0.0-beta-2506052000

View File

@@ -53,7 +53,7 @@ func runStepFunction(r *SessionRunner, step IStep) (stepResult *StepResult, err
start := time.Now()
stepResult = &StepResult{
Name: step.Name(),
StepType: StepTypeFunction,
StepType: step.Type(),
Success: false,
ContentSize: 0,
StartTime: start.Unix(),

View File

@@ -42,8 +42,8 @@ func (s *StepRendezvous) Run(r *SessionRunner) (*StepResult, error) {
Msg("rendezvous")
stepResult := &StepResult{
Name: rendezvous.Name,
StepType: StepTypeRendezvous,
Name: s.Name(),
StepType: s.Type(),
Success: true,
}

View File

@@ -282,8 +282,8 @@ func runStepRequest(r *SessionRunner, step IStep) (stepResult *StepResult, err e
stepRequest := step.(*StepRequestWithOptionalArgs)
start := time.Now()
stepResult = &StepResult{
Name: stepRequest.StepName,
StepType: StepTypeRequest,
Name: step.Name(),
StepType: step.Type(),
Success: false,
ContentSize: 0,
StartTime: start.Unix(),
@@ -925,7 +925,7 @@ func (s *StepRequestWithOptionalArgs) Name() string {
}
func (s *StepRequestWithOptionalArgs) Type() StepType {
return StepType(fmt.Sprintf("request-%v", s.Request.Method))
return StepType(fmt.Sprintf("%s-%v", StepTypeRequest, s.Request.Method))
}
func (s *StepRequestWithOptionalArgs) Config() *StepConfig {
@@ -959,7 +959,7 @@ func (s *StepRequestExtraction) Name() string {
}
func (s *StepRequestExtraction) Type() StepType {
stepType := StepType(fmt.Sprintf("request-%v", s.Request.Method))
stepType := StepType(fmt.Sprintf("%s-%v", StepTypeRequest, s.Request.Method))
return stepType + stepTypeSuffixExtraction
}
@@ -987,7 +987,7 @@ func (s *StepRequestValidation) Name() string {
}
func (s *StepRequestValidation) Type() StepType {
stepType := StepType(fmt.Sprintf("request-%v", s.Request.Method))
stepType := StepType(fmt.Sprintf("%s-%v", StepTypeRequest, s.Request.Method))
return stepType + stepTypeSuffixValidation
}

View File

@@ -91,14 +91,14 @@ func runStepShell(r *SessionRunner, step IStep) (stepResult *StepResult, err err
log.Info().
Str("name", step.Name()).
Str("type", string(StepTypeShell)).
Str("type", string(step.Type())).
Str("content", shell.String).
Msg("run shell string")
start := time.Now()
stepResult = &StepResult{
Name: step.Name(),
StepType: StepTypeShell,
StepType: step.Type(),
Success: false,
ContentSize: 0,
StartTime: start.Unix(),

View File

@@ -48,8 +48,8 @@ func (s *StepTestCaseWithOptionalArgs) Config() *StepConfig {
func (s *StepTestCaseWithOptionalArgs) Run(r *SessionRunner) (stepResult *StepResult, err error) {
start := time.Now()
stepResult = &StepResult{
Name: s.StepName,
StepType: StepTypeTestCase,
Name: s.Name(),
StepType: s.Type(),
Success: false,
StartTime: start.Unix(),
}

View File

@@ -36,8 +36,8 @@ func (s *StepThinkTime) Run(r *SessionRunner) (*StepResult, error) {
log.Info().Float64("time", thinkTime.Time).Msg("think time")
stepResult := &StepResult{
Name: s.StepName,
StepType: StepTypeThinkTime,
Name: s.Name(),
StepType: s.Type(),
Success: true,
}

View File

@@ -48,8 +48,8 @@ func (s *StepTransaction) Run(r *SessionRunner) (*StepResult, error) {
Msg("transaction")
stepResult := &StepResult{
Name: transaction.Name,
StepType: StepTypeTransaction,
Name: s.Name(),
StepType: s.Type(),
Success: true,
Elapsed: 0,
ContentSize: 0, // TODO: record transaction total response length

View File

@@ -690,6 +690,15 @@ func (s *StepMobileUIValidation) Run(r *SessionRunner) (*StepResult, error) {
}
func runStepMobileUI(s *SessionRunner, step IStep) (stepResult *StepResult, err error) {
start := time.Now()
stepResult = &StepResult{
Name: step.Name(),
StepType: step.Type(),
Success: false,
ContentSize: 0,
StartTime: start.Unix(),
}
var stepVariables map[string]interface{}
var stepValidators []interface{}
var ignorePopup bool
@@ -706,7 +715,7 @@ func runStepMobileUI(s *SessionRunner, step IStep) (stepResult *StepResult, err
stepValidators = stepMobile.Validators
ignorePopup = stepMobile.StepMobile.IgnorePopup
default:
return nil, errors.New("invalid mobile UI step type")
return stepResult, errors.New("invalid mobile UI step type")
}
// report GA event
@@ -744,7 +753,7 @@ func runStepMobileUI(s *SessionRunner, step IStep) (stepResult *StepResult, err
uiDriver, err := uixt.GetOrCreateXTDriver(config)
if err != nil {
return nil, err
return stepResult, err
}
identifier := mobileStep.Identifier
@@ -759,16 +768,7 @@ func runStepMobileUI(s *SessionRunner, step IStep) (stepResult *StepResult, err
}
}
}
start := time.Now()
stepResult = &StepResult{
Name: step.Name(),
Identifier: identifier,
StepType: step.Type(),
Success: false,
ContentSize: 0,
StartTime: start.Unix(),
}
stepResult.Identifier = identifier
defer func() {
attachments := uixt.Attachments{}
@@ -859,7 +859,8 @@ func runStepMobileUI(s *SessionRunner, step IStep) (stepResult *StepResult, err
}
// Apply global LLM service configuration for AI actions
if action.Method == option.ACTION_AIAction || action.Method == option.ACTION_StartToGoal {
if action.Method == option.ACTION_AIAction || action.Method == option.ACTION_StartToGoal ||
action.Method == option.ACTION_AIAssert || action.Method == option.ACTION_Query {
if config.LLMService != "" && action.Options.LLMService == "" {
action.Options.LLMService = string(config.LLMService)
log.Debug().Str("action", string(action.Method)).Str("llmService", action.Options.LLMService).Msg("Applied global LLM service config to action")
@@ -891,8 +892,22 @@ func runStepMobileUI(s *SessionRunner, step IStep) (stepResult *StepResult, err
continue
}
// call MCP tool to execute action
err = uiDriver.ExecuteAction(context.Background(), action)
// call MCP tool to execute action with cancellable context
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Create a goroutine to monitor for interrupt signals
go func() {
select {
case <-s.caseRunner.hrpRunner.interruptSignal:
log.Warn().Msg("cancelling action due to interrupt signal")
cancel()
case <-ctx.Done():
// Context already cancelled
}
}()
err = uiDriver.ExecuteAction(ctx, action)
actionResult.Elapsed = time.Since(actionStartTime).Milliseconds()
stepResult.Actions = append(stepResult.Actions, actionResult)
if err != nil {

View File

@@ -378,7 +378,7 @@ func runStepWebSocket(r *SessionRunner, step IStep) (stepResult *StepResult, err
start := time.Now()
stepResult = &StepResult{
Name: step.Name(),
StepType: StepTypeWebSocket,
StepType: step.Type(),
Success: false,
ContentSize: 0,
StartTime: start.Unix(),

View File

@@ -130,18 +130,21 @@ func GetModelConfig(modelType option.LLMServiceType) (*ModelConfig, error) {
func validateModelType(modelType option.LLMServiceType, modelName string) error {
switch modelType {
case option.DOUBAO_1_5_UI_TARS_250428:
if !strings.Contains(modelName, "ui-tars") {
if !strings.Contains(modelName, string(modelType)) {
return fmt.Errorf("model name %s is not supported for %s", modelName, modelType)
}
return nil
case option.DOUBAO_1_5_THINKING_VISION_PRO_250428:
if !strings.Contains(modelName, "doubao") || !strings.Contains(modelName, "vision") {
if !strings.Contains(modelName, string(modelType)) {
return fmt.Errorf("model name %s is not supported", modelName)
}
return nil
}
return fmt.Errorf("model type %s is not supported", modelType)
return fmt.Errorf("model type %s is not supported, supported types: %s, %s",
modelType,
option.DOUBAO_1_5_UI_TARS_250428,
option.DOUBAO_1_5_THINKING_VISION_PRO_250428)
}
// maskAPIKey masks the API key

View File

@@ -1,109 +0,0 @@
# HttpRunner UIXT Cache Test Suite Summary
## 概述
`httprunner/uixt/cache.go` 编写了全面的单元测试用例,覆盖了统一缓存系统的所有核心功能。
## 测试覆盖范围
### 1. GetOrCreateXTDriver 测试
- **TestGetOrCreateXTDriver_EmptySerial**: 测试空 serial 参数的错误处理
- **TestGetOrCreateXTDriver_WithUnifiedDeviceOptions**: 测试使用统一 DeviceOptions 创建驱动配置
- **TestGetOrCreateXTDriver_DifferentPlatformConfigs**: 测试不同平台Android、iOS、Harmony、Browser的配置
### 2. RegisterXTDriver 测试
- **TestRegisterXTDriver_EmptySerial**: 测试空 serial 参数的错误处理
- **TestRegisterXTDriver_NilDriver**: 测试 nil driver 参数的错误处理
- **TestRegisterXTDriver_Success**: 测试成功注册外部驱动
### 3. ReleaseXTDriver 测试
- **TestReleaseXTDriver_NonExistentSerial**: 测试释放不存在的驱动(应该不报错)
- **TestReleaseXTDriver_CleanupWhenZero**: 测试引用计数为 0 时的自动清理
### 4. 缓存管理测试
- **TestCleanupAllDrivers**: 测试清理所有缓存驱动
- **TestListCachedDrivers_Empty**: 测试空缓存的列表功能
- **TestListCachedDrivers_Multiple**: 测试多个驱动的列表功能
### 5. 配置测试
- **TestDriverCacheConfig_WithoutDeviceOpts**: 测试不使用 DeviceOpts 的配置
- **TestDriverCacheConfig_DefaultAIOptions**: 测试默认 AI 选项的配置
### 6. 并发测试
- **TestConcurrentAccess**: 测试并发访问缓存的安全性和正确性
### 7. 集成测试
- **TestIntegrationExample_BasicUsage**: 测试基本使用场景
- **TestIntegrationExample_TraditionalWay**: 测试传统方式(向后兼容)
- **TestIntegrationExample_MultipleDevices**: 测试多设备场景
### 8. DeviceOptions 集成测试
- **TestDeviceOptionsIntegration**: 测试统一 DeviceOptions 的平台自动检测功能
### 9. 引用计数管理测试
- **TestCacheReferenceCountManagement**: 测试引用计数的增减和资源管理
## 测试特点
### 1. 简化的测试方法
- 避免了复杂的 mock 实现
- 使用最小化的 `XTDriver{}` 实例进行测试
- 专注于缓存逻辑而非设备创建逻辑
### 2. 错误处理覆盖
- 测试了所有主要的错误场景
- 验证了空指针保护机制
- 确保了资源清理的安全性
### 3. 并发安全性
- 验证了 `sync.Map` 的并发访问安全性
- 测试了引用计数在并发环境下的正确性
### 4. 向后兼容性
- 验证了传统 API 的继续支持
- 测试了新旧方式的互操作性
## 修复的问题
### 1. 空指针保护
`CleanupAllDrivers``ReleaseXTDriver` 函数中添加了空指针检查:
```go
if cached.Driver != nil && cached.Driver.IDriver != nil {
if err := cached.Driver.DeleteSession(); err != nil {
// handle error
}
}
```
### 2. 并发测试逻辑
修正了并发测试的预期行为,从测试注册冲突改为测试缓存复用。
## 运行结果
所有 18 个测试用例全部通过:
- 基础功能测试:✅
- 错误处理测试:✅
- 并发安全测试:✅
- 集成场景测试:✅
- 引用计数管理:✅
## 测试命令
```bash
# 运行所有缓存相关测试
go test -v ./uixt -run "^Test.*Cache.*|^TestGetOrCreateXTDriver|^TestRegisterXTDriver|^TestReleaseXTDriver|^TestCleanupAllDrivers|^TestListCachedDrivers|^TestDriverCacheConfig|^TestConcurrentAccess|^TestIntegrationExample|^TestDeviceOptionsIntegration$"
# 运行特定测试
go test -v ./uixt -run TestConcurrentAccess
```
## 总结
这套测试用例全面覆盖了 HttpRunner UIXT 缓存系统的核心功能,确保了:
1. 缓存的正确性和一致性
2. 错误处理的健壮性
3. 并发访问的安全性
4. 资源管理的可靠性
5. API 的向后兼容性
测试设计简洁高效,避免了复杂的 mock 依赖,专注于验证缓存逻辑本身。

View File

@@ -18,13 +18,23 @@ import (
"github.com/rs/zerolog/log"
)
func (dExt *XTDriver) StartToGoal(text string, opts ...option.ActionOption) error {
func (dExt *XTDriver) StartToGoal(ctx context.Context, text string, opts ...option.ActionOption) error {
options := option.NewActionOptions(opts...)
log.Info().Int("max_retry_times", options.MaxRetryTimes).Msg("StartToGoal")
var attempt int
for {
attempt++
log.Info().Int("attempt", attempt).Msg("planning attempt")
if err := dExt.AIAction(text, opts...); err != nil {
// Check for context cancellation (interrupt signal)
select {
case <-ctx.Done():
log.Warn().Msg("interrupted in StartToGoal")
return errors.Wrap(code.InterruptError, "StartToGoal interrupted")
default:
}
if err := dExt.AIAction(ctx, text, opts...); err != nil {
// Check if this is a LLM service request error that should be retried
if errors.Is(err, code.LLMRequestServiceError) {
log.Warn().Err(err).Int("attempt", attempt).
@@ -40,15 +50,23 @@ func (dExt *XTDriver) StartToGoal(text string, opts ...option.ActionOption) erro
}
}
func (dExt *XTDriver) AIAction(text string, opts ...option.ActionOption) error {
func (dExt *XTDriver) AIAction(ctx context.Context, text string, opts ...option.ActionOption) error {
// plan next action
result, err := dExt.PlanNextAction(text, opts...)
result, err := dExt.PlanNextAction(ctx, text, opts...)
if err != nil {
return err
}
// do actions
for _, action := range result.ToolCalls {
// Check for context cancellation before each action
select {
case <-ctx.Done():
log.Warn().Msg("interrupted in AIAction")
return errors.Wrap(code.InterruptError, "AIAction interrupted")
default:
}
// call eino tool
arguments := make(map[string]interface{})
err := json.Unmarshal([]byte(action.Function.Arguments), &arguments)
@@ -68,7 +86,7 @@ func (dExt *XTDriver) AIAction(text string, opts ...option.ActionOption) error {
},
}
_, err = dExt.client.CallTool(context.Background(), req)
_, err = dExt.client.CallTool(ctx, req)
if err != nil {
return err
}
@@ -77,7 +95,7 @@ func (dExt *XTDriver) AIAction(text string, opts ...option.ActionOption) error {
return nil
}
func (dExt *XTDriver) PlanNextAction(text string, opts ...option.ActionOption) (*ai.PlanningResult, error) {
func (dExt *XTDriver) PlanNextAction(ctx context.Context, text string, opts ...option.ActionOption) (*ai.PlanningResult, error) {
if dExt.LLMService == nil {
return nil, errors.New("LLM service is not initialized")
}
@@ -124,7 +142,7 @@ func (dExt *XTDriver) PlanNextAction(text string, opts ...option.ActionOption) (
Size: size,
}
result, err := dExt.LLMService.Call(context.Background(), planningOpts)
result, err := dExt.LLMService.Call(ctx, planningOpts)
if err != nil {
return nil, errors.Wrap(err, "failed to get next action from planner")
}

View File

@@ -4,6 +4,7 @@ package uixt
import (
"bytes"
"context"
"image"
"os"
"testing"
@@ -130,7 +131,7 @@ func TestDriverExt_TapByOCR(t *testing.T) {
func TestDriverExt_TapByLLM(t *testing.T) {
driver := setupDriverExt(t)
err := driver.AIAction("点击第一个帖子的作者头像")
err := driver.AIAction(context.Background(), "点击第一个帖子的作者头像")
assert.Nil(t, err)
err = driver.AIAssert("当前在个人介绍页")
@@ -161,13 +162,13 @@ func TestDriverExt_StartToGoal(t *testing.T) {
userInstruction += "\n\n请严格按照以上游戏规则开始游戏注意请只做点击操作"
err := driver.StartToGoal(userInstruction)
err := driver.StartToGoal(context.Background(), userInstruction)
assert.Nil(t, err)
}
func TestDriverExt_PlanNextAction(t *testing.T) {
driver := setupDriverExt(t)
result, err := driver.PlanNextAction("启动抖音")
result, err := driver.PlanNextAction(context.Background(), "启动抖音")
assert.Nil(t, err)
t.Log(result)
}

View File

@@ -40,7 +40,7 @@ func (t *ToolStartToGoal) Implement() server.ToolHandlerFunc {
// Start to goal logic
log.Info().Str("prompt", unifiedReq.Prompt).Msg("starting to goal")
err = driverExt.StartToGoal(unifiedReq.Prompt)
err = driverExt.StartToGoal(ctx, unifiedReq.Prompt)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to achieve goal: %s", err.Error())), nil
}
@@ -99,7 +99,7 @@ func (t *ToolAIAction) Implement() server.ToolHandlerFunc {
// AI action logic
log.Info().Str("prompt", unifiedReq.Prompt).Msg("performing AI action")
err = driverExt.AIAction(unifiedReq.Prompt)
err = driverExt.AIAction(ctx, unifiedReq.Prompt)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("AI action failed: %s", err.Error())), nil
}

View File

@@ -53,8 +53,6 @@ func (t *ToolTapXY) Implement() server.ToolHandlerFunc {
}
// Tap action logic
log.Info().Float64("x", unifiedReq.X).Float64("y", unifiedReq.Y).Msg("tapping at coordinates")
err = driverExt.TapXY(unifiedReq.X, unifiedReq.Y, opts...)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Tap failed: %s", err.Error())), nil
@@ -354,7 +352,6 @@ func (t *ToolDoubleTapXY) Implement() server.ToolHandlerFunc {
}
// Double tap XY action logic
log.Info().Float64("x", unifiedReq.X).Float64("y", unifiedReq.Y).Msg("double tapping at coordinates")
err = driverExt.DoubleTap(unifiedReq.X, unifiedReq.Y)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Double tap failed: %s", err.Error())), nil

View File

@@ -105,6 +105,8 @@ const (
// AI actions
ACTION_StartToGoal ActionName = "start_to_goal" // start to goal action
ACTION_AIAction ActionName = "ai_action" // action with ai
ACTION_AIAssert ActionName = "ai_assert" // assert with ai
ACTION_Query ActionName = "ai_query" // query with ai
ACTION_Finished ActionName = "finished" // finished action
// anti-risk actions

View File

@@ -9,6 +9,7 @@ import (
"github.com/httprunner/httprunner/v5/uixt/option"
"github.com/mark3labs/mcp-go/client"
"github.com/mark3labs/mcp-go/mcp"
"github.com/pkg/errors"
"github.com/rs/zerolog/log"
)
@@ -34,8 +35,7 @@ func NewXTDriver(driver IDriver, opts ...option.AIServiceOption) (*XTDriver, err
if services.LLMService != "" {
driverExt.LLMService, err = ai.NewLLMService(services.LLMService)
if err != nil {
log.Error().Err(err).Msg("init llm service failed")
return nil, err
return nil, errors.Wrap(err, "init llm service failed")
}
}