Files
MyGoNavi/internal/db/mongodb_impl.go
杨国锋 72811092b4 feat(datasource): 新增 MariaDB、Vastbase、HighGo、MongoDB、SQL Server 五种数据源支持
- MariaDB:MySQL驱动占位,默认端口3306,归类关系型数据库
- Vastbase(海量):PG驱动占位,默认端口5432,归类国产数据库
- HighGo(瀚高):PG驱动,支持SM3认证扩展,归类国产数据库
- MongoDB:官方驱动实现,归类NoSQL
- SQL Server:微软官方驱动实现,归类关系型数据库
- ConnectionModal 新增数据源选项卡与默认端口配置
- database.go 新增5种类型的实例化分支
- 同步更新 db_context、methods_db、sql_sanitize、methods_file、sql_helpers 类型判断
- 新增 HighGo SM3 驱动集成指南
2026-02-08 13:39:39 +08:00

408 lines
10 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package db
import (
"context"
"fmt"
"net"
"net/url"
"strconv"
"strings"
"time"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/logger"
"GoNavi-Wails/internal/ssh"
"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/mongo"
"go.mongodb.org/mongo-driver/v2/mongo/options"
"go.mongodb.org/mongo-driver/v2/mongo/readpref"
)
type MongoDB struct {
client *mongo.Client
database string
pingTimeout time.Duration
forwarder *ssh.LocalForwarder
}
func (m *MongoDB) getURI(config connection.ConnectionConfig) string {
// mongodb://user:password@host:port/database?authSource=admin
host := config.Host
port := config.Port
if port == 0 {
port = 27017
}
uri := fmt.Sprintf("mongodb://%s:%d", host, port)
if config.User != "" {
encodedUser := url.QueryEscape(config.User)
if config.Password != "" {
encodedPass := url.QueryEscape(config.Password)
uri = fmt.Sprintf("mongodb://%s:%s@%s:%d", encodedUser, encodedPass, host, port)
} else {
uri = fmt.Sprintf("mongodb://%s@%s:%d", encodedUser, host, port)
}
}
// Add connection options
params := []string{}
timeout := getConnectTimeoutSeconds(config)
params = append(params, fmt.Sprintf("connectTimeoutMS=%d", timeout*1000))
params = append(params, fmt.Sprintf("serverSelectionTimeoutMS=%d", timeout*1000))
// authSource: 优先使用 config.Database为空时默认 admin
authSource := "admin"
if config.Database != "" {
authSource = config.Database
}
params = append(params, fmt.Sprintf("authSource=%s", authSource))
if len(params) > 0 {
uri = uri + "/?" + strings.Join(params, "&")
}
return uri
}
func (m *MongoDB) Connect(config connection.ConnectionConfig) error {
var uri string
if config.UseSSH {
logger.Infof("MongoDB 使用 SSH 连接:地址=%s:%d", config.Host, config.Port)
forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port)
if err != nil {
return fmt.Errorf("创建 SSH 隧道失败:%w", err)
}
m.forwarder = forwarder
host, portStr, err := net.SplitHostPort(forwarder.LocalAddr)
if err != nil {
return fmt.Errorf("解析本地转发地址失败:%w", err)
}
port, err := strconv.Atoi(portStr)
if err != nil {
return fmt.Errorf("解析本地端口失败:%w", err)
}
localConfig := config
localConfig.Host = host
localConfig.Port = port
localConfig.UseSSH = false
uri = m.getURI(localConfig)
logger.Infof("MongoDB 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
} else {
uri = m.getURI(config)
}
m.pingTimeout = getConnectTimeout(config)
m.database = config.Database
if m.database == "" {
m.database = "admin"
}
clientOpts := options.Client().ApplyURI(uri)
client, err := mongo.Connect(clientOpts)
if err != nil {
return fmt.Errorf("MongoDB 连接失败:%w", err)
}
m.client = client
if err := m.Ping(); err != nil {
return fmt.Errorf("MongoDB 连接验证失败:%w", err)
}
return nil
}
func (m *MongoDB) Close() error {
if m.forwarder != nil {
if err := m.forwarder.Close(); err != nil {
logger.Warnf("关闭 MongoDB SSH 端口转发失败:%v", err)
}
m.forwarder = nil
}
if m.client != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
return m.client.Disconnect(ctx)
}
return nil
}
func (m *MongoDB) Ping() error {
if m.client == nil {
return fmt.Errorf("connection not open")
}
timeout := m.pingTimeout
if timeout <= 0 {
timeout = 5 * time.Second
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return m.client.Ping(ctx, readpref.Primary())
}
// Query executes a MongoDB command and returns results
// Supports JSON format commands like: {"find": "collection", "filter": {}}
func (m *MongoDB) Query(query string) ([]map[string]interface{}, []string, error) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
return m.queryWithContext(ctx, query)
}
// QueryContext executes a MongoDB command with the given context for timeout control
func (m *MongoDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
return m.queryWithContext(ctx, query)
}
func (m *MongoDB) queryWithContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
if m.client == nil {
return nil, nil, fmt.Errorf("connection not open")
}
query = strings.TrimSpace(query)
if query == "" {
return nil, nil, fmt.Errorf("empty query")
}
// Parse JSON command
var cmd bson.D
if err := bson.UnmarshalExtJSON([]byte(query), true, &cmd); err != nil {
return nil, nil, fmt.Errorf("invalid JSON command: %w", err)
}
db := m.client.Database(m.database)
var result bson.M
if err := db.RunCommand(ctx, cmd).Decode(&result); err != nil {
return nil, nil, err
}
// Convert result to standard format
data := []map[string]interface{}{{"result": result}}
columns := []string{"result"}
// If result contains cursor with documents, extract them
if cursor, ok := result["cursor"].(bson.M); ok {
if batch, ok := cursor["firstBatch"].(bson.A); ok {
data = make([]map[string]interface{}, 0, len(batch))
columnSet := make(map[string]bool)
for _, doc := range batch {
if docMap, ok := doc.(bson.M); ok {
row := make(map[string]interface{})
for k, v := range docMap {
row[k] = v
columnSet[k] = true
}
data = append(data, row)
}
}
columns = make([]string, 0, len(columnSet))
for k := range columnSet {
columns = append(columns, k)
}
}
}
return data, columns, nil
}
func (m *MongoDB) Exec(query string) (int64, error) {
_, _, err := m.Query(query)
if err != nil {
return 0, err
}
return 1, nil
}
// ExecContext executes a MongoDB command with the given context for timeout control
func (m *MongoDB) ExecContext(ctx context.Context, query string) (int64, error) {
_, _, err := m.QueryContext(ctx, query)
if err != nil {
return 0, err
}
return 1, nil
}
func (m *MongoDB) GetDatabases() ([]string, error) {
if m.client == nil {
return nil, fmt.Errorf("connection not open")
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
dbs, err := m.client.ListDatabaseNames(ctx, bson.M{})
if err != nil {
return nil, err
}
return dbs, nil
}
func (m *MongoDB) GetTables(dbName string) ([]string, error) {
if m.client == nil {
return nil, fmt.Errorf("connection not open")
}
targetDB := dbName
if targetDB == "" {
targetDB = m.database
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
collections, err := m.client.Database(targetDB).ListCollectionNames(ctx, bson.M{})
if err != nil {
return nil, err
}
return collections, nil
}
func (m *MongoDB) GetCreateStatement(dbName, tableName string) (string, error) {
return fmt.Sprintf("// MongoDB collection: %s.%s\n// MongoDB is schemaless - no CREATE statement available", dbName, tableName), nil
}
// GetColumns returns empty for MongoDB (schemaless)
func (m *MongoDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
// MongoDB is schemaless, return empty
return []connection.ColumnDefinition{}, nil
}
// GetAllColumns returns empty for MongoDB (schemaless)
func (m *MongoDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
return []connection.ColumnDefinitionWithTable{}, nil
}
// GetIndexes returns indexes for a MongoDB collection
func (m *MongoDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
if m.client == nil {
return nil, fmt.Errorf("connection not open")
}
targetDB := dbName
if targetDB == "" {
targetDB = m.database
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
collection := m.client.Database(targetDB).Collection(tableName)
cursor, err := collection.Indexes().List(ctx)
if err != nil {
return nil, err
}
defer cursor.Close(ctx)
var indexes []connection.IndexDefinition
for cursor.Next(ctx) {
var idx bson.M
if err := cursor.Decode(&idx); err != nil {
continue
}
name := fmt.Sprintf("%v", idx["name"])
unique := false
if u, ok := idx["unique"].(bool); ok {
unique = u
}
// Extract key fields
if key, ok := idx["key"].(bson.M); ok {
seq := 1
for field := range key {
nonUnique := 1
if unique {
nonUnique = 0
}
indexes = append(indexes, connection.IndexDefinition{
Name: name,
ColumnName: field,
NonUnique: nonUnique,
SeqInIndex: seq,
IndexType: "BTREE",
})
seq++
}
}
}
return indexes, nil
}
func (m *MongoDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
// MongoDB doesn't have foreign keys
return []connection.ForeignKeyDefinition{}, nil
}
func (m *MongoDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
// MongoDB doesn't have triggers in the traditional sense
return []connection.TriggerDefinition{}, nil
}
// ApplyChanges implements batch changes for MongoDB
func (m *MongoDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
if m.client == nil {
return fmt.Errorf("connection not open")
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
collection := m.client.Database(m.database).Collection(tableName)
// Process deletes
for _, pk := range changes.Deletes {
filter := bson.M{}
for k, v := range pk {
filter[k] = v
}
if len(filter) > 0 {
if _, err := collection.DeleteOne(ctx, filter); err != nil {
return fmt.Errorf("delete error: %v", err)
}
}
}
// Process updates
for _, update := range changes.Updates {
filter := bson.M{}
for k, v := range update.Keys {
filter[k] = v
}
if len(filter) == 0 {
return fmt.Errorf("update requires keys")
}
updateDoc := bson.M{"$set": bson.M{}}
for k, v := range update.Values {
updateDoc["$set"].(bson.M)[k] = v
}
if _, err := collection.UpdateOne(ctx, filter, updateDoc); err != nil {
return fmt.Errorf("update error: %v", err)
}
}
// Process inserts
for _, row := range changes.Inserts {
doc := bson.M{}
for k, v := range row {
doc[k] = v
}
if len(doc) > 0 {
if _, err := collection.InsertOne(ctx, doc); err != nil {
return fmt.Errorf("insert error: %v", err)
}
}
}
return nil
}