feat(query): 支持多条SQL语句执行返回多结果集

- 新增 ResultSetData 结构体承载单个结果集数据
- 新增 MultiResultQuerier/MultiResultQuerierContext 可选接口
- 新增 scanMultiRows 函数利用 NextResultSet() 遍历所有结果集
- MySQL 驱动 DSN 开启 multiStatements=true 并实现多结果集接口
- 新增 DBQueryMulti Wails 方法,支持驱动原生多结果集及自动回退逐条执行
- 新增 Go 版 SQL 拆分函数 splitSQLStatements 及 10 个单元测试
- 前端 QueryEditor handleRun 改为一次性调用 DBQueryMulti
- MongoDB 保持独立的逐条执行路径不受影响
- refs #235
This commit is contained in:
杨国锋
2026-03-17 22:21:49 +08:00
parent 064cdc34be
commit 0ab10d2e80
10 changed files with 682 additions and 106 deletions

View File

@@ -6,7 +6,7 @@ import { format } from 'sql-formatter';
import { v4 as uuidv4 } from 'uuid';
import { TabData, ColumnDefinition } from '../types';
import { useStore } from '../store';
import { DBQuery, DBQueryWithCancel, DBGetTables, DBGetAllColumns, DBGetDatabases, DBGetColumns, CancelQuery, GenerateQueryID } from '../../wailsjs/go/app/App';
import { DBQuery, DBQueryWithCancel, DBQueryMulti, DBGetTables, DBGetAllColumns, DBGetDatabases, DBGetColumns, CancelQuery, GenerateQueryID } from '../../wailsjs/go/app/App';
import DataGrid, { GONAVI_ROW_KEY } from './DataGrid';
import { getDataSourceCapabilities } from '../utils/dataSourceCapabilities';
import { convertMongoShellToJsonCommand } from '../utils/mongodb';
@@ -1157,36 +1157,33 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
const dbType = String((config as any).type || 'mysql');
const normalizedDbType = dbType.trim().toLowerCase();
const normalizedRawSQL = String(rawSQL || '').replace(//g, ';');
const splitInput = normalizedDbType === 'mongodb'
? normalizedRawSQL
// MongoDB 仍走逐条执行的旧路径
const isMongoDB = normalizedDbType === 'mongodb';
if (isMongoDB) {
// MongoDB: 保持逐条执行
const splitInput = normalizedRawSQL
.replace(/^\s*\/\/.*$/gm, '')
.replace(/^\s*#.*$/gm, '')
: normalizedRawSQL;
const statements = splitSQLStatements(splitInput);
if (statements.length === 0) {
message.info('没有可执行的 SQL。');
setResultSets([]);
setActiveResultKey('');
return;
}
.replace(/^\s*#.*$/gm, '');
const statements = splitSQLStatements(splitInput);
if (statements.length === 0) {
message.info('没有可执行的 SQL。');
setResultSets([]);
setActiveResultKey('');
return;
}
const nextResultSets: ResultSet[] = [];
const maxRows = Number(queryOptions?.maxRows) || 0;
const forceReadOnlyResult = connCaps.forceReadOnlyQueryResult;
const wantsLimitProbe = Number.isFinite(maxRows) && maxRows > 0;
const probeLimit = wantsLimitProbe ? (maxRows + 1) : 0;
let anyTruncated = false;
const pendingPk: Array<{ resultKey: string; tableName: string }> = [];
const nextResultSets: ResultSet[] = [];
const maxRows = Number(queryOptions?.maxRows) || 0;
const forceReadOnlyResult = connCaps.forceReadOnlyQueryResult;
const wantsLimitProbe = Number.isFinite(maxRows) && maxRows > 0;
const probeLimit = wantsLimitProbe ? (maxRows + 1) : 0;
let anyTruncated = false;
for (let idx = 0; idx < statements.length; idx++) {
const rawStatement = statements[idx];
const leadingKeyword = getLeadingKeyword(rawStatement);
const shouldAutoLimit = leadingKeyword === 'select' || leadingKeyword === 'with';
const limitApplied = shouldAutoLimit && wantsLimitProbe;
const limited = limitApplied ? applyAutoLimit(rawStatement, dbType, probeLimit) : { sql: rawStatement, applied: false, maxRows: probeLimit };
let executedSql = limited.sql;
if (String(dbType || '').trim().toLowerCase() === 'mongodb') {
for (let idx = 0; idx < statements.length; idx++) {
const rawStatement = statements[idx];
let executedSql = rawStatement;
const shellConvert = convertMongoShellToJsonCommand(executedSql);
if (shellConvert.recognized) {
if (shellConvert.error) {
@@ -1200,10 +1197,97 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
executedSql = shellConvert.command;
}
}
}
const startTime = Date.now();
const startTime = Date.now();
let queryId: string;
try {
queryId = await GenerateQueryID();
} catch (error) {
console.warn('GenerateQueryID failed, using local UUID fallback:', error);
queryId = 'query-' + uuidv4();
}
setQueryId(queryId);
// Generate query ID for cancellation using backend UUID with fallback
const res = await DBQueryWithCancel(config as any, currentDb, executedSql, queryId);
const duration = Date.now() - startTime;
addSqlLog({
id: `log-${Date.now()}-query-${idx + 1}`,
timestamp: Date.now(),
sql: executedSql,
status: res.success ? 'success' : 'error',
duration,
message: res.success ? '' : res.message,
affectedRows: (res.success && !Array.isArray(res.data)) ? (res.data as any).affectedRows : (Array.isArray(res.data) ? res.data.length : undefined),
dbName: currentDb
});
if (!res.success) {
const prefix = statements.length > 1 ? `${idx + 1} 条语句执行失败:` : '';
message.error(prefix + res.message);
setResultSets([]);
setActiveResultKey('');
return;
}
if (Array.isArray(res.data)) {
let rows = (res.data as any[]) || [];
let truncated = false;
if (wantsLimitProbe && Number.isFinite(maxRows) && maxRows > 0 && rows.length > maxRows) {
truncated = true;
anyTruncated = true;
rows = rows.slice(0, maxRows);
}
const cols = (res.fields && res.fields.length > 0)
? (res.fields as string[])
: (rows.length > 0 ? Object.keys(rows[0]) : []);
rows.forEach((row: any, i: number) => {
if (row && typeof row === 'object') row[GONAVI_ROW_KEY] = i;
});
nextResultSets.push({
key: `result-${idx + 1}`,
sql: rawStatement,
exportSql: rawStatement,
rows,
columns: cols,
pkColumns: [],
readOnly: true,
truncated
});
} else {
const affected = Number((res.data as any)?.affectedRows);
if (Number.isFinite(affected)) {
const row = { affectedRows: affected };
(row as any)[GONAVI_ROW_KEY] = 0;
nextResultSets.push({
key: `result-${idx + 1}`,
sql: rawStatement,
exportSql: rawStatement,
rows: [row],
columns: ['affectedRows'],
pkColumns: [],
readOnly: true
});
}
}
}
setResultSets(nextResultSets);
setActiveResultKey(nextResultSets[0]?.key || '');
if (statements.length > 1) {
message.success(`已执行 ${statements.length} 条语句,生成 ${nextResultSets.length} 个结果集。`);
} else if (nextResultSets.length === 0) {
message.success('执行成功。');
}
if (anyTruncated && maxRows > 0) {
message.warning(`结果集已自动限制为最多 ${maxRows} 行(可在工具栏调整)。`);
}
} else {
// 非 MongoDB使用 DBQueryMulti 一次性执行多条 SQL后端返回多结果集
const fullSQL = normalizedRawSQL;
if (!fullSQL.trim()) {
message.info('没有可执行的 SQL。');
setResultSets([]);
setActiveResultKey('');
return;
}
const startTime = Date.now();
let queryId: string;
try {
queryId = await GenerateQueryID();
@@ -1213,22 +1297,20 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
}
setQueryId(queryId);
const res = await DBQueryWithCancel(config as any, currentDb, executedSql, queryId);
const res = await DBQueryMulti(config as any, currentDb, fullSQL, queryId);
const duration = Date.now() - startTime;
addSqlLog({
id: `log-${Date.now()}-query-${idx + 1}`,
id: `log-${Date.now()}-query-multi`,
timestamp: Date.now(),
sql: executedSql,
sql: fullSQL,
status: res.success ? 'success' : 'error',
duration,
message: res.success ? '' : res.message,
affectedRows: (res.success && !Array.isArray(res.data)) ? (res.data as any).affectedRows : (Array.isArray(res.data) ? res.data.length : undefined),
dbName: currentDb
});
if (!res.success) {
// 检查是否为查询取消错误
const errorMsg = res.message.toLowerCase();
const isCancelledError = errorMsg.includes('context canceled') ||
errorMsg.includes('查询已取消') ||
@@ -1236,72 +1318,49 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
errorMsg.includes('cancelled') ||
errorMsg.includes('statement canceled') ||
errorMsg.includes('sql: statement canceled');
// 确保不是超时错误
const isTimeoutError = errorMsg.includes('context deadline exceeded') ||
errorMsg.includes('timeout') ||
errorMsg.includes('超时') ||
errorMsg.includes('deadline exceeded');
if (isCancelledError && !isTimeoutError) {
// 查询已被用户取消,不显示错误消息,清理状态
setResultSets([]);
setActiveResultKey('');
// 清除查询ID与handleCancel保持一致
if (currentQueryIdRef.current) {
clearQueryId();
}
return;
}
const prefix = statements.length > 1 ? `${idx + 1} 条语句执行失败:` : '';
message.error(prefix + res.message);
message.error(res.message);
setResultSets([]);
setActiveResultKey('');
return;
}
if (Array.isArray(res.data)) {
let rows = (res.data as any[]) || [];
let truncated = false;
if (limited.applied && Number.isFinite(maxRows) && maxRows > 0 && rows.length > maxRows) {
truncated = true;
anyTruncated = true;
rows = rows.slice(0, maxRows);
}
const cols = (res.fields && res.fields.length > 0)
? (res.fields as string[])
: (rows.length > 0 ? Object.keys(rows[0]) : []);
// res.data 是 ResultSetData[] 数组
const resultSetDataArray = Array.isArray(res.data) ? (res.data as any[]) : [];
const nextResultSets: ResultSet[] = [];
const maxRows = Number(queryOptions?.maxRows) || 0;
const forceReadOnlyResult = connCaps.forceReadOnlyQueryResult;
let anyTruncated = false;
const pendingPk: Array<{ resultKey: string; tableName: string }> = [];
rows.forEach((row: any, i: number) => {
if (row && typeof row === 'object') row[GONAVI_ROW_KEY] = i;
});
// 前端也拆分语句用于匹配原始 SQL展示和表名检测
const statements = splitSQLStatements(fullSQL);
let simpleTableName: string | undefined = undefined;
const tableMatch = rawStatement.match(/^\s*SELECT\s+\*\s+FROM\s+[`"]?(\w+)[`"]?\s*(?:WHERE.*)?(?:ORDER BY.*)?(?:LIMIT.*)?$/i);
if (tableMatch) {
simpleTableName = tableMatch[1];
if (!forceReadOnlyResult) {
pendingPk.push({ resultKey: `result-${idx + 1}`, tableName: simpleTableName });
}
}
for (let idx = 0; idx < resultSetDataArray.length; idx++) {
const rsData = resultSetDataArray[idx];
const rawStatement = (idx < statements.length) ? statements[idx] : '';
nextResultSets.push({
key: `result-${idx + 1}`,
sql: rawStatement,
exportSql: limited.applied ? applyAutoLimit(rawStatement, dbType, Math.max(1, Number(maxRows) || 1)).sql : rawStatement,
rows,
columns: cols,
tableName: simpleTableName,
pkColumns: [],
readOnly: true,
pkLoading: !!simpleTableName,
truncated
});
} else {
const affected = Number((res.data as any)?.affectedRows);
if (Number.isFinite(affected)) {
const row = { affectedRows: affected };
// 检查是否为 affectedRows 类结果集
const isAffectedResult = Array.isArray(rsData.rows) && rsData.rows.length === 1
&& rsData.columns && rsData.columns.length === 1
&& rsData.columns[0] === 'affectedRows';
if (isAffectedResult) {
const affected = Number(rsData.rows[0]?.affectedRows);
const row = { affectedRows: Number.isFinite(affected) ? affected : 0 };
(row as any)[GONAVI_ROW_KEY] = 0;
nextResultSets.push({
key: `result-${idx + 1}`,
@@ -1312,37 +1371,76 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
pkColumns: [],
readOnly: true
});
} else {
let rows = Array.isArray(rsData.rows) ? rsData.rows : [];
let truncated = false;
if (Number.isFinite(maxRows) && maxRows > 0 && rows.length > maxRows) {
truncated = true;
anyTruncated = true;
rows = rows.slice(0, maxRows);
}
const cols = (rsData.columns && rsData.columns.length > 0)
? rsData.columns
: (rows.length > 0 ? Object.keys(rows[0]) : []);
rows.forEach((row: any, i: number) => {
if (row && typeof row === 'object') row[GONAVI_ROW_KEY] = i;
});
let simpleTableName: string | undefined = undefined;
if (rawStatement) {
const tableMatch = rawStatement.match(/^\s*SELECT\s+\*\s+FROM\s+[`"]?(\w+)[`"]?\s*(?:WHERE.*)?(?:ORDER BY.*)?(?:LIMIT.*)?$/i);
if (tableMatch) {
simpleTableName = tableMatch[1];
if (!forceReadOnlyResult) {
pendingPk.push({ resultKey: `result-${idx + 1}`, tableName: simpleTableName });
}
}
}
nextResultSets.push({
key: `result-${idx + 1}`,
sql: rawStatement,
exportSql: rawStatement,
rows,
columns: cols,
tableName: simpleTableName,
pkColumns: [],
readOnly: true,
pkLoading: !!simpleTableName,
truncated
});
}
}
}
setResultSets(nextResultSets);
setActiveResultKey(nextResultSets[0]?.key || '');
setResultSets(nextResultSets);
setActiveResultKey(nextResultSets[0]?.key || '');
pendingPk.forEach(({ resultKey, tableName }) => {
DBGetColumns(config as any, currentDb, tableName)
.then((resCols: any) => {
if (runSeqRef.current !== runSeq) return;
if (!resCols?.success) {
pendingPk.forEach(({ resultKey, tableName }) => {
DBGetColumns(config as any, currentDb, tableName)
.then((resCols: any) => {
if (runSeqRef.current !== runSeq) return;
if (!resCols?.success) {
setResultSets(prev => prev.map(rs => rs.key === resultKey ? { ...rs, pkLoading: false, readOnly: false } : rs));
return;
}
const primaryKeys = (resCols.data as ColumnDefinition[]).filter(c => c.key === 'PRI').map(c => c.name);
setResultSets(prev => prev.map(rs => rs.key === resultKey ? { ...rs, pkColumns: primaryKeys, pkLoading: false, readOnly: false } : rs));
})
.catch(() => {
if (runSeqRef.current !== runSeq) return;
setResultSets(prev => prev.map(rs => rs.key === resultKey ? { ...rs, pkLoading: false, readOnly: false } : rs));
return;
}
const primaryKeys = (resCols.data as ColumnDefinition[]).filter(c => c.key === 'PRI').map(c => c.name);
setResultSets(prev => prev.map(rs => rs.key === resultKey ? { ...rs, pkColumns: primaryKeys, pkLoading: false, readOnly: false } : rs));
})
.catch(() => {
if (runSeqRef.current !== runSeq) return;
setResultSets(prev => prev.map(rs => rs.key === resultKey ? { ...rs, pkLoading: false, readOnly: false } : rs));
});
});
});
});
if (statements.length > 1) {
message.success(`已执行 ${statements.length} 条语句,生成 ${nextResultSets.length} 个结果集。`);
} else if (nextResultSets.length === 0) {
message.success('执行成功。');
}
if (anyTruncated && maxRows > 0) {
message.warning(`结果集已自动限制为最多 ${maxRows} 行(可在工具栏调整)。`);
if (resultSetDataArray.length > 1) {
message.success(`已执行完成,生成 ${nextResultSets.length} 个结果集。`);
} else if (nextResultSets.length === 0) {
message.success('执行成功。');
}
if (anyTruncated && maxRows > 0) {
message.warning(`结果集已自动限制为最多 ${maxRows} 行(可在工具栏调整)。`);
}
}
} catch (e: any) {
message.error("Error executing query: " + e.message);

View File

@@ -41,6 +41,8 @@ export function DBQuery(arg1:connection.ConnectionConfig,arg2:string,arg3:string
export function DBQueryIsolated(arg1:connection.ConnectionConfig,arg2:string,arg3:string):Promise<connection.QueryResult>;
export function DBQueryMulti(arg1:connection.ConnectionConfig,arg2:string,arg3:string,arg4:string):Promise<connection.QueryResult>;
export function DBQueryWithCancel(arg1:connection.ConnectionConfig,arg2:string,arg3:string,arg4:string):Promise<connection.QueryResult>;
export function DBShowCreateTable(arg1:connection.ConnectionConfig,arg2:string,arg3:string):Promise<connection.QueryResult>;

View File

@@ -74,6 +74,10 @@ export function DBQueryIsolated(arg1, arg2, arg3) {
return window['go']['app']['App']['DBQueryIsolated'](arg1, arg2, arg3);
}
export function DBQueryMulti(arg1, arg2, arg3, arg4) {
return window['go']['app']['App']['DBQueryMulti'](arg1, arg2, arg3, arg4);
}
export function DBQueryWithCancel(arg1, arg2, arg3, arg4) {
return window['go']['app']['App']['DBQueryWithCancel'](arg1, arg2, arg3, arg4);
}

View File

@@ -487,6 +487,138 @@ func (a *App) DBQueryWithCancel(config connection.ConnectionConfig, dbName strin
}
}
// DBQueryMulti 执行可能包含多条 SQL 语句的查询,返回多个结果集。
// 如果底层驱动支持 MultiResultQuerier一次性执行所有语句
// 否则按分号拆分后逐条执行,模拟多结果集。
func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, query string, queryID string) connection.QueryResult {
runConfig := normalizeRunConfig(config, dbName)
if queryID == "" {
queryID = generateQueryID()
}
dbInst, err := a.getDatabase(runConfig)
if err != nil {
logger.Error(err, "DBQueryMulti 获取连接失败:%s", formatConnSummary(runConfig))
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
}
query = sanitizeSQLForPgLike(runConfig.Type, query)
timeoutSeconds := runConfig.Timeout
if timeoutSeconds <= 0 {
timeoutSeconds = 30
}
ctx, cancel := utils.ContextWithTimeout(time.Duration(timeoutSeconds) * time.Second)
defer cancel()
a.queryMu.Lock()
a.runningQueries[queryID] = queryContext{
cancel: cancel,
started: time.Now(),
}
a.queryMu.Unlock()
defer func() {
a.queryMu.Lock()
delete(a.runningQueries, queryID)
a.queryMu.Unlock()
}()
// 尝试使用驱动原生多结果集支持
runMultiQuery := func(inst db.Database) ([]connection.ResultSetData, error) {
if q, ok := inst.(db.MultiResultQuerierContext); ok {
return q.QueryMultiContext(ctx, query)
}
if q, ok := inst.(db.MultiResultQuerier); ok {
return q.QueryMulti(query)
}
return nil, nil // 返回 nil 表示不支持
}
results, err := runMultiQuery(dbInst)
if err != nil && shouldRefreshCachedConnection(err) {
if a.invalidateCachedDatabase(runConfig, err) {
retryInst, retryErr := a.getDatabaseForcePing(runConfig)
if retryErr != nil {
logger.Error(retryErr, "DBQueryMulti 重建连接失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: retryErr.Error(), QueryID: queryID}
}
results, err = runMultiQuery(retryInst)
}
}
if err != nil {
logger.Error(err, "DBQueryMulti 执行失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
}
// 驱动支持多结果集,直接返回
if results != nil {
return connection.QueryResult{Success: true, Data: results, QueryID: queryID}
}
// 驱动不支持多结果集,回退到逐条执行
statements := splitSQLStatements(query)
if len(statements) == 0 {
return connection.QueryResult{
Success: true,
Data: []connection.ResultSetData{},
QueryID: queryID,
}
}
var resultSets []connection.ResultSetData
for _, stmt := range statements {
stmt = strings.TrimSpace(stmt)
if stmt == "" {
continue
}
if isReadOnlySQLQuery(runConfig.Type, stmt) {
var data []map[string]interface{}
var columns []string
if q, ok := dbInst.(interface {
QueryContext(context.Context, string) ([]map[string]interface{}, []string, error)
}); ok {
data, columns, err = q.QueryContext(ctx, stmt)
} else {
data, columns, err = dbInst.Query(stmt)
}
if err != nil {
logger.Error(err, "DBQueryMulti 逐条查询失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(stmt))
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
}
if data == nil {
data = make([]map[string]interface{}, 0)
}
if columns == nil {
columns = []string{}
}
resultSets = append(resultSets, connection.ResultSetData{Rows: data, Columns: columns})
} else {
var affected int64
if e, ok := dbInst.(interface {
ExecContext(context.Context, string) (int64, error)
}); ok {
affected, err = e.ExecContext(ctx, stmt)
} else {
affected, err = dbInst.Exec(stmt)
}
if err != nil {
logger.Error(err, "DBQueryMulti 逐条执行失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(stmt))
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
}
resultSets = append(resultSets, connection.ResultSetData{
Rows: []map[string]interface{}{{"affectedRows": affected}},
Columns: []string{"affectedRows"},
})
}
}
if resultSets == nil {
resultSets = []connection.ResultSetData{}
}
return connection.QueryResult{Success: true, Data: resultSets, QueryID: queryID}
}
func (a *App) DBQueryIsolated(config connection.ConnectionConfig, dbName string, query string) connection.QueryResult {
runConfig := normalizeRunConfig(config, dbName)

167
internal/app/sql_split.go Normal file
View File

@@ -0,0 +1,167 @@
package app
import "strings"
// splitSQLStatements 按分号拆分 SQL 文本为独立语句。
// 正确处理单引号/双引号/反引号字符串、行注释(-- / #)、块注释(/* */)和
// PostgreSQL/Kingbase 的 $$...$$ dollar-quoting避免在这些上下文中错误拆分。
func splitSQLStatements(sql string) []string {
text := strings.ReplaceAll(sql, "\r\n", "\n")
var statements []string
cur := ""
inSingle := false
inDouble := false
inBacktick := false
escaped := false
inLineComment := false
inBlockComment := false
var dollarTag string // postgres/kingbase: $$...$$ or $tag$...$tag$
push := func() {
s := strings.TrimSpace(cur)
if s != "" {
statements = append(statements, s)
}
cur = ""
}
for i := 0; i < len(text); i++ {
ch := text[i]
next := byte(0)
if i+1 < len(text) {
next = text[i+1]
}
// 行注释
if inLineComment {
if ch == '\n' {
inLineComment = false
}
cur += string(ch)
continue
}
// 块注释
if inBlockComment {
cur += string(ch)
if ch == '*' && next == '/' {
cur += "/"
i++
inBlockComment = false
}
continue
}
// Dollar-quoting
if dollarTag != "" {
if strings.HasPrefix(text[i:], dollarTag) {
cur += dollarTag
i += len(dollarTag) - 1
dollarTag = ""
} else {
cur += string(ch)
}
continue
}
// 转义字符
if escaped {
escaped = false
cur += string(ch)
continue
}
if (inSingle || inDouble) && ch == '\\' {
escaped = true
cur += string(ch)
continue
}
// 字符串开闭
if !inDouble && !inBacktick && ch == '\'' {
inSingle = !inSingle
cur += string(ch)
continue
}
if !inSingle && !inBacktick && ch == '"' {
inDouble = !inDouble
cur += string(ch)
continue
}
if !inSingle && !inDouble && ch == '`' {
inBacktick = !inBacktick
cur += string(ch)
continue
}
// 在引号/反引号内部不做任何判断
if inSingle || inDouble || inBacktick {
cur += string(ch)
continue
}
// 行注释开始
if ch == '-' && next == '-' {
inLineComment = true
cur += string(ch)
continue
}
if ch == '#' {
inLineComment = true
cur += string(ch)
continue
}
// 块注释开始
if ch == '/' && next == '*' {
inBlockComment = true
cur += "/*"
i++
continue
}
// Dollar-quoting 开始
if ch == '$' {
if tag := parseSQLDollarTag(text[i:]); tag != "" {
dollarTag = tag
cur += tag
i += len(tag) - 1
continue
}
}
// 分号分隔(支持全角分号""
if ch == ';' {
push()
continue
}
// 全角分号 UTF-8 序列: 0xEF 0xBC 0x9B
if ch == 0xEF && i+2 < len(text) && text[i+1] == 0xBC && text[i+2] == 0x9B {
push()
i += 2
continue
}
cur += string(ch)
}
push()
return statements
}
// parseSQLDollarTag 解析 PostgreSQL/Kingbase 的 dollar-quoting 标签。
func parseSQLDollarTag(s string) string {
if len(s) < 2 || s[0] != '$' {
return ""
}
for i := 1; i < len(s); i++ {
c := s[i]
if c == '$' {
return s[:i+1]
}
if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_') {
return ""
}
}
return ""
}

View File

@@ -0,0 +1,94 @@
package app
import (
"reflect"
"testing"
)
func TestSplitSQLStatements_BasicSplit(t *testing.T) {
input := "SELECT 1; SELECT 2; SELECT 3"
got := splitSQLStatements(input)
want := []string{"SELECT 1", "SELECT 2", "SELECT 3"}
if !reflect.DeepEqual(got, want) {
t.Errorf("splitSQLStatements(%q) = %v, want %v", input, got, want)
}
}
func TestSplitSQLStatements_QuotedSemicolon(t *testing.T) {
input := `SELECT 'hello;world'; SELECT 2`
got := splitSQLStatements(input)
want := []string{`SELECT 'hello;world'`, "SELECT 2"}
if !reflect.DeepEqual(got, want) {
t.Errorf("splitSQLStatements(%q) = %v, want %v", input, got, want)
}
}
func TestSplitSQLStatements_LineComment(t *testing.T) {
input := "SELECT 1; -- this is a comment;\nSELECT 2"
got := splitSQLStatements(input)
want := []string{"SELECT 1", "-- this is a comment;\nSELECT 2"}
if !reflect.DeepEqual(got, want) {
t.Errorf("splitSQLStatements(%q) = %v, want %v", input, got, want)
}
}
func TestSplitSQLStatements_BlockComment(t *testing.T) {
input := "SELECT /* ; */ 1; SELECT 2"
got := splitSQLStatements(input)
want := []string{"SELECT /* ; */ 1", "SELECT 2"}
if !reflect.DeepEqual(got, want) {
t.Errorf("splitSQLStatements(%q) = %v, want %v", input, got, want)
}
}
func TestSplitSQLStatements_EmptyInput(t *testing.T) {
got := splitSQLStatements("")
if len(got) != 0 {
t.Errorf("splitSQLStatements(\"\") = %v, want empty slice", got)
}
}
func TestSplitSQLStatements_SingleStatement(t *testing.T) {
input := "SELECT * FROM users WHERE id = 1"
got := splitSQLStatements(input)
want := []string{"SELECT * FROM users WHERE id = 1"}
if !reflect.DeepEqual(got, want) {
t.Errorf("splitSQLStatements(%q) = %v, want %v", input, got, want)
}
}
func TestSplitSQLStatements_DollarQuoting(t *testing.T) {
input := "SELECT $tag$hello;world$tag$; SELECT 2"
got := splitSQLStatements(input)
want := []string{"SELECT $tag$hello;world$tag$", "SELECT 2"}
if !reflect.DeepEqual(got, want) {
t.Errorf("splitSQLStatements(%q) = %v, want %v", input, got, want)
}
}
func TestSplitSQLStatements_FullWidthSemicolon(t *testing.T) {
input := "SELECT 1SELECT 2"
got := splitSQLStatements(input)
want := []string{"SELECT 1", "SELECT 2"}
if !reflect.DeepEqual(got, want) {
t.Errorf("splitSQLStatements(%q) = %v, want %v", input, got, want)
}
}
func TestSplitSQLStatements_Backtick(t *testing.T) {
input := "SELECT `col;name` FROM t; SELECT 2"
got := splitSQLStatements(input)
want := []string{"SELECT `col;name` FROM t", "SELECT 2"}
if !reflect.DeepEqual(got, want) {
t.Errorf("splitSQLStatements(%q) = %v, want %v", input, got, want)
}
}
func TestSplitSQLStatements_TrailingSemicolon(t *testing.T) {
input := "SELECT 1; SELECT 2;"
got := splitSQLStatements(input)
want := []string{"SELECT 1", "SELECT 2"}
if !reflect.DeepEqual(got, want) {
t.Errorf("splitSQLStatements(%q) = %v, want %v", input, got, want)
}
}

View File

@@ -63,6 +63,12 @@ type ConnectionConfig struct {
MongoReplicaPassword string `json:"mongoReplicaPassword,omitempty"` // MongoDB replica auth password
}
// ResultSetData 表示一个查询结果集(行 + 列名),用于多结果集场景。
type ResultSetData struct {
Rows []map[string]interface{} `json:"rows"`
Columns []string `json:"columns"`
}
// QueryResult 是 Wails 绑定方法的统一响应格式,前端通过此结构体接收后端结果。
type QueryResult struct {
Success bool `json:"success"`

View File

@@ -2,6 +2,7 @@ package db
import (
"GoNavi-Wails/internal/connection"
"context"
"fmt"
"strings"
)
@@ -38,6 +39,17 @@ type Database interface {
GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error)
}
// MultiResultQuerier 是可选接口,支持多结果集的驱动实现此接口。
// 执行可能包含多条 SQL 语句的查询,返回所有结果集。
type MultiResultQuerier interface {
QueryMulti(query string) ([]connection.ResultSetData, error)
}
// MultiResultQuerierContext 是带 context 的多结果集查询接口。
type MultiResultQuerierContext interface {
QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error)
}
// BatchApplier 定义了批量变更提交接口。
// 支持批量编辑的驱动实现此接口,用于一次性提交前端 DataGrid 中的增删改操作。
type BatchApplier interface {

View File

@@ -186,7 +186,7 @@ func (m *MySQLDB) getDSN(config connection.ConnectionConfig) (string, error) {
tlsMode := resolveMySQLTLSMode(config)
return fmt.Sprintf(
"%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds&tls=%s",
"%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds&tls=%s&multiStatements=true",
config.User, config.Password, protocol, address, database, timeout, url.QueryEscape(tlsMode),
), nil
}
@@ -278,6 +278,30 @@ func (m *MySQLDB) Ping() error {
return m.conn.PingContext(ctx)
}
func (m *MySQLDB) QueryMulti(query string) ([]connection.ResultSetData, error) {
if m.conn == nil {
return nil, fmt.Errorf("连接未打开")
}
rows, err := m.conn.Query(query)
if err != nil {
return nil, err
}
defer rows.Close()
return scanMultiRows(rows)
}
func (m *MySQLDB) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) {
if m.conn == nil {
return nil, fmt.Errorf("连接未打开")
}
rows, err := m.conn.QueryContext(ctx, query)
if err != nil {
return nil, err
}
defer rows.Close()
return scanMultiRows(rows)
}
func (m *MySQLDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
if m.conn == nil {
return nil, nil, fmt.Errorf("连接未打开")

View File

@@ -2,6 +2,8 @@ package db
import (
"database/sql"
"GoNavi-Wails/internal/connection"
)
func scanRows(rows *sql.Rows) ([]map[string]interface{}, []string, error) {
@@ -44,3 +46,38 @@ func scanRows(rows *sql.Rows) ([]map[string]interface{}, []string, error) {
}
return resultData, columns, nil
}
// scanMultiRows 遍历 sql.Rows 中的所有结果集,将每个结果集作为 ResultSetData 返回。
// 利用 rows.NextResultSet() 支持一次 query 返回多个结果集的场景。
func scanMultiRows(rows *sql.Rows) ([]connection.ResultSetData, error) {
var results []connection.ResultSetData
for {
data, cols, err := scanRows(rows)
if err != nil {
return results, err
}
if data == nil {
data = make([]map[string]interface{}, 0)
}
if cols == nil {
cols = []string{}
}
results = append(results, connection.ResultSetData{
Rows: data,
Columns: cols,
})
if !rows.NextResultSet() {
break
}
}
if len(results) == 0 {
results = []connection.ResultSetData{{
Rows: make([]map[string]interface{}, 0),
Columns: []string{},
}}
}
if err := rows.Err(); err != nil {
return results, err
}
return results, nil
}