🐛 fix(oracle): 修复表结构注释读取与保存报错

- 补齐 Oracle 表字段注释元数据读取

- 在表结构 DDL 中追加表和字段注释信息

- 规范表设计器 Oracle DDL 执行前的分号处理

Refs #482
This commit is contained in:
Syngnat
2026-05-23 17:41:46 +08:00
parent b9c743d67e
commit 56b3112a07
6 changed files with 310 additions and 22 deletions

View File

@@ -11,6 +11,7 @@ import { DBGetColumns, DBGetIndexes, DBQuery, DBGetForeignKeys, DBGetTriggers, D
import { hasIndexFormChanged, normalizeIndexFormFromRow, shouldRestoreOriginalIndex, toggleIndexSelection as getNextIndexSelection, type IndexDisplaySnapshot } from './tableDesignerIndexUtils';
import { buildIndexCreateSqlPreview } from './tableDesignerIndexSql';
import { buildAlterTablePreviewSql, buildCreateTablePreviewSql, hasAlterTableDraftChanges, type StarRocksCreateTableOptions, type StarRocksDistributionType, type StarRocksKeyModel, type StarRocksTableKind } from './tableDesignerSchemaSql';
import { normalizeSchemaStatementForExecution, parseTableCommentFromDDL, splitSchemaExecutionStatements } from './tableDesignerExecutionSql';
import TableDesignerSqlPreview from './TableDesignerSqlPreview';
import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig';
import { noAutoCapInputProps } from '../utils/inputAutoCap';
@@ -832,8 +833,7 @@ const TableDesigner: React.FC<{ tab: TabData }> = ({ tab }) => {
if (ddlRes && ddlRes.success) {
const ddlText = String(ddlRes.data || '');
setDdl(ddlText);
const commentMatch = ddlText.replace(/\r?\n/g, ' ').match(/COMMENT\s*=\s*'((?:\\'|''|[^'])*)'/i);
const parsedTableComment = commentMatch ? commentMatch[1].replace(/\\'/g, "'").replace(/''/g, "'") : '';
const parsedTableComment = parseTableCommentFromDDL(ddlText);
setTableComment(parsedTableComment);
if (!isTableCommentModalOpen) {
setTableCommentDraft(parsedTableComment);
@@ -1617,10 +1617,10 @@ ${selectedTrigger.statement}`;
useSSH: conn.config.useSSH || false,
ssh: conn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" }
};
const statements = sqlText.split(/;\s*\n/).map(s => s.trim()).filter(Boolean);
const dbType = resolveTableInfo().dbType;
const statements = splitSchemaExecutionStatements(sqlText);
for (let i = 0; i < statements.length; i++) {
let stmt = statements[i];
if (!stmt.endsWith(';')) stmt += ';';
const stmt = normalizeSchemaStatementForExecution(statements[i], dbType);
const res = await DBQuery(buildRpcConnectionConfig(config) as any, tab.dbName || '', stmt);
if (!res.success) {
const prefix = statements.length > 1 ? `${i + 1}/${statements.length} 条语句执行失败: ` : '执行失败: ';

View File

@@ -0,0 +1,33 @@
import { describe, expect, it } from 'vitest';
import {
normalizeSchemaStatementForExecution,
parseTableCommentFromDDL,
splitSchemaExecutionStatements,
} from './tableDesignerExecutionSql';
describe('tableDesignerExecutionSql', () => {
it('strips trailing semicolons before executing oracle schema statements', () => {
expect(
normalizeSchemaStatementForExecution(`COMMENT ON COLUMN "H2"."D_YS_MEMCARD_CX"."ID" IS 'ID';`, 'oracle'),
).toBe(`COMMENT ON COLUMN "H2"."D_YS_MEMCARD_CX"."ID" IS 'ID'`);
});
it('keeps trailing semicolons for non-oracle schema statements', () => {
expect(normalizeSchemaStatementForExecution('ALTER TABLE `users` ADD COLUMN `age` int', 'mysql'))
.toBe('ALTER TABLE `users` ADD COLUMN `age` int;');
});
it('splits generated schema SQL into individual statements', () => {
expect(splitSchemaExecutionStatements('ALTER TABLE users ADD age int;\nCOMMENT ON COLUMN users.age IS \'年龄\';'))
.toEqual(['ALTER TABLE users ADD age int', "COMMENT ON COLUMN users.age IS '年龄';"]);
});
it('parses mysql and oracle table comments from DDL', () => {
expect(parseTableCommentFromDDL("CREATE TABLE `users` (`id` int) COMMENT='用户\\'表';"))
.toBe("用户'表");
expect(parseTableCommentFromDDL(`CREATE TABLE "HR"."EMPLOYEES" ("ID" NUMBER);
COMMENT ON TABLE "HR"."EMPLOYEES" IS '员工''表';
COMMENT ON COLUMN "HR"."EMPLOYEES"."ID" IS '主键';`)).toBe("员工'表");
});
});

View File

@@ -0,0 +1,37 @@
import { isOracleLikeDialect } from '../utils/sqlDialect';
export const splitSchemaExecutionStatements = (sqlText: string): string[] => (
String(sqlText || '')
.replace(//g, ';')
.split(/;\s*\n/)
.map(statement => statement.trim())
.filter(Boolean)
);
export const normalizeSchemaStatementForExecution = (statement: string, dbType: string): string => {
const trimmed = String(statement || '').trim();
if (!trimmed) return '';
if (isOracleLikeDialect(dbType)) {
return trimmed.replace(/;+\s*$/, '').trim();
}
return trimmed.endsWith(';') ? trimmed : `${trimmed};`;
};
const unescapeSqlComment = (text: string, mysqlBackslashEscapes = false): string => {
const unescaped = text.replace(/''/g, "'");
return mysqlBackslashEscapes ? unescaped.replace(/\\'/g, "'") : unescaped;
};
export const parseTableCommentFromDDL = (ddlText: string): string => {
const ddl = String(ddlText || '').replace(/\r?\n/g, ' ');
const mysqlMatch = ddl.match(/COMMENT\s*=\s*'((?:\\'|''|[^'])*)'/i);
if (mysqlMatch) {
return unescapeSqlComment(mysqlMatch[1], true);
}
const commentOnTableMatch = ddl.match(/\bCOMMENT\s+ON\s+TABLE\s+.+?\s+IS\s+(NULL|'((?:''|[^'])*)')/i);
if (!commentOnTableMatch || commentOnTableMatch[1].toUpperCase() === 'NULL') {
return '';
}
return unescapeSqlComment(commentOnTableMatch[2] || '');
};

View File

@@ -27,6 +27,7 @@ type oracleRecordingState struct {
mu sync.Mutex
execQueries []string
execArgs [][]driver.NamedValue
queries []string
rowsAffected int64
queryResults map[string]oracleRecordingQueryResult
queryError error
@@ -54,6 +55,12 @@ func (s *oracleRecordingState) snapshotExecArgs() [][]driver.NamedValue {
return result
}
func (s *oracleRecordingState) snapshotQueries() []string {
s.mu.Lock()
defer s.mu.Unlock()
return append([]string(nil), s.queries...)
}
type oracleRecordingDriver struct{}
func (oracleRecordingDriver) Open(name string) (driver.Conn, error) {
@@ -88,6 +95,7 @@ func (c *oracleRecordingConn) ExecContext(_ context.Context, query string, args
func (c *oracleRecordingConn) QueryContext(_ context.Context, query string, _ []driver.NamedValue) (driver.Rows, error) {
c.state.mu.Lock()
c.state.queries = append(c.state.queries, query)
if err := c.state.queryError; err != nil {
c.state.mu.Unlock()
return nil, err
@@ -103,10 +111,10 @@ func (c *oracleRecordingConn) QueryContext(_ context.Context, query string, _ []
if strings.Contains(strings.ToLower(query), "tab_columns") {
return &oracleRecordingRows{
columns: []string{"COLUMN_NAME", "DATA_TYPE", "NULLABLE", "DATA_DEFAULT"},
columns: []string{"COLUMN_NAME", "DATA_TYPE", "NULLABLE", "DATA_DEFAULT", "COLUMN_KEY", "COMMENT"},
rows: [][]driver.Value{
{"UPDATED_AT", "TIMESTAMP", "YES", nil},
{"CREATED_AT", "DATE", "NO", nil},
{"UPDATED_AT", "TIMESTAMP", "YES", nil, "", "更新时间"},
{"CREATED_AT", "DATE", "NO", nil, "", nil},
},
}, nil
}

View File

@@ -3,6 +3,7 @@ package db
import (
"database/sql/driver"
"reflect"
"strings"
"testing"
)
@@ -82,3 +83,73 @@ func TestOracleGetTablesSkipsRowsWithNullTableName(t *testing.T) {
t.Fatalf("NULL TABLE_NAME 应被跳过,期望 %v实际 %v", want, tables)
}
}
func TestOracleGetColumnsIncludesColumnComments(t *testing.T) {
t.Parallel()
dbConn, state := openOracleRecordingDB(t)
oracleDB := &OracleDB{conn: dbConn}
columns, err := oracleDB.GetColumns("MYCIMLED", "EDC_LOG")
if err != nil {
t.Fatalf("GetColumns 返回错误: %v", err)
}
if len(columns) == 0 {
t.Fatalf("expected columns")
}
if columns[0].Name != "UPDATED_AT" || columns[0].Comment != "更新时间" {
t.Fatalf("expected first column comment from Oracle metadata, got %#v", columns[0])
}
queries := state.snapshotQueries()
if len(queries) == 0 || !strings.Contains(queries[0], "all_col_comments") {
t.Fatalf("expected GetColumns to join all_col_comments, queries=%v", queries)
}
}
func TestOracleGetCreateStatementAppendsTableAndColumnComments(t *testing.T) {
t.Parallel()
dbConn, state := openOracleRecordingDB(t)
state.mu.Lock()
state.queryResults[`SELECT DBMS_METADATA.GET_DDL('TABLE', 'EDC_LOG', 'MYCIMLED') as ddl FROM DUAL`] = oracleRecordingQueryResult{
columns: []string{"DDL"},
rows: [][]driver.Value{
{`CREATE TABLE "MYCIMLED"."EDC_LOG" (
"ID" NUMBER NOT NULL
)`},
},
}
state.queryResults[`SELECT comments AS "COMMENT" FROM all_tab_comments WHERE owner = 'MYCIMLED' AND table_name = 'EDC_LOG' AND comments IS NOT NULL`] = oracleRecordingQueryResult{
columns: []string{"COMMENT"},
rows: [][]driver.Value{
{"日志表"},
},
}
state.queryResults[`SELECT c.column_name AS "COLUMN_NAME", cc.comments AS "COMMENT"
FROM all_tab_columns c
JOIN all_col_comments cc
ON cc.owner = c.owner AND cc.table_name = c.table_name AND cc.column_name = c.column_name
WHERE c.owner = 'MYCIMLED' AND c.table_name = 'EDC_LOG' AND cc.comments IS NOT NULL
ORDER BY c.column_id`] = oracleRecordingQueryResult{
columns: []string{"COLUMN_NAME", "COMMENT"},
rows: [][]driver.Value{
{"ID", "主键's"},
},
}
state.mu.Unlock()
oracleDB := &OracleDB{conn: dbConn}
ddl, err := oracleDB.GetCreateStatement("MYCIMLED", "EDC_LOG")
if err != nil {
t.Fatalf("GetCreateStatement 返回错误: %v", err)
}
for _, want := range []string{
`CREATE TABLE "MYCIMLED"."EDC_LOG"`,
`COMMENT ON TABLE "MYCIMLED"."EDC_LOG" IS '日志表';`,
`COMMENT ON COLUMN "MYCIMLED"."EDC_LOG"."ID" IS '主键''s';`,
} {
if !strings.Contains(ddl, want) {
t.Fatalf("expected DDL to contain %q, got: %s", want, ddl)
}
}
}

View File

@@ -272,7 +272,7 @@ func (o *OracleDB) GetTables(dbName string) ([]string, error) {
// 列别名用双引号包裹强制大写,避免不同驱动版本返回不一致 case 导致 row map 取值失败
var query string
if dbName != "" {
query = fmt.Sprintf(`SELECT owner AS "OWNER", table_name AS "TABLE_NAME" FROM all_tables WHERE owner = '%s' ORDER BY table_name`, strings.ToUpper(dbName))
query = fmt.Sprintf(`SELECT owner AS "OWNER", table_name AS "TABLE_NAME" FROM all_tables WHERE owner = '%s' ORDER BY table_name`, escapeOracleMetadataLiteral(dbName))
} else {
query = `SELECT USER AS "OWNER", table_name AS "TABLE_NAME" FROM user_tables ORDER BY table_name`
}
@@ -300,11 +300,13 @@ func (o *OracleDB) GetTables(dbName string) ([]string, error) {
func (o *OracleDB) GetCreateStatement(dbName, tableName string) (string, error) {
// Oracle provides DBMS_METADATA.GET_DDL
// Note: LONG type might be tricky, but basic string scan should work for smaller DDLs
metadataTableName := escapeOracleMetadataLiteral(tableName)
metadataSchemaName := escapeOracleMetadataLiteral(dbName)
query := fmt.Sprintf("SELECT DBMS_METADATA.GET_DDL('TABLE', '%s', '%s') as ddl FROM DUAL",
strings.ToUpper(tableName), strings.ToUpper(dbName))
metadataTableName, metadataSchemaName)
if dbName == "" {
query = fmt.Sprintf("SELECT DBMS_METADATA.GET_DDL('TABLE', '%s') as ddl FROM DUAL", strings.ToUpper(tableName))
query = fmt.Sprintf("SELECT DBMS_METADATA.GET_DDL('TABLE', '%s') as ddl FROM DUAL", metadataTableName)
}
data, _, err := o.Query(query)
@@ -314,16 +316,21 @@ func (o *OracleDB) GetCreateStatement(dbName, tableName string) (string, error)
if len(data) > 0 {
if val, ok := data[0]["DDL"]; ok {
return fmt.Sprintf("%v", val), nil
return o.appendOracleCommentDDL(fmt.Sprintf("%v", val), dbName, tableName), nil
}
}
return "", fmt.Errorf("未找到建表语句")
}
func (o *OracleDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
metadataTableName := escapeOracleMetadataLiteral(tableName)
metadataSchemaName := escapeOracleMetadataLiteral(dbName)
query := fmt.Sprintf(`SELECT c.column_name, c.data_type, c.nullable, c.data_default,
CASE WHEN pk.column_name IS NOT NULL THEN 'PRI' ELSE '' END AS column_key
CASE WHEN pk.column_name IS NOT NULL THEN 'PRI' ELSE '' END AS column_key,
cc.comments AS comment
FROM all_tab_columns c
LEFT JOIN all_col_comments cc
ON cc.owner = c.owner AND cc.table_name = c.table_name AND cc.column_name = c.column_name
LEFT JOIN (
SELECT cols.owner, cols.table_name, cols.column_name
FROM all_constraints cons
@@ -332,12 +339,15 @@ func (o *OracleDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefi
WHERE cons.constraint_type = 'P'
) pk ON c.owner = pk.owner AND c.table_name = pk.table_name AND c.column_name = pk.column_name
WHERE c.owner = '%s' AND c.table_name = '%s'
ORDER BY c.column_id`, strings.ToUpper(dbName), strings.ToUpper(tableName))
ORDER BY c.column_id`, metadataSchemaName, metadataTableName)
if dbName == "" {
query = fmt.Sprintf(`SELECT c.column_name, c.data_type, c.nullable, c.data_default,
CASE WHEN pk.column_name IS NOT NULL THEN 'PRI' ELSE '' END AS column_key
CASE WHEN pk.column_name IS NOT NULL THEN 'PRI' ELSE '' END AS column_key,
cc.comments AS comment
FROM user_tab_columns c
LEFT JOIN user_col_comments cc
ON cc.table_name = c.table_name AND cc.column_name = c.column_name
LEFT JOIN (
SELECT cols.table_name, cols.column_name
FROM user_constraints cons
@@ -345,7 +355,7 @@ func (o *OracleDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefi
WHERE cons.constraint_type = 'P'
) pk ON c.table_name = pk.table_name AND c.column_name = pk.column_name
WHERE c.table_name = '%s'
ORDER BY c.column_id`, strings.ToUpper(tableName))
ORDER BY c.column_id`, metadataTableName)
}
data, _, err := o.Query(query)
@@ -356,14 +366,15 @@ func (o *OracleDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefi
var columns []connection.ColumnDefinition
for _, row := range data {
col := connection.ColumnDefinition{
Name: fmt.Sprintf("%v", row["COLUMN_NAME"]),
Type: fmt.Sprintf("%v", row["DATA_TYPE"]),
Nullable: fmt.Sprintf("%v", row["NULLABLE"]),
Key: fmt.Sprintf("%v", row["COLUMN_KEY"]),
Name: oracleRowString(row, "COLUMN_NAME"),
Type: oracleRowString(row, "DATA_TYPE"),
Nullable: oracleRowString(row, "NULLABLE"),
Key: oracleRowString(row, "COLUMN_KEY"),
Comment: oracleRowString(row, "COMMENT"),
}
if row["DATA_DEFAULT"] != nil {
d := fmt.Sprintf("%v", row["DATA_DEFAULT"])
if defaultValue := oracleRowValue(row, "DATA_DEFAULT"); defaultValue != nil {
d := fmt.Sprintf("%v", defaultValue)
col.Default = &d
}
@@ -372,6 +383,134 @@ func (o *OracleDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefi
return columns, nil
}
func oracleRowValue(row map[string]interface{}, names ...string) interface{} {
for _, name := range names {
if value, ok := row[name]; ok {
return value
}
for key, value := range row {
if strings.EqualFold(key, name) {
return value
}
}
}
return nil
}
func oracleRowString(row map[string]interface{}, names ...string) string {
value := oracleRowValue(row, names...)
if value == nil {
return ""
}
return fmt.Sprintf("%v", value)
}
func (o *OracleDB) appendOracleCommentDDL(baseDDL string, dbName string, tableName string) string {
table := strings.ToUpper(strings.TrimSpace(tableName))
if strings.TrimSpace(baseDDL) == "" || table == "" {
return baseDDL
}
schema := strings.ToUpper(strings.TrimSpace(dbName))
tableRef := quoteOracleDDLIdentifier(table)
if schema != "" {
tableRef = quoteOracleDDLIdentifier(schema) + "." + tableRef
}
existingDDLUpper := strings.ToUpper(baseDDL)
commentLines := make([]string, 0, 4)
if tableComment := strings.TrimSpace(o.fetchOracleTableComment(schema, table)); tableComment != "" {
marker := "COMMENT ON TABLE " + strings.ToUpper(tableRef)
if !strings.Contains(existingDDLUpper, marker) {
commentLines = append(commentLines, fmt.Sprintf("COMMENT ON TABLE %s IS '%s';", tableRef, escapeOracleCommentLiteral(tableComment)))
}
}
for _, colComment := range o.fetchOracleColumnComments(schema, table) {
columnName := strings.TrimSpace(colComment.columnName)
comment := strings.TrimSpace(colComment.comment)
if columnName == "" || comment == "" {
continue
}
columnRef := fmt.Sprintf("%s.%s", tableRef, quoteOracleDDLIdentifier(columnName))
marker := "COMMENT ON COLUMN " + strings.ToUpper(columnRef)
if strings.Contains(existingDDLUpper, marker) {
continue
}
commentLines = append(commentLines, fmt.Sprintf("COMMENT ON COLUMN %s IS '%s';", columnRef, escapeOracleCommentLiteral(comment)))
}
if len(commentLines) == 0 {
return baseDDL
}
return strings.TrimRight(baseDDL, " \t\r\n") + "\n" + strings.Join(commentLines, "\n")
}
func (o *OracleDB) fetchOracleTableComment(schema string, table string) string {
escapedTable := escapeOracleMetadataLiteral(table)
var query string
if strings.TrimSpace(schema) != "" {
query = fmt.Sprintf(`SELECT comments AS "COMMENT" FROM all_tab_comments WHERE owner = '%s' AND table_name = '%s' AND comments IS NOT NULL`, escapeOracleMetadataLiteral(schema), escapedTable)
} else {
query = fmt.Sprintf(`SELECT comments AS "COMMENT" FROM user_tab_comments WHERE table_name = '%s' AND comments IS NOT NULL`, escapedTable)
}
data, _, err := o.Query(query)
if err != nil || len(data) == 0 {
return ""
}
return oracleRowString(data[0], "COMMENT", "COMMENTS")
}
type oracleColumnComment struct {
columnName string
comment string
}
func (o *OracleDB) fetchOracleColumnComments(schema string, table string) []oracleColumnComment {
escapedTable := escapeOracleMetadataLiteral(table)
var query string
if strings.TrimSpace(schema) != "" {
query = fmt.Sprintf(`SELECT c.column_name AS "COLUMN_NAME", cc.comments AS "COMMENT"
FROM all_tab_columns c
JOIN all_col_comments cc
ON cc.owner = c.owner AND cc.table_name = c.table_name AND cc.column_name = c.column_name
WHERE c.owner = '%s' AND c.table_name = '%s' AND cc.comments IS NOT NULL
ORDER BY c.column_id`, escapeOracleMetadataLiteral(schema), escapedTable)
} else {
query = fmt.Sprintf(`SELECT c.column_name AS "COLUMN_NAME", cc.comments AS "COMMENT"
FROM user_tab_columns c
JOIN user_col_comments cc
ON cc.table_name = c.table_name AND cc.column_name = c.column_name
WHERE c.table_name = '%s' AND cc.comments IS NOT NULL
ORDER BY c.column_id`, escapedTable)
}
data, _, err := o.Query(query)
if err != nil {
return nil
}
comments := make([]oracleColumnComment, 0, len(data))
for _, row := range data {
comments = append(comments, oracleColumnComment{
columnName: oracleRowString(row, "COLUMN_NAME"),
comment: oracleRowString(row, "COMMENT", "COMMENTS"),
})
}
return comments
}
func quoteOracleDDLIdentifier(ident string) string {
return `"` + strings.ReplaceAll(strings.TrimSpace(ident), `"`, `""`) + `"`
}
func escapeOracleCommentLiteral(text string) string {
return strings.ReplaceAll(text, "'", "''")
}
func escapeOracleMetadataLiteral(text string) string {
return strings.ReplaceAll(strings.ToUpper(strings.TrimSpace(text)), "'", "''")
}
func (o *OracleDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
esc := func(s string) string { return strings.ReplaceAll(strings.ToUpper(strings.TrimSpace(s)), "'", "''") }
table := esc(tableName)