🐛 fix(oceanbase): 修复 OceanBase 协议模式识别与缓存隔离

- 支持 MySQL/Oracle 租户协议在前后端统一解析
- 拒绝 Native 协议并避免误回退为 MySQL
- 修复 Oracle 模式下元数据、DDL、SQL 方言识别
- 修复连接缓存键与实际协议解析优先级不一致问题
- 补充前后端协议解析与缓存隔离回归测试
This commit is contained in:
Syngnat
2026-05-13 22:51:01 +08:00
parent 01eb2c25e0
commit f8abe60dc2
22 changed files with 454 additions and 192 deletions

View File

@@ -211,3 +211,50 @@ func TestGetCacheKey_OceanBaseProtocolParamWinsOverAliases(t *testing.T) {
t.Fatalf("expected explicit protocol=mysql to win over alias, got %s vs %s", left, right)
}
}
func TestGetCacheKey_OceanBaseExplicitProtocolOverridesConnectionParams(t *testing.T) {
base := connection.ConnectionConfig{
Type: "oceanbase",
Host: "ob.local",
Port: 2881,
User: "root@test",
Database: "app",
ConnectionParams: "connectTimeout=10",
}
modified := base
modified.OceanBaseProtocol = "mysql"
modified.ConnectionParams = "protocol=oracle&connectTimeout=10"
left := getCacheKey(base)
right := getCacheKey(modified)
if left != right {
t.Fatalf("expected explicit OceanBase protocol=mysql to override params protocol=oracle, got %s vs %s", left, right)
}
}
func TestGetCacheKey_KeepOceanBaseUnsupportedProtocolIsolation(t *testing.T) {
base := connection.ConnectionConfig{
Type: "oceanbase",
Host: "ob.local",
Port: 2881,
User: "root@test",
Database: "app",
ConnectionParams: "protocol=mysql",
}
modified := base
modified.ConnectionParams = "protocol=native"
left := getCacheKey(base)
right := getCacheKey(modified)
if left == right {
t.Fatalf("expected unsupported OceanBase protocol to stay isolated from MySQL cache key")
}
masked := base
masked.OceanBaseProtocol = "mysql"
masked.ConnectionParams = "protocol=native"
if left == getCacheKey(masked) {
t.Fatalf("expected unsupported OceanBase params protocol to stay isolated even with explicit mysql")
}
}

View File

