fix: 修复 StartToGoal 命令无法通过 CTRL+C 中断的问题

- 为 AI 相关方法添加 context.Context 参数支持中断

- 在重试循环中添加上下文取消检查

- 创建可取消的上下文并监听中断信号

- 更新 MCP 工具调用使用带上下文的方法

现在用户可以通过 CTRL+C 正常中断长时间运行的 AI 自动化任务
This commit is contained in:
lilong.129
2025-06-05 19:57:31 +08:00
parent d883aa6a21
commit 5f400735fc
18 changed files with 89 additions and 162 deletions

View File

@@ -130,18 +130,21 @@ func GetModelConfig(modelType option.LLMServiceType) (*ModelConfig, error) {
func validateModelType(modelType option.LLMServiceType, modelName string) error {
switch modelType {
case option.DOUBAO_1_5_UI_TARS_250428:
if !strings.Contains(modelName, "ui-tars") {
if !strings.Contains(modelName, string(modelType)) {
return fmt.Errorf("model name %s is not supported for %s", modelName, modelType)
}
return nil
case option.DOUBAO_1_5_THINKING_VISION_PRO_250428:
if !strings.Contains(modelName, "doubao") || !strings.Contains(modelName, "vision") {
if !strings.Contains(modelName, string(modelType)) {
return fmt.Errorf("model name %s is not supported", modelName)
}
return nil
}
return fmt.Errorf("model type %s is not supported", modelType)
return fmt.Errorf("model type %s is not supported, supported types: %s, %s",
modelType,
option.DOUBAO_1_5_UI_TARS_250428,
option.DOUBAO_1_5_THINKING_VISION_PRO_250428)
}
// maskAPIKey masks the API key

View File

@@ -1,109 +0,0 @@
# HttpRunner UIXT Cache Test Suite Summary
## 概述
`httprunner/uixt/cache.go` 编写了全面的单元测试用例,覆盖了统一缓存系统的所有核心功能。
## 测试覆盖范围
### 1. GetOrCreateXTDriver 测试
- **TestGetOrCreateXTDriver_EmptySerial**: 测试空 serial 参数的错误处理
- **TestGetOrCreateXTDriver_WithUnifiedDeviceOptions**: 测试使用统一 DeviceOptions 创建驱动配置
- **TestGetOrCreateXTDriver_DifferentPlatformConfigs**: 测试不同平台Android、iOS、Harmony、Browser的配置
### 2. RegisterXTDriver 测试
- **TestRegisterXTDriver_EmptySerial**: 测试空 serial 参数的错误处理
- **TestRegisterXTDriver_NilDriver**: 测试 nil driver 参数的错误处理
- **TestRegisterXTDriver_Success**: 测试成功注册外部驱动
### 3. ReleaseXTDriver 测试
- **TestReleaseXTDriver_NonExistentSerial**: 测试释放不存在的驱动(应该不报错)
- **TestReleaseXTDriver_CleanupWhenZero**: 测试引用计数为 0 时的自动清理
### 4. 缓存管理测试
- **TestCleanupAllDrivers**: 测试清理所有缓存驱动
- **TestListCachedDrivers_Empty**: 测试空缓存的列表功能
- **TestListCachedDrivers_Multiple**: 测试多个驱动的列表功能
### 5. 配置测试
- **TestDriverCacheConfig_WithoutDeviceOpts**: 测试不使用 DeviceOpts 的配置
- **TestDriverCacheConfig_DefaultAIOptions**: 测试默认 AI 选项的配置
### 6. 并发测试
- **TestConcurrentAccess**: 测试并发访问缓存的安全性和正确性
### 7. 集成测试
- **TestIntegrationExample_BasicUsage**: 测试基本使用场景
- **TestIntegrationExample_TraditionalWay**: 测试传统方式(向后兼容)
- **TestIntegrationExample_MultipleDevices**: 测试多设备场景
### 8. DeviceOptions 集成测试
- **TestDeviceOptionsIntegration**: 测试统一 DeviceOptions 的平台自动检测功能
### 9. 引用计数管理测试
- **TestCacheReferenceCountManagement**: 测试引用计数的增减和资源管理
## 测试特点
### 1. 简化的测试方法
- 避免了复杂的 mock 实现
- 使用最小化的 `XTDriver{}` 实例进行测试
- 专注于缓存逻辑而非设备创建逻辑
### 2. 错误处理覆盖
- 测试了所有主要的错误场景
- 验证了空指针保护机制
- 确保了资源清理的安全性
### 3. 并发安全性
- 验证了 `sync.Map` 的并发访问安全性
- 测试了引用计数在并发环境下的正确性
### 4. 向后兼容性
- 验证了传统 API 的继续支持
- 测试了新旧方式的互操作性
## 修复的问题
### 1. 空指针保护
`CleanupAllDrivers``ReleaseXTDriver` 函数中添加了空指针检查:
```go
if cached.Driver != nil && cached.Driver.IDriver != nil {
if err := cached.Driver.DeleteSession(); err != nil {
// handle error
}
}
```
### 2. 并发测试逻辑
修正了并发测试的预期行为,从测试注册冲突改为测试缓存复用。
## 运行结果
所有 18 个测试用例全部通过:
- 基础功能测试:✅
- 错误处理测试:✅
- 并发安全测试:✅
- 集成场景测试:✅
- 引用计数管理:✅
## 测试命令
```bash
# 运行所有缓存相关测试
go test -v ./uixt -run "^Test.*Cache.*|^TestGetOrCreateXTDriver|^TestRegisterXTDriver|^TestReleaseXTDriver|^TestCleanupAllDrivers|^TestListCachedDrivers|^TestDriverCacheConfig|^TestConcurrentAccess|^TestIntegrationExample|^TestDeviceOptionsIntegration$"
# 运行特定测试
go test -v ./uixt -run TestConcurrentAccess
```
## 总结
这套测试用例全面覆盖了 HttpRunner UIXT 缓存系统的核心功能,确保了:
1. 缓存的正确性和一致性
2. 错误处理的健壮性
3. 并发访问的安全性
4. 资源管理的可靠性
5. API 的向后兼容性
测试设计简洁高效,避免了复杂的 mock 依赖,专注于验证缓存逻辑本身。

View File

@@ -18,13 +18,23 @@ import (
"github.com/rs/zerolog/log"
)
func (dExt *XTDriver) StartToGoal(text string, opts ...option.ActionOption) error {
func (dExt *XTDriver) StartToGoal(ctx context.Context, text string, opts ...option.ActionOption) error {
options := option.NewActionOptions(opts...)
log.Info().Int("max_retry_times", options.MaxRetryTimes).Msg("StartToGoal")
var attempt int
for {
attempt++
log.Info().Int("attempt", attempt).Msg("planning attempt")
if err := dExt.AIAction(text, opts...); err != nil {
// Check for context cancellation (interrupt signal)
select {
case <-ctx.Done():
log.Warn().Msg("interrupted in StartToGoal")
return errors.Wrap(code.InterruptError, "StartToGoal interrupted")
default:
}
if err := dExt.AIAction(ctx, text, opts...); err != nil {
// Check if this is a LLM service request error that should be retried
if errors.Is(err, code.LLMRequestServiceError) {
log.Warn().Err(err).Int("attempt", attempt).
@@ -40,15 +50,23 @@ func (dExt *XTDriver) StartToGoal(text string, opts ...option.ActionOption) erro
}
}
func (dExt *XTDriver) AIAction(text string, opts ...option.ActionOption) error {
func (dExt *XTDriver) AIAction(ctx context.Context, text string, opts ...option.ActionOption) error {
// plan next action
result, err := dExt.PlanNextAction(text, opts...)
result, err := dExt.PlanNextAction(ctx, text, opts...)
if err != nil {
return err
}
// do actions
for _, action := range result.ToolCalls {
// Check for context cancellation before each action
select {
case <-ctx.Done():
log.Warn().Msg("interrupted in AIAction")
return errors.Wrap(code.InterruptError, "AIAction interrupted")
default:
}
// call eino tool
arguments := make(map[string]interface{})
err := json.Unmarshal([]byte(action.Function.Arguments), &arguments)
@@ -68,7 +86,7 @@ func (dExt *XTDriver) AIAction(text string, opts ...option.ActionOption) error {
},
}
_, err = dExt.client.CallTool(context.Background(), req)
_, err = dExt.client.CallTool(ctx, req)
if err != nil {
return err
}
@@ -77,7 +95,7 @@ func (dExt *XTDriver) AIAction(text string, opts ...option.ActionOption) error {
return nil
}
func (dExt *XTDriver) PlanNextAction(text string, opts ...option.ActionOption) (*ai.PlanningResult, error) {
func (dExt *XTDriver) PlanNextAction(ctx context.Context, text string, opts ...option.ActionOption) (*ai.PlanningResult, error) {
if dExt.LLMService == nil {
return nil, errors.New("LLM service is not initialized")
}
@@ -124,7 +142,7 @@ func (dExt *XTDriver) PlanNextAction(text string, opts ...option.ActionOption) (
Size: size,
}
result, err := dExt.LLMService.Call(context.Background(), planningOpts)
result, err := dExt.LLMService.Call(ctx, planningOpts)
if err != nil {
return nil, errors.Wrap(err, "failed to get next action from planner")
}

View File

@@ -4,6 +4,7 @@ package uixt
import (
"bytes"
"context"
"image"
"os"
"testing"
@@ -130,7 +131,7 @@ func TestDriverExt_TapByOCR(t *testing.T) {
func TestDriverExt_TapByLLM(t *testing.T) {
driver := setupDriverExt(t)
err := driver.AIAction("点击第一个帖子的作者头像")
err := driver.AIAction(context.Background(), "点击第一个帖子的作者头像")
assert.Nil(t, err)
err = driver.AIAssert("当前在个人介绍页")
@@ -161,13 +162,13 @@ func TestDriverExt_StartToGoal(t *testing.T) {
userInstruction += "\n\n请严格按照以上游戏规则开始游戏注意请只做点击操作"
err := driver.StartToGoal(userInstruction)
err := driver.StartToGoal(context.Background(), userInstruction)
assert.Nil(t, err)
}
func TestDriverExt_PlanNextAction(t *testing.T) {
driver := setupDriverExt(t)
result, err := driver.PlanNextAction("启动抖音")
result, err := driver.PlanNextAction(context.Background(), "启动抖音")
assert.Nil(t, err)
t.Log(result)
}

View File

@@ -40,7 +40,7 @@ func (t *ToolStartToGoal) Implement() server.ToolHandlerFunc {
// Start to goal logic
log.Info().Str("prompt", unifiedReq.Prompt).Msg("starting to goal")
err = driverExt.StartToGoal(unifiedReq.Prompt)
err = driverExt.StartToGoal(ctx, unifiedReq.Prompt)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to achieve goal: %s", err.Error())), nil
}
@@ -99,7 +99,7 @@ func (t *ToolAIAction) Implement() server.ToolHandlerFunc {
// AI action logic
log.Info().Str("prompt", unifiedReq.Prompt).Msg("performing AI action")
err = driverExt.AIAction(unifiedReq.Prompt)
err = driverExt.AIAction(ctx, unifiedReq.Prompt)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("AI action failed: %s", err.Error())), nil
}

View File

@@ -53,8 +53,6 @@ func (t *ToolTapXY) Implement() server.ToolHandlerFunc {
}
// Tap action logic
log.Info().Float64("x", unifiedReq.X).Float64("y", unifiedReq.Y).Msg("tapping at coordinates")
err = driverExt.TapXY(unifiedReq.X, unifiedReq.Y, opts...)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Tap failed: %s", err.Error())), nil
@@ -354,7 +352,6 @@ func (t *ToolDoubleTapXY) Implement() server.ToolHandlerFunc {
}
// Double tap XY action logic
log.Info().Float64("x", unifiedReq.X).Float64("y", unifiedReq.Y).Msg("double tapping at coordinates")
err = driverExt.DoubleTap(unifiedReq.X, unifiedReq.Y)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Double tap failed: %s", err.Error())), nil

View File

@@ -105,6 +105,8 @@ const (
// AI actions
ACTION_StartToGoal ActionName = "start_to_goal" // start to goal action
ACTION_AIAction ActionName = "ai_action" // action with ai
ACTION_AIAssert ActionName = "ai_assert" // assert with ai
ACTION_Query ActionName = "ai_query" // query with ai
ACTION_Finished ActionName = "finished" // finished action
// anti-risk actions

View File

@@ -9,6 +9,7 @@ import (
"github.com/httprunner/httprunner/v5/uixt/option"
"github.com/mark3labs/mcp-go/client"
"github.com/mark3labs/mcp-go/mcp"
"github.com/pkg/errors"
"github.com/rs/zerolog/log"
)
@@ -34,8 +35,7 @@ func NewXTDriver(driver IDriver, opts ...option.AIServiceOption) (*XTDriver, err
if services.LLMService != "" {
driverExt.LLMService, err = ai.NewLLMService(services.LLMService)
if err != nil {
log.Error().Err(err).Msg("init llm service failed")
return nil, err
return nil, errors.Wrap(err, "init llm service failed")
}
}