Files
MyGoNavi/internal/db/trino_impl.go
Syngnat 8ea7ecc477 feat(trino): 新增 Trino 可选驱动接入并补齐查询支持
- 后端新增 Trino 数据库实现与 optional driver-agent provider
- 前端补齐 catalog.schema 连接配置、URI 解析与能力开关
- SQL 编辑器对 Trino 禁用托管事务并补充前后端测试
2026-06-21 13:54:42 +08:00

662 lines
18 KiB
Go

//go:build gonavi_full_drivers || gonavi_trino_driver
package db
import (
"context"
"database/sql"
"fmt"
"net"
"net/http"
"net/url"
"sort"
"strconv"
"strings"
"time"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/logger"
"GoNavi-Wails/internal/ssh"
trinodriver "github.com/trinodb/trino-go-client/trino"
)
const (
defaultTrinoPort = 8080
defaultTrinoSource = "GoNavi"
)
type TrinoDB struct {
conn *sql.DB
pingTimeout time.Duration
forwarder *ssh.LocalForwarder
namespace string
customClientName string
}
func normalizeTrinoConfig(config connection.ConnectionConfig) connection.ConnectionConfig {
normalized := applyTrinoURI(config)
normalized = applyTrinoHostURI(normalized)
if strings.TrimSpace(normalized.Host) == "" {
normalized.Host = "localhost"
}
if normalized.Port <= 0 {
normalized.Port = defaultTrinoPort
}
return normalized
}
func applyTrinoURI(config connection.ConnectionConfig) connection.ConnectionConfig {
return applyTrinoEndpointURI(config, config.URI, false)
}
func applyTrinoHostURI(config connection.ConnectionConfig) connection.ConnectionConfig {
return applyTrinoEndpointURI(config, config.Host, true)
}
func applyTrinoEndpointURI(config connection.ConnectionConfig, raw string, fromHostField bool) connection.ConnectionConfig {
uriText := strings.TrimSpace(raw)
if uriText == "" {
return config
}
parsed, err := url.Parse(uriText)
if err != nil {
return config
}
scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
if scheme != "trino" && scheme != "http" && scheme != "https" {
return config
}
if strings.TrimSpace(parsed.Host) == "" {
return config
}
if parsed.User != nil {
if strings.TrimSpace(config.User) == "" {
config.User = parsed.User.Username()
}
if pass, ok := parsed.User.Password(); ok && config.Password == "" {
config.Password = pass
}
}
params := url.Values{}
mergeConnectionParamValues(params, parsed.Query())
mergeConnectionParamValues(params, connectionParamsFromText(config.ConnectionParams))
catalog := strings.TrimSpace(parsed.Query().Get("catalog"))
schema := strings.TrimSpace(parsed.Query().Get("schema"))
if strings.TrimSpace(config.Database) == "" {
config.Database = joinTrinoNamespace(catalog, schema)
}
if scheme == "https" {
config.UseSSL = true
if normalizeSSLModeValue(config.SSLMode) == sslModeDisable || strings.TrimSpace(config.SSLMode) == "" {
config.SSLMode = sslModeRequired
}
}
defaultPort := config.Port
if defaultPort <= 0 {
defaultPort = defaultTrinoPort
}
if fromHostField || strings.TrimSpace(config.Host) == "" {
host, port, ok := parseHostPortWithDefault(parsed.Host, defaultPort)
if ok {
config.Host = host
config.Port = port
}
}
if config.Port <= 0 {
config.Port = defaultPort
}
config.ConnectionParams = params.Encode()
return config
}
func splitTrinoNamespace(raw string) (string, string) {
text := strings.TrimSpace(raw)
if text == "" {
return "", ""
}
parts := strings.SplitN(text, ".", 2)
catalog := strings.TrimSpace(parts[0])
if len(parts) == 1 {
return catalog, ""
}
return catalog, strings.TrimSpace(parts[1])
}
func joinTrinoNamespace(catalog, schema string) string {
c := strings.TrimSpace(catalog)
s := strings.TrimSpace(schema)
switch {
case c == "":
return s
case s == "":
return c
default:
return c + "." + s
}
}
func resolveTrinoNamespace(raw string, fallback string) (string, string) {
catalog, schema := splitTrinoNamespace(raw)
if catalog != "" || schema != "" {
return catalog, schema
}
return splitTrinoNamespace(fallback)
}
func quoteTrinoIdentifier(ident string) string {
return `"` + strings.ReplaceAll(strings.TrimSpace(ident), `"`, `""`) + `"`
}
func quoteTrinoQualifiedTable(catalog, schema, table string) string {
quoted := make([]string, 0, 3)
if trimmed := strings.TrimSpace(catalog); trimmed != "" {
quoted = append(quoted, quoteTrinoIdentifier(trimmed))
}
if trimmed := strings.TrimSpace(schema); trimmed != "" {
quoted = append(quoted, quoteTrinoIdentifier(trimmed))
}
quoted = append(quoted, quoteTrinoIdentifier(table))
return strings.Join(quoted, ".")
}
func escapeTrinoSQLLiteral(value string) string {
return "'" + strings.ReplaceAll(strings.TrimSpace(value), "'", "''") + "'"
}
func trinoRowValue(row map[string]interface{}, keys ...string) (interface{}, bool) {
if len(row) == 0 {
return nil, false
}
for _, key := range keys {
for current, value := range row {
if strings.EqualFold(strings.TrimSpace(current), strings.TrimSpace(key)) {
return value, true
}
}
}
return nil, false
}
func trinoRowString(row map[string]interface{}, keys ...string) string {
value, ok := trinoRowValue(row, keys...)
if !ok || value == nil {
return ""
}
text := strings.TrimSpace(fmt.Sprintf("%v", value))
if strings.EqualFold(text, "<nil>") {
return ""
}
return text
}
func firstTrinoMapValueAsString(row map[string]interface{}) string {
for _, value := range row {
text := strings.TrimSpace(fmt.Sprintf("%v", value))
if !strings.EqualFold(text, "<nil>") {
return text
}
}
return ""
}
func firstTrinoRowValueAsString(data []map[string]interface{}) string {
if len(data) == 0 {
return ""
}
return firstTrinoMapValueAsString(data[0])
}
func (t *TrinoDB) buildTrinoHTTPClient(config connection.ConnectionConfig) (*http.Client, error) {
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: getConnectTimeout(config),
KeepAlive: 30 * time.Second,
}).DialContext,
MaxIdleConns: 32,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: time.Second,
}
tlsConfig, err := resolveGenericTLSConfig(config)
if err != nil {
return nil, err
}
if tlsConfig != nil {
transport.TLSClientConfig = tlsConfig
}
return &http.Client{Transport: transport}, nil
}
func (t *TrinoDB) registerTrinoCustomClient(config connection.ConnectionConfig) (string, error) {
client, err := t.buildTrinoHTTPClient(config)
if err != nil {
return "", err
}
name := fmt.Sprintf("gonavi-trino-%d", time.Now().UnixNano())
if err := trinodriver.RegisterCustomClient(name, client); err != nil {
return "", err
}
return name, nil
}
func buildTrinoDSN(config connection.ConnectionConfig, customClientName string) (string, error) {
user := strings.TrimSpace(config.User)
if user == "" {
return "", fmt.Errorf("Trino 用户名不能为空")
}
scheme := "http"
if config.UseSSL {
scheme = "https"
}
if config.Password != "" && scheme != "https" {
return "", fmt.Errorf("Trino 启用密码认证时必须使用 HTTPS")
}
params := connectionParamsFromText(config.ConnectionParams)
catalog, schema := resolveTrinoNamespace(config.Database, "")
if catalog != "" {
params.Set("catalog", catalog)
}
if schema != "" {
params.Set("schema", schema)
}
if strings.TrimSpace(params.Get("source")) == "" {
params.Set("source", defaultTrinoSource)
}
if strings.TrimSpace(params.Get("explicitPrepare")) == "" {
params.Set("explicitPrepare", "false")
}
if strings.TrimSpace(params.Get("query_timeout")) == "" {
params.Set("query_timeout", fmt.Sprintf("%ds", getConnectTimeoutSeconds(config)))
}
if strings.TrimSpace(customClientName) != "" {
params.Set("custom_client", strings.TrimSpace(customClientName))
}
endpoint := &url.URL{
Scheme: scheme,
Host: net.JoinHostPort(strings.TrimSpace(config.Host), strconv.Itoa(config.Port)),
RawQuery: params.Encode(),
}
if config.Password != "" {
endpoint.User = url.UserPassword(user, config.Password)
} else {
endpoint.User = url.User(user)
}
return endpoint.String(), nil
}
func (t *TrinoDB) Close() error {
if t.conn != nil {
if err := t.conn.Close(); err != nil {
return err
}
t.conn = nil
}
if t.forwarder != nil {
if err := t.forwarder.Close(); err != nil {
logger.Warnf("关闭 Trino SSH 端口转发失败:%v", err)
}
t.forwarder = nil
}
if t.customClientName != "" {
trinodriver.DeregisterCustomClient(t.customClientName)
t.customClientName = ""
}
t.namespace = ""
return nil
}
func (t *TrinoDB) Connect(config connection.ConnectionConfig) error {
_ = t.Close()
runConfig := normalizeTrinoConfig(config)
t.pingTimeout = getConnectTimeout(runConfig)
if runConfig.UseSSH {
forwarder, err := ssh.GetOrCreateLocalForwarder(runConfig.SSH, runConfig.Host, runConfig.Port)
if err != nil {
return fmt.Errorf("创建 SSH 隧道失败:%w", err)
}
t.forwarder = forwarder
host, portText, err := net.SplitHostPort(forwarder.LocalAddr)
if err != nil {
_ = t.Close()
return fmt.Errorf("解析本地转发地址失败:%w", err)
}
port, err := strconv.Atoi(portText)
if err != nil {
_ = t.Close()
return fmt.Errorf("解析本地端口失败:%w", err)
}
runConfig.Host = host
runConfig.Port = port
runConfig.UseSSH = false
logger.Infof("Trino 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
}
customClientName, err := t.registerTrinoCustomClient(runConfig)
if err != nil {
_ = t.Close()
return fmt.Errorf("注册 Trino 自定义 HTTP 客户端失败:%w", err)
}
t.customClientName = customClientName
dsn, err := buildTrinoDSN(runConfig, customClientName)
if err != nil {
_ = t.Close()
return err
}
conn, err := sql.Open("trino", dsn)
if err != nil {
_ = t.Close()
return err
}
t.conn = conn
t.namespace = strings.TrimSpace(runConfig.Database)
if err := t.Ping(); err != nil {
_ = t.Close()
return err
}
return nil
}
func (t *TrinoDB) Ping() error {
if t.conn == nil {
return fmt.Errorf("连接未打开")
}
timeout := t.pingTimeout
if timeout <= 0 {
timeout = 10 * time.Second
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
rows, err := t.conn.QueryContext(ctx, "SELECT 1")
if err != nil {
return err
}
defer rows.Close()
if !rows.Next() {
if err := rows.Err(); err != nil {
return err
}
return fmt.Errorf("连接查询验证未返回结果")
}
var value sql.NullInt64
if err := rows.Scan(&value); err != nil {
return err
}
return rows.Err()
}
func (t *TrinoDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
if t.conn == nil {
return nil, nil, fmt.Errorf("连接未打开")
}
rows, err := t.conn.QueryContext(ctx, query)
if err != nil {
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
}
func (t *TrinoDB) Query(query string) ([]map[string]interface{}, []string, error) {
return t.QueryContext(context.Background(), query)
}
func (t *TrinoDB) StreamQueryContext(ctx context.Context, query string, consumer QueryStreamConsumer) error {
if t.conn == nil {
return fmt.Errorf("连接未打开")
}
rows, err := t.conn.QueryContext(ctx, query)
if err != nil {
return err
}
defer rows.Close()
return streamRows(rows, consumer)
}
func (t *TrinoDB) StreamQuery(query string, consumer QueryStreamConsumer) error {
return t.StreamQueryContext(context.Background(), query, consumer)
}
func (t *TrinoDB) ExecContext(ctx context.Context, query string) (int64, error) {
if t.conn == nil {
return 0, fmt.Errorf("连接未打开")
}
res, err := t.conn.ExecContext(ctx, query)
if err != nil {
return 0, err
}
affected, err := res.RowsAffected()
if err != nil {
return 0, nil
}
return affected, nil
}
func (t *TrinoDB) Exec(query string) (int64, error) {
return t.ExecContext(context.Background(), query)
}
func (t *TrinoDB) queryTrinoSingleColumnStrings(query string) ([]string, error) {
data, _, err := t.Query(query)
if err != nil {
return nil, err
}
result := make([]string, 0, len(data))
for _, row := range data {
text := firstTrinoMapValueAsString(row)
if text != "" {
result = append(result, text)
}
}
return result, nil
}
func (t *TrinoDB) GetDatabases() ([]string, error) {
catalogs, err := t.queryTrinoSingleColumnStrings("SHOW CATALOGS")
if err != nil {
if strings.TrimSpace(t.namespace) != "" {
return []string{t.namespace}, nil
}
return nil, err
}
namespaces := make([]string, 0, len(catalogs)*2)
seen := make(map[string]struct{}, len(catalogs)*4)
var lastErr error
for _, catalog := range catalogs {
query := fmt.Sprintf("SHOW SCHEMAS FROM %s", quoteTrinoIdentifier(catalog))
schemas, schemaErr := t.queryTrinoSingleColumnStrings(query)
if schemaErr != nil {
lastErr = schemaErr
continue
}
for _, schema := range schemas {
namespace := joinTrinoNamespace(catalog, schema)
if namespace == "" {
continue
}
key := strings.ToLower(namespace)
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
namespaces = append(namespaces, namespace)
}
}
if len(namespaces) == 0 {
if strings.TrimSpace(t.namespace) != "" {
return []string{t.namespace}, nil
}
if lastErr != nil {
return nil, lastErr
}
}
sort.Strings(namespaces)
return namespaces, nil
}
func (t *TrinoDB) GetTables(dbName string) ([]string, error) {
catalog, schema := resolveTrinoNamespace(dbName, t.namespace)
if catalog == "" || schema == "" {
return nil, fmt.Errorf("Trino 默认命名空间必须使用 catalog.schema")
}
query := fmt.Sprintf("SHOW TABLES FROM %s.%s", quoteTrinoIdentifier(catalog), quoteTrinoIdentifier(schema))
tables, err := t.queryTrinoSingleColumnStrings(query)
if err != nil {
return nil, err
}
sort.Strings(tables)
return tables, nil
}
func (t *TrinoDB) GetCreateStatement(dbName, tableName string) (string, error) {
catalog, schema := resolveTrinoNamespace(dbName, t.namespace)
if catalog == "" || schema == "" {
return "", fmt.Errorf("Trino 默认命名空间必须使用 catalog.schema")
}
query := fmt.Sprintf("SHOW CREATE TABLE %s", quoteTrinoQualifiedTable(catalog, schema, tableName))
data, _, err := t.Query(query)
if err != nil {
return "", err
}
ddl := firstTrinoRowValueAsString(data)
if ddl == "" {
return "", fmt.Errorf("未返回建表语句")
}
return ddl, nil
}
func buildTrinoColumnsQuery(catalog, schema, tableName string) string {
return fmt.Sprintf(`SELECT
column_name,
data_type,
is_nullable,
column_default
FROM %s.information_schema.columns
WHERE table_schema = %s AND table_name = %s
ORDER BY ordinal_position`,
quoteTrinoIdentifier(catalog),
escapeTrinoSQLLiteral(schema),
escapeTrinoSQLLiteral(tableName),
)
}
func buildTrinoColumnDefinitions(data []map[string]interface{}) []connection.ColumnDefinition {
result := make([]connection.ColumnDefinition, 0, len(data))
for _, row := range data {
column := connection.ColumnDefinition{
Name: trinoRowString(row, "column_name", "Column", "Field"),
Type: trinoRowString(row, "data_type", "Type"),
Nullable: strings.ToUpper(trinoRowString(row, "is_nullable", "Null")),
}
if rawDefault, ok := trinoRowValue(row, "column_default", "Default"); ok && rawDefault != nil {
def := strings.TrimSpace(fmt.Sprintf("%v", rawDefault))
if !strings.EqualFold(def, "<nil>") && def != "" {
column.Default = &def
}
}
if column.Nullable == "" {
column.Nullable = "YES"
}
result = append(result, column)
}
return result
}
func (t *TrinoDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
catalog, schema := resolveTrinoNamespace(dbName, t.namespace)
if catalog == "" || schema == "" {
return nil, fmt.Errorf("Trino 默认命名空间必须使用 catalog.schema")
}
data, _, err := t.Query(buildTrinoColumnsQuery(catalog, schema, tableName))
if err == nil {
return buildTrinoColumnDefinitions(data), nil
}
describeQuery := fmt.Sprintf("DESCRIBE %s", quoteTrinoQualifiedTable(catalog, schema, tableName))
describeRows, _, describeErr := t.Query(describeQuery)
if describeErr != nil {
return nil, err
}
columns := make([]connection.ColumnDefinition, 0, len(describeRows))
for _, row := range describeRows {
name := trinoRowString(row, "Column", "column_name", "Field")
if name == "" || strings.HasPrefix(name, "#") {
continue
}
columns = append(columns, connection.ColumnDefinition{
Name: name,
Type: trinoRowString(row, "Type", "data_type"),
Nullable: "YES",
})
}
return columns, nil
}
func (t *TrinoDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
catalog, schema := resolveTrinoNamespace(dbName, t.namespace)
if catalog == "" || schema == "" {
return nil, fmt.Errorf("Trino 默认命名空间必须使用 catalog.schema")
}
query := fmt.Sprintf(`SELECT
table_name,
column_name,
data_type
FROM %s.information_schema.columns
WHERE table_schema = %s
ORDER BY table_name, ordinal_position`,
quoteTrinoIdentifier(catalog),
escapeTrinoSQLLiteral(schema),
)
data, _, err := t.Query(query)
if err != nil {
return nil, err
}
result := make([]connection.ColumnDefinitionWithTable, 0, len(data))
for _, row := range data {
result = append(result, connection.ColumnDefinitionWithTable{
TableName: trinoRowString(row, "table_name", "TABLE_NAME"),
Name: trinoRowString(row, "column_name", "COLUMN_NAME"),
Type: trinoRowString(row, "data_type", "DATA_TYPE"),
})
}
return result, nil
}
func (t *TrinoDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
return []connection.IndexDefinition{}, nil
}
func (t *TrinoDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
return []connection.ForeignKeyDefinition{}, nil
}
func (t *TrinoDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
return []connection.TriggerDefinition{}, nil
}
func (t *TrinoDB) OpenSessionExecer(ctx context.Context) (StatementExecer, error) {
if t.conn == nil {
return nil, fmt.Errorf("连接未打开")
}
conn, err := t.conn.Conn(ctx)
if err != nil {
return nil, err
}
return NewSQLConnStatementExecer(conn), nil
}