mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-05-12 12:19:47 +08:00
Compare commits
7 Commits
release/0.
...
release/0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c1ebce4ef5 | ||
|
|
c927e33c8c | ||
|
|
824aafbdea | ||
|
|
0c1586d7a4 | ||
|
|
b1ef52f62e | ||
|
|
05a913ccb2 | ||
|
|
f51dbcfb2c |
1
.github/workflows/dev-build.yml
vendored
1
.github/workflows/dev-build.yml
vendored
@@ -246,6 +246,7 @@ jobs:
|
||||
run: |
|
||||
set -euo pipefail
|
||||
DEV_VERSION="${{ steps.version.outputs.version }}"
|
||||
./tools/generate-driver-agent-revisions.sh --platform "${{ matrix.platform }}"
|
||||
if [ -n "${{ matrix.wails_tags }}" ]; then
|
||||
wails build -platform ${{ matrix.platform }} -clean -o ${{ matrix.build_name }} -tags "${{ matrix.wails_tags }}" -ldflags "-s -w -X GoNavi-Wails/internal/app.AppVersion=${DEV_VERSION}"
|
||||
else
|
||||
|
||||
1
.github/workflows/release.yml
vendored
1
.github/workflows/release.yml
vendored
@@ -237,6 +237,7 @@ jobs:
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
./tools/generate-driver-agent-revisions.sh --platform "${{ matrix.platform }}"
|
||||
if [ -n "${{ matrix.wails_tags }}" ]; then
|
||||
wails build -platform ${{ matrix.platform }} -clean -o ${{ matrix.build_name }} -tags "${{ matrix.wails_tags }}" -ldflags "-s -w -X GoNavi-Wails/internal/app.AppVersion=${{ github.ref_name }}"
|
||||
else
|
||||
|
||||
@@ -152,6 +152,8 @@ echo "🚀 开始构建 optional-driver-agent"
|
||||
echo " 平台:$goos/$goarch"
|
||||
echo " 输出目录:$output_dir_abs"
|
||||
echo " 驱动列表:${drivers[*]}"
|
||||
echo "🧭 生成 driver-agent revision 指纹"
|
||||
"$SCRIPT_DIR/tools/generate-driver-agent-revisions.sh" --platform "$target_platform"
|
||||
|
||||
for driver in "${drivers[@]}"; do
|
||||
if [[ "$driver" == "duckdb" && "$goos" == "windows" && "$goarch" != "amd64" ]]; then
|
||||
|
||||
@@ -155,6 +155,7 @@ package_macos_release() {
|
||||
local archive_suffix="$2"
|
||||
|
||||
echo -e "${GREEN}🍎 正在构建 macOS (${platform})...${NC}"
|
||||
generate_driver_agent_revisions "darwin/${platform}"
|
||||
wails build -platform "darwin/${platform}" -clean -ldflags "$LDFLAGS"
|
||||
if [ $? -ne 0 ]; then
|
||||
echo -e "${RED} ❌ macOS ${platform} 构建失败。${NC}"
|
||||
@@ -185,6 +186,12 @@ package_macos_release() {
|
||||
echo " ✅ 已生成 $zip_name"
|
||||
}
|
||||
|
||||
generate_driver_agent_revisions() {
|
||||
local platform="$1"
|
||||
echo " 🧭 正在生成 driver-agent revision 指纹 (${platform})..."
|
||||
./tools/generate-driver-agent-revisions.sh --platform "$platform"
|
||||
}
|
||||
|
||||
echo -e "${GREEN}🚀 开始构建 $APP_NAME $VERSION...${NC}"
|
||||
|
||||
# 清理并创建输出目录
|
||||
@@ -197,6 +204,7 @@ package_macos_release "amd64" "mac-amd64"
|
||||
# --- Windows AMD64 构建 ---
|
||||
echo -e "${GREEN}🪟 正在构建 Windows (amd64)...${NC}"
|
||||
if command -v x86_64-w64-mingw32-gcc &> /dev/null; then
|
||||
generate_driver_agent_revisions "windows/amd64"
|
||||
wails build -platform windows/amd64 -clean -ldflags "$LDFLAGS"
|
||||
if [ $? -eq 0 ]; then
|
||||
TARGET_EXE="$DIST_DIR/${APP_NAME}-${VERSION}-windows-amd64.exe"
|
||||
@@ -213,6 +221,7 @@ fi
|
||||
# --- Windows ARM64 构建 ---
|
||||
echo -e "${GREEN}🪟 正在构建 Windows (arm64)...${NC}"
|
||||
if command -v aarch64-w64-mingw32-gcc &> /dev/null; then
|
||||
generate_driver_agent_revisions "windows/arm64"
|
||||
wails build -platform windows/arm64 -clean -ldflags "$LDFLAGS"
|
||||
if [ $? -eq 0 ]; then
|
||||
TARGET_EXE="$DIST_DIR/${APP_NAME}-${VERSION}-windows-arm64.exe"
|
||||
@@ -235,6 +244,7 @@ CURRENT_ARCH=$(uname -m)
|
||||
|
||||
if [ "$CURRENT_OS" = "Linux" ] && [ "$CURRENT_ARCH" = "x86_64" ]; then
|
||||
# 本机 Linux amd64,直接构建
|
||||
generate_driver_agent_revisions "linux/amd64"
|
||||
wails build -platform linux/amd64 -clean -ldflags "$LDFLAGS"
|
||||
if [ $? -eq 0 ]; then
|
||||
TARGET_LINUX_BIN="$DIST_DIR/${APP_NAME}-${VERSION}-linux-amd64"
|
||||
@@ -255,6 +265,7 @@ elif command -v x86_64-linux-gnu-gcc &> /dev/null; then
|
||||
export CC=x86_64-linux-gnu-gcc
|
||||
export CXX=x86_64-linux-gnu-g++
|
||||
export CGO_ENABLED=1
|
||||
generate_driver_agent_revisions "linux/amd64"
|
||||
wails build -platform linux/amd64 -clean -ldflags "$LDFLAGS"
|
||||
if [ $? -eq 0 ]; then
|
||||
TARGET_LINUX_BIN="$DIST_DIR/${APP_NAME}-${VERSION}-linux-amd64"
|
||||
@@ -279,6 +290,7 @@ fi
|
||||
echo -e "${GREEN}🐧 正在构建 Linux (arm64)...${NC}"
|
||||
if [ "$CURRENT_OS" = "Linux" ] && [ "$CURRENT_ARCH" = "aarch64" ]; then
|
||||
# 本机 Linux arm64,直接构建
|
||||
generate_driver_agent_revisions "linux/arm64"
|
||||
wails build -platform linux/arm64 -clean -ldflags "$LDFLAGS"
|
||||
if [ $? -eq 0 ]; then
|
||||
TARGET_LINUX_BIN="$DIST_DIR/${APP_NAME}-${VERSION}-linux-arm64"
|
||||
@@ -298,6 +310,7 @@ elif command -v aarch64-linux-gnu-gcc &> /dev/null; then
|
||||
export CC=aarch64-linux-gnu-gcc
|
||||
export CXX=aarch64-linux-gnu-g++
|
||||
export CGO_ENABLED=1
|
||||
generate_driver_agent_revisions "linux/arm64"
|
||||
wails build -platform linux/arm64 -clean -ldflags "$LDFLAGS"
|
||||
if [ $? -eq 0 ]; then
|
||||
TARGET_LINUX_BIN="$DIST_DIR/${APP_NAME}-${VERSION}-linux-arm64"
|
||||
|
||||
@@ -37,6 +37,7 @@ type agentResponse struct {
|
||||
const (
|
||||
agentMethodConnect = "connect"
|
||||
agentMethodClose = "close"
|
||||
agentMethodMetadata = "metadata"
|
||||
agentMethodPing = "ping"
|
||||
agentMethodQuery = "query"
|
||||
agentMethodExec = "exec"
|
||||
@@ -131,6 +132,13 @@ func handleRequest(inst *db.Database, req agentRequest) agentResponse {
|
||||
*inst = nil
|
||||
}
|
||||
return resp
|
||||
case agentMethodMetadata:
|
||||
resp.Data = map[string]string{
|
||||
"driverType": strings.TrimSpace(agentDriverType),
|
||||
"agentRevision": db.OptionalDriverAgentRevision(agentDriverType),
|
||||
"protocolSchema": "json-lines-v1",
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
if *inst == nil {
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"time"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"GoNavi-Wails/internal/db"
|
||||
)
|
||||
|
||||
type duckMapLike map[any]any
|
||||
@@ -66,6 +67,33 @@ func TestNormalizeAgentResponseData_KeepByteSlice(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleRequestMetadataReportsAgentRevision(t *testing.T) {
|
||||
previousDriverType := agentDriverType
|
||||
previousFactory := agentDatabaseFactory
|
||||
t.Cleanup(func() {
|
||||
agentDriverType = previousDriverType
|
||||
agentDatabaseFactory = previousFactory
|
||||
})
|
||||
agentDriverType = "clickhouse"
|
||||
agentDatabaseFactory = func() db.Database { return nil }
|
||||
|
||||
var inst db.Database
|
||||
resp := handleRequest(&inst, agentRequest{ID: 7, Method: agentMethodMetadata})
|
||||
if !resp.Success {
|
||||
t.Fatalf("metadata request failed: %s", resp.Error)
|
||||
}
|
||||
data, ok := resp.Data.(map[string]string)
|
||||
if !ok {
|
||||
t.Fatalf("metadata response data type = %T", resp.Data)
|
||||
}
|
||||
if data["driverType"] != "clickhouse" {
|
||||
t.Fatalf("unexpected driver type: %q", data["driverType"])
|
||||
}
|
||||
if data["agentRevision"] != db.OptionalDriverAgentRevision("clickhouse") {
|
||||
t.Fatalf("unexpected agent revision: %q", data["agentRevision"])
|
||||
}
|
||||
}
|
||||
|
||||
type fakeAgentTimeoutDB struct {
|
||||
queryCalled bool
|
||||
queryContextCalled bool
|
||||
|
||||
@@ -1 +1 @@
|
||||
0295a42fd931778d85157816d79d29e5
|
||||
d0464f9da25e9356e61652e638c99ffe
|
||||
@@ -95,6 +95,7 @@ type ChoiceCardOption = {
|
||||
label: string;
|
||||
description?: string;
|
||||
};
|
||||
type ClickHouseProtocolChoice = "auto" | "http" | "native";
|
||||
const MAX_URI_LENGTH = 4096;
|
||||
const MAX_URI_HOSTS = 32;
|
||||
const MAX_TIMEOUT_SECONDS = 3600;
|
||||
@@ -102,6 +103,25 @@ const CONNECTION_MODAL_WIDTH = 960;
|
||||
const CONNECTION_MODAL_BODY_HEIGHT = 620;
|
||||
const STEP1_SIDEBAR_DIVIDER_DARK = "rgba(255, 255, 255, 0.16)";
|
||||
const STEP1_SIDEBAR_DIVIDER_LIGHT = "rgba(0, 0, 0, 0.08)";
|
||||
const CLICKHOUSE_PROTOCOL_OPTIONS: Array<{
|
||||
value: ClickHouseProtocolChoice;
|
||||
label: string;
|
||||
}> = [
|
||||
{ value: "auto", label: "自动" },
|
||||
{ value: "http", label: "HTTP" },
|
||||
{ value: "native", label: "Native" },
|
||||
];
|
||||
|
||||
const normalizeClickHouseProtocolValue = (
|
||||
value: unknown,
|
||||
): ClickHouseProtocolChoice => {
|
||||
const text = String(value || "")
|
||||
.trim()
|
||||
.toLowerCase();
|
||||
if (text === "http" || text === "https") return "http";
|
||||
if (text === "native" || text === "tcp") return "native";
|
||||
return "auto";
|
||||
};
|
||||
type ConnectionSecretKey =
|
||||
| "primaryPassword"
|
||||
| "sshPassword"
|
||||
@@ -216,6 +236,10 @@ type DriverStatusSnapshot = {
|
||||
type: string;
|
||||
name: string;
|
||||
connectable: boolean;
|
||||
expectedRevision?: string;
|
||||
needsUpdate?: boolean;
|
||||
updateReason?: string;
|
||||
affectedConnections?: number;
|
||||
message?: string;
|
||||
};
|
||||
|
||||
@@ -228,6 +252,14 @@ const normalizeDriverType = (value: string): string => {
|
||||
return normalized;
|
||||
};
|
||||
|
||||
const resolveConnectionDriverType = (type: string, driver?: string): string => {
|
||||
const normalizedType = normalizeDriverType(type);
|
||||
if (normalizedType !== "custom") {
|
||||
return normalizedType;
|
||||
}
|
||||
return normalizeDriverType(driver || "");
|
||||
};
|
||||
|
||||
const ConnectionModal: React.FC<{
|
||||
open: boolean;
|
||||
onClose: () => void;
|
||||
@@ -300,6 +332,7 @@ const ConnectionModal: React.FC<{
|
||||
const redisTopology = Form.useWatch("redisTopology", form) || "single";
|
||||
const sslMode = Form.useWatch("sslMode", form) || "preferred";
|
||||
const proxyType = Form.useWatch("proxyType", form) || "socks5";
|
||||
const customDriver = Form.useWatch("driver", form) || "";
|
||||
const mongoReadPreference =
|
||||
Form.useWatch("mongoReadPreference", form) || "primary";
|
||||
const mongoAuthMechanism = Form.useWatch("mongoAuthMechanism", form) || "";
|
||||
@@ -831,6 +864,12 @@ const ConnectionModal: React.FC<{
|
||||
type,
|
||||
name: String(item.name || item.type || type).trim(),
|
||||
connectable: !!item.connectable,
|
||||
expectedRevision: String(item.expectedRevision || "").trim() || undefined,
|
||||
needsUpdate: !!item.needsUpdate,
|
||||
updateReason: String(item.updateReason || "").trim() || undefined,
|
||||
affectedConnections: Number.isFinite(Number(item.affectedConnections))
|
||||
? Number(item.affectedConnections)
|
||||
: undefined,
|
||||
message: String(item.message || "").trim() || undefined,
|
||||
};
|
||||
});
|
||||
@@ -850,8 +889,9 @@ const ConnectionModal: React.FC<{
|
||||
|
||||
const resolveDriverUnavailableReason = async (
|
||||
type: string,
|
||||
driver?: string,
|
||||
): Promise<string> => {
|
||||
const normalized = normalizeDriverType(type);
|
||||
const normalized = resolveConnectionDriverType(type, driver);
|
||||
if (!normalized || normalized === "custom") {
|
||||
return "";
|
||||
}
|
||||
@@ -1000,6 +1040,13 @@ const ConnectionModal: React.FC<{
|
||||
}
|
||||
};
|
||||
|
||||
const normalizeUriBool = (raw: unknown) => {
|
||||
const text = String(raw ?? "")
|
||||
.trim()
|
||||
.toLowerCase();
|
||||
return text === "1" || text === "true" || text === "yes" || text === "on";
|
||||
};
|
||||
|
||||
const normalizeFileDbPath = (rawPath: string): string => {
|
||||
let pathText = String(rawPath || "").trim();
|
||||
if (!pathText) {
|
||||
@@ -1117,6 +1164,44 @@ const ConnectionModal: React.FC<{
|
||||
};
|
||||
};
|
||||
|
||||
const parseClickHouseHTTPUriToValues = (
|
||||
uriText: string,
|
||||
fallbackPort?: number,
|
||||
): Record<string, any> | null => {
|
||||
const trimmed = String(uriText || "").trim();
|
||||
const lower = trimmed.toLowerCase();
|
||||
const isHttps = lower.startsWith("https://");
|
||||
const isHttp = lower.startsWith("http://");
|
||||
if (!isHttp && !isHttps) {
|
||||
return null;
|
||||
}
|
||||
const defaultPort =
|
||||
Number.isFinite(Number(fallbackPort)) && Number(fallbackPort) > 0
|
||||
? Number(fallbackPort)
|
||||
: isHttps
|
||||
? 8443
|
||||
: 8123;
|
||||
const parsed = parseSingleHostUri(
|
||||
trimmed,
|
||||
[isHttps ? "https" : "http"],
|
||||
defaultPort,
|
||||
);
|
||||
if (!parsed) {
|
||||
return null;
|
||||
}
|
||||
const skipVerify = normalizeUriBool(parsed.params.get("skip_verify"));
|
||||
return {
|
||||
host: parsed.host,
|
||||
port: parsed.port,
|
||||
user: parsed.username,
|
||||
password: parsed.password,
|
||||
database: parsed.database || "",
|
||||
clickHouseProtocol: "http",
|
||||
useSSL: isHttps,
|
||||
sslMode: isHttps ? (skipVerify ? "skip-verify" : "required") : "disable",
|
||||
};
|
||||
};
|
||||
|
||||
const parseUriToValues = (
|
||||
uriText: string,
|
||||
type: string,
|
||||
@@ -1337,6 +1422,13 @@ const ConnectionModal: React.FC<{
|
||||
};
|
||||
}
|
||||
|
||||
if (type === "clickhouse") {
|
||||
const httpValues = parseClickHouseHTTPUriToValues(trimmedUri);
|
||||
if (httpValues) {
|
||||
return httpValues;
|
||||
}
|
||||
}
|
||||
|
||||
const singleHostSchemes = singleHostUriSchemesByType[type];
|
||||
if (singleHostSchemes && singleHostSchemes.length > 0) {
|
||||
const parsed = parseSingleHostUri(
|
||||
@@ -1412,6 +1504,9 @@ const ConnectionModal: React.FC<{
|
||||
parsedValues.sslMode = "disable";
|
||||
}
|
||||
} else if (type === "clickhouse") {
|
||||
parsedValues.clickHouseProtocol = normalizeClickHouseProtocolValue(
|
||||
parsed.params.get("protocol"),
|
||||
);
|
||||
const secure = String(
|
||||
parsed.params.get("secure") || parsed.params.get("tls") || "",
|
||||
)
|
||||
@@ -1707,7 +1802,18 @@ const ConnectionModal: React.FC<{
|
||||
return `${scheme}://${encodedAuth}${hosts.join(",")}${dbPath}${query ? `?${query}` : ""}`;
|
||||
}
|
||||
|
||||
const scheme = type === "postgres" ? "postgresql" : type;
|
||||
const clickHouseProtocol =
|
||||
type === "clickhouse"
|
||||
? normalizeClickHouseProtocolValue(values.clickHouseProtocol)
|
||||
: "auto";
|
||||
const scheme =
|
||||
type === "postgres"
|
||||
? "postgresql"
|
||||
: type === "clickhouse" && clickHouseProtocol === "http"
|
||||
? values.useSSL
|
||||
? "https"
|
||||
: "http"
|
||||
: type;
|
||||
const dbPath = database ? `/${encodeURIComponent(database)}` : "";
|
||||
const params = new URLSearchParams();
|
||||
if (supportsSSLForType(type) && values.useSSL) {
|
||||
@@ -1728,9 +1834,15 @@ const ConnectionModal: React.FC<{
|
||||
mode === "skip-verify" || mode === "preferred" ? "true" : "false",
|
||||
);
|
||||
} else if (type === "clickhouse") {
|
||||
params.set("secure", "true");
|
||||
if (mode === "skip-verify" || mode === "preferred") {
|
||||
params.set("skip_verify", "true");
|
||||
if (clickHouseProtocol === "http") {
|
||||
if (mode === "skip-verify" || mode === "preferred") {
|
||||
params.set("skip_verify", "true");
|
||||
}
|
||||
} else {
|
||||
params.set("secure", "true");
|
||||
if (mode === "skip-verify" || mode === "preferred") {
|
||||
params.set("skip_verify", "true");
|
||||
}
|
||||
}
|
||||
} else if (type === "dameng") {
|
||||
const certPath = String(values.sslCertPath || "").trim();
|
||||
@@ -1761,6 +1873,9 @@ const ConnectionModal: React.FC<{
|
||||
params.set("protocol", "ws");
|
||||
}
|
||||
}
|
||||
if (type === "clickhouse" && clickHouseProtocol !== "auto") {
|
||||
params.set("protocol", clickHouseProtocol);
|
||||
}
|
||||
const query = params.toString();
|
||||
return `${scheme}://${encodedAuth}${toAddress(host, port, defaultPort)}${dbPath}${query ? `?${query}` : ""}`;
|
||||
};
|
||||
@@ -1967,6 +2082,10 @@ const ConnectionModal: React.FC<{
|
||||
password: config.password,
|
||||
database: config.database,
|
||||
uri: config.uri || "",
|
||||
clickHouseProtocol:
|
||||
configType === "clickhouse"
|
||||
? normalizeClickHouseProtocolValue(config.clickHouseProtocol)
|
||||
: "auto",
|
||||
includeDatabases: initialValues.includeDatabases,
|
||||
includeRedisDatabases: initialValues.includeRedisDatabases,
|
||||
useSSL: !!config.useSSL,
|
||||
@@ -2287,10 +2406,14 @@ const ConnectionModal: React.FC<{
|
||||
const values = form.getFieldsValue(true);
|
||||
const unavailableReason = await resolveDriverUnavailableReason(
|
||||
values.type,
|
||||
values.driver,
|
||||
);
|
||||
if (unavailableReason) {
|
||||
message.warning(unavailableReason);
|
||||
promptInstallDriver(values.type, unavailableReason);
|
||||
promptInstallDriver(
|
||||
resolveConnectionDriverType(values.type, values.driver) || values.type,
|
||||
unavailableReason,
|
||||
);
|
||||
return;
|
||||
}
|
||||
setLoading(true);
|
||||
@@ -2445,6 +2568,7 @@ const ConnectionModal: React.FC<{
|
||||
const values = form.getFieldsValue(true);
|
||||
const unavailableReason = await resolveDriverUnavailableReason(
|
||||
values.type,
|
||||
values.driver,
|
||||
);
|
||||
if (unavailableReason) {
|
||||
applyTestFailureFeedback(
|
||||
@@ -2454,7 +2578,10 @@ const ConnectionModal: React.FC<{
|
||||
fallback: "驱动未安装启用",
|
||||
}),
|
||||
);
|
||||
promptInstallDriver(values.type, unavailableReason);
|
||||
promptInstallDriver(
|
||||
resolveConnectionDriverType(values.type, values.driver) || values.type,
|
||||
unavailableReason,
|
||||
);
|
||||
return;
|
||||
}
|
||||
const blockingSecretClearMessage = getBlockingSecretClearMessage(values);
|
||||
@@ -2740,6 +2867,15 @@ const ConnectionModal: React.FC<{
|
||||
(Array.isArray(value) && value.length === 0);
|
||||
if (parsedUriValues) {
|
||||
Object.entries(parsedUriValues).forEach(([key, value]) => {
|
||||
if (
|
||||
key === "clickHouseProtocol" &&
|
||||
normalizeClickHouseProtocolValue((mergedValues as any)[key]) ===
|
||||
"auto" &&
|
||||
normalizeClickHouseProtocolValue(value) !== "auto"
|
||||
) {
|
||||
(mergedValues as any)[key] = value;
|
||||
return;
|
||||
}
|
||||
if (isEmptyField((mergedValues as any)[key])) {
|
||||
(mergedValues as any)[key] = value;
|
||||
}
|
||||
@@ -2748,6 +2884,35 @@ const ConnectionModal: React.FC<{
|
||||
|
||||
const type = String(mergedValues.type || "").toLowerCase();
|
||||
const defaultPort = getDefaultPortByType(type);
|
||||
if (type === "clickhouse") {
|
||||
const requestedProtocol = normalizeClickHouseProtocolValue(
|
||||
mergedValues.clickHouseProtocol,
|
||||
);
|
||||
const hostSchemeValues = parseClickHouseHTTPUriToValues(
|
||||
mergedValues.host,
|
||||
Number(mergedValues.port || defaultPort),
|
||||
);
|
||||
if (hostSchemeValues) {
|
||||
mergedValues.host = hostSchemeValues.host;
|
||||
mergedValues.port = hostSchemeValues.port;
|
||||
if (requestedProtocol !== "native") {
|
||||
mergedValues.clickHouseProtocol = "http";
|
||||
mergedValues.useSSL = hostSchemeValues.useSSL;
|
||||
mergedValues.sslMode = hostSchemeValues.sslMode;
|
||||
} else {
|
||||
mergedValues.clickHouseProtocol = "native";
|
||||
}
|
||||
if (isEmptyField(mergedValues.user)) {
|
||||
mergedValues.user = hostSchemeValues.user;
|
||||
}
|
||||
if (isEmptyField(mergedValues.password)) {
|
||||
mergedValues.password = hostSchemeValues.password;
|
||||
}
|
||||
if (isEmptyField(mergedValues.database)) {
|
||||
mergedValues.database = hostSchemeValues.database;
|
||||
}
|
||||
}
|
||||
}
|
||||
const isFileDbType = isFileDatabaseType(type);
|
||||
const sslCapableType = supportsSSLForType(type);
|
||||
|
||||
@@ -2990,6 +3155,10 @@ const ConnectionModal: React.FC<{
|
||||
? Math.max(0, Math.min(15, Math.trunc(Number(mergedValues.redisDB))))
|
||||
: 0,
|
||||
uri: String(mergedValues.uri || "").trim(),
|
||||
clickHouseProtocol:
|
||||
type === "clickhouse"
|
||||
? normalizeClickHouseProtocolValue(mergedValues.clickHouseProtocol)
|
||||
: undefined,
|
||||
hosts: hosts,
|
||||
topology: topology,
|
||||
mysqlReplicaUser: mysqlReplicaUser,
|
||||
@@ -3017,7 +3186,10 @@ const ConnectionModal: React.FC<{
|
||||
}
|
||||
setTypeSelectWarning(null);
|
||||
setDbType(type);
|
||||
form.setFieldsValue({ type: type });
|
||||
form.setFieldsValue({
|
||||
type: type,
|
||||
clickHouseProtocol: type === "clickhouse" ? "auto" : undefined,
|
||||
});
|
||||
|
||||
const defaultPort = getDefaultPortByType(type);
|
||||
if (type === "jvm") {
|
||||
@@ -3188,17 +3360,27 @@ const ConnectionModal: React.FC<{
|
||||
isJVM && hasUnsupportedJvmModeSelection
|
||||
? "当前连接包含未支持的 JVM 模式。此版本只支持 JMX / Endpoint / Agent,请先调整允许模式和首选模式后再继续。"
|
||||
: "";
|
||||
const currentDriverType = normalizeDriverType(dbType);
|
||||
const currentDriverType = resolveConnectionDriverType(dbType, customDriver);
|
||||
const hasCurrentDriverType =
|
||||
currentDriverType !== "" && currentDriverType !== "custom";
|
||||
const currentDriverSnapshot = driverStatusMap[currentDriverType];
|
||||
const currentDriverUnavailableReason =
|
||||
currentDriverType !== "custom" &&
|
||||
hasCurrentDriverType &&
|
||||
currentDriverSnapshot &&
|
||||
!currentDriverSnapshot.connectable
|
||||
? currentDriverSnapshot.message ||
|
||||
`${currentDriverSnapshot.name || dbType} 驱动未安装启用`
|
||||
: "";
|
||||
const currentDriverUpdateReason =
|
||||
hasCurrentDriverType &&
|
||||
currentDriverSnapshot?.connectable &&
|
||||
currentDriverSnapshot.needsUpdate
|
||||
? currentDriverSnapshot.message ||
|
||||
currentDriverSnapshot.updateReason ||
|
||||
`${currentDriverSnapshot.name || dbType} 驱动代理需要重装后才能应用当前版本的驱动侧更新`
|
||||
: "";
|
||||
const driverStatusChecking =
|
||||
currentDriverType !== "custom" && !driverStatusLoaded && step === 2;
|
||||
hasCurrentDriverType && !driverStatusLoaded && step === 2;
|
||||
|
||||
const dbTypeGroups = [
|
||||
{
|
||||
@@ -4294,6 +4476,25 @@ const ConnectionModal: React.FC<{
|
||||
),
|
||||
})}
|
||||
|
||||
{dbType === "clickhouse" &&
|
||||
renderConfigSectionCard({
|
||||
sectionKey: "connectionMode",
|
||||
icon: <ClusterOutlined />,
|
||||
children: (
|
||||
<Form.Item
|
||||
name="clickHouseProtocol"
|
||||
label="连接协议"
|
||||
help="自动模式按 URI scheme 和常见端口判断;非标 HTTP/Native 端口可手动指定。"
|
||||
style={{ marginBottom: 0 }}
|
||||
>
|
||||
<Select
|
||||
options={CLICKHOUSE_PROTOCOL_OPTIONS}
|
||||
onChange={() => clearConnectionTestResultForChoice()}
|
||||
/>
|
||||
</Form.Item>
|
||||
),
|
||||
})}
|
||||
|
||||
{(dbType === "postgres" ||
|
||||
dbType === "kingbase" ||
|
||||
dbType === "highgo" ||
|
||||
@@ -5898,6 +6099,26 @@ const ConnectionModal: React.FC<{
|
||||
}
|
||||
/>
|
||||
)}
|
||||
{currentDriverUpdateReason && (
|
||||
<Alert
|
||||
showIcon
|
||||
type="warning"
|
||||
style={{ marginBottom: 12 }}
|
||||
message="当前数据源驱动代理建议重装"
|
||||
description={
|
||||
<Space size={8}>
|
||||
<span>{currentDriverUpdateReason}</span>
|
||||
<Button
|
||||
type="link"
|
||||
size="small"
|
||||
onClick={() => onOpenDriverManager?.()}
|
||||
>
|
||||
去驱动管理重装
|
||||
</Button>
|
||||
</Space>
|
||||
}
|
||||
/>
|
||||
)}
|
||||
{(() => {
|
||||
const sectionItems: Array<{
|
||||
key: "basic" | "network" | "appearance";
|
||||
|
||||
@@ -2,7 +2,8 @@ import React from 'react';
|
||||
import { act, create, type ReactTestRenderer } from 'react-test-renderer';
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import DataGrid from './DataGrid';
|
||||
import DataGrid, { buildDataGridCommitChangeSet, GONAVI_ROW_KEY } from './DataGrid';
|
||||
import { ORACLE_ROWID_LOCATOR_COLUMN } from '../utils/rowLocator';
|
||||
|
||||
const storeState = vi.hoisted(() => ({
|
||||
connections: [
|
||||
@@ -216,6 +217,157 @@ const waitForEffects = async () => {
|
||||
});
|
||||
};
|
||||
|
||||
const normalizeValue = (_columnName: string, value: any) => value;
|
||||
const rowKeyToString = (key: any) => String(key);
|
||||
|
||||
const commitColumnGuard = (columnName: string) => (
|
||||
columnName !== GONAVI_ROW_KEY && columnName !== ORACLE_ROWID_LOCATOR_COLUMN
|
||||
);
|
||||
|
||||
describe('DataGrid commit change set', () => {
|
||||
it('uses unique locator values instead of falling back to the whole row', () => {
|
||||
const result = buildDataGridCommitChangeSet({
|
||||
addedRows: [],
|
||||
modifiedRows: {
|
||||
'row-1': { [GONAVI_ROW_KEY]: 'row-1', EMAIL: 'a@example.com', NAME: 'new-name', AGE: 42 },
|
||||
},
|
||||
deletedRowKeys: new Set(),
|
||||
data: [{ [GONAVI_ROW_KEY]: 'row-1', EMAIL: 'a@example.com', NAME: 'old-name', AGE: 42 }],
|
||||
editLocator: {
|
||||
strategy: 'unique-key',
|
||||
columns: ['EMAIL'],
|
||||
valueColumns: ['EMAIL'],
|
||||
readOnly: false,
|
||||
},
|
||||
visibleColumnNames: ['EMAIL', 'NAME', 'AGE'],
|
||||
rowKeyToString,
|
||||
normalizeCommitCellValue: normalizeValue,
|
||||
shouldCommitColumn: commitColumnGuard,
|
||||
});
|
||||
|
||||
expect(result).toEqual({
|
||||
ok: true,
|
||||
changes: {
|
||||
inserts: [],
|
||||
updates: [{ keys: { EMAIL: 'a@example.com' }, values: { NAME: 'new-name' } }],
|
||||
deletes: [],
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('uses hidden Oracle ROWID only as locator and excludes it from update values', () => {
|
||||
const result = buildDataGridCommitChangeSet({
|
||||
addedRows: [],
|
||||
modifiedRows: {
|
||||
'row-1': { [GONAVI_ROW_KEY]: 'row-1', NAME: 'new-name', [ORACLE_ROWID_LOCATOR_COLUMN]: 'BBBB' },
|
||||
},
|
||||
deletedRowKeys: new Set(),
|
||||
data: [{ [GONAVI_ROW_KEY]: 'row-1', NAME: 'old-name', [ORACLE_ROWID_LOCATOR_COLUMN]: 'AAAA' }],
|
||||
editLocator: {
|
||||
strategy: 'oracle-rowid',
|
||||
columns: ['ROWID'],
|
||||
valueColumns: [ORACLE_ROWID_LOCATOR_COLUMN],
|
||||
hiddenColumns: [ORACLE_ROWID_LOCATOR_COLUMN],
|
||||
readOnly: false,
|
||||
},
|
||||
visibleColumnNames: ['NAME'],
|
||||
rowKeyToString,
|
||||
normalizeCommitCellValue: normalizeValue,
|
||||
shouldCommitColumn: commitColumnGuard,
|
||||
});
|
||||
|
||||
expect(result).toEqual({
|
||||
ok: true,
|
||||
changes: {
|
||||
inserts: [],
|
||||
updates: [{ keys: { ROWID: 'AAAA' }, values: { NAME: 'new-name' } }],
|
||||
deletes: [],
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('commits only writable result columns and maps aliases back to table columns', () => {
|
||||
const result = buildDataGridCommitChangeSet({
|
||||
addedRows: [],
|
||||
modifiedRows: {
|
||||
'row-1': {
|
||||
[GONAVI_ROW_KEY]: 'row-1',
|
||||
DISPLAY_NAME: 'new-name',
|
||||
NAME_UPPER: 'NEW-NAME',
|
||||
},
|
||||
},
|
||||
deletedRowKeys: new Set(),
|
||||
data: [{
|
||||
[GONAVI_ROW_KEY]: 'row-1',
|
||||
ID: 7,
|
||||
DISPLAY_NAME: 'old-name',
|
||||
NAME_UPPER: 'OLD-NAME',
|
||||
}],
|
||||
editLocator: {
|
||||
strategy: 'primary-key',
|
||||
columns: ['ID'],
|
||||
valueColumns: ['ID'],
|
||||
writableColumns: {
|
||||
DISPLAY_NAME: 'NAME',
|
||||
},
|
||||
readOnly: false,
|
||||
},
|
||||
visibleColumnNames: ['DISPLAY_NAME', 'NAME_UPPER'],
|
||||
rowKeyToString,
|
||||
normalizeCommitCellValue: normalizeValue,
|
||||
shouldCommitColumn: commitColumnGuard,
|
||||
});
|
||||
|
||||
expect(result).toEqual({
|
||||
ok: true,
|
||||
changes: {
|
||||
inserts: [],
|
||||
updates: [{ keys: { ID: 7 }, values: { NAME: 'new-name' } }],
|
||||
deletes: [],
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('fails closed when no safe locator is available', () => {
|
||||
const result = buildDataGridCommitChangeSet({
|
||||
addedRows: [],
|
||||
modifiedRows: {
|
||||
'row-1': { [GONAVI_ROW_KEY]: 'row-1', NAME: 'new-name' },
|
||||
},
|
||||
deletedRowKeys: new Set(),
|
||||
data: [{ [GONAVI_ROW_KEY]: 'row-1', NAME: 'old-name' }],
|
||||
editLocator: undefined,
|
||||
visibleColumnNames: ['NAME'],
|
||||
rowKeyToString,
|
||||
normalizeCommitCellValue: normalizeValue,
|
||||
shouldCommitColumn: commitColumnGuard,
|
||||
});
|
||||
|
||||
expect(result).toEqual({ ok: false, error: '当前结果没有可用的安全行定位方式,无法提交修改。' });
|
||||
});
|
||||
|
||||
it('rejects delete rows when unique locator value is null', () => {
|
||||
const result = buildDataGridCommitChangeSet({
|
||||
addedRows: [],
|
||||
modifiedRows: {},
|
||||
deletedRowKeys: new Set(['row-1']),
|
||||
data: [{ [GONAVI_ROW_KEY]: 'row-1', EMAIL: null, NAME: 'old-name' }],
|
||||
editLocator: {
|
||||
strategy: 'unique-key',
|
||||
columns: ['EMAIL'],
|
||||
valueColumns: ['EMAIL'],
|
||||
readOnly: false,
|
||||
},
|
||||
visibleColumnNames: ['EMAIL', 'NAME'],
|
||||
rowKeyToString,
|
||||
normalizeCommitCellValue: normalizeValue,
|
||||
shouldCommitColumn: commitColumnGuard,
|
||||
});
|
||||
|
||||
expect(result).toEqual({ ok: false, error: '定位列 EMAIL 的值为空,无法安全提交修改。' });
|
||||
});
|
||||
});
|
||||
|
||||
describe('DataGrid DDL interactions', () => {
|
||||
beforeEach(() => {
|
||||
backendApp.DBGetColumns.mockResolvedValue({ success: true, data: [] });
|
||||
|
||||
@@ -159,6 +159,7 @@ describe('DataGrid layout', () => {
|
||||
columnNames={['id', 'name']}
|
||||
loading={false}
|
||||
tableName="users"
|
||||
pkColumns={['id']}
|
||||
/>,
|
||||
);
|
||||
|
||||
|
||||
@@ -79,6 +79,13 @@ import {
|
||||
type DataGridFindMatch,
|
||||
type DataGridFindNavigationDirection,
|
||||
} from '../utils/dataGridFind';
|
||||
import {
|
||||
filterHiddenLocatorColumns,
|
||||
isWritableResultColumn,
|
||||
resolveWritableColumnName,
|
||||
resolveRowLocatorValues,
|
||||
type EditRowLocator,
|
||||
} from '../utils/rowLocator';
|
||||
|
||||
// --- Error Boundary ---
|
||||
interface DataGridErrorBoundaryState {
|
||||
@@ -916,6 +923,7 @@ interface DataGridProps {
|
||||
dbName?: string;
|
||||
connectionId?: string;
|
||||
pkColumns?: string[];
|
||||
editLocator?: EditRowLocator;
|
||||
readOnly?: boolean;
|
||||
onReload?: () => void;
|
||||
onSort?: (field: string, order: string) => void;
|
||||
@@ -960,12 +968,112 @@ type ColumnMeta = {
|
||||
comment: string;
|
||||
};
|
||||
|
||||
type NormalizeCommitCellValue = (columnName: string, value: any, mode: 'insert' | 'update') => any;
|
||||
|
||||
type DataGridCommitChangeSet = {
|
||||
inserts: any[];
|
||||
updates: any[];
|
||||
deletes: any[];
|
||||
};
|
||||
|
||||
export const buildDataGridCommitChangeSet = ({
|
||||
addedRows,
|
||||
modifiedRows,
|
||||
deletedRowKeys,
|
||||
data,
|
||||
editLocator,
|
||||
visibleColumnNames,
|
||||
rowKeyToString,
|
||||
normalizeCommitCellValue,
|
||||
shouldCommitColumn,
|
||||
}: {
|
||||
addedRows: any[];
|
||||
modifiedRows: Record<string, any>;
|
||||
deletedRowKeys: Set<string>;
|
||||
data: any[];
|
||||
editLocator?: EditRowLocator;
|
||||
visibleColumnNames: string[];
|
||||
rowKeyToString: (key: any) => string;
|
||||
normalizeCommitCellValue: NormalizeCommitCellValue;
|
||||
shouldCommitColumn: (columnName: string) => boolean;
|
||||
}): { ok: true; changes: DataGridCommitChangeSet } | { ok: false; error: string } => {
|
||||
if (!editLocator || editLocator.readOnly || editLocator.strategy === 'none') {
|
||||
return { ok: false, error: editLocator?.reason || '当前结果没有可用的安全行定位方式,无法提交修改。' };
|
||||
}
|
||||
|
||||
const normalizeValues = (values: Record<string, any>, mode: 'insert' | 'update') => {
|
||||
const normalizedValues: Record<string, any> = {};
|
||||
Object.entries(values).forEach(([col, val]) => {
|
||||
if (!shouldCommitColumn(col)) return;
|
||||
const commitColumnName = resolveWritableColumnName(col, editLocator);
|
||||
if (!commitColumnName) return;
|
||||
const normalizedVal = normalizeCommitCellValue(col, val, mode);
|
||||
if (normalizedVal !== undefined) {
|
||||
normalizedValues[commitColumnName] = normalizedVal;
|
||||
}
|
||||
});
|
||||
return normalizedValues;
|
||||
};
|
||||
|
||||
const originalRowsByKey = new Map<string, any>();
|
||||
data.forEach((row) => {
|
||||
const key = row?.[GONAVI_ROW_KEY];
|
||||
if (key === undefined || key === null) return;
|
||||
originalRowsByKey.set(rowKeyToString(key), row);
|
||||
});
|
||||
|
||||
const inserts: any[] = [];
|
||||
const updates: any[] = [];
|
||||
const deletes: any[] = [];
|
||||
|
||||
addedRows.forEach(row => {
|
||||
const key = row?.[GONAVI_ROW_KEY];
|
||||
if (key !== undefined && key !== null && deletedRowKeys.has(rowKeyToString(key))) return;
|
||||
inserts.push(normalizeValues(row, 'insert'));
|
||||
});
|
||||
|
||||
for (const keyStr of deletedRowKeys) {
|
||||
const originalRow = originalRowsByKey.get(keyStr);
|
||||
if (!originalRow) continue;
|
||||
const locatorValues = resolveRowLocatorValues(editLocator, originalRow);
|
||||
if (!locatorValues.ok) return { ok: false, error: locatorValues.error };
|
||||
deletes.push(locatorValues.values);
|
||||
}
|
||||
|
||||
for (const [keyStr, newRow] of Object.entries(modifiedRows)) {
|
||||
if (deletedRowKeys.has(keyStr)) continue;
|
||||
const originalRow = originalRowsByKey.get(keyStr);
|
||||
if (!originalRow) continue;
|
||||
|
||||
const locatorValues = resolveRowLocatorValues(editLocator, originalRow);
|
||||
if (!locatorValues.ok) return { ok: false, error: locatorValues.error };
|
||||
|
||||
const hasRowKey = Object.prototype.hasOwnProperty.call(newRow as any, GONAVI_ROW_KEY);
|
||||
let values: Record<string, any> = {};
|
||||
if (!hasRowKey) {
|
||||
values = { ...(newRow as any) };
|
||||
} else {
|
||||
visibleColumnNames.forEach((col) => {
|
||||
const nextVal = (newRow as any)?.[col];
|
||||
const prevVal = (originalRow as any)?.[col];
|
||||
if (!isCellValueEqualForDiff(prevVal, nextVal)) values[col] = nextVal;
|
||||
});
|
||||
}
|
||||
|
||||
const normalizedValues = normalizeValues(values, 'update');
|
||||
if (Object.keys(normalizedValues).length === 0) continue;
|
||||
updates.push({ keys: locatorValues.values, values: normalizedValues });
|
||||
}
|
||||
|
||||
return { ok: true, changes: { inserts, updates, deletes } };
|
||||
};
|
||||
|
||||
// P2 性能优化:提取内联 style 对象为模块级常量,避免每次 render 创建新对象
|
||||
const CELL_ELLIPSIS_STYLE: React.CSSProperties = { overflow: 'hidden', textOverflow: 'ellipsis', whiteSpace: 'nowrap' };
|
||||
const VIRTUAL_CELL_WRAPPER_STYLE: React.CSSProperties = { margin: -8, padding: '8px 8px 8px 8px' };
|
||||
|
||||
const DataGrid: React.FC<DataGridProps> = ({
|
||||
data, columnNames, loading, tableName, exportScope = 'table', resultSql, dbName, connectionId, pkColumns = [], readOnly = false,
|
||||
data, columnNames, loading, tableName, exportScope = 'table', resultSql, dbName, connectionId, pkColumns = [], editLocator, readOnly = false,
|
||||
onReload, onSort, onPageChange, pagination, onRequestTotalCount, onCancelTotalCount, sortInfoExternal, showFilter, onToggleFilter, exportSqlWithFilter, onApplyFilter, appliedFilterConditions, quickWhereCondition,
|
||||
onApplyQuickWhereCondition,
|
||||
scrollSnapshot, onScrollSnapshotChange
|
||||
@@ -999,7 +1107,25 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
darkMode,
|
||||
visible: showDataTableVerticalBorders,
|
||||
});
|
||||
const canModifyData = !readOnly && !!tableName;
|
||||
const effectiveEditLocator = useMemo<EditRowLocator | undefined>(() => {
|
||||
if (editLocator) return editLocator;
|
||||
if (pkColumns.length === 0) return undefined;
|
||||
return {
|
||||
strategy: 'primary-key',
|
||||
columns: pkColumns,
|
||||
valueColumns: pkColumns,
|
||||
readOnly: false,
|
||||
};
|
||||
}, [editLocator, pkColumns]);
|
||||
const visibleColumnNames = useMemo(
|
||||
() => filterHiddenLocatorColumns(columnNames, effectiveEditLocator),
|
||||
[columnNames, effectiveEditLocator]
|
||||
);
|
||||
const shouldCommitColumn = useCallback((columnName: string): boolean => {
|
||||
const normalized = String(columnName || '').trim();
|
||||
return normalized !== GONAVI_ROW_KEY && isWritableResultColumn(normalized, effectiveEditLocator);
|
||||
}, [effectiveEditLocator]);
|
||||
const canModifyData = !readOnly && !!tableName && !!effectiveEditLocator && !effectiveEditLocator.readOnly && effectiveEditLocator.strategy !== 'none';
|
||||
const showColumnComment = queryOptions?.showColumnComment ?? true;
|
||||
const showColumnType = queryOptions?.showColumnType ?? true;
|
||||
|
||||
@@ -1053,7 +1179,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
|
||||
// Sync display order from incoming prop and store memory
|
||||
useEffect(() => {
|
||||
let nextOrder = [...columnNames];
|
||||
let nextOrder = [...visibleColumnNames];
|
||||
if (enableColumnOrderMemory && connectionId && dbName && tableName) {
|
||||
const storedOrder = tableColumnOrders[`${connectionId}-${dbName}-${tableName}`];
|
||||
if (Array.isArray(storedOrder) && storedOrder.length > 0) {
|
||||
@@ -1066,7 +1192,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
}
|
||||
}
|
||||
setAllOrderedColumnNames(nextOrder);
|
||||
}, [columnNames, tableColumnOrders, enableColumnOrderMemory, connectionId, dbName, tableName]);
|
||||
}, [visibleColumnNames, tableColumnOrders, enableColumnOrderMemory, connectionId, dbName, tableName]);
|
||||
|
||||
// Compute final display columns
|
||||
useEffect(() => {
|
||||
@@ -1378,7 +1504,13 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
const exportData = async (rows: any[], format: string) => {
|
||||
const hide = message.loading(`正在导出 ${rows.length} 条数据...`, 0);
|
||||
try {
|
||||
const cleanRows = rows.map(({ [GONAVI_ROW_KEY]: _rowKey, ...rest }) => rest);
|
||||
const cleanRows = rows.map((row) => {
|
||||
const next: Record<string, any> = {};
|
||||
displayColumnNames.forEach((columnName) => {
|
||||
next[columnName] = row?.[columnName];
|
||||
});
|
||||
return next;
|
||||
});
|
||||
// Pass tableName (or 'export') as default filename
|
||||
const res = await ExportData(cleanRows, displayColumnNames, tableName || 'export', format);
|
||||
if (res.success) {
|
||||
@@ -1538,10 +1670,10 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
return metaColumns;
|
||||
}
|
||||
if (exportScope === 'table') {
|
||||
return columnNames.filter((columnName) => columnName !== GONAVI_ROW_KEY);
|
||||
return visibleColumnNames.filter((columnName) => columnName !== GONAVI_ROW_KEY);
|
||||
}
|
||||
return [];
|
||||
}, [columnMetaMap, exportScope, columnNames]);
|
||||
}, [columnMetaMap, exportScope, visibleColumnNames]);
|
||||
|
||||
const normalizeCommitCellValue = useCallback(
|
||||
(columnName: string, value: any, mode: 'insert' | 'update') => {
|
||||
@@ -3298,19 +3430,25 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
const jsonViewText = useMemo(() => {
|
||||
if (viewMode !== 'json') return '';
|
||||
const cleanRows = mergedDisplayData.map((row) => {
|
||||
const { [GONAVI_ROW_KEY]: _rowKey, ...rest } = row || {};
|
||||
return normalizeValueForJsonView(rest);
|
||||
const next: Record<string, any> = {};
|
||||
visibleColumnNames.forEach((columnName) => {
|
||||
next[columnName] = row?.[columnName];
|
||||
});
|
||||
return normalizeValueForJsonView(next);
|
||||
});
|
||||
return JSON.stringify(cleanRows, null, 2);
|
||||
}, [viewMode, mergedDisplayData]);
|
||||
}, [viewMode, mergedDisplayData, visibleColumnNames]);
|
||||
|
||||
const textViewRows = useMemo(() => {
|
||||
if (viewMode !== 'text') return [];
|
||||
return mergedDisplayData.map((row) => {
|
||||
const { [GONAVI_ROW_KEY]: _rowKey, ...rest } = row || {};
|
||||
return rest;
|
||||
const next: Record<string, any> = {};
|
||||
visibleColumnNames.forEach((columnName) => {
|
||||
next[columnName] = row?.[columnName];
|
||||
});
|
||||
return next;
|
||||
});
|
||||
}, [viewMode, mergedDisplayData]);
|
||||
}, [viewMode, mergedDisplayData, visibleColumnNames]);
|
||||
|
||||
const currentTextRow = useMemo(() => {
|
||||
if (viewMode !== 'text') return null;
|
||||
@@ -3363,7 +3501,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
const formMap: Record<string, any> = {};
|
||||
const nullCols = new Set<string>();
|
||||
|
||||
columnNames.forEach((col) => {
|
||||
visibleColumnNames.forEach((col) => {
|
||||
const baseVal = (baseRow as any)?.[col];
|
||||
const displayVal = (displayRow as any)?.[col];
|
||||
baseRawMap[col] = baseVal;
|
||||
@@ -3511,7 +3649,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
const keyStr = rowKeyStr(rowKey);
|
||||
const normalizedNext: Record<string, any> = {};
|
||||
let hasAnyVisibleChange = false;
|
||||
columnNames.forEach((col) => {
|
||||
visibleColumnNames.forEach((col) => {
|
||||
const currentVal = (currentRow as any)?.[col];
|
||||
const editedVal = Object.prototype.hasOwnProperty.call(nextItem, col) ? (nextItem as any)[col] : currentVal;
|
||||
if (!isJsonViewValueEqual(currentVal, editedVal)) hasAnyVisibleChange = true;
|
||||
@@ -3530,7 +3668,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
const originalRow = originalMap.get(keyStr);
|
||||
if (!originalRow) continue;
|
||||
const patch: Record<string, any> = {};
|
||||
columnNames.forEach((col) => {
|
||||
visibleColumnNames.forEach((col) => {
|
||||
const prevVal = (originalRow as any)?.[col];
|
||||
const nextVal = normalizedNext[col];
|
||||
if (!isCellValueEqualForDiff(prevVal, nextVal)) patch[col] = nextVal;
|
||||
@@ -3595,7 +3733,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
|
||||
const baseRawMap = rowEditorBaseRawRef.current || {};
|
||||
const patch: Record<string, any> = {};
|
||||
columnNames.forEach((col) => {
|
||||
visibleColumnNames.forEach((col) => {
|
||||
let nextVal = values[col];
|
||||
// 日期时间类型: 将 dayjs 对象转回格式化字符串
|
||||
if (nextVal && dayjs.isDayjs(nextVal)) {
|
||||
@@ -3615,7 +3753,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
});
|
||||
|
||||
closeRowEditor();
|
||||
}, [rowEditorRowKey, rowEditorForm, addedRows, columnNames, rowKeyStr, closeRowEditor]);
|
||||
}, [rowEditorRowKey, rowEditorForm, addedRows, visibleColumnNames, rowKeyStr, closeRowEditor]);
|
||||
|
||||
|
||||
const enableVirtual = viewMode === 'table';
|
||||
@@ -3633,7 +3771,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
}),
|
||||
sorter: onSort ? { multiple: displayColumnNames.indexOf(key) + 1 } : false,
|
||||
sortOrder: (sortInfo.find(s => s.columnKey === key && s.enabled !== false)?.order || null) as SortOrder | undefined,
|
||||
editable: canModifyData, // Only editable if table name known and not readonly
|
||||
editable: canModifyData && isWritableResultColumn(key, effectiveEditLocator),
|
||||
render: (text: any) => (
|
||||
<div style={CELL_ELLIPSIS_STYLE}>
|
||||
{renderCellDisplayValue(text, normalizedPageFindText)}
|
||||
@@ -3761,7 +3899,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
const handleAddRow = () => {
|
||||
const newKey = `new-${Date.now()}`;
|
||||
const newRow: any = { [GONAVI_ROW_KEY]: newKey };
|
||||
columnNames.forEach(col => newRow[col] = '');
|
||||
visibleColumnNames.forEach(col => newRow[col] = '');
|
||||
pendingScrollToBottomRef.current = true;
|
||||
setAddedRows(prev => [...prev, newRow]);
|
||||
};
|
||||
@@ -3775,7 +3913,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
const copiedRows = buildCopiedRowsForPaste({
|
||||
rows: mergedDisplayData as Array<Record<string, any>>,
|
||||
selectedRowKeys,
|
||||
columnNames,
|
||||
columnNames: visibleColumnNames,
|
||||
rowKeyField: GONAVI_ROW_KEY,
|
||||
rowKeyToString: rowKeyStr,
|
||||
});
|
||||
@@ -3786,7 +3924,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
|
||||
setCopiedRowsForPaste(copiedRows);
|
||||
void message.success(`已复制 ${copiedRows.length} 行,可粘贴为新增行`);
|
||||
}, [selectedRowKeys, mergedDisplayData, columnNames, rowKeyStr]);
|
||||
}, [selectedRowKeys, mergedDisplayData, visibleColumnNames, rowKeyStr]);
|
||||
|
||||
const handlePasteCopiedRowsAsNew = useCallback(() => {
|
||||
if (copiedRowsForPaste.length === 0) {
|
||||
@@ -3796,7 +3934,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
|
||||
const nextRows = buildPastedRowsFromCopiedRows({
|
||||
rows: copiedRowsForPaste,
|
||||
columnNames,
|
||||
columnNames: visibleColumnNames,
|
||||
rowKeyField: GONAVI_ROW_KEY,
|
||||
createRowKey: (index) => {
|
||||
pastedRowSequenceRef.current += 1;
|
||||
@@ -3812,7 +3950,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
setAddedRows(prev => [...prev, ...nextRows]);
|
||||
setSelectedRowKeys(nextRows.map(row => row[GONAVI_ROW_KEY]));
|
||||
void message.success(`已粘贴 ${nextRows.length} 行为新增行,请检查后提交事务`);
|
||||
}, [copiedRowsForPaste, columnNames]);
|
||||
}, [copiedRowsForPaste, visibleColumnNames]);
|
||||
|
||||
const handleDeleteSelected = () => {
|
||||
setDeletedRowKeys(prev => {
|
||||
@@ -3827,66 +3965,23 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
if (!connectionId || !tableName) return;
|
||||
const conn = connections.find(c => c.id === connectionId);
|
||||
if (!conn) return;
|
||||
|
||||
const inserts: any[] = [];
|
||||
const updates: any[] = [];
|
||||
const deletes: any[] = [];
|
||||
|
||||
addedRows.forEach(row => {
|
||||
const { [GONAVI_ROW_KEY]: _rowKey, ...vals } = row;
|
||||
const normalizedValues: Record<string, any> = {};
|
||||
Object.entries(vals).forEach(([col, val]) => {
|
||||
const normalizedVal = normalizeCommitCellValue(col, val, 'insert');
|
||||
if (normalizedVal !== undefined) {
|
||||
normalizedValues[col] = normalizedVal;
|
||||
}
|
||||
});
|
||||
inserts.push(normalizedValues);
|
||||
});
|
||||
deletedRowKeys.forEach(keyStr => {
|
||||
// Find original data
|
||||
const originalRow = data.find(d => rowKeyStr(d?.[GONAVI_ROW_KEY]) === keyStr) || addedRows.find(d => rowKeyStr(d?.[GONAVI_ROW_KEY]) === keyStr);
|
||||
if (originalRow) {
|
||||
const pkData: any = {};
|
||||
if (pkColumns.length > 0) pkColumns.forEach(k => pkData[k] = originalRow[k]);
|
||||
else { const { [GONAVI_ROW_KEY]: _rowKey, ...rest } = originalRow; Object.assign(pkData, rest); }
|
||||
deletes.push(pkData);
|
||||
}
|
||||
});
|
||||
Object.entries(modifiedRows).forEach(([keyStr, newRow]) => {
|
||||
if (deletedRowKeys.has(keyStr)) return;
|
||||
const originalRow = data.find(d => rowKeyStr(d?.[GONAVI_ROW_KEY]) === keyStr);
|
||||
if (!originalRow) return; // Should not happen for modified rows unless deleted
|
||||
|
||||
const pkData: any = {};
|
||||
if (pkColumns.length > 0) pkColumns.forEach(k => pkData[k] = originalRow[k]);
|
||||
else { const { [GONAVI_ROW_KEY]: _rowKey, ...rest } = originalRow; Object.assign(pkData, rest); }
|
||||
|
||||
const hasRowKey = Object.prototype.hasOwnProperty.call(newRow as any, GONAVI_ROW_KEY);
|
||||
let values: any = {};
|
||||
|
||||
if (!hasRowKey) {
|
||||
values = { ...(newRow as any) };
|
||||
} else {
|
||||
columnNames.forEach((col) => {
|
||||
const nextVal = (newRow as any)?.[col];
|
||||
const prevVal = (originalRow as any)?.[col];
|
||||
if (!isCellValueEqualForDiff(prevVal, nextVal)) values[col] = nextVal;
|
||||
});
|
||||
}
|
||||
|
||||
const normalizedValues: Record<string, any> = {};
|
||||
Object.entries(values).forEach(([col, val]) => {
|
||||
const normalizedVal = normalizeCommitCellValue(col, val, 'update');
|
||||
if (normalizedVal !== undefined) {
|
||||
normalizedValues[col] = normalizedVal;
|
||||
}
|
||||
});
|
||||
|
||||
if (Object.keys(normalizedValues).length === 0) return;
|
||||
updates.push({ keys: pkData, values: normalizedValues });
|
||||
const changeSetResult = buildDataGridCommitChangeSet({
|
||||
addedRows,
|
||||
modifiedRows,
|
||||
deletedRowKeys,
|
||||
data,
|
||||
editLocator: effectiveEditLocator,
|
||||
visibleColumnNames,
|
||||
rowKeyToString: rowKeyStr,
|
||||
normalizeCommitCellValue,
|
||||
shouldCommitColumn,
|
||||
});
|
||||
if (!changeSetResult.ok) {
|
||||
void message.error(changeSetResult.error);
|
||||
return;
|
||||
}
|
||||
|
||||
const { inserts, updates, deletes } = changeSetResult.changes;
|
||||
if (inserts.length === 0 && updates.length === 0 && deletes.length === 0) {
|
||||
void message.info("没有可提交的变更");
|
||||
return;
|
||||
@@ -3902,7 +3997,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
};
|
||||
|
||||
const startTime = Date.now();
|
||||
const res = await ApplyChanges(buildRpcConnectionConfig(config) as any, dbName || '', tableName, { inserts, updates, deletes } as any);
|
||||
const res = await ApplyChanges(buildRpcConnectionConfig(config) as any, dbName || '', tableName, { inserts, updates, deletes, locatorStrategy: effectiveEditLocator?.strategy } as any);
|
||||
const duration = Date.now() - startTime;
|
||||
|
||||
// Construct a pseudo-SQL representation for the log
|
||||
@@ -4051,7 +4146,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
return null;
|
||||
}
|
||||
const records = getTargets(record);
|
||||
const orderedCols = columnNames.filter(c => c !== GONAVI_ROW_KEY);
|
||||
const orderedCols = visibleColumnNames.filter(c => c !== GONAVI_ROW_KEY);
|
||||
if (mode === 'insert') {
|
||||
return records.map((row: any) => buildCopyInsertSQL({
|
||||
dbType,
|
||||
@@ -4100,7 +4195,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
}, [
|
||||
supportsCopyInsert,
|
||||
getTargets,
|
||||
columnNames,
|
||||
visibleColumnNames,
|
||||
dbType,
|
||||
tableName,
|
||||
columnTypeMapByLowerName,
|
||||
@@ -4130,16 +4225,18 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
const handleCopyJson = useCallback((record: any) => {
|
||||
const records = getTargets(record);
|
||||
const cleanRecords = records.map((r: any) => {
|
||||
const { [GONAVI_ROW_KEY]: _rowKey, ...rest } = r;
|
||||
return rest;
|
||||
const next: Record<string, any> = {};
|
||||
visibleColumnNames.forEach((columnName) => {
|
||||
next[columnName] = r?.[columnName];
|
||||
});
|
||||
return next;
|
||||
});
|
||||
copyToClipboard(JSON.stringify(cleanRecords, null, 2));
|
||||
}, [getTargets, copyToClipboard]);
|
||||
}, [getTargets, visibleColumnNames, copyToClipboard]);
|
||||
|
||||
const handleCopyCsv = useCallback((record: any) => {
|
||||
const records = getTargets(record);
|
||||
// 使用 columnNames 保持表定义的字段顺序
|
||||
const orderedCols = columnNames.filter(c => c !== GONAVI_ROW_KEY);
|
||||
const orderedCols = visibleColumnNames.filter(c => c !== GONAVI_ROW_KEY);
|
||||
const header = orderedCols.map(c => `"${c}"`).join(',');
|
||||
const lines = records.map((r: any) => {
|
||||
const values = orderedCols.map(c => {
|
||||
@@ -4152,7 +4249,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
return values.join(',');
|
||||
});
|
||||
copyToClipboard([header, ...lines].join('\n'));
|
||||
}, [getTargets, columnNames, copyToClipboard]);
|
||||
}, [getTargets, visibleColumnNames, copyToClipboard]);
|
||||
|
||||
const buildConnConfig = useCallback(() => {
|
||||
if (!connectionId) return null;
|
||||
|
||||
199
frontend/src/components/DataViewer.primary-key.test.tsx
Normal file
199
frontend/src/components/DataViewer.primary-key.test.tsx
Normal file
@@ -0,0 +1,199 @@
|
||||
import React from 'react';
|
||||
import { act, create, type ReactTestRenderer } from 'react-test-renderer';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import type { TabData } from '../types';
|
||||
import { ORACLE_ROWID_LOCATOR_COLUMN } from '../utils/rowLocator';
|
||||
import DataViewer from './DataViewer';
|
||||
|
||||
const storeState = vi.hoisted(() => ({
|
||||
connections: [
|
||||
{
|
||||
id: 'conn-1',
|
||||
name: 'oracle',
|
||||
config: {
|
||||
type: 'oracle',
|
||||
host: '127.0.0.1',
|
||||
port: 1521,
|
||||
user: 'scott',
|
||||
password: '',
|
||||
database: 'ORCLPDB1',
|
||||
},
|
||||
},
|
||||
],
|
||||
addSqlLog: vi.fn(),
|
||||
}));
|
||||
|
||||
const backendApp = vi.hoisted(() => ({
|
||||
DBQuery: vi.fn(),
|
||||
DBGetColumns: vi.fn(),
|
||||
DBGetIndexes: vi.fn(),
|
||||
}));
|
||||
|
||||
const messageApi = vi.hoisted(() => ({
|
||||
error: vi.fn(),
|
||||
warning: vi.fn(),
|
||||
}));
|
||||
|
||||
const dataGridState = vi.hoisted(() => ({
|
||||
latestProps: null as any,
|
||||
}));
|
||||
|
||||
vi.mock('../store', () => {
|
||||
const useStore = Object.assign(
|
||||
(selector: (state: typeof storeState) => any) => selector(storeState),
|
||||
{ getState: () => storeState },
|
||||
);
|
||||
return { useStore };
|
||||
});
|
||||
|
||||
vi.mock('../../wailsjs/go/app/App', () => backendApp);
|
||||
|
||||
vi.mock('antd', () => ({
|
||||
message: messageApi,
|
||||
}));
|
||||
|
||||
vi.mock('./DataGrid', () => ({
|
||||
default: (props: any) => {
|
||||
dataGridState.latestProps = props;
|
||||
return <div data-grid="true" />;
|
||||
},
|
||||
GONAVI_ROW_KEY: '__gonavi_row_key__',
|
||||
}));
|
||||
|
||||
const createTab = (overrides: Partial<TabData> = {}): TabData => ({
|
||||
id: 'tab-1',
|
||||
title: 'EDC_LOG',
|
||||
type: 'table',
|
||||
connectionId: 'conn-1',
|
||||
dbName: 'MYCIMLED',
|
||||
tableName: 'EDC_LOG',
|
||||
...overrides,
|
||||
});
|
||||
|
||||
const flushPromises = async () => {
|
||||
await act(async () => {
|
||||
await Promise.resolve();
|
||||
await Promise.resolve();
|
||||
});
|
||||
};
|
||||
|
||||
describe('DataViewer safe editing locator', () => {
|
||||
const renderAndReload = async (tab: TabData = createTab()) => {
|
||||
let renderer: ReactTestRenderer;
|
||||
await act(async () => {
|
||||
renderer = create(<DataViewer tab={tab} />);
|
||||
});
|
||||
|
||||
await act(async () => {
|
||||
await dataGridState.latestProps.onReload();
|
||||
});
|
||||
await flushPromises();
|
||||
return renderer!;
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
dataGridState.latestProps = null;
|
||||
storeState.connections[0].config.type = 'oracle';
|
||||
storeState.connections[0].config.database = 'ORCLPDB1';
|
||||
backendApp.DBQuery.mockResolvedValue({
|
||||
success: true,
|
||||
fields: ['ID', 'NAME'],
|
||||
data: [{ ID: 7, NAME: 'old-name' }],
|
||||
});
|
||||
backendApp.DBGetIndexes.mockResolvedValue({ success: true, data: [] });
|
||||
});
|
||||
|
||||
it('enables table preview editing after primary keys are loaded', async () => {
|
||||
backendApp.DBGetColumns.mockResolvedValue({
|
||||
success: true,
|
||||
data: [{ name: 'ID', key: 'PRI' }, { name: 'NAME', key: '' }],
|
||||
});
|
||||
|
||||
const renderer = await renderAndReload();
|
||||
|
||||
expect(dataGridState.latestProps?.pkColumns).toEqual(['ID']);
|
||||
expect(dataGridState.latestProps?.editLocator).toMatchObject({
|
||||
strategy: 'primary-key',
|
||||
columns: ['ID'],
|
||||
valueColumns: ['ID'],
|
||||
readOnly: false,
|
||||
});
|
||||
expect(dataGridState.latestProps?.readOnly).toBe(false);
|
||||
expect(messageApi.warning).not.toHaveBeenCalled();
|
||||
renderer.unmount();
|
||||
});
|
||||
|
||||
it('uses a unique index when the table has no primary key', async () => {
|
||||
backendApp.DBGetColumns.mockResolvedValue({
|
||||
success: true,
|
||||
data: [{ name: 'EMAIL', key: '' }, { name: 'NAME', key: '' }],
|
||||
});
|
||||
backendApp.DBGetIndexes.mockResolvedValue({
|
||||
success: true,
|
||||
data: [{ name: 'UK_EMAIL', columnName: 'EMAIL', nonUnique: 0, seqInIndex: 1, indexType: 'BTREE' }],
|
||||
});
|
||||
|
||||
const renderer = await renderAndReload();
|
||||
|
||||
expect(dataGridState.latestProps?.pkColumns).toEqual([]);
|
||||
expect(dataGridState.latestProps?.editLocator).toMatchObject({
|
||||
strategy: 'unique-key',
|
||||
columns: ['EMAIL'],
|
||||
valueColumns: ['EMAIL'],
|
||||
readOnly: false,
|
||||
});
|
||||
expect(dataGridState.latestProps?.readOnly).toBe(false);
|
||||
expect(messageApi.warning).not.toHaveBeenCalled();
|
||||
renderer.unmount();
|
||||
});
|
||||
|
||||
it('uses hidden Oracle ROWID when no primary or unique key is available', async () => {
|
||||
backendApp.DBGetColumns.mockResolvedValue({
|
||||
success: true,
|
||||
data: [{ name: 'ID', key: '' }, { name: 'NAME', key: '' }],
|
||||
});
|
||||
backendApp.DBQuery.mockResolvedValue({
|
||||
success: true,
|
||||
fields: ['ID', 'NAME', ORACLE_ROWID_LOCATOR_COLUMN],
|
||||
data: [{ ID: 7, NAME: 'old-name', [ORACLE_ROWID_LOCATOR_COLUMN]: 'AAAA' }],
|
||||
});
|
||||
|
||||
const renderer = await renderAndReload();
|
||||
|
||||
expect(dataGridState.latestProps?.pkColumns).toEqual([]);
|
||||
expect(dataGridState.latestProps?.editLocator).toMatchObject({
|
||||
strategy: 'oracle-rowid',
|
||||
columns: ['ROWID'],
|
||||
valueColumns: [ORACLE_ROWID_LOCATOR_COLUMN],
|
||||
hiddenColumns: [ORACLE_ROWID_LOCATOR_COLUMN],
|
||||
readOnly: false,
|
||||
});
|
||||
expect(dataGridState.latestProps?.readOnly).toBe(false);
|
||||
expect(messageApi.warning).not.toHaveBeenCalled();
|
||||
expect(backendApp.DBQuery.mock.calls.some((call: any[]) => String(call[2]).includes(`ROWID AS "${ORACLE_ROWID_LOCATOR_COLUMN}"`))).toBe(true);
|
||||
renderer.unmount();
|
||||
});
|
||||
|
||||
it('keeps non-Oracle table preview read-only when no safe locator exists', async () => {
|
||||
storeState.connections[0].config.type = 'mysql';
|
||||
storeState.connections[0].config.database = 'main';
|
||||
backendApp.DBGetColumns.mockResolvedValue({
|
||||
success: true,
|
||||
data: [{ name: 'ID', key: '' }, { name: 'NAME', key: '' }],
|
||||
});
|
||||
|
||||
const renderer = await renderAndReload(createTab({ dbName: 'main', tableName: 'users', title: 'users' }));
|
||||
|
||||
expect(dataGridState.latestProps?.pkColumns).toEqual([]);
|
||||
expect(dataGridState.latestProps?.editLocator).toMatchObject({
|
||||
strategy: 'none',
|
||||
readOnly: true,
|
||||
reason: '未检测到主键或可用唯一索引,无法安全提交修改。',
|
||||
});
|
||||
expect(dataGridState.latestProps?.readOnly).toBe(true);
|
||||
expect(messageApi.warning).toHaveBeenCalledWith('表 main.users 保持只读:未检测到主键或可用唯一索引,无法安全提交修改。');
|
||||
renderer.unmount();
|
||||
});
|
||||
});
|
||||
@@ -1,8 +1,8 @@
|
||||
import React, { useEffect, useState, useCallback, useRef, useMemo } from 'react';
|
||||
import { message } from 'antd';
|
||||
import { TabData, ColumnDefinition } from '../types';
|
||||
import { TabData, ColumnDefinition, IndexDefinition } from '../types';
|
||||
import { useStore } from '../store';
|
||||
import { DBQuery, DBGetColumns } from '../../wailsjs/go/app/App';
|
||||
import { DBQuery, DBGetColumns, DBGetIndexes } from '../../wailsjs/go/app/App';
|
||||
import DataGrid, { GONAVI_ROW_KEY } from './DataGrid';
|
||||
import { buildOrderBySQL, buildPaginatedSelectSQL, buildWhereSQL, hasExplicitSort, quoteIdentPart, quoteQualifiedIdent, withSortBufferTuningSQL, type FilterCondition } from '../utils/sql';
|
||||
import { buildMongoCountCommand, buildMongoFilter, buildMongoFindCommand, buildMongoSort } from '../utils/mongodb';
|
||||
@@ -15,6 +15,12 @@ import {
|
||||
normalizeQuickWhereCondition,
|
||||
validateQuickWhereCondition,
|
||||
} from '../utils/dataGridWhereFilter';
|
||||
import {
|
||||
ORACLE_ROWID_LOCATOR_COLUMN,
|
||||
resolveEditRowLocator,
|
||||
type EditRowLocator,
|
||||
} from '../utils/rowLocator';
|
||||
import { isOracleLikeDialect } from '../utils/sqlDialect';
|
||||
|
||||
type ViewerPaginationState = {
|
||||
current: number;
|
||||
@@ -79,6 +85,47 @@ const parseTotalFromCountRow = (row: any): number | null => {
|
||||
return null;
|
||||
};
|
||||
|
||||
const buildDataViewerReadOnlyLocator = (reason: string): EditRowLocator => ({
|
||||
strategy: 'none',
|
||||
columns: [],
|
||||
valueColumns: [],
|
||||
readOnly: true,
|
||||
reason,
|
||||
});
|
||||
|
||||
const formatDataViewerTableName = (dbName: string, tableName: string): string => (
|
||||
dbName ? `${dbName}.${tableName}` : tableName
|
||||
);
|
||||
|
||||
const getTableColumnNames = (columns: ColumnDefinition[] | undefined): string[] => (
|
||||
(columns || [])
|
||||
.map((column) => String(column?.name || '').trim())
|
||||
.filter(Boolean)
|
||||
);
|
||||
|
||||
const resolveDataViewerOrderFallbackColumns = (locator: EditRowLocator | undefined, pkColumns: string[]): string[] => {
|
||||
if (locator && !locator.readOnly && locator.strategy !== 'oracle-rowid') {
|
||||
return locator.valueColumns.length > 0 ? locator.valueColumns : locator.columns;
|
||||
}
|
||||
return pkColumns;
|
||||
};
|
||||
|
||||
const buildDataViewerBaseSelectSQL = (
|
||||
dbType: string,
|
||||
tableName: string,
|
||||
whereSQL: string,
|
||||
locator?: EditRowLocator,
|
||||
): string => {
|
||||
const quotedTableName = quoteQualifiedIdent(dbType, tableName);
|
||||
if (locator?.strategy !== 'oracle-rowid') {
|
||||
return `SELECT * FROM ${quotedTableName} ${whereSQL}`;
|
||||
}
|
||||
|
||||
const alias = 'gonavi_row_source';
|
||||
const rowIDAlias = quoteIdentPart(dbType, ORACLE_ROWID_LOCATOR_COLUMN);
|
||||
return `SELECT ${alias}.*, ${alias}.ROWID AS ${rowIDAlias} FROM ${quotedTableName} ${alias} ${whereSQL}`;
|
||||
};
|
||||
|
||||
const normalizeDuckDBIdentifier = (raw: string): string => {
|
||||
const text = String(raw || '').trim();
|
||||
if (text.length >= 2) {
|
||||
@@ -193,6 +240,7 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAct
|
||||
const [data, setData] = useState<any[]>([]);
|
||||
const [columnNames, setColumnNames] = useState<string[]>([]);
|
||||
const [pkColumns, setPkColumns] = useState<string[]>([]);
|
||||
const [editLocator, setEditLocator] = useState<EditRowLocator | undefined>(undefined);
|
||||
const [loading, setLoading] = useState(false);
|
||||
const connections = useStore(state => state.connections);
|
||||
const addSqlLog = useStore(state => state.addSqlLog);
|
||||
@@ -280,6 +328,7 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAct
|
||||
useEffect(() => {
|
||||
const snapshot = getViewerFilterSnapshot(tab.id);
|
||||
setPkColumns([]);
|
||||
setEditLocator(undefined);
|
||||
pkKeyRef.current = '';
|
||||
countKeyRef.current = '';
|
||||
duckdbApproxKeyRef.current = '';
|
||||
@@ -435,10 +484,84 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAct
|
||||
const whereSQL = isMongoDB
|
||||
? JSON.stringify(mongoFilter || {})
|
||||
: buildWhereSQL(dbType, effectiveFilterConditions);
|
||||
|
||||
let pkColumnsForQuery = pkColumns;
|
||||
let editLocatorForQuery = editLocator;
|
||||
if (!isMongoDB && !forceReadOnly && tableName) {
|
||||
const locatorKey = `${tab.connectionId}|${dbTypeLower}|${dbName}|${tableName}`;
|
||||
if (pkKeyRef.current !== locatorKey || !editLocatorForQuery) {
|
||||
pkKeyRef.current = locatorKey;
|
||||
const locatorSeq = ++pkSeqRef.current;
|
||||
try {
|
||||
const [resCols, resIndexes] = await Promise.all([
|
||||
DBGetColumns(buildRpcConnectionConfig(config) as any, dbName, tableName),
|
||||
DBGetIndexes(buildRpcConnectionConfig(config) as any, dbName, tableName)
|
||||
.catch((error: any) => ({ success: false, message: String(error?.message || error || '加载索引失败'), data: [] })),
|
||||
]);
|
||||
if (fetchSeqRef.current !== seq) return;
|
||||
if (pkSeqRef.current !== locatorSeq) return;
|
||||
if (pkKeyRef.current !== locatorKey) return;
|
||||
|
||||
if (!resCols?.success || !Array.isArray(resCols.data)) {
|
||||
const nextLocator = buildDataViewerReadOnlyLocator('无法加载主键/唯一索引元数据,无法安全提交修改。');
|
||||
pkColumnsForQuery = [];
|
||||
editLocatorForQuery = nextLocator;
|
||||
setPkColumns([]);
|
||||
setEditLocator(nextLocator);
|
||||
message.warning(`表 ${formatDataViewerTableName(dbName, tableName)} 保持只读:${nextLocator.reason}`);
|
||||
} else {
|
||||
const columnDefs = resCols.data as ColumnDefinition[];
|
||||
const primaryKeys = columnDefs
|
||||
.filter((column: any) => column?.key === 'PRI')
|
||||
.map((column: any) => String(column?.name || '').trim())
|
||||
.filter(Boolean);
|
||||
const indexes = resIndexes?.success && Array.isArray(resIndexes.data)
|
||||
? resIndexes.data as IndexDefinition[]
|
||||
: [];
|
||||
const resultColumns = getTableColumnNames(columnDefs);
|
||||
const locatorColumns = isOracleLikeDialect(dbType)
|
||||
? [...resultColumns, ORACLE_ROWID_LOCATOR_COLUMN]
|
||||
: resultColumns;
|
||||
let nextLocator = resolveEditRowLocator({
|
||||
dbType,
|
||||
resultColumns: locatorColumns,
|
||||
primaryKeys,
|
||||
indexes,
|
||||
allowOracleRowID: true,
|
||||
});
|
||||
|
||||
if (nextLocator.readOnly && primaryKeys.length === 0 && !resIndexes?.success && !isOracleLikeDialect(dbType)) {
|
||||
nextLocator = buildDataViewerReadOnlyLocator('无法加载唯一索引元数据,无法安全提交修改。');
|
||||
}
|
||||
|
||||
pkColumnsForQuery = primaryKeys;
|
||||
editLocatorForQuery = nextLocator;
|
||||
setPkColumns(primaryKeys);
|
||||
setEditLocator(nextLocator);
|
||||
if (nextLocator.readOnly) {
|
||||
message.warning(`表 ${formatDataViewerTableName(dbName, tableName)} 保持只读:${nextLocator.reason || '当前结果没有可用的安全行定位方式,无法提交修改。'}`);
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
if (fetchSeqRef.current !== seq) return;
|
||||
if (pkSeqRef.current !== locatorSeq) return;
|
||||
if (pkKeyRef.current !== locatorKey) return;
|
||||
const nextLocator = buildDataViewerReadOnlyLocator('无法加载主键/唯一索引元数据,无法安全提交修改。');
|
||||
pkColumnsForQuery = [];
|
||||
editLocatorForQuery = nextLocator;
|
||||
setPkColumns([]);
|
||||
setEditLocator(nextLocator);
|
||||
message.warning(`表 ${formatDataViewerTableName(dbName, tableName)} 保持只读:${nextLocator.reason}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const countSql = isMongoDB
|
||||
? buildMongoCountCommand(tableName, mongoFilter || {})
|
||||
: `SELECT COUNT(*) as total FROM ${quoteQualifiedIdent(dbType, tableName)} ${whereSQL}`;
|
||||
const orderBySQL = isMongoDB ? '' : buildOrderBySQL(dbType, sortInfo, pkColumns);
|
||||
const orderBySQL = isMongoDB
|
||||
? ''
|
||||
: buildOrderBySQL(dbType, sortInfo, resolveDataViewerOrderFallbackColumns(editLocatorForQuery, pkColumnsForQuery));
|
||||
const totalRows = Number(pagination.total);
|
||||
const hasFiniteTotal = Number.isFinite(totalRows) && totalRows >= 0;
|
||||
const totalKnown = pagination.totalKnown && hasFiniteTotal;
|
||||
@@ -469,7 +592,7 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAct
|
||||
skip: offset,
|
||||
});
|
||||
} else {
|
||||
const baseSql = `SELECT * FROM ${quoteQualifiedIdent(dbType, tableName)} ${whereSQL}`;
|
||||
const baseSql = buildDataViewerBaseSelectSQL(dbType, tableName, whereSQL, editLocatorForQuery);
|
||||
sql = `${baseSql}${orderBySQL}`;
|
||||
// ClickHouse 深分页在超大 OFFSET 下容易超时。对于总数已知且存在 ORDER BY 的场景,
|
||||
// 当“尾部偏移”小于“头部偏移”时,改为反向 ORDER BY + 小 OFFSET,并在前端翻转结果。
|
||||
@@ -557,7 +680,7 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAct
|
||||
|
||||
if (safeSelect) {
|
||||
let fallbackSql = `SELECT ${safeSelect} FROM ${quoteQualifiedIdent(dbType, tableName)} ${whereSQL}`;
|
||||
fallbackSql = buildPaginatedSelectSQL(dbType, fallbackSql, buildOrderBySQL(dbType, sortInfo, pkColumns), size + 1, offset);
|
||||
fallbackSql = buildPaginatedSelectSQL(dbType, fallbackSql, buildOrderBySQL(dbType, sortInfo, resolveDataViewerOrderFallbackColumns(editLocatorForQuery, pkColumnsForQuery)), size + 1, offset);
|
||||
executedSql = fallbackSql;
|
||||
resData = await executeDataQuery(fallbackSql, '复杂类型降级重试');
|
||||
}
|
||||
@@ -580,26 +703,6 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAct
|
||||
message.warning('已自动提升排序缓冲并重试成功。');
|
||||
}
|
||||
}
|
||||
|
||||
if (pkColumns.length === 0) {
|
||||
const pkKey = `${tab.connectionId}|${dbName}|${tableName}`;
|
||||
if (pkKeyRef.current !== pkKey) {
|
||||
pkKeyRef.current = pkKey;
|
||||
const pkSeq = ++pkSeqRef.current;
|
||||
DBGetColumns(buildRpcConnectionConfig(config) as any, dbName, tableName)
|
||||
.then((resCols: any) => {
|
||||
if (pkSeqRef.current !== pkSeq) return;
|
||||
if (pkKeyRef.current !== pkKey) return;
|
||||
if (!resCols?.success) return;
|
||||
const pks = (resCols.data as ColumnDefinition[]).filter((c: any) => c.key === 'PRI').map((c: any) => c.name);
|
||||
setPkColumns(pks);
|
||||
})
|
||||
.catch(() => {
|
||||
if (pkSeqRef.current !== pkSeq) return;
|
||||
if (pkKeyRef.current !== pkKey) return;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (resData.success) {
|
||||
let resultData = resData.data as any[];
|
||||
@@ -842,9 +945,9 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAct
|
||||
});
|
||||
}
|
||||
if (fetchSeqRef.current === seq) setLoading(false);
|
||||
}, [connections, tab, sortInfo, filterConditions, quickWhereCondition, pkColumns, pagination.total, pagination.totalKnown, pagination.totalApprox, pagination.approximateTotal, preferManualTotalCount, supportsApproximateTableCount, supportsApproximateTotalPages]);
|
||||
// 依赖 pkColumns:在无手动排序时可回退到主键稳定排序。
|
||||
// 主键信息只会在首次加载后更新一次,避免循环查询。
|
||||
}, [connections, tab, sortInfo, filterConditions, quickWhereCondition, pkColumns, editLocator, forceReadOnly, pagination.total, pagination.totalKnown, pagination.totalApprox, pagination.approximateTotal, preferManualTotalCount, supportsApproximateTableCount, supportsApproximateTotalPages]);
|
||||
// 依赖定位列:在无手动排序时可回退到安全定位列稳定排序。
|
||||
// 定位信息只会在表上下文变化后重新加载,避免循环查询。
|
||||
|
||||
// Handlers memoized
|
||||
const handleReload = useCallback(() => {
|
||||
@@ -890,14 +993,14 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAct
|
||||
if (!whereSQL) return '';
|
||||
|
||||
let sql = `SELECT * FROM ${quoteQualifiedIdent(dbType, tableName)} ${whereSQL}`;
|
||||
sql += buildOrderBySQL(dbType, sortInfo, pkColumns);
|
||||
sql += buildOrderBySQL(dbType, sortInfo, resolveDataViewerOrderFallbackColumns(editLocator, pkColumns));
|
||||
const normalizedType = dbType.toLowerCase();
|
||||
const hasSortForBuffer = hasExplicitSort(sortInfo);
|
||||
if (hasSortForBuffer && (normalizedType === 'mysql' || normalizedType === 'mariadb')) {
|
||||
sql = withSortBufferTuningSQL(normalizedType, sql, 32 * 1024 * 1024);
|
||||
}
|
||||
return sql;
|
||||
}, [tab.tableName, currentConnConfig?.type, currentConnConfig?.driver, filterConditions, quickWhereCondition, sortInfo, pkColumns]);
|
||||
}, [tab.tableName, currentConnConfig?.type, currentConnConfig?.driver, filterConditions, quickWhereCondition, sortInfo, editLocator, pkColumns]);
|
||||
|
||||
useEffect(() => {
|
||||
const action = resolveDataViewerAutoFetchAction({
|
||||
@@ -927,6 +1030,7 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAct
|
||||
dbName={tab.dbName}
|
||||
connectionId={tab.connectionId}
|
||||
pkColumns={pkColumns}
|
||||
editLocator={editLocator}
|
||||
onReload={handleReload}
|
||||
onSort={handleSort}
|
||||
onPageChange={handlePageChange}
|
||||
@@ -939,7 +1043,7 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAct
|
||||
appliedFilterConditions={filterConditions}
|
||||
quickWhereCondition={quickWhereCondition}
|
||||
onApplyQuickWhereCondition={handleApplyQuickWhereCondition}
|
||||
readOnly={forceReadOnly}
|
||||
readOnly={forceReadOnly || !editLocator || editLocator.readOnly}
|
||||
sortInfoExternal={sortInfo}
|
||||
exportSqlWithFilter={exportSqlWithFilter || undefined}
|
||||
scrollSnapshot={scrollSnapshotRef.current}
|
||||
|
||||
@@ -39,6 +39,11 @@ type DriverStatusRow = {
|
||||
packagePath?: string;
|
||||
executablePath?: string;
|
||||
downloadedAt?: string;
|
||||
agentRevision?: string;
|
||||
expectedRevision?: string;
|
||||
needsUpdate?: boolean;
|
||||
updateReason?: string;
|
||||
affectedConnections?: number;
|
||||
message?: string;
|
||||
};
|
||||
|
||||
@@ -360,6 +365,13 @@ const DriverManagerModal: React.FC<{ open: boolean; onClose: () => void; onOpenG
|
||||
packagePath: String(item.packagePath || '').trim() || undefined,
|
||||
executablePath: String(item.executablePath || '').trim() || undefined,
|
||||
downloadedAt: String(item.downloadedAt || '').trim() || undefined,
|
||||
agentRevision: String(item.agentRevision || '').trim() || undefined,
|
||||
expectedRevision: String(item.expectedRevision || '').trim() || undefined,
|
||||
needsUpdate: !!item.needsUpdate,
|
||||
updateReason: String(item.updateReason || '').trim() || undefined,
|
||||
affectedConnections: Number.isFinite(Number(item.affectedConnections))
|
||||
? Number(item.affectedConnections)
|
||||
: undefined,
|
||||
message: String(item.message || '').trim() || undefined,
|
||||
}));
|
||||
setRows(nextRows);
|
||||
@@ -1005,7 +1017,17 @@ const DriverManagerModal: React.FC<{ open: boolean; onClose: () => void; onOpenG
|
||||
title: '数据源',
|
||||
dataIndex: 'name',
|
||||
key: 'name',
|
||||
width: 150,
|
||||
width: 220,
|
||||
render: (_: string, row: DriverStatusRow) => (
|
||||
<div style={{ display: 'grid', gap: 4 }}>
|
||||
<Text strong>{row.name}</Text>
|
||||
{row.message ? (
|
||||
<Text type={row.needsUpdate ? 'warning' : 'secondary'} style={{ fontSize: 12 }}>
|
||||
{row.message}
|
||||
</Text>
|
||||
) : null}
|
||||
</div>
|
||||
),
|
||||
},
|
||||
{
|
||||
title: '安装包大小',
|
||||
@@ -1042,6 +1064,9 @@ const DriverManagerModal: React.FC<{ open: boolean; onClose: () => void; onOpenG
|
||||
if (progress && (progress.status === 'start' || progress.status === 'downloading')) {
|
||||
return <Tag color="processing">安装中 {Math.round(progress.percent)}%</Tag>;
|
||||
}
|
||||
if (row.needsUpdate) {
|
||||
return <Tag color="warning">强烈建议重装</Tag>;
|
||||
}
|
||||
if (row.connectable) {
|
||||
return <Tag color="success">已启用</Tag>;
|
||||
}
|
||||
@@ -1089,10 +1114,11 @@ const DriverManagerModal: React.FC<{ open: boolean; onClose: () => void; onOpenG
|
||||
const versionLocked = row.packageInstalled || row.connectable;
|
||||
if (versionLocked) {
|
||||
const installedVersion = String(row.installedVersion || '').trim();
|
||||
const revisionHint = row.needsUpdate ? ',需重装' : '';
|
||||
if (installedVersion) {
|
||||
return <Text type="secondary">{installedVersion}(已安装,移除后可更换)</Text>;
|
||||
return <Text type="secondary">{installedVersion}(已安装{revisionHint},移除后可更换)</Text>;
|
||||
}
|
||||
return <Text type="secondary">已安装(移除后可更换)</Text>;
|
||||
return <Text type="secondary">已安装({row.needsUpdate ? '需重装,' : ''}移除后可更换)</Text>;
|
||||
}
|
||||
const options = versionMap[row.type] || [];
|
||||
const selectedKey = selectedVersionMap[row.type];
|
||||
@@ -1148,7 +1174,16 @@ const DriverManagerModal: React.FC<{ open: boolean; onClose: () => void; onOpenG
|
||||
const logs = operationLogMap[row.type] || [];
|
||||
const hasLogs = logs.length > 0;
|
||||
|
||||
const mainAction = row.connectable ? (
|
||||
const mainAction = row.needsUpdate ? (
|
||||
<Button
|
||||
type="primary"
|
||||
icon={<DownloadOutlined />}
|
||||
loading={loadingInstallOrRemove}
|
||||
onClick={() => installDriver(row)}
|
||||
>
|
||||
重装驱动
|
||||
</Button>
|
||||
) : row.connectable ? (
|
||||
<Button
|
||||
danger
|
||||
icon={<DeleteOutlined />}
|
||||
@@ -1209,9 +1244,10 @@ const DriverManagerModal: React.FC<{ open: boolean; onClose: () => void; onOpenG
|
||||
row.type,
|
||||
row.pinnedVersion,
|
||||
row.installedVersion,
|
||||
row.updateReason,
|
||||
row.message,
|
||||
row.builtIn ? '内置' : '外置',
|
||||
row.connectable ? '已启用' : row.packageInstalled ? '已安装' : '未启用',
|
||||
row.needsUpdate ? '强烈建议重装' : row.connectable ? '已启用' : row.packageInstalled ? '已安装' : '未启用',
|
||||
];
|
||||
const searchableText = normalizeDriverSearchText(searchableParts.filter(Boolean).join(' '));
|
||||
return searchableText.includes(normalizedSearchKeyword);
|
||||
|
||||
@@ -3,6 +3,7 @@ import { act, create, type ReactTestRenderer } from 'react-test-renderer';
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import type { SavedQuery, TabData } from '../types';
|
||||
import { ORACLE_ROWID_LOCATOR_COLUMN } from '../utils/rowLocator';
|
||||
import QueryEditor from './QueryEditor';
|
||||
|
||||
const storeState = vi.hoisted(() => ({
|
||||
@@ -44,6 +45,7 @@ const backendApp = vi.hoisted(() => ({
|
||||
DBGetAllColumns: vi.fn(),
|
||||
DBGetDatabases: vi.fn(),
|
||||
DBGetColumns: vi.fn(),
|
||||
DBGetIndexes: vi.fn(),
|
||||
CancelQuery: vi.fn(),
|
||||
GenerateQueryID: vi.fn(),
|
||||
WriteSQLFile: vi.fn(),
|
||||
@@ -56,6 +58,10 @@ const messageApi = vi.hoisted(() => ({
|
||||
warning: vi.fn(),
|
||||
}));
|
||||
|
||||
const dataGridState = vi.hoisted(() => ({
|
||||
latestProps: null as any,
|
||||
}));
|
||||
|
||||
const editorState = vi.hoisted(() => {
|
||||
const state = {
|
||||
value: '',
|
||||
@@ -114,7 +120,10 @@ vi.mock('@monaco-editor/react', () => ({
|
||||
}));
|
||||
|
||||
vi.mock('./DataGrid', () => ({
|
||||
default: () => null,
|
||||
default: (props: any) => {
|
||||
dataGridState.latestProps = props;
|
||||
return <div data-grid="true" />;
|
||||
},
|
||||
GONAVI_ROW_KEY: '__gonavi_row_key__',
|
||||
}));
|
||||
|
||||
@@ -152,7 +161,7 @@ vi.mock('antd', () => {
|
||||
Dropdown: ({ children }: any) => <>{children}</>,
|
||||
Tooltip: ({ children }: any) => <>{children}</>,
|
||||
Select: () => null,
|
||||
Tabs: () => null,
|
||||
Tabs: ({ items }: any) => <div>{items?.[0]?.children}</div>,
|
||||
};
|
||||
});
|
||||
|
||||
@@ -187,7 +196,15 @@ describe('QueryEditor external SQL save', () => {
|
||||
storeState.activeTabId = 'tab-1';
|
||||
messageApi.success.mockReset();
|
||||
messageApi.error.mockReset();
|
||||
messageApi.warning.mockReset();
|
||||
backendApp.WriteSQLFile.mockResolvedValue({ success: true });
|
||||
backendApp.DBQueryMulti.mockResolvedValue({ success: true, data: [] });
|
||||
backendApp.DBGetColumns.mockResolvedValue({ success: true, data: [] });
|
||||
backendApp.DBGetIndexes.mockResolvedValue({ success: true, data: [] });
|
||||
backendApp.GenerateQueryID.mockResolvedValue('query-1');
|
||||
storeState.connections[0].config.type = 'mysql';
|
||||
storeState.connections[0].config.database = 'main';
|
||||
dataGridState.latestProps = null;
|
||||
editorState.value = '';
|
||||
editorState.editor.getValue.mockClear();
|
||||
editorState.editor.setValue.mockClear();
|
||||
@@ -276,4 +293,253 @@ describe('QueryEditor external SQL save', () => {
|
||||
createdAt: 100,
|
||||
}));
|
||||
});
|
||||
|
||||
it('automatically appends hidden primary key locator columns for editable query results', async () => {
|
||||
storeState.connections[0].config.type = 'oracle';
|
||||
storeState.connections[0].config.database = 'ORCLPDB1';
|
||||
backendApp.DBQueryMulti.mockResolvedValueOnce({
|
||||
success: true,
|
||||
data: [{ columns: ['NAME', '__gonavi_locator_1_ID'], rows: [{ NAME: 'old-name', __gonavi_locator_1_ID: 7 }] }],
|
||||
});
|
||||
backendApp.DBGetColumns.mockResolvedValueOnce({
|
||||
success: true,
|
||||
data: [{ name: 'ID', key: 'PRI' }, { name: 'NAME', key: '' }],
|
||||
});
|
||||
|
||||
let renderer: ReactTestRenderer;
|
||||
await act(async () => {
|
||||
renderer = create(<QueryEditor tab={createTab({ dbName: 'ANONYMOUS', query: 'SELECT NAME FROM MYCIMLED.EDC_LOG' })} />);
|
||||
});
|
||||
|
||||
await act(async () => {
|
||||
await findButton(renderer!, '运行').props.onClick();
|
||||
});
|
||||
await act(async () => {
|
||||
await Promise.resolve();
|
||||
await Promise.resolve();
|
||||
});
|
||||
|
||||
expect(dataGridState.latestProps?.tableName).toBe('MYCIMLED.EDC_LOG');
|
||||
expect(dataGridState.latestProps?.pkColumns).toEqual(['ID']);
|
||||
expect(dataGridState.latestProps?.editLocator).toMatchObject({
|
||||
strategy: 'primary-key',
|
||||
columns: ['ID'],
|
||||
valueColumns: ['__gonavi_locator_1_ID'],
|
||||
hiddenColumns: ['__gonavi_locator_1_ID'],
|
||||
readOnly: false,
|
||||
});
|
||||
expect(dataGridState.latestProps?.readOnly).toBe(false);
|
||||
expect(dataGridState.latestProps?.resultSql).toBe('SELECT NAME FROM MYCIMLED.EDC_LOG');
|
||||
expect(String(backendApp.DBQueryMulti.mock.calls[0][2])).toContain('"ID" AS "__gonavi_locator_1_ID"');
|
||||
expect(messageApi.warning).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('uses a unique index locator for query results without primary keys', async () => {
|
||||
storeState.connections[0].config.type = 'oracle';
|
||||
storeState.connections[0].config.database = 'ORCLPDB1';
|
||||
backendApp.DBQueryMulti.mockResolvedValueOnce({
|
||||
success: true,
|
||||
data: [{ columns: ['NAME', '__gonavi_locator_1_EMAIL'], rows: [{ NAME: 'old-name', __gonavi_locator_1_EMAIL: 'a@example.com' }] }],
|
||||
});
|
||||
backendApp.DBGetColumns.mockResolvedValueOnce({
|
||||
success: true,
|
||||
data: [{ name: 'EMAIL', key: '' }, { name: 'NAME', key: '' }],
|
||||
});
|
||||
backendApp.DBGetIndexes.mockResolvedValueOnce({
|
||||
success: true,
|
||||
data: [{ name: 'UK_EMAIL', columnName: 'EMAIL', nonUnique: 0, seqInIndex: 1, indexType: 'BTREE' }],
|
||||
});
|
||||
|
||||
let renderer: ReactTestRenderer;
|
||||
await act(async () => {
|
||||
renderer = create(<QueryEditor tab={createTab({ dbName: 'ANONYMOUS', query: 'SELECT NAME FROM MYCIMLED.EDC_LOG' })} />);
|
||||
});
|
||||
|
||||
await act(async () => {
|
||||
await findButton(renderer!, '运行').props.onClick();
|
||||
});
|
||||
await act(async () => {
|
||||
await Promise.resolve();
|
||||
await Promise.resolve();
|
||||
});
|
||||
|
||||
expect(dataGridState.latestProps?.editLocator).toMatchObject({
|
||||
strategy: 'unique-key',
|
||||
columns: ['EMAIL'],
|
||||
valueColumns: ['__gonavi_locator_1_EMAIL'],
|
||||
hiddenColumns: ['__gonavi_locator_1_EMAIL'],
|
||||
readOnly: false,
|
||||
});
|
||||
expect(dataGridState.latestProps?.readOnly).toBe(false);
|
||||
expect(String(backendApp.DBQueryMulti.mock.calls[0][2])).toContain('"EMAIL" AS "__gonavi_locator_1_EMAIL"');
|
||||
expect(messageApi.warning).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('uses hidden Oracle ROWID for query results without primary or unique keys', async () => {
|
||||
storeState.connections[0].config.type = 'oracle';
|
||||
storeState.connections[0].config.database = 'ORCLPDB1';
|
||||
backendApp.DBQueryMulti.mockResolvedValueOnce({
|
||||
success: true,
|
||||
data: [{ columns: ['NAME', ORACLE_ROWID_LOCATOR_COLUMN], rows: [{ NAME: 'old-name', [ORACLE_ROWID_LOCATOR_COLUMN]: 'AAAA' }] }],
|
||||
});
|
||||
backendApp.DBGetColumns.mockResolvedValueOnce({
|
||||
success: true,
|
||||
data: [{ name: 'NAME', key: '' }],
|
||||
});
|
||||
|
||||
let renderer: ReactTestRenderer;
|
||||
await act(async () => {
|
||||
renderer = create(<QueryEditor tab={createTab({ dbName: 'ANONYMOUS', query: 'SELECT NAME FROM MYCIMLED.EDC_LOG' })} />);
|
||||
});
|
||||
|
||||
await act(async () => {
|
||||
await findButton(renderer!, '运行').props.onClick();
|
||||
});
|
||||
await act(async () => {
|
||||
await Promise.resolve();
|
||||
await Promise.resolve();
|
||||
});
|
||||
|
||||
expect(dataGridState.latestProps?.editLocator).toMatchObject({
|
||||
strategy: 'oracle-rowid',
|
||||
columns: ['ROWID'],
|
||||
valueColumns: [ORACLE_ROWID_LOCATOR_COLUMN],
|
||||
hiddenColumns: [ORACLE_ROWID_LOCATOR_COLUMN],
|
||||
readOnly: false,
|
||||
});
|
||||
expect(dataGridState.latestProps?.readOnly).toBe(false);
|
||||
expect(String(backendApp.DBQueryMulti.mock.calls[0][2])).toContain(`ROWID AS "${ORACLE_ROWID_LOCATOR_COLUMN}"`);
|
||||
expect(messageApi.warning).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('keeps non-Oracle query results read-only when no safe locator exists', async () => {
|
||||
backendApp.DBQueryMulti.mockResolvedValueOnce({
|
||||
success: true,
|
||||
data: [{ columns: ['NAME'], rows: [{ NAME: 'old-name' }] }],
|
||||
});
|
||||
backendApp.DBGetColumns.mockResolvedValueOnce({
|
||||
success: true,
|
||||
data: [{ name: 'NAME', key: '' }],
|
||||
});
|
||||
|
||||
let renderer: ReactTestRenderer;
|
||||
await act(async () => {
|
||||
renderer = create(<QueryEditor tab={createTab({ dbName: 'main', query: 'SELECT NAME FROM users' })} />);
|
||||
});
|
||||
|
||||
await act(async () => {
|
||||
await findButton(renderer!, '运行').props.onClick();
|
||||
});
|
||||
await act(async () => {
|
||||
await Promise.resolve();
|
||||
await Promise.resolve();
|
||||
});
|
||||
|
||||
expect(dataGridState.latestProps?.tableName).toBe('users');
|
||||
expect(dataGridState.latestProps?.pkColumns).toEqual([]);
|
||||
expect(dataGridState.latestProps?.editLocator).toMatchObject({
|
||||
strategy: 'none',
|
||||
readOnly: true,
|
||||
reason: '未检测到主键或可用唯一索引,无法安全提交修改。',
|
||||
});
|
||||
expect(dataGridState.latestProps?.readOnly).toBe(true);
|
||||
expect(messageApi.warning).toHaveBeenCalledWith('查询结果保持只读:main.users 未检测到主键或可用唯一索引,无法安全提交修改。');
|
||||
});
|
||||
|
||||
it('allows editable table columns while leaving expression columns out of commits', async () => {
|
||||
backendApp.DBQueryMulti.mockResolvedValueOnce({
|
||||
success: true,
|
||||
data: [{
|
||||
columns: ['DISPLAY_NAME', 'NAME_UPPER', '__gonavi_locator_1_ID'],
|
||||
rows: [{ DISPLAY_NAME: 'old-name', NAME_UPPER: 'OLD-NAME', __gonavi_locator_1_ID: 7 }],
|
||||
}],
|
||||
});
|
||||
backendApp.DBGetColumns.mockResolvedValueOnce({
|
||||
success: true,
|
||||
data: [{ name: 'ID', key: 'PRI' }, { name: 'NAME', key: '' }],
|
||||
});
|
||||
|
||||
let renderer: ReactTestRenderer;
|
||||
await act(async () => {
|
||||
renderer = create(<QueryEditor tab={createTab({
|
||||
dbName: 'main',
|
||||
query: 'SELECT NAME AS DISPLAY_NAME, UPPER(NAME) AS NAME_UPPER FROM users',
|
||||
})} />);
|
||||
});
|
||||
|
||||
await act(async () => {
|
||||
await findButton(renderer!, '运行').props.onClick();
|
||||
});
|
||||
await act(async () => {
|
||||
await Promise.resolve();
|
||||
await Promise.resolve();
|
||||
});
|
||||
|
||||
expect(dataGridState.latestProps?.tableName).toBe('users');
|
||||
expect(dataGridState.latestProps?.editLocator).toMatchObject({
|
||||
strategy: 'primary-key',
|
||||
columns: ['ID'],
|
||||
valueColumns: ['__gonavi_locator_1_ID'],
|
||||
hiddenColumns: ['__gonavi_locator_1_ID'],
|
||||
writableColumns: {
|
||||
DISPLAY_NAME: 'NAME',
|
||||
},
|
||||
readOnly: false,
|
||||
});
|
||||
expect(dataGridState.latestProps?.readOnly).toBe(false);
|
||||
expect(String(backendApp.DBQueryMulti.mock.calls[0][2])).toContain('`ID` AS `__gonavi_locator_1_ID`');
|
||||
expect(messageApi.warning).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it.each([
|
||||
'mysql',
|
||||
'mariadb',
|
||||
'diros',
|
||||
'sphinx',
|
||||
'postgres',
|
||||
'kingbase',
|
||||
'highgo',
|
||||
'vastbase',
|
||||
'sqlserver',
|
||||
'sqlite',
|
||||
'duckdb',
|
||||
'oracle',
|
||||
'dameng',
|
||||
'tdengine',
|
||||
'clickhouse',
|
||||
])(
|
||||
'keeps aggregate query results silently read-only for %s',
|
||||
async (dbType) => {
|
||||
storeState.connections[0].config.type = dbType;
|
||||
storeState.connections[0].config.database = dbType === 'oracle' || dbType === 'dameng' ? 'APP' : 'main';
|
||||
const forceReadOnlyQueryResult = dbType === 'tdengine' || dbType === 'clickhouse';
|
||||
backendApp.DBQueryMulti.mockResolvedValueOnce({
|
||||
success: true,
|
||||
data: [{ columns: ['COUNT'], rows: [{ COUNT: 1 }] }],
|
||||
});
|
||||
|
||||
let renderer: ReactTestRenderer;
|
||||
await act(async () => {
|
||||
renderer = create(<QueryEditor tab={createTab({
|
||||
dbName: storeState.connections[0].config.database,
|
||||
query: 'SELECT count(1) FROM users',
|
||||
})} />);
|
||||
});
|
||||
|
||||
await act(async () => {
|
||||
await findButton(renderer!, '运行').props.onClick();
|
||||
});
|
||||
await act(async () => {
|
||||
await Promise.resolve();
|
||||
await Promise.resolve();
|
||||
});
|
||||
|
||||
expect(dataGridState.latestProps?.tableName).toBe(forceReadOnlyQueryResult ? undefined : 'users');
|
||||
expect(dataGridState.latestProps?.editLocator).toBeUndefined();
|
||||
expect(dataGridState.latestProps?.readOnly).toBe(true);
|
||||
expect(backendApp.DBGetColumns).not.toHaveBeenCalled();
|
||||
expect(backendApp.DBGetIndexes).not.toHaveBeenCalled();
|
||||
expect(messageApi.warning).not.toHaveBeenCalled();
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
@@ -4,16 +4,21 @@ import { Button, message, Modal, Input, Form, Dropdown, MenuProps, Tooltip, Sele
|
||||
import { PlayCircleOutlined, SaveOutlined, FormatPainterOutlined, SettingOutlined, CloseOutlined, StopOutlined, RobotOutlined } from '@ant-design/icons';
|
||||
import { format } from 'sql-formatter';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { TabData, ColumnDefinition } from '../types';
|
||||
import { TabData, ColumnDefinition, IndexDefinition } from '../types';
|
||||
import { useStore } from '../store';
|
||||
import { DBQueryWithCancel, DBQueryMulti, DBGetTables, DBGetAllColumns, DBGetDatabases, DBGetColumns, CancelQuery, GenerateQueryID, WriteSQLFile } from '../../wailsjs/go/app/App';
|
||||
import { DBQueryWithCancel, DBQueryMulti, DBGetTables, DBGetAllColumns, DBGetDatabases, DBGetColumns, DBGetIndexes, CancelQuery, GenerateQueryID, WriteSQLFile } from '../../wailsjs/go/app/App';
|
||||
import DataGrid, { GONAVI_ROW_KEY } from './DataGrid';
|
||||
import { getDataSourceCapabilities } from '../utils/dataSourceCapabilities';
|
||||
import { convertMongoShellToJsonCommand } from '../utils/mongodb';
|
||||
import { applyMongoQueryAutoLimit, convertMongoShellToJsonCommand } from '../utils/mongodb';
|
||||
import { getShortcutDisplay, isEditableElement, isShortcutMatch } from '../utils/shortcuts';
|
||||
import { useAutoFetchVisibility } from '../utils/autoFetchVisibility';
|
||||
import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig';
|
||||
import { resolveSqlDialect, resolveSqlFunctions, resolveSqlKeywords } from '../utils/sqlDialect';
|
||||
import { isOracleLikeDialect, resolveSqlDialect, resolveSqlFunctions, resolveSqlKeywords } from '../utils/sqlDialect';
|
||||
import { applyQueryAutoLimit } from '../utils/queryAutoLimit';
|
||||
import { extractQueryResultTableRef, type QueryResultTableRef } from '../utils/queryResultTable';
|
||||
import { quoteIdentPart } from '../utils/sql';
|
||||
import { resolveUniqueKeyGroupsFromIndexes } from './dataGridCopyInsert';
|
||||
import { ORACLE_ROWID_LOCATOR_COLUMN, type EditRowLocator } from '../utils/rowLocator';
|
||||
|
||||
const SQL_KEYWORDS = [
|
||||
'SELECT', 'FROM', 'WHERE', 'LIMIT', 'INSERT', 'UPDATE', 'DELETE', 'JOIN', 'LEFT', 'RIGHT',
|
||||
@@ -186,6 +191,300 @@ let sharedAllColumnsData: {dbName: string, tableName: string, name: string, type
|
||||
let sharedVisibleDbs: string[] = [];
|
||||
let sharedColumnsCacheData: Record<string, any[]> = {};
|
||||
|
||||
const QUERY_LOCATOR_ALIAS_PREFIX = '__gonavi_locator_';
|
||||
|
||||
const buildQueryReadOnlyLocator = (reason: string): EditRowLocator => ({
|
||||
strategy: 'none',
|
||||
columns: [],
|
||||
valueColumns: [],
|
||||
readOnly: true,
|
||||
reason,
|
||||
});
|
||||
|
||||
type SimpleSelectInfo = {
|
||||
selectsAll: boolean;
|
||||
writableColumns: Record<string, string>;
|
||||
};
|
||||
|
||||
type QueryStatementPlan = {
|
||||
originalSql: string;
|
||||
executedSql: string;
|
||||
tableRef?: QueryResultTableRef;
|
||||
pkColumns: string[];
|
||||
editLocator?: EditRowLocator;
|
||||
warning?: string;
|
||||
};
|
||||
|
||||
const stripQueryIdentifierQuotes = (part: string): string => {
|
||||
const text = String(part || '').trim();
|
||||
if (!text) return '';
|
||||
if ((text.startsWith('`') && text.endsWith('`')) || (text.startsWith('"') && text.endsWith('"'))) {
|
||||
return text.slice(1, -1).trim();
|
||||
}
|
||||
if (text.startsWith('[') && text.endsWith(']')) {
|
||||
return text.slice(1, -1).trim();
|
||||
}
|
||||
return text;
|
||||
};
|
||||
|
||||
const splitTopLevelComma = (text: string): string[] => {
|
||||
const parts: string[] = [];
|
||||
let current = '';
|
||||
let parenDepth = 0;
|
||||
let inSingle = false;
|
||||
let inDouble = false;
|
||||
let inBacktick = false;
|
||||
let escaped = false;
|
||||
|
||||
for (let index = 0; index < text.length; index++) {
|
||||
const ch = text[index];
|
||||
if (escaped) {
|
||||
current += ch;
|
||||
escaped = false;
|
||||
continue;
|
||||
}
|
||||
if ((inSingle || inDouble) && ch === '\\') {
|
||||
current += ch;
|
||||
escaped = true;
|
||||
continue;
|
||||
}
|
||||
if (!inDouble && !inBacktick && ch === "'") {
|
||||
inSingle = !inSingle;
|
||||
current += ch;
|
||||
continue;
|
||||
}
|
||||
if (!inSingle && !inBacktick && ch === '"') {
|
||||
inDouble = !inDouble;
|
||||
current += ch;
|
||||
continue;
|
||||
}
|
||||
if (!inSingle && !inDouble && ch === '`') {
|
||||
inBacktick = !inBacktick;
|
||||
current += ch;
|
||||
continue;
|
||||
}
|
||||
if (!inSingle && !inDouble && !inBacktick) {
|
||||
if (ch === '(') parenDepth++;
|
||||
if (ch === ')' && parenDepth > 0) parenDepth--;
|
||||
if (ch === ',' && parenDepth === 0) {
|
||||
parts.push(current.trim());
|
||||
current = '';
|
||||
continue;
|
||||
}
|
||||
}
|
||||
current += ch;
|
||||
}
|
||||
|
||||
if (current.trim()) parts.push(current.trim());
|
||||
return parts;
|
||||
};
|
||||
|
||||
const SIMPLE_IDENTIFIER_PATH_RE = /^(?:[`"\[]?[A-Za-z_][\w$]*[`"\]]?\s*\.\s*){0,2}[`"\[]?[A-Za-z_][\w$]*[`"\]]?$/;
|
||||
const QUERY_ALIAS_RESERVED = new Set([
|
||||
'where', 'group', 'order', 'having', 'limit', 'fetch', 'offset', 'join', 'left', 'right', 'inner', 'outer', 'on', 'union',
|
||||
]);
|
||||
|
||||
const getLastIdentifierPart = (path: string): string => {
|
||||
const parts = String(path || '').split('.').map((part) => stripQueryIdentifierQuotes(part.trim())).filter(Boolean);
|
||||
return parts[parts.length - 1] || '';
|
||||
};
|
||||
|
||||
const resolveSimpleSelectItemColumn = (item: string): { resultName: string; sourceName: string } | 'all' | undefined => {
|
||||
const text = String(item || '').trim();
|
||||
if (!text) return undefined;
|
||||
if (text === '*' || /\.\s*\*$/.test(text)) return 'all';
|
||||
|
||||
let expr = text;
|
||||
let alias = '';
|
||||
const asMatch = text.match(/^(.*?)\s+AS\s+([`"\[]?[A-Za-z_][\w$]*[`"\]]?)$/i);
|
||||
if (asMatch) {
|
||||
expr = asMatch[1].trim();
|
||||
alias = stripQueryIdentifierQuotes(asMatch[2]);
|
||||
} else {
|
||||
const bareAliasMatch = text.match(/^(.*?)\s+([`"\[]?[A-Za-z_][\w$]*[`"\]]?)$/);
|
||||
if (bareAliasMatch && SIMPLE_IDENTIFIER_PATH_RE.test(bareAliasMatch[1].trim())) {
|
||||
const candidateAlias = stripQueryIdentifierQuotes(bareAliasMatch[2]);
|
||||
if (candidateAlias && !QUERY_ALIAS_RESERVED.has(candidateAlias.toLowerCase())) {
|
||||
expr = bareAliasMatch[1].trim();
|
||||
alias = candidateAlias;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!SIMPLE_IDENTIFIER_PATH_RE.test(expr)) return undefined;
|
||||
const sourceName = getLastIdentifierPart(expr);
|
||||
const resultName = alias || sourceName;
|
||||
return sourceName && resultName ? { resultName, sourceName } : undefined;
|
||||
};
|
||||
|
||||
const parseSimpleSelectInfo = (sql: string): SimpleSelectInfo | undefined => {
|
||||
const match = String(sql || '').match(/^\s*SELECT\s+([\s\S]+?)\s+FROM\s+/i);
|
||||
if (!match) return undefined;
|
||||
const selectList = match[1].trim();
|
||||
if (!selectList || /^DISTINCT\b/i.test(selectList)) return undefined;
|
||||
|
||||
const writableColumns: Record<string, string> = {};
|
||||
let selectsAll = false;
|
||||
for (const item of splitTopLevelComma(selectList)) {
|
||||
const resolved = resolveSimpleSelectItemColumn(item);
|
||||
if (!resolved) continue;
|
||||
if (resolved === 'all') {
|
||||
selectsAll = true;
|
||||
continue;
|
||||
}
|
||||
writableColumns[resolved.resultName] = resolved.sourceName;
|
||||
}
|
||||
return { selectsAll, writableColumns };
|
||||
};
|
||||
|
||||
const appendQuerySelectExpressions = (sql: string, expressions: string[]): string => {
|
||||
if (expressions.length === 0) return sql;
|
||||
return String(sql || '').replace(
|
||||
/^(\s*SELECT\s+)([\s\S]+?)(\s+FROM\s+[\s\S]*)$/i,
|
||||
(_match, prefix, selectList, rest) => `${prefix}${String(selectList).trimEnd()}, ${expressions.join(', ')}${rest}`,
|
||||
);
|
||||
};
|
||||
|
||||
const findWritableResultColumnForSource = (writableColumns: Record<string, string>, target: string): string | undefined => {
|
||||
const normalizedTarget = String(target || '').trim().toLowerCase();
|
||||
return Object.entries(writableColumns || {}).find(([, sourceColumn]) => (
|
||||
String(sourceColumn || '').trim().toLowerCase() === normalizedTarget
|
||||
))?.[0];
|
||||
};
|
||||
|
||||
const buildQueryLocatorAlias = (column: string, index: number): string => {
|
||||
const normalized = String(column || '').trim().replace(/[^A-Za-z0-9_]/g, '_').slice(0, 48) || 'column';
|
||||
return `${QUERY_LOCATOR_ALIAS_PREFIX}${index}_${normalized}`;
|
||||
};
|
||||
|
||||
const buildQueryLocatorColumnExpression = (dbType: string, column: string, alias: string): string => (
|
||||
`${quoteIdentPart(dbType, column)} AS ${quoteIdentPart(dbType, alias)}`
|
||||
);
|
||||
|
||||
const buildQueryRowIDExpression = (dbType: string): string => (
|
||||
`ROWID AS ${quoteIdentPart(dbType, ORACLE_ROWID_LOCATOR_COLUMN)}`
|
||||
);
|
||||
|
||||
const resolveQueryLocatorPlan = async ({
|
||||
statement,
|
||||
dbType,
|
||||
currentDb,
|
||||
config,
|
||||
forceReadOnly,
|
||||
}: {
|
||||
statement: string;
|
||||
dbType: string;
|
||||
currentDb: string;
|
||||
config: any;
|
||||
forceReadOnly: boolean;
|
||||
}): Promise<QueryStatementPlan> => {
|
||||
const plan: QueryStatementPlan = {
|
||||
originalSql: statement,
|
||||
executedSql: statement,
|
||||
pkColumns: [],
|
||||
};
|
||||
if (forceReadOnly) return plan;
|
||||
|
||||
const tableRef = extractQueryResultTableRef(statement, dbType, currentDb);
|
||||
if (!tableRef) return plan;
|
||||
plan.tableRef = tableRef;
|
||||
|
||||
const selectInfo = parseSimpleSelectInfo(statement);
|
||||
if (!selectInfo) {
|
||||
// 聚合、函数和表达式结果天然无法安全回写到单行,静默保持只读即可。
|
||||
return plan;
|
||||
}
|
||||
if (!selectInfo.selectsAll && Object.keys(selectInfo.writableColumns).length === 0) {
|
||||
return plan;
|
||||
}
|
||||
|
||||
try {
|
||||
const [resCols, resIndexes] = await Promise.all([
|
||||
DBGetColumns(buildRpcConnectionConfig(config) as any, tableRef.metadataDbName, tableRef.metadataTableName),
|
||||
DBGetIndexes(buildRpcConnectionConfig(config) as any, tableRef.metadataDbName, tableRef.metadataTableName)
|
||||
.catch((error: any) => ({ success: false, message: String(error?.message || error || '加载索引失败'), data: [] })),
|
||||
]);
|
||||
if (!resCols?.success || !Array.isArray(resCols.data)) {
|
||||
const reason = `无法加载 ${tableRef.metadataDbName}.${tableRef.metadataTableName} 的主键/唯一索引元数据,无法安全提交修改。`;
|
||||
plan.editLocator = buildQueryReadOnlyLocator(reason);
|
||||
plan.warning = `查询结果保持只读:${reason}`;
|
||||
return plan;
|
||||
}
|
||||
|
||||
const tableColumns = resCols.data as ColumnDefinition[];
|
||||
const tableColumnNames = tableColumns.map((column) => String(column?.name || '').trim()).filter(Boolean);
|
||||
const primaryKeys = tableColumns
|
||||
.filter((column: any) => column?.key === 'PRI')
|
||||
.map((column: any) => String(column?.name || '').trim())
|
||||
.filter(Boolean);
|
||||
const indexes = resIndexes?.success && Array.isArray(resIndexes.data)
|
||||
? resIndexes.data as IndexDefinition[]
|
||||
: [];
|
||||
const writableColumns: Record<string, string> = selectInfo.selectsAll
|
||||
? Object.fromEntries(tableColumnNames.map((column) => [column, column]))
|
||||
: {};
|
||||
Object.entries(selectInfo.writableColumns).forEach(([resultColumn, sourceColumn]) => {
|
||||
writableColumns[resultColumn] = sourceColumn;
|
||||
});
|
||||
const appendExpressions: string[] = [];
|
||||
const hiddenColumns: string[] = [];
|
||||
|
||||
const buildColumnLocator = (strategy: 'primary-key' | 'unique-key', locatorColumns: string[]): EditRowLocator => {
|
||||
const valueColumns = locatorColumns.map((column, index) => {
|
||||
const selectedColumn = findWritableResultColumnForSource(writableColumns, column);
|
||||
if (selectedColumn) return selectedColumn;
|
||||
const alias = buildQueryLocatorAlias(column, index + 1);
|
||||
appendExpressions.push(buildQueryLocatorColumnExpression(dbType, column, alias));
|
||||
hiddenColumns.push(alias);
|
||||
return alias;
|
||||
});
|
||||
return {
|
||||
strategy,
|
||||
columns: locatorColumns,
|
||||
valueColumns,
|
||||
hiddenColumns: hiddenColumns.length > 0 ? [...hiddenColumns] : undefined,
|
||||
writableColumns,
|
||||
readOnly: false,
|
||||
};
|
||||
};
|
||||
|
||||
if (primaryKeys.length > 0) {
|
||||
plan.pkColumns = primaryKeys;
|
||||
plan.editLocator = buildColumnLocator('primary-key', primaryKeys);
|
||||
} else {
|
||||
const uniqueKeyGroups = resolveUniqueKeyGroupsFromIndexes(indexes);
|
||||
const uniqueKeyGroup = uniqueKeyGroups.find((group) => group.length > 0);
|
||||
if (uniqueKeyGroup) {
|
||||
plan.editLocator = buildColumnLocator('unique-key', uniqueKeyGroup);
|
||||
} else if (isOracleLikeDialect(dbType)) {
|
||||
appendExpressions.push(buildQueryRowIDExpression(dbType));
|
||||
plan.editLocator = {
|
||||
strategy: 'oracle-rowid',
|
||||
columns: ['ROWID'],
|
||||
valueColumns: [ORACLE_ROWID_LOCATOR_COLUMN],
|
||||
hiddenColumns: [ORACLE_ROWID_LOCATOR_COLUMN],
|
||||
writableColumns,
|
||||
readOnly: false,
|
||||
};
|
||||
} else {
|
||||
const reason = !resIndexes?.success
|
||||
? '无法加载唯一索引元数据,无法安全提交修改。'
|
||||
: '未检测到主键或可用唯一索引,无法安全提交修改。';
|
||||
plan.editLocator = buildQueryReadOnlyLocator(reason);
|
||||
plan.warning = `查询结果保持只读:${tableRef.metadataDbName}.${tableRef.metadataTableName} ${reason}`;
|
||||
}
|
||||
}
|
||||
|
||||
plan.executedSql = appendQuerySelectExpressions(statement, appendExpressions);
|
||||
return plan;
|
||||
} catch {
|
||||
const reason = `无法加载 ${tableRef.metadataDbName}.${tableRef.metadataTableName} 的主键/唯一索引元数据,无法安全提交修改。`;
|
||||
plan.editLocator = buildQueryReadOnlyLocator(reason);
|
||||
plan.warning = `查询结果保持只读:${reason}`;
|
||||
return plan;
|
||||
}
|
||||
};
|
||||
|
||||
const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isActive = true }) => {
|
||||
const [query, setQuery] = useState(tab.query || 'SELECT * FROM ');
|
||||
|
||||
@@ -197,6 +496,7 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
|
||||
columns: string[];
|
||||
tableName?: string;
|
||||
pkColumns: string[];
|
||||
editLocator?: EditRowLocator;
|
||||
readOnly: boolean;
|
||||
truncated?: boolean;
|
||||
pkLoading?: boolean;
|
||||
@@ -1184,359 +1484,6 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
|
||||
return statements;
|
||||
};
|
||||
|
||||
const getLeadingKeyword = (sql: string): string => {
|
||||
const text = (sql || '').replace(/\r\n/g, '\n');
|
||||
const isWS = (ch: string) => ch === ' ' || ch === '\t' || ch === '\n' || ch === '\r';
|
||||
const isWord = (ch: string) => /[A-Za-z0-9_]/.test(ch);
|
||||
|
||||
let inSingle = false;
|
||||
let inDouble = false;
|
||||
let inBacktick = false;
|
||||
let escaped = false;
|
||||
let inLineComment = false;
|
||||
let inBlockComment = false;
|
||||
let dollarTag: string | null = null;
|
||||
|
||||
for (let i = 0; i < text.length; i++) {
|
||||
const ch = text[i];
|
||||
const next = i + 1 < text.length ? text[i + 1] : '';
|
||||
const prev = i > 0 ? text[i - 1] : '';
|
||||
const next2 = i + 2 < text.length ? text[i + 2] : '';
|
||||
|
||||
if (!inSingle && !inDouble && !inBacktick) {
|
||||
if (inLineComment) {
|
||||
if (ch === '\n') inLineComment = false;
|
||||
continue;
|
||||
}
|
||||
if (inBlockComment) {
|
||||
if (ch === '*' && next === '/') {
|
||||
i++;
|
||||
inBlockComment = false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ch === '/' && next === '*') {
|
||||
i++;
|
||||
inBlockComment = true;
|
||||
continue;
|
||||
}
|
||||
if (ch === '#') {
|
||||
inLineComment = true;
|
||||
continue;
|
||||
}
|
||||
if (ch === '-' && next === '-' && (i === 0 || isWS(prev)) && (next2 === '' || isWS(next2))) {
|
||||
i++;
|
||||
inLineComment = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (dollarTag) {
|
||||
if (text.startsWith(dollarTag, i)) {
|
||||
i += dollarTag.length - 1;
|
||||
dollarTag = null;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (ch === '$') {
|
||||
const m = text.slice(i).match(/^\$[A-Za-z0-9_]*\$/);
|
||||
if (m && m[0]) {
|
||||
dollarTag = m[0];
|
||||
i += dollarTag.length - 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (escaped) {
|
||||
escaped = false;
|
||||
continue;
|
||||
}
|
||||
if ((inSingle || inDouble) && ch === '\\') {
|
||||
escaped = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!inDouble && !inBacktick && ch === '\'') {
|
||||
inSingle = !inSingle;
|
||||
continue;
|
||||
}
|
||||
if (!inSingle && !inBacktick && ch === '"') {
|
||||
inDouble = !inDouble;
|
||||
continue;
|
||||
}
|
||||
if (!inSingle && !inDouble && ch === '`') {
|
||||
inBacktick = !inBacktick;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (inSingle || inDouble || inBacktick || dollarTag) continue;
|
||||
if (isWS(ch)) continue;
|
||||
|
||||
if (isWord(ch)) {
|
||||
let j = i;
|
||||
while (j < text.length && isWord(text[j])) j++;
|
||||
return text.slice(i, j).toLowerCase();
|
||||
}
|
||||
return '';
|
||||
}
|
||||
return '';
|
||||
};
|
||||
|
||||
const splitSqlTail = (sql: string): { main: string; tail: string } => {
|
||||
const text = (sql || '').replace(/\r\n/g, '\n');
|
||||
const isWS = (ch: string) => ch === ' ' || ch === '\t' || ch === '\n' || ch === '\r';
|
||||
|
||||
let inSingle = false;
|
||||
let inDouble = false;
|
||||
let inBacktick = false;
|
||||
let escaped = false;
|
||||
let inLineComment = false;
|
||||
let inBlockComment = false;
|
||||
let dollarTag: string | null = null;
|
||||
let lastMeaningful = -1;
|
||||
|
||||
for (let i = 0; i < text.length; i++) {
|
||||
const ch = text[i];
|
||||
const next = i + 1 < text.length ? text[i + 1] : '';
|
||||
const prev = i > 0 ? text[i - 1] : '';
|
||||
const next2 = i + 2 < text.length ? text[i + 2] : '';
|
||||
|
||||
if (!inSingle && !inDouble && !inBacktick) {
|
||||
if (dollarTag) {
|
||||
if (text.startsWith(dollarTag, i)) {
|
||||
lastMeaningful = i + dollarTag.length - 1;
|
||||
i += dollarTag.length - 1;
|
||||
dollarTag = null;
|
||||
} else if (!isWS(ch)) {
|
||||
lastMeaningful = i;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (inLineComment) {
|
||||
if (ch === '\n') inLineComment = false;
|
||||
continue;
|
||||
}
|
||||
if (inBlockComment) {
|
||||
if (ch === '*' && next === '/') {
|
||||
i++;
|
||||
inBlockComment = false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Start comments
|
||||
if (ch === '/' && next === '*') {
|
||||
i++;
|
||||
inBlockComment = true;
|
||||
continue;
|
||||
}
|
||||
if (ch === '#') {
|
||||
inLineComment = true;
|
||||
continue;
|
||||
}
|
||||
if (ch === '-' && next === '-' && (i === 0 || isWS(prev)) && (next2 === '' || isWS(next2))) {
|
||||
i++;
|
||||
inLineComment = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ch === '$') {
|
||||
const m = text.slice(i).match(/^\$[A-Za-z0-9_]*\$/);
|
||||
if (m && m[0]) {
|
||||
dollarTag = m[0];
|
||||
lastMeaningful = i + dollarTag.length - 1;
|
||||
i += dollarTag.length - 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (escaped) {
|
||||
escaped = false;
|
||||
} else if ((inSingle || inDouble) && ch === '\\') {
|
||||
escaped = true;
|
||||
} else {
|
||||
if (!inDouble && !inBacktick && ch === '\'') inSingle = !inSingle;
|
||||
else if (!inSingle && !inBacktick && ch === '"') inDouble = !inDouble;
|
||||
else if (!inSingle && !inDouble && ch === '`') inBacktick = !inBacktick;
|
||||
}
|
||||
|
||||
if (!inLineComment && !inBlockComment && !isWS(ch)) {
|
||||
lastMeaningful = i;
|
||||
}
|
||||
}
|
||||
|
||||
if (lastMeaningful < 0) return { main: '', tail: text };
|
||||
return { main: text.slice(0, lastMeaningful + 1), tail: text.slice(lastMeaningful + 1) };
|
||||
};
|
||||
|
||||
const findTopLevelKeyword = (sql: string, keyword: string): number => {
|
||||
const text = sql;
|
||||
const kw = keyword.toLowerCase();
|
||||
const isWS = (ch: string) => ch === ' ' || ch === '\t' || ch === '\n' || ch === '\r';
|
||||
const isWord = (ch: string) => /[A-Za-z0-9_]/.test(ch);
|
||||
|
||||
let inSingle = false;
|
||||
let inDouble = false;
|
||||
let inBacktick = false;
|
||||
let escaped = false;
|
||||
let inLineComment = false;
|
||||
let inBlockComment = false;
|
||||
let dollarTag: string | null = null;
|
||||
let parenDepth = 0;
|
||||
|
||||
for (let i = 0; i < text.length; i++) {
|
||||
const ch = text[i];
|
||||
const next = i + 1 < text.length ? text[i + 1] : '';
|
||||
const prev = i > 0 ? text[i - 1] : '';
|
||||
const next2 = i + 2 < text.length ? text[i + 2] : '';
|
||||
|
||||
if (!inSingle && !inDouble && !inBacktick) {
|
||||
if (inLineComment) {
|
||||
if (ch === '\n') inLineComment = false;
|
||||
continue;
|
||||
}
|
||||
if (inBlockComment) {
|
||||
if (ch === '*' && next === '/') {
|
||||
i++;
|
||||
inBlockComment = false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ch === '/' && next === '*') {
|
||||
i++;
|
||||
inBlockComment = true;
|
||||
continue;
|
||||
}
|
||||
if (ch === '#') {
|
||||
inLineComment = true;
|
||||
continue;
|
||||
}
|
||||
if (ch === '-' && next === '-' && (i === 0 || isWS(prev)) && (next2 === '' || isWS(next2))) {
|
||||
i++;
|
||||
inLineComment = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (dollarTag) {
|
||||
if (text.startsWith(dollarTag, i)) {
|
||||
i += dollarTag.length - 1;
|
||||
dollarTag = null;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (ch === '$') {
|
||||
const m = text.slice(i).match(/^\$[A-Za-z0-9_]*\$/);
|
||||
if (m && m[0]) {
|
||||
dollarTag = m[0];
|
||||
i += dollarTag.length - 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (escaped) {
|
||||
escaped = false;
|
||||
continue;
|
||||
}
|
||||
if ((inSingle || inDouble) && ch === '\\') {
|
||||
escaped = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!inDouble && !inBacktick && ch === '\'') {
|
||||
inSingle = !inSingle;
|
||||
continue;
|
||||
}
|
||||
if (!inSingle && !inBacktick && ch === '"') {
|
||||
inDouble = !inDouble;
|
||||
continue;
|
||||
}
|
||||
if (!inSingle && !inDouble && ch === '`') {
|
||||
inBacktick = !inBacktick;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (inSingle || inDouble || inBacktick || dollarTag) continue;
|
||||
|
||||
if (ch === '(') { parenDepth++; continue; }
|
||||
if (ch === ')') { if (parenDepth > 0) parenDepth--; continue; }
|
||||
if (parenDepth !== 0) continue;
|
||||
|
||||
if (!isWord(ch)) continue;
|
||||
|
||||
if (text.slice(i, i + kw.length).toLowerCase() !== kw) continue;
|
||||
const before = i - 1 >= 0 ? text[i - 1] : '';
|
||||
const after = i + kw.length < text.length ? text[i + kw.length] : '';
|
||||
if ((before && isWord(before)) || (after && isWord(after))) continue;
|
||||
return i;
|
||||
}
|
||||
return -1;
|
||||
};
|
||||
|
||||
const applyAutoLimit = (sql: string, dbType: string, maxRows: number): { sql: string; applied: boolean; maxRows: number } => {
|
||||
if (!Number.isFinite(maxRows) || maxRows <= 0) return { sql, applied: false, maxRows };
|
||||
const normalizedType = (dbType || 'mysql').toLowerCase();
|
||||
|
||||
// 只对 SELECT 语句自动加限制
|
||||
const keyword = getLeadingKeyword(sql);
|
||||
if (keyword !== 'SELECT') return { sql, applied: false, maxRows };
|
||||
|
||||
const { main, tail } = splitSqlTail(sql);
|
||||
if (!main.trim()) return { sql, applied: false, maxRows };
|
||||
|
||||
const fromPos = findTopLevelKeyword(main, 'from');
|
||||
const limitPos = findTopLevelKeyword(main, 'limit');
|
||||
// 已有 LIMIT → 不注入
|
||||
if (limitPos >= 0 && (fromPos < 0 || limitPos > fromPos)) return { sql, applied: false, maxRows };
|
||||
const fetchPos = findTopLevelKeyword(main, 'fetch');
|
||||
// 已有 FETCH → 不注入
|
||||
if (fetchPos >= 0 && (fromPos < 0 || fetchPos > fromPos)) return { sql, applied: false, maxRows };
|
||||
|
||||
// SQL Server / mssql: 检查是否已有 TOP,未有则注入 SELECT TOP N
|
||||
if (normalizedType === 'sqlserver' || normalizedType === 'mssql') {
|
||||
const topPos = findTopLevelKeyword(main, 'top');
|
||||
if (topPos >= 0) return { sql, applied: false, maxRows }; // 已有 TOP
|
||||
// 在 SELECT 关键字之后插入 TOP N
|
||||
const selectPos = findTopLevelKeyword(main, 'select');
|
||||
if (selectPos < 0) return { sql, applied: false, maxRows };
|
||||
const afterSelect = selectPos + 'SELECT'.length;
|
||||
// 处理 SELECT DISTINCT 的情况
|
||||
const restAfterSelect = main.slice(afterSelect);
|
||||
const distinctMatch = restAfterSelect.match(/^(\s+DISTINCT\b)/i);
|
||||
const insertOffset = distinctMatch ? afterSelect + distinctMatch[1].length : afterSelect;
|
||||
const nextMain = main.slice(0, insertOffset) + ` TOP ${maxRows}` + main.slice(insertOffset);
|
||||
return { sql: nextMain + tail, applied: true, maxRows };
|
||||
}
|
||||
|
||||
// Oracle / Dameng: 使用 FETCH FIRST N ROWS ONLY(Oracle 12c+ 标准语法)
|
||||
if (normalizedType === 'oracle' || normalizedType === 'dameng') {
|
||||
// 检查是否已有 ROWNUM 限制
|
||||
const rownumPos = findTopLevelKeyword(main, 'rownum');
|
||||
if (rownumPos >= 0) return { sql, applied: false, maxRows };
|
||||
const offsetPos = findTopLevelKeyword(main, 'offset');
|
||||
if (offsetPos >= 0 && (fromPos < 0 || offsetPos > fromPos)) return { sql, applied: false, maxRows };
|
||||
const nextMain = main.trimEnd() + ` FETCH FIRST ${maxRows} ROWS ONLY`;
|
||||
return { sql: nextMain + tail, applied: true, maxRows };
|
||||
}
|
||||
|
||||
// 通用 LIMIT 语法(MySQL, PostgreSQL, SQLite, ClickHouse, DuckDB 等)
|
||||
const offsetPos = findTopLevelKeyword(main, 'offset');
|
||||
const forPos = findTopLevelKeyword(main, 'for');
|
||||
const lockPos = findTopLevelKeyword(main, 'lock');
|
||||
|
||||
const candidates = [offsetPos, forPos, lockPos]
|
||||
.filter(pos => pos >= 0 && (fromPos < 0 || pos > fromPos));
|
||||
|
||||
const insertAt = candidates.length > 0 ? Math.min(...candidates) : main.length;
|
||||
const before = main.slice(0, insertAt).trimEnd();
|
||||
const after = main.slice(insertAt).trimStart();
|
||||
const nextMain = [before, `LIMIT ${maxRows}`, after].filter(Boolean).join(' ').trim();
|
||||
return { sql: nextMain + tail, applied: true, maxRows };
|
||||
};
|
||||
|
||||
const getSelectedSQL = (): string => {
|
||||
const editor = editorRef.current;
|
||||
if (!editor) return '';
|
||||
@@ -1662,8 +1609,10 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
|
||||
|
||||
try {
|
||||
const rawSQL = getSelectedSQL() || currentQuery;
|
||||
const dbType = String((buildRpcConnectionConfig(config) as any).type || 'mysql');
|
||||
const normalizedDbType = dbType.trim().toLowerCase();
|
||||
const rpcConfig = buildRpcConnectionConfig(config) as any;
|
||||
const dbType = String(rpcConfig.type || 'mysql');
|
||||
const driver = String((config as any).driver || '');
|
||||
const normalizedDbType = String(resolveSqlDialect(dbType, driver)).trim().toLowerCase();
|
||||
const normalizedRawSQL = String(rawSQL || '').replace(/;/g, ';');
|
||||
|
||||
// MongoDB 仍走逐条执行的旧路径
|
||||
@@ -1703,6 +1652,12 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
|
||||
executedSql = shellConvert.command;
|
||||
}
|
||||
}
|
||||
if (wantsLimitProbe) {
|
||||
const limitResult = applyMongoQueryAutoLimit(executedSql, maxRows);
|
||||
if (limitResult.applied) {
|
||||
executedSql = limitResult.command;
|
||||
}
|
||||
}
|
||||
const startTime = Date.now();
|
||||
let queryId: string;
|
||||
try {
|
||||
@@ -1783,26 +1738,36 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
|
||||
|
||||
} else {
|
||||
// 非 MongoDB:使用 DBQueryMulti 一次性执行多条 SQL,后端返回多结果集
|
||||
let fullSQL = normalizedRawSQL;
|
||||
if (!fullSQL.trim()) {
|
||||
const sourceStatements = splitSQLStatements(normalizedRawSQL);
|
||||
if (sourceStatements.length === 0) {
|
||||
message.info('没有可执行的 SQL。');
|
||||
setResultSets([]);
|
||||
setActiveResultKey('');
|
||||
return;
|
||||
}
|
||||
|
||||
const forceReadOnlyResult = connCaps.forceReadOnlyQueryResult;
|
||||
const statementPlans: QueryStatementPlan[] = [];
|
||||
for (const statement of sourceStatements) {
|
||||
statementPlans.push(await resolveQueryLocatorPlan({
|
||||
statement,
|
||||
dbType: normalizedDbType,
|
||||
currentDb,
|
||||
config,
|
||||
forceReadOnly: forceReadOnlyResult,
|
||||
}));
|
||||
}
|
||||
|
||||
// 自动给 SELECT 语句注入行数限制(防止大结果集卡死)
|
||||
const maxRowsForLimit = Number(queryOptions?.maxRows) || 0;
|
||||
let anyLimitApplied = false;
|
||||
if (Number.isFinite(maxRowsForLimit) && maxRowsForLimit > 0) {
|
||||
const stmts = splitSQLStatements(fullSQL);
|
||||
const limitedStmts = stmts.map(s => {
|
||||
const result = applyAutoLimit(s, normalizedDbType, maxRowsForLimit);
|
||||
if (result.applied) anyLimitApplied = true;
|
||||
return result.sql;
|
||||
});
|
||||
fullSQL = limitedStmts.join(';\n');
|
||||
}
|
||||
const executablePlans = statementPlans.map((plan) => {
|
||||
if (!Number.isFinite(maxRowsForLimit) || maxRowsForLimit <= 0) return plan;
|
||||
const result = applyQueryAutoLimit(plan.executedSql, normalizedDbType, maxRowsForLimit, driver);
|
||||
if (result.applied) anyLimitApplied = true;
|
||||
return { ...plan, executedSql: result.sql };
|
||||
});
|
||||
const fullSQL = executablePlans.map((plan) => plan.executedSql).join(';\n');
|
||||
|
||||
const startTime = Date.now();
|
||||
let queryId: string;
|
||||
@@ -1859,16 +1824,13 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
|
||||
const resultSetDataArray = Array.isArray(res.data) ? (res.data as any[]) : [];
|
||||
const nextResultSets: ResultSet[] = [];
|
||||
const maxRows = Number(queryOptions?.maxRows) || 0;
|
||||
const forceReadOnlyResult = connCaps.forceReadOnlyQueryResult;
|
||||
let anyTruncated = false;
|
||||
const pendingPk: Array<{ resultKey: string; tableName: string }> = [];
|
||||
|
||||
// 前端也拆分语句用于匹配原始 SQL(展示和表名检测)
|
||||
const statements = splitSQLStatements(fullSQL);
|
||||
|
||||
for (let idx = 0; idx < resultSetDataArray.length; idx++) {
|
||||
const rsData = resultSetDataArray[idx];
|
||||
const rawStatement = (idx < statements.length) ? statements[idx] : '';
|
||||
const plan = executablePlans[idx];
|
||||
const originalSql = plan?.originalSql || '';
|
||||
const executedSql = plan?.executedSql || originalSql;
|
||||
|
||||
// 检查是否为 affectedRows 类结果集
|
||||
const isAffectedResult = Array.isArray(rsData.rows) && rsData.rows.length === 1
|
||||
@@ -1881,8 +1843,8 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
|
||||
(row as any)[GONAVI_ROW_KEY] = 0;
|
||||
nextResultSets.push({
|
||||
key: `result-${idx + 1}`,
|
||||
sql: rawStatement,
|
||||
exportSql: rawStatement,
|
||||
sql: executedSql,
|
||||
exportSql: originalSql,
|
||||
rows: [row],
|
||||
columns: ['affectedRows'],
|
||||
pkColumns: [],
|
||||
@@ -1905,32 +1867,18 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
|
||||
if (row && typeof row === 'object') row[GONAVI_ROW_KEY] = i;
|
||||
});
|
||||
|
||||
let simpleTableName: string | undefined = undefined;
|
||||
if (rawStatement) {
|
||||
// 支持多行 SQL:SELECT [cols] FROM [schema.]table [WHERE...] [ORDER BY...] [LIMIT...] 等
|
||||
// JOIN 查询表名歧义,不提取
|
||||
const hasJoin = /\bJOIN\b/i.test(rawStatement);
|
||||
const tableMatch = !hasJoin
|
||||
? rawStatement.match(/^\s*SELECT\s+.+?\s+FROM\s+(?:[\w`"\[\].]+\.)?[`"\[]?(\w+)[`"\]]?\s*(?:$|[\s;])/im)
|
||||
: null;
|
||||
if (tableMatch) {
|
||||
simpleTableName = tableMatch[1];
|
||||
if (!forceReadOnlyResult) {
|
||||
pendingPk.push({ resultKey: `result-${idx + 1}`, tableName: simpleTableName });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const tableRef = plan?.tableRef;
|
||||
const editLocator = plan?.editLocator;
|
||||
nextResultSets.push({
|
||||
key: `result-${idx + 1}`,
|
||||
sql: rawStatement,
|
||||
exportSql: rawStatement,
|
||||
sql: executedSql,
|
||||
exportSql: originalSql,
|
||||
rows,
|
||||
columns: cols,
|
||||
tableName: simpleTableName,
|
||||
pkColumns: [],
|
||||
readOnly: true,
|
||||
pkLoading: !!simpleTableName,
|
||||
tableName: tableRef?.tableName,
|
||||
pkColumns: plan?.pkColumns || [],
|
||||
editLocator,
|
||||
readOnly: forceReadOnlyResult || !editLocator || editLocator.readOnly,
|
||||
truncated
|
||||
});
|
||||
}
|
||||
@@ -1939,21 +1887,8 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
|
||||
setResultSets(nextResultSets);
|
||||
setActiveResultKey(nextResultSets[0]?.key || '');
|
||||
|
||||
pendingPk.forEach(({ resultKey, tableName }) => {
|
||||
DBGetColumns(buildRpcConnectionConfig(config) as any, currentDb, tableName)
|
||||
.then((resCols: any) => {
|
||||
if (runSeqRef.current !== runSeq) return;
|
||||
if (!resCols?.success) {
|
||||
setResultSets(prev => prev.map(rs => rs.key === resultKey ? { ...rs, pkLoading: false, readOnly: false } : rs));
|
||||
return;
|
||||
}
|
||||
const primaryKeys = (resCols.data as ColumnDefinition[]).filter(c => c.key === 'PRI').map(c => c.name);
|
||||
setResultSets(prev => prev.map(rs => rs.key === resultKey ? { ...rs, pkColumns: primaryKeys, pkLoading: false, readOnly: false } : rs));
|
||||
})
|
||||
.catch(() => {
|
||||
if (runSeqRef.current !== runSeq) return;
|
||||
setResultSets(prev => prev.map(rs => rs.key === resultKey ? { ...rs, pkLoading: false, readOnly: false } : rs));
|
||||
});
|
||||
executablePlans.forEach((plan) => {
|
||||
if (plan.warning) message.warning(plan.warning);
|
||||
});
|
||||
|
||||
// 后端附带的提示信息(如数据源不支持原生多语句执行的回退提示)
|
||||
@@ -2486,6 +2421,7 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
|
||||
dbName={currentDb}
|
||||
connectionId={currentConnectionId}
|
||||
pkColumns={rs.pkColumns}
|
||||
editLocator={rs.editLocator}
|
||||
onReload={() => handleReloadResult(rs.key, rs.sql)}
|
||||
readOnly={rs.readOnly}
|
||||
/>
|
||||
|
||||
@@ -36,9 +36,9 @@ import { Tree, message, Dropdown, MenuProps, Input, Button, Modal, Form, Badge,
|
||||
} from '@ant-design/icons';
|
||||
import { useStore } from '../store';
|
||||
import { buildOverlayWorkbenchTheme } from '../utils/overlayWorkbenchTheme';
|
||||
import { SavedConnection, ExternalSQLTreeEntry, JVMCapability, JVMResourceSummary } from '../types';
|
||||
import { SavedConnection, ExternalSQLTreeEntry, JVMCapability, JVMResourceSummary } from '../types';
|
||||
import { getDbIcon } from './DatabaseIcons';
|
||||
import { DBGetDatabases, DBGetTables, DBQuery, DBShowCreateTable, ExportTable, OpenSQLFile, ExecuteSQLFile, CancelSQLFileExecution, CreateDatabase, RenameDatabase, DropDatabase, RenameTable, DropTable, DropView, DropFunction, RenameView, SelectSQLDirectory, ListSQLDirectory, ReadSQLFile, JVMProbeCapabilities } from '../../wailsjs/go/app/App';
|
||||
import { DBGetDatabases, DBGetTables, DBQuery, DBShowCreateTable, ExportTable, OpenSQLFile, ExecuteSQLFile, CancelSQLFileExecution, CreateDatabase, RenameDatabase, DropDatabase, RenameTable, DropTable, DropView, DropFunction, RenameView, SelectSQLDirectory, ListSQLDirectory, ReadSQLFile, JVMProbeCapabilities, GetDriverStatusList } from '../../wailsjs/go/app/App';
|
||||
import { getTableDataDangerActionMeta, supportsTableTruncateAction, type TableDataDangerActionKind } from './tableDataDangerActions';
|
||||
import { EventsOn } from '../../wailsjs/runtime/runtime';
|
||||
import { isMacLikePlatform, normalizeOpacityForPlatform, resolveAppearanceValues } from '../utils/appearance';
|
||||
@@ -81,6 +81,33 @@ interface BatchObjectItem {
|
||||
dataRef: any;
|
||||
}
|
||||
|
||||
type DriverStatusSnapshot = {
|
||||
type: string;
|
||||
name: string;
|
||||
connectable: boolean;
|
||||
expectedRevision?: string;
|
||||
needsUpdate?: boolean;
|
||||
updateReason?: string;
|
||||
message?: string;
|
||||
};
|
||||
|
||||
const DRIVER_STATUS_CACHE_TTL_MS = 30_000;
|
||||
|
||||
const normalizeDriverType = (value: string): string => {
|
||||
const normalized = String(value || '').trim().toLowerCase();
|
||||
if (normalized === 'postgresql') return 'postgres';
|
||||
if (normalized === 'doris') return 'diros';
|
||||
return normalized;
|
||||
};
|
||||
|
||||
const resolveSavedConnectionDriverType = (conn: SavedConnection | undefined): string => {
|
||||
const type = normalizeDriverType(conn?.config?.type || '');
|
||||
if (type !== 'custom') {
|
||||
return type;
|
||||
}
|
||||
return normalizeDriverType(conn?.config?.driver || '');
|
||||
};
|
||||
|
||||
const SEARCH_SCOPE_OPTIONS: Array<{ value: SearchScope; label: string }> = [
|
||||
{ value: 'smart', label: '智能' },
|
||||
{ value: 'object', label: '表对象' },
|
||||
@@ -211,10 +238,12 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
const [autoExpandParent, setAutoExpandParent] = useState(true);
|
||||
const [loadedKeys, setLoadedKeys] = useState<React.Key[]>([]);
|
||||
const [selectedKeys, setSelectedKeys] = useState<React.Key[]>([]);
|
||||
const selectedNodesRef = useRef<any[]>([]);
|
||||
const loadingNodesRef = useRef<Set<string>>(new Set());
|
||||
const clickTimerRef = useRef<ReturnType<typeof setTimeout> | null>(null);
|
||||
const [contextMenu, setContextMenu] = useState<{ x: number, y: number, items: MenuProps['items'] } | null>(null);
|
||||
const selectedNodesRef = useRef<any[]>([]);
|
||||
const loadingNodesRef = useRef<Set<string>>(new Set());
|
||||
const clickTimerRef = useRef<ReturnType<typeof setTimeout> | null>(null);
|
||||
const driverStatusCacheRef = useRef<{ fetchedAt: number; items: Record<string, DriverStatusSnapshot> } | null>(null);
|
||||
const driverUpdateWarningKeysRef = useRef<Set<string>>(new Set());
|
||||
const [contextMenu, setContextMenu] = useState<{ x: number, y: number, items: MenuProps['items'] } | null>(null);
|
||||
|
||||
// Virtual Scroll State
|
||||
const [treeHeight, setTreeHeight] = useState(500);
|
||||
@@ -956,13 +985,72 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
const typeLabel = normalizedType === 'PROCEDURE' ? 'P' : 'F';
|
||||
routines.push({ displayName: `${fullName} [${typeLabel}]`, routineName: fullName, routineType: normalizedType });
|
||||
});
|
||||
});
|
||||
return { routines, supported: hasSuccessfulQuery };
|
||||
};
|
||||
});
|
||||
return { routines, supported: hasSuccessfulQuery };
|
||||
};
|
||||
|
||||
const loadDatabases = async (node: any) => {
|
||||
const conn = node.dataRef as SavedConnection;
|
||||
const loadKey = `dbs-${conn.id}`;
|
||||
const fetchDriverStatusMap = async (): Promise<Record<string, DriverStatusSnapshot>> => {
|
||||
const cached = driverStatusCacheRef.current;
|
||||
if (cached && Date.now() - cached.fetchedAt < DRIVER_STATUS_CACHE_TTL_MS) {
|
||||
return cached.items;
|
||||
}
|
||||
const result: Record<string, DriverStatusSnapshot> = {};
|
||||
const res = await GetDriverStatusList('', '');
|
||||
if (!res?.success) {
|
||||
return result;
|
||||
}
|
||||
const data = (res.data || {}) as any;
|
||||
const drivers = Array.isArray(data.drivers) ? data.drivers : [];
|
||||
drivers.forEach((item: any) => {
|
||||
const type = normalizeDriverType(String(item.type || '').trim());
|
||||
if (!type) return;
|
||||
result[type] = {
|
||||
type,
|
||||
name: String(item.name || item.type || type).trim(),
|
||||
connectable: !!item.connectable,
|
||||
expectedRevision: String(item.expectedRevision || '').trim() || undefined,
|
||||
needsUpdate: !!item.needsUpdate,
|
||||
updateReason: String(item.updateReason || '').trim() || undefined,
|
||||
message: String(item.message || '').trim() || undefined,
|
||||
};
|
||||
});
|
||||
driverStatusCacheRef.current = { fetchedAt: Date.now(), items: result };
|
||||
return result;
|
||||
};
|
||||
|
||||
const warnIfConnectionDriverAgentNeedsUpdate = async (conn: SavedConnection) => {
|
||||
try {
|
||||
const driverType = resolveSavedConnectionDriverType(conn);
|
||||
if (!driverType || driverType === 'custom') {
|
||||
return;
|
||||
}
|
||||
const statusMap = await fetchDriverStatusMap();
|
||||
const status = statusMap[driverType];
|
||||
if (!status?.connectable || !status.needsUpdate) {
|
||||
return;
|
||||
}
|
||||
const revisionKey = status.expectedRevision || status.updateReason || status.message || 'unknown';
|
||||
const warningKey = `${conn.id}:${driverType}:${revisionKey}`;
|
||||
if (driverUpdateWarningKeysRef.current.has(warningKey)) {
|
||||
return;
|
||||
}
|
||||
driverUpdateWarningKeysRef.current.add(warningKey);
|
||||
const driverName = status.name || driverType;
|
||||
const reason = status.message || status.updateReason || `${driverName} driver-agent 与当前 GoNavi 版本要求不一致`;
|
||||
message.warning({
|
||||
content: `${driverName} 驱动代理需要重装:${reason}`,
|
||||
key: `driver-agent-update-${conn.id}`,
|
||||
duration: 10,
|
||||
});
|
||||
} catch (error) {
|
||||
console.warn('检查驱动代理更新状态失败', error);
|
||||
}
|
||||
};
|
||||
|
||||
const loadDatabases = async (node: any) => {
|
||||
const conn = node.dataRef as SavedConnection;
|
||||
void warnIfConnectionDriverAgentNeedsUpdate(conn);
|
||||
const loadKey = `dbs-${conn.id}`;
|
||||
if (loadingNodesRef.current.has(loadKey)) return;
|
||||
loadingNodesRef.current.add(loadKey);
|
||||
const config = {
|
||||
@@ -1845,8 +1933,9 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
setIsBatchModalOpen(true);
|
||||
};
|
||||
|
||||
const loadDatabasesForBatch = async (conn: SavedConnection) => {
|
||||
const config = {
|
||||
const loadDatabasesForBatch = async (conn: SavedConnection) => {
|
||||
void warnIfConnectionDriverAgentNeedsUpdate(conn);
|
||||
const config = {
|
||||
...conn.config,
|
||||
port: Number(conn.config.port),
|
||||
password: conn.config.password || "",
|
||||
@@ -2154,10 +2243,11 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
setIsBatchDbModalOpen(true);
|
||||
};
|
||||
|
||||
const loadDatabasesForDbBatch = async (conn: SavedConnection) => {
|
||||
setBatchConnContext(conn);
|
||||
const loadDatabasesForDbBatch = async (conn: SavedConnection) => {
|
||||
setBatchConnContext(conn);
|
||||
void warnIfConnectionDriverAgentNeedsUpdate(conn);
|
||||
|
||||
const config = {
|
||||
const config = {
|
||||
...conn.config,
|
||||
port: Number(conn.config.port),
|
||||
password: conn.config.password || "",
|
||||
|
||||
@@ -187,6 +187,29 @@ describe('store appearance persistence', () => {
|
||||
expect(useStore.getState().connections[0]?.iconColor).toBe('#2f855a');
|
||||
});
|
||||
|
||||
it('normalizes ClickHouse protocol override when replacing saved connections', async () => {
|
||||
const { useStore } = await importStore();
|
||||
|
||||
useStore.getState().replaceConnections([
|
||||
{
|
||||
id: 'clickhouse-http',
|
||||
name: 'ClickHouse HTTP',
|
||||
config: {
|
||||
id: 'clickhouse-http',
|
||||
type: 'clickhouse',
|
||||
host: 'clickhouse.local',
|
||||
port: 8125,
|
||||
user: 'default',
|
||||
clickHouseProtocol: 'https' as any,
|
||||
},
|
||||
},
|
||||
]);
|
||||
|
||||
expect(useStore.getState().connections[0]?.config.clickHouseProtocol).toBe(
|
||||
'http',
|
||||
);
|
||||
});
|
||||
|
||||
it('keeps legacy global proxy password during hydration until explicit cleanup', async () => {
|
||||
storage.setItem('lite-db-storage', JSON.stringify({
|
||||
state: {
|
||||
|
||||
@@ -163,6 +163,15 @@ const toTrimmedString = (value: unknown, fallback = ""): string => {
|
||||
return fallback;
|
||||
};
|
||||
|
||||
const normalizeClickHouseProtocol = (
|
||||
value: unknown,
|
||||
): "auto" | "http" | "native" => {
|
||||
const text = toTrimmedString(value).toLowerCase();
|
||||
if (text === "http" || text === "https") return "http";
|
||||
if (text === "native" || text === "tcp") return "native";
|
||||
return "auto";
|
||||
};
|
||||
|
||||
const normalizePort = (value: unknown, fallbackPort: number): number => {
|
||||
const parsed = Number(value);
|
||||
if (!Number.isFinite(parsed)) return fallbackPort;
|
||||
@@ -513,6 +522,12 @@ const sanitizeConnectionConfig = (value: unknown): ConnectionConfig => {
|
||||
safeConfig.redisDB = normalizeIntegerInRange(raw.redisDB, 0, 0, 15);
|
||||
}
|
||||
|
||||
if (type === "clickhouse") {
|
||||
safeConfig.clickHouseProtocol = normalizeClickHouseProtocol(
|
||||
raw.clickHouseProtocol,
|
||||
);
|
||||
}
|
||||
|
||||
if (type === "custom") {
|
||||
safeConfig.driver = toTrimmedString(raw.driver);
|
||||
safeConfig.dsn = toTrimmedString(raw.dsn).slice(0, MAX_URI_LENGTH);
|
||||
|
||||
@@ -297,6 +297,7 @@ export interface ConnectionConfig {
|
||||
timeout?: number;
|
||||
redisDB?: number; // Redis database index (0-15)
|
||||
uri?: string; // Connection URI for copy/paste
|
||||
clickHouseProtocol?: "auto" | "http" | "native"; // ClickHouse connection protocol override
|
||||
hosts?: string[]; // Multi-host addresses: host:port
|
||||
topology?: "single" | "replica" | "cluster";
|
||||
mysqlReplicaUser?: string;
|
||||
|
||||
@@ -39,6 +39,19 @@ describe('buildRpcConnectionConfig', () => {
|
||||
expect(result.database).toBe('app');
|
||||
});
|
||||
|
||||
it('preserves ClickHouse protocol override for RPC calls', () => {
|
||||
const result = buildRpcConnectionConfig({
|
||||
id: 'conn-clickhouse',
|
||||
type: 'clickhouse',
|
||||
host: 'clickhouse.local',
|
||||
port: 8125,
|
||||
user: 'default',
|
||||
clickHouseProtocol: 'http',
|
||||
} as any);
|
||||
|
||||
expect(result.clickHouseProtocol).toBe('http');
|
||||
});
|
||||
|
||||
it('fills default nested config blocks needed by RPC calls', () => {
|
||||
const result = buildRpcConnectionConfig({
|
||||
id: 'conn-redis',
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import { convertMongoShellToJsonCommand } from './mongodb';
|
||||
import { applyMongoQueryAutoLimit, convertMongoShellToJsonCommand } from './mongodb';
|
||||
|
||||
const parseCommand = (command: string | undefined) => JSON.parse(command || '{}');
|
||||
|
||||
describe('convertMongoShellToJsonCommand', () => {
|
||||
it('converts show dbs shell shortcut to listDatabases command', () => {
|
||||
@@ -16,4 +18,105 @@ describe('convertMongoShellToJsonCommand', () => {
|
||||
command: JSON.stringify({ listCollections: 1, filter: {}, nameOnly: true }),
|
||||
});
|
||||
});
|
||||
|
||||
it('converts find shell commands without adding implicit limit', () => {
|
||||
const result = convertMongoShellToJsonCommand('db.users.find({ active: true })');
|
||||
|
||||
expect(result.recognized).toBe(true);
|
||||
expect(parseCommand(result.command)).toEqual({
|
||||
find: 'users',
|
||||
filter: { active: true },
|
||||
});
|
||||
});
|
||||
|
||||
it('keeps explicit find limit values from shell commands', () => {
|
||||
const result = convertMongoShellToJsonCommand('db.users.find({}).limit(10)');
|
||||
|
||||
expect(parseCommand(result.command)).toEqual({
|
||||
find: 'users',
|
||||
filter: {},
|
||||
limit: 10,
|
||||
});
|
||||
});
|
||||
|
||||
it('keeps explicit zero limit values from shell commands', () => {
|
||||
const result = convertMongoShellToJsonCommand('db.users.find({}).limit(0)');
|
||||
|
||||
expect(parseCommand(result.command)).toEqual({
|
||||
find: 'users',
|
||||
filter: {},
|
||||
limit: 0,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('applyMongoQueryAutoLimit', () => {
|
||||
it('adds limit to raw Mongo find commands', () => {
|
||||
const result = applyMongoQueryAutoLimit('{"find":"users","filter":{}}', 500);
|
||||
|
||||
expect(result.applied).toBe(true);
|
||||
expect(parseCommand(result.command)).toEqual({
|
||||
find: 'users',
|
||||
filter: {},
|
||||
limit: 500,
|
||||
});
|
||||
});
|
||||
|
||||
it('adds limit after shell find conversion', () => {
|
||||
const shell = convertMongoShellToJsonCommand('db.users.find({ active: true })');
|
||||
const result = applyMongoQueryAutoLimit(shell.command || '', 500);
|
||||
|
||||
expect(result.applied).toBe(true);
|
||||
expect(parseCommand(result.command)).toEqual({
|
||||
find: 'users',
|
||||
filter: { active: true },
|
||||
limit: 500,
|
||||
});
|
||||
});
|
||||
|
||||
it('does not replace explicit find limits', () => {
|
||||
const result = applyMongoQueryAutoLimit('{"find":"users","filter":{},"limit":10}', 500);
|
||||
|
||||
expect(result.applied).toBe(false);
|
||||
expect(parseCommand(result.command)).toEqual({
|
||||
find: 'users',
|
||||
filter: {},
|
||||
limit: 10,
|
||||
});
|
||||
});
|
||||
|
||||
it('adds $limit to read-only aggregate pipelines', () => {
|
||||
const result = applyMongoQueryAutoLimit('{"aggregate":"users","pipeline":[{"$match":{"active":true}}],"cursor":{}}', 500);
|
||||
|
||||
expect(result.applied).toBe(true);
|
||||
expect(parseCommand(result.command)).toEqual({
|
||||
aggregate: 'users',
|
||||
pipeline: [
|
||||
{ $match: { active: true } },
|
||||
{ $limit: 500 },
|
||||
],
|
||||
cursor: {},
|
||||
});
|
||||
});
|
||||
|
||||
it('does not add another aggregate $limit', () => {
|
||||
const command = '{"aggregate":"users","pipeline":[{"$limit":10}],"cursor":{}}';
|
||||
const result = applyMongoQueryAutoLimit(command, 500);
|
||||
|
||||
expect(result.applied).toBe(false);
|
||||
expect(result.command).toBe(command);
|
||||
});
|
||||
|
||||
it('does not alter aggregate write pipelines', () => {
|
||||
const command = '{"aggregate":"users","pipeline":[{"$match":{}},{"$out":"tmp_users"}],"cursor":{}}';
|
||||
const result = applyMongoQueryAutoLimit(command, 500);
|
||||
|
||||
expect(result.applied).toBe(false);
|
||||
expect(result.command).toBe(command);
|
||||
});
|
||||
|
||||
it('does not limit non-read or invalid commands', () => {
|
||||
expect(applyMongoQueryAutoLimit('{"count":"users","query":{}}', 500).applied).toBe(false);
|
||||
expect(applyMongoQueryAutoLimit('db.users.find({})', 500).applied).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -321,7 +321,7 @@ const parseCollectionAndMethod = (raw: string): {
|
||||
pos = nextPos;
|
||||
} else {
|
||||
let end = pos;
|
||||
while (end < input.length && /[A-Za-z0-9_$.-]/.test(input[end])) end++;
|
||||
while (end < input.length && /[A-Za-z0-9_$-]/.test(input[end])) end++;
|
||||
collection = input.slice(pos, end).trim();
|
||||
pos = end;
|
||||
}
|
||||
@@ -662,7 +662,7 @@ export const buildMongoFindCommand = (params: {
|
||||
if (params.sort && Object.keys(params.sort).length > 0) {
|
||||
command.sort = params.sort;
|
||||
}
|
||||
if (Number.isFinite(params.limit) && Number(params.limit) > 0) {
|
||||
if (Number.isFinite(params.limit) && Number(params.limit) >= 0) {
|
||||
command.limit = Math.floor(Number(params.limit));
|
||||
}
|
||||
if (Number.isFinite(params.skip) && Number(params.skip) > 0) {
|
||||
@@ -678,6 +678,45 @@ export const buildMongoCountCommand = (collection: string, filter: Record<string
|
||||
});
|
||||
};
|
||||
|
||||
const hasOwn = (obj: Record<string, unknown>, key: string) => Object.prototype.hasOwnProperty.call(obj, key);
|
||||
|
||||
const isMongoCommandObject = (value: unknown): value is Record<string, unknown> => (
|
||||
!!value && typeof value === 'object' && !Array.isArray(value)
|
||||
);
|
||||
|
||||
export const applyMongoQueryAutoLimit = (
|
||||
command: string,
|
||||
maxRows: number,
|
||||
): { command: string; applied: boolean; maxRows: number } => {
|
||||
if (!Number.isFinite(maxRows) || maxRows <= 0) return { command, applied: false, maxRows };
|
||||
|
||||
let parsed: unknown;
|
||||
try {
|
||||
parsed = JSON.parse(String(command || '').trim());
|
||||
} catch {
|
||||
return { command, applied: false, maxRows };
|
||||
}
|
||||
if (!isMongoCommandObject(parsed)) return { command, applied: false, maxRows };
|
||||
|
||||
const nextMaxRows = Math.floor(Number(maxRows));
|
||||
if (hasOwn(parsed, 'find')) {
|
||||
if (hasOwn(parsed, 'limit')) return { command, applied: false, maxRows };
|
||||
parsed.limit = nextMaxRows;
|
||||
return { command: JSON.stringify(parsed), applied: true, maxRows };
|
||||
}
|
||||
|
||||
if (hasOwn(parsed, 'aggregate') && Array.isArray(parsed.pipeline)) {
|
||||
const pipeline = parsed.pipeline as unknown[];
|
||||
const hasExplicitLimit = pipeline.some((stage) => isMongoCommandObject(stage) && hasOwn(stage, '$limit'));
|
||||
const hasWriteStage = pipeline.some((stage) => isMongoCommandObject(stage) && (hasOwn(stage, '$out') || hasOwn(stage, '$merge')));
|
||||
if (hasExplicitLimit || hasWriteStage) return { command, applied: false, maxRows };
|
||||
pipeline.push({ $limit: nextMaxRows });
|
||||
return { command: JSON.stringify(parsed), applied: true, maxRows };
|
||||
}
|
||||
|
||||
return { command, applied: false, maxRows };
|
||||
};
|
||||
|
||||
const buildMongoInsertCommand = (
|
||||
collection: string,
|
||||
documents: Record<string, unknown>[],
|
||||
|
||||
110
frontend/src/utils/queryAutoLimit.test.ts
Normal file
110
frontend/src/utils/queryAutoLimit.test.ts
Normal file
@@ -0,0 +1,110 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import { applyQueryAutoLimit } from './queryAutoLimit';
|
||||
|
||||
describe('applyQueryAutoLimit', () => {
|
||||
const limitDialects = [
|
||||
'mysql',
|
||||
'mariadb',
|
||||
'diros',
|
||||
'doris',
|
||||
'sphinx',
|
||||
'postgres',
|
||||
'postgresql',
|
||||
'kingbase',
|
||||
'kingbase8',
|
||||
'highgo',
|
||||
'vastbase',
|
||||
'sqlite',
|
||||
'sqlite3',
|
||||
'duckdb',
|
||||
'clickhouse',
|
||||
'tdengine',
|
||||
];
|
||||
|
||||
it.each(limitDialects)('adds generic LIMIT for %s connections', (dbType) => {
|
||||
expect(applyQueryAutoLimit('SELECT * FROM users', dbType, 500).sql)
|
||||
.toBe('SELECT * FROM users LIMIT 500');
|
||||
});
|
||||
|
||||
it.each([
|
||||
['oracle'],
|
||||
['dameng'],
|
||||
['dm'],
|
||||
['dm8'],
|
||||
])('adds FETCH FIRST limit for %s connections', (dbType) => {
|
||||
expect(applyQueryAutoLimit('SELECT * FROM MYCIMLED.EDC_LOG', dbType, 500).sql)
|
||||
.toBe('SELECT * FROM MYCIMLED.EDC_LOG FETCH FIRST 500 ROWS ONLY');
|
||||
});
|
||||
|
||||
it.each([
|
||||
['sqlserver'],
|
||||
['mssql'],
|
||||
['sql_server'],
|
||||
['sql-server'],
|
||||
])('adds TOP limit for %s connections', (dbType) => {
|
||||
expect(applyQueryAutoLimit('SELECT * FROM users', dbType, 500).sql)
|
||||
.toBe('SELECT TOP 500 * FROM users');
|
||||
});
|
||||
|
||||
it('adds SQL Server TOP after DISTINCT', () => {
|
||||
expect(applyQueryAutoLimit('SELECT DISTINCT name FROM users', 'sqlserver', 500).sql)
|
||||
.toBe('SELECT DISTINCT TOP 500 name FROM users');
|
||||
});
|
||||
|
||||
it.each([
|
||||
['oracle', 'SELECT * FROM users FETCH FIRST 500 ROWS ONLY'],
|
||||
['dm8', 'SELECT * FROM users FETCH FIRST 500 ROWS ONLY'],
|
||||
['mssql', 'SELECT TOP 500 * FROM users'],
|
||||
['postgresql', 'SELECT * FROM users LIMIT 500'],
|
||||
['doris', 'SELECT * FROM users LIMIT 500'],
|
||||
['sqlite3', 'SELECT * FROM users LIMIT 500'],
|
||||
])('uses custom driver dialect %s', (driver, expected) => {
|
||||
expect(applyQueryAutoLimit('SELECT * FROM users', 'custom', 500, driver).sql)
|
||||
.toBe(expected);
|
||||
});
|
||||
|
||||
it('keeps trailing semicolon and comments after injected Oracle limit', () => {
|
||||
expect(applyQueryAutoLimit('SELECT * FROM MYCIMLED.EDC_LOG; -- preview', 'oracle', 500).sql)
|
||||
.toBe('SELECT * FROM MYCIMLED.EDC_LOG FETCH FIRST 500 ROWS ONLY; -- preview');
|
||||
});
|
||||
|
||||
it('does not add another generic limit when SQL already limits rows', () => {
|
||||
expect(applyQueryAutoLimit('SELECT * FROM users LIMIT 10', 'mysql', 500).applied)
|
||||
.toBe(false);
|
||||
expect(applyQueryAutoLimit('SELECT * FROM users OFFSET 10 LIMIT 10', 'postgres', 500).applied)
|
||||
.toBe(false);
|
||||
});
|
||||
|
||||
it('does not treat nested LIMIT as the outer query limit', () => {
|
||||
expect(applyQueryAutoLimit('SELECT * FROM (SELECT * FROM users LIMIT 10) t', 'postgres', 500).sql)
|
||||
.toBe('SELECT * FROM (SELECT * FROM users LIMIT 10) t LIMIT 500');
|
||||
});
|
||||
|
||||
it('does not add another Oracle limit when Oracle SQL already limits rows', () => {
|
||||
expect(applyQueryAutoLimit('SELECT * FROM users WHERE ROWNUM <= 10', 'oracle', 500).applied)
|
||||
.toBe(false);
|
||||
expect(applyQueryAutoLimit('SELECT * FROM users FETCH FIRST 10 ROWS ONLY', 'oracle', 500).applied)
|
||||
.toBe(false);
|
||||
});
|
||||
|
||||
it('does not add another SQL Server limit when SQL already uses TOP', () => {
|
||||
expect(applyQueryAutoLimit('SELECT TOP 10 * FROM users', 'sqlserver', 500).applied)
|
||||
.toBe(false);
|
||||
});
|
||||
|
||||
it('adds generic LIMIT before locking clauses', () => {
|
||||
expect(applyQueryAutoLimit('SELECT * FROM users FOR UPDATE', 'mysql', 500).sql)
|
||||
.toBe('SELECT * FROM users LIMIT 500 FOR UPDATE');
|
||||
});
|
||||
|
||||
it('adds generic LIMIT before OFFSET clauses', () => {
|
||||
expect(applyQueryAutoLimit('SELECT * FROM users OFFSET 10', 'postgres', 500).sql)
|
||||
.toBe('SELECT * FROM users LIMIT 500 OFFSET 10');
|
||||
});
|
||||
|
||||
it('does not limit non-select statements', () => {
|
||||
expect(applyQueryAutoLimit('UPDATE users SET name = \'a\'', 'mysql', 500).applied)
|
||||
.toBe(false);
|
||||
});
|
||||
});
|
||||
336
frontend/src/utils/queryAutoLimit.ts
Normal file
336
frontend/src/utils/queryAutoLimit.ts
Normal file
@@ -0,0 +1,336 @@
|
||||
import { resolveSqlDialect } from './sqlDialect';
|
||||
|
||||
const isWS = (ch: string) => ch === ' ' || ch === '\t' || ch === '\n' || ch === '\r';
|
||||
const isWord = (ch: string) => /[A-Za-z0-9_]/.test(ch);
|
||||
|
||||
const getLeadingKeyword = (sql: string): string => {
|
||||
const text = (sql || '').replace(/\r\n/g, '\n');
|
||||
let inSingle = false;
|
||||
let inDouble = false;
|
||||
let inBacktick = false;
|
||||
let escaped = false;
|
||||
let inLineComment = false;
|
||||
let inBlockComment = false;
|
||||
let dollarTag: string | null = null;
|
||||
|
||||
for (let i = 0; i < text.length; i++) {
|
||||
const ch = text[i];
|
||||
const next = i + 1 < text.length ? text[i + 1] : '';
|
||||
const prev = i > 0 ? text[i - 1] : '';
|
||||
const next2 = i + 2 < text.length ? text[i + 2] : '';
|
||||
|
||||
if (!inSingle && !inDouble && !inBacktick) {
|
||||
if (inLineComment) {
|
||||
if (ch === '\n') inLineComment = false;
|
||||
continue;
|
||||
}
|
||||
if (inBlockComment) {
|
||||
if (ch === '*' && next === '/') {
|
||||
i++;
|
||||
inBlockComment = false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (ch === '/' && next === '*') {
|
||||
i++;
|
||||
inBlockComment = true;
|
||||
continue;
|
||||
}
|
||||
if (ch === '#') {
|
||||
inLineComment = true;
|
||||
continue;
|
||||
}
|
||||
if (ch === '-' && next === '-' && (i === 0 || isWS(prev)) && (next2 === '' || isWS(next2))) {
|
||||
i++;
|
||||
inLineComment = true;
|
||||
continue;
|
||||
}
|
||||
if (dollarTag) {
|
||||
if (text.startsWith(dollarTag, i)) {
|
||||
i += dollarTag.length - 1;
|
||||
dollarTag = null;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (ch === '$') {
|
||||
const m = text.slice(i).match(/^\$[A-Za-z0-9_]*\$/);
|
||||
if (m && m[0]) {
|
||||
dollarTag = m[0];
|
||||
i += dollarTag.length - 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (escaped) {
|
||||
escaped = false;
|
||||
continue;
|
||||
}
|
||||
if ((inSingle || inDouble) && ch === '\\') {
|
||||
escaped = true;
|
||||
continue;
|
||||
}
|
||||
if (!inDouble && !inBacktick && ch === "'") {
|
||||
inSingle = !inSingle;
|
||||
continue;
|
||||
}
|
||||
if (!inSingle && !inBacktick && ch === '"') {
|
||||
inDouble = !inDouble;
|
||||
continue;
|
||||
}
|
||||
if (!inSingle && !inDouble && ch === '`') {
|
||||
inBacktick = !inBacktick;
|
||||
continue;
|
||||
}
|
||||
if (inSingle || inDouble || inBacktick || dollarTag) continue;
|
||||
if (isWS(ch)) continue;
|
||||
if (isWord(ch)) {
|
||||
let j = i;
|
||||
while (j < text.length && isWord(text[j])) j++;
|
||||
return text.slice(i, j).toLowerCase();
|
||||
}
|
||||
return '';
|
||||
}
|
||||
return '';
|
||||
};
|
||||
|
||||
const splitSqlTail = (sql: string): { main: string; tail: string } => {
|
||||
const text = (sql || '').replace(/\r\n/g, '\n');
|
||||
let inSingle = false;
|
||||
let inDouble = false;
|
||||
let inBacktick = false;
|
||||
let escaped = false;
|
||||
let inLineComment = false;
|
||||
let inBlockComment = false;
|
||||
let dollarTag: string | null = null;
|
||||
let lastMeaningful = -1;
|
||||
|
||||
for (let i = 0; i < text.length; i++) {
|
||||
const ch = text[i];
|
||||
const next = i + 1 < text.length ? text[i + 1] : '';
|
||||
const prev = i > 0 ? text[i - 1] : '';
|
||||
const next2 = i + 2 < text.length ? text[i + 2] : '';
|
||||
|
||||
if (!inSingle && !inDouble && !inBacktick) {
|
||||
if (dollarTag) {
|
||||
if (text.startsWith(dollarTag, i)) {
|
||||
lastMeaningful = i + dollarTag.length - 1;
|
||||
i += dollarTag.length - 1;
|
||||
dollarTag = null;
|
||||
} else if (!isWS(ch)) {
|
||||
lastMeaningful = i;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (inLineComment) {
|
||||
if (ch === '\n') inLineComment = false;
|
||||
continue;
|
||||
}
|
||||
if (inBlockComment) {
|
||||
if (ch === '*' && next === '/') {
|
||||
i++;
|
||||
inBlockComment = false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (ch === '/' && next === '*') {
|
||||
i++;
|
||||
inBlockComment = true;
|
||||
continue;
|
||||
}
|
||||
if (ch === '#') {
|
||||
inLineComment = true;
|
||||
continue;
|
||||
}
|
||||
if (ch === '-' && next === '-' && (i === 0 || isWS(prev)) && (next2 === '' || isWS(next2))) {
|
||||
i++;
|
||||
inLineComment = true;
|
||||
continue;
|
||||
}
|
||||
if (ch === '$') {
|
||||
const m = text.slice(i).match(/^\$[A-Za-z0-9_]*\$/);
|
||||
if (m && m[0]) {
|
||||
dollarTag = m[0];
|
||||
lastMeaningful = i + dollarTag.length - 1;
|
||||
i += dollarTag.length - 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (escaped) {
|
||||
escaped = false;
|
||||
} else if ((inSingle || inDouble) && ch === '\\') {
|
||||
escaped = true;
|
||||
} else {
|
||||
if (!inDouble && !inBacktick && ch === "'") inSingle = !inSingle;
|
||||
else if (!inSingle && !inBacktick && ch === '"') inDouble = !inDouble;
|
||||
else if (!inSingle && !inDouble && ch === '`') inBacktick = !inBacktick;
|
||||
}
|
||||
|
||||
if (!inLineComment && !inBlockComment && !isWS(ch)) {
|
||||
lastMeaningful = i;
|
||||
}
|
||||
}
|
||||
|
||||
if (lastMeaningful < 0) return { main: '', tail: text };
|
||||
let mainEnd = lastMeaningful + 1;
|
||||
while (mainEnd > 0 && (isWS(text[mainEnd - 1]) || text[mainEnd - 1] === ';' || text[mainEnd - 1] === ';')) {
|
||||
mainEnd--;
|
||||
}
|
||||
return { main: text.slice(0, mainEnd), tail: text.slice(mainEnd) };
|
||||
};
|
||||
|
||||
const findTopLevelKeyword = (sql: string, keyword: string): number => {
|
||||
const text = sql;
|
||||
const kw = keyword.toLowerCase();
|
||||
let inSingle = false;
|
||||
let inDouble = false;
|
||||
let inBacktick = false;
|
||||
let escaped = false;
|
||||
let inLineComment = false;
|
||||
let inBlockComment = false;
|
||||
let dollarTag: string | null = null;
|
||||
let parenDepth = 0;
|
||||
|
||||
for (let i = 0; i < text.length; i++) {
|
||||
const ch = text[i];
|
||||
const next = i + 1 < text.length ? text[i + 1] : '';
|
||||
const prev = i > 0 ? text[i - 1] : '';
|
||||
const next2 = i + 2 < text.length ? text[i + 2] : '';
|
||||
|
||||
if (!inSingle && !inDouble && !inBacktick) {
|
||||
if (inLineComment) {
|
||||
if (ch === '\n') inLineComment = false;
|
||||
continue;
|
||||
}
|
||||
if (inBlockComment) {
|
||||
if (ch === '*' && next === '/') {
|
||||
i++;
|
||||
inBlockComment = false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (ch === '/' && next === '*') {
|
||||
i++;
|
||||
inBlockComment = true;
|
||||
continue;
|
||||
}
|
||||
if (ch === '#') {
|
||||
inLineComment = true;
|
||||
continue;
|
||||
}
|
||||
if (ch === '-' && next === '-' && (i === 0 || isWS(prev)) && (next2 === '' || isWS(next2))) {
|
||||
i++;
|
||||
inLineComment = true;
|
||||
continue;
|
||||
}
|
||||
if (dollarTag) {
|
||||
if (text.startsWith(dollarTag, i)) {
|
||||
i += dollarTag.length - 1;
|
||||
dollarTag = null;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (ch === '$') {
|
||||
const m = text.slice(i).match(/^\$[A-Za-z0-9_]*\$/);
|
||||
if (m && m[0]) {
|
||||
dollarTag = m[0];
|
||||
i += dollarTag.length - 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (escaped) {
|
||||
escaped = false;
|
||||
continue;
|
||||
}
|
||||
if ((inSingle || inDouble) && ch === '\\') {
|
||||
escaped = true;
|
||||
continue;
|
||||
}
|
||||
if (!inDouble && !inBacktick && ch === "'") {
|
||||
inSingle = !inSingle;
|
||||
continue;
|
||||
}
|
||||
if (!inSingle && !inBacktick && ch === '"') {
|
||||
inDouble = !inDouble;
|
||||
continue;
|
||||
}
|
||||
if (!inSingle && !inDouble && ch === '`') {
|
||||
inBacktick = !inBacktick;
|
||||
continue;
|
||||
}
|
||||
if (inSingle || inDouble || inBacktick || dollarTag) continue;
|
||||
if (ch === '(') {
|
||||
parenDepth++;
|
||||
continue;
|
||||
}
|
||||
if (ch === ')') {
|
||||
if (parenDepth > 0) parenDepth--;
|
||||
continue;
|
||||
}
|
||||
if (parenDepth !== 0) continue;
|
||||
if (!isWord(ch)) continue;
|
||||
if (text.slice(i, i + kw.length).toLowerCase() !== kw) continue;
|
||||
const before = i - 1 >= 0 ? text[i - 1] : '';
|
||||
const after = i + kw.length < text.length ? text[i + kw.length] : '';
|
||||
if ((before && isWord(before)) || (after && isWord(after))) continue;
|
||||
return i;
|
||||
}
|
||||
return -1;
|
||||
};
|
||||
|
||||
export const applyQueryAutoLimit = (
|
||||
sql: string,
|
||||
dbType: string,
|
||||
maxRows: number,
|
||||
driver = '',
|
||||
): { sql: string; applied: boolean; maxRows: number } => {
|
||||
if (!Number.isFinite(maxRows) || maxRows <= 0) return { sql, applied: false, maxRows };
|
||||
const normalizedType = String(resolveSqlDialect(dbType || 'mysql', driver)).toLowerCase();
|
||||
const keyword = getLeadingKeyword(sql);
|
||||
if (keyword !== 'select') return { sql, applied: false, maxRows };
|
||||
|
||||
const { main, tail } = splitSqlTail(sql);
|
||||
if (!main.trim()) return { sql, applied: false, maxRows };
|
||||
|
||||
const fromPos = findTopLevelKeyword(main, 'from');
|
||||
const limitPos = findTopLevelKeyword(main, 'limit');
|
||||
if (limitPos >= 0 && (fromPos < 0 || limitPos > fromPos)) return { sql, applied: false, maxRows };
|
||||
const fetchPos = findTopLevelKeyword(main, 'fetch');
|
||||
if (fetchPos >= 0 && (fromPos < 0 || fetchPos > fromPos)) return { sql, applied: false, maxRows };
|
||||
|
||||
if (normalizedType === 'sqlserver' || normalizedType === 'mssql') {
|
||||
const topPos = findTopLevelKeyword(main, 'top');
|
||||
if (topPos >= 0) return { sql, applied: false, maxRows };
|
||||
const selectPos = findTopLevelKeyword(main, 'select');
|
||||
if (selectPos < 0) return { sql, applied: false, maxRows };
|
||||
const afterSelect = selectPos + 'SELECT'.length;
|
||||
const restAfterSelect = main.slice(afterSelect);
|
||||
const distinctMatch = restAfterSelect.match(/^(\s+DISTINCT\b)/i);
|
||||
const insertOffset = distinctMatch ? afterSelect + distinctMatch[1].length : afterSelect;
|
||||
const nextMain = main.slice(0, insertOffset) + ` TOP ${maxRows}` + main.slice(insertOffset);
|
||||
return { sql: nextMain + tail, applied: true, maxRows };
|
||||
}
|
||||
|
||||
if (normalizedType === 'oracle' || normalizedType === 'dameng') {
|
||||
const rownumPos = findTopLevelKeyword(main, 'rownum');
|
||||
if (rownumPos >= 0) return { sql, applied: false, maxRows };
|
||||
const offsetPos = findTopLevelKeyword(main, 'offset');
|
||||
if (offsetPos >= 0 && (fromPos < 0 || offsetPos > fromPos)) return { sql, applied: false, maxRows };
|
||||
return { sql: `${main.trimEnd()} FETCH FIRST ${maxRows} ROWS ONLY${tail}`, applied: true, maxRows };
|
||||
}
|
||||
|
||||
const offsetPos = findTopLevelKeyword(main, 'offset');
|
||||
const forPos = findTopLevelKeyword(main, 'for');
|
||||
const lockPos = findTopLevelKeyword(main, 'lock');
|
||||
const candidates = [offsetPos, forPos, lockPos]
|
||||
.filter(pos => pos >= 0 && (fromPos < 0 || pos > fromPos));
|
||||
const insertAt = candidates.length > 0 ? Math.min(...candidates) : main.length;
|
||||
const before = main.slice(0, insertAt).trimEnd();
|
||||
const after = main.slice(insertAt).trimStart();
|
||||
const nextMain = [before, `LIMIT ${maxRows}`, after].filter(Boolean).join(' ').trim();
|
||||
return { sql: nextMain + tail, applied: true, maxRows };
|
||||
};
|
||||
44
frontend/src/utils/queryResultTable.test.ts
Normal file
44
frontend/src/utils/queryResultTable.test.ts
Normal file
@@ -0,0 +1,44 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import { extractQueryResultTableRef } from './queryResultTable';
|
||||
|
||||
describe('extractQueryResultTableRef', () => {
|
||||
it('preserves Oracle schema-qualified table names for editing', () => {
|
||||
expect(extractQueryResultTableRef('SELECT * FROM MYCIMLED.EDC_LOG FETCH FIRST 500 ROWS ONLY', 'oracle', 'ANONYMOUS'))
|
||||
.toEqual({
|
||||
tableName: 'MYCIMLED.EDC_LOG',
|
||||
metadataDbName: 'MYCIMLED',
|
||||
metadataTableName: 'EDC_LOG',
|
||||
});
|
||||
});
|
||||
|
||||
it('uses current schema for unqualified Oracle tables', () => {
|
||||
expect(extractQueryResultTableRef('SELECT * FROM EDC_LOG', 'oracle', 'MYCIMLED'))
|
||||
.toEqual({
|
||||
tableName: 'EDC_LOG',
|
||||
metadataDbName: 'MYCIMLED',
|
||||
metadataTableName: 'EDC_LOG',
|
||||
});
|
||||
});
|
||||
|
||||
it('keeps existing simple table behavior for MySQL-style qualified names', () => {
|
||||
expect(extractQueryResultTableRef('SELECT * FROM app.users LIMIT 500', 'mysql', 'app'))
|
||||
.toEqual({
|
||||
tableName: 'users',
|
||||
metadataDbName: 'app',
|
||||
metadataTableName: 'users',
|
||||
});
|
||||
});
|
||||
|
||||
it('does not mark join results as editable table refs', () => {
|
||||
expect(extractQueryResultTableRef('SELECT * FROM users u JOIN orders o ON u.id = o.user_id', 'oracle', 'APP'))
|
||||
.toBeUndefined();
|
||||
});
|
||||
|
||||
it('does not mark grouped or distinct results as editable table refs', () => {
|
||||
expect(extractQueryResultTableRef('SELECT ID FROM users GROUP BY ID', 'mysql', 'app'))
|
||||
.toBeUndefined();
|
||||
expect(extractQueryResultTableRef('SELECT DISTINCT ID FROM users', 'mysql', 'app'))
|
||||
.toBeUndefined();
|
||||
});
|
||||
});
|
||||
64
frontend/src/utils/queryResultTable.ts
Normal file
64
frontend/src/utils/queryResultTable.ts
Normal file
@@ -0,0 +1,64 @@
|
||||
export type QueryResultTableRef = {
|
||||
tableName: string;
|
||||
metadataDbName: string;
|
||||
metadataTableName: string;
|
||||
};
|
||||
|
||||
const stripIdentifierQuotes = (part: string): string => {
|
||||
const text = String(part || '').trim();
|
||||
if (!text) return '';
|
||||
if ((text.startsWith('`') && text.endsWith('`')) || (text.startsWith('"') && text.endsWith('"'))) {
|
||||
return text.slice(1, -1).trim();
|
||||
}
|
||||
if (text.startsWith('[') && text.endsWith(']')) {
|
||||
return text.slice(1, -1).trim();
|
||||
}
|
||||
return text;
|
||||
};
|
||||
|
||||
const normalizeQualifiedName = (raw: string): string => (
|
||||
String(raw || '')
|
||||
.split('.')
|
||||
.map((part) => stripIdentifierQuotes(part.trim()))
|
||||
.filter(Boolean)
|
||||
.join('.')
|
||||
);
|
||||
|
||||
const isOracleLikeDialect = (dialect: string): boolean => {
|
||||
const normalized = String(dialect || '').trim().toLowerCase();
|
||||
return normalized === 'oracle' || normalized === 'dameng' || normalized === 'dm' || normalized === 'dm8';
|
||||
};
|
||||
|
||||
export const extractQueryResultTableRef = (
|
||||
sql: string,
|
||||
dialect: string,
|
||||
currentDb: string,
|
||||
): QueryResultTableRef | undefined => {
|
||||
const text = String(sql || '').trim();
|
||||
if (!text) return undefined;
|
||||
if (/\b(JOIN|UNION|INTERSECT|EXCEPT|MINUS)\b/i.test(text)) return undefined;
|
||||
if (/^\s*SELECT\s+DISTINCT\b/i.test(text)) return undefined;
|
||||
if (/\bGROUP\s+BY\b|\bHAVING\b/i.test(text)) return undefined;
|
||||
|
||||
const tableMatch = text.match(/^\s*SELECT\s+.+?\s+FROM\s+((?:[`"\[]?\w+[`"\]]?)(?:\s*\.\s*(?:[`"\[]?\w+[`"\]]?)){0,2})\s*(?:$|[\s;])/im);
|
||||
if (!tableMatch) return undefined;
|
||||
|
||||
const qualifiedName = normalizeQualifiedName(tableMatch[1]);
|
||||
if (!qualifiedName) return undefined;
|
||||
|
||||
const parts = qualifiedName.split('.').filter(Boolean);
|
||||
const metadataTableName = parts[parts.length - 1] || '';
|
||||
if (!metadataTableName) return undefined;
|
||||
|
||||
const owner = parts.length >= 2 ? parts[parts.length - 2] : '';
|
||||
const metadataDbName = owner || currentDb || '';
|
||||
const tableName = isOracleLikeDialect(dialect) && owner
|
||||
? `${owner}.${metadataTableName}`
|
||||
: metadataTableName;
|
||||
|
||||
return {
|
||||
tableName,
|
||||
metadataDbName,
|
||||
metadataTableName,
|
||||
};
|
||||
};
|
||||
146
frontend/src/utils/rowLocator.test.ts
Normal file
146
frontend/src/utils/rowLocator.test.ts
Normal file
@@ -0,0 +1,146 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import {
|
||||
ORACLE_ROWID_LOCATOR_COLUMN,
|
||||
filterHiddenLocatorColumns,
|
||||
resolveEditRowLocator,
|
||||
resolveRowLocatorValues,
|
||||
} from './rowLocator';
|
||||
|
||||
const uniqueIndex = (name: string, columnName: string, seqInIndex = 1) => ({
|
||||
name,
|
||||
columnName,
|
||||
seqInIndex,
|
||||
nonUnique: 0,
|
||||
indexType: 'BTREE',
|
||||
});
|
||||
|
||||
const normalIndex = (name: string, columnName: string, seqInIndex = 1) => ({
|
||||
name,
|
||||
columnName,
|
||||
seqInIndex,
|
||||
nonUnique: 1,
|
||||
indexType: 'BTREE',
|
||||
});
|
||||
|
||||
describe('resolveEditRowLocator', () => {
|
||||
it('prefers primary keys over unique indexes', () => {
|
||||
expect(resolveEditRowLocator({
|
||||
dbType: 'mysql',
|
||||
resultColumns: ['ID', 'EMAIL'],
|
||||
primaryKeys: ['ID'],
|
||||
indexes: [uniqueIndex('uk_email', 'EMAIL')],
|
||||
})).toEqual({
|
||||
strategy: 'primary-key',
|
||||
columns: ['ID'],
|
||||
valueColumns: ['ID'],
|
||||
readOnly: false,
|
||||
});
|
||||
});
|
||||
|
||||
it('uses a unique index when there is no primary key', () => {
|
||||
expect(resolveEditRowLocator({
|
||||
dbType: 'mysql',
|
||||
resultColumns: ['EMAIL', 'NAME'],
|
||||
indexes: [uniqueIndex('uk_email', 'EMAIL')],
|
||||
})).toEqual({
|
||||
strategy: 'unique-key',
|
||||
columns: ['EMAIL'],
|
||||
valueColumns: ['EMAIL'],
|
||||
readOnly: false,
|
||||
});
|
||||
});
|
||||
|
||||
it('sorts composite unique index columns by sequence', () => {
|
||||
expect(resolveEditRowLocator({
|
||||
dbType: 'postgres',
|
||||
resultColumns: ['TENANT_ID', 'CODE', 'NAME'],
|
||||
indexes: [
|
||||
uniqueIndex('uk_tenant_code', 'CODE', 2),
|
||||
uniqueIndex('uk_tenant_code', 'TENANT_ID', 1),
|
||||
],
|
||||
})).toMatchObject({
|
||||
strategy: 'unique-key',
|
||||
columns: ['TENANT_ID', 'CODE'],
|
||||
valueColumns: ['TENANT_ID', 'CODE'],
|
||||
readOnly: false,
|
||||
});
|
||||
});
|
||||
|
||||
it('ignores non-unique indexes', () => {
|
||||
expect(resolveEditRowLocator({
|
||||
dbType: 'mysql',
|
||||
resultColumns: ['NAME'],
|
||||
indexes: [normalIndex('idx_name', 'NAME')],
|
||||
})).toMatchObject({
|
||||
strategy: 'none',
|
||||
readOnly: true,
|
||||
});
|
||||
});
|
||||
|
||||
it('keeps results read-only when primary key columns are missing from result columns', () => {
|
||||
expect(resolveEditRowLocator({
|
||||
dbType: 'oracle',
|
||||
resultColumns: ['NAME'],
|
||||
primaryKeys: ['ID'],
|
||||
})).toMatchObject({
|
||||
strategy: 'none',
|
||||
readOnly: true,
|
||||
reason: '结果集中缺少主键列 ID,无法安全提交修改。',
|
||||
});
|
||||
});
|
||||
|
||||
it('uses Oracle ROWID when no primary or unique key is available', () => {
|
||||
expect(resolveEditRowLocator({
|
||||
dbType: 'oracle',
|
||||
resultColumns: ['NAME', ORACLE_ROWID_LOCATOR_COLUMN],
|
||||
allowOracleRowID: true,
|
||||
})).toEqual({
|
||||
strategy: 'oracle-rowid',
|
||||
columns: ['ROWID'],
|
||||
valueColumns: [ORACLE_ROWID_LOCATOR_COLUMN],
|
||||
hiddenColumns: [ORACLE_ROWID_LOCATOR_COLUMN],
|
||||
readOnly: false,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('resolveRowLocatorValues', () => {
|
||||
it('extracts locator values from the original row', () => {
|
||||
const locator = resolveEditRowLocator({
|
||||
dbType: 'mysql',
|
||||
resultColumns: ['EMAIL', 'NAME'],
|
||||
indexes: [uniqueIndex('uk_email', 'EMAIL')],
|
||||
});
|
||||
|
||||
expect(resolveRowLocatorValues(locator, { EMAIL: 'a@example.com', NAME: 'A' })).toEqual({
|
||||
ok: true,
|
||||
values: { EMAIL: 'a@example.com' },
|
||||
});
|
||||
});
|
||||
|
||||
it('rejects nullable unique locator values', () => {
|
||||
const locator = resolveEditRowLocator({
|
||||
dbType: 'mysql',
|
||||
resultColumns: ['EMAIL', 'NAME'],
|
||||
indexes: [uniqueIndex('uk_email', 'EMAIL')],
|
||||
});
|
||||
|
||||
expect(resolveRowLocatorValues(locator, { EMAIL: null, NAME: 'A' })).toEqual({
|
||||
ok: false,
|
||||
error: '定位列 EMAIL 的值为空,无法安全提交修改。',
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('filterHiddenLocatorColumns', () => {
|
||||
it('removes hidden Oracle ROWID columns from displayed columns', () => {
|
||||
const locator = resolveEditRowLocator({
|
||||
dbType: 'oracle',
|
||||
resultColumns: ['NAME', ORACLE_ROWID_LOCATOR_COLUMN],
|
||||
allowOracleRowID: true,
|
||||
});
|
||||
|
||||
expect(filterHiddenLocatorColumns(['NAME', ORACLE_ROWID_LOCATOR_COLUMN], locator)).toEqual(['NAME']);
|
||||
});
|
||||
});
|
||||
152
frontend/src/utils/rowLocator.ts
Normal file
152
frontend/src/utils/rowLocator.ts
Normal file
@@ -0,0 +1,152 @@
|
||||
import type { IndexDefinition } from '../types';
|
||||
import { resolveUniqueKeyGroupsFromIndexes } from '../components/dataGridCopyInsert';
|
||||
import { isOracleLikeDialect } from './sqlDialect';
|
||||
|
||||
export const ORACLE_ROWID_LOCATOR_COLUMN = '__gonavi_oracle_rowid__';
|
||||
|
||||
export type RowLocatorStrategy = 'primary-key' | 'unique-key' | 'oracle-rowid' | 'none';
|
||||
|
||||
export type EditRowLocator = {
|
||||
strategy: RowLocatorStrategy;
|
||||
columns: string[];
|
||||
valueColumns: string[];
|
||||
hiddenColumns?: string[];
|
||||
writableColumns?: Record<string, string>;
|
||||
readOnly: boolean;
|
||||
reason?: string;
|
||||
};
|
||||
|
||||
export type ResolveEditRowLocatorParams = {
|
||||
dbType: string;
|
||||
resultColumns: string[];
|
||||
primaryKeys?: string[];
|
||||
indexes?: IndexDefinition[];
|
||||
allowOracleRowID?: boolean;
|
||||
};
|
||||
|
||||
export type ResolveRowLocatorValuesResult =
|
||||
| { ok: true; values: Record<string, any> }
|
||||
| { ok: false; error: string };
|
||||
|
||||
const normalizeColumnName = (value: string): string => String(value || '').trim();
|
||||
|
||||
const hasColumn = (columns: string[], target: string): boolean => {
|
||||
const normalizedTarget = normalizeColumnName(target).toLowerCase();
|
||||
return columns.some((column) => normalizeColumnName(column).toLowerCase() === normalizedTarget);
|
||||
};
|
||||
|
||||
const findColumn = (columns: string[], target: string): string => {
|
||||
const normalizedTarget = normalizeColumnName(target).toLowerCase();
|
||||
return columns.find((column) => normalizeColumnName(column).toLowerCase() === normalizedTarget) || target;
|
||||
};
|
||||
|
||||
const buildReadOnlyLocator = (reason: string): EditRowLocator => ({
|
||||
strategy: 'none',
|
||||
columns: [],
|
||||
valueColumns: [],
|
||||
readOnly: true,
|
||||
reason,
|
||||
});
|
||||
|
||||
export const resolveEditRowLocator = ({
|
||||
dbType,
|
||||
resultColumns,
|
||||
primaryKeys = [],
|
||||
indexes,
|
||||
allowOracleRowID = false,
|
||||
}: ResolveEditRowLocatorParams): EditRowLocator => {
|
||||
const columns = (resultColumns || []).map(normalizeColumnName).filter(Boolean);
|
||||
const primaryKeyColumns = (primaryKeys || []).map(normalizeColumnName).filter(Boolean);
|
||||
|
||||
if (primaryKeyColumns.length > 0) {
|
||||
const missing = primaryKeyColumns.filter((column) => !hasColumn(columns, column));
|
||||
if (missing.length === 0) {
|
||||
return {
|
||||
strategy: 'primary-key',
|
||||
columns: primaryKeyColumns,
|
||||
valueColumns: primaryKeyColumns.map((column) => findColumn(columns, column)),
|
||||
readOnly: false,
|
||||
};
|
||||
}
|
||||
return buildReadOnlyLocator(`结果集中缺少主键列 ${missing.join(', ')},无法安全提交修改。`);
|
||||
}
|
||||
|
||||
const uniqueKeyGroups = resolveUniqueKeyGroupsFromIndexes(indexes);
|
||||
const uniqueKeyGroup = uniqueKeyGroups.find((group) => group.length > 0 && group.every((column) => hasColumn(columns, column)));
|
||||
if (uniqueKeyGroup) {
|
||||
return {
|
||||
strategy: 'unique-key',
|
||||
columns: uniqueKeyGroup,
|
||||
valueColumns: uniqueKeyGroup.map((column) => findColumn(columns, column)),
|
||||
readOnly: false,
|
||||
};
|
||||
}
|
||||
|
||||
if (allowOracleRowID && isOracleLikeDialect(dbType) && hasColumn(columns, ORACLE_ROWID_LOCATOR_COLUMN)) {
|
||||
const rowIDColumn = findColumn(columns, ORACLE_ROWID_LOCATOR_COLUMN);
|
||||
return {
|
||||
strategy: 'oracle-rowid',
|
||||
columns: ['ROWID'],
|
||||
valueColumns: [rowIDColumn],
|
||||
hiddenColumns: [rowIDColumn],
|
||||
readOnly: false,
|
||||
};
|
||||
}
|
||||
|
||||
if (allowOracleRowID && isOracleLikeDialect(dbType)) {
|
||||
return buildReadOnlyLocator('未检测到主键或可用唯一索引,且结果中缺少 Oracle ROWID,无法安全提交修改。');
|
||||
}
|
||||
|
||||
return buildReadOnlyLocator('未检测到主键或可用唯一索引,无法安全提交修改。');
|
||||
};
|
||||
|
||||
export const resolveRowLocatorValues = (
|
||||
locator: EditRowLocator | undefined,
|
||||
row: Record<string, any>,
|
||||
): ResolveRowLocatorValuesResult => {
|
||||
if (!locator || locator.readOnly || locator.strategy === 'none') {
|
||||
return { ok: false, error: '当前结果没有可用的安全行定位方式,无法提交修改。' };
|
||||
}
|
||||
|
||||
const values: Record<string, any> = {};
|
||||
for (let index = 0; index < locator.columns.length; index++) {
|
||||
const column = locator.columns[index];
|
||||
const valueColumn = locator.valueColumns[index] || column;
|
||||
const value = row?.[valueColumn];
|
||||
if (value === null || value === undefined || value === '') {
|
||||
return { ok: false, error: `定位列 ${column} 的值为空,无法安全提交修改。` };
|
||||
}
|
||||
values[column] = value;
|
||||
}
|
||||
|
||||
return { ok: true, values };
|
||||
};
|
||||
|
||||
export const filterHiddenLocatorColumns = (columns: string[], locator?: EditRowLocator): string[] => {
|
||||
const hidden = new Set((locator?.hiddenColumns || []).map((column) => normalizeColumnName(column).toLowerCase()));
|
||||
if (hidden.size === 0) return columns;
|
||||
return (columns || []).filter((column) => !hidden.has(normalizeColumnName(column).toLowerCase()));
|
||||
};
|
||||
|
||||
export const isHiddenLocatorColumn = (column: string, locator?: EditRowLocator): boolean => {
|
||||
const normalized = normalizeColumnName(column).toLowerCase();
|
||||
return (locator?.hiddenColumns || []).some((hidden) => normalizeColumnName(hidden).toLowerCase() === normalized);
|
||||
};
|
||||
|
||||
export const resolveWritableColumnName = (column: string, locator?: EditRowLocator): string | undefined => {
|
||||
const normalized = normalizeColumnName(column);
|
||||
if (!normalized || isHiddenLocatorColumn(normalized, locator)) return undefined;
|
||||
const writableColumns = locator?.writableColumns;
|
||||
if (!writableColumns) return normalized;
|
||||
|
||||
const normalizedTarget = normalized.toLowerCase();
|
||||
const matchedEntry = Object.entries(writableColumns).find(([resultColumn]) => (
|
||||
normalizeColumnName(resultColumn).toLowerCase() === normalizedTarget
|
||||
));
|
||||
const tableColumnName = normalizeColumnName(matchedEntry?.[1] || '');
|
||||
return tableColumnName || undefined;
|
||||
};
|
||||
|
||||
export const isWritableResultColumn = (column: string, locator?: EditRowLocator): boolean => (
|
||||
resolveWritableColumnName(column, locator) !== undefined
|
||||
);
|
||||
@@ -426,6 +426,7 @@ export namespace connection {
|
||||
inserts: any[];
|
||||
updates: UpdateRow[];
|
||||
deletes: any[];
|
||||
locatorStrategy?: string;
|
||||
|
||||
static createFrom(source: any = {}) {
|
||||
return new ChangeSet(source);
|
||||
@@ -436,6 +437,7 @@ export namespace connection {
|
||||
this.inserts = source["inserts"];
|
||||
this.updates = this.convertValues(source["updates"], UpdateRow);
|
||||
this.deletes = source["deletes"];
|
||||
this.locatorStrategy = source["locatorStrategy"];
|
||||
}
|
||||
|
||||
convertValues(a: any, classs: any, asMap: boolean = false): any {
|
||||
@@ -668,6 +670,7 @@ export namespace connection {
|
||||
timeout?: number;
|
||||
redisDB?: number;
|
||||
uri?: string;
|
||||
clickHouseProtocol?: string;
|
||||
hosts?: string[];
|
||||
topology?: string;
|
||||
mysqlReplicaUser?: string;
|
||||
@@ -710,6 +713,7 @@ export namespace connection {
|
||||
this.timeout = source["timeout"];
|
||||
this.redisDB = source["redisDB"];
|
||||
this.uri = source["uri"];
|
||||
this.clickHouseProtocol = source["clickHouseProtocol"];
|
||||
this.hosts = source["hosts"];
|
||||
this.topology = source["topology"];
|
||||
this.mysqlReplicaUser = source["mysqlReplicaUser"];
|
||||
|
||||
@@ -467,6 +467,13 @@ func formatConnSummary(config connection.ConnectionConfig) string {
|
||||
b.WriteString(fmt.Sprintf(" 认证库=%s", strings.TrimSpace(config.AuthSource)))
|
||||
}
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(config.Type), "clickhouse") {
|
||||
protocol := strings.ToLower(strings.TrimSpace(config.ClickHouseProtocol))
|
||||
if protocol == "" {
|
||||
protocol = "auto"
|
||||
}
|
||||
b.WriteString(fmt.Sprintf(" ClickHouse协议=%s", protocol))
|
||||
}
|
||||
|
||||
if config.UseSSH {
|
||||
b.WriteString(fmt.Sprintf(" SSH=%s:%d 用户=%s", config.SSH.Host, config.SSH.Port, config.SSH.User))
|
||||
|
||||
@@ -80,3 +80,22 @@ func TestGetCacheKey_KeepDatabaseIsolation(t *testing.T) {
|
||||
t.Fatalf("expected different cache key for different database targets")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCacheKey_KeepClickHouseProtocolIsolation(t *testing.T) {
|
||||
base := connection.ConnectionConfig{
|
||||
Type: "clickhouse",
|
||||
Host: "clickhouse.local",
|
||||
Port: 8125,
|
||||
User: "default",
|
||||
Database: "default",
|
||||
ClickHouseProtocol: "native",
|
||||
}
|
||||
modified := base
|
||||
modified.ClickHouseProtocol = "http"
|
||||
|
||||
left := getCacheKey(base)
|
||||
right := getCacheKey(modified)
|
||||
if left == right {
|
||||
t.Fatalf("expected different cache key for different ClickHouse protocols")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,9 +23,24 @@ func normalizeTestConnectionConfig(config connection.ConnectionConfig) connectio
|
||||
return normalized
|
||||
}
|
||||
|
||||
func validateTestConnectionInput(config connection.ConnectionConfig) error {
|
||||
dbType := strings.ToLower(strings.TrimSpace(config.Type))
|
||||
if dbType == "" {
|
||||
return fmt.Errorf("请先选择数据源类型")
|
||||
}
|
||||
if dbType == "clickhouse" && strings.TrimSpace(config.Host) == "" && strings.TrimSpace(config.URI) == "" {
|
||||
return fmt.Errorf("请填写 ClickHouse 主机地址或连接 URI")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Generic DB Methods
|
||||
|
||||
func (a *App) DBConnect(config connection.ConnectionConfig) connection.QueryResult {
|
||||
if err := validateTestConnectionInput(config); err != nil {
|
||||
logger.Warnf("DBConnect 参数校验失败:%s %s", err.Error(), formatConnSummary(config))
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
// 连接测试需要强制 ping,避免缓存命中但连接已失效时误判成功。
|
||||
_, err := a.getDatabaseForcePing(config)
|
||||
if err != nil {
|
||||
@@ -41,6 +56,10 @@ func (a *App) TestConnection(config connection.ConnectionConfig) connection.Quer
|
||||
testConfig := normalizeTestConnectionConfig(config)
|
||||
started := time.Now()
|
||||
logger.Infof("TestConnection 开始:%s", formatConnSummary(testConfig))
|
||||
if err := validateTestConnectionInput(testConfig); err != nil {
|
||||
logger.Warnf("TestConnection 参数校验失败:耗时=%s %s 原因=%s", time.Since(started).Round(time.Millisecond), formatConnSummary(testConfig), err.Error())
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
_, err := a.getDatabaseForcePing(testConfig)
|
||||
if err != nil {
|
||||
logger.Error(err, "TestConnection 连接测试失败:耗时=%s %s", time.Since(started).Round(time.Millisecond), formatConnSummary(testConfig))
|
||||
|
||||
@@ -31,6 +31,26 @@ func TestNormalizeTestConnectionConfig_ZeroTimeout(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateTestConnectionInput_ClickHouseRequiresTarget(t *testing.T) {
|
||||
err := validateTestConnectionInput(connection.ConnectionConfig{Type: "clickhouse"})
|
||||
if err == nil {
|
||||
t.Fatal("expected ClickHouse target validation error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "ClickHouse 主机地址") {
|
||||
t.Fatalf("unexpected validation error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateTestConnectionInput_ClickHouseAllowsURI(t *testing.T) {
|
||||
err := validateTestConnectionInput(connection.ConnectionConfig{
|
||||
Type: "clickhouse",
|
||||
URI: "http://clickhouse.example.com:8125/default",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("expected ClickHouse URI to satisfy target validation, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatConnSummary_BasicMySQL(t *testing.T) {
|
||||
cfg := connection.ConnectionConfig{
|
||||
Type: "mysql",
|
||||
|
||||
@@ -127,6 +127,7 @@ type driverDefinition struct {
|
||||
type installedDriverPackage struct {
|
||||
DriverType string `json:"driverType"`
|
||||
Version string `json:"version,omitempty"`
|
||||
AgentRevision string `json:"agentRevision,omitempty"`
|
||||
FilePath string `json:"filePath"`
|
||||
FileName string `json:"fileName"`
|
||||
ExecutablePath string `json:"executablePath,omitempty"`
|
||||
@@ -136,23 +137,28 @@ type installedDriverPackage struct {
|
||||
}
|
||||
|
||||
type driverStatusItem struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
Engine string `json:"engine,omitempty"`
|
||||
BuiltIn bool `json:"builtIn"`
|
||||
PinnedVersion string `json:"pinnedVersion,omitempty"`
|
||||
InstalledVersion string `json:"installedVersion,omitempty"`
|
||||
PackageSizeText string `json:"packageSizeText,omitempty"`
|
||||
RuntimeAvailable bool `json:"runtimeAvailable"`
|
||||
PackageInstalled bool `json:"packageInstalled"`
|
||||
Connectable bool `json:"connectable"`
|
||||
DefaultDownloadURL string `json:"defaultDownloadUrl,omitempty"`
|
||||
InstallDir string `json:"installDir,omitempty"`
|
||||
PackagePath string `json:"packagePath,omitempty"`
|
||||
PackageFileName string `json:"packageFileName,omitempty"`
|
||||
ExecutablePath string `json:"executablePath,omitempty"`
|
||||
DownloadedAt string `json:"downloadedAt,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
Engine string `json:"engine,omitempty"`
|
||||
BuiltIn bool `json:"builtIn"`
|
||||
PinnedVersion string `json:"pinnedVersion,omitempty"`
|
||||
InstalledVersion string `json:"installedVersion,omitempty"`
|
||||
PackageSizeText string `json:"packageSizeText,omitempty"`
|
||||
RuntimeAvailable bool `json:"runtimeAvailable"`
|
||||
PackageInstalled bool `json:"packageInstalled"`
|
||||
Connectable bool `json:"connectable"`
|
||||
DefaultDownloadURL string `json:"defaultDownloadUrl,omitempty"`
|
||||
InstallDir string `json:"installDir,omitempty"`
|
||||
PackagePath string `json:"packagePath,omitempty"`
|
||||
PackageFileName string `json:"packageFileName,omitempty"`
|
||||
ExecutablePath string `json:"executablePath,omitempty"`
|
||||
DownloadedAt string `json:"downloadedAt,omitempty"`
|
||||
AgentRevision string `json:"agentRevision,omitempty"`
|
||||
ExpectedRevision string `json:"expectedRevision,omitempty"`
|
||||
NeedsUpdate bool `json:"needsUpdate,omitempty"`
|
||||
UpdateReason string `json:"updateReason,omitempty"`
|
||||
AffectedConnections int `json:"affectedConnections,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
const driverDownloadProgressEvent = "driver:download-progress"
|
||||
@@ -758,29 +764,36 @@ func (a *App) GetDriverStatusList(downloadDir string, manifestURL string) connec
|
||||
definitions := allDriverDefinitionsWithPackages(effectivePackages)
|
||||
triggerDriverVersionMetadataWarmup(definitions)
|
||||
packageSizeBytesMap := preloadOptionalDriverPackageSizes(definitions)
|
||||
usageCounts := a.savedConnectionDriverUsageCounts()
|
||||
items := make([]driverStatusItem, 0, len(definitions))
|
||||
for _, definition := range definitions {
|
||||
engine := effectiveDriverEngine(definition)
|
||||
runtimeAvailable, runtimeReason := db.DriverRuntimeSupportStatus(definition.Type)
|
||||
pkg, packageMetaExists := readInstalledDriverPackage(resolvedDir, definition.Type)
|
||||
needsUpdate, updateReason, expectedRevision := optionalDriverAgentRevisionStatus(definition.Type, pkg, packageMetaExists)
|
||||
packageInstalled := definition.BuiltIn || packageMetaExists
|
||||
if runtimeAvailable && db.IsOptionalGoDriver(definition.Type) {
|
||||
packageInstalled = true
|
||||
}
|
||||
|
||||
item := driverStatusItem{
|
||||
Type: definition.Type,
|
||||
Name: definition.Name,
|
||||
Engine: engine,
|
||||
BuiltIn: definition.BuiltIn,
|
||||
PinnedVersion: definition.PinnedVersion,
|
||||
InstalledVersion: strings.TrimSpace(pkg.Version),
|
||||
PackageSizeText: resolveDriverPackageSizeText(definition, pkg, packageMetaExists, packageSizeBytesMap),
|
||||
RuntimeAvailable: runtimeAvailable,
|
||||
PackageInstalled: packageInstalled,
|
||||
Connectable: runtimeAvailable,
|
||||
DefaultDownloadURL: definition.DefaultDownloadURL,
|
||||
InstallDir: driverInstallDir(resolvedDir, definition.Type),
|
||||
Type: definition.Type,
|
||||
Name: definition.Name,
|
||||
Engine: engine,
|
||||
BuiltIn: definition.BuiltIn,
|
||||
PinnedVersion: definition.PinnedVersion,
|
||||
InstalledVersion: strings.TrimSpace(pkg.Version),
|
||||
PackageSizeText: resolveDriverPackageSizeText(definition, pkg, packageMetaExists, packageSizeBytesMap),
|
||||
RuntimeAvailable: runtimeAvailable,
|
||||
PackageInstalled: packageInstalled,
|
||||
Connectable: runtimeAvailable,
|
||||
DefaultDownloadURL: definition.DefaultDownloadURL,
|
||||
InstallDir: driverInstallDir(resolvedDir, definition.Type),
|
||||
AgentRevision: strings.TrimSpace(pkg.AgentRevision),
|
||||
ExpectedRevision: expectedRevision,
|
||||
NeedsUpdate: needsUpdate,
|
||||
UpdateReason: updateReason,
|
||||
AffectedConnections: usageCounts[normalizeDriverType(definition.Type)],
|
||||
}
|
||||
if packageMetaExists {
|
||||
item.PackagePath = pkg.FilePath
|
||||
@@ -792,6 +805,12 @@ func (a *App) GetDriverStatusList(downloadDir string, manifestURL string) connec
|
||||
switch {
|
||||
case definition.BuiltIn:
|
||||
item.Message = "内置驱动,可直接连接"
|
||||
case needsUpdate:
|
||||
if item.AffectedConnections > 0 {
|
||||
item.Message = fmt.Sprintf("%s;检测到 %d 个已保存连接正在使用该驱动,请在工具-驱动管理中重装", updateReason, item.AffectedConnections)
|
||||
} else {
|
||||
item.Message = updateReason + ",请在工具-驱动管理中重装"
|
||||
}
|
||||
case runtimeAvailable:
|
||||
item.Message = "纯 Go 驱动已启用,可直接连接"
|
||||
case packageInstalled && strings.TrimSpace(runtimeReason) != "":
|
||||
@@ -2702,6 +2721,47 @@ func readInstalledDriverPackage(downloadDir string, driverType string) (installe
|
||||
return meta, true
|
||||
}
|
||||
|
||||
func optionalDriverAgentRevisionStatus(driverType string, pkg installedDriverPackage, packageMetaExists bool) (bool, string, string) {
|
||||
expected := db.OptionalDriverAgentRevision(driverType)
|
||||
if strings.TrimSpace(expected) == "" || !packageMetaExists || !db.IsOptionalGoDriver(driverType) {
|
||||
return false, "", expected
|
||||
}
|
||||
actual := strings.TrimSpace(pkg.AgentRevision)
|
||||
if actual == expected {
|
||||
return false, "", expected
|
||||
}
|
||||
displayName := resolveDriverDisplayName(driverDefinition{Type: driverType})
|
||||
updateReason := fmt.Sprintf("当前 GoNavi 版本要求更新后的 %s driver-agent(revision: %s)", displayName, expected)
|
||||
impact := "driver-agent 是独立二进制,不会随主程序自动更新;如果不重装,会继续使用旧 agent 逻辑,驱动侧已修复或优化的行为不会生效,可能继续出现旧版本问题。强烈建议重装对应驱动代理"
|
||||
if actual == "" {
|
||||
return true, fmt.Sprintf("原因:%s。影响:%s", updateReason, impact), expected
|
||||
}
|
||||
return true, fmt.Sprintf("原因:%s。影响:%s(已安装标记:%s,当前需要:%s)", updateReason, impact, actual, expected), expected
|
||||
}
|
||||
|
||||
func (a *App) savedConnectionDriverUsageCounts() map[string]int {
|
||||
counts := map[string]int{}
|
||||
if a == nil || strings.TrimSpace(a.configDir) == "" {
|
||||
return counts
|
||||
}
|
||||
items, err := a.savedConnectionRepository().List()
|
||||
if err != nil {
|
||||
logger.Warnf("统计驱动连接使用数失败:%v", err)
|
||||
return counts
|
||||
}
|
||||
for _, item := range items {
|
||||
driverType := normalizeDriverType(item.Config.Type)
|
||||
if driverType == "custom" {
|
||||
driverType = normalizeDriverType(item.Config.Driver)
|
||||
}
|
||||
if driverType == "" || !db.IsOptionalGoDriver(driverType) {
|
||||
continue
|
||||
}
|
||||
counts[driverType]++
|
||||
}
|
||||
return counts
|
||||
}
|
||||
|
||||
func writeInstalledDriverPackage(downloadDir string, driverType string, meta installedDriverPackage) error {
|
||||
driverDir := driverInstallDir(downloadDir, driverType)
|
||||
if err := os.MkdirAll(driverDir, 0o755); err != nil {
|
||||
@@ -2765,9 +2825,11 @@ func installOptionalDriverAgentPackage(a *App, definition driverDefinition, sele
|
||||
if strings.TrimSpace(downloadSource) == "" {
|
||||
downloadSource = strings.TrimSpace(downloadURL)
|
||||
}
|
||||
agentRevision := probeInstalledOptionalDriverAgentRevision(driverType, runtimePath)
|
||||
return installedDriverPackage{
|
||||
DriverType: driverType,
|
||||
Version: strings.TrimSpace(selectedVersion),
|
||||
AgentRevision: agentRevision,
|
||||
FilePath: installPath,
|
||||
FileName: filepath.Base(installPath),
|
||||
ExecutablePath: runtimePath,
|
||||
@@ -2837,9 +2899,11 @@ func installOptionalDriverAgentFromLocalPath(definition driverDefinition, filePa
|
||||
if hashErr != nil {
|
||||
return installedDriverPackage{}, fmt.Errorf("计算 %s 驱动代理摘要失败:%w", displayName, hashErr)
|
||||
}
|
||||
agentRevision := probeInstalledOptionalDriverAgentRevision(driverType, executablePath)
|
||||
return installedDriverPackage{
|
||||
DriverType: driverType,
|
||||
Version: strings.TrimSpace(selectedVersion),
|
||||
AgentRevision: agentRevision,
|
||||
FilePath: sourcePath,
|
||||
FileName: sourceName,
|
||||
ExecutablePath: executablePath,
|
||||
@@ -2849,6 +2913,19 @@ func installOptionalDriverAgentFromLocalPath(definition driverDefinition, filePa
|
||||
}, nil
|
||||
}
|
||||
|
||||
func probeInstalledOptionalDriverAgentRevision(driverType string, executablePath string) string {
|
||||
expectedRevision := db.OptionalDriverAgentRevision(driverType)
|
||||
if strings.TrimSpace(expectedRevision) == "" {
|
||||
return ""
|
||||
}
|
||||
metadata, err := db.ProbeOptionalDriverAgentMetadata(driverType, executablePath)
|
||||
if err != nil {
|
||||
logger.Warnf("%s 驱动代理未返回版本元数据:%v", resolveDriverDisplayName(driverDefinition{Type: driverType}), err)
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(metadata.AgentRevision)
|
||||
}
|
||||
|
||||
type localDriverCandidate struct {
|
||||
absPath string
|
||||
relativePath string
|
||||
|
||||
72
internal/app/methods_driver_agent_revision_test.go
Normal file
72
internal/app/methods_driver_agent_revision_test.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
)
|
||||
|
||||
func TestOptionalDriverAgentRevisionStatusDetectsStaleClickHouseAgent(t *testing.T) {
|
||||
needsUpdate, reason, expected := optionalDriverAgentRevisionStatus("clickhouse", installedDriverPackage{}, true)
|
||||
if !needsUpdate {
|
||||
t.Fatal("expected missing ClickHouse agent revision to require update")
|
||||
}
|
||||
if expected == "" {
|
||||
t.Fatal("expected ClickHouse to define an agent revision")
|
||||
}
|
||||
if reason == "" {
|
||||
t.Fatal("expected update reason")
|
||||
}
|
||||
if !strings.Contains(reason, "原因:") || !strings.Contains(reason, "影响:") {
|
||||
t.Fatalf("expected reason to explain cause and impact, got %q", reason)
|
||||
}
|
||||
if !strings.Contains(reason, "强烈建议重装") {
|
||||
t.Fatalf("expected reason to strongly recommend reinstall, got %q", reason)
|
||||
}
|
||||
|
||||
current := installedDriverPackage{AgentRevision: expected}
|
||||
needsUpdate, reason, _ = optionalDriverAgentRevisionStatus("clickhouse", current, true)
|
||||
if needsUpdate {
|
||||
t.Fatalf("expected current ClickHouse agent revision to be accepted, reason=%q", reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSavedConnectionDriverUsageCountsIncludesOptionalAndCustomDrivers(t *testing.T) {
|
||||
app := &App{configDir: t.TempDir()}
|
||||
repo := app.savedConnectionRepository()
|
||||
if err := repo.saveAll([]connection.SavedConnectionView{
|
||||
{
|
||||
ID: "conn-clickhouse",
|
||||
Name: "ClickHouse",
|
||||
Config: connection.ConnectionConfig{
|
||||
Type: "clickhouse",
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "conn-custom-clickhouse",
|
||||
Name: "Custom ClickHouse",
|
||||
Config: connection.ConnectionConfig{
|
||||
Type: "custom",
|
||||
Driver: "clickhouse",
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "conn-mysql",
|
||||
Name: "MySQL",
|
||||
Config: connection.ConnectionConfig{
|
||||
Type: "mysql",
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("save connections failed: %v", err)
|
||||
}
|
||||
|
||||
counts := app.savedConnectionDriverUsageCounts()
|
||||
if got := counts["clickhouse"]; got != 2 {
|
||||
t.Fatalf("expected two ClickHouse usages, got %d", got)
|
||||
}
|
||||
if got := counts["mysql"]; got != 0 {
|
||||
t.Fatalf("expected built-in MySQL to be ignored, got %d", got)
|
||||
}
|
||||
}
|
||||
@@ -102,6 +102,7 @@ type ConnectionConfig struct {
|
||||
Timeout int `json:"timeout,omitempty"` // Connection timeout in seconds (default: 30)
|
||||
RedisDB int `json:"redisDB,omitempty"` // Redis database index (0-15)
|
||||
URI string `json:"uri,omitempty"` // Connection URI for copy/paste
|
||||
ClickHouseProtocol string `json:"clickHouseProtocol,omitempty"` // auto | http | native
|
||||
Hosts []string `json:"hosts,omitempty"` // Multi-host addresses: host:port
|
||||
Topology string `json:"topology,omitempty"` // single | replica | cluster
|
||||
MySQLReplicaUser string `json:"mysqlReplicaUser,omitempty"` // MySQL replica auth user
|
||||
@@ -184,9 +185,10 @@ type UpdateRow struct {
|
||||
|
||||
// ChangeSet 表示一组批量变更,包含新增、修改和删除操作。
|
||||
type ChangeSet struct {
|
||||
Inserts []map[string]interface{} `json:"inserts"`
|
||||
Updates []UpdateRow `json:"updates"`
|
||||
Deletes []map[string]interface{} `json:"deletes"`
|
||||
Inserts []map[string]interface{} `json:"inserts"`
|
||||
Updates []UpdateRow `json:"updates"`
|
||||
Deletes []map[string]interface{} `json:"deletes"`
|
||||
LocatorStrategy string `json:"locatorStrategy,omitempty"`
|
||||
}
|
||||
|
||||
// MongoMemberInfo 描述 MongoDB 副本集成员的信息。
|
||||
|
||||
@@ -12,6 +12,8 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"GoNavi-Wails/internal/logger"
|
||||
@@ -26,7 +28,11 @@ const (
|
||||
defaultClickHouseUser = "default"
|
||||
defaultClickHouseDatabase = "default"
|
||||
minClickHouseReadTimeout = 5 * time.Minute
|
||||
clickHouseHTTPPortHint = "8123/8132/8443"
|
||||
clickHouseHTTPPortHint = "8123/8125/8132/8443"
|
||||
|
||||
clickHouseProtocolAuto = "auto"
|
||||
clickHouseProtocolHTTP = "http"
|
||||
clickHouseProtocolNative = "native"
|
||||
)
|
||||
|
||||
type ClickHouseDB struct {
|
||||
@@ -38,6 +44,7 @@ type ClickHouseDB struct {
|
||||
|
||||
func normalizeClickHouseConfig(config connection.ConnectionConfig) connection.ConnectionConfig {
|
||||
normalized := applyClickHouseURI(config)
|
||||
normalized = applyClickHouseHostURI(normalized)
|
||||
if strings.TrimSpace(normalized.Host) == "" {
|
||||
normalized.Host = "localhost"
|
||||
}
|
||||
@@ -58,15 +65,26 @@ func applyClickHouseURI(config connection.ConnectionConfig) connection.Connectio
|
||||
if uriText == "" {
|
||||
return config
|
||||
}
|
||||
lowerURI := strings.ToLower(uriText)
|
||||
if !strings.HasPrefix(lowerURI, "clickhouse://") {
|
||||
return applyClickHouseEndpointURI(config, uriText, false)
|
||||
}
|
||||
|
||||
func applyClickHouseHostURI(config connection.ConnectionConfig) connection.ConnectionConfig {
|
||||
hostText := strings.TrimSpace(config.Host)
|
||||
if hostText == "" {
|
||||
return config
|
||||
}
|
||||
return applyClickHouseEndpointURI(config, hostText, true)
|
||||
}
|
||||
|
||||
func applyClickHouseEndpointURI(config connection.ConnectionConfig, uriText string, fromHostField bool) connection.ConnectionConfig {
|
||||
parsed, err := url.Parse(uriText)
|
||||
if err != nil {
|
||||
return config
|
||||
}
|
||||
scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
|
||||
if !isClickHouseSupportedEndpointScheme(scheme) || strings.TrimSpace(parsed.Host) == "" {
|
||||
return config
|
||||
}
|
||||
|
||||
if parsed.User != nil {
|
||||
if strings.TrimSpace(config.User) == "" {
|
||||
@@ -85,12 +103,28 @@ func applyClickHouseURI(config connection.ConnectionConfig) connection.Connectio
|
||||
config.Database = dbName
|
||||
}
|
||||
}
|
||||
if queryProtocol := normalizeClickHouseProtocol(parsed.Query().Get("protocol")); queryProtocol != clickHouseProtocolAuto {
|
||||
config.ClickHouseProtocol = queryProtocol
|
||||
}
|
||||
endpointProtocol := normalizeClickHouseProtocol(config.ClickHouseProtocol)
|
||||
if isClickHouseHTTPURLScheme(scheme) && endpointProtocol != clickHouseProtocolNative {
|
||||
config.ClickHouseProtocol = clickHouseProtocolHTTP
|
||||
if scheme == "https" {
|
||||
config.UseSSL = true
|
||||
if normalizeSSLModeValue(config.SSLMode) == sslModeDisable || strings.TrimSpace(config.SSLMode) == "" {
|
||||
config.SSLMode = sslModeRequired
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
defaultPort := config.Port
|
||||
if defaultPort <= 0 {
|
||||
defaultPort = defaultClickHousePort
|
||||
}
|
||||
if strings.TrimSpace(config.Host) == "" {
|
||||
if isClickHouseHTTPURLScheme(scheme) && endpointProtocol != clickHouseProtocolNative && defaultPort == defaultClickHousePort {
|
||||
defaultPort = defaultClickHousePortForScheme(scheme)
|
||||
}
|
||||
if fromHostField || strings.TrimSpace(config.Host) == "" {
|
||||
host, port, ok := parseHostPortWithDefault(parsed.Host, defaultPort)
|
||||
if ok {
|
||||
config.Host = host
|
||||
@@ -103,6 +137,30 @@ func applyClickHouseURI(config connection.ConnectionConfig) connection.Connectio
|
||||
return config
|
||||
}
|
||||
|
||||
func isClickHouseSupportedEndpointScheme(scheme string) bool {
|
||||
switch scheme {
|
||||
case "clickhouse", "http", "https":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isClickHouseHTTPURLScheme(scheme string) bool {
|
||||
return scheme == "http" || scheme == "https"
|
||||
}
|
||||
|
||||
func defaultClickHousePortForScheme(scheme string) int {
|
||||
switch scheme {
|
||||
case "http":
|
||||
return 8123
|
||||
case "https":
|
||||
return 8443
|
||||
default:
|
||||
return defaultClickHousePort
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClickHouseDB) buildClickHouseOptions(config connection.ConnectionConfig) *clickhouse.Options {
|
||||
connectTimeout := getConnectTimeout(config)
|
||||
readTimeout := connectTimeout
|
||||
@@ -130,6 +188,15 @@ func (c *ClickHouseDB) buildClickHouseOptions(config connection.ConnectionConfig
|
||||
}
|
||||
|
||||
func detectClickHouseProtocol(config connection.ConnectionConfig) clickhouse.Protocol {
|
||||
switch normalizeClickHouseProtocol(config.ClickHouseProtocol) {
|
||||
case clickHouseProtocolHTTP:
|
||||
return clickhouse.HTTP
|
||||
case clickHouseProtocolNative:
|
||||
return clickhouse.Native
|
||||
}
|
||||
if hasClickHouseHTTPScheme(config.URI) || hasClickHouseHTTPScheme(config.Host) {
|
||||
return clickhouse.HTTP
|
||||
}
|
||||
uriText := strings.ToLower(strings.TrimSpace(config.URI))
|
||||
if strings.HasPrefix(uriText, "http://") || strings.HasPrefix(uriText, "https://") {
|
||||
return clickhouse.HTTP
|
||||
@@ -140,9 +207,25 @@ func detectClickHouseProtocol(config connection.ConnectionConfig) clickhouse.Pro
|
||||
return clickhouse.Native
|
||||
}
|
||||
|
||||
func normalizeClickHouseProtocol(raw string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(raw)) {
|
||||
case clickHouseProtocolHTTP, "https":
|
||||
return clickHouseProtocolHTTP
|
||||
case clickHouseProtocolNative, "tcp":
|
||||
return clickHouseProtocolNative
|
||||
default:
|
||||
return clickHouseProtocolAuto
|
||||
}
|
||||
}
|
||||
|
||||
func hasClickHouseHTTPScheme(raw string) bool {
|
||||
text := strings.ToLower(strings.TrimSpace(raw))
|
||||
return strings.HasPrefix(text, "http://") || strings.HasPrefix(text, "https://")
|
||||
}
|
||||
|
||||
func isClickHouseHTTPPort(port int) bool {
|
||||
switch port {
|
||||
case 8123, 8132, 8443:
|
||||
case 8123, 8125, 8132, 8443:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
@@ -159,18 +242,88 @@ func isClickHouseProtocolMismatch(err error) bool {
|
||||
}
|
||||
return strings.Contains(text, "unexpected packet [72]") ||
|
||||
(strings.Contains(text, "unexpected packet") && strings.Contains(text, "handshake")) ||
|
||||
(strings.Contains(text, "cannot parse input") && strings.Contains(text, "expected '('")) ||
|
||||
strings.Contains(text, "http response to https client") ||
|
||||
strings.Contains(text, "malformed http response")
|
||||
}
|
||||
|
||||
func clickHouseProtocolName(protocol clickhouse.Protocol) string {
|
||||
if protocol == clickhouse.HTTP {
|
||||
return "HTTP"
|
||||
}
|
||||
return "Native"
|
||||
}
|
||||
|
||||
func sanitizeClickHouseErrorMessage(err error) string {
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
text := strings.ToValidUTF8(err.Error(), "<22>")
|
||||
var b strings.Builder
|
||||
lastSpace := false
|
||||
for _, r := range text {
|
||||
if r == utf8.RuneError || r == '<27>' {
|
||||
if !lastSpace {
|
||||
b.WriteByte(' ')
|
||||
lastSpace = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
if unicode.IsControl(r) {
|
||||
if !lastSpace {
|
||||
b.WriteByte(' ')
|
||||
lastSpace = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
b.WriteRune(r)
|
||||
lastSpace = unicode.IsSpace(r)
|
||||
}
|
||||
sanitized := strings.Join(strings.Fields(b.String()), " ")
|
||||
if len(sanitized) > 320 {
|
||||
return sanitized[:320] + "..."
|
||||
}
|
||||
return sanitized
|
||||
}
|
||||
|
||||
func clickHouseAttemptFailureMessage(protocol clickhouse.Protocol, err error) string {
|
||||
if isClickHouseProtocolMismatch(err) {
|
||||
if protocol == clickhouse.Native {
|
||||
return "服务端响应不像 Native 握手,当前端口更像 HTTP/HTTPS 端口;请选择 HTTP 协议,或确认 ClickHouse Native 端口"
|
||||
}
|
||||
return "服务端响应不像 HTTP 响应,当前端口更像 Native 端口;请选择 Native 协议,或确认 ClickHouse HTTP 端口"
|
||||
}
|
||||
message := sanitizeClickHouseErrorMessage(err)
|
||||
if message == "" {
|
||||
return "未知错误"
|
||||
}
|
||||
return message
|
||||
}
|
||||
|
||||
func clickHouseConnectFailureSummary(config connection.ConnectionConfig, failures []string) string {
|
||||
protocolMode := normalizeClickHouseProtocol(config.ClickHouseProtocol)
|
||||
detail := strings.Join(failures, ";")
|
||||
if strings.TrimSpace(detail) == "" {
|
||||
detail = "未获取到驱动返回的错误详情"
|
||||
}
|
||||
if protocolMode != clickHouseProtocolAuto {
|
||||
return fmt.Sprintf("ClickHouse 连接验证失败:已按用户选择使用 %s 协议连接 %s:%d。%s",
|
||||
strings.ToUpper(protocolMode), config.Host, config.Port, detail)
|
||||
}
|
||||
return fmt.Sprintf("ClickHouse 连接验证失败:自动协议探测未成功(Native 常见端口 9000/9440,HTTP 常见端口 %s;非标端口建议在连接协议中手动指定)。%s",
|
||||
clickHouseHTTPPortHint, detail)
|
||||
}
|
||||
|
||||
func withClickHouseProtocol(config connection.ConnectionConfig, protocol clickhouse.Protocol) connection.ConnectionConfig {
|
||||
next := config
|
||||
switch protocol {
|
||||
case clickhouse.HTTP:
|
||||
next.ClickHouseProtocol = clickHouseProtocolHTTP
|
||||
if next.Port == 0 {
|
||||
next.Port = 8123
|
||||
}
|
||||
default:
|
||||
next.ClickHouseProtocol = clickHouseProtocolNative
|
||||
if next.Port == 0 {
|
||||
next.Port = defaultClickHousePort
|
||||
}
|
||||
@@ -178,6 +331,17 @@ func withClickHouseProtocol(config connection.ConnectionConfig, protocol clickho
|
||||
return next
|
||||
}
|
||||
|
||||
func clickHouseProtocolsForAttempt(config connection.ConnectionConfig) []clickhouse.Protocol {
|
||||
primaryProtocol := detectClickHouseProtocol(config)
|
||||
if normalizeClickHouseProtocol(config.ClickHouseProtocol) != clickHouseProtocolAuto {
|
||||
return []clickhouse.Protocol{primaryProtocol}
|
||||
}
|
||||
if primaryProtocol == clickhouse.Native {
|
||||
return []clickhouse.Protocol{primaryProtocol, clickhouse.HTTP}
|
||||
}
|
||||
return []clickhouse.Protocol{primaryProtocol, clickhouse.Native}
|
||||
}
|
||||
|
||||
func (c *ClickHouseDB) Connect(config connection.ConnectionConfig) error {
|
||||
if supported, reason := DriverRuntimeSupportStatus("clickhouse"); !supported {
|
||||
if strings.TrimSpace(reason) == "" {
|
||||
@@ -198,8 +362,14 @@ func (c *ClickHouseDB) Connect(config connection.ConnectionConfig) error {
|
||||
runConfig := normalizeClickHouseConfig(config)
|
||||
c.pingTimeout = getConnectTimeout(runConfig)
|
||||
c.database = runConfig.Database
|
||||
logger.Infof("ClickHouse 连接准备:地址=%s:%d 数据库=%s 用户=%s 协议选择=%s SSL=%t SSH=%t 超时=%s",
|
||||
runConfig.Host, runConfig.Port, runConfig.Database, runConfig.User,
|
||||
normalizeClickHouseProtocol(runConfig.ClickHouseProtocol), runConfig.UseSSL, runConfig.UseSSH, c.pingTimeout)
|
||||
|
||||
if runConfig.UseSSH {
|
||||
if normalizeClickHouseProtocol(runConfig.ClickHouseProtocol) == clickHouseProtocolAuto && detectClickHouseProtocol(runConfig) == clickhouse.HTTP {
|
||||
runConfig.ClickHouseProtocol = clickHouseProtocolHTTP
|
||||
}
|
||||
logger.Infof("ClickHouse 使用 SSH 连接:地址=%s:%d 用户=%s", runConfig.Host, runConfig.Port, runConfig.User)
|
||||
forwarder, err := ssh.GetOrCreateLocalForwarder(runConfig.SSH, runConfig.Host, runConfig.Port)
|
||||
if err != nil {
|
||||
@@ -229,19 +399,17 @@ func (c *ClickHouseDB) Connect(config connection.ConnectionConfig) error {
|
||||
|
||||
var failures []string
|
||||
for idx, attempt := range attempts {
|
||||
primaryProtocol := detectClickHouseProtocol(attempt)
|
||||
protocols := []clickhouse.Protocol{primaryProtocol}
|
||||
if primaryProtocol == clickhouse.Native {
|
||||
protocols = append(protocols, clickhouse.HTTP)
|
||||
} else {
|
||||
protocols = append(protocols, clickhouse.Native)
|
||||
}
|
||||
|
||||
protocols := clickHouseProtocolsForAttempt(attempt)
|
||||
for pIdx, protocol := range protocols {
|
||||
protocolConfig := withClickHouseProtocol(attempt, protocol)
|
||||
logger.Infof("ClickHouse 连接尝试:第%d组/%d 协议=%s 地址=%s:%d SSL=%t",
|
||||
idx+1, len(attempts), clickHouseProtocolName(protocol), protocolConfig.Host, protocolConfig.Port, protocolConfig.UseSSL)
|
||||
c.conn = clickhouse.OpenDB(c.buildClickHouseOptions(protocolConfig))
|
||||
if err := c.Ping(); err != nil {
|
||||
failures = append(failures, fmt.Sprintf("第%d次连接验证失败(protocol=%s): %v", idx+1, protocol.String(), err))
|
||||
failureMessage := clickHouseAttemptFailureMessage(protocol, err)
|
||||
failures = append(failures, fmt.Sprintf("第%d次连接验证失败(protocol=%s): %s", idx+1, protocol.String(), failureMessage))
|
||||
logger.Warnf("ClickHouse 连接尝试失败:第%d组/%d 协议=%s 地址=%s:%d SSL=%t 原因=%s",
|
||||
idx+1, len(attempts), clickHouseProtocolName(protocol), protocolConfig.Host, protocolConfig.Port, protocolConfig.UseSSL, failureMessage)
|
||||
if c.conn != nil {
|
||||
_ = c.conn.Close()
|
||||
c.conn = nil
|
||||
@@ -258,12 +426,13 @@ func (c *ClickHouseDB) Connect(config connection.ConnectionConfig) error {
|
||||
if pIdx > 0 {
|
||||
logger.Warnf("ClickHouse 已自动切换连接协议为 %s(常见于 %s HTTP 端口)", protocol.String(), clickHouseHTTPPortHint)
|
||||
}
|
||||
logger.Infof("ClickHouse 连接验证成功:协议=%s 地址=%s:%d 数据库=%s", clickHouseProtocolName(protocol), protocolConfig.Host, protocolConfig.Port, protocolConfig.Database)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
_ = c.Close()
|
||||
return fmt.Errorf("连接建立后验证失败(可检查 ClickHouse 端口与协议是否匹配:Native=9000/9440,HTTP=%s):%s", clickHouseHTTPPortHint, strings.Join(failures, ";"))
|
||||
return fmt.Errorf("%s", clickHouseConnectFailureSummary(runConfig, failures))
|
||||
}
|
||||
|
||||
func (c *ClickHouseDB) Close() error {
|
||||
|
||||
@@ -160,6 +160,13 @@ func TestDetectClickHouseProtocolTreatsHTTPPortsAsHTTP(t *testing.T) {
|
||||
},
|
||||
expected: clickhouse.HTTP,
|
||||
},
|
||||
{
|
||||
name: "custom http port 8125",
|
||||
config: connection.ConnectionConfig{
|
||||
Port: 8125,
|
||||
},
|
||||
expected: clickhouse.HTTP,
|
||||
},
|
||||
{
|
||||
name: "https port",
|
||||
config: connection.ConnectionConfig{
|
||||
@@ -181,6 +188,30 @@ func TestDetectClickHouseProtocolTreatsHTTPPortsAsHTTP(t *testing.T) {
|
||||
},
|
||||
expected: clickhouse.Native,
|
||||
},
|
||||
{
|
||||
name: "host http scheme",
|
||||
config: connection.ConnectionConfig{
|
||||
Host: "http://clickhouse.example.com",
|
||||
Port: 8125,
|
||||
},
|
||||
expected: clickhouse.HTTP,
|
||||
},
|
||||
{
|
||||
name: "manual http overrides native port",
|
||||
config: connection.ConnectionConfig{
|
||||
ClickHouseProtocol: "http",
|
||||
Port: 9000,
|
||||
},
|
||||
expected: clickhouse.HTTP,
|
||||
},
|
||||
{
|
||||
name: "manual native overrides http port",
|
||||
config: connection.ConnectionConfig{
|
||||
ClickHouseProtocol: "native",
|
||||
Port: 8123,
|
||||
},
|
||||
expected: clickhouse.Native,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -192,6 +223,172 @@ func TestDetectClickHouseProtocolTreatsHTTPPortsAsHTTP(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClickHouseConfigParsesHTTPHostScheme(t *testing.T) {
|
||||
config := normalizeClickHouseConfig(connection.ConnectionConfig{
|
||||
Type: "clickhouse",
|
||||
Host: "https://clickhouse.example.com:8125/analytics",
|
||||
User: "alice",
|
||||
Password: "secret",
|
||||
})
|
||||
|
||||
if config.Host != "clickhouse.example.com" {
|
||||
t.Fatalf("expected host without scheme, got %q", config.Host)
|
||||
}
|
||||
if config.Port != 8125 {
|
||||
t.Fatalf("expected port 8125, got %d", config.Port)
|
||||
}
|
||||
if config.Database != "analytics" {
|
||||
t.Fatalf("expected database analytics, got %q", config.Database)
|
||||
}
|
||||
if config.ClickHouseProtocol != "http" {
|
||||
t.Fatalf("expected http protocol hint, got %q", config.ClickHouseProtocol)
|
||||
}
|
||||
if !config.UseSSL || config.SSLMode != sslModeRequired {
|
||||
t.Fatalf("expected https host to enable required SSL, got useSSL=%v sslMode=%q", config.UseSSL, config.SSLMode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClickHouseConfigKeepsManualNativeWhenHostHasHTTPScheme(t *testing.T) {
|
||||
config := normalizeClickHouseConfig(connection.ConnectionConfig{
|
||||
Type: "clickhouse",
|
||||
Host: "http://clickhouse.example.com:9001/analytics",
|
||||
ClickHouseProtocol: "native",
|
||||
User: "alice",
|
||||
Password: "secret",
|
||||
})
|
||||
|
||||
if config.Host != "clickhouse.example.com" {
|
||||
t.Fatalf("expected host without scheme, got %q", config.Host)
|
||||
}
|
||||
if config.Port != 9001 {
|
||||
t.Fatalf("expected user-provided native port 9001, got %d", config.Port)
|
||||
}
|
||||
if config.Database != "analytics" {
|
||||
t.Fatalf("expected database analytics, got %q", config.Database)
|
||||
}
|
||||
if config.ClickHouseProtocol != "native" {
|
||||
t.Fatalf("expected manual native protocol to be preserved, got %q", config.ClickHouseProtocol)
|
||||
}
|
||||
if config.UseSSL {
|
||||
t.Fatalf("manual native protocol should not be forced to HTTP TLS by http scheme")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClickHouseConfigUsesNativeDefaultPortForManualNativeHTTPScheme(t *testing.T) {
|
||||
config := normalizeClickHouseConfig(connection.ConnectionConfig{
|
||||
Type: "clickhouse",
|
||||
Host: "https://clickhouse.example.com/analytics",
|
||||
ClickHouseProtocol: "native",
|
||||
})
|
||||
|
||||
if config.Host != "clickhouse.example.com" {
|
||||
t.Fatalf("expected host without scheme, got %q", config.Host)
|
||||
}
|
||||
if config.Port != defaultClickHousePort {
|
||||
t.Fatalf("expected native default port %d, got %d", defaultClickHousePort, config.Port)
|
||||
}
|
||||
if config.ClickHouseProtocol != "native" {
|
||||
t.Fatalf("expected manual native protocol to be preserved, got %q", config.ClickHouseProtocol)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClickHouseProtocolMismatchIncludesHTTPParseBinaryResponse(t *testing.T) {
|
||||
err := errors.New("code: 27, message: Cannot parse input: expected '(' before: '\x02\x00\x01\x00'")
|
||||
if !isClickHouseProtocolMismatch(err) {
|
||||
t.Fatalf("expected binary parse response to be treated as protocol mismatch")
|
||||
}
|
||||
|
||||
message := clickHouseAttemptFailureMessage(clickhouse.Native, err)
|
||||
if !strings.Contains(message, "不像 Native") || strings.Contains(message, "\x00") {
|
||||
t.Fatalf("expected user-facing native mismatch message without binary bytes, got %q", message)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithClickHouseProtocolForcesProtocolSelection(t *testing.T) {
|
||||
httpConfig := withClickHouseProtocol(connection.ConnectionConfig{
|
||||
Type: "clickhouse",
|
||||
Host: "clickhouse.example.com",
|
||||
Port: 8125,
|
||||
}, clickhouse.HTTP)
|
||||
if protocol := detectClickHouseProtocol(httpConfig); protocol != clickhouse.HTTP {
|
||||
t.Fatalf("expected forced HTTP protocol, got %s", protocol.String())
|
||||
}
|
||||
|
||||
nativeConfig := withClickHouseProtocol(connection.ConnectionConfig{
|
||||
Type: "clickhouse",
|
||||
Host: "http://clickhouse.example.com",
|
||||
Port: 8125,
|
||||
}, clickhouse.Native)
|
||||
if protocol := detectClickHouseProtocol(nativeConfig); protocol != clickhouse.Native {
|
||||
t.Fatalf("expected forced Native protocol, got %s", protocol.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestClickHouseProtocolsForAttemptOnlyFallsBackInAutoMode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config connection.ConnectionConfig
|
||||
expected []clickhouse.Protocol
|
||||
}{
|
||||
{
|
||||
name: "auto native falls back to http",
|
||||
config: connection.ConnectionConfig{
|
||||
Type: "clickhouse",
|
||||
Port: 9000,
|
||||
},
|
||||
expected: []clickhouse.Protocol{clickhouse.Native, clickhouse.HTTP},
|
||||
},
|
||||
{
|
||||
name: "auto http falls back to native",
|
||||
config: connection.ConnectionConfig{
|
||||
Type: "clickhouse",
|
||||
Port: 8125,
|
||||
},
|
||||
expected: []clickhouse.Protocol{clickhouse.HTTP, clickhouse.Native},
|
||||
},
|
||||
{
|
||||
name: "manual http does not try native",
|
||||
config: connection.ConnectionConfig{
|
||||
Type: "clickhouse",
|
||||
Port: 9000,
|
||||
ClickHouseProtocol: "http",
|
||||
},
|
||||
expected: []clickhouse.Protocol{clickhouse.HTTP},
|
||||
},
|
||||
{
|
||||
name: "manual native does not try http",
|
||||
config: connection.ConnectionConfig{
|
||||
Type: "clickhouse",
|
||||
Port: 8125,
|
||||
ClickHouseProtocol: "native",
|
||||
},
|
||||
expected: []clickhouse.Protocol{clickhouse.Native},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := clickHouseProtocolsForAttempt(tt.config)
|
||||
if len(got) != len(tt.expected) {
|
||||
t.Fatalf("expected protocols %v, got %v", protocolNames(tt.expected), protocolNames(got))
|
||||
}
|
||||
for idx := range got {
|
||||
if got[idx] != tt.expected[idx] {
|
||||
t.Fatalf("expected protocols %v, got %v", protocolNames(tt.expected), protocolNames(got))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func protocolNames(protocols []clickhouse.Protocol) []string {
|
||||
names := make([]string, 0, len(protocols))
|
||||
for _, protocol := range protocols {
|
||||
names = append(names, protocol.String())
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
type fakeClickHouseDriver struct{}
|
||||
|
||||
func (fakeClickHouseDriver) Open(name string) (driver.Conn, error) {
|
||||
|
||||
@@ -3,6 +3,7 @@ package db
|
||||
import (
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
@@ -64,6 +65,20 @@ type BatchApplier interface {
|
||||
ApplyChanges(tableName string, changes connection.ChangeSet) error
|
||||
}
|
||||
|
||||
func requireSingleRowAffected(result sql.Result, action string) error {
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s未生效:无法确认影响行数:%v", action, err)
|
||||
}
|
||||
if affected == 0 {
|
||||
return fmt.Errorf("%s未生效:未匹配到任何行", action)
|
||||
}
|
||||
if affected != 1 {
|
||||
return fmt.Errorf("%s未生效:影响了 %d 行,期望只影响 1 行", action, affected)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type databaseFactory func() Database
|
||||
|
||||
var databaseFactories = map[string]databaseFactory{
|
||||
|
||||
21
internal/db/driver_agent_revisions_gen.go
Normal file
21
internal/db/driver_agent_revisions_gen.go
Normal file
@@ -0,0 +1,21 @@
|
||||
// Code generated by tools/generate-driver-agent-revisions.sh; DO NOT EDIT.
|
||||
|
||||
package db
|
||||
|
||||
func init() {
|
||||
optionalDriverAgentRevisions = map[string]string{
|
||||
"mariadb": "src-d6c5c6717338834c",
|
||||
"diros": "src-ed4f0f64ed28d3fa",
|
||||
"sphinx": "src-f52324f0a812d7c8",
|
||||
"sqlserver": "src-ec165f18de9cd8b3",
|
||||
"sqlite": "src-9dea6c76bc931114",
|
||||
"duckdb": "src-14027ac1de3c50c7",
|
||||
"dameng": "src-1a08880ff5bbcf31",
|
||||
"kingbase": "src-28eed0e4d942b724",
|
||||
"highgo": "src-76146bf97f07f25c",
|
||||
"vastbase": "src-555b60c4863542b6",
|
||||
"mongodb": "src-2540a7350c4243aa",
|
||||
"tdengine": "src-ce3e4a9c46f6b92d",
|
||||
"clickhouse": "src-78e5ada4da56704d",
|
||||
}
|
||||
}
|
||||
@@ -36,6 +36,11 @@ var optionalGoDrivers = map[string]struct{}{
|
||||
"clickhouse": {},
|
||||
}
|
||||
|
||||
// optionalDriverAgentRevisions 记录 GoNavi 对各可选 driver-agent 包装逻辑的兼容版本。
|
||||
// 该 map 由 tools/generate-driver-agent-revisions.sh 按 driver-agent 源码依赖自动生成,
|
||||
// 避免人工判断需要 bump 哪个驱动 revision。
|
||||
var optionalDriverAgentRevisions = map[string]string{}
|
||||
|
||||
var (
|
||||
externalDriverDirMu sync.RWMutex
|
||||
externalDriverDir string
|
||||
@@ -105,6 +110,10 @@ func IsOptionalGoDriverBuildIncluded(driverType string) bool {
|
||||
return optionalGoDriverBuildIncluded(normalizeRuntimeDriverType(driverType))
|
||||
}
|
||||
|
||||
func OptionalDriverAgentRevision(driverType string) string {
|
||||
return strings.TrimSpace(optionalDriverAgentRevisions[normalizeRuntimeDriverType(driverType)])
|
||||
}
|
||||
|
||||
// IsBuiltinDriver 返回指定驱动类型是否为核心内置驱动(始终可用,无需安装)。
|
||||
func IsBuiltinDriver(driverType string) bool {
|
||||
_, ok := coreBuiltinDrivers[normalizeRuntimeDriverType(driverType)]
|
||||
|
||||
@@ -31,6 +31,17 @@ func TestBuiltinLikeDriversRemainAvailable(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestOptionalDriverAgentRevisionsGeneratedForOptionalDrivers(t *testing.T) {
|
||||
for driverType := range optionalGoDrivers {
|
||||
if revision := OptionalDriverAgentRevision(driverType); revision == "" {
|
||||
t.Fatalf("%s 缺少自动生成的 driver-agent revision", driverType)
|
||||
}
|
||||
}
|
||||
if OptionalDriverAgentRevision("doris") != OptionalDriverAgentRevision("diros") {
|
||||
t.Fatalf("doris/diros revision 应归一一致")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagedDriverRequiresInstallMarker(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
SetExternalDriverDownloadDirectory(tmpDir)
|
||||
|
||||
@@ -624,8 +624,8 @@ func (m *MySQLDB) ApplyChanges(tableName string, changes connection.ChangeSet) e
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除失败:%v", err)
|
||||
}
|
||||
if affected, err := res.RowsAffected(); err == nil && affected == 0 {
|
||||
return fmt.Errorf("删除未生效:未匹配到任何行")
|
||||
if err := requireSingleRowAffected(res, "删除"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -658,8 +658,8 @@ func (m *MySQLDB) ApplyChanges(tableName string, changes connection.ChangeSet) e
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新失败:%v", err)
|
||||
}
|
||||
if affected, err := res.RowsAffected(); err == nil && affected == 0 {
|
||||
return fmt.Errorf("更新未生效:未匹配到任何行")
|
||||
if err := requireSingleRowAffected(res, "更新"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
const (
|
||||
optionalAgentMethodConnect = "connect"
|
||||
optionalAgentMethodClose = "close"
|
||||
optionalAgentMethodMetadata = "metadata"
|
||||
optionalAgentMethodPing = "ping"
|
||||
optionalAgentMethodQuery = "query"
|
||||
optionalAgentMethodExec = "exec"
|
||||
@@ -58,6 +59,12 @@ type optionalAgentResponse struct {
|
||||
RowsAffected int64 `json:"rowsAffected,omitempty"`
|
||||
}
|
||||
|
||||
type OptionalDriverAgentMetadata struct {
|
||||
DriverType string `json:"driverType,omitempty"`
|
||||
AgentRevision string `json:"agentRevision,omitempty"`
|
||||
ProtocolSchema string `json:"protocolSchema,omitempty"`
|
||||
}
|
||||
|
||||
type optionalDriverAgentClient struct {
|
||||
cmd *exec.Cmd
|
||||
stdin io.WriteCloser
|
||||
@@ -69,6 +76,25 @@ type optionalDriverAgentClient struct {
|
||||
driver string
|
||||
}
|
||||
|
||||
func ProbeOptionalDriverAgentMetadata(driverType string, executablePath string) (OptionalDriverAgentMetadata, error) {
|
||||
client, err := newOptionalDriverAgentClient(driverType, executablePath)
|
||||
if err != nil {
|
||||
return OptionalDriverAgentMetadata{}, err
|
||||
}
|
||||
defer func() {
|
||||
_ = client.close()
|
||||
}()
|
||||
|
||||
var metadata OptionalDriverAgentMetadata
|
||||
if err := client.call(optionalAgentRequest{Method: optionalAgentMethodMetadata}, &metadata, nil, nil); err != nil {
|
||||
return OptionalDriverAgentMetadata{}, err
|
||||
}
|
||||
metadata.DriverType = normalizeRuntimeDriverType(metadata.DriverType)
|
||||
metadata.AgentRevision = strings.TrimSpace(metadata.AgentRevision)
|
||||
metadata.ProtocolSchema = strings.TrimSpace(metadata.ProtocolSchema)
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
func newOptionalDriverAgentClient(driverType string, executablePath string) (*optionalDriverAgentClient, error) {
|
||||
pathText := strings.TrimSpace(executablePath)
|
||||
if pathText == "" {
|
||||
|
||||
@@ -24,8 +24,16 @@ var (
|
||||
)
|
||||
|
||||
type oracleRecordingState struct {
|
||||
mu sync.Mutex
|
||||
execArgs [][]driver.NamedValue
|
||||
mu sync.Mutex
|
||||
execQueries []string
|
||||
execArgs [][]driver.NamedValue
|
||||
rowsAffected int64
|
||||
}
|
||||
|
||||
func (s *oracleRecordingState) snapshotExecQueries() []string {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return append([]string(nil), s.execQueries...)
|
||||
}
|
||||
|
||||
func (s *oracleRecordingState) snapshotExecArgs() [][]driver.NamedValue {
|
||||
@@ -63,11 +71,12 @@ func (c *oracleRecordingConn) Close() error { return nil }
|
||||
|
||||
func (c *oracleRecordingConn) Begin() (driver.Tx, error) { return oracleRecordingTx{}, nil }
|
||||
|
||||
func (c *oracleRecordingConn) ExecContext(_ context.Context, _ string, args []driver.NamedValue) (driver.Result, error) {
|
||||
func (c *oracleRecordingConn) ExecContext(_ context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
|
||||
c.state.mu.Lock()
|
||||
defer c.state.mu.Unlock()
|
||||
c.state.execQueries = append(c.state.execQueries, query)
|
||||
c.state.execArgs = append(c.state.execArgs, append([]driver.NamedValue(nil), args...))
|
||||
return driver.RowsAffected(1), nil
|
||||
return driver.RowsAffected(c.state.rowsAffected), nil
|
||||
}
|
||||
|
||||
func (c *oracleRecordingConn) QueryContext(_ context.Context, query string, _ []driver.NamedValue) (driver.Rows, error) {
|
||||
@@ -126,7 +135,7 @@ func openOracleRecordingDB(t *testing.T) (*sql.DB, *oracleRecordingState) {
|
||||
oracleRecordingDriverMu.Lock()
|
||||
oracleRecordingDriverSeq++
|
||||
dsn := fmt.Sprintf("oracle-recording-%d", oracleRecordingDriverSeq)
|
||||
state := &oracleRecordingState{}
|
||||
state := &oracleRecordingState{rowsAffected: 1}
|
||||
oracleRecordingDriverStates[dsn] = state
|
||||
oracleRecordingDriverMu.Unlock()
|
||||
|
||||
@@ -145,6 +154,82 @@ func openOracleRecordingDB(t *testing.T) (*sql.DB, *oracleRecordingState) {
|
||||
return dbConn, state
|
||||
}
|
||||
|
||||
func TestOracleApplyChangesReturnsErrorWhenUpdateMatchesNoRows(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbConn, state := openOracleRecordingDB(t)
|
||||
state.rowsAffected = 0
|
||||
oracleDB := &OracleDB{conn: dbConn}
|
||||
|
||||
changes := connection.ChangeSet{
|
||||
Updates: []connection.UpdateRow{{
|
||||
Keys: map[string]interface{}{
|
||||
"ID": 7,
|
||||
},
|
||||
Values: map[string]interface{}{
|
||||
"NAME": "new-name",
|
||||
},
|
||||
}},
|
||||
}
|
||||
|
||||
err := oracleDB.ApplyChanges("MYCIMLED.EDC_LOG", changes)
|
||||
if err == nil {
|
||||
t.Fatal("期望更新未匹配到行时返回错误,实际为 nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "更新未生效") {
|
||||
t.Fatalf("错误信息应提示更新未生效,实际=%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOracleApplyChangesReturnsErrorWhenUpdateAffectsMultipleRows(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbConn, state := openOracleRecordingDB(t)
|
||||
state.rowsAffected = 2
|
||||
oracleDB := &OracleDB{conn: dbConn}
|
||||
|
||||
changes := connection.ChangeSet{
|
||||
Updates: []connection.UpdateRow{{
|
||||
Keys: map[string]interface{}{
|
||||
"ID": 7,
|
||||
},
|
||||
Values: map[string]interface{}{
|
||||
"NAME": "new-name",
|
||||
},
|
||||
}},
|
||||
}
|
||||
|
||||
err := oracleDB.ApplyChanges("MYCIMLED.EDC_LOG", changes)
|
||||
if err == nil {
|
||||
t.Fatal("期望更新影响多行时返回错误,实际为 nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "影响了 2 行") {
|
||||
t.Fatalf("错误信息应提示影响多行,实际=%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOracleApplyChangesReturnsErrorWhenDeleteAffectsMultipleRows(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbConn, state := openOracleRecordingDB(t)
|
||||
state.rowsAffected = 2
|
||||
oracleDB := &OracleDB{conn: dbConn}
|
||||
|
||||
changes := connection.ChangeSet{
|
||||
Deletes: []map[string]interface{}{{
|
||||
"STATUS": "stale",
|
||||
}},
|
||||
}
|
||||
|
||||
err := oracleDB.ApplyChanges("MYCIMLED.EDC_LOG", changes)
|
||||
if err == nil {
|
||||
t.Fatal("期望删除影响多行时返回错误,实际为 nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "影响了 2 行") {
|
||||
t.Fatalf("错误信息应提示影响多行,实际=%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOracleApplyChangesNormalizesTemporalStringsForUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -181,3 +266,87 @@ func TestOracleApplyChangesNormalizesTemporalStringsForUpdate(t *testing.T) {
|
||||
t.Fatalf("日期主键字段应绑定为 time.Time,实际=%#v(%T)", args[1].Value, args[1].Value)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOracleApplyChangesUsesUnquotedRowIDLocator(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbConn, state := openOracleRecordingDB(t)
|
||||
oracleDB := &OracleDB{conn: dbConn}
|
||||
|
||||
changes := connection.ChangeSet{
|
||||
LocatorStrategy: "oracle-rowid",
|
||||
Updates: []connection.UpdateRow{{
|
||||
Keys: map[string]interface{}{
|
||||
"ROWID": "AAAA",
|
||||
},
|
||||
Values: map[string]interface{}{
|
||||
"NAME": "new-name",
|
||||
},
|
||||
}},
|
||||
}
|
||||
|
||||
if err := oracleDB.ApplyChanges("MYCIMLED.EDC_LOG", changes); err != nil {
|
||||
t.Fatalf("ApplyChanges 返回错误: %v", err)
|
||||
}
|
||||
|
||||
executions := state.snapshotExecQueries()
|
||||
if len(executions) != 1 {
|
||||
t.Fatalf("期望执行 1 条更新,实际 %d 条", len(executions))
|
||||
}
|
||||
query := executions[0]
|
||||
if !strings.Contains(query, "ROWID = :2") {
|
||||
t.Fatalf("ROWID 定位条件不正确: %s", query)
|
||||
}
|
||||
if strings.Contains(query, "\"ROWID\" =") {
|
||||
t.Fatalf("ROWID 不应被当作普通列引用: %s", query)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMySQLApplyChangesReturnsErrorWhenUpdateAffectsMultipleRows(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbConn, state := openOracleRecordingDB(t)
|
||||
state.rowsAffected = 2
|
||||
mysqlDB := &MySQLDB{conn: dbConn}
|
||||
|
||||
changes := connection.ChangeSet{
|
||||
Updates: []connection.UpdateRow{{
|
||||
Keys: map[string]interface{}{
|
||||
"id": 7,
|
||||
},
|
||||
Values: map[string]interface{}{
|
||||
"name": "new-name",
|
||||
},
|
||||
}},
|
||||
}
|
||||
|
||||
err := mysqlDB.ApplyChanges("users", changes)
|
||||
if err == nil {
|
||||
t.Fatal("期望 MySQL 更新影响多行时返回错误,实际为 nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "影响了 2 行") {
|
||||
t.Fatalf("错误信息应提示影响多行,实际=%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgresApplyChangesReturnsErrorWhenDeleteAffectsMultipleRows(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbConn, state := openOracleRecordingDB(t)
|
||||
state.rowsAffected = 2
|
||||
postgresDB := &PostgresDB{conn: dbConn}
|
||||
|
||||
changes := connection.ChangeSet{
|
||||
Deletes: []map[string]interface{}{{
|
||||
"id": 7,
|
||||
}},
|
||||
}
|
||||
|
||||
err := postgresDB.ApplyChanges("public.users", changes)
|
||||
if err == nil {
|
||||
t.Fatal("期望 PostgreSQL 删除影响多行时返回错误,实际为 nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "影响了 2 行") {
|
||||
t.Fatalf("错误信息应提示影响多行,实际=%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
32
internal/db/oracle_dsn_test.go
Normal file
32
internal/db/oracle_dsn_test.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
)
|
||||
|
||||
func TestOracleGetDSNIncludesQueryPerformanceOptions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dsn := (&OracleDB{}).getDSN(connection.ConnectionConfig{
|
||||
Host: "db.example.com",
|
||||
Port: 1521,
|
||||
User: "scott",
|
||||
Password: "tiger",
|
||||
Database: "ORCLPDB1",
|
||||
})
|
||||
|
||||
parsed, err := url.Parse(dsn)
|
||||
if err != nil {
|
||||
t.Fatalf("解析 Oracle DSN 失败: %v", err)
|
||||
}
|
||||
query := parsed.Query()
|
||||
if got := query.Get("PREFETCH_ROWS"); got != "10000" {
|
||||
t.Fatalf("PREFETCH_ROWS = %q, want 10000", got)
|
||||
}
|
||||
if got := query.Get("LOB FETCH"); got != "POST" {
|
||||
t.Fatalf("LOB FETCH = %q, want POST", got)
|
||||
}
|
||||
}
|
||||
@@ -44,6 +44,10 @@ func (o *OracleDB) getDSN(config connection.ConnectionConfig) string {
|
||||
q.Set("SSL", "TRUE")
|
||||
q.Set("SSL VERIFY", "FALSE")
|
||||
}
|
||||
// 提高 prefetch 行数,减少大结果集的网络往返次数(默认仅 25 行/次)
|
||||
q.Set("PREFETCH_ROWS", "10000")
|
||||
// LOB 数据延迟加载,避免大 LOB 列影响普通查询性能
|
||||
q.Set("LOB FETCH", "POST")
|
||||
if encoded := q.Encode(); encoded != "" {
|
||||
u.RawQuery = encoded
|
||||
}
|
||||
@@ -263,16 +267,31 @@ func (o *OracleDB) GetCreateStatement(dbName, tableName string) (string, error)
|
||||
}
|
||||
|
||||
func (o *OracleDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
|
||||
query := fmt.Sprintf(`SELECT column_name, data_type, nullable, data_default
|
||||
FROM all_tab_columns
|
||||
WHERE owner = '%s' AND table_name = '%s'
|
||||
ORDER BY column_id`, strings.ToUpper(dbName), strings.ToUpper(tableName))
|
||||
query := fmt.Sprintf(`SELECT c.column_name, c.data_type, c.nullable, c.data_default,
|
||||
CASE WHEN pk.column_name IS NOT NULL THEN 'PRI' ELSE '' END AS column_key
|
||||
FROM all_tab_columns c
|
||||
LEFT JOIN (
|
||||
SELECT cols.owner, cols.table_name, cols.column_name
|
||||
FROM all_constraints cons
|
||||
JOIN all_cons_columns cols
|
||||
ON cons.owner = cols.owner AND cons.constraint_name = cols.constraint_name
|
||||
WHERE cons.constraint_type = 'P'
|
||||
) pk ON c.owner = pk.owner AND c.table_name = pk.table_name AND c.column_name = pk.column_name
|
||||
WHERE c.owner = '%s' AND c.table_name = '%s'
|
||||
ORDER BY c.column_id`, strings.ToUpper(dbName), strings.ToUpper(tableName))
|
||||
|
||||
if dbName == "" {
|
||||
query = fmt.Sprintf(`SELECT column_name, data_type, nullable, data_default
|
||||
FROM user_tab_columns
|
||||
WHERE table_name = '%s'
|
||||
ORDER BY column_id`, strings.ToUpper(tableName))
|
||||
query = fmt.Sprintf(`SELECT c.column_name, c.data_type, c.nullable, c.data_default,
|
||||
CASE WHEN pk.column_name IS NOT NULL THEN 'PRI' ELSE '' END AS column_key
|
||||
FROM user_tab_columns c
|
||||
LEFT JOIN (
|
||||
SELECT cols.table_name, cols.column_name
|
||||
FROM user_constraints cons
|
||||
JOIN user_cons_columns cols USING (constraint_name)
|
||||
WHERE cons.constraint_type = 'P'
|
||||
) pk ON c.table_name = pk.table_name AND c.column_name = pk.column_name
|
||||
WHERE c.table_name = '%s'
|
||||
ORDER BY c.column_id`, strings.ToUpper(tableName))
|
||||
}
|
||||
|
||||
data, _, err := o.Query(query)
|
||||
@@ -286,6 +305,7 @@ func (o *OracleDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefi
|
||||
Name: fmt.Sprintf("%v", row["COLUMN_NAME"]),
|
||||
Type: fmt.Sprintf("%v", row["DATA_TYPE"]),
|
||||
Nullable: fmt.Sprintf("%v", row["NULLABLE"]),
|
||||
Key: fmt.Sprintf("%v", row["COLUMN_KEY"]),
|
||||
}
|
||||
|
||||
if row["DATA_DEFAULT"] != nil {
|
||||
@@ -299,17 +319,31 @@ func (o *OracleDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefi
|
||||
}
|
||||
|
||||
func (o *OracleDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
|
||||
query := fmt.Sprintf(`SELECT index_name, column_name, uniqueness
|
||||
FROM all_ind_columns
|
||||
JOIN all_indexes USING (index_name, owner)
|
||||
WHERE table_owner = '%s' AND table_name = '%s'`,
|
||||
strings.ToUpper(dbName), strings.ToUpper(tableName))
|
||||
esc := func(s string) string { return strings.ReplaceAll(strings.ToUpper(strings.TrimSpace(s)), "'", "''") }
|
||||
table := esc(tableName)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("表名不能为空")
|
||||
}
|
||||
|
||||
if dbName == "" {
|
||||
query = fmt.Sprintf(`SELECT index_name, column_name, uniqueness
|
||||
FROM user_ind_columns
|
||||
JOIN user_indexes USING (index_name)
|
||||
WHERE table_name = '%s'`, strings.ToUpper(tableName))
|
||||
query := fmt.Sprintf(`SELECT c.index_name, c.column_name, i.uniqueness, c.column_position, i.index_type
|
||||
FROM all_ind_columns c
|
||||
JOIN all_indexes i ON i.owner = c.index_owner AND i.index_name = c.index_name
|
||||
WHERE c.table_owner = '%s'
|
||||
AND c.table_name = '%s'
|
||||
AND c.column_name IS NOT NULL
|
||||
AND c.column_name NOT LIKE 'SYS_NC%%$'
|
||||
AND i.index_type NOT LIKE 'FUNCTION-BASED%%'
|
||||
ORDER BY c.index_name, c.column_position`, esc(dbName), table)
|
||||
|
||||
if strings.TrimSpace(dbName) == "" {
|
||||
query = fmt.Sprintf(`SELECT c.index_name, c.column_name, i.uniqueness, c.column_position, i.index_type
|
||||
FROM user_ind_columns c
|
||||
JOIN user_indexes i ON i.index_name = c.index_name
|
||||
WHERE c.table_name = '%s'
|
||||
AND c.column_name IS NOT NULL
|
||||
AND c.column_name NOT LIKE 'SYS_NC%%$'
|
||||
AND i.index_type NOT LIKE 'FUNCTION-BASED%%'
|
||||
ORDER BY c.index_name, c.column_position`, table)
|
||||
}
|
||||
|
||||
data, _, err := o.Query(query)
|
||||
@@ -317,19 +351,46 @@ func (o *OracleDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefin
|
||||
return nil, err
|
||||
}
|
||||
|
||||
getValue := func(row map[string]interface{}, names ...string) interface{} {
|
||||
for _, name := range names {
|
||||
if value, ok := row[name]; ok {
|
||||
return value
|
||||
}
|
||||
for key, value := range row {
|
||||
if strings.EqualFold(key, name) {
|
||||
return value
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
parseInt := func(value interface{}) int {
|
||||
var n int
|
||||
_, _ = fmt.Sscanf(strings.TrimSpace(fmt.Sprintf("%v", value)), "%d", &n)
|
||||
return n
|
||||
}
|
||||
|
||||
var indexes []connection.IndexDefinition
|
||||
for _, row := range data {
|
||||
unique := 1
|
||||
if val, ok := row["UNIQUENESS"]; ok && val == "UNIQUE" {
|
||||
unique = 0
|
||||
uniqueness := strings.ToUpper(strings.TrimSpace(fmt.Sprintf("%v", getValue(row, "UNIQUENESS"))))
|
||||
nonUnique := 1
|
||||
if uniqueness == "UNIQUE" {
|
||||
nonUnique = 0
|
||||
}
|
||||
indexType := strings.ToUpper(strings.TrimSpace(fmt.Sprintf("%v", getValue(row, "INDEX_TYPE"))))
|
||||
if indexType == "" || indexType == "<NIL>" {
|
||||
indexType = "BTREE"
|
||||
}
|
||||
|
||||
idx := connection.IndexDefinition{
|
||||
Name: fmt.Sprintf("%v", row["INDEX_NAME"]),
|
||||
ColumnName: fmt.Sprintf("%v", row["COLUMN_NAME"]),
|
||||
NonUnique: unique,
|
||||
// SeqInIndex is harder to get in simple join, omitting or estimating
|
||||
IndexType: "BTREE", // Default assumption
|
||||
Name: strings.TrimSpace(fmt.Sprintf("%v", getValue(row, "INDEX_NAME"))),
|
||||
ColumnName: strings.TrimSpace(fmt.Sprintf("%v", getValue(row, "COLUMN_NAME"))),
|
||||
NonUnique: nonUnique,
|
||||
SeqInIndex: parseInt(getValue(row, "COLUMN_POSITION")),
|
||||
IndexType: indexType,
|
||||
}
|
||||
if idx.Name == "" || idx.ColumnName == "" || strings.EqualFold(idx.ColumnName, "<nil>") {
|
||||
continue
|
||||
}
|
||||
indexes = append(indexes, idx)
|
||||
}
|
||||
@@ -531,23 +592,38 @@ func (o *OracleDB) ApplyChanges(tableName string, changes connection.ChangeSet)
|
||||
qualifiedTable = quoteIdent(table)
|
||||
}
|
||||
|
||||
// 1. Deletes
|
||||
for _, pk := range changes.Deletes {
|
||||
isOracleRowIDLocator := strings.EqualFold(strings.TrimSpace(changes.LocatorStrategy), "oracle-rowid")
|
||||
buildWhere := func(keys map[string]interface{}, startIndex int) ([]string, []interface{}, int) {
|
||||
var wheres []string
|
||||
var args []interface{}
|
||||
idx := 0
|
||||
for k, v := range pk {
|
||||
idx := startIndex
|
||||
for k, v := range keys {
|
||||
idx++
|
||||
if isOracleRowIDLocator && strings.EqualFold(strings.TrimSpace(k), "ROWID") {
|
||||
wheres = append(wheres, fmt.Sprintf("ROWID = :%d", idx))
|
||||
args = append(args, v)
|
||||
continue
|
||||
}
|
||||
wheres = append(wheres, fmt.Sprintf("%s = :%d", quoteIdent(k), idx))
|
||||
args = append(args, normalizeOracleValueForWrite(k, v, columnTypeMap))
|
||||
}
|
||||
return wheres, args, idx
|
||||
}
|
||||
|
||||
// 1. Deletes
|
||||
for _, pk := range changes.Deletes {
|
||||
wheres, args, _ := buildWhere(pk, 0)
|
||||
if len(wheres) == 0 {
|
||||
continue
|
||||
}
|
||||
query := fmt.Sprintf("DELETE FROM %s WHERE %s", qualifiedTable, strings.Join(wheres, " AND "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
res, err := tx.Exec(query, args...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除失败:%v", err)
|
||||
}
|
||||
if err := requireSingleRowAffected(res, "删除"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Updates
|
||||
@@ -566,21 +642,21 @@ func (o *OracleDB) ApplyChanges(tableName string, changes connection.ChangeSet)
|
||||
continue
|
||||
}
|
||||
|
||||
var wheres []string
|
||||
for k, v := range update.Keys {
|
||||
idx++
|
||||
wheres = append(wheres, fmt.Sprintf("%s = :%d", quoteIdent(k), idx))
|
||||
args = append(args, normalizeOracleValueForWrite(k, v, columnTypeMap))
|
||||
}
|
||||
wheres, whereArgs, _ := buildWhere(update.Keys, idx)
|
||||
args = append(args, whereArgs...)
|
||||
|
||||
if len(wheres) == 0 {
|
||||
return fmt.Errorf("更新操作需要主键条件")
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("UPDATE %s SET %s WHERE %s", qualifiedTable, strings.Join(sets, ", "), strings.Join(wheres, " AND "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
res, err := tx.Exec(query, args...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新失败:%v", err)
|
||||
}
|
||||
if err := requireSingleRowAffected(res, "更新"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Inserts
|
||||
@@ -602,9 +678,13 @@ func (o *OracleDB) ApplyChanges(tableName string, changes connection.ChangeSet)
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(cols, ", "), strings.Join(placeholders, ", "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
res, err := tx.Exec(query, args...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("插入失败:%v", err)
|
||||
}
|
||||
if affected, err := res.RowsAffected(); err == nil && affected == 0 {
|
||||
return fmt.Errorf("插入未生效:未影响任何行")
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
|
||||
@@ -408,6 +408,12 @@ JOIN pg_attribute a ON a.attrelid = t.oid AND a.attnum = x.attnum
|
||||
WHERE t.relkind IN ('r', 'p')
|
||||
AND t.relname = '%s'
|
||||
AND n.nspname = '%s'
|
||||
AND ix.indisvalid
|
||||
AND ix.indpred IS NULL
|
||||
AND x.ordinality <= ix.indnkeyatts
|
||||
AND NOT EXISTS (
|
||||
SELECT 1 FROM unnest(ix.indkey) AS expr_key(attnum) WHERE expr_key.attnum <= 0
|
||||
)
|
||||
ORDER BY i.relname, x.ordinality`, esc(table), esc(schema))
|
||||
|
||||
data, _, err := p.Query(query)
|
||||
@@ -758,9 +764,13 @@ func (p *PostgresDB) ApplyChanges(tableName string, changes connection.ChangeSet
|
||||
continue
|
||||
}
|
||||
query := fmt.Sprintf("DELETE FROM %s WHERE %s", qualifiedTable, strings.Join(wheres, " AND "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
res, err := tx.Exec(query, args...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除失败:%v", err)
|
||||
}
|
||||
if err := requireSingleRowAffected(res, "删除"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Updates
|
||||
@@ -791,9 +801,13 @@ func (p *PostgresDB) ApplyChanges(tableName string, changes connection.ChangeSet
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("UPDATE %s SET %s WHERE %s", qualifiedTable, strings.Join(sets, ", "), strings.Join(wheres, " AND "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
res, err := tx.Exec(query, args...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新失败:%v", err)
|
||||
}
|
||||
if err := requireSingleRowAffected(res, "更新"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Inserts
|
||||
|
||||
251
tools/generate-driver-agent-revisions.sh
Executable file
251
tools/generate-driver-agent-revisions.sh
Executable file
@@ -0,0 +1,251 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
||||
cd "$SCRIPT_DIR"
|
||||
|
||||
DEFAULT_DRIVERS=(mariadb diros sphinx sqlserver sqlite duckdb dameng kingbase highgo vastbase mongodb tdengine clickhouse)
|
||||
OUTPUT_FILE="internal/db/driver_agent_revisions_gen.go"
|
||||
|
||||
usage() {
|
||||
cat <<'EOF'
|
||||
用法:
|
||||
./tools/generate-driver-agent-revisions.sh [选项]
|
||||
|
||||
选项:
|
||||
--platform <GOOS/GOARCH> 按目标平台解析 Go build tags,默认使用当前 Go 环境
|
||||
--drivers <列表> 指定驱动列表(逗号分隔),默认生成所有 optional driver
|
||||
-h, --help 显示帮助
|
||||
EOF
|
||||
}
|
||||
|
||||
normalize_driver() {
|
||||
local value
|
||||
value="$(printf '%s' "$1" | tr '[:upper:]' '[:lower:]' | tr -d '[:space:]')"
|
||||
case "$value" in
|
||||
doris|diros) echo "diros" ;;
|
||||
mariadb|diros|sphinx|sqlserver|sqlite|duckdb|dameng|kingbase|highgo|vastbase|mongodb|tdengine|clickhouse)
|
||||
echo "$value"
|
||||
;;
|
||||
*)
|
||||
return 1
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
build_driver_name() {
|
||||
echo "$1"
|
||||
}
|
||||
|
||||
hash_file() {
|
||||
local target="$1"
|
||||
if command -v sha256sum >/dev/null 2>&1; then
|
||||
sha256sum "$target" | awk '{print $1}'
|
||||
return
|
||||
fi
|
||||
if command -v shasum >/dev/null 2>&1; then
|
||||
shasum -a 256 "$target" | awk '{print $1}'
|
||||
return
|
||||
fi
|
||||
echo "未找到 sha256sum 或 shasum" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
should_include_internal_db_file() {
|
||||
local driver="$1"
|
||||
local identity="$2"
|
||||
|
||||
case "$identity" in
|
||||
internal/db/agent_process_stub.go|\
|
||||
internal/db/agent_process_windows.go|\
|
||||
internal/db/database.go|\
|
||||
internal/db/database_optional_factories_full.go|\
|
||||
internal/db/database_optional_factories_lite.go|\
|
||||
internal/db/driver_agent_binary_check.go|\
|
||||
internal/db/driver_support.go|\
|
||||
internal/db/json_decode.go|\
|
||||
internal/db/mysql_agent_path.go|\
|
||||
internal/db/optional_driver_agent_impl.go|\
|
||||
internal/db/optional_driver_build_full.go|\
|
||||
internal/db/optional_driver_build_lite.go|\
|
||||
internal/db/query_value.go|\
|
||||
internal/db/scan_rows.go|\
|
||||
internal/db/ssl_mode.go|\
|
||||
internal/db/timeout.go)
|
||||
return 0
|
||||
;;
|
||||
esac
|
||||
|
||||
case "$driver:$identity" in
|
||||
mariadb:internal/db/mariadb_impl.go|\
|
||||
diros:internal/db/diros_impl.go|\
|
||||
diros:internal/db/mysql_impl.go|\
|
||||
sphinx:internal/db/sphinx_impl.go|\
|
||||
sphinx:internal/db/mysql_impl.go|\
|
||||
sqlserver:internal/db/sqlserver_impl.go|\
|
||||
sqlite:internal/db/sqlite_impl.go|\
|
||||
duckdb:internal/db/duckdb_impl.go|\
|
||||
duckdb:internal/db/duckdb_driver_import.go|\
|
||||
duckdb:internal/db/duckdb_platform_supported.go|\
|
||||
duckdb:internal/db/duckdb_platform_unsupported.go|\
|
||||
dameng:internal/db/dameng_impl.go|\
|
||||
dameng:internal/db/dameng_metadata.go|\
|
||||
kingbase:internal/db/kingbase_impl.go|\
|
||||
kingbase:internal/db/kingbase_identifier_utils.go|\
|
||||
highgo:internal/db/highgo_impl.go|\
|
||||
vastbase:internal/db/vastbase_impl.go|\
|
||||
mongodb:internal/db/mongodb_impl.go|\
|
||||
mongodb:internal/db/mongodb_impl_v1.go|\
|
||||
tdengine:internal/db/tdengine_impl.go|\
|
||||
clickhouse:internal/db/clickhouse_impl.go)
|
||||
return 0
|
||||
;;
|
||||
esac
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
should_include_source_file() {
|
||||
local driver="$1"
|
||||
local identity="$2"
|
||||
if [[ "$identity" == internal/db/* ]]; then
|
||||
should_include_internal_db_file "$driver" "$identity"
|
||||
return
|
||||
fi
|
||||
return 0
|
||||
}
|
||||
|
||||
target_platform=""
|
||||
driver_csv=""
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
--platform)
|
||||
target_platform="${2:-}"
|
||||
shift 2
|
||||
;;
|
||||
--drivers)
|
||||
driver_csv="${2:-}"
|
||||
shift 2
|
||||
;;
|
||||
-h|--help)
|
||||
usage
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
echo "未知参数:$1" >&2
|
||||
usage
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
if ! command -v go >/dev/null 2>&1; then
|
||||
echo "未找到 Go,请先安装 Go 并确保 go 在 PATH 中。" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ -z "$target_platform" ]]; then
|
||||
target_platform="$(go env GOOS)/$(go env GOARCH)"
|
||||
fi
|
||||
if [[ "$target_platform" != */* ]]; then
|
||||
echo "--platform 参数格式错误,应为 GOOS/GOARCH,例如 darwin/arm64" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
goos="${target_platform%%/*}"
|
||||
goarch="${target_platform##*/}"
|
||||
gomodcache="$(go env GOMODCACHE)"
|
||||
|
||||
declare -a drivers=()
|
||||
if [[ -n "$driver_csv" ]]; then
|
||||
IFS=',' read -r -a raw_drivers <<<"$driver_csv"
|
||||
for item in "${raw_drivers[@]}"; do
|
||||
drivers+=("$(normalize_driver "$item")")
|
||||
done
|
||||
else
|
||||
drivers=("${DEFAULT_DRIVERS[@]}")
|
||||
fi
|
||||
|
||||
fingerprint_driver() {
|
||||
local driver="$1"
|
||||
local build_driver tag cgo_enabled tmp file identity file_hash revision
|
||||
build_driver="$(build_driver_name "$driver")"
|
||||
tag="gonavi_${build_driver}_driver"
|
||||
cgo_enabled=0
|
||||
if [[ "$driver" == "duckdb" ]]; then
|
||||
cgo_enabled=1
|
||||
fi
|
||||
|
||||
tmp="$(mktemp "${TMPDIR:-/tmp}/gonavi-agent-revision.XXXXXX")"
|
||||
{
|
||||
printf 'driver=%s\n' "$driver"
|
||||
printf 'build_tag=%s\n' "$tag"
|
||||
printf 'goos=%s\n' "$goos"
|
||||
printf 'goarch=%s\n' "$goarch"
|
||||
} >"$tmp"
|
||||
|
||||
while IFS= read -r file; do
|
||||
[[ -n "$file" && -f "$file" ]] || continue
|
||||
case "$file" in
|
||||
"$SCRIPT_DIR"/*)
|
||||
identity="${file#$SCRIPT_DIR/}"
|
||||
;;
|
||||
"$gomodcache"/*)
|
||||
identity="gomod/${file#$gomodcache/}"
|
||||
;;
|
||||
*)
|
||||
identity="$file"
|
||||
;;
|
||||
esac
|
||||
if [[ "$identity" == "$OUTPUT_FILE" ]]; then
|
||||
continue
|
||||
fi
|
||||
if ! should_include_source_file "$driver" "$identity"; then
|
||||
continue
|
||||
fi
|
||||
file_hash="$(hash_file "$file")"
|
||||
printf '%s %s\n' "$file_hash" "$identity"
|
||||
done < <(
|
||||
CGO_ENABLED="$cgo_enabled" GOOS="$goos" GOARCH="$goarch" GOTOOLCHAIN=auto \
|
||||
go list -deps \
|
||||
-tags "$tag" \
|
||||
-f '{{if not .Standard}}{{range .GoFiles}}{{$.Dir}}/{{.}}{{"\n"}}{{end}}{{range .CgoFiles}}{{$.Dir}}/{{.}}{{"\n"}}{{end}}{{end}}' \
|
||||
./cmd/optional-driver-agent | sort -u
|
||||
) >>"$tmp"
|
||||
|
||||
revision="$(hash_file "$tmp" | cut -c1-16)"
|
||||
rm -f "$tmp"
|
||||
printf 'src-%s' "$revision"
|
||||
}
|
||||
|
||||
tmp_output="$(mktemp "${TMPDIR:-/tmp}/gonavi-agent-revisions-go.XXXXXX")"
|
||||
{
|
||||
cat <<'EOF'
|
||||
// Code generated by tools/generate-driver-agent-revisions.sh; DO NOT EDIT.
|
||||
|
||||
package db
|
||||
|
||||
func init() {
|
||||
optionalDriverAgentRevisions = map[string]string{
|
||||
EOF
|
||||
for driver in "${drivers[@]}"; do
|
||||
revision="$(fingerprint_driver "$driver")"
|
||||
printf '\t\t"%s": "%s",\n' "$driver" "$revision"
|
||||
done
|
||||
cat <<'EOF'
|
||||
}
|
||||
}
|
||||
EOF
|
||||
} >"$tmp_output"
|
||||
|
||||
gofmt -w "$tmp_output"
|
||||
|
||||
if [[ -f "$OUTPUT_FILE" ]] && cmp -s "$tmp_output" "$OUTPUT_FILE"; then
|
||||
rm -f "$tmp_output"
|
||||
else
|
||||
mv "$tmp_output" "$OUTPUT_FILE"
|
||||
fi
|
||||
|
||||
echo "已生成 driver-agent revisions: $OUTPUT_FILE ($goos/$goarch)"
|
||||
Reference in New Issue
Block a user