Files
MyGoNavi/internal/db/tdengine_impl.go
Syngnat f3e11961dc 🐛 fix(tdengine): 修复旧版 TDengine 元数据查询与驱动版本选择异常
- 放开 TDengine 已安装驱动的历史版本切换入口
- 兼容低版本 SHOW TABLES FROM 语法差异
- 修复表概览加载时报 [0x2600] syntax error near
- 新增后端兼容与前端交互回归测试
- Close #531
2026-06-14 17:22:02 +08:00

638 lines
16 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
//go:build gonavi_full_drivers || gonavi_tdengine_driver
package db
import (
"context"
"database/sql"
"errors"
"fmt"
"net"
"net/url"
"sort"
"strconv"
"strings"
"time"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/logger"
"GoNavi-Wails/internal/ssh"
"GoNavi-Wails/internal/utils"
_ "github.com/taosdata/driver-go/v3/taosWS"
)
// TDengineDB implements Database interface for TDengine.
// Uses taosWS driver via WebSocket (通常通过 taosAdapter 提供服务)。
type TDengineDB struct {
conn *sql.DB
pingTimeout time.Duration
forwarder *ssh.LocalForwarder
}
func (t *TDengineDB) getDSN(config connection.ConnectionConfig) string {
user := strings.TrimSpace(config.User)
if user == "" {
user = "root"
}
pass := config.Password
dbName := strings.TrimSpace(config.Database)
path := "/"
if dbName != "" {
path = "/" + dbName
}
netType := resolveTDengineNet(config)
params := url.Values{}
mergeConnectionParamsFromConfigWithAllowlist(params, config, tdengineConnectionParamNames, "taos", "taosws", "tdengine")
query := params.Encode()
dsn := fmt.Sprintf("%s:%s@%s(%s)%s", user, pass, netType, net.JoinHostPort(config.Host, strconv.Itoa(config.Port)), path)
if query == "" {
return dsn
}
return dsn + "?" + query
}
func (t *TDengineDB) Connect(config connection.ConnectionConfig) error {
runConfig := config
if config.UseSSH {
logger.Infof("TDengine 使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User)
forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port)
if err != nil {
return fmt.Errorf("创建 SSH 隧道失败:%w", err)
}
t.forwarder = forwarder
host, portStr, err := net.SplitHostPort(forwarder.LocalAddr)
if err != nil {
return fmt.Errorf("解析本地转发地址失败:%w", err)
}
port, err := strconv.Atoi(portStr)
if err != nil {
return fmt.Errorf("解析本地端口失败:%w", err)
}
localConfig := config
localConfig.Host = host
localConfig.Port = port
localConfig.UseSSH = false
runConfig = localConfig
logger.Infof("TDengine 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
}
attempts := []connection.ConnectionConfig{runConfig}
if shouldTrySSLPreferredFallback(runConfig) {
attempts = append(attempts, withSSLDisabled(runConfig))
}
var failures []string
for idx, attempt := range attempts {
dsn := t.getDSN(attempt)
db, err := sql.Open("taosWS", dsn)
if err != nil {
failures = append(failures, fmt.Sprintf("第%d次连接打开失败: %v", idx+1, err))
continue
}
t.conn = db
t.pingTimeout = getConnectTimeout(attempt)
if err := t.Ping(); err != nil {
_ = db.Close()
t.conn = nil
failures = append(failures, fmt.Sprintf("第%d次连接验证失败: %v", idx+1, err))
continue
}
if idx > 0 {
logger.Warnf("TDengine SSL 优先连接失败,已回退至明文连接")
}
return nil
}
return fmt.Errorf("连接建立后验证失败:%s", strings.Join(failures, ""))
}
func (t *TDengineDB) Close() error {
if t.forwarder != nil {
if err := t.forwarder.Close(); err != nil {
logger.Warnf("关闭 TDengine SSH 端口转发失败:%v", err)
}
t.forwarder = nil
}
if t.conn != nil {
return t.conn.Close()
}
return nil
}
func (t *TDengineDB) Ping() error {
if t.conn == nil {
return fmt.Errorf("连接未打开")
}
timeout := t.pingTimeout
if timeout <= 0 {
timeout = 5 * time.Second
}
ctx, cancel := utils.ContextWithTimeout(timeout)
defer cancel()
return t.conn.PingContext(ctx)
}
func (t *TDengineDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
if t.conn == nil {
return nil, nil, fmt.Errorf("连接未打开")
}
rows, err := t.conn.QueryContext(ctx, query)
if err != nil {
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
}
func (t *TDengineDB) Query(query string) ([]map[string]interface{}, []string, error) {
if t.conn == nil {
return nil, nil, fmt.Errorf("连接未打开")
}
rows, err := t.conn.Query(query)
if err != nil {
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
}
func (t *TDengineDB) ExecContext(ctx context.Context, query string) (int64, error) {
if t.conn == nil {
return 0, fmt.Errorf("连接未打开")
}
res, err := t.conn.ExecContext(ctx, query)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
func (t *TDengineDB) Exec(query string) (int64, error) {
if t.conn == nil {
return 0, fmt.Errorf("连接未打开")
}
res, err := t.conn.Exec(query)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
func (t *TDengineDB) GetDatabases() ([]string, error) {
data, _, err := t.Query("SHOW DATABASES")
if err != nil {
return nil, err
}
var dbs []string
for _, row := range data {
if val, ok := getValueFromRow(row, "name", "database", "Database", "db_name"); ok {
dbs = append(dbs, fmt.Sprintf("%v", val))
continue
}
for _, val := range row {
dbs = append(dbs, fmt.Sprintf("%v", val))
break
}
}
return dbs, nil
}
func (t *TDengineDB) GetTables(dbName string) ([]string, error) {
queries := tdengineShowTablesQueries(dbName)
var lastErr error
tableSet := make(map[string]struct{})
tables := make([]string, 0)
for _, query := range queries {
data, _, err := t.Query(query)
if err != nil {
lastErr = err
continue
}
for _, row := range data {
if val, ok := getValueFromRow(row, "table_name", "tablename", "name", "Table", "table"); ok {
tableName := strings.TrimSpace(fmt.Sprintf("%v", val))
if tableName == "" {
continue
}
if _, exists := tableSet[tableName]; exists {
continue
}
tableSet[tableName] = struct{}{}
tables = append(tables, tableName)
continue
}
for _, val := range row {
tableName := strings.TrimSpace(fmt.Sprintf("%v", val))
if tableName == "" {
break
}
if _, exists := tableSet[tableName]; exists {
break
}
tableSet[tableName] = struct{}{}
tables = append(tables, tableName)
break
}
}
}
if len(tables) > 0 {
sort.Strings(tables)
return tables, nil
}
if lastErr != nil {
return nil, lastErr
}
return []string{}, nil
}
func (t *TDengineDB) GetCreateStatement(dbName, tableName string) (string, error) {
queries := tdengineCreateStatementQueries(dbName, tableName)
var lastErr error
for _, query := range queries {
data, _, err := t.Query(query)
if err != nil {
lastErr = err
continue
}
if len(data) == 0 {
continue
}
row := data[0]
if val, ok := getValueFromRow(row, "Create Table", "create table", "Create Stable", "create stable", "SQL", "sql"); ok {
return fmt.Sprintf("%v", val), nil
}
longest := ""
for _, val := range row {
text := fmt.Sprintf("%v", val)
if strings.Contains(strings.ToUpper(text), "CREATE ") && len(text) > len(longest) {
longest = text
}
}
if longest != "" {
return longest, nil
}
}
if lastErr != nil {
return "", lastErr
}
return "", fmt.Errorf("未找到建表语句")
}
func (t *TDengineDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
var (
data []map[string]interface{}
err error
lastErr error
)
for _, query := range tdengineDescribeQueries(dbName, tableName) {
data, _, err = t.Query(query)
if err == nil {
break
}
lastErr = err
if !isTDengineSyntaxCompatibilityError(err) {
return nil, err
}
}
if err != nil {
if lastErr != nil {
return nil, lastErr
}
return nil, err
}
columns := make([]connection.ColumnDefinition, 0, len(data))
for _, row := range data {
name, _ := getValueFromRow(row, "Field", "field", "col_name", "column_name", "name")
colType, _ := getValueFromRow(row, "Type", "type", "data_type")
note, _ := getValueFromRow(row, "Note", "note", "Extra", "extra")
nullable, okNull := getValueFromRow(row, "Null", "null", "nullable")
comment, _ := getValueFromRow(row, "Comment", "comment")
defaultVal, hasDefault := getValueFromRow(row, "Default", "default")
col := connection.ColumnDefinition{
Name: fmt.Sprintf("%v", name),
Type: fmt.Sprintf("%v", colType),
Nullable: "YES",
Key: "",
Extra: fmt.Sprintf("%v", note),
Comment: fmt.Sprintf("%v", comment),
}
if okNull {
col.Nullable = strings.ToUpper(fmt.Sprintf("%v", nullable))
}
noteUpper := strings.ToUpper(fmt.Sprintf("%v", note))
if strings.Contains(noteUpper, "TAG") {
col.Key = "TAG"
}
if hasDefault && defaultVal != nil {
def := fmt.Sprintf("%v", defaultVal)
if def != "<nil>" {
col.Default = &def
}
}
columns = append(columns, col)
}
return columns, nil
}
func (t *TDengineDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
if strings.TrimSpace(dbName) == "" {
return nil, fmt.Errorf("获取全部列信息需要指定数据库名称")
}
tables, err := t.GetTables(dbName)
if err != nil {
return nil, err
}
cols := make([]connection.ColumnDefinitionWithTable, 0)
for _, table := range tables {
tableCols, err := t.GetColumns(dbName, table)
if err != nil {
continue
}
for _, col := range tableCols {
cols = append(cols, connection.ColumnDefinitionWithTable{
TableName: table,
Name: col.Name,
Type: col.Type,
Comment: col.Comment,
})
}
}
return cols, nil
}
func (t *TDengineDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
return []connection.IndexDefinition{}, nil
}
func (t *TDengineDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
return []connection.ForeignKeyDefinition{}, nil
}
func (t *TDengineDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
return []connection.TriggerDefinition{}, nil
}
func (t *TDengineDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
if t.conn == nil {
return fmt.Errorf("连接未打开")
}
if strings.TrimSpace(tableName) == "" {
return fmt.Errorf("表名不能为空")
}
if len(changes.Updates) > 0 || len(changes.Deletes) > 0 {
return fmt.Errorf("TDengine 目标端当前仅支持 INSERT 写入,暂不支持 UPDATE/DELETE 差异同步,请改用仅插入或全量覆盖模式")
}
qualifiedTable := quoteTDengineTable("", tableName)
return execTDengineInsertBatches(t.conn, qualifiedTable, changes.Inserts)
}
func execTDengineInsertBatches(conn *sql.DB, qualifiedTable string, rows []map[string]interface{}) error {
if conn == nil {
return fmt.Errorf("连接未打开")
}
return execLiteralInsertBatches(literalInsertConfig{
Table: qualifiedTable,
Rows: rows,
QuoteColumn: func(column string) string {
return fmt.Sprintf("`%s`", escapeBacktickIdent(column))
},
Literal: tdengineLiteral,
Exec: func(query string) (sql.Result, error) {
return conn.Exec(query)
},
})
}
func buildTDengineInsertSQL(qualifiedTable string, row map[string]interface{}) (string, error) {
if strings.TrimSpace(qualifiedTable) == "" {
return "", fmt.Errorf("需要指定完整的表名")
}
if len(row) == 0 {
return "", nil
}
cols := make([]string, 0, len(row))
for key := range row {
if strings.TrimSpace(key) == "" {
continue
}
cols = append(cols, key)
}
if len(cols) == 0 {
return "", nil
}
sort.Strings(cols)
quotedCols := make([]string, 0, len(cols))
values := make([]string, 0, len(cols))
for _, col := range cols {
quotedCols = append(quotedCols, fmt.Sprintf("`%s`", escapeBacktickIdent(col)))
values = append(values, tdengineLiteral(row[col]))
}
return fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(quotedCols, ", "), strings.Join(values, ", ")), nil
}
func tdengineLiteral(value interface{}) string {
switch val := value.(type) {
case nil:
return "NULL"
case bool:
if val {
return "1"
}
return "0"
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64:
return fmt.Sprintf("%v", val)
case time.Time:
return fmt.Sprintf("'%s'", val.Format("2006-01-02 15:04:05"))
case []byte:
return fmt.Sprintf("'%s'", strings.ReplaceAll(string(val), "'", "''"))
default:
return fmt.Sprintf("'%s'", strings.ReplaceAll(fmt.Sprintf("%v", val), "'", "''"))
}
}
func getValueFromRow(row map[string]interface{}, keys ...string) (interface{}, bool) {
if len(row) == 0 {
return nil, false
}
for _, key := range keys {
if val, ok := row[key]; ok {
return val, true
}
}
for existingKey, val := range row {
for _, key := range keys {
if strings.EqualFold(existingKey, key) {
return val, true
}
}
}
return nil, false
}
func escapeBacktickIdent(ident string) string {
return strings.ReplaceAll(strings.TrimSpace(ident), "`", "``")
}
func tdengineShowTablesQueries(dbName string) []string {
queries := make([]string, 0, 6)
appendQuery := func(query string) {
query = strings.TrimSpace(query)
if query == "" {
return
}
for _, existing := range queries {
if existing == query {
return
}
}
queries = append(queries, query)
}
db := strings.TrimSpace(dbName)
if db != "" {
escaped := escapeBacktickIdent(db)
appendQuery(fmt.Sprintf("SHOW TABLES FROM `%s`", escaped))
appendQuery(fmt.Sprintf("SHOW STABLES FROM `%s`", escaped))
appendQuery(fmt.Sprintf("SHOW TABLES FROM %s", db))
appendQuery(fmt.Sprintf("SHOW STABLES FROM %s", db))
}
appendQuery("SHOW TABLES")
appendQuery("SHOW STABLES")
return queries
}
func tdengineDescribeQueries(dbName, tableName string) []string {
qualified := quoteTDengineTable(dbName, tableName)
legacyQualified := quoteTDengineTableLegacy(dbName, tableName)
queries := []string{fmt.Sprintf("DESCRIBE %s", qualified)}
if legacyQualified != qualified {
queries = append(queries, fmt.Sprintf("DESCRIBE %s", legacyQualified))
}
return queries
}
func tdengineCreateStatementQueries(dbName, tableName string) []string {
queries := make([]string, 0, 4)
appendQualifiedQueries := func(qualified string) {
if strings.TrimSpace(qualified) == "" {
return
}
queries = append(queries,
fmt.Sprintf("SHOW CREATE TABLE %s", qualified),
fmt.Sprintf("SHOW CREATE STABLE %s", qualified),
)
}
qualified := quoteTDengineTable(dbName, tableName)
appendQualifiedQueries(qualified)
legacyQualified := quoteTDengineTableLegacy(dbName, tableName)
if legacyQualified != qualified {
appendQualifiedQueries(legacyQualified)
}
return queries
}
func quoteTDengineTableLegacy(dbName, tableName string) string {
table := strings.TrimSpace(tableName)
if table == "" {
return ""
}
if strings.Contains(table, ".") {
return strings.Join(splitTDengineIdentifierParts(table), ".")
}
db := strings.TrimSpace(dbName)
if db == "" {
return table
}
return db + "." + table
}
func splitTDengineIdentifierParts(path string) []string {
parts := strings.Split(strings.TrimSpace(path), ".")
result := make([]string, 0, len(parts))
for _, part := range parts {
trimmed := strings.Trim(strings.TrimSpace(part), "`")
if trimmed == "" {
continue
}
result = append(result, trimmed)
}
return result
}
func isTDengineSyntaxCompatibilityError(err error) bool {
if err == nil {
return false
}
text := strings.ToLower(strings.TrimSpace(err.Error()))
if text == "" {
return false
}
return strings.Contains(text, "syntax error near") ||
strings.Contains(text, "[0x2600]") ||
errors.Is(err, sql.ErrNoRows)
}
func quoteTDengineTable(dbName, tableName string) string {
t := escapeBacktickIdent(tableName)
if t == "" {
return "``"
}
if strings.Contains(t, ".") {
parts := strings.Split(t, ".")
quoted := make([]string, 0, len(parts))
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
continue
}
quoted = append(quoted, fmt.Sprintf("`%s`", escapeBacktickIdent(part)))
}
if len(quoted) > 0 {
return strings.Join(quoted, ".")
}
}
db := escapeBacktickIdent(dbName)
if db == "" {
return fmt.Sprintf("`%s`", t)
}
return fmt.Sprintf("`%s`.`%s`", db, t)
}