mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-06-30 21:41:34 +08:00
✨ feat(trino): 新增 Trino 可选驱动接入并补齐查询支持
- 后端新增 Trino 数据库实现与 optional driver-agent provider - 前端补齐 catalog.schema 连接配置、URI 解析与能力开关 - SQL 编辑器对 Trino 禁用托管事务并补充前后端测试
This commit is contained in:
@@ -26,7 +26,7 @@ func normalizeRunConfig(config connection.ConnectionConfig, dbName string) conne
|
||||
if !isOceanBaseOracleProtocol(config) {
|
||||
runConfig.Database = name
|
||||
}
|
||||
case "mysql", "mariadb", "goldendb", "greatdb", "gdb", "diros", "starrocks", "sphinx", "postgres", "kingbase", "highgo", "vastbase", "opengauss", "gaussdb", "sqlserver", "iris", "intersystems", "intersystemsiris", "inter-systems", "inter-systems-iris", "mongodb", "tdengine", "iotdb", "clickhouse", "rabbitmq", "rabbit-mq", "rabbit_mq":
|
||||
case "mysql", "mariadb", "goldendb", "greatdb", "gdb", "diros", "starrocks", "sphinx", "postgres", "kingbase", "highgo", "vastbase", "opengauss", "gaussdb", "sqlserver", "iris", "intersystems", "intersystemsiris", "inter-systems", "inter-systems-iris", "mongodb", "tdengine", "iotdb", "clickhouse", "trino", "rabbitmq", "rabbit-mq", "rabbit_mq":
|
||||
// 这些类型的 dbName 表示"数据库",需要写入连接配置以选择目标库。
|
||||
runConfig.Database = name
|
||||
case "dameng":
|
||||
@@ -57,7 +57,7 @@ func normalizeSchemaAndTable(config connection.ConnectionConfig, dbName string,
|
||||
|
||||
// Elasticsearch:索引名可能含多个点(如 iot_pro_biz_operate_log.index.20240626),
|
||||
// 不能按点分割,直接返回原始数据库名和完整表名。
|
||||
if dbType == "elasticsearch" || dbType == "iotdb" || dbType == "rocketmq" || dbType == "mqtt" || dbType == "kafka" || dbType == "rabbitmq" {
|
||||
if dbType == "elasticsearch" || dbType == "iotdb" || dbType == "rocketmq" || dbType == "mqtt" || dbType == "kafka" || dbType == "rabbitmq" || dbType == "trino" {
|
||||
return rawDB, rawTable
|
||||
}
|
||||
|
||||
@@ -116,7 +116,7 @@ func normalizeSchemaAndTable(config connection.ConnectionConfig, dbName string,
|
||||
func normalizeMetadataSchemaAndTable(config connection.ConnectionConfig, dbName string, tableName string) (string, string) {
|
||||
schema, table := normalizeSchemaAndTable(config, dbName, tableName)
|
||||
switch resolveDDLDBType(config) {
|
||||
case "rocketmq", "mqtt", "kafka", "rabbitmq":
|
||||
case "rocketmq", "mqtt", "kafka", "rabbitmq", "trino":
|
||||
return schema, table
|
||||
case "postgres", "kingbase", "highgo", "vastbase", "opengauss", "gaussdb":
|
||||
rawTable := strings.TrimSpace(tableName)
|
||||
|
||||
@@ -50,6 +50,18 @@ func TestNormalizeSchemaAndTable_PostgresStillSplitsQualifiedName(t *testing.T)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeSchemaAndTable_TrinoPreservesDottedTableName(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
schema, table := normalizeSchemaAndTable(connection.ConnectionConfig{
|
||||
Type: "trino",
|
||||
}, "hive.default", "daily.events.v1")
|
||||
|
||||
if schema != "hive.default" || table != "daily.events.v1" {
|
||||
t.Fatalf("expected trino table name to stay intact, got %q.%q", schema, table)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeSchemaAndTable_KingbaseNormalizesEscapedQualifiedName(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -158,6 +170,18 @@ func TestNormalizeMetadataSchemaAndTable_NonPGLikeKeepsNormalBehavior(t *testing
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeMetadataSchemaAndTable_TrinoPreservesDottedTableName(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
schema, table := normalizeMetadataSchemaAndTable(connection.ConnectionConfig{
|
||||
Type: "trino",
|
||||
}, "iceberg.analytics", "ods.orders.v1")
|
||||
|
||||
if schema != "iceberg.analytics" || table != "ods.orders.v1" {
|
||||
t.Fatalf("expected trino metadata table to stay intact, got %q.%q", schema, table)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeSchemaAndTable_PGLikePureTableStillSplitsKingbaseSearchPathOnlyInMetadata(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -216,6 +240,19 @@ func TestNormalizeRunConfig_StarRocksUsesDatabaseFromTree(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeRunConfig_TrinoUsesNamespaceFromTree(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
runConfig := normalizeRunConfig(connection.ConnectionConfig{
|
||||
Type: "trino",
|
||||
Database: "hive.default",
|
||||
}, "iceberg.analytics")
|
||||
|
||||
if runConfig.Database != "iceberg.analytics" {
|
||||
t.Fatalf("expected trino namespace from tree, got %q", runConfig.Database)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeRunConfig_GoldenDBUsesDatabaseFromTree(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -233,6 +233,8 @@ func defaultPortByType(driverType string) int {
|
||||
return 27017
|
||||
case "clickhouse":
|
||||
return 9000
|
||||
case "trino":
|
||||
return 8080
|
||||
case "highgo":
|
||||
return 5866
|
||||
case "iris":
|
||||
|
||||
@@ -496,8 +496,8 @@ func normalizeSchemaAndTableByType(dbType string, dbName string, tableName strin
|
||||
return rawDB, rawTable
|
||||
}
|
||||
|
||||
// Elasticsearch / RocketMQ / MQTT / RabbitMQ / Kafka:对象名可能含多个点或路径,不能按点分割
|
||||
if dbType == "elasticsearch" || dbType == "rocketmq" || dbType == "mqtt" || dbType == "kafka" || dbType == "rabbitmq" {
|
||||
// Elasticsearch / RocketMQ / MQTT / RabbitMQ / Kafka / Trino:对象名可能含多个点或路径,不能按点分割
|
||||
if dbType == "elasticsearch" || dbType == "rocketmq" || dbType == "mqtt" || dbType == "kafka" || dbType == "rabbitmq" || dbType == "trino" {
|
||||
return rawDB, rawTable
|
||||
}
|
||||
|
||||
@@ -575,12 +575,35 @@ func resolveCreateStatementTargets(config connection.ConnectionConfig, dbType st
|
||||
func quoteTableIdentByType(dbType string, schema string, table string) string {
|
||||
s := strings.TrimSpace(schema)
|
||||
t := strings.TrimSpace(table)
|
||||
if dbType == "trino" {
|
||||
catalog, namespace := splitTrinoNamespace(s)
|
||||
switch {
|
||||
case catalog == "" && namespace == "":
|
||||
return quoteIdentByType(dbType, t)
|
||||
case namespace == "":
|
||||
return fmt.Sprintf("%s.%s", quoteIdentByType(dbType, catalog), quoteIdentByType(dbType, t))
|
||||
default:
|
||||
return fmt.Sprintf("%s.%s.%s", quoteIdentByType(dbType, catalog), quoteIdentByType(dbType, namespace), quoteIdentByType(dbType, t))
|
||||
}
|
||||
}
|
||||
if s == "" {
|
||||
return quoteIdentByType(dbType, t)
|
||||
}
|
||||
return fmt.Sprintf("%s.%s", quoteIdentByType(dbType, s), quoteIdentByType(dbType, t))
|
||||
}
|
||||
|
||||
func splitTrinoNamespace(raw string) (string, string) {
|
||||
text := strings.TrimSpace(raw)
|
||||
if text == "" {
|
||||
return "", ""
|
||||
}
|
||||
parts := strings.SplitN(text, ".", 2)
|
||||
if len(parts) == 1 {
|
||||
return strings.TrimSpace(parts[0]), ""
|
||||
}
|
||||
return strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1])
|
||||
}
|
||||
|
||||
func buildRunConfigForDDL(config connection.ConnectionConfig, dbType string, dbName string) connection.ConnectionConfig {
|
||||
runConfig := normalizeRunConfig(config, dbName)
|
||||
if strings.EqualFold(strings.TrimSpace(config.Type), "custom") {
|
||||
|
||||
@@ -193,6 +193,24 @@ func TestNormalizeSchemaAndTableByType_RabbitMQPreservesDottedQueueName(t *testi
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeSchemaAndTableByType_TrinoPreservesDottedTableName(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
schema, table := normalizeSchemaAndTableByType("trino", "hive.default", "orders.events.v1")
|
||||
if schema != "hive.default" || table != "orders.events.v1" {
|
||||
t.Fatalf("expected trino table name to stay intact, got %q.%q", schema, table)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuoteTableIdentByType_TrinoKeepsCatalogSchemaAndDottedTable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := quoteTableIdentByType("trino", "hive.default", "orders.events.v1")
|
||||
if got != `"hive"."default"."orders.events.v1"` {
|
||||
t.Fatalf("unexpected trino quoted table: %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRunConfigForDDL_CustomHighGoUsesDatabase(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -278,6 +278,9 @@ func executeManagedSQLTransactionStatements(ctx context.Context, session db.Stat
|
||||
}
|
||||
|
||||
func shouldUseManagedSQLTransaction(dbType string, query string) bool {
|
||||
if strings.EqualFold(strings.TrimSpace(dbType), "trino") {
|
||||
return false
|
||||
}
|
||||
statements := splitSQLStatements(query)
|
||||
hasManagedWrite := false
|
||||
for _, stmt := range statements {
|
||||
|
||||
14
internal/app/methods_db_transaction_test.go
Normal file
14
internal/app/methods_db_transaction_test.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package app
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestShouldUseManagedSQLTransaction_TrinoAlwaysUsesPlainExecution(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if shouldUseManagedSQLTransaction("trino", "UPDATE hive.default.orders SET status = 'done'") {
|
||||
t.Fatal("expected trino DML to skip SQL editor managed transactions")
|
||||
}
|
||||
if shouldUseManagedSQLTransaction("trino", "BEGIN; UPDATE hive.default.orders SET status = 'done'; COMMIT;") {
|
||||
t.Fatal("expected trino explicit transactions to stay unmanaged")
|
||||
}
|
||||
}
|
||||
@@ -395,7 +395,8 @@ const builtinDriverManifestJSON = `{
|
||||
"tdengine": { "engine": "go", "version": "3.7.8", "checksumPolicy": "off", "downloadUrl": "builtin://activate/tdengine" },
|
||||
"iotdb": { "engine": "go", "version": "1.3.7", "checksumPolicy": "off", "downloadUrl": "builtin://activate/iotdb" },
|
||||
"clickhouse": { "engine": "go", "version": "2.43.1", "checksumPolicy": "off", "downloadUrl": "builtin://activate/clickhouse" },
|
||||
"elasticsearch": { "engine": "go", "version": "8.19.6", "checksumPolicy": "off", "downloadUrl": "builtin://activate/elasticsearch" }
|
||||
"elasticsearch": { "engine": "go", "version": "8.19.6", "checksumPolicy": "off", "downloadUrl": "builtin://activate/elasticsearch" },
|
||||
"trino": { "engine": "go", "version": "0.333.0", "checksumPolicy": "off", "downloadUrl": "builtin://activate/trino" }
|
||||
}
|
||||
}`
|
||||
|
||||
@@ -462,6 +463,7 @@ var latestDriverVersionMap = map[string]string{
|
||||
"iotdb": "1.3.7",
|
||||
"clickhouse": "2.43.1",
|
||||
"elasticsearch": "8.19.6",
|
||||
"trino": "0.333.0",
|
||||
"oracle": "2.9.0",
|
||||
"postgres": "1.11.2",
|
||||
"redis": "9.17.3",
|
||||
@@ -489,6 +491,7 @@ var driverGoModulePathMap = map[string]string{
|
||||
"iotdb": "github.com/apache/iotdb-client-go",
|
||||
"clickhouse": "github.com/ClickHouse/clickhouse-go/v2",
|
||||
"elasticsearch": "github.com/elastic/go-elasticsearch/v8",
|
||||
"trino": "github.com/trinodb/trino-go-client",
|
||||
}
|
||||
|
||||
var driverGoModuleAliasPathMap = map[string][]string{
|
||||
@@ -1745,6 +1748,7 @@ func allDriverDefinitionsWithPackages(packages map[string]pinnedDriverPackage) [
|
||||
buildOptionalGoDriverDefinition("iotdb", "Apache IoTDB", packages),
|
||||
buildOptionalGoDriverDefinition("clickhouse", "ClickHouse", packages),
|
||||
buildOptionalGoDriverDefinition("elasticsearch", "Elasticsearch", packages),
|
||||
buildOptionalGoDriverDefinition("trino", "Trino", packages),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4330,6 +4334,8 @@ func optionalDriverBuildTag(driverType string, selectedVersion string) (string,
|
||||
return "gonavi_clickhouse_driver", nil
|
||||
case "elasticsearch":
|
||||
return "gonavi_elasticsearch_driver", nil
|
||||
case "trino":
|
||||
return "gonavi_trino_driver", nil
|
||||
default:
|
||||
return "", fmt.Errorf("未配置驱动构建标签:%s", driverType)
|
||||
}
|
||||
|
||||
@@ -231,6 +231,7 @@ func optionalDriverAgentRevisionTestDrivers(t *testing.T) []string {
|
||||
"iotdb",
|
||||
"clickhouse",
|
||||
"elasticsearch",
|
||||
"trino",
|
||||
}
|
||||
for _, driverType := range drivers {
|
||||
if db.OptionalDriverAgentRevision(driverType) == "" {
|
||||
|
||||
@@ -504,6 +504,39 @@ func TestElasticsearchDriverDefinitionUsesOptionalAgent(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrinoDriverDefinitionUsesOptionalAgent(t *testing.T) {
|
||||
definition, ok := resolveDriverDefinition("trino")
|
||||
if !ok {
|
||||
t.Fatal("expected trino driver definition")
|
||||
}
|
||||
if definition.Name != "Trino" {
|
||||
t.Fatalf("unexpected trino driver name: %q", definition.Name)
|
||||
}
|
||||
if definition.BuiltIn {
|
||||
t.Fatal("expected trino to be an optional driver agent")
|
||||
}
|
||||
if driverGoModulePathMap["trino"] != "github.com/trinodb/trino-go-client" {
|
||||
t.Fatalf("unexpected trino go module path: %q", driverGoModulePathMap["trino"])
|
||||
}
|
||||
if definition.PinnedVersion != "0.333.0" {
|
||||
t.Fatalf("unexpected trino definition pinned version: %q", definition.PinnedVersion)
|
||||
}
|
||||
if definition.DefaultDownloadURL != "builtin://activate/trino" {
|
||||
t.Fatalf("unexpected trino default download URL: %q", definition.DefaultDownloadURL)
|
||||
}
|
||||
if latestDriverVersionMap["trino"] != "0.333.0" {
|
||||
t.Fatalf("unexpected trino pinned version: %q", latestDriverVersionMap["trino"])
|
||||
}
|
||||
|
||||
tags, err := optionalDriverBuildTags("trino", "")
|
||||
if err != nil {
|
||||
t.Fatalf("resolve trino build tags failed: %v", err)
|
||||
}
|
||||
if tags != "gonavi_trino_driver" {
|
||||
t.Fatalf("unexpected trino build tag: %q", tags)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIoTDBDriverDefinitionUsesOptionalAgent(t *testing.T) {
|
||||
definition, ok := resolveDriverDefinition("iotdb")
|
||||
if !ok {
|
||||
|
||||
@@ -2850,6 +2850,20 @@ func quoteQualifiedIdentByType(dbType string, ident string) string {
|
||||
}
|
||||
|
||||
dbType = resolveDDLDBType(connection.ConnectionConfig{Type: dbType})
|
||||
if dbType == "trino" {
|
||||
parts := strings.Split(raw, ".")
|
||||
switch {
|
||||
case len(parts) >= 3:
|
||||
catalog := strings.TrimSpace(parts[0])
|
||||
schema := strings.TrimSpace(parts[1])
|
||||
table := strings.TrimSpace(strings.Join(parts[2:], "."))
|
||||
if catalog != "" && schema != "" && table != "" {
|
||||
return quoteIdentByType(dbType, catalog) + "." + quoteIdentByType(dbType, schema) + "." + quoteIdentByType(dbType, table)
|
||||
}
|
||||
case len(parts) <= 2:
|
||||
return quoteIdentByType(dbType, raw)
|
||||
}
|
||||
}
|
||||
if dbType == "kingbase" {
|
||||
schema, table := db.SplitKingbaseQualifiedName(raw)
|
||||
if table == "" {
|
||||
|
||||
@@ -23,4 +23,5 @@ func registerOptionalDatabaseFactories() {
|
||||
registerDatabaseFactory(newOptionalDriverAgentDatabase("iotdb"), "iotdb", "apache-iotdb", "apache_iotdb")
|
||||
registerDatabaseFactory(newOptionalDriverAgentDatabase("clickhouse"), "clickhouse")
|
||||
registerDatabaseFactory(newOptionalDriverAgentDatabase("elasticsearch"), "elasticsearch", "elastic")
|
||||
registerDatabaseFactory(newOptionalDriverAgentDatabase("trino"), "trino")
|
||||
}
|
||||
|
||||
@@ -23,4 +23,5 @@ func registerOptionalDatabaseFactories() {
|
||||
registerDatabaseFactory(newOptionalDriverAgentDatabase("iotdb"), "iotdb", "apache-iotdb", "apache_iotdb")
|
||||
registerDatabaseFactory(newOptionalDriverAgentDatabase("clickhouse"), "clickhouse")
|
||||
registerDatabaseFactory(newOptionalDriverAgentDatabase("elasticsearch"), "elasticsearch", "elastic")
|
||||
registerDatabaseFactory(newOptionalDriverAgentDatabase("trino"), "trino")
|
||||
}
|
||||
|
||||
@@ -24,5 +24,6 @@ func init() {
|
||||
"iotdb": "src-5ba9da13c6a272f9",
|
||||
"clickhouse": "src-99c8babfefdf142c",
|
||||
"elasticsearch": "src-36b2e2b5f49db9d1",
|
||||
"trino": "src-d264ceca132c185c",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,6 +48,7 @@ var optionalGoDrivers = map[string]struct{}{
|
||||
"iotdb": {},
|
||||
"clickhouse": {},
|
||||
"elasticsearch": {},
|
||||
"trino": {},
|
||||
}
|
||||
|
||||
// optionalDriverAgentRevisions 记录 GoNavi 对各可选 driver-agent 包装逻辑的兼容版本。
|
||||
@@ -150,6 +151,8 @@ func driverDisplayName(driverType string) string {
|
||||
return "ClickHouse"
|
||||
case "elasticsearch":
|
||||
return "Elasticsearch"
|
||||
case "trino":
|
||||
return "Trino"
|
||||
case "chroma":
|
||||
return "Chroma"
|
||||
case "qdrant":
|
||||
|
||||
661
internal/db/trino_impl.go
Normal file
661
internal/db/trino_impl.go
Normal file
@@ -0,0 +1,661 @@
|
||||
//go:build gonavi_full_drivers || gonavi_trino_driver
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"GoNavi-Wails/internal/logger"
|
||||
"GoNavi-Wails/internal/ssh"
|
||||
|
||||
trinodriver "github.com/trinodb/trino-go-client/trino"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultTrinoPort = 8080
|
||||
defaultTrinoSource = "GoNavi"
|
||||
)
|
||||
|
||||
type TrinoDB struct {
|
||||
conn *sql.DB
|
||||
pingTimeout time.Duration
|
||||
forwarder *ssh.LocalForwarder
|
||||
namespace string
|
||||
customClientName string
|
||||
}
|
||||
|
||||
func normalizeTrinoConfig(config connection.ConnectionConfig) connection.ConnectionConfig {
|
||||
normalized := applyTrinoURI(config)
|
||||
normalized = applyTrinoHostURI(normalized)
|
||||
if strings.TrimSpace(normalized.Host) == "" {
|
||||
normalized.Host = "localhost"
|
||||
}
|
||||
if normalized.Port <= 0 {
|
||||
normalized.Port = defaultTrinoPort
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
func applyTrinoURI(config connection.ConnectionConfig) connection.ConnectionConfig {
|
||||
return applyTrinoEndpointURI(config, config.URI, false)
|
||||
}
|
||||
|
||||
func applyTrinoHostURI(config connection.ConnectionConfig) connection.ConnectionConfig {
|
||||
return applyTrinoEndpointURI(config, config.Host, true)
|
||||
}
|
||||
|
||||
func applyTrinoEndpointURI(config connection.ConnectionConfig, raw string, fromHostField bool) connection.ConnectionConfig {
|
||||
uriText := strings.TrimSpace(raw)
|
||||
if uriText == "" {
|
||||
return config
|
||||
}
|
||||
parsed, err := url.Parse(uriText)
|
||||
if err != nil {
|
||||
return config
|
||||
}
|
||||
scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
|
||||
if scheme != "trino" && scheme != "http" && scheme != "https" {
|
||||
return config
|
||||
}
|
||||
if strings.TrimSpace(parsed.Host) == "" {
|
||||
return config
|
||||
}
|
||||
|
||||
if parsed.User != nil {
|
||||
if strings.TrimSpace(config.User) == "" {
|
||||
config.User = parsed.User.Username()
|
||||
}
|
||||
if pass, ok := parsed.User.Password(); ok && config.Password == "" {
|
||||
config.Password = pass
|
||||
}
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
mergeConnectionParamValues(params, parsed.Query())
|
||||
mergeConnectionParamValues(params, connectionParamsFromText(config.ConnectionParams))
|
||||
|
||||
catalog := strings.TrimSpace(parsed.Query().Get("catalog"))
|
||||
schema := strings.TrimSpace(parsed.Query().Get("schema"))
|
||||
if strings.TrimSpace(config.Database) == "" {
|
||||
config.Database = joinTrinoNamespace(catalog, schema)
|
||||
}
|
||||
|
||||
if scheme == "https" {
|
||||
config.UseSSL = true
|
||||
if normalizeSSLModeValue(config.SSLMode) == sslModeDisable || strings.TrimSpace(config.SSLMode) == "" {
|
||||
config.SSLMode = sslModeRequired
|
||||
}
|
||||
}
|
||||
|
||||
defaultPort := config.Port
|
||||
if defaultPort <= 0 {
|
||||
defaultPort = defaultTrinoPort
|
||||
}
|
||||
if fromHostField || strings.TrimSpace(config.Host) == "" {
|
||||
host, port, ok := parseHostPortWithDefault(parsed.Host, defaultPort)
|
||||
if ok {
|
||||
config.Host = host
|
||||
config.Port = port
|
||||
}
|
||||
}
|
||||
if config.Port <= 0 {
|
||||
config.Port = defaultPort
|
||||
}
|
||||
config.ConnectionParams = params.Encode()
|
||||
return config
|
||||
}
|
||||
|
||||
func splitTrinoNamespace(raw string) (string, string) {
|
||||
text := strings.TrimSpace(raw)
|
||||
if text == "" {
|
||||
return "", ""
|
||||
}
|
||||
parts := strings.SplitN(text, ".", 2)
|
||||
catalog := strings.TrimSpace(parts[0])
|
||||
if len(parts) == 1 {
|
||||
return catalog, ""
|
||||
}
|
||||
return catalog, strings.TrimSpace(parts[1])
|
||||
}
|
||||
|
||||
func joinTrinoNamespace(catalog, schema string) string {
|
||||
c := strings.TrimSpace(catalog)
|
||||
s := strings.TrimSpace(schema)
|
||||
switch {
|
||||
case c == "":
|
||||
return s
|
||||
case s == "":
|
||||
return c
|
||||
default:
|
||||
return c + "." + s
|
||||
}
|
||||
}
|
||||
|
||||
func resolveTrinoNamespace(raw string, fallback string) (string, string) {
|
||||
catalog, schema := splitTrinoNamespace(raw)
|
||||
if catalog != "" || schema != "" {
|
||||
return catalog, schema
|
||||
}
|
||||
return splitTrinoNamespace(fallback)
|
||||
}
|
||||
|
||||
func quoteTrinoIdentifier(ident string) string {
|
||||
return `"` + strings.ReplaceAll(strings.TrimSpace(ident), `"`, `""`) + `"`
|
||||
}
|
||||
|
||||
func quoteTrinoQualifiedTable(catalog, schema, table string) string {
|
||||
quoted := make([]string, 0, 3)
|
||||
if trimmed := strings.TrimSpace(catalog); trimmed != "" {
|
||||
quoted = append(quoted, quoteTrinoIdentifier(trimmed))
|
||||
}
|
||||
if trimmed := strings.TrimSpace(schema); trimmed != "" {
|
||||
quoted = append(quoted, quoteTrinoIdentifier(trimmed))
|
||||
}
|
||||
quoted = append(quoted, quoteTrinoIdentifier(table))
|
||||
return strings.Join(quoted, ".")
|
||||
}
|
||||
|
||||
func escapeTrinoSQLLiteral(value string) string {
|
||||
return "'" + strings.ReplaceAll(strings.TrimSpace(value), "'", "''") + "'"
|
||||
}
|
||||
|
||||
func trinoRowValue(row map[string]interface{}, keys ...string) (interface{}, bool) {
|
||||
if len(row) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
for _, key := range keys {
|
||||
for current, value := range row {
|
||||
if strings.EqualFold(strings.TrimSpace(current), strings.TrimSpace(key)) {
|
||||
return value, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func trinoRowString(row map[string]interface{}, keys ...string) string {
|
||||
value, ok := trinoRowValue(row, keys...)
|
||||
if !ok || value == nil {
|
||||
return ""
|
||||
}
|
||||
text := strings.TrimSpace(fmt.Sprintf("%v", value))
|
||||
if strings.EqualFold(text, "<nil>") {
|
||||
return ""
|
||||
}
|
||||
return text
|
||||
}
|
||||
|
||||
func firstTrinoMapValueAsString(row map[string]interface{}) string {
|
||||
for _, value := range row {
|
||||
text := strings.TrimSpace(fmt.Sprintf("%v", value))
|
||||
if !strings.EqualFold(text, "<nil>") {
|
||||
return text
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func firstTrinoRowValueAsString(data []map[string]interface{}) string {
|
||||
if len(data) == 0 {
|
||||
return ""
|
||||
}
|
||||
return firstTrinoMapValueAsString(data[0])
|
||||
}
|
||||
|
||||
func (t *TrinoDB) buildTrinoHTTPClient(config connection.ConnectionConfig) (*http.Client, error) {
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: getConnectTimeout(config),
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).DialContext,
|
||||
MaxIdleConns: 32,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: time.Second,
|
||||
}
|
||||
tlsConfig, err := resolveGenericTLSConfig(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if tlsConfig != nil {
|
||||
transport.TLSClientConfig = tlsConfig
|
||||
}
|
||||
return &http.Client{Transport: transport}, nil
|
||||
}
|
||||
|
||||
func (t *TrinoDB) registerTrinoCustomClient(config connection.ConnectionConfig) (string, error) {
|
||||
client, err := t.buildTrinoHTTPClient(config)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
name := fmt.Sprintf("gonavi-trino-%d", time.Now().UnixNano())
|
||||
if err := trinodriver.RegisterCustomClient(name, client); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return name, nil
|
||||
}
|
||||
|
||||
func buildTrinoDSN(config connection.ConnectionConfig, customClientName string) (string, error) {
|
||||
user := strings.TrimSpace(config.User)
|
||||
if user == "" {
|
||||
return "", fmt.Errorf("Trino 用户名不能为空")
|
||||
}
|
||||
|
||||
scheme := "http"
|
||||
if config.UseSSL {
|
||||
scheme = "https"
|
||||
}
|
||||
if config.Password != "" && scheme != "https" {
|
||||
return "", fmt.Errorf("Trino 启用密码认证时必须使用 HTTPS")
|
||||
}
|
||||
|
||||
params := connectionParamsFromText(config.ConnectionParams)
|
||||
catalog, schema := resolveTrinoNamespace(config.Database, "")
|
||||
if catalog != "" {
|
||||
params.Set("catalog", catalog)
|
||||
}
|
||||
if schema != "" {
|
||||
params.Set("schema", schema)
|
||||
}
|
||||
if strings.TrimSpace(params.Get("source")) == "" {
|
||||
params.Set("source", defaultTrinoSource)
|
||||
}
|
||||
if strings.TrimSpace(params.Get("explicitPrepare")) == "" {
|
||||
params.Set("explicitPrepare", "false")
|
||||
}
|
||||
if strings.TrimSpace(params.Get("query_timeout")) == "" {
|
||||
params.Set("query_timeout", fmt.Sprintf("%ds", getConnectTimeoutSeconds(config)))
|
||||
}
|
||||
if strings.TrimSpace(customClientName) != "" {
|
||||
params.Set("custom_client", strings.TrimSpace(customClientName))
|
||||
}
|
||||
|
||||
endpoint := &url.URL{
|
||||
Scheme: scheme,
|
||||
Host: net.JoinHostPort(strings.TrimSpace(config.Host), strconv.Itoa(config.Port)),
|
||||
RawQuery: params.Encode(),
|
||||
}
|
||||
if config.Password != "" {
|
||||
endpoint.User = url.UserPassword(user, config.Password)
|
||||
} else {
|
||||
endpoint.User = url.User(user)
|
||||
}
|
||||
return endpoint.String(), nil
|
||||
}
|
||||
|
||||
func (t *TrinoDB) Close() error {
|
||||
if t.conn != nil {
|
||||
if err := t.conn.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
t.conn = nil
|
||||
}
|
||||
if t.forwarder != nil {
|
||||
if err := t.forwarder.Close(); err != nil {
|
||||
logger.Warnf("关闭 Trino SSH 端口转发失败:%v", err)
|
||||
}
|
||||
t.forwarder = nil
|
||||
}
|
||||
if t.customClientName != "" {
|
||||
trinodriver.DeregisterCustomClient(t.customClientName)
|
||||
t.customClientName = ""
|
||||
}
|
||||
t.namespace = ""
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TrinoDB) Connect(config connection.ConnectionConfig) error {
|
||||
_ = t.Close()
|
||||
|
||||
runConfig := normalizeTrinoConfig(config)
|
||||
t.pingTimeout = getConnectTimeout(runConfig)
|
||||
|
||||
if runConfig.UseSSH {
|
||||
forwarder, err := ssh.GetOrCreateLocalForwarder(runConfig.SSH, runConfig.Host, runConfig.Port)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建 SSH 隧道失败:%w", err)
|
||||
}
|
||||
t.forwarder = forwarder
|
||||
|
||||
host, portText, err := net.SplitHostPort(forwarder.LocalAddr)
|
||||
if err != nil {
|
||||
_ = t.Close()
|
||||
return fmt.Errorf("解析本地转发地址失败:%w", err)
|
||||
}
|
||||
port, err := strconv.Atoi(portText)
|
||||
if err != nil {
|
||||
_ = t.Close()
|
||||
return fmt.Errorf("解析本地端口失败:%w", err)
|
||||
}
|
||||
runConfig.Host = host
|
||||
runConfig.Port = port
|
||||
runConfig.UseSSH = false
|
||||
logger.Infof("Trino 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
|
||||
}
|
||||
|
||||
customClientName, err := t.registerTrinoCustomClient(runConfig)
|
||||
if err != nil {
|
||||
_ = t.Close()
|
||||
return fmt.Errorf("注册 Trino 自定义 HTTP 客户端失败:%w", err)
|
||||
}
|
||||
t.customClientName = customClientName
|
||||
|
||||
dsn, err := buildTrinoDSN(runConfig, customClientName)
|
||||
if err != nil {
|
||||
_ = t.Close()
|
||||
return err
|
||||
}
|
||||
conn, err := sql.Open("trino", dsn)
|
||||
if err != nil {
|
||||
_ = t.Close()
|
||||
return err
|
||||
}
|
||||
t.conn = conn
|
||||
t.namespace = strings.TrimSpace(runConfig.Database)
|
||||
if err := t.Ping(); err != nil {
|
||||
_ = t.Close()
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TrinoDB) Ping() error {
|
||||
if t.conn == nil {
|
||||
return fmt.Errorf("连接未打开")
|
||||
}
|
||||
timeout := t.pingTimeout
|
||||
if timeout <= 0 {
|
||||
timeout = 10 * time.Second
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
rows, err := t.conn.QueryContext(ctx, "SELECT 1")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
if err := rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf("连接查询验证未返回结果")
|
||||
}
|
||||
var value sql.NullInt64
|
||||
if err := rows.Scan(&value); err != nil {
|
||||
return err
|
||||
}
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
func (t *TrinoDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
if t.conn == nil {
|
||||
return nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
rows, err := t.conn.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
func (t *TrinoDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
return t.QueryContext(context.Background(), query)
|
||||
}
|
||||
|
||||
func (t *TrinoDB) StreamQueryContext(ctx context.Context, query string, consumer QueryStreamConsumer) error {
|
||||
if t.conn == nil {
|
||||
return fmt.Errorf("连接未打开")
|
||||
}
|
||||
rows, err := t.conn.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
return streamRows(rows, consumer)
|
||||
}
|
||||
|
||||
func (t *TrinoDB) StreamQuery(query string, consumer QueryStreamConsumer) error {
|
||||
return t.StreamQueryContext(context.Background(), query, consumer)
|
||||
}
|
||||
|
||||
func (t *TrinoDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
if t.conn == nil {
|
||||
return 0, fmt.Errorf("连接未打开")
|
||||
}
|
||||
res, err := t.conn.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return 0, nil
|
||||
}
|
||||
return affected, nil
|
||||
}
|
||||
|
||||
func (t *TrinoDB) Exec(query string) (int64, error) {
|
||||
return t.ExecContext(context.Background(), query)
|
||||
}
|
||||
|
||||
func (t *TrinoDB) queryTrinoSingleColumnStrings(query string) ([]string, error) {
|
||||
data, _, err := t.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := make([]string, 0, len(data))
|
||||
for _, row := range data {
|
||||
text := firstTrinoMapValueAsString(row)
|
||||
if text != "" {
|
||||
result = append(result, text)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (t *TrinoDB) GetDatabases() ([]string, error) {
|
||||
catalogs, err := t.queryTrinoSingleColumnStrings("SHOW CATALOGS")
|
||||
if err != nil {
|
||||
if strings.TrimSpace(t.namespace) != "" {
|
||||
return []string{t.namespace}, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
namespaces := make([]string, 0, len(catalogs)*2)
|
||||
seen := make(map[string]struct{}, len(catalogs)*4)
|
||||
var lastErr error
|
||||
for _, catalog := range catalogs {
|
||||
query := fmt.Sprintf("SHOW SCHEMAS FROM %s", quoteTrinoIdentifier(catalog))
|
||||
schemas, schemaErr := t.queryTrinoSingleColumnStrings(query)
|
||||
if schemaErr != nil {
|
||||
lastErr = schemaErr
|
||||
continue
|
||||
}
|
||||
for _, schema := range schemas {
|
||||
namespace := joinTrinoNamespace(catalog, schema)
|
||||
if namespace == "" {
|
||||
continue
|
||||
}
|
||||
key := strings.ToLower(namespace)
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
namespaces = append(namespaces, namespace)
|
||||
}
|
||||
}
|
||||
|
||||
if len(namespaces) == 0 {
|
||||
if strings.TrimSpace(t.namespace) != "" {
|
||||
return []string{t.namespace}, nil
|
||||
}
|
||||
if lastErr != nil {
|
||||
return nil, lastErr
|
||||
}
|
||||
}
|
||||
sort.Strings(namespaces)
|
||||
return namespaces, nil
|
||||
}
|
||||
|
||||
func (t *TrinoDB) GetTables(dbName string) ([]string, error) {
|
||||
catalog, schema := resolveTrinoNamespace(dbName, t.namespace)
|
||||
if catalog == "" || schema == "" {
|
||||
return nil, fmt.Errorf("Trino 默认命名空间必须使用 catalog.schema")
|
||||
}
|
||||
query := fmt.Sprintf("SHOW TABLES FROM %s.%s", quoteTrinoIdentifier(catalog), quoteTrinoIdentifier(schema))
|
||||
tables, err := t.queryTrinoSingleColumnStrings(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sort.Strings(tables)
|
||||
return tables, nil
|
||||
}
|
||||
|
||||
func (t *TrinoDB) GetCreateStatement(dbName, tableName string) (string, error) {
|
||||
catalog, schema := resolveTrinoNamespace(dbName, t.namespace)
|
||||
if catalog == "" || schema == "" {
|
||||
return "", fmt.Errorf("Trino 默认命名空间必须使用 catalog.schema")
|
||||
}
|
||||
query := fmt.Sprintf("SHOW CREATE TABLE %s", quoteTrinoQualifiedTable(catalog, schema, tableName))
|
||||
data, _, err := t.Query(query)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
ddl := firstTrinoRowValueAsString(data)
|
||||
if ddl == "" {
|
||||
return "", fmt.Errorf("未返回建表语句")
|
||||
}
|
||||
return ddl, nil
|
||||
}
|
||||
|
||||
func buildTrinoColumnsQuery(catalog, schema, tableName string) string {
|
||||
return fmt.Sprintf(`SELECT
|
||||
column_name,
|
||||
data_type,
|
||||
is_nullable,
|
||||
column_default
|
||||
FROM %s.information_schema.columns
|
||||
WHERE table_schema = %s AND table_name = %s
|
||||
ORDER BY ordinal_position`,
|
||||
quoteTrinoIdentifier(catalog),
|
||||
escapeTrinoSQLLiteral(schema),
|
||||
escapeTrinoSQLLiteral(tableName),
|
||||
)
|
||||
}
|
||||
|
||||
func buildTrinoColumnDefinitions(data []map[string]interface{}) []connection.ColumnDefinition {
|
||||
result := make([]connection.ColumnDefinition, 0, len(data))
|
||||
for _, row := range data {
|
||||
column := connection.ColumnDefinition{
|
||||
Name: trinoRowString(row, "column_name", "Column", "Field"),
|
||||
Type: trinoRowString(row, "data_type", "Type"),
|
||||
Nullable: strings.ToUpper(trinoRowString(row, "is_nullable", "Null")),
|
||||
}
|
||||
if rawDefault, ok := trinoRowValue(row, "column_default", "Default"); ok && rawDefault != nil {
|
||||
def := strings.TrimSpace(fmt.Sprintf("%v", rawDefault))
|
||||
if !strings.EqualFold(def, "<nil>") && def != "" {
|
||||
column.Default = &def
|
||||
}
|
||||
}
|
||||
if column.Nullable == "" {
|
||||
column.Nullable = "YES"
|
||||
}
|
||||
result = append(result, column)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (t *TrinoDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
|
||||
catalog, schema := resolveTrinoNamespace(dbName, t.namespace)
|
||||
if catalog == "" || schema == "" {
|
||||
return nil, fmt.Errorf("Trino 默认命名空间必须使用 catalog.schema")
|
||||
}
|
||||
data, _, err := t.Query(buildTrinoColumnsQuery(catalog, schema, tableName))
|
||||
if err == nil {
|
||||
return buildTrinoColumnDefinitions(data), nil
|
||||
}
|
||||
|
||||
describeQuery := fmt.Sprintf("DESCRIBE %s", quoteTrinoQualifiedTable(catalog, schema, tableName))
|
||||
describeRows, _, describeErr := t.Query(describeQuery)
|
||||
if describeErr != nil {
|
||||
return nil, err
|
||||
}
|
||||
columns := make([]connection.ColumnDefinition, 0, len(describeRows))
|
||||
for _, row := range describeRows {
|
||||
name := trinoRowString(row, "Column", "column_name", "Field")
|
||||
if name == "" || strings.HasPrefix(name, "#") {
|
||||
continue
|
||||
}
|
||||
columns = append(columns, connection.ColumnDefinition{
|
||||
Name: name,
|
||||
Type: trinoRowString(row, "Type", "data_type"),
|
||||
Nullable: "YES",
|
||||
})
|
||||
}
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
func (t *TrinoDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
|
||||
catalog, schema := resolveTrinoNamespace(dbName, t.namespace)
|
||||
if catalog == "" || schema == "" {
|
||||
return nil, fmt.Errorf("Trino 默认命名空间必须使用 catalog.schema")
|
||||
}
|
||||
query := fmt.Sprintf(`SELECT
|
||||
table_name,
|
||||
column_name,
|
||||
data_type
|
||||
FROM %s.information_schema.columns
|
||||
WHERE table_schema = %s
|
||||
ORDER BY table_name, ordinal_position`,
|
||||
quoteTrinoIdentifier(catalog),
|
||||
escapeTrinoSQLLiteral(schema),
|
||||
)
|
||||
data, _, err := t.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := make([]connection.ColumnDefinitionWithTable, 0, len(data))
|
||||
for _, row := range data {
|
||||
result = append(result, connection.ColumnDefinitionWithTable{
|
||||
TableName: trinoRowString(row, "table_name", "TABLE_NAME"),
|
||||
Name: trinoRowString(row, "column_name", "COLUMN_NAME"),
|
||||
Type: trinoRowString(row, "data_type", "DATA_TYPE"),
|
||||
})
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (t *TrinoDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
|
||||
return []connection.IndexDefinition{}, nil
|
||||
}
|
||||
|
||||
func (t *TrinoDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
|
||||
return []connection.ForeignKeyDefinition{}, nil
|
||||
}
|
||||
|
||||
func (t *TrinoDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
|
||||
return []connection.TriggerDefinition{}, nil
|
||||
}
|
||||
|
||||
func (t *TrinoDB) OpenSessionExecer(ctx context.Context) (StatementExecer, error) {
|
||||
if t.conn == nil {
|
||||
return nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
conn, err := t.conn.Conn(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewSQLConnStatementExecer(conn), nil
|
||||
}
|
||||
Reference in New Issue
Block a user