diff --git a/pkg/uixt/driver_session.go b/pkg/uixt/driver_session.go index 50497ccb..b4d1ece0 100644 --- a/pkg/uixt/driver_session.go +++ b/pkg/uixt/driver_session.go @@ -110,26 +110,65 @@ func (s *DriverSession) History() []*DriverRequests { } func (s *DriverSession) concatURL(elem ...string) (string, error) { + if len(elem) == 0 { + if s.baseUrl == "" { + return "", fmt.Errorf("base URL is empty") + } + return s.baseUrl, nil + } + + // 处理完整 URL + if strings.HasPrefix(elem[0], "http://") || strings.HasPrefix(elem[0], "https://") { + u, err := url.Parse(elem[0]) + if err != nil { + return "", fmt.Errorf("failed to parse URL: %w", err) + } + if len(elem) > 1 { + u.Path = path.Join(u.Path, path.Join(elem[1:]...)) + } + return u.String(), nil + } + + // 处理相对路径 if s.baseUrl == "" { return "", fmt.Errorf("base URL is empty") } - u, err := url.Parse(s.baseUrl) if err != nil { return "", fmt.Errorf("failed to parse base URL: %w", err) } - // 分离路径和查询参数 - lastElem := elem[len(elem)-1] - parts := strings.SplitN(lastElem, "?", 2) - elem[len(elem)-1] = parts[0] + // 保存原始查询参数 + baseQuery := u.Query() - // 合并基础路径 - u.Path = path.Join(append([]string{u.Path}, elem...)...) + // 处理路径和查询参数 + var paths []string + for i, e := range elem { + if i == len(elem)-1 { + // 处理最后一个元素的查询参数 + parts := strings.SplitN(e, "?", 2) + paths = append(paths, parts[0]) + if len(parts) > 1 { + newQuery, err := url.ParseQuery(parts[1]) + if err != nil { + return "", fmt.Errorf("failed to parse query params: %w", err) + } + // 合并查询参数 + for k, v := range newQuery { + baseQuery[k] = v + } + } + } else { + paths = append(paths, e) + } + } - // 如果有查询参数,添加到 URL - if len(parts) > 1 { - u.RawQuery = parts[1] + // 合并路径 + u.Path = path.Join(append([]string{u.Path}, paths...)...) + + // 设置合并后的查询参数 + if len(baseQuery) > 0 { + u.RawQuery = baseQuery.Encode() } return u.String(), nil diff --git a/pkg/uixt/driver_session_test.go b/pkg/uixt/driver_session_test.go index ea0582e3..a846cb6d 100644 --- a/pkg/uixt/driver_session_test.go +++ b/pkg/uixt/driver_session_test.go @@ -6,6 +6,99 @@ import ( "github.com/stretchr/testify/assert" ) +func TestDriverSession_concatURL(t *testing.T) { + tests := []struct { + name string + baseUrl string + elem []string + want string + wantErr bool + errMsg string + }{ + { + name: "empty elements with empty base url", + baseUrl: "", + elem: []string{}, + wantErr: true, + errMsg: "base URL is empty", + }, + { + name: "empty elements with valid base url", + baseUrl: "http://localhost:8080", + elem: []string{}, + want: "http://localhost:8080", + }, + { + name: "absolute url in first element", + baseUrl: "http://localhost:8080", + elem: []string{"https://example.com/api", "users"}, + want: "https://example.com/api/users", + }, + { + name: "invalid absolute url", + baseUrl: "http://localhost:8080", + elem: []string{"http://[invalid-url", "users"}, + wantErr: true, + errMsg: "failed to parse URL", + }, + { + name: "relative path with empty base url", + baseUrl: "", + elem: []string{"api", "users"}, + wantErr: true, + errMsg: "base URL is empty", + }, + { + name: "relative path with invalid base url", + baseUrl: "http://[invalid-url", + elem: []string{"api", "users"}, + wantErr: true, + errMsg: "failed to parse base URL", + }, + { + name: "relative path with query params", + baseUrl: "http://localhost:8080", + elem: []string{"api", "users?id=1&name=test"}, + want: "http://localhost:8080/api/users?id=1&name=test", + }, + { + name: "base url with query params", + baseUrl: "http://localhost:8080?token=123", + elem: []string{"api", "users?id=1"}, + want: "http://localhost:8080/api/users?id=1&token=123", + }, + { + name: "invalid query params", + baseUrl: "http://localhost:8080", + elem: []string{"api", "users?id=%invalid"}, + wantErr: true, + errMsg: "failed to parse query params", + }, + { + name: "multiple path segments", + baseUrl: "http://localhost:8080", + elem: []string{"api", "v1", "users", "profile"}, + want: "http://localhost:8080/api/v1/users/profile", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &DriverSession{baseUrl: tt.baseUrl} + got, err := s.concatURL(tt.elem...) + + if tt.wantErr { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.errMsg) + return + } + + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + func TestDriverSession(t *testing.T) { session := NewDriverSession() session.SetBaseURL("https://postman-echo.com") @@ -23,4 +116,8 @@ func TestDriverSession(t *testing.T) { session.Reset() driverRequests = session.History() assert.Equal(t, 0, len(driverRequests)) + + resp, err = session.GET("https://postman-echo.com/get") + assert.Nil(t, err) + t.Log(resp) }