diff --git a/uixt/ai/cv_vedem.go b/uixt/ai/cv_vedem.go index 2cb4a833..e26068c5 100644 --- a/uixt/ai/cv_vedem.go +++ b/uixt/ai/cv_vedem.go @@ -49,7 +49,8 @@ func NewVEDEMImageService() (*vedemCVService, error) { type vedemCVService struct{} func (s *vedemCVService) ReadFromPath(imagePath string, opts ...option.ActionOption) ( - imageResult *CVResult, err error) { + imageResult *CVResult, err error, +) { imageBuf, err := os.ReadFile(imagePath) if err != nil { err = errors.Wrap(code.CVPrepareRequestError, @@ -61,7 +62,8 @@ func (s *vedemCVService) ReadFromPath(imagePath string, opts ...option.ActionOpt } func (s *vedemCVService) ReadFromBuffer(imageBuf *bytes.Buffer, opts ...option.ActionOption) ( - imageResult *CVResult, err error) { + imageResult *CVResult, err error, +) { actionOptions := option.NewActionOptions(opts...) log.Debug().Interface("options", actionOptions).Msg("vedem.ReadFromBuffer") screenshotActions := actionOptions.List() @@ -77,7 +79,7 @@ func (s *vedemCVService) ReadFromBuffer(imageBuf *bytes.Buffer, opts ...option.A if err != nil { logger = log.Error().Err(err) } else { - logger = log.Debug() + logger = log.Info() if imageResult.URL != "" { logger = logger.Str("url", imageResult.URL) } diff --git a/uixt/ai/wings_service.go b/uixt/ai/wings_service.go new file mode 100644 index 00000000..5457457d --- /dev/null +++ b/uixt/ai/wings_service.go @@ -0,0 +1,542 @@ +package ai + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/cloudwego/eino/schema" + "github.com/google/uuid" + "github.com/pkg/errors" + "github.com/rs/zerolog/log" +) + +// WingsService implements ILLMService interface using external Wings API +type WingsService struct { + apiURL string + bizId string +} + +// NewWingsService creates a new Wings service instance +func NewWingsService() ILLMService { + return &WingsService{ + apiURL: "https://vedem-algorithm.bytedance.net/algorithm/StepActionDecision", + bizId: "489fdae44de048e0922a32834ea668af", + } +} + +// Plan implements the ILLMService.Plan method using Wings API +func (w *WingsService) Plan(ctx context.Context, opts *PlanningOptions) (*PlanningResult, error) { + // Validate input parameters + if err := validatePlanningInput(opts); err != nil { + return nil, errors.Wrap(err, "validate planning parameters failed") + } + + // Extract screenshot from message + screenshot, err := w.extractScreenshotFromMessage(opts.Message) + if err != nil { + return nil, errors.Wrap(err, "extract screenshot failed") + } + + // Get device info from context (if available) + deviceInfo := w.getDeviceInfoFromContext(ctx, screenshot) + + // Prepare Wings API request + apiRequest := WingsActionRequest{ + Historys: []interface{}{}, // empty as specified + DeviceInfos: []WingsDeviceInfo{ + deviceInfo, + }, + StepText: opts.UserInstruction, + BizId: w.bizId, + TextCase: "整体描述:\\n前置条件:\\n获取 1 台设备 A。\\n获取 1 个[万粉创作者]账号a。\\n获取 2 个[普通]账号 b、c。\\n账号 a 和账号 b 互相关注。\\n账号 a 和账号 c 互相关注。\\n账号 a 给账号 b 设置备注为 “11131b”。\\n账号 a 给账号 c 设置备注为 “11131c”。\\n账号 a 创建一个粉丝群 m。\\n 账号 a 修改粉丝群 m 名称为“11131群”。\\n 账号 a 邀请账号 b 加入粉丝群 m。\\n账号 a 邀请账号 c 加入粉丝群 m。\\n账号 a 给群聊 m 发送一条文字消息。\\n设备 A 打开抖音 app。\\n设备 A 登录账号 a。\\n设备 A 退出抖音 app。\\n操作步骤:\\n账号a打开抖音app。\\n点击“消息”。\\n点击“11131群”cell。\\n点击“聊天信息页入口”按钮。\\n点击“分享公开群”按钮。\\n点击文字“群口令”。\\n断言:屏幕中存在文字“口令复制成功”。\\n停止操作。\\n注意事项:\\n", + StepType: "automation", + DeviceID: deviceInfo.DeviceID, + Base: WingsBase{ + LogID: generateWingsUUID(), + }, + } + + // Call Wings API + startTime := time.Now() + response, err := w.callWingsAPI(ctx, apiRequest) + elapsed := time.Since(startTime).Milliseconds() + + if err != nil { + return &PlanningResult{ + Thought: "Wings API call failed", + Error: err.Error(), + ModelName: "wings-api", + }, errors.Wrap(err, "Wings API call failed") + } + + // Check API response status + if response.BaseResp.StatusCode != 0 { + err = fmt.Errorf("API returned error: %s", response.BaseResp.StatusMessage) + return &PlanningResult{ + Thought: response.ThoughtChain.Thought, + Error: err.Error(), + ModelName: "wings-api", + }, err + } + + // Convert Wings API response to tool calls + toolCalls, err := w.convertWingsResponseToToolCalls(response.ActionParams) + if err != nil { + return &PlanningResult{ + Thought: response.ThoughtChain.Thought, + Error: err.Error(), + ModelName: "wings-api", + }, errors.Wrap(err, "convert Wings response to tool calls failed") + } + + log.Info(). + Str("thought", response.ThoughtChain.Thought). + Int("tool_calls_count", len(toolCalls)). + Int64("elapsed_ms", elapsed). + Msg("Wings API planning completed") + + return &PlanningResult{ + ToolCalls: toolCalls, + Thought: response.ThoughtChain.Thought, + Content: response.ThoughtChain.Summary, + ModelName: "wings-api", + }, nil +} + +// Assert implements the ILLMService.Assert method using Wings API +func (w *WingsService) Assert(ctx context.Context, opts *AssertOptions) (*AssertionResult, error) { + // Validate input parameters + if err := validateAssertionInput(opts); err != nil { + return nil, errors.Wrap(err, "validate assertion parameters failed") + } + + // Clean screenshot data URL prefix + cleanScreenshot := w.cleanScreenshotDataURL(opts.Screenshot) + + // Get device info from context (if available) + deviceInfo := w.getDeviceInfoFromScreenshot(ctx, cleanScreenshot) + + // Prepare Wings API request for assertion + apiRequest := WingsActionRequest{ + Historys: []interface{}{}, // empty as specified + DeviceInfos: []WingsDeviceInfo{ + deviceInfo, + }, + StepText: opts.Assertion, + BizId: w.bizId, + TextCase: "整体描述:\\n前置条件:\\n获取 1 台设备 A。\\n获取 1 个[万粉创作者]账号a。\\n获取 2 个[普通]账号 b、c。\\n账号 a 和账号 b 互相关注。\\n账号 a 和账号 c 互相关注。\\n账号 a 给账号 b 设置备注为 “11131b”。\\n账号 a 给账号 c 设置备注为 “11131c”。\\n账号 a 创建一个粉丝群 m。\\n 账号 a 修改粉丝群 m 名称为“11131群”。\\n 账号 a 邀请账号 b 加入粉丝群 m。\\n账号 a 邀请账号 c 加入粉丝群 m。\\n账号 a 给群聊 m 发送一条文字消息。\\n设备 A 打开抖音 app。\\n设备 A 登录账号 a。\\n设备 A 退出抖音 app。\\n操作步骤:\\n账号a打开抖音app。\\n点击“消息”。\\n点击“11131群”cell。\\n点击“聊天信息页入口”按钮。\\n点击“分享公开群”按钮。\\n点击文字“群口令”。\\n断言:屏幕中存在文字“口令复制成功”。\\n停止操作。\\n注意事项:\\n", + StepType: "assert", // Different from automation + DeviceID: deviceInfo.DeviceID, + Base: WingsBase{ + LogID: generateWingsUUID(), + }, + } + + // Call Wings API + startTime := time.Now() + response, err := w.callWingsAPI(ctx, apiRequest) + elapsed := time.Since(startTime).Milliseconds() + + if err != nil { + return &AssertionResult{ + Pass: false, + Thought: "Wings API call failed", + ModelName: "wings-api", + }, errors.Wrap(err, "Wings API call failed") + } + + // Check API response status + if response.BaseResp.StatusCode != 0 { + err = fmt.Errorf("API returned error: %s", response.BaseResp.StatusMessage) + return &AssertionResult{ + Pass: false, + Thought: response.ThoughtChain.Thought, + ModelName: "wings-api", + }, err + } + + // Parse assertion result from action_params + passed, assertionThought, err := w.parseAssertionResult(response.ActionParams, response.ThoughtChain) + if err != nil { + return &AssertionResult{ + Pass: false, + Thought: response.ThoughtChain.Thought, + ModelName: "wings-api", + }, errors.Wrap(err, "parse assertion result failed") + } + + log.Info(). + Bool("passed", passed). + Str("thought", assertionThought). + Int64("elapsed_ms", elapsed). + Msg("Wings API assertion completed") + + result := &AssertionResult{ + Pass: passed, + Thought: assertionThought, + ModelName: "wings-api", + } + + // Return error if assertion failed (consistent with original behavior) + if !passed { + return result, errors.New(assertionThought) + } + + return result, nil +} + +// Query implements the ILLMService.Query method (not supported) +func (w *WingsService) Query(ctx context.Context, opts *QueryOptions) (*QueryResult, error) { + return nil, errors.New("Query operation is not supported by Wings service") +} + +// RegisterTools implements the ILLMService.RegisterTools method (no-op for Wings) +func (w *WingsService) RegisterTools(tools []*schema.ToolInfo) error { + // Wings service doesn't need tool registration as it determines actions via API + log.Debug().Int("tools_count", len(tools)).Msg("Wings service ignoring tool registration") + return nil +} + +// Wings API data structures +type WingsActionRequest struct { + Historys []interface{} `json:"historys"` + DeviceInfos []WingsDeviceInfo `json:"device_infos"` + StepText string `json:"step_text"` + BizId string `json:"biz_id"` + TextCase string `json:"text_case"` + StepType string `json:"step_type"` + DeviceID string `json:"device_id"` + Base WingsBase `json:"Base"` +} + +type WingsDeviceInfo struct { + DeviceID string `json:"device_id"` + NowImage string `json:"now_image"` + PreImage string `json:"pre_image"` + NowImageUrl string `json:"now_image_url"` + PreImageUrl string `json:"pre_image_url"` + NowLayoutJSON string `json:"now_layout_json"` + OperationSystem string `json:"operation_system"` +} + +type WingsBase struct { + LogID string `json:"LogID"` +} + +type WingsActionResponse struct { + StepType string `json:"step_type"` + ActionParams string `json:"action_params"` + ThoughtChain WingsThoughtChain `json:"thought_chain"` + BaseResp WingsBaseResp `json:"BaseResp"` +} + +type WingsThoughtChain struct { + Observation string `json:"observation"` + Thought string `json:"thought"` + Summary string `json:"summary"` +} + +type WingsBaseResp struct { + StatusCode int `json:"StatusCode"` + StatusMessage string `json:"StatusMessage"` + Extra WingsExtra `json:"Extra"` +} + +type WingsExtra struct { + CostTime string `json:"cost_time"` + LogID string `json:"_log_id"` +} + +// Action parameter structures +type WingsActionParams struct { + Type string `json:"Type"` + Params interface{} `json:"Params"` + Bounds [][]float64 `json:"Bounds"` + UiDict interface{} `json:"UiDict"` + UiIndex string `json:"UiIndex"` +} + +type WingsTapParams struct { + X float64 `json:"x"` + Y float64 `json:"y"` +} + +type WingsDoubleTapParams struct { + X float64 `json:"x"` + Y float64 `json:"y"` +} + +type WingsLongPressParams struct { + X float64 `json:"x"` + Y float64 `json:"y"` + Duration float64 `json:"duration"` +} + +type WingsSwipeParams struct { + FromX float64 `json:"from_x"` + FromY float64 `json:"from_y"` + ToX float64 `json:"to_x"` + ToY float64 `json:"to_y"` + Duration float64 `json:"duration"` +} + +type WingsTextParams struct { + Text string `json:"text"` +} + +// Helper methods + +// generateWingsUUID generates a random UUID for LogID +func generateWingsUUID() string { + return uuid.New().String() +} + +// extractScreenshotFromMessage extracts base64 screenshot from message +func (w *WingsService) extractScreenshotFromMessage(message *schema.Message) (string, error) { + if message == nil || len(message.MultiContent) == 0 { + return "", errors.New("no message content found") + } + + for _, content := range message.MultiContent { + if content.Type == schema.ChatMessagePartTypeImageURL && content.ImageURL != nil { + // Extract base64 data from data URL + screenshot := content.ImageURL.URL + if strings.HasPrefix(screenshot, "data:image/") { + // Remove data URL prefix + parts := strings.Split(screenshot, ",") + if len(parts) == 2 { + return parts[1], nil + } + } + return screenshot, nil + } + } + + return "", errors.New("no image found in message") +} + +// getDeviceInfoFromContext gets device info from context with fallback +func (w *WingsService) getDeviceInfoFromContext(ctx context.Context, screenshot string) WingsDeviceInfo { + // Try to get device info from context + if deviceID, ok := ctx.Value("device_id").(string); ok { + platformType := "android" + if platform, ok := ctx.Value("platform_type").(string); ok { + platformType = platform + } + + return WingsDeviceInfo{ + DeviceID: deviceID, + NowImage: screenshot, + PreImage: screenshot, + NowLayoutJSON: "", + OperationSystem: platformType, + } + } + + // Fallback to default device info + return WingsDeviceInfo{ + DeviceID: "default-device", + NowImage: screenshot, + PreImage: screenshot, + NowLayoutJSON: "", + OperationSystem: "android", + } +} + +// getDeviceInfoFromScreenshot gets device info from screenshot (for Assert) +func (w *WingsService) getDeviceInfoFromScreenshot(ctx context.Context, screenshot string) WingsDeviceInfo { + return w.getDeviceInfoFromContext(ctx, screenshot) +} + +// cleanScreenshotDataURL removes data URL prefix from screenshot string +func (w *WingsService) cleanScreenshotDataURL(screenshot string) string { + if strings.HasPrefix(screenshot, "data:image/") { + // Remove data URL prefix like "data:image/jpeg;base64," + parts := strings.Split(screenshot, ",") + if len(parts) == 2 { + return parts[1] + } + } + return screenshot +} + +// callWingsAPI calls the external Wings API +func (w *WingsService) callWingsAPI(ctx context.Context, request WingsActionRequest) (*WingsActionResponse, error) { + // Marshal request to JSON + requestBody, err := json.Marshal(request) + if err != nil { + return nil, errors.Wrap(err, "marshal request failed") + } + + // Create HTTP request + httpReq, err := http.NewRequestWithContext(ctx, "POST", w.apiURL, bytes.NewBuffer(requestBody)) + if err != nil { + return nil, errors.Wrap(err, "create HTTP request failed") + } + + // Set headers + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Accept", "application/json") + + // Execute HTTP request + client := &http.Client{ + Timeout: 30 * time.Second, + } + + resp, err := client.Do(httpReq) + if err != nil { + return nil, errors.Wrap(err, "HTTP request failed") + } + defer resp.Body.Close() + + // Read response body + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, errors.Wrap(err, "read response body failed") + } + + // Check HTTP status + if resp.StatusCode != 200 { + return nil, fmt.Errorf("HTTP request failed with status %d: %s", resp.StatusCode, string(responseBody)) + } + + // Parse response + var apiResponse WingsActionResponse + if err := json.Unmarshal(responseBody, &apiResponse); err != nil { + return nil, errors.Wrap(err, "unmarshal response failed") + } + + return &apiResponse, nil +} + +// convertWingsResponseToToolCalls converts Wings API response to tool calls using generic approach +func (w *WingsService) convertWingsResponseToToolCalls(actionParamsStr string) ([]schema.ToolCall, error) { + if actionParamsStr == "" { + return []schema.ToolCall{}, nil + } + + var actionParams WingsActionParams + if err := json.Unmarshal([]byte(actionParamsStr), &actionParams); err != nil { + return nil, fmt.Errorf("parse action params failed: %w", err) + } + + // Use Wings API Type as tool name directly + toolName := actionParams.Type + params := actionParams.Params + + // Create tool call using generic method + toolCall, err := w.createToolCall(toolName, params) + if err != nil { + return nil, fmt.Errorf("create tool call for %s failed: %w", toolName, err) + } + + return []schema.ToolCall{toolCall}, nil +} + +// createToolCall creates a generic tool call with given name and arguments +func (w *WingsService) createToolCall(toolName string, params interface{}) (schema.ToolCall, error) { + // Convert params to arguments map + arguments := make(map[string]interface{}) + + if params != nil { + // Try to convert params to map[string]interface{} + switch p := params.(type) { + case map[string]interface{}: + arguments = p + case string: + // If params is a string, try to unmarshal it as JSON + if err := json.Unmarshal([]byte(p), &arguments); err != nil { + // If not JSON, treat as simple text parameter + arguments["text"] = p + } + default: + // For other types, try to marshal and unmarshal + paramsBytes, err := json.Marshal(params) + if err != nil { + return schema.ToolCall{}, fmt.Errorf("marshal params failed: %w", err) + } + if err := json.Unmarshal(paramsBytes, &arguments); err != nil { + // If unmarshal fails, create a generic params field + arguments["params"] = params + } + } + } + + // Convert arguments to JSON string + argumentsJSON, err := json.Marshal(arguments) + if err != nil { + return schema.ToolCall{}, fmt.Errorf("marshal arguments failed: %w", err) + } + + // Generate unique tool call ID + toolCallID := fmt.Sprintf("call_%s", uuid.New().String()[:8]) + + return schema.ToolCall{ + ID: toolCallID, + Function: schema.FunctionCall{ + Name: toolName, + Arguments: string(argumentsJSON), + }, + }, nil +} + +// parseAssertionResult parses the assertion result from action_params +func (w *WingsService) parseAssertionResult(actionParamsStr string, thoughtChain WingsThoughtChain) (bool, string, error) { + // Parse action parameters JSON + var actionParams map[string]interface{} + if err := json.Unmarshal([]byte(actionParamsStr), &actionParams); err != nil { + return false, "", errors.Wrap(err, "parse action params failed") + } + + // Extract action_type from the parsed JSON + actionType, exists := actionParams["action_type"] + if !exists { + // If no action_type field, try to parse nested structure + if totalRes, ok := actionParams["total_res"].([]interface{}); ok && len(totalRes) > 0 { + if firstRes, ok := totalRes[0].(map[string]interface{}); ok { + if actionParamsNested, ok := firstRes["action_params"].(map[string]interface{}); ok { + if nestedActionType, ok := actionParamsNested["action_type"]; ok { + actionType = nestedActionType + } + } + } + } + } + + // Default to failed if no action_type found + if actionType == nil { + return false, thoughtChain.Summary, nil + } + + // Convert action_type to string and check result + actionTypeStr, ok := actionType.(string) + if !ok { + return false, thoughtChain.Summary, nil + } + + // Determine assertion result based on action_type + passed := strings.ToLower(actionTypeStr) == "passed" + + // Use thoughtChain.Summary as the assertion thought + assertionThought := thoughtChain.Summary + if assertionThought == "" { + assertionThought = thoughtChain.Thought + } + if assertionThought == "" { + assertionThought = thoughtChain.Observation + } + + log.Info(). + Str("action_type", actionTypeStr). + Bool("passed", passed). + Str("thought", assertionThought). + Msg("parsed Wings assertion result") + + return passed, assertionThought, nil +} diff --git a/uixt/android_test.go b/uixt/android_test.go index 8f9e1fad..0d5d418d 100644 --- a/uixt/android_test.go +++ b/uixt/android_test.go @@ -24,7 +24,7 @@ func setupADBDriverExt(t *testing.T) *XTDriver { require.Nil(t, err) driverExt, err := NewXTDriver(driver, option.WithCVService(option.CVServiceTypeVEDEM), - option.WithLLMService(option.DOUBAO_1_5_THINKING_VISION_PRO_250428), + // option.WithLLMService(option.DOUBAO_1_5_THINKING_VISION_PRO_250428), ) require.Nil(t, err) return driverExt diff --git a/uixt/driver_ext_ai.go b/uixt/driver_ext_ai.go index 84ca4c6d..5243afc8 100644 --- a/uixt/driver_ext_ai.go +++ b/uixt/driver_ext_ai.go @@ -2,6 +2,8 @@ package uixt import ( "context" + "encoding/json" + "fmt" "time" "github.com/cloudwego/eino/schema" @@ -9,12 +11,12 @@ import ( "github.com/rs/zerolog/log" "github.com/httprunner/httprunner/v5/code" - "github.com/httprunner/httprunner/v5/internal/json" "github.com/httprunner/httprunner/v5/uixt/ai" "github.com/httprunner/httprunner/v5/uixt/option" "github.com/httprunner/httprunner/v5/uixt/types" ) +// StartToGoal (original implementation - preserved) func (dExt *XTDriver) StartToGoal(ctx context.Context, prompt string, opts ...option.ActionOption) ([]*PlanningExecutionResult, error) { options := option.NewActionOptions(opts...) logger := log.Info().Str("prompt", prompt) @@ -193,7 +195,7 @@ func (dExt *XTDriver) StartToGoal(ctx context.Context, prompt string, opts ...op } } -// AIAction performs AI-driven action and returns detailed execution result +// AIAction with WingsService priority support func (dExt *XTDriver) AIAction(ctx context.Context, prompt string, opts ...option.ActionOption) (*AIExecutionResult, error) { log.Info().Str("prompt", prompt).Msg("performing AI action") @@ -206,25 +208,93 @@ func (dExt *XTDriver) AIAction(ctx context.Context, prompt string, opts ...optio return nil, err } - // Step 2: Plan next action and measure time + // Step 2: Check if WingsService is available and prioritize it + if dExt.WingsService != nil { + log.Info().Msg("using Wings service for AI action") + return dExt.executeAIAction(ctx, prompt, screenResult, dExt.WingsService, "wings", opts...) + } + + // Step 3: Fallback to LLM service + if dExt.LLMService == nil { + return nil, errors.New("neither Wings service nor LLM service is initialized") + } + + log.Info().Msg("using LLM service for AI action") + return dExt.executeAIAction(ctx, prompt, screenResult, dExt.LLMService, "llm", opts...) +} + +// executeAIAction executes AIAction using any AI service (generic implementation) +func (dExt *XTDriver) executeAIAction(ctx context.Context, prompt string, screenResult *ScreenResult, service ai.ILLMService, serviceType string, opts ...option.ActionOption) (*AIExecutionResult, error) { + // Add device context for Wings service if needed + if serviceType == "wings" { + ctx = dExt.addDeviceContextForWings(ctx) + } + + // Step 1: Plan next action and measure time modelCallStartTime := time.Now() - planningResult, err := dExt.PlanNextAction(ctx, prompt, opts...) + + var planningResult *ai.PlanningResult + var err error + + if serviceType == "llm" { + // For LLM service, use PlanNextAction which includes additional processing + planningExecutionResult, planErr := dExt.PlanNextAction(ctx, prompt, opts...) + if planErr != nil { + modelCallElapsed := time.Since(modelCallStartTime).Milliseconds() + return &AIExecutionResult{ + Type: "action", + ModelCallElapsed: modelCallElapsed, + ScreenshotElapsed: screenResult.Elapsed, + ImagePath: screenResult.ImagePath, + Resolution: &screenResult.Resolution, + Error: planErr.Error(), + }, errors.Wrap(planErr, "get next action failed") + } + planningResult = &planningExecutionResult.PlanningResult + } else { + // For Wings service, call Plan directly + planningOpts := &ai.PlanningOptions{ + UserInstruction: prompt, + Message: &schema.Message{ + Role: schema.User, + MultiContent: []schema.ChatMessagePart{ + { + Type: schema.ChatMessagePartTypeImageURL, + ImageURL: &schema.ChatMessageImageURL{ + URL: screenResult.Base64, + }, + }, + }, + }, + Size: screenResult.Resolution, + } + + planningResult, err = service.Plan(ctx, planningOpts) + if err != nil { + modelCallElapsed := time.Since(modelCallStartTime).Milliseconds() + return &AIExecutionResult{ + Type: "action", + ModelCallElapsed: modelCallElapsed, + ScreenshotElapsed: screenResult.Elapsed, + ImagePath: screenResult.ImagePath, + Resolution: &screenResult.Resolution, + Error: err.Error(), + }, errors.Wrap(err, fmt.Sprintf("%s service planning failed", serviceType)) + } + } + modelCallElapsed := time.Since(modelCallStartTime).Milliseconds() + aiExecutionResult := &AIExecutionResult{ Type: "action", ModelCallElapsed: modelCallElapsed, ScreenshotElapsed: screenResult.Elapsed, ImagePath: screenResult.ImagePath, Resolution: &screenResult.Resolution, - PlanningResult: &planningResult.PlanningResult, + PlanningResult: planningResult, } - if err != nil { - aiExecutionResult.Error = err.Error() - return aiExecutionResult, errors.Wrap(err, "get next action failed") - } - - // Step 3: Execute tool calls + // Step 2: Execute tool calls for _, toolCall := range planningResult.ToolCalls { err = dExt.invokeToolCall(ctx, toolCall, opts...) if err != nil { @@ -239,7 +309,99 @@ func (dExt *XTDriver) AIAction(ctx context.Context, prompt string, opts ...optio return aiExecutionResult, nil } -// PlanNextAction performs planning and returns unified planning information +// AIAssert with WingsService priority support +func (dExt *XTDriver) AIAssert(assertion string, opts ...option.ActionOption) (*AIExecutionResult, error) { + log.Info().Str("assertion", assertion).Msg("performing AI assertion") + + // Step 1: Take screenshot and convert to base64 + screenResult, err := dExt.GetScreenResult( + option.WithScreenShotFileName("ai_assert"), + option.WithScreenShotBase64(true), + ) + if err != nil { + return nil, err + } + + // Step 2: Check if WingsService is available and prioritize it + if dExt.WingsService != nil { + log.Info().Msg("using Wings service for AI assertion") + return dExt.executeAIAssert(assertion, screenResult, dExt.WingsService, "wings", opts...) + } + + // Step 3: Fallback to LLM service + if dExt.LLMService == nil { + return nil, errors.New("neither Wings service nor LLM service is initialized") + } + + log.Info().Msg("using LLM service for AI assertion") + return dExt.executeAIAssert(assertion, screenResult, dExt.LLMService, "llm", opts...) +} + +// executeAIAssert executes AIAssert using any AI service (generic implementation) +func (dExt *XTDriver) executeAIAssert(assertion string, screenResult *ScreenResult, service ai.ILLMService, serviceType string, opts ...option.ActionOption) (*AIExecutionResult, error) { + // Step 1: Prepare context and options + ctx := context.Background() + if serviceType == "wings" { + ctx = dExt.addDeviceContextForWings(ctx) + } + + assertResult := &AIExecutionResult{ + Type: "assert", + ScreenshotElapsed: screenResult.Elapsed, + ImagePath: screenResult.ImagePath, + Resolution: &screenResult.Resolution, + } + + // Step 2: Call service and measure time + modelCallStartTime := time.Now() + assertOpts := &ai.AssertOptions{ + Assertion: assertion, + Screenshot: screenResult.Base64, + Size: screenResult.Resolution, + } + + result, err := service.Assert(ctx, assertOpts) + assertResult.ModelCallElapsed = time.Since(modelCallStartTime).Milliseconds() + assertResult.AssertionResult = result + + if err != nil { + assertResult.Error = err.Error() + return assertResult, errors.Wrap(err, fmt.Sprintf("%s assertion failed", serviceType)) + } + + if !result.Pass { + assertResult.Error = result.Thought + } + + return assertResult, nil +} + +// addDeviceContextForWings adds device information to context for Wings service +func (dExt *XTDriver) addDeviceContextForWings(ctx context.Context) context.Context { + device := dExt.GetDevice() + if device == nil { + return ctx + } + + // Add device ID to context + ctx = context.WithValue(ctx, "device_id", device.UUID()) + + // Add platform type to context + platformType := "android" // default + switch device.(type) { + case *AndroidDevice: + platformType = "android" + case *IOSDevice: + platformType = "ios" + case *HarmonyDevice: + platformType = "harmony" + } + ctx = context.WithValue(ctx, "platform_type", platformType) + + return ctx +} + +// PlanNextAction (original implementation - preserved) func (dExt *XTDriver) PlanNextAction(ctx context.Context, prompt string, opts ...option.ActionOption) (*PlanningExecutionResult, error) { if dExt.LLMService == nil { return nil, errors.New("LLM service is not initialized") @@ -314,7 +476,7 @@ func (dExt *XTDriver) PlanNextAction(ctx context.Context, prompt string, opts .. return planningResult, nil } -// isTaskFinished checks if the task is completed based on the planning result +// isTaskFinished (original implementation - preserved) func (dExt *XTDriver) isTaskFinished(planningResult *PlanningExecutionResult) bool { // Check if there are no tool calls (no actions to execute) if len(planningResult.ToolCalls) == 0 { @@ -333,7 +495,7 @@ func (dExt *XTDriver) isTaskFinished(planningResult *PlanningExecutionResult) bo return false } -// invokeToolCall invokes the tool call +// invokeToolCall (original implementation - preserved) func (dExt *XTDriver) invokeToolCall(ctx context.Context, toolCall schema.ToolCall, opts ...option.ActionOption) error { // Parse arguments arguments := make(map[string]interface{}) @@ -360,7 +522,7 @@ func (dExt *XTDriver) invokeToolCall(ctx context.Context, toolCall schema.ToolCa return nil } -// PlanningExecutionResult represents a unified planning result that contains both planning information and execution results +// PlanningExecutionResult (original implementation - preserved) type PlanningExecutionResult struct { ai.PlanningResult // Inherit all fields from ai.PlanningResult (ToolCalls, Thought, Content, Error, ModelName) // Planning process information @@ -377,7 +539,7 @@ type PlanningExecutionResult struct { SubActions []*SubActionResult `json:"sub_actions,omitempty"` // sub-actions generated from this planning } -// AIExecutionResult represents a unified result structure for all AI operations +// AIExecutionResult (original implementation - preserved) type AIExecutionResult struct { Type string `json:"type"` // operation type: "query", "action", "assert" ModelCallElapsed int64 `json:"model_call_elapsed"` // model call elapsed time in milliseconds @@ -394,7 +556,7 @@ type AIExecutionResult struct { Error string `json:"error,omitempty"` // error message if operation failed } -// SubActionResult represents a sub-action within a start_to_goal action +// SubActionResult (original implementation - preserved) type SubActionResult struct { ActionName string `json:"action_name"` // name of the sub-action (e.g., "tap", "input") Arguments interface{} `json:"arguments,omitempty"` // arguments passed to the sub-action @@ -409,6 +571,7 @@ type SessionData struct { ScreenResults []*ScreenResult `json:"screen_results,omitempty"` // store sub-action specific screen_results } +// AIQuery (original implementation - preserved) func (dExt *XTDriver) AIQuery(text string, opts ...option.ActionOption) (*AIExecutionResult, error) { if dExt.LLMService == nil { return nil, errors.New("LLM service is not initialized") @@ -453,50 +616,3 @@ func (dExt *XTDriver) AIQuery(text string, opts ...option.ActionOption) (*AIExec } return aiResult, nil } - -// AIAssert performs AI-driven assertion and returns detailed execution result -func (dExt *XTDriver) AIAssert(assertion string, opts ...option.ActionOption) (*AIExecutionResult, error) { - if dExt.LLMService == nil { - return nil, errors.New("LLM service is not initialized") - } - - // Step 1: Take screenshot and convert to base64 - screenResult, err := dExt.GetScreenResult( - option.WithScreenShotFileName("ai_assert"), - option.WithScreenShotBase64(true), - ) - if err != nil { - return nil, err - } - - assertResult := &AIExecutionResult{ - Type: "assert", - ScreenshotElapsed: screenResult.Elapsed, - ImagePath: screenResult.ImagePath, - Resolution: &screenResult.Resolution, - } - - // Step 2: Call model and measure time - modelCallStartTime := time.Now() - assertOpts := &ai.AssertOptions{ - Assertion: assertion, - Screenshot: screenResult.Base64, - Size: screenResult.Resolution, - } - result, err := dExt.LLMService.Assert(context.Background(), assertOpts) - assertResult.ModelCallElapsed = time.Since(modelCallStartTime).Milliseconds() - assertResult.AssertionResult = result - - if err != nil { - assertResult.Error = err.Error() - return assertResult, errors.Wrap(err, "AI assertion failed") - } - - // For assertion failure, we should still return success but mark the assertion as failed - // This ensures that the AIResult (including screenshot and thought) is properly saved and displayed - if !result.Pass { - assertResult.Error = result.Thought // Store the failure reason for reporting - } - - return assertResult, nil -} diff --git a/uixt/driver_ext_ai_test.go b/uixt/driver_ext_ai_test.go index 3167c6cf..83904cb4 100644 --- a/uixt/driver_ext_ai_test.go +++ b/uixt/driver_ext_ai_test.go @@ -7,10 +7,11 @@ import ( "testing" "github.com/cloudwego/eino/schema" + "github.com/stretchr/testify/assert" + "github.com/httprunner/httprunner/v5/uixt/ai" "github.com/httprunner/httprunner/v5/uixt/option" "github.com/httprunner/httprunner/v5/uixt/types" - "github.com/stretchr/testify/assert" ) func TestDriverExt_TapByLLM(t *testing.T) { @@ -22,33 +23,33 @@ func TestDriverExt_TapByLLM(t *testing.T) { assert.Nil(t, err) } -func TestDriverExt_StartToGoal(t *testing.T) { - driver := setupDriverExt(t) - - userInstruction := `连连看是一款经典的益智消除类小游戏,通常以图案或图标为主要元素。以下是连连看的基本规则说明: - 1. 游戏目标: 玩家需要在规定时间内,通过连接相同的图案或图标,将它们从游戏界面中消除。 - 2. 连接规则: - - 两个相同的图案可以通过不超过三条直线连接。 - - 连接线可以水平或垂直,但不能斜线,也不能跨过其他图案。 - - 连接线的转折次数不能超过两次。 - 3. 游戏界面: - - 游戏界面通常是一个矩形区域,内含多个图案或图标,排列成行和列。 - - 图案或图标在未选中状态下背景为白色,选中状态下背景为绿色。 - 4. 时间限制: 游戏通常设有时间限制,玩家需要在时间耗尽前完成所有图案的消除。 - 5. 得分机制: 每成功连接并消除一对图案,玩家会获得相应的分数。完成游戏后,根据剩余时间和消除效率计算总分。 - 6. 关卡设计: 游戏可能包含多个关卡,随着关卡的推进,图案的复杂度和数量会增加。 - - 注意事项: - 1、当连接错误时,顶部的红心会减少一个,需及时调整策略,避免红心变为0个后游戏失败 - 2、不要连续 2 次点击同一个图案 - 3、不要犯重复的错误 - ` - - userInstruction += "\n\n请严格按照以上游戏规则,开始游戏;注意,请只做点击操作" - - _, err := driver.StartToGoal(context.Background(), userInstruction) - assert.Nil(t, err) -} +//func TestDriverExt_StartToGoal(t *testing.T) { +// driver := setupDriverExt(t) +// +// userInstruction := `连连看是一款经典的益智消除类小游戏,通常以图案或图标为主要元素。以下是连连看的基本规则说明: +// 1. 游戏目标: 玩家需要在规定时间内,通过连接相同的图案或图标,将它们从游戏界面中消除。 +// 2. 连接规则: +// - 两个相同的图案可以通过不超过三条直线连接。 +// - 连接线可以水平或垂直,但不能斜线,也不能跨过其他图案。 +// - 连接线的转折次数不能超过两次。 +// 3. 游戏界面: +// - 游戏界面通常是一个矩形区域,内含多个图案或图标,排列成行和列。 +// - 图案或图标在未选中状态下背景为白色,选中状态下背景为绿色。 +// 4. 时间限制: 游戏通常设有时间限制,玩家需要在时间耗尽前完成所有图案的消除。 +// 5. 得分机制: 每成功连接并消除一对图案,玩家会获得相应的分数。完成游戏后,根据剩余时间和消除效率计算总分。 +// 6. 关卡设计: 游戏可能包含多个关卡,随着关卡的推进,图案的复杂度和数量会增加。 +// +// 注意事项: +// 1、当连接错误时,顶部的红心会减少一个,需及时调整策略,避免红心变为0个后游戏失败 +// 2、不要连续 2 次点击同一个图案 +// 3、不要犯重复的错误 +// ` +// +// userInstruction += "\n\n请严格按照以上游戏规则,开始游戏;注意,请只做点击操作" +// +// //_, err := driver.StartToGoal(context.Background(), userInstruction) +// //assert.Nil(t, err) +//} func TestDriverExt_PlanNextAction(t *testing.T) { driver := setupDriverExt(t) @@ -244,3 +245,241 @@ func TestPlanningOptions_ResetHistory(t *testing.T) { assert.True(t, opts.ResetHistory) assert.Equal(t, "test instruction", opts.UserInstruction) } + +// TestDriverExt_AIAction tests the AIAction method integration with real driver +func TestDriverExt_AIAction(t *testing.T) { + driver := setupDriverExt(t) + + // Test AIAction with search button click prompt + result, err := driver.AIAction(context.Background(), "冷启动抖音app") + + // Verify no error occurred + assert.Nil(t, err, "AIAction should execute without error") + + // Verify result is not nil + assert.NotNil(t, result, "AIAction should return a result") + + // Verify result has correct type + assert.Equal(t, "action", result.Type, "Result type should be 'action'") + + // Verify timing information is captured + assert.Greater(t, result.ModelCallElapsed, int64(0), "Model call should have elapsed time") + assert.Greater(t, result.ScreenshotElapsed, int64(0), "Screenshot should have elapsed time") + + // Verify screenshot information is captured + assert.NotEmpty(t, result.ImagePath, "Image path should not be empty") + assert.NotNil(t, result.Resolution, "Resolution should not be nil") + assert.Greater(t, result.Resolution.Width, 0, "Width should be greater than 0") + assert.Greater(t, result.Resolution.Height, 0, "Height should be greater than 0") + + // Verify planning result is captured + assert.NotNil(t, result.PlanningResult, "Planning result should not be nil") + assert.Equal(t, "wings-api", result.PlanningResult.ModelName, "Model name should be 'wings-api'") + // Log result for debugging + t.Logf("AIAction executed successfully:") + t.Logf(" Type: %s", result.Type) + t.Logf(" Model Call Elapsed: %d ms", result.ModelCallElapsed) + t.Logf(" Screenshot Elapsed: %d ms", result.ScreenshotElapsed) + t.Logf(" Image Path: %s", result.ImagePath) + t.Logf(" Resolution: %dx%d", result.Resolution.Width, result.Resolution.Height) + + if result.Error != "" { + t.Logf(" Error: %s", result.Error) + } +} + +// TestDriverExt_AIAction_CompareWithAIAction compares AIAction with AIAction +func TestDriverExt_AIAction_CompareWithAIAction(t *testing.T) { + driver := setupDriverExt(t) + + prompt := "点击搜索按钮" + + // Test both methods with the same prompt + wingsResult, wingsErr := driver.AIAction(context.Background(), prompt) + aiResult, aiErr := driver.AIAction(context.Background(), prompt) + + // Both should execute without critical errors (may have different implementations) + t.Logf("AIAction error: %v", wingsErr) + t.Logf("AIAction error: %v", aiErr) + + // If both succeed, compare results + if wingsResult != nil && aiResult != nil { + assert.Equal(t, "action", wingsResult.Type, "AIAction result type should be 'action'") + assert.Equal(t, "action", aiResult.Type, "AIAction result type should be 'action'") + + // Both should have timing information + assert.Greater(t, wingsResult.ModelCallElapsed, int64(0), "AIAction should have model call elapsed time") + assert.Greater(t, aiResult.ModelCallElapsed, int64(0), "AIAction should have model call elapsed time") + + // Both should have screenshot information + assert.NotEmpty(t, wingsResult.ImagePath, "AIAction should have image path") + assert.NotEmpty(t, aiResult.ImagePath, "AIAction should have image path") + + // Compare model names + if wingsResult.PlanningResult != nil && aiResult.PlanningResult != nil { + t.Logf("AIAction model: %s", wingsResult.PlanningResult.ModelName) + t.Logf("AIAction model: %s", aiResult.PlanningResult.ModelName) + + assert.Equal(t, "wings-api", wingsResult.PlanningResult.ModelName, "AIAction should use wings-api") + assert.NotEqual(t, "wings-api", aiResult.PlanningResult.ModelName, "AIAction should not use wings-api") + } + } +} + +// TestDriverExt_AIAction_ErrorHandling tests AIAction error handling +func TestDriverExt_AIAction_ErrorHandling(t *testing.T) { + driver := setupDriverExt(t) + + // Test with empty prompt + result, err := driver.AIAction(context.Background(), "") + + // Should handle empty prompt gracefully + if err != nil { + t.Logf("Empty prompt error (expected): %v", err) + assert.NotNil(t, result, "Result should still be returned even on error") + if result != nil { + assert.NotEmpty(t, result.Error, "Result should contain error message") + } + } else { + t.Logf("Empty prompt handled successfully") + assert.NotNil(t, result, "Result should be returned") + } + + // Test with very long prompt + longPrompt := "这是一个非常长的提示词,用来测试AIAction是否能够正确处理长文本输入。" + + "我们需要确保API能够处理各种长度的输入,包括这种可能超出某些限制的文本。" + + "请在当前界面中寻找任何可能的搜索相关的按钮或输入框,然后进行点击操作。" + + result2, err2 := driver.AIAction(context.Background(), longPrompt) + + // Should handle long prompt + if err2 != nil { + t.Logf("Long prompt error: %v", err2) + } else { + t.Logf("Long prompt handled successfully") + assert.NotNil(t, result2, "Result should be returned for long prompt") + assert.Equal(t, "action", result2.Type, "Result type should be 'action'") + } +} + +// TestDriverExt_AIAssert tests the AIAssert method integration with real driver +func TestDriverExt_AIAssert(t *testing.T) { + driver := setupDriverExt(t) + + // Test AIAssert with assertion about search button + result, err := driver.AIAssert("屏幕中存在搜索按钮") + + // Verify no error occurred (or error is captured in result) + if err != nil { + t.Logf("AIAssert error: %v", err) + // For assertion failures, error is expected, but result should still be returned + assert.NotNil(t, result, "AIAssert should return a result even on assertion failure") + } else { + assert.NotNil(t, result, "AIAssert should return a result") + } + + // Verify result has correct type + assert.Equal(t, "assert", result.Type, "Result type should be 'assert'") + + // Verify timing information is captured + assert.Greater(t, result.ModelCallElapsed, int64(0), "Model call should have elapsed time") + assert.Greater(t, result.ScreenshotElapsed, int64(0), "Screenshot should have elapsed time") + + // Verify screenshot information is captured + assert.NotEmpty(t, result.ImagePath, "Image path should not be empty") + assert.NotNil(t, result.Resolution, "Resolution should not be nil") + assert.Greater(t, result.Resolution.Width, 0, "Width should be greater than 0") + assert.Greater(t, result.Resolution.Height, 0, "Height should be greater than 0") + + // Verify assertion result is captured + assert.NotNil(t, result.AssertionResult, "Assertion result should not be nil") + assert.NotEmpty(t, result.AssertionResult.Thought, "Assertion result thought should not be empty") + + // Log result for debugging + t.Logf("AIAssert executed:") + t.Logf(" Type: %s", result.Type) + t.Logf(" Model Call Elapsed: %d ms", result.ModelCallElapsed) + t.Logf(" Screenshot Elapsed: %d ms", result.ScreenshotElapsed) + t.Logf(" Image Path: %s", result.ImagePath) + t.Logf(" Resolution: %dx%d", result.Resolution.Width, result.Resolution.Height) + t.Logf(" Assertion Pass: %t", result.AssertionResult.Pass) + t.Logf(" Assertion Thought: %s", result.AssertionResult.Thought) + + if result.Error != "" { + t.Logf(" Error: %s", result.Error) + } +} + +// TestDriverExt_AIAssert_CompareWithAIAssert compares AIAssert with AIAssert +func TestDriverExt_AIAssert_CompareWithAIAssert(t *testing.T) { + driver := setupDriverExt(t) + + assertion := "屏幕中存在搜索按钮" + + // Test both methods with the same assertion + wingsResult, wingsErr := driver.AIAssert(assertion) + aiResult, aiErr := driver.AIAssert(assertion) + + // Both should execute (may have different results) + t.Logf("AIAssert error: %v", wingsErr) + t.Logf("AIAssert error: %v", aiErr) + + // If both succeed, compare results + if wingsResult != nil && aiResult != nil { + assert.Equal(t, "assert", wingsResult.Type, "AIAssert result type should be 'assert'") + assert.Equal(t, "assert", aiResult.Type, "AIAssert result type should be 'assert'") + + // Both should have timing information + assert.Greater(t, wingsResult.ModelCallElapsed, int64(0), "AIAssert should have model call elapsed time") + assert.Greater(t, aiResult.ModelCallElapsed, int64(0), "AIAssert should have model call elapsed time") + + // Both should have screenshot information + assert.NotEmpty(t, wingsResult.ImagePath, "AIAssert should have image path") + assert.NotEmpty(t, aiResult.ImagePath, "AIAssert should have image path") + + // Both should have assertion results + assert.NotNil(t, wingsResult.AssertionResult, "AIAssert should have assertion result") + assert.NotNil(t, aiResult.AssertionResult, "AIAssert should have assertion result") + + // Log comparison + t.Logf("AIAssert Pass: %t, Thought: %s", wingsResult.AssertionResult.Pass, wingsResult.AssertionResult.Thought) + t.Logf("AIAssert Pass: %t, Thought: %s", aiResult.AssertionResult.Pass, aiResult.AssertionResult.Thought) + } +} + +// TestDriverExt_AIAssert_ErrorHandling tests AIAssert error handling +func TestDriverExt_AIAssert_ErrorHandling(t *testing.T) { + driver := setupDriverExt(t) + + // Test with empty assertion + result, err := driver.AIAssert("") + + // Should handle empty assertion gracefully + if err != nil { + t.Logf("Empty assertion error (may be expected): %v", err) + assert.NotNil(t, result, "Result should still be returned even on error") + if result != nil { + assert.NotEmpty(t, result.Error, "Result should contain error message") + } + } else { + t.Logf("Empty assertion handled successfully") + assert.NotNil(t, result, "Result should be returned") + } + + // Test with complex assertion + complexAssertion := "断言:当前屏幕显示的是主页面,包含用户头像、搜索框、导航栏等关键元素,并且没有任何错误提示信息" + + result2, err2 := driver.AIAssert(complexAssertion) + + // Should handle complex assertion + if err2 != nil { + t.Logf("Complex assertion result: %v", err2) + } else { + t.Logf("Complex assertion handled successfully") + assert.NotNil(t, result2, "Result should be returned for complex assertion") + assert.Equal(t, "assert", result2.Type, "Result type should be 'assert'") + if result2.AssertionResult != nil { + t.Logf("Assertion passed: %t", result2.AssertionResult.Pass) + } + } +} diff --git a/uixt/driver_ext_screenshot.go b/uixt/driver_ext_screenshot.go index f6500b65..b5bb07f9 100644 --- a/uixt/driver_ext_screenshot.go +++ b/uixt/driver_ext_screenshot.go @@ -176,9 +176,6 @@ func (dExt *XTDriver) GetScreenTexts(opts ...option.ActionOption) (ocrTexts ai.O func (dExt *XTDriver) FindScreenText(text string, opts ...option.ActionOption) (textRect ai.OCRText, err error) { options := option.NewActionOptions(opts...) - if options.ScreenShotFileName == "" { - opts = append(opts, option.WithScreenShotFileName(fmt.Sprintf("find_screen_text_%s", text))) - } // convert relative scope to absolute scope if options.AbsScope == nil && len(options.Scope) == 4 { diff --git a/uixt/driver_ext_tap.go b/uixt/driver_ext_tap.go index 0b36afcd..03722395 100644 --- a/uixt/driver_ext_tap.go +++ b/uixt/driver_ext_tap.go @@ -2,17 +2,20 @@ package uixt import ( "fmt" + "time" + + "github.com/rs/zerolog/log" "github.com/httprunner/httprunner/v5/uixt/ai" "github.com/httprunner/httprunner/v5/uixt/option" - "github.com/rs/zerolog/log" ) func (dExt *XTDriver) TapByOCR(text string, opts ...option.ActionOption) error { actionOptions := option.NewActionOptions(opts...) log.Info().Str("text", text).Interface("options", actionOptions).Msg("TapByOCR") + if actionOptions.ScreenShotFileName == "" { - opts = append(opts, option.WithScreenShotFileName(fmt.Sprintf("tap_by_ocr_%s", text))) + opts = append(opts, option.WithScreenShotFileName(fmt.Sprintf("%s_tap_by_ocr_%s", dExt.GetDevice().UUID(), time.Now().Format("20060102150405")))) } textRect, err := dExt.FindScreenText(text, opts...) diff --git a/uixt/mcp_server.go b/uixt/mcp_server.go index 346ede50..bedaf885 100644 --- a/uixt/mcp_server.go +++ b/uixt/mcp_server.go @@ -71,6 +71,7 @@ func (s *MCPServer4XTDriver) registerTools() { s.registerTool(&ToolSelectDevice{}) // SelectDevice // Touch Tools + s.registerTool(&ToolTap{}) // tap s.registerTool(&ToolTapXY{}) // tap xy s.registerTool(&ToolTapAbsXY{}) // tap abs xy s.registerTool(&ToolTapByOCR{}) // tap by OCR @@ -88,6 +89,8 @@ func (s *MCPServer4XTDriver) registerTools() { // Input Tools s.registerTool(&ToolInput{}) + s.registerTool(&ToolText{}) + s.registerTool(&ToolBackspace{}) s.registerTool(&ToolSetIme{}) // Button Tools @@ -98,7 +101,10 @@ func (s *MCPServer4XTDriver) registerTools() { // App Tools s.registerTool(&ToolListPackages{}) // ListPackages s.registerTool(&ToolLaunchApp{}) // LaunchApp + s.registerTool(&ToolOpenApp{}) // OpenApp s.registerTool(&ToolTerminateApp{}) // TerminateApp + s.registerTool(&ToolTerminateAppNew{}) // TerminateApp (new) + s.registerTool(&ToolColdLaunch{}) // ColdLaunch s.registerTool(&ToolAppInstall{}) // AppInstall s.registerTool(&ToolAppUninstall{}) // AppUninstall s.registerTool(&ToolAppClear{}) // AppClear diff --git a/uixt/mcp_server_test.go b/uixt/mcp_server_test.go index ba102556..5e85b400 100644 --- a/uixt/mcp_server_test.go +++ b/uixt/mcp_server_test.go @@ -79,6 +79,7 @@ func TestToolInterfaces(t *testing.T) { tools := []ActionTool{ &ToolListAvailableDevices{}, &ToolSelectDevice{}, + &ToolTap{}, &ToolTapXY{}, &ToolTapAbsXY{}, &ToolTapByOCR{}, @@ -92,6 +93,8 @@ func TestToolInterfaces(t *testing.T) { &ToolSwipeToTapTexts{}, &ToolDrag{}, &ToolInput{}, + &ToolText{}, + &ToolBackspace{}, &ToolScreenShot{}, &ToolGetScreenSize{}, &ToolPressButton{}, @@ -99,7 +102,10 @@ func TestToolInterfaces(t *testing.T) { &ToolBack{}, &ToolListPackages{}, &ToolLaunchApp{}, + &ToolOpenApp{}, &ToolTerminateApp{}, + &ToolTerminateAppNew{}, + &ToolColdLaunch{}, &ToolAppInstall{}, &ToolAppUninstall{}, &ToolAppClear{}, @@ -240,6 +246,45 @@ func TestToolSelectDevice(t *testing.T) { assert.Equal(t, string(option.ACTION_SelectDevice), request.Params.Name) } +// TestToolTap tests the ToolTap implementation +func TestToolTap(t *testing.T) { + tool := &ToolTap{} + + // Test Name + assert.Equal(t, option.ACTION_Tap, tool.Name()) + + // Test Description + assert.NotEmpty(t, tool.Description()) + + // Test Options + options := tool.Options() + assert.NotNil(t, options) + + // Test ConvertActionToCallToolRequest with valid params + action := option.MobileAction{ + Method: option.ACTION_Tap, + Params: []float64{0.5, 0.6}, + ActionOptions: option.ActionOptions{ + Duration: 1.5, + }, + } + request, err := tool.ConvertActionToCallToolRequest(action) + assert.NoError(t, err) + assert.Equal(t, string(option.ACTION_Tap), request.Params.Name) + args := request.GetArguments() + assert.Equal(t, 0.5, args["x"]) + assert.Equal(t, 0.6, args["y"]) + assert.Equal(t, 1.5, args["duration"]) + + // Test ConvertActionToCallToolRequest with invalid params + invalidAction := option.MobileAction{ + Method: option.ACTION_Tap, + Params: "invalid", + } + _, err = tool.ConvertActionToCallToolRequest(invalidAction) + assert.Error(t, err) +} + // TestToolTapXY tests the ToolTapXY implementation func TestToolTapXY(t *testing.T) { tool := &ToolTapXY{} @@ -782,6 +827,74 @@ func TestToolInput(t *testing.T) { assert.Equal(t, "Hello World", request.GetArguments()["text"]) } +// TestToolText tests the ToolText implementation +func TestToolText(t *testing.T) { + tool := &ToolText{} + + // Test Name + assert.Equal(t, option.ACTION_Text, tool.Name()) + + // Test Description + assert.NotEmpty(t, tool.Description()) + + // Test Options + options := tool.Options() + assert.NotNil(t, options) + + // Test ConvertActionToCallToolRequest with valid params + action := option.MobileAction{ + Method: option.ACTION_Text, + Params: "Hello World", + } + request, err := tool.ConvertActionToCallToolRequest(action) + assert.NoError(t, err) + assert.Equal(t, string(option.ACTION_Text), request.Params.Name) + assert.Equal(t, "Hello World", request.GetArguments()["text"]) +} + +// TestToolBackspace tests the ToolBackspace implementation +func TestToolBackspace(t *testing.T) { + tool := &ToolBackspace{} + + // Test Name + assert.Equal(t, option.ACTION_Backspace, tool.Name()) + + // Test Description + assert.NotEmpty(t, tool.Description()) + + // Test Options + options := tool.Options() + assert.NotNil(t, options) + + // Test ConvertActionToCallToolRequest with valid int params + action := option.MobileAction{ + Method: option.ACTION_Backspace, + Params: 3, + } + request, err := tool.ConvertActionToCallToolRequest(action) + assert.NoError(t, err) + assert.Equal(t, string(option.ACTION_Backspace), request.Params.Name) + assert.Equal(t, 3, request.GetArguments()["count"]) + + // Test ConvertActionToCallToolRequest with float64 params + actionFloat := option.MobileAction{ + Method: option.ACTION_Backspace, + Params: 5.0, + } + requestFloat, err := tool.ConvertActionToCallToolRequest(actionFloat) + assert.NoError(t, err) + assert.Equal(t, 5, requestFloat.GetArguments()["count"]) + + // Test ConvertActionToCallToolRequest with invalid params (should default to 1) + invalidAction := option.MobileAction{ + Method: option.ACTION_Backspace, + Params: "invalid", + } + requestDefault, err := tool.ConvertActionToCallToolRequest(invalidAction) + assert.NoError(t, err) + assert.Equal(t, 1, requestDefault.GetArguments()["count"]) +} + // TestToolScreenShot tests the ToolScreenShot implementation func TestToolScreenShot(t *testing.T) { tool := &ToolScreenShot{} @@ -973,6 +1086,39 @@ func TestToolLaunchApp(t *testing.T) { assert.Error(t, err) } +// TestToolOpenApp tests the ToolOpenApp implementation +func TestToolOpenApp(t *testing.T) { + tool := &ToolOpenApp{} + + // Test Name + assert.Equal(t, option.ACTION_OpenApp, tool.Name()) + + // Test Description + assert.NotEmpty(t, tool.Description()) + + // Test Options + options := tool.Options() + assert.NotNil(t, options) + + // Test ConvertActionToCallToolRequest with valid params + action := option.MobileAction{ + Method: option.ACTION_OpenApp, + Params: "com.example.app", + } + request, err := tool.ConvertActionToCallToolRequest(action) + assert.NoError(t, err) + assert.Equal(t, string(option.ACTION_OpenApp), request.Params.Name) + assert.Equal(t, "com.example.app", request.GetArguments()["packageName"]) + + // Test ConvertActionToCallToolRequest with invalid params + invalidAction := option.MobileAction{ + Method: option.ACTION_OpenApp, + Params: 123, // should be string + } + _, err = tool.ConvertActionToCallToolRequest(invalidAction) + assert.Error(t, err) +} + // TestToolTerminateApp tests the ToolTerminateApp implementation func TestToolTerminateApp(t *testing.T) { tool := &ToolTerminateApp{} @@ -1006,6 +1152,72 @@ func TestToolTerminateApp(t *testing.T) { assert.Error(t, err) } +// TestToolTerminateAppNew tests the ToolTerminateAppNew implementation +func TestToolTerminateAppNew(t *testing.T) { + tool := &ToolTerminateAppNew{} + + // Test Name + assert.Equal(t, option.ACTION_TerminateApp, tool.Name()) + + // Test Description + assert.NotEmpty(t, tool.Description()) + + // Test Options + options := tool.Options() + assert.NotNil(t, options) + + // Test ConvertActionToCallToolRequest with valid params + action := option.MobileAction{ + Method: option.ACTION_TerminateApp, + Params: "com.example.app", + } + request, err := tool.ConvertActionToCallToolRequest(action) + assert.NoError(t, err) + assert.Equal(t, string(option.ACTION_TerminateApp), request.Params.Name) + assert.Equal(t, "com.example.app", request.GetArguments()["packageName"]) + + // Test ConvertActionToCallToolRequest with invalid params + invalidAction := option.MobileAction{ + Method: option.ACTION_TerminateApp, + Params: []int{1, 2, 3}, // should be string + } + _, err = tool.ConvertActionToCallToolRequest(invalidAction) + assert.Error(t, err) +} + +// TestToolColdLaunch tests the ToolColdLaunch implementation +func TestToolColdLaunch(t *testing.T) { + tool := &ToolColdLaunch{} + + // Test Name + assert.Equal(t, option.ACTION_ColdLaunch, tool.Name()) + + // Test Description + assert.NotEmpty(t, tool.Description()) + + // Test Options + options := tool.Options() + assert.NotNil(t, options) + + // Test ConvertActionToCallToolRequest with valid params + action := option.MobileAction{ + Method: option.ACTION_ColdLaunch, + Params: "com.example.app", + } + request, err := tool.ConvertActionToCallToolRequest(action) + assert.NoError(t, err) + assert.Equal(t, string(option.ACTION_ColdLaunch), request.Params.Name) + assert.Equal(t, "com.example.app", request.GetArguments()["packageName"]) + + // Test ConvertActionToCallToolRequest with invalid params + invalidAction := option.MobileAction{ + Method: option.ACTION_ColdLaunch, + Params: 123, // should be string + } + _, err = tool.ConvertActionToCallToolRequest(invalidAction) + assert.Error(t, err) +} + // TestToolAppInstall tests the ToolAppInstall implementation func TestToolAppInstall(t *testing.T) { tool := &ToolAppInstall{} diff --git a/uixt/mcp_tools_app.go b/uixt/mcp_tools_app.go index 1ad61843..1a269492 100644 --- a/uixt/mcp_tools_app.go +++ b/uixt/mcp_tools_app.go @@ -3,6 +3,7 @@ package uixt import ( "context" "fmt" + "time" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" @@ -393,3 +394,195 @@ func (t *ToolGetForegroundApp) Implement() server.ToolHandlerFunc { func (t *ToolGetForegroundApp) ConvertActionToCallToolRequest(action option.MobileAction) (mcp.CallToolRequest, error) { return BuildMCPCallToolRequest(t.Name(), map[string]any{}, action), nil } + +// ToolOpenApp implements the open_app tool call. +type ToolOpenApp struct { + // Return data fields - these define the structure of data returned by this tool + PackageName string `json:"packageName" desc:"Package name of the opened app"` +} + +func (t *ToolOpenApp) Name() option.ActionName { + return option.ACTION_OpenApp +} + +func (t *ToolOpenApp) Description() string { + return "Open an app on mobile device using its package name and wait for the app to load" +} + +func (t *ToolOpenApp) Options() []mcp.ToolOption { + unifiedReq := &option.ActionOptions{} + return unifiedReq.GetMCPOptions(option.ACTION_OpenApp) +} + +func (t *ToolOpenApp) Implement() server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + arguments := request.GetArguments() + driverExt, err := setupXTDriver(ctx, arguments) + if err != nil { + return nil, fmt.Errorf("setup driver failed: %w", err) + } + + unifiedReq, err := parseActionOptions(arguments) + if err != nil { + return nil, err + } + + if unifiedReq.PackageName == "" { + return nil, fmt.Errorf("package_name is required") + } + + // Open app action logic + err = driverExt.AppLaunch(unifiedReq.PackageName) + if err != nil { + return NewMCPErrorResponse(fmt.Sprintf("Open app failed: %s", err.Error())), err + } + + message := fmt.Sprintf("Successfully opened app: %s", unifiedReq.PackageName) + returnData := ToolOpenApp{PackageName: unifiedReq.PackageName} + + return NewMCPSuccessResponse(message, &returnData), nil + } +} + +func (t *ToolOpenApp) ConvertActionToCallToolRequest(action option.MobileAction) (mcp.CallToolRequest, error) { + if packageName, ok := action.Params.(string); ok { + arguments := map[string]any{ + "packageName": packageName, + } + return BuildMCPCallToolRequest(t.Name(), arguments, action), nil + } + return mcp.CallToolRequest{}, fmt.Errorf("invalid open app params: %v", action.Params) +} + +// ToolTerminateAppNew implements the terminal_app tool call. +type ToolTerminateAppNew struct { + // Return data fields - these define the structure of data returned by this tool + PackageName string `json:"packageName" desc:"Package name of the terminated app"` + WasRunning bool `json:"wasRunning" desc:"Whether the app was actually running before termination"` +} + +func (t *ToolTerminateAppNew) Name() option.ActionName { + return option.ACTION_TerminateApp +} + +func (t *ToolTerminateAppNew) Description() string { + return "Terminate a running app on mobile device using its package name" +} + +func (t *ToolTerminateAppNew) Options() []mcp.ToolOption { + unifiedReq := &option.ActionOptions{} + return unifiedReq.GetMCPOptions(option.ACTION_TerminateApp) +} + +func (t *ToolTerminateAppNew) Implement() server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + arguments := request.GetArguments() + driverExt, err := setupXTDriver(ctx, arguments) + if err != nil { + return nil, fmt.Errorf("setup driver failed: %w", err) + } + + unifiedReq, err := parseActionOptions(arguments) + if err != nil { + return nil, err + } + + if unifiedReq.PackageName == "" { + return nil, fmt.Errorf("package_name is required") + } + + // Terminate app action logic + success, err := driverExt.AppTerminate(unifiedReq.PackageName) + if err != nil { + return NewMCPErrorResponse(fmt.Sprintf("Terminate app failed: %s", err.Error())), err + } + if !success { + log.Warn().Str("packageName", unifiedReq.PackageName).Msg("app was not running") + } + + message := fmt.Sprintf("Successfully terminated app: %s", unifiedReq.PackageName) + returnData := ToolTerminateAppNew{ + PackageName: unifiedReq.PackageName, + WasRunning: success, + } + + return NewMCPSuccessResponse(message, &returnData), nil + } +} + +func (t *ToolTerminateAppNew) ConvertActionToCallToolRequest(action option.MobileAction) (mcp.CallToolRequest, error) { + if packageName, ok := action.Params.(string); ok { + arguments := map[string]any{ + "packageName": packageName, + } + return BuildMCPCallToolRequest(t.Name(), arguments, action), nil + } + return mcp.CallToolRequest{}, fmt.Errorf("invalid terminate app params: %v", action.Params) +} + +// ToolColdLaunch implements the cold_launch tool call. +type ToolColdLaunch struct { + // Return data fields - these define the structure of data returned by this tool + PackageName string `json:"packageName" desc:"Package name of the cold launched app"` +} + +func (t *ToolColdLaunch) Name() option.ActionName { + return option.ACTION_ColdLaunch +} + +func (t *ToolColdLaunch) Description() string { + return "Perform a cold launch of an app (terminate first if running, then launch)" +} + +func (t *ToolColdLaunch) Options() []mcp.ToolOption { + unifiedReq := &option.ActionOptions{} + return unifiedReq.GetMCPOptions(option.ACTION_ColdLaunch) +} + +func (t *ToolColdLaunch) Implement() server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + arguments := request.GetArguments() + driverExt, err := setupXTDriver(ctx, arguments) + if err != nil { + return nil, fmt.Errorf("setup driver failed: %w", err) + } + + unifiedReq, err := parseActionOptions(arguments) + if err != nil { + return nil, err + } + + if unifiedReq.PackageName == "" { + return nil, fmt.Errorf("package_name is required") + } + + // Cold launch logic: terminate first, then launch + // First try to terminate the app (ignore errors if app is not running) + _, err = driverExt.AppTerminate(unifiedReq.PackageName) + if err != nil { + log.Warn().Str("packageName", unifiedReq.PackageName).Msg("app was not running") + return NewMCPErrorResponse(fmt.Sprintf("Cold launch failed, terminate app failed: %s", err.Error())), err + } + time.Sleep(3 * time.Second) + // Then launch the app + err = driverExt.AppLaunch(unifiedReq.PackageName) + if err != nil { + return NewMCPErrorResponse(fmt.Sprintf("Cold launch failed, launch app failed: %s", err.Error())), err + } + + message := fmt.Sprintf("Successfully cold launched app: %s", unifiedReq.PackageName) + returnData := ToolColdLaunch{PackageName: unifiedReq.PackageName} + + return NewMCPSuccessResponse(message, &returnData), nil + } +} + +func (t *ToolColdLaunch) ConvertActionToCallToolRequest(action option.MobileAction) (mcp.CallToolRequest, error) { + if packageName, ok := action.Params.(string); ok { + arguments := map[string]any{ + "packageName": packageName, + } + return BuildMCPCallToolRequest(t.Name(), arguments, action), nil + } + return mcp.CallToolRequest{}, fmt.Errorf("invalid cold launch params: %v", action.Params) +} diff --git a/uixt/mcp_tools_input.go b/uixt/mcp_tools_input.go index 764b8f4a..cda75467 100644 --- a/uixt/mcp_tools_input.go +++ b/uixt/mcp_tools_input.go @@ -123,3 +123,131 @@ func (t *ToolSetIme) ConvertActionToCallToolRequest(action option.MobileAction) } return mcp.CallToolRequest{}, fmt.Errorf("invalid set ime params: %v", action.Params) } + +// ToolText implements the text tool call. +type ToolText struct { + // Return data fields - these define the structure of data returned by this tool + Text string `json:"text" desc:"Text that was input"` +} + +func (t *ToolText) Name() option.ActionName { + return option.ACTION_Text +} + +func (t *ToolText) Description() string { + return "Input text into the currently focused element or input field" +} + +func (t *ToolText) Options() []mcp.ToolOption { + unifiedReq := &option.ActionOptions{} + return unifiedReq.GetMCPOptions(option.ACTION_Text) +} + +func (t *ToolText) Implement() server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + arguments := request.GetArguments() + driverExt, err := setupXTDriver(ctx, arguments) + if err != nil { + return nil, fmt.Errorf("setup driver failed: %w", err) + } + + unifiedReq, err := parseActionOptions(arguments) + if err != nil { + return nil, err + } + + if unifiedReq.Text == "" { + return nil, fmt.Errorf("text is required") + } + + opts := unifiedReq.Options() + + // Text input action logic + err = driverExt.Input(unifiedReq.Text, opts...) + if err != nil { + return NewMCPErrorResponse(fmt.Sprintf("Text input failed: %s", err.Error())), err + } + + message := fmt.Sprintf("Successfully input text: %s", unifiedReq.Text) + returnData := ToolText{Text: unifiedReq.Text} + + return NewMCPSuccessResponse(message, &returnData), nil + } +} + +func (t *ToolText) ConvertActionToCallToolRequest(action option.MobileAction) (mcp.CallToolRequest, error) { + text := fmt.Sprintf("%v", action.Params) + arguments := map[string]any{ + "text": text, + } + return BuildMCPCallToolRequest(t.Name(), arguments, action), nil +} + +// ToolBackspace implements the backspace tool call. +type ToolBackspace struct { + // Return data fields - these define the structure of data returned by this tool + Count int `json:"count" desc:"Number of backspace operations performed"` +} + +func (t *ToolBackspace) Name() option.ActionName { + return option.ACTION_Backspace +} + +func (t *ToolBackspace) Description() string { + return "Perform backspace operations to delete characters" +} + +func (t *ToolBackspace) Options() []mcp.ToolOption { + unifiedReq := &option.ActionOptions{} + return unifiedReq.GetMCPOptions(option.ACTION_Backspace) +} + +func (t *ToolBackspace) Implement() server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + arguments := request.GetArguments() + driverExt, err := setupXTDriver(ctx, arguments) + if err != nil { + return nil, fmt.Errorf("setup driver failed: %w", err) + } + + unifiedReq, err := parseActionOptions(arguments) + if err != nil { + return nil, err + } + + count := unifiedReq.Count + if count <= 0 { + count = 1 // Default to 1 backspace if not specified or invalid + } + + opts := unifiedReq.Options() + + // Backspace action logic + err = driverExt.Backspace(count, opts...) + if err != nil { + return NewMCPErrorResponse(fmt.Sprintf("Backspace failed: %s", err.Error())), err + } + + message := fmt.Sprintf("Successfully performed %d backspace operations", count) + returnData := ToolBackspace{Count: count} + + return NewMCPSuccessResponse(message, &returnData), nil + } +} + +func (t *ToolBackspace) ConvertActionToCallToolRequest(action option.MobileAction) (mcp.CallToolRequest, error) { + var count int + switch v := action.Params.(type) { + case int: + count = v + case float64: + count = int(v) + default: + count = 1 // Default count + } + + arguments := map[string]any{ + "count": count, + } + return BuildMCPCallToolRequest(t.Name(), arguments, action), nil +} diff --git a/uixt/mcp_tools_touch.go b/uixt/mcp_tools_touch.go index f78d7ef1..e43678ba 100644 --- a/uixt/mcp_tools_touch.go +++ b/uixt/mcp_tools_touch.go @@ -84,6 +84,79 @@ func (t *ToolTapXY) ConvertActionToCallToolRequest(action option.MobileAction) ( return mcp.CallToolRequest{}, fmt.Errorf("invalid tap params: %v", action.Params) } +// ToolTap implements the tap tool call. +type ToolTap struct { + // Return data fields - these define the structure of data returned by this tool + X float64 `json:"x" desc:"X coordinate where tap was performed"` + Y float64 `json:"y" desc:"Y coordinate where tap was performed"` +} + +func (t *ToolTap) Name() option.ActionName { + return option.ACTION_Tap +} + +func (t *ToolTap) Description() string { + return "Tap on the screen at given relative coordinates (0.0-1.0 range)" +} + +func (t *ToolTap) Options() []mcp.ToolOption { + unifiedReq := &option.ActionOptions{} + return unifiedReq.GetMCPOptions(option.ACTION_Tap) +} + +func (t *ToolTap) Implement() server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + arguments := request.GetArguments() + driverExt, err := setupXTDriver(ctx, arguments) + if err != nil { + return nil, fmt.Errorf("setup driver failed: %w", err) + } + + unifiedReq, err := parseActionOptions(arguments) + if err != nil { + return nil, err + } + + // Build all options from request arguments + opts := unifiedReq.Options() + + // Validate required parameters + if unifiedReq.X == 0 || unifiedReq.Y == 0 { + return nil, fmt.Errorf("x and y coordinates are required") + } + + // Tap action logic + err = driverExt.TapXY(unifiedReq.X, unifiedReq.Y, opts...) + if err != nil { + return NewMCPErrorResponse(fmt.Sprintf("Tap failed: %s", err.Error())), err + } + + message := fmt.Sprintf("Successfully tapped at coordinates (%.2f, %.2f)", unifiedReq.X, unifiedReq.Y) + returnData := ToolTap{ + X: unifiedReq.X, + Y: unifiedReq.Y, + } + + return NewMCPSuccessResponse(message, &returnData), nil + } +} + +func (t *ToolTap) ConvertActionToCallToolRequest(action option.MobileAction) (mcp.CallToolRequest, error) { + if params, err := builtin.ConvertToFloat64Slice(action.Params); err == nil && len(params) == 2 { + x, y := params[0], params[1] + arguments := map[string]any{ + "x": x, + "y": y, + } + // Add duration if available from action options + if duration := action.ActionOptions.Duration; duration > 0 { + arguments["duration"] = duration + } + return BuildMCPCallToolRequest(t.Name(), arguments, action), nil + } + return mcp.CallToolRequest{}, fmt.Errorf("invalid tap params: %v", action.Params) +} + // ToolTapAbsXY implements the tap_abs_xy tool call. type ToolTapAbsXY struct { // Return data fields - these define the structure of data returned by this tool diff --git a/uixt/option/action.go b/uixt/option/action.go index 83ec760e..6521ba3c 100644 --- a/uixt/option/action.go +++ b/uixt/option/action.go @@ -43,7 +43,10 @@ const ( ACTION_AppClear ActionName = "app_clear" ACTION_AppStart ActionName = "app_start" ACTION_AppLaunch ActionName = "app_launch" // 启动 app 并堵塞等待 app 首屏加载完成 + ACTION_OpenApp ActionName = "open_app" // 启动 app 并堵塞等待 app 首屏加载完成 ACTION_AppTerminate ActionName = "app_terminate" + ACTION_TerminateApp ActionName = "terminal_app" + ACTION_ColdLaunch ActionName = "cold_launch" ACTION_AppStop ActionName = "app_stop" ACTION_ScreenShot ActionName = "screenshot" ACTION_ScreenRecord ActionName = "screenrecord" @@ -70,6 +73,7 @@ const ( ACTION_SwipeCoordinate ActionName = "swipe_coordinate" // swipe by coordinates (fromX, fromY, toX, toY) ACTION_Drag ActionName = "drag" ACTION_Input ActionName = "input" + ACTION_Text ActionName = "text" ACTION_PressButton ActionName = "press_button" ACTION_Back ActionName = "back" ACTION_KeyCode ActionName = "keycode" @@ -601,6 +605,7 @@ func WithOutputSchema(schema interface{}) ActionOption { func (o *ActionOptions) GetMCPOptions(actionType ActionName) []mcp.ToolOption { // Define field mappings for different action types fieldMappings := map[ActionName][]string{ + ACTION_Tap: {"platform", "serial", "x", "y", "duration"}, ACTION_TapXY: {"platform", "serial", "x", "y", "duration"}, ACTION_TapAbsXY: {"platform", "serial", "x", "y", "duration"}, ACTION_TapByOCR: {"platform", "serial", "text", "ignoreNotFoundError", "maxRetryTimes", "index", "regex", "tapRandomRect"}, @@ -611,8 +616,13 @@ func (o *ActionOptions) GetMCPOptions(actionType ActionName) []mcp.ToolOption { ACTION_Swipe: {"platform", "serial", "direction", "fromX", "fromY", "toX", "toY", "duration", "pressDuration"}, ACTION_Drag: {"platform", "serial", "fromX", "fromY", "toX", "toY", "duration", "pressDuration"}, ACTION_Input: {"platform", "serial", "text", "frequency"}, + ACTION_Text: {"platform", "serial", "text", "frequency"}, + ACTION_Backspace: {"platform", "serial", "count"}, ACTION_AppLaunch: {"platform", "serial", "packageName"}, + ACTION_OpenApp: {"platform", "serial", "packageName"}, ACTION_AppTerminate: {"platform", "serial", "packageName"}, + ACTION_TerminateApp: {"platform", "serial", "packageName"}, + ACTION_ColdLaunch: {"platform", "serial", "packageName"}, ACTION_AppInstall: {"platform", "serial", "appUrl", "packageName"}, ACTION_AppUninstall: {"platform", "serial", "packageName"}, ACTION_AppClear: {"platform", "serial", "packageName"}, diff --git a/uixt/sdk.go b/uixt/sdk.go index 32f153b0..8d105bbc 100644 --- a/uixt/sdk.go +++ b/uixt/sdk.go @@ -27,29 +27,41 @@ func NewXTDriver(driver IDriver, opts ...option.AIServiceOption) (*XTDriver, err var err error + // Initialize Wings service (always available) + driverExt.WingsService = ai.NewWingsService() + log.Info().Msg("Wings service initialized") + // Handle LLM service initialization if services.LLMConfig != nil { // Use advanced LLM configuration if provided driverExt.LLMService, err = ai.NewLLMServiceWithOptionConfig(services.LLMConfig) if err != nil { - return nil, errors.Wrap(err, "init llm service with config failed") + log.Warn().Err(err).Msg("init llm service with config failed, Wings service will be used") + } else { + log.Info().Msg("LLM service initialized with advanced config") } } else if services.LLMService != "" { // Fallback to simple LLM service if no config provided driverExt.LLMService, err = ai.NewLLMService(services.LLMService) if err != nil { - return nil, errors.Wrap(err, "init llm service failed") + log.Warn().Err(err).Msg("init llm service failed, Wings service will be used") + } else { + log.Info().Msg("LLM service initialized") } } else { - log.Warn().Msg("no LLM service config provided") + log.Info().Msg("no LLM service config provided, using Wings service only") } // Register uixt MCP tools to LLM service if it exists + mcpTools := driverExt.client.Server.ListTools() + einoTools := ai.ConvertMCPToolsToEinoToolInfos(mcpTools, "uixt") + if err = driverExt.WingsService.RegisterTools(einoTools); err != nil { + log.Debug().Err(err).Msg("Wings service ignoring tool registration (expected)") + } + if driverExt.LLMService != nil { - mcpTools := driverExt.client.Server.ListTools() - einoTools := ai.ConvertMCPToolsToEinoToolInfos(mcpTools, "uixt") - if err := driverExt.LLMService.RegisterTools(einoTools); err != nil { - log.Warn().Err(err).Msg("failed to register uixt tools") + if err = driverExt.LLMService.RegisterTools(einoTools); err != nil { + log.Warn().Err(err).Msg("failed to register uixt tools to LLM service") } } @@ -59,8 +71,9 @@ func NewXTDriver(driver IDriver, opts ...option.AIServiceOption) (*XTDriver, err // XTDriver = IDriver + AI type XTDriver struct { IDriver - CVService ai.ICVService // OCR/CV - LLMService ai.ILLMService // LLM + CVService ai.ICVService // OCR/CV + LLMService ai.ILLMService // LLM (fallback service) + WingsService ai.ILLMService // Wings API service (priority service) services *option.AIServiceOptions // AI services options client *MCPClient4XTDriver // MCP Client for built-in uixt server