🐛 fix(batch-truncate/query): 修复批量清空表安全隐患并优化多语句执行错误反馈

- 安全加固:TruncateTables 增加审计日志(Warnf 级别)和参数校验(上限 200 张)
- 容错增强:批量清空部分失败时返回已执行 SQL 列表并提示已清空表不可恢复
- 错误优化:DBQueryMulti 逐条执行失败时附带语句序号和已成功条数
- 性能优化:splitSQLStatements 从 string 拼接改为 strings.Builder,消除 O(n²) 分配
- 转义修复:splitSQLStatements 支持 SQL 标准转义单引号 '' 防止误拆分
- 前端修复:handleBatchClear 统一取消判断字符串为 '已取消' 并移除冗余变量声明
- refs #244
This commit is contained in:
Syngnat
2026-03-18 14:32:11 +08:00
parent fbd785400f
commit 64021ffd2a
6 changed files with 101 additions and 35 deletions

View File

@@ -1 +1 @@
d0f9366af59a6367ad3c7e2d4185ead4
5b8157374dae5f9340e31b2d0bd2c00e

View File

@@ -1838,8 +1838,8 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
try {
const app = (window as any).go.app.App;
const res = await app.TruncateTables(normalizeConnConfig(conn.config), dbName, objectNames);
const duration = Date.now() - startTime;
hide();
const duration = Date.now() - startTime;
if (res.success) {
message.success('清空成功');
// 构造 SQL 日志
@@ -1859,10 +1859,9 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
dbName,
affectedRows: res.data?.count || 0
});
} else if (res.message !== 'Cancelled') {
} else if (res.message !== '已取消') {
message.error('清空失败: ' + res.message);
// 记录失败的日志
const duration = Date.now() - startTime;
let logSql = `/* Truncate Tables (${objectNames.length} tables) - FAILED */\n`;
if (res.data && res.data.executedSQLs && Array.isArray(res.data.executedSQLs)) {
logSql += res.data.executedSQLs.join(';\n') + ';';

View File

@@ -566,7 +566,7 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
}
var resultSets []connection.ResultSetData
for _, stmt := range statements {
for idx, stmt := range statements {
stmt = strings.TrimSpace(stmt)
if stmt == "" {
continue
@@ -583,8 +583,12 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
data, columns, err = dbInst.Query(stmt)
}
if err != nil {
logger.Error(err, "DBQueryMulti 逐条查询失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(stmt))
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
logger.Error(err, "DBQueryMulti 逐条查询失败(第 %d/%d 条)%s SQL片段=%q", idx+1, len(statements), formatConnSummary(runConfig), sqlSnippet(stmt))
errMsg := fmt.Sprintf("第 %d 条语句执行失败: %v", idx+1, err)
if len(resultSets) > 0 {
errMsg += fmt.Sprintf("(前 %d 条已执行成功)", len(resultSets))
}
return connection.QueryResult{Success: false, Message: errMsg, QueryID: queryID}
}
if data == nil {
data = make([]map[string]interface{}, 0)
@@ -603,8 +607,12 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
affected, err = dbInst.Exec(stmt)
}
if err != nil {
logger.Error(err, "DBQueryMulti 逐条执行失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(stmt))
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
logger.Error(err, "DBQueryMulti 逐条执行失败(第 %d/%d 条)%s SQL片段=%q", idx+1, len(statements), formatConnSummary(runConfig), sqlSnippet(stmt))
errMsg := fmt.Sprintf("第 %d 条语句执行失败: %v", idx+1, err)
if len(resultSets) > 0 {
errMsg += fmt.Sprintf("(前 %d 条已执行成功)", len(resultSets))
}
return connection.QueryResult{Success: false, Message: errMsg, QueryID: queryID}
}
resultSets = append(resultSets, connection.ResultSetData{
Rows: []map[string]interface{}{{"affectedRows": affected}},

View File

@@ -775,13 +775,15 @@ func (a *App) ExportDatabaseSQL(config connection.ConnectionConfig, dbName strin
return connection.QueryResult{Success: true, Message: "导出完成"}
}
// TruncateTables 清空指定表的数据(针对 MySQL 使用 TRUNCATEMongoDB 使用 delete否则使用 DELETE
// TruncateTables 清空指定表的数据(针对 MySQL 使用 TRUNCATEMongoDB 使用 delete否则使用 DELETE
// 注意MySQL 的 TRUNCATE TABLE 是 DDL 操作,无法事务回滚;批量清空为逐表执行,
// 如果中途失败,已清空的表无法恢复。错误结果会附带已执行的 SQL 列表供排查。
func (a *App) TruncateTables(config connection.ConnectionConfig, dbName string, tableNames []string) connection.QueryResult {
runConfig := normalizeRunConfig(config, dbName)
dbInst, err := a.getDatabase(runConfig)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
// 参数校验
if len(tableNames) == 0 {
return connection.QueryResult{Success: false, Message: "未指定要清空的表"}
}
objects := make([]string, 0, len(tableNames))
@@ -798,9 +800,25 @@ func (a *App) TruncateTables(config connection.ConnectionConfig, dbName string,
objects = append(objects, tt)
}
if len(objects) == 0 {
return connection.QueryResult{Success: false, Message: "未指定要清空的表"}
}
const maxBatchSize = 200
if len(objects) > maxBatchSize {
return connection.QueryResult{Success: false, Message: fmt.Sprintf("单次最多清空 %d 张表,当前选中 %d 张", maxBatchSize, len(objects))}
}
dbInst, err := a.getDatabase(runConfig)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
// 审计日志:记录清空操作的发起
logger.Warnf("TruncateTables 开始:%s db=%s tables=%v共 %d 张)", formatConnSummary(runConfig), dbName, objects, len(objects))
dbType := strings.ToLower(strings.TrimSpace(runConfig.Type))
var executedSQLs []string
for _, objectName := range objects {
for i, objectName := range objects {
var sql string
if dbType == "mysql" || dbType == "mariadb" {
sql = fmt.Sprintf("TRUNCATE TABLE %s", quoteQualifiedIdentByType(runConfig.Type, objectName))
@@ -813,11 +831,25 @@ func (a *App) TruncateTables(config connection.ConnectionConfig, dbName string,
}
if _, err := dbInst.Exec(sql); err != nil {
return connection.QueryResult{Success: false, Message: fmt.Sprintf("清空 %s 失败: %v", objectName, err)}
logger.Warnf("TruncateTables 第 %d/%d 张表失败:%s table=%s err=%v已成功清空 %d 张)", i+1, len(objects), formatConnSummary(runConfig), objectName, err, len(executedSQLs))
errMsg := fmt.Sprintf("清空 %s 失败: %v", objectName, err)
if len(executedSQLs) > 0 {
errMsg += fmt.Sprintf("(注意:前 %d 张表已清空且无法恢复)", len(executedSQLs))
}
return connection.QueryResult{
Success: false,
Message: errMsg,
Data: map[string]interface{}{
"executedSQLs": executedSQLs,
"count": len(executedSQLs),
},
}
}
executedSQLs = append(executedSQLs, sql)
}
logger.Warnf("TruncateTables 完成:%s db=%s 共清空 %d 张表", formatConnSummary(runConfig), dbName, len(executedSQLs))
return connection.QueryResult{
Success: true,
Message: "清空成功",

View File

@@ -5,11 +5,12 @@ import "strings"
// splitSQLStatements 按分号拆分 SQL 文本为独立语句。
// 正确处理单引号/双引号/反引号字符串、行注释(-- / #)、块注释(/* */)和
// PostgreSQL/Kingbase 的 $$...$$ dollar-quoting避免在这些上下文中错误拆分。
// 同时支持 SQL 标准的转义单引号(两个连续单引号 '' 表示字面量引号)。
func splitSQLStatements(sql string) []string {
text := strings.ReplaceAll(sql, "\r\n", "\n")
var statements []string
cur := ""
var cur strings.Builder
inSingle := false
inDouble := false
inBacktick := false
@@ -19,11 +20,11 @@ func splitSQLStatements(sql string) []string {
var dollarTag string // postgres/kingbase: $$...$$ or $tag$...$tag$
push := func() {
s := strings.TrimSpace(cur)
s := strings.TrimSpace(cur.String())
if s != "" {
statements = append(statements, s)
}
cur = ""
cur.Reset()
}
for i := 0; i < len(text); i++ {
@@ -38,15 +39,15 @@ func splitSQLStatements(sql string) []string {
if ch == '\n' {
inLineComment = false
}
cur += string(ch)
cur.WriteByte(ch)
continue
}
// 块注释
if inBlockComment {
cur += string(ch)
cur.WriteByte(ch)
if ch == '*' && next == '/' {
cur += "/"
cur.WriteByte('/')
i++
inBlockComment = false
}
@@ -56,66 +57,73 @@ func splitSQLStatements(sql string) []string {
// Dollar-quoting
if dollarTag != "" {
if strings.HasPrefix(text[i:], dollarTag) {
cur += dollarTag
cur.WriteString(dollarTag)
i += len(dollarTag) - 1
dollarTag = ""
} else {
cur += string(ch)
cur.WriteByte(ch)
}
continue
}
// 转义字符
// 转义字符反斜杠转义MySQL 风格)
if escaped {
escaped = false
cur += string(ch)
cur.WriteByte(ch)
continue
}
if (inSingle || inDouble) && ch == '\\' {
escaped = true
cur += string(ch)
cur.WriteByte(ch)
continue
}
// 字符串开闭
if !inDouble && !inBacktick && ch == '\'' {
if inSingle && next == '\'' {
// SQL 标准转义:两个连续单引号 '' 表示字面量引号,保持在引号内
cur.WriteByte(ch)
cur.WriteByte(next)
i++
continue
}
inSingle = !inSingle
cur += string(ch)
cur.WriteByte(ch)
continue
}
if !inSingle && !inBacktick && ch == '"' {
inDouble = !inDouble
cur += string(ch)
cur.WriteByte(ch)
continue
}
if !inSingle && !inDouble && ch == '`' {
inBacktick = !inBacktick
cur += string(ch)
cur.WriteByte(ch)
continue
}
// 在引号/反引号内部不做任何判断
if inSingle || inDouble || inBacktick {
cur += string(ch)
cur.WriteByte(ch)
continue
}
// 行注释开始
if ch == '-' && next == '-' {
inLineComment = true
cur += string(ch)
cur.WriteByte(ch)
continue
}
if ch == '#' {
inLineComment = true
cur += string(ch)
cur.WriteByte(ch)
continue
}
// 块注释开始
if ch == '/' && next == '*' {
inBlockComment = true
cur += "/*"
cur.WriteString("/*")
i++
continue
}
@@ -124,7 +132,7 @@ func splitSQLStatements(sql string) []string {
if ch == '$' {
if tag := parseSQLDollarTag(text[i:]); tag != "" {
dollarTag = tag
cur += tag
cur.WriteString(tag)
i += len(tag) - 1
continue
}
@@ -142,7 +150,7 @@ func splitSQLStatements(sql string) []string {
continue
}
cur += string(ch)
cur.WriteByte(ch)
}
push()

View File

@@ -92,3 +92,22 @@ func TestSplitSQLStatements_TrailingSemicolon(t *testing.T) {
t.Errorf("splitSQLStatements(%q) = %v, want %v", input, got, want)
}
}
func TestSplitSQLStatements_SQLEscapedQuote(t *testing.T) {
input := "SELECT 'it''s a test'; SELECT 2"
got := splitSQLStatements(input)
want := []string{"SELECT 'it''s a test'", "SELECT 2"}
if !reflect.DeepEqual(got, want) {
t.Errorf("splitSQLStatements(%q) = %v, want %v", input, got, want)
}
}
func TestSplitSQLStatements_SQLEscapedQuoteMultiple(t *testing.T) {
input := "INSERT INTO t VALUES ('O''Brien', 'it''s OK'); SELECT 1"
got := splitSQLStatements(input)
want := []string{"INSERT INTO t VALUES ('O''Brien', 'it''s OK')", "SELECT 1"}
if !reflect.DeepEqual(got, want) {
t.Errorf("splitSQLStatements(%q) = %v, want %v", input, got, want)
}
}