diff --git a/frontend/src/components/TableDesigner.tsx b/frontend/src/components/TableDesigner.tsx index 4ca8c50..941f99b 100644 --- a/frontend/src/components/TableDesigner.tsx +++ b/frontend/src/components/TableDesigner.tsx @@ -11,6 +11,7 @@ import { DBGetColumns, DBGetIndexes, DBQuery, DBGetForeignKeys, DBGetTriggers, D import { hasIndexFormChanged, normalizeIndexFormFromRow, shouldRestoreOriginalIndex, toggleIndexSelection as getNextIndexSelection, type IndexDisplaySnapshot } from './tableDesignerIndexUtils'; import { buildIndexCreateSqlPreview } from './tableDesignerIndexSql'; import { buildAlterTablePreviewSql, buildCreateTablePreviewSql, hasAlterTableDraftChanges, type StarRocksCreateTableOptions, type StarRocksDistributionType, type StarRocksKeyModel, type StarRocksTableKind } from './tableDesignerSchemaSql'; +import { normalizeSchemaStatementForExecution, parseTableCommentFromDDL, splitSchemaExecutionStatements } from './tableDesignerExecutionSql'; import TableDesignerSqlPreview from './TableDesignerSqlPreview'; import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig'; import { noAutoCapInputProps } from '../utils/inputAutoCap'; @@ -832,8 +833,7 @@ const TableDesigner: React.FC<{ tab: TabData }> = ({ tab }) => { if (ddlRes && ddlRes.success) { const ddlText = String(ddlRes.data || ''); setDdl(ddlText); - const commentMatch = ddlText.replace(/\r?\n/g, ' ').match(/COMMENT\s*=\s*'((?:\\'|''|[^'])*)'/i); - const parsedTableComment = commentMatch ? commentMatch[1].replace(/\\'/g, "'").replace(/''/g, "'") : ''; + const parsedTableComment = parseTableCommentFromDDL(ddlText); setTableComment(parsedTableComment); if (!isTableCommentModalOpen) { setTableCommentDraft(parsedTableComment); @@ -1617,10 +1617,10 @@ ${selectedTrigger.statement}`; useSSH: conn.config.useSSH || false, ssh: conn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" } }; - const statements = sqlText.split(/;\s*\n/).map(s => s.trim()).filter(Boolean); + const dbType = resolveTableInfo().dbType; + const statements = splitSchemaExecutionStatements(sqlText); for (let i = 0; i < statements.length; i++) { - let stmt = statements[i]; - if (!stmt.endsWith(';')) stmt += ';'; + const stmt = normalizeSchemaStatementForExecution(statements[i], dbType); const res = await DBQuery(buildRpcConnectionConfig(config) as any, tab.dbName || '', stmt); if (!res.success) { const prefix = statements.length > 1 ? `第 ${i + 1}/${statements.length} 条语句执行失败: ` : '执行失败: '; diff --git a/frontend/src/components/tableDesignerExecutionSql.test.ts b/frontend/src/components/tableDesignerExecutionSql.test.ts new file mode 100644 index 0000000..6c5b6d3 --- /dev/null +++ b/frontend/src/components/tableDesignerExecutionSql.test.ts @@ -0,0 +1,33 @@ +import { describe, expect, it } from 'vitest'; + +import { + normalizeSchemaStatementForExecution, + parseTableCommentFromDDL, + splitSchemaExecutionStatements, +} from './tableDesignerExecutionSql'; + +describe('tableDesignerExecutionSql', () => { + it('strips trailing semicolons before executing oracle schema statements', () => { + expect( + normalizeSchemaStatementForExecution(`COMMENT ON COLUMN "H2"."D_YS_MEMCARD_CX"."ID" IS 'ID';`, 'oracle'), + ).toBe(`COMMENT ON COLUMN "H2"."D_YS_MEMCARD_CX"."ID" IS 'ID'`); + }); + + it('keeps trailing semicolons for non-oracle schema statements', () => { + expect(normalizeSchemaStatementForExecution('ALTER TABLE `users` ADD COLUMN `age` int', 'mysql')) + .toBe('ALTER TABLE `users` ADD COLUMN `age` int;'); + }); + + it('splits generated schema SQL into individual statements', () => { + expect(splitSchemaExecutionStatements('ALTER TABLE users ADD age int;\nCOMMENT ON COLUMN users.age IS \'年龄\';')) + .toEqual(['ALTER TABLE users ADD age int', "COMMENT ON COLUMN users.age IS '年龄';"]); + }); + + it('parses mysql and oracle table comments from DDL', () => { + expect(parseTableCommentFromDDL("CREATE TABLE `users` (`id` int) COMMENT='用户\\'表';")) + .toBe("用户'表"); + expect(parseTableCommentFromDDL(`CREATE TABLE "HR"."EMPLOYEES" ("ID" NUMBER); +COMMENT ON TABLE "HR"."EMPLOYEES" IS '员工''表'; +COMMENT ON COLUMN "HR"."EMPLOYEES"."ID" IS '主键';`)).toBe("员工'表"); + }); +}); diff --git a/frontend/src/components/tableDesignerExecutionSql.ts b/frontend/src/components/tableDesignerExecutionSql.ts new file mode 100644 index 0000000..8926c11 --- /dev/null +++ b/frontend/src/components/tableDesignerExecutionSql.ts @@ -0,0 +1,37 @@ +import { isOracleLikeDialect } from '../utils/sqlDialect'; + +export const splitSchemaExecutionStatements = (sqlText: string): string[] => ( + String(sqlText || '') + .replace(/;/g, ';') + .split(/;\s*\n/) + .map(statement => statement.trim()) + .filter(Boolean) +); + +export const normalizeSchemaStatementForExecution = (statement: string, dbType: string): string => { + const trimmed = String(statement || '').trim(); + if (!trimmed) return ''; + if (isOracleLikeDialect(dbType)) { + return trimmed.replace(/;+\s*$/, '').trim(); + } + return trimmed.endsWith(';') ? trimmed : `${trimmed};`; +}; + +const unescapeSqlComment = (text: string, mysqlBackslashEscapes = false): string => { + const unescaped = text.replace(/''/g, "'"); + return mysqlBackslashEscapes ? unescaped.replace(/\\'/g, "'") : unescaped; +}; + +export const parseTableCommentFromDDL = (ddlText: string): string => { + const ddl = String(ddlText || '').replace(/\r?\n/g, ' '); + const mysqlMatch = ddl.match(/COMMENT\s*=\s*'((?:\\'|''|[^'])*)'/i); + if (mysqlMatch) { + return unescapeSqlComment(mysqlMatch[1], true); + } + + const commentOnTableMatch = ddl.match(/\bCOMMENT\s+ON\s+TABLE\s+.+?\s+IS\s+(NULL|'((?:''|[^'])*)')/i); + if (!commentOnTableMatch || commentOnTableMatch[1].toUpperCase() === 'NULL') { + return ''; + } + return unescapeSqlComment(commentOnTableMatch[2] || ''); +}; diff --git a/internal/db/oracle_applychanges_test.go b/internal/db/oracle_applychanges_test.go index f88ba5a..23891bc 100644 --- a/internal/db/oracle_applychanges_test.go +++ b/internal/db/oracle_applychanges_test.go @@ -27,6 +27,7 @@ type oracleRecordingState struct { mu sync.Mutex execQueries []string execArgs [][]driver.NamedValue + queries []string rowsAffected int64 queryResults map[string]oracleRecordingQueryResult queryError error @@ -54,6 +55,12 @@ func (s *oracleRecordingState) snapshotExecArgs() [][]driver.NamedValue { return result } +func (s *oracleRecordingState) snapshotQueries() []string { + s.mu.Lock() + defer s.mu.Unlock() + return append([]string(nil), s.queries...) +} + type oracleRecordingDriver struct{} func (oracleRecordingDriver) Open(name string) (driver.Conn, error) { @@ -88,6 +95,7 @@ func (c *oracleRecordingConn) ExecContext(_ context.Context, query string, args func (c *oracleRecordingConn) QueryContext(_ context.Context, query string, _ []driver.NamedValue) (driver.Rows, error) { c.state.mu.Lock() + c.state.queries = append(c.state.queries, query) if err := c.state.queryError; err != nil { c.state.mu.Unlock() return nil, err @@ -103,10 +111,10 @@ func (c *oracleRecordingConn) QueryContext(_ context.Context, query string, _ [] if strings.Contains(strings.ToLower(query), "tab_columns") { return &oracleRecordingRows{ - columns: []string{"COLUMN_NAME", "DATA_TYPE", "NULLABLE", "DATA_DEFAULT"}, + columns: []string{"COLUMN_NAME", "DATA_TYPE", "NULLABLE", "DATA_DEFAULT", "COLUMN_KEY", "COMMENT"}, rows: [][]driver.Value{ - {"UPDATED_AT", "TIMESTAMP", "YES", nil}, - {"CREATED_AT", "DATE", "NO", nil}, + {"UPDATED_AT", "TIMESTAMP", "YES", nil, "", "更新时间"}, + {"CREATED_AT", "DATE", "NO", nil, "", nil}, }, }, nil } diff --git a/internal/db/oracle_get_tables_test.go b/internal/db/oracle_get_tables_test.go index c971bd0..2db3b61 100644 --- a/internal/db/oracle_get_tables_test.go +++ b/internal/db/oracle_get_tables_test.go @@ -3,6 +3,7 @@ package db import ( "database/sql/driver" "reflect" + "strings" "testing" ) @@ -82,3 +83,73 @@ func TestOracleGetTablesSkipsRowsWithNullTableName(t *testing.T) { t.Fatalf("NULL TABLE_NAME 应被跳过,期望 %v,实际 %v", want, tables) } } + +func TestOracleGetColumnsIncludesColumnComments(t *testing.T) { + t.Parallel() + + dbConn, state := openOracleRecordingDB(t) + oracleDB := &OracleDB{conn: dbConn} + columns, err := oracleDB.GetColumns("MYCIMLED", "EDC_LOG") + if err != nil { + t.Fatalf("GetColumns 返回错误: %v", err) + } + if len(columns) == 0 { + t.Fatalf("expected columns") + } + if columns[0].Name != "UPDATED_AT" || columns[0].Comment != "更新时间" { + t.Fatalf("expected first column comment from Oracle metadata, got %#v", columns[0]) + } + + queries := state.snapshotQueries() + if len(queries) == 0 || !strings.Contains(queries[0], "all_col_comments") { + t.Fatalf("expected GetColumns to join all_col_comments, queries=%v", queries) + } +} + +func TestOracleGetCreateStatementAppendsTableAndColumnComments(t *testing.T) { + t.Parallel() + + dbConn, state := openOracleRecordingDB(t) + state.mu.Lock() + state.queryResults[`SELECT DBMS_METADATA.GET_DDL('TABLE', 'EDC_LOG', 'MYCIMLED') as ddl FROM DUAL`] = oracleRecordingQueryResult{ + columns: []string{"DDL"}, + rows: [][]driver.Value{ + {`CREATE TABLE "MYCIMLED"."EDC_LOG" ( + "ID" NUMBER NOT NULL +)`}, + }, + } + state.queryResults[`SELECT comments AS "COMMENT" FROM all_tab_comments WHERE owner = 'MYCIMLED' AND table_name = 'EDC_LOG' AND comments IS NOT NULL`] = oracleRecordingQueryResult{ + columns: []string{"COMMENT"}, + rows: [][]driver.Value{ + {"日志表"}, + }, + } + state.queryResults[`SELECT c.column_name AS "COLUMN_NAME", cc.comments AS "COMMENT" +FROM all_tab_columns c +JOIN all_col_comments cc + ON cc.owner = c.owner AND cc.table_name = c.table_name AND cc.column_name = c.column_name +WHERE c.owner = 'MYCIMLED' AND c.table_name = 'EDC_LOG' AND cc.comments IS NOT NULL +ORDER BY c.column_id`] = oracleRecordingQueryResult{ + columns: []string{"COLUMN_NAME", "COMMENT"}, + rows: [][]driver.Value{ + {"ID", "主键's"}, + }, + } + state.mu.Unlock() + + oracleDB := &OracleDB{conn: dbConn} + ddl, err := oracleDB.GetCreateStatement("MYCIMLED", "EDC_LOG") + if err != nil { + t.Fatalf("GetCreateStatement 返回错误: %v", err) + } + for _, want := range []string{ + `CREATE TABLE "MYCIMLED"."EDC_LOG"`, + `COMMENT ON TABLE "MYCIMLED"."EDC_LOG" IS '日志表';`, + `COMMENT ON COLUMN "MYCIMLED"."EDC_LOG"."ID" IS '主键''s';`, + } { + if !strings.Contains(ddl, want) { + t.Fatalf("expected DDL to contain %q, got: %s", want, ddl) + } + } +} diff --git a/internal/db/oracle_impl.go b/internal/db/oracle_impl.go index 52763ef..d1ec431 100644 --- a/internal/db/oracle_impl.go +++ b/internal/db/oracle_impl.go @@ -272,7 +272,7 @@ func (o *OracleDB) GetTables(dbName string) ([]string, error) { // 列别名用双引号包裹强制大写,避免不同驱动版本返回不一致 case 导致 row map 取值失败 var query string if dbName != "" { - query = fmt.Sprintf(`SELECT owner AS "OWNER", table_name AS "TABLE_NAME" FROM all_tables WHERE owner = '%s' ORDER BY table_name`, strings.ToUpper(dbName)) + query = fmt.Sprintf(`SELECT owner AS "OWNER", table_name AS "TABLE_NAME" FROM all_tables WHERE owner = '%s' ORDER BY table_name`, escapeOracleMetadataLiteral(dbName)) } else { query = `SELECT USER AS "OWNER", table_name AS "TABLE_NAME" FROM user_tables ORDER BY table_name` } @@ -300,11 +300,13 @@ func (o *OracleDB) GetTables(dbName string) ([]string, error) { func (o *OracleDB) GetCreateStatement(dbName, tableName string) (string, error) { // Oracle provides DBMS_METADATA.GET_DDL // Note: LONG type might be tricky, but basic string scan should work for smaller DDLs + metadataTableName := escapeOracleMetadataLiteral(tableName) + metadataSchemaName := escapeOracleMetadataLiteral(dbName) query := fmt.Sprintf("SELECT DBMS_METADATA.GET_DDL('TABLE', '%s', '%s') as ddl FROM DUAL", - strings.ToUpper(tableName), strings.ToUpper(dbName)) + metadataTableName, metadataSchemaName) if dbName == "" { - query = fmt.Sprintf("SELECT DBMS_METADATA.GET_DDL('TABLE', '%s') as ddl FROM DUAL", strings.ToUpper(tableName)) + query = fmt.Sprintf("SELECT DBMS_METADATA.GET_DDL('TABLE', '%s') as ddl FROM DUAL", metadataTableName) } data, _, err := o.Query(query) @@ -314,16 +316,21 @@ func (o *OracleDB) GetCreateStatement(dbName, tableName string) (string, error) if len(data) > 0 { if val, ok := data[0]["DDL"]; ok { - return fmt.Sprintf("%v", val), nil + return o.appendOracleCommentDDL(fmt.Sprintf("%v", val), dbName, tableName), nil } } return "", fmt.Errorf("未找到建表语句") } func (o *OracleDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) { + metadataTableName := escapeOracleMetadataLiteral(tableName) + metadataSchemaName := escapeOracleMetadataLiteral(dbName) query := fmt.Sprintf(`SELECT c.column_name, c.data_type, c.nullable, c.data_default, - CASE WHEN pk.column_name IS NOT NULL THEN 'PRI' ELSE '' END AS column_key + CASE WHEN pk.column_name IS NOT NULL THEN 'PRI' ELSE '' END AS column_key, + cc.comments AS comment FROM all_tab_columns c + LEFT JOIN all_col_comments cc + ON cc.owner = c.owner AND cc.table_name = c.table_name AND cc.column_name = c.column_name LEFT JOIN ( SELECT cols.owner, cols.table_name, cols.column_name FROM all_constraints cons @@ -332,12 +339,15 @@ func (o *OracleDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefi WHERE cons.constraint_type = 'P' ) pk ON c.owner = pk.owner AND c.table_name = pk.table_name AND c.column_name = pk.column_name WHERE c.owner = '%s' AND c.table_name = '%s' - ORDER BY c.column_id`, strings.ToUpper(dbName), strings.ToUpper(tableName)) + ORDER BY c.column_id`, metadataSchemaName, metadataTableName) if dbName == "" { query = fmt.Sprintf(`SELECT c.column_name, c.data_type, c.nullable, c.data_default, - CASE WHEN pk.column_name IS NOT NULL THEN 'PRI' ELSE '' END AS column_key + CASE WHEN pk.column_name IS NOT NULL THEN 'PRI' ELSE '' END AS column_key, + cc.comments AS comment FROM user_tab_columns c + LEFT JOIN user_col_comments cc + ON cc.table_name = c.table_name AND cc.column_name = c.column_name LEFT JOIN ( SELECT cols.table_name, cols.column_name FROM user_constraints cons @@ -345,7 +355,7 @@ func (o *OracleDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefi WHERE cons.constraint_type = 'P' ) pk ON c.table_name = pk.table_name AND c.column_name = pk.column_name WHERE c.table_name = '%s' - ORDER BY c.column_id`, strings.ToUpper(tableName)) + ORDER BY c.column_id`, metadataTableName) } data, _, err := o.Query(query) @@ -356,14 +366,15 @@ func (o *OracleDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefi var columns []connection.ColumnDefinition for _, row := range data { col := connection.ColumnDefinition{ - Name: fmt.Sprintf("%v", row["COLUMN_NAME"]), - Type: fmt.Sprintf("%v", row["DATA_TYPE"]), - Nullable: fmt.Sprintf("%v", row["NULLABLE"]), - Key: fmt.Sprintf("%v", row["COLUMN_KEY"]), + Name: oracleRowString(row, "COLUMN_NAME"), + Type: oracleRowString(row, "DATA_TYPE"), + Nullable: oracleRowString(row, "NULLABLE"), + Key: oracleRowString(row, "COLUMN_KEY"), + Comment: oracleRowString(row, "COMMENT"), } - if row["DATA_DEFAULT"] != nil { - d := fmt.Sprintf("%v", row["DATA_DEFAULT"]) + if defaultValue := oracleRowValue(row, "DATA_DEFAULT"); defaultValue != nil { + d := fmt.Sprintf("%v", defaultValue) col.Default = &d } @@ -372,6 +383,134 @@ func (o *OracleDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefi return columns, nil } +func oracleRowValue(row map[string]interface{}, names ...string) interface{} { + for _, name := range names { + if value, ok := row[name]; ok { + return value + } + for key, value := range row { + if strings.EqualFold(key, name) { + return value + } + } + } + return nil +} + +func oracleRowString(row map[string]interface{}, names ...string) string { + value := oracleRowValue(row, names...) + if value == nil { + return "" + } + return fmt.Sprintf("%v", value) +} + +func (o *OracleDB) appendOracleCommentDDL(baseDDL string, dbName string, tableName string) string { + table := strings.ToUpper(strings.TrimSpace(tableName)) + if strings.TrimSpace(baseDDL) == "" || table == "" { + return baseDDL + } + + schema := strings.ToUpper(strings.TrimSpace(dbName)) + tableRef := quoteOracleDDLIdentifier(table) + if schema != "" { + tableRef = quoteOracleDDLIdentifier(schema) + "." + tableRef + } + existingDDLUpper := strings.ToUpper(baseDDL) + commentLines := make([]string, 0, 4) + + if tableComment := strings.TrimSpace(o.fetchOracleTableComment(schema, table)); tableComment != "" { + marker := "COMMENT ON TABLE " + strings.ToUpper(tableRef) + if !strings.Contains(existingDDLUpper, marker) { + commentLines = append(commentLines, fmt.Sprintf("COMMENT ON TABLE %s IS '%s';", tableRef, escapeOracleCommentLiteral(tableComment))) + } + } + + for _, colComment := range o.fetchOracleColumnComments(schema, table) { + columnName := strings.TrimSpace(colComment.columnName) + comment := strings.TrimSpace(colComment.comment) + if columnName == "" || comment == "" { + continue + } + columnRef := fmt.Sprintf("%s.%s", tableRef, quoteOracleDDLIdentifier(columnName)) + marker := "COMMENT ON COLUMN " + strings.ToUpper(columnRef) + if strings.Contains(existingDDLUpper, marker) { + continue + } + commentLines = append(commentLines, fmt.Sprintf("COMMENT ON COLUMN %s IS '%s';", columnRef, escapeOracleCommentLiteral(comment))) + } + + if len(commentLines) == 0 { + return baseDDL + } + return strings.TrimRight(baseDDL, " \t\r\n") + "\n" + strings.Join(commentLines, "\n") +} + +func (o *OracleDB) fetchOracleTableComment(schema string, table string) string { + escapedTable := escapeOracleMetadataLiteral(table) + var query string + if strings.TrimSpace(schema) != "" { + query = fmt.Sprintf(`SELECT comments AS "COMMENT" FROM all_tab_comments WHERE owner = '%s' AND table_name = '%s' AND comments IS NOT NULL`, escapeOracleMetadataLiteral(schema), escapedTable) + } else { + query = fmt.Sprintf(`SELECT comments AS "COMMENT" FROM user_tab_comments WHERE table_name = '%s' AND comments IS NOT NULL`, escapedTable) + } + data, _, err := o.Query(query) + if err != nil || len(data) == 0 { + return "" + } + return oracleRowString(data[0], "COMMENT", "COMMENTS") +} + +type oracleColumnComment struct { + columnName string + comment string +} + +func (o *OracleDB) fetchOracleColumnComments(schema string, table string) []oracleColumnComment { + escapedTable := escapeOracleMetadataLiteral(table) + var query string + if strings.TrimSpace(schema) != "" { + query = fmt.Sprintf(`SELECT c.column_name AS "COLUMN_NAME", cc.comments AS "COMMENT" +FROM all_tab_columns c +JOIN all_col_comments cc + ON cc.owner = c.owner AND cc.table_name = c.table_name AND cc.column_name = c.column_name +WHERE c.owner = '%s' AND c.table_name = '%s' AND cc.comments IS NOT NULL +ORDER BY c.column_id`, escapeOracleMetadataLiteral(schema), escapedTable) + } else { + query = fmt.Sprintf(`SELECT c.column_name AS "COLUMN_NAME", cc.comments AS "COMMENT" +FROM user_tab_columns c +JOIN user_col_comments cc + ON cc.table_name = c.table_name AND cc.column_name = c.column_name +WHERE c.table_name = '%s' AND cc.comments IS NOT NULL +ORDER BY c.column_id`, escapedTable) + } + + data, _, err := o.Query(query) + if err != nil { + return nil + } + comments := make([]oracleColumnComment, 0, len(data)) + for _, row := range data { + comments = append(comments, oracleColumnComment{ + columnName: oracleRowString(row, "COLUMN_NAME"), + comment: oracleRowString(row, "COMMENT", "COMMENTS"), + }) + } + return comments +} + +func quoteOracleDDLIdentifier(ident string) string { + return `"` + strings.ReplaceAll(strings.TrimSpace(ident), `"`, `""`) + `"` +} + +func escapeOracleCommentLiteral(text string) string { + return strings.ReplaceAll(text, "'", "''") +} + +func escapeOracleMetadataLiteral(text string) string { + return strings.ReplaceAll(strings.ToUpper(strings.TrimSpace(text)), "'", "''") +} + func (o *OracleDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) { esc := func(s string) string { return strings.ReplaceAll(strings.ToUpper(strings.TrimSpace(s)), "'", "''") } table := esc(tableName)