Merge remote-tracking branch 'origin/dev' into feature/20260602_connection_driver_i18n

This commit is contained in:
tianqijiuyun-latiao
2026-06-23 13:14:43 +08:00
30 changed files with 2095 additions and 320 deletions

View File

@@ -36,6 +36,7 @@ type agentResponse struct {
Error string `json:"error,omitempty"`
Data interface{} `json:"data,omitempty"`
Fields []string `json:"fields,omitempty"`
Messages []string `json:"messages,omitempty"`
ChunkType string `json:"chunkType,omitempty"`
RowsAffected int64 `json:"rowsAffected,omitempty"`
}
@@ -48,6 +49,7 @@ const (
agentMethodOpenSession = "openSession"
agentMethodCloseSession = "closeSession"
agentMethodQuery = "query"
agentMethodQueryMulti = "queryMulti"
agentMethodStreamQuery = "streamQuery"
agentMethodExec = "exec"
agentMethodGetDatabases = "getDatabases"
@@ -64,9 +66,9 @@ const (
const legacyClickHouseDefaultTimeout = 2 * time.Hour
const (
agentChunkColumns = "columns"
agentChunkRows = "rows"
agentChunkDone = "done"
agentChunkColumns = "columns"
agentChunkRows = "rows"
agentChunkDone = "done"
// agentStreamBatchSize 控制 driver-agent 向主进程发送 row chunk 的批次大小。
// 调小到 64单批 JSON 编码 + 主进程解码的瞬时内存峰值降为原来的 1/4
// 代价是 IPC 次数变为 4 倍,但每批仅一次 stdin/stdout 行读写,整体影响可忽略。
@@ -236,12 +238,23 @@ func handleRequest(runtimeState *agentRuntime, req agentRequest) agentResponse {
} else if ok {
switch method {
case agentMethodQuery:
data, fields, err := queryStatementWithOptionalTimeout(session, req.Query, req.TimeoutMs)
data, fields, messages, err := queryStatementWithMessagesOptionalTimeout(session, req.Query, req.TimeoutMs)
if err != nil {
return fail(resp, err.Error())
}
resp.Data = data
resp.Fields = fields
resp.Messages = messages
case agentMethodQueryMulti:
data, messages, supported, err := queryMultiStatementWithMessagesOptionalTimeout(session, req.Query, req.TimeoutMs)
if err != nil {
return fail(resp, err.Error())
}
if !supported {
return fail(resp, "当前事务会话不支持多结果集查询")
}
resp.Data = data
resp.Messages = messages
case agentMethodExec:
affected, err := execStatementWithOptionalTimeout(session, req.Query, req.TimeoutMs)
if err != nil {
@@ -260,12 +273,23 @@ func handleRequest(runtimeState *agentRuntime, req agentRequest) agentResponse {
return fail(resp, err.Error())
}
case agentMethodQuery:
data, fields, err := queryWithOptionalTimeout(runtimeState.inst, req.Query, req.TimeoutMs)
data, fields, messages, err := queryWithMessagesOptionalTimeout(runtimeState.inst, req.Query, req.TimeoutMs)
if err != nil {
return fail(resp, err.Error())
}
resp.Data = data
resp.Fields = fields
resp.Messages = messages
case agentMethodQueryMulti:
data, messages, supported, err := queryMultiWithMessagesOptionalTimeout(runtimeState.inst, req.Query, req.TimeoutMs)
if err != nil {
return fail(resp, err.Error())
}
if !supported {
return fail(resp, "当前驱动不支持原生多结果集查询")
}
resp.Data = data
resp.Messages = messages
case agentMethodExec:
affected, err := execWithOptionalTimeout(runtimeState.inst, req.Query, req.TimeoutMs)
if err != nil {
@@ -581,6 +605,30 @@ type agentQueryContextRunner interface {
QueryContext(context.Context, string) ([]map[string]interface{}, []string, error)
}
type agentQueryMessageRunner interface {
QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error)
}
type agentQueryMessageContextRunner interface {
QueryContextWithMessages(context.Context, string) ([]map[string]interface{}, []string, []string, error)
}
type agentMultiResultMessageRunner interface {
QueryMultiWithMessages(query string) ([]connection.ResultSetData, []string, error)
}
type agentMultiResultMessageContextRunner interface {
QueryMultiContextWithMessages(context.Context, string) ([]connection.ResultSetData, []string, error)
}
type agentMultiResultRunner interface {
QueryMulti(query string) ([]connection.ResultSetData, error)
}
type agentMultiResultContextRunner interface {
QueryMultiContext(context.Context, string) ([]connection.ResultSetData, error)
}
type agentExecRunner interface {
Exec(string) (int64, error)
}
@@ -589,20 +637,39 @@ type agentExecContextRunner interface {
ExecContext(context.Context, string) (int64, error)
}
func queryWithOptionalTimeout(inst agentQueryRunner, query string, timeoutMs int64) ([]map[string]interface{}, []string, error) {
func queryWithMessagesOptionalTimeout(inst agentQueryRunner, query string, timeoutMs int64) ([]map[string]interface{}, []string, []string, error) {
effectiveTimeoutMs := timeoutMs
if effectiveTimeoutMs <= 0 && strings.EqualFold(strings.TrimSpace(agentDriverType), "clickhouse") {
effectiveTimeoutMs = int64(legacyClickHouseDefaultTimeout / time.Millisecond)
}
if effectiveTimeoutMs <= 0 {
return inst.Query(query)
if q, ok := inst.(agentQueryMessageRunner); ok {
return q.QueryWithMessages(query)
}
data, fields, err := inst.Query(query)
return data, fields, nil, err
}
if q, ok := inst.(agentQueryMessageContextRunner); ok {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(effectiveTimeoutMs)*time.Millisecond)
defer cancel()
return q.QueryContextWithMessages(ctx, query)
}
if q, ok := inst.(agentQueryContextRunner); ok {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(effectiveTimeoutMs)*time.Millisecond)
defer cancel()
return q.QueryContext(ctx, query)
data, fields, err := q.QueryContext(ctx, query)
return data, fields, nil, err
}
return inst.Query(query)
if q, ok := inst.(agentQueryMessageRunner); ok {
return q.QueryWithMessages(query)
}
data, fields, err := inst.Query(query)
return data, fields, nil, err
}
func queryWithOptionalTimeout(inst agentQueryRunner, query string, timeoutMs int64) ([]map[string]interface{}, []string, error) {
data, fields, _, err := queryWithMessagesOptionalTimeout(inst, query, timeoutMs)
return data, fields, err
}
func queryStatementWithOptionalTimeout(inst db.StatementExecer, query string, timeoutMs int64) ([]map[string]interface{}, []string, error) {
@@ -613,6 +680,74 @@ func queryStatementWithOptionalTimeout(inst db.StatementExecer, query string, ti
return queryWithOptionalTimeout(queryRunner, query, timeoutMs)
}
func queryStatementWithMessagesOptionalTimeout(inst db.StatementExecer, query string, timeoutMs int64) ([]map[string]interface{}, []string, []string, error) {
queryRunner, ok := inst.(agentQueryRunner)
if !ok {
return nil, nil, nil, fmt.Errorf("当前事务会话不支持查询语句")
}
return queryWithMessagesOptionalTimeout(queryRunner, query, timeoutMs)
}
func queryMultiWithMessagesOptionalTimeout(inst db.Database, query string, timeoutMs int64) ([]connection.ResultSetData, []string, bool, error) {
effectiveTimeoutMs := timeoutMs
if effectiveTimeoutMs <= 0 && strings.EqualFold(strings.TrimSpace(agentDriverType), "clickhouse") {
effectiveTimeoutMs = int64(legacyClickHouseDefaultTimeout / time.Millisecond)
}
if effectiveTimeoutMs > 0 {
if q, ok := inst.(agentMultiResultMessageContextRunner); ok {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(effectiveTimeoutMs)*time.Millisecond)
defer cancel()
data, messages, err := q.QueryMultiContextWithMessages(ctx, query)
return data, messages, true, err
}
if q, ok := inst.(agentMultiResultContextRunner); ok {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(effectiveTimeoutMs)*time.Millisecond)
defer cancel()
data, err := q.QueryMultiContext(ctx, query)
return data, nil, true, err
}
}
if q, ok := inst.(agentMultiResultMessageRunner); ok {
data, messages, err := q.QueryMultiWithMessages(query)
return data, messages, true, err
}
if q, ok := inst.(agentMultiResultRunner); ok {
data, err := q.QueryMulti(query)
return data, nil, true, err
}
return nil, nil, false, nil
}
func queryMultiStatementWithMessagesOptionalTimeout(inst db.StatementExecer, query string, timeoutMs int64) ([]connection.ResultSetData, []string, bool, error) {
effectiveTimeoutMs := timeoutMs
if effectiveTimeoutMs <= 0 && strings.EqualFold(strings.TrimSpace(agentDriverType), "clickhouse") {
effectiveTimeoutMs = int64(legacyClickHouseDefaultTimeout / time.Millisecond)
}
if effectiveTimeoutMs > 0 {
if q, ok := inst.(agentMultiResultMessageContextRunner); ok {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(effectiveTimeoutMs)*time.Millisecond)
defer cancel()
data, messages, err := q.QueryMultiContextWithMessages(ctx, query)
return data, messages, true, err
}
if q, ok := inst.(agentMultiResultContextRunner); ok {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(effectiveTimeoutMs)*time.Millisecond)
defer cancel()
data, err := q.QueryMultiContext(ctx, query)
return data, nil, true, err
}
}
if q, ok := inst.(agentMultiResultMessageRunner); ok {
data, messages, err := q.QueryMultiWithMessages(query)
return data, messages, true, err
}
if q, ok := inst.(agentMultiResultRunner); ok {
data, err := q.QueryMulti(query)
return data, nil, true, err
}
return nil, nil, false, nil
}
func streamWithOptionalTimeout(inst db.StreamQueryExecer, query string, timeoutMs int64, consumer db.QueryStreamConsumer) error {
effectiveTimeoutMs := timeoutMs
if effectiveTimeoutMs <= 0 && strings.EqualFold(strings.TrimSpace(agentDriverType), "clickhouse") {

View File

@@ -101,6 +101,9 @@ type fakeAgentTimeoutDB struct {
execCalled bool
execContextCalled bool
deadlineSet bool
queryMessages []string
multiResults []connection.ResultSetData
multiMessages []string
}
func (f *fakeAgentTimeoutDB) Connect(config connection.ConnectionConfig) error { return nil }
@@ -117,6 +120,14 @@ func (f *fakeAgentTimeoutDB) QueryContext(ctx context.Context, query string) ([]
}
return []map[string]interface{}{{"ok": 1}}, []string{"ok"}, nil
}
func (f *fakeAgentTimeoutDB) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) {
data, fields, err := f.QueryContext(context.Background(), query)
return data, fields, append([]string(nil), f.queryMessages...), err
}
func (f *fakeAgentTimeoutDB) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) {
data, fields, err := f.QueryContext(ctx, query)
return data, fields, append([]string(nil), f.queryMessages...), err
}
func (f *fakeAgentTimeoutDB) Exec(query string) (int64, error) {
f.execCalled = true
return 0, errors.New("exec should not be called")
@@ -150,6 +161,15 @@ func (f *fakeAgentTimeoutDB) GetForeignKeys(dbName, tableName string) ([]connect
func (f *fakeAgentTimeoutDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
return nil, nil
}
func (f *fakeAgentTimeoutDB) QueryMultiWithMessages(query string) ([]connection.ResultSetData, []string, error) {
return append([]connection.ResultSetData(nil), f.multiResults...), append([]string(nil), f.multiMessages...), nil
}
func (f *fakeAgentTimeoutDB) QueryMultiContextWithMessages(ctx context.Context, query string) ([]connection.ResultSetData, []string, error) {
if _, ok := ctx.Deadline(); ok {
f.deadlineSet = true
}
return f.QueryMultiWithMessages(query)
}
type fakeAgentSessionDB struct {
fakeAgentTimeoutDB
@@ -165,6 +185,7 @@ type fakeAgentStatementSession struct {
queryCalls int
execCalls int
closed bool
messages []string
}
func (f *fakeAgentStatementSession) Query(query string) ([]map[string]interface{}, []string, error) {
@@ -175,6 +196,14 @@ func (f *fakeAgentStatementSession) QueryContext(ctx context.Context, query stri
f.queryCalls++
return []map[string]interface{}{{"session_ok": 1}}, []string{"session_ok"}, nil
}
func (f *fakeAgentStatementSession) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) {
data, fields, err := f.QueryContext(context.Background(), query)
return data, fields, append([]string(nil), f.messages...), err
}
func (f *fakeAgentStatementSession) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) {
data, fields, err := f.QueryContext(ctx, query)
return data, fields, append([]string(nil), f.messages...), err
}
func (f *fakeAgentStatementSession) Exec(query string) (int64, error) {
return f.ExecContext(context.Background(), query)
@@ -297,6 +326,77 @@ func TestQueryWithOptionalTimeout_ClickHouseLegacyModeUsesQueryContext(t *testin
}
}
func TestHandleRequest_QueryIncludesServerMessages(t *testing.T) {
old := agentDriverType
defer func() { agentDriverType = old }()
agentDriverType = "sqlserver"
fake := &fakeAgentTimeoutDB{
queryMessages: []string{"PRINT sql line 1", "PRINT sql line 2"},
}
runtimeState := &agentRuntime{inst: fake, sessions: make(map[string]db.StatementExecer)}
resp := handleRequest(runtimeState, agentRequest{
ID: 11,
Method: agentMethodQuery,
Query: "exec dbo.p_get_select",
TimeoutMs: int64((2 * time.Second).Milliseconds()),
})
if !resp.Success {
t.Fatalf("query request failed: %s", resp.Error)
}
if len(resp.Messages) != 2 || resp.Messages[0] != "PRINT sql line 1" {
t.Fatalf("expected query messages to be preserved, got %#v", resp.Messages)
}
}
func TestHandleRequest_QueryMultiIncludesResultSetsAndMessages(t *testing.T) {
old := agentDriverType
defer func() { agentDriverType = old }()
agentDriverType = "sqlserver"
fake := &fakeAgentTimeoutDB{
multiResults: []connection.ResultSetData{
{
StatementIndex: 1,
Rows: []map[string]interface{}{{"name": "master"}},
Columns: []string{"name"},
},
{
StatementIndex: 1,
Rows: []map[string]interface{}{},
Columns: []string{},
Messages: []string{"PRINT generated sql"},
},
},
multiMessages: []string{"batch top-level message"},
}
runtimeState := &agentRuntime{inst: fake, sessions: make(map[string]db.StatementExecer)}
resp := handleRequest(runtimeState, agentRequest{
ID: 12,
Method: agentMethodQueryMulti,
Query: "exec dbo.p_get_select",
TimeoutMs: int64((2 * time.Second).Milliseconds()),
})
if !resp.Success {
t.Fatalf("queryMulti request failed: %s", resp.Error)
}
if len(resp.Messages) != 1 || resp.Messages[0] != "batch top-level message" {
t.Fatalf("expected top-level messages to be preserved, got %#v", resp.Messages)
}
resultSets, ok := resp.Data.([]connection.ResultSetData)
if !ok {
t.Fatalf("expected []connection.ResultSetData, got %T", resp.Data)
}
if len(resultSets) != 2 {
t.Fatalf("expected 2 result sets, got %#v", resultSets)
}
if len(resultSets[1].Messages) != 1 || resultSets[1].Messages[0] != "PRINT generated sql" {
t.Fatalf("expected message-only result set to be preserved, got %#v", resultSets[1])
}
}
func TestHandleRequest_UsesPinnedSessionForSessionScopedQueryAndExec(t *testing.T) {
old := agentDriverType
defer func() { agentDriverType = old }()
@@ -329,6 +429,9 @@ func TestHandleRequest_UsesPinnedSessionForSessionScopedQueryAndExec(t *testing.
if !queryResp.Success {
t.Fatalf("session query failed: %s", queryResp.Error)
}
if len(queryResp.Messages) != 0 {
t.Fatalf("expected empty default session messages, got %#v", queryResp.Messages)
}
if fake.queryCalled || fake.queryContextCalled {
t.Fatalf("expected session query to bypass database-level query path, got Query=%v QueryContext=%v", fake.queryCalled, fake.queryContextCalled)
}

View File

@@ -1 +1 @@
1d8f9adbde8018f90d013cc740e0405b
84ec3a6d42105c92f224232f0d83a33b

View File

@@ -12,6 +12,7 @@ import DataGrid, {
import DataGridToolbarFrame from './DataGridToolbarFrame';
import { V2CellContextMenuView, V2ColumnHeaderContextMenuView, V2TableGroupContextMenuView } from './V2TableContextMenu';
import { setCurrentLanguage, t } from '../i18n';
import { parseMongoEditedValue } from '../utils/mongodb';
import { DUCKDB_ROWID_LOCATOR_COLUMN, ORACLE_ROWID_LOCATOR_COLUMN } from '../utils/rowLocator';
const storeState = vi.hoisted(() => ({
@@ -648,6 +649,47 @@ describe('DataGrid commit change set', () => {
});
});
it('keeps MongoDB explicit typed edit values in the final commit payload', () => {
const result = buildDataGridCommitChangeSet({
addedRows: [{
[GONAVI_ROW_KEY]: 'new-1',
_id: '507f1f77bcf86cd799439013',
age: '{"$numberLong":"12"}',
ratio: '1.5',
}],
modifiedRows: {},
deletedRowKeys: new Set(),
data: [],
editLocator: {
strategy: 'primary-key',
columns: ['_id'],
valueColumns: ['_id'],
readOnly: false,
},
visibleColumnNames: ['_id', 'age', 'ratio'],
rowKeyToString,
normalizeCommitCellValue: (columnName, value) => parseMongoEditedValue(
columnName,
value,
columnName === 'ratio' ? { $numberDouble: '0.5' } : undefined,
),
shouldCommitColumn: commitColumnGuard,
});
expect(result).toEqual({
ok: true,
changes: {
inserts: [{
_id: { $oid: '507f1f77bcf86cd799439013' },
age: { $numberLong: '12' },
ratio: { $numberDouble: '1.5' },
}],
updates: [],
deletes: [],
},
});
});
it('fails closed when no safe locator is available', () => {
const result = buildDataGridCommitChangeSet({
addedRows: [],

View File

@@ -158,6 +158,7 @@ import { useDataGridColumnResize } from './useDataGridColumnResize';
import { useDataGridPreviewPanel } from './useDataGridPreviewPanel';
import { buildTableExportTab } from '../utils/tableExportTab';
import { buildDataGridCssText } from './dataGridStyles';
import { formatMongoEditableValue, parseMongoEditedValue } from '../utils/mongodb';
// --- Error Boundary ---
import {
@@ -533,6 +534,7 @@ const DataGrid: React.FC<DataGridProps> = ({
const supportsApproximateTableCount = dataSourceCaps.supportsApproximateTableCount;
const supportsApproximateTotalPages = dataSourceCaps.supportsApproximateTotalPages;
const dbType = dataSourceCaps.type;
const isMongoDBConnection = dbType === 'mongodb';
const isDuckDBConnection = dataSourceCaps.type === 'duckdb';
const supportsCopyInsert = dataSourceCaps.supportsCopyInsert;
const supportsSqlQueryExport = dataSourceCaps.supportsSqlQueryExport;
@@ -544,6 +546,33 @@ const DataGrid: React.FC<DataGridProps> = ({
const filteredExportSql = useMemo(() => String(exportSqlWithFilter || '').trim(), [exportSqlWithFilter]);
const hasFilteredExportSql = exportScope === 'table' && filteredExportSql.length > 0;
const mongoAwareEditableText = useCallback((value: any): string => (
isMongoDBConnection ? formatMongoEditableValue(value) : toEditableText(value)
), [isMongoDBConnection]);
const mongoAwareFormText = useCallback((value: any): string => (
isMongoDBConnection ? formatMongoEditableValue(value) : toFormText(value)
), [isMongoDBConnection]);
const normalizeMongoEditedCellValue = useCallback((columnName: string, value: any, currentValue?: any) => (
isMongoDBConnection ? parseMongoEditedValue(columnName, value, currentValue) : value
), [isMongoDBConnection]);
const normalizeMongoEditedRow = useCallback((row: any, currentRow?: any) => {
if (!isMongoDBConnection || !row || typeof row !== 'object') return row;
let changed = false;
const nextRow: any = { ...row };
Object.keys(row).forEach((columnName) => {
if (columnName === GONAVI_ROW_KEY) return;
const normalizedValue = normalizeMongoEditedCellValue(columnName, row[columnName], currentRow?.[columnName]);
if (normalizedValue !== row[columnName]) {
nextRow[columnName] = normalizedValue;
changed = true;
}
});
return changed ? nextRow : row;
}, [isMongoDBConnection, normalizeMongoEditedCellValue]);
// --- 主题样式变量(仅在 darkMode / opacity / blur 变化时重算) ---
const themeStyles = useMemo(() => {
const _getBg = (darkHex: string) => {
@@ -679,7 +708,7 @@ const DataGrid: React.FC<DataGridProps> = ({
openBatchEditModal,
closeBatchEditModal,
} = useDataGridModalEditors({
toEditableText,
toEditableText: mongoAwareEditableText,
looksLikeJsonText,
});
const [virtualEditingCell, setVirtualEditingCell] = useState<VirtualEditingCellState | null>(null);
@@ -699,7 +728,7 @@ const DataGrid: React.FC<DataGridProps> = ({
updateFocusedCell,
handleDataPanelFormatJson,
} = useDataGridPreviewPanel({
toEditableText,
toEditableText: mongoAwareEditableText,
looksLikeJsonText,
normalizeDateTimeString,
});
@@ -954,6 +983,9 @@ const DataGrid: React.FC<DataGridProps> = ({
const normalizeCommitCellValue = useCallback(
(columnName: string, value: any, mode: 'insert' | 'update') => {
if (value === undefined) return undefined;
if (isMongoDBConnection) {
return parseMongoEditedValue(columnName, value, undefined);
}
const normalizedName = String(columnName || '').trim();
const meta = columnMetaMap[normalizedName] || columnMetaMapByLowerName[normalizedName.toLowerCase()];
const temporal = isTemporalColumnType(meta?.type, dbType);
@@ -977,7 +1009,7 @@ const DataGrid: React.FC<DataGridProps> = ({
return value;
},
[columnMetaMap, columnMetaMapByLowerName, dbType]
[columnMetaMap, columnMetaMapByLowerName, dbType, isMongoDBConnection]
);
const openTableByName = useCallback((nextTableName: string) => {
@@ -1568,19 +1600,23 @@ const DataGrid: React.FC<DataGridProps> = ({
const keyStr = rowKeyStr(rowKey);
const isAdded = addedRows.some(r => r?.[GONAVI_ROW_KEY] === rowKey);
if (isAdded) {
setAddedRows(prev => prev.map(r => r?.[GONAVI_ROW_KEY] === rowKey ? { ...r, ...row } : r));
const currentAddedRow = addedRows.find(r => r?.[GONAVI_ROW_KEY] === rowKey);
const normalizedRow = normalizeMongoEditedRow(row, currentAddedRow);
setAddedRows(prev => prev.map(r => r?.[GONAVI_ROW_KEY] === rowKey ? { ...r, ...normalizedRow } : r));
return;
}
if (deletedRowKeys.has(keyStr)) return;
// 查找原始行数据,对比是否真正有值变更
const originalRow = data.find(r => r?.[GONAVI_ROW_KEY] === rowKey);
if (originalRow) {
const currentRow = modifiedRows[keyStr] ? { ...originalRow, ...modifiedRows[keyStr] } : originalRow;
const normalizedRow = normalizeMongoEditedRow(row, currentRow);
const changedFields: Record<string, any> = {};
for (const col of Object.keys(row)) {
for (const col of Object.keys(normalizedRow)) {
if (col === GONAVI_ROW_KEY) continue;
if (!isWritableResultColumn(col, effectiveEditLocator)) continue;
if (!isCellValueEqualForDiff(originalRow[col], row[col])) {
changedFields[col] = row[col];
if (!isCellValueEqualForDiff(originalRow[col], normalizedRow[col])) {
changedFields[col] = normalizedRow[col];
}
}
if (Object.keys(changedFields).length === 0) {
@@ -1610,9 +1646,9 @@ const DataGrid: React.FC<DataGridProps> = ({
}
return { ...prev, [keyStr]: newCols };
});
setModifiedRows(prev => ({ ...prev, [keyStr]: row }));
setModifiedRows(prev => ({ ...prev, [keyStr]: normalizedRow }));
}
}, [addedRows, data, rowKeyStr, deletedRowKeys, effectiveEditLocator]);
}, [addedRows, data, rowKeyStr, deletedRowKeys, effectiveEditLocator, modifiedRows, normalizeMongoEditedRow]);
const handleDataPanelSave = useCallback(() => {
if (!focusedCellInfo) return;
@@ -1730,7 +1766,9 @@ const DataGrid: React.FC<DataGridProps> = ({
if (isDateTimeField) {
setCellFieldValue(form, fieldName, parseToDayjs(raw, pickerType));
} else {
const initialValue = typeof raw === 'string' ? normalizeDateTimeString(raw) : raw;
const initialValue = isMongoDBConnection
? mongoAwareEditableText(raw)
: (typeof raw === 'string' ? normalizeDateTimeString(raw) : raw);
setCellFieldValue(form, fieldName, initialValue);
}
setVirtualEditingCell({
@@ -1739,7 +1777,7 @@ const DataGrid: React.FC<DataGridProps> = ({
title,
columnType,
});
}, [canModifyData, columnMetaMap, columnMetaMapByLowerName, currentConnConfig, dbType, form, openCellEditor, rowKeyStr]);
}, [canModifyData, columnMetaMap, columnMetaMapByLowerName, currentConnConfig, dbType, form, isMongoDBConnection, mongoAwareEditableText, openCellEditor, rowKeyStr]);
const handleVirtualCellActivate = useCallback((record: Item, dataIndex: string, title: React.ReactNode) => {
if (!canModifyData) return;
@@ -2015,7 +2053,7 @@ const DataGrid: React.FC<DataGridProps> = ({
const baseVal = (baseRow as any)?.[col];
const displayVal = (displayRow as any)?.[col];
baseRawMap[col] = baseVal;
displayMap[col] = toFormText(displayVal);
displayMap[col] = mongoAwareFormText(displayVal);
// 日期时间类型: 将字符串值转为 dayjs 对象供 DatePicker 使用
const colMeta = columnMetaMap[col] || columnMetaMapByLowerName[col.toLowerCase()];
const rowPickerType = getTemporalPickerType(colMeta?.type, dbType, currentConnConfig);
@@ -2023,7 +2061,7 @@ const DataGrid: React.FC<DataGridProps> = ({
const dVal = parseToDayjs(displayVal, rowPickerType);
formMap[col] = dVal;
} else {
formMap[col] = displayVal === null || displayVal === undefined ? undefined : toFormText(displayVal);
formMap[col] = displayVal === null || displayVal === undefined ? undefined : mongoAwareFormText(displayVal);
}
if (baseVal === null || baseVal === undefined) nullCols.add(col);
});
@@ -2035,7 +2073,7 @@ const DataGrid: React.FC<DataGridProps> = ({
nullCols,
formValues: formMap,
});
}, [addedRows, canModifyData, columnMetaMap, columnMetaMapByLowerName, currentConnConfig, data, dbType, mergedDisplayData, openRowEditor, rowKeyStr, translateDataGrid, visibleColumnNames]);
}, [addedRows, canModifyData, columnMetaMap, columnMetaMapByLowerName, currentConnConfig, data, dbType, mergedDisplayData, mongoAwareFormText, openRowEditor, rowKeyStr, translateDataGrid, visibleColumnNames]);
const openCurrentViewRowEditor = useCallback(() => {
if (!canModifyData) return;
@@ -2193,6 +2231,7 @@ const DataGrid: React.FC<DataGridProps> = ({
const keyStr = rowEditorRowKey;
if (!keyStr) return;
const values = rowEditorForm.getFieldsValue(true) || {};
const baseRawMap = rowEditorBaseRawRef.current || {};
const isAdded = addedRows.some(r => rowKeyStr(r?.[GONAVI_ROW_KEY]) === keyStr);
if (isAdded) {
@@ -2200,12 +2239,13 @@ const DataGrid: React.FC<DataGridProps> = ({
const convertedValues: Record<string, any> = {};
Object.entries(values).forEach(([col, val]) => {
if (!isWritableResultColumn(col, effectiveEditLocator)) return;
const baseVal = baseRawMap[col];
if (val && dayjs.isDayjs(val)) {
const colMeta = columnMetaMap[col] || columnMetaMapByLowerName[col.toLowerCase()];
const rowPickerType = getTemporalPickerType(colMeta?.type, dbType, currentConnConfig);
convertedValues[col] = formatFromDayjs(val as dayjs.Dayjs, rowPickerType);
} else {
convertedValues[col] = val;
convertedValues[col] = normalizeMongoEditedCellValue(col, val, baseVal);
}
});
setAddedRows(prev => prev.map(r => rowKeyStr(r?.[GONAVI_ROW_KEY]) === keyStr ? { ...r, ...convertedValues } : r));
@@ -2213,7 +2253,6 @@ const DataGrid: React.FC<DataGridProps> = ({
return;
}
const baseRawMap = rowEditorBaseRawRef.current || {};
const patch: Record<string, any> = {};
visibleColumnNames.forEach((col) => {
if (!isWritableResultColumn(col, effectiveEditLocator)) return;
@@ -2223,6 +2262,8 @@ const DataGrid: React.FC<DataGridProps> = ({
const colMeta = columnMetaMap[col] || columnMetaMapByLowerName[col.toLowerCase()];
const rowPickerType = getTemporalPickerType(colMeta?.type, dbType, currentConnConfig);
nextVal = formatFromDayjs(nextVal as dayjs.Dayjs, rowPickerType);
} else {
nextVal = normalizeMongoEditedCellValue(col, nextVal, baseRawMap[col]);
}
const baseVal = baseRawMap[col];
if (!isCellValueEqualForDiff(baseVal, nextVal)) patch[col] = nextVal;
@@ -2236,7 +2277,7 @@ const DataGrid: React.FC<DataGridProps> = ({
});
closeRowEditor();
}, [addedRows, closeRowEditor, columnMetaMap, columnMetaMapByLowerName, currentConnConfig, dbType, effectiveEditLocator, rowEditorForm, rowEditorRowKey, rowKeyStr, visibleColumnNames]);
}, [addedRows, closeRowEditor, columnMetaMap, columnMetaMapByLowerName, currentConnConfig, dbType, effectiveEditLocator, normalizeMongoEditedCellValue, rowEditorForm, rowEditorRowKey, rowKeyStr, visibleColumnNames]);
const enableVirtual = isTableSurfaceActive;

View File

@@ -73,6 +73,7 @@ import {
} from './dataGridClipboardExport';
import { applyNoAutoCapAttributesWithin, noAutoCapInputProps } from '../utils/inputAutoCap';
import { DEFAULT_SHORTCUT_OPTIONS, getShortcutPlatform, resolveShortcutDisplay } from '../utils/shortcuts';
import { formatMongoValueForDisplay } from '../utils/mongodb';
import {
TEMPORAL_FORMATS,
formatFromDayjs,
@@ -355,6 +356,10 @@ export const formatCellDisplayText = (val: any, columnType?: string, connectionC
if (val === null) return 'NULL';
const bitText = normalizeBitHexDisplayText(val, columnType);
if (bitText !== null) return bitText;
if (String(connectionConfig?.type || '').trim().toLowerCase() === 'mongodb') {
const mongoText = formatMongoValueForDisplay(val);
return mongoText.length > TABLE_CELL_PREVIEW_MAX_CHARS ? `${mongoText.slice(0, TABLE_CELL_PREVIEW_MAX_CHARS)}` : mongoText;
}
if (typeof val === 'object') {
if (!Array.isArray(val) && !isPlainObject(val)) {
return String(val);
@@ -398,6 +403,9 @@ const formatClipboardCellText = (val: any, columnType?: string, connectionConfig
if (val === null || val === undefined) return 'NULL';
const bitText = normalizeBitHexDisplayText(val, columnType);
if (bitText !== null) return bitText;
if (String(connectionConfig?.type || '').trim().toLowerCase() === 'mongodb') {
return formatMongoValueForDisplay(val);
}
if (typeof val === 'string') {
const oceanBaseDateOnly = normalizeOceanBaseOracleDateDisplayText(val, columnType, connectionConfig);
if (oceanBaseDateOnly !== null) return oceanBaseDateOnly;

View File

@@ -164,6 +164,14 @@ describe('DataViewer safe editing locator', () => {
expect(source).toContain('data_viewer.sql_log.phase.sort_buffer_retry');
});
it('caps viewer filter snapshots so long-running sessions do not retain unbounded table state', () => {
const source = readFileSync(new URL('./DataViewer.tsx', import.meta.url), 'utf8');
expect(source).toContain('const MAX_VIEWER_FILTER_SNAPSHOTS = 64;');
expect(source).toContain('const trimViewerFilterSnapshots = () => {');
expect(source).toContain('setViewerFilterSnapshot(normalizedTabId, {');
});
it('enables table preview editing after primary keys are loaded', async () => {
backendApp.DBGetColumns.mockResolvedValue({
success: true,

View File

@@ -289,8 +289,32 @@ type ViewerScrollSnapshot = {
};
const viewerFilterSnapshotsByTab = new Map<string, ViewerFilterSnapshot>();
const MAX_VIEWER_FILTER_SNAPSHOTS = 64;
const VIEWER_SCROLL_SNAPSHOT_PERSIST_DELAY_MS = 160;
const trimViewerFilterSnapshots = () => {
while (viewerFilterSnapshotsByTab.size > MAX_VIEWER_FILTER_SNAPSHOTS) {
const oldestKey = viewerFilterSnapshotsByTab.keys().next().value;
if (!oldestKey) {
break;
}
viewerFilterSnapshotsByTab.delete(oldestKey);
}
};
const setViewerFilterSnapshot = (
tabId: string,
snapshot: ViewerFilterSnapshot,
) => {
const normalizedTabId = String(tabId || '').trim();
if (!normalizedTabId) return;
if (viewerFilterSnapshotsByTab.has(normalizedTabId)) {
viewerFilterSnapshotsByTab.delete(normalizedTabId);
}
viewerFilterSnapshotsByTab.set(normalizedTabId, snapshot);
trimViewerFilterSnapshots();
};
const normalizeViewerFilterConditions = (conditions: FilterCondition[] | undefined): FilterCondition[] => {
if (!Array.isArray(conditions)) return [];
return conditions.map((cond) => ({
@@ -389,7 +413,7 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = React.memo(({
const persistViewerSnapshot = useCallback((tabId: string, overrides?: Partial<ViewerFilterSnapshot>) => {
const normalizedTabId = String(tabId || '').trim();
if (!normalizedTabId) return;
viewerFilterSnapshotsByTab.set(normalizedTabId, {
setViewerFilterSnapshot(normalizedTabId, {
showFilter,
conditions: normalizeViewerFilterConditions(filterConditions),
quickWhereCondition: normalizeQuickWhereCondition(quickWhereCondition),

View File

@@ -2878,6 +2878,62 @@ describe('QueryEditor external SQL save', () => {
expect(messageApi.warning).not.toHaveBeenCalled();
});
it('auto aliases Oracle duplicate explicit columns before alias star expansion', async () => {
storeState.connections[0].config.type = 'oracle';
storeState.connections[0].config.database = 'APP';
backendApp.DBQueryMulti.mockResolvedValueOnce({
success: true,
data: [{
columns: ['EHR_USERID_1', 'USERID', 'EHR_USERID', 'USERNAME'],
rows: [{
EHR_USERID_1: 'emp-1',
USERID: 7,
EHR_USERID: 'emp-1',
USERNAME: 'alice',
}],
}],
});
backendApp.DBGetColumns.mockResolvedValueOnce({
success: true,
data: [
{ name: 'USERID', key: 'PRI' },
{ name: 'EHR_USERID', key: '' },
{ name: 'USERNAME', key: '' },
],
});
let renderer: ReactTestRenderer;
await act(async () => {
renderer = create(<QueryEditor tab={createTab({
dbName: 'APP',
query: 'SELECT EHR_USERID, a.* FROM S_USER_BASE a',
})} />);
});
await act(async () => {
await findButton(renderer!, '运行').props.onClick();
});
await act(async () => {
await Promise.resolve();
await Promise.resolve();
});
expect(dataGridState.latestProps?.readOnly).toBe(false);
expect(dataGridState.latestProps?.editLocator).toMatchObject({
strategy: 'primary-key',
columns: ['USERID'],
valueColumns: ['USERID'],
writableColumns: {
USERID: 'USERID',
EHR_USERID: 'EHR_USERID',
USERNAME: 'USERNAME',
},
readOnly: false,
});
expect(String(backendApp.DBQueryMulti.mock.calls[0][2])).toContain('EHR_USERID AS EHR_USERID_1, a.*');
expect(messageApi.warning).not.toHaveBeenCalled();
});
it.each([
'mysql',
'mariadb',

View File

@@ -2414,9 +2414,6 @@ describe('Sidebar locate toolbar', () => {
const loadTablesStart = source.indexOf('const loadTables = async (node: any) => {');
const loadTablesEnd = source.indexOf('const config = {', loadTablesStart);
const loadTablesSource = source.slice(loadTablesStart, loadTablesEnd);
const externalSqlReadStart = source.indexOf('const externalSQLDirectoryResults = await Promise.all(', loadTablesStart);
const externalSqlReadEnd = source.indexOf('const externalSQLTrees = externalSQLDirectoryResults.reduce', externalSqlReadStart);
const externalSqlReadSource = source.slice(externalSqlReadStart, externalSqlReadEnd);
const externalSqlFlowStart = source.indexOf('const handleAddExternalSQLDirectory = async (node: any) => {');
const externalSqlFlowEnd = source.indexOf('const cancelSQLFileExecution = () => {', externalSqlFlowStart);
const externalSqlFlowSource = source.slice(externalSqlFlowStart, externalSqlFlowEnd);
@@ -2439,8 +2436,6 @@ describe('Sidebar locate toolbar', () => {
[
loadTablesStart,
loadTablesEnd,
externalSqlReadStart,
externalSqlReadEnd,
externalSqlFlowStart,
externalSqlFlowEnd,
treeTitleStart,
@@ -2457,9 +2452,7 @@ describe('Sidebar locate toolbar', () => {
expect(loadTablesSource).toContain("title: t('sidebar.tree.saved_queries')");
expect(loadTablesSource).not.toContain("title: '已存查询'");
expect(externalSqlReadSource).toContain("t('sidebar.message.external_sql_directory_read_failed'");
expect(externalSqlReadSource).not.toContain('SQL 目录读取失败');
expect(source).not.toContain('const externalSQLDirectoryResults = await Promise.all(');
expect(loadTablesSource).not.toContain('SQL 目录读取失败');
expect(loadTablesSource).not.toContain("'SQL目录'");

View File

@@ -164,6 +164,7 @@ import {
resolveSidebarDropInsertBefore,
resolveSidebarDropNodeFromDomEvent,
resolveSidebarDropTargetMetricsFromDomEvent,
resolveSidebarDatabaseTreePruneKeys,
resolveSidebarNodeConnectionId,
resolveSidebarTagDropInsertBefore,
resolveV2ActiveConnectionId,
@@ -190,6 +191,7 @@ export {
resolveSidebarDropInsertBefore,
resolveSidebarDropNodeFromDomEvent,
resolveSidebarDropTargetMetricsFromDomEvent,
resolveSidebarDatabaseTreePruneKeys,
resolveSidebarNodeConnectionId,
resolveSidebarTagDropInsertBefore,
resolveV2ActiveConnectionId,
@@ -205,6 +207,7 @@ export type { V2CommandSearchItem, V2RailConnectionGroup } from './sidebarV2Util
const { Search } = Input;
const SIDEBAR_LOCATE_LOAD_WAIT_INTERVAL_MS = 50;
const SIDEBAR_LOCATE_LOAD_WAIT_ATTEMPTS = 160;
const SIDEBAR_CACHED_DATABASE_TREE_LIMIT = 12;
// resolveV2ObjectGroupTitle 已迁移到 ./sidebar/sidebarHelpers
@@ -506,6 +509,8 @@ const Sidebar: React.FC<{
const [selectedKeys, setSelectedKeys] = useState<React.Key[]>([]);
const selectedNodesRef = useRef<any[]>([]);
const loadingNodesRef = useRef<Set<string>>(new Set());
const databaseTreeTouchedAtRef = useRef<Record<string, number>>({});
const pruneLoadedDatabaseTreesRef = useRef<() => void>(() => {});
const clickTimerRef = useRef<ReturnType<typeof setTimeout> | null>(null);
const treeDragSelectSuppressUntilRef = useRef(0);
const treeDragSelectionSnapshotRef = useRef<{
@@ -544,6 +549,7 @@ const Sidebar: React.FC<{
}, [setActiveContext]);
const openV2CommandSearch = useCallback(() => {
pruneLoadedDatabaseTreesRef.current();
setIsV2CommandSearchOpen(true);
setV2CommandActiveIndex(0);
}, []);
@@ -984,6 +990,55 @@ const Sidebar: React.FC<{
return nextTreeData;
};
const clearTreeNodeChildrenByKeys = useCallback((keysToClear: string[]) => {
const keysToClearSet = new Set(keysToClear.map((key) => String(key || '').trim()).filter(Boolean));
if (keysToClearSet.size === 0) {
return;
}
const clearChildren = (nodes: TreeNode[]): TreeNode[] => (
nodes.map((node) => {
const nodeKey = String(node.key || '').trim();
if (keysToClearSet.has(nodeKey)) {
return { ...node, children: undefined };
}
if (node.children?.length) {
return { ...node, children: clearChildren(node.children) };
}
return node;
})
);
setTreeData((prev) => {
const nextTreeData = clearChildren(prev);
treeDataRef.current = nextTreeData;
return nextTreeData;
});
setLoadedKeys((prev) => prev.filter((key) => !keysToClearSet.has(String(key))));
keysToClearSet.forEach((key) => {
delete databaseTreeTouchedAtRef.current[key];
});
}, []);
const pruneLoadedDatabaseTrees = useCallback(() => {
const activeDatabaseKey = activeContext?.connectionId && activeContext?.dbName
? `${activeContext.connectionId}-${activeContext.dbName}`
: '';
const keysToClear = resolveSidebarDatabaseTreePruneKeys({
treeData: treeDataRef.current,
expandedKeys,
selectedKeys,
activeDatabaseKey,
touchedAtByDatabaseKey: databaseTreeTouchedAtRef.current,
maxLoadedDatabases: SIDEBAR_CACHED_DATABASE_TREE_LIMIT,
});
if (keysToClear.length === 0) {
return;
}
clearTreeNodeChildrenByKeys(keysToClear);
}, [activeContext?.connectionId, activeContext?.dbName, clearTreeNodeChildrenByKeys, expandedKeys, selectedKeys]);
pruneLoadedDatabaseTreesRef.current = pruneLoadedDatabaseTrees;
const mergeExpandedTreeKeys = (requiredKeys: React.Key[]) => {
setExpandedKeys(prev => {
const merged = [...prev];
@@ -1727,7 +1782,6 @@ const Sidebar: React.FC<{
loadTables,
} = useSidebarTreeLoaders({
savedQueries,
externalSQLDirectories,
tableSortPreference,
tableAccessCount,
pinnedSidebarTables,
@@ -1740,7 +1794,10 @@ const Sidebar: React.FC<{
buildJVMRuntimeConfig,
buildJVMDiagnosticTreeNodes,
resolveSavedQueryDisplayName,
decorateExternalSQLTreeNode,
onDatabaseTreeLoaded: (databaseKey: string) => {
databaseTreeTouchedAtRef.current[databaseKey] = Date.now();
pruneLoadedDatabaseTrees();
},
});
const {
@@ -1950,6 +2007,7 @@ const Sidebar: React.FC<{
treeViewportWidth,
treeHeight,
isV2Ui,
isV2CommandSearchOpen,
connections,
connectionIds,
selectedKeys,

View File

@@ -210,7 +210,13 @@ export const getLastIdentifierPart = (path: string): string => {
return parts[parts.length - 1] || '';
};
export const resolveSimpleSelectItemColumn = (item: string): { resultName: string; sourceName: string } | 'all' | undefined => {
export type SelectItemInfo = {
expression: string;
resultName: string;
sourceName?: string;
};
export const resolveSelectItemInfo = (item: string): SelectItemInfo | 'all' | undefined => {
const text = String(item || '').trim();
if (!text) return undefined;
if (text === '*' || /\.\s*\*$/.test(text)) return 'all';
@@ -232,10 +238,16 @@ export const resolveSimpleSelectItemColumn = (item: string): { resultName: strin
}
}
if (!SIMPLE_IDENTIFIER_PATH_RE.test(expr)) return undefined;
const sourceName = getLastIdentifierPart(expr);
if (!alias && !SIMPLE_IDENTIFIER_PATH_RE.test(expr)) return undefined;
const sourceName = SIMPLE_IDENTIFIER_PATH_RE.test(expr) ? getLastIdentifierPart(expr) : '';
const resultName = alias || sourceName;
return sourceName && resultName ? { resultName, sourceName } : undefined;
return resultName ? { expression: expr, resultName, sourceName: sourceName || undefined } : undefined;
};
export const resolveSimpleSelectItemColumn = (item: string): { resultName: string; sourceName: string } | 'all' | undefined => {
const resolved = resolveSelectItemInfo(item);
if (!resolved || resolved === 'all' || !resolved.sourceName) return resolved === 'all' ? 'all' : undefined;
return { resultName: resolved.resultName, sourceName: resolved.sourceName };
};
export const parseSimpleSelectInfo = (sql: string): SimpleSelectInfo | undefined => {
@@ -354,6 +366,57 @@ export const rewriteOracleSelectAllWithExpressions = (sql: string, expressions:
return `${prefix}${finalSelectItems.join(', ')}${fromKeyword}${tableText}${aliasClause}${parsedAlias.remainder}`;
};
export const rewriteOracleDuplicateSelectColumns = (sql: string, tableColumnNames: string[]): string | undefined => {
const metadataNames = new Set(
tableColumnNames
.map((name) => String(name || '').trim().toLowerCase())
.filter(Boolean),
);
if (metadataNames.size === 0) return undefined;
const match = String(sql || '').match(/^(\s*SELECT\s+)([\s\S]+?)(\s+FROM\s+[\s\S]*)$/i);
if (!match) return undefined;
const prefix = match[1];
const selectList = match[2].trim();
const rest = match[3];
const selectItems = splitTopLevelComma(selectList);
if (selectItems.length === 0) return undefined;
const parsedItems = selectItems.map((item) => ({
raw: String(item || '').trimEnd(),
info: resolveSelectItemInfo(item),
}));
const hasWildcard = parsedItems.some(({ info }) => info === 'all');
if (!hasWildcard) return undefined;
const usedResultNames = new Set<string>(metadataNames);
parsedItems.forEach(({ info }) => {
if (!info || info === 'all') return;
const normalizedResult = String(info.resultName || '').trim().toLowerCase();
if (normalizedResult) usedResultNames.add(normalizedResult);
});
let changed = false;
const rewrittenItems = parsedItems.map(({ raw, info }) => {
if (!info || info === 'all') return raw;
const normalizedResult = String(info.resultName || '').trim().toLowerCase();
if (!metadataNames.has(normalizedResult)) return raw;
let nextIndex = 1;
let alias = `${info.resultName}_${nextIndex}`;
while (usedResultNames.has(alias.toLowerCase())) {
nextIndex++;
alias = `${info.resultName}_${nextIndex}`;
}
usedResultNames.add(alias.toLowerCase());
changed = true;
return `${info.expression} AS ${alias}`;
});
return changed ? `${prefix}${rewrittenItems.join(', ')}${rest}` : undefined;
};
export const findWritableResultColumnForSource = (writableColumns: Record<string, string>, target: string): string | undefined => {
const normalizedTarget = String(target || '').trim().toLowerCase();
return Object.entries(writableColumns || {}).find(([, sourceColumn]) => (
@@ -1968,6 +2031,11 @@ export const resolveQueryLocatorPlan = async ({
const tableColumns = resCols.data as ColumnDefinition[];
const tableColumnNames = tableColumns.map(getColumnDefinitionName).filter(Boolean);
let executableStatement = statement;
if (isOracleLikeDialect(dbType) && selectInfo.selectsAll) {
const rewritten = rewriteOracleDuplicateSelectColumns(executableStatement, tableColumnNames);
if (rewritten) executableStatement = rewritten;
}
const primaryKeys = tableColumns
.filter((column: any) => getColumnDefinitionKey(column) === 'PRI')
.map(getColumnDefinitionName)
@@ -2058,7 +2126,7 @@ export const resolveQueryLocatorPlan = async ({
];
if (executableAppendExpressions.length > 0 && isOracleLikeDialect(dbType) && selectInfo.selectsBareAll) {
const rewritten = rewriteOracleSelectAllWithExpressions(statement, executableAppendExpressions);
const rewritten = rewriteOracleSelectAllWithExpressions(executableStatement, executableAppendExpressions);
if (rewritten) {
plan.executedSql = rewritten;
return plan;
@@ -2070,7 +2138,7 @@ export const resolveQueryLocatorPlan = async ({
return plan;
}
plan.executedSql = appendQuerySelectExpressions(statement, executableAppendExpressions);
plan.executedSql = appendQuerySelectExpressions(executableStatement, executableAppendExpressions);
return plan;
} catch {
const reason = translate('query_editor.message.read_only_table_locator_metadata_unavailable', {

View File

@@ -30,6 +30,7 @@ import {
} from './sidebarHelpers';
import type { SearchScope } from '../sidebarCoreUtils';
import {
buildV2CommandSearchTreeIndex,
V2_TREE_HORIZONTAL_SCROLL_BOTTOM_RESERVE,
estimateV2TreeHorizontalScrollWidth,
filterV2CommandSearchTreeItems,
@@ -74,6 +75,7 @@ type SidebarSearchModelArgs = {
treeViewportWidth: number;
treeHeight: number;
isV2Ui: boolean;
isV2CommandSearchOpen: boolean;
connections: SavedConnection[];
connectionIds: string[];
selectedKeys: React.Key[];
@@ -111,6 +113,7 @@ export const useSidebarSearchModel = ({
treeViewportWidth,
treeHeight,
isV2Ui,
isV2CommandSearchOpen,
connections,
connectionIds,
selectedKeys,
@@ -179,6 +182,10 @@ export const useSidebarSearchModel = ({
};
const currentLanguage = getCurrentLanguage();
const connectionById = useMemo(
() => new Map(connections.map((connection) => [connection.id, connection])),
[connections],
);
const searchScopeSummary = useMemo(() => {
if (searchScopes.includes('smart')) {
@@ -360,6 +367,9 @@ export const useSidebarSearchModel = ({
}, [deferredSearchValue, searchScopes, treeData]);
const commandSearchTreeItems = useMemo(() => {
if (!isV2CommandSearchOpen) {
return [];
}
const result: V2CommandSearchItem[] = [];
const visit = (nodes: TreeNode[]) => {
nodes.forEach((node) => {
@@ -375,7 +385,7 @@ export const useSidebarSearchModel = ({
node,
});
} else if (node.type === 'database') {
const conn = connections.find((item) => item.id === dataRef.id);
const conn = connectionById.get(String(dataRef.id || ''));
result.push({
key: `node-${node.key}`,
kind: 'node',
@@ -392,7 +402,7 @@ export const useSidebarSearchModel = ({
|| node.type === 'db-event'
|| node.type === 'routine'
) {
const conn = connections.find((item) => item.id === dataRef.id);
const conn = connectionById.get(String(dataRef.id || ''));
const objectName = String(dataRef.tableName || dataRef.viewName || dataRef.triggerName || dataRef.eventName || dataRef.routineName || node.title || '').trim();
const displayName = String(node.title || extractObjectName(objectName) || objectName).trim();
result.push({
@@ -412,7 +422,11 @@ export const useSidebarSearchModel = ({
visit(treeData);
return result;
}, [connections, treeData]);
}, [connectionById, extractObjectName, isV2CommandSearchOpen, treeData]);
const commandSearchTreeIndex = useMemo(
() => buildV2CommandSearchTreeIndex(commandSearchTreeItems),
[commandSearchTreeItems],
);
const commandSearchRecentItems = useMemo<V2CommandSearchItem[]>(() => {
return sqlLogs.slice(0, 5).map((log) => ({
@@ -473,8 +487,8 @@ export const useSidebarSearchModel = ({
const v2CommandSearchObjectMode = v2CommandSearchQuery.mode === 'object';
const v2CommandSearchAiMode = v2CommandSearchQuery.mode === 'ai';
const filteredCommandSearchTreeItems = useMemo(() => {
return filterV2CommandSearchTreeItems(commandSearchTreeItems, v2CommandSearchQuery);
}, [commandSearchTreeItems, v2CommandSearchQuery]);
return filterV2CommandSearchTreeItems(commandSearchTreeIndex, v2CommandSearchQuery);
}, [commandSearchTreeIndex, v2CommandSearchQuery]);
const filteredCommandSearchActionItems = useMemo(() => {
if (v2CommandSearchObjectMode || v2CommandSearchAiMode) return [];

View File

@@ -13,14 +13,13 @@ import {
TableOutlined,
ThunderboltOutlined,
} from '@ant-design/icons';
import type { SavedConnection, SavedQuery, ExternalSQLDirectory, ExternalSQLTreeEntry, JVMCapability, JVMResourceSummary } from '../../types';
import type { SavedConnection, SavedQuery, JVMCapability, JVMResourceSummary } from '../../types';
import { useStore } from '../../store';
import { t } from '../../i18n';
import { buildRpcConnectionConfig } from '../../utils/connectionRpcConfig';
import { buildRedisDbNodeLabel, getRedisDbAlias } from '../../utils/redisDbAlias';
import { buildJVMMonitoringActionDescriptors } from '../../utils/jvmSidebarActions';
import { type SidebarViewMetadataEntry } from '../../utils/sidebarMetadata';
import { buildExternalSQLRootNode, type ExternalSQLTreeNode } from '../../utils/externalSqlTree';
import {
buildQualifiedName,
buildSidebarObjectKeyName,
@@ -47,7 +46,7 @@ import {
sortSidebarTableEntries,
type SidebarTreeNode as TreeNode,
} from '../sidebarV2Utils';
import { DBGetDatabases, DBGetTables, DBQuery, GetDriverStatusList, JVMProbeCapabilities, ListSQLDirectory } from '../../../wailsjs/go/app/App';
import { DBGetDatabases, DBGetTables, DBQuery, GetDriverStatusList, JVMProbeCapabilities } from '../../../wailsjs/go/app/App';
type DriverStatusSnapshot = {
type: string;
@@ -119,7 +118,6 @@ const resolveSavedConnectionDriverType = (conn: SavedConnection | undefined): st
type UseSidebarTreeLoadersOptions = {
savedQueries: SavedQuery[];
externalSQLDirectories: ExternalSQLDirectory[];
tableSortPreference: Record<string, any>;
tableAccessCount: Record<string, any>;
pinnedSidebarTables: any[];
@@ -132,12 +130,11 @@ type UseSidebarTreeLoadersOptions = {
buildJVMRuntimeConfig: (conn: SavedConnection & { dbName?: string }, providerMode: string) => any;
buildJVMDiagnosticTreeNodes: (conn: SavedConnection) => TreeNode[];
resolveSavedQueryDisplayName: (name: string | null | undefined) => string;
decorateExternalSQLTreeNode: (node: ExternalSQLTreeNode) => TreeNode;
onDatabaseTreeLoaded?: (databaseKey: string) => void;
};
export const useSidebarTreeLoaders = ({
savedQueries,
externalSQLDirectories,
tableSortPreference,
tableAccessCount,
pinnedSidebarTables,
@@ -150,7 +147,7 @@ export const useSidebarTreeLoaders = ({
buildJVMRuntimeConfig,
buildJVMDiagnosticTreeNodes,
resolveSavedQueryDisplayName,
decorateExternalSQLTreeNode,
onDatabaseTreeLoaded,
}: UseSidebarTreeLoadersOptions) => {
const driverStatusCacheRef = useRef<{
fetchedAt: number;
@@ -516,40 +513,6 @@ export const useSidebarTreeLoaders = ({
loadFunctions(conn, conn.dbName),
loadDatabaseEvents(conn, conn.dbName),
]);
const externalSQLDirectoryResults = await Promise.all(
externalSQLDirectories.map(async (directory: ExternalSQLDirectory) => {
const directoryRes = await ListSQLDirectory(directory.path);
if (!directoryRes.success) {
message.warning({
key: `external-sql-${directory.id}`,
content: t('sidebar.message.external_sql_directory_read_failed', {
name: directory.name,
error: directoryRes.message,
}),
});
return { id: directory.id, entries: [] as ExternalSQLTreeEntry[] };
}
return {
id: directory.id,
entries: Array.isArray(directoryRes.data) ? directoryRes.data as ExternalSQLTreeEntry[] : [],
};
}),
);
const externalSQLTrees = externalSQLDirectoryResults.reduce<Record<string, ExternalSQLTreeEntry[]>>((accumulator, item) => {
accumulator[item.id] = item.entries;
return accumulator;
}, {});
const externalSQLRootNode = decorateExternalSQLTreeNode(buildExternalSQLRootNode({
dbNodeKey: String(key),
connectionId: String(conn.id),
dbName: String(conn.dbName),
directories: externalSQLDirectories,
directoryTrees: externalSQLTrees,
labels: {
root: t('sidebar.external_sql.root'),
directoryFallback: t('sidebar.external_sql.directory_fallback'),
},
}));
const viewRows: SidebarViewMetadataEntry[] = Array.isArray(viewsResult.views) ? viewsResult.views : [];
const materializedViewRows: SidebarViewMetadataEntry[] = Array.isArray(materializedViewsResult.views) ? materializedViewsResult.views : [];
const triggerRows: any[] = Array.isArray(triggersResult.triggers) ? triggersResult.triggers : [];
@@ -855,6 +818,7 @@ export const useSidebarTreeLoaders = ({
replaceTreeNodeChildren(key, [queriesNode, ...groupedNodes]);
}
onDatabaseTreeLoaded?.(String(key));
} else {
setConnectionStates(prev => ({ ...prev, [key as string]: 'error' }));
message.error({ content: res.message, key: `db-${key}-tables` });

View File

@@ -0,0 +1,116 @@
import { describe, expect, it } from 'vitest';
import {
V2_COMMAND_SEARCH_INITIAL_TREE_LIMIT,
V2_COMMAND_SEARCH_MAX_TREE_RESULTS,
buildV2CommandSearchTreeIndex,
filterV2CommandSearchTreeItems,
parseV2CommandSearchQuery,
resolveSidebarDatabaseTreePruneKeys,
type V2CommandSearchItem,
} from './sidebarV2Utils';
const buildNodeItems = (count: number): V2CommandSearchItem[] => {
return Array.from({ length: count }, (_, index) => ({
key: `node-table-${index}`,
kind: 'node' as const,
title: `fs_order_${index}`,
meta: `开发240 · front_end_sys_${index % 4}`,
icon: null,
node: {
type: index % 6 === 0 ? 'view' : 'table',
key: `table-${index}`,
title: `fs_order_${index}`,
dataRef: {
tableName: `fs_order_${index}`,
viewName: index % 6 === 0 ? `v_order_${index}` : undefined,
dbName: `front_end_sys_${index % 4}`,
name: `obj_${index}`,
config: {
host: `10.0.0.${index % 16}`,
},
},
},
}));
};
describe('sidebarV2 command search performance helpers', () => {
it('keeps the initial tree result limit when the query is empty', () => {
const items = buildNodeItems(V2_COMMAND_SEARCH_INITIAL_TREE_LIMIT + 80);
expect(
filterV2CommandSearchTreeItems(items, parseV2CommandSearchQuery('')),
).toHaveLength(V2_COMMAND_SEARCH_INITIAL_TREE_LIMIT);
});
it('caps broad keyword matches to avoid rendering the full loaded tree', () => {
const items = buildNodeItems(V2_COMMAND_SEARCH_MAX_TREE_RESULTS + 160);
const result = filterV2CommandSearchTreeItems(
items,
parseV2CommandSearchQuery('fs_order'),
);
expect(result).toHaveLength(V2_COMMAND_SEARCH_MAX_TREE_RESULTS);
expect(result[0]?.key).toBe('node-table-0');
expect(result[result.length - 1]?.key).toBe(`node-table-${V2_COMMAND_SEARCH_MAX_TREE_RESULTS - 1}`);
});
it('returns the same matches when filtering with a prebuilt search index', () => {
const items = buildNodeItems(200);
const index = buildV2CommandSearchTreeIndex(items);
const query = parseV2CommandSearchQuery('@fs_order_1');
expect(filterV2CommandSearchTreeItems(index, query)).toEqual(
filterV2CommandSearchTreeItems(items, query),
);
});
it('prunes only cold collapsed database trees when too many object trees stay loaded', () => {
expect(resolveSidebarDatabaseTreePruneKeys({
treeData: [
{
key: 'conn-1',
title: 'conn-1',
type: 'connection',
children: [
{
key: 'conn-1-db-a',
title: 'db-a',
type: 'database',
children: [{ key: 'a-tables', title: '表', type: 'object-group' }],
},
{
key: 'conn-1-db-b',
title: 'db-b',
type: 'database',
children: [{ key: 'b-tables', title: '表', type: 'object-group' }],
},
{
key: 'conn-1-db-c',
title: 'db-c',
type: 'database',
children: [{ key: 'c-tables', title: '表', type: 'object-group' }],
},
{
key: 'conn-1-db-d',
title: 'db-d',
type: 'database',
children: [{ key: 'd-tables', title: '表', type: 'object-group' }],
},
],
},
],
expandedKeys: ['conn-1-db-c'],
selectedKeys: [],
activeDatabaseKey: 'conn-1-db-d',
touchedAtByDatabaseKey: {
'conn-1-db-a': 10,
'conn-1-db-b': 20,
'conn-1-db-c': 30,
'conn-1-db-d': 40,
},
maxLoadedDatabases: 2,
})).toEqual(['conn-1-db-a', 'conn-1-db-b']);
});
});

View File

@@ -427,6 +427,13 @@ export type V2CommandSearchItem =
dbName?: string;
};
export interface V2CommandSearchTreeIndexEntry {
item: Extract<V2CommandSearchItem, { kind: 'node' }>;
normalizedSearchText: string;
normalizedObjectText: string;
objectNode: boolean;
}
export type V2CommandSearchMode = 'default' | 'object' | 'ai';
export interface V2CommandSearchQuery {
@@ -479,40 +486,69 @@ const isV2CommandSearchObjectNode = (node: SidebarTreeNode): boolean => {
|| node.type === 'materialized-view';
};
const V2_COMMAND_SEARCH_INITIAL_TREE_LIMIT = 24;
export const V2_COMMAND_SEARCH_INITIAL_TREE_LIMIT = 24;
export const V2_COMMAND_SEARCH_MAX_TREE_RESULTS = 120;
export const buildV2CommandSearchTreeIndex = (
items: V2CommandSearchItem[],
): V2CommandSearchTreeIndexEntry[] => {
return items.flatMap((item) => {
if (item.kind !== 'node') {
return [];
}
const dataRef = item.node.dataRef || {};
const normalizedTitle = String(item.title || '').toLowerCase();
const normalizedPrimaryObjectText = String(
dataRef.tableName || dataRef.viewName || item.title || '',
).toLowerCase();
return [{
item,
normalizedSearchText: [
item.title,
item.meta,
dataRef.tableName,
dataRef.viewName,
dataRef.dbName,
dataRef.name,
dataRef.config?.host,
].filter(Boolean).join(' ').toLowerCase(),
normalizedObjectText: `${normalizedPrimaryObjectText} ${normalizedTitle}`.trim(),
objectNode: isV2CommandSearchObjectNode(item.node),
}];
});
};
export const filterV2CommandSearchTreeItems = (
items: V2CommandSearchItem[],
items: V2CommandSearchItem[] | V2CommandSearchTreeIndexEntry[],
query: V2CommandSearchQuery,
): V2CommandSearchItem[] => {
if (query.mode === 'ai') return [];
const index = items.length > 0 && 'item' in items[0]
? items as V2CommandSearchTreeIndexEntry[]
: buildV2CommandSearchTreeIndex(items as V2CommandSearchItem[]);
const normalizedKeyword = query.normalizedKeyword;
const objectMode = query.mode === 'object';
const matchedItems = items.filter((item) => {
if (item.kind !== 'node') return false;
const node = item.node;
const dataRef = node.dataRef || {};
if (objectMode && !isV2CommandSearchObjectNode(node)) {
return false;
const result: V2CommandSearchItem[] = [];
const maxResults = normalizedKeyword
? V2_COMMAND_SEARCH_MAX_TREE_RESULTS
: V2_COMMAND_SEARCH_INITIAL_TREE_LIMIT;
for (const entry of index) {
if (objectMode && !entry.objectNode) {
continue;
}
if (!normalizedKeyword) return true;
const objectName = String(dataRef.tableName || dataRef.viewName || item.title || '').toLowerCase();
if (objectMode) {
return objectName.includes(normalizedKeyword)
|| String(item.title || '').toLowerCase().includes(normalizedKeyword);
if (!normalizedKeyword) {
result.push(entry.item);
} else if (objectMode ? entry.normalizedObjectText.includes(normalizedKeyword) : entry.normalizedSearchText.includes(normalizedKeyword)) {
result.push(entry.item);
}
const haystack = [
item.title,
item.meta,
dataRef.tableName,
dataRef.viewName,
dataRef.dbName,
dataRef.name,
dataRef.config?.host,
].filter(Boolean).join(' ').toLowerCase();
return haystack.includes(normalizedKeyword);
});
return normalizedKeyword ? matchedItems : matchedItems.slice(0, V2_COMMAND_SEARCH_INITIAL_TREE_LIMIT);
if (result.length >= maxResults) {
break;
}
}
return result;
};
export interface V2CommandSearchEnterState {
@@ -777,4 +813,63 @@ export const resolveV2ActiveConnectionId = ({
|| '';
};
export const resolveSidebarDatabaseTreePruneKeys = ({
treeData,
expandedKeys,
selectedKeys,
activeDatabaseKey,
touchedAtByDatabaseKey,
maxLoadedDatabases,
}: {
treeData: SidebarTreeNode[];
expandedKeys: React.Key[];
selectedKeys: React.Key[];
activeDatabaseKey?: string;
touchedAtByDatabaseKey?: Record<string, number>;
maxLoadedDatabases: number;
}): string[] => {
if (!Number.isFinite(maxLoadedDatabases) || maxLoadedDatabases <= 0) {
return [];
}
const loadedDatabaseKeys: string[] = [];
const visit = (nodes: SidebarTreeNode[]) => {
nodes.forEach((node) => {
if (node.type === 'database' && Array.isArray(node.children) && node.children.length > 0) {
loadedDatabaseKeys.push(String(node.key || '').trim());
return;
}
if (node.children?.length) {
visit(node.children);
}
});
};
visit(treeData);
if (loadedDatabaseKeys.length <= maxLoadedDatabases) {
return [];
}
const expandedKeySet = new Set(expandedKeys.map((key) => String(key || '').trim()).filter(Boolean));
const selectedKeySet = new Set(selectedKeys.map((key) => String(key || '').trim()).filter(Boolean));
const protectedDatabaseKeys = new Set<string>();
if (activeDatabaseKey) {
protectedDatabaseKeys.add(String(activeDatabaseKey).trim());
}
const candidates = loadedDatabaseKeys
.filter((key) => key && !expandedKeySet.has(key) && !selectedKeySet.has(key) && !protectedDatabaseKeys.has(key))
.sort((left, right) => {
const leftTouchedAt = Number(touchedAtByDatabaseKey?.[left] || 0);
const rightTouchedAt = Number(touchedAtByDatabaseKey?.[right] || 0);
if (leftTouchedAt !== rightTouchedAt) {
return leftTouchedAt - rightTouchedAt;
}
return left.localeCompare(right);
});
const pruneCount = loadedDatabaseKeys.length - maxLoadedDatabases;
return candidates.slice(0, pruneCount);
};
export const shouldClearSidebarActiveContextOnEmptySelect = (isV2Ui: boolean): boolean => !isV2Ui;

View File

@@ -1337,33 +1337,80 @@ describe('store appearance persistence', () => {
expect(useStore.getState().activeTabId).toBe('query-1');
});
it('persists recent SQL execution logs and trims oversized entries', async () => {
it('keeps only the most recent runtime SQL logs and trims oversized entries', async () => {
const { useStore } = await importStore();
const longSql = `select '${'x'.repeat(120 * 1024)}'`;
const longSql = `select '${'x'.repeat(20 * 1024)}'`;
useStore.getState().addSqlLog({
id: 'log-1',
timestamp: 100,
sql: longSql,
status: 'success',
duration: 12,
for (let i = 0; i < 140; i += 1) {
useStore.getState().addSqlLog({
id: `log-${i}`,
timestamp: 100 + i,
sql: longSql,
status: 'success',
duration: 12 + i,
dbName: 'main',
});
}
expect(useStore.getState().sqlLogs).toHaveLength(120);
expect(useStore.getState().sqlLogs[0]).toEqual(expect.objectContaining({
id: 'log-139',
dbName: 'main',
});
}));
expect(useStore.getState().sqlLogs[119]).toEqual(expect.objectContaining({
id: 'log-20',
}));
expect(useStore.getState().sqlLogs[0]?.sql.length).toBe(12 * 1024);
const persisted = JSON.parse(storage.getItem('lite-db-storage') || '{}');
expect(persisted.state.sqlLogs).toHaveLength(1);
expect(persisted.state.sqlLogs[0].sql.length).toBe(100 * 1024);
expect(persisted.state.sqlLogs).toHaveLength(120);
expect(persisted.state.sqlLogs[0].sql.length).toBe(12 * 1024);
expect(persisted.state.sqlLogs[0].dbName).toBe('main');
vi.resetModules();
const reloaded = await importStore();
expect(reloaded.useStore.getState().sqlLogs[0]).toEqual(expect.objectContaining({
id: 'log-1',
id: 'log-139',
status: 'success',
duration: 12,
duration: 151,
dbName: 'main',
}));
expect(reloaded.useStore.getState().sqlLogs[0]?.sql.length).toBe(100 * 1024);
expect(reloaded.useStore.getState().sqlLogs).toHaveLength(120);
expect(reloaded.useStore.getState().sqlLogs[119]).toEqual(expect.objectContaining({
id: 'log-20',
}));
expect(reloaded.useStore.getState().sqlLogs[0]?.sql.length).toBe(12 * 1024);
});
it('shrinks oversized SQL logs from older persisted snapshots during hydration', async () => {
storage.setItem('lite-db-storage', JSON.stringify({
state: {
sqlLogs: Array.from({ length: 200 }, (_, index) => ({
id: `legacy-log-${index}`,
timestamp: 500 + index,
sql: `select '${'x'.repeat(18 * 1024)}'`,
status: index % 2 === 0 ? 'success' : 'error',
duration: index,
dbName: 'legacy',
message: 'm'.repeat(3 * 1024),
})),
},
version: 12,
}));
const { useStore } = await importStore();
const sqlLogs = useStore.getState().sqlLogs;
expect(sqlLogs).toHaveLength(120);
expect(sqlLogs[0]).toEqual(expect.objectContaining({
id: 'legacy-log-0',
dbName: 'legacy',
}));
expect(sqlLogs[119]).toEqual(expect.objectContaining({
id: 'legacy-log-119',
}));
expect(sqlLogs[0]?.sql.length).toBe(12 * 1024);
expect(sqlLogs[0]?.message?.length).toBe(1024);
});
it('defaults AI chat send shortcut to Enter in shared shortcut options', async () => {

View File

@@ -137,14 +137,16 @@ const MIN_KEEPALIVE_INTERVAL_MINUTES = 1;
const MAX_KEEPALIVE_INTERVAL_MINUTES = 1440;
const DEFAULT_DIAGNOSTIC_TIMEOUT_SECONDS = 15;
const MAX_DIAGNOSTIC_TIMEOUT_SECONDS = 300;
const PERSIST_VERSION = 12;
const PERSIST_VERSION = 13;
const PERSIST_STORAGE_KEY = "lite-db-storage";
const PERSIST_WRITE_DEBOUNCE_MS = 160;
const MAX_PERSISTED_QUERY_TABS = 20;
const MAX_PERSISTED_QUERY_LENGTH = 1024 * 1024;
const MAX_SQL_LOGS = 1000;
const MAX_RUNTIME_SQL_LOGS = 120;
const MAX_RUNTIME_SQL_LOG_LENGTH = 12 * 1024;
const MAX_RUNTIME_SQL_LOG_MESSAGE_LENGTH = 1024;
const MAX_PERSISTED_SQL_LOGS = 200;
const MAX_PERSISTED_SQL_LOG_LENGTH = 100 * 1024;
const MAX_PERSISTED_SQL_LOG_LENGTH = 24 * 1024;
const MAX_PERSISTED_SQL_LOG_MESSAGE_LENGTH = 2 * 1024;
const MAX_TABLE_EXPORT_HISTORY_PER_TARGET = 20;
const MAX_TABLE_EXPORT_HISTORY_TARGETS = 200;
@@ -1725,50 +1727,101 @@ const resolveActiveContextForTabId = (
return fallbackContext;
};
const sanitizeSqlLogs = (value: unknown, limit = MAX_PERSISTED_SQL_LOGS): SqlLog[] => {
type SqlLogSanitizeOptions = {
limit: number;
sqlLength: number;
messageLength: number;
};
const RUNTIME_SQL_LOG_SANITIZE_OPTIONS: SqlLogSanitizeOptions = {
limit: MAX_RUNTIME_SQL_LOGS,
sqlLength: MAX_RUNTIME_SQL_LOG_LENGTH,
messageLength: MAX_RUNTIME_SQL_LOG_MESSAGE_LENGTH,
};
const PERSISTED_SQL_LOG_SANITIZE_OPTIONS: SqlLogSanitizeOptions = {
limit: MAX_PERSISTED_SQL_LOGS,
sqlLength: MAX_PERSISTED_SQL_LOG_LENGTH,
messageLength: MAX_PERSISTED_SQL_LOG_MESSAGE_LENGTH,
};
const sanitizeSqlLogEntry = (
entry: unknown,
index: number,
options: SqlLogSanitizeOptions,
): SqlLog | null => {
if (!entry || typeof entry !== "object") return null;
const raw = entry as Record<string, unknown>;
const sql = typeof raw.sql === "string" ? raw.sql.slice(0, options.sqlLength) : "";
if (!sql.trim()) return null;
const status = raw.status === "error" ? "error" : "success";
const timestamp = Number(raw.timestamp);
const duration = Number(raw.duration);
const affectedRows = Number(raw.affectedRows);
const message = typeof raw.message === "string"
? raw.message.slice(0, options.messageLength)
: "";
const log: SqlLog = {
id: toTrimmedString(raw.id, `log-${index + 1}`) || `log-${index + 1}`,
timestamp: Number.isFinite(timestamp) && timestamp > 0 ? timestamp : Date.now(),
sql,
status,
duration: Number.isFinite(duration) && duration >= 0 ? duration : 0,
dbName: toTrimmedString(raw.dbName) || undefined,
};
if (message) {
log.message = message;
}
if (Number.isFinite(affectedRows)) {
log.affectedRows = affectedRows;
}
return log;
};
const sanitizeSqlLogs = (
value: unknown,
options: SqlLogSanitizeOptions = PERSISTED_SQL_LOG_SANITIZE_OPTIONS,
): SqlLog[] => {
if (!Array.isArray(value)) return [];
const result: SqlLog[] = [];
const seenIds = new Set<string>();
value.forEach((entry, index) => {
if (!entry || typeof entry !== "object") return;
const raw = entry as Record<string, unknown>;
const sql = typeof raw.sql === "string" ? raw.sql.slice(0, MAX_PERSISTED_SQL_LOG_LENGTH) : "";
if (!sql.trim()) return;
const log = sanitizeSqlLogEntry(entry, index, options);
if (!log) return;
let id = toTrimmedString(raw.id, `log-${index + 1}`) || `log-${index + 1}`;
let id = log.id;
if (seenIds.has(id)) {
id = `${id}-${index + 1}`;
}
seenIds.add(id);
const status = raw.status === "error" ? "error" : "success";
const timestamp = Number(raw.timestamp);
const duration = Number(raw.duration);
const affectedRows = Number(raw.affectedRows);
const log: SqlLog = {
id,
timestamp: Number.isFinite(timestamp) && timestamp > 0 ? timestamp : Date.now(),
sql,
status,
duration: Number.isFinite(duration) && duration >= 0 ? duration : 0,
dbName: toTrimmedString(raw.dbName) || undefined,
};
const message = typeof raw.message === "string"
? raw.message.slice(0, MAX_PERSISTED_SQL_LOG_MESSAGE_LENGTH)
: "";
if (message) {
log.message = message;
}
if (Number.isFinite(affectedRows)) {
log.affectedRows = affectedRows;
}
result.push(log);
result.push(id === log.id ? log : { ...log, id });
});
return result.slice(0, limit);
return result.slice(0, options.limit);
};
const sanitizeRuntimeSqlLogs = (value: unknown) =>
sanitizeSqlLogs(value, RUNTIME_SQL_LOG_SANITIZE_OPTIONS);
const sanitizePersistedSqlLogs = (value: unknown) =>
sanitizeSqlLogs(value, PERSISTED_SQL_LOG_SANITIZE_OPTIONS);
const appendRuntimeSqlLog = (existing: SqlLog[], entry: SqlLog): SqlLog[] => {
const nextEntry = sanitizeSqlLogEntry(entry, 0, RUNTIME_SQL_LOG_SANITIZE_OPTIONS);
if (!nextEntry) {
return existing;
}
const nextLogs = [nextEntry, ...existing.slice(0, MAX_RUNTIME_SQL_LOGS - 1)];
return existing.some((item) => item.id === nextEntry.id)
? sanitizeRuntimeSqlLogs(nextLogs)
: nextLogs;
};
const hasLegacyConnectionSecrets = (
@@ -3173,7 +3226,7 @@ export const useStore = create<AppState>()(
}),
addSqlLog: (log) =>
set((state) => ({ sqlLogs: sanitizeSqlLogs([log, ...state.sqlLogs], MAX_SQL_LOGS) })),
set((state) => ({ sqlLogs: appendRuntimeSqlLog(state.sqlLogs, log) })),
clearSqlLogs: () => set({ sqlLogs: [] }),
upsertTableExportHistory: (historyKey, entry) =>
set((state) => {
@@ -3573,7 +3626,7 @@ export const useStore = create<AppState>()(
nextState.shortcutOptions = sanitizeShortcutOptions(
state.shortcutOptions,
);
nextState.sqlLogs = sanitizeSqlLogs(state.sqlLogs);
nextState.sqlLogs = sanitizeRuntimeSqlLogs(state.sqlLogs);
nextState.tableExportHistories = sanitizeTableExportHistories(
state.tableExportHistories,
);
@@ -3686,7 +3739,7 @@ export const useStore = create<AppState>()(
state.sqlEditorTransactionOptions,
),
shortcutOptions: sanitizeShortcutOptions(state.shortcutOptions),
sqlLogs: sanitizeSqlLogs(state.sqlLogs),
sqlLogs: sanitizeRuntimeSqlLogs(state.sqlLogs),
sqlSnippets: sanitizeSqlSnippets(state.sqlSnippets),
tableAccessCount: sanitizeTableAccessCount(state.tableAccessCount),
@@ -3718,7 +3771,7 @@ export const useStore = create<AppState>()(
dataEditTransactionOptions: state.dataEditTransactionOptions,
sqlEditorTransactionOptions: state.sqlEditorTransactionOptions,
shortcutOptions: resolveShortcutOptionsForPersistence(state.shortcutOptions),
sqlLogs: sanitizeSqlLogs(state.sqlLogs),
sqlLogs: sanitizePersistedSqlLogs(state.sqlLogs),
tableExportHistories: sanitizeTableExportHistories(
state.tableExportHistories,
),

View File

@@ -1,6 +1,12 @@
import { describe, expect, it } from 'vitest';
import { applyMongoQueryAutoLimit, buildMongoFindCommand, convertMongoShellToJsonCommand } from './mongodb';
import {
applyMongoQueryAutoLimit,
buildMongoFindCommand,
convertMongoShellToJsonCommand,
formatMongoEditableValue,
parseMongoEditedValue,
} from './mongodb';
const parseCommand = (command: string | undefined) => JSON.parse(command || '{}');
@@ -134,3 +140,37 @@ describe('buildMongoFindCommand', () => {
});
});
});
describe('Mongo edit value helpers', () => {
it('formats common extended JSON wrappers to editable literals', () => {
expect(formatMongoEditableValue({ $oid: '507f1f77bcf86cd799439011' })).toBe('ObjectId("507f1f77bcf86cd799439011")');
expect(formatMongoEditableValue({ $date: { $numberLong: '1719100800000' } })).toBe('ISODate("2024-06-23T00:00:00.000Z")');
expect(formatMongoEditableValue({ $numberInt: '7' })).toBe('NumberInt(7)');
expect(formatMongoEditableValue({ $numberLong: '8' })).toBe('NumberLong("8")');
expect(formatMongoEditableValue({ $numberDouble: '1.5' })).toBe('1.5');
expect(formatMongoEditableValue({ $numberDecimal: '9.99' })).toBe('NumberDecimal("9.99")');
expect(formatMongoEditableValue({
$binary: {
base64: 'EjRWeBI0RniSNFZ4EjRWeA==',
subType: '04',
},
})).toBe('UUID("12345678-1234-4678-9234-567812345678")');
});
it('parses typed Mongo edit text back to extended JSON wrappers', () => {
expect(parseMongoEditedValue('_id', '507f1f77bcf86cd799439011')).toEqual({ $oid: '507f1f77bcf86cd799439011' });
expect(parseMongoEditedValue('createdAt', '2024-06-23T00:00:00.000Z', { $date: { $numberLong: '1719100800000' } })).toEqual({
$date: { $numberLong: '1719100800000' },
});
expect(parseMongoEditedValue('count32', '7', { $numberInt: '1' })).toEqual({ $numberInt: '7' });
expect(parseMongoEditedValue('count64', '8', { $numberLong: '1' })).toEqual({ $numberLong: '8' });
expect(parseMongoEditedValue('ratio', '1.5', { $numberDouble: '0.5' })).toEqual({ $numberDouble: '1.5' });
expect(parseMongoEditedValue('price', '9.99', { $numberDecimal: '1.23' })).toEqual({ $numberDecimal: '9.99' });
expect(parseMongoEditedValue('uid', 'UUID("12345678-1234-4678-9234-567812345678")')).toEqual({
$binary: {
base64: 'EjRWeBI0RniSNFZ4EjRWeA==',
subType: '04',
},
});
});
});

View File

@@ -16,8 +16,168 @@ type ShellConvertResult = {
};
const HEX24_RE = /^[0-9a-fA-F]{24}$/;
const UUID_RE = /^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$/;
const INTEGER_RE = /^[+-]?\d+$/;
const FLOAT_RE = /^[+-]?(?:\d+\.\d+|\d+\.|\.\d+)$/;
const SCIENTIFIC_RE = /^[+-]?(?:\d+(?:\.\d+)?|\.\d+)[eE][+-]?\d+$/;
const isPlainMongoObject = (value: unknown): value is Record<string, unknown> => (
!!value && typeof value === 'object' && !Array.isArray(value)
);
const getSingleMongoOperatorEntry = (value: unknown): [string, unknown] | null => {
if (!isPlainMongoObject(value)) return null;
const entries = Object.entries(value);
if (entries.length !== 1) return null;
return entries[0] || null;
};
const byteArrayToBase64 = (bytes: Uint8Array): string => {
const BufferCtor = (globalThis as any)?.Buffer;
if (BufferCtor) {
return BufferCtor.from(bytes).toString('base64');
}
let binary = '';
bytes.forEach((byte) => {
binary += String.fromCharCode(byte);
});
return globalThis.btoa(binary);
};
const base64ToByteArray = (base64: string): Uint8Array => {
const BufferCtor = (globalThis as any)?.Buffer;
if (BufferCtor) {
return Uint8Array.from(BufferCtor.from(base64, 'base64'));
}
const binary = globalThis.atob(base64);
const bytes = new Uint8Array(binary.length);
for (let index = 0; index < binary.length; index += 1) {
bytes[index] = binary.charCodeAt(index);
}
return bytes;
};
const uuidToBytes = (uuid: string): Uint8Array => {
const hex = String(uuid || '').trim().replace(/-/g, '').toLowerCase();
const bytes = new Uint8Array(16);
for (let index = 0; index < 16; index += 1) {
bytes[index] = Number.parseInt(hex.slice(index * 2, index * 2 + 2), 16);
}
return bytes;
};
const bytesToUuid = (bytes: Uint8Array): string => {
const hex = Array.from(bytes).map((byte) => byte.toString(16).padStart(2, '0')).join('');
if (hex.length !== 32) return '';
return [
hex.slice(0, 8),
hex.slice(8, 12),
hex.slice(12, 16),
hex.slice(16, 20),
hex.slice(20, 32),
].join('-');
};
const buildMongoBinaryUUID = (uuidText: string): { $binary: { base64: string; subType: string } } => ({
$binary: {
base64: byteArrayToBase64(uuidToBytes(uuidText)),
subType: '04',
},
});
const buildMongoDateLiteralText = (raw?: unknown): string => {
const millis = typeof raw === 'object' && raw && !Array.isArray(raw)
? parseMongoDateToMillis((raw as Record<string, unknown>)?.$numberLong ?? raw)
: parseMongoDateToMillis(raw);
if (millis !== null) {
return new Date(millis).toISOString();
}
return String(raw ?? '');
};
const buildMongoBinaryLiteralText = (raw: unknown): string | null => {
if (!isPlainMongoObject(raw)) return null;
const binary = raw.$binary;
if (!isPlainMongoObject(binary)) return null;
const subType = String(binary.subType ?? '').trim().toLowerCase();
const base64 = String(binary.base64 ?? '').trim();
if (subType !== '04' || !base64) return null;
try {
const uuidText = bytesToUuid(base64ToByteArray(base64));
return UUID_RE.test(uuidText) ? `UUID("${uuidText}")` : null;
} catch {
return null;
}
};
const looksLikeExplicitMongoTypedLiteral = (raw: string): boolean => (
/^(?:ObjectId|ISODate|NumberInt|NumberLong|NumberDouble|NumberDecimal|UUID|MaxKey|MinKey)\s*\(/i.test(String(raw || '').trim())
);
const looksLikeMongoStructuredLiteral = (raw: string): boolean => {
const text = String(raw || '').trim();
if (!text) return false;
const first = text[0];
const last = text[text.length - 1];
return (first === '{' && last === '}') || (first === '[' && last === ']');
};
type MongoValueKind =
| 'nullish'
| 'string'
| 'boolean'
| 'number'
| 'object'
| 'array'
| 'objectId'
| 'date'
| 'int32'
| 'int64'
| 'double'
| 'decimal128'
| 'uuid'
| 'binary'
| 'maxKey'
| 'minKey';
const resolveMongoValueKind = (value: unknown): MongoValueKind => {
if (value === null || typeof value === 'undefined') return 'nullish';
if (Array.isArray(value)) return 'array';
if (typeof value === 'string') return 'string';
if (typeof value === 'boolean') return 'boolean';
if (typeof value === 'number') return 'number';
const singleEntry = getSingleMongoOperatorEntry(value);
if (singleEntry) {
switch (singleEntry[0]) {
case '$oid':
return 'objectId';
case '$date':
return 'date';
case '$numberInt':
return 'int32';
case '$numberLong':
return 'int64';
case '$numberDouble':
return 'double';
case '$numberDecimal':
return 'decimal128';
case '$binary': {
const binary = singleEntry[1];
if (isPlainMongoObject(binary) && String(binary.subType ?? '').trim().toLowerCase() === '04') {
return 'uuid';
}
return 'binary';
}
case '$maxKey':
return 'maxKey';
case '$minKey':
return 'minKey';
default:
break;
}
}
return typeof value === 'object' ? 'object' : 'string';
};
const escapeRegex = (raw: string) => raw.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
@@ -69,13 +229,31 @@ const parseBooleanLiteral = (raw: string): boolean | null => {
return null;
};
const normalizeMongoDoubleLiteral = (raw: string): string | null => {
const text = String(raw || '').trim();
if (!text) return null;
const lower = text.toLowerCase();
if (lower === 'nan') return 'NaN';
if (lower === 'infinity' || lower === '+infinity') return 'Infinity';
if (lower === '-infinity') return '-Infinity';
if (INTEGER_RE.test(text) || FLOAT_RE.test(text) || SCIENTIFIC_RE.test(text)) {
const parsed = Number(text);
return Number.isFinite(parsed) ? String(parsed) : null;
}
return null;
};
const normalizeExtendedJSON = (raw: string): string => {
let text = String(raw || '');
text = text.replace(/ObjectId\s*\(\s*["']([0-9a-fA-F]{24})["']\s*\)/g, (_m, oid: string) => JSON.stringify({ $oid: oid }));
text = text.replace(/ISODate\s*\(\s*["']([^"']+)["']\s*\)/g, (_m, dateText: string) => JSON.stringify(buildMongoExtendedDate(dateText)));
text = text.replace(/NumberLong\s*\(\s*["']?([+-]?\d+)["']?\s*\)/g, '{"$numberLong":"$1"}');
text = text.replace(/NumberInt\s*\(\s*["']?([+-]?\d+)["']?\s*\)/g, '{"$numberInt":"$1"}');
text = text.replace(/NumberDouble\s*\(\s*["']?([^"')]+)["']?\s*\)/g, '{"$numberDouble":"$1"}');
text = text.replace(/NumberDecimal\s*\(\s*["']?([+-]?(?:\d+(?:\.\d+)?|\.\d+))["']?\s*\)/g, '{"$numberDecimal":"$1"}');
text = text.replace(/UUID\s*\(\s*["']([0-9a-fA-F-]{36})["']\s*\)/g, (_m, uuidText: string) => JSON.stringify(buildMongoBinaryUUID(uuidText)));
text = text.replace(/MaxKey\s*\(\s*\)/g, '{"$maxKey":1}');
text = text.replace(/MinKey\s*\(\s*\)/g, '{"$minKey":1}');
return text;
};
@@ -130,21 +308,39 @@ const evalMongoLikeLiteral = (raw: string): unknown => {
if (!INTEGER_RE.test(text)) throw new Error(`NumberLong invalid value: ${text}`);
return { $numberLong: text };
};
const NumberDouble = (value: unknown) => {
const normalized = normalizeMongoDoubleLiteral(String(value ?? '').trim());
if (!normalized) throw new Error(`NumberDouble invalid value: ${String(value)}`);
return { $numberDouble: normalized };
};
const NumberDecimal = (value: unknown) => {
const text = String(value ?? '').trim();
if (!text) throw new Error('NumberDecimal invalid value');
return { $numberDecimal: text };
};
const UUID = (value: unknown) => {
const text = String(value ?? '').trim().replace(/^['"]|['"]$/g, '');
if (!UUID_RE.test(text)) {
throw new Error(`UUID invalid value: ${text}`);
}
return buildMongoBinaryUUID(text.toLowerCase());
};
const MaxKey = () => ({ $maxKey: 1 });
const MinKey = () => ({ $minKey: 1 });
const parser = new Function(
'ObjectId',
'ISODate',
'NumberInt',
'NumberLong',
'NumberDouble',
'NumberDecimal',
'UUID',
'MaxKey',
'MinKey',
'"use strict"; return (' + expression + ');',
);
const evaluated = parser(ObjectId, ISODate, NumberInt, NumberLong, NumberDecimal);
const evaluated = parser(ObjectId, ISODate, NumberInt, NumberLong, NumberDouble, NumberDecimal, UUID, MaxKey, MinKey);
return normalizeEvaluatedMongoValue(evaluated);
};
@@ -183,6 +379,135 @@ const parseMongoJSONValue = (raw: string): unknown => {
}
};
export const formatMongoValueForDisplay = (value: unknown): string => {
if (value === null) return 'NULL';
if (typeof value === 'undefined') return '';
const singleEntry = getSingleMongoOperatorEntry(value);
if (singleEntry) {
switch (singleEntry[0]) {
case '$oid':
return `ObjectId("${String(singleEntry[1] ?? '')}")`;
case '$date':
return `ISODate("${buildMongoDateLiteralText(singleEntry[1])}")`;
case '$numberInt':
return `NumberInt(${String(singleEntry[1] ?? '')})`;
case '$numberLong':
return `NumberLong("${String(singleEntry[1] ?? '')}")`;
case '$numberDouble':
return String(singleEntry[1] ?? '');
case '$numberDecimal':
return `NumberDecimal("${String(singleEntry[1] ?? '')}")`;
case '$binary': {
const binaryText = buildMongoBinaryLiteralText(value);
if (binaryText) return binaryText;
break;
}
case '$maxKey':
return 'MaxKey()';
case '$minKey':
return 'MinKey()';
default:
break;
}
}
if (Array.isArray(value) || isPlainMongoObject(value)) {
try {
return JSON.stringify(value);
} catch {
return String(value);
}
}
return String(value);
};
export const formatMongoEditableValue = (value: unknown): string => {
if (value === null || typeof value === 'undefined') return '';
const singleEntry = getSingleMongoOperatorEntry(value);
if (singleEntry) {
return formatMongoValueForDisplay(value);
}
if (Array.isArray(value) || isPlainMongoObject(value)) {
try {
return JSON.stringify(value, null, 2);
} catch {
return String(value);
}
}
return String(value);
};
export const parseMongoEditedValue = (
columnName: string,
rawValue: unknown,
currentValue?: unknown,
): unknown => {
if (typeof rawValue !== 'string') return rawValue;
const currentKind = resolveMongoValueKind(currentValue);
const text = rawValue.trim();
const structuredLiteral = looksLikeMongoStructuredLiteral(rawValue);
const explicitLiteral = looksLikeExplicitMongoTypedLiteral(rawValue);
if (structuredLiteral || explicitLiteral) {
return parseMongoJSONValue(rawValue);
}
switch (currentKind) {
case 'objectId':
if (HEX24_RE.test(text)) return { $oid: text.toLowerCase() };
return rawValue;
case 'date':
if (!text) return rawValue;
return buildMongoExtendedDate(text);
case 'int32':
if (INTEGER_RE.test(text)) return { $numberInt: String(Number.parseInt(text, 10)) };
if (text.toLowerCase() === 'null') return null;
return rawValue;
case 'int64':
if (INTEGER_RE.test(text)) return { $numberLong: text };
if (text.toLowerCase() === 'null') return null;
return rawValue;
case 'double': {
const normalized = normalizeMongoDoubleLiteral(text);
if (normalized !== null) return { $numberDouble: normalized };
if (text.toLowerCase() === 'null') return null;
return rawValue;
}
case 'decimal128':
if (INTEGER_RE.test(text) || FLOAT_RE.test(text)) return { $numberDecimal: text };
if (text.toLowerCase() === 'null') return null;
return rawValue;
case 'boolean': {
const boolValue = parseBooleanLiteral(text);
if (boolValue !== null) return boolValue;
if (text.toLowerCase() === 'null') return null;
return rawValue;
}
case 'number':
if (INTEGER_RE.test(text) || FLOAT_RE.test(text)) {
const parsed = Number(text);
return Number.isFinite(parsed) ? parsed : rawValue;
}
if (text.toLowerCase() === 'null') return null;
return rawValue;
case 'array':
case 'object':
case 'uuid':
case 'binary':
case 'maxKey':
case 'minKey':
if (text.toLowerCase() === 'null') return null;
return rawValue;
case 'string':
case 'nullish':
default:
if (String(columnName || '').trim() === '_id' && HEX24_RE.test(text)) {
return { $oid: text.toLowerCase() };
}
return rawValue;
}
};
const splitTopLevelComma = (raw: string): string[] => {
const text = String(raw || '');
const result: string[] = [];
@@ -1098,4 +1423,3 @@ export const convertMongoShellToJsonCommand = (raw: string): ShellConvertResult
};
}
};

View File

@@ -1401,6 +1401,7 @@ export namespace sync {
targetConfig: connection.ConnectionConfig;
sourceDatabase?: string;
targetDatabase?: string;
targetSchema?: string;
tables: string[];
sourceQuery?: string;
content?: string;
@@ -1422,6 +1423,7 @@ export namespace sync {
this.targetConfig = this.convertValues(source["targetConfig"], connection.ConnectionConfig);
this.sourceDatabase = source["sourceDatabase"];
this.targetDatabase = source["targetDatabase"];
this.targetSchema = source["targetSchema"];
this.tables = source["tables"];
this.sourceQuery = source["sourceQuery"];
this.content = source["content"];

View File

@@ -1062,6 +1062,13 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
}
// 某些 optional driver-agent 的原生多结果集路径会异常返回“成功但无任何结果集”。
// 对只读查询这是不可信信号,回退到逐条执行可以避免普通 SELECT 在结果面板中被吃空。
if useNativeMultiResult && allReadOnly && results != nil && len(results) == 0 && len(resultMessages) == 0 {
logger.Warnf("DBQueryMulti 原生多结果集返回空结果,将回退逐条执行:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
results = nil
}
// 驱动支持多结果集,直接返回
if results != nil {
return connection.QueryResult{Success: true, Data: results, Messages: resultMessages, QueryID: queryID}

View File

@@ -34,6 +34,11 @@ type fakeNativeMultiResultDB struct {
multiCalls int
}
type fakeEmptyNativeMultiResultDB struct {
*fakeBatchWriteDB
multiCalls int
}
func (f *fakeNativeMultiResultDB) QueryMulti(query string) ([]connection.ResultSetData, error) {
results, _, err := f.QueryMultiWithMessages(query)
return results, err
@@ -67,6 +72,28 @@ func (f *fakeNativeMultiResultDB) QueryMultiContextWithMessages(ctx context.Cont
}}, append([]string(nil), messages...), nil
}
func (f *fakeEmptyNativeMultiResultDB) QueryMulti(query string) ([]connection.ResultSetData, error) {
results, _, err := f.QueryMultiWithMessages(query)
return results, err
}
func (f *fakeEmptyNativeMultiResultDB) QueryMultiWithMessages(query string) ([]connection.ResultSetData, []string, error) {
return f.QueryMultiContextWithMessages(context.Background(), query)
}
func (f *fakeEmptyNativeMultiResultDB) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) {
results, _, err := f.QueryMultiContextWithMessages(ctx, query)
return results, err
}
func (f *fakeEmptyNativeMultiResultDB) QueryMultiContextWithMessages(ctx context.Context, query string) ([]connection.ResultSetData, []string, error) {
f.multiCalls++
if err := f.queryErr[query]; err != nil {
return nil, nil, err
}
return []connection.ResultSetData{}, nil, nil
}
func (f *fakeBatchWriteDB) Connect(config connection.ConnectionConfig) error {
return nil
}
@@ -1332,6 +1359,57 @@ func TestDBQueryMultiRunsSQLServerStatisticsBatchNatively(t *testing.T) {
}
}
func TestDBQueryMultiFallsBackWhenNativeReadOnlyBatchReturnsEmptyResults(t *testing.T) {
originalNewDatabaseFunc := newDatabaseFunc
t.Cleanup(func() {
newDatabaseFunc = originalNewDatabaseFunc
})
query := "SELECT 1 AS value"
baseDB := &fakeBatchWriteDB{
queryMap: map[string][]map[string]interface{}{
query: {
{"value": 1},
},
},
fieldMap: map[string][]string{
query: {"value"},
},
queryErr: map[string]error{},
}
fakeDB := &fakeEmptyNativeMultiResultDB{fakeBatchWriteDB: baseDB}
newDatabaseFunc = func(dbType string) (db.Database, error) {
return fakeDB, nil
}
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
config := connection.ConnectionConfig{Type: "sqlserver", Host: "127.0.0.1", Port: 1433, User: "sa"}
result := app.DBQueryMulti(config, "master", query, "sqlserver-empty-native-read-fallback-test")
if !result.Success {
t.Fatalf("expected DBQueryMulti success, got failure: %s", result.Message)
}
if fakeDB.multiCalls != 1 {
t.Fatalf("expected one native multi-result attempt, got %d", fakeDB.multiCalls)
}
if baseDB.session == nil {
t.Fatal("expected empty native result to fall back to pinned session query")
}
if baseDB.session.queryCalls != 1 {
t.Fatalf("expected fallback to query through pinned session once, got %d", baseDB.session.queryCalls)
}
resultSets, ok := result.Data.([]connection.ResultSetData)
if !ok {
t.Fatalf("expected []connection.ResultSetData, got %T", result.Data)
}
if len(resultSets) != 1 {
t.Fatalf("expected one fallback result set, got %#v", resultSets)
}
if got := resultSets[0].Rows[0]["value"]; got != 1 {
t.Fatalf("expected fallback SELECT result value=1, got %#v", got)
}
}
func TestDBQueryMultiUsesPinnedSessionForSequentialFallback(t *testing.T) {
originalNewDatabaseFunc := newDatabaseFunc
t.Cleanup(func() {

View File

@@ -4,26 +4,26 @@ package db
func init() {
optionalDriverAgentRevisions = map[string]string{
"mariadb": "src-b23e2ce1581a5064",
"oceanbase": "src-5067dbdf0ca7b9c4",
"diros": "src-db43faca6bf15d9b",
"starrocks": "src-01e9f06c0fab09d5",
"sphinx": "src-38ee5cae952cc809",
"sqlserver": "src-7a87f6deb816f110",
"sqlite": "src-d3d439cd788880e2",
"duckdb": "src-b11506b8706bfb73",
"dameng": "src-1638124bfd7fce09",
"kingbase": "src-fb3a404cf4eb1bd9",
"highgo": "src-72fe51afa884f6bc",
"vastbase": "src-3d48607603bfd8b7",
"opengauss": "src-709acf442f016e30",
"gaussdb": "src-f6beccc924d71031",
"iris": "src-9ebf5b970a73b341",
"mongodb": "src-367d11cd04e982c1",
"tdengine": "src-3c13c42f18ba01e1",
"iotdb": "src-5ba9da13c6a272f9",
"clickhouse": "src-99c8babfefdf142c",
"elasticsearch": "src-36b2e2b5f49db9d1",
"trino": "src-d264ceca132c185c",
"mariadb": "src-cc133d2524ceb634",
"oceanbase": "src-ac17327184366ff0",
"diros": "src-7d4fe439271d0c56",
"starrocks": "src-ce9ee22641a32f46",
"sphinx": "src-08f5ae54efb3d9df",
"sqlserver": "src-33b3b2c6dad5b3e6",
"sqlite": "src-96dfa25b3042b2d5",
"duckdb": "src-8804eb2cdbc89433",
"dameng": "src-016e77082aea6718",
"kingbase": "src-17728b2ebda94dc9",
"highgo": "src-da2e8a9d2e661d3b",
"vastbase": "src-da186ac367206c16",
"opengauss": "src-54dc852e4c502947",
"gaussdb": "src-3bbbffc6991dc8ae",
"iris": "src-e798713e492e9a09",
"mongodb": "src-2610395b35c2e708",
"tdengine": "src-779b9b537f08856f",
"iotdb": "src-7edea4aba8d4869e",
"clickhouse": "src-0197342ca5afa8b5",
"elasticsearch": "src-08e8e80cb17a409a",
"trino": "src-ba947f211ce7b19f",
}
}

View File

@@ -4,6 +4,7 @@ package db
import (
"context"
"encoding/json"
"fmt"
"net"
"net/url"
@@ -1058,7 +1059,16 @@ func (m *MongoDB) execCount(ctx context.Context, cmd bson.D) ([]map[string]inter
// convertBsonValue 将 BSON 特殊类型转换为前端可读的 JSON 友好值
func convertBsonValue(v interface{}) interface{} {
switch val := v.(type) {
case map[string]interface{}:
result := make(map[string]interface{}, len(val))
for k, v2 := range val {
result[k] = convertBsonValue(v2)
}
return result
case bson.ObjectID:
if converted, ok := encodeMongoExtendedJSONFieldValue(val); ok {
return converted
}
return val.Hex()
case bson.M:
result := make(map[string]interface{}, len(val))
@@ -1078,11 +1088,75 @@ func convertBsonValue(v interface{}) interface{} {
result[i] = convertBsonValue(v2)
}
return result
case []interface{}:
result := make([]interface{}, len(val))
for i, v2 := range val {
result[i] = convertBsonValue(v2)
}
return result
default:
if !shouldEncodeMongoExtendedJSONFieldValue(v) {
return v
}
if converted, ok := encodeMongoExtendedJSONFieldValue(v); ok {
return converted
}
return v
}
}
func shouldEncodeMongoExtendedJSONFieldValue(v interface{}) bool {
switch v.(type) {
case bson.DateTime,
bson.Decimal128,
bson.Binary,
bson.Regex,
bson.Timestamp,
bson.MaxKey,
bson.MinKey,
bson.Undefined,
int32,
int64,
[]byte,
time.Time:
return true
default:
return false
}
}
func encodeMongoExtendedJSONFieldValue(v interface{}) (interface{}, bool) {
payload, err := bson.MarshalExtJSON(bson.M{"v": v}, true, false)
if err != nil {
return nil, false
}
var wrapped map[string]interface{}
if err := json.Unmarshal(payload, &wrapped); err != nil {
return nil, false
}
converted, ok := wrapped["v"]
return converted, ok
}
func decodeMongoExtendedJSONFieldValue(v interface{}) interface{} {
payload, err := json.Marshal(map[string]interface{}{"v": v})
if err != nil {
return v
}
var wrapped bson.M
if err := bson.UnmarshalExtJSON(payload, false, &wrapped); err != nil {
return v
}
if converted, ok := wrapped["v"]; ok {
return converted
}
return v
}
func (m *MongoDB) Exec(query string) (int64, error) {
_, _, err := m.Query(query)
if err != nil {
@@ -1220,7 +1294,7 @@ func (m *MongoDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDef
func copyMongoChangeDocument(row map[string]interface{}) bson.M {
doc := bson.M{}
for k, v := range row {
doc[k] = v
doc[k] = decodeMongoExtendedJSONFieldValue(v)
}
return doc
}
@@ -1228,46 +1302,11 @@ func copyMongoChangeDocument(row map[string]interface{}) bson.M {
func buildMongoChangeFilter(row map[string]interface{}) bson.M {
filter := bson.M{}
for k, v := range row {
filter[k] = normalizeMongoChangeFilterValue(k, v)
filter[k] = decodeMongoExtendedJSONFieldValue(v)
}
return filter
}
func normalizeMongoChangeFilterValue(key string, value interface{}) interface{} {
if strings.TrimSpace(key) != "_id" {
return value
}
switch val := value.(type) {
case map[string]interface{}:
if raw, ok := val["$oid"]; ok {
if oid, parsed := parseMongoObjectIDHex(fmt.Sprintf("%v", raw)); parsed {
return oid
}
}
case bson.M:
if raw, ok := val["$oid"]; ok {
if oid, parsed := parseMongoObjectIDHex(fmt.Sprintf("%v", raw)); parsed {
return oid
}
}
}
return value
}
func parseMongoObjectIDHex(value string) (bson.ObjectID, bool) {
text := strings.TrimSpace(value)
var zero bson.ObjectID
if len(text) != 24 {
return zero, false
}
oid, err := bson.ObjectIDFromHex(text)
if err != nil {
return zero, false
}
return oid, true
}
// ApplyChanges implements batch changes for MongoDB
func (m *MongoDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
if m.client == nil {
@@ -1300,10 +1339,7 @@ func (m *MongoDB) ApplyChanges(tableName string, changes connection.ChangeSet) e
return fmt.Errorf("更新操作需要主键条件")
}
updateDoc := bson.M{"$set": bson.M{}}
for k, v := range update.Values {
updateDoc["$set"].(bson.M)[k] = v
}
updateDoc := bson.M{"$set": copyMongoChangeDocument(update.Values)}
result, err := collection.UpdateOne(ctx, filter, updateDoc)
if err != nil {

View File

@@ -128,3 +128,138 @@ func TestCopyMongoChangeDocument_LeavesInsertIDStringUntouched(t *testing.T) {
t.Fatalf("insert _id string should stay string, got %T %v", doc["_id"], doc["_id"])
}
}
func TestConvertBsonValue_EncodesMongoTypedValues(t *testing.T) {
const oidHex = "507f1f77bcf86cd799439011"
oid, err := bson.ObjectIDFromHex(oidHex)
if err != nil {
t.Fatal(err)
}
decimalValue, err := bson.ParseDecimal128("12.34")
if err != nil {
t.Fatal(err)
}
converted, ok := convertBsonValue(bson.M{
"_id": oid,
"createdAt": bson.DateTime(1719100800000),
"count32": int32(7),
"count64": int64(8),
"ratio": 1.5,
"price": decimalValue,
"uid": bson.Binary{
Subtype: 0x04,
Data: []byte{0x12, 0x34, 0x56, 0x78, 0x12, 0x34, 0x56, 0x78, 0x12, 0x34, 0x56, 0x78, 0x12, 0x34, 0x56, 0x78},
},
"nested": bson.M{
"innerId": oid,
},
"items": bson.A{int32(1), int64(2)},
}).(map[string]interface{})
if !ok {
t.Fatalf("expected converted document map, got %T", converted)
}
if converted["_id"].(map[string]interface{})["$oid"] != oidHex {
t.Fatalf("unexpected ObjectID wrapper: %#v", converted["_id"])
}
if converted["createdAt"].(map[string]interface{})["$date"].(map[string]interface{})["$numberLong"] != "1719100800000" {
t.Fatalf("unexpected date wrapper: %#v", converted["createdAt"])
}
if converted["count32"].(map[string]interface{})["$numberInt"] != "7" {
t.Fatalf("unexpected int32 wrapper: %#v", converted["count32"])
}
if converted["count64"].(map[string]interface{})["$numberLong"] != "8" {
t.Fatalf("unexpected int64 wrapper: %#v", converted["count64"])
}
if converted["ratio"] != 1.5 {
t.Fatalf("plain double should stay float64, got %T %#v", converted["ratio"], converted["ratio"])
}
if converted["price"].(map[string]interface{})["$numberDecimal"] != "12.34" {
t.Fatalf("unexpected decimal wrapper: %#v", converted["price"])
}
if converted["uid"].(map[string]interface{})["$binary"].(map[string]interface{})["base64"] != "EjRWeBI0VngSNFZ4EjRWeA==" {
t.Fatalf("unexpected binary wrapper: %#v", converted["uid"])
}
nestedDoc, ok := converted["nested"].(map[string]interface{})
if !ok {
t.Fatalf("expected nested map, got %T", converted["nested"])
}
if nestedDoc["innerId"].(map[string]interface{})["$oid"] != oidHex {
t.Fatalf("unexpected nested ObjectID wrapper: %#v", nestedDoc["innerId"])
}
items, ok := converted["items"].([]interface{})
if !ok || len(items) != 2 {
t.Fatalf("unexpected items wrapper: %#v", converted["items"])
}
if items[0].(map[string]interface{})["$numberInt"] != "1" || items[1].(map[string]interface{})["$numberLong"] != "2" {
t.Fatalf("unexpected numeric array wrappers: %#v", items)
}
}
func TestCopyMongoChangeDocument_DecodesExtendedJSONWrappers(t *testing.T) {
doc := copyMongoChangeDocument(map[string]interface{}{
"_id": map[string]interface{}{"$oid": "507f1f77bcf86cd799439011"},
"createdAt": map[string]interface{}{"$date": map[string]interface{}{"$numberLong": "1719100800000"}},
"count32": map[string]interface{}{"$numberInt": "7"},
"count64": map[string]interface{}{"$numberLong": "8"},
"ratio": map[string]interface{}{"$numberDouble": "1.5"},
"price": map[string]interface{}{"$numberDecimal": "12.34"},
"uid": map[string]interface{}{
"$binary": map[string]interface{}{
"base64": "EjRWeBI0VngSNFZ4EjRWeA==",
"subType": "04",
},
},
"nested": map[string]interface{}{
"innerId": map[string]interface{}{"$oid": "507f1f77bcf86cd799439012"},
},
"items": []interface{}{
map[string]interface{}{"$numberInt": "1"},
map[string]interface{}{"$numberLong": "2"},
},
})
if _, ok := doc["_id"].(bson.ObjectID); !ok {
t.Fatalf("expected _id to decode to bson.ObjectID, got %T", doc["_id"])
}
if got, ok := doc["createdAt"].(bson.DateTime); !ok || got != bson.DateTime(1719100800000) {
t.Fatalf("expected createdAt bson.DateTime, got %T %#v", doc["createdAt"], doc["createdAt"])
}
if got, ok := doc["count32"].(int32); !ok || got != 7 {
t.Fatalf("expected count32 int32, got %T %#v", doc["count32"], doc["count32"])
}
if got, ok := doc["count64"].(int64); !ok || got != 8 {
t.Fatalf("expected count64 int64, got %T %#v", doc["count64"], doc["count64"])
}
if got, ok := doc["ratio"].(float64); !ok || got != 1.5 {
t.Fatalf("expected ratio float64, got %T %#v", doc["ratio"], doc["ratio"])
}
if _, ok := doc["price"].(bson.Decimal128); !ok {
t.Fatalf("expected price bson.Decimal128, got %T", doc["price"])
}
if binaryValue, ok := doc["uid"].(bson.Binary); !ok || binaryValue.Subtype != 0x04 || len(binaryValue.Data) != 16 {
t.Fatalf("expected uid bson.Binary UUID, got %T %#v", doc["uid"], doc["uid"])
}
nestedDoc, ok := doc["nested"].(bson.D)
if !ok || len(nestedDoc) != 1 || nestedDoc[0].Key != "innerId" {
t.Fatalf("expected nested bson.D, got %T %#v", doc["nested"], doc["nested"])
}
if _, ok := nestedDoc[0].Value.(bson.ObjectID); !ok {
t.Fatalf("expected nested innerId ObjectID, got %T", nestedDoc[0].Value)
}
items, ok := doc["items"].(bson.A)
if !ok || len(items) != 2 {
t.Fatalf("expected items bson.A, got %T %#v", doc["items"], doc["items"])
}
if got, ok := items[0].(int32); !ok || got != 1 {
t.Fatalf("expected items[0] int32, got %T %#v", items[0], items[0])
}
if got, ok := items[1].(int64); !ok || got != 2 {
t.Fatalf("expected items[1] int64, got %T %#v", items[1], items[1])
}
}

View File

@@ -4,6 +4,7 @@ package db
import (
"context"
"encoding/json"
"fmt"
"net"
"net/url"
@@ -1061,7 +1062,16 @@ func (m *MongoDBV1) execCount(ctx context.Context, cmd bson.D) ([]map[string]int
// convertBsonValue 将 BSON 特殊类型转换为前端可读的 JSON 友好值
func convertBsonValue(v interface{}) interface{} {
switch val := v.(type) {
case map[string]interface{}:
result := make(map[string]interface{}, len(val))
for k, v2 := range val {
result[k] = convertBsonValue(v2)
}
return result
case primitive.ObjectID:
if converted, ok := encodeMongoExtendedJSONFieldValue(val); ok {
return converted
}
return val.Hex()
case bson.M:
result := make(map[string]interface{}, len(val))
@@ -1081,11 +1091,75 @@ func convertBsonValue(v interface{}) interface{} {
result[i] = convertBsonValue(v2)
}
return result
case []interface{}:
result := make([]interface{}, len(val))
for i, v2 := range val {
result[i] = convertBsonValue(v2)
}
return result
default:
if !shouldEncodeMongoExtendedJSONFieldValue(v) {
return v
}
if converted, ok := encodeMongoExtendedJSONFieldValue(v); ok {
return converted
}
return v
}
}
func shouldEncodeMongoExtendedJSONFieldValue(v interface{}) bool {
switch v.(type) {
case primitive.DateTime,
primitive.Decimal128,
primitive.Binary,
primitive.Regex,
primitive.Timestamp,
primitive.MaxKey,
primitive.MinKey,
primitive.Undefined,
int32,
int64,
[]byte,
time.Time:
return true
default:
return false
}
}
func encodeMongoExtendedJSONFieldValue(v interface{}) (interface{}, bool) {
payload, err := bson.MarshalExtJSON(bson.M{"v": v}, true, false)
if err != nil {
return nil, false
}
var wrapped map[string]interface{}
if err := json.Unmarshal(payload, &wrapped); err != nil {
return nil, false
}
converted, ok := wrapped["v"]
return converted, ok
}
func decodeMongoExtendedJSONFieldValue(v interface{}) interface{} {
payload, err := json.Marshal(map[string]interface{}{"v": v})
if err != nil {
return v
}
var wrapped bson.M
if err := bson.UnmarshalExtJSON(payload, false, &wrapped); err != nil {
return v
}
if converted, ok := wrapped["v"]; ok {
return converted
}
return v
}
func (m *MongoDBV1) Exec(query string) (int64, error) {
_, _, err := m.Query(query)
if err != nil {
@@ -1223,7 +1297,7 @@ func (m *MongoDBV1) GetTriggers(dbName, tableName string) ([]connection.TriggerD
func copyMongoChangeDocument(row map[string]interface{}) bson.M {
doc := bson.M{}
for k, v := range row {
doc[k] = v
doc[k] = decodeMongoExtendedJSONFieldValue(v)
}
return doc
}
@@ -1231,46 +1305,11 @@ func copyMongoChangeDocument(row map[string]interface{}) bson.M {
func buildMongoChangeFilter(row map[string]interface{}) bson.M {
filter := bson.M{}
for k, v := range row {
filter[k] = normalizeMongoChangeFilterValue(k, v)
filter[k] = decodeMongoExtendedJSONFieldValue(v)
}
return filter
}
func normalizeMongoChangeFilterValue(key string, value interface{}) interface{} {
if strings.TrimSpace(key) != "_id" {
return value
}
switch val := value.(type) {
case map[string]interface{}:
if raw, ok := val["$oid"]; ok {
if oid, parsed := parseMongoObjectIDHex(fmt.Sprintf("%v", raw)); parsed {
return oid
}
}
case bson.M:
if raw, ok := val["$oid"]; ok {
if oid, parsed := parseMongoObjectIDHex(fmt.Sprintf("%v", raw)); parsed {
return oid
}
}
}
return value
}
func parseMongoObjectIDHex(value string) (primitive.ObjectID, bool) {
text := strings.TrimSpace(value)
var zero primitive.ObjectID
if len(text) != 24 {
return zero, false
}
oid, err := primitive.ObjectIDFromHex(text)
if err != nil {
return zero, false
}
return oid, true
}
// ApplyChanges implements batch changes for MongoDB
func (m *MongoDBV1) ApplyChanges(tableName string, changes connection.ChangeSet) error {
if m.client == nil {
@@ -1303,10 +1342,7 @@ func (m *MongoDBV1) ApplyChanges(tableName string, changes connection.ChangeSet)
return fmt.Errorf("更新操作需要主键条件")
}
updateDoc := bson.M{"$set": bson.M{}}
for k, v := range update.Values {
updateDoc["$set"].(bson.M)[k] = v
}
updateDoc := bson.M{"$set": copyMongoChangeDocument(update.Values)}
result, err := collection.UpdateOne(ctx, filter, updateDoc)
if err != nil {

View File

@@ -87,3 +87,138 @@ func TestCopyMongoChangeDocumentV1_LeavesInsertIDStringUntouched(t *testing.T) {
t.Fatalf("insert _id string should stay string, got %T %v", doc["_id"], doc["_id"])
}
}
func TestConvertBsonValueV1_EncodesMongoTypedValues(t *testing.T) {
const oidHex = "507f1f77bcf86cd799439011"
oid, err := primitive.ObjectIDFromHex(oidHex)
if err != nil {
t.Fatal(err)
}
decimalValue, err := primitive.ParseDecimal128("12.34")
if err != nil {
t.Fatal(err)
}
converted, ok := convertBsonValue(bson.M{
"_id": oid,
"createdAt": primitive.DateTime(1719100800000),
"count32": int32(7),
"count64": int64(8),
"ratio": 1.5,
"price": decimalValue,
"uid": primitive.Binary{
Subtype: 0x04,
Data: []byte{0x12, 0x34, 0x56, 0x78, 0x12, 0x34, 0x56, 0x78, 0x12, 0x34, 0x56, 0x78, 0x12, 0x34, 0x56, 0x78},
},
"nested": bson.M{
"innerId": oid,
},
"items": bson.A{int32(1), int64(2)},
}).(map[string]interface{})
if !ok {
t.Fatalf("expected converted document map, got %T", converted)
}
if converted["_id"].(map[string]interface{})["$oid"] != oidHex {
t.Fatalf("unexpected ObjectID wrapper: %#v", converted["_id"])
}
if converted["createdAt"].(map[string]interface{})["$date"].(map[string]interface{})["$numberLong"] != "1719100800000" {
t.Fatalf("unexpected date wrapper: %#v", converted["createdAt"])
}
if converted["count32"].(map[string]interface{})["$numberInt"] != "7" {
t.Fatalf("unexpected int32 wrapper: %#v", converted["count32"])
}
if converted["count64"].(map[string]interface{})["$numberLong"] != "8" {
t.Fatalf("unexpected int64 wrapper: %#v", converted["count64"])
}
if converted["ratio"] != 1.5 {
t.Fatalf("plain double should stay float64, got %T %#v", converted["ratio"], converted["ratio"])
}
if converted["price"].(map[string]interface{})["$numberDecimal"] != "12.34" {
t.Fatalf("unexpected decimal wrapper: %#v", converted["price"])
}
if converted["uid"].(map[string]interface{})["$binary"].(map[string]interface{})["base64"] != "EjRWeBI0VngSNFZ4EjRWeA==" {
t.Fatalf("unexpected binary wrapper: %#v", converted["uid"])
}
nestedDoc, ok := converted["nested"].(map[string]interface{})
if !ok {
t.Fatalf("expected nested map, got %T", converted["nested"])
}
if nestedDoc["innerId"].(map[string]interface{})["$oid"] != oidHex {
t.Fatalf("unexpected nested ObjectID wrapper: %#v", nestedDoc["innerId"])
}
items, ok := converted["items"].([]interface{})
if !ok || len(items) != 2 {
t.Fatalf("unexpected items wrapper: %#v", converted["items"])
}
if items[0].(map[string]interface{})["$numberInt"] != "1" || items[1].(map[string]interface{})["$numberLong"] != "2" {
t.Fatalf("unexpected numeric array wrappers: %#v", items)
}
}
func TestCopyMongoChangeDocumentV1_DecodesExtendedJSONWrappers(t *testing.T) {
doc := copyMongoChangeDocument(map[string]interface{}{
"_id": map[string]interface{}{"$oid": "507f1f77bcf86cd799439011"},
"createdAt": map[string]interface{}{"$date": map[string]interface{}{"$numberLong": "1719100800000"}},
"count32": map[string]interface{}{"$numberInt": "7"},
"count64": map[string]interface{}{"$numberLong": "8"},
"ratio": map[string]interface{}{"$numberDouble": "1.5"},
"price": map[string]interface{}{"$numberDecimal": "12.34"},
"uid": map[string]interface{}{
"$binary": map[string]interface{}{
"base64": "EjRWeBI0VngSNFZ4EjRWeA==",
"subType": "04",
},
},
"nested": map[string]interface{}{
"innerId": map[string]interface{}{"$oid": "507f1f77bcf86cd799439012"},
},
"items": []interface{}{
map[string]interface{}{"$numberInt": "1"},
map[string]interface{}{"$numberLong": "2"},
},
})
if _, ok := doc["_id"].(primitive.ObjectID); !ok {
t.Fatalf("expected _id to decode to primitive.ObjectID, got %T", doc["_id"])
}
if got, ok := doc["createdAt"].(primitive.DateTime); !ok || got != primitive.DateTime(1719100800000) {
t.Fatalf("expected createdAt primitive.DateTime, got %T %#v", doc["createdAt"], doc["createdAt"])
}
if got, ok := doc["count32"].(int32); !ok || got != 7 {
t.Fatalf("expected count32 int32, got %T %#v", doc["count32"], doc["count32"])
}
if got, ok := doc["count64"].(int64); !ok || got != 8 {
t.Fatalf("expected count64 int64, got %T %#v", doc["count64"], doc["count64"])
}
if got, ok := doc["ratio"].(float64); !ok || got != 1.5 {
t.Fatalf("expected ratio float64, got %T %#v", doc["ratio"], doc["ratio"])
}
if _, ok := doc["price"].(primitive.Decimal128); !ok {
t.Fatalf("expected price primitive.Decimal128, got %T", doc["price"])
}
if binaryValue, ok := doc["uid"].(primitive.Binary); !ok || binaryValue.Subtype != 0x04 || len(binaryValue.Data) != 16 {
t.Fatalf("expected uid primitive.Binary UUID, got %T %#v", doc["uid"], doc["uid"])
}
nestedDoc, ok := doc["nested"].(primitive.M)
if !ok {
t.Fatalf("expected nested primitive.M, got %T %#v", doc["nested"], doc["nested"])
}
if _, ok := nestedDoc["innerId"].(primitive.ObjectID); !ok {
t.Fatalf("expected nested innerId ObjectID, got %T", nestedDoc["innerId"])
}
items, ok := doc["items"].(bson.A)
if !ok || len(items) != 2 {
t.Fatalf("expected items bson.A, got %T %#v", doc["items"], doc["items"])
}
if got, ok := items[0].(int32); !ok || got != 1 {
t.Fatalf("expected items[0] int32, got %T %#v", items[0], items[0])
}
if got, ok := items[1].(int64); !ok || got != 2 {
t.Fatalf("expected items[1] int64, got %T %#v", items[1], items[1])
}
}

View File

@@ -29,6 +29,7 @@ const (
optionalAgentMethodOpenSession = "openSession"
optionalAgentMethodCloseSession = "closeSession"
optionalAgentMethodQuery = "query"
optionalAgentMethodQueryMulti = "queryMulti"
optionalAgentMethodStreamQuery = "streamQuery"
optionalAgentMethodExec = "exec"
optionalAgentMethodGetDatabases = "getDatabases"
@@ -75,6 +76,7 @@ type optionalAgentResponse struct {
Error string `json:"error,omitempty"`
Data json.RawMessage `json:"data,omitempty"`
Fields []string `json:"fields,omitempty"`
Messages []string `json:"messages,omitempty"`
ChunkType string `json:"chunkType,omitempty"`
RowsAffected int64 `json:"rowsAffected,omitempty"`
}
@@ -106,7 +108,7 @@ func ProbeOptionalDriverAgentMetadata(driverType string, executablePath string)
}()
var metadata OptionalDriverAgentMetadata
if err := client.callWithTimeout(optionalAgentRequest{Method: optionalAgentMethodMetadata}, &metadata, nil, nil, optionalAgentMetadataProbeTimeout); err != nil {
if err := client.callWithTimeout(optionalAgentRequest{Method: optionalAgentMethodMetadata}, &metadata, nil, nil, nil, optionalAgentMetadataProbeTimeout); err != nil {
return OptionalDriverAgentMetadata{}, err
}
metadata.DriverType = normalizeRuntimeDriverType(metadata.DriverType)
@@ -208,7 +210,7 @@ func (c *optionalDriverAgentClient) stderrText() string {
return strings.TrimSpace(c.stderr.String())
}
func (c *optionalDriverAgentClient) call(req optionalAgentRequest, out interface{}, fields *[]string, rowsAffected *int64) error {
func (c *optionalDriverAgentClient) call(req optionalAgentRequest, out interface{}, fields *[]string, messages *[]string, rowsAffected *int64) error {
c.mu.Lock()
defer c.mu.Unlock()
@@ -252,6 +254,9 @@ func (c *optionalDriverAgentClient) call(req optionalAgentRequest, out interface
if fields != nil {
*fields = resp.Fields
}
if messages != nil {
*messages = append((*messages)[:0], resp.Messages...)
}
if rowsAffected != nil {
*rowsAffected = resp.RowsAffected
}
@@ -263,14 +268,14 @@ func (c *optionalDriverAgentClient) call(req optionalAgentRequest, out interface
return nil
}
func (c *optionalDriverAgentClient) callWithTimeout(req optionalAgentRequest, out interface{}, fields *[]string, rowsAffected *int64, timeout time.Duration) error {
func (c *optionalDriverAgentClient) callWithTimeout(req optionalAgentRequest, out interface{}, fields *[]string, messages *[]string, rowsAffected *int64, timeout time.Duration) error {
if timeout <= 0 {
return c.call(req, out, fields, rowsAffected)
return c.call(req, out, fields, messages, rowsAffected)
}
errCh := make(chan error, 1)
go func() {
errCh <- c.call(req, out, fields, rowsAffected)
errCh <- c.call(req, out, fields, messages, rowsAffected)
}()
timer := time.NewTimer(timeout)
@@ -469,7 +474,7 @@ func (d *OptionalDriverAgentDB) Connect(config connection.ConnectionConfig) erro
if err := client.call(optionalAgentRequest{
Method: optionalAgentMethodConnect,
Config: &config,
}, nil, nil, nil); err != nil {
}, nil, nil, nil, nil); err != nil {
_ = client.close()
return err
}
@@ -482,7 +487,7 @@ func (d *OptionalDriverAgentDB) Close() error {
if d.client == nil {
return nil
}
_ = d.client.call(optionalAgentRequest{Method: optionalAgentMethodClose}, nil, nil, nil)
_ = d.client.call(optionalAgentRequest{Method: optionalAgentMethodClose}, nil, nil, nil, nil)
err := d.client.close()
d.client = nil
return err
@@ -493,10 +498,87 @@ func (d *OptionalDriverAgentDB) Ping() error {
if err != nil {
return err
}
return client.call(optionalAgentRequest{Method: optionalAgentMethodPing}, nil, nil, nil)
return client.call(optionalAgentRequest{Method: optionalAgentMethodPing}, nil, nil, nil, nil)
}
func (d *OptionalDriverAgentDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
data, fields, _, err := d.QueryContextWithMessages(ctx, query)
return data, fields, err
}
func (d *OptionalDriverAgentDB) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) {
if err := ctx.Err(); err != nil {
return nil, nil, nil, err
}
client, err := d.requireClient()
if err != nil {
return nil, nil, nil, err
}
var data []map[string]interface{}
var fields []string
var messages []string
if err := client.call(optionalAgentRequest{
Method: optionalAgentMethodQuery,
Query: query,
TimeoutMs: timeoutMsFromContext(ctx),
}, &data, &fields, &messages, nil); err != nil {
return nil, nil, nil, err
}
return data, fields, messages, nil
}
func (d *OptionalDriverAgentDB) Query(query string) ([]map[string]interface{}, []string, error) {
data, fields, _, err := d.QueryWithMessages(query)
return data, fields, err
}
func (d *OptionalDriverAgentDB) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) {
client, err := d.requireClient()
if err != nil {
return nil, nil, nil, err
}
var data []map[string]interface{}
var fields []string
var messages []string
if err := client.call(optionalAgentRequest{
Method: optionalAgentMethodQuery,
Query: query,
}, &data, &fields, &messages, nil); err != nil {
return nil, nil, nil, err
}
return data, fields, messages, nil
}
func (d *OptionalDriverAgentDB) QueryMulti(query string) ([]connection.ResultSetData, error) {
results, _, err := d.QueryMultiWithMessages(query)
return results, err
}
func (d *OptionalDriverAgentDB) QueryMultiWithMessages(query string) ([]connection.ResultSetData, []string, error) {
client, err := d.requireClient()
if err != nil {
return nil, nil, err
}
var results []connection.ResultSetData
var messages []string
if err := client.call(optionalAgentRequest{
Method: optionalAgentMethodQueryMulti,
Query: query,
}, &results, nil, &messages, nil); err != nil {
if isOptionalAgentMultiResultUnsupportedError(err) {
return nil, nil, nil
}
return nil, nil, err
}
return results, messages, nil
}
func (d *OptionalDriverAgentDB) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) {
results, _, err := d.QueryMultiContextWithMessages(ctx, query)
return results, err
}
func (d *OptionalDriverAgentDB) QueryMultiContextWithMessages(ctx context.Context, query string) ([]connection.ResultSetData, []string, error) {
if err := ctx.Err(); err != nil {
return nil, nil, err
}
@@ -504,32 +586,19 @@ func (d *OptionalDriverAgentDB) QueryContext(ctx context.Context, query string)
if err != nil {
return nil, nil, err
}
var data []map[string]interface{}
var fields []string
var results []connection.ResultSetData
var messages []string
if err := client.call(optionalAgentRequest{
Method: optionalAgentMethodQuery,
Method: optionalAgentMethodQueryMulti,
Query: query,
TimeoutMs: timeoutMsFromContext(ctx),
}, &data, &fields, nil); err != nil {
}, &results, nil, &messages, nil); err != nil {
if isOptionalAgentMultiResultUnsupportedError(err) {
return nil, nil, nil
}
return nil, nil, err
}
return data, fields, nil
}
func (d *OptionalDriverAgentDB) Query(query string) ([]map[string]interface{}, []string, error) {
client, err := d.requireClient()
if err != nil {
return nil, nil, err
}
var data []map[string]interface{}
var fields []string
if err := client.call(optionalAgentRequest{
Method: optionalAgentMethodQuery,
Query: query,
}, &data, &fields, nil); err != nil {
return nil, nil, err
}
return data, fields, nil
return results, messages, nil
}
func (d *OptionalDriverAgentDB) StreamQuery(query string, consumer QueryStreamConsumer) error {
@@ -581,7 +650,7 @@ func (d *OptionalDriverAgentDB) ExecContext(ctx context.Context, query string) (
Method: optionalAgentMethodExec,
Query: query,
TimeoutMs: timeoutMsFromContext(ctx),
}, nil, nil, &affected); err != nil {
}, nil, nil, nil, &affected); err != nil {
return 0, err
}
return affected, nil
@@ -596,7 +665,7 @@ func (d *OptionalDriverAgentDB) Exec(query string) (int64, error) {
if err := client.call(optionalAgentRequest{
Method: optionalAgentMethodExec,
Query: query,
}, nil, nil, &affected); err != nil {
}, nil, nil, nil, &affected); err != nil {
return 0, err
}
return affected, nil
@@ -611,7 +680,7 @@ func (d *OptionalDriverAgentDB) OpenSessionExecer(ctx context.Context) (Statemen
if err := client.call(optionalAgentRequest{
Method: optionalAgentMethodOpenSession,
TimeoutMs: timeoutMsFromContext(ctx),
}, &sessionID, nil, nil); err != nil {
}, &sessionID, nil, nil, nil); err != nil {
return nil, err
}
sessionID = strings.TrimSpace(sessionID)
@@ -629,6 +698,10 @@ func (s *optionalDriverAgentSession) Query(query string) ([]map[string]interface
return s.QueryContext(context.Background(), query)
}
func (s *optionalDriverAgentSession) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) {
return s.QueryContextWithMessages(context.Background(), query)
}
func (s *optionalDriverAgentSession) StreamQuery(query string, consumer QueryStreamConsumer) error {
return s.StreamQueryContext(context.Background(), query, consumer)
}
@@ -663,20 +736,26 @@ func (s *optionalDriverAgentSession) StreamQueryContext(ctx context.Context, que
}
func (s *optionalDriverAgentSession) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
data, fields, _, err := s.QueryContextWithMessages(ctx, query)
return data, fields, err
}
func (s *optionalDriverAgentSession) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) {
if err := s.ensureOpen(); err != nil {
return nil, nil, err
return nil, nil, nil, err
}
var data []map[string]interface{}
var fields []string
var messages []string
if err := s.client.call(optionalAgentRequest{
Method: optionalAgentMethodQuery,
SessionID: s.sessionID,
Query: query,
TimeoutMs: timeoutMsFromContext(ctx),
}, &data, &fields, nil); err != nil {
return nil, nil, err
}, &data, &fields, &messages, nil); err != nil {
return nil, nil, nil, err
}
return data, fields, nil
return data, fields, messages, nil
}
func (s *optionalDriverAgentSession) Exec(query string) (int64, error) {
@@ -693,7 +772,7 @@ func (s *optionalDriverAgentSession) ExecContext(ctx context.Context, query stri
SessionID: s.sessionID,
Query: query,
TimeoutMs: timeoutMsFromContext(ctx),
}, nil, nil, &affected); err != nil {
}, nil, nil, nil, &affected); err != nil {
return 0, err
}
return affected, nil
@@ -714,7 +793,7 @@ func (s *optionalDriverAgentSession) Close() error {
return s.client.call(optionalAgentRequest{
Method: optionalAgentMethodCloseSession,
SessionID: sessionID,
}, nil, nil, nil)
}, nil, nil, nil, nil)
}
func (s *optionalDriverAgentSession) ensureOpen() error {
@@ -740,6 +819,19 @@ func isOptionalAgentStreamUnsupportedError(err error) bool {
return strings.Contains(text, "不支持的方法") || strings.Contains(text, "不支持流式查询")
}
func isOptionalAgentMultiResultUnsupportedError(err error) bool {
if err == nil {
return false
}
text := strings.TrimSpace(err.Error())
if text == "" {
return false
}
return strings.Contains(text, "不支持的方法") ||
strings.Contains(text, "不支持原生多结果集查询") ||
strings.Contains(text, "不支持多结果集查询")
}
func (d *OptionalDriverAgentDB) GetDatabases() ([]string, error) {
client, err := d.requireClient()
if err != nil {
@@ -748,7 +840,7 @@ func (d *OptionalDriverAgentDB) GetDatabases() ([]string, error) {
var dbs []string
if err := client.call(optionalAgentRequest{
Method: optionalAgentMethodGetDatabases,
}, &dbs, nil, nil); err != nil {
}, &dbs, nil, nil, nil); err != nil {
return nil, err
}
return dbs, nil
@@ -763,7 +855,7 @@ func (d *OptionalDriverAgentDB) GetTables(dbName string) ([]string, error) {
if err := client.call(optionalAgentRequest{
Method: optionalAgentMethodGetTables,
DBName: dbName,
}, &tables, nil, nil); err != nil {
}, &tables, nil, nil, nil); err != nil {
return nil, err
}
return tables, nil
@@ -779,7 +871,7 @@ func (d *OptionalDriverAgentDB) GetCreateStatement(dbName, tableName string) (st
Method: optionalAgentMethodGetCreateStmt,
DBName: dbName,
TableName: tableName,
}, &sqlText, nil, nil); err != nil {
}, &sqlText, nil, nil, nil); err != nil {
return "", err
}
return sqlText, nil
@@ -795,7 +887,7 @@ func (d *OptionalDriverAgentDB) GetColumns(dbName, tableName string) ([]connecti
Method: optionalAgentMethodGetColumns,
DBName: dbName,
TableName: tableName,
}, &columns, nil, nil); err != nil {
}, &columns, nil, nil, nil); err != nil {
return nil, err
}
return columns, nil
@@ -810,7 +902,7 @@ func (d *OptionalDriverAgentDB) GetAllColumns(dbName string) ([]connection.Colum
if err := client.call(optionalAgentRequest{
Method: optionalAgentMethodGetAllColumns,
DBName: dbName,
}, &columns, nil, nil); err != nil {
}, &columns, nil, nil, nil); err != nil {
return nil, err
}
return columns, nil
@@ -826,7 +918,7 @@ func (d *OptionalDriverAgentDB) GetIndexes(dbName, tableName string) ([]connecti
Method: optionalAgentMethodGetIndexes,
DBName: dbName,
TableName: tableName,
}, &indexes, nil, nil); err != nil {
}, &indexes, nil, nil, nil); err != nil {
return nil, err
}
return indexes, nil
@@ -842,7 +934,7 @@ func (d *OptionalDriverAgentDB) GetForeignKeys(dbName, tableName string) ([]conn
Method: optionalAgentMethodGetForeignKeys,
DBName: dbName,
TableName: tableName,
}, &keys, nil, nil); err != nil {
}, &keys, nil, nil, nil); err != nil {
return nil, err
}
return keys, nil
@@ -858,7 +950,7 @@ func (d *OptionalDriverAgentDB) GetTriggers(dbName, tableName string) ([]connect
Method: optionalAgentMethodGetTriggers,
DBName: dbName,
TableName: tableName,
}, &triggers, nil, nil); err != nil {
}, &triggers, nil, nil, nil); err != nil {
return nil, err
}
return triggers, nil
@@ -883,7 +975,7 @@ func (d *OptionalDriverAgentDB) ApplyChanges(tableName string, changes connectio
Method: optionalAgentMethodApplyChanges,
TableName: tableName,
Changes: &changes,
}, nil, nil, nil)
}, nil, nil, nil, nil)
}
func (d *OptionalDriverAgentDB) requireClient() (*optionalDriverAgentClient, error) {

View File

@@ -136,3 +136,68 @@ func TestOptionalDriverAgentClientCallStreamQueryConsumesChunks(t *testing.T) {
t.Fatalf("请求未使用 streamQuery 方法: %s", stdin.String())
}
}
func TestOptionalDriverAgentDBQueryWithMessagesParsesAgentMessages(t *testing.T) {
var stdin optionalAgentTestWriteCloser
stdout := `{"id":1,"success":true,"data":[{"sql_text":"select 1"}],"fields":["sql_text"],"messages":["PRINT sql line 1","PRINT sql line 2"]}` + "\n"
dbInst := &OptionalDriverAgentDB{
driverType: "sqlserver",
client: &optionalDriverAgentClient{
stdin: &stdin,
reader: bufio.NewReader(strings.NewReader(stdout)),
driver: "sqlserver",
},
}
rows, fields, messages, err := dbInst.QueryWithMessages("exec dbo.p_get_select")
if err != nil {
t.Fatalf("QueryWithMessages 返回错误: %v", err)
}
if len(rows) != 1 || rows[0]["sql_text"] != "select 1" {
t.Fatalf("查询结果异常: %#v", rows)
}
if len(fields) != 1 || fields[0] != "sql_text" {
t.Fatalf("字段异常: %#v", fields)
}
if len(messages) != 2 || messages[0] != "PRINT sql line 1" {
t.Fatalf("消息异常: %#v", messages)
}
if !strings.Contains(stdin.String(), `"method":"query"`) {
t.Fatalf("请求未使用 query 方法: %s", stdin.String())
}
}
func TestOptionalDriverAgentDBQueryMultiWithMessagesParsesResultSets(t *testing.T) {
var stdin optionalAgentTestWriteCloser
stdout := `{"id":1,"success":true,"data":[{"statementIndex":1,"rows":[{"name":"master"}],"columns":["name"]},{"statementIndex":1,"rows":[],"columns":[],"messages":["PRINT generated sql"]}],"messages":["batch top-level message"]}` + "\n"
dbInst := &OptionalDriverAgentDB{
driverType: "sqlserver",
client: &optionalDriverAgentClient{
stdin: &stdin,
reader: bufio.NewReader(strings.NewReader(stdout)),
driver: "sqlserver",
},
}
resultSets, messages, err := dbInst.QueryMultiWithMessages("exec dbo.p_get_select")
if err != nil {
t.Fatalf("QueryMultiWithMessages 返回错误: %v", err)
}
if len(resultSets) != 2 {
t.Fatalf("结果集数量异常: %#v", resultSets)
}
if got := resultSets[0].Rows[0]["name"]; got != "master" {
t.Fatalf("首个结果集异常got=%v", got)
}
if len(resultSets[1].Messages) != 1 || resultSets[1].Messages[0] != "PRINT generated sql" {
t.Fatalf("消息结果集异常: %#v", resultSets[1])
}
if len(messages) != 1 || messages[0] != "batch top-level message" {
t.Fatalf("顶层消息异常: %#v", messages)
}
if !strings.Contains(stdin.String(), `"method":"queryMulti"`) {
t.Fatalf("请求未使用 queryMulti 方法: %s", stdin.String())
}
}