From 64021ffd2a9eb6549a8222ec99a2e5025819bad8 Mon Sep 17 00:00:00 2001 From: Syngnat Date: Wed, 18 Mar 2026 14:32:11 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(batch-truncate/query):=20?= =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E6=89=B9=E9=87=8F=E6=B8=85=E7=A9=BA=E8=A1=A8?= =?UTF-8?q?=E5=AE=89=E5=85=A8=E9=9A=90=E6=82=A3=E5=B9=B6=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E5=A4=9A=E8=AF=AD=E5=8F=A5=E6=89=A7=E8=A1=8C=E9=94=99=E8=AF=AF?= =?UTF-8?q?=E5=8F=8D=E9=A6=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 安全加固:TruncateTables 增加审计日志(Warnf 级别)和参数校验(上限 200 张) - 容错增强:批量清空部分失败时返回已执行 SQL 列表并提示已清空表不可恢复 - 错误优化:DBQueryMulti 逐条执行失败时附带语句序号和已成功条数 - 性能优化:splitSQLStatements 从 string 拼接改为 strings.Builder,消除 O(n²) 分配 - 转义修复:splitSQLStatements 支持 SQL 标准转义单引号 '' 防止误拆分 - 前端修复:handleBatchClear 统一取消判断字符串为 '已取消' 并移除冗余变量声明 - refs #244 --- frontend/package.json.md5 | 2 +- frontend/src/components/Sidebar.tsx | 5 ++- internal/app/methods_db.go | 18 ++++++++--- internal/app/methods_file.go | 44 ++++++++++++++++++++++---- internal/app/sql_split.go | 48 +++++++++++++++++------------ internal/app/sql_split_test.go | 19 ++++++++++++ 6 files changed, 101 insertions(+), 35 deletions(-) diff --git a/frontend/package.json.md5 b/frontend/package.json.md5 index a7661c0..0f8f4fe 100755 --- a/frontend/package.json.md5 +++ b/frontend/package.json.md5 @@ -1 +1 @@ -d0f9366af59a6367ad3c7e2d4185ead4 \ No newline at end of file +5b8157374dae5f9340e31b2d0bd2c00e \ No newline at end of file diff --git a/frontend/src/components/Sidebar.tsx b/frontend/src/components/Sidebar.tsx index 564d3e4..520602d 100644 --- a/frontend/src/components/Sidebar.tsx +++ b/frontend/src/components/Sidebar.tsx @@ -1838,8 +1838,8 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }> try { const app = (window as any).go.app.App; const res = await app.TruncateTables(normalizeConnConfig(conn.config), dbName, objectNames); - const duration = Date.now() - startTime; hide(); + const duration = Date.now() - startTime; if (res.success) { message.success('清空成功'); // 构造 SQL 日志 @@ -1859,10 +1859,9 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }> dbName, affectedRows: res.data?.count || 0 }); - } else if (res.message !== 'Cancelled') { + } else if (res.message !== '已取消') { message.error('清空失败: ' + res.message); // 记录失败的日志 - const duration = Date.now() - startTime; let logSql = `/* Truncate Tables (${objectNames.length} tables) - FAILED */\n`; if (res.data && res.data.executedSQLs && Array.isArray(res.data.executedSQLs)) { logSql += res.data.executedSQLs.join(';\n') + ';'; diff --git a/internal/app/methods_db.go b/internal/app/methods_db.go index c591f7f..601a6a5 100644 --- a/internal/app/methods_db.go +++ b/internal/app/methods_db.go @@ -566,7 +566,7 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu } var resultSets []connection.ResultSetData - for _, stmt := range statements { + for idx, stmt := range statements { stmt = strings.TrimSpace(stmt) if stmt == "" { continue @@ -583,8 +583,12 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu data, columns, err = dbInst.Query(stmt) } if err != nil { - logger.Error(err, "DBQueryMulti 逐条查询失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(stmt)) - return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID} + logger.Error(err, "DBQueryMulti 逐条查询失败(第 %d/%d 条):%s SQL片段=%q", idx+1, len(statements), formatConnSummary(runConfig), sqlSnippet(stmt)) + errMsg := fmt.Sprintf("第 %d 条语句执行失败: %v", idx+1, err) + if len(resultSets) > 0 { + errMsg += fmt.Sprintf("(前 %d 条已执行成功)", len(resultSets)) + } + return connection.QueryResult{Success: false, Message: errMsg, QueryID: queryID} } if data == nil { data = make([]map[string]interface{}, 0) @@ -603,8 +607,12 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu affected, err = dbInst.Exec(stmt) } if err != nil { - logger.Error(err, "DBQueryMulti 逐条执行失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(stmt)) - return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID} + logger.Error(err, "DBQueryMulti 逐条执行失败(第 %d/%d 条):%s SQL片段=%q", idx+1, len(statements), formatConnSummary(runConfig), sqlSnippet(stmt)) + errMsg := fmt.Sprintf("第 %d 条语句执行失败: %v", idx+1, err) + if len(resultSets) > 0 { + errMsg += fmt.Sprintf("(前 %d 条已执行成功)", len(resultSets)) + } + return connection.QueryResult{Success: false, Message: errMsg, QueryID: queryID} } resultSets = append(resultSets, connection.ResultSetData{ Rows: []map[string]interface{}{{"affectedRows": affected}}, diff --git a/internal/app/methods_file.go b/internal/app/methods_file.go index 0c57ac1..3e2c581 100644 --- a/internal/app/methods_file.go +++ b/internal/app/methods_file.go @@ -775,13 +775,15 @@ func (a *App) ExportDatabaseSQL(config connection.ConnectionConfig, dbName strin return connection.QueryResult{Success: true, Message: "导出完成"} } -// TruncateTables 清空指定表的数据(针对 MySQL 使用 TRUNCATE,MongoDB 使用 delete,否则使用 DELETE) +// TruncateTables 清空指定表的数据(针对 MySQL 使用 TRUNCATE,MongoDB 使用 delete,否则使用 DELETE)。 +// 注意:MySQL 的 TRUNCATE TABLE 是 DDL 操作,无法事务回滚;批量清空为逐表执行, +// 如果中途失败,已清空的表无法恢复。错误结果会附带已执行的 SQL 列表供排查。 func (a *App) TruncateTables(config connection.ConnectionConfig, dbName string, tableNames []string) connection.QueryResult { runConfig := normalizeRunConfig(config, dbName) - dbInst, err := a.getDatabase(runConfig) - if err != nil { - return connection.QueryResult{Success: false, Message: err.Error()} + // 参数校验 + if len(tableNames) == 0 { + return connection.QueryResult{Success: false, Message: "未指定要清空的表"} } objects := make([]string, 0, len(tableNames)) @@ -798,9 +800,25 @@ func (a *App) TruncateTables(config connection.ConnectionConfig, dbName string, objects = append(objects, tt) } + if len(objects) == 0 { + return connection.QueryResult{Success: false, Message: "未指定要清空的表"} + } + const maxBatchSize = 200 + if len(objects) > maxBatchSize { + return connection.QueryResult{Success: false, Message: fmt.Sprintf("单次最多清空 %d 张表,当前选中 %d 张", maxBatchSize, len(objects))} + } + + dbInst, err := a.getDatabase(runConfig) + if err != nil { + return connection.QueryResult{Success: false, Message: err.Error()} + } + + // 审计日志:记录清空操作的发起 + logger.Warnf("TruncateTables 开始:%s db=%s tables=%v(共 %d 张)", formatConnSummary(runConfig), dbName, objects, len(objects)) + dbType := strings.ToLower(strings.TrimSpace(runConfig.Type)) var executedSQLs []string - for _, objectName := range objects { + for i, objectName := range objects { var sql string if dbType == "mysql" || dbType == "mariadb" { sql = fmt.Sprintf("TRUNCATE TABLE %s", quoteQualifiedIdentByType(runConfig.Type, objectName)) @@ -813,11 +831,25 @@ func (a *App) TruncateTables(config connection.ConnectionConfig, dbName string, } if _, err := dbInst.Exec(sql); err != nil { - return connection.QueryResult{Success: false, Message: fmt.Sprintf("清空 %s 失败: %v", objectName, err)} + logger.Warnf("TruncateTables 第 %d/%d 张表失败:%s table=%s err=%v(已成功清空 %d 张)", i+1, len(objects), formatConnSummary(runConfig), objectName, err, len(executedSQLs)) + errMsg := fmt.Sprintf("清空 %s 失败: %v", objectName, err) + if len(executedSQLs) > 0 { + errMsg += fmt.Sprintf("(注意:前 %d 张表已清空且无法恢复)", len(executedSQLs)) + } + return connection.QueryResult{ + Success: false, + Message: errMsg, + Data: map[string]interface{}{ + "executedSQLs": executedSQLs, + "count": len(executedSQLs), + }, + } } executedSQLs = append(executedSQLs, sql) } + logger.Warnf("TruncateTables 完成:%s db=%s 共清空 %d 张表", formatConnSummary(runConfig), dbName, len(executedSQLs)) + return connection.QueryResult{ Success: true, Message: "清空成功", diff --git a/internal/app/sql_split.go b/internal/app/sql_split.go index de2ee8c..f73cebd 100644 --- a/internal/app/sql_split.go +++ b/internal/app/sql_split.go @@ -5,11 +5,12 @@ import "strings" // splitSQLStatements 按分号拆分 SQL 文本为独立语句。 // 正确处理单引号/双引号/反引号字符串、行注释(-- / #)、块注释(/* */)和 // PostgreSQL/Kingbase 的 $$...$$ dollar-quoting,避免在这些上下文中错误拆分。 +// 同时支持 SQL 标准的转义单引号(两个连续单引号 '' 表示字面量引号)。 func splitSQLStatements(sql string) []string { text := strings.ReplaceAll(sql, "\r\n", "\n") var statements []string - cur := "" + var cur strings.Builder inSingle := false inDouble := false inBacktick := false @@ -19,11 +20,11 @@ func splitSQLStatements(sql string) []string { var dollarTag string // postgres/kingbase: $$...$$ or $tag$...$tag$ push := func() { - s := strings.TrimSpace(cur) + s := strings.TrimSpace(cur.String()) if s != "" { statements = append(statements, s) } - cur = "" + cur.Reset() } for i := 0; i < len(text); i++ { @@ -38,15 +39,15 @@ func splitSQLStatements(sql string) []string { if ch == '\n' { inLineComment = false } - cur += string(ch) + cur.WriteByte(ch) continue } // 块注释 if inBlockComment { - cur += string(ch) + cur.WriteByte(ch) if ch == '*' && next == '/' { - cur += "/" + cur.WriteByte('/') i++ inBlockComment = false } @@ -56,66 +57,73 @@ func splitSQLStatements(sql string) []string { // Dollar-quoting if dollarTag != "" { if strings.HasPrefix(text[i:], dollarTag) { - cur += dollarTag + cur.WriteString(dollarTag) i += len(dollarTag) - 1 dollarTag = "" } else { - cur += string(ch) + cur.WriteByte(ch) } continue } - // 转义字符 + // 转义字符(反斜杠转义,MySQL 风格) if escaped { escaped = false - cur += string(ch) + cur.WriteByte(ch) continue } if (inSingle || inDouble) && ch == '\\' { escaped = true - cur += string(ch) + cur.WriteByte(ch) continue } // 字符串开闭 if !inDouble && !inBacktick && ch == '\'' { + if inSingle && next == '\'' { + // SQL 标准转义:两个连续单引号 '' 表示字面量引号,保持在引号内 + cur.WriteByte(ch) + cur.WriteByte(next) + i++ + continue + } inSingle = !inSingle - cur += string(ch) + cur.WriteByte(ch) continue } if !inSingle && !inBacktick && ch == '"' { inDouble = !inDouble - cur += string(ch) + cur.WriteByte(ch) continue } if !inSingle && !inDouble && ch == '`' { inBacktick = !inBacktick - cur += string(ch) + cur.WriteByte(ch) continue } // 在引号/反引号内部不做任何判断 if inSingle || inDouble || inBacktick { - cur += string(ch) + cur.WriteByte(ch) continue } // 行注释开始 if ch == '-' && next == '-' { inLineComment = true - cur += string(ch) + cur.WriteByte(ch) continue } if ch == '#' { inLineComment = true - cur += string(ch) + cur.WriteByte(ch) continue } // 块注释开始 if ch == '/' && next == '*' { inBlockComment = true - cur += "/*" + cur.WriteString("/*") i++ continue } @@ -124,7 +132,7 @@ func splitSQLStatements(sql string) []string { if ch == '$' { if tag := parseSQLDollarTag(text[i:]); tag != "" { dollarTag = tag - cur += tag + cur.WriteString(tag) i += len(tag) - 1 continue } @@ -142,7 +150,7 @@ func splitSQLStatements(sql string) []string { continue } - cur += string(ch) + cur.WriteByte(ch) } push() diff --git a/internal/app/sql_split_test.go b/internal/app/sql_split_test.go index 3c248fe..7c9c3d3 100644 --- a/internal/app/sql_split_test.go +++ b/internal/app/sql_split_test.go @@ -92,3 +92,22 @@ func TestSplitSQLStatements_TrailingSemicolon(t *testing.T) { t.Errorf("splitSQLStatements(%q) = %v, want %v", input, got, want) } } + +func TestSplitSQLStatements_SQLEscapedQuote(t *testing.T) { + input := "SELECT 'it''s a test'; SELECT 2" + got := splitSQLStatements(input) + want := []string{"SELECT 'it''s a test'", "SELECT 2"} + if !reflect.DeepEqual(got, want) { + t.Errorf("splitSQLStatements(%q) = %v, want %v", input, got, want) + } +} + +func TestSplitSQLStatements_SQLEscapedQuoteMultiple(t *testing.T) { + input := "INSERT INTO t VALUES ('O''Brien', 'it''s OK'); SELECT 1" + got := splitSQLStatements(input) + want := []string{"INSERT INTO t VALUES ('O''Brien', 'it''s OK')", "SELECT 1"} + if !reflect.DeepEqual(got, want) { + t.Errorf("splitSQLStatements(%q) = %v, want %v", input, got, want) + } +} +