Merge pull request #93 from Syngnat/release/0.3.8

Release/0.3.8
This commit is contained in:
Syngnat
2026-02-09 21:56:38 +08:00
committed by GitHub
20 changed files with 2346 additions and 562 deletions

View File

@@ -67,6 +67,11 @@ body[data-theme='dark'] {
text-shadow: 0 1px 2px rgba(0, 0, 0, 0.8);
}
/* 连接配置弹窗:滚动仅在弹窗 body 内部,不使用外层 wrap 滚动条 */
.connection-modal-wrap {
overflow: hidden !important;
}
/* Custom Title Bar Close Button Hover */
.titlebar-close-btn:hover {
background-color: #ff4d4f !important;

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -4,7 +4,7 @@ import { TabData, ColumnDefinition } from '../types';
import { useStore } from '../store';
import { DBQuery, DBGetColumns } from '../../wailsjs/go/app/App';
import DataGrid, { GONAVI_ROW_KEY } from './DataGrid';
import { buildWhereSQL, quoteIdentPart, quoteQualifiedIdent } from '../utils/sql';
import { buildWhereSQL, quoteIdentPart, quoteQualifiedIdent, type FilterCondition } from '../utils/sql';
const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
const [data, setData] = useState<any[]>([]);
@@ -29,7 +29,7 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
const [sortInfo, setSortInfo] = useState<{ columnKey: string, order: string } | null>(null);
const [showFilter, setShowFilter] = useState(false);
const [filterConditions, setFilterConditions] = useState<any[]>([]);
const [filterConditions, setFilterConditions] = useState<FilterCondition[]>([]);
const currentConnType = (connections.find(c => c.id === tab.connectionId)?.config?.type || '').toLowerCase();
const forceReadOnly = currentConnType === 'tdengine';
@@ -220,7 +220,7 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
const handleSort = useCallback((field: string, order: string) => setSortInfo({ columnKey: field, order }), []);
const handlePageChange = useCallback((page: number, size: number) => fetchData(page, size), [fetchData]);
const handleToggleFilter = useCallback(() => setShowFilter(prev => !prev), []);
const handleApplyFilter = useCallback((conditions: any[]) => setFilterConditions(conditions), []);
const handleApplyFilter = useCallback((conditions: FilterCondition[]) => setFilterConditions(conditions), []);
useEffect(() => {
fetchData(1, pagination.pageSize);

View File

@@ -148,14 +148,25 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
}, [savedQueries]);
useEffect(() => {
setTreeData(connections.map(conn => ({
title: conn.name,
key: conn.id,
icon: conn.config.type === 'redis' ? <CloudOutlined style={{ color: '#DC382D' }} /> : <HddOutlined />,
type: 'connection',
dataRef: conn,
isLeaf: false,
})));
setTreeData((prev) => {
const prevMap = new Map<string, TreeNode>();
prev.forEach((node) => {
prevMap.set(String(node.key), node);
});
return connections.map((conn) => {
const existing = prevMap.get(conn.id);
return {
title: conn.name,
key: conn.id,
icon: conn.config.type === 'redis' ? <CloudOutlined style={{ color: '#DC382D' }} /> : <HddOutlined />,
type: 'connection',
dataRef: conn,
isLeaf: false,
children: existing?.children,
} as TreeNode;
});
});
}, [connections]);
const updateTreeData = (list: TreeNode[], key: React.Key, children: TreeNode[] | undefined): TreeNode[] => {

View File

@@ -8,9 +8,29 @@ import TableDesigner from './TableDesigner';
import RedisViewer from './RedisViewer';
import RedisCommandEditor from './RedisCommandEditor';
import TriggerViewer from './TriggerViewer';
import type { TabData } from '../types';
const detectConnectionEnvLabel = (connectionName: string): string | null => {
const tokens = connectionName.toLowerCase().split(/[^a-z0-9]+/).filter(Boolean);
if (tokens.includes('prod') || tokens.includes('production')) return 'PROD';
if (tokens.includes('uat')) return 'UAT';
if (tokens.includes('dev') || tokens.includes('development')) return 'DEV';
if (tokens.includes('sit')) return 'SIT';
if (tokens.includes('stg') || tokens.includes('stage') || tokens.includes('staging') || tokens.includes('pre')) return 'STG';
if (tokens.includes('test') || tokens.includes('qa')) return 'TEST';
return null;
};
const buildTabDisplayTitle = (tab: TabData, connectionName: string | undefined): string => {
if (tab.type !== 'table' && tab.type !== 'design') return tab.title;
if (!connectionName) return tab.title;
const prefix = detectConnectionEnvLabel(connectionName) || connectionName;
return `[${prefix}] ${tab.title}`;
};
const TabManager: React.FC = () => {
const tabs = useStore(state => state.tabs);
const connections = useStore(state => state.connections);
const activeTabId = useStore(state => state.activeTabId);
const setActiveTab = useStore(state => state.setActiveTab);
const closeTab = useStore(state => state.closeTab);
@@ -30,6 +50,8 @@ const TabManager: React.FC = () => {
};
const items = useMemo(() => tabs.map((tab, index) => {
const connectionName = connections.find((conn) => conn.id === tab.connectionId)?.name;
const displayTitle = buildTabDisplayTitle(tab, connectionName);
let content;
if (tab.type === 'query') {
content = <QueryEditor tab={tab} />;
@@ -76,13 +98,13 @@ const TabManager: React.FC = () => {
return {
label: (
<Dropdown menu={{ items: menuItems }} trigger={['contextMenu']}>
<span onContextMenu={(e) => e.preventDefault()}>{tab.title}</span>
<span onContextMenu={(e) => e.preventDefault()}>{displayTitle}</span>
</Dropdown>
),
key: tab.id,
children: content,
};
}), [tabs, closeOtherTabs, closeTabsToLeft, closeTabsToRight, closeAllTabs]);
}), [tabs, connections, closeOtherTabs, closeTabsToLeft, closeTabsToRight, closeAllTabs]);
return (
<>

View File

@@ -12,10 +12,32 @@ export interface ConnectionConfig {
port: number;
user: string;
password?: string;
savePassword?: boolean;
database?: string;
useSSH?: boolean;
ssh?: SSHConfig;
redisDB?: number; // Redis database index (0-15)
uri?: string; // Connection URI for copy/paste
hosts?: string[]; // Multi-host addresses: host:port
topology?: 'single' | 'replica';
mysqlReplicaUser?: string;
mysqlReplicaPassword?: string;
replicaSet?: string;
authSource?: string;
readPreference?: string;
mongoSrv?: boolean;
mongoAuthMechanism?: string;
mongoReplicaUser?: string;
mongoReplicaPassword?: string;
}
export interface MongoMemberInfo {
host: string;
role: string;
state: string;
stateCode?: number;
healthy: boolean;
isSelf?: boolean;
}
export interface SavedConnection {

View File

@@ -1,5 +1,6 @@
export type FilterCondition = {
id?: number;
enabled?: boolean;
column?: string;
op?: string;
value?: string;
@@ -75,6 +76,8 @@ export const buildWhereSQL = (dbType: string, conditions: FilterCondition[]) =>
const whereParts: string[] = [];
(conditions || []).forEach((cond) => {
if (cond?.enabled === false) return;
const op = (cond?.op || '').trim();
const column = (cond?.column || '').trim();
const value = (cond?.value ?? '').toString();

View File

@@ -62,6 +62,8 @@ export function ImportData(arg1:connection.ConnectionConfig,arg2:string,arg3:str
export function InstallUpdateAndRestart():Promise<connection.QueryResult>;
export function MongoDiscoverMembers(arg1:connection.ConnectionConfig):Promise<connection.QueryResult>;
export function MySQLConnect(arg1:connection.ConnectionConfig):Promise<connection.QueryResult>;
export function MySQLGetDatabases(arg1:connection.ConnectionConfig):Promise<connection.QueryResult>;

View File

@@ -118,6 +118,10 @@ export function InstallUpdateAndRestart() {
return window['go']['app']['App']['InstallUpdateAndRestart']();
}
export function MongoDiscoverMembers(arg1) {
return window['go']['app']['App']['MongoDiscoverMembers'](arg1);
}
export function MySQLConnect(arg1) {
return window['go']['app']['App']['MySQLConnect'](arg1);
}

View File

@@ -74,6 +74,7 @@ export namespace connection {
port: number;
user: string;
password: string;
savePassword?: boolean;
database: string;
useSSH: boolean;
ssh: SSHConfig;
@@ -81,6 +82,18 @@ export namespace connection {
dsn?: string;
timeout?: number;
redisDB?: number;
uri?: string;
hosts?: string[];
topology?: string;
mysqlReplicaUser?: string;
mysqlReplicaPassword?: string;
replicaSet?: string;
authSource?: string;
readPreference?: string;
mongoSrv?: boolean;
mongoAuthMechanism?: string;
mongoReplicaUser?: string;
mongoReplicaPassword?: string;
static createFrom(source: any = {}) {
return new ConnectionConfig(source);
@@ -93,6 +106,7 @@ export namespace connection {
this.port = source["port"];
this.user = source["user"];
this.password = source["password"];
this.savePassword = source["savePassword"];
this.database = source["database"];
this.useSSH = source["useSSH"];
this.ssh = this.convertValues(source["ssh"], SSHConfig);
@@ -100,6 +114,18 @@ export namespace connection {
this.dsn = source["dsn"];
this.timeout = source["timeout"];
this.redisDB = source["redisDB"];
this.uri = source["uri"];
this.hosts = source["hosts"];
this.topology = source["topology"];
this.mysqlReplicaUser = source["mysqlReplicaUser"];
this.mysqlReplicaPassword = source["mysqlReplicaPassword"];
this.replicaSet = source["replicaSet"];
this.authSource = source["authSource"];
this.readPreference = source["readPreference"];
this.mongoSrv = source["mongoSrv"];
this.mongoAuthMechanism = source["mongoAuthMechanism"];
this.mongoReplicaUser = source["mongoReplicaUser"];
this.mongoReplicaPassword = source["mongoReplicaPassword"];
}
convertValues(a: any, classs: any, asMap: boolean = false): any {

2
go.mod
View File

@@ -14,6 +14,7 @@ require (
github.com/wailsapp/wails/v2 v2.11.0
go.mongodb.org/mongo-driver/v2 v2.5.0
golang.org/x/crypto v0.47.0
golang.org/x/text v0.33.0
modernc.org/sqlite v1.44.3
)
@@ -64,7 +65,6 @@ require (
golang.org/x/net v0.48.0 // indirect
golang.org/x/sync v0.19.0 // indirect
golang.org/x/sys v0.40.0 // indirect
golang.org/x/text v0.33.0 // indirect
modernc.org/libc v1.67.6 // indirect
modernc.org/mathutil v1.7.1 // indirect
modernc.org/memory v1.11.0 // indirect

View File

@@ -103,10 +103,11 @@ type withLogHint struct {
}
func (e withLogHint) Error() string {
message := normalizeErrorMessage(e.err)
if strings.TrimSpace(e.logPath) == "" {
return e.err.Error()
return message
}
return fmt.Sprintf("%s详细日志%s", e.err.Error(), e.logPath)
return fmt.Sprintf("%s详细日志%s", message, e.logPath)
}
func (e withLogHint) Unwrap() error {
@@ -128,6 +129,33 @@ func formatConnSummary(config connection.ConnectionConfig) string {
b.WriteString(fmt.Sprintf("类型=%s 地址=%s:%d 数据库=%s 用户=%s 超时=%ds",
config.Type, config.Host, config.Port, dbName, config.User, timeoutSeconds))
if len(config.Hosts) > 0 {
b.WriteString(fmt.Sprintf(" 节点数=%d", len(config.Hosts)))
}
if strings.TrimSpace(config.Topology) != "" {
b.WriteString(fmt.Sprintf(" 拓扑=%s", strings.TrimSpace(config.Topology)))
}
if strings.TrimSpace(config.URI) != "" {
b.WriteString(fmt.Sprintf(" URI=已配置(长度=%d)", len(config.URI)))
}
if strings.TrimSpace(config.MySQLReplicaUser) != "" {
b.WriteString(" MySQL从库凭据=已配置")
}
if strings.EqualFold(strings.TrimSpace(config.Type), "mongodb") {
if strings.TrimSpace(config.MongoReplicaUser) != "" {
b.WriteString(" Mongo从库凭据=已配置")
}
if strings.TrimSpace(config.ReplicaSet) != "" {
b.WriteString(fmt.Sprintf(" 副本集=%s", strings.TrimSpace(config.ReplicaSet)))
}
if strings.TrimSpace(config.ReadPreference) != "" {
b.WriteString(fmt.Sprintf(" 读偏好=%s", strings.TrimSpace(config.ReadPreference)))
}
if strings.TrimSpace(config.AuthSource) != "" {
b.WriteString(fmt.Sprintf(" 认证库=%s", strings.TrimSpace(config.AuthSource)))
}
}
if config.UseSSH {
b.WriteString(fmt.Sprintf(" SSH=%s:%d 用户=%s", config.SSH.Host, config.SSH.Port, config.SSH.User))
}

100
internal/app/error_text.go Normal file
View File

@@ -0,0 +1,100 @@
package app
import (
"strings"
"unicode"
"unicode/utf8"
"golang.org/x/text/encoding/simplifiedchinese"
"golang.org/x/text/transform"
)
func normalizeErrorMessage(err error) string {
if err == nil {
return ""
}
return normalizeMixedEncodingText(err.Error())
}
func normalizeMixedEncodingText(text string) string {
if text == "" {
return text
}
raw := []byte(text)
output := make([]byte, 0, len(raw)+16)
suspect := make([]byte, 0, 16)
flushSuspect := func() {
if len(suspect) == 0 {
return
}
fallback := strings.ToValidUTF8(string(suspect), "<22>")
decoded, _, err := transform.Bytes(simplifiedchinese.GB18030.NewDecoder(), suspect)
if err == nil && utf8.Valid(decoded) {
candidate := string(decoded)
if scoreDecodedText(candidate) > scoreDecodedText(fallback) {
output = append(output, []byte(candidate)...)
} else {
output = append(output, []byte(fallback)...)
}
} else {
output = append(output, []byte(fallback)...)
}
suspect = suspect[:0]
}
for len(raw) > 0 {
r, size := utf8.DecodeRune(raw)
if r == utf8.RuneError && size == 1 {
suspect = append(suspect, raw[0])
raw = raw[1:]
continue
}
if isLikelyMojibakeRune(r) {
suspect = append(suspect, raw[:size]...)
} else {
flushSuspect()
output = append(output, raw[:size]...)
}
raw = raw[size:]
}
flushSuspect()
return string(output)
}
func isLikelyMojibakeRune(r rune) bool {
if r == utf8.RuneError {
return true
}
if r >= 0x00C0 && r <= 0x02FF {
return true
}
if unicode.In(r, unicode.Hebrew, unicode.Arabic, unicode.Cyrillic, unicode.Greek) {
return true
}
return false
}
func scoreDecodedText(text string) int {
score := 0
for _, r := range text {
switch {
case r == '<27>':
score -= 6
case unicode.Is(unicode.Han, r):
score += 4
case isLikelyMojibakeRune(r):
score -= 3
case unicode.IsPrint(r):
score += 1
default:
score -= 2
}
}
return score
}

View File

@@ -0,0 +1,25 @@
package app
import "testing"
func TestNormalizeMixedEncodingText_GBKErrorMessage(t *testing.T) {
raw := []byte("pq: ")
raw = append(raw, 0xD3, 0xC3, 0xBB, 0xA7) // 用户
raw = append(raw, []byte(` "root" Password `)...)
raw = append(raw, 0xC8, 0xCF, 0xD6, 0xA4, 0xCA, 0xA7, 0xB0, 0xDC) // 认证失败
raw = append(raw, []byte(" (28P01)")...)
got := normalizeMixedEncodingText(string(raw))
want := `pq: 用户 "root" Password 认证失败 (28P01)`
if got != want {
t.Fatalf("normalizeMixedEncodingText() mismatch\nwant: %q\ngot: %q", want, got)
}
}
func TestNormalizeMixedEncodingText_KeepUTF8(t *testing.T) {
input := `连接建立后验证失败pq: password authentication failed for user "root"`
got := normalizeMixedEncodingText(input)
if got != input {
t.Fatalf("expected unchanged utf8 text, got: %q", got)
}
}

View File

@@ -36,6 +36,41 @@ func (a *App) TestConnection(config connection.ConnectionConfig) connection.Quer
return connection.QueryResult{Success: true, Message: "连接成功"}
}
func (a *App) MongoDiscoverMembers(config connection.ConnectionConfig) connection.QueryResult {
config.Type = "mongodb"
dbInst, err := a.getDatabaseForcePing(config)
if err != nil {
logger.Error(err, "MongoDiscoverMembers 获取连接失败:%s", formatConnSummary(config))
return connection.QueryResult{Success: false, Message: err.Error()}
}
discoverable, ok := dbInst.(interface {
DiscoverMembers() (string, []connection.MongoMemberInfo, error)
})
if !ok {
return connection.QueryResult{Success: false, Message: "当前 MongoDB 驱动不支持成员发现"}
}
replicaSet, members, err := discoverable.DiscoverMembers()
if err != nil {
logger.Error(err, "MongoDiscoverMembers 执行失败:%s", formatConnSummary(config))
return connection.QueryResult{Success: false, Message: err.Error()}
}
data := map[string]interface{}{
"replicaSet": replicaSet,
"members": members,
}
logger.Infof("MongoDiscoverMembers 成功:%s 成员数=%d 副本集=%s", formatConnSummary(config), len(members), replicaSet)
return connection.QueryResult{
Success: true,
Message: fmt.Sprintf("发现 %d 个成员", len(members)),
Data: data,
}
}
func (a *App) CreateDatabase(config connection.ConnectionConfig, dbName string) connection.QueryResult {
runConfig := config
runConfig.Database = ""

View File

@@ -857,7 +857,12 @@ func detectMacAppPath(exePath string) string {
parts := strings.Split(exePath, string(filepath.Separator))
for i := len(parts) - 1; i >= 0; i-- {
if strings.HasSuffix(parts[i], ".app") {
return filepath.Join(parts[:i+1]...)
appPath := filepath.Join(parts[:i+1]...)
// 确保返回绝对路径
if !filepath.IsAbs(appPath) {
appPath = string(filepath.Separator) + appPath
}
return appPath
}
}
return ""

View File

@@ -11,18 +11,31 @@ type SSHConfig struct {
// ConnectionConfig holds database connection details including SSH
type ConnectionConfig struct {
Type string `json:"type"`
Host string `json:"host"`
Port int `json:"port"`
User string `json:"user"`
Password string `json:"password"`
Database string `json:"database"`
UseSSH bool `json:"useSSH"`
SSH SSHConfig `json:"ssh"`
Driver string `json:"driver,omitempty"` // For custom connection
DSN string `json:"dsn,omitempty"` // For custom connection
Timeout int `json:"timeout,omitempty"` // Connection timeout in seconds (default: 30)
RedisDB int `json:"redisDB,omitempty"` // Redis database index (0-15)
Type string `json:"type"`
Host string `json:"host"`
Port int `json:"port"`
User string `json:"user"`
Password string `json:"password"`
SavePassword bool `json:"savePassword,omitempty"` // Persist password in saved connection
Database string `json:"database"`
UseSSH bool `json:"useSSH"`
SSH SSHConfig `json:"ssh"`
Driver string `json:"driver,omitempty"` // For custom connection
DSN string `json:"dsn,omitempty"` // For custom connection
Timeout int `json:"timeout,omitempty"` // Connection timeout in seconds (default: 30)
RedisDB int `json:"redisDB,omitempty"` // Redis database index (0-15)
URI string `json:"uri,omitempty"` // Connection URI for copy/paste
Hosts []string `json:"hosts,omitempty"` // Multi-host addresses: host:port
Topology string `json:"topology,omitempty"` // single | replica
MySQLReplicaUser string `json:"mysqlReplicaUser,omitempty"` // MySQL replica auth user
MySQLReplicaPassword string `json:"mysqlReplicaPassword,omitempty"` // MySQL replica auth password
ReplicaSet string `json:"replicaSet,omitempty"` // MongoDB replica set name
AuthSource string `json:"authSource,omitempty"` // MongoDB authSource
ReadPreference string `json:"readPreference,omitempty"` // MongoDB readPreference
MongoSRV bool `json:"mongoSrv,omitempty"` // MongoDB use mongodb+srv URI scheme
MongoAuthMechanism string `json:"mongoAuthMechanism,omitempty"` // MongoDB authMechanism
MongoReplicaUser string `json:"mongoReplicaUser,omitempty"` // MongoDB replica auth user
MongoReplicaPassword string `json:"mongoReplicaPassword,omitempty"` // MongoDB replica auth password
}
// QueryResult is the standard response format for Wails methods
@@ -89,3 +102,12 @@ type ChangeSet struct {
Updates []UpdateRow `json:"updates"`
Deletes []map[string]interface{} `json:"deletes"`
}
type MongoMemberInfo struct {
Host string `json:"host"`
Role string `json:"role"`
State string `json:"state"`
StateCode int `json:"stateCode,omitempty"`
Healthy bool `json:"healthy"`
IsSelf bool `json:"isSelf,omitempty"`
}

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"net"
"net/url"
"sort"
"strconv"
"strings"
"time"
@@ -26,53 +27,264 @@ type MongoDB struct {
forwarder *ssh.LocalForwarder
}
func (m *MongoDB) getURI(config connection.ConnectionConfig) string {
// mongodb://user:password@host:port/database?authSource=admin
host := config.Host
port := config.Port
if port == 0 {
port = 27017
const defaultMongoPort = 27017
func normalizeMongoAddress(host string, port int) string {
h := strings.TrimSpace(host)
if h == "" {
h = "localhost"
}
p := port
if p <= 0 {
p = defaultMongoPort
}
return fmt.Sprintf("%s:%d", h, p)
}
func normalizeMongoSeed(raw string, defaultPort int, useSRV bool) (string, bool) {
host, port, ok := parseHostPortWithDefault(raw, defaultPort)
if !ok {
return "", false
}
uri := fmt.Sprintf("mongodb://%s:%d", host, port)
if useSRV {
normalized := strings.TrimSpace(host)
if normalized == "" {
return "", false
}
return normalized, true
}
if config.User != "" {
encodedUser := url.QueryEscape(config.User)
if config.Password != "" {
encodedPass := url.QueryEscape(config.Password)
uri = fmt.Sprintf("mongodb://%s:%s@%s:%d", encodedUser, encodedPass, host, port)
return normalizeMongoAddress(host, port), true
}
func collectMongoSeeds(config connection.ConnectionConfig) []string {
defaultPort := config.Port
if defaultPort <= 0 {
defaultPort = defaultMongoPort
}
useSRV := config.MongoSRV
candidates := make([]string, 0, len(config.Hosts)+1)
if len(config.Hosts) > 0 {
candidates = append(candidates, config.Hosts...)
} else {
if useSRV {
candidates = append(candidates, strings.TrimSpace(config.Host))
} else {
uri = fmt.Sprintf("mongodb://%s@%s:%d", encodedUser, host, port)
candidates = append(candidates, normalizeMongoAddress(config.Host, defaultPort))
}
}
// Add connection options
params := []string{}
timeout := getConnectTimeoutSeconds(config)
params = append(params, fmt.Sprintf("connectTimeoutMS=%d", timeout*1000))
params = append(params, fmt.Sprintf("serverSelectionTimeoutMS=%d", timeout*1000))
// authSource: 优先使用 config.Database为空时默认 admin
authSource := "admin"
if config.Database != "" {
authSource = config.Database
result := make([]string, 0, len(candidates))
seen := make(map[string]struct{}, len(candidates))
for _, entry := range candidates {
normalized, ok := normalizeMongoSeed(entry, defaultPort, useSRV)
if !ok {
continue
}
if _, exists := seen[normalized]; exists {
continue
}
seen[normalized] = struct{}{}
result = append(result, normalized)
}
params = append(params, fmt.Sprintf("authSource=%s", authSource))
if len(params) > 0 {
uri = uri + "/?" + strings.Join(params, "&")
return result
}
func applyMongoURI(config connection.ConnectionConfig) connection.ConnectionConfig {
uriText := strings.TrimSpace(config.URI)
if uriText == "" {
return config
}
lowerURI := strings.ToLower(uriText)
if strings.HasPrefix(lowerURI, "mongodb+srv://") {
config.MongoSRV = true
}
if !strings.HasPrefix(lowerURI, "mongodb://") && !strings.HasPrefix(lowerURI, "mongodb+srv://") {
return config
}
parsed, err := url.Parse(uriText)
if err != nil {
return config
}
if parsed.User != nil {
if config.User == "" {
config.User = parsed.User.Username()
}
if pass, ok := parsed.User.Password(); ok && config.Password == "" {
config.Password = pass
}
}
if dbName := strings.TrimPrefix(parsed.Path, "/"); dbName != "" && config.Database == "" {
config.Database = dbName
}
defaultPort := config.Port
if defaultPort <= 0 {
defaultPort = defaultMongoPort
}
hostsFromURI := make([]string, 0, 4)
hostText := strings.TrimSpace(parsed.Host)
if hostText != "" {
for _, entry := range strings.Split(hostText, ",") {
normalized, ok := normalizeMongoSeed(entry, defaultPort, config.MongoSRV)
if ok {
hostsFromURI = append(hostsFromURI, normalized)
}
}
}
if len(config.Hosts) == 0 && len(hostsFromURI) > 0 {
config.Hosts = hostsFromURI
}
if strings.TrimSpace(config.Host) == "" && len(hostsFromURI) > 0 {
host, port, ok := parseHostPortWithDefault(hostsFromURI[0], defaultPort)
if ok {
config.Host = host
config.Port = port
}
}
query := parsed.Query()
if config.AuthSource == "" {
config.AuthSource = strings.TrimSpace(query.Get("authSource"))
}
if config.ReadPreference == "" {
config.ReadPreference = strings.TrimSpace(query.Get("readPreference"))
}
if config.ReplicaSet == "" {
config.ReplicaSet = strings.TrimSpace(query.Get("replicaSet"))
}
if config.MongoAuthMechanism == "" {
config.MongoAuthMechanism = strings.TrimSpace(query.Get("authMechanism"))
}
if config.Topology == "" {
if len(config.Hosts) > 1 || strings.TrimSpace(config.ReplicaSet) != "" {
config.Topology = "replica"
} else {
config.Topology = "single"
}
}
return config
}
func (m *MongoDB) getURI(config connection.ConnectionConfig) string {
if strings.TrimSpace(config.URI) != "" {
return strings.TrimSpace(config.URI)
}
seeds := collectMongoSeeds(config)
if len(seeds) == 0 {
if config.MongoSRV {
seed := strings.TrimSpace(config.Host)
if seed == "" {
seed = "localhost"
}
seeds = append(seeds, seed)
} else {
seeds = append(seeds, normalizeMongoAddress(config.Host, config.Port))
}
}
scheme := "mongodb"
if config.MongoSRV {
scheme = "mongodb+srv"
}
hostText := strings.Join(seeds, ",")
uri := fmt.Sprintf("%s://%s", scheme, hostText)
if config.User != "" {
encodedUser := url.PathEscape(config.User)
if config.Password != "" {
encodedPass := url.PathEscape(config.Password)
uri = fmt.Sprintf("%s://%s:%s@%s", scheme, encodedUser, encodedPass, hostText)
} else {
uri = fmt.Sprintf("%s://%s@%s", scheme, encodedUser, hostText)
}
}
path := "/"
if strings.TrimSpace(config.Database) != "" {
path = "/" + url.PathEscape(strings.TrimSpace(config.Database))
}
uri += path
params := url.Values{}
timeout := getConnectTimeoutSeconds(config)
params.Set("connectTimeoutMS", strconv.Itoa(timeout*1000))
params.Set("serverSelectionTimeoutMS", strconv.Itoa(timeout*1000))
authSource := strings.TrimSpace(config.AuthSource)
if authSource == "" && strings.TrimSpace(config.Database) != "" {
authSource = strings.TrimSpace(config.Database)
}
if authSource == "" {
authSource = "admin"
}
params.Set("authSource", authSource)
if replicaSet := strings.TrimSpace(config.ReplicaSet); replicaSet != "" {
params.Set("replicaSet", replicaSet)
}
if readPreference := strings.TrimSpace(config.ReadPreference); readPreference != "" {
params.Set("readPreference", readPreference)
}
if authMechanism := strings.TrimSpace(config.MongoAuthMechanism); authMechanism != "" {
params.Set("authMechanism", authMechanism)
}
if encoded := params.Encode(); encoded != "" {
uri += "?" + encoded
}
return uri
}
func buildMongoAuthAttempts(config connection.ConnectionConfig) []connection.ConnectionConfig {
attempts := []connection.ConnectionConfig{config}
replicaUser := strings.TrimSpace(config.MongoReplicaUser)
if replicaUser == "" {
return attempts
}
if replicaUser == strings.TrimSpace(config.User) && config.MongoReplicaPassword == config.Password {
return attempts
}
replicaConfig := config
replicaConfig.URI = ""
replicaConfig.User = replicaUser
replicaConfig.Password = config.MongoReplicaPassword
attempts = append(attempts, replicaConfig)
return attempts
}
func (m *MongoDB) Connect(config connection.ConnectionConfig) error {
var uri string
runConfig := applyMongoURI(config)
connectConfig := runConfig
if config.UseSSH {
logger.Infof("MongoDB 使用 SSH 连接:地址=%s:%d", config.Host, config.Port)
if runConfig.UseSSH && runConfig.MongoSRV {
return fmt.Errorf("MongoDB SRV 记录模式暂不支持 SSH 隧道")
}
forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port)
if runConfig.UseSSH {
seeds := collectMongoSeeds(runConfig)
if len(seeds) == 0 {
seeds = append(seeds, normalizeMongoAddress(runConfig.Host, runConfig.Port))
}
targetHost, targetPort, ok := parseHostPortWithDefault(seeds[0], defaultMongoPort)
if !ok {
return fmt.Errorf("MongoDB 连接失败:无效地址 %s", seeds[0])
}
logger.Infof("MongoDB 使用 SSH 连接:地址=%s:%d", targetHost, targetPort)
forwarder, err := ssh.GetOrCreateLocalForwarder(runConfig.SSH, targetHost, targetPort)
if err != nil {
return fmt.Errorf("创建 SSH 隧道失败:%w", err)
}
@@ -88,35 +300,55 @@ func (m *MongoDB) Connect(config connection.ConnectionConfig) error {
return fmt.Errorf("解析本地端口失败:%w", err)
}
localConfig := config
localConfig := runConfig
localConfig.Host = host
localConfig.Port = port
localConfig.UseSSH = false
uri = m.getURI(localConfig)
logger.Infof("MongoDB 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
} else {
uri = m.getURI(config)
localConfig.URI = ""
localConfig.Hosts = []string{normalizeMongoAddress(host, port)}
connectConfig = localConfig
logger.Infof("MongoDB 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, targetHost, targetPort)
}
m.pingTimeout = getConnectTimeout(config)
m.database = config.Database
m.pingTimeout = getConnectTimeout(connectConfig)
m.database = connectConfig.Database
if m.database == "" {
m.database = "admin"
}
clientOpts := options.Client().ApplyURI(uri)
client, err := mongo.Connect(clientOpts)
if err != nil {
return fmt.Errorf("MongoDB 连接失败:%w", err)
}
m.client = client
attemptConfigs := buildMongoAuthAttempts(connectConfig)
var errorDetails []string
for index, attemptConfig := range attemptConfigs {
authLabel := "主库凭据"
if index > 0 {
authLabel = "从库凭据"
}
if err := m.Ping(); err != nil {
return fmt.Errorf("MongoDB 连接验证失败:%w", err)
uri := m.getURI(attemptConfig)
clientOpts := options.Client().ApplyURI(uri)
client, err := mongo.Connect(clientOpts)
if err != nil {
errorDetails = append(errorDetails, fmt.Sprintf("%s连接失败: %v", authLabel, err))
continue
}
m.client = client
if err := m.Ping(); err != nil {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
_ = client.Disconnect(ctx)
cancel()
m.client = nil
errorDetails = append(errorDetails, fmt.Sprintf("%s验证失败: %v", authLabel, err))
continue
}
return nil
}
return nil
if len(errorDetails) > 0 {
return fmt.Errorf("MongoDB 连接失败:%s", strings.Join(errorDetails, ""))
}
return fmt.Errorf("MongoDB 连接失败:无可用连接方案")
}
func (m *MongoDB) Close() error {
@@ -148,6 +380,226 @@ func (m *MongoDB) Ping() error {
return m.client.Ping(ctx, readpref.Primary())
}
func asMongoStringList(raw interface{}) []string {
values, ok := raw.(bson.A)
if !ok {
return nil
}
result := make([]string, 0, len(values))
for _, entry := range values {
text := strings.TrimSpace(fmt.Sprintf("%v", entry))
if text != "" {
result = append(result, text)
}
}
return result
}
func asMongoString(raw interface{}) string {
if raw == nil {
return ""
}
if value, ok := raw.(string); ok {
return strings.TrimSpace(value)
}
return strings.TrimSpace(fmt.Sprintf("%v", raw))
}
func asMongoInt(raw interface{}) int {
switch value := raw.(type) {
case int:
return value
case int32:
return int(value)
case int64:
return int(value)
case float32:
return int(value)
case float64:
return int(value)
default:
return 0
}
}
func asMongoBool(raw interface{}) bool {
switch value := raw.(type) {
case bool:
return value
case int:
return value != 0
case int32:
return value != 0
case int64:
return value != 0
case float32:
return value != 0
case float64:
return value != 0
default:
return false
}
}
func mongoStateByCode(code int) string {
switch code {
case 1:
return "PRIMARY"
case 2:
return "SECONDARY"
case 3:
return "RECOVERING"
case 5:
return "STARTUP2"
case 6:
return "UNKNOWN"
case 7:
return "ARBITER"
case 8:
return "DOWN"
case 9:
return "ROLLBACK"
case 10:
return "REMOVED"
default:
return "UNKNOWN"
}
}
func normalizeMongoStateLabel(state string, stateCode int) string {
normalized := strings.ToUpper(strings.TrimSpace(state))
if normalized != "" {
return normalized
}
return mongoStateByCode(stateCode)
}
func buildMembersFromReplStatus(raw bson.M) []connection.MongoMemberInfo {
items, ok := raw["members"].(bson.A)
if !ok {
return nil
}
members := make([]connection.MongoMemberInfo, 0, len(items))
for _, entry := range items {
member, ok := entry.(bson.M)
if !ok {
continue
}
host := asMongoString(member["name"])
if host == "" {
continue
}
stateCode := asMongoInt(member["state"])
state := normalizeMongoStateLabel(asMongoString(member["stateStr"]), stateCode)
members = append(members, connection.MongoMemberInfo{
Host: host,
Role: state,
State: state,
StateCode: stateCode,
Healthy: asMongoInt(member["health"]) > 0 || asMongoBool(member["health"]),
IsSelf: asMongoBool(member["self"]),
})
}
sort.Slice(members, func(i, j int) bool {
return members[i].Host < members[j].Host
})
return members
}
func buildMembersFromHello(raw bson.M) []connection.MongoMemberInfo {
hosts := asMongoStringList(raw["hosts"])
if len(hosts) == 0 {
return nil
}
primary := asMongoString(raw["primary"])
selfHost := asMongoString(raw["me"])
passiveSet := make(map[string]struct{})
for _, host := range asMongoStringList(raw["passives"]) {
passiveSet[host] = struct{}{}
}
arbiterSet := make(map[string]struct{})
for _, host := range asMongoStringList(raw["arbiters"]) {
arbiterSet[host] = struct{}{}
}
members := make([]connection.MongoMemberInfo, 0, len(hosts))
for _, host := range hosts {
state := "SECONDARY"
stateCode := 2
if host == primary {
state = "PRIMARY"
stateCode = 1
} else if _, ok := arbiterSet[host]; ok {
state = "ARBITER"
stateCode = 7
} else if _, ok := passiveSet[host]; ok {
state = "PASSIVE"
stateCode = 6
}
members = append(members, connection.MongoMemberInfo{
Host: host,
Role: state,
State: state,
StateCode: stateCode,
Healthy: true,
IsSelf: host == selfHost,
})
}
sort.Slice(members, func(i, j int) bool {
return members[i].Host < members[j].Host
})
return members
}
func (m *MongoDB) DiscoverMembers() (string, []connection.MongoMemberInfo, error) {
if m.client == nil {
return "", nil, fmt.Errorf("connection not open")
}
timeout := m.pingTimeout
if timeout <= 0 {
timeout = 10 * time.Second
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
adminDB := m.client.Database("admin")
var replStatus bson.M
replErr := adminDB.RunCommand(ctx, bson.D{{Key: "replSetGetStatus", Value: 1}}).Decode(&replStatus)
if replErr == nil {
replicaSet := asMongoString(replStatus["set"])
members := buildMembersFromReplStatus(replStatus)
if len(members) > 0 {
return replicaSet, members, nil
}
}
var helloResult bson.M
helloErr := adminDB.RunCommand(ctx, bson.D{{Key: "hello", Value: 1}}).Decode(&helloResult)
if helloErr != nil {
if err := adminDB.RunCommand(ctx, bson.D{{Key: "isMaster", Value: 1}}).Decode(&helloResult); err != nil {
if replErr != nil {
return "", nil, fmt.Errorf("成员发现失败replSetGetStatus=%vhello=%v", replErr, err)
}
return "", nil, fmt.Errorf("成员发现失败hello=%w", err)
}
}
replicaSet := asMongoString(helloResult["setName"])
members := buildMembersFromHello(helloResult)
if len(members) == 0 {
if replErr != nil {
return replicaSet, nil, fmt.Errorf("未获取到成员信息replSetGetStatus=%v", replErr)
}
return replicaSet, nil, fmt.Errorf("未获取到成员信息")
}
return replicaSet, members, nil
}
// Query executes a MongoDB command and returns results
// Supports JSON format commands like: {"find": "collection", "filter": {}}
func (m *MongoDB) Query(query string) ([]map[string]interface{}, []string, error) {

View File

@@ -4,6 +4,8 @@ import (
"context"
"database/sql"
"fmt"
"net/url"
"strconv"
"strings"
"time"
@@ -20,16 +22,161 @@ type MySQLDB struct {
pingTimeout time.Duration
}
const defaultMySQLPort = 3306
func parseHostPortWithDefault(raw string, defaultPort int) (string, int, bool) {
text := strings.TrimSpace(raw)
if text == "" {
return "", 0, false
}
if strings.HasPrefix(text, "[") {
end := strings.Index(text, "]")
if end < 0 {
return text, defaultPort, true
}
host := text[1:end]
portText := strings.TrimSpace(text[end+1:])
if strings.HasPrefix(portText, ":") {
if p, err := strconv.Atoi(strings.TrimSpace(strings.TrimPrefix(portText, ":"))); err == nil && p > 0 {
return host, p, true
}
}
return host, defaultPort, true
}
lastColon := strings.LastIndex(text, ":")
if lastColon > 0 && strings.Count(text, ":") == 1 {
host := strings.TrimSpace(text[:lastColon])
portText := strings.TrimSpace(text[lastColon+1:])
if host != "" {
if p, err := strconv.Atoi(portText); err == nil && p > 0 {
return host, p, true
}
return host, defaultPort, true
}
}
return text, defaultPort, true
}
func normalizeMySQLAddress(host string, port int) string {
h := strings.TrimSpace(host)
if h == "" {
h = "localhost"
}
p := port
if p <= 0 {
p = defaultMySQLPort
}
return fmt.Sprintf("%s:%d", h, p)
}
func applyMySQLURI(config connection.ConnectionConfig) connection.ConnectionConfig {
uriText := strings.TrimSpace(config.URI)
if uriText == "" {
return config
}
if !strings.HasPrefix(strings.ToLower(uriText), "mysql://") {
return config
}
parsed, err := url.Parse(uriText)
if err != nil {
return config
}
if parsed.User != nil {
if config.User == "" {
config.User = parsed.User.Username()
}
if pass, ok := parsed.User.Password(); ok && config.Password == "" {
config.Password = pass
}
}
if dbName := strings.TrimPrefix(parsed.Path, "/"); dbName != "" && config.Database == "" {
config.Database = dbName
}
defaultPort := config.Port
if defaultPort <= 0 {
defaultPort = defaultMySQLPort
}
hostsFromURI := make([]string, 0, 4)
hostText := strings.TrimSpace(parsed.Host)
if hostText != "" {
for _, entry := range strings.Split(hostText, ",") {
host, port, ok := parseHostPortWithDefault(entry, defaultPort)
if !ok {
continue
}
hostsFromURI = append(hostsFromURI, normalizeMySQLAddress(host, port))
}
}
if len(config.Hosts) == 0 && len(hostsFromURI) > 0 {
config.Hosts = hostsFromURI
}
if strings.TrimSpace(config.Host) == "" && len(hostsFromURI) > 0 {
host, port, ok := parseHostPortWithDefault(hostsFromURI[0], defaultPort)
if ok {
config.Host = host
config.Port = port
}
}
if config.Topology == "" {
topology := strings.TrimSpace(parsed.Query().Get("topology"))
if topology != "" {
config.Topology = strings.ToLower(topology)
}
}
return config
}
func collectMySQLAddresses(config connection.ConnectionConfig) []string {
defaultPort := config.Port
if defaultPort <= 0 {
defaultPort = defaultMySQLPort
}
candidates := make([]string, 0, len(config.Hosts)+1)
if len(config.Hosts) > 0 {
candidates = append(candidates, config.Hosts...)
} else {
candidates = append(candidates, normalizeMySQLAddress(config.Host, defaultPort))
}
result := make([]string, 0, len(candidates))
seen := make(map[string]struct{}, len(candidates))
for _, entry := range candidates {
host, port, ok := parseHostPortWithDefault(entry, defaultPort)
if !ok {
continue
}
normalized := normalizeMySQLAddress(host, port)
if _, exists := seen[normalized]; exists {
continue
}
seen[normalized] = struct{}{}
result = append(result, normalized)
}
return result
}
func (m *MySQLDB) getDSN(config connection.ConnectionConfig) string {
database := config.Database
protocol := "tcp"
address := fmt.Sprintf("%s:%d", config.Host, config.Port)
address := normalizeMySQLAddress(config.Host, config.Port)
if config.UseSSH {
netName, err := ssh.RegisterSSHNetwork(config.SSH)
if err == nil {
protocol = netName
address = fmt.Sprintf("%s:%d", config.Host, config.Port)
address = normalizeMySQLAddress(config.Host, config.Port)
} else {
logger.Warnf("注册 SSH 网络失败,将尝试直连:地址=%s:%d 用户=%s原因%v", config.Host, config.Port, config.User, err)
}
@@ -41,20 +188,67 @@ func (m *MySQLDB) getDSN(config connection.ConnectionConfig) string {
config.User, config.Password, protocol, address, database, timeout)
}
func (m *MySQLDB) Connect(config connection.ConnectionConfig) error {
dsn := m.getDSN(config)
db, err := sql.Open("mysql", dsn)
if err != nil {
return fmt.Errorf("打开数据库连接失败:%w", err)
}
m.conn = db
m.pingTimeout = getConnectTimeout(config)
func resolveMySQLCredential(config connection.ConnectionConfig, addressIndex int) (string, string) {
primaryUser := strings.TrimSpace(config.User)
primaryPassword := config.Password
replicaUser := strings.TrimSpace(config.MySQLReplicaUser)
replicaPassword := config.MySQLReplicaPassword
// Force verification
if err := m.Ping(); err != nil {
return fmt.Errorf("连接建立后验证失败:%w", err)
if addressIndex > 0 && replicaUser != "" {
return replicaUser, replicaPassword
}
return nil
if primaryUser == "" && replicaUser != "" {
return replicaUser, replicaPassword
}
return config.User, primaryPassword
}
func (m *MySQLDB) Connect(config connection.ConnectionConfig) error {
runConfig := applyMySQLURI(config)
addresses := collectMySQLAddresses(runConfig)
if len(addresses) == 0 {
return fmt.Errorf("连接建立后验证失败:未找到可用的 MySQL 地址")
}
var errorDetails []string
for index, address := range addresses {
candidateConfig := runConfig
host, port, ok := parseHostPortWithDefault(address, defaultMySQLPort)
if !ok {
continue
}
candidateConfig.Host = host
candidateConfig.Port = port
candidateConfig.User, candidateConfig.Password = resolveMySQLCredential(runConfig, index)
dsn := m.getDSN(candidateConfig)
db, err := sql.Open("mysql", dsn)
if err != nil {
errorDetails = append(errorDetails, fmt.Sprintf("%s 打开失败: %v", address, err))
continue
}
timeout := getConnectTimeout(candidateConfig)
ctx, cancel := utils.ContextWithTimeout(timeout)
pingErr := db.PingContext(ctx)
cancel()
if pingErr != nil {
_ = db.Close()
errorDetails = append(errorDetails, fmt.Sprintf("%s 验证失败: %v", address, pingErr))
continue
}
m.conn = db
m.pingTimeout = timeout
return nil
}
if len(errorDetails) == 0 {
return fmt.Errorf("连接建立后验证失败:未找到可用的 MySQL 地址")
}
return fmt.Errorf("连接建立后验证失败:%s", strings.Join(errorDetails, ""))
}
func (m *MySQLDB) Close() error {