mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-06-01 02:19:35 +08:00
- 新增 CA 证书、客户端证书和私钥路径配置 - 为 MySQL、PostgreSQL、ClickHouse、MongoDB、Redis 等连接接入 TLS 证书 - 修正 SSL 模式下证书校验、明文回退和 DER 证书兼容问题 - 补充证书路径保存、RPC 传递和 DSN 生成回归测试 Refs #463
294 lines
7.7 KiB
Go
294 lines
7.7 KiB
Go
//go:build gonavi_full_drivers || gonavi_starrocks_driver
|
||
|
||
package db
|
||
|
||
import (
|
||
"database/sql"
|
||
"fmt"
|
||
"strings"
|
||
|
||
"GoNavi-Wails/internal/connection"
|
||
"GoNavi-Wails/internal/ssh"
|
||
"GoNavi-Wails/internal/utils"
|
||
|
||
mysqlDriver "github.com/go-sql-driver/mysql"
|
||
)
|
||
|
||
const (
|
||
starRocksDriverName = "starrocks"
|
||
defaultStarRocksPort = 9030
|
||
)
|
||
|
||
// StarRocksDB 使用独立 driver 名称接入,底层协议兼容 MySQL。
|
||
type StarRocksDB struct {
|
||
MySQLDB
|
||
}
|
||
|
||
func init() {
|
||
for _, name := range sql.Drivers() {
|
||
if name == starRocksDriverName {
|
||
return
|
||
}
|
||
}
|
||
sql.Register(starRocksDriverName, &mysqlDriver.MySQLDriver{})
|
||
}
|
||
|
||
func applyStarRocksURI(config connection.ConnectionConfig) connection.ConnectionConfig {
|
||
uriText := strings.TrimSpace(config.URI)
|
||
if uriText == "" {
|
||
return config
|
||
}
|
||
|
||
parsed, ok := parseMySQLCompatibleURI(uriText, "starrocks", "mysql")
|
||
if !ok {
|
||
return config
|
||
}
|
||
|
||
if parsed.User != nil {
|
||
if config.User == "" {
|
||
config.User = parsed.User.Username()
|
||
}
|
||
if pass, ok := parsed.User.Password(); ok && config.Password == "" {
|
||
config.Password = pass
|
||
}
|
||
}
|
||
|
||
if dbName := strings.TrimPrefix(parsed.Path, "/"); dbName != "" && config.Database == "" {
|
||
config.Database = dbName
|
||
}
|
||
|
||
defaultPort := config.Port
|
||
if defaultPort <= 0 {
|
||
defaultPort = defaultStarRocksPort
|
||
}
|
||
|
||
hostsFromURI := make([]string, 0, 4)
|
||
hostText := strings.TrimSpace(parsed.Host)
|
||
if hostText != "" {
|
||
for _, entry := range strings.Split(hostText, ",") {
|
||
host, port, ok := parseHostPortWithDefault(entry, defaultPort)
|
||
if !ok {
|
||
continue
|
||
}
|
||
hostsFromURI = append(hostsFromURI, normalizeMySQLAddress(host, port))
|
||
}
|
||
}
|
||
|
||
if len(config.Hosts) == 0 && len(hostsFromURI) > 0 {
|
||
config.Hosts = hostsFromURI
|
||
}
|
||
if strings.TrimSpace(config.Host) == "" && len(hostsFromURI) > 0 {
|
||
host, port, ok := parseHostPortWithDefault(hostsFromURI[0], defaultPort)
|
||
if ok {
|
||
config.Host = host
|
||
config.Port = port
|
||
}
|
||
}
|
||
|
||
if config.Topology == "" {
|
||
topology := strings.TrimSpace(parsed.Query().Get("topology"))
|
||
if topology != "" {
|
||
config.Topology = strings.ToLower(topology)
|
||
}
|
||
}
|
||
|
||
return config
|
||
}
|
||
|
||
func collectStarRocksAddresses(config connection.ConnectionConfig) []string {
|
||
defaultPort := config.Port
|
||
if defaultPort <= 0 {
|
||
defaultPort = defaultStarRocksPort
|
||
}
|
||
|
||
candidates := make([]string, 0, len(config.Hosts)+1)
|
||
if len(config.Hosts) > 0 {
|
||
candidates = append(candidates, config.Hosts...)
|
||
} else {
|
||
candidates = append(candidates, normalizeMySQLAddress(config.Host, defaultPort))
|
||
}
|
||
|
||
result := make([]string, 0, len(candidates))
|
||
seen := make(map[string]struct{}, len(candidates))
|
||
for _, entry := range candidates {
|
||
host, port, ok := parseHostPortWithDefault(entry, defaultPort)
|
||
if !ok {
|
||
continue
|
||
}
|
||
normalized := normalizeMySQLAddress(host, port)
|
||
if _, exists := seen[normalized]; exists {
|
||
continue
|
||
}
|
||
seen[normalized] = struct{}{}
|
||
result = append(result, normalized)
|
||
}
|
||
|
||
return result
|
||
}
|
||
|
||
func starRocksMetadataLiteral(value string) string {
|
||
return "'" + strings.ReplaceAll(value, "'", "''") + "'"
|
||
}
|
||
|
||
func buildStarRocksColumnsQuery(dbName, tableName string) string {
|
||
schemaPredicate := "TABLE_SCHEMA = DATABASE()"
|
||
if strings.TrimSpace(dbName) != "" {
|
||
schemaPredicate = fmt.Sprintf("TABLE_SCHEMA = %s", starRocksMetadataLiteral(strings.TrimSpace(dbName)))
|
||
}
|
||
|
||
return fmt.Sprintf(`SELECT
|
||
COLUMN_NAME,
|
||
COLUMN_TYPE,
|
||
IS_NULLABLE,
|
||
COLUMN_KEY,
|
||
COLUMN_DEFAULT,
|
||
EXTRA,
|
||
COLUMN_COMMENT
|
||
FROM information_schema.columns
|
||
WHERE %s AND TABLE_NAME = %s
|
||
ORDER BY ORDINAL_POSITION`, schemaPredicate, starRocksMetadataLiteral(strings.TrimSpace(tableName)))
|
||
}
|
||
|
||
func getStarRocksRowValue(row map[string]interface{}, keys ...string) (interface{}, bool) {
|
||
if len(row) == 0 {
|
||
return nil, false
|
||
}
|
||
for _, key := range keys {
|
||
for k, v := range row {
|
||
if !strings.EqualFold(strings.TrimSpace(k), strings.TrimSpace(key)) {
|
||
continue
|
||
}
|
||
return v, true
|
||
}
|
||
}
|
||
return nil, false
|
||
}
|
||
|
||
func getStarRocksRowString(row map[string]interface{}, keys ...string) string {
|
||
v, ok := getStarRocksRowValue(row, keys...)
|
||
if !ok || v == nil {
|
||
return ""
|
||
}
|
||
text := strings.TrimSpace(fmt.Sprintf("%v", v))
|
||
if text == "" || strings.EqualFold(text, "<nil>") {
|
||
return ""
|
||
}
|
||
return text
|
||
}
|
||
|
||
func buildStarRocksColumnDefinitions(data []map[string]interface{}) []connection.ColumnDefinition {
|
||
columns := make([]connection.ColumnDefinition, 0, len(data))
|
||
for _, row := range data {
|
||
col := connection.ColumnDefinition{
|
||
Name: getStarRocksRowString(row, "Field", "COLUMN_NAME"),
|
||
Type: getStarRocksRowString(row, "Type", "COLUMN_TYPE"),
|
||
Nullable: getStarRocksRowString(row, "Null", "IS_NULLABLE"),
|
||
Key: strings.ToUpper(getStarRocksRowString(row, "Key", "COLUMN_KEY")),
|
||
Extra: getStarRocksRowString(row, "Extra", "EXTRA"),
|
||
Comment: getStarRocksRowString(row, "Comment", "COLUMN_COMMENT"),
|
||
}
|
||
|
||
if rawDefault, ok := getStarRocksRowValue(row, "Default", "COLUMN_DEFAULT"); ok && rawDefault != nil {
|
||
def := fmt.Sprintf("%v", rawDefault)
|
||
if strings.EqualFold(def, "<nil>") {
|
||
def = ""
|
||
}
|
||
col.Default = &def
|
||
}
|
||
|
||
columns = append(columns, col)
|
||
}
|
||
return columns
|
||
}
|
||
|
||
func (s *StarRocksDB) getDSN(config connection.ConnectionConfig) (string, error) {
|
||
database := config.Database
|
||
protocol := "tcp"
|
||
address := normalizeMySQLAddress(config.Host, config.Port)
|
||
|
||
if config.UseSSH {
|
||
netName, err := ssh.RegisterSSHNetwork(config.SSH)
|
||
if err != nil {
|
||
return "", fmt.Errorf("创建 SSH 隧道失败:%w", err)
|
||
}
|
||
protocol = netName
|
||
}
|
||
|
||
return buildMySQLCompatibleDSN(config, protocol, address, database)
|
||
}
|
||
|
||
func resolveStarRocksCredential(config connection.ConnectionConfig, addressIndex int) (string, string) {
|
||
primaryUser := strings.TrimSpace(config.User)
|
||
primaryPassword := config.Password
|
||
replicaUser := strings.TrimSpace(config.MySQLReplicaUser)
|
||
replicaPassword := config.MySQLReplicaPassword
|
||
|
||
if addressIndex > 0 && replicaUser != "" {
|
||
return replicaUser, replicaPassword
|
||
}
|
||
|
||
if primaryUser == "" && replicaUser != "" {
|
||
return replicaUser, replicaPassword
|
||
}
|
||
|
||
return config.User, primaryPassword
|
||
}
|
||
|
||
func (s *StarRocksDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
|
||
data, _, err := s.Query(buildStarRocksColumnsQuery(dbName, tableName))
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return buildStarRocksColumnDefinitions(data), nil
|
||
}
|
||
|
||
func (s *StarRocksDB) Connect(config connection.ConnectionConfig) error {
|
||
runConfig := applyStarRocksURI(config)
|
||
addresses := collectStarRocksAddresses(runConfig)
|
||
if len(addresses) == 0 {
|
||
return fmt.Errorf("连接建立后验证失败:未找到可用的 StarRocks 地址")
|
||
}
|
||
|
||
var errorDetails []string
|
||
for index, address := range addresses {
|
||
candidateConfig := runConfig
|
||
host, port, ok := parseHostPortWithDefault(address, defaultStarRocksPort)
|
||
if !ok {
|
||
continue
|
||
}
|
||
candidateConfig.Host = host
|
||
candidateConfig.Port = port
|
||
candidateConfig.User, candidateConfig.Password = resolveStarRocksCredential(runConfig, index)
|
||
|
||
dsn, err := s.getDSN(candidateConfig)
|
||
if err != nil {
|
||
errorDetails = append(errorDetails, fmt.Sprintf("%s 生成连接串失败: %v", address, err))
|
||
continue
|
||
}
|
||
db, err := sql.Open(starRocksDriverName, dsn)
|
||
if err != nil {
|
||
errorDetails = append(errorDetails, fmt.Sprintf("%s 打开失败: %v", address, err))
|
||
continue
|
||
}
|
||
|
||
timeout := getConnectTimeout(candidateConfig)
|
||
ctx, cancel := utils.ContextWithTimeout(timeout)
|
||
pingErr := db.PingContext(ctx)
|
||
cancel()
|
||
if pingErr != nil {
|
||
_ = db.Close()
|
||
errorDetails = append(errorDetails, fmt.Sprintf("%s 验证失败: %v", address, pingErr))
|
||
continue
|
||
}
|
||
|
||
s.conn = db
|
||
s.pingTimeout = timeout
|
||
return nil
|
||
}
|
||
|
||
if len(errorDetails) == 0 {
|
||
return fmt.Errorf("连接建立后验证失败:未找到可用的 StarRocks 地址")
|
||
}
|
||
return fmt.Errorf("连接建立后验证失败:%s", strings.Join(errorDetails, ";"))
|
||
}
|