Files
MyGoNavi/internal/db/mongodb_impl.go
杨国锋 78e35a5be8 ️ perf(data-grid): 重构批量编辑链路并优化表格渲染性能
- 重构批量改单元格的状态流,减少高频交互时的无效重渲染
- 优化大数据量场景下的表格交互流畅度与响应延迟
- 调整单元格编辑细节,增强与 Navicat 编辑习惯的一致性

🔧 fix(sidebar-connection): 修复多数据源切换后旧连接节点无响应问题

- 修复新建并连接新数据源后,旧数据源点击无响应的问题

 feat(tab-manager): 表与设计标签支持环境前缀显示

- 基于连接名识别 DEV/UAT/PROD/SIT/STG/TEST 环境标记
- 仅对 table/design 标签添加环境前缀,查询等标签保持原样
- 无法识别标准环境时回退显示连接名,提升多环境可辨识性

 feat(connection-config): 新增连接URI复制解析并支持MySQL/Mongo主从配置

- 连接弹窗新增 URI 生成、解析、复制能力,支持参数回填
- MySQL 支持多地址主从拓扑、从库地址列表与从库独立凭据
- Mongo 支持多节点配置、replicaSet、authSource、readPreference
- 扩展前后端连接配置模型并同步 Wails 生成类型文件
- 后端接入主从凭据回退策略,保持旧配置兼容

 feat(mongodb-replica): 对齐Navicat主从配置并补齐成员发现能力

- 新增 mongoSrv、mongoAuthMechanism、savePassword 配置项
- 支持 mongodb+srv URI 构建与解析,并透传 authMechanism
- 新增 MongoDiscoverMembers 接口,返回成员与状态信息
- 驱动侧实现 replSetGetStatus -> hello/isMaster 回退发现链路
- 前端弹窗新增 SRV 开关、验证方式、成员发现按钮与状态表
- 增加 SRV+SSH 冲突提示与后端保护,避免无效连接路径

🔧 fix(app-error-text): 修复连接测试错误信息乱码并完善日志提示

- 新增错误文本编码纠正能力,处理混合编码导致的中文乱码
- 连接错误提示统一走 normalizeErrorMessage 输出
- 增加 GB18030 纠正相关单元测试覆盖 PostgreSQL 认证失败场景
- go.mod 显式引入 golang.org/x/text 依赖

 feat(filter-panel): 筛选条件支持启用停用与批量开关

- 筛选条件新增 enabled 状态,支持按条件勾选启用/停用
- 筛选面板新增“全启用”“全停用”快捷操作
- SQL 组装时自动跳过已停用条件,保留条件内容便于复用
- 同步 DataViewer 与 SQL 工具层类型,确保筛选链路一致性

🔧 fix(connection-modal-scroll): 修复连接弹窗滚动行为并去除外层滚动条

- 连接配置步骤设置弹窗 body 最大高度与内部滚动
- 为连接弹窗增加专用 wrapClassName 并禁用外层滚动
- 修复出现双滚动条的问题,确保仅保留弹窗内部滚动条
2026-02-09 21:54:11 +08:00

