Files
MyGoNavi/internal/db/oceanbase_impl.go
Syngnat 527ecd37e1 🐛 fix(oceanbase): 增强 Oracle 协议连接校验与诊断
- 运行时校验可选 driver-agent revision,避免旧代理继续被复用
- OceanBase agent revision 纳入 oracle_impl.go 指纹并重新生成
- OceanBase Oracle 保留 URI 中的 Oracle 连接参数
- Oracle DSN 默认写入连接和读取超时,并输出脱敏诊断摘要
- 补充 revision、Oracle DSN、OceanBase Oracle 参数提升测试
2026-05-14 10:30:17 +08:00

483 lines
14 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
//go:build gonavi_full_drivers || gonavi_oceanbase_driver
package db
import (
"context"
"database/sql"
"fmt"
"net/url"
"strings"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/logger"
"GoNavi-Wails/internal/ssh"
"GoNavi-Wails/internal/utils"
mysqlDriver "github.com/go-sql-driver/mysql"
)
const (
oceanbaseDriverName = "oceanbase"
defaultOceanBasePort = 2881
oceanBaseProtocolMySQL = "mysql"
oceanBaseProtocolOracle = "oracle"
)
// OceanBaseDB 支持 OceanBase MySQL/Oracle 两种租户协议。
type OceanBaseDB struct {
MySQLDB
oracle *OracleDB
protocol string
}
func init() {
for _, name := range sql.Drivers() {
if name == oceanbaseDriverName {
return
}
}
sql.Register(oceanbaseDriverName, &mysqlDriver.MySQLDriver{})
}
func applyOceanBaseURI(config connection.ConnectionConfig) connection.ConnectionConfig {
uriText := strings.TrimSpace(config.URI)
if uriText == "" {
return config
}
parsed, ok := parseMySQLCompatibleURI(uriText, "oceanbase", "mysql")
if !ok {
return config
}
if parsed.User != nil {
if config.User == "" {
config.User = parsed.User.Username()
}
if pass, ok := parsed.User.Password(); ok && config.Password == "" {
config.Password = pass
}
}
if dbName := strings.TrimPrefix(parsed.Path, "/"); dbName != "" && config.Database == "" {
config.Database = dbName
}
defaultPort := config.Port
if defaultPort <= 0 {
defaultPort = defaultOceanBasePort
}
hostsFromURI := make([]string, 0, 4)
hostText := strings.TrimSpace(parsed.Host)
if hostText != "" {
for _, entry := range strings.Split(hostText, ",") {
host, port, ok := parseHostPortWithDefault(entry, defaultPort)
if !ok {
continue
}
hostsFromURI = append(hostsFromURI, normalizeMySQLAddress(host, port))
}
}
if len(config.Hosts) == 0 && len(hostsFromURI) > 0 {
config.Hosts = hostsFromURI
}
if strings.TrimSpace(config.Host) == "" && len(hostsFromURI) > 0 {
host, port, ok := parseHostPortWithDefault(hostsFromURI[0], defaultPort)
if ok {
config.Host = host
config.Port = port
}
}
if config.Topology == "" {
topology := strings.TrimSpace(parsed.Query().Get("topology"))
if topology != "" {
config.Topology = strings.ToLower(topology)
}
}
return config
}
func collectOceanBaseAddresses(config connection.ConnectionConfig) []string {
defaultPort := config.Port
if defaultPort <= 0 {
defaultPort = defaultOceanBasePort
}
candidates := make([]string, 0, len(config.Hosts)+1)
if len(config.Hosts) > 0 {
candidates = append(candidates, config.Hosts...)
} else {
candidates = append(candidates, normalizeMySQLAddress(config.Host, defaultPort))
}
result := make([]string, 0, len(candidates))
seen := make(map[string]struct{}, len(candidates))
for _, entry := range candidates {
host, port, ok := parseHostPortWithDefault(entry, defaultPort)
if !ok {
continue
}
normalized := normalizeMySQLAddress(host, port)
if _, exists := seen[normalized]; exists {
continue
}
seen[normalized] = struct{}{}
result = append(result, normalized)
}
return result
}
func (o *OceanBaseDB) getDSN(config connection.ConnectionConfig) (string, error) {
database := config.Database
protocol := "tcp"
address := normalizeMySQLAddress(config.Host, config.Port)
if config.UseSSH {
netName, err := ssh.RegisterSSHNetwork(config.SSH)
if err != nil {
return "", fmt.Errorf("创建 SSH 隧道失败:%w", err)
}
protocol = netName
}
return buildMySQLCompatibleDSN(config, protocol, address, database), nil
}
func normalizeOceanBaseProtocol(raw string) string {
switch strings.ToLower(strings.TrimSpace(raw)) {
case oceanBaseProtocolOracle, "oracle-mode", "oracle_mode", "oboracle":
return oceanBaseProtocolOracle
case oceanBaseProtocolMySQL, "mysql-compatible", "mysql_compatible", "mysql-mode", "mysql_mode", "obmysql", "":
return oceanBaseProtocolMySQL
default:
return ""
}
}
func unsupportedOceanBaseProtocolError(raw string) error {
return fmt.Errorf("OceanBase 当前仅支持 MySQL/Oracle 租户协议,不支持 %q请改为 MySQL 或 Oracle", strings.TrimSpace(raw))
}
func resolveOceanBaseProtocolFromValues(values url.Values) (string, error) {
if len(values) == 0 {
return "", nil
}
for _, key := range []string{"protocol", "oceanBaseProtocol", "oceanbaseProtocol", "tenantMode", "compatMode", "mode"} {
if value := strings.TrimSpace(values.Get(key)); value != "" {
protocol := normalizeOceanBaseProtocol(value)
if protocol == "" {
return "", unsupportedOceanBaseProtocolError(value)
}
return protocol, nil
}
}
return "", nil
}
func resolveOceanBaseProtocol(config connection.ConnectionConfig) (string, error) {
explicitProtocol := ""
if explicit := strings.TrimSpace(config.OceanBaseProtocol); explicit != "" {
protocol := normalizeOceanBaseProtocol(explicit)
if protocol == "" {
return "", unsupportedOceanBaseProtocolError(explicit)
}
explicitProtocol = protocol
}
if protocol, err := resolveOceanBaseProtocolFromValues(connectionParamsFromText(config.ConnectionParams)); err != nil {
return "", err
} else if protocol != "" {
if explicitProtocol != "" {
return explicitProtocol, nil
}
return protocol, nil
}
if protocol, err := resolveOceanBaseProtocolFromValues(connectionParamsFromURI(config.URI, "oceanbase", "mysql")); err != nil {
return "", err
} else if protocol != "" {
if explicitProtocol != "" {
return explicitProtocol, nil
}
return protocol, nil
}
if explicitProtocol != "" {
return explicitProtocol, nil
}
return oceanBaseProtocolMySQL, nil
}
func stripOceanBaseProtocolParams(raw string) string {
values := connectionParamsFromText(raw)
if len(values) == 0 {
return strings.TrimSpace(raw)
}
for _, key := range []string{"protocol", "oceanBaseProtocol", "oceanbaseProtocol", "tenantMode", "compatMode", "mode"} {
values.Del(key)
}
return values.Encode()
}
func stripOceanBaseProtocolURI(raw string) string {
text := strings.TrimSpace(raw)
if text == "" {
return text
}
parsed, ok := parseConnectionURI(text, "oceanbase", "mysql")
if !ok {
return text
}
values := parsed.Query()
if len(values) == 0 {
return text
}
for _, key := range []string{"protocol", "oceanBaseProtocol", "oceanbaseProtocol", "tenantMode", "compatMode", "mode"} {
values.Del(key)
}
parsed.RawQuery = values.Encode()
return parsed.String()
}
func withoutOceanBaseProtocolParams(config connection.ConnectionConfig) connection.ConnectionConfig {
next := config
next.OceanBaseProtocol = ""
next.ConnectionParams = stripOceanBaseProtocolParams(config.ConnectionParams)
next.URI = stripOceanBaseProtocolURI(config.URI)
return next
}
func promoteOceanBaseOracleURIParams(config connection.ConnectionConfig) connection.ConnectionConfig {
uriParams := connectionParamsFromURI(config.URI, "oceanbase", "mysql")
if len(uriParams) == 0 {
return config
}
for _, key := range []string{"protocol", "oceanBaseProtocol", "oceanbaseProtocol", "tenantMode", "compatMode", "mode"} {
uriParams.Del(key)
}
if len(uriParams) == 0 {
return config
}
merged := url.Values{}
mergeConnectionParamValuesWithAllowlist(merged, uriParams, oracleConnectionParamNames)
mergeConnectionParamValuesWithAllowlist(merged, connectionParamsFromText(config.ConnectionParams), oracleConnectionParamNames)
config.ConnectionParams = merged.Encode()
return config
}
func prepareOceanBaseOracleConfig(config connection.ConnectionConfig) connection.ConnectionConfig {
runConfig := withoutOceanBaseProtocolParams(applyOceanBaseURI(config))
runConfig = promoteOceanBaseOracleURIParams(runConfig)
runConfig.Type = "oracle"
// OracleDB 不解析 oceanbase:// URI。连接要素已落到结构化字段和 ConnectionParams。
runConfig.URI = ""
return runConfig
}
func isOceanBaseOracleTenantMySQLDriverError(err error) bool {
if err == nil {
return false
}
text := strings.ToLower(err.Error())
return strings.Contains(text, "oracle tenant") && strings.Contains(text, "not supported")
}
func formatOceanBaseMySQLAttemptError(address string, err error) string {
if isOceanBaseOracleTenantMySQLDriverError(err) {
return fmt.Sprintf("%s 验证失败: 当前选择的是 OceanBase MySQL 协议,但服务端返回 Oracle 租户不支持 MySQL 客户端驱动;请在连接配置中将 OceanBase 协议切换为 Oracle并填写服务名 (Service Name)", address)
}
return fmt.Sprintf("%s 验证失败: %v", address, err)
}
func (o *OceanBaseDB) connectOracle(config connection.ConnectionConfig) error {
runConfig := prepareOceanBaseOracleConfig(config)
if strings.TrimSpace(runConfig.Database) == "" {
return fmt.Errorf("OceanBase Oracle 协议需要填写服务名Service Name请在连接配置中填写租户监听的服务名")
}
oracleDB := &OracleDB{}
if err := oracleDB.Connect(runConfig); err != nil {
return fmt.Errorf("OceanBase Oracle 协议连接失败:%w", err)
}
o.oracle = oracleDB
o.protocol = oceanBaseProtocolOracle
return nil
}
func (o *OceanBaseDB) Connect(config connection.ConnectionConfig) error {
o.oracle = nil
o.protocol = oceanBaseProtocolMySQL
appliedConfig := applyOceanBaseURI(config)
protocol, err := resolveOceanBaseProtocol(appliedConfig)
if err != nil {
return err
}
runConfig := withoutOceanBaseProtocolParams(appliedConfig)
if protocol == oceanBaseProtocolOracle {
logger.Infof("OceanBase 使用 Oracle 协议连接:地址=%s:%d 用户=%s", runConfig.Host, runConfig.Port, runConfig.User)
return o.connectOracle(runConfig)
}
addresses := collectOceanBaseAddresses(runConfig)
if len(addresses) == 0 {
return fmt.Errorf("连接建立后验证失败:未找到可用的 OceanBase 地址")
}
var errorDetails []string
for index, address := range addresses {
candidateConfig := runConfig
host, port, ok := parseHostPortWithDefault(address, defaultOceanBasePort)
if !ok {
continue
}
candidateConfig.Host = host
candidateConfig.Port = port
candidateConfig.User, candidateConfig.Password = resolveMySQLCredential(runConfig, index)
dsn, err := o.getDSN(candidateConfig)
if err != nil {
errorDetails = append(errorDetails, fmt.Sprintf("%s 生成连接串失败: %v", address, err))
continue
}
db, err := sql.Open(oceanbaseDriverName, dsn)
if err != nil {
errorDetails = append(errorDetails, fmt.Sprintf("%s 打开失败: %v", address, err))
continue
}
timeout := getConnectTimeout(candidateConfig)
ctx, cancel := utils.ContextWithTimeout(timeout)
pingErr := db.PingContext(ctx)
cancel()
if pingErr != nil {
_ = db.Close()
errorDetails = append(errorDetails, formatOceanBaseMySQLAttemptError(address, pingErr))
continue
}
o.conn = db
o.pingTimeout = timeout
o.protocol = oceanBaseProtocolMySQL
return nil
}
if len(errorDetails) == 0 {
return fmt.Errorf("连接建立后验证失败:未找到可用的 OceanBase 地址")
}
return fmt.Errorf("连接建立后验证失败:%s", strings.Join(errorDetails, ""))
}
func (o *OceanBaseDB) activeDatabase() Database {
if o.oracle != nil {
return o.oracle
}
return &o.MySQLDB
}
func (o *OceanBaseDB) Close() error {
if o.oracle != nil {
err := o.oracle.Close()
o.oracle = nil
return err
}
return o.MySQLDB.Close()
}
func (o *OceanBaseDB) Ping() error {
return o.activeDatabase().Ping()
}
func (o *OceanBaseDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
if q, ok := o.activeDatabase().(interface {
QueryContext(context.Context, string) ([]map[string]interface{}, []string, error)
}); ok {
return q.QueryContext(ctx, query)
}
return o.activeDatabase().Query(query)
}
func (o *OceanBaseDB) Query(query string) ([]map[string]interface{}, []string, error) {
return o.activeDatabase().Query(query)
}
func (o *OceanBaseDB) ExecContext(ctx context.Context, query string) (int64, error) {
if e, ok := o.activeDatabase().(interface {
ExecContext(context.Context, string) (int64, error)
}); ok {
return e.ExecContext(ctx, query)
}
return o.activeDatabase().Exec(query)
}
func (o *OceanBaseDB) Exec(query string) (int64, error) {
return o.activeDatabase().Exec(query)
}
func (o *OceanBaseDB) QueryMulti(query string) ([]connection.ResultSetData, error) {
if q, ok := o.activeDatabase().(MultiResultQuerier); ok {
return q.QueryMulti(query)
}
data, columns, err := o.Query(query)
if err != nil {
return nil, err
}
return []connection.ResultSetData{{Rows: data, Columns: columns}}, nil
}
func (o *OceanBaseDB) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) {
if q, ok := o.activeDatabase().(MultiResultQuerierContext); ok {
return q.QueryMultiContext(ctx, query)
}
data, columns, err := o.QueryContext(ctx, query)
if err != nil {
return nil, err
}
return []connection.ResultSetData{{Rows: data, Columns: columns}}, nil
}
func (o *OceanBaseDB) ExecBatchContext(ctx context.Context, query string) (int64, error) {
if e, ok := o.activeDatabase().(BatchWriteExecer); ok {
return e.ExecBatchContext(ctx, query)
}
return o.ExecContext(ctx, query)
}
func (o *OceanBaseDB) GetDatabases() ([]string, error) {
return o.activeDatabase().GetDatabases()
}
func (o *OceanBaseDB) GetTables(dbName string) ([]string, error) {
return o.activeDatabase().GetTables(dbName)
}
func (o *OceanBaseDB) GetCreateStatement(dbName, tableName string) (string, error) {
return o.activeDatabase().GetCreateStatement(dbName, tableName)
}
func (o *OceanBaseDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
return o.activeDatabase().GetColumns(dbName, tableName)
}
func (o *OceanBaseDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
return o.activeDatabase().GetAllColumns(dbName)
}
func (o *OceanBaseDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
return o.activeDatabase().GetIndexes(dbName, tableName)
}
func (o *OceanBaseDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
return o.activeDatabase().GetForeignKeys(dbName, tableName)
}
func (o *OceanBaseDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
return o.activeDatabase().GetTriggers(dbName, tableName)
}
func (o *OceanBaseDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
if applier, ok := o.activeDatabase().(BatchApplier); ok {
return applier.ApplyChanges(tableName, changes)
}
return fmt.Errorf("当前 OceanBase %s 协议不支持 ApplyChanges", o.protocol)
}