mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-06-07 23:19:35 +08:00
🐛 fix(oceanbase): 修复 OceanBase 协议模式识别与缓存隔离
- 支持 MySQL/Oracle 租户协议在前后端统一解析 - 拒绝 Native 协议并避免误回退为 MySQL - 修复 Oracle 模式下元数据、DDL、SQL 方言识别 - 修复连接缓存键与实际协议解析优先级不一致问题 - 补充前后端协议解析与缓存隔离回归测试
This commit is contained in:
@@ -211,3 +211,50 @@ func TestGetCacheKey_OceanBaseProtocolParamWinsOverAliases(t *testing.T) {
|
||||
t.Fatalf("expected explicit protocol=mysql to win over alias, got %s vs %s", left, right)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCacheKey_OceanBaseExplicitProtocolOverridesConnectionParams(t *testing.T) {
|
||||
base := connection.ConnectionConfig{
|
||||
Type: "oceanbase",
|
||||
Host: "ob.local",
|
||||
Port: 2881,
|
||||
User: "root@test",
|
||||
Database: "app",
|
||||
ConnectionParams: "connectTimeout=10",
|
||||
}
|
||||
modified := base
|
||||
modified.OceanBaseProtocol = "mysql"
|
||||
modified.ConnectionParams = "protocol=oracle&connectTimeout=10"
|
||||
|
||||
left := getCacheKey(base)
|
||||
right := getCacheKey(modified)
|
||||
if left != right {
|
||||
t.Fatalf("expected explicit OceanBase protocol=mysql to override params protocol=oracle, got %s vs %s", left, right)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCacheKey_KeepOceanBaseUnsupportedProtocolIsolation(t *testing.T) {
|
||||
base := connection.ConnectionConfig{
|
||||
Type: "oceanbase",
|
||||
Host: "ob.local",
|
||||
Port: 2881,
|
||||
User: "root@test",
|
||||
Database: "app",
|
||||
ConnectionParams: "protocol=mysql",
|
||||
}
|
||||
modified := base
|
||||
modified.ConnectionParams = "protocol=native"
|
||||
|
||||
left := getCacheKey(base)
|
||||
right := getCacheKey(modified)
|
||||
if left == right {
|
||||
t.Fatalf("expected unsupported OceanBase protocol to stay isolated from MySQL cache key")
|
||||
}
|
||||
|
||||
masked := base
|
||||
masked.OceanBaseProtocol = "mysql"
|
||||
masked.ConnectionParams = "protocol=native"
|
||||
|
||||
if left == getCacheKey(masked) {
|
||||
t.Fatalf("expected unsupported OceanBase params protocol to stay isolated even with explicit mysql")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,29 +8,53 @@ import (
|
||||
)
|
||||
|
||||
func normalizeOceanBaseProtocolForApp(raw string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(raw)) {
|
||||
normalized := strings.ToLower(strings.TrimSpace(raw))
|
||||
switch normalized {
|
||||
case "oracle", "oracle-mode", "oracle_mode", "oboracle":
|
||||
return "oracle"
|
||||
case "mysql", "mysql-compatible", "mysql_compatible", "mysql-mode", "mysql_mode":
|
||||
case "mysql", "mysql-compatible", "mysql_compatible", "mysql-mode", "mysql_mode", "obmysql":
|
||||
return "mysql"
|
||||
default:
|
||||
return "mysql"
|
||||
return normalized
|
||||
}
|
||||
}
|
||||
|
||||
func isSupportedOceanBaseProtocolForApp(protocol string) bool {
|
||||
return protocol == "mysql" || protocol == "oracle"
|
||||
}
|
||||
|
||||
func resolveOceanBaseProtocolForApp(config connection.ConnectionConfig) string {
|
||||
if !strings.EqualFold(strings.TrimSpace(config.Type), "oceanbase") {
|
||||
return ""
|
||||
}
|
||||
explicitProtocol := ""
|
||||
if explicit := strings.TrimSpace(config.OceanBaseProtocol); explicit != "" {
|
||||
return normalizeOceanBaseProtocolForApp(explicit)
|
||||
explicitProtocol = normalizeOceanBaseProtocolForApp(explicit)
|
||||
if !isSupportedOceanBaseProtocolForApp(explicitProtocol) {
|
||||
return explicitProtocol
|
||||
}
|
||||
}
|
||||
if protocol := resolveOceanBaseProtocolParam(config.ConnectionParams); protocol != "" {
|
||||
if !isSupportedOceanBaseProtocolForApp(protocol) {
|
||||
return protocol
|
||||
}
|
||||
if explicitProtocol != "" {
|
||||
return explicitProtocol
|
||||
}
|
||||
return protocol
|
||||
}
|
||||
if protocol := resolveOceanBaseProtocolParam(config.URI); protocol != "" {
|
||||
if !isSupportedOceanBaseProtocolForApp(protocol) {
|
||||
return protocol
|
||||
}
|
||||
if explicitProtocol != "" {
|
||||
return explicitProtocol
|
||||
}
|
||||
return protocol
|
||||
}
|
||||
if explicitProtocol != "" {
|
||||
return explicitProtocol
|
||||
}
|
||||
return "mysql"
|
||||
}
|
||||
|
||||
@@ -57,7 +81,7 @@ func resolveOceanBaseProtocolParam(raw string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func normalizeOceanBaseConnectionParamsForCache(raw string) string {
|
||||
func stripOceanBaseConnectionParamsForCache(raw string) string {
|
||||
text := strings.TrimSpace(raw)
|
||||
if text == "" {
|
||||
return ""
|
||||
@@ -69,26 +93,40 @@ func normalizeOceanBaseConnectionParamsForCache(raw string) string {
|
||||
if len(values) == 0 {
|
||||
return ""
|
||||
}
|
||||
protocol := resolveOceanBaseProtocolParam(raw)
|
||||
for _, key := range []string{"protocol", "oceanBaseProtocol", "oceanbaseProtocol", "tenantMode", "compatMode", "mode"} {
|
||||
values.Del(key)
|
||||
}
|
||||
if strings.EqualFold(protocol, "oracle") {
|
||||
values.Set("protocol", "oracle")
|
||||
}
|
||||
return values.Encode()
|
||||
}
|
||||
|
||||
func normalizeOceanBaseConnectionParamsForCache(raw string) string {
|
||||
normalized := stripOceanBaseConnectionParamsForCache(raw)
|
||||
protocol := resolveOceanBaseProtocolParam(raw)
|
||||
if protocol != "" && !strings.EqualFold(protocol, "mysql") {
|
||||
values, err := url.ParseQuery(strings.TrimLeft(strings.TrimSpace(normalized), "?&"))
|
||||
if err != nil {
|
||||
values = url.Values{}
|
||||
}
|
||||
values.Set("protocol", protocol)
|
||||
return values.Encode()
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
func normalizeOceanBaseConnectionParamsForCacheWithProtocol(raw string, protocol string) string {
|
||||
normalized := normalizeOceanBaseConnectionParamsForCache(raw)
|
||||
if !strings.EqualFold(protocol, "oracle") {
|
||||
resolvedProtocol := normalizeOceanBaseProtocolForApp(protocol)
|
||||
if resolvedProtocol == "" {
|
||||
return normalizeOceanBaseConnectionParamsForCache(raw)
|
||||
}
|
||||
normalized := stripOceanBaseConnectionParamsForCache(raw)
|
||||
if strings.EqualFold(resolvedProtocol, "mysql") {
|
||||
return normalized
|
||||
}
|
||||
values, err := url.ParseQuery(strings.TrimLeft(strings.TrimSpace(normalized), "?&"))
|
||||
if err != nil {
|
||||
values = url.Values{}
|
||||
}
|
||||
values.Set("protocol", "oracle")
|
||||
values.Set("protocol", resolvedProtocol)
|
||||
return values.Encode()
|
||||
}
|
||||
|
||||
|
||||
@@ -104,7 +104,7 @@ type ConnectionConfig struct {
|
||||
RedisDB int `json:"redisDB,omitempty"` // Redis database index (0-15)
|
||||
URI string `json:"uri,omitempty"` // Connection URI for copy/paste
|
||||
ClickHouseProtocol string `json:"clickHouseProtocol,omitempty"` // auto | http | native
|
||||
OceanBaseProtocol string `json:"oceanBaseProtocol,omitempty"` // mysql | oracle
|
||||
OceanBaseProtocol string `json:"oceanBaseProtocol,omitempty"` // OceanBase tenant compatibility protocol: mysql | oracle
|
||||
Hosts []string `json:"hosts,omitempty"` // Multi-host addresses: host:port
|
||||
Topology string `json:"topology,omitempty"` // single | replica | cluster
|
||||
MySQLReplicaUser string `json:"mysqlReplicaUser,omitempty"` // MySQL replica auth user
|
||||
|
||||
@@ -151,36 +151,62 @@ func normalizeOceanBaseProtocol(raw string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(raw)) {
|
||||
case oceanBaseProtocolOracle, "oracle-mode", "oracle_mode", "oboracle":
|
||||
return oceanBaseProtocolOracle
|
||||
case oceanBaseProtocolMySQL, "mysql-compatible", "mysql_compatible", "mysql-mode", "mysql_mode", "":
|
||||
case oceanBaseProtocolMySQL, "mysql-compatible", "mysql_compatible", "mysql-mode", "mysql_mode", "obmysql", "":
|
||||
return oceanBaseProtocolMySQL
|
||||
default:
|
||||
return oceanBaseProtocolMySQL
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func resolveOceanBaseProtocolFromValues(values url.Values) string {
|
||||
func unsupportedOceanBaseProtocolError(raw string) error {
|
||||
return fmt.Errorf("OceanBase 当前仅支持 MySQL/Oracle 租户协议,不支持 %q;请改为 MySQL 或 Oracle", strings.TrimSpace(raw))
|
||||
}
|
||||
|
||||
func resolveOceanBaseProtocolFromValues(values url.Values) (string, error) {
|
||||
if len(values) == 0 {
|
||||
return ""
|
||||
return "", nil
|
||||
}
|
||||
for _, key := range []string{"protocol", "oceanBaseProtocol", "oceanbaseProtocol", "tenantMode", "compatMode", "mode"} {
|
||||
if value := strings.TrimSpace(values.Get(key)); value != "" {
|
||||
return normalizeOceanBaseProtocol(value)
|
||||
protocol := normalizeOceanBaseProtocol(value)
|
||||
if protocol == "" {
|
||||
return "", unsupportedOceanBaseProtocolError(value)
|
||||
}
|
||||
return protocol, nil
|
||||
}
|
||||
}
|
||||
return ""
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func resolveOceanBaseProtocol(config connection.ConnectionConfig) string {
|
||||
func resolveOceanBaseProtocol(config connection.ConnectionConfig) (string, error) {
|
||||
explicitProtocol := ""
|
||||
if explicit := strings.TrimSpace(config.OceanBaseProtocol); explicit != "" {
|
||||
return normalizeOceanBaseProtocol(explicit)
|
||||
protocol := normalizeOceanBaseProtocol(explicit)
|
||||
if protocol == "" {
|
||||
return "", unsupportedOceanBaseProtocolError(explicit)
|
||||
}
|
||||
explicitProtocol = protocol
|
||||
}
|
||||
if protocol := resolveOceanBaseProtocolFromValues(connectionParamsFromText(config.ConnectionParams)); protocol != "" {
|
||||
return protocol
|
||||
if protocol, err := resolveOceanBaseProtocolFromValues(connectionParamsFromText(config.ConnectionParams)); err != nil {
|
||||
return "", err
|
||||
} else if protocol != "" {
|
||||
if explicitProtocol != "" {
|
||||
return explicitProtocol, nil
|
||||
}
|
||||
return protocol, nil
|
||||
}
|
||||
if protocol := resolveOceanBaseProtocolFromValues(connectionParamsFromURI(config.URI, "oceanbase", "mysql")); protocol != "" {
|
||||
return protocol
|
||||
if protocol, err := resolveOceanBaseProtocolFromValues(connectionParamsFromURI(config.URI, "oceanbase", "mysql")); err != nil {
|
||||
return "", err
|
||||
} else if protocol != "" {
|
||||
if explicitProtocol != "" {
|
||||
return explicitProtocol, nil
|
||||
}
|
||||
return protocol, nil
|
||||
}
|
||||
return oceanBaseProtocolMySQL
|
||||
if explicitProtocol != "" {
|
||||
return explicitProtocol, nil
|
||||
}
|
||||
return oceanBaseProtocolMySQL, nil
|
||||
}
|
||||
|
||||
func stripOceanBaseProtocolParams(raw string) string {
|
||||
@@ -256,7 +282,10 @@ func (o *OceanBaseDB) Connect(config connection.ConnectionConfig) error {
|
||||
o.oracle = nil
|
||||
o.protocol = oceanBaseProtocolMySQL
|
||||
appliedConfig := applyOceanBaseURI(config)
|
||||
protocol := resolveOceanBaseProtocol(appliedConfig)
|
||||
protocol, err := resolveOceanBaseProtocol(appliedConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
runConfig := withoutOceanBaseProtocolParams(appliedConfig)
|
||||
if protocol == oceanBaseProtocolOracle {
|
||||
logger.Infof("OceanBase 使用 Oracle 协议连接:地址=%s:%d 用户=%s", runConfig.Host, runConfig.Port, runConfig.User)
|
||||
|
||||
@@ -79,13 +79,52 @@ func TestResolveOceanBaseProtocol(t *testing.T) {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := resolveOceanBaseProtocol(tt.config); got != tt.want {
|
||||
got, err := resolveOceanBaseProtocol(tt.config)
|
||||
if err != nil {
|
||||
t.Fatalf("resolveOceanBaseProtocol() unexpected error: %v", err)
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Fatalf("resolveOceanBaseProtocol() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveOceanBaseProtocolRejectsUnsupportedNative(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config connection.ConnectionConfig
|
||||
}{
|
||||
{
|
||||
name: "params native",
|
||||
config: connection.ConnectionConfig{
|
||||
Type: "oceanbase",
|
||||
ConnectionParams: "protocol=native",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "explicit mysql does not mask params native",
|
||||
config: connection.ConnectionConfig{
|
||||
Type: "oceanbase",
|
||||
OceanBaseProtocol: "mysql",
|
||||
ConnectionParams: "protocol=native",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, err := resolveOceanBaseProtocol(tt.config)
|
||||
if err == nil || !strings.Contains(err.Error(), "不支持") {
|
||||
t.Fatalf("expected unsupported protocol error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithoutOceanBaseProtocolParamsStripsDriverMeta(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user