860 lines
22 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package db
import (
"context"
"fmt"
"net"
"net/url"
"sort"
"strconv"
"strings"
"time"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/logger"
"GoNavi-Wails/internal/ssh"
"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/mongo"
"go.mongodb.org/mongo-driver/v2/mongo/options"
"go.mongodb.org/mongo-driver/v2/mongo/readpref"
)
type MongoDB struct {
client *mongo.Client
database string
pingTimeout time.Duration
forwarder *ssh.LocalForwarder
}
const defaultMongoPort = 27017
func normalizeMongoAddress(host string, port int) string {
h := strings.TrimSpace(host)
if h == "" {
h = "localhost"
}
p := port
if p <= 0 {
p = defaultMongoPort
}
return fmt.Sprintf("%s:%d", h, p)
}
func normalizeMongoSeed(raw string, defaultPort int, useSRV bool) (string, bool) {
host, port, ok := parseHostPortWithDefault(raw, defaultPort)
if !ok {
return "", false
}
if useSRV {
normalized := strings.TrimSpace(host)
if normalized == "" {
return "", false
}
return normalized, true
}
return normalizeMongoAddress(host, port), true
}
func collectMongoSeeds(config connection.ConnectionConfig) []string {
defaultPort := config.Port
if defaultPort <= 0 {
defaultPort = defaultMongoPort
}
useSRV := config.MongoSRV
candidates := make([]string, 0, len(config.Hosts)+1)
if len(config.Hosts) > 0 {
candidates = append(candidates, config.Hosts...)
} else {
if useSRV {
candidates = append(candidates, strings.TrimSpace(config.Host))
} else {
candidates = append(candidates, normalizeMongoAddress(config.Host, defaultPort))
}
}
result := make([]string, 0, len(candidates))
seen := make(map[string]struct{}, len(candidates))
for _, entry := range candidates {
normalized, ok := normalizeMongoSeed(entry, defaultPort, useSRV)
if !ok {
continue
}
if _, exists := seen[normalized]; exists {
continue
}
seen[normalized] = struct{}{}
result = append(result, normalized)
}
return result
}
func applyMongoURI(config connection.ConnectionConfig) connection.ConnectionConfig {
uriText := strings.TrimSpace(config.URI)
if uriText == "" {
return config
}
lowerURI := strings.ToLower(uriText)
if strings.HasPrefix(lowerURI, "mongodb+srv://") {
config.MongoSRV = true
}
if !strings.HasPrefix(lowerURI, "mongodb://") && !strings.HasPrefix(lowerURI, "mongodb+srv://") {
return config
}
parsed, err := url.Parse(uriText)
if err != nil {
return config
}
if parsed.User != nil {
if config.User == "" {
config.User = parsed.User.Username()
}
if pass, ok := parsed.User.Password(); ok && config.Password == "" {
config.Password = pass
}
}
if dbName := strings.TrimPrefix(parsed.Path, "/"); dbName != "" && config.Database == "" {
config.Database = dbName
}
defaultPort := config.Port
if defaultPort <= 0 {
defaultPort = defaultMongoPort
}
hostsFromURI := make([]string, 0, 4)
hostText := strings.TrimSpace(parsed.Host)
if hostText != "" {
for _, entry := range strings.Split(hostText, ",") {
normalized, ok := normalizeMongoSeed(entry, defaultPort, config.MongoSRV)
if ok {
hostsFromURI = append(hostsFromURI, normalized)
}
}
}
if len(config.Hosts) == 0 && len(hostsFromURI) > 0 {
config.Hosts = hostsFromURI
}
if strings.TrimSpace(config.Host) == "" && len(hostsFromURI) > 0 {
host, port, ok := parseHostPortWithDefault(hostsFromURI[0], defaultPort)
if ok {
config.Host = host
config.Port = port
}
}
query := parsed.Query()
if config.AuthSource == "" {
config.AuthSource = strings.TrimSpace(query.Get("authSource"))
}
if config.ReadPreference == "" {
config.ReadPreference = strings.TrimSpace(query.Get("readPreference"))
}
if config.ReplicaSet == "" {
config.ReplicaSet = strings.TrimSpace(query.Get("replicaSet"))
}
if config.MongoAuthMechanism == "" {
config.MongoAuthMechanism = strings.TrimSpace(query.Get("authMechanism"))
}
if config.Topology == "" {
if len(config.Hosts) > 1 || strings.TrimSpace(config.ReplicaSet) != "" {
config.Topology = "replica"
} else {
config.Topology = "single"
}
}
return config
}
func (m *MongoDB) getURI(config connection.ConnectionConfig) string {
if strings.TrimSpace(config.URI) != "" {
return strings.TrimSpace(config.URI)
}
seeds := collectMongoSeeds(config)
if len(seeds) == 0 {
if config.MongoSRV {
seed := strings.TrimSpace(config.Host)
if seed == "" {
seed = "localhost"
}
seeds = append(seeds, seed)
} else {
seeds = append(seeds, normalizeMongoAddress(config.Host, config.Port))
}
}
scheme := "mongodb"
if config.MongoSRV {
scheme = "mongodb+srv"
}
hostText := strings.Join(seeds, ",")
uri := fmt.Sprintf("%s://%s", scheme, hostText)
if config.User != "" {
encodedUser := url.PathEscape(config.User)
if config.Password != "" {
encodedPass := url.PathEscape(config.Password)
uri = fmt.Sprintf("%s://%s:%s@%s", scheme, encodedUser, encodedPass, hostText)
} else {
uri = fmt.Sprintf("%s://%s@%s", scheme, encodedUser, hostText)
}
}
path := "/"
if strings.TrimSpace(config.Database) != "" {
path = "/" + url.PathEscape(strings.TrimSpace(config.Database))
}
uri += path
params := url.Values{}
timeout := getConnectTimeoutSeconds(config)
params.Set("connectTimeoutMS", strconv.Itoa(timeout*1000))
params.Set("serverSelectionTimeoutMS", strconv.Itoa(timeout*1000))
authSource := strings.TrimSpace(config.AuthSource)
if authSource == "" && strings.TrimSpace(config.Database) != "" {
authSource = strings.TrimSpace(config.Database)
}
if authSource == "" {
authSource = "admin"
}
params.Set("authSource", authSource)
if replicaSet := strings.TrimSpace(config.ReplicaSet); replicaSet != "" {
params.Set("replicaSet", replicaSet)
}
if readPreference := strings.TrimSpace(config.ReadPreference); readPreference != "" {
params.Set("readPreference", readPreference)
}
if authMechanism := strings.TrimSpace(config.MongoAuthMechanism); authMechanism != "" {
params.Set("authMechanism", authMechanism)
}
if encoded := params.Encode(); encoded != "" {
uri += "?" + encoded
}
return uri
}
func buildMongoAuthAttempts(config connection.ConnectionConfig) []connection.ConnectionConfig {
attempts := []connection.ConnectionConfig{config}
replicaUser := strings.TrimSpace(config.MongoReplicaUser)
if replicaUser == "" {
return attempts
}
if replicaUser == strings.TrimSpace(config.User) && config.MongoReplicaPassword == config.Password {
return attempts
}
replicaConfig := config
replicaConfig.URI = ""
replicaConfig.User = replicaUser
replicaConfig.Password = config.MongoReplicaPassword
attempts = append(attempts, replicaConfig)
return attempts
}
func (m *MongoDB) Connect(config connection.ConnectionConfig) error {
runConfig := applyMongoURI(config)
connectConfig := runConfig
if runConfig.UseSSH && runConfig.MongoSRV {
return fmt.Errorf("MongoDB SRV 记录模式暂不支持 SSH 隧道")
}
if runConfig.UseSSH {
seeds := collectMongoSeeds(runConfig)
if len(seeds) == 0 {
seeds = append(seeds, normalizeMongoAddress(runConfig.Host, runConfig.Port))
}
targetHost, targetPort, ok := parseHostPortWithDefault(seeds[0], defaultMongoPort)
if !ok {
return fmt.Errorf("MongoDB 连接失败:无效地址 %s", seeds[0])
}
logger.Infof("MongoDB 使用 SSH 连接:地址=%s:%d", targetHost, targetPort)
forwarder, err := ssh.GetOrCreateLocalForwarder(runConfig.SSH, targetHost, targetPort)
if err != nil {
return fmt.Errorf("创建 SSH 隧道失败:%w", err)
}
m.forwarder = forwarder
host, portStr, err := net.SplitHostPort(forwarder.LocalAddr)
if err != nil {
return fmt.Errorf("解析本地转发地址失败:%w", err)
}
port, err := strconv.Atoi(portStr)
if err != nil {
return fmt.Errorf("解析本地端口失败:%w", err)
}
localConfig := runConfig
localConfig.Host = host
localConfig.Port = port
localConfig.UseSSH = false
localConfig.URI = ""
localConfig.Hosts = []string{normalizeMongoAddress(host, port)}
connectConfig = localConfig
logger.Infof("MongoDB 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, targetHost, targetPort)
}
m.pingTimeout = getConnectTimeout(connectConfig)
m.database = connectConfig.Database
if m.database == "" {
m.database = "admin"
}
attemptConfigs := buildMongoAuthAttempts(connectConfig)
var errorDetails []string
for index, attemptConfig := range attemptConfigs {
authLabel := "主库凭据"
if index > 0 {
authLabel = "从库凭据"
}
uri := m.getURI(attemptConfig)
clientOpts := options.Client().ApplyURI(uri)
client, err := mongo.Connect(clientOpts)
if err != nil {
errorDetails = append(errorDetails, fmt.Sprintf("%s连接失败: %v", authLabel, err))
continue
}
m.client = client
if err := m.Ping(); err != nil {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
_ = client.Disconnect(ctx)
cancel()
m.client = nil
errorDetails = append(errorDetails, fmt.Sprintf("%s验证失败: %v", authLabel, err))
continue
}
return nil
}
if len(errorDetails) > 0 {
return fmt.Errorf("MongoDB 连接失败:%s", strings.Join(errorDetails, ""))
}
return fmt.Errorf("MongoDB 连接失败:无可用连接方案")
}
func (m *MongoDB) Close() error {
if m.forwarder != nil {
if err := m.forwarder.Close(); err != nil {
logger.Warnf("关闭 MongoDB SSH 端口转发失败:%v", err)
}
m.forwarder = nil
}
if m.client != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
return m.client.Disconnect(ctx)
}
return nil
}
func (m *MongoDB) Ping() error {
if m.client == nil {
return fmt.Errorf("connection not open")
}
timeout := m.pingTimeout
if timeout <= 0 {
timeout = 5 * time.Second
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return m.client.Ping(ctx, readpref.Primary())
}
func asMongoStringList(raw interface{}) []string {
values, ok := raw.(bson.A)
if !ok {
return nil
}
result := make([]string, 0, len(values))
for _, entry := range values {
text := strings.TrimSpace(fmt.Sprintf("%v", entry))
if text != "" {
result = append(result, text)
}
}
return result
}
func asMongoString(raw interface{}) string {
if raw == nil {
return ""
}
if value, ok := raw.(string); ok {
return strings.TrimSpace(value)
}
return strings.TrimSpace(fmt.Sprintf("%v", raw))
}
func asMongoInt(raw interface{}) int {
switch value := raw.(type) {
case int:
return value
case int32:
return int(value)
case int64:
return int(value)
case float32:
return int(value)
case float64:
return int(value)
default:
return 0
}
}
func asMongoBool(raw interface{}) bool {
switch value := raw.(type) {
case bool:
return value
case int:
return value != 0
case int32:
return value != 0
case int64:
return value != 0
case float32:
return value != 0
case float64:
return value != 0
default:
return false
}
}
func mongoStateByCode(code int) string {
switch code {
case 1:
return "PRIMARY"
case 2:
return "SECONDARY"
case 3:
return "RECOVERING"
case 5:
return "STARTUP2"
case 6:
return "UNKNOWN"
case 7:
return "ARBITER"
case 8:
return "DOWN"
case 9:
return "ROLLBACK"
case 10:
return "REMOVED"
default:
return "UNKNOWN"
}
}
func normalizeMongoStateLabel(state string, stateCode int) string {
normalized := strings.ToUpper(strings.TrimSpace(state))
if normalized != "" {
return normalized
}
return mongoStateByCode(stateCode)
}
func buildMembersFromReplStatus(raw bson.M) []connection.MongoMemberInfo {
items, ok := raw["members"].(bson.A)
if !ok {
return nil
}
members := make([]connection.MongoMemberInfo, 0, len(items))
for _, entry := range items {
member, ok := entry.(bson.M)
if !ok {
continue
}
host := asMongoString(member["name"])
if host == "" {
continue
}
stateCode := asMongoInt(member["state"])
state := normalizeMongoStateLabel(asMongoString(member["stateStr"]), stateCode)
members = append(members, connection.MongoMemberInfo{
Host: host,
Role: state,
State: state,
StateCode: stateCode,
Healthy: asMongoInt(member["health"]) > 0 || asMongoBool(member["health"]),
IsSelf: asMongoBool(member["self"]),
})
}
sort.Slice(members, func(i, j int) bool {
return members[i].Host < members[j].Host
})
return members
}
func buildMembersFromHello(raw bson.M) []connection.MongoMemberInfo {
hosts := asMongoStringList(raw["hosts"])
if len(hosts) == 0 {
return nil
}
primary := asMongoString(raw["primary"])
selfHost := asMongoString(raw["me"])
passiveSet := make(map[string]struct{})
for _, host := range asMongoStringList(raw["passives"]) {
passiveSet[host] = struct{}{}
}
arbiterSet := make(map[string]struct{})
for _, host := range asMongoStringList(raw["arbiters"]) {
arbiterSet[host] = struct{}{}
}
members := make([]connection.MongoMemberInfo, 0, len(hosts))
for _, host := range hosts {
state := "SECONDARY"
stateCode := 2
if host == primary {
state = "PRIMARY"
stateCode = 1
} else if _, ok := arbiterSet[host]; ok {
state = "ARBITER"
stateCode = 7
} else if _, ok := passiveSet[host]; ok {
state = "PASSIVE"
stateCode = 6
}
members = append(members, connection.MongoMemberInfo{
Host: host,
Role: state,
State: state,
StateCode: stateCode,
Healthy: true,
IsSelf: host == selfHost,
})
}
sort.Slice(members, func(i, j int) bool {
return members[i].Host < members[j].Host
})
return members
}
func (m *MongoDB) DiscoverMembers() (string, []connection.MongoMemberInfo, error) {
if m.client == nil {
return "", nil, fmt.Errorf("connection not open")
}
timeout := m.pingTimeout
if timeout <= 0 {
timeout = 10 * time.Second
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
adminDB := m.client.Database("admin")
var replStatus bson.M
replErr := adminDB.RunCommand(ctx, bson.D{{Key: "replSetGetStatus", Value: 1}}).Decode(&replStatus)
if replErr == nil {
replicaSet := asMongoString(replStatus["set"])
members := buildMembersFromReplStatus(replStatus)
if len(members) > 0 {
return replicaSet, members, nil
}
}
var helloResult bson.M
helloErr := adminDB.RunCommand(ctx, bson.D{{Key: "hello", Value: 1}}).Decode(&helloResult)
if helloErr != nil {
if err := adminDB.RunCommand(ctx, bson.D{{Key: "isMaster", Value: 1}}).Decode(&helloResult); err != nil {
if replErr != nil {
return "", nil, fmt.Errorf("成员发现失败replSetGetStatus=%vhello=%v", replErr, err)
}
return "", nil, fmt.Errorf("成员发现失败hello=%w", err)
}
}
replicaSet := asMongoString(helloResult["setName"])
members := buildMembersFromHello(helloResult)
if len(members) == 0 {
if replErr != nil {
return replicaSet, nil, fmt.Errorf("未获取到成员信息replSetGetStatus=%v", replErr)
}
return replicaSet, nil, fmt.Errorf("未获取到成员信息")
}
return replicaSet, members, nil
}
// Query executes a MongoDB command and returns results
// Supports JSON format commands like: {"find": "collection", "filter": {}}
func (m *MongoDB) Query(query string) ([]map[string]interface{}, []string, error) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
return m.queryWithContext(ctx, query)
}
// QueryContext executes a MongoDB command with the given context for timeout control
func (m *MongoDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
return m.queryWithContext(ctx, query)
}
func (m *MongoDB) queryWithContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
if m.client == nil {
return nil, nil, fmt.Errorf("connection not open")
}
query = strings.TrimSpace(query)
if query == "" {
return nil, nil, fmt.Errorf("empty query")
}
// Parse JSON command
var cmd bson.D
if err := bson.UnmarshalExtJSON([]byte(query), true, &cmd); err != nil {
return nil, nil, fmt.Errorf("invalid JSON command: %w", err)
}
db := m.client.Database(m.database)
var result bson.M
if err := db.RunCommand(ctx, cmd).Decode(&result); err != nil {
return nil, nil, err
}
// Convert result to standard format
data := []map[string]interface{}{{"result": result}}
columns := []string{"result"}
// If result contains cursor with documents, extract them
if cursor, ok := result["cursor"].(bson.M); ok {
if batch, ok := cursor["firstBatch"].(bson.A); ok {
data = make([]map[string]interface{}, 0, len(batch))
columnSet := make(map[string]bool)
for _, doc := range batch {
if docMap, ok := doc.(bson.M); ok {
row := make(map[string]interface{})
for k, v := range docMap {
row[k] = v
columnSet[k] = true
}
data = append(data, row)
}
}
columns = make([]string, 0, len(columnSet))
for k := range columnSet {
columns = append(columns, k)
}
}
}
return data, columns, nil
}
func (m *MongoDB) Exec(query string) (int64, error) {
_, _, err := m.Query(query)
if err != nil {
return 0, err
}
return 1, nil
}
// ExecContext executes a MongoDB command with the given context for timeout control
func (m *MongoDB) ExecContext(ctx context.Context, query string) (int64, error) {
_, _, err := m.QueryContext(ctx, query)
if err != nil {
return 0, err
}
return 1, nil
}
func (m *MongoDB) GetDatabases() ([]string, error) {
if m.client == nil {
return nil, fmt.Errorf("connection not open")
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
dbs, err := m.client.ListDatabaseNames(ctx, bson.M{})
if err != nil {
return nil, err
}
return dbs, nil
}
func (m *MongoDB) GetTables(dbName string) ([]string, error) {
if m.client == nil {
return nil, fmt.Errorf("connection not open")
}
targetDB := dbName
if targetDB == "" {
targetDB = m.database
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
collections, err := m.client.Database(targetDB).ListCollectionNames(ctx, bson.M{})
if err != nil {
return nil, err
}
return collections, nil
}
func (m *MongoDB) GetCreateStatement(dbName, tableName string) (string, error) {
return fmt.Sprintf("// MongoDB collection: %s.%s\n// MongoDB is schemaless - no CREATE statement available", dbName, tableName), nil
}
// GetColumns returns empty for MongoDB (schemaless)
func (m *MongoDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
// MongoDB is schemaless, return empty
return []connection.ColumnDefinition{}, nil
}
// GetAllColumns returns empty for MongoDB (schemaless)
func (m *MongoDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
return []connection.ColumnDefinitionWithTable{}, nil
}
// GetIndexes returns indexes for a MongoDB collection
func (m *MongoDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
if m.client == nil {
return nil, fmt.Errorf("connection not open")
}
targetDB := dbName
if targetDB == "" {
targetDB = m.database
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
collection := m.client.Database(targetDB).Collection(tableName)
cursor, err := collection.Indexes().List(ctx)
if err != nil {
return nil, err
}
defer cursor.Close(ctx)
var indexes []connection.IndexDefinition
for cursor.Next(ctx) {
var idx bson.M
if err := cursor.Decode(&idx); err != nil {
continue
}
name := fmt.Sprintf("%v", idx["name"])
unique := false
if u, ok := idx["unique"].(bool); ok {
unique = u
}
// Extract key fields
if key, ok := idx["key"].(bson.M); ok {
seq := 1
for field := range key {
nonUnique := 1
if unique {
nonUnique = 0
}
indexes = append(indexes, connection.IndexDefinition{
Name: name,
ColumnName: field,
NonUnique: nonUnique,
SeqInIndex: seq,
IndexType: "BTREE",
})
seq++
}
}
}
return indexes, nil
}
func (m *MongoDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
// MongoDB doesn't have foreign keys
return []connection.ForeignKeyDefinition{}, nil
}
func (m *MongoDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
// MongoDB doesn't have triggers in the traditional sense
return []connection.TriggerDefinition{}, nil
}
// ApplyChanges implements batch changes for MongoDB
func (m *MongoDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
if m.client == nil {
return fmt.Errorf("connection not open")
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
collection := m.client.Database(m.database).Collection(tableName)
// Process deletes
for _, pk := range changes.Deletes {
filter := bson.M{}
for k, v := range pk {
filter[k] = v
}
if len(filter) > 0 {
if _, err := collection.DeleteOne(ctx, filter); err != nil {
return fmt.Errorf("delete error: %v", err)
}
}
}
// Process updates
for _, update := range changes.Updates {
filter := bson.M{}
for k, v := range update.Keys {
filter[k] = v
}
if len(filter) == 0 {
return fmt.Errorf("update requires keys")
}
updateDoc := bson.M{"$set": bson.M{}}
for k, v := range update.Values {
updateDoc["$set"].(bson.M)[k] = v
}
if _, err := collection.UpdateOne(ctx, filter, updateDoc); err != nil {
return fmt.Errorf("update error: %v", err)
}
}
// Process inserts
for _, row := range changes.Inserts {
doc := bson.M{}
for k, v := range row {
doc[k] = v
}
if len(doc) > 0 {
if _, err := collection.InsertOne(ctx, doc); err != nil {
return fmt.Errorf("insert error: %v", err)
}
}
}
return nil
}