From 4b1cd1b7278676da34a65f90de1b9472ce517d71 Mon Sep 17 00:00:00 2001 From: Syngnat Date: Thu, 25 Jun 2026 10:50:07 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(oracle):=20=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=E8=A7=A6=E5=8F=91=E5=99=A8=E8=84=9A=E6=9C=AC=E6=98=BE?= =?UTF-8?q?=E7=A4=BA=E4=B8=BA=E7=A9=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 触发器列表查询补齐 OWNER、TABLE_NAME、TRIGGER_BODY 等 Oracle 元数据字段 - 优先使用 DBMS_METADATA.GET_DDL 返回完整 CREATE TRIGGER 脚本 - 在 DDL 不可用时基于 USER_TRIGGERS/ALL_TRIGGERS 重建可编辑触发器语句 - 补充 Oracle 触发器 DDL 获取与回退重建回归测试 --- internal/db/oracle_impl.go | 163 ++++++++++++++++++++++++++-- internal/db/oracle_triggers_test.go | 100 +++++++++++++++++ 2 files changed, 253 insertions(+), 10 deletions(-) create mode 100644 internal/db/oracle_triggers_test.go diff --git a/internal/db/oracle_impl.go b/internal/db/oracle_impl.go index ef51496..37e98f8 100644 --- a/internal/db/oracle_impl.go +++ b/internal/db/oracle_impl.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "net/url" + "regexp" "strconv" "strings" "time" @@ -28,6 +29,11 @@ type OracleDB struct { var _ SessionExecerProvider = (*OracleDB)(nil) var _ TransactionExecerProvider = (*OracleDB)(nil) +var ( + oracleTriggerCreatePattern = regexp.MustCompile(`(?is)^\s*CREATE\s+(?:OR\s+REPLACE\s+)?TRIGGER\b`) + oracleTriggerTimingPattern = regexp.MustCompile(`(?is)^\s*(?:BEFORE|AFTER|INSTEAD\s+OF)\b`) +) + func oracleRuntimeError(key string, params map[string]any) error { return fmt.Errorf("%s", localizedDriverRuntimeText(key, params)) } @@ -917,7 +923,7 @@ func (o *OracleDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDe if len(data) == 0 { continue } - return parseOracleTriggers(data), nil + return o.parseOracleTriggers(data), nil } return []connection.TriggerDefinition{}, nil } @@ -926,30 +932,167 @@ func buildOracleTriggersQuery(schema string, table string) string { metadataTableName := escapeOracleMetadataLiteralExact(table) metadataSchemaName := escapeOracleMetadataLiteralExact(schema) if strings.TrimSpace(schema) == "" { - return fmt.Sprintf(`SELECT trigger_name, trigger_type, triggering_event + return fmt.Sprintf(`SELECT USER AS "OWNER", USER AS "TABLE_OWNER", table_name AS "TABLE_NAME", trigger_name AS "TRIGGER_NAME", trigger_type AS "TRIGGER_TYPE", triggering_event AS "TRIGGERING_EVENT", when_clause AS "WHEN_CLAUSE", trigger_body AS "TRIGGER_BODY" FROM user_triggers - WHERE table_name = '%s'`, metadataTableName) + WHERE table_name = '%s' + ORDER BY trigger_name`, metadataTableName) } - return fmt.Sprintf(`SELECT trigger_name, trigger_type, triggering_event + return fmt.Sprintf(`SELECT owner AS "OWNER", table_owner AS "TABLE_OWNER", table_name AS "TABLE_NAME", trigger_name AS "TRIGGER_NAME", trigger_type AS "TRIGGER_TYPE", triggering_event AS "TRIGGERING_EVENT", when_clause AS "WHEN_CLAUSE", trigger_body AS "TRIGGER_BODY" FROM all_triggers - WHERE table_owner = '%s' AND table_name = '%s'`, + WHERE table_owner = '%s' AND table_name = '%s' + ORDER BY owner, trigger_name`, metadataSchemaName, metadataTableName) } -func parseOracleTriggers(data []map[string]interface{}) []connection.TriggerDefinition { +func (o *OracleDB) parseOracleTriggers(data []map[string]interface{}) []connection.TriggerDefinition { var triggers []connection.TriggerDefinition for _, row := range data { + owner := oracleRowString(row, "OWNER") + triggerName := oracleRowString(row, "TRIGGER_NAME") + statement := strings.TrimSpace(o.fetchOracleTriggerDDL(owner, triggerName)) + if statement == "" { + statement = buildOracleTriggerDDLFromMetadata(row) + } + trig := connection.TriggerDefinition{ - Name: fmt.Sprintf("%v", row["TRIGGER_NAME"]), - Timing: fmt.Sprintf("%v", row["TRIGGER_TYPE"]), - Event: fmt.Sprintf("%v", row["TRIGGERING_EVENT"]), - Statement: "SOURCE HIDDEN", // Requires more complex query to get body + Name: triggerName, + Timing: oracleRowString(row, "TRIGGER_TYPE"), + Event: oracleRowString(row, "TRIGGERING_EVENT"), + Statement: statement, } triggers = append(triggers, trig) } return triggers } +func (o *OracleDB) fetchOracleTriggerDDL(owner string, triggerName string) string { + if strings.TrimSpace(triggerName) == "" { + return "" + } + for _, candidate := range oracleMetadataNamePairs(owner, triggerName) { + metadataTriggerName := escapeOracleMetadataLiteralExact(candidate.table) + metadataOwnerName := escapeOracleMetadataLiteralExact(candidate.schema) + query := fmt.Sprintf("SELECT DBMS_METADATA.GET_DDL('TRIGGER', '%s', '%s') as ddl FROM DUAL", + metadataTriggerName, metadataOwnerName) + if candidate.schema == "" { + query = fmt.Sprintf("SELECT DBMS_METADATA.GET_DDL('TRIGGER', '%s') as ddl FROM DUAL", metadataTriggerName) + } + + data, _, err := o.Query(query) + if err != nil || len(data) == 0 { + continue + } + ddl := oracleRowString(data[0], "DDL", "ddl", "TRIGGER_DEFINITION", "trigger_definition") + if ddl != "" { + return ensureOracleDDLStatementTerminator(ddl) + } + } + return "" +} + +func buildOracleTriggerDDLFromMetadata(row map[string]interface{}) string { + body := strings.TrimSpace(oracleRowString(row, "TRIGGER_BODY")) + if body == "" || strings.EqualFold(body, "SOURCE HIDDEN") { + return "" + } + + if startsWithOracleTriggerCreate(body) { + return ensureOracleDDLStatementTerminator(body) + } + + triggerName := oracleRowString(row, "TRIGGER_NAME") + if triggerName == "" { + return "" + } + + if strings.HasPrefix(strings.ToUpper(body), "TRIGGER ") { + return ensureOracleDDLStatementTerminator("CREATE OR REPLACE " + body) + } + + triggerOwner := oracleRowString(row, "OWNER") + tableOwner := oracleRowString(row, "TABLE_OWNER") + tableName := oracleRowString(row, "TABLE_NAME") + triggerRef := quoteOracleTableRef(triggerOwner, triggerName) + + if startsWithOracleTriggerTiming(body) { + return ensureOracleDDLStatementTerminator(fmt.Sprintf("CREATE OR REPLACE TRIGGER %s\n%s", triggerRef, body)) + } + + triggerClause := buildOracleTriggerClause( + oracleRowString(row, "TRIGGER_TYPE"), + oracleRowString(row, "TRIGGERING_EVENT"), + oracleTriggerTableRef(tableOwner, tableName), + ) + if triggerClause == "" { + return "" + } + + lines := []string{ + fmt.Sprintf("CREATE OR REPLACE TRIGGER %s", triggerRef), + triggerClause, + } + if shouldAppendOracleForEachRow(oracleRowString(row, "TRIGGER_TYPE")) { + lines = append(lines, "FOR EACH ROW") + } + if whenClause := normalizeOracleTriggerWhenClause(oracleRowString(row, "WHEN_CLAUSE")); whenClause != "" { + lines = append(lines, whenClause) + } + lines = append(lines, body) + return ensureOracleDDLStatementTerminator(strings.Join(lines, "\n")) +} + +func startsWithOracleTriggerCreate(sql string) bool { + return oracleTriggerCreatePattern.MatchString(sql) +} + +func startsWithOracleTriggerTiming(sql string) bool { + return oracleTriggerTimingPattern.MatchString(sql) +} + +func oracleTriggerTableRef(tableOwner string, tableName string) string { + if strings.TrimSpace(tableName) == "" { + return "" + } + return quoteOracleTableRef(tableOwner, tableName) +} + +func buildOracleTriggerClause(triggerType string, event string, tableRef string) string { + normalizedType := strings.ToUpper(strings.TrimSpace(triggerType)) + normalizedEvent := strings.TrimSpace(event) + if tableRef == "" || normalizedEvent == "" { + return "" + } + + switch { + case strings.HasPrefix(normalizedType, "BEFORE"): + return fmt.Sprintf("BEFORE %s ON %s", normalizedEvent, tableRef) + case strings.HasPrefix(normalizedType, "AFTER"): + return fmt.Sprintf("AFTER %s ON %s", normalizedEvent, tableRef) + case strings.HasPrefix(normalizedType, "INSTEAD OF"): + return fmt.Sprintf("INSTEAD OF %s ON %s", normalizedEvent, tableRef) + case strings.Contains(normalizedType, "COMPOUND"): + return fmt.Sprintf("FOR %s ON %s", normalizedEvent, tableRef) + default: + return fmt.Sprintf("%s %s ON %s", strings.TrimSpace(triggerType), normalizedEvent, tableRef) + } +} + +func shouldAppendOracleForEachRow(triggerType string) bool { + normalizedType := strings.ToUpper(strings.TrimSpace(triggerType)) + return strings.Contains(normalizedType, "EACH ROW") && !strings.HasPrefix(normalizedType, "INSTEAD OF") +} + +func normalizeOracleTriggerWhenClause(whenClause string) string { + trimmed := strings.TrimSpace(whenClause) + if trimmed == "" { + return "" + } + if strings.HasPrefix(trimmed, "(") && strings.HasSuffix(trimmed, ")") { + return "WHEN " + trimmed + } + return "WHEN (" + trimmed + ")" +} + func splitOracleQualifiedTableName(raw string) (string, string) { table := strings.TrimSpace(raw) schema := "" diff --git a/internal/db/oracle_triggers_test.go b/internal/db/oracle_triggers_test.go new file mode 100644 index 0000000..032743c --- /dev/null +++ b/internal/db/oracle_triggers_test.go @@ -0,0 +1,100 @@ +package db + +import ( + "database/sql/driver" + "slices" + "strings" + "testing" +) + +func TestOracleGetTriggersUsesDBMSMetadataDDL(t *testing.T) { + t.Parallel() + + dbConn, state := openOracleRecordingDB(t) + triggerListQuery := buildOracleTriggersQuery("H2", "T_MEMCARD_CASH") + triggerDDLQuery := `SELECT DBMS_METADATA.GET_DDL('TRIGGER', 'TR_T_MEMCARD_CASH', 'H2') as ddl FROM DUAL` + metadataDDL := `CREATE OR REPLACE TRIGGER "H2"."TR_T_MEMCARD_CASH" +BEFORE INSERT ON "H2"."T_MEMCARD_CASH" +BEGIN + NULL; +END;` + + state.mu.Lock() + state.queryResults[triggerListQuery] = oracleRecordingQueryResult{ + columns: []string{"OWNER", "TABLE_OWNER", "TABLE_NAME", "TRIGGER_NAME", "TRIGGER_TYPE", "TRIGGERING_EVENT", "WHEN_CLAUSE", "TRIGGER_BODY"}, + rows: [][]driver.Value{ + {"H2", "H2", "T_MEMCARD_CASH", "TR_T_MEMCARD_CASH", "BEFORE EACH ROW", "INSERT", nil, "SOURCE HIDDEN"}, + }, + } + state.queryResults[triggerDDLQuery] = oracleRecordingQueryResult{ + columns: []string{"DDL"}, + rows: [][]driver.Value{ + {metadataDDL}, + }, + } + state.mu.Unlock() + + oracleDB := &OracleDB{conn: dbConn} + triggers, err := oracleDB.GetTriggers("H2", "T_MEMCARD_CASH") + if err != nil { + t.Fatalf("GetTriggers 返回错误: %v", err) + } + if len(triggers) != 1 { + t.Fatalf("期望返回 1 个触发器,实际 %#v", triggers) + } + if !strings.Contains(triggers[0].Statement, `CREATE OR REPLACE TRIGGER "H2"."TR_T_MEMCARD_CASH"`) { + t.Fatalf("期望返回 DBMS_METADATA 完整 DDL,实际: %s", triggers[0].Statement) + } + if strings.Contains(triggers[0].Statement, "SOURCE HIDDEN") { + t.Fatalf("触发器语句不应继续返回 SOURCE HIDDEN: %s", triggers[0].Statement) + } + if queries := state.snapshotQueries(); !slices.Contains(queries, triggerDDLQuery) { + t.Fatalf("期望查询 DBMS_METADATA 获取触发器 DDL,实际 queries=%v", queries) + } +} + +func TestOracleGetTriggersRebuildsDDLFromTriggerBodyWhenMetadataDDLIsEmpty(t *testing.T) { + t.Parallel() + + dbConn, state := openOracleRecordingDB(t) + triggerListQuery := buildOracleTriggersQuery("H2", "T_MEMCARD_CASH") + triggerDDLQuery := `SELECT DBMS_METADATA.GET_DDL('TRIGGER', 'TR_T_MEMCARD_CASH', 'H2') as ddl FROM DUAL` + + state.mu.Lock() + state.queryResults[triggerListQuery] = oracleRecordingQueryResult{ + columns: []string{"OWNER", "TABLE_OWNER", "TABLE_NAME", "TRIGGER_NAME", "TRIGGER_TYPE", "TRIGGERING_EVENT", "WHEN_CLAUSE", "TRIGGER_BODY"}, + rows: [][]driver.Value{ + {"H2", "H2", "T_MEMCARD_CASH", "TR_T_MEMCARD_CASH", "BEFORE EACH ROW", "INSERT OR UPDATE", "NEW.ID IS NOT NULL", "BEGIN\n :NEW.UPDATED_AT := SYSDATE;\nEND;"}, + }, + } + state.queryResults[triggerDDLQuery] = oracleRecordingQueryResult{ + columns: []string{"DDL"}, + rows: [][]driver.Value{}, + } + state.mu.Unlock() + + oracleDB := &OracleDB{conn: dbConn} + triggers, err := oracleDB.GetTriggers("H2", "T_MEMCARD_CASH") + if err != nil { + t.Fatalf("GetTriggers 返回错误: %v", err) + } + if len(triggers) != 1 { + t.Fatalf("期望返回 1 个触发器,实际 %#v", triggers) + } + + statement := triggers[0].Statement + for _, want := range []string{ + `CREATE OR REPLACE TRIGGER "H2"."TR_T_MEMCARD_CASH"`, + `BEFORE INSERT OR UPDATE ON "H2"."T_MEMCARD_CASH"`, + `FOR EACH ROW`, + `WHEN (NEW.ID IS NOT NULL)`, + `:NEW.UPDATED_AT := SYSDATE;`, + } { + if !strings.Contains(statement, want) { + t.Fatalf("期望重建后的触发器 DDL 包含 %q,实际: %s", want, statement) + } + } + if strings.Contains(statement, "SOURCE HIDDEN") { + t.Fatalf("触发器语句不应继续返回 SOURCE HIDDEN: %s", statement) + } +}