Compare commits

...

7 Commits

Author SHA1 Message Date
Syngnat
c1ebce4ef5 feat(query-editor): 放宽单表查询结果列级编辑边界
- 查询编辑:支持简单表列与表达式列混合展示
- 编辑安全:仅允许真实表列编辑,表达式列保持只读
- 提交流程:支持结果列别名映射回真实表字段
- 测试覆盖:补充聚合查询静默只读与列级提交用例
2026-04-29 20:07:22 +08:00
Syngnat
c927e33c8c feat(driver): 提醒重装旧版驱动代理
- optional-driver-agent 新增 metadata 方法返回 driverType、agentRevision 与协议版本
- 安装和本地导入驱动后记录 agentRevision,并在驱动状态中比对是否需要更新
- 驱动管理、连接表单和已有连接加载入口提示重装旧版 agent
- 补充旧 revision 检测和 custom 连接使用统计回归测试
2026-04-29 17:26:16 +08:00
Syngnat
824aafbdea 🔧 chore(driver): 自动生成驱动代理 revision
- 新增脚本按 optional driver-agent 源码依赖生成 revision 指纹
- 构建脚本与 dev/release workflow 在打包前自动刷新 revision
- 生成驱动 revision 映射并补充 optional driver 覆盖校验
2026-04-29 17:26:09 +08:00
Syngnat
0c1586d7a4 🐛 fix(clickhouse): 修复协议选择与连接错误提示
- 支持 ClickHouse 手动 HTTP/Native 协议优先级,避免 URI scheme 覆盖用户选择
- Auto 模式识别 Native/HTTP 协议误配错误并自动尝试备用协议
- 净化连接失败中的二进制乱码,补充测试连接参数校验和排查日志
- 前端表单增加 ClickHouse 协议选择并同步类型、缓存 key 与持久化兼容
Refs #425
2026-04-29 17:25:54 +08:00
Syngnat
b1ef52f62e feat(data-grid): 支持无主键表安全编辑
- 定位策略:新增主键、唯一索引和 Oracle ROWID 三类安全行定位能力
- 查询编辑器:简单单表 SELECT 自动补充隐藏定位列,复杂结果保持只读
- 表预览:无主键表可通过唯一索引或 Oracle ROWID 安全编辑
- 提交流程:移除无主键整行 WHERE fallback,隐藏定位列不参与展示和写入
- 后端保护:Oracle、MySQL、PostgreSQL 更新删除必须恰好影响 1 行
- 测试覆盖:补充 QueryEditor、DataViewer、DataGrid 和 ApplyChanges 相关用例
Refs #419
2026-04-29 12:33:35 +08:00
Syngnat
05a913ccb2 🐛 fix(query-editor): 修复多数据源大查询限流失效
- SQL限流:抽取查询自动限流工具,修复 SELECT 判断大小写不一致导致限制未生效
- 方言适配:按 Oracle/Dameng、SQL Server、MySQL/PostgreSQL 等方言分别注入行数限制
- 自定义驱动:支持 custom 连接根据 driver 解析 Oracle、PostgreSQL、SQL Server 等方言
- MongoDB修复:修正 db.collection.find() 解析边界,并对 find/只读 aggregate 下推 limit
- Oracle优化:DSN 增加 PREFETCH_ROWS 和 LOB FETCH 参数,减少大结果集拉取开销
- 测试覆盖:补充 SQL 方言矩阵、MongoDB 限流和 Oracle DSN 参数测试
Refs #424
2026-04-29 10:29:34 +08:00
Syngnat
f51dbcfb2c 🐛 fix(oracle): 修复查询结果编辑提交后数据还原
Oracle GetColumns 未返回主键列标记,前端 pkColumns 为空后退化为
全列 WHERE 条件,Oracle 空字符串即 NULL 语义导致 UPDATE 匹配 0 行。

LEFT JOIN all_constraints + all_cons_columns 检测主键列并赋值 Key="PRI",
与达梦驱动实现方式一致。
2026-04-29 09:41:25 +08:00
50 changed files with 4089 additions and 674 deletions

View File

@@ -246,6 +246,7 @@ jobs:
run: |
set -euo pipefail
DEV_VERSION="${{ steps.version.outputs.version }}"
./tools/generate-driver-agent-revisions.sh --platform "${{ matrix.platform }}"
if [ -n "${{ matrix.wails_tags }}" ]; then
wails build -platform ${{ matrix.platform }} -clean -o ${{ matrix.build_name }} -tags "${{ matrix.wails_tags }}" -ldflags "-s -w -X GoNavi-Wails/internal/app.AppVersion=${DEV_VERSION}"
else

View File

@@ -237,6 +237,7 @@ jobs:
shell: bash
run: |
set -euo pipefail
./tools/generate-driver-agent-revisions.sh --platform "${{ matrix.platform }}"
if [ -n "${{ matrix.wails_tags }}" ]; then
wails build -platform ${{ matrix.platform }} -clean -o ${{ matrix.build_name }} -tags "${{ matrix.wails_tags }}" -ldflags "-s -w -X GoNavi-Wails/internal/app.AppVersion=${{ github.ref_name }}"
else

View File

@@ -152,6 +152,8 @@ echo "🚀 开始构建 optional-driver-agent"
echo " 平台:$goos/$goarch"
echo " 输出目录:$output_dir_abs"
echo " 驱动列表:${drivers[*]}"
echo "🧭 生成 driver-agent revision 指纹"
"$SCRIPT_DIR/tools/generate-driver-agent-revisions.sh" --platform "$target_platform"
for driver in "${drivers[@]}"; do
if [[ "$driver" == "duckdb" && "$goos" == "windows" && "$goarch" != "amd64" ]]; then

View File

