mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-05-06 20:03:05 +08:00
- 前端连接表单新增额外连接参数入口,支持 URI query 格式录入与解析回填 - MySQL 兼容驱动支持 JDBC 常见参数映射,修复 UTF-8 字符集与 serverTimezone 兼容问题 - 扩展 Oracle、PostgreSQL 兼容、SQL Server、ClickHouse、MongoDB、达梦、TDengine 参数合并 - 按不同驱动通道处理 DSN、URI、Options 与 Settings,避免统一透传导致连接异常 - 修复编辑已保存连接时解析无认证 URI 会清空已有账号密码的问题 - 补充连接参数透传、缓存隔离、DSN 合并与 URI 回填回归测试
527 lines
13 KiB
Go
527 lines
13 KiB
Go
//go:build gonavi_full_drivers || gonavi_tdengine_driver
|
||
|
||
package db
|
||
|
||
import (
|
||
"context"
|
||
"database/sql"
|
||
"fmt"
|
||
"net"
|
||
"net/url"
|
||
"sort"
|
||
"strconv"
|
||
"strings"
|
||
"time"
|
||
|
||
"GoNavi-Wails/internal/connection"
|
||
"GoNavi-Wails/internal/logger"
|
||
"GoNavi-Wails/internal/ssh"
|
||
"GoNavi-Wails/internal/utils"
|
||
|
||
_ "github.com/taosdata/driver-go/v3/taosWS"
|
||
)
|
||
|
||
// TDengineDB implements Database interface for TDengine.
|
||
// Uses taosWS driver via WebSocket (通常通过 taosAdapter 提供服务)。
|
||
type TDengineDB struct {
|
||
conn *sql.DB
|
||
pingTimeout time.Duration
|
||
forwarder *ssh.LocalForwarder
|
||
}
|
||
|
||
func (t *TDengineDB) getDSN(config connection.ConnectionConfig) string {
|
||
user := strings.TrimSpace(config.User)
|
||
if user == "" {
|
||
user = "root"
|
||
}
|
||
|
||
pass := config.Password
|
||
dbName := strings.TrimSpace(config.Database)
|
||
path := "/"
|
||
if dbName != "" {
|
||
path = "/" + dbName
|
||
}
|
||
|
||
netType := resolveTDengineNet(config)
|
||
params := url.Values{}
|
||
mergeConnectionParamsFromConfig(params, config, "taos", "taosws", "tdengine")
|
||
params.Del("protocol")
|
||
params.Del("skip_verify")
|
||
query := params.Encode()
|
||
dsn := fmt.Sprintf("%s:%s@%s(%s)%s", user, pass, netType, net.JoinHostPort(config.Host, strconv.Itoa(config.Port)), path)
|
||
if query == "" {
|
||
return dsn
|
||
}
|
||
return dsn + "?" + query
|
||
}
|
||
|
||
func (t *TDengineDB) Connect(config connection.ConnectionConfig) error {
|
||
runConfig := config
|
||
|
||
if config.UseSSH {
|
||
logger.Infof("TDengine 使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User)
|
||
|
||
forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port)
|
||
if err != nil {
|
||
return fmt.Errorf("创建 SSH 隧道失败:%w", err)
|
||
}
|
||
t.forwarder = forwarder
|
||
|
||
host, portStr, err := net.SplitHostPort(forwarder.LocalAddr)
|
||
if err != nil {
|
||
return fmt.Errorf("解析本地转发地址失败:%w", err)
|
||
}
|
||
port, err := strconv.Atoi(portStr)
|
||
if err != nil {
|
||
return fmt.Errorf("解析本地端口失败:%w", err)
|
||
}
|
||
|
||
localConfig := config
|
||
localConfig.Host = host
|
||
localConfig.Port = port
|
||
localConfig.UseSSH = false
|
||
runConfig = localConfig
|
||
logger.Infof("TDengine 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
|
||
}
|
||
|
||
attempts := []connection.ConnectionConfig{runConfig}
|
||
if shouldTrySSLPreferredFallback(runConfig) {
|
||
attempts = append(attempts, withSSLDisabled(runConfig))
|
||
}
|
||
|
||
var failures []string
|
||
for idx, attempt := range attempts {
|
||
dsn := t.getDSN(attempt)
|
||
db, err := sql.Open("taosWS", dsn)
|
||
if err != nil {
|
||
failures = append(failures, fmt.Sprintf("第%d次连接打开失败: %v", idx+1, err))
|
||
continue
|
||
}
|
||
t.conn = db
|
||
t.pingTimeout = getConnectTimeout(attempt)
|
||
|
||
if err := t.Ping(); err != nil {
|
||
_ = db.Close()
|
||
t.conn = nil
|
||
failures = append(failures, fmt.Sprintf("第%d次连接验证失败: %v", idx+1, err))
|
||
continue
|
||
}
|
||
if idx > 0 {
|
||
logger.Warnf("TDengine SSL 优先连接失败,已回退至明文连接")
|
||
}
|
||
return nil
|
||
}
|
||
return fmt.Errorf("连接建立后验证失败:%s", strings.Join(failures, ";"))
|
||
}
|
||
|
||
func (t *TDengineDB) Close() error {
|
||
if t.forwarder != nil {
|
||
if err := t.forwarder.Close(); err != nil {
|
||
logger.Warnf("关闭 TDengine SSH 端口转发失败:%v", err)
|
||
}
|
||
t.forwarder = nil
|
||
}
|
||
|
||
if t.conn != nil {
|
||
return t.conn.Close()
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (t *TDengineDB) Ping() error {
|
||
if t.conn == nil {
|
||
return fmt.Errorf("连接未打开")
|
||
}
|
||
timeout := t.pingTimeout
|
||
if timeout <= 0 {
|
||
timeout = 5 * time.Second
|
||
}
|
||
ctx, cancel := utils.ContextWithTimeout(timeout)
|
||
defer cancel()
|
||
return t.conn.PingContext(ctx)
|
||
}
|
||
|
||
func (t *TDengineDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||
if t.conn == nil {
|
||
return nil, nil, fmt.Errorf("连接未打开")
|
||
}
|
||
|
||
rows, err := t.conn.QueryContext(ctx, query)
|
||
if err != nil {
|
||
return nil, nil, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
return scanRows(rows)
|
||
}
|
||
|
||
func (t *TDengineDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||
if t.conn == nil {
|
||
return nil, nil, fmt.Errorf("连接未打开")
|
||
}
|
||
|
||
rows, err := t.conn.Query(query)
|
||
if err != nil {
|
||
return nil, nil, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
return scanRows(rows)
|
||
}
|
||
|
||
func (t *TDengineDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||
if t.conn == nil {
|
||
return 0, fmt.Errorf("连接未打开")
|
||
}
|
||
res, err := t.conn.ExecContext(ctx, query)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return res.RowsAffected()
|
||
}
|
||
|
||
func (t *TDengineDB) Exec(query string) (int64, error) {
|
||
if t.conn == nil {
|
||
return 0, fmt.Errorf("连接未打开")
|
||
}
|
||
res, err := t.conn.Exec(query)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return res.RowsAffected()
|
||
}
|
||
|
||
func (t *TDengineDB) GetDatabases() ([]string, error) {
|
||
data, _, err := t.Query("SHOW DATABASES")
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
var dbs []string
|
||
for _, row := range data {
|
||
if val, ok := getValueFromRow(row, "name", "database", "Database", "db_name"); ok {
|
||
dbs = append(dbs, fmt.Sprintf("%v", val))
|
||
continue
|
||
}
|
||
for _, val := range row {
|
||
dbs = append(dbs, fmt.Sprintf("%v", val))
|
||
break
|
||
}
|
||
}
|
||
return dbs, nil
|
||
}
|
||
|
||
func (t *TDengineDB) GetTables(dbName string) ([]string, error) {
|
||
queries := make([]string, 0, 4)
|
||
if strings.TrimSpace(dbName) != "" {
|
||
queries = append(queries, fmt.Sprintf("SHOW TABLES FROM `%s`", escapeBacktickIdent(dbName)))
|
||
queries = append(queries, fmt.Sprintf("SHOW STABLES FROM `%s`", escapeBacktickIdent(dbName)))
|
||
}
|
||
queries = append(queries, "SHOW TABLES")
|
||
queries = append(queries, "SHOW STABLES")
|
||
|
||
var lastErr error
|
||
tableSet := make(map[string]struct{})
|
||
tables := make([]string, 0)
|
||
for _, query := range queries {
|
||
data, _, err := t.Query(query)
|
||
if err != nil {
|
||
lastErr = err
|
||
continue
|
||
}
|
||
|
||
for _, row := range data {
|
||
if val, ok := getValueFromRow(row, "table_name", "tablename", "name", "Table", "table"); ok {
|
||
tableName := strings.TrimSpace(fmt.Sprintf("%v", val))
|
||
if tableName == "" {
|
||
continue
|
||
}
|
||
if _, exists := tableSet[tableName]; exists {
|
||
continue
|
||
}
|
||
tableSet[tableName] = struct{}{}
|
||
tables = append(tables, tableName)
|
||
continue
|
||
}
|
||
for _, val := range row {
|
||
tableName := strings.TrimSpace(fmt.Sprintf("%v", val))
|
||
if tableName == "" {
|
||
break
|
||
}
|
||
if _, exists := tableSet[tableName]; exists {
|
||
break
|
||
}
|
||
tableSet[tableName] = struct{}{}
|
||
tables = append(tables, tableName)
|
||
break
|
||
}
|
||
}
|
||
}
|
||
if len(tables) > 0 {
|
||
sort.Strings(tables)
|
||
return tables, nil
|
||
}
|
||
|
||
if lastErr != nil {
|
||
return nil, lastErr
|
||
}
|
||
return []string{}, nil
|
||
}
|
||
|
||
func (t *TDengineDB) GetCreateStatement(dbName, tableName string) (string, error) {
|
||
qualified := quoteTDengineTable(dbName, tableName)
|
||
queries := []string{
|
||
fmt.Sprintf("SHOW CREATE TABLE %s", qualified),
|
||
fmt.Sprintf("SHOW CREATE STABLE %s", qualified),
|
||
}
|
||
|
||
var lastErr error
|
||
for _, query := range queries {
|
||
data, _, err := t.Query(query)
|
||
if err != nil {
|
||
lastErr = err
|
||
continue
|
||
}
|
||
if len(data) == 0 {
|
||
continue
|
||
}
|
||
|
||
row := data[0]
|
||
if val, ok := getValueFromRow(row, "Create Table", "create table", "Create Stable", "create stable", "SQL", "sql"); ok {
|
||
return fmt.Sprintf("%v", val), nil
|
||
}
|
||
|
||
longest := ""
|
||
for _, val := range row {
|
||
text := fmt.Sprintf("%v", val)
|
||
if strings.Contains(strings.ToUpper(text), "CREATE ") && len(text) > len(longest) {
|
||
longest = text
|
||
}
|
||
}
|
||
if longest != "" {
|
||
return longest, nil
|
||
}
|
||
}
|
||
|
||
if lastErr != nil {
|
||
return "", lastErr
|
||
}
|
||
return "", fmt.Errorf("未找到建表语句")
|
||
}
|
||
|
||
func (t *TDengineDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
|
||
query := fmt.Sprintf("DESCRIBE %s", quoteTDengineTable(dbName, tableName))
|
||
data, _, err := t.Query(query)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
columns := make([]connection.ColumnDefinition, 0, len(data))
|
||
for _, row := range data {
|
||
name, _ := getValueFromRow(row, "Field", "field", "col_name", "column_name", "name")
|
||
colType, _ := getValueFromRow(row, "Type", "type", "data_type")
|
||
note, _ := getValueFromRow(row, "Note", "note", "Extra", "extra")
|
||
nullable, okNull := getValueFromRow(row, "Null", "null", "nullable")
|
||
comment, _ := getValueFromRow(row, "Comment", "comment")
|
||
defaultVal, hasDefault := getValueFromRow(row, "Default", "default")
|
||
|
||
col := connection.ColumnDefinition{
|
||
Name: fmt.Sprintf("%v", name),
|
||
Type: fmt.Sprintf("%v", colType),
|
||
Nullable: "YES",
|
||
Key: "",
|
||
Extra: fmt.Sprintf("%v", note),
|
||
Comment: fmt.Sprintf("%v", comment),
|
||
}
|
||
|
||
if okNull {
|
||
col.Nullable = strings.ToUpper(fmt.Sprintf("%v", nullable))
|
||
}
|
||
|
||
noteUpper := strings.ToUpper(fmt.Sprintf("%v", note))
|
||
if strings.Contains(noteUpper, "TAG") {
|
||
col.Key = "TAG"
|
||
}
|
||
|
||
if hasDefault && defaultVal != nil {
|
||
def := fmt.Sprintf("%v", defaultVal)
|
||
if def != "<nil>" {
|
||
col.Default = &def
|
||
}
|
||
}
|
||
|
||
columns = append(columns, col)
|
||
}
|
||
return columns, nil
|
||
}
|
||
|
||
func (t *TDengineDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
|
||
if strings.TrimSpace(dbName) == "" {
|
||
return nil, fmt.Errorf("获取全部列信息需要指定数据库名称")
|
||
}
|
||
|
||
tables, err := t.GetTables(dbName)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
cols := make([]connection.ColumnDefinitionWithTable, 0)
|
||
for _, table := range tables {
|
||
tableCols, err := t.GetColumns(dbName, table)
|
||
if err != nil {
|
||
continue
|
||
}
|
||
for _, col := range tableCols {
|
||
cols = append(cols, connection.ColumnDefinitionWithTable{
|
||
TableName: table,
|
||
Name: col.Name,
|
||
Type: col.Type,
|
||
})
|
||
}
|
||
}
|
||
|
||
return cols, nil
|
||
}
|
||
|
||
func (t *TDengineDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
|
||
return []connection.IndexDefinition{}, nil
|
||
}
|
||
|
||
func (t *TDengineDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
|
||
return []connection.ForeignKeyDefinition{}, nil
|
||
}
|
||
|
||
func (t *TDengineDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
|
||
return []connection.TriggerDefinition{}, nil
|
||
}
|
||
|
||
func (t *TDengineDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
|
||
if t.conn == nil {
|
||
return fmt.Errorf("连接未打开")
|
||
}
|
||
if strings.TrimSpace(tableName) == "" {
|
||
return fmt.Errorf("表名不能为空")
|
||
}
|
||
if len(changes.Updates) > 0 || len(changes.Deletes) > 0 {
|
||
return fmt.Errorf("TDengine 目标端当前仅支持 INSERT 写入,暂不支持 UPDATE/DELETE 差异同步,请改用仅插入或全量覆盖模式")
|
||
}
|
||
|
||
qualifiedTable := quoteTDengineTable("", tableName)
|
||
for _, row := range changes.Inserts {
|
||
query, err := buildTDengineInsertSQL(qualifiedTable, row)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if query == "" {
|
||
continue
|
||
}
|
||
if _, err := t.conn.Exec(query); err != nil {
|
||
return fmt.Errorf("插入失败:%v; sql=%s", err, query)
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func buildTDengineInsertSQL(qualifiedTable string, row map[string]interface{}) (string, error) {
|
||
if strings.TrimSpace(qualifiedTable) == "" {
|
||
return "", fmt.Errorf("需要指定完整的表名")
|
||
}
|
||
if len(row) == 0 {
|
||
return "", nil
|
||
}
|
||
|
||
cols := make([]string, 0, len(row))
|
||
for key := range row {
|
||
if strings.TrimSpace(key) == "" {
|
||
continue
|
||
}
|
||
cols = append(cols, key)
|
||
}
|
||
if len(cols) == 0 {
|
||
return "", nil
|
||
}
|
||
sort.Strings(cols)
|
||
|
||
quotedCols := make([]string, 0, len(cols))
|
||
values := make([]string, 0, len(cols))
|
||
for _, col := range cols {
|
||
quotedCols = append(quotedCols, fmt.Sprintf("`%s`", escapeBacktickIdent(col)))
|
||
values = append(values, tdengineLiteral(row[col]))
|
||
}
|
||
|
||
return fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(quotedCols, ", "), strings.Join(values, ", ")), nil
|
||
}
|
||
|
||
func tdengineLiteral(value interface{}) string {
|
||
switch val := value.(type) {
|
||
case nil:
|
||
return "NULL"
|
||
case bool:
|
||
if val {
|
||
return "1"
|
||
}
|
||
return "0"
|
||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64:
|
||
return fmt.Sprintf("%v", val)
|
||
case time.Time:
|
||
return fmt.Sprintf("'%s'", val.Format("2006-01-02 15:04:05"))
|
||
case []byte:
|
||
return fmt.Sprintf("'%s'", strings.ReplaceAll(string(val), "'", "''"))
|
||
default:
|
||
return fmt.Sprintf("'%s'", strings.ReplaceAll(fmt.Sprintf("%v", val), "'", "''"))
|
||
}
|
||
}
|
||
|
||
func getValueFromRow(row map[string]interface{}, keys ...string) (interface{}, bool) {
|
||
if len(row) == 0 {
|
||
return nil, false
|
||
}
|
||
|
||
for _, key := range keys {
|
||
if val, ok := row[key]; ok {
|
||
return val, true
|
||
}
|
||
}
|
||
|
||
for existingKey, val := range row {
|
||
for _, key := range keys {
|
||
if strings.EqualFold(existingKey, key) {
|
||
return val, true
|
||
}
|
||
}
|
||
}
|
||
|
||
return nil, false
|
||
}
|
||
|
||
func escapeBacktickIdent(ident string) string {
|
||
return strings.ReplaceAll(strings.TrimSpace(ident), "`", "``")
|
||
}
|
||
|
||
func quoteTDengineTable(dbName, tableName string) string {
|
||
t := escapeBacktickIdent(tableName)
|
||
if t == "" {
|
||
return "``"
|
||
}
|
||
if strings.Contains(t, ".") {
|
||
parts := strings.Split(t, ".")
|
||
quoted := make([]string, 0, len(parts))
|
||
for _, part := range parts {
|
||
part = strings.TrimSpace(part)
|
||
if part == "" {
|
||
continue
|
||
}
|
||
quoted = append(quoted, fmt.Sprintf("`%s`", escapeBacktickIdent(part)))
|
||
}
|
||
if len(quoted) > 0 {
|
||
return strings.Join(quoted, ".")
|
||
}
|
||
}
|
||
|
||
db := escapeBacktickIdent(dbName)
|
||
if db == "" {
|
||
return fmt.Sprintf("`%s`", t)
|
||
}
|
||
return fmt.Sprintf("`%s`.`%s`", db, t)
|
||
}
|