mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-06-22 14:33:44 +08:00
🐛 fix(oceanbase/data-grid): 修复 Oracle 时间字段显示编辑与结果视图异常
- 修复 OceanBase Oracle DATE 与 TIMESTAMP 的解码、展示和编辑精度丢失问题 - 修复查询结果与数据视图的行号显示、分页页数和日期列展示口径 - 打通 Oracle 与 OceanBase 会话执行链路的扫描方言透传 - 补齐 DBQuery、DataGrid temporal 和 OceanBase 结果链路回归测试
This commit is contained in:
@@ -222,7 +222,9 @@ describe('DataGrid layout', () => {
|
||||
paginationSummaryText="当前 24 条 / 共 24 条"
|
||||
paginationControlTotal={24}
|
||||
paginationTotalPages={1}
|
||||
paginationPageText="第 1 / 1 页"
|
||||
paginationPageSizeOptions={['100', '200']}
|
||||
showKnownPageCount
|
||||
onPageChange={() => {}}
|
||||
onPageSizeChange={() => {}}
|
||||
onV2PageStep={() => {}}
|
||||
@@ -233,6 +235,37 @@ describe('DataGrid layout', () => {
|
||||
expect(markup).not.toContain('第 1 / 1 页');
|
||||
});
|
||||
|
||||
it('keeps unknown-total pagination in sequential mode instead of pretending total pages are known', () => {
|
||||
const markup = renderToStaticMarkup(
|
||||
<DataGrid
|
||||
data={[
|
||||
{
|
||||
__gonavi_row_key__: 'row-1',
|
||||
id: 1,
|
||||
name: 'alpha',
|
||||
},
|
||||
]}
|
||||
columnNames={['id', 'name']}
|
||||
loading={false}
|
||||
tableName="users"
|
||||
dbName="main"
|
||||
connectionId="conn-1"
|
||||
readOnly
|
||||
pagination={{
|
||||
current: 3,
|
||||
pageSize: 100,
|
||||
total: 400,
|
||||
totalKnown: false,
|
||||
}}
|
||||
onPageChange={() => {}}
|
||||
/>,
|
||||
);
|
||||
|
||||
expect(markup).toContain('第 3 页');
|
||||
expect(markup).not.toContain('<strong>3</strong><span>/</span><span>4</span>');
|
||||
expect(markup).not.toContain('data-grid-pagination-jump="true"');
|
||||
});
|
||||
|
||||
it('renders the v2 DataGrid toolbar using the redesigned topbar hooks', () => {
|
||||
const markup = renderToStaticMarkup(
|
||||
<DataGrid
|
||||
@@ -281,12 +314,62 @@ describe('DataGrid layout', () => {
|
||||
expect(markup).toContain('AI 洞察');
|
||||
});
|
||||
|
||||
it('renders a non-data row number column when enabled', () => {
|
||||
const markup = renderToStaticMarkup(
|
||||
<DataGrid
|
||||
data={[
|
||||
{
|
||||
__gonavi_row_key__: 'row-1',
|
||||
id: 1,
|
||||
name: 'alpha',
|
||||
},
|
||||
]}
|
||||
columnNames={['id', 'name']}
|
||||
loading={false}
|
||||
tableName="events"
|
||||
dbName="main"
|
||||
connectionId="conn-1"
|
||||
readOnly
|
||||
showRowNumberColumn
|
||||
pagination={{
|
||||
current: 2,
|
||||
pageSize: 50,
|
||||
total: 51,
|
||||
}}
|
||||
onPageChange={() => {}}
|
||||
/>,
|
||||
);
|
||||
|
||||
expect(markup).toContain('aria-label="行号"');
|
||||
expect(markup).toContain('<span aria-label="行号">#</span>');
|
||||
expect(markup).not.toContain('>行号<');
|
||||
expect(markup).toContain('data-grid-row-number-title="true"');
|
||||
expect(markup).toContain('data-grid-column-title-single-line="true"');
|
||||
expect(markup).toContain('justify-content:center');
|
||||
expect(markup).toContain('align-items:center');
|
||||
expect(markup).toContain('min-height:var(--gonavi-header-min-height, 40px)');
|
||||
expect(markup).toContain('text-align:center');
|
||||
expect(markup).toContain('padding-inline:0');
|
||||
expect(markup).toContain('vertical-align:middle');
|
||||
expect(markup).toContain('data-grid-row-number="true"');
|
||||
expect(markup).toContain('51');
|
||||
});
|
||||
|
||||
it('clears modified cell markers when refreshing the grid', () => {
|
||||
const source = readFileSync(new URL('./DataGrid.tsx', import.meta.url), 'utf8');
|
||||
|
||||
expect(source).toMatch(/const handleRefreshGrid = useCallback\(\(\) => \{[\s\S]*setModifiedColumns\(\{\}\);[\s\S]*if \(onReload\) onReload\(\);[\s\S]*\}, \[clearAutoCommitTimer, onReload\]\);/);
|
||||
});
|
||||
|
||||
it('routes temporal inline editors through the current connection config', () => {
|
||||
const source = readFileSync(new URL('./DataGrid.tsx', import.meta.url), 'utf8');
|
||||
|
||||
expect(source).toContain('const pickerType = getTemporalPickerType(columnType, dbType, connectionConfig);');
|
||||
expect(source).toContain('const pickerType = getTemporalPickerType(columnType, dbType, currentConnConfig);');
|
||||
expect(source).toContain('cellProps.connectionConfig = currentConnConfig;');
|
||||
expect(source).toContain('format={getTemporalPickerFormat(pickerType)}');
|
||||
});
|
||||
|
||||
it('renders a cell-level undo action in the v2 context menu for modified cells', () => {
|
||||
const markup = renderToStaticMarkup(
|
||||
<V2CellContextMenuView
|
||||
@@ -305,6 +388,18 @@ describe('DataGrid layout', () => {
|
||||
expect(formatCellDisplayText('2026-05-10T09:12:33.456+08:00')).toBe('2026-05-10 09:12:33.456');
|
||||
});
|
||||
|
||||
it('collapses OceanBase Oracle DATE midnight values to date-only text', () => {
|
||||
const oceanBaseOracleConfig = {
|
||||
type: 'oceanbase',
|
||||
oceanBaseProtocol: 'oracle',
|
||||
} as any;
|
||||
|
||||
expect(formatCellDisplayText('2026-06-16T00:00:00Z', 'DATE', oceanBaseOracleConfig)).toBe('2026-06-16');
|
||||
expect(formatCellDisplayText('2026-06-16 00:00:00', 'DATE', oceanBaseOracleConfig)).toBe('2026-06-16');
|
||||
expect(formatCellDisplayText('2026-06-16T13:14:15Z', 'DATE', oceanBaseOracleConfig)).toBe('2026-06-16 13:14:15');
|
||||
expect(formatCellDisplayText('2026-06-16T00:00:00Z', 'DATE', { type: 'oracle' } as any)).toBe('2026-06-16 00:00:00');
|
||||
});
|
||||
|
||||
it('renders bit column hex values as decimal flags', () => {
|
||||
expect(formatCellDisplayText('0x00', 'bit(1)')).toBe('0');
|
||||
expect(formatCellDisplayText('0x01', 'bit(1)')).toBe('1');
|
||||
|
||||
@@ -31,12 +31,13 @@ import { buildOrderBySQL, buildPaginatedSelectSQL, buildWhereSQL, escapeLiteral,
|
||||
import { isMacLikePlatform, normalizeOpacityForPlatform, resolveAppearanceValues } from '../utils/appearance';
|
||||
import { getDataSourceCapabilities, resolveDataSourceType } from '../utils/dataSourceCapabilities';
|
||||
import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig';
|
||||
import { normalizeOceanBaseProtocol } from '../utils/oceanBaseProtocol';
|
||||
import {
|
||||
getDensityParams,
|
||||
resolveDataTableColumnWidth,
|
||||
resolveDataTableVerticalBorderColor,
|
||||
} from '../utils/dataGridDisplay';
|
||||
import { resolvePaginationSummaryText, resolvePaginationTotalForControl } from '../utils/dataGridPagination';
|
||||
import { resolvePaginationPageText, resolvePaginationSummaryText, resolvePaginationTotalForControl } from '../utils/dataGridPagination';
|
||||
import { resolveGridSortInfoFromTableSorter } from '../utils/dataGridSort';
|
||||
import {
|
||||
calculateExternalHorizontalScrollInnerWidth,
|
||||
@@ -71,10 +72,12 @@ import { DEFAULT_SHORTCUT_OPTIONS, getShortcutPlatform, resolveShortcutDisplay }
|
||||
import {
|
||||
TEMPORAL_FORMATS,
|
||||
formatFromDayjs,
|
||||
getTemporalPickerFormat,
|
||||
getTemporalPickerType,
|
||||
isTemporalColumnType,
|
||||
parseToDayjs,
|
||||
resolveTemporalEditorSaveValue,
|
||||
type TemporalConnectionLike,
|
||||
type TemporalPickerType,
|
||||
} from './dataGridTemporal';
|
||||
import {
|
||||
@@ -182,6 +185,7 @@ class DataGridErrorBoundary extends React.Component<
|
||||
|
||||
// 内部行标识字段:避免与真实业务字段(如 `key` 列)冲突。
|
||||
export const GONAVI_ROW_KEY = '__gonavi_row_key__';
|
||||
export const GONAVI_ROW_NUMBER_COLUMN_KEY = '__gonavi_row_number__';
|
||||
|
||||
// Cell key helpers for batch selection/fill.
|
||||
// Use a control character separator to avoid collisions with rowKey/columnName contents (e.g. `new-123`).
|
||||
@@ -189,6 +193,7 @@ const CELL_KEY_SEP = '\u0001';
|
||||
const CELL_SELECTION_DRAG_THRESHOLD_PX = 4;
|
||||
const DATE_TIME_CACHE_LIMIT = 2000;
|
||||
const TABLE_CELL_PREVIEW_MAX_CHARS = 240;
|
||||
const ROW_NUMBER_COLUMN_WIDTH = 58;
|
||||
const DATA_EDIT_AUTO_COMMIT_DELAY_OPTIONS = [
|
||||
{ value: 3000, label: '3 秒' },
|
||||
{ value: 5000, label: '5 秒' },
|
||||
@@ -281,7 +286,46 @@ const normalizeBitHexDisplayText = (val: any, columnType?: string): string | nul
|
||||
}
|
||||
};
|
||||
|
||||
export const formatCellDisplayText = (val: any, columnType?: string): string => {
|
||||
type CellDisplayConnectionLike = TemporalConnectionLike;
|
||||
|
||||
const isDateOnlyColumnType = (columnType?: string): boolean => {
|
||||
const normalized = String(columnType || '').trim().toLowerCase();
|
||||
if (!normalized) return false;
|
||||
const base = normalized.split(/[ (]/)[0];
|
||||
return base === 'date' || base === 'newdate';
|
||||
};
|
||||
|
||||
const isOceanBaseOracleDisplayConnection = (connectionConfig?: CellDisplayConnectionLike): boolean => {
|
||||
if (!connectionConfig) return false;
|
||||
const type = String(connectionConfig.type || '').trim().toLowerCase();
|
||||
const driver = String(connectionConfig.driver || '').trim().toLowerCase();
|
||||
return (type === 'oceanbase' || driver === 'oceanbase')
|
||||
&& normalizeOceanBaseProtocol(connectionConfig.oceanBaseProtocol) === 'oracle';
|
||||
};
|
||||
|
||||
const normalizeOceanBaseOracleDateDisplayText = (
|
||||
val: string,
|
||||
columnType?: string,
|
||||
connectionConfig?: CellDisplayConnectionLike,
|
||||
): string | null => {
|
||||
if (!isDateOnlyColumnType(columnType) || !isOceanBaseOracleDisplayConnection(connectionConfig)) {
|
||||
return null;
|
||||
}
|
||||
const trimmed = String(val || '').trim();
|
||||
if (!trimmed) return trimmed;
|
||||
const match = trimmed.match(
|
||||
/^(\d{4}-\d{2}-\d{2})(?:[T ](\d{2}:\d{2}:\d{2})(\.\d+)?(?:\s*(?:Z|[+-]\d{2}:?\d{2})(?:\s+[A-Za-z_\/+-]+)?)?)?$/
|
||||
);
|
||||
if (!match) return null;
|
||||
const [, datePart, timePart, fractionPart] = match;
|
||||
if (!timePart) return datePart;
|
||||
if (timePart === '00:00:00' && (!fractionPart || /^\.0+$/.test(fractionPart))) {
|
||||
return datePart;
|
||||
}
|
||||
return null;
|
||||
};
|
||||
|
||||
export const formatCellDisplayText = (val: any, columnType?: string, connectionConfig?: CellDisplayConnectionLike): string => {
|
||||
try {
|
||||
if (val === null) return 'NULL';
|
||||
const bitText = normalizeBitHexDisplayText(val, columnType);
|
||||
@@ -310,6 +354,10 @@ export const formatCellDisplayText = (val: any, columnType?: string): string =>
|
||||
}
|
||||
}
|
||||
if (typeof val === 'string') {
|
||||
const oceanBaseDateOnly = normalizeOceanBaseOracleDateDisplayText(val, columnType, connectionConfig);
|
||||
if (oceanBaseDateOnly !== null) {
|
||||
return oceanBaseDateOnly.length > TABLE_CELL_PREVIEW_MAX_CHARS ? `${oceanBaseDateOnly.slice(0, TABLE_CELL_PREVIEW_MAX_CHARS)}…` : oceanBaseDateOnly;
|
||||
}
|
||||
const normalized = normalizeDateTimeString(val);
|
||||
return normalized.length > TABLE_CELL_PREVIEW_MAX_CHARS ? `${normalized.slice(0, TABLE_CELL_PREVIEW_MAX_CHARS)}…` : normalized;
|
||||
}
|
||||
@@ -320,12 +368,16 @@ export const formatCellDisplayText = (val: any, columnType?: string): string =>
|
||||
}
|
||||
};
|
||||
|
||||
const formatClipboardCellText = (val: any, columnType?: string): string => {
|
||||
const formatClipboardCellText = (val: any, columnType?: string, connectionConfig?: CellDisplayConnectionLike): string => {
|
||||
try {
|
||||
if (val === null || val === undefined) return 'NULL';
|
||||
const bitText = normalizeBitHexDisplayText(val, columnType);
|
||||
if (bitText !== null) return bitText;
|
||||
if (typeof val === 'string') return normalizeDateTimeString(val);
|
||||
if (typeof val === 'string') {
|
||||
const oceanBaseDateOnly = normalizeOceanBaseOracleDateDisplayText(val, columnType, connectionConfig);
|
||||
if (oceanBaseDateOnly !== null) return oceanBaseDateOnly;
|
||||
return normalizeDateTimeString(val);
|
||||
}
|
||||
if (typeof val === 'object') {
|
||||
try {
|
||||
return JSON.stringify(val);
|
||||
@@ -346,6 +398,7 @@ const buildClipboardTsv = (
|
||||
rows: Array<Record<string, any>>,
|
||||
columnNames: string[],
|
||||
getColumnType?: (columnName: string) => string | undefined,
|
||||
connectionConfig?: CellDisplayConnectionLike,
|
||||
): string => {
|
||||
if (!Array.isArray(rows) || rows.length === 0 || !Array.isArray(columnNames) || columnNames.length === 0) {
|
||||
return '';
|
||||
@@ -353,7 +406,7 @@ const buildClipboardTsv = (
|
||||
const header = columnNames.map(normalizeClipboardTsvCell).join('\t');
|
||||
const lines = rows.map((row) => (
|
||||
columnNames
|
||||
.map((columnName) => normalizeClipboardTsvCell(formatClipboardCellText(row?.[columnName], getColumnType?.(columnName))))
|
||||
.map((columnName) => normalizeClipboardTsvCell(formatClipboardCellText(row?.[columnName], getColumnType?.(columnName), connectionConfig)))
|
||||
.join('\t')
|
||||
));
|
||||
return [header, ...lines].join('\n');
|
||||
@@ -382,8 +435,8 @@ const renderHighlightedCellText = (text: string, query: string): React.ReactNode
|
||||
return <>{nodes}</>;
|
||||
};
|
||||
|
||||
const renderCellDisplayValue = (val: any, query: string, columnType?: string): React.ReactNode => {
|
||||
const text = formatCellDisplayText(val, columnType);
|
||||
const renderCellDisplayValue = (val: any, query: string, columnType?: string, connectionConfig?: CellDisplayConnectionLike): React.ReactNode => {
|
||||
const text = formatCellDisplayText(val, columnType, connectionConfig);
|
||||
const content = renderHighlightedCellText(text, query);
|
||||
if (val === null) return <span style={{ color: '#ccc' }}>{content}</span>;
|
||||
return content;
|
||||
@@ -778,6 +831,7 @@ interface EditableCellProps {
|
||||
focusCell?: (record: Item, dataIndex: string, title: React.ReactNode) => void;
|
||||
columnType?: string;
|
||||
dbType?: string;
|
||||
connectionConfig?: CellDisplayConnectionLike;
|
||||
inputCellPadding?: React.CSSProperties;
|
||||
as?: any;
|
||||
modifiedColumns?: Record<string, Set<string>>;
|
||||
@@ -828,6 +882,9 @@ const areEditableCellPropsEqual = (prevProps: EditableCellProps, nextProps: Edit
|
||||
if (prevProps.title !== nextProps.title) return false;
|
||||
if (prevProps.columnType !== nextProps.columnType) return false;
|
||||
if (prevProps.dbType !== nextProps.dbType) return false;
|
||||
if ((prevProps.connectionConfig?.type ?? null) !== (nextProps.connectionConfig?.type ?? null)) return false;
|
||||
if ((prevProps.connectionConfig?.driver ?? null) !== (nextProps.connectionConfig?.driver ?? null)) return false;
|
||||
if ((prevProps.connectionConfig?.oceanBaseProtocol ?? null) !== (nextProps.connectionConfig?.oceanBaseProtocol ?? null)) return false;
|
||||
if (prevProps.darkMode !== nextProps.darkMode) return false;
|
||||
if (prevProps.as !== nextProps.as) return false;
|
||||
if (prevProps.handleSave !== nextProps.handleSave) return false;
|
||||
@@ -866,6 +923,7 @@ const EditableCell: React.FC<EditableCellProps> = React.memo(({
|
||||
focusCell,
|
||||
columnType,
|
||||
dbType,
|
||||
connectionConfig,
|
||||
inputCellPadding,
|
||||
as: Component = 'td',
|
||||
modifiedColumns,
|
||||
@@ -953,7 +1011,7 @@ const EditableCell: React.FC<EditableCellProps> = React.memo(({
|
||||
|
||||
let childNode = children;
|
||||
|
||||
const pickerType = getTemporalPickerType(columnType, dbType);
|
||||
const pickerType = getTemporalPickerType(columnType, dbType, connectionConfig);
|
||||
const isDateTimeField = !!pickerType && !(/^0{4}-0{2}-0{2}/.test(String(record?.[dataIndex] || '')));
|
||||
|
||||
const isRowDeleted = deletedRowKeys && rowKeyStr && record?.[GONAVI_ROW_KEY] !== undefined
|
||||
@@ -988,7 +1046,7 @@ const EditableCell: React.FC<EditableCellProps> = React.memo(({
|
||||
style={{ width: '100%' }}
|
||||
showTime
|
||||
showNow={false}
|
||||
format={TEMPORAL_FORMATS[pickerType]}
|
||||
format={getTemporalPickerFormat(pickerType)}
|
||||
renderExtraFooter={() => (
|
||||
<a
|
||||
style={{ padding: '0 2px' }}
|
||||
@@ -1210,6 +1268,7 @@ interface DataGridProps {
|
||||
pkColumns?: string[];
|
||||
editLocator?: EditRowLocator;
|
||||
readOnly?: boolean;
|
||||
showRowNumberColumn?: boolean;
|
||||
onReload?: () => void;
|
||||
onSort?: (field: string, order: string) => void;
|
||||
onPageChange?: (page: number, size: number) => void;
|
||||
@@ -1499,7 +1558,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
resultExportAllSql,
|
||||
onReload, onSort, onPageChange, pagination, onRequestTotalCount, onCancelTotalCount, sortInfoExternal, showFilter, onToggleFilter, exportSqlWithFilter, onApplyFilter, appliedFilterConditions, quickWhereCondition,
|
||||
onApplyQuickWhereCondition,
|
||||
scrollSnapshot, onScrollSnapshotChange, toolbarExtraActions
|
||||
scrollSnapshot, onScrollSnapshotChange, toolbarExtraActions, showRowNumberColumn = false
|
||||
}) => {
|
||||
const connections = useStore(state => state.connections);
|
||||
const addTab = useStore(state => state.addTab);
|
||||
@@ -4448,7 +4507,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
}
|
||||
|
||||
const columnType = (columnMetaMap[dataIndex] || columnMetaMapByLowerName[dataIndex.toLowerCase()])?.type;
|
||||
const pickerType = getTemporalPickerType(columnType, dbType);
|
||||
const pickerType = getTemporalPickerType(columnType, dbType, currentConnConfig);
|
||||
const isDateTimeField = !!pickerType && !(/^0{4}-0{2}-0{2}/.test(String(raw || '')));
|
||||
const fieldName = getCellFieldName(record, dataIndex);
|
||||
if (isDateTimeField) {
|
||||
@@ -4463,7 +4522,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
title,
|
||||
columnType,
|
||||
});
|
||||
}, [canModifyData, columnMetaMap, columnMetaMapByLowerName, dbType, form, openCellEditor, rowKeyStr]);
|
||||
}, [canModifyData, columnMetaMap, columnMetaMapByLowerName, currentConnConfig, dbType, form, openCellEditor, rowKeyStr]);
|
||||
|
||||
const handleVirtualCellActivate = useCallback((record: Item, dataIndex: string, title: React.ReactNode) => {
|
||||
if (!canModifyData) return;
|
||||
@@ -4593,7 +4652,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
return;
|
||||
}
|
||||
|
||||
const pickerType = getTemporalPickerType(editingCell.columnType, dbType);
|
||||
const pickerType = getTemporalPickerType(editingCell.columnType, dbType, currentConnConfig);
|
||||
const isDateTimeField = !!pickerType && !(/^0{4}-0{2}-0{2}/.test(String(record?.[editingCell.dataIndex] || '')));
|
||||
const fieldName = getCellFieldName(record, editingCell.dataIndex);
|
||||
try {
|
||||
@@ -4612,22 +4671,30 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
closeVirtualInlineEditor();
|
||||
}
|
||||
}
|
||||
}, [closeVirtualInlineEditor, dbType, form, handleCellSave, virtualEditingCell]);
|
||||
}, [closeVirtualInlineEditor, currentConnConfig, dbType, form, handleCellSave, virtualEditingCell]);
|
||||
|
||||
const pageFindMatches = useMemo(() => collectDataGridFindMatches(
|
||||
mergedDisplayData,
|
||||
displayColumnNames,
|
||||
normalizedPageFindText,
|
||||
(value, _row, columnName) => formatCellDisplayText(value, (columnMetaMap[columnName] || columnMetaMapByLowerName[columnName.toLowerCase()])?.type),
|
||||
(value, _row, columnName) => formatCellDisplayText(
|
||||
value,
|
||||
(columnMetaMap[columnName] || columnMetaMapByLowerName[columnName.toLowerCase()])?.type,
|
||||
currentConnConfig,
|
||||
),
|
||||
(row, rowIndex) => String(row?.[GONAVI_ROW_KEY] ?? `row-${rowIndex}`),
|
||||
), [mergedDisplayData, displayColumnNames, normalizedPageFindText, columnMetaMap, columnMetaMapByLowerName]);
|
||||
), [mergedDisplayData, displayColumnNames, normalizedPageFindText, columnMetaMap, columnMetaMapByLowerName, currentConnConfig]);
|
||||
|
||||
const pageFindSummary = useMemo(() => summarizeDataGridFindMatches(
|
||||
mergedDisplayData,
|
||||
displayColumnNames,
|
||||
normalizedPageFindText,
|
||||
(value, _row, columnName) => formatCellDisplayText(value, (columnMetaMap[columnName] || columnMetaMapByLowerName[columnName.toLowerCase()])?.type),
|
||||
), [mergedDisplayData, displayColumnNames, normalizedPageFindText, columnMetaMap, columnMetaMapByLowerName]);
|
||||
(value, _row, columnName) => formatCellDisplayText(
|
||||
value,
|
||||
(columnMetaMap[columnName] || columnMetaMapByLowerName[columnName.toLowerCase()])?.type,
|
||||
currentConnConfig,
|
||||
),
|
||||
), [mergedDisplayData, displayColumnNames, normalizedPageFindText, columnMetaMap, columnMetaMapByLowerName, currentConnConfig]);
|
||||
|
||||
useEffect(() => {
|
||||
setActivePageFindMatchIndex(-1);
|
||||
@@ -4734,7 +4801,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
displayMap[col] = toFormText(displayVal);
|
||||
// 日期时间类型: 将字符串值转为 dayjs 对象供 DatePicker 使用
|
||||
const colMeta = columnMetaMap[col] || columnMetaMapByLowerName[col.toLowerCase()];
|
||||
const rowPickerType = getTemporalPickerType(colMeta?.type, dbType);
|
||||
const rowPickerType = getTemporalPickerType(colMeta?.type, dbType, currentConnConfig);
|
||||
if (rowPickerType && displayVal !== null && displayVal !== undefined) {
|
||||
const dVal = parseToDayjs(displayVal, rowPickerType);
|
||||
formMap[col] = dVal;
|
||||
@@ -4751,7 +4818,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
nullCols,
|
||||
formValues: formMap,
|
||||
});
|
||||
}, [canModifyData, mergedDisplayData, data, addedRows, visibleColumnNames, rowKeyStr, columnMetaMap, columnMetaMapByLowerName, dbType, openRowEditor]);
|
||||
}, [addedRows, canModifyData, columnMetaMap, columnMetaMapByLowerName, currentConnConfig, data, dbType, mergedDisplayData, openRowEditor, rowKeyStr, visibleColumnNames]);
|
||||
|
||||
const openCurrentViewRowEditor = useCallback(() => {
|
||||
if (!canModifyData) return;
|
||||
@@ -4916,7 +4983,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
if (!isWritableResultColumn(col, effectiveEditLocator)) return;
|
||||
if (val && dayjs.isDayjs(val)) {
|
||||
const colMeta = columnMetaMap[col] || columnMetaMapByLowerName[col.toLowerCase()];
|
||||
const rowPickerType = getTemporalPickerType(colMeta?.type, dbType);
|
||||
const rowPickerType = getTemporalPickerType(colMeta?.type, dbType, currentConnConfig);
|
||||
convertedValues[col] = formatFromDayjs(val as dayjs.Dayjs, rowPickerType);
|
||||
} else {
|
||||
convertedValues[col] = val;
|
||||
@@ -4935,7 +5002,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
// 日期时间类型: 将 dayjs 对象转回格式化字符串
|
||||
if (nextVal && dayjs.isDayjs(nextVal)) {
|
||||
const colMeta = columnMetaMap[col] || columnMetaMapByLowerName[col.toLowerCase()];
|
||||
const rowPickerType = getTemporalPickerType(colMeta?.type, dbType);
|
||||
const rowPickerType = getTemporalPickerType(colMeta?.type, dbType, currentConnConfig);
|
||||
nextVal = formatFromDayjs(nextVal as dayjs.Dayjs, rowPickerType);
|
||||
}
|
||||
const baseVal = baseRawMap[col];
|
||||
@@ -4950,7 +5017,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
});
|
||||
|
||||
closeRowEditor();
|
||||
}, [rowEditorRowKey, rowEditorForm, addedRows, visibleColumnNames, rowKeyStr, closeRowEditor, effectiveEditLocator, columnMetaMap, columnMetaMapByLowerName, dbType]);
|
||||
}, [addedRows, closeRowEditor, columnMetaMap, columnMetaMapByLowerName, currentConnConfig, dbType, effectiveEditLocator, rowEditorForm, rowEditorRowKey, rowKeyStr, visibleColumnNames]);
|
||||
|
||||
|
||||
const enableVirtual = isTableSurfaceActive;
|
||||
@@ -4985,7 +5052,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
sortOrder: (sortInfo.find(s => s.columnKey === key && s.enabled !== false)?.order || null) as SortOrder | undefined,
|
||||
editable: canModifyData && isWritableResultColumn(key, effectiveEditLocator),
|
||||
render: (text: any) => {
|
||||
const renderedContent = renderCellDisplayValue(text, normalizedPageFindText, displayColumnTypeMap[key]);
|
||||
const renderedContent = renderCellDisplayValue(text, normalizedPageFindText, displayColumnTypeMap[key], currentConnConfig);
|
||||
if (enableVirtual) {
|
||||
return renderedContent;
|
||||
}
|
||||
@@ -5037,7 +5104,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
},
|
||||
}),
|
||||
}));
|
||||
}, [displayColumnNames, columnWidths, sortInfo, handleResizeStart, handleResizeAutoFit, isV2Ui, showColumnHeaderContextMenu, canModifyData, onSort, renderColumnTitle, dataTableDensity, normalizedPageFindText, displayColumnTypeMap, enableVirtual, showColumnComment, showColumnType]);
|
||||
}, [canModifyData, columnWidths, currentConnConfig, dataTableDensity, displayColumnNames, displayColumnTypeMap, enableVirtual, handleResizeAutoFit, handleResizeStart, isV2Ui, normalizedPageFindText, onSort, renderColumnTitle, showColumnComment, showColumnHeaderContextMenu, showColumnType, sortInfo]);
|
||||
|
||||
const mergedColumns = useMemo(() => columns.map((col): ColumnType<any> => {
|
||||
const dataIndex = String(col.dataIndex);
|
||||
@@ -5067,6 +5134,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
cellProps.focusCell = openCellEditor;
|
||||
cellProps.columnType = displayColumnTypeMap[dataIndex];
|
||||
cellProps.dbType = dbType;
|
||||
cellProps.connectionConfig = currentConnConfig;
|
||||
cellProps.inputCellPadding = inputCellPadding;
|
||||
cellProps.modifiedColumns = modifiedColumns;
|
||||
cellProps.rowKeyStr = rowKeyStr;
|
||||
@@ -5097,7 +5165,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
: undefined;
|
||||
const shouldUsePlainVirtualContent = isV2Ui && !modifiedStyle;
|
||||
if (enableVirtual && enableInlineEditableCell) {
|
||||
const pickerType = getTemporalPickerType(columnType, dbType);
|
||||
const pickerType = getTemporalPickerType(columnType, dbType, currentConnConfig);
|
||||
const isDateTimeField = !!pickerType && !(/^0{4}-0{2}-0{2}/.test(String(record?.[dataIndex] || '')));
|
||||
const virtualCellStyle = modifiedStyle ? { ...virtualCellWrapperStyle, ...modifiedStyle } : virtualCellWrapperStyle;
|
||||
const virtualEditable = !!col.editable && !rowDeletedForRender;
|
||||
@@ -5126,7 +5194,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
style={{ width: '100%' }}
|
||||
showTime
|
||||
showNow={false}
|
||||
format={TEMPORAL_FORMATS[pickerType]}
|
||||
format={getTemporalPickerFormat(pickerType)}
|
||||
renderExtraFooter={() => (
|
||||
<a
|
||||
style={{ padding: '0 2px' }}
|
||||
@@ -5211,7 +5279,58 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
return originalRenderContent;
|
||||
}
|
||||
};
|
||||
}), [columns, useInlineEditableBodyCell, enableInlineEditableCell, enableVirtual, handleCellSave, openCellEditor, handleVirtualCellActivate, handleSharedCellContextMenu, displayColumnTypeMap, dbType, inputCellPadding, virtualCellWrapperStyle, modifiedColumns, rowKeyStr, deletedRowKeys, darkMode, virtualEditingCell, form, saveVirtualInlineEditor, lockVirtualInlineTableScroll, closeVirtualInlineEditor, updateFocusedCell]);
|
||||
}), [closeVirtualInlineEditor, columns, currentConnConfig, darkMode, dbType, deletedRowKeys, displayColumnTypeMap, enableInlineEditableCell, enableVirtual, form, handleCellSave, handleSharedCellContextMenu, handleVirtualCellActivate, inputCellPadding, lockVirtualInlineTableScroll, modifiedColumns, openCellEditor, rowKeyStr, saveVirtualInlineEditor, updateFocusedCell, useInlineEditableBodyCell, virtualCellWrapperStyle, virtualEditingCell]);
|
||||
|
||||
const rowNumberColumn = useMemo<ColumnType<any>>(() => ({
|
||||
title: (
|
||||
<div
|
||||
className="gn-v2-column-title is-single-line"
|
||||
data-grid-row-number-title="true"
|
||||
data-grid-column-title-single-line="true"
|
||||
style={{
|
||||
display: 'flex',
|
||||
flexDirection: 'column',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
minWidth: 0,
|
||||
width: '100%',
|
||||
maxWidth: '100%',
|
||||
minHeight: 'var(--gonavi-header-min-height, 40px)',
|
||||
lineHeight: 1.2,
|
||||
textAlign: 'center',
|
||||
}}
|
||||
>
|
||||
<span aria-label="行号">#</span>
|
||||
</div>
|
||||
),
|
||||
key: GONAVI_ROW_NUMBER_COLUMN_KEY,
|
||||
dataIndex: GONAVI_ROW_NUMBER_COLUMN_KEY,
|
||||
width: ROW_NUMBER_COLUMN_WIDTH,
|
||||
className: 'data-grid-row-number-cell',
|
||||
align: 'center',
|
||||
onHeaderCell: () => ({
|
||||
style: {
|
||||
textAlign: 'center' as const,
|
||||
paddingInline: 0,
|
||||
verticalAlign: 'middle' as const,
|
||||
},
|
||||
}),
|
||||
render: (_value: unknown, _record: Item, index: number) => {
|
||||
const currentPage = Math.max(1, Number(pagination?.current) || 1);
|
||||
const pageSize = Math.max(1, Number(pagination?.pageSize) || 0);
|
||||
const offset = pageSize > 0 ? (currentPage - 1) * pageSize : 0;
|
||||
return (
|
||||
<span className="data-grid-row-number" data-grid-row-number="true">
|
||||
{offset + index + 1}
|
||||
</span>
|
||||
);
|
||||
},
|
||||
}), [pagination?.current, pagination?.pageSize]);
|
||||
|
||||
const tableColumns = useMemo(
|
||||
() => (showRowNumberColumn ? [rowNumberColumn, ...mergedColumns] : mergedColumns),
|
||||
[mergedColumns, rowNumberColumn, showRowNumberColumn]
|
||||
);
|
||||
|
||||
const handleAddRow = () => {
|
||||
const newKey = `new-${Date.now()}`;
|
||||
@@ -5529,10 +5648,10 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
|
||||
const columnType = (columnMetaMap[normalizedColumnName] || columnMetaMapByLowerName[normalizedColumnName.toLowerCase()])?.type;
|
||||
const text = mergedDisplayData
|
||||
.map((row) => normalizeClipboardTsvCell(formatClipboardCellText(row?.[normalizedColumnName], columnType)))
|
||||
.map((row) => normalizeClipboardTsvCell(formatClipboardCellText(row?.[normalizedColumnName], columnType, currentConnConfig)))
|
||||
.join('\n');
|
||||
copyToClipboard(text);
|
||||
}, [columnMetaMap, columnMetaMapByLowerName, copyToClipboard, displayOutputColumnNames, mergedDisplayData]);
|
||||
}, [columnMetaMap, columnMetaMapByLowerName, copyToClipboard, currentConnConfig, displayOutputColumnNames, mergedDisplayData]);
|
||||
|
||||
const handleV2ColumnHeaderContextMenuAction = useCallback((action: V2ColumnHeaderContextMenuActionKey) => {
|
||||
const columnName = resolveContextMenuFieldName(cellContextMenu.dataIndex, cellContextMenu.title);
|
||||
@@ -5872,13 +5991,14 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
rows,
|
||||
columns,
|
||||
(columnName) => (columnMetaMap[columnName] || columnMetaMapByLowerName[columnName.toLowerCase()])?.type,
|
||||
currentConnConfig,
|
||||
);
|
||||
if (!text) {
|
||||
void message.info('当前行没有可复制内容');
|
||||
return;
|
||||
}
|
||||
copyToClipboard(text);
|
||||
}, [columnMetaMap, columnMetaMapByLowerName, copyToClipboard, displayOutputColumnNames, getContextMenuTargetRows]);
|
||||
}, [columnMetaMap, columnMetaMapByLowerName, copyToClipboard, currentConnConfig, displayOutputColumnNames, getContextMenuTargetRows]);
|
||||
|
||||
const buildConnConfig = useCallback(() => {
|
||||
if (!connectionId) return null;
|
||||
@@ -6435,7 +6555,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
|
||||
const rowPropsFactory = useCallback((record: any) => ({ record } as any), []);
|
||||
|
||||
const totalWidth = columns.reduce((sum: number, col: any) => sum + (Number(col.width) || densityParams.defaultColumnWidth), 0) + selectionColumnWidth;
|
||||
const totalWidth = tableColumns.reduce((sum: number, col: any) => sum + (Number(col.width) || densityParams.defaultColumnWidth), 0) + selectionColumnWidth;
|
||||
const useContextMenuRow = false;
|
||||
const tableScrollX = useMemo(() => {
|
||||
// rc-table 在 scroll.x 小于容器宽度时会把实际列宽按视口补齐。
|
||||
@@ -7330,6 +7450,14 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
});
|
||||
}, [pagination, supportsApproximateTotalPages]);
|
||||
|
||||
const paginationHasKnownTotalPages = useMemo(() => {
|
||||
if (!pagination) return false;
|
||||
if (pagination.totalKnown !== false) return true;
|
||||
if (!supportsApproximateTotalPages || !pagination.totalApprox) return false;
|
||||
const approximateTotal = Number(pagination.approximateTotal);
|
||||
return Number.isFinite(approximateTotal) && approximateTotal > 0;
|
||||
}, [pagination, supportsApproximateTotalPages]);
|
||||
|
||||
const paginationTotalPages = useMemo(() => {
|
||||
if (!pagination) return 1;
|
||||
if (!Number.isFinite(paginationControlTotal) || paginationControlTotal <= 0) {
|
||||
@@ -7361,6 +7489,14 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
supportsApproximateTotalPages,
|
||||
]);
|
||||
|
||||
const paginationPageText = useMemo(() => {
|
||||
if (!pagination) return '';
|
||||
return resolvePaginationPageText({
|
||||
pagination,
|
||||
supportsApproximateTotalPages,
|
||||
});
|
||||
}, [pagination, supportsApproximateTotalPages]);
|
||||
|
||||
const handlePageSizeChange = useCallback((value: string) => {
|
||||
if (!pagination || !onPageChange) return;
|
||||
const nextSize = Number(value);
|
||||
@@ -7412,7 +7548,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
ref={tableRef}
|
||||
components={tableComponents}
|
||||
dataSource={tableRenderData}
|
||||
columns={mergedColumns}
|
||||
columns={tableColumns}
|
||||
{...(enableVirtual && typeof virtualListItemHeight === 'number'
|
||||
? { listItemHeight: virtualListItemHeight }
|
||||
: {})}
|
||||
@@ -7502,7 +7638,9 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
paginationSummaryText={paginationSummaryText}
|
||||
paginationControlTotal={paginationControlTotal}
|
||||
paginationTotalPages={paginationTotalPages}
|
||||
paginationPageText={paginationPageText}
|
||||
paginationPageSizeOptions={paginationPageSizeOptions}
|
||||
showKnownPageCount={paginationHasKnownTotalPages}
|
||||
onPageChange={onPageChange}
|
||||
onPageSizeChange={handlePageSizeChange}
|
||||
onV2PageStep={handleV2PageStep}
|
||||
@@ -7516,7 +7654,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
const isJson = looksLikeJsonText(sample);
|
||||
const useTextArea = isJson || sample.includes('\n') || sample.length >= 160;
|
||||
const colMeta = columnMetaMap[col] || columnMetaMapByLowerName[col.toLowerCase()];
|
||||
const pickerType = getTemporalPickerType(colMeta?.type, dbType);
|
||||
const pickerType = getTemporalPickerType(colMeta?.type, dbType, currentConnConfig);
|
||||
const isTemporalValue = !!pickerType && !(/^0{4}-0{2}-0{2}/.test(String(sample || '')));
|
||||
const isWritable = isWritableResultColumn(col, effectiveEditLocator);
|
||||
return {
|
||||
@@ -7530,7 +7668,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
isWritable,
|
||||
};
|
||||
})
|
||||
), [displayColumnNames, columnMetaMap, columnMetaMapByLowerName, dbType, effectiveEditLocator, rowEditorOpen, rowEditorRowKey]);
|
||||
), [columnMetaMap, columnMetaMapByLowerName, currentConnConfig, dbType, displayColumnNames, effectiveEditLocator, rowEditorOpen, rowEditorRowKey]);
|
||||
|
||||
const handleRefreshGrid = useCallback(() => {
|
||||
clearAutoCommitTimer();
|
||||
|
||||
@@ -20,7 +20,9 @@ export interface DataGridPaginationBarProps {
|
||||
paginationSummaryText: string;
|
||||
paginationControlTotal: number;
|
||||
paginationTotalPages: number;
|
||||
paginationPageText: string;
|
||||
paginationPageSizeOptions: string[];
|
||||
showKnownPageCount: boolean;
|
||||
onPageChange?: (page: number, size: number) => void;
|
||||
onPageSizeChange: (value: string) => void;
|
||||
onV2PageStep: (direction: 'previous' | 'next') => void;
|
||||
@@ -33,7 +35,9 @@ const DataGridPaginationBar: React.FC<DataGridPaginationBarProps> = ({
|
||||
paginationSummaryText,
|
||||
paginationControlTotal,
|
||||
paginationTotalPages,
|
||||
paginationPageText,
|
||||
paginationPageSizeOptions,
|
||||
showKnownPageCount,
|
||||
onPageChange,
|
||||
onPageSizeChange,
|
||||
onV2PageStep,
|
||||
@@ -58,7 +62,7 @@ const DataGridPaginationBar: React.FC<DataGridPaginationBarProps> = ({
|
||||
if (normalizedJumpPage === pagination.current) return;
|
||||
onPageChange(normalizedJumpPage, pagination.pageSize);
|
||||
};
|
||||
const jumpPageControl = (
|
||||
const jumpPageControl = showKnownPageCount ? (
|
||||
<div className="data-grid-pagination-jump" data-grid-pagination-jump="true">
|
||||
<span className="data-grid-pagination-jump-label">跳页</span>
|
||||
<InputNumber
|
||||
@@ -83,7 +87,7 @@ const DataGridPaginationBar: React.FC<DataGridPaginationBarProps> = ({
|
||||
跳
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
) : null;
|
||||
|
||||
return (
|
||||
<div
|
||||
@@ -103,9 +107,15 @@ const DataGridPaginationBar: React.FC<DataGridPaginationBarProps> = ({
|
||||
onClick={() => onV2PageStep('previous')}
|
||||
/>
|
||||
<div className="data-grid-pagination-page-chip" data-grid-v2-page-chip="true">
|
||||
<strong>{pagination.current}</strong>
|
||||
<span>/</span>
|
||||
<span>{paginationTotalPages}</span>
|
||||
{showKnownPageCount ? (
|
||||
<>
|
||||
<strong>{pagination.current}</strong>
|
||||
<span>/</span>
|
||||
<span>{paginationTotalPages}</span>
|
||||
</>
|
||||
) : (
|
||||
<span>{paginationPageText}</span>
|
||||
)}
|
||||
</div>
|
||||
<Button
|
||||
data-grid-v2-pagination-next="true"
|
||||
|
||||
@@ -331,6 +331,38 @@ describe('DataViewer safe editing locator', () => {
|
||||
renderer.unmount();
|
||||
});
|
||||
|
||||
it('uses hidden OceanBase Oracle ROWID when no primary or unique key is available', async () => {
|
||||
storeState.connections[0].config.type = 'oceanbase';
|
||||
(storeState.connections[0].config as any).oceanBaseProtocol = 'oracle';
|
||||
storeState.connections[0].config.user = 'dev';
|
||||
storeState.connections[0].config.database = 'ORCLPDB1';
|
||||
backendApp.DBGetColumns.mockResolvedValue({
|
||||
success: true,
|
||||
data: [{ name: 'ID', key: '' }, { name: 'NAME', key: '' }],
|
||||
});
|
||||
backendApp.DBQuery.mockResolvedValue({
|
||||
success: true,
|
||||
fields: ['ID', 'NAME', ORACLE_ROWID_LOCATOR_COLUMN],
|
||||
data: [{ ID: 7, NAME: 'old-name', [ORACLE_ROWID_LOCATOR_COLUMN]: 'AAAA' }],
|
||||
});
|
||||
|
||||
const renderer = await renderAndReload(createTab({ id: 'tab-ob-oracle-rowid', dbName: 'ORCLPDB1', tableName: 'EDC_LOG', title: 'EDC_LOG' }));
|
||||
|
||||
expect(dataGridState.latestProps?.pkColumns).toEqual([]);
|
||||
expect(dataGridState.latestProps?.editLocator).toMatchObject({
|
||||
strategy: 'oracle-rowid',
|
||||
columns: ['ROWID'],
|
||||
valueColumns: [ORACLE_ROWID_LOCATOR_COLUMN],
|
||||
hiddenColumns: [ORACLE_ROWID_LOCATOR_COLUMN],
|
||||
readOnly: false,
|
||||
});
|
||||
expect(dataGridState.latestProps?.readOnly).toBe(false);
|
||||
expect(dataGridState.latestProps?.showRowNumberColumn).toBe(true);
|
||||
expect(messageApi.warning).not.toHaveBeenCalled();
|
||||
expect(backendApp.DBQuery.mock.calls.some((call: any[]) => String(call[2]).includes(`ROWID AS "${ORACLE_ROWID_LOCATOR_COLUMN}"`))).toBe(true);
|
||||
renderer.unmount();
|
||||
});
|
||||
|
||||
it('does not add fallback ORDER BY for DuckDB table preview when a primary key is available', async () => {
|
||||
storeState.connections[0].config.type = 'duckdb';
|
||||
storeState.connections[0].config.database = 'main';
|
||||
|
||||
@@ -7,7 +7,7 @@ import DataGrid, { GONAVI_ROW_KEY } from './DataGrid';
|
||||
import { buildOrderBySQL, buildPaginatedSelectSQL, buildWhereSQL, hasExplicitSort, quoteIdentPart, quoteQualifiedIdent, reverseOrderBySQL, withSortBufferTuningSQL, type FilterCondition } from '../utils/sql';
|
||||
import { buildMongoCountCommand, buildMongoFilter, buildMongoFindCommand, buildMongoSort } from '../utils/mongodb';
|
||||
import { buildOracleApproximateTotalSql, parseApproximateTableCountRow, resolveApproximateTableCountStrategy } from '../utils/approximateTableCount';
|
||||
import { getDataSourceCapabilities, resolveDataSourceType } from '../utils/dataSourceCapabilities';
|
||||
import { getDataSourceCapabilities, resolveDataSourceType, shouldShowOceanBaseRowNumberColumn } from '../utils/dataSourceCapabilities';
|
||||
import { resolveDataViewerAutoFetchAction } from '../utils/dataViewerAutoFetch';
|
||||
import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig';
|
||||
import {
|
||||
@@ -324,6 +324,7 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = React.memo(({
|
||||
const [quickWhereCondition, setQuickWhereCondition] = useState<string>(initialViewerSnapshot.quickWhereCondition);
|
||||
const duckdbSafeSelectCacheRef = useRef<Record<string, string>>({});
|
||||
const currentConnConfig = connections.find(c => c.id === tab.connectionId)?.config;
|
||||
const showRowNumberColumn = shouldShowOceanBaseRowNumberColumn(currentConnConfig);
|
||||
const currentConnCaps = getDataSourceCapabilities(currentConnConfig);
|
||||
const forceReadOnly = currentConnCaps.forceReadOnlyQueryResult;
|
||||
const preferManualTotalCount = currentConnCaps.preferManualTotalCount;
|
||||
@@ -1110,6 +1111,7 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = React.memo(({
|
||||
connectionId={tab.connectionId}
|
||||
pkColumns={pkColumns}
|
||||
editLocator={editLocator}
|
||||
showRowNumberColumn={showRowNumberColumn}
|
||||
onReload={handleReload}
|
||||
onSort={handleSort}
|
||||
onPageChange={handlePageChange}
|
||||
|
||||
@@ -1,7 +1,13 @@
|
||||
import dayjs from 'dayjs';
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import { getTemporalPickerType, resolveTemporalEditorSaveValue } from './dataGridTemporal';
|
||||
import {
|
||||
formatFromDayjs,
|
||||
getTemporalPickerFormat,
|
||||
getTemporalPickerType,
|
||||
parseToDayjs,
|
||||
resolveTemporalEditorSaveValue,
|
||||
} from './dataGridTemporal';
|
||||
|
||||
describe('dataGridTemporal helpers', () => {
|
||||
it('prefers the picker selected date when form store has not caught up yet', () => {
|
||||
@@ -23,4 +29,39 @@ describe('dataGridTemporal helpers', () => {
|
||||
expect(resolveTemporalEditorSaveValue(undefined, dayjs('2026-06-11 19:42:13'), pickerType))
|
||||
.toBe('2026-06-11');
|
||||
});
|
||||
|
||||
it('keeps OceanBase Oracle DATE columns as date-only editors', () => {
|
||||
const pickerType = getTemporalPickerType('DATE', 'oracle', {
|
||||
type: 'oceanbase',
|
||||
oceanBaseProtocol: 'oracle',
|
||||
} as any);
|
||||
|
||||
expect(pickerType).toBe('date');
|
||||
expect(resolveTemporalEditorSaveValue(undefined, dayjs('2026-06-11 19:42:13'), pickerType))
|
||||
.toBe('2026-06-11');
|
||||
});
|
||||
|
||||
it('preserves datetime fractional seconds when round-tripping through the editor helper', () => {
|
||||
const parsed = parseToDayjs('2026-06-16T16:46:23.158844Z', 'datetime');
|
||||
|
||||
expect(parsed?.isValid()).toBe(true);
|
||||
expect(formatFromDayjs(parsed, 'datetime')).toBe('2026-06-16 16:46:23.158844');
|
||||
expect(resolveTemporalEditorSaveValue(undefined, parsed, 'datetime')).toBe('2026-06-16 16:46:23.158844');
|
||||
});
|
||||
|
||||
it('keeps RFC3339 wall clock text instead of applying local timezone conversion in datetime editors', () => {
|
||||
const parsed = parseToDayjs('2026-06-17T05:00:00Z', 'datetime');
|
||||
|
||||
expect(parsed?.isValid()).toBe(true);
|
||||
expect(formatFromDayjs(parsed, 'datetime')).toBe('2026-06-17 05:00:00');
|
||||
});
|
||||
|
||||
it('uses a custom datetime picker format that can display preserved microseconds', () => {
|
||||
const parsed = parseToDayjs('2026-06-16T16:46:23.158844Z', 'datetime');
|
||||
const format = getTemporalPickerFormat('datetime');
|
||||
|
||||
expect(Array.isArray(format)).toBe(true);
|
||||
expect(typeof format[0]).toBe('function');
|
||||
expect((format[0] as (value: dayjs.Dayjs) => string)(parsed!)).toBe('2026-06-16 16:46:23.158844');
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import dayjs from 'dayjs';
|
||||
import type { ConnectionConfig } from '../types';
|
||||
import { normalizeOceanBaseProtocol } from '../utils/oceanBaseProtocol';
|
||||
import { isOracleLikeDialect } from '../utils/sqlDialect';
|
||||
|
||||
export type TemporalPickerType = 'datetime' | 'date' | 'time' | 'year' | null;
|
||||
export type TemporalConnectionLike = Pick<ConnectionConfig, 'type' | 'driver' | 'oceanBaseProtocol'> | null | undefined;
|
||||
|
||||
export const TEMPORAL_FORMATS: Record<string, string> = {
|
||||
datetime: 'YYYY-MM-DD HH:mm:ss',
|
||||
@@ -10,16 +13,100 @@ export const TEMPORAL_FORMATS: Record<string, string> = {
|
||||
year: 'YYYY',
|
||||
};
|
||||
|
||||
export const isTemporalColumnType = (columnType?: string, dbType?: string): boolean => {
|
||||
return !!getTemporalPickerType(columnType, dbType);
|
||||
const TEMPORAL_DATE_TIME_RE =
|
||||
/^(\d{4}-\d{2}-\d{2})[T ](\d{2}:\d{2}:\d{2})(?:\.(\d{1,9}))?(?:\s*(?:Z|[+-]\d{2}:?\d{2})(?:\s+[A-Za-z_\/+-]+)?)?$/;
|
||||
const temporalFractionMetaKey = Symbol('temporalFractionMeta');
|
||||
|
||||
type DayjsWithTemporalFractionMeta = dayjs.Dayjs & {
|
||||
[temporalFractionMetaKey]?: string;
|
||||
};
|
||||
|
||||
export const getTemporalPickerType = (columnType?: string, dbType?: string): TemporalPickerType => {
|
||||
const parseTemporalDateTimeParts = (value: string): { datePart: string; timePart: string; fractionDigits: string } | null => {
|
||||
const match = String(value || '').trim().match(TEMPORAL_DATE_TIME_RE);
|
||||
if (!match) return null;
|
||||
return {
|
||||
datePart: match[1],
|
||||
timePart: match[2],
|
||||
fractionDigits: match[3] || '',
|
||||
};
|
||||
};
|
||||
|
||||
const attachTemporalFractionMeta = (value: dayjs.Dayjs, fractionDigits: string): dayjs.Dayjs => {
|
||||
if (!value?.isValid?.()) return value;
|
||||
const normalizedDigits = String(fractionDigits || '');
|
||||
if (!normalizedDigits) return value;
|
||||
(value as DayjsWithTemporalFractionMeta)[temporalFractionMetaKey] = normalizedDigits;
|
||||
return value;
|
||||
};
|
||||
|
||||
const getTemporalFractionMeta = (value: dayjs.Dayjs | null | undefined): string => {
|
||||
if (!value || !value.isValid()) return '';
|
||||
return String((value as DayjsWithTemporalFractionMeta)[temporalFractionMetaKey] || '');
|
||||
};
|
||||
|
||||
const buildDayjsParseTextForDateTime = (datePart: string, timePart: string, fractionDigits: string): string => {
|
||||
if (!fractionDigits) return `${datePart} ${timePart}`;
|
||||
const milliseconds = fractionDigits.slice(0, 3).padEnd(3, '0');
|
||||
return `${datePart} ${timePart}.${milliseconds}`;
|
||||
};
|
||||
|
||||
const formatDateTimeWithFractionMeta = (value: dayjs.Dayjs): string => {
|
||||
const base = value.format(TEMPORAL_FORMATS.datetime);
|
||||
const fractionDigits = getTemporalFractionMeta(value);
|
||||
if (!fractionDigits) return base;
|
||||
|
||||
const milliseconds = String(value.millisecond()).padStart(3, '0');
|
||||
if (fractionDigits.length <= 3) {
|
||||
return `${base}.${milliseconds.slice(0, fractionDigits.length)}`;
|
||||
}
|
||||
return `${base}.${milliseconds}${fractionDigits.slice(3)}`;
|
||||
};
|
||||
|
||||
export const getTemporalPickerFormat = (
|
||||
pickerType: TemporalPickerType,
|
||||
): string | ((value: dayjs.Dayjs) => string) | Array<string | ((value: dayjs.Dayjs) => string)> => {
|
||||
if (pickerType !== 'datetime') {
|
||||
return TEMPORAL_FORMATS[pickerType || 'datetime'];
|
||||
}
|
||||
return [
|
||||
(value: dayjs.Dayjs) => formatDateTimeWithFractionMeta(value),
|
||||
'YYYY-MM-DD HH:mm:ss.SSSSSS',
|
||||
'YYYY-MM-DD HH:mm:ss.SSS',
|
||||
'YYYY-MM-DD HH:mm:ss',
|
||||
];
|
||||
};
|
||||
|
||||
export const isTemporalColumnType = (
|
||||
columnType?: string,
|
||||
dbType?: string,
|
||||
connectionConfig?: TemporalConnectionLike,
|
||||
): boolean => {
|
||||
return !!getTemporalPickerType(columnType, dbType, connectionConfig);
|
||||
};
|
||||
|
||||
const isOceanBaseOracleDateOnlyConnection = (connectionConfig?: TemporalConnectionLike): boolean => {
|
||||
if (!connectionConfig) return false;
|
||||
const type = String(connectionConfig.type || '').trim().toLowerCase();
|
||||
const driver = String(connectionConfig.driver || '').trim().toLowerCase();
|
||||
return (type === 'oceanbase' || driver === 'oceanbase')
|
||||
&& normalizeOceanBaseProtocol(connectionConfig.oceanBaseProtocol) === 'oracle';
|
||||
};
|
||||
|
||||
export const getTemporalPickerType = (
|
||||
columnType?: string,
|
||||
dbType?: string,
|
||||
connectionConfig?: TemporalConnectionLike,
|
||||
): TemporalPickerType => {
|
||||
const raw = String(columnType || '').trim().toLowerCase();
|
||||
if (!raw) return null;
|
||||
if (raw.includes('datetime') || raw.includes('timestamp')) return 'datetime';
|
||||
const base = raw.split(/[ (]/)[0];
|
||||
if (base === 'date') return isOracleLikeDialect(String(dbType || '')) ? 'datetime' : 'date';
|
||||
if (base === 'date') {
|
||||
if (isOracleLikeDialect(String(dbType || '')) && !isOceanBaseOracleDateOnlyConnection(connectionConfig)) {
|
||||
return 'datetime';
|
||||
}
|
||||
return 'date';
|
||||
}
|
||||
if (base === 'time') return 'time';
|
||||
if (base === 'year') return 'year';
|
||||
return null;
|
||||
@@ -29,13 +116,32 @@ export const parseToDayjs = (val: any, pickerType: TemporalPickerType): dayjs.Da
|
||||
if (val === null || val === undefined || val === '') return null;
|
||||
const str = String(val).trim();
|
||||
if (!str || /^0{4}-0{2}-0{2}/.test(str)) return null;
|
||||
if (pickerType === 'datetime') {
|
||||
const parts = parseTemporalDateTimeParts(str);
|
||||
if (parts) {
|
||||
const parsed = dayjs(buildDayjsParseTextForDateTime(parts.datePart, parts.timePart, parts.fractionDigits));
|
||||
if (parsed.isValid()) {
|
||||
return attachTemporalFractionMeta(parsed, parts.fractionDigits);
|
||||
}
|
||||
}
|
||||
}
|
||||
const fmt = TEMPORAL_FORMATS[pickerType || 'datetime'];
|
||||
const d = dayjs(str, fmt);
|
||||
return d.isValid() ? d : dayjs(str).isValid() ? dayjs(str) : null;
|
||||
if (d.isValid()) {
|
||||
const parts = pickerType === 'datetime' ? parseTemporalDateTimeParts(str) : null;
|
||||
return parts ? attachTemporalFractionMeta(d, parts.fractionDigits) : d;
|
||||
}
|
||||
const fallback = dayjs(str);
|
||||
if (!fallback.isValid()) return null;
|
||||
const parts = pickerType === 'datetime' ? parseTemporalDateTimeParts(str) : null;
|
||||
return parts ? attachTemporalFractionMeta(fallback, parts.fractionDigits) : fallback;
|
||||
};
|
||||
|
||||
export const formatFromDayjs = (val: dayjs.Dayjs | null, pickerType: TemporalPickerType): string => {
|
||||
if (!val || !val.isValid()) return '';
|
||||
if (pickerType === 'datetime') {
|
||||
return formatDateTimeWithFractionMeta(val);
|
||||
}
|
||||
const fmt = TEMPORAL_FORMATS[pickerType || 'datetime'];
|
||||
return val.format(fmt);
|
||||
};
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import { getDataSourceCapabilities } from './dataSourceCapabilities';
|
||||
import { getDataSourceCapabilities, shouldShowOceanBaseRowNumberColumn } from './dataSourceCapabilities';
|
||||
|
||||
describe('dataSourceCapabilities', () => {
|
||||
it('treats Oracle table preview totals as manual exact count plus approximate metadata count', () => {
|
||||
@@ -258,6 +258,14 @@ describe('dataSourceCapabilities', () => {
|
||||
});
|
||||
});
|
||||
|
||||
it('shows row numbers for OceanBase datasources regardless of protocol normalization', () => {
|
||||
expect(shouldShowOceanBaseRowNumberColumn({ type: 'oceanbase' })).toBe(true);
|
||||
expect(shouldShowOceanBaseRowNumberColumn({ type: 'oceanbase', oceanBaseProtocol: 'oracle' })).toBe(true);
|
||||
expect(shouldShowOceanBaseRowNumberColumn({ type: 'custom', driver: 'oceanbase', oceanBaseProtocol: 'oracle' })).toBe(true);
|
||||
expect(shouldShowOceanBaseRowNumberColumn({ type: 'oracle' })).toBe(false);
|
||||
expect(shouldShowOceanBaseRowNumberColumn({ type: 'mysql' })).toBe(false);
|
||||
});
|
||||
|
||||
it('treats custom OceanBase Oracle driver as Oracle capabilities', () => {
|
||||
expect(getDataSourceCapabilities({
|
||||
type: 'custom',
|
||||
|
||||
@@ -82,6 +82,13 @@ export const resolveDataSourceType = (config: ConnectionLike): string => {
|
||||
return type;
|
||||
};
|
||||
|
||||
export const shouldShowOceanBaseRowNumberColumn = (config: ConnectionLike): boolean => {
|
||||
if (!config) return false;
|
||||
const type = normalizeDataSourceToken(String(config.type || ''));
|
||||
const driver = normalizeDataSourceToken(String(config.driver || ''));
|
||||
return type === 'oceanbase' || driver === 'oceanbase';
|
||||
};
|
||||
|
||||
const SQL_QUERY_EXPORT_TYPES = new Set([
|
||||
'mysql',
|
||||
'goldendb',
|
||||
|
||||
@@ -1511,9 +1511,23 @@ func resolveCreateStatementWithFallback(dbInst db.Database, config connection.Co
|
||||
|
||||
sqlStr, sourceErr := dbInst.GetCreateStatement(metadataSchemaName, metadataTableName)
|
||||
if sourceErr == nil && !shouldFallbackCreateStatement(dbType, sqlStr) {
|
||||
if strings.TrimSpace(sqlStr) != "" {
|
||||
return sqlStr, nil
|
||||
}
|
||||
if isOceanBaseOracleProtocol(config) {
|
||||
if showDDL, ok := tryGetOceanBaseOracleShowCreateStatement(dbInst, metadataSchemaName, metadataTableName); ok {
|
||||
return showDDL, nil
|
||||
}
|
||||
}
|
||||
return sqlStr, nil
|
||||
}
|
||||
|
||||
if isOceanBaseOracleProtocol(config) {
|
||||
if showDDL, ok := tryGetOceanBaseOracleShowCreateStatement(dbInst, metadataSchemaName, metadataTableName); ok {
|
||||
return showDDL, nil
|
||||
}
|
||||
}
|
||||
|
||||
if supportsViewCreateStatementLookup(dbType) {
|
||||
if viewDDL, ok := tryGetViewCreateStatement(dbInst, config, dbName, ddlSchemaName, ddlTableName); ok {
|
||||
return viewDDL, nil
|
||||
@@ -1545,6 +1559,42 @@ func resolveCreateStatementWithFallback(dbInst db.Database, config connection.Co
|
||||
return fallbackDDL, nil
|
||||
}
|
||||
|
||||
func tryGetOceanBaseOracleShowCreateStatement(dbInst db.Database, schemaName string, tableName string) (string, bool) {
|
||||
query := "SHOW CREATE TABLE " + quoteOracleMetadataTableRef(schemaName, tableName)
|
||||
data, _, err := dbInst.Query(query)
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
for _, row := range data {
|
||||
for _, key := range []string{"Create Table", "CREATE TABLE", "CREATE_TABLE", "DDL", "ddl"} {
|
||||
if val, ok := row[key]; ok {
|
||||
text := strings.TrimSpace(fmt.Sprintf("%v", val))
|
||||
if text != "" && !strings.EqualFold(text, "<nil>") {
|
||||
return text, true
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, val := range row {
|
||||
text := strings.TrimSpace(fmt.Sprintf("%v", val))
|
||||
lower := strings.ToLower(text)
|
||||
if strings.HasPrefix(lower, "create table") ||
|
||||
strings.HasPrefix(lower, "create view") ||
|
||||
strings.HasPrefix(lower, "create or replace view") {
|
||||
return text, true
|
||||
}
|
||||
}
|
||||
if len(row) == 1 {
|
||||
for _, val := range row {
|
||||
text := strings.TrimSpace(fmt.Sprintf("%v", val))
|
||||
if text != "" && !strings.EqualFold(text, "<nil>") {
|
||||
return text, true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func supportsCreateStatementFallback(dbType string) bool {
|
||||
switch dbType {
|
||||
case "postgres", "kingbase", "highgo", "vastbase", "opengauss", "gaussdb", "sqlserver":
|
||||
@@ -1720,10 +1770,318 @@ func (a *App) DBGetColumns(config connection.ConnectionConfig, dbName string, ta
|
||||
logger.Error(err, "DBGetColumns 获取列定义失败:%s 表=%s.%s schema=%s pureTable=%s", formatConnSummary(runConfig), dbName, tableName, schemaName, pureTableName)
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
if len(columns) == 0 && resolveDDLDBType(config) == "oracle" {
|
||||
if inferred, inferErr := inferOracleColumnsFromDictionary(dbInst, schemaName, pureTableName); inferErr == nil && len(inferred) > 0 {
|
||||
columns = inferred
|
||||
}
|
||||
if len(columns) == 0 {
|
||||
if inferred, inferErr := inferOracleColumnsFromEmptySelect(dbInst, schemaName, pureTableName); inferErr == nil && len(inferred) > 0 {
|
||||
columns = inferred
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return connection.QueryResult{Success: true, Data: ensureNonNilSlice(columns)}
|
||||
}
|
||||
|
||||
func inferOracleColumnsFromDictionary(dbInst db.Database, schemaName string, tableName string) ([]connection.ColumnDefinition, error) {
|
||||
var lastErr error
|
||||
for _, candidate := range appOracleMetadataNamePairs(schemaName, tableName) {
|
||||
data, _, err := dbInst.Query(buildAppOracleColumnsQuery(candidate.schema, candidate.table))
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
columns := parseAppOracleColumns(data)
|
||||
if len(columns) > 0 {
|
||||
return columns, nil
|
||||
}
|
||||
}
|
||||
if lastErr != nil {
|
||||
return nil, lastErr
|
||||
}
|
||||
return nil, fmt.Errorf("未获取到字段定义")
|
||||
}
|
||||
|
||||
type appOracleMetadataNamePair struct {
|
||||
schema string
|
||||
table string
|
||||
}
|
||||
|
||||
func appOracleMetadataNamePairs(schemaName string, tableName string) []appOracleMetadataNamePair {
|
||||
rawSchema := strings.TrimSpace(schemaName)
|
||||
rawTable := strings.TrimSpace(tableName)
|
||||
if rawTable == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
upperSchema := strings.ToUpper(rawSchema)
|
||||
upperTable := strings.ToUpper(rawTable)
|
||||
pairs := make([]appOracleMetadataNamePair, 0, 4)
|
||||
seen := map[string]struct{}{}
|
||||
add := func(schema string, table string) {
|
||||
key := schema + "\x00" + table
|
||||
if _, exists := seen[key]; exists {
|
||||
return
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
pairs = append(pairs, appOracleMetadataNamePair{schema: schema, table: table})
|
||||
}
|
||||
|
||||
add(rawSchema, rawTable)
|
||||
add(upperSchema, upperTable)
|
||||
add(rawSchema, upperTable)
|
||||
add(upperSchema, rawTable)
|
||||
return pairs
|
||||
}
|
||||
|
||||
func buildAppOracleColumnsQuery(schema string, table string) string {
|
||||
metadataTableName := escapeAppOracleMetadataLiteral(table)
|
||||
metadataSchemaName := escapeAppOracleMetadataLiteral(schema)
|
||||
if strings.TrimSpace(schema) == "" {
|
||||
return fmt.Sprintf(`SELECT c.column_name AS "COLUMN_NAME", c.data_type AS "DATA_TYPE", c.data_length AS "DATA_LENGTH", c.char_length AS "CHAR_LENGTH", c.data_precision AS "DATA_PRECISION", c.data_scale AS "DATA_SCALE", c.nullable AS "NULLABLE", c.data_default AS "DATA_DEFAULT",
|
||||
CASE WHEN pk.column_name IS NOT NULL THEN 'PRI' ELSE '' END AS "COLUMN_KEY",
|
||||
cc.comments AS "COMMENT"
|
||||
FROM user_tab_columns c
|
||||
LEFT JOIN user_col_comments cc
|
||||
ON cc.table_name = c.table_name AND cc.column_name = c.column_name
|
||||
LEFT JOIN (
|
||||
SELECT cols.table_name, cols.column_name
|
||||
FROM user_constraints cons
|
||||
JOIN user_cons_columns cols
|
||||
ON cons.constraint_name = cols.constraint_name
|
||||
WHERE cons.constraint_type = 'P'
|
||||
) pk ON c.table_name = pk.table_name AND c.column_name = pk.column_name
|
||||
WHERE c.table_name = '%s'
|
||||
ORDER BY c.column_id`, metadataTableName)
|
||||
}
|
||||
|
||||
return fmt.Sprintf(`SELECT c.column_name AS "COLUMN_NAME", c.data_type AS "DATA_TYPE", c.data_length AS "DATA_LENGTH", c.char_length AS "CHAR_LENGTH", c.data_precision AS "DATA_PRECISION", c.data_scale AS "DATA_SCALE", c.nullable AS "NULLABLE", c.data_default AS "DATA_DEFAULT",
|
||||
CASE WHEN pk.column_name IS NOT NULL THEN 'PRI' ELSE '' END AS "COLUMN_KEY",
|
||||
cc.comments AS "COMMENT"
|
||||
FROM all_tab_columns c
|
||||
LEFT JOIN all_col_comments cc
|
||||
ON cc.owner = c.owner AND cc.table_name = c.table_name AND cc.column_name = c.column_name
|
||||
LEFT JOIN (
|
||||
SELECT cols.owner, cols.table_name, cols.column_name
|
||||
FROM all_constraints cons
|
||||
JOIN all_cons_columns cols
|
||||
ON cons.owner = cols.owner AND cons.constraint_name = cols.constraint_name
|
||||
WHERE cons.constraint_type = 'P'
|
||||
) pk ON c.owner = pk.owner AND c.table_name = pk.table_name AND c.column_name = pk.column_name
|
||||
WHERE c.owner = '%s' AND c.table_name = '%s'
|
||||
ORDER BY c.column_id`, metadataSchemaName, metadataTableName)
|
||||
}
|
||||
|
||||
func parseAppOracleColumns(data []map[string]interface{}) []connection.ColumnDefinition {
|
||||
columns := make([]connection.ColumnDefinition, 0, len(data))
|
||||
for _, row := range data {
|
||||
name := appOracleRowString(row, "COLUMN_NAME", "column_name")
|
||||
if strings.TrimSpace(name) == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
defaultValue := appOracleRowString(row, "DATA_DEFAULT", "COLUMN_DEFAULT", "data_default", "column_default")
|
||||
col := connection.ColumnDefinition{
|
||||
Name: name,
|
||||
Type: formatAppOracleColumnType(row),
|
||||
Nullable: normalizeAppOracleNullable(appOracleRowString(row, "NULLABLE", "nullable")),
|
||||
Key: appOracleRowString(row, "COLUMN_KEY", "column_key", "KEY", "key"),
|
||||
Extra: appOracleAutoIncrementExtra(defaultValue),
|
||||
Comment: appOracleRowString(row, "COMMENT", "COMMENTS", "comment", "comments"),
|
||||
}
|
||||
if defaultValue != "" {
|
||||
col.Default = &defaultValue
|
||||
}
|
||||
columns = append(columns, col)
|
||||
}
|
||||
return columns
|
||||
}
|
||||
|
||||
func formatAppOracleColumnType(row map[string]interface{}) string {
|
||||
dataType := appOracleRowString(row, "DATA_TYPE", "TYPE_NAME", "data_type", "type_name")
|
||||
if dataType == "" || strings.Contains(dataType, "(") {
|
||||
return dataType
|
||||
}
|
||||
|
||||
upperType := strings.ToUpper(dataType)
|
||||
if isAppOracleLengthQualifiedType(upperType) {
|
||||
if charLength, ok := appOracleRowInt(row, "CHAR_LENGTH", "CHAR_COL_DECL_LENGTH", "char_length", "char_col_decl_length"); ok && charLength > 0 {
|
||||
return fmt.Sprintf("%s(%d)", dataType, charLength)
|
||||
}
|
||||
if dataLength, ok := appOracleRowInt(row, "DATA_LENGTH", "data_length"); ok && dataLength > 0 {
|
||||
return fmt.Sprintf("%s(%d)", dataType, dataLength)
|
||||
}
|
||||
}
|
||||
|
||||
if strings.Contains(upperType, "NUMBER") || strings.Contains(upperType, "DECIMAL") || strings.Contains(upperType, "NUMERIC") {
|
||||
precision, hasPrecision := appOracleRowInt(row, "DATA_PRECISION", "NUMERIC_PRECISION", "data_precision", "numeric_precision")
|
||||
if hasPrecision && precision > 0 {
|
||||
scale, hasScale := appOracleRowInt(row, "DATA_SCALE", "NUMERIC_SCALE", "data_scale", "numeric_scale")
|
||||
if hasScale && scale > 0 {
|
||||
return fmt.Sprintf("%s(%d,%d)", dataType, precision, scale)
|
||||
}
|
||||
return fmt.Sprintf("%s(%d)", dataType, precision)
|
||||
}
|
||||
}
|
||||
|
||||
return dataType
|
||||
}
|
||||
|
||||
func isAppOracleLengthQualifiedType(upperType string) bool {
|
||||
switch strings.TrimSpace(upperType) {
|
||||
case "CHAR", "NCHAR", "VARCHAR", "VARCHAR2", "NVARCHAR", "NVARCHAR2", "RAW", "BINARY", "VARBINARY":
|
||||
return true
|
||||
default:
|
||||
return strings.Contains(upperType, "CHARACTER")
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeAppOracleNullable(nullable string) string {
|
||||
switch strings.ToUpper(strings.TrimSpace(nullable)) {
|
||||
case "N", "NO":
|
||||
return "NO"
|
||||
case "Y", "YES":
|
||||
return "YES"
|
||||
default:
|
||||
return strings.TrimSpace(nullable)
|
||||
}
|
||||
}
|
||||
|
||||
func appOracleAutoIncrementExtra(defaultValue string) string {
|
||||
if strings.Contains(strings.ToUpper(strings.TrimSpace(defaultValue)), "NEXTVAL") {
|
||||
return "auto_increment"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func appOracleRowValue(row map[string]interface{}, names ...string) interface{} {
|
||||
for _, name := range names {
|
||||
if value, ok := row[name]; ok {
|
||||
return value
|
||||
}
|
||||
}
|
||||
for key, value := range row {
|
||||
for _, name := range names {
|
||||
if strings.EqualFold(key, name) {
|
||||
return value
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func appOracleRowString(row map[string]interface{}, names ...string) string {
|
||||
return appOracleValueString(appOracleRowValue(row, names...))
|
||||
}
|
||||
|
||||
func appOracleValueString(value interface{}) string {
|
||||
if value == nil {
|
||||
return ""
|
||||
}
|
||||
switch typed := value.(type) {
|
||||
case []byte:
|
||||
return strings.TrimSpace(string(typed))
|
||||
case string:
|
||||
return strings.TrimSpace(typed)
|
||||
default:
|
||||
text := strings.TrimSpace(fmt.Sprintf("%v", typed))
|
||||
if strings.EqualFold(text, "<nil>") {
|
||||
return ""
|
||||
}
|
||||
return text
|
||||
}
|
||||
}
|
||||
|
||||
func appOracleRowInt(row map[string]interface{}, names ...string) (int, bool) {
|
||||
value := appOracleRowValue(row, names...)
|
||||
switch typed := value.(type) {
|
||||
case int:
|
||||
return typed, true
|
||||
case int8:
|
||||
return int(typed), true
|
||||
case int16:
|
||||
return int(typed), true
|
||||
case int32:
|
||||
return int(typed), true
|
||||
case int64:
|
||||
return int(typed), true
|
||||
case uint:
|
||||
return int(typed), true
|
||||
case uint8:
|
||||
return int(typed), true
|
||||
case uint16:
|
||||
return int(typed), true
|
||||
case uint32:
|
||||
return int(typed), true
|
||||
case uint64:
|
||||
return int(typed), true
|
||||
case float32:
|
||||
return int(typed), true
|
||||
case float64:
|
||||
return int(typed), true
|
||||
case []byte:
|
||||
parsed, err := strconv.Atoi(strings.TrimSpace(string(typed)))
|
||||
return parsed, err == nil
|
||||
case string:
|
||||
parsed, err := strconv.Atoi(strings.TrimSpace(typed))
|
||||
return parsed, err == nil
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func escapeAppOracleMetadataLiteral(text string) string {
|
||||
return strings.ReplaceAll(strings.TrimSpace(text), "'", "''")
|
||||
}
|
||||
|
||||
func inferOracleColumnsFromEmptySelect(dbInst db.Database, schemaName string, tableName string) ([]connection.ColumnDefinition, error) {
|
||||
table := strings.TrimSpace(tableName)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("表名不能为空")
|
||||
}
|
||||
|
||||
query := "SELECT * FROM " + quoteOracleMetadataTableRef(schemaName, table) + " WHERE 1 = 0"
|
||||
_, fields, err := dbInst.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(fields) == 0 {
|
||||
return nil, fmt.Errorf("未获取到字段定义")
|
||||
}
|
||||
|
||||
columns := make([]connection.ColumnDefinition, 0, len(fields))
|
||||
for _, field := range fields {
|
||||
name := strings.TrimSpace(field)
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
columns = append(columns, connection.ColumnDefinition{
|
||||
Name: name,
|
||||
Nullable: "",
|
||||
Key: "",
|
||||
Extra: "",
|
||||
Comment: "",
|
||||
})
|
||||
}
|
||||
if len(columns) == 0 {
|
||||
return nil, fmt.Errorf("未获取到字段定义")
|
||||
}
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
func quoteOracleMetadataIdentifier(ident string) string {
|
||||
return `"` + strings.ReplaceAll(strings.TrimSpace(ident), `"`, `""`) + `"`
|
||||
}
|
||||
|
||||
func quoteOracleMetadataTableRef(schemaName string, tableName string) string {
|
||||
tableRef := quoteOracleMetadataIdentifier(tableName)
|
||||
if strings.TrimSpace(schemaName) != "" {
|
||||
return quoteOracleMetadataIdentifier(schemaName) + "." + tableRef
|
||||
}
|
||||
return tableRef
|
||||
}
|
||||
|
||||
func (a *App) DBGetIndexes(config connection.ConnectionConfig, dbName string, tableName string) connection.QueryResult {
|
||||
runConfig := normalizeRunConfig(config, dbName)
|
||||
|
||||
|
||||
@@ -167,6 +167,14 @@ func TestFormatConnSummary_DefaultTimeout(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDBReleaseConnectionClosesAllDatabaseCacheEntriesForSameInstance(t *testing.T) {
|
||||
proxySnapshot := currentGlobalProxyConfig()
|
||||
if _, err := setGlobalProxyConfig(false, proxySnapshot.Proxy); err != nil {
|
||||
t.Fatalf("disable global proxy failed: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_, _ = setGlobalProxyConfig(proxySnapshot.Enabled, proxySnapshot.Proxy)
|
||||
})
|
||||
|
||||
app := NewApp()
|
||||
mainConfig := connection.ConnectionConfig{Type: "mysql", Host: "127.0.0.1", Port: 3306, User: "root", Database: "main"}
|
||||
analyticsConfig := mainConfig
|
||||
|
||||
@@ -332,6 +332,37 @@ func TestResolveCreateStatementWithFallback_ReturnsCreateViewDirectly(t *testing
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCreateStatementWithFallback_OceanBaseOracleUsesShowCreateWhenAgentDDLIsEmpty(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbInst := &fakeCreateStatementDB{
|
||||
createSQL: "",
|
||||
queryRows: []map[string]interface{}{
|
||||
{"Create Table": `CREATE TABLE "SYS"."test" ("id" NUMBER)`},
|
||||
},
|
||||
}
|
||||
|
||||
ddl, err := resolveCreateStatementWithFallback(dbInst, connection.ConnectionConfig{
|
||||
Type: "oceanbase",
|
||||
ConnectionParams: "protocol=oracle",
|
||||
}, "SYS", "SYS.test")
|
||||
if err != nil {
|
||||
t.Fatalf("resolveCreateStatementWithFallback() unexpected error: %v", err)
|
||||
}
|
||||
if ddl != `CREATE TABLE "SYS"."test" ("id" NUMBER)` {
|
||||
t.Fatalf("expected SHOW CREATE TABLE fallback DDL, got: %s", ddl)
|
||||
}
|
||||
if dbInst.createSchema != "SYS" || dbInst.createTable != "test" {
|
||||
t.Fatalf("expected metadata target SYS.test, got %q.%q", dbInst.createSchema, dbInst.createTable)
|
||||
}
|
||||
if len(dbInst.queries) != 1 || dbInst.queries[0] != `SHOW CREATE TABLE "SYS"."test"` {
|
||||
t.Fatalf("expected SHOW CREATE TABLE query, got: %v", dbInst.queries)
|
||||
}
|
||||
if dbInst.columnsCalls != 0 {
|
||||
t.Fatalf("OceanBase Oracle SHOW CREATE fallback should not call GetColumns, calls=%d", dbInst.columnsCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCreateStatementWithFallback_PGLikeViewHelperBeforeColumnFallback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -28,6 +28,11 @@ type fakeMetadataRetryDB struct {
|
||||
indexes []connection.IndexDefinition
|
||||
columnsErr error
|
||||
indexesErr error
|
||||
queryResults []fakeMetadataQueryResult
|
||||
queryRows []map[string]interface{}
|
||||
queryFields []string
|
||||
queryErr error
|
||||
queries []string
|
||||
columnCalls int
|
||||
indexCalls int
|
||||
columnSchema string
|
||||
@@ -36,11 +41,27 @@ type fakeMetadataRetryDB struct {
|
||||
indexTable string
|
||||
}
|
||||
|
||||
type fakeMetadataQueryResult struct {
|
||||
match string
|
||||
rows []map[string]interface{}
|
||||
fields []string
|
||||
err error
|
||||
}
|
||||
|
||||
func (f *fakeMetadataRetryDB) Connect(config connection.ConnectionConfig) error { return nil }
|
||||
func (f *fakeMetadataRetryDB) Close() error { return nil }
|
||||
func (f *fakeMetadataRetryDB) Ping() error { return nil }
|
||||
func (f *fakeMetadataRetryDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
return nil, nil, nil
|
||||
f.queries = append(f.queries, query)
|
||||
for _, result := range f.queryResults {
|
||||
if result.match == "" || strings.Contains(query, result.match) {
|
||||
return result.rows, result.fields, result.err
|
||||
}
|
||||
}
|
||||
if f.queryErr != nil {
|
||||
return nil, nil, f.queryErr
|
||||
}
|
||||
return f.queryRows, f.queryFields, nil
|
||||
}
|
||||
func (f *fakeMetadataRetryDB) Exec(query string) (int64, error) { return 0, nil }
|
||||
func (f *fakeMetadataRetryDB) GetDatabases() ([]string, error) { return nil, nil }
|
||||
@@ -238,6 +259,132 @@ func TestDBGetColumnsKeepsDatabaseForMySQLMetadata(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBGetColumnsInfersOceanBaseOracleFieldsWhenAgentMetadataIsEmpty(t *testing.T) {
|
||||
originalNewDatabaseFunc := newDatabaseFunc
|
||||
originalResolveDialConfigWithProxyFunc := resolveDialConfigWithProxyFunc
|
||||
t.Cleanup(func() {
|
||||
newDatabaseFunc = originalNewDatabaseFunc
|
||||
resolveDialConfigWithProxyFunc = originalResolveDialConfigWithProxyFunc
|
||||
})
|
||||
|
||||
dbInst := &fakeMetadataRetryDB{
|
||||
columns: []connection.ColumnDefinition{},
|
||||
queryResults: []fakeMetadataQueryResult{
|
||||
{
|
||||
match: "FROM all_tab_columns c",
|
||||
rows: []map[string]interface{}{
|
||||
{
|
||||
"COLUMN_NAME": "id",
|
||||
"DATA_TYPE": "NUMBER",
|
||||
"NULLABLE": "N",
|
||||
"DATA_DEFAULT": "SEQUENCE.NEXTVAL",
|
||||
"COLUMN_KEY": "PRI",
|
||||
"COMMENT": "",
|
||||
"DATA_PRECISION": nil,
|
||||
"DATA_SCALE": nil,
|
||||
},
|
||||
{
|
||||
"COLUMN_NAME": "new_col_1",
|
||||
"DATA_TYPE": "VARCHAR2",
|
||||
"CHAR_LENGTH": 255,
|
||||
"NULLABLE": "Y",
|
||||
"COLUMN_KEY": "",
|
||||
"COMMENT": "",
|
||||
},
|
||||
},
|
||||
fields: []string{"COLUMN_NAME", "DATA_TYPE", "DATA_LENGTH", "CHAR_LENGTH", "DATA_PRECISION", "DATA_SCALE", "NULLABLE", "DATA_DEFAULT", "COLUMN_KEY", "COMMENT"},
|
||||
},
|
||||
},
|
||||
}
|
||||
newDatabaseFunc = func(dbType string) (db.Database, error) {
|
||||
return dbInst, nil
|
||||
}
|
||||
resolveDialConfigWithProxyFunc = func(raw connection.ConnectionConfig) (connection.ConnectionConfig, error) {
|
||||
return raw, nil
|
||||
}
|
||||
|
||||
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
|
||||
result := app.DBGetColumns(connection.ConnectionConfig{
|
||||
Type: "oceanbase",
|
||||
Host: "127.0.0.1",
|
||||
Port: 12881,
|
||||
User: "SYS",
|
||||
ConnectionParams: "protocol=oracle",
|
||||
}, "SYS", "SYS.test")
|
||||
|
||||
if !result.Success {
|
||||
t.Fatalf("expected DBGetColumns success, got failure: %s", result.Message)
|
||||
}
|
||||
if dbInst.columnSchema != "SYS" || dbInst.columnTable != "test" {
|
||||
t.Fatalf("expected OceanBase Oracle metadata to split schema/table, got %q.%q", dbInst.columnSchema, dbInst.columnTable)
|
||||
}
|
||||
if len(dbInst.queries) != 1 || !strings.Contains(dbInst.queries[0], "FROM all_tab_columns c") {
|
||||
t.Fatalf("expected dictionary metadata fallback query, got %v", dbInst.queries)
|
||||
}
|
||||
columns, ok := result.Data.([]connection.ColumnDefinition)
|
||||
if !ok {
|
||||
t.Fatalf("expected []connection.ColumnDefinition, got %T", result.Data)
|
||||
}
|
||||
if len(columns) != 2 || columns[0].Name != "id" || columns[1].Name != "new_col_1" {
|
||||
t.Fatalf("unexpected inferred columns: %#v", columns)
|
||||
}
|
||||
if columns[0].Type != "NUMBER" || columns[0].Nullable != "NO" || columns[0].Key != "PRI" || columns[0].Extra != "auto_increment" {
|
||||
t.Fatalf("expected id to keep type/not-null/primary-key/auto-increment metadata, got %#v", columns[0])
|
||||
}
|
||||
if columns[0].Default == nil || *columns[0].Default != "SEQUENCE.NEXTVAL" {
|
||||
t.Fatalf("expected id default to keep sequence nextval, got %#v", columns[0].Default)
|
||||
}
|
||||
if columns[1].Type != "VARCHAR2(255)" || columns[1].Nullable != "YES" || columns[1].Key != "" {
|
||||
t.Fatalf("expected new_col_1 to keep varchar nullable metadata, got %#v", columns[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBGetColumnsFallsBackToEmptySelectWhenOceanBaseOracleDictionaryIsEmpty(t *testing.T) {
|
||||
originalNewDatabaseFunc := newDatabaseFunc
|
||||
originalResolveDialConfigWithProxyFunc := resolveDialConfigWithProxyFunc
|
||||
t.Cleanup(func() {
|
||||
newDatabaseFunc = originalNewDatabaseFunc
|
||||
resolveDialConfigWithProxyFunc = originalResolveDialConfigWithProxyFunc
|
||||
})
|
||||
|
||||
dbInst := &fakeMetadataRetryDB{
|
||||
columns: []connection.ColumnDefinition{},
|
||||
queryResults: []fakeMetadataQueryResult{
|
||||
{match: "FROM all_tab_columns c", rows: []map[string]interface{}{}},
|
||||
{match: `SELECT * FROM "SYS"."test" WHERE 1 = 0`, fields: []string{"id", "new_col_1"}},
|
||||
},
|
||||
}
|
||||
newDatabaseFunc = func(dbType string) (db.Database, error) {
|
||||
return dbInst, nil
|
||||
}
|
||||
resolveDialConfigWithProxyFunc = func(raw connection.ConnectionConfig) (connection.ConnectionConfig, error) {
|
||||
return raw, nil
|
||||
}
|
||||
|
||||
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
|
||||
result := app.DBGetColumns(connection.ConnectionConfig{
|
||||
Type: "oceanbase",
|
||||
Host: "127.0.0.1",
|
||||
Port: 12881,
|
||||
User: "SYS",
|
||||
ConnectionParams: "protocol=oracle",
|
||||
}, "SYS", "SYS.test")
|
||||
|
||||
if !result.Success {
|
||||
t.Fatalf("expected DBGetColumns success, got failure: %s", result.Message)
|
||||
}
|
||||
if len(dbInst.queries) < 2 {
|
||||
t.Fatalf("expected dictionary and empty-select fallback queries, got %v", dbInst.queries)
|
||||
}
|
||||
columns, ok := result.Data.([]connection.ColumnDefinition)
|
||||
if !ok {
|
||||
t.Fatalf("expected []connection.ColumnDefinition, got %T", result.Data)
|
||||
}
|
||||
if len(columns) != 2 || columns[0].Name != "id" || columns[1].Name != "new_col_1" {
|
||||
t.Fatalf("unexpected inferred columns: %#v", columns)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBGetColumnsKeepsDuckDBQualifiedTableMetadata(t *testing.T) {
|
||||
originalNewDatabaseFunc := newDatabaseFunc
|
||||
originalResolveDialConfigWithProxyFunc := resolveDialConfigWithProxyFunc
|
||||
|
||||
@@ -135,11 +135,16 @@ type TransactionExecerProvider interface {
|
||||
}
|
||||
|
||||
type sqlConnStatementExecer struct {
|
||||
conn *sql.Conn
|
||||
conn *sql.Conn
|
||||
scanDialect string
|
||||
}
|
||||
|
||||
func NewSQLConnStatementExecer(conn *sql.Conn) StatementExecer {
|
||||
return &sqlConnStatementExecer{conn: conn}
|
||||
return NewSQLConnStatementExecerWithDialect(conn, "")
|
||||
}
|
||||
|
||||
func NewSQLConnStatementExecerWithDialect(conn *sql.Conn, scanDialect string) StatementExecer {
|
||||
return &sqlConnStatementExecer{conn: conn, scanDialect: scanDialect}
|
||||
}
|
||||
|
||||
func (e *sqlConnStatementExecer) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
@@ -166,7 +171,7 @@ func (e *sqlConnStatementExecer) QueryContext(ctx context.Context, query string)
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanRows(rows)
|
||||
return scanRowsForDialect(rows, e.scanDialect)
|
||||
}
|
||||
|
||||
func (e *sqlConnStatementExecer) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
@@ -182,7 +187,7 @@ func (e *sqlConnStatementExecer) QueryMultiContext(ctx context.Context, query st
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanMultiRows(rows)
|
||||
return scanMultiRowsForDialect(rows, e.scanDialect)
|
||||
}
|
||||
|
||||
func (e *sqlConnStatementExecer) QueryMulti(query string) ([]connection.ResultSetData, error) {
|
||||
@@ -206,13 +211,19 @@ type sqlConnTransactionExecer struct {
|
||||
done bool
|
||||
commitSQL string
|
||||
rollbackSQL string
|
||||
scanDialect string
|
||||
}
|
||||
|
||||
func NewSQLConnTransactionExecer(conn *sql.Conn, commitSQL string, rollbackSQL string) TransactionExecer {
|
||||
return NewSQLConnTransactionExecerWithDialect(conn, commitSQL, rollbackSQL, "")
|
||||
}
|
||||
|
||||
func NewSQLConnTransactionExecerWithDialect(conn *sql.Conn, commitSQL string, rollbackSQL string, scanDialect string) TransactionExecer {
|
||||
return &sqlConnTransactionExecer{
|
||||
conn: conn,
|
||||
commitSQL: strings.TrimSpace(commitSQL),
|
||||
rollbackSQL: strings.TrimSpace(rollbackSQL),
|
||||
scanDialect: scanDialect,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -257,7 +268,7 @@ func (e *sqlConnTransactionExecer) QueryContext(ctx context.Context, query strin
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanRows(rows)
|
||||
return scanRowsForDialect(rows, e.scanDialect)
|
||||
}
|
||||
|
||||
func (e *sqlConnTransactionExecer) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
@@ -274,7 +285,7 @@ func (e *sqlConnTransactionExecer) QueryMultiContext(ctx context.Context, query
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanMultiRows(rows)
|
||||
return scanMultiRowsForDialect(rows, e.scanDialect)
|
||||
}
|
||||
|
||||
func (e *sqlConnTransactionExecer) QueryMulti(query string) ([]connection.ResultSetData, error) {
|
||||
|
||||
87
internal/db/database_session_test.go
Normal file
87
internal/db/database_session_test.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func openScanRowsDuplicateSQLConn(t *testing.T) *sql.Conn {
|
||||
t.Helper()
|
||||
|
||||
registerScanRowsDuplicateDriverOnce.Do(func() {
|
||||
sql.Register(scanRowsDuplicateDriverName, scanRowsDuplicateDriver{})
|
||||
})
|
||||
|
||||
dbConn, err := sql.Open(scanRowsDuplicateDriverName, "")
|
||||
if err != nil {
|
||||
t.Fatalf("open duplicate scan rows db failed: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = dbConn.Close()
|
||||
})
|
||||
|
||||
conn, err := dbConn.Conn(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("open sql conn failed: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = conn.Close()
|
||||
})
|
||||
return conn
|
||||
}
|
||||
|
||||
func TestSQLConnStatementExecerWithDialectDecodesOceanBaseOracleTimestamp(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := openScanRowsDuplicateSQLConn(t)
|
||||
execer, ok := NewSQLConnStatementExecerWithDialect(conn, oceanBaseOracleScanDialect).(StatementMultiResultQueryExecer)
|
||||
if !ok {
|
||||
t.Fatal("statement execer should support multi-result query")
|
||||
}
|
||||
|
||||
results, err := execer.QueryMultiContext(context.Background(), "SELECT timestamp_precision_columns")
|
||||
if err != nil {
|
||||
t.Fatalf("query multi failed: %v", err)
|
||||
}
|
||||
if len(results) != 1 {
|
||||
t.Fatalf("expected one result set, got=%d", len(results))
|
||||
}
|
||||
if !reflect.DeepEqual(results[0].Columns, []string{"created_at"}) {
|
||||
t.Fatalf("unexpected columns: %v", results[0].Columns)
|
||||
}
|
||||
if len(results[0].Rows) != 1 {
|
||||
t.Fatalf("expected one row, got=%d", len(results[0].Rows))
|
||||
}
|
||||
if got := results[0].Rows[0]["created_at"]; got != "2026-06-16T12:34:56.123456Z" {
|
||||
t.Fatalf("statement execer should decode OceanBase Oracle TIMESTAMP(6), got=%v(%T)", got, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLConnTransactionExecerWithDialectDecodesOceanBaseOracleTimestamp(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := openScanRowsDuplicateSQLConn(t)
|
||||
execer, ok := NewSQLConnTransactionExecerWithDialect(conn, "COMMIT", "ROLLBACK", oceanBaseOracleScanDialect).(StatementMultiResultQueryExecer)
|
||||
if !ok {
|
||||
t.Fatal("transaction execer should support multi-result query")
|
||||
}
|
||||
|
||||
results, err := execer.QueryMultiContext(context.Background(), "SELECT timestamp_precision_columns")
|
||||
if err != nil {
|
||||
t.Fatalf("query multi failed: %v", err)
|
||||
}
|
||||
if len(results) != 1 {
|
||||
t.Fatalf("expected one result set, got=%d", len(results))
|
||||
}
|
||||
if !reflect.DeepEqual(results[0].Columns, []string{"created_at"}) {
|
||||
t.Fatalf("unexpected columns: %v", results[0].Columns)
|
||||
}
|
||||
if len(results[0].Rows) != 1 {
|
||||
t.Fatalf("expected one row, got=%d", len(results[0].Rows))
|
||||
}
|
||||
if got := results[0].Rows[0]["created_at"]; got != "2026-06-16T12:34:56.123456Z" {
|
||||
t.Fatalf("transaction execer should decode OceanBase Oracle TIMESTAMP(6), got=%v(%T)", got, got)
|
||||
}
|
||||
}
|
||||
@@ -566,7 +566,7 @@ func (o *OceanBaseDB) connectOracleViaTNS(config connection.ConnectionConfig) er
|
||||
if strings.TrimSpace(runConfig.Database) == "" {
|
||||
return fmt.Errorf("OceanBase Oracle 协议(TNS 路径)需要填写服务名(Service Name),请在连接配置中填写租户监听的服务名(例如 ORCL / tenant_oracle 等)")
|
||||
}
|
||||
oracleDB := &OracleDB{}
|
||||
oracleDB := &OracleDB{scanDialect: oceanBaseOracleScanDialect}
|
||||
if err := oracleDB.Connect(runConfig); err != nil {
|
||||
return annotateOceanBaseOracleConnectError(err)
|
||||
}
|
||||
@@ -660,7 +660,7 @@ func (o *OceanBaseDB) bindConnectedDatabase(db *sql.DB, timeout time.Duration, p
|
||||
o.conn = nil
|
||||
o.pingTimeout = 0
|
||||
if protocol == oceanBaseProtocolOracle {
|
||||
o.oracle = &OracleDB{conn: db, pingTimeout: timeout}
|
||||
o.oracle = &OracleDB{conn: db, pingTimeout: timeout, scanDialect: oceanBaseOracleScanDialect}
|
||||
o.protocol = oceanBaseProtocolOracle
|
||||
return
|
||||
}
|
||||
@@ -871,9 +871,79 @@ func (o *OceanBaseDB) GetTables(dbName string) ([]string, error) {
|
||||
}
|
||||
|
||||
func (o *OceanBaseDB) GetCreateStatement(dbName, tableName string) (string, error) {
|
||||
if o.protocol == oceanBaseProtocolOracle && o.oracle != nil {
|
||||
ddl, err := o.oracle.GetCreateStatement(dbName, tableName)
|
||||
if err == nil && strings.TrimSpace(ddl) != "" {
|
||||
return ddl, nil
|
||||
}
|
||||
showDDL, showErr := o.getOceanBaseOracleShowCreateStatement(dbName, tableName)
|
||||
if showErr == nil {
|
||||
return showDDL, nil
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("%w;OceanBase Oracle SHOW CREATE TABLE 兜底失败:%v", err, showErr)
|
||||
}
|
||||
return "", showErr
|
||||
}
|
||||
return o.activeDatabase().GetCreateStatement(dbName, tableName)
|
||||
}
|
||||
|
||||
func (o *OceanBaseDB) getOceanBaseOracleShowCreateStatement(dbName string, tableName string) (string, error) {
|
||||
var firstErr error
|
||||
for _, candidate := range oracleMetadataNamePairs(dbName, tableName) {
|
||||
query := buildOceanBaseOracleShowCreateTableQuery(candidate.schema, candidate.table)
|
||||
data, _, err := o.oracle.Query(query)
|
||||
if err != nil {
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
continue
|
||||
}
|
||||
if ddl := extractOceanBaseOracleCreateStatement(data); ddl != "" {
|
||||
return o.oracle.appendOracleCommentDDL(ddl, candidate.schema, candidate.table), nil
|
||||
}
|
||||
}
|
||||
if firstErr != nil {
|
||||
return "", firstErr
|
||||
}
|
||||
return "", fmt.Errorf("未找到建表语句")
|
||||
}
|
||||
|
||||
func buildOceanBaseOracleShowCreateTableQuery(schema string, table string) string {
|
||||
return "SHOW CREATE TABLE " + quoteOracleTableRef(schema, table)
|
||||
}
|
||||
|
||||
func extractOceanBaseOracleCreateStatement(data []map[string]interface{}) string {
|
||||
for _, row := range data {
|
||||
for _, key := range []string{"Create Table", "CREATE TABLE", "CREATE_TABLE", "DDL", "ddl"} {
|
||||
if val, ok := row[key]; ok {
|
||||
text := strings.TrimSpace(fmt.Sprintf("%v", val))
|
||||
if text != "" && !strings.EqualFold(text, "<nil>") {
|
||||
return text
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, val := range row {
|
||||
text := strings.TrimSpace(fmt.Sprintf("%v", val))
|
||||
lower := strings.ToLower(text)
|
||||
if strings.HasPrefix(lower, "create table") ||
|
||||
strings.HasPrefix(lower, "create view") ||
|
||||
strings.HasPrefix(lower, "create or replace view") {
|
||||
return text
|
||||
}
|
||||
}
|
||||
if len(row) == 1 {
|
||||
for _, val := range row {
|
||||
text := strings.TrimSpace(fmt.Sprintf("%v", val))
|
||||
if text != "" && !strings.EqualFold(text, "<nil>") {
|
||||
return text
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (o *OceanBaseDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
|
||||
return o.activeDatabase().GetColumns(dbName, tableName)
|
||||
}
|
||||
@@ -907,6 +977,60 @@ func (o *OceanBaseDB) ApplyChanges(tableName string, changes connection.ChangeSe
|
||||
return fmt.Errorf("当前 OceanBase %s 协议不支持 ApplyChanges", o.protocol)
|
||||
}
|
||||
|
||||
func buildOceanBaseOracleTemporalBind(columnType string, value interface{}) (string, interface{}, bool) {
|
||||
if value == nil {
|
||||
return "?", nil, false
|
||||
}
|
||||
|
||||
rawType := strings.ToUpper(strings.TrimSpace(columnType))
|
||||
if !isOracleTemporalColumnType(rawType) {
|
||||
return "?", value, false
|
||||
}
|
||||
|
||||
var parsed time.Time
|
||||
switch typed := value.(type) {
|
||||
case time.Time:
|
||||
parsed = typed
|
||||
case string:
|
||||
text := strings.TrimSpace(typed)
|
||||
if text == "" {
|
||||
return "?", nil, false
|
||||
}
|
||||
var ok bool
|
||||
parsed, ok = parseOracleTemporalString(text)
|
||||
if !ok {
|
||||
return "?", value, false
|
||||
}
|
||||
default:
|
||||
return "?", value, false
|
||||
}
|
||||
|
||||
if strings.Contains(rawType, "TIMESTAMP") {
|
||||
text := parsed.Format("2006-01-02 15:04:05")
|
||||
format := "YYYY-MM-DD HH24:MI:SS"
|
||||
if parsed.Nanosecond() != 0 {
|
||||
text = parsed.Format("2006-01-02 15:04:05.999999999")
|
||||
text = strings.TrimRight(strings.TrimRight(text, "0"), ".")
|
||||
format = "YYYY-MM-DD HH24:MI:SS.FF"
|
||||
}
|
||||
return fmt.Sprintf("TO_TIMESTAMP(?, '%s')", format), text, true
|
||||
}
|
||||
|
||||
if parsed.Hour() == 0 && parsed.Minute() == 0 && parsed.Second() == 0 && parsed.Nanosecond() == 0 {
|
||||
return "TO_DATE(?, 'YYYY-MM-DD')", parsed.Format("2006-01-02"), true
|
||||
}
|
||||
return "TO_DATE(?, 'YYYY-MM-DD HH24:MI:SS')", parsed.Format("2006-01-02 15:04:05"), true
|
||||
}
|
||||
|
||||
func buildOceanBaseOracleAssignment(columnName string, value interface{}, columnTypeMap map[string]string) (string, []interface{}) {
|
||||
columnType := columnTypeMap[strings.ToLower(strings.TrimSpace(columnName))]
|
||||
normalized := normalizeOracleValueForWrite(columnName, value, columnTypeMap)
|
||||
if expr, bind, ok := buildOceanBaseOracleTemporalBind(columnType, normalized); ok {
|
||||
return expr, []interface{}{bind}
|
||||
}
|
||||
return "?", []interface{}{normalized}
|
||||
}
|
||||
|
||||
// applyOracleChangesMySQLWire 在 OceanBase Oracle 租户的 mysql wire 连接上执行
|
||||
// DELETE/UPDATE/INSERT,使用 Oracle 风格双引号引用标识符 + mysql wire 风格 "?" 占位符。
|
||||
func (o *OceanBaseDB) applyOracleChangesMySQLWire(tableName string, changes connection.ChangeSet) error {
|
||||
@@ -959,8 +1083,9 @@ func (o *OceanBaseDB) applyOracleChangesMySQLWire(tableName string, changes conn
|
||||
args = append(args, v)
|
||||
continue
|
||||
}
|
||||
wheres = append(wheres, fmt.Sprintf("%s = ?", quoteIdent(k)))
|
||||
args = append(args, normalizeOracleValueForWrite(k, v, columnTypeMap))
|
||||
valueExpr, valueArgs := buildOceanBaseOracleAssignment(k, v, columnTypeMap)
|
||||
wheres = append(wheres, fmt.Sprintf("%s = %s", quoteIdent(k), valueExpr))
|
||||
args = append(args, valueArgs...)
|
||||
}
|
||||
return wheres, args
|
||||
}
|
||||
@@ -985,8 +1110,9 @@ func (o *OceanBaseDB) applyOracleChangesMySQLWire(tableName string, changes conn
|
||||
var args []interface{}
|
||||
|
||||
for k, v := range update.Values {
|
||||
sets = append(sets, fmt.Sprintf("%s = ?", quoteIdent(k)))
|
||||
args = append(args, normalizeOracleValueForWrite(k, v, columnTypeMap))
|
||||
valueExpr, valueArgs := buildOceanBaseOracleAssignment(k, v, columnTypeMap)
|
||||
sets = append(sets, fmt.Sprintf("%s = %s", quoteIdent(k), valueExpr))
|
||||
args = append(args, valueArgs...)
|
||||
}
|
||||
|
||||
if len(sets) == 0 {
|
||||
@@ -1017,8 +1143,9 @@ func (o *OceanBaseDB) applyOracleChangesMySQLWire(tableName string, changes conn
|
||||
|
||||
for k, v := range row {
|
||||
cols = append(cols, quoteIdent(k))
|
||||
placeholders = append(placeholders, "?")
|
||||
args = append(args, normalizeOracleValueForWrite(k, v, columnTypeMap))
|
||||
valueExpr, valueArgs := buildOceanBaseOracleAssignment(k, v, columnTypeMap)
|
||||
placeholders = append(placeholders, valueExpr)
|
||||
args = append(args, valueArgs...)
|
||||
}
|
||||
|
||||
if len(cols) == 0 {
|
||||
|
||||
@@ -4,9 +4,11 @@ package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"net"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -782,6 +784,81 @@ func TestOceanBaseOracleOBClientApplyChangesUsesMySQLWirePlaceholders(t *testing
|
||||
}
|
||||
}
|
||||
|
||||
func TestOceanBaseOracleOBClientApplyChangesFormatsTemporalValuesExplicitly(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbConn, state := openOracleRecordingDB(t)
|
||||
oceanbaseDB := &OceanBaseDB{}
|
||||
oceanbaseDB.bindConnectedDatabase(dbConn, 0, oceanBaseProtocolOracle)
|
||||
|
||||
changes := connection.ChangeSet{
|
||||
Updates: []connection.UpdateRow{{
|
||||
Keys: map[string]interface{}{
|
||||
"ID": int64(7),
|
||||
},
|
||||
Values: map[string]interface{}{
|
||||
"UPDATED_AT": "2026-06-16 17:37:08",
|
||||
},
|
||||
}},
|
||||
}
|
||||
|
||||
if err := oceanbaseDB.ApplyChanges("APP.USERS", changes); err != nil {
|
||||
t.Fatalf("ApplyChanges() unexpected error: %v", err)
|
||||
}
|
||||
|
||||
queries := state.snapshotExecQueries()
|
||||
if len(queries) != 1 {
|
||||
t.Fatalf("expected one exec query, got %#v", queries)
|
||||
}
|
||||
if !strings.Contains(queries[0], `"UPDATED_AT" = TO_TIMESTAMP(?, 'YYYY-MM-DD HH24:MI:SS')`) {
|
||||
t.Fatalf("expected explicit TO_TIMESTAMP binding for temporal update, got %q", queries[0])
|
||||
}
|
||||
|
||||
executions := state.snapshotExecArgs()
|
||||
if len(executions) != 1 || len(executions[0]) != 2 {
|
||||
t.Fatalf("unexpected exec args: %#v", executions)
|
||||
}
|
||||
if got, ok := executions[0][0].Value.(string); !ok || got != "2026-06-16 17:37:08" {
|
||||
t.Fatalf("expected temporal bind arg kept as canonical string, got %#v (%T)", executions[0][0].Value, executions[0][0].Value)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOceanBaseOracleGetCreateStatementFallsBackToShowCreateTable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbConn, state := openOracleRecordingDB(t)
|
||||
state.mu.Lock()
|
||||
state.queryResults[`SHOW CREATE TABLE "SYS"."test"`] = oracleRecordingQueryResult{
|
||||
columns: []string{"Create Table"},
|
||||
rows: [][]driver.Value{
|
||||
{`CREATE TABLE "SYS"."test" ("ID" NUMBER)`},
|
||||
},
|
||||
}
|
||||
state.mu.Unlock()
|
||||
|
||||
oceanbaseDB := &OceanBaseDB{}
|
||||
oceanbaseDB.bindConnectedDatabase(dbConn, 0, oceanBaseProtocolOracle)
|
||||
|
||||
ddl, err := oceanbaseDB.GetCreateStatement("SYS", "test")
|
||||
if err != nil {
|
||||
t.Fatalf("GetCreateStatement() unexpected error: %v", err)
|
||||
}
|
||||
if !strings.Contains(ddl, `CREATE TABLE "SYS"."test"`) {
|
||||
t.Fatalf("expected SHOW CREATE TABLE DDL, got: %s", ddl)
|
||||
}
|
||||
|
||||
queries := state.snapshotQueries()
|
||||
if len(queries) < 3 {
|
||||
t.Fatalf("expected DBMS_METADATA attempts followed by SHOW CREATE TABLE, got: %v", queries)
|
||||
}
|
||||
if queries[0] != `SELECT DBMS_METADATA.GET_DDL('TABLE', 'test', 'SYS') as ddl FROM DUAL` {
|
||||
t.Fatalf("expected original-case DBMS_METADATA first, got: %v", queries)
|
||||
}
|
||||
if !slices.Contains(queries, `SHOW CREATE TABLE "SYS"."test"`) {
|
||||
t.Fatalf("expected SHOW CREATE TABLE fallback, got: %v", queries)
|
||||
}
|
||||
}
|
||||
|
||||
// 用户通过 ConnectionParams 设置 connectionAttributes 时,OceanBase MySQL wire 路径必须把
|
||||
// 这些 attribute 透传到 go-sql-driver/mysql DSN,让 driver 在握手响应里发 CLIENT_CONNECT_ATTRS。
|
||||
// 这是 OBClient 协议握手探索的入口:高级用户/DBA 可以试错不同 attribute 组合而不需要改 GoNavi 代码。
|
||||
|
||||
@@ -25,19 +25,22 @@ var (
|
||||
)
|
||||
|
||||
type oracleRecordingState struct {
|
||||
mu sync.Mutex
|
||||
execQueries []string
|
||||
execArgs [][]driver.NamedValue
|
||||
queries []string
|
||||
beginCalls int
|
||||
rowsAffected int64
|
||||
queryResults map[string]oracleRecordingQueryResult
|
||||
queryError error
|
||||
mu sync.Mutex
|
||||
execQueries []string
|
||||
execArgs [][]driver.NamedValue
|
||||
queries []string
|
||||
beginCalls int
|
||||
rowsAffected int64
|
||||
queryResults map[string]oracleRecordingQueryResult
|
||||
queryError error
|
||||
disableDefaultTabColumns bool
|
||||
}
|
||||
|
||||
type oracleRecordingQueryResult struct {
|
||||
columns []string
|
||||
rows [][]driver.Value
|
||||
columns []string
|
||||
columnTypes []string
|
||||
nullable []bool
|
||||
rows [][]driver.Value
|
||||
}
|
||||
|
||||
func (s *oracleRecordingState) snapshotExecQueries() []string {
|
||||
@@ -109,6 +112,7 @@ func (c *oracleRecordingConn) ExecContext(_ context.Context, query string, args
|
||||
func (c *oracleRecordingConn) QueryContext(_ context.Context, query string, _ []driver.NamedValue) (driver.Rows, error) {
|
||||
c.state.mu.Lock()
|
||||
c.state.queries = append(c.state.queries, query)
|
||||
disableDefaultTabColumns := c.state.disableDefaultTabColumns
|
||||
if err := c.state.queryError; err != nil {
|
||||
c.state.mu.Unlock()
|
||||
return nil, err
|
||||
@@ -116,13 +120,15 @@ func (c *oracleRecordingConn) QueryContext(_ context.Context, query string, _ []
|
||||
if result, ok := c.state.queryResults[query]; ok {
|
||||
c.state.mu.Unlock()
|
||||
return &oracleRecordingRows{
|
||||
columns: append([]string(nil), result.columns...),
|
||||
rows: cloneOracleRecordingRows(result.rows),
|
||||
columns: append([]string(nil), result.columns...),
|
||||
columnTypes: append([]string(nil), result.columnTypes...),
|
||||
nullable: append([]bool(nil), result.nullable...),
|
||||
rows: cloneOracleRecordingRows(result.rows),
|
||||
}, nil
|
||||
}
|
||||
c.state.mu.Unlock()
|
||||
|
||||
if strings.Contains(strings.ToLower(query), "tab_columns") {
|
||||
if strings.Contains(strings.ToLower(query), "tab_columns") && !disableDefaultTabColumns {
|
||||
return &oracleRecordingRows{
|
||||
columns: []string{"COLUMN_NAME", "DATA_TYPE", "NULLABLE", "DATA_DEFAULT", "COLUMN_KEY", "COMMENT"},
|
||||
rows: [][]driver.Value{
|
||||
@@ -151,9 +157,11 @@ func (oracleRecordingTx) Commit() error { return nil }
|
||||
func (oracleRecordingTx) Rollback() error { return nil }
|
||||
|
||||
type oracleRecordingRows struct {
|
||||
columns []string
|
||||
rows [][]driver.Value
|
||||
index int
|
||||
columns []string
|
||||
columnTypes []string
|
||||
nullable []bool
|
||||
rows [][]driver.Value
|
||||
index int
|
||||
}
|
||||
|
||||
func (r *oracleRecordingRows) Columns() []string {
|
||||
@@ -162,6 +170,20 @@ func (r *oracleRecordingRows) Columns() []string {
|
||||
|
||||
func (r *oracleRecordingRows) Close() error { return nil }
|
||||
|
||||
func (r *oracleRecordingRows) ColumnTypeDatabaseTypeName(index int) string {
|
||||
if index < 0 || index >= len(r.columnTypes) {
|
||||
return ""
|
||||
}
|
||||
return r.columnTypes[index]
|
||||
}
|
||||
|
||||
func (r *oracleRecordingRows) ColumnTypeNullable(index int) (nullable, ok bool) {
|
||||
if index < 0 || index >= len(r.nullable) {
|
||||
return false, false
|
||||
}
|
||||
return r.nullable[index], true
|
||||
}
|
||||
|
||||
func (r *oracleRecordingRows) Next(dest []driver.Value) error {
|
||||
if r.index >= len(r.rows) {
|
||||
return io.EOF
|
||||
|
||||
@@ -3,6 +3,7 @@ package db
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
@@ -111,6 +112,59 @@ func TestOracleGetColumnsIncludesColumnComments(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestOracleGetColumnsPreservesMetadataNameCaseBeforeUppercaseFallback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbConn, state := openOracleRecordingDB(t)
|
||||
oracleDB := &OracleDB{conn: dbConn}
|
||||
if _, err := oracleDB.GetColumns("SYS", "test"); err != nil {
|
||||
t.Fatalf("GetColumns 返回错误: %v", err)
|
||||
}
|
||||
|
||||
queries := state.snapshotQueries()
|
||||
if len(queries) == 0 {
|
||||
t.Fatalf("expected metadata query")
|
||||
}
|
||||
if !strings.Contains(queries[0], `WHERE c.owner = 'SYS' AND c.table_name = 'test'`) {
|
||||
t.Fatalf("expected first metadata query to preserve table case, got: %s", queries[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestOracleGetColumnsFallsBackToSelectMetadataWhenDictionaryIsEmpty(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbConn, state := openOracleRecordingDB(t)
|
||||
state.mu.Lock()
|
||||
state.disableDefaultTabColumns = true
|
||||
state.queryResults[`SELECT * FROM "SYS"."test" WHERE 1 = 0`] = oracleRecordingQueryResult{
|
||||
columns: []string{"id", "new_col_1"},
|
||||
columnTypes: []string{"NUMBER", "VARCHAR2"},
|
||||
nullable: []bool{false, true},
|
||||
rows: [][]driver.Value{},
|
||||
}
|
||||
state.mu.Unlock()
|
||||
|
||||
oracleDB := &OracleDB{conn: dbConn}
|
||||
columns, err := oracleDB.GetColumns("SYS", "test")
|
||||
if err != nil {
|
||||
t.Fatalf("GetColumns 返回错误: %v", err)
|
||||
}
|
||||
if len(columns) != 2 {
|
||||
t.Fatalf("expected fallback columns, got %#v", columns)
|
||||
}
|
||||
if columns[0].Name != "id" || columns[0].Type != "NUMBER" || columns[0].Nullable != "NO" {
|
||||
t.Fatalf("unexpected first fallback column: %#v", columns[0])
|
||||
}
|
||||
if columns[1].Name != "new_col_1" || columns[1].Type != "VARCHAR2" || columns[1].Nullable != "YES" {
|
||||
t.Fatalf("unexpected second fallback column: %#v", columns[1])
|
||||
}
|
||||
|
||||
queries := state.snapshotQueries()
|
||||
if !slices.Contains(queries, `SELECT * FROM "SYS"."test" WHERE 1 = 0`) {
|
||||
t.Fatalf("expected SELECT metadata fallback query, got: %v", queries)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatOracleColumnTypeIncludesLengthAndPrecision(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -157,6 +211,66 @@ func TestFormatOracleColumnTypeIncludesLengthAndPrecision(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestOracleGetCreateStatementPreservesMetadataNameCase(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbConn, state := openOracleRecordingDB(t)
|
||||
state.mu.Lock()
|
||||
state.queryResults[`SELECT DBMS_METADATA.GET_DDL('TABLE', 'test', 'SYS') as ddl FROM DUAL`] = oracleRecordingQueryResult{
|
||||
columns: []string{"DDL"},
|
||||
rows: [][]driver.Value{
|
||||
{`CREATE TABLE "SYS"."test" ("ID" NUMBER)`},
|
||||
},
|
||||
}
|
||||
state.mu.Unlock()
|
||||
|
||||
oracleDB := &OracleDB{conn: dbConn}
|
||||
ddl, err := oracleDB.GetCreateStatement("SYS", "test")
|
||||
if err != nil {
|
||||
t.Fatalf("GetCreateStatement 返回错误: %v", err)
|
||||
}
|
||||
if !strings.Contains(ddl, `CREATE TABLE "SYS"."test"`) {
|
||||
t.Fatalf("expected lowercase metadata DDL, got: %s", ddl)
|
||||
}
|
||||
|
||||
queries := state.snapshotQueries()
|
||||
if len(queries) == 0 || queries[0] != `SELECT DBMS_METADATA.GET_DDL('TABLE', 'test', 'SYS') as ddl FROM DUAL` {
|
||||
t.Fatalf("expected first DDL query to preserve case, got: %v", queries)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOracleGetCreateStatementFallsBackToUppercaseMetadataName(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbConn, state := openOracleRecordingDB(t)
|
||||
state.mu.Lock()
|
||||
state.queryResults[`SELECT DBMS_METADATA.GET_DDL('TABLE', 'TEST', 'SYS') as ddl FROM DUAL`] = oracleRecordingQueryResult{
|
||||
columns: []string{"DDL"},
|
||||
rows: [][]driver.Value{
|
||||
{`CREATE TABLE "SYS"."TEST" ("ID" NUMBER)`},
|
||||
},
|
||||
}
|
||||
state.mu.Unlock()
|
||||
|
||||
oracleDB := &OracleDB{conn: dbConn}
|
||||
ddl, err := oracleDB.GetCreateStatement("SYS", "test")
|
||||
if err != nil {
|
||||
t.Fatalf("GetCreateStatement 返回错误: %v", err)
|
||||
}
|
||||
if !strings.Contains(ddl, `CREATE TABLE "SYS"."TEST"`) {
|
||||
t.Fatalf("expected uppercase fallback DDL, got: %s", ddl)
|
||||
}
|
||||
|
||||
queries := state.snapshotQueries()
|
||||
if len(queries) < 2 {
|
||||
t.Fatalf("expected original-case query followed by uppercase fallback, got: %v", queries)
|
||||
}
|
||||
if queries[0] != `SELECT DBMS_METADATA.GET_DDL('TABLE', 'test', 'SYS') as ddl FROM DUAL` ||
|
||||
queries[1] != `SELECT DBMS_METADATA.GET_DDL('TABLE', 'TEST', 'SYS') as ddl FROM DUAL` {
|
||||
t.Fatalf("unexpected DDL fallback query order: %v", queries)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOracleGetCreateStatementAppendsTableAndColumnComments(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ type OracleDB struct {
|
||||
conn *sql.DB
|
||||
pingTimeout time.Duration
|
||||
forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder
|
||||
scanDialect string
|
||||
}
|
||||
|
||||
var _ SessionExecerProvider = (*OracleDB)(nil)
|
||||
@@ -216,7 +217,7 @@ func (o *OracleDB) QueryContext(ctx context.Context, query string) ([]map[string
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanRows(rows)
|
||||
return scanRowsForDialect(rows, o.scanDialect)
|
||||
}
|
||||
|
||||
func (o *OracleDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
@@ -229,7 +230,7 @@ func (o *OracleDB) Query(query string) ([]map[string]interface{}, []string, erro
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanRows(rows)
|
||||
return scanRowsForDialect(rows, o.scanDialect)
|
||||
}
|
||||
|
||||
func (o *OracleDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
@@ -262,7 +263,7 @@ func (o *OracleDB) OpenTransactionExecer(ctx context.Context) (TransactionExecer
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewSQLConnTransactionExecer(conn, "COMMIT", "ROLLBACK"), nil
|
||||
return NewSQLConnTransactionExecerWithDialect(conn, "COMMIT", "ROLLBACK", o.scanDialect), nil
|
||||
}
|
||||
|
||||
func (o *OracleDB) OpenSessionExecer(ctx context.Context) (StatementExecer, error) {
|
||||
@@ -273,7 +274,7 @@ func (o *OracleDB) OpenSessionExecer(ctx context.Context) (StatementExecer, erro
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewSQLConnStatementExecer(conn), nil
|
||||
return NewSQLConnStatementExecerWithDialect(conn, o.scanDialect), nil
|
||||
}
|
||||
|
||||
func (o *OracleDB) GetDatabases() ([]string, error) {
|
||||
@@ -325,89 +326,153 @@ func (o *OracleDB) GetTables(dbName string) ([]string, error) {
|
||||
func (o *OracleDB) GetCreateStatement(dbName, tableName string) (string, error) {
|
||||
// Oracle provides DBMS_METADATA.GET_DDL
|
||||
// Note: LONG type might be tricky, but basic string scan should work for smaller DDLs
|
||||
metadataTableName := escapeOracleMetadataLiteral(tableName)
|
||||
metadataSchemaName := escapeOracleMetadataLiteral(dbName)
|
||||
query := fmt.Sprintf("SELECT DBMS_METADATA.GET_DDL('TABLE', '%s', '%s') as ddl FROM DUAL",
|
||||
metadataTableName, metadataSchemaName)
|
||||
var firstErr error
|
||||
for _, candidate := range oracleMetadataNamePairs(dbName, tableName) {
|
||||
metadataTableName := escapeOracleMetadataLiteralExact(candidate.table)
|
||||
metadataSchemaName := escapeOracleMetadataLiteralExact(candidate.schema)
|
||||
query := fmt.Sprintf("SELECT DBMS_METADATA.GET_DDL('TABLE', '%s', '%s') as ddl FROM DUAL",
|
||||
metadataTableName, metadataSchemaName)
|
||||
|
||||
if dbName == "" {
|
||||
query = fmt.Sprintf("SELECT DBMS_METADATA.GET_DDL('TABLE', '%s') as ddl FROM DUAL", metadataTableName)
|
||||
}
|
||||
|
||||
data, _, err := o.Query(query)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if len(data) > 0 {
|
||||
if val, ok := data[0]["DDL"]; ok {
|
||||
return o.appendOracleCommentDDL(fmt.Sprintf("%v", val), dbName, tableName), nil
|
||||
if candidate.schema == "" {
|
||||
query = fmt.Sprintf("SELECT DBMS_METADATA.GET_DDL('TABLE', '%s') as ddl FROM DUAL", metadataTableName)
|
||||
}
|
||||
|
||||
data, _, err := o.Query(query)
|
||||
if err != nil {
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if len(data) > 0 {
|
||||
if val, ok := data[0]["DDL"]; ok {
|
||||
ddl := strings.TrimSpace(fmt.Sprintf("%v", val))
|
||||
if ddl != "" {
|
||||
return o.appendOracleCommentDDL(ddl, candidate.schema, candidate.table), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if firstErr != nil {
|
||||
return "", firstErr
|
||||
}
|
||||
return "", fmt.Errorf("未找到建表语句")
|
||||
}
|
||||
|
||||
func (o *OracleDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
|
||||
metadataTableName := escapeOracleMetadataLiteral(tableName)
|
||||
metadataSchemaName := escapeOracleMetadataLiteral(dbName)
|
||||
query := fmt.Sprintf(`SELECT c.column_name AS "COLUMN_NAME", c.data_type AS "DATA_TYPE", c.data_length AS "DATA_LENGTH", c.char_length AS "CHAR_LENGTH", c.data_precision AS "DATA_PRECISION", c.data_scale AS "DATA_SCALE", c.nullable AS "NULLABLE", c.data_default AS "DATA_DEFAULT",
|
||||
CASE WHEN pk.column_name IS NOT NULL THEN 'PRI' ELSE '' END AS "COLUMN_KEY",
|
||||
cc.comments AS "COMMENT"
|
||||
FROM all_tab_columns c
|
||||
LEFT JOIN all_col_comments cc
|
||||
ON cc.owner = c.owner AND cc.table_name = c.table_name AND cc.column_name = c.column_name
|
||||
LEFT JOIN (
|
||||
SELECT cols.owner, cols.table_name, cols.column_name
|
||||
FROM all_constraints cons
|
||||
JOIN all_cons_columns cols
|
||||
ON cons.owner = cols.owner AND cons.constraint_name = cols.constraint_name
|
||||
WHERE cons.constraint_type = 'P'
|
||||
) pk ON c.owner = pk.owner AND c.table_name = pk.table_name AND c.column_name = pk.column_name
|
||||
WHERE c.owner = '%s' AND c.table_name = '%s'
|
||||
ORDER BY c.column_id`, metadataSchemaName, metadataTableName)
|
||||
for _, candidate := range oracleMetadataNamePairs(dbName, tableName) {
|
||||
query := buildOracleColumnsQuery(candidate.schema, candidate.table)
|
||||
data, _, err := o.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if dbName == "" {
|
||||
query = fmt.Sprintf(`SELECT c.column_name AS "COLUMN_NAME", c.data_type AS "DATA_TYPE", c.data_length AS "DATA_LENGTH", c.char_length AS "CHAR_LENGTH", c.data_precision AS "DATA_PRECISION", c.data_scale AS "DATA_SCALE", c.nullable AS "NULLABLE", c.data_default AS "DATA_DEFAULT",
|
||||
CASE WHEN pk.column_name IS NOT NULL THEN 'PRI' ELSE '' END AS "COLUMN_KEY",
|
||||
cc.comments AS "COMMENT"
|
||||
FROM user_tab_columns c
|
||||
LEFT JOIN user_col_comments cc
|
||||
ON cc.table_name = c.table_name AND cc.column_name = c.column_name
|
||||
LEFT JOIN (
|
||||
SELECT cols.table_name, cols.column_name
|
||||
FROM user_constraints cons
|
||||
JOIN user_cons_columns cols USING (constraint_name)
|
||||
WHERE cons.constraint_type = 'P'
|
||||
) pk ON c.table_name = pk.table_name AND c.column_name = pk.column_name
|
||||
WHERE c.table_name = '%s'
|
||||
ORDER BY c.column_id`, metadataTableName)
|
||||
return parseOracleColumns(data), nil
|
||||
}
|
||||
if columns, err := o.inferOracleColumnsFromSelect(dbName, tableName); err == nil && len(columns) > 0 {
|
||||
return columns, nil
|
||||
}
|
||||
return []connection.ColumnDefinition{}, nil
|
||||
}
|
||||
|
||||
func (o *OracleDB) inferOracleColumnsFromSelect(dbName string, tableName string) ([]connection.ColumnDefinition, error) {
|
||||
if o.conn == nil {
|
||||
return nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
data, _, err := o.Query(query)
|
||||
var firstErr error
|
||||
for _, candidate := range oracleMetadataNamePairs(dbName, tableName) {
|
||||
query := "SELECT * FROM " + quoteOracleTableRef(candidate.schema, candidate.table) + " WHERE 1 = 0"
|
||||
rows, err := o.conn.Query(query)
|
||||
if err != nil {
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
continue
|
||||
}
|
||||
columns, parseErr := oracleColumnsFromSQLRows(rows)
|
||||
closeErr := rows.Close()
|
||||
if parseErr != nil {
|
||||
if firstErr == nil {
|
||||
firstErr = parseErr
|
||||
}
|
||||
continue
|
||||
}
|
||||
if closeErr != nil {
|
||||
if firstErr == nil {
|
||||
firstErr = closeErr
|
||||
}
|
||||
continue
|
||||
}
|
||||
if len(columns) > 0 {
|
||||
return columns, nil
|
||||
}
|
||||
}
|
||||
if firstErr != nil {
|
||||
return nil, firstErr
|
||||
}
|
||||
return nil, fmt.Errorf("未获取到字段定义")
|
||||
}
|
||||
|
||||
func oracleColumnsFromSQLRows(rows *sql.Rows) ([]connection.ColumnDefinition, error) {
|
||||
names, err := rows.Columns()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
colTypes, err := rows.ColumnTypes()
|
||||
if err != nil || len(colTypes) != len(names) {
|
||||
colTypes = nil
|
||||
}
|
||||
|
||||
var columns []connection.ColumnDefinition
|
||||
for _, row := range data {
|
||||
columns := make([]connection.ColumnDefinition, 0, len(names))
|
||||
for idx, name := range names {
|
||||
col := connection.ColumnDefinition{
|
||||
Name: oracleRowString(row, "COLUMN_NAME"),
|
||||
Type: formatOracleColumnType(row),
|
||||
Nullable: oracleRowString(row, "NULLABLE"),
|
||||
Key: oracleRowString(row, "COLUMN_KEY"),
|
||||
Comment: oracleRowString(row, "COMMENT"),
|
||||
Name: strings.TrimSpace(name),
|
||||
Nullable: "",
|
||||
Key: "",
|
||||
Extra: "",
|
||||
Comment: "",
|
||||
}
|
||||
|
||||
if defaultValue := oracleRowValue(row, "DATA_DEFAULT"); defaultValue != nil {
|
||||
d := fmt.Sprintf("%v", defaultValue)
|
||||
col.Default = &d
|
||||
if colTypes != nil && idx < len(colTypes) && colTypes[idx] != nil {
|
||||
col.Type = formatOracleSQLColumnType(colTypes[idx])
|
||||
if nullable, ok := colTypes[idx].Nullable(); ok {
|
||||
if nullable {
|
||||
col.Nullable = "YES"
|
||||
} else {
|
||||
col.Nullable = "NO"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
columns = append(columns, col)
|
||||
}
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
func formatOracleSQLColumnType(colType *sql.ColumnType) string {
|
||||
if colType == nil {
|
||||
return ""
|
||||
}
|
||||
typeName := strings.TrimSpace(colType.DatabaseTypeName())
|
||||
if typeName == "" {
|
||||
return ""
|
||||
}
|
||||
upperType := strings.ToUpper(typeName)
|
||||
if length, ok := colType.Length(); ok && length > 0 && strings.Contains(upperType, "CHAR") {
|
||||
return fmt.Sprintf("%s(%d)", typeName, length)
|
||||
}
|
||||
if precision, scale, ok := colType.DecimalSize(); ok && precision > 0 && (strings.Contains(upperType, "NUMBER") || strings.Contains(upperType, "DECIMAL") || strings.Contains(upperType, "NUMERIC")) {
|
||||
if scale > 0 {
|
||||
return fmt.Sprintf("%s(%d,%d)", typeName, precision, scale)
|
||||
}
|
||||
return fmt.Sprintf("%s(%d)", typeName, precision)
|
||||
}
|
||||
return typeName
|
||||
}
|
||||
|
||||
func oracleRowValue(row map[string]interface{}, names ...string) interface{} {
|
||||
for _, name := range names {
|
||||
if value, ok := row[name]; ok {
|
||||
@@ -482,12 +547,12 @@ func formatOracleColumnType(row map[string]interface{}) string {
|
||||
}
|
||||
|
||||
func (o *OracleDB) appendOracleCommentDDL(baseDDL string, dbName string, tableName string) string {
|
||||
table := strings.ToUpper(strings.TrimSpace(tableName))
|
||||
table := strings.TrimSpace(tableName)
|
||||
if strings.TrimSpace(baseDDL) == "" || table == "" {
|
||||
return baseDDL
|
||||
}
|
||||
|
||||
schema := strings.ToUpper(strings.TrimSpace(dbName))
|
||||
schema := strings.TrimSpace(dbName)
|
||||
tableRef := quoteOracleDDLIdentifier(table)
|
||||
if schema != "" {
|
||||
tableRef = quoteOracleDDLIdentifier(schema) + "." + tableRef
|
||||
@@ -523,10 +588,10 @@ func (o *OracleDB) appendOracleCommentDDL(baseDDL string, dbName string, tableNa
|
||||
}
|
||||
|
||||
func (o *OracleDB) fetchOracleTableComment(schema string, table string) string {
|
||||
escapedTable := escapeOracleMetadataLiteral(table)
|
||||
escapedTable := escapeOracleMetadataLiteralExact(table)
|
||||
var query string
|
||||
if strings.TrimSpace(schema) != "" {
|
||||
query = fmt.Sprintf(`SELECT comments AS "COMMENT" FROM all_tab_comments WHERE owner = '%s' AND table_name = '%s' AND comments IS NOT NULL`, escapeOracleMetadataLiteral(schema), escapedTable)
|
||||
query = fmt.Sprintf(`SELECT comments AS "COMMENT" FROM all_tab_comments WHERE owner = '%s' AND table_name = '%s' AND comments IS NOT NULL`, escapeOracleMetadataLiteralExact(schema), escapedTable)
|
||||
} else {
|
||||
query = fmt.Sprintf(`SELECT comments AS "COMMENT" FROM user_tab_comments WHERE table_name = '%s' AND comments IS NOT NULL`, escapedTable)
|
||||
}
|
||||
@@ -543,7 +608,7 @@ type oracleColumnComment struct {
|
||||
}
|
||||
|
||||
func (o *OracleDB) fetchOracleColumnComments(schema string, table string) []oracleColumnComment {
|
||||
escapedTable := escapeOracleMetadataLiteral(table)
|
||||
escapedTable := escapeOracleMetadataLiteralExact(table)
|
||||
var query string
|
||||
if strings.TrimSpace(schema) != "" {
|
||||
query = fmt.Sprintf(`SELECT c.column_name AS "COLUMN_NAME", cc.comments AS "COMMENT"
|
||||
@@ -551,7 +616,7 @@ FROM all_tab_columns c
|
||||
JOIN all_col_comments cc
|
||||
ON cc.owner = c.owner AND cc.table_name = c.table_name AND cc.column_name = c.column_name
|
||||
WHERE c.owner = '%s' AND c.table_name = '%s' AND cc.comments IS NOT NULL
|
||||
ORDER BY c.column_id`, escapeOracleMetadataLiteral(schema), escapedTable)
|
||||
ORDER BY c.column_id`, escapeOracleMetadataLiteralExact(schema), escapedTable)
|
||||
} else {
|
||||
query = fmt.Sprintf(`SELECT c.column_name AS "COLUMN_NAME", cc.comments AS "COMMENT"
|
||||
FROM user_tab_columns c
|
||||
@@ -579,6 +644,14 @@ func quoteOracleDDLIdentifier(ident string) string {
|
||||
return `"` + strings.ReplaceAll(strings.TrimSpace(ident), `"`, `""`) + `"`
|
||||
}
|
||||
|
||||
func quoteOracleTableRef(schema string, table string) string {
|
||||
tableRef := quoteOracleDDLIdentifier(table)
|
||||
if strings.TrimSpace(schema) != "" {
|
||||
return quoteOracleDDLIdentifier(schema) + "." + tableRef
|
||||
}
|
||||
return tableRef
|
||||
}
|
||||
|
||||
func escapeOracleCommentLiteral(text string) string {
|
||||
return strings.ReplaceAll(text, "'", "''")
|
||||
}
|
||||
@@ -587,14 +660,132 @@ func escapeOracleMetadataLiteral(text string) string {
|
||||
return strings.ReplaceAll(strings.ToUpper(strings.TrimSpace(text)), "'", "''")
|
||||
}
|
||||
|
||||
func escapeOracleMetadataLiteralExact(text string) string {
|
||||
return strings.ReplaceAll(strings.TrimSpace(text), "'", "''")
|
||||
}
|
||||
|
||||
type oracleMetadataNamePair struct {
|
||||
schema string
|
||||
table string
|
||||
}
|
||||
|
||||
func oracleMetadataNamePairs(dbName string, tableName string) []oracleMetadataNamePair {
|
||||
rawSchema := strings.TrimSpace(dbName)
|
||||
rawTable := strings.TrimSpace(tableName)
|
||||
if rawTable == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
upperSchema := strings.ToUpper(rawSchema)
|
||||
upperTable := strings.ToUpper(rawTable)
|
||||
pairs := make([]oracleMetadataNamePair, 0, 4)
|
||||
seen := map[string]struct{}{}
|
||||
add := func(schema string, table string) {
|
||||
key := schema + "\x00" + table
|
||||
if _, ok := seen[key]; ok {
|
||||
return
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
pairs = append(pairs, oracleMetadataNamePair{schema: schema, table: table})
|
||||
}
|
||||
|
||||
add(rawSchema, rawTable)
|
||||
add(upperSchema, upperTable)
|
||||
add(rawSchema, upperTable)
|
||||
add(upperSchema, rawTable)
|
||||
return pairs
|
||||
}
|
||||
|
||||
func buildOracleColumnsQuery(schema string, table string) string {
|
||||
metadataTableName := escapeOracleMetadataLiteralExact(table)
|
||||
metadataSchemaName := escapeOracleMetadataLiteralExact(schema)
|
||||
if strings.TrimSpace(schema) == "" {
|
||||
return fmt.Sprintf(`SELECT c.column_name AS "COLUMN_NAME", c.data_type AS "DATA_TYPE", c.data_length AS "DATA_LENGTH", c.char_length AS "CHAR_LENGTH", c.data_precision AS "DATA_PRECISION", c.data_scale AS "DATA_SCALE", c.nullable AS "NULLABLE", c.data_default AS "DATA_DEFAULT",
|
||||
CASE WHEN pk.column_name IS NOT NULL THEN 'PRI' ELSE '' END AS "COLUMN_KEY",
|
||||
cc.comments AS "COMMENT"
|
||||
FROM user_tab_columns c
|
||||
LEFT JOIN user_col_comments cc
|
||||
ON cc.table_name = c.table_name AND cc.column_name = c.column_name
|
||||
LEFT JOIN (
|
||||
SELECT cols.table_name, cols.column_name
|
||||
FROM user_constraints cons
|
||||
JOIN user_cons_columns cols USING (constraint_name)
|
||||
WHERE cons.constraint_type = 'P'
|
||||
) pk ON c.table_name = pk.table_name AND c.column_name = pk.column_name
|
||||
WHERE c.table_name = '%s'
|
||||
ORDER BY c.column_id`, metadataTableName)
|
||||
}
|
||||
|
||||
return fmt.Sprintf(`SELECT c.column_name AS "COLUMN_NAME", c.data_type AS "DATA_TYPE", c.data_length AS "DATA_LENGTH", c.char_length AS "CHAR_LENGTH", c.data_precision AS "DATA_PRECISION", c.data_scale AS "DATA_SCALE", c.nullable AS "NULLABLE", c.data_default AS "DATA_DEFAULT",
|
||||
CASE WHEN pk.column_name IS NOT NULL THEN 'PRI' ELSE '' END AS "COLUMN_KEY",
|
||||
cc.comments AS "COMMENT"
|
||||
FROM all_tab_columns c
|
||||
LEFT JOIN all_col_comments cc
|
||||
ON cc.owner = c.owner AND cc.table_name = c.table_name AND cc.column_name = c.column_name
|
||||
LEFT JOIN (
|
||||
SELECT cols.owner, cols.table_name, cols.column_name
|
||||
FROM all_constraints cons
|
||||
JOIN all_cons_columns cols
|
||||
ON cons.owner = cols.owner AND cons.constraint_name = cols.constraint_name
|
||||
WHERE cons.constraint_type = 'P'
|
||||
) pk ON c.owner = pk.owner AND c.table_name = pk.table_name AND c.column_name = pk.column_name
|
||||
WHERE c.owner = '%s' AND c.table_name = '%s'
|
||||
ORDER BY c.column_id`, metadataSchemaName, metadataTableName)
|
||||
}
|
||||
|
||||
func parseOracleColumns(data []map[string]interface{}) []connection.ColumnDefinition {
|
||||
columns := make([]connection.ColumnDefinition, 0, len(data))
|
||||
for _, row := range data {
|
||||
col := connection.ColumnDefinition{
|
||||
Name: oracleRowString(row, "COLUMN_NAME"),
|
||||
Type: formatOracleColumnType(row),
|
||||
Nullable: oracleRowString(row, "NULLABLE"),
|
||||
Key: oracleRowString(row, "COLUMN_KEY"),
|
||||
Comment: oracleRowString(row, "COMMENT"),
|
||||
}
|
||||
|
||||
if defaultValue := oracleRowValue(row, "DATA_DEFAULT"); defaultValue != nil {
|
||||
d := fmt.Sprintf("%v", defaultValue)
|
||||
col.Default = &d
|
||||
}
|
||||
|
||||
columns = append(columns, col)
|
||||
}
|
||||
return columns
|
||||
}
|
||||
|
||||
func (o *OracleDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
|
||||
esc := func(s string) string { return strings.ReplaceAll(strings.ToUpper(strings.TrimSpace(s)), "'", "''") }
|
||||
table := esc(tableName)
|
||||
if table == "" {
|
||||
if strings.TrimSpace(tableName) == "" {
|
||||
return nil, fmt.Errorf("表名不能为空")
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`SELECT c.index_name, c.column_name, i.uniqueness, c.column_position, i.index_type
|
||||
for _, candidate := range oracleMetadataNamePairs(dbName, tableName) {
|
||||
data, _, err := o.Query(buildOracleIndexesQuery(candidate.schema, candidate.table))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
return parseOracleIndexes(data), nil
|
||||
}
|
||||
return []connection.IndexDefinition{}, nil
|
||||
}
|
||||
|
||||
func buildOracleIndexesQuery(schema string, table string) string {
|
||||
metadataTableName := escapeOracleMetadataLiteralExact(table)
|
||||
metadataSchemaName := escapeOracleMetadataLiteralExact(schema)
|
||||
if strings.TrimSpace(schema) == "" {
|
||||
return fmt.Sprintf(`SELECT c.index_name, c.column_name, i.uniqueness, c.column_position, i.index_type
|
||||
FROM user_ind_columns c
|
||||
JOIN user_indexes i ON i.index_name = c.index_name
|
||||
WHERE c.table_name = '%s'
|
||||
AND c.column_name IS NOT NULL
|
||||
AND c.column_name NOT LIKE 'SYS_NC%%$'
|
||||
AND i.index_type NOT LIKE 'FUNCTION-BASED%%'
|
||||
ORDER BY c.index_name, c.column_position`, metadataTableName)
|
||||
}
|
||||
return fmt.Sprintf(`SELECT c.index_name, c.column_name, i.uniqueness, c.column_position, i.index_type
|
||||
FROM all_ind_columns c
|
||||
JOIN all_indexes i ON i.owner = c.index_owner AND i.index_name = c.index_name
|
||||
WHERE c.table_owner = '%s'
|
||||
@@ -602,24 +793,10 @@ func (o *OracleDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefin
|
||||
AND c.column_name IS NOT NULL
|
||||
AND c.column_name NOT LIKE 'SYS_NC%%$'
|
||||
AND i.index_type NOT LIKE 'FUNCTION-BASED%%'
|
||||
ORDER BY c.index_name, c.column_position`, esc(dbName), table)
|
||||
|
||||
if strings.TrimSpace(dbName) == "" {
|
||||
query = fmt.Sprintf(`SELECT c.index_name, c.column_name, i.uniqueness, c.column_position, i.index_type
|
||||
FROM user_ind_columns c
|
||||
JOIN user_indexes i ON i.index_name = c.index_name
|
||||
WHERE c.table_name = '%s'
|
||||
AND c.column_name IS NOT NULL
|
||||
AND c.column_name NOT LIKE 'SYS_NC%%$'
|
||||
AND i.index_type NOT LIKE 'FUNCTION-BASED%%'
|
||||
ORDER BY c.index_name, c.column_position`, table)
|
||||
}
|
||||
|
||||
data, _, err := o.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ORDER BY c.index_name, c.column_position`, metadataSchemaName, metadataTableName)
|
||||
}
|
||||
|
||||
func parseOracleIndexes(data []map[string]interface{}) []connection.IndexDefinition {
|
||||
getValue := func(row map[string]interface{}, names ...string) interface{} {
|
||||
for _, name := range names {
|
||||
if value, ok := row[name]; ok {
|
||||
@@ -663,24 +840,44 @@ func (o *OracleDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefin
|
||||
}
|
||||
indexes = append(indexes, idx)
|
||||
}
|
||||
return indexes, nil
|
||||
return indexes
|
||||
}
|
||||
|
||||
func (o *OracleDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
|
||||
// Simplified query for FKs
|
||||
query := fmt.Sprintf(`SELECT a.constraint_name, a.column_name, c_pk.table_name r_table_name, b.column_name r_column_name
|
||||
for _, candidate := range oracleMetadataNamePairs(dbName, tableName) {
|
||||
data, _, err := o.Query(buildOracleForeignKeysQuery(candidate.schema, candidate.table))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
return parseOracleForeignKeys(data), nil
|
||||
}
|
||||
return []connection.ForeignKeyDefinition{}, nil
|
||||
}
|
||||
|
||||
func buildOracleForeignKeysQuery(schema string, table string) string {
|
||||
metadataTableName := escapeOracleMetadataLiteralExact(table)
|
||||
metadataSchemaName := escapeOracleMetadataLiteralExact(schema)
|
||||
if strings.TrimSpace(schema) == "" {
|
||||
return fmt.Sprintf(`SELECT a.constraint_name, a.column_name, c_pk.table_name r_table_name, b.column_name r_column_name
|
||||
FROM user_cons_columns a
|
||||
JOIN user_constraints c ON a.constraint_name = c.constraint_name
|
||||
JOIN user_constraints c_pk ON c.r_constraint_name = c_pk.constraint_name
|
||||
JOIN user_cons_columns b ON c_pk.constraint_name = b.constraint_name AND a.position = b.position
|
||||
WHERE c.constraint_type = 'R' AND a.table_name = '%s'`, metadataTableName)
|
||||
}
|
||||
return fmt.Sprintf(`SELECT a.constraint_name, a.column_name, c_pk.table_name r_table_name, b.column_name r_column_name
|
||||
FROM all_cons_columns a
|
||||
JOIN all_constraints c ON a.owner = c.owner AND a.constraint_name = c.constraint_name
|
||||
JOIN all_constraints c_pk ON c.r_owner = c_pk.owner AND c.r_constraint_name = c_pk.constraint_name
|
||||
JOIN all_cons_columns b ON c_pk.owner = b.owner AND c_pk.constraint_name = b.constraint_name AND a.position = b.position
|
||||
WHERE c.constraint_type = 'R' AND a.owner = '%s' AND a.table_name = '%s'`,
|
||||
strings.ToUpper(dbName), strings.ToUpper(tableName))
|
||||
|
||||
data, _, err := o.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
metadataSchemaName, metadataTableName)
|
||||
}
|
||||
|
||||
func parseOracleForeignKeys(data []map[string]interface{}) []connection.ForeignKeyDefinition {
|
||||
var fks []connection.ForeignKeyDefinition
|
||||
for _, row := range data {
|
||||
fk := connection.ForeignKeyDefinition{
|
||||
@@ -692,20 +889,38 @@ func (o *OracleDB) GetForeignKeys(dbName, tableName string) ([]connection.Foreig
|
||||
}
|
||||
fks = append(fks, fk)
|
||||
}
|
||||
return fks, nil
|
||||
return fks
|
||||
}
|
||||
|
||||
func (o *OracleDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
|
||||
query := fmt.Sprintf(`SELECT trigger_name, trigger_type, triggering_event
|
||||
FROM all_triggers
|
||||
WHERE table_owner = '%s' AND table_name = '%s'`,
|
||||
strings.ToUpper(dbName), strings.ToUpper(tableName))
|
||||
|
||||
data, _, err := o.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
for _, candidate := range oracleMetadataNamePairs(dbName, tableName) {
|
||||
data, _, err := o.Query(buildOracleTriggersQuery(candidate.schema, candidate.table))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
return parseOracleTriggers(data), nil
|
||||
}
|
||||
return []connection.TriggerDefinition{}, nil
|
||||
}
|
||||
|
||||
func buildOracleTriggersQuery(schema string, table string) string {
|
||||
metadataTableName := escapeOracleMetadataLiteralExact(table)
|
||||
metadataSchemaName := escapeOracleMetadataLiteralExact(schema)
|
||||
if strings.TrimSpace(schema) == "" {
|
||||
return fmt.Sprintf(`SELECT trigger_name, trigger_type, triggering_event
|
||||
FROM user_triggers
|
||||
WHERE table_name = '%s'`, metadataTableName)
|
||||
}
|
||||
return fmt.Sprintf(`SELECT trigger_name, trigger_type, triggering_event
|
||||
FROM all_triggers
|
||||
WHERE table_owner = '%s' AND table_name = '%s'`,
|
||||
metadataSchemaName, metadataTableName)
|
||||
}
|
||||
|
||||
func parseOracleTriggers(data []map[string]interface{}) []connection.TriggerDefinition {
|
||||
var triggers []connection.TriggerDefinition
|
||||
for _, row := range data {
|
||||
trig := connection.TriggerDefinition{
|
||||
@@ -716,7 +931,7 @@ func (o *OracleDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDe
|
||||
}
|
||||
triggers = append(triggers, trig)
|
||||
}
|
||||
return triggers, nil
|
||||
return triggers
|
||||
}
|
||||
|
||||
func splitOracleQualifiedTableName(raw string) (string, string) {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@@ -16,9 +17,10 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
jsMaxSafeInteger int64 = 9007199254740991
|
||||
jsMinSafeInteger int64 = -9007199254740991
|
||||
jsMaxSafeUint uint64 = 9007199254740991
|
||||
jsMaxSafeInteger int64 = 9007199254740991
|
||||
jsMinSafeInteger int64 = -9007199254740991
|
||||
jsMaxSafeUint uint64 = 9007199254740991
|
||||
oceanBaseOracleScanDialect = "oceanbase-oracle"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -40,7 +42,15 @@ func normalizeQueryValueWithDBTypeAndDialect(v interface{}, databaseTypeName, di
|
||||
if tm, ok := v.(time.Time); ok {
|
||||
return normalizeTemporalValueForDisplay(tm, databaseTypeName, dialect)
|
||||
}
|
||||
if s, ok := v.(string); ok {
|
||||
if tm, normalizedType, ok := decodeOceanBaseOracleTemporalString(s, databaseTypeName, dialect); ok {
|
||||
return normalizeTemporalValueForDisplay(tm, normalizedType, dialect)
|
||||
}
|
||||
}
|
||||
if b, ok := v.([]byte); ok {
|
||||
if tm, normalizedType, ok := decodeOceanBaseOracleTemporalBytes(b, databaseTypeName, dialect); ok {
|
||||
return normalizeTemporalValueForDisplay(tm, normalizedType, dialect)
|
||||
}
|
||||
return bytesToDisplayValue(b, databaseTypeName)
|
||||
}
|
||||
return normalizeCompositeQueryValue(v)
|
||||
@@ -52,15 +62,19 @@ func normalizeTemporalValueForDisplay(value time.Time, databaseTypeName, dialect
|
||||
return zeroValue
|
||||
}
|
||||
}
|
||||
if shouldDisplayTemporalValueAsDateOnly(databaseTypeName, dialect) {
|
||||
if shouldDisplayTemporalValueAsDateOnly(databaseTypeName, dialect) || shouldDisplayOceanBaseOracleDateAsDateOnly(value, databaseTypeName, dialect) {
|
||||
return value.Format("2006-01-02")
|
||||
}
|
||||
return value.Format(time.RFC3339Nano)
|
||||
}
|
||||
|
||||
func shouldDisplayTemporalValueAsDateOnly(databaseTypeName, dialect string) bool {
|
||||
func isDateOnlyDatabaseTypeName(databaseTypeName string) bool {
|
||||
typeName := strings.ToUpper(strings.TrimSpace(databaseTypeName))
|
||||
if typeName != "DATE" && typeName != "NEWDATE" {
|
||||
return typeName == "DATE" || typeName == "NEWDATE"
|
||||
}
|
||||
|
||||
func shouldDisplayTemporalValueAsDateOnly(databaseTypeName, dialect string) bool {
|
||||
if !isDateOnlyDatabaseTypeName(databaseTypeName) {
|
||||
return false
|
||||
}
|
||||
switch strings.ToLower(strings.TrimSpace(dialect)) {
|
||||
@@ -71,6 +85,317 @@ func shouldDisplayTemporalValueAsDateOnly(databaseTypeName, dialect string) bool
|
||||
}
|
||||
}
|
||||
|
||||
func shouldDisplayOceanBaseOracleDateAsDateOnly(value time.Time, databaseTypeName, dialect string) bool {
|
||||
if !isDateOnlyDatabaseTypeName(databaseTypeName) {
|
||||
return false
|
||||
}
|
||||
if strings.ToLower(strings.TrimSpace(dialect)) != oceanBaseOracleScanDialect {
|
||||
return false
|
||||
}
|
||||
return value.Hour() == 0 && value.Minute() == 0 && value.Second() == 0 && value.Nanosecond() == 0
|
||||
}
|
||||
|
||||
func decodeOceanBaseOracleTemporalString(value string, databaseTypeName, dialect string) (time.Time, string, bool) {
|
||||
return decodeOceanBaseOracleTemporalBytes([]byte(value), databaseTypeName, dialect)
|
||||
}
|
||||
|
||||
func decodeOceanBaseOracleTemporalBytes(value []byte, databaseTypeName, dialect string) (time.Time, string, bool) {
|
||||
if strings.ToLower(strings.TrimSpace(dialect)) != oceanBaseOracleScanDialect {
|
||||
return time.Time{}, "", false
|
||||
}
|
||||
if !shouldAttemptOceanBaseOracleTemporalDecode(databaseTypeName, value) {
|
||||
return time.Time{}, "", false
|
||||
}
|
||||
return parseOceanBaseOracleTemporal(value, databaseTypeName)
|
||||
}
|
||||
|
||||
func shouldAttemptOceanBaseOracleTemporalDecode(databaseTypeName string, value []byte) bool {
|
||||
if isOceanBaseOracleTemporalDatabaseTypeName(databaseTypeName) {
|
||||
return true
|
||||
}
|
||||
if !hasOceanBaseOracleTemporalEncodedLength(value) {
|
||||
return false
|
||||
}
|
||||
typeName := strings.ToUpper(strings.TrimSpace(databaseTypeName))
|
||||
if typeName == "" {
|
||||
return true
|
||||
}
|
||||
if isLikelyOceanBaseOracleTemporalCarrierType(typeName) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func hasOceanBaseOracleTemporalEncodedLength(value []byte) bool {
|
||||
switch len(value) {
|
||||
case 5, 7, 8, 11, 12, 13:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isLikelyOceanBaseOracleTemporalCarrierType(typeName string) bool {
|
||||
if typeName == "" {
|
||||
return false
|
||||
}
|
||||
switch {
|
||||
case strings.Contains(typeName, "CHAR"),
|
||||
strings.Contains(typeName, "TEXT"),
|
||||
strings.Contains(typeName, "STRING"),
|
||||
strings.Contains(typeName, "BINARY"),
|
||||
strings.Contains(typeName, "VARBINARY"),
|
||||
strings.Contains(typeName, "RAW"),
|
||||
strings.Contains(typeName, "BLOB"),
|
||||
strings.Contains(typeName, "LOB"):
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isOceanBaseOracleTemporalDatabaseTypeName(databaseTypeName string) bool {
|
||||
typeName := strings.ToUpper(strings.TrimSpace(databaseTypeName))
|
||||
if typeName == "DATE" || typeName == "TYPE_CA" {
|
||||
return true
|
||||
}
|
||||
return strings.HasPrefix(typeName, "TIMESTAMP")
|
||||
}
|
||||
|
||||
func parseOceanBaseOracleTemporal(value []byte, databaseTypeName string) (time.Time, string, bool) {
|
||||
if tm, normalizedType, ok := parseOracleBinaryTemporal(value, databaseTypeName); ok {
|
||||
return tm, normalizedType, true
|
||||
}
|
||||
if tm, normalizedType, ok := parseOceanBaseOracleTypeCATemporal(value, databaseTypeName); ok {
|
||||
return tm, normalizedType, true
|
||||
}
|
||||
return parseMySQLLengthEncodedTemporal(value, databaseTypeName)
|
||||
}
|
||||
|
||||
func parseOceanBaseOracleTypeCATemporal(value []byte, databaseTypeName string) (time.Time, string, bool) {
|
||||
if len(value) != 12 {
|
||||
return time.Time{}, "", false
|
||||
}
|
||||
|
||||
yearHigh := int(value[0])
|
||||
yearLow := int(value[1])
|
||||
month := int(value[2])
|
||||
day := int(value[3])
|
||||
hour := int(value[4])
|
||||
minute := int(value[5])
|
||||
second := int(value[6])
|
||||
nsec := int(binary.LittleEndian.Uint32(value[7:11]))
|
||||
scale := int(value[11])
|
||||
|
||||
if yearHigh < 0 || yearHigh > 99 || yearLow < 0 || yearLow > 99 {
|
||||
return time.Time{}, "", false
|
||||
}
|
||||
if month < 1 || month > 12 || day < 1 || day > 31 || hour < 0 || hour > 23 || minute < 0 || minute > 59 || second < 0 || second > 59 {
|
||||
return time.Time{}, "", false
|
||||
}
|
||||
if scale < 0 || scale > 9 || nsec < 0 || nsec >= 1_000_000_000 {
|
||||
return time.Time{}, "", false
|
||||
}
|
||||
if !matchesTemporalScale(nsec, scale) {
|
||||
return time.Time{}, "", false
|
||||
}
|
||||
|
||||
year := yearHigh*100 + yearLow
|
||||
parsed := time.Date(year, time.Month(month), day, hour, minute, second, nsec, time.UTC)
|
||||
if parsed.Year() != year || int(parsed.Month()) != month || parsed.Day() != day ||
|
||||
parsed.Hour() != hour || parsed.Minute() != minute || parsed.Second() != second || parsed.Nanosecond() != nsec {
|
||||
return time.Time{}, "", false
|
||||
}
|
||||
return parsed, normalizeOracleTemporalDatabaseTypeName(databaseTypeName), true
|
||||
}
|
||||
|
||||
func matchesTemporalScale(nsec, scale int) bool {
|
||||
if scale >= 9 {
|
||||
return true
|
||||
}
|
||||
step := 1
|
||||
for i := 0; i < 9-scale; i++ {
|
||||
step *= 10
|
||||
}
|
||||
return nsec%step == 0
|
||||
}
|
||||
|
||||
func parseOracleBinaryTemporal(value []byte, databaseTypeName string) (time.Time, string, bool) {
|
||||
switch len(value) {
|
||||
case 7:
|
||||
tm, ok := parseOracleBinaryDateTime(value[:7])
|
||||
return tm, "DATE", ok
|
||||
case 11:
|
||||
tm, ok := parseOracleBinaryTimestamp(value)
|
||||
return tm, normalizeOracleTemporalDatabaseTypeName(databaseTypeName), ok
|
||||
case 13:
|
||||
tm, ok := parseOracleBinaryTimestampWithTimezone(value)
|
||||
return tm, normalizeOracleTemporalDatabaseTypeName(databaseTypeName), ok
|
||||
default:
|
||||
return time.Time{}, "", false
|
||||
}
|
||||
}
|
||||
|
||||
func parseMySQLLengthEncodedTemporal(value []byte, databaseTypeName string) (time.Time, string, bool) {
|
||||
if len(value) == 0 {
|
||||
return time.Time{}, "", false
|
||||
}
|
||||
payloadLength := int(value[0])
|
||||
if payloadLength != len(value)-1 {
|
||||
return time.Time{}, "", false
|
||||
}
|
||||
|
||||
switch payloadLength {
|
||||
case 4:
|
||||
tm, ok := parseMySQLBinaryDateTimePayload(value[1:], false)
|
||||
return tm, "DATE", ok
|
||||
case 7:
|
||||
tm, ok := parseMySQLBinaryDateTimePayload(value[1:], false)
|
||||
return tm, normalizeOracleTemporalDatabaseTypeName(databaseTypeName), ok
|
||||
case 11:
|
||||
tm, ok := parseMySQLBinaryDateTimePayload(value[1:], true)
|
||||
return tm, normalizeOracleTemporalDatabaseTypeName(databaseTypeName), ok
|
||||
default:
|
||||
return time.Time{}, "", false
|
||||
}
|
||||
}
|
||||
|
||||
func parseMySQLBinaryDateTimePayload(value []byte, withFraction bool) (time.Time, bool) {
|
||||
expectedLength := 7
|
||||
if !withFraction {
|
||||
switch len(value) {
|
||||
case 4:
|
||||
year := int(binary.LittleEndian.Uint16(value[0:2]))
|
||||
month := int(value[2])
|
||||
day := int(value[3])
|
||||
if year < 0 || month < 1 || month > 12 || day < 1 || day > 31 {
|
||||
return time.Time{}, false
|
||||
}
|
||||
parsed := time.Date(year, time.Month(month), day, 0, 0, 0, 0, time.UTC)
|
||||
if parsed.Year() != year || int(parsed.Month()) != month || parsed.Day() != day {
|
||||
return time.Time{}, false
|
||||
}
|
||||
return parsed, true
|
||||
case expectedLength:
|
||||
default:
|
||||
return time.Time{}, false
|
||||
}
|
||||
} else if len(value) != 11 {
|
||||
return time.Time{}, false
|
||||
}
|
||||
|
||||
year := int(binary.LittleEndian.Uint16(value[0:2]))
|
||||
month := int(value[2])
|
||||
day := int(value[3])
|
||||
hour := int(value[4])
|
||||
minute := int(value[5])
|
||||
second := int(value[6])
|
||||
nsec := 0
|
||||
if withFraction {
|
||||
usec := binary.LittleEndian.Uint32(value[7:11])
|
||||
if usec >= 1_000_000 {
|
||||
return time.Time{}, false
|
||||
}
|
||||
nsec = int(usec) * 1000
|
||||
}
|
||||
|
||||
if year < 0 || month < 1 || month > 12 || day < 1 || day > 31 || hour < 0 || hour > 23 || minute < 0 || minute > 59 || second < 0 || second > 59 {
|
||||
return time.Time{}, false
|
||||
}
|
||||
|
||||
parsed := time.Date(year, time.Month(month), day, hour, minute, second, nsec, time.UTC)
|
||||
if parsed.Year() != year || int(parsed.Month()) != month || parsed.Day() != day ||
|
||||
parsed.Hour() != hour || parsed.Minute() != minute || parsed.Second() != second || parsed.Nanosecond() != nsec {
|
||||
return time.Time{}, false
|
||||
}
|
||||
return parsed, true
|
||||
}
|
||||
|
||||
func normalizeOracleTemporalDatabaseTypeName(databaseTypeName string) string {
|
||||
typeName := strings.ToUpper(strings.TrimSpace(databaseTypeName))
|
||||
switch typeName {
|
||||
case "TYPE_CA":
|
||||
return "TIMESTAMP"
|
||||
default:
|
||||
if typeName == "" {
|
||||
return "TIMESTAMP"
|
||||
}
|
||||
return typeName
|
||||
}
|
||||
}
|
||||
|
||||
func parseOracleBinaryTimestamp(value []byte) (time.Time, bool) {
|
||||
if len(value) != 11 {
|
||||
return time.Time{}, false
|
||||
}
|
||||
baseTime, ok := parseOracleBinaryDateTime(value[:7])
|
||||
if !ok {
|
||||
return time.Time{}, false
|
||||
}
|
||||
nsec := binary.BigEndian.Uint32(value[7:11])
|
||||
if nsec >= 1_000_000_000 {
|
||||
return time.Time{}, false
|
||||
}
|
||||
return time.Date(
|
||||
baseTime.Year(),
|
||||
baseTime.Month(),
|
||||
baseTime.Day(),
|
||||
baseTime.Hour(),
|
||||
baseTime.Minute(),
|
||||
baseTime.Second(),
|
||||
int(nsec),
|
||||
time.UTC,
|
||||
), true
|
||||
}
|
||||
|
||||
func parseOracleBinaryTimestampWithTimezone(value []byte) (time.Time, bool) {
|
||||
if len(value) != 13 {
|
||||
return time.Time{}, false
|
||||
}
|
||||
baseTime, ok := parseOracleBinaryTimestamp(value[:11])
|
||||
if !ok {
|
||||
return time.Time{}, false
|
||||
}
|
||||
tzHour := int(value[11]) - 20
|
||||
tzMinute := int(value[12]) - 60
|
||||
if tzHour < -12 || tzHour > 14 || tzMinute < 0 || tzMinute >= 60 {
|
||||
return time.Time{}, false
|
||||
}
|
||||
location := time.FixedZone("", tzHour*3600+tzMinute*60)
|
||||
return time.Date(
|
||||
baseTime.Year(),
|
||||
baseTime.Month(),
|
||||
baseTime.Day(),
|
||||
baseTime.Hour(),
|
||||
baseTime.Minute(),
|
||||
baseTime.Second(),
|
||||
baseTime.Nanosecond(),
|
||||
location,
|
||||
), true
|
||||
}
|
||||
|
||||
func parseOracleBinaryDateTime(value []byte) (time.Time, bool) {
|
||||
if len(value) != 7 {
|
||||
return time.Time{}, false
|
||||
}
|
||||
year := (int(value[0]) - 100) * 100
|
||||
year += int(value[1]) - 100
|
||||
month := int(value[2])
|
||||
day := int(value[3])
|
||||
hour := int(value[4]) - 1
|
||||
minute := int(value[5]) - 1
|
||||
second := int(value[6]) - 1
|
||||
if year < 0 || month < 1 || month > 12 || day < 1 || day > 31 || hour < 0 || hour > 23 || minute < 0 || minute > 59 || second < 0 || second > 59 {
|
||||
return time.Time{}, false
|
||||
}
|
||||
parsed := time.Date(year, time.Month(month), day, hour, minute, second, 0, time.UTC)
|
||||
if parsed.Year() != year || int(parsed.Month()) != month || parsed.Day() != day ||
|
||||
parsed.Hour() != hour || parsed.Minute() != minute || parsed.Second() != second {
|
||||
return time.Time{}, false
|
||||
}
|
||||
return parsed, true
|
||||
}
|
||||
|
||||
func zeroTemporalDisplayValue(databaseTypeName string) (string, bool) {
|
||||
typeName := strings.ToUpper(strings.TrimSpace(databaseTypeName))
|
||||
if typeName == "" {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
@@ -220,6 +221,154 @@ func TestNormalizeQueryValueWithDBType_TimeStructToRFC3339(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeQueryValueWithDBTypeAndDialect_OceanBaseOracleBinaryTimestampString(t *testing.T) {
|
||||
raw := buildOracleBinaryTimestamp(time.Date(2026, 6, 16, 12, 34, 56, 123456000, time.UTC))
|
||||
|
||||
got := normalizeQueryValueWithDBTypeAndDialect(string(raw), "TYPE_CA", oceanBaseOracleScanDialect)
|
||||
if got != "2026-06-16T12:34:56.123456Z" {
|
||||
t.Fatalf("OceanBase Oracle 二进制 TIMESTAMP 字符串应解码为 RFC3339,实际=%v(%T)", got, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeQueryValueWithDBTypeAndDialect_OceanBaseOracleBinaryTimestampStringWithPrecisionType(t *testing.T) {
|
||||
raw := buildOracleBinaryTimestamp(time.Date(2026, 6, 16, 12, 34, 56, 123456000, time.UTC))
|
||||
|
||||
got := normalizeQueryValueWithDBTypeAndDialect(string(raw), "TIMESTAMP(6)", oceanBaseOracleScanDialect)
|
||||
if got != "2026-06-16T12:34:56.123456Z" {
|
||||
t.Fatalf("OceanBase Oracle TIMESTAMP(6) 字符串应解码为 RFC3339,实际=%v(%T)", got, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeQueryValueWithDBTypeAndDialect_OceanBaseOracleBinaryTimestampStringWithGenericCarrierType(t *testing.T) {
|
||||
raw := buildOracleBinaryTimestamp(time.Date(2026, 6, 16, 12, 34, 56, 123456000, time.UTC))
|
||||
|
||||
got := normalizeQueryValueWithDBTypeAndDialect(string(raw), "VARCHAR2", oceanBaseOracleScanDialect)
|
||||
if got != "2026-06-16T12:34:56.123456Z" {
|
||||
t.Fatalf("OceanBase Oracle 泛型载体类型的 TIMESTAMP 字符串应解码为 RFC3339,实际=%v(%T)", got, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeQueryValueWithDBTypeAndDialect_OceanBaseOracleBinaryTimestampBytes(t *testing.T) {
|
||||
raw := buildOracleBinaryTimestamp(time.Date(2026, 6, 16, 12, 34, 56, 123456000, time.UTC))
|
||||
|
||||
got := normalizeQueryValueWithDBTypeAndDialect(raw, "TYPE_CA", oceanBaseOracleScanDialect)
|
||||
if got != "2026-06-16T12:34:56.123456Z" {
|
||||
t.Fatalf("OceanBase Oracle 二进制 TIMESTAMP 字节值应解码为 RFC3339,实际=%v(%T)", got, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeQueryValueWithDBTypeAndDialect_OceanBaseOracleBinaryTimestampBytesWithGenericCarrierType(t *testing.T) {
|
||||
raw := buildOracleBinaryTimestamp(time.Date(2026, 6, 16, 12, 34, 56, 123456000, time.UTC))
|
||||
|
||||
got := normalizeQueryValueWithDBTypeAndDialect(raw, "VARCHAR2", oceanBaseOracleScanDialect)
|
||||
if got != "2026-06-16T12:34:56.123456Z" {
|
||||
t.Fatalf("OceanBase Oracle 泛型载体类型的 TIMESTAMP 字节值应解码为 RFC3339,实际=%v(%T)", got, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeQueryValueWithDBTypeAndDialect_OceanBaseOracleBinaryTimestampBytesWithPrecisionType(t *testing.T) {
|
||||
raw := buildOracleBinaryTimestamp(time.Date(2026, 6, 16, 12, 34, 56, 123456000, time.UTC))
|
||||
|
||||
got := normalizeQueryValueWithDBTypeAndDialect(raw, "TIMESTAMP(6)", oceanBaseOracleScanDialect)
|
||||
if got != "2026-06-16T12:34:56.123456Z" {
|
||||
t.Fatalf("OceanBase Oracle TIMESTAMP(6) 字节值应解码为 RFC3339,实际=%v(%T)", got, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeQueryValueWithDBTypeAndDialect_OceanBaseOracleBinaryTimestampStringWithoutTypeName(t *testing.T) {
|
||||
raw := buildOracleBinaryTimestamp(time.Date(2026, 6, 16, 12, 34, 56, 123456000, time.UTC))
|
||||
|
||||
got := normalizeQueryValueWithDBTypeAndDialect(string(raw), "", oceanBaseOracleScanDialect)
|
||||
if got != "2026-06-16T12:34:56.123456Z" {
|
||||
t.Fatalf("OceanBase Oracle 空类型名的 TIMESTAMP 字符串应解码为 RFC3339,实际=%v(%T)", got, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeQueryValueWithDBTypeAndDialect_OceanBaseOracleBinaryTimestampBytesWithoutTypeName(t *testing.T) {
|
||||
raw := buildOracleBinaryTimestamp(time.Date(2026, 6, 16, 12, 34, 56, 123456000, time.UTC))
|
||||
|
||||
got := normalizeQueryValueWithDBTypeAndDialect(raw, "", oceanBaseOracleScanDialect)
|
||||
if got != "2026-06-16T12:34:56.123456Z" {
|
||||
t.Fatalf("OceanBase Oracle 空类型名的 TIMESTAMP 字节值应解码为 RFC3339,实际=%v(%T)", got, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeQueryValueWithDBTypeAndDialect_OceanBaseOracleMySQLEncodedTimestampString(t *testing.T) {
|
||||
raw := buildMySQLBinaryTimestamp(time.Date(2026, 6, 16, 12, 34, 56, 123456000, time.UTC))
|
||||
|
||||
got := normalizeQueryValueWithDBTypeAndDialect(string(raw), "TYPE_CA", oceanBaseOracleScanDialect)
|
||||
if got != "2026-06-16T12:34:56.123456Z" {
|
||||
t.Fatalf("OceanBase Oracle length-encoded TIMESTAMP 字符串应解码为 RFC3339,实际=%v(%T)", got, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeQueryValueWithDBTypeAndDialect_OceanBaseOracleMySQLEncodedTimestampBytes(t *testing.T) {
|
||||
raw := buildMySQLBinaryTimestamp(time.Date(2026, 6, 16, 12, 34, 56, 123456000, time.UTC))
|
||||
|
||||
got := normalizeQueryValueWithDBTypeAndDialect(raw, "TYPE_CA", oceanBaseOracleScanDialect)
|
||||
if got != "2026-06-16T12:34:56.123456Z" {
|
||||
t.Fatalf("OceanBase Oracle length-encoded TIMESTAMP 字节值应解码为 RFC3339,实际=%v(%T)", got, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeQueryValueWithDBTypeAndDialect_OceanBaseOracleTypeCALiveTimestampString(t *testing.T) {
|
||||
raw := []byte{20, 26, 6, 16, 16, 46, 23, 96, 196, 119, 9, 6}
|
||||
|
||||
got := normalizeQueryValueWithDBTypeAndDialect(string(raw), "TYPE_CA", oceanBaseOracleScanDialect)
|
||||
if got != "2026-06-16T16:46:23.158844Z" {
|
||||
t.Fatalf("OceanBase Oracle TYPE_CA live TIMESTAMP 字符串应解码为 RFC3339,实际=%v(%T)", got, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeQueryValueWithDBTypeAndDialect_OceanBaseOracleTypeCALiveTimestampBytes(t *testing.T) {
|
||||
raw := []byte{20, 26, 6, 16, 16, 46, 23, 96, 196, 119, 9, 6}
|
||||
|
||||
got := normalizeQueryValueWithDBTypeAndDialect(raw, "TYPE_CA", oceanBaseOracleScanDialect)
|
||||
if got != "2026-06-16T16:46:23.158844Z" {
|
||||
t.Fatalf("OceanBase Oracle TYPE_CA live TIMESTAMP 字节值应解码为 RFC3339,实际=%v(%T)", got, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeQueryValueWithDBTypeAndDialect_OceanBaseOracleTypeCALiveTimestampWithoutFraction(t *testing.T) {
|
||||
raw := []byte{20, 26, 6, 17, 5, 0, 0, 0, 0, 0, 0, 6}
|
||||
|
||||
got := normalizeQueryValueWithDBTypeAndDialect(raw, "TYPE_CA", oceanBaseOracleScanDialect)
|
||||
if got != "2026-06-17T05:00:00Z" {
|
||||
t.Fatalf("OceanBase Oracle TYPE_CA 零小数 TIMESTAMP 字节值应解码为 RFC3339,实际=%v(%T)", got, got)
|
||||
}
|
||||
}
|
||||
|
||||
func buildOracleBinaryTimestamp(tm time.Time) []byte {
|
||||
if tm.Location() != time.UTC {
|
||||
tm = tm.In(time.UTC)
|
||||
}
|
||||
buf := []byte{
|
||||
byte(tm.Year()/100 + 100),
|
||||
byte(tm.Year()%100 + 100),
|
||||
byte(tm.Month()),
|
||||
byte(tm.Day()),
|
||||
byte(tm.Hour() + 1),
|
||||
byte(tm.Minute() + 1),
|
||||
byte(tm.Second() + 1),
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
}
|
||||
binary.BigEndian.PutUint32(buf[7:11], uint32(tm.Nanosecond()))
|
||||
return buf
|
||||
}
|
||||
|
||||
func buildMySQLBinaryTimestamp(tm time.Time) []byte {
|
||||
if tm.Location() != time.UTC {
|
||||
tm = tm.In(time.UTC)
|
||||
}
|
||||
buf := []byte{11, 0, 0, byte(tm.Month()), byte(tm.Day()), byte(tm.Hour()), byte(tm.Minute()), byte(tm.Second()), 0, 0, 0, 0}
|
||||
binary.LittleEndian.PutUint16(buf[1:3], uint16(tm.Year()))
|
||||
binary.LittleEndian.PutUint32(buf[8:12], uint32(tm.Nanosecond()/1000))
|
||||
return buf
|
||||
}
|
||||
|
||||
func TestNormalizeQueryValueWithDBTypeAndDialect_MySQLDateOnly(t *testing.T) {
|
||||
input := time.Date(2025, 10, 1, 0, 0, 0, 0, time.Local)
|
||||
|
||||
@@ -248,6 +397,24 @@ func TestNormalizeQueryValueWithDBTypeAndDialect_DatetimeKeepsTime(t *testing.T)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeQueryValueWithDBTypeAndDialect_OceanBaseOracleDateMidnightDisplaysDateOnly(t *testing.T) {
|
||||
input := time.Date(2025, 10, 1, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
got := normalizeQueryValueWithDBTypeAndDialect(input, "DATE", oceanBaseOracleScanDialect)
|
||||
if got != "2025-10-01" {
|
||||
t.Fatalf("OceanBase Oracle DATE 的午夜值应只展示日期,实际=%v(%T)", got, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeQueryValueWithDBTypeAndDialect_OceanBaseOracleDateKeepsNonMidnightTime(t *testing.T) {
|
||||
input := time.Date(2025, 10, 1, 13, 14, 15, 0, time.UTC)
|
||||
|
||||
got := normalizeQueryValueWithDBTypeAndDialect(input, "DATE", oceanBaseOracleScanDialect)
|
||||
if got != "2025-10-01T13:14:15Z" {
|
||||
t.Fatalf("OceanBase Oracle DATE 非午夜值应保留时间,实际=%v(%T)", got, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeQueryValueWithDBType_ZeroTemporalValues(t *testing.T) {
|
||||
zero := time.Time{}
|
||||
cases := []struct {
|
||||
|
||||
@@ -40,6 +40,78 @@ func (scanRowsDuplicateConn) QueryContext(_ context.Context, query string, args
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
if query == "SELECT timestamp_columns" {
|
||||
raw := buildOracleBinaryTimestamp(time.Date(2026, 6, 16, 12, 34, 56, 123456000, time.UTC))
|
||||
return &scanRowsDuplicateRows{
|
||||
columns: []string{"created_at"},
|
||||
columnTypes: []string{"TYPE_CA"},
|
||||
rows: [][]driver.Value{
|
||||
{
|
||||
string(raw),
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
if query == "SELECT timestamp_precision_columns" {
|
||||
raw := buildOracleBinaryTimestamp(time.Date(2026, 6, 16, 12, 34, 56, 123456000, time.UTC))
|
||||
return &scanRowsDuplicateRows{
|
||||
columns: []string{"created_at"},
|
||||
columnTypes: []string{"TIMESTAMP(6)"},
|
||||
rows: [][]driver.Value{
|
||||
{
|
||||
string(raw),
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
if query == "SELECT timestamp_generic_carrier_columns" {
|
||||
raw := buildOracleBinaryTimestamp(time.Date(2026, 6, 16, 12, 34, 56, 123456000, time.UTC))
|
||||
return &scanRowsDuplicateRows{
|
||||
columns: []string{"created_at"},
|
||||
columnTypes: []string{"VARCHAR2"},
|
||||
rows: [][]driver.Value{
|
||||
{
|
||||
string(raw),
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
if query == "SELECT timestamp_unknown_type_columns" {
|
||||
raw := buildOracleBinaryTimestamp(time.Date(2026, 6, 16, 12, 34, 56, 123456000, time.UTC))
|
||||
return &scanRowsDuplicateRows{
|
||||
columns: []string{"created_at"},
|
||||
columnTypes: []string{""},
|
||||
rows: [][]driver.Value{
|
||||
{
|
||||
string(raw),
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
if query == "SELECT timestamp_mysql_encoded_columns" {
|
||||
raw := buildMySQLBinaryTimestamp(time.Date(2026, 6, 16, 12, 34, 56, 123456000, time.UTC))
|
||||
return &scanRowsDuplicateRows{
|
||||
columns: []string{"created_at"},
|
||||
columnTypes: []string{"TYPE_CA"},
|
||||
rows: [][]driver.Value{
|
||||
{
|
||||
string(raw),
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
if query == "SELECT timestamp_type_ca_live_columns" {
|
||||
raw := []byte{20, 26, 6, 16, 16, 46, 23, 96, 196, 119, 9, 6}
|
||||
return &scanRowsDuplicateRows{
|
||||
columns: []string{"created_at"},
|
||||
columnTypes: []string{"TYPE_CA"},
|
||||
rows: [][]driver.Value{
|
||||
{
|
||||
string(raw),
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
return &scanRowsDuplicateRows{
|
||||
columns: []string{"id", "id", "name"},
|
||||
rows: [][]driver.Value{
|
||||
@@ -184,3 +256,267 @@ func TestScanRowsForOracleDialectKeepsDateTime(t *testing.T) {
|
||||
t.Fatalf("Oracle DATE 应保留 datetime 语义,实际=%v(%T)", data[0]["ship_date"], data[0]["ship_date"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanRowsForOceanBaseOracleDialectFormatsMidnightDateOnly(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
registerScanRowsDuplicateDriverOnce.Do(func() {
|
||||
sql.Register(scanRowsDuplicateDriverName, scanRowsDuplicateDriver{})
|
||||
})
|
||||
|
||||
dbConn, err := sql.Open(scanRowsDuplicateDriverName, "")
|
||||
if err != nil {
|
||||
t.Fatalf("open date scan rows db failed: %v", err)
|
||||
}
|
||||
defer dbConn.Close()
|
||||
|
||||
rows, err := dbConn.QueryContext(context.Background(), "SELECT date_columns")
|
||||
if err != nil {
|
||||
t.Fatalf("query date scan rows db failed: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
data, _, err := scanRowsForDialect(rows, oceanBaseOracleScanDialect)
|
||||
if err != nil {
|
||||
t.Fatalf("scanRowsForDialect returned error: %v", err)
|
||||
}
|
||||
if len(data) != 1 {
|
||||
t.Fatalf("expected one row, got=%d", len(data))
|
||||
}
|
||||
if data[0]["ship_date"] != "2025-10-01" {
|
||||
t.Fatalf("OceanBase Oracle DATE 的午夜值应展示为日期,实际=%v(%T)", data[0]["ship_date"], data[0]["ship_date"])
|
||||
}
|
||||
if data[0]["created_at"] != "2025-10-01T13:14:15Z" {
|
||||
t.Fatalf("OceanBase Oracle DATETIME 应保留时间,实际=%v(%T)", data[0]["created_at"], data[0]["created_at"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestOracleDBQueryUsesCustomScanDialect(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
registerScanRowsDuplicateDriverOnce.Do(func() {
|
||||
sql.Register(scanRowsDuplicateDriverName, scanRowsDuplicateDriver{})
|
||||
})
|
||||
|
||||
dbConn, err := sql.Open(scanRowsDuplicateDriverName, "")
|
||||
if err != nil {
|
||||
t.Fatalf("open date scan rows db failed: %v", err)
|
||||
}
|
||||
defer dbConn.Close()
|
||||
|
||||
oracleDB := &OracleDB{conn: dbConn, scanDialect: oceanBaseOracleScanDialect}
|
||||
data, _, err := oracleDB.Query("SELECT date_columns")
|
||||
if err != nil {
|
||||
t.Fatalf("OracleDB.Query returned error: %v", err)
|
||||
}
|
||||
if len(data) != 1 {
|
||||
t.Fatalf("expected one row, got=%d", len(data))
|
||||
}
|
||||
if data[0]["ship_date"] != "2025-10-01" {
|
||||
t.Fatalf("OracleDB 自定义扫描方言未生效,实际=%v(%T)", data[0]["ship_date"], data[0]["ship_date"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanRowsForOceanBaseOracleDialectDecodesBinaryTimestampString(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
registerScanRowsDuplicateDriverOnce.Do(func() {
|
||||
sql.Register(scanRowsDuplicateDriverName, scanRowsDuplicateDriver{})
|
||||
})
|
||||
|
||||
dbConn, err := sql.Open(scanRowsDuplicateDriverName, "")
|
||||
if err != nil {
|
||||
t.Fatalf("open timestamp scan rows db failed: %v", err)
|
||||
}
|
||||
defer dbConn.Close()
|
||||
|
||||
rows, err := dbConn.QueryContext(context.Background(), "SELECT timestamp_columns")
|
||||
if err != nil {
|
||||
t.Fatalf("query timestamp scan rows db failed: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
data, columns, err := scanRowsForDialect(rows, oceanBaseOracleScanDialect)
|
||||
if err != nil {
|
||||
t.Fatalf("scanRowsForDialect returned error: %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(columns, []string{"created_at"}) {
|
||||
t.Fatalf("unexpected columns: %v", columns)
|
||||
}
|
||||
if len(data) != 1 {
|
||||
t.Fatalf("expected one row, got=%d", len(data))
|
||||
}
|
||||
if data[0]["created_at"] != "2026-06-16T12:34:56.123456Z" {
|
||||
t.Fatalf("OceanBase Oracle 二进制 TIMESTAMP 应解码为 RFC3339,实际=%v(%T)", data[0]["created_at"], data[0]["created_at"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanRowsForOceanBaseOracleDialectDecodesBinaryTimestampStringWithPrecisionType(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
registerScanRowsDuplicateDriverOnce.Do(func() {
|
||||
sql.Register(scanRowsDuplicateDriverName, scanRowsDuplicateDriver{})
|
||||
})
|
||||
|
||||
dbConn, err := sql.Open(scanRowsDuplicateDriverName, "")
|
||||
if err != nil {
|
||||
t.Fatalf("open timestamp precision scan rows db failed: %v", err)
|
||||
}
|
||||
defer dbConn.Close()
|
||||
|
||||
rows, err := dbConn.QueryContext(context.Background(), "SELECT timestamp_precision_columns")
|
||||
if err != nil {
|
||||
t.Fatalf("query timestamp precision scan rows db failed: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
data, columns, err := scanRowsForDialect(rows, oceanBaseOracleScanDialect)
|
||||
if err != nil {
|
||||
t.Fatalf("scanRowsForDialect returned error: %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(columns, []string{"created_at"}) {
|
||||
t.Fatalf("unexpected columns: %v", columns)
|
||||
}
|
||||
if len(data) != 1 {
|
||||
t.Fatalf("expected one row, got=%d", len(data))
|
||||
}
|
||||
if data[0]["created_at"] != "2026-06-16T12:34:56.123456Z" {
|
||||
t.Fatalf("OceanBase Oracle TIMESTAMP(6) 应解码为 RFC3339,实际=%v(%T)", data[0]["created_at"], data[0]["created_at"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanRowsForOceanBaseOracleDialectDecodesBinaryTimestampStringWithGenericCarrierType(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
registerScanRowsDuplicateDriverOnce.Do(func() {
|
||||
sql.Register(scanRowsDuplicateDriverName, scanRowsDuplicateDriver{})
|
||||
})
|
||||
|
||||
dbConn, err := sql.Open(scanRowsDuplicateDriverName, "")
|
||||
if err != nil {
|
||||
t.Fatalf("open timestamp generic-carrier scan rows db failed: %v", err)
|
||||
}
|
||||
defer dbConn.Close()
|
||||
|
||||
rows, err := dbConn.QueryContext(context.Background(), "SELECT timestamp_generic_carrier_columns")
|
||||
if err != nil {
|
||||
t.Fatalf("query timestamp generic-carrier scan rows db failed: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
data, columns, err := scanRowsForDialect(rows, oceanBaseOracleScanDialect)
|
||||
if err != nil {
|
||||
t.Fatalf("scanRowsForDialect returned error: %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(columns, []string{"created_at"}) {
|
||||
t.Fatalf("unexpected columns: %v", columns)
|
||||
}
|
||||
if len(data) != 1 {
|
||||
t.Fatalf("expected one row, got=%d", len(data))
|
||||
}
|
||||
if data[0]["created_at"] != "2026-06-16T12:34:56.123456Z" {
|
||||
t.Fatalf("OceanBase Oracle 泛型载体类型的 TIMESTAMP 应解码为 RFC3339,实际=%v(%T)", data[0]["created_at"], data[0]["created_at"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanRowsForOceanBaseOracleDialectDecodesBinaryTimestampStringWithoutTypeName(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
registerScanRowsDuplicateDriverOnce.Do(func() {
|
||||
sql.Register(scanRowsDuplicateDriverName, scanRowsDuplicateDriver{})
|
||||
})
|
||||
|
||||
dbConn, err := sql.Open(scanRowsDuplicateDriverName, "")
|
||||
if err != nil {
|
||||
t.Fatalf("open timestamp unknown-type scan rows db failed: %v", err)
|
||||
}
|
||||
defer dbConn.Close()
|
||||
|
||||
rows, err := dbConn.QueryContext(context.Background(), "SELECT timestamp_unknown_type_columns")
|
||||
if err != nil {
|
||||
t.Fatalf("query timestamp unknown-type scan rows db failed: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
data, columns, err := scanRowsForDialect(rows, oceanBaseOracleScanDialect)
|
||||
if err != nil {
|
||||
t.Fatalf("scanRowsForDialect returned error: %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(columns, []string{"created_at"}) {
|
||||
t.Fatalf("unexpected columns: %v", columns)
|
||||
}
|
||||
if len(data) != 1 {
|
||||
t.Fatalf("expected one row, got=%d", len(data))
|
||||
}
|
||||
if data[0]["created_at"] != "2026-06-16T12:34:56.123456Z" {
|
||||
t.Fatalf("OceanBase Oracle 空类型名的 TIMESTAMP 应解码为 RFC3339,实际=%v(%T)", data[0]["created_at"], data[0]["created_at"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanRowsForOceanBaseOracleDialectDecodesMySQLLengthEncodedTimestampString(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
registerScanRowsDuplicateDriverOnce.Do(func() {
|
||||
sql.Register(scanRowsDuplicateDriverName, scanRowsDuplicateDriver{})
|
||||
})
|
||||
|
||||
dbConn, err := sql.Open(scanRowsDuplicateDriverName, "")
|
||||
if err != nil {
|
||||
t.Fatalf("open mysql-encoded timestamp scan rows db failed: %v", err)
|
||||
}
|
||||
defer dbConn.Close()
|
||||
|
||||
rows, err := dbConn.QueryContext(context.Background(), "SELECT timestamp_mysql_encoded_columns")
|
||||
if err != nil {
|
||||
t.Fatalf("query mysql-encoded timestamp scan rows db failed: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
data, columns, err := scanRowsForDialect(rows, oceanBaseOracleScanDialect)
|
||||
if err != nil {
|
||||
t.Fatalf("scanRowsForDialect returned error: %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(columns, []string{"created_at"}) {
|
||||
t.Fatalf("unexpected columns: %v", columns)
|
||||
}
|
||||
if len(data) != 1 {
|
||||
t.Fatalf("expected one row, got=%d", len(data))
|
||||
}
|
||||
if data[0]["created_at"] != "2026-06-16T12:34:56.123456Z" {
|
||||
t.Fatalf("OceanBase Oracle length-encoded TIMESTAMP 应解码为 RFC3339,实际=%v(%T)", data[0]["created_at"], data[0]["created_at"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanRowsForOceanBaseOracleDialectDecodesTypeCALiveTimestampString(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
registerScanRowsDuplicateDriverOnce.Do(func() {
|
||||
sql.Register(scanRowsDuplicateDriverName, scanRowsDuplicateDriver{})
|
||||
})
|
||||
|
||||
dbConn, err := sql.Open(scanRowsDuplicateDriverName, "")
|
||||
if err != nil {
|
||||
t.Fatalf("open timestamp scan rows db failed: %v", err)
|
||||
}
|
||||
defer dbConn.Close()
|
||||
|
||||
rows, err := dbConn.QueryContext(context.Background(), "SELECT timestamp_type_ca_live_columns")
|
||||
if err != nil {
|
||||
t.Fatalf("query timestamp scan rows db failed: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
data, columns, err := scanRowsForDialect(rows, oceanBaseOracleScanDialect)
|
||||
if err != nil {
|
||||
t.Fatalf("scanRowsForDialect returned error: %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(columns, []string{"created_at"}) {
|
||||
t.Fatalf("unexpected columns: %v", columns)
|
||||
}
|
||||
if len(data) != 1 {
|
||||
t.Fatalf("expected one row, got=%d", len(data))
|
||||
}
|
||||
if data[0]["created_at"] != "2026-06-16T16:46:23.158844Z" {
|
||||
t.Fatalf("OceanBase Oracle TYPE_CA live TIMESTAMP 应解码为 RFC3339,实际=%v(%T)", data[0]["created_at"], data[0]["created_at"])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -906,7 +906,7 @@ func isOperationAllowed(level ai.SQLPermissionLevel, opType ai.SQLOperationType)
|
||||
case ai.PermissionReadWrite:
|
||||
return opType == ai.SQLOpQuery || opType == ai.SQLOpDML
|
||||
case ai.PermissionFull:
|
||||
return opType == ai.SQLOpQuery || opType == ai.SQLOpDML || opType == ai.SQLOpDDL
|
||||
return true
|
||||
default:
|
||||
return opType == ai.SQLOpQuery
|
||||
}
|
||||
@@ -945,7 +945,7 @@ func safetyLevelRuleText(level ai.SQLPermissionLevel) string {
|
||||
case ai.PermissionReadWrite:
|
||||
return "读写模式仅允许查询和 DML 语句。"
|
||||
case ai.PermissionFull:
|
||||
return "完全模式仅允许查询、DML 和 DDL;未识别操作仍会被阻止。"
|
||||
return "完全模式允许所有 SQL 操作;高风险或未识别语句仍会要求确认。"
|
||||
default:
|
||||
return "只读模式仅允许查询语句。"
|
||||
}
|
||||
|
||||
@@ -554,6 +554,46 @@ func TestExecuteSQLAllowsDDLWhenAISafetyIsFullAndAllowMutating(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteSQLAllowsOtherStatementsWhenAISafetyIsFullAndAllowMutating(t *testing.T) {
|
||||
backend := &fakeBackend{
|
||||
editableConnection: connection.SavedConnectionView{
|
||||
ID: "oracle-main",
|
||||
Config: connection.ConnectionConfig{
|
||||
Type: "oracle",
|
||||
Database: "app",
|
||||
},
|
||||
},
|
||||
inspection: appcore.SQLInspection{
|
||||
StatementCount: 1,
|
||||
ReadOnly: false,
|
||||
Statements: []appcore.SQLStatementInspection{
|
||||
{Index: 1, Keyword: "call", ReadOnly: false},
|
||||
},
|
||||
},
|
||||
safetyLevel: ai.PermissionFull,
|
||||
queryResult: connection.QueryResult{
|
||||
Success: true,
|
||||
Data: []connection.ResultSetData{},
|
||||
},
|
||||
}
|
||||
|
||||
service := NewService(backend)
|
||||
result, _, err := service.ExecuteSQL(context.Background(), nil, executeSQLArgs{
|
||||
ConnectionID: "oracle-main",
|
||||
SQL: "CALL bulk_insert_users(100000)",
|
||||
AllowMutating: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteSQL returned error: %v", err)
|
||||
}
|
||||
if result == nil || result.IsError {
|
||||
t.Fatalf("expected success result, got %#v", result)
|
||||
}
|
||||
if !backend.queryCalled {
|
||||
t.Fatalf("expected SQL to be executed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteSQLNormalizesAndTruncatesResultSets(t *testing.T) {
|
||||
backend := &fakeBackend{
|
||||
editableConnection: connection.SavedConnectionView{
|
||||
|
||||
Reference in New Issue
Block a user