diff --git a/frontend/src/components/ConnectionModal.tsx b/frontend/src/components/ConnectionModal.tsx index 7f9efa3..c55c0b2 100644 --- a/frontend/src/components/ConnectionModal.tsx +++ b/frontend/src/components/ConnectionModal.tsx @@ -1259,24 +1259,30 @@ const ConnectionModal: React.FC<{ ? await RedisConnect(config as any) : await TestConnection(config as any); - if (res.success) { - setTestResult({ type: 'success', message: res.message }); - if (isRedisType) { - setRedisDbList(Array.from({ length: 16 }, (_, i) => i)); - } else { - // Other databases: fetch database list - const dbRes = await DBGetDatabases(config as any); - if (dbRes.success) { - const dbRows = Array.isArray(dbRes.data) ? dbRes.data : []; - const dbs = dbRows - .map((row: any) => row?.Database || row?.database) - .filter((name: any) => typeof name === 'string' && name.trim() !== ''); - setDbList(dbs); + if (res.success) { + setTestResult({ type: 'success', message: res.message }); + if (isRedisType) { + setRedisDbList(Array.from({ length: 16 }, (_, i) => i)); } else { - setDbList([]); + // Other databases: fetch database list + const dbRes = await DBGetDatabases(config as any); + if (dbRes.success) { + const dbRows = Array.isArray(dbRes.data) ? dbRes.data : []; + const dbs = dbRows + .map((row: any) => row?.Database || row?.database) + .filter((name: any) => typeof name === 'string' && name.trim() !== ''); + setDbList(dbs); + if (dbs.length === 0) { + message.warning(values.type === 'dameng' + ? '连接成功,但未获取到可见 schema;请检查当前账号权限或默认 schema 配置' + : '连接成功,但未获取到可见数据库列表'); + } + } else { + setDbList([]); + message.warning(`连接成功,但获取数据库列表失败:${dbRes.message || '未知错误'}`); + } } - } - } else { + } else { const failMessage = buildTestFailureMessage( res?.message, '连接被拒绝或参数无效,请检查后重试' diff --git a/internal/db/dameng_impl.go b/internal/db/dameng_impl.go index 5cceb0a..1cf27e6 100644 --- a/internal/db/dameng_impl.go +++ b/internal/db/dameng_impl.go @@ -8,7 +8,6 @@ import ( "fmt" "net" "net/url" - "sort" "strconv" "strings" "time" @@ -205,80 +204,9 @@ func (d *DamengDB) Exec(query string) (int64, error) { } func (d *DamengDB) GetDatabases() ([]string, error) { - // 达梦将「用户/模式」作为数据库列表来源,不同权限下可见口径不同。 - // 这里采用多查询口径聚合,避免仅依赖单一视图导致“少库”。 - queries := []string{ - "SELECT USERNAME AS DATABASE_NAME FROM SYS.DBA_USERS ORDER BY USERNAME", - "SELECT USERNAME AS DATABASE_NAME FROM DBA_USERS ORDER BY USERNAME", - "SELECT USERNAME AS DATABASE_NAME FROM ALL_USERS ORDER BY USERNAME", - "SELECT USERNAME AS DATABASE_NAME FROM USER_USERS", - "SELECT DISTINCT OWNER AS DATABASE_NAME FROM ALL_TABLES ORDER BY OWNER", - } - - seen := make(map[string]struct{}) - dbs := make([]string, 0, 64) - var lastErr error - success := false - - for _, q := range queries { - data, _, err := d.Query(q) - if err != nil { - lastErr = err - continue - } - success = true - for _, row := range data { - name := getDamengRowString(row, "DATABASE_NAME", "USERNAME", "OWNER", "SCHEMA_NAME") - if name == "" { - // 回退到第一列,兼容驱动返回列名差异。 - for _, v := range row { - text := strings.TrimSpace(fmt.Sprintf("%v", v)) - if text == "" || strings.EqualFold(text, "") { - continue - } - name = text - break - } - } - if name == "" { - continue - } - key := strings.ToUpper(name) - if _, ok := seen[key]; ok { - continue - } - seen[key] = struct{}{} - dbs = append(dbs, name) - } - } - - if !success && lastErr != nil { - return nil, lastErr - } - - sort.Slice(dbs, func(i, j int) bool { - return strings.ToUpper(dbs[i]) < strings.ToUpper(dbs[j]) - }) - return dbs, nil -} - -func getDamengRowString(row map[string]interface{}, keys ...string) string { - if len(row) == 0 { - return "" - } - for _, key := range keys { - for k, v := range row { - if !strings.EqualFold(strings.TrimSpace(k), strings.TrimSpace(key)) { - continue - } - text := strings.TrimSpace(fmt.Sprintf("%v", v)) - if text == "" || strings.EqualFold(text, "") { - return "" - } - return text - } - } - return "" + // 达梦在本项目中将 schema/owner 作为“数据库”展示口径。 + // 先查当前 schema / 当前用户,再聚合可见用户与 owner,避免权限受限时返回空列表。 + return collectDamengDatabaseNames(d.Query) } func (d *DamengDB) GetTables(dbName string) ([]string, error) { diff --git a/internal/db/dameng_metadata.go b/internal/db/dameng_metadata.go new file mode 100644 index 0000000..c963da1 --- /dev/null +++ b/internal/db/dameng_metadata.go @@ -0,0 +1,91 @@ +package db + +import ( + "fmt" + "sort" + "strings" +) + +var damengDatabaseQueries = []string{ + "SELECT SYS_CONTEXT('USERENV', 'CURRENT_SCHEMA') AS DATABASE_NAME FROM DUAL", + "SELECT SYS_CONTEXT('USERENV', 'CURRENT_USER') AS DATABASE_NAME FROM DUAL", + "SELECT USERNAME AS DATABASE_NAME FROM USER_USERS", + "SELECT USERNAME AS DATABASE_NAME FROM ALL_USERS ORDER BY USERNAME", + "SELECT USERNAME AS DATABASE_NAME FROM DBA_USERS ORDER BY USERNAME", + "SELECT USERNAME AS DATABASE_NAME FROM SYS.DBA_USERS ORDER BY USERNAME", + "SELECT DISTINCT OWNER AS DATABASE_NAME FROM ALL_OBJECTS ORDER BY OWNER", + "SELECT DISTINCT OWNER AS DATABASE_NAME FROM ALL_TABLES ORDER BY OWNER", +} + +type damengQueryFunc func(query string) ([]map[string]interface{}, []string, error) + +func collectDamengDatabaseNames(query damengQueryFunc) ([]string, error) { + seen := make(map[string]struct{}) + dbs := make([]string, 0, 64) + var lastErr error + + for _, q := range damengDatabaseQueries { + data, _, err := query(q) + if err != nil { + lastErr = err + continue + } + for _, row := range data { + name := getDamengRowString(row, + "DATABASE_NAME", + "USERNAME", + "OWNER", + "SCHEMA_NAME", + "CURRENT_SCHEMA", + "CURRENT_USER", + ) + if name == "" { + for _, v := range row { + text := strings.TrimSpace(fmt.Sprintf("%v", v)) + if text == "" || strings.EqualFold(text, "") { + continue + } + name = text + break + } + } + if name == "" { + continue + } + key := strings.ToUpper(name) + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + dbs = append(dbs, name) + } + } + + if len(dbs) == 0 && lastErr != nil { + return nil, lastErr + } + + sort.Slice(dbs, func(i, j int) bool { + return strings.ToUpper(dbs[i]) < strings.ToUpper(dbs[j]) + }) + return dbs, nil +} + +func getDamengRowString(row map[string]interface{}, keys ...string) string { + if len(row) == 0 { + return "" + } + for _, key := range keys { + for k, v := range row { + if !strings.EqualFold(strings.TrimSpace(k), strings.TrimSpace(key)) { + continue + } + text := strings.TrimSpace(fmt.Sprintf("%v", v)) + if text == "" || strings.EqualFold(text, "") { + return "" + } + return text + } + } + return "" +} diff --git a/internal/db/dameng_metadata_test.go b/internal/db/dameng_metadata_test.go new file mode 100644 index 0000000..5310679 --- /dev/null +++ b/internal/db/dameng_metadata_test.go @@ -0,0 +1,73 @@ +package db + +import ( + "errors" + "reflect" + "testing" +) + +func TestCollectDamengDatabaseNames_UsesCurrentSchemaFallback(t *testing.T) { + t.Parallel() + + got, err := collectDamengDatabaseNames(func(query string) ([]map[string]interface{}, []string, error) { + switch query { + case damengDatabaseQueries[0]: + return []map[string]interface{}{{"DATABASE_NAME": "APP_SCHEMA"}}, nil, nil + case damengDatabaseQueries[1]: + return []map[string]interface{}{{"DATABASE_NAME": "app_schema"}}, nil, nil + default: + return nil, nil, errors.New("permission denied") + } + }) + if err != nil { + t.Fatalf("collectDamengDatabaseNames 返回错误: %v", err) + } + + want := []string{"APP_SCHEMA"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("unexpected database names, got=%v want=%v", got, want) + } +} + +func TestCollectDamengDatabaseNames_CollectsOwnersWhenVisible(t *testing.T) { + t.Parallel() + + got, err := collectDamengDatabaseNames(func(query string) ([]map[string]interface{}, []string, error) { + switch query { + case damengDatabaseQueries[0], damengDatabaseQueries[1], damengDatabaseQueries[2], damengDatabaseQueries[3], damengDatabaseQueries[4], damengDatabaseQueries[5]: + return []map[string]interface{}{}, nil, nil + case damengDatabaseQueries[6]: + return []map[string]interface{}{{"OWNER": "BIZ"}, {"OWNER": "audit"}}, nil, nil + case damengDatabaseQueries[7]: + return []map[string]interface{}{{"OWNER": "BIZ"}}, nil, nil + default: + return nil, nil, nil + } + }) + if err != nil { + t.Fatalf("collectDamengDatabaseNames 返回错误: %v", err) + } + + want := []string{"audit", "BIZ"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("unexpected database names, got=%v want=%v", got, want) + } +} + +func TestCollectDamengDatabaseNames_ReturnsErrorWhenNoNameResolved(t *testing.T) { + t.Parallel() + + expectErr := errors.New("last query failed") + got, err := collectDamengDatabaseNames(func(query string) ([]map[string]interface{}, []string, error) { + if query == damengDatabaseQueries[len(damengDatabaseQueries)-1] { + return nil, nil, expectErr + } + return nil, nil, errors.New("permission denied") + }) + if err == nil { + t.Fatalf("期望返回错误,实际 got=%v", got) + } + if !errors.Is(err, expectErr) { + t.Fatalf("错误不符合预期: %v", err) + } +}