mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-06-06 22:49:35 +08:00
✨feat(drivers): 支持按需启动数据源并通过外置驱动代理减少发行包体积
- MySQL/Redis/Oracle/PostgreSQL 内置可用,其余数据源改为“安装启用”后可用 - 新建连接对未安装驱动做弹窗内拦截提示,并支持一键跳转驱动管理安装 - 驱动管理展示安装包真实大小(从 Release 资产元数据读取)并优化加载性能 - Release 工作流发布各平台驱动代理资产,主程序构建启用 -s -w 精简
This commit is contained in:
@@ -198,6 +198,20 @@ func (a *App) getDatabaseWithPing(config connection.ConnectionConfig, forcePing
|
||||
shortKey = shortKey[:12]
|
||||
}
|
||||
|
||||
if supported, reason := db.DriverRuntimeSupportStatus(config.Type); !supported {
|
||||
if strings.TrimSpace(reason) == "" {
|
||||
reason = fmt.Sprintf("%s 驱动未启用,请先在驱动管理中安装启用", strings.TrimSpace(config.Type))
|
||||
}
|
||||
// Best-effort cleanup: if cached instance exists for this exact config, close it.
|
||||
a.mu.Lock()
|
||||
if cur, exists := a.dbCache[key]; exists && cur.inst != nil {
|
||||
_ = cur.inst.Close()
|
||||
delete(a.dbCache, key)
|
||||
}
|
||||
a.mu.Unlock()
|
||||
return nil, withLogHint{err: fmt.Errorf("%s", reason), logPath: logger.Path()}
|
||||
}
|
||||
|
||||
a.mu.RLock()
|
||||
entry, ok := a.dbCache[key]
|
||||
a.mu.RUnlock()
|
||||
|
||||
1657
internal/app/methods_driver.go
Normal file
1657
internal/app/methods_driver.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,3 +1,5 @@
|
||||
//go:build gonavi_full_drivers || gonavi_dameng_driver
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
|
||||
@@ -3,6 +3,7 @@ package db
|
||||
import (
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Database interface {
|
||||
@@ -25,46 +26,61 @@ type BatchApplier interface {
|
||||
ApplyChanges(tableName string, changes connection.ChangeSet) error
|
||||
}
|
||||
|
||||
// Factory
|
||||
func NewDatabase(dbType string) (Database, error) {
|
||||
switch dbType {
|
||||
case "mysql":
|
||||
return &MySQLDB{}, nil
|
||||
case "postgres":
|
||||
return &PostgresDB{}, nil
|
||||
case "sqlite":
|
||||
return &SQLiteDB{}, nil
|
||||
case "oracle":
|
||||
return &OracleDB{}, nil
|
||||
case "dameng":
|
||||
return &DamengDB{}, nil
|
||||
case "kingbase":
|
||||
return &KingbaseDB{}, nil
|
||||
case "mongodb":
|
||||
return &MongoDB{}, nil
|
||||
case "sqlserver":
|
||||
return &SqlServerDB{}, nil
|
||||
case "highgo":
|
||||
return &HighGoDB{}, nil
|
||||
case "mariadb":
|
||||
return &MariaDB{}, nil
|
||||
case "diros", "doris":
|
||||
return &DirosDB{}, nil
|
||||
case "sphinx":
|
||||
return &SphinxDB{}, nil
|
||||
case "vastbase":
|
||||
return &VastbaseDB{}, nil
|
||||
case "tdengine":
|
||||
return &TDengineDB{}, nil
|
||||
case "duckdb":
|
||||
return &DuckDB{}, nil
|
||||
case "custom":
|
||||
return &CustomDB{}, nil
|
||||
default:
|
||||
// Default to MySQL for backward compatibility if empty
|
||||
if dbType == "" {
|
||||
return &MySQLDB{}, nil
|
||||
type databaseFactory func() Database
|
||||
|
||||
var databaseFactories = map[string]databaseFactory{
|
||||
"mysql": func() Database {
|
||||
return &MySQLDB{}
|
||||
},
|
||||
"postgres": func() Database {
|
||||
return &PostgresDB{}
|
||||
},
|
||||
"oracle": func() Database {
|
||||
return &OracleDB{}
|
||||
},
|
||||
"custom": func() Database {
|
||||
return &CustomDB{}
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
registerOptionalDatabaseFactories()
|
||||
}
|
||||
|
||||
func registerDatabaseFactory(factory databaseFactory, dbTypes ...string) {
|
||||
if factory == nil || len(dbTypes) == 0 {
|
||||
return
|
||||
}
|
||||
for _, dbType := range dbTypes {
|
||||
normalized := normalizeDatabaseType(dbType)
|
||||
if normalized == "" {
|
||||
continue
|
||||
}
|
||||
return nil, fmt.Errorf("unsupported database type: %s", dbType)
|
||||
databaseFactories[normalized] = factory
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeDatabaseType(dbType string) string {
|
||||
normalized := strings.ToLower(strings.TrimSpace(dbType))
|
||||
switch normalized {
|
||||
case "doris":
|
||||
return "diros"
|
||||
case "postgresql":
|
||||
return "postgres"
|
||||
default:
|
||||
return normalized
|
||||
}
|
||||
}
|
||||
|
||||
// Factory
|
||||
func NewDatabase(dbType string) (Database, error) {
|
||||
normalized := normalizeDatabaseType(dbType)
|
||||
if normalized == "" {
|
||||
normalized = "mysql"
|
||||
}
|
||||
factory, ok := databaseFactories[normalized]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unsupported database type: %s", dbType)
|
||||
}
|
||||
return factory(), nil
|
||||
}
|
||||
|
||||
18
internal/db/database_optional_factories_full.go
Normal file
18
internal/db/database_optional_factories_full.go
Normal file
@@ -0,0 +1,18 @@
|
||||
//go:build gonavi_full_drivers
|
||||
|
||||
package db
|
||||
|
||||
func registerOptionalDatabaseFactories() {
|
||||
registerDatabaseFactory(newOptionalDriverAgentDatabase("mariadb"), "mariadb")
|
||||
registerDatabaseFactory(newOptionalDriverAgentDatabase("diros"), "diros", "doris")
|
||||
registerDatabaseFactory(newOptionalDriverAgentDatabase("sphinx"), "sphinx")
|
||||
registerDatabaseFactory(newOptionalDriverAgentDatabase("sqlserver"), "sqlserver")
|
||||
registerDatabaseFactory(newOptionalDriverAgentDatabase("sqlite"), "sqlite")
|
||||
registerDatabaseFactory(newOptionalDriverAgentDatabase("duckdb"), "duckdb")
|
||||
registerDatabaseFactory(newOptionalDriverAgentDatabase("dameng"), "dameng")
|
||||
registerDatabaseFactory(newOptionalDriverAgentDatabase("kingbase"), "kingbase")
|
||||
registerDatabaseFactory(newOptionalDriverAgentDatabase("highgo"), "highgo")
|
||||
registerDatabaseFactory(newOptionalDriverAgentDatabase("vastbase"), "vastbase")
|
||||
registerDatabaseFactory(newOptionalDriverAgentDatabase("mongodb"), "mongodb")
|
||||
registerDatabaseFactory(newOptionalDriverAgentDatabase("tdengine"), "tdengine")
|
||||
}
|
||||
18
internal/db/database_optional_factories_lite.go
Normal file
18
internal/db/database_optional_factories_lite.go
Normal file
@@ -0,0 +1,18 @@
|
||||
//go:build !gonavi_full_drivers
|
||||
|
||||
package db
|
||||
|
||||
func registerOptionalDatabaseFactories() {
|
||||
registerDatabaseFactory(newOptionalDriverAgentDatabase("mariadb"), "mariadb")
|
||||
registerDatabaseFactory(newOptionalDriverAgentDatabase("diros"), "diros", "doris")
|
||||
registerDatabaseFactory(newOptionalDriverAgentDatabase("sphinx"), "sphinx")
|
||||
registerDatabaseFactory(newOptionalDriverAgentDatabase("sqlserver"), "sqlserver")
|
||||
registerDatabaseFactory(newOptionalDriverAgentDatabase("sqlite"), "sqlite")
|
||||
registerDatabaseFactory(newOptionalDriverAgentDatabase("duckdb"), "duckdb")
|
||||
registerDatabaseFactory(newOptionalDriverAgentDatabase("dameng"), "dameng")
|
||||
registerDatabaseFactory(newOptionalDriverAgentDatabase("kingbase"), "kingbase")
|
||||
registerDatabaseFactory(newOptionalDriverAgentDatabase("highgo"), "highgo")
|
||||
registerDatabaseFactory(newOptionalDriverAgentDatabase("vastbase"), "vastbase")
|
||||
registerDatabaseFactory(newOptionalDriverAgentDatabase("mongodb"), "mongodb")
|
||||
registerDatabaseFactory(newOptionalDriverAgentDatabase("tdengine"), "tdengine")
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build gonavi_full_drivers || gonavi_diros_driver
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
|
||||
222
internal/db/driver_support.go
Normal file
222
internal/db/driver_support.go
Normal file
@@ -0,0 +1,222 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var coreBuiltinDrivers = map[string]struct{}{
|
||||
"mysql": {},
|
||||
"redis": {},
|
||||
"oracle": {},
|
||||
"postgres": {},
|
||||
}
|
||||
|
||||
// optionalGoDrivers 表示需要用户“安装启用”后才能使用的纯 Go 驱动。
|
||||
// 注意:这是一种运行时门控(installed.json 标记),并不减少主二进制体积。
|
||||
var optionalGoDrivers = map[string]struct{}{
|
||||
"mariadb": {},
|
||||
"diros": {},
|
||||
"sphinx": {},
|
||||
"sqlserver": {},
|
||||
"sqlite": {},
|
||||
"duckdb": {},
|
||||
"dameng": {},
|
||||
"kingbase": {},
|
||||
"highgo": {},
|
||||
"vastbase": {},
|
||||
"mongodb": {},
|
||||
"tdengine": {},
|
||||
}
|
||||
|
||||
var (
|
||||
externalDriverDirMu sync.RWMutex
|
||||
externalDriverDir string
|
||||
)
|
||||
|
||||
func normalizeRuntimeDriverType(driverType string) string {
|
||||
normalized := strings.ToLower(strings.TrimSpace(driverType))
|
||||
switch normalized {
|
||||
case "doris":
|
||||
return "diros"
|
||||
case "postgresql":
|
||||
return "postgres"
|
||||
default:
|
||||
return normalized
|
||||
}
|
||||
}
|
||||
|
||||
func driverDisplayName(driverType string) string {
|
||||
switch normalizeRuntimeDriverType(driverType) {
|
||||
case "mysql":
|
||||
return "MySQL"
|
||||
case "oracle":
|
||||
return "Oracle"
|
||||
case "redis":
|
||||
return "Redis"
|
||||
case "mariadb":
|
||||
return "MariaDB"
|
||||
case "diros":
|
||||
return "Diros"
|
||||
case "sphinx":
|
||||
return "Sphinx"
|
||||
case "postgres":
|
||||
return "PostgreSQL"
|
||||
case "sqlserver":
|
||||
return "SQL Server"
|
||||
case "sqlite":
|
||||
return "SQLite"
|
||||
case "duckdb":
|
||||
return "DuckDB"
|
||||
case "dameng":
|
||||
return "Dameng"
|
||||
case "kingbase":
|
||||
return "Kingbase"
|
||||
case "highgo":
|
||||
return "HighGo"
|
||||
case "vastbase":
|
||||
return "Vastbase"
|
||||
case "mongodb":
|
||||
return "MongoDB"
|
||||
case "tdengine":
|
||||
return "TDengine"
|
||||
default:
|
||||
return strings.ToUpper(strings.TrimSpace(driverType))
|
||||
}
|
||||
}
|
||||
|
||||
func IsOptionalGoDriver(driverType string) bool {
|
||||
_, ok := optionalGoDrivers[normalizeRuntimeDriverType(driverType)]
|
||||
return ok
|
||||
}
|
||||
|
||||
func IsOptionalGoDriverBuildIncluded(driverType string) bool {
|
||||
return optionalGoDriverBuildIncluded(normalizeRuntimeDriverType(driverType))
|
||||
}
|
||||
|
||||
func IsBuiltinDriver(driverType string) bool {
|
||||
_, ok := coreBuiltinDrivers[normalizeRuntimeDriverType(driverType)]
|
||||
return ok
|
||||
}
|
||||
|
||||
func defaultExternalDriverDownloadDirectory() string {
|
||||
if home, err := os.UserHomeDir(); err == nil && strings.TrimSpace(home) != "" {
|
||||
return filepath.Join(home, ".gonavi", "drivers")
|
||||
}
|
||||
if wd, err := os.Getwd(); err == nil && strings.TrimSpace(wd) != "" {
|
||||
return filepath.Join(wd, ".gonavi-drivers")
|
||||
}
|
||||
return ".gonavi-drivers"
|
||||
}
|
||||
|
||||
func resolveExternalDriverRoot(downloadDir string) (string, error) {
|
||||
root := strings.TrimSpace(downloadDir)
|
||||
if root == "" {
|
||||
root = currentExternalDriverDownloadDirectory()
|
||||
}
|
||||
if root == "" {
|
||||
root = defaultExternalDriverDownloadDirectory()
|
||||
}
|
||||
if !filepath.IsAbs(root) {
|
||||
abs, err := filepath.Abs(root)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
root = abs
|
||||
}
|
||||
if err := os.MkdirAll(root, 0o755); err != nil {
|
||||
return "", fmt.Errorf("创建驱动目录失败:%w", err)
|
||||
}
|
||||
return root, nil
|
||||
}
|
||||
|
||||
func currentExternalDriverDownloadDirectory() string {
|
||||
externalDriverDirMu.RLock()
|
||||
current := strings.TrimSpace(externalDriverDir)
|
||||
externalDriverDirMu.RUnlock()
|
||||
if current != "" {
|
||||
return current
|
||||
}
|
||||
return defaultExternalDriverDownloadDirectory()
|
||||
}
|
||||
|
||||
func SetExternalDriverDownloadDirectory(downloadDir string) {
|
||||
root, err := resolveExternalDriverRoot(downloadDir)
|
||||
if err != nil {
|
||||
root = defaultExternalDriverDownloadDirectory()
|
||||
}
|
||||
externalDriverDirMu.Lock()
|
||||
externalDriverDir = root
|
||||
externalDriverDirMu.Unlock()
|
||||
}
|
||||
|
||||
func ResolveExternalDriverRoot(downloadDir string) (string, error) {
|
||||
return resolveExternalDriverRoot(downloadDir)
|
||||
}
|
||||
|
||||
func ResolveOptionalGoDriverMarkerPath(downloadDir string, driverType string) (string, error) {
|
||||
normalized := normalizeRuntimeDriverType(driverType)
|
||||
if !IsOptionalGoDriver(normalized) {
|
||||
return "", fmt.Errorf("%s 不是可选 Go 驱动", driverDisplayName(normalized))
|
||||
}
|
||||
root, err := resolveExternalDriverRoot(downloadDir)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(root, normalized, "installed.json"), nil
|
||||
}
|
||||
|
||||
func optionalGoDriverInstalled(driverType string) bool {
|
||||
markerPath, err := ResolveOptionalGoDriverMarkerPath("", driverType)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
info, statErr := os.Stat(markerPath)
|
||||
return statErr == nil && !info.IsDir()
|
||||
}
|
||||
|
||||
func optionalGoDriverRuntimeReady(driverType string) (bool, string) {
|
||||
normalized := normalizeRuntimeDriverType(driverType)
|
||||
if !IsOptionalGoDriver(normalized) {
|
||||
return true, ""
|
||||
}
|
||||
executablePath, err := ResolveOptionalDriverAgentExecutablePath("", normalized)
|
||||
if err != nil {
|
||||
return false, fmt.Sprintf("%s 驱动代理路径解析失败,请在驱动管理中重新安装启用", driverDisplayName(normalized))
|
||||
}
|
||||
info, statErr := os.Stat(executablePath)
|
||||
if statErr != nil || info.IsDir() {
|
||||
return false, fmt.Sprintf("%s 驱动代理缺失,请在驱动管理中重新安装启用", driverDisplayName(normalized))
|
||||
}
|
||||
return true, ""
|
||||
}
|
||||
|
||||
// DriverRuntimeSupportStatus 返回当前构建下驱动是否可用(可直接用于连接)。
|
||||
func DriverRuntimeSupportStatus(driverType string) (bool, string) {
|
||||
normalized := normalizeRuntimeDriverType(driverType)
|
||||
if normalized == "" {
|
||||
return false, "未识别的数据源类型"
|
||||
}
|
||||
if normalized == "custom" {
|
||||
return true, ""
|
||||
}
|
||||
if IsBuiltinDriver(normalized) {
|
||||
return true, ""
|
||||
}
|
||||
if IsOptionalGoDriver(normalized) {
|
||||
if !IsOptionalGoDriverBuildIncluded(normalized) {
|
||||
return false, fmt.Sprintf("%s 当前发行包为精简构建,未内置该驱动;如需使用请安装 Full 版", driverDisplayName(normalized))
|
||||
}
|
||||
if optionalGoDriverInstalled(normalized) {
|
||||
if ready, reason := optionalGoDriverRuntimeReady(normalized); !ready {
|
||||
return false, reason
|
||||
}
|
||||
return true, ""
|
||||
}
|
||||
return false, fmt.Sprintf("%s 纯 Go 驱动未启用,请先在驱动管理中点击“安装启用”", driverDisplayName(normalized))
|
||||
}
|
||||
return true, ""
|
||||
}
|
||||
89
internal/db/driver_support_test.go
Normal file
89
internal/db/driver_support_test.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPostgresRuntimeSupportRequiresInstallMarker(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
SetExternalDriverDownloadDirectory(tmpDir)
|
||||
|
||||
supported, _ := DriverRuntimeSupportStatus("postgres")
|
||||
if !supported {
|
||||
t.Fatalf("postgres 属于免安装内置驱动,应可用")
|
||||
}
|
||||
supported, reason := DriverRuntimeSupportStatus("postgres")
|
||||
if !supported {
|
||||
t.Fatalf("postgres 应可用,reason=%s", reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuiltinLikeDriversRemainAvailable(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
SetExternalDriverDownloadDirectory(tmpDir)
|
||||
|
||||
supported, reason := DriverRuntimeSupportStatus("redis")
|
||||
if !supported {
|
||||
t.Fatalf("redis 应始终可用,reason=%s", reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagedDriverRequiresInstallMarker(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
SetExternalDriverDownloadDirectory(tmpDir)
|
||||
|
||||
supported, _ := DriverRuntimeSupportStatus("mariadb")
|
||||
if supported {
|
||||
t.Fatalf("mariadb 未安装时不应可用")
|
||||
}
|
||||
|
||||
if !IsOptionalGoDriverBuildIncluded("mariadb") {
|
||||
supported, reason := DriverRuntimeSupportStatus("mariadb")
|
||||
if supported {
|
||||
t.Fatalf("精简构建下 mariadb 不应可用")
|
||||
}
|
||||
if reason == "" {
|
||||
t.Fatalf("精简构建下 mariadb 应返回不可用原因")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
markerPath, err := ResolveOptionalGoDriverMarkerPath(tmpDir, "mariadb")
|
||||
if err != nil {
|
||||
t.Fatalf("解析 marker 路径失败: %v", err)
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(markerPath), 0o755); err != nil {
|
||||
t.Fatalf("创建 marker 目录失败: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(markerPath, []byte("{}"), 0o644); err != nil {
|
||||
t.Fatalf("写入 marker 失败: %v", err)
|
||||
}
|
||||
executablePath, err := ResolveOptionalDriverAgentExecutablePath(tmpDir, "mariadb")
|
||||
if err != nil {
|
||||
t.Fatalf("解析 mariadb 代理路径失败: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(executablePath, []byte("placeholder"), 0o755); err != nil {
|
||||
t.Fatalf("写入 mariadb 代理占位文件失败: %v", err)
|
||||
}
|
||||
if runtime.GOOS == "windows" {
|
||||
_ = os.Chmod(executablePath, 0o644)
|
||||
}
|
||||
|
||||
supported, reason := DriverRuntimeSupportStatus("mariadb")
|
||||
if !supported {
|
||||
t.Fatalf("mariadb 安装后应可用,reason=%s", reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMySQLBuiltinRuntimeSupportAvailable(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
SetExternalDriverDownloadDirectory(tmpDir)
|
||||
|
||||
supported, reason := DriverRuntimeSupportStatus("mysql")
|
||||
if !supported {
|
||||
t.Fatalf("mysql 属于免安装内置驱动,应可用,reason=%s", reason)
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build gonavi_full_drivers
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build cgo && (duckdb_use_lib || duckdb_use_static_lib || (darwin && (amd64 || arm64)) || (linux && (amd64 || arm64)) || (windows && amd64))
|
||||
//go:build (gonavi_full_drivers || gonavi_duckdb_driver) && cgo && (duckdb_use_lib || duckdb_use_static_lib || (darwin && (amd64 || arm64)) || (linux && (amd64 || arm64)) || (windows && amd64))
|
||||
|
||||
package db
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build gonavi_full_drivers || gonavi_duckdb_driver
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build cgo && (duckdb_use_lib || duckdb_use_static_lib || (darwin && (amd64 || arm64)) || (linux && (amd64 || arm64)) || (windows && amd64))
|
||||
//go:build (gonavi_full_drivers || gonavi_duckdb_driver) && cgo && (duckdb_use_lib || duckdb_use_static_lib || (darwin && (amd64 || arm64)) || (linux && (amd64 || arm64)) || (windows && amd64))
|
||||
|
||||
package db
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !(cgo && (duckdb_use_lib || duckdb_use_static_lib || (darwin && (amd64 || arm64)) || (linux && (amd64 || arm64)) || (windows && amd64)))
|
||||
//go:build (gonavi_full_drivers || gonavi_duckdb_driver) && !(cgo && (duckdb_use_lib || duckdb_use_static_lib || (darwin && (amd64 || arm64)) || (linux && (amd64 || arm64)) || (windows && amd64)))
|
||||
|
||||
package db
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build gonavi_full_drivers || gonavi_highgo_driver
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build gonavi_full_drivers || gonavi_kingbase_driver
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build gonavi_full_drivers || gonavi_mariadb_driver
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build gonavi_full_drivers || gonavi_mongodb_driver
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
|
||||
430
internal/db/mysql_agent_impl.go
Normal file
430
internal/db/mysql_agent_impl.go
Normal file
@@ -0,0 +1,430 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
)
|
||||
|
||||
const (
|
||||
mysqlAgentMethodConnect = "connect"
|
||||
mysqlAgentMethodClose = "close"
|
||||
mysqlAgentMethodPing = "ping"
|
||||
mysqlAgentMethodQuery = "query"
|
||||
mysqlAgentMethodExec = "exec"
|
||||
mysqlAgentMethodGetDatabases = "getDatabases"
|
||||
mysqlAgentMethodGetTables = "getTables"
|
||||
mysqlAgentMethodGetCreateStmt = "getCreateStatement"
|
||||
mysqlAgentMethodGetColumns = "getColumns"
|
||||
mysqlAgentMethodGetAllColumns = "getAllColumns"
|
||||
mysqlAgentMethodGetIndexes = "getIndexes"
|
||||
mysqlAgentMethodGetForeignKeys = "getForeignKeys"
|
||||
mysqlAgentMethodGetTriggers = "getTriggers"
|
||||
mysqlAgentMethodApplyChanges = "applyChanges"
|
||||
mysqlAgentDefaultScannerMaxBytes = 8 << 20
|
||||
)
|
||||
|
||||
type mysqlAgentRequest struct {
|
||||
ID int64 `json:"id"`
|
||||
Method string `json:"method"`
|
||||
Config *connection.ConnectionConfig `json:"config,omitempty"`
|
||||
Query string `json:"query,omitempty"`
|
||||
DBName string `json:"dbName,omitempty"`
|
||||
TableName string `json:"tableName,omitempty"`
|
||||
Changes *connection.ChangeSet `json:"changes,omitempty"`
|
||||
}
|
||||
|
||||
type mysqlAgentResponse struct {
|
||||
ID int64 `json:"id"`
|
||||
Success bool `json:"success"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
Fields []string `json:"fields,omitempty"`
|
||||
RowsAffected int64 `json:"rowsAffected,omitempty"`
|
||||
}
|
||||
|
||||
type mysqlAgentClient struct {
|
||||
cmd *exec.Cmd
|
||||
stdin io.WriteCloser
|
||||
reader *bufio.Reader
|
||||
nextID int64
|
||||
mu sync.Mutex
|
||||
stderrMu sync.Mutex
|
||||
stderr strings.Builder
|
||||
}
|
||||
|
||||
func newMySQLAgentClient(executablePath string) (*mysqlAgentClient, error) {
|
||||
pathText := strings.TrimSpace(executablePath)
|
||||
if pathText == "" {
|
||||
return nil, fmt.Errorf("MySQL 驱动代理路径为空")
|
||||
}
|
||||
if info, err := os.Stat(pathText); err != nil || info.IsDir() {
|
||||
return nil, fmt.Errorf("MySQL 驱动代理不存在:%s", pathText)
|
||||
}
|
||||
|
||||
cmd := exec.Command(pathText)
|
||||
stdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建 MySQL 驱动代理 stdin 失败:%w", err)
|
||||
}
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建 MySQL 驱动代理 stdout 失败:%w", err)
|
||||
}
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建 MySQL 驱动代理 stderr 失败:%w", err)
|
||||
}
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("启动 MySQL 驱动代理失败:%w", err)
|
||||
}
|
||||
|
||||
client := &mysqlAgentClient{
|
||||
cmd: cmd,
|
||||
stdin: stdin,
|
||||
reader: bufio.NewReader(stdout),
|
||||
}
|
||||
go client.captureStderr(stderr)
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (c *mysqlAgentClient) captureStderr(stderr io.Reader) {
|
||||
scanner := bufio.NewScanner(stderr)
|
||||
buffer := make([]byte, 0, 8<<10)
|
||||
scanner.Buffer(buffer, mysqlAgentDefaultScannerMaxBytes)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
c.stderrMu.Lock()
|
||||
if c.stderr.Len() > 0 {
|
||||
c.stderr.WriteString(" | ")
|
||||
}
|
||||
c.stderr.WriteString(line)
|
||||
c.stderrMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *mysqlAgentClient) stderrText() string {
|
||||
c.stderrMu.Lock()
|
||||
defer c.stderrMu.Unlock()
|
||||
return strings.TrimSpace(c.stderr.String())
|
||||
}
|
||||
|
||||
func (c *mysqlAgentClient) call(req mysqlAgentRequest, out interface{}, fields *[]string, rowsAffected *int64) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.nextID++
|
||||
req.ID = c.nextID
|
||||
|
||||
payload, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
payload = append(payload, '\n')
|
||||
if _, err := c.stdin.Write(payload); err != nil {
|
||||
stderrText := c.stderrText()
|
||||
if stderrText == "" {
|
||||
return fmt.Errorf("调用 MySQL 驱动代理失败:%w", err)
|
||||
}
|
||||
return fmt.Errorf("调用 MySQL 驱动代理失败:%w(stderr: %s)", err, stderrText)
|
||||
}
|
||||
|
||||
line, err := c.reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
stderrText := c.stderrText()
|
||||
if stderrText == "" {
|
||||
return fmt.Errorf("读取 MySQL 驱动代理响应失败:%w", err)
|
||||
}
|
||||
return fmt.Errorf("读取 MySQL 驱动代理响应失败:%w(stderr: %s)", err, stderrText)
|
||||
}
|
||||
|
||||
var resp mysqlAgentResponse
|
||||
if err := json.Unmarshal(line, &resp); err != nil {
|
||||
return fmt.Errorf("解析 MySQL 驱动代理响应失败:%w", err)
|
||||
}
|
||||
if !resp.Success {
|
||||
errText := strings.TrimSpace(resp.Error)
|
||||
if errText == "" {
|
||||
errText = "MySQL 驱动代理返回失败"
|
||||
}
|
||||
return errors.New(errText)
|
||||
}
|
||||
|
||||
if fields != nil {
|
||||
*fields = resp.Fields
|
||||
}
|
||||
if rowsAffected != nil {
|
||||
*rowsAffected = resp.RowsAffected
|
||||
}
|
||||
if out != nil && len(resp.Data) > 0 {
|
||||
if err := json.Unmarshal(resp.Data, out); err != nil {
|
||||
return fmt.Errorf("解析 MySQL 驱动代理数据失败:%w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *mysqlAgentClient) close() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
var closeErr error
|
||||
if c.stdin != nil {
|
||||
_ = c.stdin.Close()
|
||||
}
|
||||
if c.cmd != nil && c.cmd.Process != nil {
|
||||
if err := c.cmd.Process.Kill(); err != nil {
|
||||
closeErr = err
|
||||
}
|
||||
}
|
||||
if c.cmd != nil {
|
||||
_ = c.cmd.Wait()
|
||||
}
|
||||
return closeErr
|
||||
}
|
||||
|
||||
type MySQLAgentDB struct {
|
||||
client *mysqlAgentClient
|
||||
}
|
||||
|
||||
func (m *MySQLAgentDB) Connect(config connection.ConnectionConfig) error {
|
||||
if m.client != nil {
|
||||
_ = m.client.close()
|
||||
m.client = nil
|
||||
}
|
||||
|
||||
executablePath, err := ResolveMySQLAgentExecutablePath("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
client, err := newMySQLAgentClient(executablePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := client.call(mysqlAgentRequest{
|
||||
Method: mysqlAgentMethodConnect,
|
||||
Config: &config,
|
||||
}, nil, nil, nil); err != nil {
|
||||
_ = client.close()
|
||||
return err
|
||||
}
|
||||
m.client = client
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MySQLAgentDB) Close() error {
|
||||
if m.client == nil {
|
||||
return nil
|
||||
}
|
||||
_ = m.client.call(mysqlAgentRequest{Method: mysqlAgentMethodClose}, nil, nil, nil)
|
||||
err := m.client.close()
|
||||
m.client = nil
|
||||
return err
|
||||
}
|
||||
|
||||
func (m *MySQLAgentDB) Ping() error {
|
||||
client, err := m.requireClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return client.call(mysqlAgentRequest{Method: mysqlAgentMethodPing}, nil, nil, nil)
|
||||
}
|
||||
|
||||
func (m *MySQLAgentDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return m.Query(query)
|
||||
}
|
||||
|
||||
func (m *MySQLAgentDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
client, err := m.requireClient()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
var data []map[string]interface{}
|
||||
var fields []string
|
||||
if err := client.call(mysqlAgentRequest{
|
||||
Method: mysqlAgentMethodQuery,
|
||||
Query: query,
|
||||
}, &data, &fields, nil); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return data, fields, nil
|
||||
}
|
||||
|
||||
func (m *MySQLAgentDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return m.Exec(query)
|
||||
}
|
||||
|
||||
func (m *MySQLAgentDB) Exec(query string) (int64, error) {
|
||||
client, err := m.requireClient()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
var affected int64
|
||||
if err := client.call(mysqlAgentRequest{
|
||||
Method: mysqlAgentMethodExec,
|
||||
Query: query,
|
||||
}, nil, nil, &affected); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return affected, nil
|
||||
}
|
||||
|
||||
func (m *MySQLAgentDB) GetDatabases() ([]string, error) {
|
||||
client, err := m.requireClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var dbs []string
|
||||
if err := client.call(mysqlAgentRequest{
|
||||
Method: mysqlAgentMethodGetDatabases,
|
||||
}, &dbs, nil, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dbs, nil
|
||||
}
|
||||
|
||||
func (m *MySQLAgentDB) GetTables(dbName string) ([]string, error) {
|
||||
client, err := m.requireClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var tables []string
|
||||
if err := client.call(mysqlAgentRequest{
|
||||
Method: mysqlAgentMethodGetTables,
|
||||
DBName: dbName,
|
||||
}, &tables, nil, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tables, nil
|
||||
}
|
||||
|
||||
func (m *MySQLAgentDB) GetCreateStatement(dbName, tableName string) (string, error) {
|
||||
client, err := m.requireClient()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
var sqlText string
|
||||
if err := client.call(mysqlAgentRequest{
|
||||
Method: mysqlAgentMethodGetCreateStmt,
|
||||
DBName: dbName,
|
||||
TableName: tableName,
|
||||
}, &sqlText, nil, nil); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return sqlText, nil
|
||||
}
|
||||
|
||||
func (m *MySQLAgentDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
|
||||
client, err := m.requireClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var columns []connection.ColumnDefinition
|
||||
if err := client.call(mysqlAgentRequest{
|
||||
Method: mysqlAgentMethodGetColumns,
|
||||
DBName: dbName,
|
||||
TableName: tableName,
|
||||
}, &columns, nil, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
func (m *MySQLAgentDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
|
||||
client, err := m.requireClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var columns []connection.ColumnDefinitionWithTable
|
||||
if err := client.call(mysqlAgentRequest{
|
||||
Method: mysqlAgentMethodGetAllColumns,
|
||||
DBName: dbName,
|
||||
}, &columns, nil, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
func (m *MySQLAgentDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
|
||||
client, err := m.requireClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var indexes []connection.IndexDefinition
|
||||
if err := client.call(mysqlAgentRequest{
|
||||
Method: mysqlAgentMethodGetIndexes,
|
||||
DBName: dbName,
|
||||
TableName: tableName,
|
||||
}, &indexes, nil, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return indexes, nil
|
||||
}
|
||||
|
||||
func (m *MySQLAgentDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
|
||||
client, err := m.requireClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var keys []connection.ForeignKeyDefinition
|
||||
if err := client.call(mysqlAgentRequest{
|
||||
Method: mysqlAgentMethodGetForeignKeys,
|
||||
DBName: dbName,
|
||||
TableName: tableName,
|
||||
}, &keys, nil, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func (m *MySQLAgentDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
|
||||
client, err := m.requireClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var triggers []connection.TriggerDefinition
|
||||
if err := client.call(mysqlAgentRequest{
|
||||
Method: mysqlAgentMethodGetTriggers,
|
||||
DBName: dbName,
|
||||
TableName: tableName,
|
||||
}, &triggers, nil, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return triggers, nil
|
||||
}
|
||||
|
||||
func (m *MySQLAgentDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
|
||||
client, err := m.requireClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return client.call(mysqlAgentRequest{
|
||||
Method: mysqlAgentMethodApplyChanges,
|
||||
TableName: tableName,
|
||||
Changes: &changes,
|
||||
}, nil, nil, nil)
|
||||
}
|
||||
|
||||
func (m *MySQLAgentDB) requireClient() (*mysqlAgentClient, error) {
|
||||
if m.client == nil {
|
||||
return nil, fmt.Errorf("connection not open")
|
||||
}
|
||||
return m.client, nil
|
||||
}
|
||||
40
internal/db/mysql_agent_path.go
Normal file
40
internal/db/mysql_agent_path.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func mysqlAgentExecutableName() string {
|
||||
return optionalDriverAgentExecutableName("mysql")
|
||||
}
|
||||
|
||||
func optionalDriverAgentExecutableName(driverType string) string {
|
||||
normalized := normalizeRuntimeDriverType(driverType)
|
||||
if normalized == "" {
|
||||
normalized = "unknown"
|
||||
}
|
||||
name := fmt.Sprintf("%s-driver-agent", normalized)
|
||||
if runtime.GOOS == "windows" {
|
||||
return name + ".exe"
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
func ResolveOptionalDriverAgentExecutablePath(downloadDir string, driverType string) (string, error) {
|
||||
normalized := normalizeRuntimeDriverType(driverType)
|
||||
if strings.TrimSpace(normalized) == "" {
|
||||
return "", fmt.Errorf("驱动类型为空")
|
||||
}
|
||||
root, err := resolveExternalDriverRoot(downloadDir)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(root, normalized, optionalDriverAgentExecutableName(normalized)), nil
|
||||
}
|
||||
|
||||
func ResolveMySQLAgentExecutablePath(downloadDir string) (string, error) {
|
||||
return ResolveOptionalDriverAgentExecutablePath(downloadDir, "mysql")
|
||||
}
|
||||
440
internal/db/optional_driver_agent_impl.go
Normal file
440
internal/db/optional_driver_agent_impl.go
Normal file
@@ -0,0 +1,440 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
)
|
||||
|
||||
const (
|
||||
optionalAgentMethodConnect = "connect"
|
||||
optionalAgentMethodClose = "close"
|
||||
optionalAgentMethodPing = "ping"
|
||||
optionalAgentMethodQuery = "query"
|
||||
optionalAgentMethodExec = "exec"
|
||||
optionalAgentMethodGetDatabases = "getDatabases"
|
||||
optionalAgentMethodGetTables = "getTables"
|
||||
optionalAgentMethodGetCreateStmt = "getCreateStatement"
|
||||
optionalAgentMethodGetColumns = "getColumns"
|
||||
optionalAgentMethodGetAllColumns = "getAllColumns"
|
||||
optionalAgentMethodGetIndexes = "getIndexes"
|
||||
optionalAgentMethodGetForeignKeys = "getForeignKeys"
|
||||
optionalAgentMethodGetTriggers = "getTriggers"
|
||||
optionalAgentMethodApplyChanges = "applyChanges"
|
||||
optionalAgentDefaultScannerMaxBytes = 8 << 20
|
||||
)
|
||||
|
||||
type optionalAgentRequest struct {
|
||||
ID int64 `json:"id"`
|
||||
Method string `json:"method"`
|
||||
Config *connection.ConnectionConfig `json:"config,omitempty"`
|
||||
Query string `json:"query,omitempty"`
|
||||
DBName string `json:"dbName,omitempty"`
|
||||
TableName string `json:"tableName,omitempty"`
|
||||
Changes *connection.ChangeSet `json:"changes,omitempty"`
|
||||
}
|
||||
|
||||
type optionalAgentResponse struct {
|
||||
ID int64 `json:"id"`
|
||||
Success bool `json:"success"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
Fields []string `json:"fields,omitempty"`
|
||||
RowsAffected int64 `json:"rowsAffected,omitempty"`
|
||||
}
|
||||
|
||||
type optionalDriverAgentClient struct {
|
||||
cmd *exec.Cmd
|
||||
stdin io.WriteCloser
|
||||
reader *bufio.Reader
|
||||
nextID int64
|
||||
mu sync.Mutex
|
||||
stderrMu sync.Mutex
|
||||
stderr strings.Builder
|
||||
driver string
|
||||
}
|
||||
|
||||
func newOptionalDriverAgentClient(driverType string, executablePath string) (*optionalDriverAgentClient, error) {
|
||||
pathText := strings.TrimSpace(executablePath)
|
||||
if pathText == "" {
|
||||
return nil, fmt.Errorf("%s 驱动代理路径为空", driverDisplayName(driverType))
|
||||
}
|
||||
if info, err := os.Stat(pathText); err != nil || info.IsDir() {
|
||||
return nil, fmt.Errorf("%s 驱动代理不存在:%s", driverDisplayName(driverType), pathText)
|
||||
}
|
||||
|
||||
cmd := exec.Command(pathText)
|
||||
stdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建 %s 驱动代理 stdin 失败:%w", driverDisplayName(driverType), err)
|
||||
}
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建 %s 驱动代理 stdout 失败:%w", driverDisplayName(driverType), err)
|
||||
}
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建 %s 驱动代理 stderr 失败:%w", driverDisplayName(driverType), err)
|
||||
}
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("启动 %s 驱动代理失败:%w", driverDisplayName(driverType), err)
|
||||
}
|
||||
|
||||
client := &optionalDriverAgentClient{
|
||||
cmd: cmd,
|
||||
stdin: stdin,
|
||||
reader: bufio.NewReader(stdout),
|
||||
driver: normalizeRuntimeDriverType(driverType),
|
||||
}
|
||||
go client.captureStderr(stderr)
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (c *optionalDriverAgentClient) captureStderr(stderr io.Reader) {
|
||||
scanner := bufio.NewScanner(stderr)
|
||||
buffer := make([]byte, 0, 8<<10)
|
||||
scanner.Buffer(buffer, optionalAgentDefaultScannerMaxBytes)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
c.stderrMu.Lock()
|
||||
if c.stderr.Len() > 0 {
|
||||
c.stderr.WriteString(" | ")
|
||||
}
|
||||
c.stderr.WriteString(line)
|
||||
c.stderrMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *optionalDriverAgentClient) stderrText() string {
|
||||
c.stderrMu.Lock()
|
||||
defer c.stderrMu.Unlock()
|
||||
return strings.TrimSpace(c.stderr.String())
|
||||
}
|
||||
|
||||
func (c *optionalDriverAgentClient) call(req optionalAgentRequest, out interface{}, fields *[]string, rowsAffected *int64) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.nextID++
|
||||
req.ID = c.nextID
|
||||
|
||||
payload, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
payload = append(payload, '\n')
|
||||
if _, err := c.stdin.Write(payload); err != nil {
|
||||
stderrText := c.stderrText()
|
||||
if stderrText == "" {
|
||||
return fmt.Errorf("调用 %s 驱动代理失败:%w", driverDisplayName(c.driver), err)
|
||||
}
|
||||
return fmt.Errorf("调用 %s 驱动代理失败:%w(stderr: %s)", driverDisplayName(c.driver), err, stderrText)
|
||||
}
|
||||
|
||||
line, err := c.reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
stderrText := c.stderrText()
|
||||
if stderrText == "" {
|
||||
return fmt.Errorf("读取 %s 驱动代理响应失败:%w", driverDisplayName(c.driver), err)
|
||||
}
|
||||
return fmt.Errorf("读取 %s 驱动代理响应失败:%w(stderr: %s)", driverDisplayName(c.driver), err, stderrText)
|
||||
}
|
||||
|
||||
var resp optionalAgentResponse
|
||||
if err := json.Unmarshal(line, &resp); err != nil {
|
||||
return fmt.Errorf("解析 %s 驱动代理响应失败:%w", driverDisplayName(c.driver), err)
|
||||
}
|
||||
if !resp.Success {
|
||||
errText := strings.TrimSpace(resp.Error)
|
||||
if errText == "" {
|
||||
errText = fmt.Sprintf("%s 驱动代理返回失败", driverDisplayName(c.driver))
|
||||
}
|
||||
return errors.New(errText)
|
||||
}
|
||||
|
||||
if fields != nil {
|
||||
*fields = resp.Fields
|
||||
}
|
||||
if rowsAffected != nil {
|
||||
*rowsAffected = resp.RowsAffected
|
||||
}
|
||||
if out != nil && len(resp.Data) > 0 {
|
||||
if err := json.Unmarshal(resp.Data, out); err != nil {
|
||||
return fmt.Errorf("解析 %s 驱动代理数据失败:%w", driverDisplayName(c.driver), err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *optionalDriverAgentClient) close() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
var closeErr error
|
||||
if c.stdin != nil {
|
||||
_ = c.stdin.Close()
|
||||
}
|
||||
if c.cmd != nil && c.cmd.Process != nil {
|
||||
if err := c.cmd.Process.Kill(); err != nil {
|
||||
closeErr = err
|
||||
}
|
||||
}
|
||||
if c.cmd != nil {
|
||||
_ = c.cmd.Wait()
|
||||
}
|
||||
return closeErr
|
||||
}
|
||||
|
||||
type OptionalDriverAgentDB struct {
|
||||
driverType string
|
||||
client *optionalDriverAgentClient
|
||||
}
|
||||
|
||||
func newOptionalDriverAgentDatabase(driverType string) databaseFactory {
|
||||
normalized := normalizeRuntimeDriverType(driverType)
|
||||
return func() Database {
|
||||
return &OptionalDriverAgentDB{driverType: normalized}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *OptionalDriverAgentDB) Connect(config connection.ConnectionConfig) error {
|
||||
if d.client != nil {
|
||||
_ = d.client.close()
|
||||
d.client = nil
|
||||
}
|
||||
|
||||
executablePath, err := ResolveOptionalDriverAgentExecutablePath("", d.driverType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
client, err := newOptionalDriverAgentClient(d.driverType, executablePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := client.call(optionalAgentRequest{
|
||||
Method: optionalAgentMethodConnect,
|
||||
Config: &config,
|
||||
}, nil, nil, nil); err != nil {
|
||||
_ = client.close()
|
||||
return err
|
||||
}
|
||||
d.client = client
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *OptionalDriverAgentDB) Close() error {
|
||||
if d.client == nil {
|
||||
return nil
|
||||
}
|
||||
_ = d.client.call(optionalAgentRequest{Method: optionalAgentMethodClose}, nil, nil, nil)
|
||||
err := d.client.close()
|
||||
d.client = nil
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *OptionalDriverAgentDB) Ping() error {
|
||||
client, err := d.requireClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return client.call(optionalAgentRequest{Method: optionalAgentMethodPing}, nil, nil, nil)
|
||||
}
|
||||
|
||||
func (d *OptionalDriverAgentDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return d.Query(query)
|
||||
}
|
||||
|
||||
func (d *OptionalDriverAgentDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
client, err := d.requireClient()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
var data []map[string]interface{}
|
||||
var fields []string
|
||||
if err := client.call(optionalAgentRequest{
|
||||
Method: optionalAgentMethodQuery,
|
||||
Query: query,
|
||||
}, &data, &fields, nil); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return data, fields, nil
|
||||
}
|
||||
|
||||
func (d *OptionalDriverAgentDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return d.Exec(query)
|
||||
}
|
||||
|
||||
func (d *OptionalDriverAgentDB) Exec(query string) (int64, error) {
|
||||
client, err := d.requireClient()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
var affected int64
|
||||
if err := client.call(optionalAgentRequest{
|
||||
Method: optionalAgentMethodExec,
|
||||
Query: query,
|
||||
}, nil, nil, &affected); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return affected, nil
|
||||
}
|
||||
|
||||
func (d *OptionalDriverAgentDB) GetDatabases() ([]string, error) {
|
||||
client, err := d.requireClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var dbs []string
|
||||
if err := client.call(optionalAgentRequest{
|
||||
Method: optionalAgentMethodGetDatabases,
|
||||
}, &dbs, nil, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dbs, nil
|
||||
}
|
||||
|
||||
func (d *OptionalDriverAgentDB) GetTables(dbName string) ([]string, error) {
|
||||
client, err := d.requireClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var tables []string
|
||||
if err := client.call(optionalAgentRequest{
|
||||
Method: optionalAgentMethodGetTables,
|
||||
DBName: dbName,
|
||||
}, &tables, nil, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tables, nil
|
||||
}
|
||||
|
||||
func (d *OptionalDriverAgentDB) GetCreateStatement(dbName, tableName string) (string, error) {
|
||||
client, err := d.requireClient()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
var sqlText string
|
||||
if err := client.call(optionalAgentRequest{
|
||||
Method: optionalAgentMethodGetCreateStmt,
|
||||
DBName: dbName,
|
||||
TableName: tableName,
|
||||
}, &sqlText, nil, nil); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return sqlText, nil
|
||||
}
|
||||
|
||||
func (d *OptionalDriverAgentDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
|
||||
client, err := d.requireClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var columns []connection.ColumnDefinition
|
||||
if err := client.call(optionalAgentRequest{
|
||||
Method: optionalAgentMethodGetColumns,
|
||||
DBName: dbName,
|
||||
TableName: tableName,
|
||||
}, &columns, nil, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
func (d *OptionalDriverAgentDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
|
||||
client, err := d.requireClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var columns []connection.ColumnDefinitionWithTable
|
||||
if err := client.call(optionalAgentRequest{
|
||||
Method: optionalAgentMethodGetAllColumns,
|
||||
DBName: dbName,
|
||||
}, &columns, nil, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
func (d *OptionalDriverAgentDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
|
||||
client, err := d.requireClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var indexes []connection.IndexDefinition
|
||||
if err := client.call(optionalAgentRequest{
|
||||
Method: optionalAgentMethodGetIndexes,
|
||||
DBName: dbName,
|
||||
TableName: tableName,
|
||||
}, &indexes, nil, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return indexes, nil
|
||||
}
|
||||
|
||||
func (d *OptionalDriverAgentDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
|
||||
client, err := d.requireClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var keys []connection.ForeignKeyDefinition
|
||||
if err := client.call(optionalAgentRequest{
|
||||
Method: optionalAgentMethodGetForeignKeys,
|
||||
DBName: dbName,
|
||||
TableName: tableName,
|
||||
}, &keys, nil, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func (d *OptionalDriverAgentDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
|
||||
client, err := d.requireClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var triggers []connection.TriggerDefinition
|
||||
if err := client.call(optionalAgentRequest{
|
||||
Method: optionalAgentMethodGetTriggers,
|
||||
DBName: dbName,
|
||||
TableName: tableName,
|
||||
}, &triggers, nil, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return triggers, nil
|
||||
}
|
||||
|
||||
func (d *OptionalDriverAgentDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
|
||||
client, err := d.requireClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return client.call(optionalAgentRequest{
|
||||
Method: optionalAgentMethodApplyChanges,
|
||||
TableName: tableName,
|
||||
Changes: &changes,
|
||||
}, nil, nil, nil)
|
||||
}
|
||||
|
||||
func (d *OptionalDriverAgentDB) requireClient() (*optionalDriverAgentClient, error) {
|
||||
if d.client == nil {
|
||||
return nil, fmt.Errorf("connection not open")
|
||||
}
|
||||
return d.client, nil
|
||||
}
|
||||
9
internal/db/optional_driver_build_full.go
Normal file
9
internal/db/optional_driver_build_full.go
Normal file
@@ -0,0 +1,9 @@
|
||||
//go:build gonavi_full_drivers
|
||||
|
||||
package db
|
||||
|
||||
func optionalGoDriverBuildIncluded(driverType string) bool {
|
||||
_, ok := optionalGoDrivers[normalizeRuntimeDriverType(driverType)]
|
||||
return ok
|
||||
}
|
||||
|
||||
8
internal/db/optional_driver_build_lite.go
Normal file
8
internal/db/optional_driver_build_lite.go
Normal file
@@ -0,0 +1,8 @@
|
||||
//go:build !gonavi_full_drivers
|
||||
|
||||
package db
|
||||
|
||||
func optionalGoDriverBuildIncluded(driverType string) bool {
|
||||
_, ok := optionalGoDrivers[normalizeRuntimeDriverType(driverType)]
|
||||
return ok
|
||||
}
|
||||
@@ -18,14 +18,12 @@ import (
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
|
||||
type PostgresDB struct {
|
||||
conn *sql.DB
|
||||
pingTimeout time.Duration
|
||||
forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder
|
||||
}
|
||||
|
||||
|
||||
func (p *PostgresDB) getDSN(config connection.ConnectionConfig) string {
|
||||
// postgres://user:password@host:port/dbname?sslmode=disable
|
||||
dbname := config.Database
|
||||
@@ -48,6 +46,13 @@ func (p *PostgresDB) getDSN(config connection.ConnectionConfig) string {
|
||||
}
|
||||
|
||||
func (p *PostgresDB) Connect(config connection.ConnectionConfig) error {
|
||||
if supported, reason := DriverRuntimeSupportStatus("postgres"); !supported {
|
||||
if strings.TrimSpace(reason) == "" {
|
||||
reason = "PostgreSQL 纯 Go 驱动未启用,请先在驱动管理中安装启用"
|
||||
}
|
||||
return fmt.Errorf("%s", reason)
|
||||
}
|
||||
|
||||
var dsn string
|
||||
var err error
|
||||
|
||||
@@ -98,7 +103,6 @@ func (p *PostgresDB) Connect(config connection.ConnectionConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
func (p *PostgresDB) Close() error {
|
||||
// Close SSH forwarder first if exists
|
||||
if p.forwarder != nil {
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build gonavi_full_drivers || gonavi_sphinx_driver
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build gonavi_full_drivers || gonavi_sqlite_driver
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build gonavi_full_drivers || gonavi_sqlserver_driver
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build gonavi_full_drivers || gonavi_tdengine_driver
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build gonavi_full_drivers || gonavi_vastbase_driver
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
|
||||
Reference in New Issue
Block a user