🐛 fix(oceanbase/data-grid): 修复 Oracle 时间字段显示编辑与结果视图异常

- 修复 OceanBase Oracle DATE 与 TIMESTAMP 的解码、展示和编辑精度丢失问题
- 修复查询结果与数据视图的行号显示、分页页数和日期列展示口径
- 打通 Oracle 与 OceanBase 会话执行链路的扫描方言透传
- 补齐 DBQuery、DataGrid temporal 和 OceanBase 结果链路回归测试
This commit is contained in:
Syngnat
2026-06-17 09:49:15 +08:00
parent 6421662f5d
commit 0632c5242c
25 changed files with 2702 additions and 198 deletions

View File

@@ -222,7 +222,9 @@ describe('DataGrid layout', () => {
paginationSummaryText="当前 24 条 / 共 24 条"
paginationControlTotal={24}
paginationTotalPages={1}
paginationPageText="第 1 / 1 页"
paginationPageSizeOptions={['100', '200']}
showKnownPageCount
onPageChange={() => {}}
onPageSizeChange={() => {}}
onV2PageStep={() => {}}
@@ -233,6 +235,37 @@ describe('DataGrid layout', () => {
expect(markup).not.toContain('第 1 / 1 页');
});
it('keeps unknown-total pagination in sequential mode instead of pretending total pages are known', () => {
const markup = renderToStaticMarkup(
<DataGrid
data={[
{
__gonavi_row_key__: 'row-1',
id: 1,
name: 'alpha',
},
]}
columnNames={['id', 'name']}
loading={false}
tableName="users"
dbName="main"
connectionId="conn-1"
readOnly
pagination={{
current: 3,
pageSize: 100,
total: 400,
totalKnown: false,
}}
onPageChange={() => {}}
/>,
);
expect(markup).toContain('第 3 页');
expect(markup).not.toContain('<strong>3</strong><span>/</span><span>4</span>');
expect(markup).not.toContain('data-grid-pagination-jump="true"');
});
it('renders the v2 DataGrid toolbar using the redesigned topbar hooks', () => {
const markup = renderToStaticMarkup(
<DataGrid
@@ -281,12 +314,62 @@ describe('DataGrid layout', () => {
expect(markup).toContain('AI 洞察');
});
it('renders a non-data row number column when enabled', () => {
const markup = renderToStaticMarkup(
<DataGrid
data={[
{
__gonavi_row_key__: 'row-1',
id: 1,
name: 'alpha',
},
]}
columnNames={['id', 'name']}
loading={false}
tableName="events"
dbName="main"
connectionId="conn-1"
readOnly
showRowNumberColumn
pagination={{
current: 2,
pageSize: 50,
total: 51,
}}
onPageChange={() => {}}
/>,
);
expect(markup).toContain('aria-label="行号"');
expect(markup).toContain('<span aria-label="行号">#</span>');
expect(markup).not.toContain('>行号<');
expect(markup).toContain('data-grid-row-number-title="true"');
expect(markup).toContain('data-grid-column-title-single-line="true"');
expect(markup).toContain('justify-content:center');
expect(markup).toContain('align-items:center');
expect(markup).toContain('min-height:var(--gonavi-header-min-height, 40px)');
expect(markup).toContain('text-align:center');
expect(markup).toContain('padding-inline:0');
expect(markup).toContain('vertical-align:middle');
expect(markup).toContain('data-grid-row-number="true"');
expect(markup).toContain('51');
});
it('clears modified cell markers when refreshing the grid', () => {
const source = readFileSync(new URL('./DataGrid.tsx', import.meta.url), 'utf8');
expect(source).toMatch(/const handleRefreshGrid = useCallback\(\(\) => \{[\s\S]*setModifiedColumns\(\{\}\);[\s\S]*if \(onReload\) onReload\(\);[\s\S]*\}, \[clearAutoCommitTimer, onReload\]\);/);
});
it('routes temporal inline editors through the current connection config', () => {
const source = readFileSync(new URL('./DataGrid.tsx', import.meta.url), 'utf8');
expect(source).toContain('const pickerType = getTemporalPickerType(columnType, dbType, connectionConfig);');
expect(source).toContain('const pickerType = getTemporalPickerType(columnType, dbType, currentConnConfig);');
expect(source).toContain('cellProps.connectionConfig = currentConnConfig;');
expect(source).toContain('format={getTemporalPickerFormat(pickerType)}');
});
it('renders a cell-level undo action in the v2 context menu for modified cells', () => {
const markup = renderToStaticMarkup(
<V2CellContextMenuView
@@ -305,6 +388,18 @@ describe('DataGrid layout', () => {
expect(formatCellDisplayText('2026-05-10T09:12:33.456+08:00')).toBe('2026-05-10 09:12:33.456');
});
it('collapses OceanBase Oracle DATE midnight values to date-only text', () => {
const oceanBaseOracleConfig = {
type: 'oceanbase',
oceanBaseProtocol: 'oracle',
} as any;
expect(formatCellDisplayText('2026-06-16T00:00:00Z', 'DATE', oceanBaseOracleConfig)).toBe('2026-06-16');
expect(formatCellDisplayText('2026-06-16 00:00:00', 'DATE', oceanBaseOracleConfig)).toBe('2026-06-16');
expect(formatCellDisplayText('2026-06-16T13:14:15Z', 'DATE', oceanBaseOracleConfig)).toBe('2026-06-16 13:14:15');
expect(formatCellDisplayText('2026-06-16T00:00:00Z', 'DATE', { type: 'oracle' } as any)).toBe('2026-06-16 00:00:00');
});
it('renders bit column hex values as decimal flags', () => {
expect(formatCellDisplayText('0x00', 'bit(1)')).toBe('0');
expect(formatCellDisplayText('0x01', 'bit(1)')).toBe('1');

View File

@@ -31,12 +31,13 @@ import { buildOrderBySQL, buildPaginatedSelectSQL, buildWhereSQL, escapeLiteral,
import { isMacLikePlatform, normalizeOpacityForPlatform, resolveAppearanceValues } from '../utils/appearance';
import { getDataSourceCapabilities, resolveDataSourceType } from '../utils/dataSourceCapabilities';
import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig';
import { normalizeOceanBaseProtocol } from '../utils/oceanBaseProtocol';
import {
getDensityParams,
resolveDataTableColumnWidth,
resolveDataTableVerticalBorderColor,
} from '../utils/dataGridDisplay';
import { resolvePaginationSummaryText, resolvePaginationTotalForControl } from '../utils/dataGridPagination';
import { resolvePaginationPageText, resolvePaginationSummaryText, resolvePaginationTotalForControl } from '../utils/dataGridPagination';
import { resolveGridSortInfoFromTableSorter } from '../utils/dataGridSort';
import {
calculateExternalHorizontalScrollInnerWidth,
@@ -71,10 +72,12 @@ import { DEFAULT_SHORTCUT_OPTIONS, getShortcutPlatform, resolveShortcutDisplay }
import {
TEMPORAL_FORMATS,
formatFromDayjs,
getTemporalPickerFormat,
getTemporalPickerType,
isTemporalColumnType,
parseToDayjs,
resolveTemporalEditorSaveValue,
type TemporalConnectionLike,
type TemporalPickerType,
} from './dataGridTemporal';
import {
@@ -182,6 +185,7 @@ class DataGridErrorBoundary extends React.Component<
// 内部行标识字段:避免与真实业务字段(如 `key` 列)冲突。
export const GONAVI_ROW_KEY = '__gonavi_row_key__';
export const GONAVI_ROW_NUMBER_COLUMN_KEY = '__gonavi_row_number__';
// Cell key helpers for batch selection/fill.
// Use a control character separator to avoid collisions with rowKey/columnName contents (e.g. `new-123`).
@@ -189,6 +193,7 @@ const CELL_KEY_SEP = '\u0001';
const CELL_SELECTION_DRAG_THRESHOLD_PX = 4;
const DATE_TIME_CACHE_LIMIT = 2000;
const TABLE_CELL_PREVIEW_MAX_CHARS = 240;
const ROW_NUMBER_COLUMN_WIDTH = 58;
const DATA_EDIT_AUTO_COMMIT_DELAY_OPTIONS = [
{ value: 3000, label: '3 秒' },
{ value: 5000, label: '5 秒' },
@@ -281,7 +286,46 @@ const normalizeBitHexDisplayText = (val: any, columnType?: string): string | nul
}
};
export const formatCellDisplayText = (val: any, columnType?: string): string => {
type CellDisplayConnectionLike = TemporalConnectionLike;
const isDateOnlyColumnType = (columnType?: string): boolean => {
const normalized = String(columnType || '').trim().toLowerCase();
if (!normalized) return false;
const base = normalized.split(/[ (]/)[0];
return base === 'date' || base === 'newdate';
};
const isOceanBaseOracleDisplayConnection = (connectionConfig?: CellDisplayConnectionLike): boolean => {
if (!connectionConfig) return false;
const type = String(connectionConfig.type || '').trim().toLowerCase();
const driver = String(connectionConfig.driver || '').trim().toLowerCase();
return (type === 'oceanbase' || driver === 'oceanbase')
&& normalizeOceanBaseProtocol(connectionConfig.oceanBaseProtocol) === 'oracle';
};
const normalizeOceanBaseOracleDateDisplayText = (
val: string,
columnType?: string,
connectionConfig?: CellDisplayConnectionLike,
): string | null => {
if (!isDateOnlyColumnType(columnType) || !isOceanBaseOracleDisplayConnection(connectionConfig)) {
return null;
}
const trimmed = String(val || '').trim();
if (!trimmed) return trimmed;
const match = trimmed.match(
/^(\d{4}-\d{2}-\d{2})(?:[T ](\d{2}:\d{2}:\d{2})(\.\d+)?(?:\s*(?:Z|[+-]\d{2}:?\d{2})(?:\s+[A-Za-z_\/+-]+)?)?)?$/
);
if (!match) return null;
const [, datePart, timePart, fractionPart] = match;
if (!timePart) return datePart;
if (timePart === '00:00:00' && (!fractionPart || /^\.0+$/.test(fractionPart))) {
return datePart;
}
return null;
};
export const formatCellDisplayText = (val: any, columnType?: string, connectionConfig?: CellDisplayConnectionLike): string => {
try {
if (val === null) return 'NULL';
const bitText = normalizeBitHexDisplayText(val, columnType);
@@ -310,6 +354,10 @@ export const formatCellDisplayText = (val: any, columnType?: string): string =>
}
}
if (typeof val === 'string') {
const oceanBaseDateOnly = normalizeOceanBaseOracleDateDisplayText(val, columnType, connectionConfig);
if (oceanBaseDateOnly !== null) {
return oceanBaseDateOnly.length > TABLE_CELL_PREVIEW_MAX_CHARS ? `${oceanBaseDateOnly.slice(0, TABLE_CELL_PREVIEW_MAX_CHARS)}` : oceanBaseDateOnly;
}
const normalized = normalizeDateTimeString(val);
return normalized.length > TABLE_CELL_PREVIEW_MAX_CHARS ? `${normalized.slice(0, TABLE_CELL_PREVIEW_MAX_CHARS)}` : normalized;
}
@@ -320,12 +368,16 @@ export const formatCellDisplayText = (val: any, columnType?: string): string =>
}
};
const formatClipboardCellText = (val: any, columnType?: string): string => {
const formatClipboardCellText = (val: any, columnType?: string, connectionConfig?: CellDisplayConnectionLike): string => {
try {
if (val === null || val === undefined) return 'NULL';
const bitText = normalizeBitHexDisplayText(val, columnType);
if (bitText !== null) return bitText;
if (typeof val === 'string') return normalizeDateTimeString(val);
if (typeof val === 'string') {
const oceanBaseDateOnly = normalizeOceanBaseOracleDateDisplayText(val, columnType, connectionConfig);
if (oceanBaseDateOnly !== null) return oceanBaseDateOnly;
return normalizeDateTimeString(val);
}
if (typeof val === 'object') {
try {
return JSON.stringify(val);
@@ -346,6 +398,7 @@ const buildClipboardTsv = (
rows: Array<Record<string, any>>,
columnNames: string[],
getColumnType?: (columnName: string) => string | undefined,
connectionConfig?: CellDisplayConnectionLike,
): string => {
if (!Array.isArray(rows) || rows.length === 0 || !Array.isArray(columnNames) || columnNames.length === 0) {
return '';
@@ -353,7 +406,7 @@ const buildClipboardTsv = (
const header = columnNames.map(normalizeClipboardTsvCell).join('\t');
const lines = rows.map((row) => (
columnNames
.map((columnName) => normalizeClipboardTsvCell(formatClipboardCellText(row?.[columnName], getColumnType?.(columnName))))
.map((columnName) => normalizeClipboardTsvCell(formatClipboardCellText(row?.[columnName], getColumnType?.(columnName), connectionConfig)))
.join('\t')
));
return [header, ...lines].join('\n');
@@ -382,8 +435,8 @@ const renderHighlightedCellText = (text: string, query: string): React.ReactNode
return <>{nodes}</>;
};
const renderCellDisplayValue = (val: any, query: string, columnType?: string): React.ReactNode => {
const text = formatCellDisplayText(val, columnType);
const renderCellDisplayValue = (val: any, query: string, columnType?: string, connectionConfig?: CellDisplayConnectionLike): React.ReactNode => {
const text = formatCellDisplayText(val, columnType, connectionConfig);
const content = renderHighlightedCellText(text, query);
if (val === null) return <span style={{ color: '#ccc' }}>{content}</span>;
return content;
@@ -778,6 +831,7 @@ interface EditableCellProps {
focusCell?: (record: Item, dataIndex: string, title: React.ReactNode) => void;
columnType?: string;
dbType?: string;
connectionConfig?: CellDisplayConnectionLike;
inputCellPadding?: React.CSSProperties;
as?: any;
modifiedColumns?: Record<string, Set<string>>;
@@ -828,6 +882,9 @@ const areEditableCellPropsEqual = (prevProps: EditableCellProps, nextProps: Edit
if (prevProps.title !== nextProps.title) return false;
if (prevProps.columnType !== nextProps.columnType) return false;
if (prevProps.dbType !== nextProps.dbType) return false;
if ((prevProps.connectionConfig?.type ?? null) !== (nextProps.connectionConfig?.type ?? null)) return false;
if ((prevProps.connectionConfig?.driver ?? null) !== (nextProps.connectionConfig?.driver ?? null)) return false;
if ((prevProps.connectionConfig?.oceanBaseProtocol ?? null) !== (nextProps.connectionConfig?.oceanBaseProtocol ?? null)) return false;
if (prevProps.darkMode !== nextProps.darkMode) return false;
if (prevProps.as !== nextProps.as) return false;
if (prevProps.handleSave !== nextProps.handleSave) return false;
@@ -866,6 +923,7 @@ const EditableCell: React.FC<EditableCellProps> = React.memo(({
focusCell,
columnType,
dbType,
connectionConfig,
inputCellPadding,
as: Component = 'td',
modifiedColumns,
@@ -953,7 +1011,7 @@ const EditableCell: React.FC<EditableCellProps> = React.memo(({
let childNode = children;
const pickerType = getTemporalPickerType(columnType, dbType);
const pickerType = getTemporalPickerType(columnType, dbType, connectionConfig);
const isDateTimeField = !!pickerType && !(/^0{4}-0{2}-0{2}/.test(String(record?.[dataIndex] || '')));
const isRowDeleted = deletedRowKeys && rowKeyStr && record?.[GONAVI_ROW_KEY] !== undefined
@@ -988,7 +1046,7 @@ const EditableCell: React.FC<EditableCellProps> = React.memo(({
style={{ width: '100%' }}
showTime
showNow={false}
format={TEMPORAL_FORMATS[pickerType]}
format={getTemporalPickerFormat(pickerType)}
renderExtraFooter={() => (
<a
style={{ padding: '0 2px' }}
@@ -1210,6 +1268,7 @@ interface DataGridProps {
pkColumns?: string[];
editLocator?: EditRowLocator;
readOnly?: boolean;
showRowNumberColumn?: boolean;
onReload?: () => void;
onSort?: (field: string, order: string) => void;
onPageChange?: (page: number, size: number) => void;
@@ -1499,7 +1558,7 @@ const DataGrid: React.FC<DataGridProps> = ({
resultExportAllSql,
onReload, onSort, onPageChange, pagination, onRequestTotalCount, onCancelTotalCount, sortInfoExternal, showFilter, onToggleFilter, exportSqlWithFilter, onApplyFilter, appliedFilterConditions, quickWhereCondition,
onApplyQuickWhereCondition,
scrollSnapshot, onScrollSnapshotChange, toolbarExtraActions
scrollSnapshot, onScrollSnapshotChange, toolbarExtraActions, showRowNumberColumn = false
}) => {
const connections = useStore(state => state.connections);
const addTab = useStore(state => state.addTab);
@@ -4448,7 +4507,7 @@ const DataGrid: React.FC<DataGridProps> = ({
}
const columnType = (columnMetaMap[dataIndex] || columnMetaMapByLowerName[dataIndex.toLowerCase()])?.type;
const pickerType = getTemporalPickerType(columnType, dbType);
const pickerType = getTemporalPickerType(columnType, dbType, currentConnConfig);
const isDateTimeField = !!pickerType && !(/^0{4}-0{2}-0{2}/.test(String(raw || '')));
const fieldName = getCellFieldName(record, dataIndex);
if (isDateTimeField) {
@@ -4463,7 +4522,7 @@ const DataGrid: React.FC<DataGridProps> = ({
title,
columnType,
});
}, [canModifyData, columnMetaMap, columnMetaMapByLowerName, dbType, form, openCellEditor, rowKeyStr]);
}, [canModifyData, columnMetaMap, columnMetaMapByLowerName, currentConnConfig, dbType, form, openCellEditor, rowKeyStr]);
const handleVirtualCellActivate = useCallback((record: Item, dataIndex: string, title: React.ReactNode) => {
if (!canModifyData) return;
@@ -4593,7 +4652,7 @@ const DataGrid: React.FC<DataGridProps> = ({
return;
}
const pickerType = getTemporalPickerType(editingCell.columnType, dbType);
const pickerType = getTemporalPickerType(editingCell.columnType, dbType, currentConnConfig);
const isDateTimeField = !!pickerType && !(/^0{4}-0{2}-0{2}/.test(String(record?.[editingCell.dataIndex] || '')));
const fieldName = getCellFieldName(record, editingCell.dataIndex);
try {
@@ -4612,22 +4671,30 @@ const DataGrid: React.FC<DataGridProps> = ({
closeVirtualInlineEditor();
}
}
}, [closeVirtualInlineEditor, dbType, form, handleCellSave, virtualEditingCell]);
}, [closeVirtualInlineEditor, currentConnConfig, dbType, form, handleCellSave, virtualEditingCell]);
const pageFindMatches = useMemo(() => collectDataGridFindMatches(
mergedDisplayData,
displayColumnNames,
normalizedPageFindText,
(value, _row, columnName) => formatCellDisplayText(value, (columnMetaMap[columnName] || columnMetaMapByLowerName[columnName.toLowerCase()])?.type),
(value, _row, columnName) => formatCellDisplayText(
value,
(columnMetaMap[columnName] || columnMetaMapByLowerName[columnName.toLowerCase()])?.type,
currentConnConfig,
),
(row, rowIndex) => String(row?.[GONAVI_ROW_KEY] ?? `row-${rowIndex}`),
), [mergedDisplayData, displayColumnNames, normalizedPageFindText, columnMetaMap, columnMetaMapByLowerName]);
), [mergedDisplayData, displayColumnNames, normalizedPageFindText, columnMetaMap, columnMetaMapByLowerName, currentConnConfig]);
const pageFindSummary = useMemo(() => summarizeDataGridFindMatches(
mergedDisplayData,
displayColumnNames,
normalizedPageFindText,
(value, _row, columnName) => formatCellDisplayText(value, (columnMetaMap[columnName] || columnMetaMapByLowerName[columnName.toLowerCase()])?.type),
), [mergedDisplayData, displayColumnNames, normalizedPageFindText, columnMetaMap, columnMetaMapByLowerName]);
(value, _row, columnName) => formatCellDisplayText(
value,
(columnMetaMap[columnName] || columnMetaMapByLowerName[columnName.toLowerCase()])?.type,
currentConnConfig,
),
), [mergedDisplayData, displayColumnNames, normalizedPageFindText, columnMetaMap, columnMetaMapByLowerName, currentConnConfig]);
useEffect(() => {
setActivePageFindMatchIndex(-1);
@@ -4734,7 +4801,7 @@ const DataGrid: React.FC<DataGridProps> = ({
displayMap[col] = toFormText(displayVal);
// 日期时间类型: 将字符串值转为 dayjs 对象供 DatePicker 使用
const colMeta = columnMetaMap[col] || columnMetaMapByLowerName[col.toLowerCase()];
const rowPickerType = getTemporalPickerType(colMeta?.type, dbType);
const rowPickerType = getTemporalPickerType(colMeta?.type, dbType, currentConnConfig);
if (rowPickerType && displayVal !== null && displayVal !== undefined) {
const dVal = parseToDayjs(displayVal, rowPickerType);
formMap[col] = dVal;
@@ -4751,7 +4818,7 @@ const DataGrid: React.FC<DataGridProps> = ({
nullCols,
formValues: formMap,
});
}, [canModifyData, mergedDisplayData, data, addedRows, visibleColumnNames, rowKeyStr, columnMetaMap, columnMetaMapByLowerName, dbType, openRowEditor]);
}, [addedRows, canModifyData, columnMetaMap, columnMetaMapByLowerName, currentConnConfig, data, dbType, mergedDisplayData, openRowEditor, rowKeyStr, visibleColumnNames]);
const openCurrentViewRowEditor = useCallback(() => {
if (!canModifyData) return;
@@ -4916,7 +4983,7 @@ const DataGrid: React.FC<DataGridProps> = ({
if (!isWritableResultColumn(col, effectiveEditLocator)) return;
if (val && dayjs.isDayjs(val)) {
const colMeta = columnMetaMap[col] || columnMetaMapByLowerName[col.toLowerCase()];
const rowPickerType = getTemporalPickerType(colMeta?.type, dbType);
const rowPickerType = getTemporalPickerType(colMeta?.type, dbType, currentConnConfig);
convertedValues[col] = formatFromDayjs(val as dayjs.Dayjs, rowPickerType);
} else {
convertedValues[col] = val;
@@ -4935,7 +5002,7 @@ const DataGrid: React.FC<DataGridProps> = ({
// 日期时间类型: 将 dayjs 对象转回格式化字符串
if (nextVal && dayjs.isDayjs(nextVal)) {
const colMeta = columnMetaMap[col] || columnMetaMapByLowerName[col.toLowerCase()];
const rowPickerType = getTemporalPickerType(colMeta?.type, dbType);
const rowPickerType = getTemporalPickerType(colMeta?.type, dbType, currentConnConfig);
nextVal = formatFromDayjs(nextVal as dayjs.Dayjs, rowPickerType);
}
const baseVal = baseRawMap[col];
@@ -4950,7 +5017,7 @@ const DataGrid: React.FC<DataGridProps> = ({
});
closeRowEditor();
}, [rowEditorRowKey, rowEditorForm, addedRows, visibleColumnNames, rowKeyStr, closeRowEditor, effectiveEditLocator, columnMetaMap, columnMetaMapByLowerName, dbType]);
}, [addedRows, closeRowEditor, columnMetaMap, columnMetaMapByLowerName, currentConnConfig, dbType, effectiveEditLocator, rowEditorForm, rowEditorRowKey, rowKeyStr, visibleColumnNames]);
const enableVirtual = isTableSurfaceActive;
@@ -4985,7 +5052,7 @@ const DataGrid: React.FC<DataGridProps> = ({
sortOrder: (sortInfo.find(s => s.columnKey === key && s.enabled !== false)?.order || null) as SortOrder | undefined,
editable: canModifyData && isWritableResultColumn(key, effectiveEditLocator),
render: (text: any) => {
const renderedContent = renderCellDisplayValue(text, normalizedPageFindText, displayColumnTypeMap[key]);
const renderedContent = renderCellDisplayValue(text, normalizedPageFindText, displayColumnTypeMap[key], currentConnConfig);
if (enableVirtual) {
return renderedContent;
}
@@ -5037,7 +5104,7 @@ const DataGrid: React.FC<DataGridProps> = ({
},
}),
}));
}, [displayColumnNames, columnWidths, sortInfo, handleResizeStart, handleResizeAutoFit, isV2Ui, showColumnHeaderContextMenu, canModifyData, onSort, renderColumnTitle, dataTableDensity, normalizedPageFindText, displayColumnTypeMap, enableVirtual, showColumnComment, showColumnType]);
}, [canModifyData, columnWidths, currentConnConfig, dataTableDensity, displayColumnNames, displayColumnTypeMap, enableVirtual, handleResizeAutoFit, handleResizeStart, isV2Ui, normalizedPageFindText, onSort, renderColumnTitle, showColumnComment, showColumnHeaderContextMenu, showColumnType, sortInfo]);
const mergedColumns = useMemo(() => columns.map((col): ColumnType<any> => {
const dataIndex = String(col.dataIndex);
@@ -5067,6 +5134,7 @@ const DataGrid: React.FC<DataGridProps> = ({
cellProps.focusCell = openCellEditor;
cellProps.columnType = displayColumnTypeMap[dataIndex];
cellProps.dbType = dbType;
cellProps.connectionConfig = currentConnConfig;
cellProps.inputCellPadding = inputCellPadding;
cellProps.modifiedColumns = modifiedColumns;
cellProps.rowKeyStr = rowKeyStr;
@@ -5097,7 +5165,7 @@ const DataGrid: React.FC<DataGridProps> = ({
: undefined;
const shouldUsePlainVirtualContent = isV2Ui && !modifiedStyle;
if (enableVirtual && enableInlineEditableCell) {
const pickerType = getTemporalPickerType(columnType, dbType);
const pickerType = getTemporalPickerType(columnType, dbType, currentConnConfig);
const isDateTimeField = !!pickerType && !(/^0{4}-0{2}-0{2}/.test(String(record?.[dataIndex] || '')));
const virtualCellStyle = modifiedStyle ? { ...virtualCellWrapperStyle, ...modifiedStyle } : virtualCellWrapperStyle;
const virtualEditable = !!col.editable && !rowDeletedForRender;
@@ -5126,7 +5194,7 @@ const DataGrid: React.FC<DataGridProps> = ({
style={{ width: '100%' }}
showTime
showNow={false}
format={TEMPORAL_FORMATS[pickerType]}
format={getTemporalPickerFormat(pickerType)}
renderExtraFooter={() => (
<a
style={{ padding: '0 2px' }}
@@ -5211,7 +5279,58 @@ const DataGrid: React.FC<DataGridProps> = ({
return originalRenderContent;
}
};
}), [columns, useInlineEditableBodyCell, enableInlineEditableCell, enableVirtual, handleCellSave, openCellEditor, handleVirtualCellActivate, handleSharedCellContextMenu, displayColumnTypeMap, dbType, inputCellPadding, virtualCellWrapperStyle, modifiedColumns, rowKeyStr, deletedRowKeys, darkMode, virtualEditingCell, form, saveVirtualInlineEditor, lockVirtualInlineTableScroll, closeVirtualInlineEditor, updateFocusedCell]);
}), [closeVirtualInlineEditor, columns, currentConnConfig, darkMode, dbType, deletedRowKeys, displayColumnTypeMap, enableInlineEditableCell, enableVirtual, form, handleCellSave, handleSharedCellContextMenu, handleVirtualCellActivate, inputCellPadding, lockVirtualInlineTableScroll, modifiedColumns, openCellEditor, rowKeyStr, saveVirtualInlineEditor, updateFocusedCell, useInlineEditableBodyCell, virtualCellWrapperStyle, virtualEditingCell]);
const rowNumberColumn = useMemo<ColumnType<any>>(() => ({
title: (
<div
className="gn-v2-column-title is-single-line"
data-grid-row-number-title="true"
data-grid-column-title-single-line="true"
style={{
display: 'flex',
flexDirection: 'column',
alignItems: 'center',
justifyContent: 'center',
minWidth: 0,
width: '100%',
maxWidth: '100%',
minHeight: 'var(--gonavi-header-min-height, 40px)',
lineHeight: 1.2,
textAlign: 'center',
}}
>
<span aria-label="行号">#</span>
</div>
),
key: GONAVI_ROW_NUMBER_COLUMN_KEY,
dataIndex: GONAVI_ROW_NUMBER_COLUMN_KEY,
width: ROW_NUMBER_COLUMN_WIDTH,
className: 'data-grid-row-number-cell',
align: 'center',
onHeaderCell: () => ({
style: {
textAlign: 'center' as const,
paddingInline: 0,
verticalAlign: 'middle' as const,
},
}),
render: (_value: unknown, _record: Item, index: number) => {
const currentPage = Math.max(1, Number(pagination?.current) || 1);
const pageSize = Math.max(1, Number(pagination?.pageSize) || 0);
const offset = pageSize > 0 ? (currentPage - 1) * pageSize : 0;
return (
<span className="data-grid-row-number" data-grid-row-number="true">
{offset + index + 1}
</span>
);
},
}), [pagination?.current, pagination?.pageSize]);
const tableColumns = useMemo(
() => (showRowNumberColumn ? [rowNumberColumn, ...mergedColumns] : mergedColumns),
[mergedColumns, rowNumberColumn, showRowNumberColumn]
);
const handleAddRow = () => {
const newKey = `new-${Date.now()}`;
@@ -5529,10 +5648,10 @@ const DataGrid: React.FC<DataGridProps> = ({
const columnType = (columnMetaMap[normalizedColumnName] || columnMetaMapByLowerName[normalizedColumnName.toLowerCase()])?.type;
const text = mergedDisplayData
.map((row) => normalizeClipboardTsvCell(formatClipboardCellText(row?.[normalizedColumnName], columnType)))
.map((row) => normalizeClipboardTsvCell(formatClipboardCellText(row?.[normalizedColumnName], columnType, currentConnConfig)))
.join('\n');
copyToClipboard(text);
}, [columnMetaMap, columnMetaMapByLowerName, copyToClipboard, displayOutputColumnNames, mergedDisplayData]);
}, [columnMetaMap, columnMetaMapByLowerName, copyToClipboard, currentConnConfig, displayOutputColumnNames, mergedDisplayData]);
const handleV2ColumnHeaderContextMenuAction = useCallback((action: V2ColumnHeaderContextMenuActionKey) => {
const columnName = resolveContextMenuFieldName(cellContextMenu.dataIndex, cellContextMenu.title);
@@ -5872,13 +5991,14 @@ const DataGrid: React.FC<DataGridProps> = ({
rows,
columns,
(columnName) => (columnMetaMap[columnName] || columnMetaMapByLowerName[columnName.toLowerCase()])?.type,
currentConnConfig,
);
if (!text) {
void message.info('当前行没有可复制内容');
return;
}
copyToClipboard(text);
}, [columnMetaMap, columnMetaMapByLowerName, copyToClipboard, displayOutputColumnNames, getContextMenuTargetRows]);
}, [columnMetaMap, columnMetaMapByLowerName, copyToClipboard, currentConnConfig, displayOutputColumnNames, getContextMenuTargetRows]);
const buildConnConfig = useCallback(() => {
if (!connectionId) return null;
@@ -6435,7 +6555,7 @@ const DataGrid: React.FC<DataGridProps> = ({
const rowPropsFactory = useCallback((record: any) => ({ record } as any), []);
const totalWidth = columns.reduce((sum: number, col: any) => sum + (Number(col.width) || densityParams.defaultColumnWidth), 0) + selectionColumnWidth;
const totalWidth = tableColumns.reduce((sum: number, col: any) => sum + (Number(col.width) || densityParams.defaultColumnWidth), 0) + selectionColumnWidth;
const useContextMenuRow = false;
const tableScrollX = useMemo(() => {
// rc-table 在 scroll.x 小于容器宽度时会把实际列宽按视口补齐。
@@ -7330,6 +7450,14 @@ const DataGrid: React.FC<DataGridProps> = ({
});
}, [pagination, supportsApproximateTotalPages]);
const paginationHasKnownTotalPages = useMemo(() => {
if (!pagination) return false;
if (pagination.totalKnown !== false) return true;
if (!supportsApproximateTotalPages || !pagination.totalApprox) return false;
const approximateTotal = Number(pagination.approximateTotal);
return Number.isFinite(approximateTotal) && approximateTotal > 0;
}, [pagination, supportsApproximateTotalPages]);
const paginationTotalPages = useMemo(() => {
if (!pagination) return 1;
if (!Number.isFinite(paginationControlTotal) || paginationControlTotal <= 0) {
@@ -7361,6 +7489,14 @@ const DataGrid: React.FC<DataGridProps> = ({
supportsApproximateTotalPages,
]);
const paginationPageText = useMemo(() => {
if (!pagination) return '';
return resolvePaginationPageText({
pagination,
supportsApproximateTotalPages,
});
}, [pagination, supportsApproximateTotalPages]);
const handlePageSizeChange = useCallback((value: string) => {
if (!pagination || !onPageChange) return;
const nextSize = Number(value);
@@ -7412,7 +7548,7 @@ const DataGrid: React.FC<DataGridProps> = ({
ref={tableRef}
components={tableComponents}
dataSource={tableRenderData}
columns={mergedColumns}
columns={tableColumns}
{...(enableVirtual && typeof virtualListItemHeight === 'number'
? { listItemHeight: virtualListItemHeight }
: {})}
@@ -7502,7 +7638,9 @@ const DataGrid: React.FC<DataGridProps> = ({
paginationSummaryText={paginationSummaryText}
paginationControlTotal={paginationControlTotal}
paginationTotalPages={paginationTotalPages}
paginationPageText={paginationPageText}
paginationPageSizeOptions={paginationPageSizeOptions}
showKnownPageCount={paginationHasKnownTotalPages}
onPageChange={onPageChange}
onPageSizeChange={handlePageSizeChange}
onV2PageStep={handleV2PageStep}
@@ -7516,7 +7654,7 @@ const DataGrid: React.FC<DataGridProps> = ({
const isJson = looksLikeJsonText(sample);
const useTextArea = isJson || sample.includes('\n') || sample.length >= 160;
const colMeta = columnMetaMap[col] || columnMetaMapByLowerName[col.toLowerCase()];
const pickerType = getTemporalPickerType(colMeta?.type, dbType);
const pickerType = getTemporalPickerType(colMeta?.type, dbType, currentConnConfig);
const isTemporalValue = !!pickerType && !(/^0{4}-0{2}-0{2}/.test(String(sample || '')));
const isWritable = isWritableResultColumn(col, effectiveEditLocator);
return {
@@ -7530,7 +7668,7 @@ const DataGrid: React.FC<DataGridProps> = ({
isWritable,
};
})
), [displayColumnNames, columnMetaMap, columnMetaMapByLowerName, dbType, effectiveEditLocator, rowEditorOpen, rowEditorRowKey]);
), [columnMetaMap, columnMetaMapByLowerName, currentConnConfig, dbType, displayColumnNames, effectiveEditLocator, rowEditorOpen, rowEditorRowKey]);
const handleRefreshGrid = useCallback(() => {
clearAutoCommitTimer();

View File

@@ -20,7 +20,9 @@ export interface DataGridPaginationBarProps {
paginationSummaryText: string;
paginationControlTotal: number;
paginationTotalPages: number;
paginationPageText: string;
paginationPageSizeOptions: string[];
showKnownPageCount: boolean;
onPageChange?: (page: number, size: number) => void;
onPageSizeChange: (value: string) => void;
onV2PageStep: (direction: 'previous' | 'next') => void;
@@ -33,7 +35,9 @@ const DataGridPaginationBar: React.FC<DataGridPaginationBarProps> = ({
paginationSummaryText,
paginationControlTotal,
paginationTotalPages,
paginationPageText,
paginationPageSizeOptions,
showKnownPageCount,
onPageChange,
onPageSizeChange,
onV2PageStep,
@@ -58,7 +62,7 @@ const DataGridPaginationBar: React.FC<DataGridPaginationBarProps> = ({
if (normalizedJumpPage === pagination.current) return;
onPageChange(normalizedJumpPage, pagination.pageSize);
};
const jumpPageControl = (
const jumpPageControl = showKnownPageCount ? (
<div className="data-grid-pagination-jump" data-grid-pagination-jump="true">
<span className="data-grid-pagination-jump-label"></span>
<InputNumber
@@ -83,7 +87,7 @@ const DataGridPaginationBar: React.FC<DataGridPaginationBarProps> = ({
</Button>
</div>
);
) : null;
return (
<div
@@ -103,9 +107,15 @@ const DataGridPaginationBar: React.FC<DataGridPaginationBarProps> = ({
onClick={() => onV2PageStep('previous')}
/>
<div className="data-grid-pagination-page-chip" data-grid-v2-page-chip="true">
<strong>{pagination.current}</strong>
<span>/</span>
<span>{paginationTotalPages}</span>
{showKnownPageCount ? (
<>
<strong>{pagination.current}</strong>
<span>/</span>
<span>{paginationTotalPages}</span>
</>
) : (
<span>{paginationPageText}</span>
)}
</div>
<Button
data-grid-v2-pagination-next="true"

View File

@@ -331,6 +331,38 @@ describe('DataViewer safe editing locator', () => {
renderer.unmount();
});
it('uses hidden OceanBase Oracle ROWID when no primary or unique key is available', async () => {
storeState.connections[0].config.type = 'oceanbase';
(storeState.connections[0].config as any).oceanBaseProtocol = 'oracle';
storeState.connections[0].config.user = 'dev';
storeState.connections[0].config.database = 'ORCLPDB1';
backendApp.DBGetColumns.mockResolvedValue({
success: true,
data: [{ name: 'ID', key: '' }, { name: 'NAME', key: '' }],
});
backendApp.DBQuery.mockResolvedValue({
success: true,
fields: ['ID', 'NAME', ORACLE_ROWID_LOCATOR_COLUMN],
data: [{ ID: 7, NAME: 'old-name', [ORACLE_ROWID_LOCATOR_COLUMN]: 'AAAA' }],
});
const renderer = await renderAndReload(createTab({ id: 'tab-ob-oracle-rowid', dbName: 'ORCLPDB1', tableName: 'EDC_LOG', title: 'EDC_LOG' }));
expect(dataGridState.latestProps?.pkColumns).toEqual([]);
expect(dataGridState.latestProps?.editLocator).toMatchObject({
strategy: 'oracle-rowid',
columns: ['ROWID'],
valueColumns: [ORACLE_ROWID_LOCATOR_COLUMN],
hiddenColumns: [ORACLE_ROWID_LOCATOR_COLUMN],
readOnly: false,
});
expect(dataGridState.latestProps?.readOnly).toBe(false);
expect(dataGridState.latestProps?.showRowNumberColumn).toBe(true);
expect(messageApi.warning).not.toHaveBeenCalled();
expect(backendApp.DBQuery.mock.calls.some((call: any[]) => String(call[2]).includes(`ROWID AS "${ORACLE_ROWID_LOCATOR_COLUMN}"`))).toBe(true);
renderer.unmount();
});
it('does not add fallback ORDER BY for DuckDB table preview when a primary key is available', async () => {
storeState.connections[0].config.type = 'duckdb';
storeState.connections[0].config.database = 'main';

View File

@@ -7,7 +7,7 @@ import DataGrid, { GONAVI_ROW_KEY } from './DataGrid';
import { buildOrderBySQL, buildPaginatedSelectSQL, buildWhereSQL, hasExplicitSort, quoteIdentPart, quoteQualifiedIdent, reverseOrderBySQL, withSortBufferTuningSQL, type FilterCondition } from '../utils/sql';
import { buildMongoCountCommand, buildMongoFilter, buildMongoFindCommand, buildMongoSort } from '../utils/mongodb';
import { buildOracleApproximateTotalSql, parseApproximateTableCountRow, resolveApproximateTableCountStrategy } from '../utils/approximateTableCount';
import { getDataSourceCapabilities, resolveDataSourceType } from '../utils/dataSourceCapabilities';
import { getDataSourceCapabilities, resolveDataSourceType, shouldShowOceanBaseRowNumberColumn } from '../utils/dataSourceCapabilities';
import { resolveDataViewerAutoFetchAction } from '../utils/dataViewerAutoFetch';
import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig';
import {
@@ -324,6 +324,7 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = React.memo(({
const [quickWhereCondition, setQuickWhereCondition] = useState<string>(initialViewerSnapshot.quickWhereCondition);
const duckdbSafeSelectCacheRef = useRef<Record<string, string>>({});
const currentConnConfig = connections.find(c => c.id === tab.connectionId)?.config;
const showRowNumberColumn = shouldShowOceanBaseRowNumberColumn(currentConnConfig);
const currentConnCaps = getDataSourceCapabilities(currentConnConfig);
const forceReadOnly = currentConnCaps.forceReadOnlyQueryResult;
const preferManualTotalCount = currentConnCaps.preferManualTotalCount;
@@ -1110,6 +1111,7 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = React.memo(({
connectionId={tab.connectionId}
pkColumns={pkColumns}
editLocator={editLocator}
showRowNumberColumn={showRowNumberColumn}
onReload={handleReload}
onSort={handleSort}
onPageChange={handlePageChange}

View File

@@ -1,7 +1,13 @@
import dayjs from 'dayjs';
import { describe, expect, it } from 'vitest';
import { getTemporalPickerType, resolveTemporalEditorSaveValue } from './dataGridTemporal';
import {
formatFromDayjs,
getTemporalPickerFormat,
getTemporalPickerType,
parseToDayjs,
resolveTemporalEditorSaveValue,
} from './dataGridTemporal';
describe('dataGridTemporal helpers', () => {
it('prefers the picker selected date when form store has not caught up yet', () => {
@@ -23,4 +29,39 @@ describe('dataGridTemporal helpers', () => {
expect(resolveTemporalEditorSaveValue(undefined, dayjs('2026-06-11 19:42:13'), pickerType))
.toBe('2026-06-11');
});
it('keeps OceanBase Oracle DATE columns as date-only editors', () => {
const pickerType = getTemporalPickerType('DATE', 'oracle', {
type: 'oceanbase',
oceanBaseProtocol: 'oracle',
} as any);
expect(pickerType).toBe('date');
expect(resolveTemporalEditorSaveValue(undefined, dayjs('2026-06-11 19:42:13'), pickerType))
.toBe('2026-06-11');
});
it('preserves datetime fractional seconds when round-tripping through the editor helper', () => {
const parsed = parseToDayjs('2026-06-16T16:46:23.158844Z', 'datetime');
expect(parsed?.isValid()).toBe(true);
expect(formatFromDayjs(parsed, 'datetime')).toBe('2026-06-16 16:46:23.158844');
expect(resolveTemporalEditorSaveValue(undefined, parsed, 'datetime')).toBe('2026-06-16 16:46:23.158844');
});
it('keeps RFC3339 wall clock text instead of applying local timezone conversion in datetime editors', () => {
const parsed = parseToDayjs('2026-06-17T05:00:00Z', 'datetime');
expect(parsed?.isValid()).toBe(true);
expect(formatFromDayjs(parsed, 'datetime')).toBe('2026-06-17 05:00:00');
});
it('uses a custom datetime picker format that can display preserved microseconds', () => {
const parsed = parseToDayjs('2026-06-16T16:46:23.158844Z', 'datetime');
const format = getTemporalPickerFormat('datetime');
expect(Array.isArray(format)).toBe(true);
expect(typeof format[0]).toBe('function');
expect((format[0] as (value: dayjs.Dayjs) => string)(parsed!)).toBe('2026-06-16 16:46:23.158844');
});
});

View File

@@ -1,7 +1,10 @@
import dayjs from 'dayjs';
import type { ConnectionConfig } from '../types';
import { normalizeOceanBaseProtocol } from '../utils/oceanBaseProtocol';
import { isOracleLikeDialect } from '../utils/sqlDialect';
export type TemporalPickerType = 'datetime' | 'date' | 'time' | 'year' | null;
export type TemporalConnectionLike = Pick<ConnectionConfig, 'type' | 'driver' | 'oceanBaseProtocol'> | null | undefined;
export const TEMPORAL_FORMATS: Record<string, string> = {
datetime: 'YYYY-MM-DD HH:mm:ss',
@@ -10,16 +13,100 @@ export const TEMPORAL_FORMATS: Record<string, string> = {
year: 'YYYY',
};
export const isTemporalColumnType = (columnType?: string, dbType?: string): boolean => {
return !!getTemporalPickerType(columnType, dbType);
const TEMPORAL_DATE_TIME_RE =
/^(\d{4}-\d{2}-\d{2})[T ](\d{2}:\d{2}:\d{2})(?:\.(\d{1,9}))?(?:\s*(?:Z|[+-]\d{2}:?\d{2})(?:\s+[A-Za-z_\/+-]+)?)?$/;
const temporalFractionMetaKey = Symbol('temporalFractionMeta');
type DayjsWithTemporalFractionMeta = dayjs.Dayjs & {
[temporalFractionMetaKey]?: string;
};
export const getTemporalPickerType = (columnType?: string, dbType?: string): TemporalPickerType => {
const parseTemporalDateTimeParts = (value: string): { datePart: string; timePart: string; fractionDigits: string } | null => {
const match = String(value || '').trim().match(TEMPORAL_DATE_TIME_RE);
if (!match) return null;
return {
datePart: match[1],
timePart: match[2],
fractionDigits: match[3] || '',
};
};
const attachTemporalFractionMeta = (value: dayjs.Dayjs, fractionDigits: string): dayjs.Dayjs => {
if (!value?.isValid?.()) return value;
const normalizedDigits = String(fractionDigits || '');
if (!normalizedDigits) return value;
(value as DayjsWithTemporalFractionMeta)[temporalFractionMetaKey] = normalizedDigits;
return value;
};
const getTemporalFractionMeta = (value: dayjs.Dayjs | null | undefined): string => {
if (!value || !value.isValid()) return '';
return String((value as DayjsWithTemporalFractionMeta)[temporalFractionMetaKey] || '');
};
const buildDayjsParseTextForDateTime = (datePart: string, timePart: string, fractionDigits: string): string => {
if (!fractionDigits) return `${datePart} ${timePart}`;
const milliseconds = fractionDigits.slice(0, 3).padEnd(3, '0');
return `${datePart} ${timePart}.${milliseconds}`;
};
const formatDateTimeWithFractionMeta = (value: dayjs.Dayjs): string => {
const base = value.format(TEMPORAL_FORMATS.datetime);
const fractionDigits = getTemporalFractionMeta(value);
if (!fractionDigits) return base;
const milliseconds = String(value.millisecond()).padStart(3, '0');
if (fractionDigits.length <= 3) {
return `${base}.${milliseconds.slice(0, fractionDigits.length)}`;
}
return `${base}.${milliseconds}${fractionDigits.slice(3)}`;
};
export const getTemporalPickerFormat = (
pickerType: TemporalPickerType,
): string | ((value: dayjs.Dayjs) => string) | Array<string | ((value: dayjs.Dayjs) => string)> => {
if (pickerType !== 'datetime') {
return TEMPORAL_FORMATS[pickerType || 'datetime'];
}
return [
(value: dayjs.Dayjs) => formatDateTimeWithFractionMeta(value),
'YYYY-MM-DD HH:mm:ss.SSSSSS',
'YYYY-MM-DD HH:mm:ss.SSS',
'YYYY-MM-DD HH:mm:ss',
];
};
export const isTemporalColumnType = (
columnType?: string,
dbType?: string,
connectionConfig?: TemporalConnectionLike,
): boolean => {
return !!getTemporalPickerType(columnType, dbType, connectionConfig);
};
const isOceanBaseOracleDateOnlyConnection = (connectionConfig?: TemporalConnectionLike): boolean => {
if (!connectionConfig) return false;
const type = String(connectionConfig.type || '').trim().toLowerCase();
const driver = String(connectionConfig.driver || '').trim().toLowerCase();
return (type === 'oceanbase' || driver === 'oceanbase')
&& normalizeOceanBaseProtocol(connectionConfig.oceanBaseProtocol) === 'oracle';
};
export const getTemporalPickerType = (
columnType?: string,
dbType?: string,
connectionConfig?: TemporalConnectionLike,
): TemporalPickerType => {
const raw = String(columnType || '').trim().toLowerCase();
if (!raw) return null;
if (raw.includes('datetime') || raw.includes('timestamp')) return 'datetime';
const base = raw.split(/[ (]/)[0];
if (base === 'date') return isOracleLikeDialect(String(dbType || '')) ? 'datetime' : 'date';
if (base === 'date') {
if (isOracleLikeDialect(String(dbType || '')) && !isOceanBaseOracleDateOnlyConnection(connectionConfig)) {
return 'datetime';
}
return 'date';
}
if (base === 'time') return 'time';
if (base === 'year') return 'year';
return null;
@@ -29,13 +116,32 @@ export const parseToDayjs = (val: any, pickerType: TemporalPickerType): dayjs.Da
if (val === null || val === undefined || val === '') return null;
const str = String(val).trim();
if (!str || /^0{4}-0{2}-0{2}/.test(str)) return null;
if (pickerType === 'datetime') {
const parts = parseTemporalDateTimeParts(str);
if (parts) {
const parsed = dayjs(buildDayjsParseTextForDateTime(parts.datePart, parts.timePart, parts.fractionDigits));
if (parsed.isValid()) {
return attachTemporalFractionMeta(parsed, parts.fractionDigits);
}
}
}
const fmt = TEMPORAL_FORMATS[pickerType || 'datetime'];
const d = dayjs(str, fmt);
return d.isValid() ? d : dayjs(str).isValid() ? dayjs(str) : null;
if (d.isValid()) {
const parts = pickerType === 'datetime' ? parseTemporalDateTimeParts(str) : null;
return parts ? attachTemporalFractionMeta(d, parts.fractionDigits) : d;
}
const fallback = dayjs(str);
if (!fallback.isValid()) return null;
const parts = pickerType === 'datetime' ? parseTemporalDateTimeParts(str) : null;
return parts ? attachTemporalFractionMeta(fallback, parts.fractionDigits) : fallback;
};
export const formatFromDayjs = (val: dayjs.Dayjs | null, pickerType: TemporalPickerType): string => {
if (!val || !val.isValid()) return '';
if (pickerType === 'datetime') {
return formatDateTimeWithFractionMeta(val);
}
const fmt = TEMPORAL_FORMATS[pickerType || 'datetime'];
return val.format(fmt);
};

View File

@@ -1,6 +1,6 @@
import { describe, expect, it } from 'vitest';
import { getDataSourceCapabilities } from './dataSourceCapabilities';
import { getDataSourceCapabilities, shouldShowOceanBaseRowNumberColumn } from './dataSourceCapabilities';
describe('dataSourceCapabilities', () => {
it('treats Oracle table preview totals as manual exact count plus approximate metadata count', () => {
@@ -258,6 +258,14 @@ describe('dataSourceCapabilities', () => {
});
});
it('shows row numbers for OceanBase datasources regardless of protocol normalization', () => {
expect(shouldShowOceanBaseRowNumberColumn({ type: 'oceanbase' })).toBe(true);
expect(shouldShowOceanBaseRowNumberColumn({ type: 'oceanbase', oceanBaseProtocol: 'oracle' })).toBe(true);
expect(shouldShowOceanBaseRowNumberColumn({ type: 'custom', driver: 'oceanbase', oceanBaseProtocol: 'oracle' })).toBe(true);
expect(shouldShowOceanBaseRowNumberColumn({ type: 'oracle' })).toBe(false);
expect(shouldShowOceanBaseRowNumberColumn({ type: 'mysql' })).toBe(false);
});
it('treats custom OceanBase Oracle driver as Oracle capabilities', () => {
expect(getDataSourceCapabilities({
type: 'custom',

View File

@@ -82,6 +82,13 @@ export const resolveDataSourceType = (config: ConnectionLike): string => {
return type;
};
export const shouldShowOceanBaseRowNumberColumn = (config: ConnectionLike): boolean => {
if (!config) return false;
const type = normalizeDataSourceToken(String(config.type || ''));
const driver = normalizeDataSourceToken(String(config.driver || ''));
return type === 'oceanbase' || driver === 'oceanbase';
};
const SQL_QUERY_EXPORT_TYPES = new Set([
'mysql',
'goldendb',

View File

@@ -1511,9 +1511,23 @@ func resolveCreateStatementWithFallback(dbInst db.Database, config connection.Co
sqlStr, sourceErr := dbInst.GetCreateStatement(metadataSchemaName, metadataTableName)
if sourceErr == nil && !shouldFallbackCreateStatement(dbType, sqlStr) {
if strings.TrimSpace(sqlStr) != "" {
return sqlStr, nil
}
if isOceanBaseOracleProtocol(config) {
if showDDL, ok := tryGetOceanBaseOracleShowCreateStatement(dbInst, metadataSchemaName, metadataTableName); ok {
return showDDL, nil
}
}
return sqlStr, nil
}
if isOceanBaseOracleProtocol(config) {
if showDDL, ok := tryGetOceanBaseOracleShowCreateStatement(dbInst, metadataSchemaName, metadataTableName); ok {
return showDDL, nil
}
}
if supportsViewCreateStatementLookup(dbType) {
if viewDDL, ok := tryGetViewCreateStatement(dbInst, config, dbName, ddlSchemaName, ddlTableName); ok {
return viewDDL, nil
@@ -1545,6 +1559,42 @@ func resolveCreateStatementWithFallback(dbInst db.Database, config connection.Co
return fallbackDDL, nil
}
func tryGetOceanBaseOracleShowCreateStatement(dbInst db.Database, schemaName string, tableName string) (string, bool) {
query := "SHOW CREATE TABLE " + quoteOracleMetadataTableRef(schemaName, tableName)
data, _, err := dbInst.Query(query)
if err != nil {
return "", false
}
for _, row := range data {
for _, key := range []string{"Create Table", "CREATE TABLE", "CREATE_TABLE", "DDL", "ddl"} {
if val, ok := row[key]; ok {
text := strings.TrimSpace(fmt.Sprintf("%v", val))
if text != "" && !strings.EqualFold(text, "<nil>") {
return text, true
}
}
}
for _, val := range row {
text := strings.TrimSpace(fmt.Sprintf("%v", val))
lower := strings.ToLower(text)
if strings.HasPrefix(lower, "create table") ||
strings.HasPrefix(lower, "create view") ||
strings.HasPrefix(lower, "create or replace view") {
return text, true
}
}
if len(row) == 1 {
for _, val := range row {
text := strings.TrimSpace(fmt.Sprintf("%v", val))
if text != "" && !strings.EqualFold(text, "<nil>") {
return text, true
}
}
}
}
return "", false
}
func supportsCreateStatementFallback(dbType string) bool {
switch dbType {
case "postgres", "kingbase", "highgo", "vastbase", "opengauss", "gaussdb", "sqlserver":
@@ -1720,10 +1770,318 @@ func (a *App) DBGetColumns(config connection.ConnectionConfig, dbName string, ta
logger.Error(err, "DBGetColumns 获取列定义失败:%s 表=%s.%s schema=%s pureTable=%s", formatConnSummary(runConfig), dbName, tableName, schemaName, pureTableName)
return connection.QueryResult{Success: false, Message: err.Error()}
}
if len(columns) == 0 && resolveDDLDBType(config) == "oracle" {
if inferred, inferErr := inferOracleColumnsFromDictionary(dbInst, schemaName, pureTableName); inferErr == nil && len(inferred) > 0 {
columns = inferred
}
if len(columns) == 0 {
if inferred, inferErr := inferOracleColumnsFromEmptySelect(dbInst, schemaName, pureTableName); inferErr == nil && len(inferred) > 0 {
columns = inferred
}
}
}
return connection.QueryResult{Success: true, Data: ensureNonNilSlice(columns)}
}
func inferOracleColumnsFromDictionary(dbInst db.Database, schemaName string, tableName string) ([]connection.ColumnDefinition, error) {
var lastErr error
for _, candidate := range appOracleMetadataNamePairs(schemaName, tableName) {
data, _, err := dbInst.Query(buildAppOracleColumnsQuery(candidate.schema, candidate.table))
if err != nil {
lastErr = err
continue
}
columns := parseAppOracleColumns(data)
if len(columns) > 0 {
return columns, nil
}
}
if lastErr != nil {
return nil, lastErr
}
return nil, fmt.Errorf("未获取到字段定义")
}
type appOracleMetadataNamePair struct {
schema string
table string
}
func appOracleMetadataNamePairs(schemaName string, tableName string) []appOracleMetadataNamePair {
rawSchema := strings.TrimSpace(schemaName)
rawTable := strings.TrimSpace(tableName)
if rawTable == "" {
return nil
}
upperSchema := strings.ToUpper(rawSchema)
upperTable := strings.ToUpper(rawTable)
pairs := make([]appOracleMetadataNamePair, 0, 4)
seen := map[string]struct{}{}
add := func(schema string, table string) {
key := schema + "\x00" + table
if _, exists := seen[key]; exists {
return
}
seen[key] = struct{}{}
pairs = append(pairs, appOracleMetadataNamePair{schema: schema, table: table})
}
add(rawSchema, rawTable)
add(upperSchema, upperTable)
add(rawSchema, upperTable)
add(upperSchema, rawTable)
return pairs
}
func buildAppOracleColumnsQuery(schema string, table string) string {
metadataTableName := escapeAppOracleMetadataLiteral(table)
metadataSchemaName := escapeAppOracleMetadataLiteral(schema)
if strings.TrimSpace(schema) == "" {
return fmt.Sprintf(`SELECT c.column_name AS "COLUMN_NAME", c.data_type AS "DATA_TYPE", c.data_length AS "DATA_LENGTH", c.char_length AS "CHAR_LENGTH", c.data_precision AS "DATA_PRECISION", c.data_scale AS "DATA_SCALE", c.nullable AS "NULLABLE", c.data_default AS "DATA_DEFAULT",
CASE WHEN pk.column_name IS NOT NULL THEN 'PRI' ELSE '' END AS "COLUMN_KEY",
cc.comments AS "COMMENT"
FROM user_tab_columns c
LEFT JOIN user_col_comments cc
ON cc.table_name = c.table_name AND cc.column_name = c.column_name
LEFT JOIN (
SELECT cols.table_name, cols.column_name
FROM user_constraints cons
JOIN user_cons_columns cols
ON cons.constraint_name = cols.constraint_name
WHERE cons.constraint_type = 'P'
) pk ON c.table_name = pk.table_name AND c.column_name = pk.column_name
WHERE c.table_name = '%s'
ORDER BY c.column_id`, metadataTableName)
}
return fmt.Sprintf(`SELECT c.column_name AS "COLUMN_NAME", c.data_type AS "DATA_TYPE", c.data_length AS "DATA_LENGTH", c.char_length AS "CHAR_LENGTH", c.data_precision AS "DATA_PRECISION", c.data_scale AS "DATA_SCALE", c.nullable AS "NULLABLE", c.data_default AS "DATA_DEFAULT",
CASE WHEN pk.column_name IS NOT NULL THEN 'PRI' ELSE '' END AS "COLUMN_KEY",
cc.comments AS "COMMENT"
FROM all_tab_columns c
LEFT JOIN all_col_comments cc
ON cc.owner = c.owner AND cc.table_name = c.table_name AND cc.column_name = c.column_name
LEFT JOIN (
SELECT cols.owner, cols.table_name, cols.column_name
FROM all_constraints cons
JOIN all_cons_columns cols
ON cons.owner = cols.owner AND cons.constraint_name = cols.constraint_name
WHERE cons.constraint_type = 'P'
) pk ON c.owner = pk.owner AND c.table_name = pk.table_name AND c.column_name = pk.column_name
WHERE c.owner = '%s' AND c.table_name = '%s'
ORDER BY c.column_id`, metadataSchemaName, metadataTableName)
}
func parseAppOracleColumns(data []map[string]interface{}) []connection.ColumnDefinition {
columns := make([]connection.ColumnDefinition, 0, len(data))
for _, row := range data {
name := appOracleRowString(row, "COLUMN_NAME", "column_name")
if strings.TrimSpace(name) == "" {
continue
}
defaultValue := appOracleRowString(row, "DATA_DEFAULT", "COLUMN_DEFAULT", "data_default", "column_default")
col := connection.ColumnDefinition{
Name: name,
Type: formatAppOracleColumnType(row),
Nullable: normalizeAppOracleNullable(appOracleRowString(row, "NULLABLE", "nullable")),
Key: appOracleRowString(row, "COLUMN_KEY", "column_key", "KEY", "key"),
Extra: appOracleAutoIncrementExtra(defaultValue),
Comment: appOracleRowString(row, "COMMENT", "COMMENTS", "comment", "comments"),
}
if defaultValue != "" {
col.Default = &defaultValue
}
columns = append(columns, col)
}
return columns
}
func formatAppOracleColumnType(row map[string]interface{}) string {
dataType := appOracleRowString(row, "DATA_TYPE", "TYPE_NAME", "data_type", "type_name")
if dataType == "" || strings.Contains(dataType, "(") {
return dataType
}
upperType := strings.ToUpper(dataType)
if isAppOracleLengthQualifiedType(upperType) {
if charLength, ok := appOracleRowInt(row, "CHAR_LENGTH", "CHAR_COL_DECL_LENGTH", "char_length", "char_col_decl_length"); ok && charLength > 0 {
return fmt.Sprintf("%s(%d)", dataType, charLength)
}
if dataLength, ok := appOracleRowInt(row, "DATA_LENGTH", "data_length"); ok && dataLength > 0 {
return fmt.Sprintf("%s(%d)", dataType, dataLength)
}
}
if strings.Contains(upperType, "NUMBER") || strings.Contains(upperType, "DECIMAL") || strings.Contains(upperType, "NUMERIC") {
precision, hasPrecision := appOracleRowInt(row, "DATA_PRECISION", "NUMERIC_PRECISION", "data_precision", "numeric_precision")
if hasPrecision && precision > 0 {
scale, hasScale := appOracleRowInt(row, "DATA_SCALE", "NUMERIC_SCALE", "data_scale", "numeric_scale")
if hasScale && scale > 0 {
return fmt.Sprintf("%s(%d,%d)", dataType, precision, scale)
}
return fmt.Sprintf("%s(%d)", dataType, precision)
}
}
return dataType
}
func isAppOracleLengthQualifiedType(upperType string) bool {
switch strings.TrimSpace(upperType) {
case "CHAR", "NCHAR", "VARCHAR", "VARCHAR2", "NVARCHAR", "NVARCHAR2", "RAW", "BINARY", "VARBINARY":
return true
default:
return strings.Contains(upperType, "CHARACTER")
}
}
func normalizeAppOracleNullable(nullable string) string {
switch strings.ToUpper(strings.TrimSpace(nullable)) {
case "N", "NO":
return "NO"
case "Y", "YES":
return "YES"
default:
return strings.TrimSpace(nullable)
}
}
func appOracleAutoIncrementExtra(defaultValue string) string {
if strings.Contains(strings.ToUpper(strings.TrimSpace(defaultValue)), "NEXTVAL") {
return "auto_increment"
}
return ""
}
func appOracleRowValue(row map[string]interface{}, names ...string) interface{} {
for _, name := range names {
if value, ok := row[name]; ok {
return value
}
}
for key, value := range row {
for _, name := range names {
if strings.EqualFold(key, name) {
return value
}
}
}
return nil
}
func appOracleRowString(row map[string]interface{}, names ...string) string {
return appOracleValueString(appOracleRowValue(row, names...))
}
func appOracleValueString(value interface{}) string {
if value == nil {
return ""
}
switch typed := value.(type) {
case []byte:
return strings.TrimSpace(string(typed))
case string:
return strings.TrimSpace(typed)
default:
text := strings.TrimSpace(fmt.Sprintf("%v", typed))
if strings.EqualFold(text, "<nil>") {
return ""
}
return text
}
}
func appOracleRowInt(row map[string]interface{}, names ...string) (int, bool) {
value := appOracleRowValue(row, names...)
switch typed := value.(type) {
case int:
return typed, true
case int8:
return int(typed), true
case int16:
return int(typed), true
case int32:
return int(typed), true
case int64:
return int(typed), true
case uint:
return int(typed), true
case uint8:
return int(typed), true
case uint16:
return int(typed), true
case uint32:
return int(typed), true
case uint64:
return int(typed), true
case float32:
return int(typed), true
case float64:
return int(typed), true
case []byte:
parsed, err := strconv.Atoi(strings.TrimSpace(string(typed)))
return parsed, err == nil
case string:
parsed, err := strconv.Atoi(strings.TrimSpace(typed))
return parsed, err == nil
default:
return 0, false
}
}
func escapeAppOracleMetadataLiteral(text string) string {
return strings.ReplaceAll(strings.TrimSpace(text), "'", "''")
}
func inferOracleColumnsFromEmptySelect(dbInst db.Database, schemaName string, tableName string) ([]connection.ColumnDefinition, error) {
table := strings.TrimSpace(tableName)
if table == "" {
return nil, fmt.Errorf("表名不能为空")
}
query := "SELECT * FROM " + quoteOracleMetadataTableRef(schemaName, table) + " WHERE 1 = 0"
_, fields, err := dbInst.Query(query)
if err != nil {
return nil, err
}
if len(fields) == 0 {
return nil, fmt.Errorf("未获取到字段定义")
}
columns := make([]connection.ColumnDefinition, 0, len(fields))
for _, field := range fields {
name := strings.TrimSpace(field)
if name == "" {
continue
}
columns = append(columns, connection.ColumnDefinition{
Name: name,
Nullable: "",
Key: "",
Extra: "",
Comment: "",
})
}
if len(columns) == 0 {
return nil, fmt.Errorf("未获取到字段定义")
}
return columns, nil
}
func quoteOracleMetadataIdentifier(ident string) string {
return `"` + strings.ReplaceAll(strings.TrimSpace(ident), `"`, `""`) + `"`
}
func quoteOracleMetadataTableRef(schemaName string, tableName string) string {
tableRef := quoteOracleMetadataIdentifier(tableName)
if strings.TrimSpace(schemaName) != "" {
return quoteOracleMetadataIdentifier(schemaName) + "." + tableRef
}
return tableRef
}
func (a *App) DBGetIndexes(config connection.ConnectionConfig, dbName string, tableName string) connection.QueryResult {
runConfig := normalizeRunConfig(config, dbName)

View File

@@ -167,6 +167,14 @@ func TestFormatConnSummary_DefaultTimeout(t *testing.T) {
}
func TestDBReleaseConnectionClosesAllDatabaseCacheEntriesForSameInstance(t *testing.T) {
proxySnapshot := currentGlobalProxyConfig()
if _, err := setGlobalProxyConfig(false, proxySnapshot.Proxy); err != nil {
t.Fatalf("disable global proxy failed: %v", err)
}
t.Cleanup(func() {
_, _ = setGlobalProxyConfig(proxySnapshot.Enabled, proxySnapshot.Proxy)
})
app := NewApp()
mainConfig := connection.ConnectionConfig{Type: "mysql", Host: "127.0.0.1", Port: 3306, User: "root", Database: "main"}
analyticsConfig := mainConfig

View File

@@ -332,6 +332,37 @@ func TestResolveCreateStatementWithFallback_ReturnsCreateViewDirectly(t *testing
}
}
func TestResolveCreateStatementWithFallback_OceanBaseOracleUsesShowCreateWhenAgentDDLIsEmpty(t *testing.T) {
t.Parallel()
dbInst := &fakeCreateStatementDB{
createSQL: "",
queryRows: []map[string]interface{}{
{"Create Table": `CREATE TABLE "SYS"."test" ("id" NUMBER)`},
},
}
ddl, err := resolveCreateStatementWithFallback(dbInst, connection.ConnectionConfig{
Type: "oceanbase",
ConnectionParams: "protocol=oracle",
}, "SYS", "SYS.test")
if err != nil {
t.Fatalf("resolveCreateStatementWithFallback() unexpected error: %v", err)
}
if ddl != `CREATE TABLE "SYS"."test" ("id" NUMBER)` {
t.Fatalf("expected SHOW CREATE TABLE fallback DDL, got: %s", ddl)
}
if dbInst.createSchema != "SYS" || dbInst.createTable != "test" {
t.Fatalf("expected metadata target SYS.test, got %q.%q", dbInst.createSchema, dbInst.createTable)
}
if len(dbInst.queries) != 1 || dbInst.queries[0] != `SHOW CREATE TABLE "SYS"."test"` {
t.Fatalf("expected SHOW CREATE TABLE query, got: %v", dbInst.queries)
}
if dbInst.columnsCalls != 0 {
t.Fatalf("OceanBase Oracle SHOW CREATE fallback should not call GetColumns, calls=%d", dbInst.columnsCalls)
}
}
func TestResolveCreateStatementWithFallback_PGLikeViewHelperBeforeColumnFallback(t *testing.T) {
t.Parallel()

View File

@@ -28,6 +28,11 @@ type fakeMetadataRetryDB struct {
indexes []connection.IndexDefinition
columnsErr error
indexesErr error
queryResults []fakeMetadataQueryResult
queryRows []map[string]interface{}
queryFields []string
queryErr error
queries []string
columnCalls int
indexCalls int
columnSchema string
@@ -36,11 +41,27 @@ type fakeMetadataRetryDB struct {
indexTable string
}
type fakeMetadataQueryResult struct {
match string
rows []map[string]interface{}
fields []string
err error
}
func (f *fakeMetadataRetryDB) Connect(config connection.ConnectionConfig) error { return nil }
func (f *fakeMetadataRetryDB) Close() error { return nil }
func (f *fakeMetadataRetryDB) Ping() error { return nil }
func (f *fakeMetadataRetryDB) Query(query string) ([]map[string]interface{}, []string, error) {
return nil, nil, nil
f.queries = append(f.queries, query)
for _, result := range f.queryResults {
if result.match == "" || strings.Contains(query, result.match) {
return result.rows, result.fields, result.err
}
}
if f.queryErr != nil {
return nil, nil, f.queryErr
}
return f.queryRows, f.queryFields, nil
}
func (f *fakeMetadataRetryDB) Exec(query string) (int64, error) { return 0, nil }
func (f *fakeMetadataRetryDB) GetDatabases() ([]string, error) { return nil, nil }
@@ -238,6 +259,132 @@ func TestDBGetColumnsKeepsDatabaseForMySQLMetadata(t *testing.T) {
}
}
func TestDBGetColumnsInfersOceanBaseOracleFieldsWhenAgentMetadataIsEmpty(t *testing.T) {
originalNewDatabaseFunc := newDatabaseFunc
originalResolveDialConfigWithProxyFunc := resolveDialConfigWithProxyFunc
t.Cleanup(func() {
newDatabaseFunc = originalNewDatabaseFunc
resolveDialConfigWithProxyFunc = originalResolveDialConfigWithProxyFunc
})
dbInst := &fakeMetadataRetryDB{
columns: []connection.ColumnDefinition{},
queryResults: []fakeMetadataQueryResult{
{
match: "FROM all_tab_columns c",
rows: []map[string]interface{}{
{
"COLUMN_NAME": "id",
"DATA_TYPE": "NUMBER",
"NULLABLE": "N",
"DATA_DEFAULT": "SEQUENCE.NEXTVAL",
"COLUMN_KEY": "PRI",
"COMMENT": "",
"DATA_PRECISION": nil,
"DATA_SCALE": nil,
},
{
"COLUMN_NAME": "new_col_1",
"DATA_TYPE": "VARCHAR2",
"CHAR_LENGTH": 255,
"NULLABLE": "Y",
"COLUMN_KEY": "",
"COMMENT": "",
},
},
fields: []string{"COLUMN_NAME", "DATA_TYPE", "DATA_LENGTH", "CHAR_LENGTH", "DATA_PRECISION", "DATA_SCALE", "NULLABLE", "DATA_DEFAULT", "COLUMN_KEY", "COMMENT"},
},
},
}
newDatabaseFunc = func(dbType string) (db.Database, error) {
return dbInst, nil
}
resolveDialConfigWithProxyFunc = func(raw connection.ConnectionConfig) (connection.ConnectionConfig, error) {
return raw, nil
}
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
result := app.DBGetColumns(connection.ConnectionConfig{
Type: "oceanbase",
Host: "127.0.0.1",
Port: 12881,
User: "SYS",
ConnectionParams: "protocol=oracle",
}, "SYS", "SYS.test")
if !result.Success {
t.Fatalf("expected DBGetColumns success, got failure: %s", result.Message)
}
if dbInst.columnSchema != "SYS" || dbInst.columnTable != "test" {
t.Fatalf("expected OceanBase Oracle metadata to split schema/table, got %q.%q", dbInst.columnSchema, dbInst.columnTable)
}
if len(dbInst.queries) != 1 || !strings.Contains(dbInst.queries[0], "FROM all_tab_columns c") {
t.Fatalf("expected dictionary metadata fallback query, got %v", dbInst.queries)
}
columns, ok := result.Data.([]connection.ColumnDefinition)
if !ok {
t.Fatalf("expected []connection.ColumnDefinition, got %T", result.Data)
}
if len(columns) != 2 || columns[0].Name != "id" || columns[1].Name != "new_col_1" {
t.Fatalf("unexpected inferred columns: %#v", columns)
}
if columns[0].Type != "NUMBER" || columns[0].Nullable != "NO" || columns[0].Key != "PRI" || columns[0].Extra != "auto_increment" {
t.Fatalf("expected id to keep type/not-null/primary-key/auto-increment metadata, got %#v", columns[0])
}
if columns[0].Default == nil || *columns[0].Default != "SEQUENCE.NEXTVAL" {
t.Fatalf("expected id default to keep sequence nextval, got %#v", columns[0].Default)
}
if columns[1].Type != "VARCHAR2(255)" || columns[1].Nullable != "YES" || columns[1].Key != "" {
t.Fatalf("expected new_col_1 to keep varchar nullable metadata, got %#v", columns[1])
}
}
func TestDBGetColumnsFallsBackToEmptySelectWhenOceanBaseOracleDictionaryIsEmpty(t *testing.T) {
originalNewDatabaseFunc := newDatabaseFunc
originalResolveDialConfigWithProxyFunc := resolveDialConfigWithProxyFunc
t.Cleanup(func() {
newDatabaseFunc = originalNewDatabaseFunc
resolveDialConfigWithProxyFunc = originalResolveDialConfigWithProxyFunc
})
dbInst := &fakeMetadataRetryDB{
columns: []connection.ColumnDefinition{},
queryResults: []fakeMetadataQueryResult{
{match: "FROM all_tab_columns c", rows: []map[string]interface{}{}},
{match: `SELECT * FROM "SYS"."test" WHERE 1 = 0`, fields: []string{"id", "new_col_1"}},
},
}
newDatabaseFunc = func(dbType string) (db.Database, error) {
return dbInst, nil
}
resolveDialConfigWithProxyFunc = func(raw connection.ConnectionConfig) (connection.ConnectionConfig, error) {
return raw, nil
}
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
result := app.DBGetColumns(connection.ConnectionConfig{
Type: "oceanbase",
Host: "127.0.0.1",
Port: 12881,
User: "SYS",
ConnectionParams: "protocol=oracle",
}, "SYS", "SYS.test")
if !result.Success {
t.Fatalf("expected DBGetColumns success, got failure: %s", result.Message)
}
if len(dbInst.queries) < 2 {
t.Fatalf("expected dictionary and empty-select fallback queries, got %v", dbInst.queries)
}
columns, ok := result.Data.([]connection.ColumnDefinition)
if !ok {
t.Fatalf("expected []connection.ColumnDefinition, got %T", result.Data)
}
if len(columns) != 2 || columns[0].Name != "id" || columns[1].Name != "new_col_1" {
t.Fatalf("unexpected inferred columns: %#v", columns)
}
}
func TestDBGetColumnsKeepsDuckDBQualifiedTableMetadata(t *testing.T) {
originalNewDatabaseFunc := newDatabaseFunc
originalResolveDialConfigWithProxyFunc := resolveDialConfigWithProxyFunc

View File

@@ -135,11 +135,16 @@ type TransactionExecerProvider interface {
}
type sqlConnStatementExecer struct {
conn *sql.Conn
conn *sql.Conn
scanDialect string
}
func NewSQLConnStatementExecer(conn *sql.Conn) StatementExecer {
return &sqlConnStatementExecer{conn: conn}
return NewSQLConnStatementExecerWithDialect(conn, "")
}
func NewSQLConnStatementExecerWithDialect(conn *sql.Conn, scanDialect string) StatementExecer {
return &sqlConnStatementExecer{conn: conn, scanDialect: scanDialect}
}
func (e *sqlConnStatementExecer) ExecContext(ctx context.Context, query string) (int64, error) {
@@ -166,7 +171,7 @@ func (e *sqlConnStatementExecer) QueryContext(ctx context.Context, query string)
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
return scanRowsForDialect(rows, e.scanDialect)
}
func (e *sqlConnStatementExecer) Query(query string) ([]map[string]interface{}, []string, error) {
@@ -182,7 +187,7 @@ func (e *sqlConnStatementExecer) QueryMultiContext(ctx context.Context, query st
return nil, err
}
defer rows.Close()
return scanMultiRows(rows)
return scanMultiRowsForDialect(rows, e.scanDialect)
}
func (e *sqlConnStatementExecer) QueryMulti(query string) ([]connection.ResultSetData, error) {
@@ -206,13 +211,19 @@ type sqlConnTransactionExecer struct {
done bool
commitSQL string
rollbackSQL string
scanDialect string
}
func NewSQLConnTransactionExecer(conn *sql.Conn, commitSQL string, rollbackSQL string) TransactionExecer {
return NewSQLConnTransactionExecerWithDialect(conn, commitSQL, rollbackSQL, "")
}
func NewSQLConnTransactionExecerWithDialect(conn *sql.Conn, commitSQL string, rollbackSQL string, scanDialect string) TransactionExecer {
return &sqlConnTransactionExecer{
conn: conn,
commitSQL: strings.TrimSpace(commitSQL),
rollbackSQL: strings.TrimSpace(rollbackSQL),
scanDialect: scanDialect,
}
}
@@ -257,7 +268,7 @@ func (e *sqlConnTransactionExecer) QueryContext(ctx context.Context, query strin
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
return scanRowsForDialect(rows, e.scanDialect)
}
func (e *sqlConnTransactionExecer) Query(query string) ([]map[string]interface{}, []string, error) {
@@ -274,7 +285,7 @@ func (e *sqlConnTransactionExecer) QueryMultiContext(ctx context.Context, query
return nil, err
}
defer rows.Close()
return scanMultiRows(rows)
return scanMultiRowsForDialect(rows, e.scanDialect)
}
func (e *sqlConnTransactionExecer) QueryMulti(query string) ([]connection.ResultSetData, error) {

View File

@@ -0,0 +1,87 @@
package db
import (
"context"
"database/sql"
"reflect"
"testing"
)
func openScanRowsDuplicateSQLConn(t *testing.T) *sql.Conn {
t.Helper()
registerScanRowsDuplicateDriverOnce.Do(func() {
sql.Register(scanRowsDuplicateDriverName, scanRowsDuplicateDriver{})
})
dbConn, err := sql.Open(scanRowsDuplicateDriverName, "")
if err != nil {
t.Fatalf("open duplicate scan rows db failed: %v", err)
}
t.Cleanup(func() {
_ = dbConn.Close()
})
conn, err := dbConn.Conn(context.Background())
if err != nil {
t.Fatalf("open sql conn failed: %v", err)
}
t.Cleanup(func() {
_ = conn.Close()
})
return conn
}
func TestSQLConnStatementExecerWithDialectDecodesOceanBaseOracleTimestamp(t *testing.T) {
t.Parallel()
conn := openScanRowsDuplicateSQLConn(t)
execer, ok := NewSQLConnStatementExecerWithDialect(conn, oceanBaseOracleScanDialect).(StatementMultiResultQueryExecer)
if !ok {
t.Fatal("statement execer should support multi-result query")
}
results, err := execer.QueryMultiContext(context.Background(), "SELECT timestamp_precision_columns")
if err != nil {
t.Fatalf("query multi failed: %v", err)
}
if len(results) != 1 {
t.Fatalf("expected one result set, got=%d", len(results))
}
if !reflect.DeepEqual(results[0].Columns, []string{"created_at"}) {
t.Fatalf("unexpected columns: %v", results[0].Columns)
}
if len(results[0].Rows) != 1 {
t.Fatalf("expected one row, got=%d", len(results[0].Rows))
}
if got := results[0].Rows[0]["created_at"]; got != "2026-06-16T12:34:56.123456Z" {
t.Fatalf("statement execer should decode OceanBase Oracle TIMESTAMP(6), got=%v(%T)", got, got)
}
}
func TestSQLConnTransactionExecerWithDialectDecodesOceanBaseOracleTimestamp(t *testing.T) {
t.Parallel()
conn := openScanRowsDuplicateSQLConn(t)
execer, ok := NewSQLConnTransactionExecerWithDialect(conn, "COMMIT", "ROLLBACK", oceanBaseOracleScanDialect).(StatementMultiResultQueryExecer)
if !ok {
t.Fatal("transaction execer should support multi-result query")
}
results, err := execer.QueryMultiContext(context.Background(), "SELECT timestamp_precision_columns")
if err != nil {
t.Fatalf("query multi failed: %v", err)
}
if len(results) != 1 {
t.Fatalf("expected one result set, got=%d", len(results))
}
if !reflect.DeepEqual(results[0].Columns, []string{"created_at"}) {
t.Fatalf("unexpected columns: %v", results[0].Columns)
}
if len(results[0].Rows) != 1 {
t.Fatalf("expected one row, got=%d", len(results[0].Rows))
}
if got := results[0].Rows[0]["created_at"]; got != "2026-06-16T12:34:56.123456Z" {
t.Fatalf("transaction execer should decode OceanBase Oracle TIMESTAMP(6), got=%v(%T)", got, got)
}
}

View File

@@ -566,7 +566,7 @@ func (o *OceanBaseDB) connectOracleViaTNS(config connection.ConnectionConfig) er
if strings.TrimSpace(runConfig.Database) == "" {
return fmt.Errorf("OceanBase Oracle 协议TNS 路径需要填写服务名Service Name请在连接配置中填写租户监听的服务名例如 ORCL / tenant_oracle 等)")
}
oracleDB := &OracleDB{}
oracleDB := &OracleDB{scanDialect: oceanBaseOracleScanDialect}
if err := oracleDB.Connect(runConfig); err != nil {
return annotateOceanBaseOracleConnectError(err)
}
@@ -660,7 +660,7 @@ func (o *OceanBaseDB) bindConnectedDatabase(db *sql.DB, timeout time.Duration, p
o.conn = nil
o.pingTimeout = 0
if protocol == oceanBaseProtocolOracle {
o.oracle = &OracleDB{conn: db, pingTimeout: timeout}
o.oracle = &OracleDB{conn: db, pingTimeout: timeout, scanDialect: oceanBaseOracleScanDialect}
o.protocol = oceanBaseProtocolOracle
return
}
@@ -871,9 +871,79 @@ func (o *OceanBaseDB) GetTables(dbName string) ([]string, error) {
}
func (o *OceanBaseDB) GetCreateStatement(dbName, tableName string) (string, error) {
if o.protocol == oceanBaseProtocolOracle && o.oracle != nil {
ddl, err := o.oracle.GetCreateStatement(dbName, tableName)
if err == nil && strings.TrimSpace(ddl) != "" {
return ddl, nil
}
showDDL, showErr := o.getOceanBaseOracleShowCreateStatement(dbName, tableName)
if showErr == nil {
return showDDL, nil
}
if err != nil {
return "", fmt.Errorf("%wOceanBase Oracle SHOW CREATE TABLE 兜底失败:%v", err, showErr)
}
return "", showErr
}
return o.activeDatabase().GetCreateStatement(dbName, tableName)
}
func (o *OceanBaseDB) getOceanBaseOracleShowCreateStatement(dbName string, tableName string) (string, error) {
var firstErr error
for _, candidate := range oracleMetadataNamePairs(dbName, tableName) {
query := buildOceanBaseOracleShowCreateTableQuery(candidate.schema, candidate.table)
data, _, err := o.oracle.Query(query)
if err != nil {
if firstErr == nil {
firstErr = err
}
continue
}
if ddl := extractOceanBaseOracleCreateStatement(data); ddl != "" {
return o.oracle.appendOracleCommentDDL(ddl, candidate.schema, candidate.table), nil
}
}
if firstErr != nil {
return "", firstErr
}
return "", fmt.Errorf("未找到建表语句")
}
func buildOceanBaseOracleShowCreateTableQuery(schema string, table string) string {
return "SHOW CREATE TABLE " + quoteOracleTableRef(schema, table)
}
func extractOceanBaseOracleCreateStatement(data []map[string]interface{}) string {
for _, row := range data {
for _, key := range []string{"Create Table", "CREATE TABLE", "CREATE_TABLE", "DDL", "ddl"} {
if val, ok := row[key]; ok {
text := strings.TrimSpace(fmt.Sprintf("%v", val))
if text != "" && !strings.EqualFold(text, "<nil>") {
return text
}
}
}
for _, val := range row {
text := strings.TrimSpace(fmt.Sprintf("%v", val))
lower := strings.ToLower(text)
if strings.HasPrefix(lower, "create table") ||
strings.HasPrefix(lower, "create view") ||
strings.HasPrefix(lower, "create or replace view") {
return text
}
}
if len(row) == 1 {
for _, val := range row {
text := strings.TrimSpace(fmt.Sprintf("%v", val))
if text != "" && !strings.EqualFold(text, "<nil>") {
return text
}
}
}
}
return ""
}
func (o *OceanBaseDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
return o.activeDatabase().GetColumns(dbName, tableName)
}
@@ -907,6 +977,60 @@ func (o *OceanBaseDB) ApplyChanges(tableName string, changes connection.ChangeSe
return fmt.Errorf("当前 OceanBase %s 协议不支持 ApplyChanges", o.protocol)
}
func buildOceanBaseOracleTemporalBind(columnType string, value interface{}) (string, interface{}, bool) {
if value == nil {
return "?", nil, false
}
rawType := strings.ToUpper(strings.TrimSpace(columnType))
if !isOracleTemporalColumnType(rawType) {
return "?", value, false
}
var parsed time.Time
switch typed := value.(type) {
case time.Time:
parsed = typed
case string:
text := strings.TrimSpace(typed)
if text == "" {
return "?", nil, false
}
var ok bool
parsed, ok = parseOracleTemporalString(text)
if !ok {
return "?", value, false
}
default:
return "?", value, false
}
if strings.Contains(rawType, "TIMESTAMP") {
text := parsed.Format("2006-01-02 15:04:05")
format := "YYYY-MM-DD HH24:MI:SS"
if parsed.Nanosecond() != 0 {
text = parsed.Format("2006-01-02 15:04:05.999999999")
text = strings.TrimRight(strings.TrimRight(text, "0"), ".")
format = "YYYY-MM-DD HH24:MI:SS.FF"
}
return fmt.Sprintf("TO_TIMESTAMP(?, '%s')", format), text, true
}
if parsed.Hour() == 0 && parsed.Minute() == 0 && parsed.Second() == 0 && parsed.Nanosecond() == 0 {
return "TO_DATE(?, 'YYYY-MM-DD')", parsed.Format("2006-01-02"), true
}
return "TO_DATE(?, 'YYYY-MM-DD HH24:MI:SS')", parsed.Format("2006-01-02 15:04:05"), true
}
func buildOceanBaseOracleAssignment(columnName string, value interface{}, columnTypeMap map[string]string) (string, []interface{}) {
columnType := columnTypeMap[strings.ToLower(strings.TrimSpace(columnName))]
normalized := normalizeOracleValueForWrite(columnName, value, columnTypeMap)
if expr, bind, ok := buildOceanBaseOracleTemporalBind(columnType, normalized); ok {
return expr, []interface{}{bind}
}
return "?", []interface{}{normalized}
}
// applyOracleChangesMySQLWire 在 OceanBase Oracle 租户的 mysql wire 连接上执行
// DELETE/UPDATE/INSERT使用 Oracle 风格双引号引用标识符 + mysql wire 风格 "?" 占位符。
func (o *OceanBaseDB) applyOracleChangesMySQLWire(tableName string, changes connection.ChangeSet) error {
@@ -959,8 +1083,9 @@ func (o *OceanBaseDB) applyOracleChangesMySQLWire(tableName string, changes conn
args = append(args, v)
continue
}
wheres = append(wheres, fmt.Sprintf("%s = ?", quoteIdent(k)))
args = append(args, normalizeOracleValueForWrite(k, v, columnTypeMap))
valueExpr, valueArgs := buildOceanBaseOracleAssignment(k, v, columnTypeMap)
wheres = append(wheres, fmt.Sprintf("%s = %s", quoteIdent(k), valueExpr))
args = append(args, valueArgs...)
}
return wheres, args
}
@@ -985,8 +1110,9 @@ func (o *OceanBaseDB) applyOracleChangesMySQLWire(tableName string, changes conn
var args []interface{}
for k, v := range update.Values {
sets = append(sets, fmt.Sprintf("%s = ?", quoteIdent(k)))
args = append(args, normalizeOracleValueForWrite(k, v, columnTypeMap))
valueExpr, valueArgs := buildOceanBaseOracleAssignment(k, v, columnTypeMap)
sets = append(sets, fmt.Sprintf("%s = %s", quoteIdent(k), valueExpr))
args = append(args, valueArgs...)
}
if len(sets) == 0 {
@@ -1017,8 +1143,9 @@ func (o *OceanBaseDB) applyOracleChangesMySQLWire(tableName string, changes conn
for k, v := range row {
cols = append(cols, quoteIdent(k))
placeholders = append(placeholders, "?")
args = append(args, normalizeOracleValueForWrite(k, v, columnTypeMap))
valueExpr, valueArgs := buildOceanBaseOracleAssignment(k, v, columnTypeMap)
placeholders = append(placeholders, valueExpr)
args = append(args, valueArgs...)
}
if len(cols) == 0 {

View File

@@ -4,9 +4,11 @@ package db
import (
"context"
"database/sql/driver"
"errors"
"net"
"net/url"
"slices"
"strconv"
"strings"
"testing"
@@ -782,6 +784,81 @@ func TestOceanBaseOracleOBClientApplyChangesUsesMySQLWirePlaceholders(t *testing
}
}
func TestOceanBaseOracleOBClientApplyChangesFormatsTemporalValuesExplicitly(t *testing.T) {
t.Parallel()
dbConn, state := openOracleRecordingDB(t)
oceanbaseDB := &OceanBaseDB{}
oceanbaseDB.bindConnectedDatabase(dbConn, 0, oceanBaseProtocolOracle)
changes := connection.ChangeSet{
Updates: []connection.UpdateRow{{
Keys: map[string]interface{}{
"ID": int64(7),
},
Values: map[string]interface{}{
"UPDATED_AT": "2026-06-16 17:37:08",
},
}},
}
if err := oceanbaseDB.ApplyChanges("APP.USERS", changes); err != nil {
t.Fatalf("ApplyChanges() unexpected error: %v", err)
}
queries := state.snapshotExecQueries()
if len(queries) != 1 {
t.Fatalf("expected one exec query, got %#v", queries)
}
if !strings.Contains(queries[0], `"UPDATED_AT" = TO_TIMESTAMP(?, 'YYYY-MM-DD HH24:MI:SS')`) {
t.Fatalf("expected explicit TO_TIMESTAMP binding for temporal update, got %q", queries[0])
}
executions := state.snapshotExecArgs()
if len(executions) != 1 || len(executions[0]) != 2 {
t.Fatalf("unexpected exec args: %#v", executions)
}
if got, ok := executions[0][0].Value.(string); !ok || got != "2026-06-16 17:37:08" {
t.Fatalf("expected temporal bind arg kept as canonical string, got %#v (%T)", executions[0][0].Value, executions[0][0].Value)
}
}
func TestOceanBaseOracleGetCreateStatementFallsBackToShowCreateTable(t *testing.T) {
t.Parallel()
dbConn, state := openOracleRecordingDB(t)
state.mu.Lock()
state.queryResults[`SHOW CREATE TABLE "SYS"."test"`] = oracleRecordingQueryResult{
columns: []string{"Create Table"},
rows: [][]driver.Value{
{`CREATE TABLE "SYS"."test" ("ID" NUMBER)`},
},
}
state.mu.Unlock()
oceanbaseDB := &OceanBaseDB{}
oceanbaseDB.bindConnectedDatabase(dbConn, 0, oceanBaseProtocolOracle)
ddl, err := oceanbaseDB.GetCreateStatement("SYS", "test")
if err != nil {
t.Fatalf("GetCreateStatement() unexpected error: %v", err)
}
if !strings.Contains(ddl, `CREATE TABLE "SYS"."test"`) {
t.Fatalf("expected SHOW CREATE TABLE DDL, got: %s", ddl)
}
queries := state.snapshotQueries()
if len(queries) < 3 {
t.Fatalf("expected DBMS_METADATA attempts followed by SHOW CREATE TABLE, got: %v", queries)
}
if queries[0] != `SELECT DBMS_METADATA.GET_DDL('TABLE', 'test', 'SYS') as ddl FROM DUAL` {
t.Fatalf("expected original-case DBMS_METADATA first, got: %v", queries)
}
if !slices.Contains(queries, `SHOW CREATE TABLE "SYS"."test"`) {
t.Fatalf("expected SHOW CREATE TABLE fallback, got: %v", queries)
}
}
// 用户通过 ConnectionParams 设置 connectionAttributes 时OceanBase MySQL wire 路径必须把
// 这些 attribute 透传到 go-sql-driver/mysql DSN让 driver 在握手响应里发 CLIENT_CONNECT_ATTRS。
// 这是 OBClient 协议握手探索的入口:高级用户/DBA 可以试错不同 attribute 组合而不需要改 GoNavi 代码。

View File

@@ -25,19 +25,22 @@ var (
)
type oracleRecordingState struct {
mu sync.Mutex
execQueries []string
execArgs [][]driver.NamedValue
queries []string
beginCalls int
rowsAffected int64
queryResults map[string]oracleRecordingQueryResult
queryError error
mu sync.Mutex
execQueries []string
execArgs [][]driver.NamedValue
queries []string
beginCalls int
rowsAffected int64
queryResults map[string]oracleRecordingQueryResult
queryError error
disableDefaultTabColumns bool
}
type oracleRecordingQueryResult struct {
columns []string
rows [][]driver.Value
columns []string
columnTypes []string
nullable []bool
rows [][]driver.Value
}
func (s *oracleRecordingState) snapshotExecQueries() []string {
@@ -109,6 +112,7 @@ func (c *oracleRecordingConn) ExecContext(_ context.Context, query string, args
func (c *oracleRecordingConn) QueryContext(_ context.Context, query string, _ []driver.NamedValue) (driver.Rows, error) {
c.state.mu.Lock()
c.state.queries = append(c.state.queries, query)
disableDefaultTabColumns := c.state.disableDefaultTabColumns
if err := c.state.queryError; err != nil {
c.state.mu.Unlock()
return nil, err
@@ -116,13 +120,15 @@ func (c *oracleRecordingConn) QueryContext(_ context.Context, query string, _ []
if result, ok := c.state.queryResults[query]; ok {
c.state.mu.Unlock()
return &oracleRecordingRows{
columns: append([]string(nil), result.columns...),
rows: cloneOracleRecordingRows(result.rows),
columns: append([]string(nil), result.columns...),
columnTypes: append([]string(nil), result.columnTypes...),
nullable: append([]bool(nil), result.nullable...),
rows: cloneOracleRecordingRows(result.rows),
}, nil
}
c.state.mu.Unlock()
if strings.Contains(strings.ToLower(query), "tab_columns") {
if strings.Contains(strings.ToLower(query), "tab_columns") && !disableDefaultTabColumns {
return &oracleRecordingRows{
columns: []string{"COLUMN_NAME", "DATA_TYPE", "NULLABLE", "DATA_DEFAULT", "COLUMN_KEY", "COMMENT"},
rows: [][]driver.Value{
@@ -151,9 +157,11 @@ func (oracleRecordingTx) Commit() error { return nil }
func (oracleRecordingTx) Rollback() error { return nil }
type oracleRecordingRows struct {
columns []string
rows [][]driver.Value
index int
columns []string
columnTypes []string
nullable []bool
rows [][]driver.Value
index int
}
func (r *oracleRecordingRows) Columns() []string {
@@ -162,6 +170,20 @@ func (r *oracleRecordingRows) Columns() []string {
func (r *oracleRecordingRows) Close() error { return nil }
func (r *oracleRecordingRows) ColumnTypeDatabaseTypeName(index int) string {
if index < 0 || index >= len(r.columnTypes) {
return ""
}
return r.columnTypes[index]
}
func (r *oracleRecordingRows) ColumnTypeNullable(index int) (nullable, ok bool) {
if index < 0 || index >= len(r.nullable) {
return false, false
}
return r.nullable[index], true
}
func (r *oracleRecordingRows) Next(dest []driver.Value) error {
if r.index >= len(r.rows) {
return io.EOF

View File

@@ -3,6 +3,7 @@ package db
import (
"database/sql/driver"
"reflect"
"slices"
"strings"
"testing"
)
@@ -111,6 +112,59 @@ func TestOracleGetColumnsIncludesColumnComments(t *testing.T) {
}
}
func TestOracleGetColumnsPreservesMetadataNameCaseBeforeUppercaseFallback(t *testing.T) {
t.Parallel()
dbConn, state := openOracleRecordingDB(t)
oracleDB := &OracleDB{conn: dbConn}
if _, err := oracleDB.GetColumns("SYS", "test"); err != nil {
t.Fatalf("GetColumns 返回错误: %v", err)
}
queries := state.snapshotQueries()
if len(queries) == 0 {
t.Fatalf("expected metadata query")
}
if !strings.Contains(queries[0], `WHERE c.owner = 'SYS' AND c.table_name = 'test'`) {
t.Fatalf("expected first metadata query to preserve table case, got: %s", queries[0])
}
}
func TestOracleGetColumnsFallsBackToSelectMetadataWhenDictionaryIsEmpty(t *testing.T) {
t.Parallel()
dbConn, state := openOracleRecordingDB(t)
state.mu.Lock()
state.disableDefaultTabColumns = true
state.queryResults[`SELECT * FROM "SYS"."test" WHERE 1 = 0`] = oracleRecordingQueryResult{
columns: []string{"id", "new_col_1"},
columnTypes: []string{"NUMBER", "VARCHAR2"},
nullable: []bool{false, true},
rows: [][]driver.Value{},
}
state.mu.Unlock()
oracleDB := &OracleDB{conn: dbConn}
columns, err := oracleDB.GetColumns("SYS", "test")
if err != nil {
t.Fatalf("GetColumns 返回错误: %v", err)
}
if len(columns) != 2 {
t.Fatalf("expected fallback columns, got %#v", columns)
}
if columns[0].Name != "id" || columns[0].Type != "NUMBER" || columns[0].Nullable != "NO" {
t.Fatalf("unexpected first fallback column: %#v", columns[0])
}
if columns[1].Name != "new_col_1" || columns[1].Type != "VARCHAR2" || columns[1].Nullable != "YES" {
t.Fatalf("unexpected second fallback column: %#v", columns[1])
}
queries := state.snapshotQueries()
if !slices.Contains(queries, `SELECT * FROM "SYS"."test" WHERE 1 = 0`) {
t.Fatalf("expected SELECT metadata fallback query, got: %v", queries)
}
}
func TestFormatOracleColumnTypeIncludesLengthAndPrecision(t *testing.T) {
t.Parallel()
@@ -157,6 +211,66 @@ func TestFormatOracleColumnTypeIncludesLengthAndPrecision(t *testing.T) {
}
}
func TestOracleGetCreateStatementPreservesMetadataNameCase(t *testing.T) {
t.Parallel()
dbConn, state := openOracleRecordingDB(t)
state.mu.Lock()
state.queryResults[`SELECT DBMS_METADATA.GET_DDL('TABLE', 'test', 'SYS') as ddl FROM DUAL`] = oracleRecordingQueryResult{
columns: []string{"DDL"},
rows: [][]driver.Value{
{`CREATE TABLE "SYS"."test" ("ID" NUMBER)`},
},
}
state.mu.Unlock()
oracleDB := &OracleDB{conn: dbConn}
ddl, err := oracleDB.GetCreateStatement("SYS", "test")
if err != nil {
t.Fatalf("GetCreateStatement 返回错误: %v", err)
}
if !strings.Contains(ddl, `CREATE TABLE "SYS"."test"`) {
t.Fatalf("expected lowercase metadata DDL, got: %s", ddl)
}
queries := state.snapshotQueries()
if len(queries) == 0 || queries[0] != `SELECT DBMS_METADATA.GET_DDL('TABLE', 'test', 'SYS') as ddl FROM DUAL` {
t.Fatalf("expected first DDL query to preserve case, got: %v", queries)
}
}
func TestOracleGetCreateStatementFallsBackToUppercaseMetadataName(t *testing.T) {
t.Parallel()
dbConn, state := openOracleRecordingDB(t)
state.mu.Lock()
state.queryResults[`SELECT DBMS_METADATA.GET_DDL('TABLE', 'TEST', 'SYS') as ddl FROM DUAL`] = oracleRecordingQueryResult{
columns: []string{"DDL"},
rows: [][]driver.Value{
{`CREATE TABLE "SYS"."TEST" ("ID" NUMBER)`},
},
}
state.mu.Unlock()
oracleDB := &OracleDB{conn: dbConn}
ddl, err := oracleDB.GetCreateStatement("SYS", "test")
if err != nil {
t.Fatalf("GetCreateStatement 返回错误: %v", err)
}
if !strings.Contains(ddl, `CREATE TABLE "SYS"."TEST"`) {
t.Fatalf("expected uppercase fallback DDL, got: %s", ddl)
}
queries := state.snapshotQueries()
if len(queries) < 2 {
t.Fatalf("expected original-case query followed by uppercase fallback, got: %v", queries)
}
if queries[0] != `SELECT DBMS_METADATA.GET_DDL('TABLE', 'test', 'SYS') as ddl FROM DUAL` ||
queries[1] != `SELECT DBMS_METADATA.GET_DDL('TABLE', 'TEST', 'SYS') as ddl FROM DUAL` {
t.Fatalf("unexpected DDL fallback query order: %v", queries)
}
}
func TestOracleGetCreateStatementAppendsTableAndColumnComments(t *testing.T) {
t.Parallel()

View File

@@ -22,6 +22,7 @@ type OracleDB struct {
conn *sql.DB
pingTimeout time.Duration
forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder
scanDialect string
}
var _ SessionExecerProvider = (*OracleDB)(nil)
@@ -216,7 +217,7 @@ func (o *OracleDB) QueryContext(ctx context.Context, query string) ([]map[string
}
defer rows.Close()
return scanRows(rows)
return scanRowsForDialect(rows, o.scanDialect)
}
func (o *OracleDB) Query(query string) ([]map[string]interface{}, []string, error) {
@@ -229,7 +230,7 @@ func (o *OracleDB) Query(query string) ([]map[string]interface{}, []string, erro
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
return scanRowsForDialect(rows, o.scanDialect)
}
func (o *OracleDB) ExecContext(ctx context.Context, query string) (int64, error) {
@@ -262,7 +263,7 @@ func (o *OracleDB) OpenTransactionExecer(ctx context.Context) (TransactionExecer
if err != nil {
return nil, err
}
return NewSQLConnTransactionExecer(conn, "COMMIT", "ROLLBACK"), nil
return NewSQLConnTransactionExecerWithDialect(conn, "COMMIT", "ROLLBACK", o.scanDialect), nil
}
func (o *OracleDB) OpenSessionExecer(ctx context.Context) (StatementExecer, error) {
@@ -273,7 +274,7 @@ func (o *OracleDB) OpenSessionExecer(ctx context.Context) (StatementExecer, erro
if err != nil {
return nil, err
}
return NewSQLConnStatementExecer(conn), nil
return NewSQLConnStatementExecerWithDialect(conn, o.scanDialect), nil
}
func (o *OracleDB) GetDatabases() ([]string, error) {
@@ -325,89 +326,153 @@ func (o *OracleDB) GetTables(dbName string) ([]string, error) {
func (o *OracleDB) GetCreateStatement(dbName, tableName string) (string, error) {
// Oracle provides DBMS_METADATA.GET_DDL
// Note: LONG type might be tricky, but basic string scan should work for smaller DDLs
metadataTableName := escapeOracleMetadataLiteral(tableName)
metadataSchemaName := escapeOracleMetadataLiteral(dbName)
query := fmt.Sprintf("SELECT DBMS_METADATA.GET_DDL('TABLE', '%s', '%s') as ddl FROM DUAL",
metadataTableName, metadataSchemaName)
var firstErr error
for _, candidate := range oracleMetadataNamePairs(dbName, tableName) {
metadataTableName := escapeOracleMetadataLiteralExact(candidate.table)
metadataSchemaName := escapeOracleMetadataLiteralExact(candidate.schema)
query := fmt.Sprintf("SELECT DBMS_METADATA.GET_DDL('TABLE', '%s', '%s') as ddl FROM DUAL",
metadataTableName, metadataSchemaName)
if dbName == "" {
query = fmt.Sprintf("SELECT DBMS_METADATA.GET_DDL('TABLE', '%s') as ddl FROM DUAL", metadataTableName)
}
data, _, err := o.Query(query)
if err != nil {
return "", err
}
if len(data) > 0 {
if val, ok := data[0]["DDL"]; ok {
return o.appendOracleCommentDDL(fmt.Sprintf("%v", val), dbName, tableName), nil
if candidate.schema == "" {
query = fmt.Sprintf("SELECT DBMS_METADATA.GET_DDL('TABLE', '%s') as ddl FROM DUAL", metadataTableName)
}
data, _, err := o.Query(query)
if err != nil {
if firstErr == nil {
firstErr = err
}
continue
}
if len(data) > 0 {
if val, ok := data[0]["DDL"]; ok {
ddl := strings.TrimSpace(fmt.Sprintf("%v", val))
if ddl != "" {
return o.appendOracleCommentDDL(ddl, candidate.schema, candidate.table), nil
}
}
}
}
if firstErr != nil {
return "", firstErr
}
return "", fmt.Errorf("未找到建表语句")
}
func (o *OracleDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
metadataTableName := escapeOracleMetadataLiteral(tableName)
metadataSchemaName := escapeOracleMetadataLiteral(dbName)
query := fmt.Sprintf(`SELECT c.column_name AS "COLUMN_NAME", c.data_type AS "DATA_TYPE", c.data_length AS "DATA_LENGTH", c.char_length AS "CHAR_LENGTH", c.data_precision AS "DATA_PRECISION", c.data_scale AS "DATA_SCALE", c.nullable AS "NULLABLE", c.data_default AS "DATA_DEFAULT",
CASE WHEN pk.column_name IS NOT NULL THEN 'PRI' ELSE '' END AS "COLUMN_KEY",
cc.comments AS "COMMENT"
FROM all_tab_columns c
LEFT JOIN all_col_comments cc
ON cc.owner = c.owner AND cc.table_name = c.table_name AND cc.column_name = c.column_name
LEFT JOIN (
SELECT cols.owner, cols.table_name, cols.column_name
FROM all_constraints cons
JOIN all_cons_columns cols
ON cons.owner = cols.owner AND cons.constraint_name = cols.constraint_name
WHERE cons.constraint_type = 'P'
) pk ON c.owner = pk.owner AND c.table_name = pk.table_name AND c.column_name = pk.column_name
WHERE c.owner = '%s' AND c.table_name = '%s'
ORDER BY c.column_id`, metadataSchemaName, metadataTableName)
for _, candidate := range oracleMetadataNamePairs(dbName, tableName) {
query := buildOracleColumnsQuery(candidate.schema, candidate.table)
data, _, err := o.Query(query)
if err != nil {
return nil, err
}
if len(data) == 0 {
continue
}
if dbName == "" {
query = fmt.Sprintf(`SELECT c.column_name AS "COLUMN_NAME", c.data_type AS "DATA_TYPE", c.data_length AS "DATA_LENGTH", c.char_length AS "CHAR_LENGTH", c.data_precision AS "DATA_PRECISION", c.data_scale AS "DATA_SCALE", c.nullable AS "NULLABLE", c.data_default AS "DATA_DEFAULT",
CASE WHEN pk.column_name IS NOT NULL THEN 'PRI' ELSE '' END AS "COLUMN_KEY",
cc.comments AS "COMMENT"
FROM user_tab_columns c
LEFT JOIN user_col_comments cc
ON cc.table_name = c.table_name AND cc.column_name = c.column_name
LEFT JOIN (
SELECT cols.table_name, cols.column_name
FROM user_constraints cons
JOIN user_cons_columns cols USING (constraint_name)
WHERE cons.constraint_type = 'P'
) pk ON c.table_name = pk.table_name AND c.column_name = pk.column_name
WHERE c.table_name = '%s'
ORDER BY c.column_id`, metadataTableName)
return parseOracleColumns(data), nil
}
if columns, err := o.inferOracleColumnsFromSelect(dbName, tableName); err == nil && len(columns) > 0 {
return columns, nil
}
return []connection.ColumnDefinition{}, nil
}
func (o *OracleDB) inferOracleColumnsFromSelect(dbName string, tableName string) ([]connection.ColumnDefinition, error) {
if o.conn == nil {
return nil, fmt.Errorf("连接未打开")
}
data, _, err := o.Query(query)
var firstErr error
for _, candidate := range oracleMetadataNamePairs(dbName, tableName) {
query := "SELECT * FROM " + quoteOracleTableRef(candidate.schema, candidate.table) + " WHERE 1 = 0"
rows, err := o.conn.Query(query)
if err != nil {
if firstErr == nil {
firstErr = err
}
continue
}
columns, parseErr := oracleColumnsFromSQLRows(rows)
closeErr := rows.Close()
if parseErr != nil {
if firstErr == nil {
firstErr = parseErr
}
continue
}
if closeErr != nil {
if firstErr == nil {
firstErr = closeErr
}
continue
}
if len(columns) > 0 {
return columns, nil
}
}
if firstErr != nil {
return nil, firstErr
}
return nil, fmt.Errorf("未获取到字段定义")
}
func oracleColumnsFromSQLRows(rows *sql.Rows) ([]connection.ColumnDefinition, error) {
names, err := rows.Columns()
if err != nil {
return nil, err
}
colTypes, err := rows.ColumnTypes()
if err != nil || len(colTypes) != len(names) {
colTypes = nil
}
var columns []connection.ColumnDefinition
for _, row := range data {
columns := make([]connection.ColumnDefinition, 0, len(names))
for idx, name := range names {
col := connection.ColumnDefinition{
Name: oracleRowString(row, "COLUMN_NAME"),
Type: formatOracleColumnType(row),
Nullable: oracleRowString(row, "NULLABLE"),
Key: oracleRowString(row, "COLUMN_KEY"),
Comment: oracleRowString(row, "COMMENT"),
Name: strings.TrimSpace(name),
Nullable: "",
Key: "",
Extra: "",
Comment: "",
}
if defaultValue := oracleRowValue(row, "DATA_DEFAULT"); defaultValue != nil {
d := fmt.Sprintf("%v", defaultValue)
col.Default = &d
if colTypes != nil && idx < len(colTypes) && colTypes[idx] != nil {
col.Type = formatOracleSQLColumnType(colTypes[idx])
if nullable, ok := colTypes[idx].Nullable(); ok {
if nullable {
col.Nullable = "YES"
} else {
col.Nullable = "NO"
}
}
}
columns = append(columns, col)
}
return columns, nil
}
func formatOracleSQLColumnType(colType *sql.ColumnType) string {
if colType == nil {
return ""
}
typeName := strings.TrimSpace(colType.DatabaseTypeName())
if typeName == "" {
return ""
}
upperType := strings.ToUpper(typeName)
if length, ok := colType.Length(); ok && length > 0 && strings.Contains(upperType, "CHAR") {
return fmt.Sprintf("%s(%d)", typeName, length)
}
if precision, scale, ok := colType.DecimalSize(); ok && precision > 0 && (strings.Contains(upperType, "NUMBER") || strings.Contains(upperType, "DECIMAL") || strings.Contains(upperType, "NUMERIC")) {
if scale > 0 {
return fmt.Sprintf("%s(%d,%d)", typeName, precision, scale)
}
return fmt.Sprintf("%s(%d)", typeName, precision)
}
return typeName
}
func oracleRowValue(row map[string]interface{}, names ...string) interface{} {
for _, name := range names {
if value, ok := row[name]; ok {
@@ -482,12 +547,12 @@ func formatOracleColumnType(row map[string]interface{}) string {
}
func (o *OracleDB) appendOracleCommentDDL(baseDDL string, dbName string, tableName string) string {
table := strings.ToUpper(strings.TrimSpace(tableName))
table := strings.TrimSpace(tableName)
if strings.TrimSpace(baseDDL) == "" || table == "" {
return baseDDL
}
schema := strings.ToUpper(strings.TrimSpace(dbName))
schema := strings.TrimSpace(dbName)
tableRef := quoteOracleDDLIdentifier(table)
if schema != "" {
tableRef = quoteOracleDDLIdentifier(schema) + "." + tableRef
@@ -523,10 +588,10 @@ func (o *OracleDB) appendOracleCommentDDL(baseDDL string, dbName string, tableNa
}
func (o *OracleDB) fetchOracleTableComment(schema string, table string) string {
escapedTable := escapeOracleMetadataLiteral(table)
escapedTable := escapeOracleMetadataLiteralExact(table)
var query string
if strings.TrimSpace(schema) != "" {
query = fmt.Sprintf(`SELECT comments AS "COMMENT" FROM all_tab_comments WHERE owner = '%s' AND table_name = '%s' AND comments IS NOT NULL`, escapeOracleMetadataLiteral(schema), escapedTable)
query = fmt.Sprintf(`SELECT comments AS "COMMENT" FROM all_tab_comments WHERE owner = '%s' AND table_name = '%s' AND comments IS NOT NULL`, escapeOracleMetadataLiteralExact(schema), escapedTable)
} else {
query = fmt.Sprintf(`SELECT comments AS "COMMENT" FROM user_tab_comments WHERE table_name = '%s' AND comments IS NOT NULL`, escapedTable)
}
@@ -543,7 +608,7 @@ type oracleColumnComment struct {
}
func (o *OracleDB) fetchOracleColumnComments(schema string, table string) []oracleColumnComment {
escapedTable := escapeOracleMetadataLiteral(table)
escapedTable := escapeOracleMetadataLiteralExact(table)
var query string
if strings.TrimSpace(schema) != "" {
query = fmt.Sprintf(`SELECT c.column_name AS "COLUMN_NAME", cc.comments AS "COMMENT"
@@ -551,7 +616,7 @@ FROM all_tab_columns c
JOIN all_col_comments cc
ON cc.owner = c.owner AND cc.table_name = c.table_name AND cc.column_name = c.column_name
WHERE c.owner = '%s' AND c.table_name = '%s' AND cc.comments IS NOT NULL
ORDER BY c.column_id`, escapeOracleMetadataLiteral(schema), escapedTable)
ORDER BY c.column_id`, escapeOracleMetadataLiteralExact(schema), escapedTable)
} else {
query = fmt.Sprintf(`SELECT c.column_name AS "COLUMN_NAME", cc.comments AS "COMMENT"
FROM user_tab_columns c
@@ -579,6 +644,14 @@ func quoteOracleDDLIdentifier(ident string) string {
return `"` + strings.ReplaceAll(strings.TrimSpace(ident), `"`, `""`) + `"`
}
func quoteOracleTableRef(schema string, table string) string {
tableRef := quoteOracleDDLIdentifier(table)
if strings.TrimSpace(schema) != "" {
return quoteOracleDDLIdentifier(schema) + "." + tableRef
}
return tableRef
}
func escapeOracleCommentLiteral(text string) string {
return strings.ReplaceAll(text, "'", "''")
}
@@ -587,14 +660,132 @@ func escapeOracleMetadataLiteral(text string) string {
return strings.ReplaceAll(strings.ToUpper(strings.TrimSpace(text)), "'", "''")
}
func escapeOracleMetadataLiteralExact(text string) string {
return strings.ReplaceAll(strings.TrimSpace(text), "'", "''")
}
type oracleMetadataNamePair struct {
schema string
table string
}
func oracleMetadataNamePairs(dbName string, tableName string) []oracleMetadataNamePair {
rawSchema := strings.TrimSpace(dbName)
rawTable := strings.TrimSpace(tableName)
if rawTable == "" {
return nil
}
upperSchema := strings.ToUpper(rawSchema)
upperTable := strings.ToUpper(rawTable)
pairs := make([]oracleMetadataNamePair, 0, 4)
seen := map[string]struct{}{}
add := func(schema string, table string) {
key := schema + "\x00" + table
if _, ok := seen[key]; ok {
return
}
seen[key] = struct{}{}
pairs = append(pairs, oracleMetadataNamePair{schema: schema, table: table})
}
add(rawSchema, rawTable)
add(upperSchema, upperTable)
add(rawSchema, upperTable)
add(upperSchema, rawTable)
return pairs
}
func buildOracleColumnsQuery(schema string, table string) string {
metadataTableName := escapeOracleMetadataLiteralExact(table)
metadataSchemaName := escapeOracleMetadataLiteralExact(schema)
if strings.TrimSpace(schema) == "" {
return fmt.Sprintf(`SELECT c.column_name AS "COLUMN_NAME", c.data_type AS "DATA_TYPE", c.data_length AS "DATA_LENGTH", c.char_length AS "CHAR_LENGTH", c.data_precision AS "DATA_PRECISION", c.data_scale AS "DATA_SCALE", c.nullable AS "NULLABLE", c.data_default AS "DATA_DEFAULT",
CASE WHEN pk.column_name IS NOT NULL THEN 'PRI' ELSE '' END AS "COLUMN_KEY",
cc.comments AS "COMMENT"
FROM user_tab_columns c
LEFT JOIN user_col_comments cc
ON cc.table_name = c.table_name AND cc.column_name = c.column_name
LEFT JOIN (
SELECT cols.table_name, cols.column_name
FROM user_constraints cons
JOIN user_cons_columns cols USING (constraint_name)
WHERE cons.constraint_type = 'P'
) pk ON c.table_name = pk.table_name AND c.column_name = pk.column_name
WHERE c.table_name = '%s'
ORDER BY c.column_id`, metadataTableName)
}
return fmt.Sprintf(`SELECT c.column_name AS "COLUMN_NAME", c.data_type AS "DATA_TYPE", c.data_length AS "DATA_LENGTH", c.char_length AS "CHAR_LENGTH", c.data_precision AS "DATA_PRECISION", c.data_scale AS "DATA_SCALE", c.nullable AS "NULLABLE", c.data_default AS "DATA_DEFAULT",
CASE WHEN pk.column_name IS NOT NULL THEN 'PRI' ELSE '' END AS "COLUMN_KEY",
cc.comments AS "COMMENT"
FROM all_tab_columns c
LEFT JOIN all_col_comments cc
ON cc.owner = c.owner AND cc.table_name = c.table_name AND cc.column_name = c.column_name
LEFT JOIN (
SELECT cols.owner, cols.table_name, cols.column_name
FROM all_constraints cons
JOIN all_cons_columns cols
ON cons.owner = cols.owner AND cons.constraint_name = cols.constraint_name
WHERE cons.constraint_type = 'P'
) pk ON c.owner = pk.owner AND c.table_name = pk.table_name AND c.column_name = pk.column_name
WHERE c.owner = '%s' AND c.table_name = '%s'
ORDER BY c.column_id`, metadataSchemaName, metadataTableName)
}
func parseOracleColumns(data []map[string]interface{}) []connection.ColumnDefinition {
columns := make([]connection.ColumnDefinition, 0, len(data))
for _, row := range data {
col := connection.ColumnDefinition{
Name: oracleRowString(row, "COLUMN_NAME"),
Type: formatOracleColumnType(row),
Nullable: oracleRowString(row, "NULLABLE"),
Key: oracleRowString(row, "COLUMN_KEY"),
Comment: oracleRowString(row, "COMMENT"),
}
if defaultValue := oracleRowValue(row, "DATA_DEFAULT"); defaultValue != nil {
d := fmt.Sprintf("%v", defaultValue)
col.Default = &d
}
columns = append(columns, col)
}
return columns
}
func (o *OracleDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
esc := func(s string) string { return strings.ReplaceAll(strings.ToUpper(strings.TrimSpace(s)), "'", "''") }
table := esc(tableName)
if table == "" {
if strings.TrimSpace(tableName) == "" {
return nil, fmt.Errorf("表名不能为空")
}
query := fmt.Sprintf(`SELECT c.index_name, c.column_name, i.uniqueness, c.column_position, i.index_type
for _, candidate := range oracleMetadataNamePairs(dbName, tableName) {
data, _, err := o.Query(buildOracleIndexesQuery(candidate.schema, candidate.table))
if err != nil {
return nil, err
}
if len(data) == 0 {
continue
}
return parseOracleIndexes(data), nil
}
return []connection.IndexDefinition{}, nil
}
func buildOracleIndexesQuery(schema string, table string) string {
metadataTableName := escapeOracleMetadataLiteralExact(table)
metadataSchemaName := escapeOracleMetadataLiteralExact(schema)
if strings.TrimSpace(schema) == "" {
return fmt.Sprintf(`SELECT c.index_name, c.column_name, i.uniqueness, c.column_position, i.index_type
FROM user_ind_columns c
JOIN user_indexes i ON i.index_name = c.index_name
WHERE c.table_name = '%s'
AND c.column_name IS NOT NULL
AND c.column_name NOT LIKE 'SYS_NC%%$'
AND i.index_type NOT LIKE 'FUNCTION-BASED%%'
ORDER BY c.index_name, c.column_position`, metadataTableName)
}
return fmt.Sprintf(`SELECT c.index_name, c.column_name, i.uniqueness, c.column_position, i.index_type
FROM all_ind_columns c
JOIN all_indexes i ON i.owner = c.index_owner AND i.index_name = c.index_name
WHERE c.table_owner = '%s'
@@ -602,24 +793,10 @@ func (o *OracleDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefin
AND c.column_name IS NOT NULL
AND c.column_name NOT LIKE 'SYS_NC%%$'
AND i.index_type NOT LIKE 'FUNCTION-BASED%%'
ORDER BY c.index_name, c.column_position`, esc(dbName), table)
if strings.TrimSpace(dbName) == "" {
query = fmt.Sprintf(`SELECT c.index_name, c.column_name, i.uniqueness, c.column_position, i.index_type
FROM user_ind_columns c
JOIN user_indexes i ON i.index_name = c.index_name
WHERE c.table_name = '%s'
AND c.column_name IS NOT NULL
AND c.column_name NOT LIKE 'SYS_NC%%$'
AND i.index_type NOT LIKE 'FUNCTION-BASED%%'
ORDER BY c.index_name, c.column_position`, table)
}
data, _, err := o.Query(query)
if err != nil {
return nil, err
}
ORDER BY c.index_name, c.column_position`, metadataSchemaName, metadataTableName)
}
func parseOracleIndexes(data []map[string]interface{}) []connection.IndexDefinition {
getValue := func(row map[string]interface{}, names ...string) interface{} {
for _, name := range names {
if value, ok := row[name]; ok {
@@ -663,24 +840,44 @@ func (o *OracleDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefin
}
indexes = append(indexes, idx)
}
return indexes, nil
return indexes
}
func (o *OracleDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
// Simplified query for FKs
query := fmt.Sprintf(`SELECT a.constraint_name, a.column_name, c_pk.table_name r_table_name, b.column_name r_column_name
for _, candidate := range oracleMetadataNamePairs(dbName, tableName) {
data, _, err := o.Query(buildOracleForeignKeysQuery(candidate.schema, candidate.table))
if err != nil {
return nil, err
}
if len(data) == 0 {
continue
}
return parseOracleForeignKeys(data), nil
}
return []connection.ForeignKeyDefinition{}, nil
}
func buildOracleForeignKeysQuery(schema string, table string) string {
metadataTableName := escapeOracleMetadataLiteralExact(table)
metadataSchemaName := escapeOracleMetadataLiteralExact(schema)
if strings.TrimSpace(schema) == "" {
return fmt.Sprintf(`SELECT a.constraint_name, a.column_name, c_pk.table_name r_table_name, b.column_name r_column_name
FROM user_cons_columns a
JOIN user_constraints c ON a.constraint_name = c.constraint_name
JOIN user_constraints c_pk ON c.r_constraint_name = c_pk.constraint_name
JOIN user_cons_columns b ON c_pk.constraint_name = b.constraint_name AND a.position = b.position
WHERE c.constraint_type = 'R' AND a.table_name = '%s'`, metadataTableName)
}
return fmt.Sprintf(`SELECT a.constraint_name, a.column_name, c_pk.table_name r_table_name, b.column_name r_column_name
FROM all_cons_columns a
JOIN all_constraints c ON a.owner = c.owner AND a.constraint_name = c.constraint_name
JOIN all_constraints c_pk ON c.r_owner = c_pk.owner AND c.r_constraint_name = c_pk.constraint_name
JOIN all_cons_columns b ON c_pk.owner = b.owner AND c_pk.constraint_name = b.constraint_name AND a.position = b.position
WHERE c.constraint_type = 'R' AND a.owner = '%s' AND a.table_name = '%s'`,
strings.ToUpper(dbName), strings.ToUpper(tableName))
data, _, err := o.Query(query)
if err != nil {
return nil, err
}
metadataSchemaName, metadataTableName)
}
func parseOracleForeignKeys(data []map[string]interface{}) []connection.ForeignKeyDefinition {
var fks []connection.ForeignKeyDefinition
for _, row := range data {
fk := connection.ForeignKeyDefinition{
@@ -692,20 +889,38 @@ func (o *OracleDB) GetForeignKeys(dbName, tableName string) ([]connection.Foreig
}
fks = append(fks, fk)
}
return fks, nil
return fks
}
func (o *OracleDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
query := fmt.Sprintf(`SELECT trigger_name, trigger_type, triggering_event
FROM all_triggers
WHERE table_owner = '%s' AND table_name = '%s'`,
strings.ToUpper(dbName), strings.ToUpper(tableName))
data, _, err := o.Query(query)
if err != nil {
return nil, err
for _, candidate := range oracleMetadataNamePairs(dbName, tableName) {
data, _, err := o.Query(buildOracleTriggersQuery(candidate.schema, candidate.table))
if err != nil {
return nil, err
}
if len(data) == 0 {
continue
}
return parseOracleTriggers(data), nil
}
return []connection.TriggerDefinition{}, nil
}
func buildOracleTriggersQuery(schema string, table string) string {
metadataTableName := escapeOracleMetadataLiteralExact(table)
metadataSchemaName := escapeOracleMetadataLiteralExact(schema)
if strings.TrimSpace(schema) == "" {
return fmt.Sprintf(`SELECT trigger_name, trigger_type, triggering_event
FROM user_triggers
WHERE table_name = '%s'`, metadataTableName)
}
return fmt.Sprintf(`SELECT trigger_name, trigger_type, triggering_event
FROM all_triggers
WHERE table_owner = '%s' AND table_name = '%s'`,
metadataSchemaName, metadataTableName)
}
func parseOracleTriggers(data []map[string]interface{}) []connection.TriggerDefinition {
var triggers []connection.TriggerDefinition
for _, row := range data {
trig := connection.TriggerDefinition{
@@ -716,7 +931,7 @@ func (o *OracleDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDe
}
triggers = append(triggers, trig)
}
return triggers, nil
return triggers
}
func splitOracleQualifiedTableName(raw string) (string, string) {

View File

@@ -1,6 +1,7 @@
package db
import (
"encoding/binary"
"encoding/hex"
"encoding/json"
"fmt"
@@ -16,9 +17,10 @@ import (
)
const (
jsMaxSafeInteger int64 = 9007199254740991
jsMinSafeInteger int64 = -9007199254740991
jsMaxSafeUint uint64 = 9007199254740991
jsMaxSafeInteger int64 = 9007199254740991
jsMinSafeInteger int64 = -9007199254740991
jsMaxSafeUint uint64 = 9007199254740991
oceanBaseOracleScanDialect = "oceanbase-oracle"
)
var (
@@ -40,7 +42,15 @@ func normalizeQueryValueWithDBTypeAndDialect(v interface{}, databaseTypeName, di
if tm, ok := v.(time.Time); ok {
return normalizeTemporalValueForDisplay(tm, databaseTypeName, dialect)
}
if s, ok := v.(string); ok {
if tm, normalizedType, ok := decodeOceanBaseOracleTemporalString(s, databaseTypeName, dialect); ok {
return normalizeTemporalValueForDisplay(tm, normalizedType, dialect)
}
}
if b, ok := v.([]byte); ok {
if tm, normalizedType, ok := decodeOceanBaseOracleTemporalBytes(b, databaseTypeName, dialect); ok {
return normalizeTemporalValueForDisplay(tm, normalizedType, dialect)
}
return bytesToDisplayValue(b, databaseTypeName)
}
return normalizeCompositeQueryValue(v)
@@ -52,15 +62,19 @@ func normalizeTemporalValueForDisplay(value time.Time, databaseTypeName, dialect
return zeroValue
}
}
if shouldDisplayTemporalValueAsDateOnly(databaseTypeName, dialect) {
if shouldDisplayTemporalValueAsDateOnly(databaseTypeName, dialect) || shouldDisplayOceanBaseOracleDateAsDateOnly(value, databaseTypeName, dialect) {
return value.Format("2006-01-02")
}
return value.Format(time.RFC3339Nano)
}
func shouldDisplayTemporalValueAsDateOnly(databaseTypeName, dialect string) bool {
func isDateOnlyDatabaseTypeName(databaseTypeName string) bool {
typeName := strings.ToUpper(strings.TrimSpace(databaseTypeName))
if typeName != "DATE" && typeName != "NEWDATE" {
return typeName == "DATE" || typeName == "NEWDATE"
}
func shouldDisplayTemporalValueAsDateOnly(databaseTypeName, dialect string) bool {
if !isDateOnlyDatabaseTypeName(databaseTypeName) {
return false
}
switch strings.ToLower(strings.TrimSpace(dialect)) {
@@ -71,6 +85,317 @@ func shouldDisplayTemporalValueAsDateOnly(databaseTypeName, dialect string) bool
}
}
func shouldDisplayOceanBaseOracleDateAsDateOnly(value time.Time, databaseTypeName, dialect string) bool {
if !isDateOnlyDatabaseTypeName(databaseTypeName) {
return false
}
if strings.ToLower(strings.TrimSpace(dialect)) != oceanBaseOracleScanDialect {
return false
}
return value.Hour() == 0 && value.Minute() == 0 && value.Second() == 0 && value.Nanosecond() == 0
}
func decodeOceanBaseOracleTemporalString(value string, databaseTypeName, dialect string) (time.Time, string, bool) {
return decodeOceanBaseOracleTemporalBytes([]byte(value), databaseTypeName, dialect)
}
func decodeOceanBaseOracleTemporalBytes(value []byte, databaseTypeName, dialect string) (time.Time, string, bool) {
if strings.ToLower(strings.TrimSpace(dialect)) != oceanBaseOracleScanDialect {
return time.Time{}, "", false
}
if !shouldAttemptOceanBaseOracleTemporalDecode(databaseTypeName, value) {
return time.Time{}, "", false
}
return parseOceanBaseOracleTemporal(value, databaseTypeName)
}
func shouldAttemptOceanBaseOracleTemporalDecode(databaseTypeName string, value []byte) bool {
if isOceanBaseOracleTemporalDatabaseTypeName(databaseTypeName) {
return true
}
if !hasOceanBaseOracleTemporalEncodedLength(value) {
return false
}
typeName := strings.ToUpper(strings.TrimSpace(databaseTypeName))
if typeName == "" {
return true
}
if isLikelyOceanBaseOracleTemporalCarrierType(typeName) {
return true
}
return false
}
func hasOceanBaseOracleTemporalEncodedLength(value []byte) bool {
switch len(value) {
case 5, 7, 8, 11, 12, 13:
return true
default:
return false
}
}
func isLikelyOceanBaseOracleTemporalCarrierType(typeName string) bool {
if typeName == "" {
return false
}
switch {
case strings.Contains(typeName, "CHAR"),
strings.Contains(typeName, "TEXT"),
strings.Contains(typeName, "STRING"),
strings.Contains(typeName, "BINARY"),
strings.Contains(typeName, "VARBINARY"),
strings.Contains(typeName, "RAW"),
strings.Contains(typeName, "BLOB"),
strings.Contains(typeName, "LOB"):
return true
default:
return false
}
}
func isOceanBaseOracleTemporalDatabaseTypeName(databaseTypeName string) bool {
typeName := strings.ToUpper(strings.TrimSpace(databaseTypeName))
if typeName == "DATE" || typeName == "TYPE_CA" {
return true
}
return strings.HasPrefix(typeName, "TIMESTAMP")
}
func parseOceanBaseOracleTemporal(value []byte, databaseTypeName string) (time.Time, string, bool) {
if tm, normalizedType, ok := parseOracleBinaryTemporal(value, databaseTypeName); ok {
return tm, normalizedType, true
}
if tm, normalizedType, ok := parseOceanBaseOracleTypeCATemporal(value, databaseTypeName); ok {
return tm, normalizedType, true
}
return parseMySQLLengthEncodedTemporal(value, databaseTypeName)
}
func parseOceanBaseOracleTypeCATemporal(value []byte, databaseTypeName string) (time.Time, string, bool) {
if len(value) != 12 {
return time.Time{}, "", false
}
yearHigh := int(value[0])
yearLow := int(value[1])
month := int(value[2])
day := int(value[3])
hour := int(value[4])
minute := int(value[5])
second := int(value[6])
nsec := int(binary.LittleEndian.Uint32(value[7:11]))
scale := int(value[11])
if yearHigh < 0 || yearHigh > 99 || yearLow < 0 || yearLow > 99 {
return time.Time{}, "", false
}
if month < 1 || month > 12 || day < 1 || day > 31 || hour < 0 || hour > 23 || minute < 0 || minute > 59 || second < 0 || second > 59 {
return time.Time{}, "", false
}
if scale < 0 || scale > 9 || nsec < 0 || nsec >= 1_000_000_000 {
return time.Time{}, "", false
}
if !matchesTemporalScale(nsec, scale) {
return time.Time{}, "", false
}
year := yearHigh*100 + yearLow
parsed := time.Date(year, time.Month(month), day, hour, minute, second, nsec, time.UTC)
if parsed.Year() != year || int(parsed.Month()) != month || parsed.Day() != day ||
parsed.Hour() != hour || parsed.Minute() != minute || parsed.Second() != second || parsed.Nanosecond() != nsec {
return time.Time{}, "", false
}
return parsed, normalizeOracleTemporalDatabaseTypeName(databaseTypeName), true
}
func matchesTemporalScale(nsec, scale int) bool {
if scale >= 9 {
return true
}
step := 1
for i := 0; i < 9-scale; i++ {
step *= 10
}
return nsec%step == 0
}
func parseOracleBinaryTemporal(value []byte, databaseTypeName string) (time.Time, string, bool) {
switch len(value) {
case 7:
tm, ok := parseOracleBinaryDateTime(value[:7])
return tm, "DATE", ok
case 11:
tm, ok := parseOracleBinaryTimestamp(value)
return tm, normalizeOracleTemporalDatabaseTypeName(databaseTypeName), ok
case 13:
tm, ok := parseOracleBinaryTimestampWithTimezone(value)
return tm, normalizeOracleTemporalDatabaseTypeName(databaseTypeName), ok
default:
return time.Time{}, "", false
}
}
func parseMySQLLengthEncodedTemporal(value []byte, databaseTypeName string) (time.Time, string, bool) {
if len(value) == 0 {
return time.Time{}, "", false
}
payloadLength := int(value[0])
if payloadLength != len(value)-1 {
return time.Time{}, "", false
}
switch payloadLength {
case 4:
tm, ok := parseMySQLBinaryDateTimePayload(value[1:], false)
return tm, "DATE", ok
case 7:
tm, ok := parseMySQLBinaryDateTimePayload(value[1:], false)
return tm, normalizeOracleTemporalDatabaseTypeName(databaseTypeName), ok
case 11:
tm, ok := parseMySQLBinaryDateTimePayload(value[1:], true)
return tm, normalizeOracleTemporalDatabaseTypeName(databaseTypeName), ok
default:
return time.Time{}, "", false
}
}
func parseMySQLBinaryDateTimePayload(value []byte, withFraction bool) (time.Time, bool) {
expectedLength := 7
if !withFraction {
switch len(value) {
case 4:
year := int(binary.LittleEndian.Uint16(value[0:2]))
month := int(value[2])
day := int(value[3])
if year < 0 || month < 1 || month > 12 || day < 1 || day > 31 {
return time.Time{}, false
}
parsed := time.Date(year, time.Month(month), day, 0, 0, 0, 0, time.UTC)
if parsed.Year() != year || int(parsed.Month()) != month || parsed.Day() != day {
return time.Time{}, false
}
return parsed, true
case expectedLength:
default:
return time.Time{}, false
}
} else if len(value) != 11 {
return time.Time{}, false
}
year := int(binary.LittleEndian.Uint16(value[0:2]))
month := int(value[2])
day := int(value[3])
hour := int(value[4])
minute := int(value[5])
second := int(value[6])
nsec := 0
if withFraction {
usec := binary.LittleEndian.Uint32(value[7:11])
if usec >= 1_000_000 {
return time.Time{}, false
}
nsec = int(usec) * 1000
}
if year < 0 || month < 1 || month > 12 || day < 1 || day > 31 || hour < 0 || hour > 23 || minute < 0 || minute > 59 || second < 0 || second > 59 {
return time.Time{}, false
}
parsed := time.Date(year, time.Month(month), day, hour, minute, second, nsec, time.UTC)
if parsed.Year() != year || int(parsed.Month()) != month || parsed.Day() != day ||
parsed.Hour() != hour || parsed.Minute() != minute || parsed.Second() != second || parsed.Nanosecond() != nsec {
return time.Time{}, false
}
return parsed, true
}
func normalizeOracleTemporalDatabaseTypeName(databaseTypeName string) string {
typeName := strings.ToUpper(strings.TrimSpace(databaseTypeName))
switch typeName {
case "TYPE_CA":
return "TIMESTAMP"
default:
if typeName == "" {
return "TIMESTAMP"
}
return typeName
}
}
func parseOracleBinaryTimestamp(value []byte) (time.Time, bool) {
if len(value) != 11 {
return time.Time{}, false
}
baseTime, ok := parseOracleBinaryDateTime(value[:7])
if !ok {
return time.Time{}, false
}
nsec := binary.BigEndian.Uint32(value[7:11])
if nsec >= 1_000_000_000 {
return time.Time{}, false
}
return time.Date(
baseTime.Year(),
baseTime.Month(),
baseTime.Day(),
baseTime.Hour(),
baseTime.Minute(),
baseTime.Second(),
int(nsec),
time.UTC,
), true
}
func parseOracleBinaryTimestampWithTimezone(value []byte) (time.Time, bool) {
if len(value) != 13 {
return time.Time{}, false
}
baseTime, ok := parseOracleBinaryTimestamp(value[:11])
if !ok {
return time.Time{}, false
}
tzHour := int(value[11]) - 20
tzMinute := int(value[12]) - 60
if tzHour < -12 || tzHour > 14 || tzMinute < 0 || tzMinute >= 60 {
return time.Time{}, false
}
location := time.FixedZone("", tzHour*3600+tzMinute*60)
return time.Date(
baseTime.Year(),
baseTime.Month(),
baseTime.Day(),
baseTime.Hour(),
baseTime.Minute(),
baseTime.Second(),
baseTime.Nanosecond(),
location,
), true
}
func parseOracleBinaryDateTime(value []byte) (time.Time, bool) {
if len(value) != 7 {
return time.Time{}, false
}
year := (int(value[0]) - 100) * 100
year += int(value[1]) - 100
month := int(value[2])
day := int(value[3])
hour := int(value[4]) - 1
minute := int(value[5]) - 1
second := int(value[6]) - 1
if year < 0 || month < 1 || month > 12 || day < 1 || day > 31 || hour < 0 || hour > 23 || minute < 0 || minute > 59 || second < 0 || second > 59 {
return time.Time{}, false
}
parsed := time.Date(year, time.Month(month), day, hour, minute, second, 0, time.UTC)
if parsed.Year() != year || int(parsed.Month()) != month || parsed.Day() != day ||
parsed.Hour() != hour || parsed.Minute() != minute || parsed.Second() != second {
return time.Time{}, false
}
return parsed, true
}
func zeroTemporalDisplayValue(databaseTypeName string) (string, bool) {
typeName := strings.ToUpper(strings.TrimSpace(databaseTypeName))
if typeName == "" {

View File

@@ -1,6 +1,7 @@
package db
import (
"encoding/binary"
"encoding/json"
"fmt"
"testing"
@@ -220,6 +221,154 @@ func TestNormalizeQueryValueWithDBType_TimeStructToRFC3339(t *testing.T) {
}
}
func TestNormalizeQueryValueWithDBTypeAndDialect_OceanBaseOracleBinaryTimestampString(t *testing.T) {
raw := buildOracleBinaryTimestamp(time.Date(2026, 6, 16, 12, 34, 56, 123456000, time.UTC))
got := normalizeQueryValueWithDBTypeAndDialect(string(raw), "TYPE_CA", oceanBaseOracleScanDialect)
if got != "2026-06-16T12:34:56.123456Z" {
t.Fatalf("OceanBase Oracle 二进制 TIMESTAMP 字符串应解码为 RFC3339实际=%v(%T)", got, got)
}
}
func TestNormalizeQueryValueWithDBTypeAndDialect_OceanBaseOracleBinaryTimestampStringWithPrecisionType(t *testing.T) {
raw := buildOracleBinaryTimestamp(time.Date(2026, 6, 16, 12, 34, 56, 123456000, time.UTC))
got := normalizeQueryValueWithDBTypeAndDialect(string(raw), "TIMESTAMP(6)", oceanBaseOracleScanDialect)
if got != "2026-06-16T12:34:56.123456Z" {
t.Fatalf("OceanBase Oracle TIMESTAMP(6) 字符串应解码为 RFC3339实际=%v(%T)", got, got)
}
}
func TestNormalizeQueryValueWithDBTypeAndDialect_OceanBaseOracleBinaryTimestampStringWithGenericCarrierType(t *testing.T) {
raw := buildOracleBinaryTimestamp(time.Date(2026, 6, 16, 12, 34, 56, 123456000, time.UTC))
got := normalizeQueryValueWithDBTypeAndDialect(string(raw), "VARCHAR2", oceanBaseOracleScanDialect)
if got != "2026-06-16T12:34:56.123456Z" {
t.Fatalf("OceanBase Oracle 泛型载体类型的 TIMESTAMP 字符串应解码为 RFC3339实际=%v(%T)", got, got)
}
}
func TestNormalizeQueryValueWithDBTypeAndDialect_OceanBaseOracleBinaryTimestampBytes(t *testing.T) {
raw := buildOracleBinaryTimestamp(time.Date(2026, 6, 16, 12, 34, 56, 123456000, time.UTC))
got := normalizeQueryValueWithDBTypeAndDialect(raw, "TYPE_CA", oceanBaseOracleScanDialect)
if got != "2026-06-16T12:34:56.123456Z" {
t.Fatalf("OceanBase Oracle 二进制 TIMESTAMP 字节值应解码为 RFC3339实际=%v(%T)", got, got)
}
}
func TestNormalizeQueryValueWithDBTypeAndDialect_OceanBaseOracleBinaryTimestampBytesWithGenericCarrierType(t *testing.T) {
raw := buildOracleBinaryTimestamp(time.Date(2026, 6, 16, 12, 34, 56, 123456000, time.UTC))
got := normalizeQueryValueWithDBTypeAndDialect(raw, "VARCHAR2", oceanBaseOracleScanDialect)
if got != "2026-06-16T12:34:56.123456Z" {
t.Fatalf("OceanBase Oracle 泛型载体类型的 TIMESTAMP 字节值应解码为 RFC3339实际=%v(%T)", got, got)
}
}
func TestNormalizeQueryValueWithDBTypeAndDialect_OceanBaseOracleBinaryTimestampBytesWithPrecisionType(t *testing.T) {
raw := buildOracleBinaryTimestamp(time.Date(2026, 6, 16, 12, 34, 56, 123456000, time.UTC))
got := normalizeQueryValueWithDBTypeAndDialect(raw, "TIMESTAMP(6)", oceanBaseOracleScanDialect)
if got != "2026-06-16T12:34:56.123456Z" {
t.Fatalf("OceanBase Oracle TIMESTAMP(6) 字节值应解码为 RFC3339实际=%v(%T)", got, got)
}
}
func TestNormalizeQueryValueWithDBTypeAndDialect_OceanBaseOracleBinaryTimestampStringWithoutTypeName(t *testing.T) {
raw := buildOracleBinaryTimestamp(time.Date(2026, 6, 16, 12, 34, 56, 123456000, time.UTC))
got := normalizeQueryValueWithDBTypeAndDialect(string(raw), "", oceanBaseOracleScanDialect)
if got != "2026-06-16T12:34:56.123456Z" {
t.Fatalf("OceanBase Oracle 空类型名的 TIMESTAMP 字符串应解码为 RFC3339实际=%v(%T)", got, got)
}
}
func TestNormalizeQueryValueWithDBTypeAndDialect_OceanBaseOracleBinaryTimestampBytesWithoutTypeName(t *testing.T) {
raw := buildOracleBinaryTimestamp(time.Date(2026, 6, 16, 12, 34, 56, 123456000, time.UTC))
got := normalizeQueryValueWithDBTypeAndDialect(raw, "", oceanBaseOracleScanDialect)
if got != "2026-06-16T12:34:56.123456Z" {
t.Fatalf("OceanBase Oracle 空类型名的 TIMESTAMP 字节值应解码为 RFC3339实际=%v(%T)", got, got)
}
}
func TestNormalizeQueryValueWithDBTypeAndDialect_OceanBaseOracleMySQLEncodedTimestampString(t *testing.T) {
raw := buildMySQLBinaryTimestamp(time.Date(2026, 6, 16, 12, 34, 56, 123456000, time.UTC))
got := normalizeQueryValueWithDBTypeAndDialect(string(raw), "TYPE_CA", oceanBaseOracleScanDialect)
if got != "2026-06-16T12:34:56.123456Z" {
t.Fatalf("OceanBase Oracle length-encoded TIMESTAMP 字符串应解码为 RFC3339实际=%v(%T)", got, got)
}
}
func TestNormalizeQueryValueWithDBTypeAndDialect_OceanBaseOracleMySQLEncodedTimestampBytes(t *testing.T) {
raw := buildMySQLBinaryTimestamp(time.Date(2026, 6, 16, 12, 34, 56, 123456000, time.UTC))
got := normalizeQueryValueWithDBTypeAndDialect(raw, "TYPE_CA", oceanBaseOracleScanDialect)
if got != "2026-06-16T12:34:56.123456Z" {
t.Fatalf("OceanBase Oracle length-encoded TIMESTAMP 字节值应解码为 RFC3339实际=%v(%T)", got, got)
}
}
func TestNormalizeQueryValueWithDBTypeAndDialect_OceanBaseOracleTypeCALiveTimestampString(t *testing.T) {
raw := []byte{20, 26, 6, 16, 16, 46, 23, 96, 196, 119, 9, 6}
got := normalizeQueryValueWithDBTypeAndDialect(string(raw), "TYPE_CA", oceanBaseOracleScanDialect)
if got != "2026-06-16T16:46:23.158844Z" {
t.Fatalf("OceanBase Oracle TYPE_CA live TIMESTAMP 字符串应解码为 RFC3339实际=%v(%T)", got, got)
}
}
func TestNormalizeQueryValueWithDBTypeAndDialect_OceanBaseOracleTypeCALiveTimestampBytes(t *testing.T) {
raw := []byte{20, 26, 6, 16, 16, 46, 23, 96, 196, 119, 9, 6}
got := normalizeQueryValueWithDBTypeAndDialect(raw, "TYPE_CA", oceanBaseOracleScanDialect)
if got != "2026-06-16T16:46:23.158844Z" {
t.Fatalf("OceanBase Oracle TYPE_CA live TIMESTAMP 字节值应解码为 RFC3339实际=%v(%T)", got, got)
}
}
func TestNormalizeQueryValueWithDBTypeAndDialect_OceanBaseOracleTypeCALiveTimestampWithoutFraction(t *testing.T) {
raw := []byte{20, 26, 6, 17, 5, 0, 0, 0, 0, 0, 0, 6}
got := normalizeQueryValueWithDBTypeAndDialect(raw, "TYPE_CA", oceanBaseOracleScanDialect)
if got != "2026-06-17T05:00:00Z" {
t.Fatalf("OceanBase Oracle TYPE_CA 零小数 TIMESTAMP 字节值应解码为 RFC3339实际=%v(%T)", got, got)
}
}
func buildOracleBinaryTimestamp(tm time.Time) []byte {
if tm.Location() != time.UTC {
tm = tm.In(time.UTC)
}
buf := []byte{
byte(tm.Year()/100 + 100),
byte(tm.Year()%100 + 100),
byte(tm.Month()),
byte(tm.Day()),
byte(tm.Hour() + 1),
byte(tm.Minute() + 1),
byte(tm.Second() + 1),
0,
0,
0,
0,
}
binary.BigEndian.PutUint32(buf[7:11], uint32(tm.Nanosecond()))
return buf
}
func buildMySQLBinaryTimestamp(tm time.Time) []byte {
if tm.Location() != time.UTC {
tm = tm.In(time.UTC)
}
buf := []byte{11, 0, 0, byte(tm.Month()), byte(tm.Day()), byte(tm.Hour()), byte(tm.Minute()), byte(tm.Second()), 0, 0, 0, 0}
binary.LittleEndian.PutUint16(buf[1:3], uint16(tm.Year()))
binary.LittleEndian.PutUint32(buf[8:12], uint32(tm.Nanosecond()/1000))
return buf
}
func TestNormalizeQueryValueWithDBTypeAndDialect_MySQLDateOnly(t *testing.T) {
input := time.Date(2025, 10, 1, 0, 0, 0, 0, time.Local)
@@ -248,6 +397,24 @@ func TestNormalizeQueryValueWithDBTypeAndDialect_DatetimeKeepsTime(t *testing.T)
}
}
func TestNormalizeQueryValueWithDBTypeAndDialect_OceanBaseOracleDateMidnightDisplaysDateOnly(t *testing.T) {
input := time.Date(2025, 10, 1, 0, 0, 0, 0, time.UTC)
got := normalizeQueryValueWithDBTypeAndDialect(input, "DATE", oceanBaseOracleScanDialect)
if got != "2025-10-01" {
t.Fatalf("OceanBase Oracle DATE 的午夜值应只展示日期,实际=%v(%T)", got, got)
}
}
func TestNormalizeQueryValueWithDBTypeAndDialect_OceanBaseOracleDateKeepsNonMidnightTime(t *testing.T) {
input := time.Date(2025, 10, 1, 13, 14, 15, 0, time.UTC)
got := normalizeQueryValueWithDBTypeAndDialect(input, "DATE", oceanBaseOracleScanDialect)
if got != "2025-10-01T13:14:15Z" {
t.Fatalf("OceanBase Oracle DATE 非午夜值应保留时间,实际=%v(%T)", got, got)
}
}
func TestNormalizeQueryValueWithDBType_ZeroTemporalValues(t *testing.T) {
zero := time.Time{}
cases := []struct {

View File

@@ -40,6 +40,78 @@ func (scanRowsDuplicateConn) QueryContext(_ context.Context, query string, args
},
}, nil
}
if query == "SELECT timestamp_columns" {
raw := buildOracleBinaryTimestamp(time.Date(2026, 6, 16, 12, 34, 56, 123456000, time.UTC))
return &scanRowsDuplicateRows{
columns: []string{"created_at"},
columnTypes: []string{"TYPE_CA"},
rows: [][]driver.Value{
{
string(raw),
},
},
}, nil
}
if query == "SELECT timestamp_precision_columns" {
raw := buildOracleBinaryTimestamp(time.Date(2026, 6, 16, 12, 34, 56, 123456000, time.UTC))
return &scanRowsDuplicateRows{
columns: []string{"created_at"},
columnTypes: []string{"TIMESTAMP(6)"},
rows: [][]driver.Value{
{
string(raw),
},
},
}, nil
}
if query == "SELECT timestamp_generic_carrier_columns" {
raw := buildOracleBinaryTimestamp(time.Date(2026, 6, 16, 12, 34, 56, 123456000, time.UTC))
return &scanRowsDuplicateRows{
columns: []string{"created_at"},
columnTypes: []string{"VARCHAR2"},
rows: [][]driver.Value{
{
string(raw),
},
},
}, nil
}
if query == "SELECT timestamp_unknown_type_columns" {
raw := buildOracleBinaryTimestamp(time.Date(2026, 6, 16, 12, 34, 56, 123456000, time.UTC))
return &scanRowsDuplicateRows{
columns: []string{"created_at"},
columnTypes: []string{""},
rows: [][]driver.Value{
{
string(raw),
},
},
}, nil
}
if query == "SELECT timestamp_mysql_encoded_columns" {
raw := buildMySQLBinaryTimestamp(time.Date(2026, 6, 16, 12, 34, 56, 123456000, time.UTC))
return &scanRowsDuplicateRows{
columns: []string{"created_at"},
columnTypes: []string{"TYPE_CA"},
rows: [][]driver.Value{
{
string(raw),
},
},
}, nil
}
if query == "SELECT timestamp_type_ca_live_columns" {
raw := []byte{20, 26, 6, 16, 16, 46, 23, 96, 196, 119, 9, 6}
return &scanRowsDuplicateRows{
columns: []string{"created_at"},
columnTypes: []string{"TYPE_CA"},
rows: [][]driver.Value{
{
string(raw),
},
},
}, nil
}
return &scanRowsDuplicateRows{
columns: []string{"id", "id", "name"},
rows: [][]driver.Value{
@@ -184,3 +256,267 @@ func TestScanRowsForOracleDialectKeepsDateTime(t *testing.T) {
t.Fatalf("Oracle DATE 应保留 datetime 语义,实际=%v(%T)", data[0]["ship_date"], data[0]["ship_date"])
}
}
func TestScanRowsForOceanBaseOracleDialectFormatsMidnightDateOnly(t *testing.T) {
t.Parallel()
registerScanRowsDuplicateDriverOnce.Do(func() {
sql.Register(scanRowsDuplicateDriverName, scanRowsDuplicateDriver{})
})
dbConn, err := sql.Open(scanRowsDuplicateDriverName, "")
if err != nil {
t.Fatalf("open date scan rows db failed: %v", err)
}
defer dbConn.Close()
rows, err := dbConn.QueryContext(context.Background(), "SELECT date_columns")
if err != nil {
t.Fatalf("query date scan rows db failed: %v", err)
}
defer rows.Close()
data, _, err := scanRowsForDialect(rows, oceanBaseOracleScanDialect)
if err != nil {
t.Fatalf("scanRowsForDialect returned error: %v", err)
}
if len(data) != 1 {
t.Fatalf("expected one row, got=%d", len(data))
}
if data[0]["ship_date"] != "2025-10-01" {
t.Fatalf("OceanBase Oracle DATE 的午夜值应展示为日期,实际=%v(%T)", data[0]["ship_date"], data[0]["ship_date"])
}
if data[0]["created_at"] != "2025-10-01T13:14:15Z" {
t.Fatalf("OceanBase Oracle DATETIME 应保留时间,实际=%v(%T)", data[0]["created_at"], data[0]["created_at"])
}
}
func TestOracleDBQueryUsesCustomScanDialect(t *testing.T) {
t.Parallel()
registerScanRowsDuplicateDriverOnce.Do(func() {
sql.Register(scanRowsDuplicateDriverName, scanRowsDuplicateDriver{})
})
dbConn, err := sql.Open(scanRowsDuplicateDriverName, "")
if err != nil {
t.Fatalf("open date scan rows db failed: %v", err)
}
defer dbConn.Close()
oracleDB := &OracleDB{conn: dbConn, scanDialect: oceanBaseOracleScanDialect}
data, _, err := oracleDB.Query("SELECT date_columns")
if err != nil {
t.Fatalf("OracleDB.Query returned error: %v", err)
}
if len(data) != 1 {
t.Fatalf("expected one row, got=%d", len(data))
}
if data[0]["ship_date"] != "2025-10-01" {
t.Fatalf("OracleDB 自定义扫描方言未生效,实际=%v(%T)", data[0]["ship_date"], data[0]["ship_date"])
}
}
func TestScanRowsForOceanBaseOracleDialectDecodesBinaryTimestampString(t *testing.T) {
t.Parallel()
registerScanRowsDuplicateDriverOnce.Do(func() {
sql.Register(scanRowsDuplicateDriverName, scanRowsDuplicateDriver{})
})
dbConn, err := sql.Open(scanRowsDuplicateDriverName, "")
if err != nil {
t.Fatalf("open timestamp scan rows db failed: %v", err)
}
defer dbConn.Close()
rows, err := dbConn.QueryContext(context.Background(), "SELECT timestamp_columns")
if err != nil {
t.Fatalf("query timestamp scan rows db failed: %v", err)
}
defer rows.Close()
data, columns, err := scanRowsForDialect(rows, oceanBaseOracleScanDialect)
if err != nil {
t.Fatalf("scanRowsForDialect returned error: %v", err)
}
if !reflect.DeepEqual(columns, []string{"created_at"}) {
t.Fatalf("unexpected columns: %v", columns)
}
if len(data) != 1 {
t.Fatalf("expected one row, got=%d", len(data))
}
if data[0]["created_at"] != "2026-06-16T12:34:56.123456Z" {
t.Fatalf("OceanBase Oracle 二进制 TIMESTAMP 应解码为 RFC3339实际=%v(%T)", data[0]["created_at"], data[0]["created_at"])
}
}
func TestScanRowsForOceanBaseOracleDialectDecodesBinaryTimestampStringWithPrecisionType(t *testing.T) {
t.Parallel()
registerScanRowsDuplicateDriverOnce.Do(func() {
sql.Register(scanRowsDuplicateDriverName, scanRowsDuplicateDriver{})
})
dbConn, err := sql.Open(scanRowsDuplicateDriverName, "")
if err != nil {
t.Fatalf("open timestamp precision scan rows db failed: %v", err)
}
defer dbConn.Close()
rows, err := dbConn.QueryContext(context.Background(), "SELECT timestamp_precision_columns")
if err != nil {
t.Fatalf("query timestamp precision scan rows db failed: %v", err)
}
defer rows.Close()
data, columns, err := scanRowsForDialect(rows, oceanBaseOracleScanDialect)
if err != nil {
t.Fatalf("scanRowsForDialect returned error: %v", err)
}
if !reflect.DeepEqual(columns, []string{"created_at"}) {
t.Fatalf("unexpected columns: %v", columns)
}
if len(data) != 1 {
t.Fatalf("expected one row, got=%d", len(data))
}
if data[0]["created_at"] != "2026-06-16T12:34:56.123456Z" {
t.Fatalf("OceanBase Oracle TIMESTAMP(6) 应解码为 RFC3339实际=%v(%T)", data[0]["created_at"], data[0]["created_at"])
}
}
func TestScanRowsForOceanBaseOracleDialectDecodesBinaryTimestampStringWithGenericCarrierType(t *testing.T) {
t.Parallel()
registerScanRowsDuplicateDriverOnce.Do(func() {
sql.Register(scanRowsDuplicateDriverName, scanRowsDuplicateDriver{})
})
dbConn, err := sql.Open(scanRowsDuplicateDriverName, "")
if err != nil {
t.Fatalf("open timestamp generic-carrier scan rows db failed: %v", err)
}
defer dbConn.Close()
rows, err := dbConn.QueryContext(context.Background(), "SELECT timestamp_generic_carrier_columns")
if err != nil {
t.Fatalf("query timestamp generic-carrier scan rows db failed: %v", err)
}
defer rows.Close()
data, columns, err := scanRowsForDialect(rows, oceanBaseOracleScanDialect)
if err != nil {
t.Fatalf("scanRowsForDialect returned error: %v", err)
}
if !reflect.DeepEqual(columns, []string{"created_at"}) {
t.Fatalf("unexpected columns: %v", columns)
}
if len(data) != 1 {
t.Fatalf("expected one row, got=%d", len(data))
}
if data[0]["created_at"] != "2026-06-16T12:34:56.123456Z" {
t.Fatalf("OceanBase Oracle 泛型载体类型的 TIMESTAMP 应解码为 RFC3339实际=%v(%T)", data[0]["created_at"], data[0]["created_at"])
}
}
func TestScanRowsForOceanBaseOracleDialectDecodesBinaryTimestampStringWithoutTypeName(t *testing.T) {
t.Parallel()
registerScanRowsDuplicateDriverOnce.Do(func() {
sql.Register(scanRowsDuplicateDriverName, scanRowsDuplicateDriver{})
})
dbConn, err := sql.Open(scanRowsDuplicateDriverName, "")
if err != nil {
t.Fatalf("open timestamp unknown-type scan rows db failed: %v", err)
}
defer dbConn.Close()
rows, err := dbConn.QueryContext(context.Background(), "SELECT timestamp_unknown_type_columns")
if err != nil {
t.Fatalf("query timestamp unknown-type scan rows db failed: %v", err)
}
defer rows.Close()
data, columns, err := scanRowsForDialect(rows, oceanBaseOracleScanDialect)
if err != nil {
t.Fatalf("scanRowsForDialect returned error: %v", err)
}
if !reflect.DeepEqual(columns, []string{"created_at"}) {
t.Fatalf("unexpected columns: %v", columns)
}
if len(data) != 1 {
t.Fatalf("expected one row, got=%d", len(data))
}
if data[0]["created_at"] != "2026-06-16T12:34:56.123456Z" {
t.Fatalf("OceanBase Oracle 空类型名的 TIMESTAMP 应解码为 RFC3339实际=%v(%T)", data[0]["created_at"], data[0]["created_at"])
}
}
func TestScanRowsForOceanBaseOracleDialectDecodesMySQLLengthEncodedTimestampString(t *testing.T) {
t.Parallel()
registerScanRowsDuplicateDriverOnce.Do(func() {
sql.Register(scanRowsDuplicateDriverName, scanRowsDuplicateDriver{})
})
dbConn, err := sql.Open(scanRowsDuplicateDriverName, "")
if err != nil {
t.Fatalf("open mysql-encoded timestamp scan rows db failed: %v", err)
}
defer dbConn.Close()
rows, err := dbConn.QueryContext(context.Background(), "SELECT timestamp_mysql_encoded_columns")
if err != nil {
t.Fatalf("query mysql-encoded timestamp scan rows db failed: %v", err)
}
defer rows.Close()
data, columns, err := scanRowsForDialect(rows, oceanBaseOracleScanDialect)
if err != nil {
t.Fatalf("scanRowsForDialect returned error: %v", err)
}
if !reflect.DeepEqual(columns, []string{"created_at"}) {
t.Fatalf("unexpected columns: %v", columns)
}
if len(data) != 1 {
t.Fatalf("expected one row, got=%d", len(data))
}
if data[0]["created_at"] != "2026-06-16T12:34:56.123456Z" {
t.Fatalf("OceanBase Oracle length-encoded TIMESTAMP 应解码为 RFC3339实际=%v(%T)", data[0]["created_at"], data[0]["created_at"])
}
}
func TestScanRowsForOceanBaseOracleDialectDecodesTypeCALiveTimestampString(t *testing.T) {
t.Parallel()
registerScanRowsDuplicateDriverOnce.Do(func() {
sql.Register(scanRowsDuplicateDriverName, scanRowsDuplicateDriver{})
})
dbConn, err := sql.Open(scanRowsDuplicateDriverName, "")
if err != nil {
t.Fatalf("open timestamp scan rows db failed: %v", err)
}
defer dbConn.Close()
rows, err := dbConn.QueryContext(context.Background(), "SELECT timestamp_type_ca_live_columns")
if err != nil {
t.Fatalf("query timestamp scan rows db failed: %v", err)
}
defer rows.Close()
data, columns, err := scanRowsForDialect(rows, oceanBaseOracleScanDialect)
if err != nil {
t.Fatalf("scanRowsForDialect returned error: %v", err)
}
if !reflect.DeepEqual(columns, []string{"created_at"}) {
t.Fatalf("unexpected columns: %v", columns)
}
if len(data) != 1 {
t.Fatalf("expected one row, got=%d", len(data))
}
if data[0]["created_at"] != "2026-06-16T16:46:23.158844Z" {
t.Fatalf("OceanBase Oracle TYPE_CA live TIMESTAMP 应解码为 RFC3339实际=%v(%T)", data[0]["created_at"], data[0]["created_at"])
}
}

View File

@@ -906,7 +906,7 @@ func isOperationAllowed(level ai.SQLPermissionLevel, opType ai.SQLOperationType)
case ai.PermissionReadWrite:
return opType == ai.SQLOpQuery || opType == ai.SQLOpDML
case ai.PermissionFull:
return opType == ai.SQLOpQuery || opType == ai.SQLOpDML || opType == ai.SQLOpDDL
return true
default:
return opType == ai.SQLOpQuery
}
@@ -945,7 +945,7 @@ func safetyLevelRuleText(level ai.SQLPermissionLevel) string {
case ai.PermissionReadWrite:
return "读写模式仅允许查询和 DML 语句。"
case ai.PermissionFull:
return "完全模式允许查询、DML 和 DDL未识别操作仍会被阻止。"
return "完全模式允许所有 SQL 操作;高风险或未识别语句仍会要求确认。"
default:
return "只读模式仅允许查询语句。"
}

View File

@@ -554,6 +554,46 @@ func TestExecuteSQLAllowsDDLWhenAISafetyIsFullAndAllowMutating(t *testing.T) {
}
}
func TestExecuteSQLAllowsOtherStatementsWhenAISafetyIsFullAndAllowMutating(t *testing.T) {
backend := &fakeBackend{
editableConnection: connection.SavedConnectionView{
ID: "oracle-main",
Config: connection.ConnectionConfig{
Type: "oracle",
Database: "app",
},
},
inspection: appcore.SQLInspection{
StatementCount: 1,
ReadOnly: false,
Statements: []appcore.SQLStatementInspection{
{Index: 1, Keyword: "call", ReadOnly: false},
},
},
safetyLevel: ai.PermissionFull,
queryResult: connection.QueryResult{
Success: true,
Data: []connection.ResultSetData{},
},
}
service := NewService(backend)
result, _, err := service.ExecuteSQL(context.Background(), nil, executeSQLArgs{
ConnectionID: "oracle-main",
SQL: "CALL bulk_insert_users(100000)",
AllowMutating: true,
})
if err != nil {
t.Fatalf("ExecuteSQL returned error: %v", err)
}
if result == nil || result.IsError {
t.Fatalf("expected success result, got %#v", result)
}
if !backend.queryCalled {
t.Fatalf("expected SQL to be executed")
}
}
func TestExecuteSQLNormalizesAndTruncatesResultSets(t *testing.T) {
backend := &fakeBackend{
editableConnection: connection.SavedConnectionView{