@@ -155,6 +155,7 @@ package_macos_release() {
local archive_suffix="$2"
echo -e "${GREEN}🍎 正在构建 macOS (${platform})...${NC}"
generate_driver_agent_revisions "darwin/${platform}"
wails build -platform "darwin/${platform}" -clean -ldflags "$LDFLAGS"
if [ $? -ne 0 ]; then
echo -e "${RED} ❌ macOS ${platform} 构建失败。${NC}"
@@ -185,6 +186,12 @@ package_macos_release() {
echo " ✅ 已生成 $zip_name"
}
generate_driver_agent_revisions() {
local platform="$1"
echo " 🧭 正在生成 driver-agent revision 指纹 (${platform})..."
./tools/generate-driver-agent-revisions.sh --platform "$platform"
}
echo -e "${GREEN}🚀 开始构建 $APP_NAME $VERSION...${NC}"
# 清理并创建输出目录
@@ -197,6 +204,7 @@ package_macos_release "amd64" "mac-amd64"
# --- Windows AMD64 构建 ---
echo -e "${GREEN}🪟 正在构建 Windows (amd64)...${NC}"
if command -v x86_64-w64-mingw32-gcc &> /dev/null; then
generate_driver_agent_revisions "windows/amd64"
wails build -platform windows/amd64 -clean -ldflags "$LDFLAGS"
if [ $? -eq 0 ]; then
TARGET_EXE="$DIST_DIR/${APP_NAME}-${VERSION}-windows-amd64.exe"
@@ -213,6 +221,7 @@ fi
# --- Windows ARM64 构建 ---
echo -e "${GREEN}🪟 正在构建 Windows (arm64)...${NC}"
if command -v aarch64-w64-mingw32-gcc &> /dev/null; then
generate_driver_agent_revisions "windows/arm64"
wails build -platform windows/arm64 -clean -ldflags "$LDFLAGS"
if [ $? -eq 0 ]; then
TARGET_EXE="$DIST_DIR/${APP_NAME}-${VERSION}-windows-arm64.exe"
@@ -235,6 +244,7 @@ CURRENT_ARCH=$(uname -m)
if [ "$CURRENT_OS" = "Linux" ] && [ "$CURRENT_ARCH" = "x86_64" ]; then
# 本机 Linux amd64直接构建
generate_driver_agent_revisions "linux/amd64"
wails build -platform linux/amd64 -clean -ldflags "$LDFLAGS"
if [ $? -eq 0 ]; then
TARGET_LINUX_BIN="$DIST_DIR/${APP_NAME}-${VERSION}-linux-amd64"
@@ -255,6 +265,7 @@ elif command -v x86_64-linux-gnu-gcc &> /dev/null; then
export CC=x86_64-linux-gnu-gcc
export CXX=x86_64-linux-gnu-g++
export CGO_ENABLED=1
generate_driver_agent_revisions "linux/amd64"
wails build -platform linux/amd64 -clean -ldflags "$LDFLAGS"
if [ $? -eq 0 ]; then
TARGET_LINUX_BIN="$DIST_DIR/${APP_NAME}-${VERSION}-linux-amd64"
@@ -279,6 +290,7 @@ fi
echo -e "${GREEN}🐧 正在构建 Linux (arm64)...${NC}"
if [ "$CURRENT_OS" = "Linux" ] && [ "$CURRENT_ARCH" = "aarch64" ]; then
# 本机 Linux arm64直接构建
generate_driver_agent_revisions "linux/arm64"
wails build -platform linux/arm64 -clean -ldflags "$LDFLAGS"
if [ $? -eq 0 ]; then
TARGET_LINUX_BIN="$DIST_DIR/${APP_NAME}-${VERSION}-linux-arm64"
@@ -298,6 +310,7 @@ elif command -v aarch64-linux-gnu-gcc &> /dev/null; then
export CC=aarch64-linux-gnu-gcc
export CXX=aarch64-linux-gnu-g++
export CGO_ENABLED=1
generate_driver_agent_revisions "linux/arm64"
wails build -platform linux/arm64 -clean -ldflags "$LDFLAGS"
if [ $? -eq 0 ]; then
TARGET_LINUX_BIN="$DIST_DIR/${APP_NAME}-${VERSION}-linux-arm64"

View File

@@ -37,6 +37,7 @@ type agentResponse struct {
const (
agentMethodConnect = "connect"
agentMethodClose = "close"
agentMethodMetadata = "metadata"
agentMethodPing = "ping"
agentMethodQuery = "query"
agentMethodExec = "exec"
@@ -131,6 +132,13 @@ func handleRequest(inst *db.Database, req agentRequest) agentResponse {
*inst = nil
}
return resp
case agentMethodMetadata:
resp.Data = map[string]string{
"driverType": strings.TrimSpace(agentDriverType),
"agentRevision": db.OptionalDriverAgentRevision(agentDriverType),
"protocolSchema": "json-lines-v1",
}
return resp
}
if *inst == nil {

View File

@@ -10,6 +10,7 @@ import (
"time"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/db"
)
type duckMapLike map[any]any
@@ -66,6 +67,33 @@ func TestNormalizeAgentResponseData_KeepByteSlice(t *testing.T) {
}
}
func TestHandleRequestMetadataReportsAgentRevision(t *testing.T) {
previousDriverType := agentDriverType
previousFactory := agentDatabaseFactory
t.Cleanup(func() {
agentDriverType = previousDriverType
agentDatabaseFactory = previousFactory
})
agentDriverType = "clickhouse"
agentDatabaseFactory = func() db.Database { return nil }
var inst db.Database
resp := handleRequest(&inst, agentRequest{ID: 7, Method: agentMethodMetadata})
if !resp.Success {
t.Fatalf("metadata request failed: %s", resp.Error)
}
data, ok := resp.Data.(map[string]string)
if !ok {
t.Fatalf("metadata response data type = %T", resp.Data)
}
if data["driverType"] != "clickhouse" {
t.Fatalf("unexpected driver type: %q", data["driverType"])
}
if data["agentRevision"] != db.OptionalDriverAgentRevision("clickhouse") {
t.Fatalf("unexpected agent revision: %q", data["agentRevision"])
}
}
type fakeAgentTimeoutDB struct {
queryCalled bool
queryContextCalled bool

View File

@@ -1 +1 @@
0295a42fd931778d85157816d79d29e5
d0464f9da25e9356e61652e638c99ffe

View File

@@ -95,6 +95,7 @@ type ChoiceCardOption = {
label: string;
description?: string;
};
type ClickHouseProtocolChoice = "auto" | "http" | "native";
const MAX_URI_LENGTH = 4096;
const MAX_URI_HOSTS = 32;
const MAX_TIMEOUT_SECONDS = 3600;
@@ -102,6 +103,25 @@ const CONNECTION_MODAL_WIDTH = 960;
const CONNECTION_MODAL_BODY_HEIGHT = 620;
const STEP1_SIDEBAR_DIVIDER_DARK = "rgba(255, 255, 255, 0.16)";
const STEP1_SIDEBAR_DIVIDER_LIGHT = "rgba(0, 0, 0, 0.08)";
const CLICKHOUSE_PROTOCOL_OPTIONS: Array<{
value: ClickHouseProtocolChoice;
label: string;
}> = [
{ value: "auto", label: "自动" },
{ value: "http", label: "HTTP" },
{ value: "native", label: "Native" },
];
const normalizeClickHouseProtocolValue = (
value: unknown,
): ClickHouseProtocolChoice => {
const text = String(value || "")
.trim()
.toLowerCase();
if (text === "http" || text === "https") return "http";
if (text === "native" || text === "tcp") return "native";
return "auto";
};
type ConnectionSecretKey =
| "primaryPassword"
| "sshPassword"
@@ -216,6 +236,10 @@ type DriverStatusSnapshot = {
type: string;
name: string;
connectable: boolean;
expectedRevision?: string;
needsUpdate?: boolean;
updateReason?: string;
affectedConnections?: number;
message?: string;
};
@@ -228,6 +252,14 @@ const normalizeDriverType = (value: string): string => {
return normalized;
};
const resolveConnectionDriverType = (type: string, driver?: string): string => {
const normalizedType = normalizeDriverType(type);
if (normalizedType !== "custom") {
return normalizedType;
}
return normalizeDriverType(driver || "");
};
const ConnectionModal: React.FC<{
open: boolean;
onClose: () => void;
@@ -300,6 +332,7 @@ const ConnectionModal: React.FC<{
const redisTopology = Form.useWatch("redisTopology", form) || "single";
const sslMode = Form.useWatch("sslMode", form) || "preferred";
const proxyType = Form.useWatch("proxyType", form) || "socks5";
const customDriver = Form.useWatch("driver", form) || "";
const mongoReadPreference =
Form.useWatch("mongoReadPreference", form) || "primary";
const mongoAuthMechanism = Form.useWatch("mongoAuthMechanism", form) || "";
@@ -831,6 +864,12 @@ const ConnectionModal: React.FC<{
type,
name: String(item.name || item.type || type).trim(),
connectable: !!item.connectable,
expectedRevision: String(item.expectedRevision || "").trim() || undefined,
needsUpdate: !!item.needsUpdate,
updateReason: String(item.updateReason || "").trim() || undefined,
affectedConnections: Number.isFinite(Number(item.affectedConnections))
? Number(item.affectedConnections)
: undefined,
message: String(item.message || "").trim() || undefined,
};
});
@@ -850,8 +889,9 @@ const ConnectionModal: React.FC<{
const resolveDriverUnavailableReason = async (
type: string,
driver?: string,
): Promise<string> => {
const normalized = normalizeDriverType(type);
const normalized = resolveConnectionDriverType(type, driver);
if (!normalized || normalized === "custom") {
return "";
}
@@ -1000,6 +1040,13 @@ const ConnectionModal: React.FC<{
}
};
const normalizeUriBool = (raw: unknown) => {
const text = String(raw ?? "")
.trim()
.toLowerCase();
return text === "1" || text === "true" || text === "yes" || text === "on";
};
const normalizeFileDbPath = (rawPath: string): string => {
let pathText = String(rawPath || "").trim();
if (!pathText) {
@@ -1117,6 +1164,44 @@ const ConnectionModal: React.FC<{
};
};
const parseClickHouseHTTPUriToValues = (
uriText: string,
fallbackPort?: number,
): Record<string, any> | null => {
const trimmed = String(uriText || "").trim();
const lower = trimmed.toLowerCase();
const isHttps = lower.startsWith("https://");
const isHttp = lower.startsWith("http://");
if (!isHttp && !isHttps) {
return null;
}
const defaultPort =
Number.isFinite(Number(fallbackPort)) && Number(fallbackPort) > 0
? Number(fallbackPort)
: isHttps
? 8443
: 8123;
const parsed = parseSingleHostUri(
trimmed,
[isHttps ? "https" : "http"],
defaultPort,
);
if (!parsed) {
return null;
}
const skipVerify = normalizeUriBool(parsed.params.get("skip_verify"));
return {
host: parsed.host,
port: parsed.port,
user: parsed.username,
password: parsed.password,
database: parsed.database || "",
clickHouseProtocol: "http",
useSSL: isHttps,
sslMode: isHttps ? (skipVerify ? "skip-verify" : "required") : "disable",
};
};
const parseUriToValues = (
uriText: string,
type: string,
@@ -1337,6 +1422,13 @@ const ConnectionModal: React.FC<{
};
}
if (type === "clickhouse") {
const httpValues = parseClickHouseHTTPUriToValues(trimmedUri);
if (httpValues) {
return httpValues;
}
}
const singleHostSchemes = singleHostUriSchemesByType[type];
if (singleHostSchemes && singleHostSchemes.length > 0) {
const parsed = parseSingleHostUri(
@@ -1412,6 +1504,9 @@ const ConnectionModal: React.FC<{
parsedValues.sslMode = "disable";
}
} else if (type === "clickhouse") {
parsedValues.clickHouseProtocol = normalizeClickHouseProtocolValue(
parsed.params.get("protocol"),
);
const secure = String(
parsed.params.get("secure") || parsed.params.get("tls") || "",
)
@@ -1707,7 +1802,18 @@ const ConnectionModal: React.FC<{
return `${scheme}://${encodedAuth}${hosts.join(",")}${dbPath}${query ? `?${query}` : ""}`;
}
const scheme = type === "postgres" ? "postgresql" : type;
const clickHouseProtocol =
type === "clickhouse"
? normalizeClickHouseProtocolValue(values.clickHouseProtocol)
: "auto";
const scheme =
type === "postgres"
? "postgresql"
: type === "clickhouse" && clickHouseProtocol === "http"
? values.useSSL
? "https"
: "http"
: type;
const dbPath = database ? `/${encodeURIComponent(database)}` : "";
const params = new URLSearchParams();
if (supportsSSLForType(type) && values.useSSL) {
@@ -1728,9 +1834,15 @@ const ConnectionModal: React.FC<{
mode === "skip-verify" || mode === "preferred" ? "true" : "false",
);
} else if (type === "clickhouse") {
params.set("secure", "true");
if (mode === "skip-verify" || mode === "preferred") {
params.set("skip_verify", "true");
if (clickHouseProtocol === "http") {
if (mode === "skip-verify" || mode === "preferred") {
params.set("skip_verify", "true");
}
} else {
params.set("secure", "true");
if (mode === "skip-verify" || mode === "preferred") {
params.set("skip_verify", "true");
}
}
} else if (type === "dameng") {
const certPath = String(values.sslCertPath || "").trim();
@@ -1761,6 +1873,9 @@ const ConnectionModal: React.FC<{
params.set("protocol", "ws");
}
}
if (type === "clickhouse" && clickHouseProtocol !== "auto") {
params.set("protocol", clickHouseProtocol);
}
const query = params.toString();
return `${scheme}://${encodedAuth}${toAddress(host, port, defaultPort)}${dbPath}${query ? `?${query}` : ""}`;
};
@@ -1967,6 +2082,10 @@ const ConnectionModal: React.FC<{
password: config.password,
database: config.database,
uri: config.uri || "",
clickHouseProtocol:
configType === "clickhouse"
? normalizeClickHouseProtocolValue(config.clickHouseProtocol)
: "auto",
includeDatabases: initialValues.includeDatabases,
includeRedisDatabases: initialValues.includeRedisDatabases,
useSSL: !!config.useSSL,
@@ -2287,10 +2406,14 @@ const ConnectionModal: React.FC<{
const values = form.getFieldsValue(true);
const unavailableReason = await resolveDriverUnavailableReason(
values.type,
values.driver,
);
if (unavailableReason) {
message.warning(unavailableReason);
promptInstallDriver(values.type, unavailableReason);
promptInstallDriver(
resolveConnectionDriverType(values.type, values.driver) || values.type,
unavailableReason,
);
return;
}
setLoading(true);
@@ -2445,6 +2568,7 @@ const ConnectionModal: React.FC<{
const values = form.getFieldsValue(true);
const unavailableReason = await resolveDriverUnavailableReason(
values.type,
values.driver,
);
if (unavailableReason) {
applyTestFailureFeedback(
@@ -2454,7 +2578,10 @@ const ConnectionModal: React.FC<{
fallback: "驱动未安装启用",
}),
);
promptInstallDriver(values.type, unavailableReason);
promptInstallDriver(
resolveConnectionDriverType(values.type, values.driver) || values.type,
unavailableReason,
);
return;
}
const blockingSecretClearMessage = getBlockingSecretClearMessage(values);
@@ -2740,6 +2867,15 @@ const ConnectionModal: React.FC<{
(Array.isArray(value) && value.length === 0);
if (parsedUriValues) {
Object.entries(parsedUriValues).forEach(([key, value]) => {
if (
key === "clickHouseProtocol" &&
normalizeClickHouseProtocolValue((mergedValues as any)[key]) ===
"auto" &&
normalizeClickHouseProtocolValue(value) !== "auto"
) {
(mergedValues as any)[key] = value;
return;
}
if (isEmptyField((mergedValues as any)[key])) {
(mergedValues as any)[key] = value;
}
@@ -2748,6 +2884,35 @@ const ConnectionModal: React.FC<{
const type = String(mergedValues.type || "").toLowerCase();
const defaultPort = getDefaultPortByType(type);
if (type === "clickhouse") {
const requestedProtocol = normalizeClickHouseProtocolValue(
mergedValues.clickHouseProtocol,
);
const hostSchemeValues = parseClickHouseHTTPUriToValues(
mergedValues.host,
Number(mergedValues.port || defaultPort),
);
if (hostSchemeValues) {
mergedValues.host = hostSchemeValues.host;
mergedValues.port = hostSchemeValues.port;
if (requestedProtocol !== "native") {
mergedValues.clickHouseProtocol = "http";
mergedValues.useSSL = hostSchemeValues.useSSL;
mergedValues.sslMode = hostSchemeValues.sslMode;
} else {
mergedValues.clickHouseProtocol = "native";
}
if (isEmptyField(mergedValues.user)) {
mergedValues.user = hostSchemeValues.user;
}
if (isEmptyField(mergedValues.password)) {
mergedValues.password = hostSchemeValues.password;
}
if (isEmptyField(mergedValues.database)) {
mergedValues.database = hostSchemeValues.database;
}
}
}
const isFileDbType = isFileDatabaseType(type);
const sslCapableType = supportsSSLForType(type);
@@ -2990,6 +3155,10 @@ const ConnectionModal: React.FC<{
? Math.max(0, Math.min(15, Math.trunc(Number(mergedValues.redisDB))))
: 0,
uri: String(mergedValues.uri || "").trim(),
clickHouseProtocol:
type === "clickhouse"
? normalizeClickHouseProtocolValue(mergedValues.clickHouseProtocol)
: undefined,
hosts: hosts,
topology: topology,
mysqlReplicaUser: mysqlReplicaUser,
@@ -3017,7 +3186,10 @@ const ConnectionModal: React.FC<{
}
setTypeSelectWarning(null);
setDbType(type);
form.setFieldsValue({ type: type });
form.setFieldsValue({
type: type,
clickHouseProtocol: type === "clickhouse" ? "auto" : undefined,
});
const defaultPort = getDefaultPortByType(type);
if (type === "jvm") {
@@ -3188,17 +3360,27 @@ const ConnectionModal: React.FC<{
isJVM && hasUnsupportedJvmModeSelection
? "当前连接包含未支持的 JVM 模式。此版本只支持 JMX / Endpoint / Agent请先调整允许模式和首选模式后再继续。"
: "";
const currentDriverType = normalizeDriverType(dbType);
const currentDriverType = resolveConnectionDriverType(dbType, customDriver);
const hasCurrentDriverType =
currentDriverType !== "" && currentDriverType !== "custom";
const currentDriverSnapshot = driverStatusMap[currentDriverType];
const currentDriverUnavailableReason =
currentDriverType !== "custom" &&
hasCurrentDriverType &&
currentDriverSnapshot &&
!currentDriverSnapshot.connectable
? currentDriverSnapshot.message ||
`${currentDriverSnapshot.name || dbType} 驱动未安装启用`
: "";
const currentDriverUpdateReason =
hasCurrentDriverType &&
currentDriverSnapshot?.connectable &&
currentDriverSnapshot.needsUpdate
? currentDriverSnapshot.message ||
currentDriverSnapshot.updateReason ||
`${currentDriverSnapshot.name || dbType} 驱动代理需要重装后才能应用当前版本的驱动侧更新`
: "";
const driverStatusChecking =
currentDriverType !== "custom" && !driverStatusLoaded && step === 2;
hasCurrentDriverType && !driverStatusLoaded && step === 2;
const dbTypeGroups = [
{
@@ -4294,6 +4476,25 @@ const ConnectionModal: React.FC<{
),
})}
{dbType === "clickhouse" &&
renderConfigSectionCard({
sectionKey: "connectionMode",
icon: <ClusterOutlined />,
children: (
<Form.Item
name="clickHouseProtocol"
label="连接协议"
help="自动模式按 URI scheme 和常见端口判断;非标 HTTP/Native 端口可手动指定。"
style={{ marginBottom: 0 }}
>
<Select
options={CLICKHOUSE_PROTOCOL_OPTIONS}
onChange={() => clearConnectionTestResultForChoice()}
/>
</Form.Item>
),
})}
{(dbType === "postgres" ||
dbType === "kingbase" ||
dbType === "highgo" ||
@@ -5898,6 +6099,26 @@ const ConnectionModal: React.FC<{
}
/>
)}
{currentDriverUpdateReason && (
<Alert
showIcon
type="warning"
style={{ marginBottom: 12 }}
message="当前数据源驱动代理建议重装"
description={
<Space size={8}>
<span>{currentDriverUpdateReason}</span>
<Button
type="link"
size="small"
onClick={() => onOpenDriverManager?.()}
>
</Button>
</Space>
}
/>
)}
{(() => {
const sectionItems: Array<{
key: "basic" | "network" | "appearance";

View File

@@ -2,7 +2,8 @@ import React from 'react';
import { act, create, type ReactTestRenderer } from 'react-test-renderer';
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
import DataGrid from './DataGrid';
import DataGrid, { buildDataGridCommitChangeSet, GONAVI_ROW_KEY } from './DataGrid';
import { ORACLE_ROWID_LOCATOR_COLUMN } from '../utils/rowLocator';
const storeState = vi.hoisted(() => ({
connections: [
@@ -216,6 +217,157 @@ const waitForEffects = async () => {
});
};
const normalizeValue = (_columnName: string, value: any) => value;
const rowKeyToString = (key: any) => String(key);
const commitColumnGuard = (columnName: string) => (
columnName !== GONAVI_ROW_KEY && columnName !== ORACLE_ROWID_LOCATOR_COLUMN
);
describe('DataGrid commit change set', () => {
it('uses unique locator values instead of falling back to the whole row', () => {
const result = buildDataGridCommitChangeSet({
addedRows: [],
modifiedRows: {
'row-1': { [GONAVI_ROW_KEY]: 'row-1', EMAIL: 'a@example.com', NAME: 'new-name', AGE: 42 },
},
deletedRowKeys: new Set(),
data: [{ [GONAVI_ROW_KEY]: 'row-1', EMAIL: 'a@example.com', NAME: 'old-name', AGE: 42 }],
editLocator: {
strategy: 'unique-key',
columns: ['EMAIL'],
valueColumns: ['EMAIL'],
readOnly: false,
},
visibleColumnNames: ['EMAIL', 'NAME', 'AGE'],
rowKeyToString,
normalizeCommitCellValue: normalizeValue,
shouldCommitColumn: commitColumnGuard,
});
expect(result).toEqual({
ok: true,
changes: {
inserts: [],
updates: [{ keys: { EMAIL: 'a@example.com' }, values: { NAME: 'new-name' } }],
deletes: [],
},
});
});
it('uses hidden Oracle ROWID only as locator and excludes it from update values', () => {
const result = buildDataGridCommitChangeSet({
addedRows: [],
modifiedRows: {
'row-1': { [GONAVI_ROW_KEY]: 'row-1', NAME: 'new-name', [ORACLE_ROWID_LOCATOR_COLUMN]: 'BBBB' },
},
deletedRowKeys: new Set(),
data: [{ [GONAVI_ROW_KEY]: 'row-1', NAME: 'old-name', [ORACLE_ROWID_LOCATOR_COLUMN]: 'AAAA' }],
editLocator: {
strategy: 'oracle-rowid',
columns: ['ROWID'],
valueColumns: [ORACLE_ROWID_LOCATOR_COLUMN],
hiddenColumns: [ORACLE_ROWID_LOCATOR_COLUMN],
readOnly: false,
},
visibleColumnNames: ['NAME'],
rowKeyToString,
normalizeCommitCellValue: normalizeValue,
shouldCommitColumn: commitColumnGuard,
});
expect(result).toEqual({
ok: true,
changes: {
inserts: [],
updates: [{ keys: { ROWID: 'AAAA' }, values: { NAME: 'new-name' } }],
deletes: [],
},
});
});
it('commits only writable result columns and maps aliases back to table columns', () => {
const result = buildDataGridCommitChangeSet({
addedRows: [],
modifiedRows: {
'row-1': {
[GONAVI_ROW_KEY]: 'row-1',
DISPLAY_NAME: 'new-name',
NAME_UPPER: 'NEW-NAME',
},
},
deletedRowKeys: new Set(),
data: [{
[GONAVI_ROW_KEY]: 'row-1',
ID: 7,
DISPLAY_NAME: 'old-name',
NAME_UPPER: 'OLD-NAME',
}],
editLocator: {
strategy: 'primary-key',
columns: ['ID'],
valueColumns: ['ID'],
writableColumns: {
DISPLAY_NAME: 'NAME',
},
readOnly: false,
},
visibleColumnNames: ['DISPLAY_NAME', 'NAME_UPPER'],
rowKeyToString,
normalizeCommitCellValue: normalizeValue,
shouldCommitColumn: commitColumnGuard,
});
expect(result).toEqual({
ok: true,
changes: {
inserts: [],
updates: [{ keys: { ID: 7 }, values: { NAME: 'new-name' } }],
deletes: [],
},
});
});
it('fails closed when no safe locator is available', () => {
const result = buildDataGridCommitChangeSet({
addedRows: [],
modifiedRows: {
'row-1': { [GONAVI_ROW_KEY]: 'row-1', NAME: 'new-name' },
},
deletedRowKeys: new Set(),
data: [{ [GONAVI_ROW_KEY]: 'row-1', NAME: 'old-name' }],
editLocator: undefined,
visibleColumnNames: ['NAME'],
rowKeyToString,
normalizeCommitCellValue: normalizeValue,
shouldCommitColumn: commitColumnGuard,
});
expect(result).toEqual({ ok: false, error: '当前结果没有可用的安全行定位方式,无法提交修改。' });
});
it('rejects delete rows when unique locator value is null', () => {
const result = buildDataGridCommitChangeSet({
addedRows: [],
modifiedRows: {},
deletedRowKeys: new Set(['row-1']),
data: [{ [GONAVI_ROW_KEY]: 'row-1', EMAIL: null, NAME: 'old-name' }],
editLocator: {
strategy: 'unique-key',
columns: ['EMAIL'],
valueColumns: ['EMAIL'],
readOnly: false,
},
visibleColumnNames: ['EMAIL', 'NAME'],
rowKeyToString,
normalizeCommitCellValue: normalizeValue,
shouldCommitColumn: commitColumnGuard,
});
expect(result).toEqual({ ok: false, error: '定位列 EMAIL 的值为空,无法安全提交修改。' });
});
});
describe('DataGrid DDL interactions', () => {
beforeEach(() => {
backendApp.DBGetColumns.mockResolvedValue({ success: true, data: [] });

View File

@@ -159,6 +159,7 @@ describe('DataGrid layout', () => {
columnNames={['id', 'name']}
loading={false}
tableName="users"
pkColumns={['id']}
/>,
);

View File

@@ -79,6 +79,13 @@ import {
type DataGridFindMatch,
type DataGridFindNavigationDirection,
} from '../utils/dataGridFind';
import {
filterHiddenLocatorColumns,
isWritableResultColumn,
resolveWritableColumnName,
resolveRowLocatorValues,
type EditRowLocator,
} from '../utils/rowLocator';
// --- Error Boundary ---
interface DataGridErrorBoundaryState {
@@ -916,6 +923,7 @@ interface DataGridProps {
dbName?: string;
connectionId?: string;
pkColumns?: string[];
editLocator?: EditRowLocator;
readOnly?: boolean;
onReload?: () => void;
onSort?: (field: string, order: string) => void;
@@ -960,12 +968,112 @@ type ColumnMeta = {
comment: string;
};
type NormalizeCommitCellValue = (columnName: string, value: any, mode: 'insert' | 'update') => any;
type DataGridCommitChangeSet = {
inserts: any[];
updates: any[];
deletes: any[];
};
export const buildDataGridCommitChangeSet = ({
addedRows,
modifiedRows,
deletedRowKeys,
data,
editLocator,
visibleColumnNames,
rowKeyToString,
normalizeCommitCellValue,
shouldCommitColumn,
}: {
addedRows: any[];
modifiedRows: Record<string, any>;
deletedRowKeys: Set<string>;
data: any[];
editLocator?: EditRowLocator;
visibleColumnNames: string[];
rowKeyToString: (key: any) => string;
normalizeCommitCellValue: NormalizeCommitCellValue;
shouldCommitColumn: (columnName: string) => boolean;
}): { ok: true; changes: DataGridCommitChangeSet } | { ok: false; error: string } => {
if (!editLocator || editLocator.readOnly || editLocator.strategy === 'none') {
return { ok: false, error: editLocator?.reason || '当前结果没有可用的安全行定位方式,无法提交修改。' };
}
const normalizeValues = (values: Record<string, any>, mode: 'insert' | 'update') => {
const normalizedValues: Record<string, any> = {};
Object.entries(values).forEach(([col, val]) => {
if (!shouldCommitColumn(col)) return;
const commitColumnName = resolveWritableColumnName(col, editLocator);
if (!commitColumnName) return;
const normalizedVal = normalizeCommitCellValue(col, val, mode);
if (normalizedVal !== undefined) {
normalizedValues[commitColumnName] = normalizedVal;
}
});
return normalizedValues;
};
const originalRowsByKey = new Map<string, any>();
data.forEach((row) => {
const key = row?.[GONAVI_ROW_KEY];
if (key === undefined || key === null) return;
originalRowsByKey.set(rowKeyToString(key), row);
});
const inserts: any[] = [];
const updates: any[] = [];
const deletes: any[] = [];
addedRows.forEach(row => {
const key = row?.[GONAVI_ROW_KEY];
if (key !== undefined && key !== null && deletedRowKeys.has(rowKeyToString(key))) return;
inserts.push(normalizeValues(row, 'insert'));
});
for (const keyStr of deletedRowKeys) {
const originalRow = originalRowsByKey.get(keyStr);
if (!originalRow) continue;
const locatorValues = resolveRowLocatorValues(editLocator, originalRow);
if (!locatorValues.ok) return { ok: false, error: locatorValues.error };
deletes.push(locatorValues.values);
}
for (const [keyStr, newRow] of Object.entries(modifiedRows)) {
if (deletedRowKeys.has(keyStr)) continue;
const originalRow = originalRowsByKey.get(keyStr);
if (!originalRow) continue;
const locatorValues = resolveRowLocatorValues(editLocator, originalRow);
if (!locatorValues.ok) return { ok: false, error: locatorValues.error };
const hasRowKey = Object.prototype.hasOwnProperty.call(newRow as any, GONAVI_ROW_KEY);
let values: Record<string, any> = {};
if (!hasRowKey) {
values = { ...(newRow as any) };
} else {
visibleColumnNames.forEach((col) => {
const nextVal = (newRow as any)?.[col];
const prevVal = (originalRow as any)?.[col];
if (!isCellValueEqualForDiff(prevVal, nextVal)) values[col] = nextVal;
});
}
const normalizedValues = normalizeValues(values, 'update');
if (Object.keys(normalizedValues).length === 0) continue;
updates.push({ keys: locatorValues.values, values: normalizedValues });
}
return { ok: true, changes: { inserts, updates, deletes } };
};
// P2 性能优化:提取内联 style 对象为模块级常量,避免每次 render 创建新对象
const CELL_ELLIPSIS_STYLE: React.CSSProperties = { overflow: 'hidden', textOverflow: 'ellipsis', whiteSpace: 'nowrap' };
const VIRTUAL_CELL_WRAPPER_STYLE: React.CSSProperties = { margin: -8, padding: '8px 8px 8px 8px' };
const DataGrid: React.FC<DataGridProps> = ({
data, columnNames, loading, tableName, exportScope = 'table', resultSql, dbName, connectionId, pkColumns = [], readOnly = false,
data, columnNames, loading, tableName, exportScope = 'table', resultSql, dbName, connectionId, pkColumns = [], editLocator, readOnly = false,
onReload, onSort, onPageChange, pagination, onRequestTotalCount, onCancelTotalCount, sortInfoExternal, showFilter, onToggleFilter, exportSqlWithFilter, onApplyFilter, appliedFilterConditions, quickWhereCondition,
onApplyQuickWhereCondition,
scrollSnapshot, onScrollSnapshotChange
@@ -999,7 +1107,25 @@ const DataGrid: React.FC<DataGridProps> = ({
darkMode,
visible: showDataTableVerticalBorders,
});
const canModifyData = !readOnly && !!tableName;
const effectiveEditLocator = useMemo<EditRowLocator | undefined>(() => {
if (editLocator) return editLocator;
if (pkColumns.length === 0) return undefined;
return {
strategy: 'primary-key',
columns: pkColumns,
valueColumns: pkColumns,
readOnly: false,
};
}, [editLocator, pkColumns]);
const visibleColumnNames = useMemo(
() => filterHiddenLocatorColumns(columnNames, effectiveEditLocator),
[columnNames, effectiveEditLocator]
);
const shouldCommitColumn = useCallback((columnName: string): boolean => {
const normalized = String(columnName || '').trim();
return normalized !== GONAVI_ROW_KEY && isWritableResultColumn(normalized, effectiveEditLocator);
}, [effectiveEditLocator]);
const canModifyData = !readOnly && !!tableName && !!effectiveEditLocator && !effectiveEditLocator.readOnly && effectiveEditLocator.strategy !== 'none';
const showColumnComment = queryOptions?.showColumnComment ?? true;
const showColumnType = queryOptions?.showColumnType ?? true;
@@ -1053,7 +1179,7 @@ const DataGrid: React.FC<DataGridProps> = ({
// Sync display order from incoming prop and store memory
useEffect(() => {
let nextOrder = [...columnNames];
let nextOrder = [...visibleColumnNames];
if (enableColumnOrderMemory && connectionId && dbName && tableName) {
const storedOrder = tableColumnOrders[`${connectionId}-${dbName}-${tableName}`];
if (Array.isArray(storedOrder) && storedOrder.length > 0) {
@@ -1066,7 +1192,7 @@ const DataGrid: React.FC<DataGridProps> = ({
}
}
setAllOrderedColumnNames(nextOrder);
}, [columnNames, tableColumnOrders, enableColumnOrderMemory, connectionId, dbName, tableName]);
}, [visibleColumnNames, tableColumnOrders, enableColumnOrderMemory, connectionId, dbName, tableName]);
// Compute final display columns
useEffect(() => {
@@ -1378,7 +1504,13 @@ const DataGrid: React.FC<DataGridProps> = ({
const exportData = async (rows: any[], format: string) => {
const hide = message.loading(`正在导出 ${rows.length} 条数据...`, 0);
try {
const cleanRows = rows.map(({ [GONAVI_ROW_KEY]: _rowKey, ...rest }) => rest);
const cleanRows = rows.map((row) => {
const next: Record<string, any> = {};
displayColumnNames.forEach((columnName) => {
next[columnName] = row?.[columnName];
});
return next;
});
// Pass tableName (or 'export') as default filename
const res = await ExportData(cleanRows, displayColumnNames, tableName || 'export', format);
if (res.success) {
@@ -1538,10 +1670,10 @@ const DataGrid: React.FC<DataGridProps> = ({
return metaColumns;
}
if (exportScope === 'table') {
return columnNames.filter((columnName) => columnName !== GONAVI_ROW_KEY);
return visibleColumnNames.filter((columnName) => columnName !== GONAVI_ROW_KEY);
}
return [];
}, [columnMetaMap, exportScope, columnNames]);
}, [columnMetaMap, exportScope, visibleColumnNames]);
const normalizeCommitCellValue = useCallback(
(columnName: string, value: any, mode: 'insert' | 'update') => {
@@ -3298,19 +3430,25 @@ const DataGrid: React.FC<DataGridProps> = ({
const jsonViewText = useMemo(() => {
if (viewMode !== 'json') return '';
const cleanRows = mergedDisplayData.map((row) => {
const { [GONAVI_ROW_KEY]: _rowKey, ...rest } = row || {};
return normalizeValueForJsonView(rest);
const next: Record<string, any> = {};
visibleColumnNames.forEach((columnName) => {
next[columnName] = row?.[columnName];
});
return normalizeValueForJsonView(next);
});
return JSON.stringify(cleanRows, null, 2);
}, [viewMode, mergedDisplayData]);
}, [viewMode, mergedDisplayData, visibleColumnNames]);
const textViewRows = useMemo(() => {
if (viewMode !== 'text') return [];
return mergedDisplayData.map((row) => {
const { [GONAVI_ROW_KEY]: _rowKey, ...rest } = row || {};
return rest;
const next: Record<string, any> = {};
visibleColumnNames.forEach((columnName) => {
next[columnName] = row?.[columnName];
});
return next;
});
}, [viewMode, mergedDisplayData]);
}, [viewMode, mergedDisplayData, visibleColumnNames]);
const currentTextRow = useMemo(() => {
if (viewMode !== 'text') return null;
@@ -3363,7 +3501,7 @@ const DataGrid: React.FC<DataGridProps> = ({
const formMap: Record<string, any> = {};
const nullCols = new Set<string>();
columnNames.forEach((col) => {
visibleColumnNames.forEach((col) => {
const baseVal = (baseRow as any)?.[col];
const displayVal = (displayRow as any)?.[col];
baseRawMap[col] = baseVal;
@@ -3511,7 +3649,7 @@ const DataGrid: React.FC<DataGridProps> = ({
const keyStr = rowKeyStr(rowKey);
const normalizedNext: Record<string, any> = {};
let hasAnyVisibleChange = false;
columnNames.forEach((col) => {
visibleColumnNames.forEach((col) => {
const currentVal = (currentRow as any)?.[col];
const editedVal = Object.prototype.hasOwnProperty.call(nextItem, col) ? (nextItem as any)[col] : currentVal;
if (!isJsonViewValueEqual(currentVal, editedVal)) hasAnyVisibleChange = true;
@@ -3530,7 +3668,7 @@ const DataGrid: React.FC<DataGridProps> = ({
const originalRow = originalMap.get(keyStr);
if (!originalRow) continue;
const patch: Record<string, any> = {};
columnNames.forEach((col) => {
visibleColumnNames.forEach((col) => {
const prevVal = (originalRow as any)?.[col];
const nextVal = normalizedNext[col];
if (!isCellValueEqualForDiff(prevVal, nextVal)) patch[col] = nextVal;
@@ -3595,7 +3733,7 @@ const DataGrid: React.FC<DataGridProps> = ({
const baseRawMap = rowEditorBaseRawRef.current || {};
const patch: Record<string, any> = {};
columnNames.forEach((col) => {
visibleColumnNames.forEach((col) => {
let nextVal = values[col];
// 日期时间类型: 将 dayjs 对象转回格式化字符串
if (nextVal && dayjs.isDayjs(nextVal)) {
@@ -3615,7 +3753,7 @@ const DataGrid: React.FC<DataGridProps> = ({
});
closeRowEditor();
}, [rowEditorRowKey, rowEditorForm, addedRows, columnNames, rowKeyStr, closeRowEditor]);
}, [rowEditorRowKey, rowEditorForm, addedRows, visibleColumnNames, rowKeyStr, closeRowEditor]);
const enableVirtual = viewMode === 'table';
@@ -3633,7 +3771,7 @@ const DataGrid: React.FC<DataGridProps> = ({
}),
sorter: onSort ? { multiple: displayColumnNames.indexOf(key) + 1 } : false,
sortOrder: (sortInfo.find(s => s.columnKey === key && s.enabled !== false)?.order || null) as SortOrder | undefined,
editable: canModifyData, // Only editable if table name known and not readonly
editable: canModifyData && isWritableResultColumn(key, effectiveEditLocator),
render: (text: any) => (
<div style={CELL_ELLIPSIS_STYLE}>
{renderCellDisplayValue(text, normalizedPageFindText)}
@@ -3761,7 +3899,7 @@ const DataGrid: React.FC<DataGridProps> = ({
const handleAddRow = () => {
const newKey = `new-${Date.now()}`;
const newRow: any = { [GONAVI_ROW_KEY]: newKey };
columnNames.forEach(col => newRow[col] = '');
visibleColumnNames.forEach(col => newRow[col] = '');
pendingScrollToBottomRef.current = true;
setAddedRows(prev => [...prev, newRow]);
};
@@ -3775,7 +3913,7 @@ const DataGrid: React.FC<DataGridProps> = ({
const copiedRows = buildCopiedRowsForPaste({
rows: mergedDisplayData as Array<Record<string, any>>,
selectedRowKeys,
columnNames,
columnNames: visibleColumnNames,
rowKeyField: GONAVI_ROW_KEY,
rowKeyToString: rowKeyStr,
});
@@ -3786,7 +3924,7 @@ const DataGrid: React.FC<DataGridProps> = ({
setCopiedRowsForPaste(copiedRows);
void message.success(`已复制 ${copiedRows.length} 行,可粘贴为新增行`);
}, [selectedRowKeys, mergedDisplayData, columnNames, rowKeyStr]);
}, [selectedRowKeys, mergedDisplayData, visibleColumnNames, rowKeyStr]);
const handlePasteCopiedRowsAsNew = useCallback(() => {
if (copiedRowsForPaste.length === 0) {
@@ -3796,7 +3934,7 @@ const DataGrid: React.FC<DataGridProps> = ({
const nextRows = buildPastedRowsFromCopiedRows({
rows: copiedRowsForPaste,
columnNames,
columnNames: visibleColumnNames,
rowKeyField: GONAVI_ROW_KEY,
createRowKey: (index) => {
pastedRowSequenceRef.current += 1;
@@ -3812,7 +3950,7 @@ const DataGrid: React.FC<DataGridProps> = ({
setAddedRows(prev => [...prev, ...nextRows]);
setSelectedRowKeys(nextRows.map(row => row[GONAVI_ROW_KEY]));
void message.success(`已粘贴 ${nextRows.length} 行为新增行,请检查后提交事务`);
}, [copiedRowsForPaste, columnNames]);
}, [copiedRowsForPaste, visibleColumnNames]);
const handleDeleteSelected = () => {
setDeletedRowKeys(prev => {
@@ -3827,66 +3965,23 @@ const DataGrid: React.FC<DataGridProps> = ({
if (!connectionId || !tableName) return;
const conn = connections.find(c => c.id === connectionId);
if (!conn) return;
const inserts: any[] = [];
const updates: any[] = [];
const deletes: any[] = [];
addedRows.forEach(row => {
const { [GONAVI_ROW_KEY]: _rowKey, ...vals } = row;
const normalizedValues: Record<string, any> = {};
Object.entries(vals).forEach(([col, val]) => {
const normalizedVal = normalizeCommitCellValue(col, val, 'insert');
if (normalizedVal !== undefined) {
normalizedValues[col] = normalizedVal;
}
});
inserts.push(normalizedValues);
});
deletedRowKeys.forEach(keyStr => {
// Find original data
const originalRow = data.find(d => rowKeyStr(d?.[GONAVI_ROW_KEY]) === keyStr) || addedRows.find(d => rowKeyStr(d?.[GONAVI_ROW_KEY]) === keyStr);
if (originalRow) {
const pkData: any = {};
if (pkColumns.length > 0) pkColumns.forEach(k => pkData[k] = originalRow[k]);
else { const { [GONAVI_ROW_KEY]: _rowKey, ...rest } = originalRow; Object.assign(pkData, rest); }
deletes.push(pkData);
}
});
Object.entries(modifiedRows).forEach(([keyStr, newRow]) => {
if (deletedRowKeys.has(keyStr)) return;
const originalRow = data.find(d => rowKeyStr(d?.[GONAVI_ROW_KEY]) === keyStr);
if (!originalRow) return; // Should not happen for modified rows unless deleted
const pkData: any = {};
if (pkColumns.length > 0) pkColumns.forEach(k => pkData[k] = originalRow[k]);
else { const { [GONAVI_ROW_KEY]: _rowKey, ...rest } = originalRow; Object.assign(pkData, rest); }
const hasRowKey = Object.prototype.hasOwnProperty.call(newRow as any, GONAVI_ROW_KEY);
let values: any = {};
if (!hasRowKey) {
values = { ...(newRow as any) };
} else {
columnNames.forEach((col) => {
const nextVal = (newRow as any)?.[col];
const prevVal = (originalRow as any)?.[col];
if (!isCellValueEqualForDiff(prevVal, nextVal)) values[col] = nextVal;
});
}
const normalizedValues: Record<string, any> = {};
Object.entries(values).forEach(([col, val]) => {
const normalizedVal = normalizeCommitCellValue(col, val, 'update');
if (normalizedVal !== undefined) {
normalizedValues[col] = normalizedVal;
}
});
if (Object.keys(normalizedValues).length === 0) return;
updates.push({ keys: pkData, values: normalizedValues });
const changeSetResult = buildDataGridCommitChangeSet({
addedRows,
modifiedRows,
deletedRowKeys,
data,
editLocator: effectiveEditLocator,
visibleColumnNames,
rowKeyToString: rowKeyStr,
normalizeCommitCellValue,
shouldCommitColumn,
});
if (!changeSetResult.ok) {
void message.error(changeSetResult.error);
return;
}
const { inserts, updates, deletes } = changeSetResult.changes;
if (inserts.length === 0 && updates.length === 0 && deletes.length === 0) {
void message.info("没有可提交的变更");
return;
@@ -3902,7 +3997,7 @@ const DataGrid: React.FC<DataGridProps> = ({
};
const startTime = Date.now();
const res = await ApplyChanges(buildRpcConnectionConfig(config) as any, dbName || '', tableName, { inserts, updates, deletes } as any);
const res = await ApplyChanges(buildRpcConnectionConfig(config) as any, dbName || '', tableName, { inserts, updates, deletes, locatorStrategy: effectiveEditLocator?.strategy } as any);
const duration = Date.now() - startTime;
// Construct a pseudo-SQL representation for the log
@@ -4051,7 +4146,7 @@ const DataGrid: React.FC<DataGridProps> = ({
return null;
}
const records = getTargets(record);
const orderedCols = columnNames.filter(c => c !== GONAVI_ROW_KEY);
const orderedCols = visibleColumnNames.filter(c => c !== GONAVI_ROW_KEY);
if (mode === 'insert') {
return records.map((row: any) => buildCopyInsertSQL({
dbType,
@@ -4100,7 +4195,7 @@ const DataGrid: React.FC<DataGridProps> = ({
}, [
supportsCopyInsert,
getTargets,
columnNames,
visibleColumnNames,
dbType,
tableName,
columnTypeMapByLowerName,
@@ -4130,16 +4225,18 @@ const DataGrid: React.FC<DataGridProps> = ({
const handleCopyJson = useCallback((record: any) => {
const records = getTargets(record);
const cleanRecords = records.map((r: any) => {
const { [GONAVI_ROW_KEY]: _rowKey, ...rest } = r;
return rest;
const next: Record<string, any> = {};
visibleColumnNames.forEach((columnName) => {
next[columnName] = r?.[columnName];
});
return next;
});
copyToClipboard(JSON.stringify(cleanRecords, null, 2));
}, [getTargets, copyToClipboard]);
}, [getTargets, visibleColumnNames, copyToClipboard]);
const handleCopyCsv = useCallback((record: any) => {
const records = getTargets(record);
// 使用 columnNames 保持表定义的字段顺序
const orderedCols = columnNames.filter(c => c !== GONAVI_ROW_KEY);
const orderedCols = visibleColumnNames.filter(c => c !== GONAVI_ROW_KEY);
const header = orderedCols.map(c => `"${c}"`).join(',');
const lines = records.map((r: any) => {
const values = orderedCols.map(c => {
@@ -4152,7 +4249,7 @@ const DataGrid: React.FC<DataGridProps> = ({
return values.join(',');
});
copyToClipboard([header, ...lines].join('\n'));
}, [getTargets, columnNames, copyToClipboard]);
}, [getTargets, visibleColumnNames, copyToClipboard]);
const buildConnConfig = useCallback(() => {
if (!connectionId) return null;

View File

@@ -0,0 +1,199 @@
import React from 'react';
import { act, create, type ReactTestRenderer } from 'react-test-renderer';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import type { TabData } from '../types';
import { ORACLE_ROWID_LOCATOR_COLUMN } from '../utils/rowLocator';
import DataViewer from './DataViewer';
const storeState = vi.hoisted(() => ({
connections: [
{
id: 'conn-1',
name: 'oracle',
config: {
type: 'oracle',
host: '127.0.0.1',
port: 1521,
user: 'scott',
password: '',
database: 'ORCLPDB1',
},
},
],
addSqlLog: vi.fn(),
}));
const backendApp = vi.hoisted(() => ({
DBQuery: vi.fn(),
DBGetColumns: vi.fn(),
DBGetIndexes: vi.fn(),
}));
const messageApi = vi.hoisted(() => ({
error: vi.fn(),
warning: vi.fn(),
}));
const dataGridState = vi.hoisted(() => ({
latestProps: null as any,
}));
vi.mock('../store', () => {
const useStore = Object.assign(
(selector: (state: typeof storeState) => any) => selector(storeState),
{ getState: () => storeState },
);
return { useStore };
});
vi.mock('../../wailsjs/go/app/App', () => backendApp);
vi.mock('antd', () => ({
message: messageApi,
}));
vi.mock('./DataGrid', () => ({
default: (props: any) => {
dataGridState.latestProps = props;
return <div data-grid="true" />;
},
GONAVI_ROW_KEY: '__gonavi_row_key__',
}));
const createTab = (overrides: Partial<TabData> = {}): TabData => ({
id: 'tab-1',
title: 'EDC_LOG',
type: 'table',
connectionId: 'conn-1',
dbName: 'MYCIMLED',
tableName: 'EDC_LOG',
...overrides,
});
const flushPromises = async () => {
await act(async () => {
await Promise.resolve();
await Promise.resolve();
});
};
describe('DataViewer safe editing locator', () => {
const renderAndReload = async (tab: TabData = createTab()) => {
let renderer: ReactTestRenderer;
await act(async () => {
renderer = create(<DataViewer tab={tab} />);
});
await act(async () => {
await dataGridState.latestProps.onReload();
});
await flushPromises();
return renderer!;
};
beforeEach(() => {
vi.clearAllMocks();
dataGridState.latestProps = null;
storeState.connections[0].config.type = 'oracle';
storeState.connections[0].config.database = 'ORCLPDB1';
backendApp.DBQuery.mockResolvedValue({
success: true,
fields: ['ID', 'NAME'],
data: [{ ID: 7, NAME: 'old-name' }],
});
backendApp.DBGetIndexes.mockResolvedValue({ success: true, data: [] });
});
it('enables table preview editing after primary keys are loaded', async () => {
backendApp.DBGetColumns.mockResolvedValue({
success: true,
data: [{ name: 'ID', key: 'PRI' }, { name: 'NAME', key: '' }],
});
const renderer = await renderAndReload();
expect(dataGridState.latestProps?.pkColumns).toEqual(['ID']);
expect(dataGridState.latestProps?.editLocator).toMatchObject({
strategy: 'primary-key',
columns: ['ID'],
valueColumns: ['ID'],
readOnly: false,
});
expect(dataGridState.latestProps?.readOnly).toBe(false);
expect(messageApi.warning).not.toHaveBeenCalled();
renderer.unmount();
});
it('uses a unique index when the table has no primary key', async () => {
backendApp.DBGetColumns.mockResolvedValue({
success: true,
data: [{ name: 'EMAIL', key: '' }, { name: 'NAME', key: '' }],
});
backendApp.DBGetIndexes.mockResolvedValue({
success: true,
data: [{ name: 'UK_EMAIL', columnName: 'EMAIL', nonUnique: 0, seqInIndex: 1, indexType: 'BTREE' }],
});
const renderer = await renderAndReload();
expect(dataGridState.latestProps?.pkColumns).toEqual([]);
expect(dataGridState.latestProps?.editLocator).toMatchObject({
strategy: 'unique-key',
columns: ['EMAIL'],
valueColumns: ['EMAIL'],
readOnly: false,
});
expect(dataGridState.latestProps?.readOnly).toBe(false);
expect(messageApi.warning).not.toHaveBeenCalled();
renderer.unmount();
});
it('uses hidden Oracle ROWID when no primary or unique key is available', async () => {
backendApp.DBGetColumns.mockResolvedValue({
success: true,
data: [{ name: 'ID', key: '' }, { name: 'NAME', key: '' }],
});
backendApp.DBQuery.mockResolvedValue({
success: true,
fields: ['ID', 'NAME', ORACLE_ROWID_LOCATOR_COLUMN],
data: [{ ID: 7, NAME: 'old-name', [ORACLE_ROWID_LOCATOR_COLUMN]: 'AAAA' }],
});
const renderer = await renderAndReload();
expect(dataGridState.latestProps?.pkColumns).toEqual([]);
expect(dataGridState.latestProps?.editLocator).toMatchObject({
strategy: 'oracle-rowid',
columns: ['ROWID'],
valueColumns: [ORACLE_ROWID_LOCATOR_COLUMN],
hiddenColumns: [ORACLE_ROWID_LOCATOR_COLUMN],
readOnly: false,
});
expect(dataGridState.latestProps?.readOnly).toBe(false);
expect(messageApi.warning).not.toHaveBeenCalled();
expect(backendApp.DBQuery.mock.calls.some((call: any[]) => String(call[2]).includes(`ROWID AS "${ORACLE_ROWID_LOCATOR_COLUMN}"`))).toBe(true);
renderer.unmount();
});
it('keeps non-Oracle table preview read-only when no safe locator exists', async () => {
storeState.connections[0].config.type = 'mysql';
storeState.connections[0].config.database = 'main';
backendApp.DBGetColumns.mockResolvedValue({
success: true,
data: [{ name: 'ID', key: '' }, { name: 'NAME', key: '' }],
});
const renderer = await renderAndReload(createTab({ dbName: 'main', tableName: 'users', title: 'users' }));
expect(dataGridState.latestProps?.pkColumns).toEqual([]);
expect(dataGridState.latestProps?.editLocator).toMatchObject({
strategy: 'none',
readOnly: true,
reason: '未检测到主键或可用唯一索引,无法安全提交修改。',
});
expect(dataGridState.latestProps?.readOnly).toBe(true);
expect(messageApi.warning).toHaveBeenCalledWith('表 main.users 保持只读:未检测到主键或可用唯一索引,无法安全提交修改。');
renderer.unmount();
});
});

View File

@@ -1,8 +1,8 @@
import React, { useEffect, useState, useCallback, useRef, useMemo } from 'react';
import { message } from 'antd';
import { TabData, ColumnDefinition } from '../types';
import { TabData, ColumnDefinition, IndexDefinition } from '../types';
import { useStore } from '../store';
import { DBQuery, DBGetColumns } from '../../wailsjs/go/app/App';
import { DBQuery, DBGetColumns, DBGetIndexes } from '../../wailsjs/go/app/App';
import DataGrid, { GONAVI_ROW_KEY } from './DataGrid';
import { buildOrderBySQL, buildPaginatedSelectSQL, buildWhereSQL, hasExplicitSort, quoteIdentPart, quoteQualifiedIdent, withSortBufferTuningSQL, type FilterCondition } from '../utils/sql';
import { buildMongoCountCommand, buildMongoFilter, buildMongoFindCommand, buildMongoSort } from '../utils/mongodb';
@@ -15,6 +15,12 @@ import {
normalizeQuickWhereCondition,
validateQuickWhereCondition,
} from '../utils/dataGridWhereFilter';
import {
ORACLE_ROWID_LOCATOR_COLUMN,
resolveEditRowLocator,
type EditRowLocator,
} from '../utils/rowLocator';
import { isOracleLikeDialect } from '../utils/sqlDialect';
type ViewerPaginationState = {
current: number;
@@ -79,6 +85,47 @@ const parseTotalFromCountRow = (row: any): number | null => {
return null;
};
const buildDataViewerReadOnlyLocator = (reason: string): EditRowLocator => ({
strategy: 'none',
columns: [],
valueColumns: [],
readOnly: true,
reason,
});
const formatDataViewerTableName = (dbName: string, tableName: string): string => (
dbName ? `${dbName}.${tableName}` : tableName
);
const getTableColumnNames = (columns: ColumnDefinition[] | undefined): string[] => (
(columns || [])
.map((column) => String(column?.name || '').trim())
.filter(Boolean)
);
const resolveDataViewerOrderFallbackColumns = (locator: EditRowLocator | undefined, pkColumns: string[]): string[] => {
if (locator && !locator.readOnly && locator.strategy !== 'oracle-rowid') {
return locator.valueColumns.length > 0 ? locator.valueColumns : locator.columns;
}
return pkColumns;
};
const buildDataViewerBaseSelectSQL = (
dbType: string,
tableName: string,
whereSQL: string,
locator?: EditRowLocator,
): string => {
const quotedTableName = quoteQualifiedIdent(dbType, tableName);
if (locator?.strategy !== 'oracle-rowid') {
return `SELECT * FROM ${quotedTableName} ${whereSQL}`;
}
const alias = 'gonavi_row_source';
const rowIDAlias = quoteIdentPart(dbType, ORACLE_ROWID_LOCATOR_COLUMN);
return `SELECT ${alias}.*, ${alias}.ROWID AS ${rowIDAlias} FROM ${quotedTableName} ${alias} ${whereSQL}`;
};
const normalizeDuckDBIdentifier = (raw: string): string => {
const text = String(raw || '').trim();
if (text.length >= 2) {
@@ -193,6 +240,7 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAct
const [data, setData] = useState<any[]>([]);
const [columnNames, setColumnNames] = useState<string[]>([]);
const [pkColumns, setPkColumns] = useState<string[]>([]);
const [editLocator, setEditLocator] = useState<EditRowLocator | undefined>(undefined);
const [loading, setLoading] = useState(false);
const connections = useStore(state => state.connections);
const addSqlLog = useStore(state => state.addSqlLog);
@@ -280,6 +328,7 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAct
useEffect(() => {
const snapshot = getViewerFilterSnapshot(tab.id);
setPkColumns([]);
setEditLocator(undefined);
pkKeyRef.current = '';
countKeyRef.current = '';
duckdbApproxKeyRef.current = '';
@@ -435,10 +484,84 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAct
const whereSQL = isMongoDB
? JSON.stringify(mongoFilter || {})
: buildWhereSQL(dbType, effectiveFilterConditions);
let pkColumnsForQuery = pkColumns;
let editLocatorForQuery = editLocator;
if (!isMongoDB && !forceReadOnly && tableName) {
const locatorKey = `${tab.connectionId}|${dbTypeLower}|${dbName}|${tableName}`;
if (pkKeyRef.current !== locatorKey || !editLocatorForQuery) {
pkKeyRef.current = locatorKey;
const locatorSeq = ++pkSeqRef.current;
try {
const [resCols, resIndexes] = await Promise.all([
DBGetColumns(buildRpcConnectionConfig(config) as any, dbName, tableName),
DBGetIndexes(buildRpcConnectionConfig(config) as any, dbName, tableName)
.catch((error: any) => ({ success: false, message: String(error?.message || error || '加载索引失败'), data: [] })),
]);
if (fetchSeqRef.current !== seq) return;
if (pkSeqRef.current !== locatorSeq) return;
if (pkKeyRef.current !== locatorKey) return;
if (!resCols?.success || !Array.isArray(resCols.data)) {
const nextLocator = buildDataViewerReadOnlyLocator('无法加载主键/唯一索引元数据,无法安全提交修改。');
pkColumnsForQuery = [];
editLocatorForQuery = nextLocator;
setPkColumns([]);
setEditLocator(nextLocator);
message.warning(`${formatDataViewerTableName(dbName, tableName)} 保持只读:${nextLocator.reason}`);
} else {
const columnDefs = resCols.data as ColumnDefinition[];
const primaryKeys = columnDefs
.filter((column: any) => column?.key === 'PRI')
.map((column: any) => String(column?.name || '').trim())
.filter(Boolean);
const indexes = resIndexes?.success && Array.isArray(resIndexes.data)
? resIndexes.data as IndexDefinition[]
: [];
const resultColumns = getTableColumnNames(columnDefs);
const locatorColumns = isOracleLikeDialect(dbType)
? [...resultColumns, ORACLE_ROWID_LOCATOR_COLUMN]
: resultColumns;
let nextLocator = resolveEditRowLocator({
dbType,
resultColumns: locatorColumns,
primaryKeys,
indexes,
allowOracleRowID: true,
});
if (nextLocator.readOnly && primaryKeys.length === 0 && !resIndexes?.success && !isOracleLikeDialect(dbType)) {
nextLocator = buildDataViewerReadOnlyLocator('无法加载唯一索引元数据,无法安全提交修改。');
}
pkColumnsForQuery = primaryKeys;
editLocatorForQuery = nextLocator;
setPkColumns(primaryKeys);
setEditLocator(nextLocator);
if (nextLocator.readOnly) {
message.warning(`${formatDataViewerTableName(dbName, tableName)} 保持只读:${nextLocator.reason || '当前结果没有可用的安全行定位方式,无法提交修改。'}`);
}
}
} catch {
if (fetchSeqRef.current !== seq) return;
if (pkSeqRef.current !== locatorSeq) return;
if (pkKeyRef.current !== locatorKey) return;
const nextLocator = buildDataViewerReadOnlyLocator('无法加载主键/唯一索引元数据,无法安全提交修改。');
pkColumnsForQuery = [];
editLocatorForQuery = nextLocator;
setPkColumns([]);
setEditLocator(nextLocator);
message.warning(`${formatDataViewerTableName(dbName, tableName)} 保持只读:${nextLocator.reason}`);
}
}
}
const countSql = isMongoDB
? buildMongoCountCommand(tableName, mongoFilter || {})
: `SELECT COUNT(*) as total FROM ${quoteQualifiedIdent(dbType, tableName)} ${whereSQL}`;
const orderBySQL = isMongoDB ? '' : buildOrderBySQL(dbType, sortInfo, pkColumns);
const orderBySQL = isMongoDB
? ''
: buildOrderBySQL(dbType, sortInfo, resolveDataViewerOrderFallbackColumns(editLocatorForQuery, pkColumnsForQuery));
const totalRows = Number(pagination.total);
const hasFiniteTotal = Number.isFinite(totalRows) && totalRows >= 0;
const totalKnown = pagination.totalKnown && hasFiniteTotal;
@@ -469,7 +592,7 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAct
skip: offset,
});
} else {
const baseSql = `SELECT * FROM ${quoteQualifiedIdent(dbType, tableName)} ${whereSQL}`;
const baseSql = buildDataViewerBaseSelectSQL(dbType, tableName, whereSQL, editLocatorForQuery);
sql = `${baseSql}${orderBySQL}`;
// ClickHouse 深分页在超大 OFFSET 下容易超时。对于总数已知且存在 ORDER BY 的场景,
// 当“尾部偏移”小于“头部偏移”时,改为反向 ORDER BY + 小 OFFSET并在前端翻转结果。
@@ -557,7 +680,7 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAct
if (safeSelect) {
let fallbackSql = `SELECT ${safeSelect} FROM ${quoteQualifiedIdent(dbType, tableName)} ${whereSQL}`;
fallbackSql = buildPaginatedSelectSQL(dbType, fallbackSql, buildOrderBySQL(dbType, sortInfo, pkColumns), size + 1, offset);
fallbackSql = buildPaginatedSelectSQL(dbType, fallbackSql, buildOrderBySQL(dbType, sortInfo, resolveDataViewerOrderFallbackColumns(editLocatorForQuery, pkColumnsForQuery)), size + 1, offset);
executedSql = fallbackSql;
resData = await executeDataQuery(fallbackSql, '复杂类型降级重试');
}
@@ -580,26 +703,6 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAct
message.warning('已自动提升排序缓冲并重试成功。');
}
}
if (pkColumns.length === 0) {
const pkKey = `${tab.connectionId}|${dbName}|${tableName}`;
if (pkKeyRef.current !== pkKey) {
pkKeyRef.current = pkKey;
const pkSeq = ++pkSeqRef.current;
DBGetColumns(buildRpcConnectionConfig(config) as any, dbName, tableName)
.then((resCols: any) => {
if (pkSeqRef.current !== pkSeq) return;
if (pkKeyRef.current !== pkKey) return;
if (!resCols?.success) return;
const pks = (resCols.data as ColumnDefinition[]).filter((c: any) => c.key === 'PRI').map((c: any) => c.name);
setPkColumns(pks);
})
.catch(() => {
if (pkSeqRef.current !== pkSeq) return;
if (pkKeyRef.current !== pkKey) return;
});
}
}
if (resData.success) {
let resultData = resData.data as any[];
@@ -842,9 +945,9 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAct
});
}
if (fetchSeqRef.current === seq) setLoading(false);
}, [connections, tab, sortInfo, filterConditions, quickWhereCondition, pkColumns, pagination.total, pagination.totalKnown, pagination.totalApprox, pagination.approximateTotal, preferManualTotalCount, supportsApproximateTableCount, supportsApproximateTotalPages]);
// 依赖 pkColumns:在无手动排序时可回退到主键稳定排序。
// 主键信息只会在首次加载后更新一次,避免循环查询。
}, [connections, tab, sortInfo, filterConditions, quickWhereCondition, pkColumns, editLocator, forceReadOnly, pagination.total, pagination.totalKnown, pagination.totalApprox, pagination.approximateTotal, preferManualTotalCount, supportsApproximateTableCount, supportsApproximateTotalPages]);
// 依赖定位列:在无手动排序时可回退到安全定位列稳定排序。
// 定位信息只会在表上下文变化后重新加载,避免循环查询。
// Handlers memoized
const handleReload = useCallback(() => {
@@ -890,14 +993,14 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAct
if (!whereSQL) return '';
let sql = `SELECT * FROM ${quoteQualifiedIdent(dbType, tableName)} ${whereSQL}`;
sql += buildOrderBySQL(dbType, sortInfo, pkColumns);
sql += buildOrderBySQL(dbType, sortInfo, resolveDataViewerOrderFallbackColumns(editLocator, pkColumns));
const normalizedType = dbType.toLowerCase();
const hasSortForBuffer = hasExplicitSort(sortInfo);
if (hasSortForBuffer && (normalizedType === 'mysql' || normalizedType === 'mariadb')) {
sql = withSortBufferTuningSQL(normalizedType, sql, 32 * 1024 * 1024);
}
return sql;
}, [tab.tableName, currentConnConfig?.type, currentConnConfig?.driver, filterConditions, quickWhereCondition, sortInfo, pkColumns]);
}, [tab.tableName, currentConnConfig?.type, currentConnConfig?.driver, filterConditions, quickWhereCondition, sortInfo, editLocator, pkColumns]);
useEffect(() => {
const action = resolveDataViewerAutoFetchAction({
@@ -927,6 +1030,7 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAct
dbName={tab.dbName}
connectionId={tab.connectionId}
pkColumns={pkColumns}
editLocator={editLocator}
onReload={handleReload}
onSort={handleSort}
onPageChange={handlePageChange}
@@ -939,7 +1043,7 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAct
appliedFilterConditions={filterConditions}
quickWhereCondition={quickWhereCondition}
onApplyQuickWhereCondition={handleApplyQuickWhereCondition}
readOnly={forceReadOnly}
readOnly={forceReadOnly || !editLocator || editLocator.readOnly}
sortInfoExternal={sortInfo}
exportSqlWithFilter={exportSqlWithFilter || undefined}
scrollSnapshot={scrollSnapshotRef.current}

View File

@@ -39,6 +39,11 @@ type DriverStatusRow = {
packagePath?: string;
executablePath?: string;
downloadedAt?: string;
agentRevision?: string;
expectedRevision?: string;
needsUpdate?: boolean;
updateReason?: string;
affectedConnections?: number;
message?: string;
};
@@ -360,6 +365,13 @@ const DriverManagerModal: React.FC<{ open: boolean; onClose: () => void; onOpenG
packagePath: String(item.packagePath || '').trim() || undefined,
executablePath: String(item.executablePath || '').trim() || undefined,
downloadedAt: String(item.downloadedAt || '').trim() || undefined,
agentRevision: String(item.agentRevision || '').trim() || undefined,
expectedRevision: String(item.expectedRevision || '').trim() || undefined,
needsUpdate: !!item.needsUpdate,
updateReason: String(item.updateReason || '').trim() || undefined,
affectedConnections: Number.isFinite(Number(item.affectedConnections))
? Number(item.affectedConnections)
: undefined,
message: String(item.message || '').trim() || undefined,
}));
setRows(nextRows);
@@ -1005,7 +1017,17 @@ const DriverManagerModal: React.FC<{ open: boolean; onClose: () => void; onOpenG
title: '数据源',
dataIndex: 'name',
key: 'name',
width: 150,
width: 220,
render: (_: string, row: DriverStatusRow) => (
<div style={{ display: 'grid', gap: 4 }}>
<Text strong>{row.name}</Text>
{row.message ? (
<Text type={row.needsUpdate ? 'warning' : 'secondary'} style={{ fontSize: 12 }}>
{row.message}
</Text>
) : null}
</div>
),
},
{
title: '安装包大小',
@@ -1042,6 +1064,9 @@ const DriverManagerModal: React.FC<{ open: boolean; onClose: () => void; onOpenG
if (progress && (progress.status === 'start' || progress.status === 'downloading')) {
return <Tag color="processing"> {Math.round(progress.percent)}%</Tag>;
}
if (row.needsUpdate) {
return <Tag color="warning"></Tag>;
}
if (row.connectable) {
return <Tag color="success"></Tag>;
}
@@ -1089,10 +1114,11 @@ const DriverManagerModal: React.FC<{ open: boolean; onClose: () => void; onOpenG
const versionLocked = row.packageInstalled || row.connectable;
if (versionLocked) {
const installedVersion = String(row.installedVersion || '').trim();
const revisionHint = row.needsUpdate ? ',需重装' : '';
if (installedVersion) {
return <Text type="secondary">{installedVersion}</Text>;
return <Text type="secondary">{installedVersion}{revisionHint}</Text>;
}
return <Text type="secondary"></Text>;
return <Text type="secondary">{row.needsUpdate ? '需重装,' : ''}</Text>;
}
const options = versionMap[row.type] || [];
const selectedKey = selectedVersionMap[row.type];
@@ -1148,7 +1174,16 @@ const DriverManagerModal: React.FC<{ open: boolean; onClose: () => void; onOpenG
const logs = operationLogMap[row.type] || [];
const hasLogs = logs.length > 0;
const mainAction = row.connectable ? (
const mainAction = row.needsUpdate ? (
<Button
type="primary"
icon={<DownloadOutlined />}
loading={loadingInstallOrRemove}
onClick={() => installDriver(row)}
>
</Button>
) : row.connectable ? (
<Button
danger
icon={<DeleteOutlined />}
@@ -1209,9 +1244,10 @@ const DriverManagerModal: React.FC<{ open: boolean; onClose: () => void; onOpenG
row.type,
row.pinnedVersion,
row.installedVersion,
row.updateReason,
row.message,
row.builtIn ? '内置' : '外置',
row.connectable ? '已启用' : row.packageInstalled ? '已安装' : '未启用',
row.needsUpdate ? '强烈建议重装' : row.connectable ? '已启用' : row.packageInstalled ? '已安装' : '未启用',
];
const searchableText = normalizeDriverSearchText(searchableParts.filter(Boolean).join(' '));
return searchableText.includes(normalizedSearchKeyword);

View File

@@ -3,6 +3,7 @@ import { act, create, type ReactTestRenderer } from 'react-test-renderer';
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
import type { SavedQuery, TabData } from '../types';
import { ORACLE_ROWID_LOCATOR_COLUMN } from '../utils/rowLocator';
import QueryEditor from './QueryEditor';
const storeState = vi.hoisted(() => ({
@@ -44,6 +45,7 @@ const backendApp = vi.hoisted(() => ({
DBGetAllColumns: vi.fn(),
DBGetDatabases: vi.fn(),
DBGetColumns: vi.fn(),
DBGetIndexes: vi.fn(),
CancelQuery: vi.fn(),
GenerateQueryID: vi.fn(),
WriteSQLFile: vi.fn(),
@@ -56,6 +58,10 @@ const messageApi = vi.hoisted(() => ({
warning: vi.fn(),
}));
const dataGridState = vi.hoisted(() => ({
latestProps: null as any,
}));
const editorState = vi.hoisted(() => {
const state = {
value: '',
@@ -114,7 +120,10 @@ vi.mock('@monaco-editor/react', () => ({
}));
vi.mock('./DataGrid', () => ({
default: () => null,
default: (props: any) => {
dataGridState.latestProps = props;
return <div data-grid="true" />;
},
GONAVI_ROW_KEY: '__gonavi_row_key__',
}));
@@ -152,7 +161,7 @@ vi.mock('antd', () => {
Dropdown: ({ children }: any) => <>{children}</>,
Tooltip: ({ children }: any) => <>{children}</>,
Select: () => null,
Tabs: () => null,
Tabs: ({ items }: any) => <div>{items?.[0]?.children}</div>,
};
});
@@ -187,7 +196,15 @@ describe('QueryEditor external SQL save', () => {
storeState.activeTabId = 'tab-1';
messageApi.success.mockReset();
messageApi.error.mockReset();
messageApi.warning.mockReset();
backendApp.WriteSQLFile.mockResolvedValue({ success: true });
backendApp.DBQueryMulti.mockResolvedValue({ success: true, data: [] });
backendApp.DBGetColumns.mockResolvedValue({ success: true, data: [] });
backendApp.DBGetIndexes.mockResolvedValue({ success: true, data: [] });
backendApp.GenerateQueryID.mockResolvedValue('query-1');
storeState.connections[0].config.type = 'mysql';
storeState.connections[0].config.database = 'main';
dataGridState.latestProps = null;
editorState.value = '';
editorState.editor.getValue.mockClear();
editorState.editor.setValue.mockClear();
@@ -276,4 +293,253 @@ describe('QueryEditor external SQL save', () => {
createdAt: 100,
}));
});
it('automatically appends hidden primary key locator columns for editable query results', async () => {
storeState.connections[0].config.type = 'oracle';
storeState.connections[0].config.database = 'ORCLPDB1';
backendApp.DBQueryMulti.mockResolvedValueOnce({
success: true,
data: [{ columns: ['NAME', '__gonavi_locator_1_ID'], rows: [{ NAME: 'old-name', __gonavi_locator_1_ID: 7 }] }],
});
backendApp.DBGetColumns.mockResolvedValueOnce({
success: true,
data: [{ name: 'ID', key: 'PRI' }, { name: 'NAME', key: '' }],
});
let renderer: ReactTestRenderer;
await act(async () => {
renderer = create(<QueryEditor tab={createTab({ dbName: 'ANONYMOUS', query: 'SELECT NAME FROM MYCIMLED.EDC_LOG' })} />);
});
await act(async () => {
await findButton(renderer!, '运行').props.onClick();
});
await act(async () => {
await Promise.resolve();
await Promise.resolve();
});
expect(dataGridState.latestProps?.tableName).toBe('MYCIMLED.EDC_LOG');
expect(dataGridState.latestProps?.pkColumns).toEqual(['ID']);
expect(dataGridState.latestProps?.editLocator).toMatchObject({
strategy: 'primary-key',
columns: ['ID'],
valueColumns: ['__gonavi_locator_1_ID'],
hiddenColumns: ['__gonavi_locator_1_ID'],
readOnly: false,
});
expect(dataGridState.latestProps?.readOnly).toBe(false);
expect(dataGridState.latestProps?.resultSql).toBe('SELECT NAME FROM MYCIMLED.EDC_LOG');
expect(String(backendApp.DBQueryMulti.mock.calls[0][2])).toContain('"ID" AS "__gonavi_locator_1_ID"');
expect(messageApi.warning).not.toHaveBeenCalled();
});
it('uses a unique index locator for query results without primary keys', async () => {
storeState.connections[0].config.type = 'oracle';
storeState.connections[0].config.database = 'ORCLPDB1';
backendApp.DBQueryMulti.mockResolvedValueOnce({
success: true,
data: [{ columns: ['NAME', '__gonavi_locator_1_EMAIL'], rows: [{ NAME: 'old-name', __gonavi_locator_1_EMAIL: 'a@example.com' }] }],
});
backendApp.DBGetColumns.mockResolvedValueOnce({
success: true,
data: [{ name: 'EMAIL', key: '' }, { name: 'NAME', key: '' }],
});
backendApp.DBGetIndexes.mockResolvedValueOnce({
success: true,
data: [{ name: 'UK_EMAIL', columnName: 'EMAIL', nonUnique: 0, seqInIndex: 1, indexType: 'BTREE' }],
});
let renderer: ReactTestRenderer;
await act(async () => {
renderer = create(<QueryEditor tab={createTab({ dbName: 'ANONYMOUS', query: 'SELECT NAME FROM MYCIMLED.EDC_LOG' })} />);
});
await act(async () => {
await findButton(renderer!, '运行').props.onClick();
});
await act(async () => {
await Promise.resolve();
await Promise.resolve();
});
expect(dataGridState.latestProps?.editLocator).toMatchObject({
strategy: 'unique-key',
columns: ['EMAIL'],
valueColumns: ['__gonavi_locator_1_EMAIL'],
hiddenColumns: ['__gonavi_locator_1_EMAIL'],
readOnly: false,
});
expect(dataGridState.latestProps?.readOnly).toBe(false);
expect(String(backendApp.DBQueryMulti.mock.calls[0][2])).toContain('"EMAIL" AS "__gonavi_locator_1_EMAIL"');
expect(messageApi.warning).not.toHaveBeenCalled();
});
it('uses hidden Oracle ROWID for query results without primary or unique keys', async () => {
storeState.connections[0].config.type = 'oracle';
storeState.connections[0].config.database = 'ORCLPDB1';
backendApp.DBQueryMulti.mockResolvedValueOnce({
success: true,
data: [{ columns: ['NAME', ORACLE_ROWID_LOCATOR_COLUMN], rows: [{ NAME: 'old-name', [ORACLE_ROWID_LOCATOR_COLUMN]: 'AAAA' }] }],
});
backendApp.DBGetColumns.mockResolvedValueOnce({
success: true,
data: [{ name: 'NAME', key: '' }],
});
let renderer: ReactTestRenderer;
await act(async () => {
renderer = create(<QueryEditor tab={createTab({ dbName: 'ANONYMOUS', query: 'SELECT NAME FROM MYCIMLED.EDC_LOG' })} />);
});
await act(async () => {
await findButton(renderer!, '运行').props.onClick();
});
await act(async () => {
await Promise.resolve();
await Promise.resolve();
});
expect(dataGridState.latestProps?.editLocator).toMatchObject({
strategy: 'oracle-rowid',
columns: ['ROWID'],
valueColumns: [ORACLE_ROWID_LOCATOR_COLUMN],
hiddenColumns: [ORACLE_ROWID_LOCATOR_COLUMN],
readOnly: false,
});
expect(dataGridState.latestProps?.readOnly).toBe(false);
expect(String(backendApp.DBQueryMulti.mock.calls[0][2])).toContain(`ROWID AS "${ORACLE_ROWID_LOCATOR_COLUMN}"`);
expect(messageApi.warning).not.toHaveBeenCalled();
});
it('keeps non-Oracle query results read-only when no safe locator exists', async () => {
backendApp.DBQueryMulti.mockResolvedValueOnce({
success: true,
data: [{ columns: ['NAME'], rows: [{ NAME: 'old-name' }] }],
});
backendApp.DBGetColumns.mockResolvedValueOnce({
success: true,
data: [{ name: 'NAME', key: '' }],
});
let renderer: ReactTestRenderer;
await act(async () => {
renderer = create(<QueryEditor tab={createTab({ dbName: 'main', query: 'SELECT NAME FROM users' })} />);
});
await act(async () => {
await findButton(renderer!, '运行').props.onClick();
});
await act(async () => {
await Promise.resolve();
await Promise.resolve();
});
expect(dataGridState.latestProps?.tableName).toBe('users');
expect(dataGridState.latestProps?.pkColumns).toEqual([]);
expect(dataGridState.latestProps?.editLocator).toMatchObject({
strategy: 'none',
readOnly: true,
reason: '未检测到主键或可用唯一索引,无法安全提交修改。',
});
expect(dataGridState.latestProps?.readOnly).toBe(true);
expect(messageApi.warning).toHaveBeenCalledWith('查询结果保持只读main.users 未检测到主键或可用唯一索引,无法安全提交修改。');
});
it('allows editable table columns while leaving expression columns out of commits', async () => {
backendApp.DBQueryMulti.mockResolvedValueOnce({
success: true,
data: [{
columns: ['DISPLAY_NAME', 'NAME_UPPER', '__gonavi_locator_1_ID'],
rows: [{ DISPLAY_NAME: 'old-name', NAME_UPPER: 'OLD-NAME', __gonavi_locator_1_ID: 7 }],
}],
});
backendApp.DBGetColumns.mockResolvedValueOnce({
success: true,
data: [{ name: 'ID', key: 'PRI' }, { name: 'NAME', key: '' }],
});
let renderer: ReactTestRenderer;
await act(async () => {
renderer = create(<QueryEditor tab={createTab({
dbName: 'main',
query: 'SELECT NAME AS DISPLAY_NAME, UPPER(NAME) AS NAME_UPPER FROM users',
})} />);
});
await act(async () => {
await findButton(renderer!, '运行').props.onClick();
});
await act(async () => {
await Promise.resolve();
await Promise.resolve();
});
expect(dataGridState.latestProps?.tableName).toBe('users');
expect(dataGridState.latestProps?.editLocator).toMatchObject({
strategy: 'primary-key',
columns: ['ID'],
valueColumns: ['__gonavi_locator_1_ID'],
hiddenColumns: ['__gonavi_locator_1_ID'],
writableColumns: {
DISPLAY_NAME: 'NAME',
},
readOnly: false,
});
expect(dataGridState.latestProps?.readOnly).toBe(false);
expect(String(backendApp.DBQueryMulti.mock.calls[0][2])).toContain('`ID` AS `__gonavi_locator_1_ID`');
expect(messageApi.warning).not.toHaveBeenCalled();
});
it.each([
'mysql',
'mariadb',
'diros',
'sphinx',
'postgres',
'kingbase',
'highgo',
'vastbase',
'sqlserver',
'sqlite',
'duckdb',
'oracle',
'dameng',
'tdengine',
'clickhouse',
])(
'keeps aggregate query results silently read-only for %s',
async (dbType) => {
storeState.connections[0].config.type = dbType;
storeState.connections[0].config.database = dbType === 'oracle' || dbType === 'dameng' ? 'APP' : 'main';
const forceReadOnlyQueryResult = dbType === 'tdengine' || dbType === 'clickhouse';
backendApp.DBQueryMulti.mockResolvedValueOnce({
success: true,
data: [{ columns: ['COUNT'], rows: [{ COUNT: 1 }] }],
});
let renderer: ReactTestRenderer;
await act(async () => {
renderer = create(<QueryEditor tab={createTab({
dbName: storeState.connections[0].config.database,
query: 'SELECT count(1) FROM users',
})} />);
});
await act(async () => {
await findButton(renderer!, '运行').props.onClick();
});
await act(async () => {
await Promise.resolve();
await Promise.resolve();
});
expect(dataGridState.latestProps?.tableName).toBe(forceReadOnlyQueryResult ? undefined : 'users');
expect(dataGridState.latestProps?.editLocator).toBeUndefined();
expect(dataGridState.latestProps?.readOnly).toBe(true);
expect(backendApp.DBGetColumns).not.toHaveBeenCalled();
expect(backendApp.DBGetIndexes).not.toHaveBeenCalled();
expect(messageApi.warning).not.toHaveBeenCalled();
},
);
});

View File

@@ -4,16 +4,21 @@ import { Button, message, Modal, Input, Form, Dropdown, MenuProps, Tooltip, Sele
import { PlayCircleOutlined, SaveOutlined, FormatPainterOutlined, SettingOutlined, CloseOutlined, StopOutlined, RobotOutlined } from '@ant-design/icons';
import { format } from 'sql-formatter';
import { v4 as uuidv4 } from 'uuid';
import { TabData, ColumnDefinition } from '../types';
import { TabData, ColumnDefinition, IndexDefinition } from '../types';
import { useStore } from '../store';
import { DBQueryWithCancel, DBQueryMulti, DBGetTables, DBGetAllColumns, DBGetDatabases, DBGetColumns, CancelQuery, GenerateQueryID, WriteSQLFile } from '../../wailsjs/go/app/App';
import { DBQueryWithCancel, DBQueryMulti, DBGetTables, DBGetAllColumns, DBGetDatabases, DBGetColumns, DBGetIndexes, CancelQuery, GenerateQueryID, WriteSQLFile } from '../../wailsjs/go/app/App';
import DataGrid, { GONAVI_ROW_KEY } from './DataGrid';
import { getDataSourceCapabilities } from '../utils/dataSourceCapabilities';
import { convertMongoShellToJsonCommand } from '../utils/mongodb';
import { applyMongoQueryAutoLimit, convertMongoShellToJsonCommand } from '../utils/mongodb';
import { getShortcutDisplay, isEditableElement, isShortcutMatch } from '../utils/shortcuts';
import { useAutoFetchVisibility } from '../utils/autoFetchVisibility';
import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig';
import { resolveSqlDialect, resolveSqlFunctions, resolveSqlKeywords } from '../utils/sqlDialect';
import { isOracleLikeDialect, resolveSqlDialect, resolveSqlFunctions, resolveSqlKeywords } from '../utils/sqlDialect';
import { applyQueryAutoLimit } from '../utils/queryAutoLimit';
import { extractQueryResultTableRef, type QueryResultTableRef } from '../utils/queryResultTable';
import { quoteIdentPart } from '../utils/sql';
import { resolveUniqueKeyGroupsFromIndexes } from './dataGridCopyInsert';
import { ORACLE_ROWID_LOCATOR_COLUMN, type EditRowLocator } from '../utils/rowLocator';
const SQL_KEYWORDS = [
'SELECT', 'FROM', 'WHERE', 'LIMIT', 'INSERT', 'UPDATE', 'DELETE', 'JOIN', 'LEFT', 'RIGHT',
@@ -186,6 +191,300 @@ let sharedAllColumnsData: {dbName: string, tableName: string, name: string, type
let sharedVisibleDbs: string[] = [];
let sharedColumnsCacheData: Record<string, any[]> = {};
const QUERY_LOCATOR_ALIAS_PREFIX = '__gonavi_locator_';
const buildQueryReadOnlyLocator = (reason: string): EditRowLocator => ({
strategy: 'none',
columns: [],
valueColumns: [],
readOnly: true,
reason,
});
type SimpleSelectInfo = {
selectsAll: boolean;
writableColumns: Record<string, string>;
};
type QueryStatementPlan = {
originalSql: string;
executedSql: string;
tableRef?: QueryResultTableRef;
pkColumns: string[];
editLocator?: EditRowLocator;
warning?: string;
};
const stripQueryIdentifierQuotes = (part: string): string => {
const text = String(part || '').trim();
if (!text) return '';
if ((text.startsWith('`') && text.endsWith('`')) || (text.startsWith('"') && text.endsWith('"'))) {
return text.slice(1, -1).trim();
}
if (text.startsWith('[') && text.endsWith(']')) {
return text.slice(1, -1).trim();
}
return text;
};
const splitTopLevelComma = (text: string): string[] => {
const parts: string[] = [];
let current = '';
let parenDepth = 0;
let inSingle = false;
let inDouble = false;
let inBacktick = false;
let escaped = false;
for (let index = 0; index < text.length; index++) {
const ch = text[index];
if (escaped) {
current += ch;
escaped = false;
continue;
}
if ((inSingle || inDouble) && ch === '\\') {
current += ch;
escaped = true;
continue;
}
if (!inDouble && !inBacktick && ch === "'") {
inSingle = !inSingle;
current += ch;
continue;
}
if (!inSingle && !inBacktick && ch === '"') {
inDouble = !inDouble;
current += ch;
continue;
}
if (!inSingle && !inDouble && ch === '`') {
inBacktick = !inBacktick;
current += ch;
continue;
}
if (!inSingle && !inDouble && !inBacktick) {
if (ch === '(') parenDepth++;
if (ch === ')' && parenDepth > 0) parenDepth--;
if (ch === ',' && parenDepth === 0) {
parts.push(current.trim());
current = '';
continue;
}
}
current += ch;
}
if (current.trim()) parts.push(current.trim());
return parts;
};
const SIMPLE_IDENTIFIER_PATH_RE = /^(?:[`"\[]?[A-Za-z_][\w$]*[`"\]]?\s*\.\s*){0,2}[`"\[]?[A-Za-z_][\w$]*[`"\]]?$/;
const QUERY_ALIAS_RESERVED = new Set([
'where', 'group', 'order', 'having', 'limit', 'fetch', 'offset', 'join', 'left', 'right', 'inner', 'outer', 'on', 'union',
]);
const getLastIdentifierPart = (path: string): string => {
const parts = String(path || '').split('.').map((part) => stripQueryIdentifierQuotes(part.trim())).filter(Boolean);
return parts[parts.length - 1] || '';
};
const resolveSimpleSelectItemColumn = (item: string): { resultName: string; sourceName: string } | 'all' | undefined => {
const text = String(item || '').trim();
if (!text) return undefined;
if (text === '*' || /\.\s*\*$/.test(text)) return 'all';
let expr = text;
let alias = '';
const asMatch = text.match(/^(.*?)\s+AS\s+([`"\[]?[A-Za-z_][\w$]*[`"\]]?)$/i);
if (asMatch) {
expr = asMatch[1].trim();
alias = stripQueryIdentifierQuotes(asMatch[2]);
} else {
const bareAliasMatch = text.match(/^(.*?)\s+([`"\[]?[A-Za-z_][\w$]*[`"\]]?)$/);
if (bareAliasMatch && SIMPLE_IDENTIFIER_PATH_RE.test(bareAliasMatch[1].trim())) {
const candidateAlias = stripQueryIdentifierQuotes(bareAliasMatch[2]);
if (candidateAlias && !QUERY_ALIAS_RESERVED.has(candidateAlias.toLowerCase())) {
expr = bareAliasMatch[1].trim();
alias = candidateAlias;
}
}
}
if (!SIMPLE_IDENTIFIER_PATH_RE.test(expr)) return undefined;
const sourceName = getLastIdentifierPart(expr);
const resultName = alias || sourceName;
return sourceName && resultName ? { resultName, sourceName } : undefined;
};
const parseSimpleSelectInfo = (sql: string): SimpleSelectInfo | undefined => {
const match = String(sql || '').match(/^\s*SELECT\s+([\s\S]+?)\s+FROM\s+/i);
if (!match) return undefined;
const selectList = match[1].trim();
if (!selectList || /^DISTINCT\b/i.test(selectList)) return undefined;
const writableColumns: Record<string, string> = {};
let selectsAll = false;
for (const item of splitTopLevelComma(selectList)) {
const resolved = resolveSimpleSelectItemColumn(item);
if (!resolved) continue;
if (resolved === 'all') {
selectsAll = true;
continue;
}
writableColumns[resolved.resultName] = resolved.sourceName;
}
return { selectsAll, writableColumns };
};
const appendQuerySelectExpressions = (sql: string, expressions: string[]): string => {
if (expressions.length === 0) return sql;
return String(sql || '').replace(
/^(\s*SELECT\s+)([\s\S]+?)(\s+FROM\s+[\s\S]*)$/i,
(_match, prefix, selectList, rest) => `${prefix}${String(selectList).trimEnd()}, ${expressions.join(', ')}${rest}`,
);
};
const findWritableResultColumnForSource = (writableColumns: Record<string, string>, target: string): string | undefined => {
const normalizedTarget = String(target || '').trim().toLowerCase();
return Object.entries(writableColumns || {}).find(([, sourceColumn]) => (
String(sourceColumn || '').trim().toLowerCase() === normalizedTarget
))?.[0];
};
const buildQueryLocatorAlias = (column: string, index: number): string => {
const normalized = String(column || '').trim().replace(/[^A-Za-z0-9_]/g, '_').slice(0, 48) || 'column';
return `${QUERY_LOCATOR_ALIAS_PREFIX}${index}_${normalized}`;
};
const buildQueryLocatorColumnExpression = (dbType: string, column: string, alias: string): string => (
`${quoteIdentPart(dbType, column)} AS ${quoteIdentPart(dbType, alias)}`
);
const buildQueryRowIDExpression = (dbType: string): string => (
`ROWID AS ${quoteIdentPart(dbType, ORACLE_ROWID_LOCATOR_COLUMN)}`
);
const resolveQueryLocatorPlan = async ({
statement,
dbType,
currentDb,
config,
forceReadOnly,
}: {
statement: string;
dbType: string;
currentDb: string;
config: any;
forceReadOnly: boolean;
}): Promise<QueryStatementPlan> => {
const plan: QueryStatementPlan = {
originalSql: statement,
executedSql: statement,
pkColumns: [],
};
if (forceReadOnly) return plan;
const tableRef = extractQueryResultTableRef(statement, dbType, currentDb);
if (!tableRef) return plan;
plan.tableRef = tableRef;
const selectInfo = parseSimpleSelectInfo(statement);
if (!selectInfo) {
// 聚合、函数和表达式结果天然无法安全回写到单行,静默保持只读即可。
return plan;
}
if (!selectInfo.selectsAll && Object.keys(selectInfo.writableColumns).length === 0) {
return plan;
}
try {
const [resCols, resIndexes] = await Promise.all([
DBGetColumns(buildRpcConnectionConfig(config) as any, tableRef.metadataDbName, tableRef.metadataTableName),
DBGetIndexes(buildRpcConnectionConfig(config) as any, tableRef.metadataDbName, tableRef.metadataTableName)
.catch((error: any) => ({ success: false, message: String(error?.message || error || '加载索引失败'), data: [] })),
]);
if (!resCols?.success || !Array.isArray(resCols.data)) {
const reason = `无法加载 ${tableRef.metadataDbName}.${tableRef.metadataTableName} 的主键/唯一索引元数据,无法安全提交修改。`;
plan.editLocator = buildQueryReadOnlyLocator(reason);
plan.warning = `查询结果保持只读:${reason}`;
return plan;
}
const tableColumns = resCols.data as ColumnDefinition[];
const tableColumnNames = tableColumns.map((column) => String(column?.name || '').trim()).filter(Boolean);
const primaryKeys = tableColumns
.filter((column: any) => column?.key === 'PRI')
.map((column: any) => String(column?.name || '').trim())
.filter(Boolean);
const indexes = resIndexes?.success && Array.isArray(resIndexes.data)
? resIndexes.data as IndexDefinition[]
: [];
const writableColumns: Record<string, string> = selectInfo.selectsAll
? Object.fromEntries(tableColumnNames.map((column) => [column, column]))
: {};
Object.entries(selectInfo.writableColumns).forEach(([resultColumn, sourceColumn]) => {
writableColumns[resultColumn] = sourceColumn;
});
const appendExpressions: string[] = [];
const hiddenColumns: string[] = [];
const buildColumnLocator = (strategy: 'primary-key' | 'unique-key', locatorColumns: string[]): EditRowLocator => {
const valueColumns = locatorColumns.map((column, index) => {
const selectedColumn = findWritableResultColumnForSource(writableColumns, column);
if (selectedColumn) return selectedColumn;
const alias = buildQueryLocatorAlias(column, index + 1);
appendExpressions.push(buildQueryLocatorColumnExpression(dbType, column, alias));
hiddenColumns.push(alias);
return alias;
});
return {
strategy,
columns: locatorColumns,
valueColumns,
hiddenColumns: hiddenColumns.length > 0 ? [...hiddenColumns] : undefined,
writableColumns,
readOnly: false,
};
};
if (primaryKeys.length > 0) {
plan.pkColumns = primaryKeys;
plan.editLocator = buildColumnLocator('primary-key', primaryKeys);
} else {
const uniqueKeyGroups = resolveUniqueKeyGroupsFromIndexes(indexes);
const uniqueKeyGroup = uniqueKeyGroups.find((group) => group.length > 0);
if (uniqueKeyGroup) {
plan.editLocator = buildColumnLocator('unique-key', uniqueKeyGroup);
} else if (isOracleLikeDialect(dbType)) {
appendExpressions.push(buildQueryRowIDExpression(dbType));
plan.editLocator = {
strategy: 'oracle-rowid',
columns: ['ROWID'],
valueColumns: [ORACLE_ROWID_LOCATOR_COLUMN],
hiddenColumns: [ORACLE_ROWID_LOCATOR_COLUMN],
writableColumns,
readOnly: false,
};
} else {
const reason = !resIndexes?.success
? '无法加载唯一索引元数据,无法安全提交修改。'
: '未检测到主键或可用唯一索引,无法安全提交修改。';
plan.editLocator = buildQueryReadOnlyLocator(reason);
plan.warning = `查询结果保持只读:${tableRef.metadataDbName}.${tableRef.metadataTableName} ${reason}`;
}
}
plan.executedSql = appendQuerySelectExpressions(statement, appendExpressions);
return plan;
} catch {
const reason = `无法加载 ${tableRef.metadataDbName}.${tableRef.metadataTableName} 的主键/唯一索引元数据,无法安全提交修改。`;
plan.editLocator = buildQueryReadOnlyLocator(reason);
plan.warning = `查询结果保持只读:${reason}`;
return plan;
}
};
const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isActive = true }) => {
const [query, setQuery] = useState(tab.query || 'SELECT * FROM ');
@@ -197,6 +496,7 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
columns: string[];
tableName?: string;
pkColumns: string[];
editLocator?: EditRowLocator;
readOnly: boolean;
truncated?: boolean;
pkLoading?: boolean;
@@ -1184,359 +1484,6 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
return statements;
};
const getLeadingKeyword = (sql: string): string => {
const text = (sql || '').replace(/\r\n/g, '\n');
const isWS = (ch: string) => ch === ' ' || ch === '\t' || ch === '\n' || ch === '\r';
const isWord = (ch: string) => /[A-Za-z0-9_]/.test(ch);
let inSingle = false;
let inDouble = false;
let inBacktick = false;
let escaped = false;
let inLineComment = false;
let inBlockComment = false;
let dollarTag: string | null = null;
for (let i = 0; i < text.length; i++) {
const ch = text[i];
const next = i + 1 < text.length ? text[i + 1] : '';
const prev = i > 0 ? text[i - 1] : '';
const next2 = i + 2 < text.length ? text[i + 2] : '';
if (!inSingle && !inDouble && !inBacktick) {
if (inLineComment) {
if (ch === '\n') inLineComment = false;
continue;
}
if (inBlockComment) {
if (ch === '*' && next === '/') {
i++;
inBlockComment = false;
}
continue;
}
if (ch === '/' && next === '*') {
i++;
inBlockComment = true;
continue;
}
if (ch === '#') {
inLineComment = true;
continue;
}
if (ch === '-' && next === '-' && (i === 0 || isWS(prev)) && (next2 === '' || isWS(next2))) {
i++;
inLineComment = true;
continue;
}
if (dollarTag) {
if (text.startsWith(dollarTag, i)) {
i += dollarTag.length - 1;
dollarTag = null;
}
continue;
}
if (ch === '$') {
const m = text.slice(i).match(/^\$[A-Za-z0-9_]*\$/);
if (m && m[0]) {
dollarTag = m[0];
i += dollarTag.length - 1;
continue;
}
}
}
if (escaped) {
escaped = false;
continue;
}
if ((inSingle || inDouble) && ch === '\\') {
escaped = true;
continue;
}
if (!inDouble && !inBacktick && ch === '\'') {
inSingle = !inSingle;
continue;
}
if (!inSingle && !inBacktick && ch === '"') {
inDouble = !inDouble;
continue;
}
if (!inSingle && !inDouble && ch === '`') {
inBacktick = !inBacktick;
continue;
}
if (inSingle || inDouble || inBacktick || dollarTag) continue;
if (isWS(ch)) continue;
if (isWord(ch)) {
let j = i;
while (j < text.length && isWord(text[j])) j++;
return text.slice(i, j).toLowerCase();
}
return '';
}
return '';
};
const splitSqlTail = (sql: string): { main: string; tail: string } => {
const text = (sql || '').replace(/\r\n/g, '\n');
const isWS = (ch: string) => ch === ' ' || ch === '\t' || ch === '\n' || ch === '\r';
let inSingle = false;
let inDouble = false;
let inBacktick = false;
let escaped = false;
let inLineComment = false;
let inBlockComment = false;
let dollarTag: string | null = null;
let lastMeaningful = -1;
for (let i = 0; i < text.length; i++) {
const ch = text[i];
const next = i + 1 < text.length ? text[i + 1] : '';
const prev = i > 0 ? text[i - 1] : '';
const next2 = i + 2 < text.length ? text[i + 2] : '';
if (!inSingle && !inDouble && !inBacktick) {
if (dollarTag) {
if (text.startsWith(dollarTag, i)) {
lastMeaningful = i + dollarTag.length - 1;
i += dollarTag.length - 1;
dollarTag = null;
} else if (!isWS(ch)) {
lastMeaningful = i;
}
continue;
}
if (inLineComment) {
if (ch === '\n') inLineComment = false;
continue;
}
if (inBlockComment) {
if (ch === '*' && next === '/') {
i++;
inBlockComment = false;
}
continue;
}
// Start comments
if (ch === '/' && next === '*') {
i++;
inBlockComment = true;
continue;
}
if (ch === '#') {
inLineComment = true;
continue;
}
if (ch === '-' && next === '-' && (i === 0 || isWS(prev)) && (next2 === '' || isWS(next2))) {
i++;
inLineComment = true;
continue;
}
if (ch === '$') {
const m = text.slice(i).match(/^\$[A-Za-z0-9_]*\$/);
if (m && m[0]) {
dollarTag = m[0];
lastMeaningful = i + dollarTag.length - 1;
i += dollarTag.length - 1;
continue;
}
}
}
if (escaped) {
escaped = false;
} else if ((inSingle || inDouble) && ch === '\\') {
escaped = true;
} else {
if (!inDouble && !inBacktick && ch === '\'') inSingle = !inSingle;
else if (!inSingle && !inBacktick && ch === '"') inDouble = !inDouble;
else if (!inSingle && !inDouble && ch === '`') inBacktick = !inBacktick;
}
if (!inLineComment && !inBlockComment && !isWS(ch)) {
lastMeaningful = i;
}
}
if (lastMeaningful < 0) return { main: '', tail: text };
return { main: text.slice(0, lastMeaningful + 1), tail: text.slice(lastMeaningful + 1) };
};
const findTopLevelKeyword = (sql: string, keyword: string): number => {
const text = sql;
const kw = keyword.toLowerCase();
const isWS = (ch: string) => ch === ' ' || ch === '\t' || ch === '\n' || ch === '\r';
const isWord = (ch: string) => /[A-Za-z0-9_]/.test(ch);
let inSingle = false;
let inDouble = false;
let inBacktick = false;
let escaped = false;
let inLineComment = false;
let inBlockComment = false;
let dollarTag: string | null = null;
let parenDepth = 0;
for (let i = 0; i < text.length; i++) {
const ch = text[i];
const next = i + 1 < text.length ? text[i + 1] : '';
const prev = i > 0 ? text[i - 1] : '';
const next2 = i + 2 < text.length ? text[i + 2] : '';
if (!inSingle && !inDouble && !inBacktick) {
if (inLineComment) {
if (ch === '\n') inLineComment = false;
continue;
}
if (inBlockComment) {
if (ch === '*' && next === '/') {
i++;
inBlockComment = false;
}
continue;
}
if (ch === '/' && next === '*') {
i++;
inBlockComment = true;
continue;
}
if (ch === '#') {
inLineComment = true;
continue;
}
if (ch === '-' && next === '-' && (i === 0 || isWS(prev)) && (next2 === '' || isWS(next2))) {
i++;
inLineComment = true;
continue;
}
if (dollarTag) {
if (text.startsWith(dollarTag, i)) {
i += dollarTag.length - 1;
dollarTag = null;
}
continue;
}
if (ch === '$') {
const m = text.slice(i).match(/^\$[A-Za-z0-9_]*\$/);
if (m && m[0]) {
dollarTag = m[0];
i += dollarTag.length - 1;
continue;
}
}
}
if (escaped) {
escaped = false;
continue;
}
if ((inSingle || inDouble) && ch === '\\') {
escaped = true;
continue;
}
if (!inDouble && !inBacktick && ch === '\'') {
inSingle = !inSingle;
continue;
}
if (!inSingle && !inBacktick && ch === '"') {
inDouble = !inDouble;
continue;
}
if (!inSingle && !inDouble && ch === '`') {
inBacktick = !inBacktick;
continue;
}
if (inSingle || inDouble || inBacktick || dollarTag) continue;
if (ch === '(') { parenDepth++; continue; }
if (ch === ')') { if (parenDepth > 0) parenDepth--; continue; }
if (parenDepth !== 0) continue;
if (!isWord(ch)) continue;
if (text.slice(i, i + kw.length).toLowerCase() !== kw) continue;
const before = i - 1 >= 0 ? text[i - 1] : '';
const after = i + kw.length < text.length ? text[i + kw.length] : '';
if ((before && isWord(before)) || (after && isWord(after))) continue;
return i;
}
return -1;
};
const applyAutoLimit = (sql: string, dbType: string, maxRows: number): { sql: string; applied: boolean; maxRows: number } => {
if (!Number.isFinite(maxRows) || maxRows <= 0) return { sql, applied: false, maxRows };
const normalizedType = (dbType || 'mysql').toLowerCase();
// 只对 SELECT 语句自动加限制
const keyword = getLeadingKeyword(sql);
if (keyword !== 'SELECT') return { sql, applied: false, maxRows };
const { main, tail } = splitSqlTail(sql);
if (!main.trim()) return { sql, applied: false, maxRows };
const fromPos = findTopLevelKeyword(main, 'from');
const limitPos = findTopLevelKeyword(main, 'limit');
// 已有 LIMIT → 不注入
if (limitPos >= 0 && (fromPos < 0 || limitPos > fromPos)) return { sql, applied: false, maxRows };
const fetchPos = findTopLevelKeyword(main, 'fetch');
// 已有 FETCH → 不注入
if (fetchPos >= 0 && (fromPos < 0 || fetchPos > fromPos)) return { sql, applied: false, maxRows };
// SQL Server / mssql: 检查是否已有 TOP未有则注入 SELECT TOP N
if (normalizedType === 'sqlserver' || normalizedType === 'mssql') {
const topPos = findTopLevelKeyword(main, 'top');
if (topPos >= 0) return { sql, applied: false, maxRows }; // 已有 TOP
// 在 SELECT 关键字之后插入 TOP N
const selectPos = findTopLevelKeyword(main, 'select');
if (selectPos < 0) return { sql, applied: false, maxRows };
const afterSelect = selectPos + 'SELECT'.length;
// 处理 SELECT DISTINCT 的情况
const restAfterSelect = main.slice(afterSelect);
const distinctMatch = restAfterSelect.match(/^(\s+DISTINCT\b)/i);
const insertOffset = distinctMatch ? afterSelect + distinctMatch[1].length : afterSelect;
const nextMain = main.slice(0, insertOffset) + ` TOP ${maxRows}` + main.slice(insertOffset);
return { sql: nextMain + tail, applied: true, maxRows };
}
// Oracle / Dameng: 使用 FETCH FIRST N ROWS ONLYOracle 12c+ 标准语法)
if (normalizedType === 'oracle' || normalizedType === 'dameng') {
// 检查是否已有 ROWNUM 限制
const rownumPos = findTopLevelKeyword(main, 'rownum');
if (rownumPos >= 0) return { sql, applied: false, maxRows };
const offsetPos = findTopLevelKeyword(main, 'offset');
if (offsetPos >= 0 && (fromPos < 0 || offsetPos > fromPos)) return { sql, applied: false, maxRows };
const nextMain = main.trimEnd() + ` FETCH FIRST ${maxRows} ROWS ONLY`;
return { sql: nextMain + tail, applied: true, maxRows };
}
// 通用 LIMIT 语法MySQL, PostgreSQL, SQLite, ClickHouse, DuckDB 等)
const offsetPos = findTopLevelKeyword(main, 'offset');
const forPos = findTopLevelKeyword(main, 'for');
const lockPos = findTopLevelKeyword(main, 'lock');
const candidates = [offsetPos, forPos, lockPos]
.filter(pos => pos >= 0 && (fromPos < 0 || pos > fromPos));
const insertAt = candidates.length > 0 ? Math.min(...candidates) : main.length;
const before = main.slice(0, insertAt).trimEnd();
const after = main.slice(insertAt).trimStart();
const nextMain = [before, `LIMIT ${maxRows}`, after].filter(Boolean).join(' ').trim();
return { sql: nextMain + tail, applied: true, maxRows };
};
const getSelectedSQL = (): string => {
const editor = editorRef.current;
if (!editor) return '';
@@ -1662,8 +1609,10 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
try {
const rawSQL = getSelectedSQL() || currentQuery;
const dbType = String((buildRpcConnectionConfig(config) as any).type || 'mysql');
const normalizedDbType = dbType.trim().toLowerCase();
const rpcConfig = buildRpcConnectionConfig(config) as any;
const dbType = String(rpcConfig.type || 'mysql');
const driver = String((config as any).driver || '');
const normalizedDbType = String(resolveSqlDialect(dbType, driver)).trim().toLowerCase();
const normalizedRawSQL = String(rawSQL || '').replace(//g, ';');
// MongoDB 仍走逐条执行的旧路径
@@ -1703,6 +1652,12 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
executedSql = shellConvert.command;
}
}
if (wantsLimitProbe) {
const limitResult = applyMongoQueryAutoLimit(executedSql, maxRows);
if (limitResult.applied) {
executedSql = limitResult.command;
}
}
const startTime = Date.now();
let queryId: string;
try {
@@ -1783,26 +1738,36 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
} else {
// 非 MongoDB使用 DBQueryMulti 一次性执行多条 SQL后端返回多结果集
let fullSQL = normalizedRawSQL;
if (!fullSQL.trim()) {
const sourceStatements = splitSQLStatements(normalizedRawSQL);
if (sourceStatements.length === 0) {
message.info('没有可执行的 SQL。');
setResultSets([]);
setActiveResultKey('');
return;
}
const forceReadOnlyResult = connCaps.forceReadOnlyQueryResult;
const statementPlans: QueryStatementPlan[] = [];
for (const statement of sourceStatements) {
statementPlans.push(await resolveQueryLocatorPlan({
statement,
dbType: normalizedDbType,
currentDb,
config,
forceReadOnly: forceReadOnlyResult,
}));
}
// 自动给 SELECT 语句注入行数限制(防止大结果集卡死)
const maxRowsForLimit = Number(queryOptions?.maxRows) || 0;
let anyLimitApplied = false;
if (Number.isFinite(maxRowsForLimit) && maxRowsForLimit > 0) {
const stmts = splitSQLStatements(fullSQL);
const limitedStmts = stmts.map(s => {
const result = applyAutoLimit(s, normalizedDbType, maxRowsForLimit);
if (result.applied) anyLimitApplied = true;
return result.sql;
});
fullSQL = limitedStmts.join(';\n');
}
const executablePlans = statementPlans.map((plan) => {
if (!Number.isFinite(maxRowsForLimit) || maxRowsForLimit <= 0) return plan;
const result = applyQueryAutoLimit(plan.executedSql, normalizedDbType, maxRowsForLimit, driver);
if (result.applied) anyLimitApplied = true;
return { ...plan, executedSql: result.sql };
});
const fullSQL = executablePlans.map((plan) => plan.executedSql).join(';\n');
const startTime = Date.now();
let queryId: string;
@@ -1859,16 +1824,13 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
const resultSetDataArray = Array.isArray(res.data) ? (res.data as any[]) : [];
const nextResultSets: ResultSet[] = [];
const maxRows = Number(queryOptions?.maxRows) || 0;
const forceReadOnlyResult = connCaps.forceReadOnlyQueryResult;
let anyTruncated = false;
const pendingPk: Array<{ resultKey: string; tableName: string }> = [];
// 前端也拆分语句用于匹配原始 SQL展示和表名检测
const statements = splitSQLStatements(fullSQL);
for (let idx = 0; idx < resultSetDataArray.length; idx++) {
const rsData = resultSetDataArray[idx];
const rawStatement = (idx < statements.length) ? statements[idx] : '';
const plan = executablePlans[idx];
const originalSql = plan?.originalSql || '';
const executedSql = plan?.executedSql || originalSql;
// 检查是否为 affectedRows 类结果集
const isAffectedResult = Array.isArray(rsData.rows) && rsData.rows.length === 1
@@ -1881,8 +1843,8 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
(row as any)[GONAVI_ROW_KEY] = 0;
nextResultSets.push({
key: `result-${idx + 1}`,
sql: rawStatement,
exportSql: rawStatement,
sql: executedSql,
exportSql: originalSql,
rows: [row],
columns: ['affectedRows'],
pkColumns: [],
@@ -1905,32 +1867,18 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
if (row && typeof row === 'object') row[GONAVI_ROW_KEY] = i;
});
let simpleTableName: string | undefined = undefined;
if (rawStatement) {
// 支持多行 SQLSELECT [cols] FROM [schema.]table [WHERE...] [ORDER BY...] [LIMIT...] 等
// JOIN 查询表名歧义,不提取
const hasJoin = /\bJOIN\b/i.test(rawStatement);
const tableMatch = !hasJoin
? rawStatement.match(/^\s*SELECT\s+.+?\s+FROM\s+(?:[\w`"\[\].]+\.)?[`"\[]?(\w+)[`"\]]?\s*(?:$|[\s;])/im)
: null;
if (tableMatch) {
simpleTableName = tableMatch[1];
if (!forceReadOnlyResult) {
pendingPk.push({ resultKey: `result-${idx + 1}`, tableName: simpleTableName });
}
}
}
const tableRef = plan?.tableRef;
const editLocator = plan?.editLocator;
nextResultSets.push({
key: `result-${idx + 1}`,
sql: rawStatement,
exportSql: rawStatement,
sql: executedSql,
exportSql: originalSql,
rows,
columns: cols,
tableName: simpleTableName,
pkColumns: [],
readOnly: true,
pkLoading: !!simpleTableName,
tableName: tableRef?.tableName,
pkColumns: plan?.pkColumns || [],
editLocator,
readOnly: forceReadOnlyResult || !editLocator || editLocator.readOnly,
truncated
});
}
@@ -1939,21 +1887,8 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
setResultSets(nextResultSets);
setActiveResultKey(nextResultSets[0]?.key || '');
pendingPk.forEach(({ resultKey, tableName }) => {
DBGetColumns(buildRpcConnectionConfig(config) as any, currentDb, tableName)
.then((resCols: any) => {
if (runSeqRef.current !== runSeq) return;
if (!resCols?.success) {
setResultSets(prev => prev.map(rs => rs.key === resultKey ? { ...rs, pkLoading: false, readOnly: false } : rs));
return;
}
const primaryKeys = (resCols.data as ColumnDefinition[]).filter(c => c.key === 'PRI').map(c => c.name);
setResultSets(prev => prev.map(rs => rs.key === resultKey ? { ...rs, pkColumns: primaryKeys, pkLoading: false, readOnly: false } : rs));
})
.catch(() => {
if (runSeqRef.current !== runSeq) return;
setResultSets(prev => prev.map(rs => rs.key === resultKey ? { ...rs, pkLoading: false, readOnly: false } : rs));
});
executablePlans.forEach((plan) => {
if (plan.warning) message.warning(plan.warning);
});
// 后端附带的提示信息(如数据源不支持原生多语句执行的回退提示)
@@ -2486,6 +2421,7 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
dbName={currentDb}
connectionId={currentConnectionId}
pkColumns={rs.pkColumns}
editLocator={rs.editLocator}
onReload={() => handleReloadResult(rs.key, rs.sql)}
readOnly={rs.readOnly}
/>

View File

@@ -36,9 +36,9 @@ import { Tree, message, Dropdown, MenuProps, Input, Button, Modal, Form, Badge,
} from '@ant-design/icons';
import { useStore } from '../store';
import { buildOverlayWorkbenchTheme } from '../utils/overlayWorkbenchTheme';
import { SavedConnection, ExternalSQLTreeEntry, JVMCapability, JVMResourceSummary } from '../types';
import { SavedConnection, ExternalSQLTreeEntry, JVMCapability, JVMResourceSummary } from '../types';
import { getDbIcon } from './DatabaseIcons';
import { DBGetDatabases, DBGetTables, DBQuery, DBShowCreateTable, ExportTable, OpenSQLFile, ExecuteSQLFile, CancelSQLFileExecution, CreateDatabase, RenameDatabase, DropDatabase, RenameTable, DropTable, DropView, DropFunction, RenameView, SelectSQLDirectory, ListSQLDirectory, ReadSQLFile, JVMProbeCapabilities } from '../../wailsjs/go/app/App';
import { DBGetDatabases, DBGetTables, DBQuery, DBShowCreateTable, ExportTable, OpenSQLFile, ExecuteSQLFile, CancelSQLFileExecution, CreateDatabase, RenameDatabase, DropDatabase, RenameTable, DropTable, DropView, DropFunction, RenameView, SelectSQLDirectory, ListSQLDirectory, ReadSQLFile, JVMProbeCapabilities, GetDriverStatusList } from '../../wailsjs/go/app/App';
import { getTableDataDangerActionMeta, supportsTableTruncateAction, type TableDataDangerActionKind } from './tableDataDangerActions';
import { EventsOn } from '../../wailsjs/runtime/runtime';
import { isMacLikePlatform, normalizeOpacityForPlatform, resolveAppearanceValues } from '../utils/appearance';
@@ -81,6 +81,33 @@ interface BatchObjectItem {
dataRef: any;
}
type DriverStatusSnapshot = {
type: string;
name: string;
connectable: boolean;
expectedRevision?: string;
needsUpdate?: boolean;
updateReason?: string;
message?: string;
};
const DRIVER_STATUS_CACHE_TTL_MS = 30_000;
const normalizeDriverType = (value: string): string => {
const normalized = String(value || '').trim().toLowerCase();
if (normalized === 'postgresql') return 'postgres';
if (normalized === 'doris') return 'diros';
return normalized;
};
const resolveSavedConnectionDriverType = (conn: SavedConnection | undefined): string => {
const type = normalizeDriverType(conn?.config?.type || '');
if (type !== 'custom') {
return type;
}
return normalizeDriverType(conn?.config?.driver || '');
};
const SEARCH_SCOPE_OPTIONS: Array<{ value: SearchScope; label: string }> = [
{ value: 'smart', label: '智能' },
{ value: 'object', label: '表对象' },
@@ -211,10 +238,12 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
const [autoExpandParent, setAutoExpandParent] = useState(true);
const [loadedKeys, setLoadedKeys] = useState<React.Key[]>([]);
const [selectedKeys, setSelectedKeys] = useState<React.Key[]>([]);
const selectedNodesRef = useRef<any[]>([]);
const loadingNodesRef = useRef<Set<string>>(new Set());
const clickTimerRef = useRef<ReturnType<typeof setTimeout> | null>(null);
const [contextMenu, setContextMenu] = useState<{ x: number, y: number, items: MenuProps['items'] } | null>(null);
const selectedNodesRef = useRef<any[]>([]);
const loadingNodesRef = useRef<Set<string>>(new Set());
const clickTimerRef = useRef<ReturnType<typeof setTimeout> | null>(null);
const driverStatusCacheRef = useRef<{ fetchedAt: number; items: Record<string, DriverStatusSnapshot> } | null>(null);
const driverUpdateWarningKeysRef = useRef<Set<string>>(new Set());
const [contextMenu, setContextMenu] = useState<{ x: number, y: number, items: MenuProps['items'] } | null>(null);
// Virtual Scroll State
const [treeHeight, setTreeHeight] = useState(500);
@@ -956,13 +985,72 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
const typeLabel = normalizedType === 'PROCEDURE' ? 'P' : 'F';
routines.push({ displayName: `${fullName} [${typeLabel}]`, routineName: fullName, routineType: normalizedType });
});
});
return { routines, supported: hasSuccessfulQuery };
};
});
return { routines, supported: hasSuccessfulQuery };
};
const loadDatabases = async (node: any) => {
const conn = node.dataRef as SavedConnection;
const loadKey = `dbs-${conn.id}`;
const fetchDriverStatusMap = async (): Promise<Record<string, DriverStatusSnapshot>> => {
const cached = driverStatusCacheRef.current;
if (cached && Date.now() - cached.fetchedAt < DRIVER_STATUS_CACHE_TTL_MS) {
return cached.items;
}
const result: Record<string, DriverStatusSnapshot> = {};
const res = await GetDriverStatusList('', '');
if (!res?.success) {
return result;
}
const data = (res.data || {}) as any;
const drivers = Array.isArray(data.drivers) ? data.drivers : [];
drivers.forEach((item: any) => {
const type = normalizeDriverType(String(item.type || '').trim());
if (!type) return;
result[type] = {
type,
name: String(item.name || item.type || type).trim(),
connectable: !!item.connectable,
expectedRevision: String(item.expectedRevision || '').trim() || undefined,
needsUpdate: !!item.needsUpdate,
updateReason: String(item.updateReason || '').trim() || undefined,
message: String(item.message || '').trim() || undefined,
};
});
driverStatusCacheRef.current = { fetchedAt: Date.now(), items: result };
return result;
};
const warnIfConnectionDriverAgentNeedsUpdate = async (conn: SavedConnection) => {
try {
const driverType = resolveSavedConnectionDriverType(conn);
if (!driverType || driverType === 'custom') {
return;
}
const statusMap = await fetchDriverStatusMap();
const status = statusMap[driverType];
if (!status?.connectable || !status.needsUpdate) {
return;
}
const revisionKey = status.expectedRevision || status.updateReason || status.message || 'unknown';
const warningKey = `${conn.id}:${driverType}:${revisionKey}`;
if (driverUpdateWarningKeysRef.current.has(warningKey)) {
return;
}
driverUpdateWarningKeysRef.current.add(warningKey);
const driverName = status.name || driverType;
const reason = status.message || status.updateReason || `${driverName} driver-agent 与当前 GoNavi 版本要求不一致`;
message.warning({
content: `${driverName} 驱动代理需要重装:${reason}`,
key: `driver-agent-update-${conn.id}`,
duration: 10,
});
} catch (error) {
console.warn('检查驱动代理更新状态失败', error);
}
};
const loadDatabases = async (node: any) => {
const conn = node.dataRef as SavedConnection;
void warnIfConnectionDriverAgentNeedsUpdate(conn);
const loadKey = `dbs-${conn.id}`;
if (loadingNodesRef.current.has(loadKey)) return;
loadingNodesRef.current.add(loadKey);
const config = {
@@ -1845,8 +1933,9 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
setIsBatchModalOpen(true);
};
const loadDatabasesForBatch = async (conn: SavedConnection) => {
const config = {
const loadDatabasesForBatch = async (conn: SavedConnection) => {
void warnIfConnectionDriverAgentNeedsUpdate(conn);
const config = {
...conn.config,
port: Number(conn.config.port),
password: conn.config.password || "",
@@ -2154,10 +2243,11 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
setIsBatchDbModalOpen(true);
};
const loadDatabasesForDbBatch = async (conn: SavedConnection) => {
setBatchConnContext(conn);
const loadDatabasesForDbBatch = async (conn: SavedConnection) => {
setBatchConnContext(conn);
void warnIfConnectionDriverAgentNeedsUpdate(conn);
const config = {
const config = {
...conn.config,
port: Number(conn.config.port),
password: conn.config.password || "",

View File

@@ -187,6 +187,29 @@ describe('store appearance persistence', () => {
expect(useStore.getState().connections[0]?.iconColor).toBe('#2f855a');
});
it('normalizes ClickHouse protocol override when replacing saved connections', async () => {
const { useStore } = await importStore();
useStore.getState().replaceConnections([
{
id: 'clickhouse-http',
name: 'ClickHouse HTTP',
config: {
id: 'clickhouse-http',
type: 'clickhouse',
host: 'clickhouse.local',
port: 8125,
user: 'default',
clickHouseProtocol: 'https' as any,
},
},
]);
expect(useStore.getState().connections[0]?.config.clickHouseProtocol).toBe(
'http',
);
});
it('keeps legacy global proxy password during hydration until explicit cleanup', async () => {
storage.setItem('lite-db-storage', JSON.stringify({
state: {

View File

@@ -163,6 +163,15 @@ const toTrimmedString = (value: unknown, fallback = ""): string => {
return fallback;
};
const normalizeClickHouseProtocol = (
value: unknown,
): "auto" | "http" | "native" => {
const text = toTrimmedString(value).toLowerCase();
if (text === "http" || text === "https") return "http";
if (text === "native" || text === "tcp") return "native";
return "auto";
};
const normalizePort = (value: unknown, fallbackPort: number): number => {
const parsed = Number(value);
if (!Number.isFinite(parsed)) return fallbackPort;
@@ -513,6 +522,12 @@ const sanitizeConnectionConfig = (value: unknown): ConnectionConfig => {
safeConfig.redisDB = normalizeIntegerInRange(raw.redisDB, 0, 0, 15);
}
if (type === "clickhouse") {
safeConfig.clickHouseProtocol = normalizeClickHouseProtocol(
raw.clickHouseProtocol,
);
}
if (type === "custom") {
safeConfig.driver = toTrimmedString(raw.driver);
safeConfig.dsn = toTrimmedString(raw.dsn).slice(0, MAX_URI_LENGTH);

View File

@@ -297,6 +297,7 @@ export interface ConnectionConfig {
timeout?: number;
redisDB?: number; // Redis database index (0-15)
uri?: string; // Connection URI for copy/paste
clickHouseProtocol?: "auto" | "http" | "native"; // ClickHouse connection protocol override
hosts?: string[]; // Multi-host addresses: host:port
topology?: "single" | "replica" | "cluster";
mysqlReplicaUser?: string;

View File

@@ -39,6 +39,19 @@ describe('buildRpcConnectionConfig', () => {
expect(result.database).toBe('app');
});
it('preserves ClickHouse protocol override for RPC calls', () => {
const result = buildRpcConnectionConfig({
id: 'conn-clickhouse',
type: 'clickhouse',
host: 'clickhouse.local',
port: 8125,
user: 'default',
clickHouseProtocol: 'http',
} as any);
expect(result.clickHouseProtocol).toBe('http');
});
it('fills default nested config blocks needed by RPC calls', () => {
const result = buildRpcConnectionConfig({
id: 'conn-redis',

View File

@@ -1,6 +1,8 @@
import { describe, expect, it } from 'vitest';
import { convertMongoShellToJsonCommand } from './mongodb';
import { applyMongoQueryAutoLimit, convertMongoShellToJsonCommand } from './mongodb';
const parseCommand = (command: string | undefined) => JSON.parse(command || '{}');
describe('convertMongoShellToJsonCommand', () => {
it('converts show dbs shell shortcut to listDatabases command', () => {
@@ -16,4 +18,105 @@ describe('convertMongoShellToJsonCommand', () => {
command: JSON.stringify({ listCollections: 1, filter: {}, nameOnly: true }),
});
});
it('converts find shell commands without adding implicit limit', () => {
const result = convertMongoShellToJsonCommand('db.users.find({ active: true })');
expect(result.recognized).toBe(true);
expect(parseCommand(result.command)).toEqual({
find: 'users',
filter: { active: true },
});
});
it('keeps explicit find limit values from shell commands', () => {
const result = convertMongoShellToJsonCommand('db.users.find({}).limit(10)');
expect(parseCommand(result.command)).toEqual({
find: 'users',
filter: {},
limit: 10,
});
});
it('keeps explicit zero limit values from shell commands', () => {
const result = convertMongoShellToJsonCommand('db.users.find({}).limit(0)');
expect(parseCommand(result.command)).toEqual({
find: 'users',
filter: {},
limit: 0,
});
});
});
describe('applyMongoQueryAutoLimit', () => {
it('adds limit to raw Mongo find commands', () => {
const result = applyMongoQueryAutoLimit('{"find":"users","filter":{}}', 500);
expect(result.applied).toBe(true);
expect(parseCommand(result.command)).toEqual({
find: 'users',
filter: {},
limit: 500,
});
});
it('adds limit after shell find conversion', () => {
const shell = convertMongoShellToJsonCommand('db.users.find({ active: true })');
const result = applyMongoQueryAutoLimit(shell.command || '', 500);
expect(result.applied).toBe(true);
expect(parseCommand(result.command)).toEqual({
find: 'users',
filter: { active: true },
limit: 500,
});
});
it('does not replace explicit find limits', () => {
const result = applyMongoQueryAutoLimit('{"find":"users","filter":{},"limit":10}', 500);
expect(result.applied).toBe(false);
expect(parseCommand(result.command)).toEqual({
find: 'users',
filter: {},
limit: 10,
});
});
it('adds $limit to read-only aggregate pipelines', () => {
const result = applyMongoQueryAutoLimit('{"aggregate":"users","pipeline":[{"$match":{"active":true}}],"cursor":{}}', 500);
expect(result.applied).toBe(true);
expect(parseCommand(result.command)).toEqual({
aggregate: 'users',
pipeline: [
{ $match: { active: true } },
{ $limit: 500 },
],
cursor: {},
});
});
it('does not add another aggregate $limit', () => {
const command = '{"aggregate":"users","pipeline":[{"$limit":10}],"cursor":{}}';
const result = applyMongoQueryAutoLimit(command, 500);
expect(result.applied).toBe(false);
expect(result.command).toBe(command);
});
it('does not alter aggregate write pipelines', () => {
const command = '{"aggregate":"users","pipeline":[{"$match":{}},{"$out":"tmp_users"}],"cursor":{}}';
const result = applyMongoQueryAutoLimit(command, 500);
expect(result.applied).toBe(false);
expect(result.command).toBe(command);
});
it('does not limit non-read or invalid commands', () => {
expect(applyMongoQueryAutoLimit('{"count":"users","query":{}}', 500).applied).toBe(false);
expect(applyMongoQueryAutoLimit('db.users.find({})', 500).applied).toBe(false);
});
});

View File

@@ -321,7 +321,7 @@ const parseCollectionAndMethod = (raw: string): {
pos = nextPos;
} else {
let end = pos;
while (end < input.length && /[A-Za-z0-9_$.-]/.test(input[end])) end++;
while (end < input.length && /[A-Za-z0-9_$-]/.test(input[end])) end++;
collection = input.slice(pos, end).trim();
pos = end;
}
@@ -662,7 +662,7 @@ export const buildMongoFindCommand = (params: {
if (params.sort && Object.keys(params.sort).length > 0) {
command.sort = params.sort;
}
if (Number.isFinite(params.limit) && Number(params.limit) > 0) {
if (Number.isFinite(params.limit) && Number(params.limit) >= 0) {
command.limit = Math.floor(Number(params.limit));
}
if (Number.isFinite(params.skip) && Number(params.skip) > 0) {
@@ -678,6 +678,45 @@ export const buildMongoCountCommand = (collection: string, filter: Record<string
});
};
const hasOwn = (obj: Record<string, unknown>, key: string) => Object.prototype.hasOwnProperty.call(obj, key);
const isMongoCommandObject = (value: unknown): value is Record<string, unknown> => (
!!value && typeof value === 'object' && !Array.isArray(value)
);
export const applyMongoQueryAutoLimit = (
command: string,
maxRows: number,
): { command: string; applied: boolean; maxRows: number } => {
if (!Number.isFinite(maxRows) || maxRows <= 0) return { command, applied: false, maxRows };
let parsed: unknown;
try {
parsed = JSON.parse(String(command || '').trim());
} catch {
return { command, applied: false, maxRows };
}
if (!isMongoCommandObject(parsed)) return { command, applied: false, maxRows };
const nextMaxRows = Math.floor(Number(maxRows));
if (hasOwn(parsed, 'find')) {
if (hasOwn(parsed, 'limit')) return { command, applied: false, maxRows };
parsed.limit = nextMaxRows;
return { command: JSON.stringify(parsed), applied: true, maxRows };
}
if (hasOwn(parsed, 'aggregate') && Array.isArray(parsed.pipeline)) {
const pipeline = parsed.pipeline as unknown[];
const hasExplicitLimit = pipeline.some((stage) => isMongoCommandObject(stage) && hasOwn(stage, '$limit'));
const hasWriteStage = pipeline.some((stage) => isMongoCommandObject(stage) && (hasOwn(stage, '$out') || hasOwn(stage, '$merge')));
if (hasExplicitLimit || hasWriteStage) return { command, applied: false, maxRows };
pipeline.push({ $limit: nextMaxRows });
return { command: JSON.stringify(parsed), applied: true, maxRows };
}
return { command, applied: false, maxRows };
};
const buildMongoInsertCommand = (
collection: string,
documents: Record<string, unknown>[],

View File

@@ -0,0 +1,110 @@
import { describe, expect, it } from 'vitest';
import { applyQueryAutoLimit } from './queryAutoLimit';
describe('applyQueryAutoLimit', () => {
const limitDialects = [
'mysql',
'mariadb',
'diros',
'doris',
'sphinx',
'postgres',
'postgresql',
'kingbase',
'kingbase8',
'highgo',
'vastbase',
'sqlite',
'sqlite3',
'duckdb',
'clickhouse',
'tdengine',
];
it.each(limitDialects)('adds generic LIMIT for %s connections', (dbType) => {
expect(applyQueryAutoLimit('SELECT * FROM users', dbType, 500).sql)
.toBe('SELECT * FROM users LIMIT 500');
});
it.each([
['oracle'],
['dameng'],
['dm'],
['dm8'],
])('adds FETCH FIRST limit for %s connections', (dbType) => {
expect(applyQueryAutoLimit('SELECT * FROM MYCIMLED.EDC_LOG', dbType, 500).sql)
.toBe('SELECT * FROM MYCIMLED.EDC_LOG FETCH FIRST 500 ROWS ONLY');
});
it.each([
['sqlserver'],
['mssql'],
['sql_server'],
['sql-server'],
])('adds TOP limit for %s connections', (dbType) => {
expect(applyQueryAutoLimit('SELECT * FROM users', dbType, 500).sql)
.toBe('SELECT TOP 500 * FROM users');
});
it('adds SQL Server TOP after DISTINCT', () => {
expect(applyQueryAutoLimit('SELECT DISTINCT name FROM users', 'sqlserver', 500).sql)
.toBe('SELECT DISTINCT TOP 500 name FROM users');
});
it.each([
['oracle', 'SELECT * FROM users FETCH FIRST 500 ROWS ONLY'],
['dm8', 'SELECT * FROM users FETCH FIRST 500 ROWS ONLY'],
['mssql', 'SELECT TOP 500 * FROM users'],
['postgresql', 'SELECT * FROM users LIMIT 500'],
['doris', 'SELECT * FROM users LIMIT 500'],
['sqlite3', 'SELECT * FROM users LIMIT 500'],
])('uses custom driver dialect %s', (driver, expected) => {
expect(applyQueryAutoLimit('SELECT * FROM users', 'custom', 500, driver).sql)
.toBe(expected);
});
it('keeps trailing semicolon and comments after injected Oracle limit', () => {
expect(applyQueryAutoLimit('SELECT * FROM MYCIMLED.EDC_LOG; -- preview', 'oracle', 500).sql)
.toBe('SELECT * FROM MYCIMLED.EDC_LOG FETCH FIRST 500 ROWS ONLY; -- preview');
});
it('does not add another generic limit when SQL already limits rows', () => {
expect(applyQueryAutoLimit('SELECT * FROM users LIMIT 10', 'mysql', 500).applied)
.toBe(false);
expect(applyQueryAutoLimit('SELECT * FROM users OFFSET 10 LIMIT 10', 'postgres', 500).applied)
.toBe(false);
});
it('does not treat nested LIMIT as the outer query limit', () => {
expect(applyQueryAutoLimit('SELECT * FROM (SELECT * FROM users LIMIT 10) t', 'postgres', 500).sql)
.toBe('SELECT * FROM (SELECT * FROM users LIMIT 10) t LIMIT 500');
});
it('does not add another Oracle limit when Oracle SQL already limits rows', () => {
expect(applyQueryAutoLimit('SELECT * FROM users WHERE ROWNUM <= 10', 'oracle', 500).applied)
.toBe(false);
expect(applyQueryAutoLimit('SELECT * FROM users FETCH FIRST 10 ROWS ONLY', 'oracle', 500).applied)
.toBe(false);
});
it('does not add another SQL Server limit when SQL already uses TOP', () => {
expect(applyQueryAutoLimit('SELECT TOP 10 * FROM users', 'sqlserver', 500).applied)
.toBe(false);
});
it('adds generic LIMIT before locking clauses', () => {
expect(applyQueryAutoLimit('SELECT * FROM users FOR UPDATE', 'mysql', 500).sql)
.toBe('SELECT * FROM users LIMIT 500 FOR UPDATE');
});
it('adds generic LIMIT before OFFSET clauses', () => {
expect(applyQueryAutoLimit('SELECT * FROM users OFFSET 10', 'postgres', 500).sql)
.toBe('SELECT * FROM users LIMIT 500 OFFSET 10');
});
it('does not limit non-select statements', () => {
expect(applyQueryAutoLimit('UPDATE users SET name = \'a\'', 'mysql', 500).applied)
.toBe(false);
});
});

View File

@@ -0,0 +1,336 @@
import { resolveSqlDialect } from './sqlDialect';
const isWS = (ch: string) => ch === ' ' || ch === '\t' || ch === '\n' || ch === '\r';
const isWord = (ch: string) => /[A-Za-z0-9_]/.test(ch);
const getLeadingKeyword = (sql: string): string => {
const text = (sql || '').replace(/\r\n/g, '\n');
let inSingle = false;
let inDouble = false;
let inBacktick = false;
let escaped = false;
let inLineComment = false;
let inBlockComment = false;
let dollarTag: string | null = null;
for (let i = 0; i < text.length; i++) {
const ch = text[i];
const next = i + 1 < text.length ? text[i + 1] : '';
const prev = i > 0 ? text[i - 1] : '';
const next2 = i + 2 < text.length ? text[i + 2] : '';
if (!inSingle && !inDouble && !inBacktick) {
if (inLineComment) {
if (ch === '\n') inLineComment = false;
continue;
}
if (inBlockComment) {
if (ch === '*' && next === '/') {
i++;
inBlockComment = false;
}
continue;
}
if (ch === '/' && next === '*') {
i++;
inBlockComment = true;
continue;
}
if (ch === '#') {
inLineComment = true;
continue;
}
if (ch === '-' && next === '-' && (i === 0 || isWS(prev)) && (next2 === '' || isWS(next2))) {
i++;
inLineComment = true;
continue;
}
if (dollarTag) {
if (text.startsWith(dollarTag, i)) {
i += dollarTag.length - 1;
dollarTag = null;
}
continue;
}
if (ch === '$') {
const m = text.slice(i).match(/^\$[A-Za-z0-9_]*\$/);
if (m && m[0]) {
dollarTag = m[0];
i += dollarTag.length - 1;
continue;
}
}
}
if (escaped) {
escaped = false;
continue;
}
if ((inSingle || inDouble) && ch === '\\') {
escaped = true;
continue;
}
if (!inDouble && !inBacktick && ch === "'") {
inSingle = !inSingle;
continue;
}
if (!inSingle && !inBacktick && ch === '"') {
inDouble = !inDouble;
continue;
}
if (!inSingle && !inDouble && ch === '`') {
inBacktick = !inBacktick;
continue;
}
if (inSingle || inDouble || inBacktick || dollarTag) continue;
if (isWS(ch)) continue;
if (isWord(ch)) {
let j = i;
while (j < text.length && isWord(text[j])) j++;
return text.slice(i, j).toLowerCase();
}
return '';
}
return '';
};
const splitSqlTail = (sql: string): { main: string; tail: string } => {
const text = (sql || '').replace(/\r\n/g, '\n');
let inSingle = false;
let inDouble = false;
let inBacktick = false;
let escaped = false;
let inLineComment = false;
let inBlockComment = false;
let dollarTag: string | null = null;
let lastMeaningful = -1;
for (let i = 0; i < text.length; i++) {
const ch = text[i];
const next = i + 1 < text.length ? text[i + 1] : '';
const prev = i > 0 ? text[i - 1] : '';
const next2 = i + 2 < text.length ? text[i + 2] : '';
if (!inSingle && !inDouble && !inBacktick) {
if (dollarTag) {
if (text.startsWith(dollarTag, i)) {
lastMeaningful = i + dollarTag.length - 1;
i += dollarTag.length - 1;
dollarTag = null;
} else if (!isWS(ch)) {
lastMeaningful = i;
}
continue;
}
if (inLineComment) {
if (ch === '\n') inLineComment = false;
continue;
}
if (inBlockComment) {
if (ch === '*' && next === '/') {
i++;
inBlockComment = false;
}
continue;
}
if (ch === '/' && next === '*') {
i++;
inBlockComment = true;
continue;
}
if (ch === '#') {
inLineComment = true;
continue;
}
if (ch === '-' && next === '-' && (i === 0 || isWS(prev)) && (next2 === '' || isWS(next2))) {
i++;
inLineComment = true;
continue;
}
if (ch === '$') {
const m = text.slice(i).match(/^\$[A-Za-z0-9_]*\$/);
if (m && m[0]) {
dollarTag = m[0];
lastMeaningful = i + dollarTag.length - 1;
i += dollarTag.length - 1;
continue;
}
}
}
if (escaped) {
escaped = false;
} else if ((inSingle || inDouble) && ch === '\\') {
escaped = true;
} else {
if (!inDouble && !inBacktick && ch === "'") inSingle = !inSingle;
else if (!inSingle && !inBacktick && ch === '"') inDouble = !inDouble;
else if (!inSingle && !inDouble && ch === '`') inBacktick = !inBacktick;
}
if (!inLineComment && !inBlockComment && !isWS(ch)) {
lastMeaningful = i;
}
}
if (lastMeaningful < 0) return { main: '', tail: text };
let mainEnd = lastMeaningful + 1;
while (mainEnd > 0 && (isWS(text[mainEnd - 1]) || text[mainEnd - 1] === ';' || text[mainEnd - 1] === '')) {
mainEnd--;
}
return { main: text.slice(0, mainEnd), tail: text.slice(mainEnd) };
};
const findTopLevelKeyword = (sql: string, keyword: string): number => {
const text = sql;
const kw = keyword.toLowerCase();
let inSingle = false;
let inDouble = false;
let inBacktick = false;
let escaped = false;
let inLineComment = false;
let inBlockComment = false;
let dollarTag: string | null = null;
let parenDepth = 0;
for (let i = 0; i < text.length; i++) {
const ch = text[i];
const next = i + 1 < text.length ? text[i + 1] : '';
const prev = i > 0 ? text[i - 1] : '';
const next2 = i + 2 < text.length ? text[i + 2] : '';
if (!inSingle && !inDouble && !inBacktick) {
if (inLineComment) {
if (ch === '\n') inLineComment = false;
continue;
}
if (inBlockComment) {
if (ch === '*' && next === '/') {
i++;
inBlockComment = false;
}
continue;
}
if (ch === '/' && next === '*') {
i++;
inBlockComment = true;
continue;
}
if (ch === '#') {
inLineComment = true;
continue;
}
if (ch === '-' && next === '-' && (i === 0 || isWS(prev)) && (next2 === '' || isWS(next2))) {
i++;
inLineComment = true;
continue;
}
if (dollarTag) {
if (text.startsWith(dollarTag, i)) {
i += dollarTag.length - 1;
dollarTag = null;
}
continue;
}
if (ch === '$') {
const m = text.slice(i).match(/^\$[A-Za-z0-9_]*\$/);
if (m && m[0]) {
dollarTag = m[0];
i += dollarTag.length - 1;
continue;
}
}
}
if (escaped) {
escaped = false;
continue;
}
if ((inSingle || inDouble) && ch === '\\') {
escaped = true;
continue;
}
if (!inDouble && !inBacktick && ch === "'") {
inSingle = !inSingle;
continue;
}
if (!inSingle && !inBacktick && ch === '"') {
inDouble = !inDouble;
continue;
}
if (!inSingle && !inDouble && ch === '`') {
inBacktick = !inBacktick;
continue;
}
if (inSingle || inDouble || inBacktick || dollarTag) continue;
if (ch === '(') {
parenDepth++;
continue;
}
if (ch === ')') {
if (parenDepth > 0) parenDepth--;
continue;
}
if (parenDepth !== 0) continue;
if (!isWord(ch)) continue;
if (text.slice(i, i + kw.length).toLowerCase() !== kw) continue;
const before = i - 1 >= 0 ? text[i - 1] : '';
const after = i + kw.length < text.length ? text[i + kw.length] : '';
if ((before && isWord(before)) || (after && isWord(after))) continue;
return i;
}
return -1;
};
export const applyQueryAutoLimit = (
sql: string,
dbType: string,
maxRows: number,
driver = '',
): { sql: string; applied: boolean; maxRows: number } => {
if (!Number.isFinite(maxRows) || maxRows <= 0) return { sql, applied: false, maxRows };
const normalizedType = String(resolveSqlDialect(dbType || 'mysql', driver)).toLowerCase();
const keyword = getLeadingKeyword(sql);
if (keyword !== 'select') return { sql, applied: false, maxRows };
const { main, tail } = splitSqlTail(sql);
if (!main.trim()) return { sql, applied: false, maxRows };
const fromPos = findTopLevelKeyword(main, 'from');
const limitPos = findTopLevelKeyword(main, 'limit');
if (limitPos >= 0 && (fromPos < 0 || limitPos > fromPos)) return { sql, applied: false, maxRows };
const fetchPos = findTopLevelKeyword(main, 'fetch');
if (fetchPos >= 0 && (fromPos < 0 || fetchPos > fromPos)) return { sql, applied: false, maxRows };
if (normalizedType === 'sqlserver' || normalizedType === 'mssql') {
const topPos = findTopLevelKeyword(main, 'top');
if (topPos >= 0) return { sql, applied: false, maxRows };
const selectPos = findTopLevelKeyword(main, 'select');
if (selectPos < 0) return { sql, applied: false, maxRows };
const afterSelect = selectPos + 'SELECT'.length;
const restAfterSelect = main.slice(afterSelect);
const distinctMatch = restAfterSelect.match(/^(\s+DISTINCT\b)/i);
const insertOffset = distinctMatch ? afterSelect + distinctMatch[1].length : afterSelect;
const nextMain = main.slice(0, insertOffset) + ` TOP ${maxRows}` + main.slice(insertOffset);
return { sql: nextMain + tail, applied: true, maxRows };
}
if (normalizedType === 'oracle' || normalizedType === 'dameng') {
const rownumPos = findTopLevelKeyword(main, 'rownum');
if (rownumPos >= 0) return { sql, applied: false, maxRows };
const offsetPos = findTopLevelKeyword(main, 'offset');
if (offsetPos >= 0 && (fromPos < 0 || offsetPos > fromPos)) return { sql, applied: false, maxRows };
return { sql: `${main.trimEnd()} FETCH FIRST ${maxRows} ROWS ONLY${tail}`, applied: true, maxRows };
}
const offsetPos = findTopLevelKeyword(main, 'offset');
const forPos = findTopLevelKeyword(main, 'for');
const lockPos = findTopLevelKeyword(main, 'lock');
const candidates = [offsetPos, forPos, lockPos]
.filter(pos => pos >= 0 && (fromPos < 0 || pos > fromPos));
const insertAt = candidates.length > 0 ? Math.min(...candidates) : main.length;
const before = main.slice(0, insertAt).trimEnd();
const after = main.slice(insertAt).trimStart();
const nextMain = [before, `LIMIT ${maxRows}`, after].filter(Boolean).join(' ').trim();
return { sql: nextMain + tail, applied: true, maxRows };
};

View File

@@ -0,0 +1,44 @@
import { describe, expect, it } from 'vitest';
import { extractQueryResultTableRef } from './queryResultTable';
describe('extractQueryResultTableRef', () => {
it('preserves Oracle schema-qualified table names for editing', () => {
expect(extractQueryResultTableRef('SELECT * FROM MYCIMLED.EDC_LOG FETCH FIRST 500 ROWS ONLY', 'oracle', 'ANONYMOUS'))
.toEqual({
tableName: 'MYCIMLED.EDC_LOG',
metadataDbName: 'MYCIMLED',
metadataTableName: 'EDC_LOG',
});
});
it('uses current schema for unqualified Oracle tables', () => {
expect(extractQueryResultTableRef('SELECT * FROM EDC_LOG', 'oracle', 'MYCIMLED'))
.toEqual({
tableName: 'EDC_LOG',
metadataDbName: 'MYCIMLED',
metadataTableName: 'EDC_LOG',
});
});
it('keeps existing simple table behavior for MySQL-style qualified names', () => {
expect(extractQueryResultTableRef('SELECT * FROM app.users LIMIT 500', 'mysql', 'app'))
.toEqual({
tableName: 'users',
metadataDbName: 'app',
metadataTableName: 'users',
});
});
it('does not mark join results as editable table refs', () => {
expect(extractQueryResultTableRef('SELECT * FROM users u JOIN orders o ON u.id = o.user_id', 'oracle', 'APP'))
.toBeUndefined();
});
it('does not mark grouped or distinct results as editable table refs', () => {
expect(extractQueryResultTableRef('SELECT ID FROM users GROUP BY ID', 'mysql', 'app'))
.toBeUndefined();
expect(extractQueryResultTableRef('SELECT DISTINCT ID FROM users', 'mysql', 'app'))
.toBeUndefined();
});
});

View File

@@ -0,0 +1,64 @@
export type QueryResultTableRef = {
tableName: string;
metadataDbName: string;
metadataTableName: string;
};
const stripIdentifierQuotes = (part: string): string => {
const text = String(part || '').trim();
if (!text) return '';
if ((text.startsWith('`') && text.endsWith('`')) || (text.startsWith('"') && text.endsWith('"'))) {
return text.slice(1, -1).trim();
}
if (text.startsWith('[') && text.endsWith(']')) {
return text.slice(1, -1).trim();
}
return text;
};
const normalizeQualifiedName = (raw: string): string => (
String(raw || '')
.split('.')
.map((part) => stripIdentifierQuotes(part.trim()))
.filter(Boolean)
.join('.')
);
const isOracleLikeDialect = (dialect: string): boolean => {
const normalized = String(dialect || '').trim().toLowerCase();
return normalized === 'oracle' || normalized === 'dameng' || normalized === 'dm' || normalized === 'dm8';
};
export const extractQueryResultTableRef = (
sql: string,
dialect: string,
currentDb: string,
): QueryResultTableRef | undefined => {
const text = String(sql || '').trim();
if (!text) return undefined;
if (/\b(JOIN|UNION|INTERSECT|EXCEPT|MINUS)\b/i.test(text)) return undefined;
if (/^\s*SELECT\s+DISTINCT\b/i.test(text)) return undefined;
if (/\bGROUP\s+BY\b|\bHAVING\b/i.test(text)) return undefined;
const tableMatch = text.match(/^\s*SELECT\s+.+?\s+FROM\s+((?:[`"\[]?\w+[`"\]]?)(?:\s*\.\s*(?:[`"\[]?\w+[`"\]]?)){0,2})\s*(?:$|[\s;])/im);
if (!tableMatch) return undefined;
const qualifiedName = normalizeQualifiedName(tableMatch[1]);
if (!qualifiedName) return undefined;
const parts = qualifiedName.split('.').filter(Boolean);
const metadataTableName = parts[parts.length - 1] || '';
if (!metadataTableName) return undefined;
const owner = parts.length >= 2 ? parts[parts.length - 2] : '';
const metadataDbName = owner || currentDb || '';
const tableName = isOracleLikeDialect(dialect) && owner
? `${owner}.${metadataTableName}`
: metadataTableName;
return {
tableName,
metadataDbName,
metadataTableName,
};
};

View File

@@ -0,0 +1,146 @@
import { describe, expect, it } from 'vitest';
import {
ORACLE_ROWID_LOCATOR_COLUMN,
filterHiddenLocatorColumns,
resolveEditRowLocator,
resolveRowLocatorValues,
} from './rowLocator';
const uniqueIndex = (name: string, columnName: string, seqInIndex = 1) => ({
name,
columnName,
seqInIndex,
nonUnique: 0,
indexType: 'BTREE',
});
const normalIndex = (name: string, columnName: string, seqInIndex = 1) => ({
name,
columnName,
seqInIndex,
nonUnique: 1,
indexType: 'BTREE',
});
describe('resolveEditRowLocator', () => {
it('prefers primary keys over unique indexes', () => {
expect(resolveEditRowLocator({
dbType: 'mysql',
resultColumns: ['ID', 'EMAIL'],
primaryKeys: ['ID'],
indexes: [uniqueIndex('uk_email', 'EMAIL')],
})).toEqual({
strategy: 'primary-key',
columns: ['ID'],
valueColumns: ['ID'],
readOnly: false,
});
});
it('uses a unique index when there is no primary key', () => {
expect(resolveEditRowLocator({
dbType: 'mysql',
resultColumns: ['EMAIL', 'NAME'],
indexes: [uniqueIndex('uk_email', 'EMAIL')],
})).toEqual({
strategy: 'unique-key',
columns: ['EMAIL'],
valueColumns: ['EMAIL'],
readOnly: false,
});
});
it('sorts composite unique index columns by sequence', () => {
expect(resolveEditRowLocator({
dbType: 'postgres',
resultColumns: ['TENANT_ID', 'CODE', 'NAME'],
indexes: [
uniqueIndex('uk_tenant_code', 'CODE', 2),
uniqueIndex('uk_tenant_code', 'TENANT_ID', 1),
],
})).toMatchObject({
strategy: 'unique-key',
columns: ['TENANT_ID', 'CODE'],
valueColumns: ['TENANT_ID', 'CODE'],
readOnly: false,
});
});
it('ignores non-unique indexes', () => {
expect(resolveEditRowLocator({
dbType: 'mysql',
resultColumns: ['NAME'],
indexes: [normalIndex('idx_name', 'NAME')],
})).toMatchObject({
strategy: 'none',
readOnly: true,
});
});
it('keeps results read-only when primary key columns are missing from result columns', () => {
expect(resolveEditRowLocator({
dbType: 'oracle',
resultColumns: ['NAME'],
primaryKeys: ['ID'],
})).toMatchObject({
strategy: 'none',
readOnly: true,
reason: '结果集中缺少主键列 ID无法安全提交修改。',
});
});
it('uses Oracle ROWID when no primary or unique key is available', () => {
expect(resolveEditRowLocator({
dbType: 'oracle',
resultColumns: ['NAME', ORACLE_ROWID_LOCATOR_COLUMN],
allowOracleRowID: true,
})).toEqual({
strategy: 'oracle-rowid',
columns: ['ROWID'],
valueColumns: [ORACLE_ROWID_LOCATOR_COLUMN],
hiddenColumns: [ORACLE_ROWID_LOCATOR_COLUMN],
readOnly: false,
});
});
});
describe('resolveRowLocatorValues', () => {
it('extracts locator values from the original row', () => {
const locator = resolveEditRowLocator({
dbType: 'mysql',
resultColumns: ['EMAIL', 'NAME'],
indexes: [uniqueIndex('uk_email', 'EMAIL')],
});
expect(resolveRowLocatorValues(locator, { EMAIL: 'a@example.com', NAME: 'A' })).toEqual({
ok: true,
values: { EMAIL: 'a@example.com' },
});
});
it('rejects nullable unique locator values', () => {
const locator = resolveEditRowLocator({
dbType: 'mysql',
resultColumns: ['EMAIL', 'NAME'],
indexes: [uniqueIndex('uk_email', 'EMAIL')],
});
expect(resolveRowLocatorValues(locator, { EMAIL: null, NAME: 'A' })).toEqual({
ok: false,
error: '定位列 EMAIL 的值为空,无法安全提交修改。',
});
});
});
describe('filterHiddenLocatorColumns', () => {
it('removes hidden Oracle ROWID columns from displayed columns', () => {
const locator = resolveEditRowLocator({
dbType: 'oracle',
resultColumns: ['NAME', ORACLE_ROWID_LOCATOR_COLUMN],
allowOracleRowID: true,
});
expect(filterHiddenLocatorColumns(['NAME', ORACLE_ROWID_LOCATOR_COLUMN], locator)).toEqual(['NAME']);
});
});

View File

@@ -0,0 +1,152 @@
import type { IndexDefinition } from '../types';
import { resolveUniqueKeyGroupsFromIndexes } from '../components/dataGridCopyInsert';
import { isOracleLikeDialect } from './sqlDialect';
export const ORACLE_ROWID_LOCATOR_COLUMN = '__gonavi_oracle_rowid__';
export type RowLocatorStrategy = 'primary-key' | 'unique-key' | 'oracle-rowid' | 'none';
export type EditRowLocator = {
strategy: RowLocatorStrategy;
columns: string[];
valueColumns: string[];
hiddenColumns?: string[];
writableColumns?: Record<string, string>;
readOnly: boolean;
reason?: string;
};
export type ResolveEditRowLocatorParams = {
dbType: string;
resultColumns: string[];
primaryKeys?: string[];
indexes?: IndexDefinition[];
allowOracleRowID?: boolean;
};
export type ResolveRowLocatorValuesResult =
| { ok: true; values: Record<string, any> }
| { ok: false; error: string };
const normalizeColumnName = (value: string): string => String(value || '').trim();
const hasColumn = (columns: string[], target: string): boolean => {
const normalizedTarget = normalizeColumnName(target).toLowerCase();
return columns.some((column) => normalizeColumnName(column).toLowerCase() === normalizedTarget);
};
const findColumn = (columns: string[], target: string): string => {
const normalizedTarget = normalizeColumnName(target).toLowerCase();
return columns.find((column) => normalizeColumnName(column).toLowerCase() === normalizedTarget) || target;
};
const buildReadOnlyLocator = (reason: string): EditRowLocator => ({
strategy: 'none',
columns: [],
valueColumns: [],
readOnly: true,
reason,
});
export const resolveEditRowLocator = ({
dbType,
resultColumns,
primaryKeys = [],
indexes,
allowOracleRowID = false,
}: ResolveEditRowLocatorParams): EditRowLocator => {
const columns = (resultColumns || []).map(normalizeColumnName).filter(Boolean);
const primaryKeyColumns = (primaryKeys || []).map(normalizeColumnName).filter(Boolean);
if (primaryKeyColumns.length > 0) {
const missing = primaryKeyColumns.filter((column) => !hasColumn(columns, column));
if (missing.length === 0) {
return {
strategy: 'primary-key',
columns: primaryKeyColumns,
valueColumns: primaryKeyColumns.map((column) => findColumn(columns, column)),
readOnly: false,
};
}
return buildReadOnlyLocator(`结果集中缺少主键列 ${missing.join(', ')},无法安全提交修改。`);
}
const uniqueKeyGroups = resolveUniqueKeyGroupsFromIndexes(indexes);
const uniqueKeyGroup = uniqueKeyGroups.find((group) => group.length > 0 && group.every((column) => hasColumn(columns, column)));
if (uniqueKeyGroup) {
return {
strategy: 'unique-key',
columns: uniqueKeyGroup,
valueColumns: uniqueKeyGroup.map((column) => findColumn(columns, column)),
readOnly: false,
};
}
if (allowOracleRowID && isOracleLikeDialect(dbType) && hasColumn(columns, ORACLE_ROWID_LOCATOR_COLUMN)) {
const rowIDColumn = findColumn(columns, ORACLE_ROWID_LOCATOR_COLUMN);
return {
strategy: 'oracle-rowid',
columns: ['ROWID'],
valueColumns: [rowIDColumn],
hiddenColumns: [rowIDColumn],
readOnly: false,
};
}
if (allowOracleRowID && isOracleLikeDialect(dbType)) {
return buildReadOnlyLocator('未检测到主键或可用唯一索引,且结果中缺少 Oracle ROWID无法安全提交修改。');
}
return buildReadOnlyLocator('未检测到主键或可用唯一索引,无法安全提交修改。');
};
export const resolveRowLocatorValues = (
locator: EditRowLocator | undefined,
row: Record<string, any>,
): ResolveRowLocatorValuesResult => {
if (!locator || locator.readOnly || locator.strategy === 'none') {
return { ok: false, error: '当前结果没有可用的安全行定位方式,无法提交修改。' };
}
const values: Record<string, any> = {};
for (let index = 0; index < locator.columns.length; index++) {
const column = locator.columns[index];
const valueColumn = locator.valueColumns[index] || column;
const value = row?.[valueColumn];
if (value === null || value === undefined || value === '') {
return { ok: false, error: `定位列 ${column} 的值为空,无法安全提交修改。` };
}
values[column] = value;
}
return { ok: true, values };
};
export const filterHiddenLocatorColumns = (columns: string[], locator?: EditRowLocator): string[] => {
const hidden = new Set((locator?.hiddenColumns || []).map((column) => normalizeColumnName(column).toLowerCase()));
if (hidden.size === 0) return columns;
return (columns || []).filter((column) => !hidden.has(normalizeColumnName(column).toLowerCase()));
};
export const isHiddenLocatorColumn = (column: string, locator?: EditRowLocator): boolean => {
const normalized = normalizeColumnName(column).toLowerCase();
return (locator?.hiddenColumns || []).some((hidden) => normalizeColumnName(hidden).toLowerCase() === normalized);
};
export const resolveWritableColumnName = (column: string, locator?: EditRowLocator): string | undefined => {
const normalized = normalizeColumnName(column);
if (!normalized || isHiddenLocatorColumn(normalized, locator)) return undefined;
const writableColumns = locator?.writableColumns;
if (!writableColumns) return normalized;
const normalizedTarget = normalized.toLowerCase();
const matchedEntry = Object.entries(writableColumns).find(([resultColumn]) => (
normalizeColumnName(resultColumn).toLowerCase() === normalizedTarget
));
const tableColumnName = normalizeColumnName(matchedEntry?.[1] || '');
return tableColumnName || undefined;
};
export const isWritableResultColumn = (column: string, locator?: EditRowLocator): boolean => (
resolveWritableColumnName(column, locator) !== undefined
);

View File

@@ -426,6 +426,7 @@ export namespace connection {
inserts: any[];
updates: UpdateRow[];
deletes: any[];
locatorStrategy?: string;
static createFrom(source: any = {}) {
return new ChangeSet(source);
@@ -436,6 +437,7 @@ export namespace connection {
this.inserts = source["inserts"];
this.updates = this.convertValues(source["updates"], UpdateRow);
this.deletes = source["deletes"];
this.locatorStrategy = source["locatorStrategy"];
}
convertValues(a: any, classs: any, asMap: boolean = false): any {
@@ -668,6 +670,7 @@ export namespace connection {
timeout?: number;
redisDB?: number;
uri?: string;
clickHouseProtocol?: string;
hosts?: string[];
topology?: string;
mysqlReplicaUser?: string;
@@ -710,6 +713,7 @@ export namespace connection {
this.timeout = source["timeout"];
this.redisDB = source["redisDB"];
this.uri = source["uri"];
this.clickHouseProtocol = source["clickHouseProtocol"];
this.hosts = source["hosts"];
this.topology = source["topology"];
this.mysqlReplicaUser = source["mysqlReplicaUser"];

View File

@@ -467,6 +467,13 @@ func formatConnSummary(config connection.ConnectionConfig) string {
b.WriteString(fmt.Sprintf(" 认证库=%s", strings.TrimSpace(config.AuthSource)))
}
}
if strings.EqualFold(strings.TrimSpace(config.Type), "clickhouse") {
protocol := strings.ToLower(strings.TrimSpace(config.ClickHouseProtocol))
if protocol == "" {
protocol = "auto"
}
b.WriteString(fmt.Sprintf(" ClickHouse协议=%s", protocol))
}
if config.UseSSH {
b.WriteString(fmt.Sprintf(" SSH=%s:%d 用户=%s", config.SSH.Host, config.SSH.Port, config.SSH.User))

View File

@@ -80,3 +80,22 @@ func TestGetCacheKey_KeepDatabaseIsolation(t *testing.T) {
t.Fatalf("expected different cache key for different database targets")
}
}
func TestGetCacheKey_KeepClickHouseProtocolIsolation(t *testing.T) {
base := connection.ConnectionConfig{
Type: "clickhouse",
Host: "clickhouse.local",
Port: 8125,
User: "default",
Database: "default",
ClickHouseProtocol: "native",
}
modified := base
modified.ClickHouseProtocol = "http"
left := getCacheKey(base)
right := getCacheKey(modified)
if left == right {
t.Fatalf("expected different cache key for different ClickHouse protocols")
}
}

View File

@@ -23,9 +23,24 @@ func normalizeTestConnectionConfig(config connection.ConnectionConfig) connectio
return normalized
}
func validateTestConnectionInput(config connection.ConnectionConfig) error {
dbType := strings.ToLower(strings.TrimSpace(config.Type))
if dbType == "" {
return fmt.Errorf("请先选择数据源类型")
}
if dbType == "clickhouse" && strings.TrimSpace(config.Host) == "" && strings.TrimSpace(config.URI) == "" {
return fmt.Errorf("请填写 ClickHouse 主机地址或连接 URI")
}
return nil
}
// Generic DB Methods
func (a *App) DBConnect(config connection.ConnectionConfig) connection.QueryResult {
if err := validateTestConnectionInput(config); err != nil {
logger.Warnf("DBConnect 参数校验失败:%s %s", err.Error(), formatConnSummary(config))
return connection.QueryResult{Success: false, Message: err.Error()}
}
// 连接测试需要强制 ping避免缓存命中但连接已失效时误判成功。
_, err := a.getDatabaseForcePing(config)
if err != nil {
@@ -41,6 +56,10 @@ func (a *App) TestConnection(config connection.ConnectionConfig) connection.Quer
testConfig := normalizeTestConnectionConfig(config)
started := time.Now()
logger.Infof("TestConnection 开始:%s", formatConnSummary(testConfig))
if err := validateTestConnectionInput(testConfig); err != nil {
logger.Warnf("TestConnection 参数校验失败:耗时=%s %s 原因=%s", time.Since(started).Round(time.Millisecond), formatConnSummary(testConfig), err.Error())
return connection.QueryResult{Success: false, Message: err.Error()}
}
_, err := a.getDatabaseForcePing(testConfig)
if err != nil {
logger.Error(err, "TestConnection 连接测试失败:耗时=%s %s", time.Since(started).Round(time.Millisecond), formatConnSummary(testConfig))

View File

@@ -31,6 +31,26 @@ func TestNormalizeTestConnectionConfig_ZeroTimeout(t *testing.T) {
}
}
func TestValidateTestConnectionInput_ClickHouseRequiresTarget(t *testing.T) {
err := validateTestConnectionInput(connection.ConnectionConfig{Type: "clickhouse"})
if err == nil {
t.Fatal("expected ClickHouse target validation error")
}
if !strings.Contains(err.Error(), "ClickHouse 主机地址") {
t.Fatalf("unexpected validation error: %v", err)
}
}
func TestValidateTestConnectionInput_ClickHouseAllowsURI(t *testing.T) {
err := validateTestConnectionInput(connection.ConnectionConfig{
Type: "clickhouse",
URI: "http://clickhouse.example.com:8125/default",
})
if err != nil {
t.Fatalf("expected ClickHouse URI to satisfy target validation, got %v", err)
}
}
func TestFormatConnSummary_BasicMySQL(t *testing.T) {
cfg := connection.ConnectionConfig{
Type: "mysql",

View File

@@ -127,6 +127,7 @@ type driverDefinition struct {
type installedDriverPackage struct {
DriverType string `json:"driverType"`
Version string `json:"version,omitempty"`
AgentRevision string `json:"agentRevision,omitempty"`
FilePath string `json:"filePath"`
FileName string `json:"fileName"`
ExecutablePath string `json:"executablePath,omitempty"`
@@ -136,23 +137,28 @@ type installedDriverPackage struct {
}
type driverStatusItem struct {
Type string `json:"type"`
Name string `json:"name"`
Engine string `json:"engine,omitempty"`
BuiltIn bool `json:"builtIn"`
PinnedVersion string `json:"pinnedVersion,omitempty"`
InstalledVersion string `json:"installedVersion,omitempty"`
PackageSizeText string `json:"packageSizeText,omitempty"`
RuntimeAvailable bool `json:"runtimeAvailable"`
PackageInstalled bool `json:"packageInstalled"`
Connectable bool `json:"connectable"`
DefaultDownloadURL string `json:"defaultDownloadUrl,omitempty"`
InstallDir string `json:"installDir,omitempty"`
PackagePath string `json:"packagePath,omitempty"`
PackageFileName string `json:"packageFileName,omitempty"`
ExecutablePath string `json:"executablePath,omitempty"`
DownloadedAt string `json:"downloadedAt,omitempty"`
Message string `json:"message,omitempty"`
Type string `json:"type"`
Name string `json:"name"`
Engine string `json:"engine,omitempty"`
BuiltIn bool `json:"builtIn"`
PinnedVersion string `json:"pinnedVersion,omitempty"`
InstalledVersion string `json:"installedVersion,omitempty"`
PackageSizeText string `json:"packageSizeText,omitempty"`
RuntimeAvailable bool `json:"runtimeAvailable"`
PackageInstalled bool `json:"packageInstalled"`
Connectable bool `json:"connectable"`
DefaultDownloadURL string `json:"defaultDownloadUrl,omitempty"`
InstallDir string `json:"installDir,omitempty"`
PackagePath string `json:"packagePath,omitempty"`
PackageFileName string `json:"packageFileName,omitempty"`
ExecutablePath string `json:"executablePath,omitempty"`
DownloadedAt string `json:"downloadedAt,omitempty"`
AgentRevision string `json:"agentRevision,omitempty"`
ExpectedRevision string `json:"expectedRevision,omitempty"`
NeedsUpdate bool `json:"needsUpdate,omitempty"`
UpdateReason string `json:"updateReason,omitempty"`
AffectedConnections int `json:"affectedConnections,omitempty"`
Message string `json:"message,omitempty"`
}
const driverDownloadProgressEvent = "driver:download-progress"
@@ -758,29 +764,36 @@ func (a *App) GetDriverStatusList(downloadDir string, manifestURL string) connec
definitions := allDriverDefinitionsWithPackages(effectivePackages)
triggerDriverVersionMetadataWarmup(definitions)
packageSizeBytesMap := preloadOptionalDriverPackageSizes(definitions)
usageCounts := a.savedConnectionDriverUsageCounts()
items := make([]driverStatusItem, 0, len(definitions))
for _, definition := range definitions {
engine := effectiveDriverEngine(definition)
runtimeAvailable, runtimeReason := db.DriverRuntimeSupportStatus(definition.Type)
pkg, packageMetaExists := readInstalledDriverPackage(resolvedDir, definition.Type)
needsUpdate, updateReason, expectedRevision := optionalDriverAgentRevisionStatus(definition.Type, pkg, packageMetaExists)
packageInstalled := definition.BuiltIn || packageMetaExists
if runtimeAvailable && db.IsOptionalGoDriver(definition.Type) {
packageInstalled = true
}
item := driverStatusItem{
Type: definition.Type,
Name: definition.Name,
Engine: engine,
BuiltIn: definition.BuiltIn,
PinnedVersion: definition.PinnedVersion,
InstalledVersion: strings.TrimSpace(pkg.Version),
PackageSizeText: resolveDriverPackageSizeText(definition, pkg, packageMetaExists, packageSizeBytesMap),
RuntimeAvailable: runtimeAvailable,
PackageInstalled: packageInstalled,
Connectable: runtimeAvailable,
DefaultDownloadURL: definition.DefaultDownloadURL,
InstallDir: driverInstallDir(resolvedDir, definition.Type),
Type: definition.Type,
Name: definition.Name,
Engine: engine,
BuiltIn: definition.BuiltIn,
PinnedVersion: definition.PinnedVersion,
InstalledVersion: strings.TrimSpace(pkg.Version),
PackageSizeText: resolveDriverPackageSizeText(definition, pkg, packageMetaExists, packageSizeBytesMap),
RuntimeAvailable: runtimeAvailable,
PackageInstalled: packageInstalled,
Connectable: runtimeAvailable,
DefaultDownloadURL: definition.DefaultDownloadURL,
InstallDir: driverInstallDir(resolvedDir, definition.Type),
AgentRevision: strings.TrimSpace(pkg.AgentRevision),
ExpectedRevision: expectedRevision,
NeedsUpdate: needsUpdate,
UpdateReason: updateReason,
AffectedConnections: usageCounts[normalizeDriverType(definition.Type)],
}
if packageMetaExists {
item.PackagePath = pkg.FilePath
@@ -792,6 +805,12 @@ func (a *App) GetDriverStatusList(downloadDir string, manifestURL string) connec
switch {
case definition.BuiltIn:
item.Message = "内置驱动,可直接连接"
case needsUpdate:
if item.AffectedConnections > 0 {
item.Message = fmt.Sprintf("%s检测到 %d 个已保存连接正在使用该驱动,请在工具-驱动管理中重装", updateReason, item.AffectedConnections)
} else {
item.Message = updateReason + ",请在工具-驱动管理中重装"
}
case runtimeAvailable:
item.Message = "纯 Go 驱动已启用,可直接连接"
case packageInstalled && strings.TrimSpace(runtimeReason) != "":
@@ -2702,6 +2721,47 @@ func readInstalledDriverPackage(downloadDir string, driverType string) (installe
return meta, true
}
func optionalDriverAgentRevisionStatus(driverType string, pkg installedDriverPackage, packageMetaExists bool) (bool, string, string) {
expected := db.OptionalDriverAgentRevision(driverType)
if strings.TrimSpace(expected) == "" || !packageMetaExists || !db.IsOptionalGoDriver(driverType) {
return false, "", expected
}
actual := strings.TrimSpace(pkg.AgentRevision)
if actual == expected {
return false, "", expected
}
displayName := resolveDriverDisplayName(driverDefinition{Type: driverType})
updateReason := fmt.Sprintf("当前 GoNavi 版本要求更新后的 %s driver-agentrevision: %s", displayName, expected)
impact := "driver-agent 是独立二进制,不会随主程序自动更新;如果不重装,会继续使用旧 agent 逻辑,驱动侧已修复或优化的行为不会生效,可能继续出现旧版本问题。强烈建议重装对应驱动代理"
if actual == "" {
return true, fmt.Sprintf("原因:%s。影响%s", updateReason, impact), expected
}
return true, fmt.Sprintf("原因:%s。影响%s已安装标记%s当前需要%s", updateReason, impact, actual, expected), expected
}
func (a *App) savedConnectionDriverUsageCounts() map[string]int {
counts := map[string]int{}
if a == nil || strings.TrimSpace(a.configDir) == "" {
return counts
}
items, err := a.savedConnectionRepository().List()
if err != nil {
logger.Warnf("统计驱动连接使用数失败:%v", err)
return counts
}
for _, item := range items {
driverType := normalizeDriverType(item.Config.Type)
if driverType == "custom" {
driverType = normalizeDriverType(item.Config.Driver)
}
if driverType == "" || !db.IsOptionalGoDriver(driverType) {
continue
}
counts[driverType]++
}
return counts
}
func writeInstalledDriverPackage(downloadDir string, driverType string, meta installedDriverPackage) error {
driverDir := driverInstallDir(downloadDir, driverType)
if err := os.MkdirAll(driverDir, 0o755); err != nil {
@@ -2765,9 +2825,11 @@ func installOptionalDriverAgentPackage(a *App, definition driverDefinition, sele
if strings.TrimSpace(downloadSource) == "" {
downloadSource = strings.TrimSpace(downloadURL)
}
agentRevision := probeInstalledOptionalDriverAgentRevision(driverType, runtimePath)
return installedDriverPackage{
DriverType: driverType,
Version: strings.TrimSpace(selectedVersion),
AgentRevision: agentRevision,
FilePath: installPath,
FileName: filepath.Base(installPath),
ExecutablePath: runtimePath,
@@ -2837,9 +2899,11 @@ func installOptionalDriverAgentFromLocalPath(definition driverDefinition, filePa
if hashErr != nil {
return installedDriverPackage{}, fmt.Errorf("计算 %s 驱动代理摘要失败:%w", displayName, hashErr)
}
agentRevision := probeInstalledOptionalDriverAgentRevision(driverType, executablePath)
return installedDriverPackage{
DriverType: driverType,
Version: strings.TrimSpace(selectedVersion),
AgentRevision: agentRevision,
FilePath: sourcePath,
FileName: sourceName,
ExecutablePath: executablePath,
@@ -2849,6 +2913,19 @@ func installOptionalDriverAgentFromLocalPath(definition driverDefinition, filePa
}, nil
}
func probeInstalledOptionalDriverAgentRevision(driverType string, executablePath string) string {
expectedRevision := db.OptionalDriverAgentRevision(driverType)
if strings.TrimSpace(expectedRevision) == "" {
return ""
}
metadata, err := db.ProbeOptionalDriverAgentMetadata(driverType, executablePath)
if err != nil {
logger.Warnf("%s 驱动代理未返回版本元数据:%v", resolveDriverDisplayName(driverDefinition{Type: driverType}), err)
return ""
}
return strings.TrimSpace(metadata.AgentRevision)
}
type localDriverCandidate struct {
absPath string
relativePath string

View File

@@ -0,0 +1,72 @@
package app
import (
"strings"
"testing"
"GoNavi-Wails/internal/connection"
)
func TestOptionalDriverAgentRevisionStatusDetectsStaleClickHouseAgent(t *testing.T) {
needsUpdate, reason, expected := optionalDriverAgentRevisionStatus("clickhouse", installedDriverPackage{}, true)
if !needsUpdate {
t.Fatal("expected missing ClickHouse agent revision to require update")
}
if expected == "" {
t.Fatal("expected ClickHouse to define an agent revision")
}
if reason == "" {
t.Fatal("expected update reason")
}
if !strings.Contains(reason, "原因:") || !strings.Contains(reason, "影响:") {
t.Fatalf("expected reason to explain cause and impact, got %q", reason)
}
if !strings.Contains(reason, "强烈建议重装") {
t.Fatalf("expected reason to strongly recommend reinstall, got %q", reason)
}
current := installedDriverPackage{AgentRevision: expected}
needsUpdate, reason, _ = optionalDriverAgentRevisionStatus("clickhouse", current, true)
if needsUpdate {
t.Fatalf("expected current ClickHouse agent revision to be accepted, reason=%q", reason)
}
}
func TestSavedConnectionDriverUsageCountsIncludesOptionalAndCustomDrivers(t *testing.T) {
app := &App{configDir: t.TempDir()}
repo := app.savedConnectionRepository()
if err := repo.saveAll([]connection.SavedConnectionView{
{
ID: "conn-clickhouse",
Name: "ClickHouse",
Config: connection.ConnectionConfig{
Type: "clickhouse",
},
},
{
ID: "conn-custom-clickhouse",
Name: "Custom ClickHouse",
Config: connection.ConnectionConfig{
Type: "custom",
Driver: "clickhouse",
},
},
{
ID: "conn-mysql",
Name: "MySQL",
Config: connection.ConnectionConfig{
Type: "mysql",
},
},
}); err != nil {
t.Fatalf("save connections failed: %v", err)
}
counts := app.savedConnectionDriverUsageCounts()
if got := counts["clickhouse"]; got != 2 {
t.Fatalf("expected two ClickHouse usages, got %d", got)
}
if got := counts["mysql"]; got != 0 {
t.Fatalf("expected built-in MySQL to be ignored, got %d", got)
}
}

View File

@@ -102,6 +102,7 @@ type ConnectionConfig struct {
Timeout int `json:"timeout,omitempty"` // Connection timeout in seconds (default: 30)
RedisDB int `json:"redisDB,omitempty"` // Redis database index (0-15)
URI string `json:"uri,omitempty"` // Connection URI for copy/paste
ClickHouseProtocol string `json:"clickHouseProtocol,omitempty"` // auto | http | native
Hosts []string `json:"hosts,omitempty"` // Multi-host addresses: host:port
Topology string `json:"topology,omitempty"` // single | replica | cluster
MySQLReplicaUser string `json:"mysqlReplicaUser,omitempty"` // MySQL replica auth user
@@ -184,9 +185,10 @@ type UpdateRow struct {
// ChangeSet 表示一组批量变更,包含新增、修改和删除操作。
type ChangeSet struct {
Inserts []map[string]interface{} `json:"inserts"`
Updates []UpdateRow `json:"updates"`
Deletes []map[string]interface{} `json:"deletes"`
Inserts []map[string]interface{} `json:"inserts"`
Updates []UpdateRow `json:"updates"`
Deletes []map[string]interface{} `json:"deletes"`
LocatorStrategy string `json:"locatorStrategy,omitempty"`
}
// MongoMemberInfo 描述 MongoDB 副本集成员的信息。

View File

@@ -12,6 +12,8 @@ import (
"strconv"
"strings"
"time"
"unicode"
"unicode/utf8"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/logger"
@@ -26,7 +28,11 @@ const (
defaultClickHouseUser = "default"
defaultClickHouseDatabase = "default"
minClickHouseReadTimeout = 5 * time.Minute
clickHouseHTTPPortHint = "8123/8132/8443"
clickHouseHTTPPortHint = "8123/8125/8132/8443"
clickHouseProtocolAuto = "auto"
clickHouseProtocolHTTP = "http"
clickHouseProtocolNative = "native"
)
type ClickHouseDB struct {
@@ -38,6 +44,7 @@ type ClickHouseDB struct {
func normalizeClickHouseConfig(config connection.ConnectionConfig) connection.ConnectionConfig {
normalized := applyClickHouseURI(config)
normalized = applyClickHouseHostURI(normalized)
if strings.TrimSpace(normalized.Host) == "" {
normalized.Host = "localhost"
}
@@ -58,15 +65,26 @@ func applyClickHouseURI(config connection.ConnectionConfig) connection.Connectio
if uriText == "" {
return config
}
lowerURI := strings.ToLower(uriText)
if !strings.HasPrefix(lowerURI, "clickhouse://") {
return applyClickHouseEndpointURI(config, uriText, false)
}
func applyClickHouseHostURI(config connection.ConnectionConfig) connection.ConnectionConfig {
hostText := strings.TrimSpace(config.Host)
if hostText == "" {
return config
}
return applyClickHouseEndpointURI(config, hostText, true)
}
func applyClickHouseEndpointURI(config connection.ConnectionConfig, uriText string, fromHostField bool) connection.ConnectionConfig {
parsed, err := url.Parse(uriText)
if err != nil {
return config
}
scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
if !isClickHouseSupportedEndpointScheme(scheme) || strings.TrimSpace(parsed.Host) == "" {
return config
}
if parsed.User != nil {
if strings.TrimSpace(config.User) == "" {
@@ -85,12 +103,28 @@ func applyClickHouseURI(config connection.ConnectionConfig) connection.Connectio
config.Database = dbName
}
}
if queryProtocol := normalizeClickHouseProtocol(parsed.Query().Get("protocol")); queryProtocol != clickHouseProtocolAuto {
config.ClickHouseProtocol = queryProtocol
}
endpointProtocol := normalizeClickHouseProtocol(config.ClickHouseProtocol)
if isClickHouseHTTPURLScheme(scheme) && endpointProtocol != clickHouseProtocolNative {
config.ClickHouseProtocol = clickHouseProtocolHTTP
if scheme == "https" {
config.UseSSL = true
if normalizeSSLModeValue(config.SSLMode) == sslModeDisable || strings.TrimSpace(config.SSLMode) == "" {
config.SSLMode = sslModeRequired
}
}
}
defaultPort := config.Port
if defaultPort <= 0 {
defaultPort = defaultClickHousePort
}
if strings.TrimSpace(config.Host) == "" {
if isClickHouseHTTPURLScheme(scheme) && endpointProtocol != clickHouseProtocolNative && defaultPort == defaultClickHousePort {
defaultPort = defaultClickHousePortForScheme(scheme)
}
if fromHostField || strings.TrimSpace(config.Host) == "" {
host, port, ok := parseHostPortWithDefault(parsed.Host, defaultPort)
if ok {
config.Host = host
@@ -103,6 +137,30 @@ func applyClickHouseURI(config connection.ConnectionConfig) connection.Connectio
return config
}
func isClickHouseSupportedEndpointScheme(scheme string) bool {
switch scheme {
case "clickhouse", "http", "https":
return true
default:
return false
}
}
func isClickHouseHTTPURLScheme(scheme string) bool {
return scheme == "http" || scheme == "https"
}
func defaultClickHousePortForScheme(scheme string) int {
switch scheme {
case "http":
return 8123
case "https":
return 8443
default:
return defaultClickHousePort
}
}
func (c *ClickHouseDB) buildClickHouseOptions(config connection.ConnectionConfig) *clickhouse.Options {
connectTimeout := getConnectTimeout(config)
readTimeout := connectTimeout
@@ -130,6 +188,15 @@ func (c *ClickHouseDB) buildClickHouseOptions(config connection.ConnectionConfig
}
func detectClickHouseProtocol(config connection.ConnectionConfig) clickhouse.Protocol {
switch normalizeClickHouseProtocol(config.ClickHouseProtocol) {
case clickHouseProtocolHTTP:
return clickhouse.HTTP
case clickHouseProtocolNative:
return clickhouse.Native
}
if hasClickHouseHTTPScheme(config.URI) || hasClickHouseHTTPScheme(config.Host) {
return clickhouse.HTTP
}
uriText := strings.ToLower(strings.TrimSpace(config.URI))
if strings.HasPrefix(uriText, "http://") || strings.HasPrefix(uriText, "https://") {
return clickhouse.HTTP
@@ -140,9 +207,25 @@ func detectClickHouseProtocol(config connection.ConnectionConfig) clickhouse.Pro
return clickhouse.Native
}
func normalizeClickHouseProtocol(raw string) string {
switch strings.ToLower(strings.TrimSpace(raw)) {
case clickHouseProtocolHTTP, "https":
return clickHouseProtocolHTTP
case clickHouseProtocolNative, "tcp":
return clickHouseProtocolNative
default:
return clickHouseProtocolAuto
}
}
func hasClickHouseHTTPScheme(raw string) bool {
text := strings.ToLower(strings.TrimSpace(raw))
return strings.HasPrefix(text, "http://") || strings.HasPrefix(text, "https://")
}
func isClickHouseHTTPPort(port int) bool {
switch port {
case 8123, 8132, 8443:
case 8123, 8125, 8132, 8443:
return true
default:
return false
@@ -159,18 +242,88 @@ func isClickHouseProtocolMismatch(err error) bool {
}
return strings.Contains(text, "unexpected packet [72]") ||
(strings.Contains(text, "unexpected packet") && strings.Contains(text, "handshake")) ||
(strings.Contains(text, "cannot parse input") && strings.Contains(text, "expected '('")) ||
strings.Contains(text, "http response to https client") ||
strings.Contains(text, "malformed http response")
}
func clickHouseProtocolName(protocol clickhouse.Protocol) string {
if protocol == clickhouse.HTTP {
return "HTTP"
}
return "Native"
}
func sanitizeClickHouseErrorMessage(err error) string {
if err == nil {
return ""
}
text := strings.ToValidUTF8(err.Error(), "<22>")
var b strings.Builder
lastSpace := false
for _, r := range text {
if r == utf8.RuneError || r == '<27>' {
if !lastSpace {
b.WriteByte(' ')
lastSpace = true
}
continue
}
if unicode.IsControl(r) {
if !lastSpace {
b.WriteByte(' ')
lastSpace = true
}
continue
}
b.WriteRune(r)
lastSpace = unicode.IsSpace(r)
}
sanitized := strings.Join(strings.Fields(b.String()), " ")
if len(sanitized) > 320 {
return sanitized[:320] + "..."
}
return sanitized
}
func clickHouseAttemptFailureMessage(protocol clickhouse.Protocol, err error) string {
if isClickHouseProtocolMismatch(err) {
if protocol == clickhouse.Native {
return "服务端响应不像 Native 握手,当前端口更像 HTTP/HTTPS 端口;请选择 HTTP 协议,或确认 ClickHouse Native 端口"
}
return "服务端响应不像 HTTP 响应,当前端口更像 Native 端口;请选择 Native 协议,或确认 ClickHouse HTTP 端口"
}
message := sanitizeClickHouseErrorMessage(err)
if message == "" {
return "未知错误"
}
return message
}
func clickHouseConnectFailureSummary(config connection.ConnectionConfig, failures []string) string {
protocolMode := normalizeClickHouseProtocol(config.ClickHouseProtocol)
detail := strings.Join(failures, "")
if strings.TrimSpace(detail) == "" {
detail = "未获取到驱动返回的错误详情"
}
if protocolMode != clickHouseProtocolAuto {
return fmt.Sprintf("ClickHouse 连接验证失败:已按用户选择使用 %s 协议连接 %s:%d。%s",
strings.ToUpper(protocolMode), config.Host, config.Port, detail)
}
return fmt.Sprintf("ClickHouse 连接验证失败自动协议探测未成功Native 常见端口 9000/9440HTTP 常见端口 %s非标端口建议在连接协议中手动指定。%s",
clickHouseHTTPPortHint, detail)
}
func withClickHouseProtocol(config connection.ConnectionConfig, protocol clickhouse.Protocol) connection.ConnectionConfig {
next := config
switch protocol {
case clickhouse.HTTP:
next.ClickHouseProtocol = clickHouseProtocolHTTP
if next.Port == 0 {
next.Port = 8123
}
default:
next.ClickHouseProtocol = clickHouseProtocolNative
if next.Port == 0 {
next.Port = defaultClickHousePort
}
@@ -178,6 +331,17 @@ func withClickHouseProtocol(config connection.ConnectionConfig, protocol clickho
return next
}
func clickHouseProtocolsForAttempt(config connection.ConnectionConfig) []clickhouse.Protocol {
primaryProtocol := detectClickHouseProtocol(config)
if normalizeClickHouseProtocol(config.ClickHouseProtocol) != clickHouseProtocolAuto {
return []clickhouse.Protocol{primaryProtocol}
}
if primaryProtocol == clickhouse.Native {
return []clickhouse.Protocol{primaryProtocol, clickhouse.HTTP}
}
return []clickhouse.Protocol{primaryProtocol, clickhouse.Native}
}
func (c *ClickHouseDB) Connect(config connection.ConnectionConfig) error {
if supported, reason := DriverRuntimeSupportStatus("clickhouse"); !supported {
if strings.TrimSpace(reason) == "" {
@@ -198,8 +362,14 @@ func (c *ClickHouseDB) Connect(config connection.ConnectionConfig) error {
runConfig := normalizeClickHouseConfig(config)
c.pingTimeout = getConnectTimeout(runConfig)
c.database = runConfig.Database
logger.Infof("ClickHouse 连接准备:地址=%s:%d 数据库=%s 用户=%s 协议选择=%s SSL=%t SSH=%t 超时=%s",
runConfig.Host, runConfig.Port, runConfig.Database, runConfig.User,
normalizeClickHouseProtocol(runConfig.ClickHouseProtocol), runConfig.UseSSL, runConfig.UseSSH, c.pingTimeout)
if runConfig.UseSSH {
if normalizeClickHouseProtocol(runConfig.ClickHouseProtocol) == clickHouseProtocolAuto && detectClickHouseProtocol(runConfig) == clickhouse.HTTP {
runConfig.ClickHouseProtocol = clickHouseProtocolHTTP
}
logger.Infof("ClickHouse 使用 SSH 连接:地址=%s:%d 用户=%s", runConfig.Host, runConfig.Port, runConfig.User)
forwarder, err := ssh.GetOrCreateLocalForwarder(runConfig.SSH, runConfig.Host, runConfig.Port)
if err != nil {
@@ -229,19 +399,17 @@ func (c *ClickHouseDB) Connect(config connection.ConnectionConfig) error {
var failures []string
for idx, attempt := range attempts {
primaryProtocol := detectClickHouseProtocol(attempt)
protocols := []clickhouse.Protocol{primaryProtocol}
if primaryProtocol == clickhouse.Native {
protocols = append(protocols, clickhouse.HTTP)
} else {
protocols = append(protocols, clickhouse.Native)
}
protocols := clickHouseProtocolsForAttempt(attempt)
for pIdx, protocol := range protocols {
protocolConfig := withClickHouseProtocol(attempt, protocol)
logger.Infof("ClickHouse 连接尝试:第%d组/%d 协议=%s 地址=%s:%d SSL=%t",
idx+1, len(attempts), clickHouseProtocolName(protocol), protocolConfig.Host, protocolConfig.Port, protocolConfig.UseSSL)
c.conn = clickhouse.OpenDB(c.buildClickHouseOptions(protocolConfig))
if err := c.Ping(); err != nil {
failures = append(failures, fmt.Sprintf("第%d次连接验证失败(protocol=%s): %v", idx+1, protocol.String(), err))
failureMessage := clickHouseAttemptFailureMessage(protocol, err)
failures = append(failures, fmt.Sprintf("第%d次连接验证失败(protocol=%s): %s", idx+1, protocol.String(), failureMessage))
logger.Warnf("ClickHouse 连接尝试失败:第%d组/%d 协议=%s 地址=%s:%d SSL=%t 原因=%s",
idx+1, len(attempts), clickHouseProtocolName(protocol), protocolConfig.Host, protocolConfig.Port, protocolConfig.UseSSL, failureMessage)
if c.conn != nil {
_ = c.conn.Close()
c.conn = nil
@@ -258,12 +426,13 @@ func (c *ClickHouseDB) Connect(config connection.ConnectionConfig) error {
if pIdx > 0 {
logger.Warnf("ClickHouse 已自动切换连接协议为 %s常见于 %s HTTP 端口)", protocol.String(), clickHouseHTTPPortHint)
}
logger.Infof("ClickHouse 连接验证成功:协议=%s 地址=%s:%d 数据库=%s", clickHouseProtocolName(protocol), protocolConfig.Host, protocolConfig.Port, protocolConfig.Database)
return nil
}
}
_ = c.Close()
return fmt.Errorf("连接建立后验证失败(可检查 ClickHouse 端口与协议是否匹配Native=9000/9440HTTP=%s%s", clickHouseHTTPPortHint, strings.Join(failures, ""))
return fmt.Errorf("%s", clickHouseConnectFailureSummary(runConfig, failures))
}
func (c *ClickHouseDB) Close() error {

View File

@@ -160,6 +160,13 @@ func TestDetectClickHouseProtocolTreatsHTTPPortsAsHTTP(t *testing.T) {
},
expected: clickhouse.HTTP,
},
{
name: "custom http port 8125",
config: connection.ConnectionConfig{
Port: 8125,
},
expected: clickhouse.HTTP,
},
{
name: "https port",
config: connection.ConnectionConfig{
@@ -181,6 +188,30 @@ func TestDetectClickHouseProtocolTreatsHTTPPortsAsHTTP(t *testing.T) {
},
expected: clickhouse.Native,
},
{
name: "host http scheme",
config: connection.ConnectionConfig{
Host: "http://clickhouse.example.com",
Port: 8125,
},
expected: clickhouse.HTTP,
},
{
name: "manual http overrides native port",
config: connection.ConnectionConfig{
ClickHouseProtocol: "http",
Port: 9000,
},
expected: clickhouse.HTTP,
},
{
name: "manual native overrides http port",
config: connection.ConnectionConfig{
ClickHouseProtocol: "native",
Port: 8123,
},
expected: clickhouse.Native,
},
}
for _, tt := range tests {
@@ -192,6 +223,172 @@ func TestDetectClickHouseProtocolTreatsHTTPPortsAsHTTP(t *testing.T) {
}
}
func TestNormalizeClickHouseConfigParsesHTTPHostScheme(t *testing.T) {
config := normalizeClickHouseConfig(connection.ConnectionConfig{
Type: "clickhouse",
Host: "https://clickhouse.example.com:8125/analytics",
User: "alice",
Password: "secret",
})
if config.Host != "clickhouse.example.com" {
t.Fatalf("expected host without scheme, got %q", config.Host)
}
if config.Port != 8125 {
t.Fatalf("expected port 8125, got %d", config.Port)
}
if config.Database != "analytics" {
t.Fatalf("expected database analytics, got %q", config.Database)
}
if config.ClickHouseProtocol != "http" {
t.Fatalf("expected http protocol hint, got %q", config.ClickHouseProtocol)
}
if !config.UseSSL || config.SSLMode != sslModeRequired {
t.Fatalf("expected https host to enable required SSL, got useSSL=%v sslMode=%q", config.UseSSL, config.SSLMode)
}
}
func TestNormalizeClickHouseConfigKeepsManualNativeWhenHostHasHTTPScheme(t *testing.T) {
config := normalizeClickHouseConfig(connection.ConnectionConfig{
Type: "clickhouse",
Host: "http://clickhouse.example.com:9001/analytics",
ClickHouseProtocol: "native",
User: "alice",
Password: "secret",
})
if config.Host != "clickhouse.example.com" {
t.Fatalf("expected host without scheme, got %q", config.Host)
}
if config.Port != 9001 {
t.Fatalf("expected user-provided native port 9001, got %d", config.Port)
}
if config.Database != "analytics" {
t.Fatalf("expected database analytics, got %q", config.Database)
}
if config.ClickHouseProtocol != "native" {
t.Fatalf("expected manual native protocol to be preserved, got %q", config.ClickHouseProtocol)
}
if config.UseSSL {
t.Fatalf("manual native protocol should not be forced to HTTP TLS by http scheme")
}
}
func TestNormalizeClickHouseConfigUsesNativeDefaultPortForManualNativeHTTPScheme(t *testing.T) {
config := normalizeClickHouseConfig(connection.ConnectionConfig{
Type: "clickhouse",
Host: "https://clickhouse.example.com/analytics",
ClickHouseProtocol: "native",
})
if config.Host != "clickhouse.example.com" {
t.Fatalf("expected host without scheme, got %q", config.Host)
}
if config.Port != defaultClickHousePort {
t.Fatalf("expected native default port %d, got %d", defaultClickHousePort, config.Port)
}
if config.ClickHouseProtocol != "native" {
t.Fatalf("expected manual native protocol to be preserved, got %q", config.ClickHouseProtocol)
}
}
func TestClickHouseProtocolMismatchIncludesHTTPParseBinaryResponse(t *testing.T) {
err := errors.New("code: 27, message: Cannot parse input: expected '(' before: '\x02\x00\x01\x00'")
if !isClickHouseProtocolMismatch(err) {
t.Fatalf("expected binary parse response to be treated as protocol mismatch")
}
message := clickHouseAttemptFailureMessage(clickhouse.Native, err)
if !strings.Contains(message, "不像 Native") || strings.Contains(message, "\x00") {
t.Fatalf("expected user-facing native mismatch message without binary bytes, got %q", message)
}
}
func TestWithClickHouseProtocolForcesProtocolSelection(t *testing.T) {
httpConfig := withClickHouseProtocol(connection.ConnectionConfig{
Type: "clickhouse",
Host: "clickhouse.example.com",
Port: 8125,
}, clickhouse.HTTP)
if protocol := detectClickHouseProtocol(httpConfig); protocol != clickhouse.HTTP {
t.Fatalf("expected forced HTTP protocol, got %s", protocol.String())
}
nativeConfig := withClickHouseProtocol(connection.ConnectionConfig{
Type: "clickhouse",
Host: "http://clickhouse.example.com",
Port: 8125,
}, clickhouse.Native)
if protocol := detectClickHouseProtocol(nativeConfig); protocol != clickhouse.Native {
t.Fatalf("expected forced Native protocol, got %s", protocol.String())
}
}
func TestClickHouseProtocolsForAttemptOnlyFallsBackInAutoMode(t *testing.T) {
tests := []struct {
name string
config connection.ConnectionConfig
expected []clickhouse.Protocol
}{
{
name: "auto native falls back to http",
config: connection.ConnectionConfig{
Type: "clickhouse",
Port: 9000,
},
expected: []clickhouse.Protocol{clickhouse.Native, clickhouse.HTTP},
},
{
name: "auto http falls back to native",
config: connection.ConnectionConfig{
Type: "clickhouse",
Port: 8125,
},
expected: []clickhouse.Protocol{clickhouse.HTTP, clickhouse.Native},
},
{
name: "manual http does not try native",
config: connection.ConnectionConfig{
Type: "clickhouse",
Port: 9000,
ClickHouseProtocol: "http",
},
expected: []clickhouse.Protocol{clickhouse.HTTP},
},
{
name: "manual native does not try http",
config: connection.ConnectionConfig{
Type: "clickhouse",
Port: 8125,
ClickHouseProtocol: "native",
},
expected: []clickhouse.Protocol{clickhouse.Native},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := clickHouseProtocolsForAttempt(tt.config)
if len(got) != len(tt.expected) {
t.Fatalf("expected protocols %v, got %v", protocolNames(tt.expected), protocolNames(got))
}
for idx := range got {
if got[idx] != tt.expected[idx] {
t.Fatalf("expected protocols %v, got %v", protocolNames(tt.expected), protocolNames(got))
}
}
})
}
}
func protocolNames(protocols []clickhouse.Protocol) []string {
names := make([]string, 0, len(protocols))
for _, protocol := range protocols {
names = append(names, protocol.String())
}
return names
}
type fakeClickHouseDriver struct{}
func (fakeClickHouseDriver) Open(name string) (driver.Conn, error) {

View File

@@ -3,6 +3,7 @@ package db
import (
"GoNavi-Wails/internal/connection"
"context"
"database/sql"
"fmt"
"strings"
)
@@ -64,6 +65,20 @@ type BatchApplier interface {
ApplyChanges(tableName string, changes connection.ChangeSet) error
}
func requireSingleRowAffected(result sql.Result, action string) error {
affected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("%s未生效无法确认影响行数%v", action, err)
}
if affected == 0 {
return fmt.Errorf("%s未生效未匹配到任何行", action)
}
if affected != 1 {
return fmt.Errorf("%s未生效影响了 %d 行,期望只影响 1 行", action, affected)
}
return nil
}
type databaseFactory func() Database
var databaseFactories = map[string]databaseFactory{

View File

@@ -0,0 +1,21 @@
// Code generated by tools/generate-driver-agent-revisions.sh; DO NOT EDIT.
package db
func init() {
optionalDriverAgentRevisions = map[string]string{
"mariadb": "src-d6c5c6717338834c",
"diros": "src-ed4f0f64ed28d3fa",
"sphinx": "src-f52324f0a812d7c8",
"sqlserver": "src-ec165f18de9cd8b3",
"sqlite": "src-9dea6c76bc931114",
"duckdb": "src-14027ac1de3c50c7",
"dameng": "src-1a08880ff5bbcf31",
"kingbase": "src-28eed0e4d942b724",
"highgo": "src-76146bf97f07f25c",
"vastbase": "src-555b60c4863542b6",
"mongodb": "src-2540a7350c4243aa",
"tdengine": "src-ce3e4a9c46f6b92d",
"clickhouse": "src-78e5ada4da56704d",
}
}

View File

@@ -36,6 +36,11 @@ var optionalGoDrivers = map[string]struct{}{
"clickhouse": {},
}
// optionalDriverAgentRevisions 记录 GoNavi 对各可选 driver-agent 包装逻辑的兼容版本。
// 该 map 由 tools/generate-driver-agent-revisions.sh 按 driver-agent 源码依赖自动生成,
// 避免人工判断需要 bump 哪个驱动 revision。
var optionalDriverAgentRevisions = map[string]string{}
var (
externalDriverDirMu sync.RWMutex
externalDriverDir string
@@ -105,6 +110,10 @@ func IsOptionalGoDriverBuildIncluded(driverType string) bool {
return optionalGoDriverBuildIncluded(normalizeRuntimeDriverType(driverType))
}
func OptionalDriverAgentRevision(driverType string) string {
return strings.TrimSpace(optionalDriverAgentRevisions[normalizeRuntimeDriverType(driverType)])
}
// IsBuiltinDriver 返回指定驱动类型是否为核心内置驱动(始终可用,无需安装)。
func IsBuiltinDriver(driverType string) bool {
_, ok := coreBuiltinDrivers[normalizeRuntimeDriverType(driverType)]

View File

@@ -31,6 +31,17 @@ func TestBuiltinLikeDriversRemainAvailable(t *testing.T) {
}
}
func TestOptionalDriverAgentRevisionsGeneratedForOptionalDrivers(t *testing.T) {
for driverType := range optionalGoDrivers {
if revision := OptionalDriverAgentRevision(driverType); revision == "" {
t.Fatalf("%s 缺少自动生成的 driver-agent revision", driverType)
}
}
if OptionalDriverAgentRevision("doris") != OptionalDriverAgentRevision("diros") {
t.Fatalf("doris/diros revision 应归一一致")
}
}
func TestManagedDriverRequiresInstallMarker(t *testing.T) {
tmpDir := t.TempDir()
SetExternalDriverDownloadDirectory(tmpDir)

View File

@@ -624,8 +624,8 @@ func (m *MySQLDB) ApplyChanges(tableName string, changes connection.ChangeSet) e
if err != nil {
return fmt.Errorf("删除失败:%v", err)
}
if affected, err := res.RowsAffected(); err == nil && affected == 0 {
return fmt.Errorf("删除未生效:未匹配到任何行")
if err := requireSingleRowAffected(res, "删除"); err != nil {
return err
}
}
@@ -658,8 +658,8 @@ func (m *MySQLDB) ApplyChanges(tableName string, changes connection.ChangeSet) e
if err != nil {
return fmt.Errorf("更新失败:%v", err)
}
if affected, err := res.RowsAffected(); err == nil && affected == 0 {
return fmt.Errorf("更新未生效:未匹配到任何行")
if err := requireSingleRowAffected(res, "更新"); err != nil {
return err
}
}

View File

@@ -23,6 +23,7 @@ import (
const (
optionalAgentMethodConnect = "connect"
optionalAgentMethodClose = "close"
optionalAgentMethodMetadata = "metadata"
optionalAgentMethodPing = "ping"
optionalAgentMethodQuery = "query"
optionalAgentMethodExec = "exec"
@@ -58,6 +59,12 @@ type optionalAgentResponse struct {
RowsAffected int64 `json:"rowsAffected,omitempty"`
}
type OptionalDriverAgentMetadata struct {
DriverType string `json:"driverType,omitempty"`
AgentRevision string `json:"agentRevision,omitempty"`
ProtocolSchema string `json:"protocolSchema,omitempty"`
}
type optionalDriverAgentClient struct {
cmd *exec.Cmd
stdin io.WriteCloser
@@ -69,6 +76,25 @@ type optionalDriverAgentClient struct {
driver string
}
func ProbeOptionalDriverAgentMetadata(driverType string, executablePath string) (OptionalDriverAgentMetadata, error) {
client, err := newOptionalDriverAgentClient(driverType, executablePath)
if err != nil {
return OptionalDriverAgentMetadata{}, err
}
defer func() {
_ = client.close()
}()
var metadata OptionalDriverAgentMetadata
if err := client.call(optionalAgentRequest{Method: optionalAgentMethodMetadata}, &metadata, nil, nil); err != nil {
return OptionalDriverAgentMetadata{}, err
}
metadata.DriverType = normalizeRuntimeDriverType(metadata.DriverType)
metadata.AgentRevision = strings.TrimSpace(metadata.AgentRevision)
metadata.ProtocolSchema = strings.TrimSpace(metadata.ProtocolSchema)
return metadata, nil
}
func newOptionalDriverAgentClient(driverType string, executablePath string) (*optionalDriverAgentClient, error) {
pathText := strings.TrimSpace(executablePath)
if pathText == "" {

View File

@@ -24,8 +24,16 @@ var (
)
type oracleRecordingState struct {
mu sync.Mutex
execArgs [][]driver.NamedValue
mu sync.Mutex
execQueries []string
execArgs [][]driver.NamedValue
rowsAffected int64
}
func (s *oracleRecordingState) snapshotExecQueries() []string {
s.mu.Lock()
defer s.mu.Unlock()
return append([]string(nil), s.execQueries...)
}
func (s *oracleRecordingState) snapshotExecArgs() [][]driver.NamedValue {
@@ -63,11 +71,12 @@ func (c *oracleRecordingConn) Close() error { return nil }
func (c *oracleRecordingConn) Begin() (driver.Tx, error) { return oracleRecordingTx{}, nil }
func (c *oracleRecordingConn) ExecContext(_ context.Context, _ string, args []driver.NamedValue) (driver.Result, error) {
func (c *oracleRecordingConn) ExecContext(_ context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
c.state.mu.Lock()
defer c.state.mu.Unlock()
c.state.execQueries = append(c.state.execQueries, query)
c.state.execArgs = append(c.state.execArgs, append([]driver.NamedValue(nil), args...))
return driver.RowsAffected(1), nil
return driver.RowsAffected(c.state.rowsAffected), nil
}
func (c *oracleRecordingConn) QueryContext(_ context.Context, query string, _ []driver.NamedValue) (driver.Rows, error) {
@@ -126,7 +135,7 @@ func openOracleRecordingDB(t *testing.T) (*sql.DB, *oracleRecordingState) {
oracleRecordingDriverMu.Lock()
oracleRecordingDriverSeq++
dsn := fmt.Sprintf("oracle-recording-%d", oracleRecordingDriverSeq)
state := &oracleRecordingState{}
state := &oracleRecordingState{rowsAffected: 1}
oracleRecordingDriverStates[dsn] = state
oracleRecordingDriverMu.Unlock()
@@ -145,6 +154,82 @@ func openOracleRecordingDB(t *testing.T) (*sql.DB, *oracleRecordingState) {
return dbConn, state
}
func TestOracleApplyChangesReturnsErrorWhenUpdateMatchesNoRows(t *testing.T) {
t.Parallel()
dbConn, state := openOracleRecordingDB(t)
state.rowsAffected = 0
oracleDB := &OracleDB{conn: dbConn}
changes := connection.ChangeSet{
Updates: []connection.UpdateRow{{
Keys: map[string]interface{}{
"ID": 7,
},
Values: map[string]interface{}{
"NAME": "new-name",
},
}},
}
err := oracleDB.ApplyChanges("MYCIMLED.EDC_LOG", changes)
if err == nil {
t.Fatal("期望更新未匹配到行时返回错误,实际为 nil")
}
if !strings.Contains(err.Error(), "更新未生效") {
t.Fatalf("错误信息应提示更新未生效,实际=%v", err)
}
}
func TestOracleApplyChangesReturnsErrorWhenUpdateAffectsMultipleRows(t *testing.T) {
t.Parallel()
dbConn, state := openOracleRecordingDB(t)
state.rowsAffected = 2
oracleDB := &OracleDB{conn: dbConn}
changes := connection.ChangeSet{
Updates: []connection.UpdateRow{{
Keys: map[string]interface{}{
"ID": 7,
},
Values: map[string]interface{}{
"NAME": "new-name",
},
}},
}
err := oracleDB.ApplyChanges("MYCIMLED.EDC_LOG", changes)
if err == nil {
t.Fatal("期望更新影响多行时返回错误,实际为 nil")
}
if !strings.Contains(err.Error(), "影响了 2 行") {
t.Fatalf("错误信息应提示影响多行,实际=%v", err)
}
}
func TestOracleApplyChangesReturnsErrorWhenDeleteAffectsMultipleRows(t *testing.T) {
t.Parallel()
dbConn, state := openOracleRecordingDB(t)
state.rowsAffected = 2
oracleDB := &OracleDB{conn: dbConn}
changes := connection.ChangeSet{
Deletes: []map[string]interface{}{{
"STATUS": "stale",
}},
}
err := oracleDB.ApplyChanges("MYCIMLED.EDC_LOG", changes)
if err == nil {
t.Fatal("期望删除影响多行时返回错误,实际为 nil")
}
if !strings.Contains(err.Error(), "影响了 2 行") {
t.Fatalf("错误信息应提示影响多行,实际=%v", err)
}
}
func TestOracleApplyChangesNormalizesTemporalStringsForUpdate(t *testing.T) {
t.Parallel()
@@ -181,3 +266,87 @@ func TestOracleApplyChangesNormalizesTemporalStringsForUpdate(t *testing.T) {
t.Fatalf("日期主键字段应绑定为 time.Time实际=%#v(%T)", args[1].Value, args[1].Value)
}
}
func TestOracleApplyChangesUsesUnquotedRowIDLocator(t *testing.T) {
t.Parallel()
dbConn, state := openOracleRecordingDB(t)
oracleDB := &OracleDB{conn: dbConn}
changes := connection.ChangeSet{
LocatorStrategy: "oracle-rowid",
Updates: []connection.UpdateRow{{
Keys: map[string]interface{}{
"ROWID": "AAAA",
},
Values: map[string]interface{}{
"NAME": "new-name",
},
}},
}
if err := oracleDB.ApplyChanges("MYCIMLED.EDC_LOG", changes); err != nil {
t.Fatalf("ApplyChanges 返回错误: %v", err)
}
executions := state.snapshotExecQueries()
if len(executions) != 1 {
t.Fatalf("期望执行 1 条更新,实际 %d 条", len(executions))
}
query := executions[0]
if !strings.Contains(query, "ROWID = :2") {
t.Fatalf("ROWID 定位条件不正确: %s", query)
}
if strings.Contains(query, "\"ROWID\" =") {
t.Fatalf("ROWID 不应被当作普通列引用: %s", query)
}
}
func TestMySQLApplyChangesReturnsErrorWhenUpdateAffectsMultipleRows(t *testing.T) {
t.Parallel()
dbConn, state := openOracleRecordingDB(t)
state.rowsAffected = 2
mysqlDB := &MySQLDB{conn: dbConn}
changes := connection.ChangeSet{
Updates: []connection.UpdateRow{{
Keys: map[string]interface{}{
"id": 7,
},
Values: map[string]interface{}{
"name": "new-name",
},
}},
}
err := mysqlDB.ApplyChanges("users", changes)
if err == nil {
t.Fatal("期望 MySQL 更新影响多行时返回错误,实际为 nil")
}
if !strings.Contains(err.Error(), "影响了 2 行") {
t.Fatalf("错误信息应提示影响多行,实际=%v", err)
}
}
func TestPostgresApplyChangesReturnsErrorWhenDeleteAffectsMultipleRows(t *testing.T) {
t.Parallel()
dbConn, state := openOracleRecordingDB(t)
state.rowsAffected = 2
postgresDB := &PostgresDB{conn: dbConn}
changes := connection.ChangeSet{
Deletes: []map[string]interface{}{{
"id": 7,
}},
}
err := postgresDB.ApplyChanges("public.users", changes)
if err == nil {
t.Fatal("期望 PostgreSQL 删除影响多行时返回错误,实际为 nil")
}
if !strings.Contains(err.Error(), "影响了 2 行") {
t.Fatalf("错误信息应提示影响多行,实际=%v", err)
}
}

View File

@@ -0,0 +1,32 @@
package db
import (
"net/url"
"testing"
"GoNavi-Wails/internal/connection"
)
func TestOracleGetDSNIncludesQueryPerformanceOptions(t *testing.T) {
t.Parallel()
dsn := (&OracleDB{}).getDSN(connection.ConnectionConfig{
Host: "db.example.com",
Port: 1521,
User: "scott",
Password: "tiger",
Database: "ORCLPDB1",
})
parsed, err := url.Parse(dsn)
if err != nil {
t.Fatalf("解析 Oracle DSN 失败: %v", err)
}
query := parsed.Query()
if got := query.Get("PREFETCH_ROWS"); got != "10000" {
t.Fatalf("PREFETCH_ROWS = %q, want 10000", got)
}
if got := query.Get("LOB FETCH"); got != "POST" {
t.Fatalf("LOB FETCH = %q, want POST", got)
}
}

View File

@@ -44,6 +44,10 @@ func (o *OracleDB) getDSN(config connection.ConnectionConfig) string {
q.Set("SSL", "TRUE")
q.Set("SSL VERIFY", "FALSE")
}
// 提高 prefetch 行数,减少大结果集的网络往返次数(默认仅 25 行/次)
q.Set("PREFETCH_ROWS", "10000")
// LOB 数据延迟加载,避免大 LOB 列影响普通查询性能
q.Set("LOB FETCH", "POST")
if encoded := q.Encode(); encoded != "" {
u.RawQuery = encoded
}
@@ -263,16 +267,31 @@ func (o *OracleDB) GetCreateStatement(dbName, tableName string) (string, error)
}
func (o *OracleDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
query := fmt.Sprintf(`SELECT column_name, data_type, nullable, data_default
FROM all_tab_columns
WHERE owner = '%s' AND table_name = '%s'
ORDER BY column_id`, strings.ToUpper(dbName), strings.ToUpper(tableName))
query := fmt.Sprintf(`SELECT c.column_name, c.data_type, c.nullable, c.data_default,
CASE WHEN pk.column_name IS NOT NULL THEN 'PRI' ELSE '' END AS column_key
FROM all_tab_columns c
LEFT JOIN (
SELECT cols.owner, cols.table_name, cols.column_name
FROM all_constraints cons
JOIN all_cons_columns cols
ON cons.owner = cols.owner AND cons.constraint_name = cols.constraint_name
WHERE cons.constraint_type = 'P'
) pk ON c.owner = pk.owner AND c.table_name = pk.table_name AND c.column_name = pk.column_name
WHERE c.owner = '%s' AND c.table_name = '%s'
ORDER BY c.column_id`, strings.ToUpper(dbName), strings.ToUpper(tableName))
if dbName == "" {
query = fmt.Sprintf(`SELECT column_name, data_type, nullable, data_default
FROM user_tab_columns
WHERE table_name = '%s'
ORDER BY column_id`, strings.ToUpper(tableName))
query = fmt.Sprintf(`SELECT c.column_name, c.data_type, c.nullable, c.data_default,
CASE WHEN pk.column_name IS NOT NULL THEN 'PRI' ELSE '' END AS column_key
FROM user_tab_columns c
LEFT JOIN (
SELECT cols.table_name, cols.column_name
FROM user_constraints cons
JOIN user_cons_columns cols USING (constraint_name)
WHERE cons.constraint_type = 'P'
) pk ON c.table_name = pk.table_name AND c.column_name = pk.column_name
WHERE c.table_name = '%s'
ORDER BY c.column_id`, strings.ToUpper(tableName))
}
data, _, err := o.Query(query)
@@ -286,6 +305,7 @@ func (o *OracleDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefi
Name: fmt.Sprintf("%v", row["COLUMN_NAME"]),
Type: fmt.Sprintf("%v", row["DATA_TYPE"]),
Nullable: fmt.Sprintf("%v", row["NULLABLE"]),
Key: fmt.Sprintf("%v", row["COLUMN_KEY"]),
}
if row["DATA_DEFAULT"] != nil {
@@ -299,17 +319,31 @@ func (o *OracleDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefi
}
func (o *OracleDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
query := fmt.Sprintf(`SELECT index_name, column_name, uniqueness
FROM all_ind_columns
JOIN all_indexes USING (index_name, owner)
WHERE table_owner = '%s' AND table_name = '%s'`,
strings.ToUpper(dbName), strings.ToUpper(tableName))
esc := func(s string) string { return strings.ReplaceAll(strings.ToUpper(strings.TrimSpace(s)), "'", "''") }
table := esc(tableName)
if table == "" {
return nil, fmt.Errorf("表名不能为空")
}
if dbName == "" {
query = fmt.Sprintf(`SELECT index_name, column_name, uniqueness
FROM user_ind_columns
JOIN user_indexes USING (index_name)
WHERE table_name = '%s'`, strings.ToUpper(tableName))
query := fmt.Sprintf(`SELECT c.index_name, c.column_name, i.uniqueness, c.column_position, i.index_type
FROM all_ind_columns c
JOIN all_indexes i ON i.owner = c.index_owner AND i.index_name = c.index_name
WHERE c.table_owner = '%s'
AND c.table_name = '%s'
AND c.column_name IS NOT NULL
AND c.column_name NOT LIKE 'SYS_NC%%$'
AND i.index_type NOT LIKE 'FUNCTION-BASED%%'
ORDER BY c.index_name, c.column_position`, esc(dbName), table)
if strings.TrimSpace(dbName) == "" {
query = fmt.Sprintf(`SELECT c.index_name, c.column_name, i.uniqueness, c.column_position, i.index_type
FROM user_ind_columns c
JOIN user_indexes i ON i.index_name = c.index_name
WHERE c.table_name = '%s'
AND c.column_name IS NOT NULL
AND c.column_name NOT LIKE 'SYS_NC%%$'
AND i.index_type NOT LIKE 'FUNCTION-BASED%%'
ORDER BY c.index_name, c.column_position`, table)
}
data, _, err := o.Query(query)
@@ -317,19 +351,46 @@ func (o *OracleDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefin
return nil, err
}
getValue := func(row map[string]interface{}, names ...string) interface{} {
for _, name := range names {
if value, ok := row[name]; ok {
return value
}
for key, value := range row {
if strings.EqualFold(key, name) {
return value
}
}
}
return nil
}
parseInt := func(value interface{}) int {
var n int
_, _ = fmt.Sscanf(strings.TrimSpace(fmt.Sprintf("%v", value)), "%d", &n)
return n
}
var indexes []connection.IndexDefinition
for _, row := range data {
unique := 1
if val, ok := row["UNIQUENESS"]; ok && val == "UNIQUE" {
unique = 0
uniqueness := strings.ToUpper(strings.TrimSpace(fmt.Sprintf("%v", getValue(row, "UNIQUENESS"))))
nonUnique := 1
if uniqueness == "UNIQUE" {
nonUnique = 0
}
indexType := strings.ToUpper(strings.TrimSpace(fmt.Sprintf("%v", getValue(row, "INDEX_TYPE"))))
if indexType == "" || indexType == "<NIL>" {
indexType = "BTREE"
}
idx := connection.IndexDefinition{
Name: fmt.Sprintf("%v", row["INDEX_NAME"]),
ColumnName: fmt.Sprintf("%v", row["COLUMN_NAME"]),
NonUnique: unique,
// SeqInIndex is harder to get in simple join, omitting or estimating
IndexType: "BTREE", // Default assumption
Name: strings.TrimSpace(fmt.Sprintf("%v", getValue(row, "INDEX_NAME"))),
ColumnName: strings.TrimSpace(fmt.Sprintf("%v", getValue(row, "COLUMN_NAME"))),
NonUnique: nonUnique,
SeqInIndex: parseInt(getValue(row, "COLUMN_POSITION")),
IndexType: indexType,
}
if idx.Name == "" || idx.ColumnName == "" || strings.EqualFold(idx.ColumnName, "<nil>") {
continue
}
indexes = append(indexes, idx)
}
@@ -531,23 +592,38 @@ func (o *OracleDB) ApplyChanges(tableName string, changes connection.ChangeSet)
qualifiedTable = quoteIdent(table)
}
// 1. Deletes
for _, pk := range changes.Deletes {
isOracleRowIDLocator := strings.EqualFold(strings.TrimSpace(changes.LocatorStrategy), "oracle-rowid")
buildWhere := func(keys map[string]interface{}, startIndex int) ([]string, []interface{}, int) {
var wheres []string
var args []interface{}
idx := 0
for k, v := range pk {
idx := startIndex
for k, v := range keys {
idx++
if isOracleRowIDLocator && strings.EqualFold(strings.TrimSpace(k), "ROWID") {
wheres = append(wheres, fmt.Sprintf("ROWID = :%d", idx))
args = append(args, v)
continue
}
wheres = append(wheres, fmt.Sprintf("%s = :%d", quoteIdent(k), idx))
args = append(args, normalizeOracleValueForWrite(k, v, columnTypeMap))
}
return wheres, args, idx
}
// 1. Deletes
for _, pk := range changes.Deletes {
wheres, args, _ := buildWhere(pk, 0)
if len(wheres) == 0 {
continue
}
query := fmt.Sprintf("DELETE FROM %s WHERE %s", qualifiedTable, strings.Join(wheres, " AND "))
if _, err := tx.Exec(query, args...); err != nil {
res, err := tx.Exec(query, args...)
if err != nil {
return fmt.Errorf("删除失败:%v", err)
}
if err := requireSingleRowAffected(res, "删除"); err != nil {
return err
}
}
// 2. Updates
@@ -566,21 +642,21 @@ func (o *OracleDB) ApplyChanges(tableName string, changes connection.ChangeSet)
continue
}
var wheres []string
for k, v := range update.Keys {
idx++
wheres = append(wheres, fmt.Sprintf("%s = :%d", quoteIdent(k), idx))
args = append(args, normalizeOracleValueForWrite(k, v, columnTypeMap))
}
wheres, whereArgs, _ := buildWhere(update.Keys, idx)
args = append(args, whereArgs...)
if len(wheres) == 0 {
return fmt.Errorf("更新操作需要主键条件")
}
query := fmt.Sprintf("UPDATE %s SET %s WHERE %s", qualifiedTable, strings.Join(sets, ", "), strings.Join(wheres, " AND "))
if _, err := tx.Exec(query, args...); err != nil {
res, err := tx.Exec(query, args...)
if err != nil {
return fmt.Errorf("更新失败:%v", err)
}
if err := requireSingleRowAffected(res, "更新"); err != nil {
return err
}
}
// 3. Inserts
@@ -602,9 +678,13 @@ func (o *OracleDB) ApplyChanges(tableName string, changes connection.ChangeSet)
}
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(cols, ", "), strings.Join(placeholders, ", "))
if _, err := tx.Exec(query, args...); err != nil {
res, err := tx.Exec(query, args...)
if err != nil {
return fmt.Errorf("插入失败:%v", err)
}
if affected, err := res.RowsAffected(); err == nil && affected == 0 {
return fmt.Errorf("插入未生效:未影响任何行")
}
}
return tx.Commit()

View File

@@ -408,6 +408,12 @@ JOIN pg_attribute a ON a.attrelid = t.oid AND a.attnum = x.attnum
WHERE t.relkind IN ('r', 'p')
AND t.relname = '%s'
AND n.nspname = '%s'
AND ix.indisvalid
AND ix.indpred IS NULL
AND x.ordinality <= ix.indnkeyatts
AND NOT EXISTS (
SELECT 1 FROM unnest(ix.indkey) AS expr_key(attnum) WHERE expr_key.attnum <= 0
)
ORDER BY i.relname, x.ordinality`, esc(table), esc(schema))
data, _, err := p.Query(query)
@@ -758,9 +764,13 @@ func (p *PostgresDB) ApplyChanges(tableName string, changes connection.ChangeSet
continue
}
query := fmt.Sprintf("DELETE FROM %s WHERE %s", qualifiedTable, strings.Join(wheres, " AND "))
if _, err := tx.Exec(query, args...); err != nil {
res, err := tx.Exec(query, args...)
if err != nil {
return fmt.Errorf("删除失败:%v", err)
}
if err := requireSingleRowAffected(res, "删除"); err != nil {
return err
}
}
// 2. Updates
@@ -791,9 +801,13 @@ func (p *PostgresDB) ApplyChanges(tableName string, changes connection.ChangeSet
}
query := fmt.Sprintf("UPDATE %s SET %s WHERE %s", qualifiedTable, strings.Join(sets, ", "), strings.Join(wheres, " AND "))
if _, err := tx.Exec(query, args...); err != nil {
res, err := tx.Exec(query, args...)
if err != nil {
return fmt.Errorf("更新失败:%v", err)
}
if err := requireSingleRowAffected(res, "更新"); err != nil {
return err
}
}
// 3. Inserts

View File

@@ -0,0 +1,251 @@
#!/usr/bin/env bash
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
cd "$SCRIPT_DIR"
DEFAULT_DRIVERS=(mariadb diros sphinx sqlserver sqlite duckdb dameng kingbase highgo vastbase mongodb tdengine clickhouse)
OUTPUT_FILE="internal/db/driver_agent_revisions_gen.go"
usage() {
cat <<'EOF'
用法:
./tools/generate-driver-agent-revisions.sh [选项]
选项:
--platform <GOOS/GOARCH> 按目标平台解析 Go build tags默认使用当前 Go 环境
--drivers <列表> 指定驱动列表(逗号分隔),默认生成所有 optional driver
-h, --help 显示帮助
EOF
}
normalize_driver() {
local value
value="$(printf '%s' "$1" | tr '[:upper:]' '[:lower:]' | tr -d '[:space:]')"
case "$value" in
doris|diros) echo "diros" ;;
mariadb|diros|sphinx|sqlserver|sqlite|duckdb|dameng|kingbase|highgo|vastbase|mongodb|tdengine|clickhouse)
echo "$value"
;;
*)
return 1
;;
esac
}
build_driver_name() {
echo "$1"
}
hash_file() {
local target="$1"
if command -v sha256sum >/dev/null 2>&1; then
sha256sum "$target" | awk '{print $1}'
return
fi
if command -v shasum >/dev/null 2>&1; then
shasum -a 256 "$target" | awk '{print $1}'
return
fi
echo "未找到 sha256sum 或 shasum" >&2
exit 1
}
should_include_internal_db_file() {
local driver="$1"
local identity="$2"
case "$identity" in
internal/db/agent_process_stub.go|\
internal/db/agent_process_windows.go|\
internal/db/database.go|\
internal/db/database_optional_factories_full.go|\
internal/db/database_optional_factories_lite.go|\
internal/db/driver_agent_binary_check.go|\
internal/db/driver_support.go|\
internal/db/json_decode.go|\
internal/db/mysql_agent_path.go|\
internal/db/optional_driver_agent_impl.go|\
internal/db/optional_driver_build_full.go|\
internal/db/optional_driver_build_lite.go|\
internal/db/query_value.go|\
internal/db/scan_rows.go|\
internal/db/ssl_mode.go|\
internal/db/timeout.go)
return 0
;;
esac
case "$driver:$identity" in
mariadb:internal/db/mariadb_impl.go|\
diros:internal/db/diros_impl.go|\
diros:internal/db/mysql_impl.go|\
sphinx:internal/db/sphinx_impl.go|\
sphinx:internal/db/mysql_impl.go|\
sqlserver:internal/db/sqlserver_impl.go|\
sqlite:internal/db/sqlite_impl.go|\
duckdb:internal/db/duckdb_impl.go|\
duckdb:internal/db/duckdb_driver_import.go|\
duckdb:internal/db/duckdb_platform_supported.go|\
duckdb:internal/db/duckdb_platform_unsupported.go|\
dameng:internal/db/dameng_impl.go|\
dameng:internal/db/dameng_metadata.go|\
kingbase:internal/db/kingbase_impl.go|\
kingbase:internal/db/kingbase_identifier_utils.go|\
highgo:internal/db/highgo_impl.go|\
vastbase:internal/db/vastbase_impl.go|\
mongodb:internal/db/mongodb_impl.go|\
mongodb:internal/db/mongodb_impl_v1.go|\
tdengine:internal/db/tdengine_impl.go|\
clickhouse:internal/db/clickhouse_impl.go)
return 0
;;
esac
return 1
}
should_include_source_file() {
local driver="$1"
local identity="$2"
if [[ "$identity" == internal/db/* ]]; then
should_include_internal_db_file "$driver" "$identity"
return
fi
return 0
}
target_platform=""
driver_csv=""
while [[ $# -gt 0 ]]; do
case "$1" in
--platform)
target_platform="${2:-}"
shift 2
;;
--drivers)
driver_csv="${2:-}"
shift 2
;;
-h|--help)
usage
exit 0
;;
*)
echo "未知参数:$1" >&2
usage
exit 1
;;
esac
done
if ! command -v go >/dev/null 2>&1; then
echo "未找到 Go请先安装 Go 并确保 go 在 PATH 中。" >&2
exit 1
fi
if [[ -z "$target_platform" ]]; then
target_platform="$(go env GOOS)/$(go env GOARCH)"
fi
if [[ "$target_platform" != */* ]]; then
echo "--platform 参数格式错误,应为 GOOS/GOARCH例如 darwin/arm64" >&2
exit 1
fi
goos="${target_platform%%/*}"
goarch="${target_platform##*/}"
gomodcache="$(go env GOMODCACHE)"
declare -a drivers=()
if [[ -n "$driver_csv" ]]; then
IFS=',' read -r -a raw_drivers <<<"$driver_csv"
for item in "${raw_drivers[@]}"; do
drivers+=("$(normalize_driver "$item")")
done
else
drivers=("${DEFAULT_DRIVERS[@]}")
fi
fingerprint_driver() {
local driver="$1"
local build_driver tag cgo_enabled tmp file identity file_hash revision
build_driver="$(build_driver_name "$driver")"
tag="gonavi_${build_driver}_driver"
cgo_enabled=0
if [[ "$driver" == "duckdb" ]]; then
cgo_enabled=1
fi
tmp="$(mktemp "${TMPDIR:-/tmp}/gonavi-agent-revision.XXXXXX")"
{
printf 'driver=%s\n' "$driver"
printf 'build_tag=%s\n' "$tag"
printf 'goos=%s\n' "$goos"
printf 'goarch=%s\n' "$goarch"
} >"$tmp"
while IFS= read -r file; do
[[ -n "$file" && -f "$file" ]] || continue
case "$file" in
"$SCRIPT_DIR"/*)
identity="${file#$SCRIPT_DIR/}"
;;
"$gomodcache"/*)
identity="gomod/${file#$gomodcache/}"
;;
*)
identity="$file"
;;
esac
if [[ "$identity" == "$OUTPUT_FILE" ]]; then
continue
fi
if ! should_include_source_file "$driver" "$identity"; then
continue
fi
file_hash="$(hash_file "$file")"
printf '%s %s\n' "$file_hash" "$identity"
done < <(
CGO_ENABLED="$cgo_enabled" GOOS="$goos" GOARCH="$goarch" GOTOOLCHAIN=auto \
go list -deps \
-tags "$tag" \
-f '{{if not .Standard}}{{range .GoFiles}}{{$.Dir}}/{{.}}{{"\n"}}{{end}}{{range .CgoFiles}}{{$.Dir}}/{{.}}{{"\n"}}{{end}}{{end}}' \
./cmd/optional-driver-agent | sort -u
) >>"$tmp"
revision="$(hash_file "$tmp" | cut -c1-16)"
rm -f "$tmp"
printf 'src-%s' "$revision"
}
tmp_output="$(mktemp "${TMPDIR:-/tmp}/gonavi-agent-revisions-go.XXXXXX")"
{
cat <<'EOF'
// Code generated by tools/generate-driver-agent-revisions.sh; DO NOT EDIT.
package db
func init() {
optionalDriverAgentRevisions = map[string]string{
EOF
for driver in "${drivers[@]}"; do
revision="$(fingerprint_driver "$driver")"
printf '\t\t"%s": "%s",\n' "$driver" "$revision"
done
cat <<'EOF'
}
}
EOF
} >"$tmp_output"
gofmt -w "$tmp_output"
if [[ -f "$OUTPUT_FILE" ]] && cmp -s "$tmp_output" "$OUTPUT_FILE"; then
rm -f "$tmp_output"
else
mv "$tmp_output" "$OUTPUT_FILE"
fi
echo "已生成 driver-agent revisions: $OUTPUT_FILE ($goos/$goarch)"