From 7d5592d8d97040529219332e7e9baaa040c8026d Mon Sep 17 00:00:00 2001 From: Syngnat Date: Fri, 27 Feb 2026 09:31:24 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat(db):=20=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E5=BA=93=E8=BF=9E=E6=8E=A5=E6=96=B0=E5=A2=9E=20SOCKS5/HTTP=20?= =?UTF-8?q?=E4=BB=A3=E7=90=86=E8=83=BD=E5=8A=9B=E5=B9=B6=E5=85=BC=E5=AE=B9?= =?UTF-8?q?=20SRV/SSH=20=E5=9C=BA=E6=99=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 后端 ConnectionConfig 增加代理配置并完成规范化处理 - 普通 TCP 数据源通过本地转发接入代理 - MongoDB 使用 Dialer 支持代理连接(含 SRV) - 前端连接配置新增代理 UI、字段清洗与数据回填 - refs #122 --- frontend/src/components/ConnectionModal.tsx | 89 ++++- frontend/src/store.ts | 12 + frontend/src/types.ts | 10 + frontend/wailsjs/go/models.ts | 25 ++ go.mod | 2 +- internal/app/app.go | 20 +- internal/app/db_proxy.go | 202 ++++++++++++ internal/connection/types.go | 61 ++-- internal/db/mongodb_impl.go | 12 + internal/proxy/proxy.go | 344 ++++++++++++++++++++ internal/proxy/proxy_test.go | 44 +++ 11 files changed, 792 insertions(+), 29 deletions(-) create mode 100644 internal/app/db_proxy.go create mode 100644 internal/proxy/proxy.go create mode 100644 internal/proxy/proxy_test.go diff --git a/frontend/src/components/ConnectionModal.tsx b/frontend/src/components/ConnectionModal.tsx index 0224149..cfc3b71 100644 --- a/frontend/src/components/ConnectionModal.tsx +++ b/frontend/src/components/ConnectionModal.tsx @@ -3,7 +3,7 @@ import { Modal, Form, Input, InputNumber, Button, message, Checkbox, Divider, Se import { DatabaseOutlined, ConsoleSqlOutlined, FileTextOutlined, CloudServerOutlined, AppstoreAddOutlined, CloudOutlined, CheckCircleFilled, CloseCircleFilled } from '@ant-design/icons'; import { useStore } from '../store'; import { DBGetDatabases, GetDriverStatusList, MongoDiscoverMembers, TestConnection, RedisConnect, SelectSSHKeyFile } from '../../wailsjs/go/app/App'; -import { MongoMemberInfo, SavedConnection } from '../types'; +import { ConnectionConfig, MongoMemberInfo, SavedConnection } from '../types'; const { Meta } = Card; const { Text } = Typography; @@ -58,6 +58,7 @@ const ConnectionModal: React.FC<{ const [form] = Form.useForm(); const [loading, setLoading] = useState(false); const [useSSH, setUseSSH] = useState(false); + const [useProxy, setUseProxy] = useState(false); const [dbType, setDbType] = useState('mysql'); const [step, setStep] = useState(1); // 1: Select Type, 2: Configure const [activeGroup, setActiveGroup] = useState(0); // Active category index in step 1 @@ -655,6 +656,12 @@ const ConnectionModal: React.FC<{ sshUser: config.ssh?.user, sshPassword: config.ssh?.password, sshKeyPath: config.ssh?.keyPath, + useProxy: config.useProxy, + proxyType: config.proxy?.type || 'socks5', + proxyHost: config.proxy?.host, + proxyPort: config.proxy?.port, + proxyUser: config.proxy?.user, + proxyPassword: config.proxy?.password, driver: config.driver, dsn: config.dsn, timeout: config.timeout || 30, @@ -674,6 +681,7 @@ const ConnectionModal: React.FC<{ mongoReplicaPassword: config.mongoReplicaPassword || '' }); setUseSSH(config.useSSH || false); + setUseProxy(config.useProxy || false); setDbType(configType); // 如果是 Redis 编辑模式,设置已保存的 Redis 数据库列表 if (configType === 'redis') { @@ -684,6 +692,7 @@ const ConnectionModal: React.FC<{ setStep(1); form.resetFields(); setUseSSH(false); + setUseProxy(false); setDbType('mysql'); setActiveGroup(0); } @@ -733,6 +742,7 @@ const ConnectionModal: React.FC<{ setLoading(false); form.resetFields(); setUseSSH(false); + setUseProxy(false); setDbType('mysql'); setStep(1); onClose(); @@ -852,7 +862,7 @@ const ConnectionModal: React.FC<{ } }; - const buildConfig = async (values: any, forPersist: boolean) => { + const buildConfig = async (values: any, forPersist: boolean): Promise => { const mergedValues = { ...values }; const parsedUriValues = parseUriToValues(mergedValues.uri, mergedValues.type); const isEmptyField = (value: unknown) => ( @@ -951,6 +961,22 @@ const ConnectionModal: React.FC<{ password: mergedValues.sshPassword || "", keyPath: mergedValues.sshKeyPath || "" } : { host: "", port: 22, user: "", password: "", keyPath: "" }; + const effectiveUseProxy = !isFileDbType && !!mergedValues.useProxy; + const proxyTypeRaw = String(mergedValues.proxyType || 'socks5').toLowerCase(); + const proxyType: 'socks5' | 'http' = proxyTypeRaw === 'http' ? 'http' : 'socks5'; + const proxyConfig: NonNullable = effectiveUseProxy ? { + type: proxyType, + host: String(mergedValues.proxyHost || '').trim(), + port: Number(mergedValues.proxyPort || (proxyTypeRaw === 'http' ? 8080 : 1080)), + user: String(mergedValues.proxyUser || '').trim(), + password: mergedValues.proxyPassword || "", + } : { + type: 'socks5', + host: '', + port: 1080, + user: '', + password: '', + }; const keepPassword = !forPersist || savePassword; @@ -964,6 +990,8 @@ const ConnectionModal: React.FC<{ database: mergedValues.database || "", useSSH: !!mergedValues.useSSH, ssh: sshConfig, + useProxy: effectiveUseProxy, + proxy: proxyConfig, driver: mergedValues.driver, dsn: mergedValues.dsn, timeout: Number(mergedValues.timeout || 30), @@ -997,6 +1025,7 @@ const ConnectionModal: React.FC<{ const defaultPort = getDefaultPortByType(type); if (isFileDatabaseType(type)) { setUseSSH(false); + setUseProxy(false); form.setFieldsValue({ host: '', port: 0, @@ -1009,6 +1038,12 @@ const ConnectionModal: React.FC<{ sshUser: '', sshPassword: '', sshKeyPath: '', + useProxy: false, + proxyType: 'socks5', + proxyHost: '', + proxyPort: 1080, + proxyUser: '', + proxyPassword: '', mysqlTopology: 'single', mongoTopology: 'single', mongoSrv: false, @@ -1167,6 +1202,9 @@ const ConnectionModal: React.FC<{ user: 'root', useSSH: false, sshPort: 22, + useProxy: false, + proxyType: 'socks5', + proxyPort: 1080, timeout: 30, uri: '', mysqlTopology: 'single', @@ -1191,6 +1229,21 @@ const ConnectionModal: React.FC<{ setUriFeedback(null); } if (changed.useSSH !== undefined) setUseSSH(changed.useSSH); + if (changed.useProxy !== undefined) setUseProxy(changed.useProxy); + if (changed.proxyType !== undefined) { + const nextType = String(changed.proxyType || 'socks5').toLowerCase(); + if (nextType === 'http') { + const currentPort = Number(form.getFieldValue('proxyPort') || 0); + if (!currentPort || currentPort === 1080) { + form.setFieldValue('proxyPort', 8080); + } + } else { + const currentPort = Number(form.getFieldValue('proxyPort') || 0); + if (!currentPort || currentPort === 8080) { + form.setFieldValue('proxyPort', 1080); + } + } + } // Type change handled by step 1, but keep sync if select changes (hidden now) if (changed.type !== undefined) setDbType(changed.type); if ( @@ -1531,6 +1584,38 @@ const ConnectionModal: React.FC<{ )} + + + 使用代理 (SOCKS5 / HTTP CONNECT) + + + {useProxy && ( +
+
+ + + + + + +
+
+ + + + + + +
+
+ )} + { password: toTrimmedString(sshRaw.password), keyPath: toTrimmedString(sshRaw.keyPath), }; + const proxyRaw = (raw.proxy && typeof raw.proxy === 'object') ? raw.proxy as Record : {}; + const proxyTypeRaw = toTrimmedString(proxyRaw.type, 'socks5').toLowerCase(); + const proxyType: 'socks5' | 'http' = proxyTypeRaw === 'http' ? 'http' : 'socks5'; + const proxy = { + type: proxyType, + host: toTrimmedString(proxyRaw.host), + port: normalizePort(proxyRaw.port, proxyTypeRaw === 'http' ? 8080 : 1080), + user: toTrimmedString(proxyRaw.user), + password: toTrimmedString(proxyRaw.password), + }; const safeConfig: ConnectionConfig & Record = { ...raw, @@ -167,6 +177,8 @@ const sanitizeConnectionConfig = (value: unknown): ConnectionConfig => { database: toTrimmedString(raw.database), useSSH: !!raw.useSSH, ssh, + useProxy: !!raw.useProxy, + proxy, uri: toTrimmedString(raw.uri).slice(0, MAX_URI_LENGTH), hosts: sanitizeAddressList(raw.hosts), topology: raw.topology === 'replica' ? 'replica' : 'single', diff --git a/frontend/src/types.ts b/frontend/src/types.ts index 7fa3b54..f700677 100644 --- a/frontend/src/types.ts +++ b/frontend/src/types.ts @@ -6,6 +6,14 @@ export interface SSHConfig { keyPath?: string; } +export interface ProxyConfig { + type: 'socks5' | 'http'; + host: string; + port: number; + user?: string; + password?: string; +} + export interface ConnectionConfig { type: string; host: string; @@ -16,6 +24,8 @@ export interface ConnectionConfig { database?: string; useSSH?: boolean; ssh?: SSHConfig; + useProxy?: boolean; + proxy?: ProxyConfig; driver?: string; dsn?: string; timeout?: number; diff --git a/frontend/wailsjs/go/models.ts b/frontend/wailsjs/go/models.ts index eaf39e7..d9de709 100755 --- a/frontend/wailsjs/go/models.ts +++ b/frontend/wailsjs/go/models.ts @@ -48,6 +48,26 @@ export namespace connection { return a; } } + export class ProxyConfig { + type: string; + host: string; + port: number; + user?: string; + password?: string; + + static createFrom(source: any = {}) { + return new ProxyConfig(source); + } + + constructor(source: any = {}) { + if ('string' === typeof source) source = JSON.parse(source); + this.type = source["type"]; + this.host = source["host"]; + this.port = source["port"]; + this.user = source["user"]; + this.password = source["password"]; + } + } export class SSHConfig { host: string; port: number; @@ -78,6 +98,8 @@ export namespace connection { database: string; useSSH: boolean; ssh: SSHConfig; + useProxy?: boolean; + proxy?: ProxyConfig; driver?: string; dsn?: string; timeout?: number; @@ -110,6 +132,8 @@ export namespace connection { this.database = source["database"]; this.useSSH = source["useSSH"]; this.ssh = this.convertValues(source["ssh"], SSHConfig); + this.useProxy = source["useProxy"]; + this.proxy = this.convertValues(source["proxy"], ProxyConfig); this.driver = source["driver"]; this.dsn = source["dsn"]; this.timeout = source["timeout"]; @@ -146,6 +170,7 @@ export namespace connection { return a; } } + export class QueryResult { success: boolean; message: string; diff --git a/go.mod b/go.mod index ea50c60..9203b37 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( go.mongodb.org/mongo-driver/v2 v2.5.0 golang.org/x/crypto v0.47.0 golang.org/x/mod v0.32.0 + golang.org/x/net v0.49.0 golang.org/x/text v0.33.0 modernc.org/sqlite v1.44.3 ) @@ -84,7 +85,6 @@ require ( github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect github.com/zeebo/xxh3 v1.1.0 // indirect golang.org/x/exp v0.0.0-20260112195511-716be5621a96 // indirect - golang.org/x/net v0.49.0 // indirect golang.org/x/sync v0.19.0 // indirect golang.org/x/sys v0.40.0 // indirect golang.org/x/telemetry v0.0.0-20260116145544-c6413dc483f5 // indirect diff --git a/internal/app/app.go b/internal/app/app.go index 248093c..124b9d2 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -15,6 +15,7 @@ import ( "GoNavi-Wails/internal/connection" "GoNavi-Wails/internal/db" "GoNavi-Wails/internal/logger" + proxytunnel "GoNavi-Wails/internal/proxy" ) const dbCachePingInterval = 30 * time.Second @@ -66,6 +67,7 @@ func (a *App) Shutdown(ctx context.Context) { logger.Error(err, "关闭数据库连接失败") } } + proxytunnel.CloseAllForwarders() // Close all Redis connections CloseAllRedisClients() logger.Infof("资源释放完成,应用已关闭") @@ -77,6 +79,9 @@ func getCacheKey(config connection.ConnectionConfig) string { if !config.UseSSH { config.SSH = connection.SSHConfig{} } + if !config.UseProxy { + config.Proxy = connection.ProxyConfig{} + } // 保持与驱动默认一致,避免同一连接被重复缓存 if config.Type == "postgres" && config.Database == "" { config.Database = "postgres" @@ -175,6 +180,12 @@ func formatConnSummary(config connection.ConnectionConfig) string { if config.UseSSH { b.WriteString(fmt.Sprintf(" SSH=%s:%d 用户=%s", config.SSH.Host, config.SSH.Port, config.SSH.User)) } + if config.UseProxy { + b.WriteString(fmt.Sprintf(" 代理=%s://%s:%d", strings.ToLower(strings.TrimSpace(config.Proxy.Type)), config.Proxy.Host, config.Proxy.Port)) + if strings.TrimSpace(config.Proxy.User) != "" { + b.WriteString(" 代理认证=已配置") + } + } if config.Type == "custom" { driver := strings.TrimSpace(config.Driver) @@ -269,7 +280,14 @@ func (a *App) getDatabaseWithPing(config connection.ConnectionConfig, forcePing return nil, err } - if err := dbInst.Connect(config); err != nil { + connectConfig, proxyErr := resolveDialConfigWithProxy(config) + if proxyErr != nil { + wrapped := wrapConnectError(config, proxyErr) + logger.Error(wrapped, "连接代理准备失败:%s 缓存Key=%s", formatConnSummary(config), shortKey) + return nil, wrapped + } + + if err := dbInst.Connect(connectConfig); err != nil { wrapped := wrapConnectError(config, err) logger.Error(wrapped, "建立数据库连接失败:%s 缓存Key=%s", formatConnSummary(config), shortKey) return nil, wrapped diff --git a/internal/app/db_proxy.go b/internal/app/db_proxy.go new file mode 100644 index 0000000..6adf0c2 --- /dev/null +++ b/internal/app/db_proxy.go @@ -0,0 +1,202 @@ +package app + +import ( + "fmt" + "net" + "strconv" + "strings" + + "GoNavi-Wails/internal/connection" + proxytunnel "GoNavi-Wails/internal/proxy" +) + +func resolveDialConfigWithProxy(raw connection.ConnectionConfig) (connection.ConnectionConfig, error) { + config := raw + if !config.UseProxy { + config.Proxy = connection.ProxyConfig{} + return config, nil + } + + normalizedProxy, err := proxytunnel.NormalizeConfig(config.Proxy) + if err != nil { + return connection.ConnectionConfig{}, err + } + config.Proxy = normalizedProxy + + if config.UseSSH { + sshPort := config.SSH.Port + if sshPort <= 0 { + sshPort = 22 + } + forwardedSSH, err := buildProxyForwardAddress(normalizedProxy, strings.TrimSpace(config.SSH.Host), sshPort) + if err != nil { + return connection.ConnectionConfig{}, fmt.Errorf("代理连接 SSH 网关失败:%w", err) + } + config.SSH.Host = forwardedSSH.host + config.SSH.Port = forwardedSSH.port + config.UseProxy = false + config.Proxy = connection.ProxyConfig{} + return config, nil + } + + normalizedType := strings.ToLower(strings.TrimSpace(config.Type)) + if normalizedType == "sqlite" || normalizedType == "duckdb" || normalizedType == "custom" { + // 文件型/自定义 DSN 类型不走标准 host:port,不在此层改写。 + return config, nil + } + if normalizedType == "mongodb" && config.MongoSRV { + // Mongo SRV 由驱动侧 Dialer 处理代理,避免破坏 DNS SRV 拓扑发现。 + return config, nil + } + + targetPort := config.Port + if targetPort <= 0 { + targetPort = defaultPortByType(normalizedType) + } + forwardedPrimary, err := buildProxyForwardAddress(normalizedProxy, strings.TrimSpace(config.Host), targetPort) + if err != nil { + return connection.ConnectionConfig{}, err + } + config.Host = forwardedPrimary.host + config.Port = forwardedPrimary.port + + if len(config.Hosts) > 0 { + rewritten := make([]string, 0, len(config.Hosts)) + seen := make(map[string]struct{}, len(config.Hosts)) + for _, rawEntry := range config.Hosts { + targetHost, targetPort, ok := parseAddressWithDefaultPort(rawEntry, defaultPortByType(normalizedType)) + if !ok { + continue + } + forwarded, forwardErr := buildProxyForwardAddress(normalizedProxy, targetHost, targetPort) + if forwardErr != nil { + return connection.ConnectionConfig{}, forwardErr + } + rewrittenAddress := formatHostPort(forwarded.host, forwarded.port) + if _, exists := seen[rewrittenAddress]; exists { + continue + } + seen[rewrittenAddress] = struct{}{} + rewritten = append(rewritten, rewrittenAddress) + } + config.Hosts = rewritten + } + + config.UseProxy = false + config.Proxy = connection.ProxyConfig{} + return config, nil +} + +type hostPort struct { + host string + port int +} + +func buildProxyForwardAddress(proxyConfig connection.ProxyConfig, targetHost string, targetPort int) (hostPort, error) { + host := strings.TrimSpace(targetHost) + if host == "" { + host = "localhost" + } + port := targetPort + if port <= 0 { + return hostPort{}, fmt.Errorf("目标端口无效:%d", targetPort) + } + + forwarder, err := proxytunnel.GetOrCreateLocalForwarder(proxyConfig, host, port) + if err != nil { + return hostPort{}, err + } + localHost, localPort, splitOK := parseAddressWithDefaultPort(forwarder.LocalAddr, 0) + if !splitOK || localPort <= 0 { + return hostPort{}, fmt.Errorf("解析代理本地转发地址失败:%s", forwarder.LocalAddr) + } + return hostPort{host: localHost, port: localPort}, nil +} + +func parseAddressWithDefaultPort(raw string, defaultPort int) (string, int, bool) { + text := strings.TrimSpace(raw) + if text == "" { + return "", 0, false + } + + if strings.HasPrefix(text, "[") { + if host, portText, err := net.SplitHostPort(text); err == nil { + if port, convErr := strconv.Atoi(portText); convErr == nil && port > 0 && port <= 65535 { + return strings.TrimSpace(host), port, true + } + return "", 0, false + } + trimmed := strings.Trim(strings.TrimPrefix(text, "["), "]") + if trimmed != "" && defaultPort > 0 { + return trimmed, defaultPort, true + } + return "", 0, false + } + + if strings.Count(text, ":") == 0 { + if defaultPort <= 0 { + return "", 0, false + } + return text, defaultPort, true + } + + if strings.Count(text, ":") == 1 { + host, portText, err := net.SplitHostPort(text) + if err == nil { + port, convErr := strconv.Atoi(portText) + if convErr == nil && port > 0 && port <= 65535 { + return strings.TrimSpace(host), port, true + } + return "", 0, false + } + if defaultPort > 0 { + return strings.TrimSpace(text), defaultPort, true + } + return "", 0, false + } + + // IPv6 地址未带端口,使用默认端口。 + if defaultPort > 0 { + return text, defaultPort, true + } + return "", 0, false +} + +func formatHostPort(host string, port int) string { + h := strings.TrimSpace(host) + if strings.Contains(h, ":") && !strings.HasPrefix(h, "[") { + return fmt.Sprintf("[%s]:%d", h, port) + } + return fmt.Sprintf("%s:%d", h, port) +} + +func defaultPortByType(driverType string) int { + switch strings.ToLower(strings.TrimSpace(driverType)) { + case "mysql", "mariadb": + return 3306 + case "diros": + return 9030 + case "sphinx": + return 9306 + case "postgres", "vastbase": + return 5432 + case "redis": + return 6379 + case "tdengine": + return 6041 + case "oracle": + return 1521 + case "dameng": + return 5236 + case "kingbase": + return 54321 + case "sqlserver": + return 1433 + case "mongodb": + return 27017 + case "highgo": + return 5866 + default: + return 0 + } +} diff --git a/internal/connection/types.go b/internal/connection/types.go index cbc479d..cfc0253 100644 --- a/internal/connection/types.go +++ b/internal/connection/types.go @@ -9,33 +9,44 @@ type SSHConfig struct { KeyPath string `json:"keyPath"` } +// ProxyConfig holds proxy connection details +type ProxyConfig struct { + Type string `json:"type"` // socks5 | http + Host string `json:"host"` + Port int `json:"port"` + User string `json:"user,omitempty"` + Password string `json:"password,omitempty"` +} + // ConnectionConfig holds database connection details including SSH type ConnectionConfig struct { - Type string `json:"type"` - Host string `json:"host"` - Port int `json:"port"` - User string `json:"user"` - Password string `json:"password"` - SavePassword bool `json:"savePassword,omitempty"` // Persist password in saved connection - Database string `json:"database"` - UseSSH bool `json:"useSSH"` - SSH SSHConfig `json:"ssh"` - Driver string `json:"driver,omitempty"` // For custom connection - DSN string `json:"dsn,omitempty"` // For custom connection - Timeout int `json:"timeout,omitempty"` // Connection timeout in seconds (default: 30) - RedisDB int `json:"redisDB,omitempty"` // Redis database index (0-15) - URI string `json:"uri,omitempty"` // Connection URI for copy/paste - Hosts []string `json:"hosts,omitempty"` // Multi-host addresses: host:port - Topology string `json:"topology,omitempty"` // single | replica - MySQLReplicaUser string `json:"mysqlReplicaUser,omitempty"` // MySQL replica auth user - MySQLReplicaPassword string `json:"mysqlReplicaPassword,omitempty"` // MySQL replica auth password - ReplicaSet string `json:"replicaSet,omitempty"` // MongoDB replica set name - AuthSource string `json:"authSource,omitempty"` // MongoDB authSource - ReadPreference string `json:"readPreference,omitempty"` // MongoDB readPreference - MongoSRV bool `json:"mongoSrv,omitempty"` // MongoDB use mongodb+srv URI scheme - MongoAuthMechanism string `json:"mongoAuthMechanism,omitempty"` // MongoDB authMechanism - MongoReplicaUser string `json:"mongoReplicaUser,omitempty"` // MongoDB replica auth user - MongoReplicaPassword string `json:"mongoReplicaPassword,omitempty"` // MongoDB replica auth password + Type string `json:"type"` + Host string `json:"host"` + Port int `json:"port"` + User string `json:"user"` + Password string `json:"password"` + SavePassword bool `json:"savePassword,omitempty"` // Persist password in saved connection + Database string `json:"database"` + UseSSH bool `json:"useSSH"` + SSH SSHConfig `json:"ssh"` + UseProxy bool `json:"useProxy,omitempty"` + Proxy ProxyConfig `json:"proxy,omitempty"` + Driver string `json:"driver,omitempty"` // For custom connection + DSN string `json:"dsn,omitempty"` // For custom connection + Timeout int `json:"timeout,omitempty"` // Connection timeout in seconds (default: 30) + RedisDB int `json:"redisDB,omitempty"` // Redis database index (0-15) + URI string `json:"uri,omitempty"` // Connection URI for copy/paste + Hosts []string `json:"hosts,omitempty"` // Multi-host addresses: host:port + Topology string `json:"topology,omitempty"` // single | replica + MySQLReplicaUser string `json:"mysqlReplicaUser,omitempty"` // MySQL replica auth user + MySQLReplicaPassword string `json:"mysqlReplicaPassword,omitempty"` // MySQL replica auth password + ReplicaSet string `json:"replicaSet,omitempty"` // MongoDB replica set name + AuthSource string `json:"authSource,omitempty"` // MongoDB authSource + ReadPreference string `json:"readPreference,omitempty"` // MongoDB readPreference + MongoSRV bool `json:"mongoSrv,omitempty"` // MongoDB use mongodb+srv URI scheme + MongoAuthMechanism string `json:"mongoAuthMechanism,omitempty"` // MongoDB authMechanism + MongoReplicaUser string `json:"mongoReplicaUser,omitempty"` // MongoDB replica auth user + MongoReplicaPassword string `json:"mongoReplicaPassword,omitempty"` // MongoDB replica auth password } // QueryResult is the standard response format for Wails methods diff --git a/internal/db/mongodb_impl.go b/internal/db/mongodb_impl.go index 35fc9f3..4b2f630 100644 --- a/internal/db/mongodb_impl.go +++ b/internal/db/mongodb_impl.go @@ -14,6 +14,7 @@ import ( "GoNavi-Wails/internal/connection" "GoNavi-Wails/internal/logger" + proxytunnel "GoNavi-Wails/internal/proxy" "GoNavi-Wails/internal/ssh" "go.mongodb.org/mongo-driver/v2/bson" @@ -29,6 +30,14 @@ type MongoDB struct { forwarder *ssh.LocalForwarder } +type mongoProxyDialer struct { + proxyConfig connection.ProxyConfig +} + +func (d *mongoProxyDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + return proxytunnel.DialContext(ctx, d.proxyConfig, network, address) +} + const defaultMongoPort = 27017 func normalizeMongoAddress(host string, port int) string { @@ -328,6 +337,9 @@ func (m *MongoDB) Connect(config connection.ConnectionConfig) error { uri := m.getURI(attemptConfig) clientOpts := options.Client().ApplyURI(uri) + if attemptConfig.UseProxy { + clientOpts.SetDialer(&mongoProxyDialer{proxyConfig: attemptConfig.Proxy}) + } client, err := mongo.Connect(clientOpts) if err != nil { errorDetails = append(errorDetails, fmt.Sprintf("%s连接失败: %v", authLabel, err)) diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go new file mode 100644 index 0000000..f378858 --- /dev/null +++ b/internal/proxy/proxy.go @@ -0,0 +1,344 @@ +package proxy + +import ( + "bufio" + "context" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "fmt" + "io" + "net" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "GoNavi-Wails/internal/connection" + "GoNavi-Wails/internal/logger" + + xproxy "golang.org/x/net/proxy" +) + +const ( + defaultDialTimeout = 8 * time.Second +) + +type LocalForwarder struct { + LocalAddr string + RemoteAddr string + ProxyAddr string + ProxyType string + + cfg connection.ProxyConfig + listener net.Listener + closeChan chan struct{} + closeOnce sync.Once + + closed bool + closedMu sync.RWMutex +} + +var ( + forwarderMu sync.RWMutex + localForwarders = make(map[string]*LocalForwarder) +) + +func NormalizeConfig(config connection.ProxyConfig) (connection.ProxyConfig, error) { + result := connection.ProxyConfig{ + Type: strings.ToLower(strings.TrimSpace(config.Type)), + Host: strings.TrimSpace(config.Host), + Port: config.Port, + User: strings.TrimSpace(config.User), + Password: config.Password, + } + + switch result.Type { + case "socks5", "socks5h", "http": + default: + return result, fmt.Errorf("不支持的代理类型:%s", config.Type) + } + if result.Type == "socks5h" { + result.Type = "socks5" + } + if result.Host == "" { + return result, fmt.Errorf("代理主机为空") + } + if result.Port <= 0 || result.Port > 65535 { + return result, fmt.Errorf("代理端口无效:%d", result.Port) + } + return result, nil +} + +func GetOrCreateLocalForwarder(proxyConfig connection.ProxyConfig, remoteHost string, remotePort int) (*LocalForwarder, error) { + cfg, err := NormalizeConfig(proxyConfig) + if err != nil { + return nil, err + } + if strings.TrimSpace(remoteHost) == "" || remotePort <= 0 { + return nil, fmt.Errorf("无效的远端地址:%s:%d", remoteHost, remotePort) + } + + key := forwarderCacheKey(cfg, remoteHost, remotePort) + forwarderMu.RLock() + forwarder, exists := localForwarders[key] + forwarderMu.RUnlock() + if exists && forwarder != nil && !forwarder.IsClosed() { + return forwarder, nil + } + + if exists { + forwarderMu.Lock() + delete(localForwarders, key) + forwarderMu.Unlock() + } + + next, err := NewLocalForwarder(cfg, remoteHost, remotePort) + if err != nil { + return nil, err + } + + forwarderMu.Lock() + localForwarders[key] = next + forwarderMu.Unlock() + return next, nil +} + +func forwarderCacheKey(cfg connection.ProxyConfig, remoteHost string, remotePort int) string { + trimmedHost := strings.TrimSpace(remoteHost) + credential := cfg.User + "\x00" + cfg.Password + credentialHash := sha256.Sum256([]byte(credential)) + // 仅保留短指纹用于区分不同认证信息,避免在 key 日志中泄露明文口令。 + fingerprint := hex.EncodeToString(credentialHash[:8]) + return fmt.Sprintf("%s://%s:%d@%s:%d#%s", cfg.Type, cfg.Host, cfg.Port, trimmedHost, remotePort, fingerprint) +} + +func NewLocalForwarder(proxyConfig connection.ProxyConfig, remoteHost string, remotePort int) (*LocalForwarder, error) { + cfg, err := NormalizeConfig(proxyConfig) + if err != nil { + return nil, err + } + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return nil, fmt.Errorf("创建本地代理监听失败:%w", err) + } + + localAddr := listener.Addr().String() + remoteAddr := net.JoinHostPort(strings.TrimSpace(remoteHost), fmt.Sprintf("%d", remotePort)) + proxyAddr := net.JoinHostPort(cfg.Host, fmt.Sprintf("%d", cfg.Port)) + forwarder := &LocalForwarder{ + LocalAddr: localAddr, + RemoteAddr: remoteAddr, + ProxyAddr: proxyAddr, + ProxyType: cfg.Type, + cfg: cfg, + listener: listener, + closeChan: make(chan struct{}), + } + + go forwarder.forward() + logger.Infof("已创建代理端口转发:本地 %s -> 远端 %s(代理 %s://%s)", localAddr, remoteAddr, cfg.Type, proxyAddr) + return forwarder, nil +} + +func (f *LocalForwarder) forward() { + for { + localConn, err := f.listener.Accept() + if err != nil { + select { + case <-f.closeChan: + return + default: + logger.Warnf("接受本地代理连接失败:%v", err) + return + } + } + go f.handleConnection(localConn) + } +} + +func (f *LocalForwarder) handleConnection(localConn net.Conn) { + defer localConn.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultDialTimeout) + remoteConn, err := dialThroughProxy(ctx, f.cfg, "tcp", f.RemoteAddr) + cancel() + if err != nil { + logger.Warnf("通过代理连接远端失败:远端=%s 代理=%s://%s 错误=%v", f.RemoteAddr, f.ProxyType, f.ProxyAddr, err) + return + } + defer remoteConn.Close() + + errc := make(chan error, 2) + var closeOnce sync.Once + closeBoth := func() { + _ = localConn.Close() + _ = remoteConn.Close() + } + go func() { + _, copyErr := io.Copy(remoteConn, localConn) + closeOnce.Do(closeBoth) + errc <- copyErr + }() + go func() { + _, copyErr := io.Copy(localConn, remoteConn) + closeOnce.Do(closeBoth) + errc <- copyErr + }() + <-errc + <-errc +} + +func (f *LocalForwarder) Close() error { + var err error + f.closeOnce.Do(func() { + f.closedMu.Lock() + f.closed = true + f.closedMu.Unlock() + close(f.closeChan) + err = f.listener.Close() + if err != nil { + logger.Warnf("关闭代理端口转发失败:%v", err) + } + }) + return err +} + +func (f *LocalForwarder) IsClosed() bool { + f.closedMu.RLock() + defer f.closedMu.RUnlock() + return f.closed +} + +func CloseAllForwarders() { + forwarderMu.Lock() + defer forwarderMu.Unlock() + + for key, forwarder := range localForwarders { + if forwarder == nil { + continue + } + _ = forwarder.Close() + logger.Infof("已关闭代理端口转发:%s", key) + } + localForwarders = make(map[string]*LocalForwarder) +} + +func DialContext(ctx context.Context, proxyConfig connection.ProxyConfig, network, address string) (net.Conn, error) { + cfg, err := NormalizeConfig(proxyConfig) + if err != nil { + return nil, err + } + return dialThroughProxy(ctx, cfg, network, address) +} + +func dialThroughProxy(ctx context.Context, cfg connection.ProxyConfig, network, address string) (net.Conn, error) { + switch cfg.Type { + case "socks5": + return dialSOCKS5(ctx, cfg, network, address) + case "http": + return dialHTTPConnect(ctx, cfg, address) + default: + return nil, fmt.Errorf("不支持的代理类型:%s", cfg.Type) + } +} + +func dialSOCKS5(ctx context.Context, cfg connection.ProxyConfig, network, address string) (net.Conn, error) { + proxyAddr := net.JoinHostPort(cfg.Host, fmt.Sprintf("%d", cfg.Port)) + var auth *xproxy.Auth + if cfg.User != "" || cfg.Password != "" { + auth = &xproxy.Auth{ + User: cfg.User, + Password: cfg.Password, + } + } + dialer, err := xproxy.SOCKS5("tcp", proxyAddr, auth, &net.Dialer{Timeout: defaultDialTimeout}) + if err != nil { + return nil, fmt.Errorf("创建 SOCKS5 代理拨号器失败:%w", err) + } + + type result struct { + conn net.Conn + err error + } + ch := make(chan result, 1) + go func() { + conn, dialErr := dialer.Dial(network, address) + ch <- result{conn: conn, err: dialErr} + }() + + select { + case <-ctx.Done(): + go func() { + r := <-ch + if r.conn != nil { + _ = r.conn.Close() + } + }() + return nil, ctx.Err() + case r := <-ch: + if r.err != nil { + return nil, fmt.Errorf("SOCKS5 代理连接失败:%w", r.err) + } + return r.conn, nil + } +} + +func dialHTTPConnect(ctx context.Context, cfg connection.ProxyConfig, address string) (net.Conn, error) { + proxyAddr := net.JoinHostPort(cfg.Host, fmt.Sprintf("%d", cfg.Port)) + dialer := &net.Dialer{Timeout: defaultDialTimeout} + conn, err := dialer.DialContext(ctx, "tcp", proxyAddr) + if err != nil { + return nil, fmt.Errorf("连接 HTTP 代理失败:%w", err) + } + + connectReq := &http.Request{ + Method: http.MethodConnect, + URL: &url.URL{Opaque: address}, + Host: address, + Header: make(http.Header), + } + if cfg.User != "" || cfg.Password != "" { + raw := cfg.User + ":" + cfg.Password + connectReq.Header.Set("Proxy-Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(raw))) + } + if err := connectReq.Write(conn); err != nil { + _ = conn.Close() + return nil, fmt.Errorf("发送 HTTP CONNECT 请求失败:%w", err) + } + + reader := bufio.NewReader(conn) + resp, err := http.ReadResponse(reader, connectReq) + if err != nil { + _ = conn.Close() + return nil, fmt.Errorf("读取 HTTP CONNECT 响应失败:%w", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + _ = conn.Close() + return nil, fmt.Errorf("HTTP 代理 CONNECT 失败:%s", strings.TrimSpace(resp.Status)) + } + + if reader.Buffered() == 0 { + return conn, nil + } + return &bufferedConn{Conn: conn, reader: reader}, nil +} + +type bufferedConn struct { + net.Conn + reader *bufio.Reader +} + +func (c *bufferedConn) Read(p []byte) (int, error) { + if c.reader == nil { + return c.Conn.Read(p) + } + if c.reader.Buffered() == 0 { + c.reader = nil + return c.Conn.Read(p) + } + return c.reader.Read(p) +} diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go new file mode 100644 index 0000000..4395bc1 --- /dev/null +++ b/internal/proxy/proxy_test.go @@ -0,0 +1,44 @@ +package proxy + +import ( + "strings" + "testing" + + "GoNavi-Wails/internal/connection" +) + +func TestNormalizeConfigSupportsSocks5hAlias(t *testing.T) { + cfg, err := NormalizeConfig(connection.ProxyConfig{ + Type: "SOCKS5H", + Host: "127.0.0.1", + Port: 1080, + }) + if err != nil { + t.Fatalf("NormalizeConfig returned error: %v", err) + } + if cfg.Type != "socks5" { + t.Fatalf("expected normalized proxy type socks5, got %s", cfg.Type) + } +} + +func TestForwarderCacheKeyIncludesCredentialFingerprint(t *testing.T) { + base := connection.ProxyConfig{ + Type: "socks5", + Host: "127.0.0.1", + Port: 1080, + User: "tester", + Password: "first-password", + } + other := base + other.Password = "second-password" + + keyA := forwarderCacheKey(base, "db.internal", 3306) + keyB := forwarderCacheKey(other, "db.internal", 3306) + + if keyA == keyB { + t.Fatalf("expected different cache key for different credentials") + } + if strings.Contains(keyA, base.Password) || strings.Contains(keyB, other.Password) { + t.Fatalf("cache key should not contain raw password") + } +}