@@ -8,29 +8,53 @@ import (
)
func normalizeOceanBaseProtocolForApp(raw string) string {
switch strings.ToLower(strings.TrimSpace(raw)) {
normalized := strings.ToLower(strings.TrimSpace(raw))
switch normalized {
case "oracle", "oracle-mode", "oracle_mode", "oboracle":
return "oracle"
case "mysql", "mysql-compatible", "mysql_compatible", "mysql-mode", "mysql_mode":
case "mysql", "mysql-compatible", "mysql_compatible", "mysql-mode", "mysql_mode", "obmysql":
return "mysql"
default:
return "mysql"
return normalized
}
}
func isSupportedOceanBaseProtocolForApp(protocol string) bool {
return protocol == "mysql" || protocol == "oracle"
}
func resolveOceanBaseProtocolForApp(config connection.ConnectionConfig) string {
if !strings.EqualFold(strings.TrimSpace(config.Type), "oceanbase") {
return ""
}
explicitProtocol := ""
if explicit := strings.TrimSpace(config.OceanBaseProtocol); explicit != "" {
return normalizeOceanBaseProtocolForApp(explicit)
explicitProtocol = normalizeOceanBaseProtocolForApp(explicit)
if !isSupportedOceanBaseProtocolForApp(explicitProtocol) {
return explicitProtocol
}
}
if protocol := resolveOceanBaseProtocolParam(config.ConnectionParams); protocol != "" {
if !isSupportedOceanBaseProtocolForApp(protocol) {
return protocol
}
if explicitProtocol != "" {
return explicitProtocol
}
return protocol
}
if protocol := resolveOceanBaseProtocolParam(config.URI); protocol != "" {
if !isSupportedOceanBaseProtocolForApp(protocol) {
return protocol
}
if explicitProtocol != "" {
return explicitProtocol
}
return protocol
}
if explicitProtocol != "" {
return explicitProtocol
}
return "mysql"
}
@@ -57,7 +81,7 @@ func resolveOceanBaseProtocolParam(raw string) string {
return ""
}
func normalizeOceanBaseConnectionParamsForCache(raw string) string {
func stripOceanBaseConnectionParamsForCache(raw string) string {
text := strings.TrimSpace(raw)
if text == "" {
return ""
@@ -69,26 +93,40 @@ func normalizeOceanBaseConnectionParamsForCache(raw string) string {
if len(values) == 0 {
return ""
}
protocol := resolveOceanBaseProtocolParam(raw)
for _, key := range []string{"protocol", "oceanBaseProtocol", "oceanbaseProtocol", "tenantMode", "compatMode", "mode"} {
values.Del(key)
}
if strings.EqualFold(protocol, "oracle") {
values.Set("protocol", "oracle")
}
return values.Encode()
}
func normalizeOceanBaseConnectionParamsForCache(raw string) string {
normalized := stripOceanBaseConnectionParamsForCache(raw)
protocol := resolveOceanBaseProtocolParam(raw)
if protocol != "" && !strings.EqualFold(protocol, "mysql") {
values, err := url.ParseQuery(strings.TrimLeft(strings.TrimSpace(normalized), "?&"))
if err != nil {
values = url.Values{}
}
values.Set("protocol", protocol)
return values.Encode()
}
return normalized
}
func normalizeOceanBaseConnectionParamsForCacheWithProtocol(raw string, protocol string) string {
normalized := normalizeOceanBaseConnectionParamsForCache(raw)
if !strings.EqualFold(protocol, "oracle") {
resolvedProtocol := normalizeOceanBaseProtocolForApp(protocol)
if resolvedProtocol == "" {
return normalizeOceanBaseConnectionParamsForCache(raw)
}
normalized := stripOceanBaseConnectionParamsForCache(raw)
if strings.EqualFold(resolvedProtocol, "mysql") {
return normalized
}
values, err := url.ParseQuery(strings.TrimLeft(strings.TrimSpace(normalized), "?&"))
if err != nil {
values = url.Values{}
}
values.Set("protocol", "oracle")
values.Set("protocol", resolvedProtocol)
return values.Encode()
}

View File

@@ -104,7 +104,7 @@ type ConnectionConfig struct {
RedisDB int `json:"redisDB,omitempty"` // Redis database index (0-15)
URI string `json:"uri,omitempty"` // Connection URI for copy/paste
ClickHouseProtocol string `json:"clickHouseProtocol,omitempty"` // auto | http | native
OceanBaseProtocol string `json:"oceanBaseProtocol,omitempty"` // mysql | oracle
OceanBaseProtocol string `json:"oceanBaseProtocol,omitempty"` // OceanBase tenant compatibility protocol: mysql | oracle
Hosts []string `json:"hosts,omitempty"` // Multi-host addresses: host:port
Topology string `json:"topology,omitempty"` // single | replica | cluster
MySQLReplicaUser string `json:"mysqlReplicaUser,omitempty"` // MySQL replica auth user

View File

@@ -151,36 +151,62 @@ func normalizeOceanBaseProtocol(raw string) string {
switch strings.ToLower(strings.TrimSpace(raw)) {
case oceanBaseProtocolOracle, "oracle-mode", "oracle_mode", "oboracle":
return oceanBaseProtocolOracle
case oceanBaseProtocolMySQL, "mysql-compatible", "mysql_compatible", "mysql-mode", "mysql_mode", "":
case oceanBaseProtocolMySQL, "mysql-compatible", "mysql_compatible", "mysql-mode", "mysql_mode", "obmysql", "":
return oceanBaseProtocolMySQL
default:
return oceanBaseProtocolMySQL
return ""
}
}
func resolveOceanBaseProtocolFromValues(values url.Values) string {
func unsupportedOceanBaseProtocolError(raw string) error {
return fmt.Errorf("OceanBase 当前仅支持 MySQL/Oracle 租户协议,不支持 %q请改为 MySQL 或 Oracle", strings.TrimSpace(raw))
}
func resolveOceanBaseProtocolFromValues(values url.Values) (string, error) {
if len(values) == 0 {
return ""
return "", nil
}
for _, key := range []string{"protocol", "oceanBaseProtocol", "oceanbaseProtocol", "tenantMode", "compatMode", "mode"} {
if value := strings.TrimSpace(values.Get(key)); value != "" {
return normalizeOceanBaseProtocol(value)
protocol := normalizeOceanBaseProtocol(value)
if protocol == "" {
return "", unsupportedOceanBaseProtocolError(value)
}
return protocol, nil
}
}
return ""
return "", nil
}
func resolveOceanBaseProtocol(config connection.ConnectionConfig) string {
func resolveOceanBaseProtocol(config connection.ConnectionConfig) (string, error) {
explicitProtocol := ""
if explicit := strings.TrimSpace(config.OceanBaseProtocol); explicit != "" {
return normalizeOceanBaseProtocol(explicit)
protocol := normalizeOceanBaseProtocol(explicit)
if protocol == "" {
return "", unsupportedOceanBaseProtocolError(explicit)
}
explicitProtocol = protocol
}
if protocol := resolveOceanBaseProtocolFromValues(connectionParamsFromText(config.ConnectionParams)); protocol != "" {
return protocol
if protocol, err := resolveOceanBaseProtocolFromValues(connectionParamsFromText(config.ConnectionParams)); err != nil {
return "", err
} else if protocol != "" {
if explicitProtocol != "" {
return explicitProtocol, nil
}
return protocol, nil
}
if protocol := resolveOceanBaseProtocolFromValues(connectionParamsFromURI(config.URI, "oceanbase", "mysql")); protocol != "" {
return protocol
if protocol, err := resolveOceanBaseProtocolFromValues(connectionParamsFromURI(config.URI, "oceanbase", "mysql")); err != nil {
return "", err
} else if protocol != "" {
if explicitProtocol != "" {
return explicitProtocol, nil
}
return protocol, nil
}
return oceanBaseProtocolMySQL
if explicitProtocol != "" {
return explicitProtocol, nil
}
return oceanBaseProtocolMySQL, nil
}
func stripOceanBaseProtocolParams(raw string) string {
@@ -256,7 +282,10 @@ func (o *OceanBaseDB) Connect(config connection.ConnectionConfig) error {
o.oracle = nil
o.protocol = oceanBaseProtocolMySQL
appliedConfig := applyOceanBaseURI(config)
protocol := resolveOceanBaseProtocol(appliedConfig)
protocol, err := resolveOceanBaseProtocol(appliedConfig)
if err != nil {
return err
}
runConfig := withoutOceanBaseProtocolParams(appliedConfig)
if protocol == oceanBaseProtocolOracle {
logger.Infof("OceanBase 使用 Oracle 协议连接:地址=%s:%d 用户=%s", runConfig.Host, runConfig.Port, runConfig.User)

View File

@@ -79,13 +79,52 @@ func TestResolveOceanBaseProtocol(t *testing.T) {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := resolveOceanBaseProtocol(tt.config); got != tt.want {
got, err := resolveOceanBaseProtocol(tt.config)
if err != nil {
t.Fatalf("resolveOceanBaseProtocol() unexpected error: %v", err)
}
if got != tt.want {
t.Fatalf("resolveOceanBaseProtocol() = %q, want %q", got, tt.want)
}
})
}
}
func TestResolveOceanBaseProtocolRejectsUnsupportedNative(t *testing.T) {
t.Parallel()
tests := []struct {
name string
config connection.ConnectionConfig
}{
{
name: "params native",
config: connection.ConnectionConfig{
Type: "oceanbase",
ConnectionParams: "protocol=native",
},
},
{
name: "explicit mysql does not mask params native",
config: connection.ConnectionConfig{
Type: "oceanbase",
OceanBaseProtocol: "mysql",
ConnectionParams: "protocol=native",
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
_, err := resolveOceanBaseProtocol(tt.config)
if err == nil || !strings.Contains(err.Error(), "不支持") {
t.Fatalf("expected unsupported protocol error, got %v", err)
}
})
}
}
func TestWithoutOceanBaseProtocolParamsStripsDriverMeta(t *testing.T) {
t.Parallel()