feat(iris): 新增 InterSystems IRIS 数据源支持

- 后端新增 IRIS 连接、查询、DDL、索引元数据和 DataGrid 编辑能力
- 接入 optional driver-agent、构建标签、revision 生成和变更检测流程
- 前端新增 IRIS 连接入口、方言映射、能力配置和图标展示
- 修复 IRIS 主键识别、事务开启错误处理和驱动连接关闭问题
- 补充后端、前端和构建脚本相关回归测试
Refs #408
This commit is contained in:
Syngnat
2026-05-17 10:32:08 +08:00
parent 0cde96844d
commit 992d2dee45
57 changed files with 4391 additions and 16 deletions

View File

@@ -5,7 +5,7 @@ set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
cd "$SCRIPT_DIR"
DEFAULT_DRIVERS=(mariadb oceanbase doris starrocks sphinx sqlserver sqlite duckdb dameng kingbase highgo vastbase opengauss mongodb tdengine clickhouse)
DEFAULT_DRIVERS=(mariadb oceanbase doris starrocks sphinx sqlserver sqlite duckdb dameng kingbase highgo vastbase opengauss iris mongodb tdengine clickhouse)
DEFAULT_PLATFORMS=(darwin/amd64 darwin/arm64 windows/amd64 windows/arm64 linux/amd64 linux/arm64)
DUCKDB_WINDOWS_LIBRARY_VERSION="v1.4.4"
DUCKDB_WINDOWS_LIBRARY_URL="https://github.com/duckdb/duckdb/releases/download/${DUCKDB_WINDOWS_LIBRARY_VERSION}/libduckdb-windows-amd64.zip"
@@ -42,7 +42,7 @@ normalize_driver() {
case "$name" in
doris|diros) echo "doris" ;;
open_gauss|open-gauss) echo "opengauss" ;;
mariadb|oceanbase|starrocks|sphinx|sqlserver|sqlite|duckdb|dameng|kingbase|highgo|vastbase|opengauss|mongodb|tdengine|clickhouse)
mariadb|oceanbase|starrocks|sphinx|sqlserver|sqlite|duckdb|dameng|kingbase|highgo|vastbase|opengauss|iris|mongodb|tdengine|clickhouse)
echo "$name"
;;
*)

View File

@@ -0,0 +1,12 @@
//go:build gonavi_iris_driver
package main
import "GoNavi-Wails/internal/db"
func init() {
agentDriverType = "iris"
agentDatabaseFactory = func() db.Database {
return &db.IrisDB{}
}
}

View File

@@ -224,6 +224,8 @@ const getDefaultPortByType = (type: string) => {
return 54321;
case "sqlserver":
return 1433;
case "iris":
return 1972;
case "mongodb":
return 27017;
case "highgo":
@@ -247,6 +249,7 @@ const singleHostUriSchemesByType: Record<string, string[]> = {
clickhouse: ["clickhouse"],
oracle: ["oracle"],
sqlserver: ["sqlserver"],
iris: ["iris", "intersystems"],
redis: ["redis"],
tdengine: ["tdengine"],
dameng: ["dameng", "dm"],
@@ -368,6 +371,7 @@ const supportsConnectionParamsForType = (type: string) =>
type === "opengauss" ||
type === "oracle" ||
type === "sqlserver" ||
type === "iris" ||
type === "clickhouse" ||
type === "mongodb" ||
type === "dameng" ||
@@ -390,6 +394,13 @@ const normalizeDriverType = (value: string): string => {
.toLowerCase();
if (normalized === "postgresql") return "postgres";
if (normalized === "doris") return "diros";
if (
normalized === "intersystems" ||
normalized === "intersystemsiris" ||
normalized === "inter-systems-iris" ||
normalized === "inter-systems"
)
return "iris";
if (
normalized === "open_gauss" ||
normalized === "open-gauss" ||
@@ -1980,6 +1991,9 @@ const ConnectionModal: React.FC<{
if (dbType === "oracle") {
return "oracle://user:pass@127.0.0.1:1521/ORCLPDB1";
}
if (dbType === "iris") {
return "iris://user:pass@127.0.0.1:1972/USER";
}
if (dbType === "opengauss") {
return "opengauss://user:pass@127.0.0.1:5432/db_name";
}
@@ -2006,6 +2020,8 @@ const ConnectionModal: React.FC<{
return "PREFETCH_ROWS=5000&TRACE FILE=/tmp/go-ora.trc";
case "sqlserver":
return "app name=GoNavi&packet size=32767";
case "iris":
return "timeout=30";
case "clickhouse":
return "max_execution_time=60&compress=lz4";
case "mongodb":
@@ -3869,6 +3885,11 @@ const ConnectionModal: React.FC<{
name: "SQL Server",
icon: getDbIcon("sqlserver", undefined, 36),
},
{
key: "iris",
name: "InterSystems IRIS",
icon: getDbIcon("iris", undefined, 36),
},
{
key: "sqlite",
name: "SQLite",

View File

@@ -0,0 +1,10 @@
import { describe, expect, it } from 'vitest';
import { DB_ICON_TYPES, getDbIconLabel } from './DatabaseIcons';
describe('DatabaseIcons', () => {
it('includes InterSystems IRIS in the selectable database icons', () => {
expect(DB_ICON_TYPES).toContain('iris');
expect(getDbIconLabel('iris')).toBe('InterSystems IRIS');
});
});

View File

@@ -27,6 +27,7 @@ const DB_DEFAULT_COLORS: Record<string, string> = {
vastbase: '#0066CC',
opengauss: '#2446A8',
highgo: '#00A86B',
iris: '#1F6FEB',
tdengine: '#2962FF',
diros: '#0050B3',
starrocks: '#00A6A6',
@@ -146,6 +147,9 @@ const OpenGaussIcon: React.FC<DbIconProps> = ({ size = 16, color }) => (
const HighGoIcon: React.FC<DbIconProps> = ({ size = 16, color }) => (
<ColorBadge size={size} color={color || DB_DEFAULT_COLORS.highgo} label="HG" />
);
const IrisIcon: React.FC<DbIconProps> = ({ size = 16, color }) => (
<ColorBadge size={size} color={color || DB_DEFAULT_COLORS.iris} label="IR" />
);
const TDengineIcon: React.FC<DbIconProps> = ({ size = 16, color }) => (
<ColorBadge size={size} color={color || DB_DEFAULT_COLORS.tdengine} label="TD" />
);
@@ -195,6 +199,7 @@ const DB_ICON_MAP: Record<string, React.FC<DbIconProps>> = {
vastbase: VastBaseIcon,
opengauss: OpenGaussIcon,
highgo: HighGoIcon,
iris: IrisIcon,
tdengine: TDengineIcon,
custom: CustomIcon,
};
@@ -203,7 +208,7 @@ const DB_ICON_MAP: Record<string, React.FC<DbIconProps>> = {
export const DB_ICON_TYPES: string[] = [
'mysql', 'mariadb', 'oceanbase', 'postgres', 'redis', 'mongodb', 'jvm',
'oracle', 'sqlserver', 'sqlite', 'duckdb', 'clickhouse', 'starrocks',
'kingbase', 'dameng', 'vastbase', 'opengauss', 'highgo', 'tdengine', 'custom',
'kingbase', 'dameng', 'vastbase', 'opengauss', 'highgo', 'iris', 'tdengine', 'custom',
];
/** 该类型是否有品牌 SVG 文件 */
@@ -225,7 +230,7 @@ export const getDbIconLabel = (type: string): string => {
sqlserver: 'SQL Server', clickhouse: 'ClickHouse', sqlite: 'SQLite',
starrocks: 'StarRocks',
duckdb: 'DuckDB', kingbase: '金仓', dameng: '达梦',
vastbase: 'VastBase', opengauss: 'OpenGauss', highgo: '瀚高', tdengine: 'TDengine',
vastbase: 'VastBase', opengauss: 'OpenGauss', highgo: '瀚高', iris: 'InterSystems IRIS', tdengine: 'TDengine',
custom: '自定义',
};
return labels[type?.toLowerCase()] || type;

View File

@@ -131,6 +131,12 @@ const normalizeDriverType = (value: string): string => {
normalized === 'open-gauss' ||
normalized === 'opengauss'
) return 'opengauss';
if (
normalized === 'intersystems' ||
normalized === 'intersystemsiris' ||
normalized === 'inter-systems' ||
normalized === 'inter-systems-iris'
) return 'iris';
return normalized;
};
@@ -648,6 +654,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
'open_gauss',
'open-gauss',
'sqlserver',
'iris',
'oracle',
'dameng',
]);
@@ -661,6 +668,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
'open_gauss',
'open-gauss',
'sqlserver',
'iris',
'oracle',
'dm',
]);

View File

@@ -38,6 +38,10 @@ const resolveCustomDriverDialect = (driver: string): string => {
return 'highgo';
case 'vastbase':
return 'vastbase';
case 'iris':
case 'intersystems':
case 'intersystemsiris':
return 'iris';
default:
break;
}

View File

@@ -84,6 +84,7 @@ describe('connectionModalPresentation', () => {
'highgo',
'vastbase',
'opengauss',
'iris',
'mongodb',
'redis',
'tdengine',
@@ -139,6 +140,13 @@ describe('connectionModalPresentation', () => {
'customDriver',
'customDsn',
]);
expect(resolveConnectionConfigLayout('iris').sections).toEqual([
'identity',
'uri',
'target',
'credentials',
'databaseScope',
]);
});
it('uses localized labels for layout kinds shown in the modal', () => {

View File

@@ -40,6 +40,20 @@ describe('dataSourceCapabilities', () => {
});
});
it('keeps InterSystems IRIS as an editable SQL datasource capability', () => {
expect(getDataSourceCapabilities({ type: 'iris' })).toMatchObject({
type: 'iris',
supportsQueryEditor: true,
supportsSqlQueryExport: true,
supportsCopyInsert: true,
forceReadOnlyQueryResult: false,
});
expect(getDataSourceCapabilities({ type: 'custom', driver: 'intersystemsiris' })).toMatchObject({
type: 'iris',
supportsQueryEditor: true,
});
});
it('treats OceanBase Oracle protocol as Oracle capabilities', () => {
expect(getDataSourceCapabilities({
type: 'oceanbase',

View File

@@ -18,6 +18,11 @@ const normalizeDataSourceToken = (raw: string): string => {
return 'opengauss';
case 'dm':
return 'dameng';
case 'intersystems':
case 'intersystemsiris':
case 'inter-systems':
case 'inter-systems-iris':
return 'iris';
default:
return normalized;
}
@@ -52,6 +57,7 @@ const SQL_QUERY_EXPORT_TYPES = new Set([
'vastbase',
'opengauss',
'sqlserver',
'iris',
'sqlite',
'duckdb',
'oracle',
@@ -73,6 +79,7 @@ const COPY_INSERT_TYPES = new Set([
'vastbase',
'opengauss',
'sqlserver',
'iris',
'sqlite',
'duckdb',
'oracle',

View File

@@ -19,6 +19,8 @@ describe('driver import guidance', () => {
expect(CUSTOM_CONNECTION_DRIVER_HELP).toContain('pgx');
expect(CUSTOM_CONNECTION_DRIVER_HELP).toContain('open_gauss');
expect(CUSTOM_CONNECTION_DRIVER_HELP).toContain('oceanbase');
expect(CUSTOM_CONNECTION_DRIVER_HELP).toContain('Go database/sql');
expect(CUSTOM_CONNECTION_DRIVER_HELP).toContain('ODBC/JDBC');
expect(CUSTOM_CONNECTION_DRIVER_HELP).toContain('JDBC Jar');
});
});

View File

@@ -7,4 +7,4 @@ export const DRIVER_LOCAL_IMPORT_SINGLE_FILE_HELP =
'行内“导入驱动包”仅用于单个驱动文件/总包(如 `mariadb-driver-agent`、`mariadb-driver-agent.exe`、`GoNavi-DriverAgents.zip`),不支持直接导入 JDBC Jar批量导入请使用上方“导入驱动目录”。';
export const CUSTOM_CONNECTION_DRIVER_HELP =
'已支持: mysql, starrocks, oceanbase, postgres, opengauss, sqlite, oracle, dm, kingbase别名支持 postgresql/pgx、open_gauss/open-gauss、dm8、kingbase8/kingbasees/kingbasev8。当前不支持通过 JDBC Jar 扩展驱动。';
'已支持: mysql, starrocks, oceanbase, postgres, opengauss, sqlite, oracle, dm, kingbase别名支持 postgresql/pgx、open_gauss/open-gauss、dm8、kingbase8/kingbasees/kingbasev8。请填写 GoNavi 已注册的 Go database/sql 驱动名,不能直接填写系统 ODBC/JDBC 驱动名或导入 JDBC Jar。';

View File

@@ -19,6 +19,8 @@ describe('sqlDialect', () => {
expect(resolveSqlDialect('doris')).toBe('diros');
expect(resolveSqlDialect('StarRocks')).toBe('starrocks');
expect(resolveSqlDialect('dameng')).toBe('dameng');
expect(resolveSqlDialect('InterSystems IRIS')).toBe('iris');
expect(resolveSqlDialect('custom', 'intersystemsiris')).toBe('iris');
expect(resolveSqlDialect('custom', 'kingbase8')).toBe('kingbase');
expect(resolveSqlDialect('custom', 'dm8')).toBe('dameng');
expect(resolveSqlDialect('custom', 'mariadb')).toBe('mariadb');
@@ -43,6 +45,7 @@ describe('sqlDialect', () => {
expect(values(resolveColumnTypeOptions('starrocks'))).toContain('PERCENTILE');
expect(values(resolveColumnTypeOptions('sphinx'))).toContain('text');
expect(values(resolveColumnTypeOptions('clickhouse'))).toContain('DateTime64(3)');
expect(values(resolveColumnTypeOptions('iris'))).toContain('varchar(255)');
expect(values(resolveColumnTypeOptions('tdengine'))).toContain('TIMESTAMP');
expect(values(resolveColumnTypeOptions('duckdb'))).toContain('STRUCT');
});

View File

@@ -22,6 +22,7 @@ export type SqlDialect =
| 'oracle'
| 'dameng'
| 'sqlserver'
| 'iris'
| 'sqlite'
| 'duckdb'
| 'clickhouse'
@@ -68,6 +69,12 @@ export const resolveSqlDialect = (
case 'sql_server':
case 'sql-server':
return 'sqlserver';
case 'intersystems':
case 'intersystemsiris':
case 'inter-systems':
case 'inter-systems-iris':
case 'iris':
return 'iris';
case 'doris':
case 'diros':
return 'diros';
@@ -122,6 +129,7 @@ export const resolveSqlDialect = (
if (source.includes('clickhouse')) return 'clickhouse';
if (source.includes('tdengine')) return 'tdengine';
if (source.includes('sqlserver') || source.includes('mssql')) return 'sqlserver';
if (source.includes('iris') || source.includes('intersystems')) return 'iris';
return source;
};
@@ -479,6 +487,7 @@ export const resolveColumnTypeOptions = (dbType: string): ColumnTypeOption[] =>
if (dialect === 'oracle') return ORACLE_TYPES;
if (dialect === 'dameng') return DAMENG_TYPES;
if (dialect === 'sqlserver') return SQLSERVER_TYPES;
if (dialect === 'iris') return COMMON_TYPES;
if (dialect === 'sqlite') return SQLITE_TYPES;
if (dialect === 'duckdb') return DUCKDB_TYPES;
if (dialect === 'clickhouse') return CLICKHOUSE_TYPES;

3
go.mod
View File

@@ -6,6 +6,7 @@ require (
gitea.com/kingbase/gokb v0.0.0-20201021123113-29bd62a876c3
gitee.com/chunanyong/dm v1.8.22
github.com/ClickHouse/clickhouse-go/v2 v2.43.0
github.com/caretdev/go-irisnative v0.2.1
github.com/duckdb/duckdb-go/v2 v2.5.5
github.com/go-sql-driver/mysql v1.9.3
github.com/google/uuid v1.6.0
@@ -122,3 +123,5 @@ require (
)
replace github.com/highgo/pq-sm3 => ./third_party/highgo-pq
replace github.com/caretdev/go-irisnative => ./third_party/go-irisnative

View File

@@ -20,7 +20,7 @@ func normalizeRunConfig(config connection.ConnectionConfig, dbName string) conne
if !isOceanBaseOracleProtocol(config) {
runConfig.Database = name
}
case "mysql", "mariadb", "diros", "starrocks", "sphinx", "postgres", "kingbase", "highgo", "vastbase", "opengauss", "sqlserver", "mongodb", "tdengine", "clickhouse":
case "mysql", "mariadb", "diros", "starrocks", "sphinx", "postgres", "kingbase", "highgo", "vastbase", "opengauss", "sqlserver", "iris", "intersystems", "intersystemsiris", "inter-systems", "inter-systems-iris", "mongodb", "tdengine", "clickhouse":
// 这些类型的 dbName 表示"数据库",需要写入连接配置以选择目标库。
runConfig.Database = name
case "dameng":
@@ -68,6 +68,16 @@ func normalizeSchemaAndTable(config connection.ConnectionConfig, dbName string,
}
}
if dbType == "iris" {
schema, table := db.SplitSQLQualifiedName(rawTable)
if schema != "" && table != "" {
return schema, table
}
if table != "" {
return "", table
}
}
if parts := strings.SplitN(rawTable, ".", 2); len(parts) == 2 {
schema := strings.TrimSpace(parts[0])
table := strings.TrimSpace(parts[1])

View File

@@ -90,6 +90,43 @@ func TestNormalizeRunConfig_StarRocksUsesDatabaseFromTree(t *testing.T) {
}
}
func TestNormalizeRunConfig_IRISUsesNamespaceFromTree(t *testing.T) {
t.Parallel()
runConfig := normalizeRunConfig(connection.ConnectionConfig{
Type: "iris",
Database: "USER",
}, "APP")
if runConfig.Database != "APP" {
t.Fatalf("expected IRIS namespace from tree, got %q", runConfig.Database)
}
}
func TestNormalizeSchemaAndTable_IRISDoesNotTreatNamespaceAsSchema(t *testing.T) {
t.Parallel()
schema, table := normalizeSchemaAndTable(connection.ConnectionConfig{
Type: "iris",
}, "USER", "Person")
if schema != "" || table != "Person" {
t.Fatalf("expected IRIS pure table to omit schema, got %q.%q", schema, table)
}
}
func TestNormalizeSchemaAndTable_IRISSplitsQualifiedTable(t *testing.T) {
t.Parallel()
schema, table := normalizeSchemaAndTable(connection.ConnectionConfig{
Type: "iris",
}, "USER", `"Sample.Schema"."Person.Table"`)
if schema != "Sample.Schema" || table != "Person.Table" {
t.Fatalf("expected IRIS qualified table split, got %q.%q", schema, table)
}
}
func TestNormalizeSchemaAndTable_OceanBaseOracleUsesSchemaFromDatabaseTree(t *testing.T) {
t.Parallel()

View File

@@ -231,6 +231,8 @@ func defaultPortByType(driverType string) int {
return 9000
case "highgo":
return 5866
case "iris":
return 1972
default:
return 0
}

View File

@@ -163,6 +163,9 @@ func resolveDDLDBType(config connection.ConnectionConfig) string {
if dbType == "kingbase8" || dbType == "kingbasees" || dbType == "kingbasev8" {
return "kingbase"
}
if dbType == "intersystems" || dbType == "intersystemsiris" || dbType == "inter-systems" || dbType == "inter-systems-iris" {
return "iris"
}
if dbType == "oceanbase" && isOceanBaseOracleProtocol(config) {
return "oracle"
}
@@ -194,6 +197,8 @@ func resolveDDLDBType(config connection.ConnectionConfig) string {
return "highgo"
case "vastbase":
return "vastbase"
case "iris", "intersystems", "intersystemsiris", "inter-systems", "inter-systems-iris":
return "iris"
case "oceanbase":
return "oceanbase"
}
@@ -209,6 +214,8 @@ func resolveDDLDBType(config connection.ConnectionConfig) string {
return "highgo"
case strings.Contains(driver, "vastbase"):
return "vastbase"
case strings.Contains(driver, "iris"), strings.Contains(driver, "intersystems"):
return "iris"
case strings.Contains(driver, "sqlite"):
return "sqlite"
case strings.Contains(driver, "sphinx"):
@@ -253,6 +260,16 @@ func normalizeSchemaAndTableByType(dbType string, dbName string, tableName strin
}
}
if dbType == "iris" {
schema, table := db.SplitSQLQualifiedName(rawTable)
if schema != "" && table != "" {
return schema, table
}
if table != "" {
return "", table
}
}
if parts := strings.SplitN(rawTable, ".", 2); len(parts) == 2 {
schema := strings.TrimSpace(parts[0])
table := strings.TrimSpace(parts[1])

View File

@@ -72,6 +72,7 @@ func TestResolveDDLDBType_CustomDriverAlias(t *testing.T) {
{name: "kingbase contains alias", driver: "kingbasees", want: "kingbase"},
{name: "dm alias", driver: "dm8", want: "dameng"},
{name: "sqlite alias", driver: "sqlite3", want: "sqlite"},
{name: "iris alias", driver: "InterSystems IRIS", want: "iris"},
}
for _, tc := range testCases {
@@ -106,6 +107,14 @@ func TestResolveDDLDBType_KingbaseTypeAlias(t *testing.T) {
}
}
func TestResolveDDLDBType_IRISTypeAlias(t *testing.T) {
t.Parallel()
if got := resolveDDLDBType(connection.ConnectionConfig{Type: "InterSystemsIRIS"}); got != "iris" {
t.Fatalf("expected InterSystemsIRIS type alias to resolve to iris, got %q", got)
}
}
func TestNormalizeSchemaAndTableByType_PGLikeQuotedQualifiedName(t *testing.T) {
t.Parallel()

View File

@@ -346,6 +346,7 @@ const builtinDriverManifestJSON = `{
"highgo": { "engine": "go", "version": "0.0.0-local", "checksumPolicy": "off", "downloadUrl": "builtin://activate/highgo" },
"vastbase": { "engine": "go", "version": "1.11.1", "checksumPolicy": "off", "downloadUrl": "builtin://activate/vastbase" },
"opengauss": { "engine": "go", "version": "1.11.1", "checksumPolicy": "off", "downloadUrl": "builtin://activate/opengauss" },
"iris": { "engine": "go", "version": "0.2.1", "checksumPolicy": "off", "downloadUrl": "builtin://activate/iris" },
"mongodb": { "engine": "go", "version": "2.5.0", "checksumPolicy": "off", "downloadUrl": "builtin://activate/mongodb" },
"tdengine": { "engine": "go", "version": "3.7.8", "checksumPolicy": "off", "downloadUrl": "builtin://activate/tdengine" },
"clickhouse": { "engine": "go", "version": "2.43.1", "checksumPolicy": "off", "downloadUrl": "builtin://activate/clickhouse" }
@@ -402,6 +403,7 @@ var latestDriverVersionMap = map[string]string{
"highgo": "0.0.0-local",
"vastbase": "1.11.2",
"opengauss": "1.11.1",
"iris": "0.2.1",
"mongodb": "2.5.0",
"tdengine": "3.7.8",
"clickhouse": "2.43.1",
@@ -424,6 +426,7 @@ var driverGoModulePathMap = map[string]string{
"highgo": "github.com/highgo/pq-sm3",
"vastbase": "github.com/lib/pq",
"opengauss": "github.com/lib/pq",
"iris": "github.com/caretdev/go-irisnative",
"mongodb": "go.mongodb.org/mongo-driver/v2",
"tdengine": "github.com/taosdata/driver-go/v3",
"clickhouse": "github.com/ClickHouse/clickhouse-go/v2",
@@ -1404,6 +1407,8 @@ func normalizeDriverType(driverType string) string {
return "postgres"
case "opengauss", "open_gauss", "open-gauss":
return "opengauss"
case "intersystems", "intersystemsiris", "inter-systems-iris", "inter-systems":
return "iris"
default:
return normalized
}
@@ -1485,6 +1490,7 @@ func allDriverDefinitionsWithPackages(packages map[string]pinnedDriverPackage) [
buildOptionalGoDriverDefinition("highgo", "HighGo", packages),
buildOptionalGoDriverDefinition("vastbase", "Vastbase", packages),
buildOptionalGoDriverDefinition("opengauss", "OpenGauss", packages),
buildOptionalGoDriverDefinition("iris", "InterSystems IRIS", packages),
buildOptionalGoDriverDefinition("mongodb", "MongoDB", packages),
buildOptionalGoDriverDefinition("tdengine", "TDengine", packages),
buildOptionalGoDriverDefinition("clickhouse", "ClickHouse", packages),
@@ -3804,6 +3810,8 @@ func optionalDriverBuildTag(driverType string, selectedVersion string) (string,
return "gonavi_vastbase_driver", nil
case "opengauss":
return "gonavi_opengauss_driver", nil
case "iris":
return "gonavi_iris_driver", nil
case "mongodb":
if resolveMongoDriverMajorFromVersion(selectedVersion) == 1 {
return "gonavi_mongodb_driver_v1", nil

View File

@@ -138,6 +138,7 @@ func optionalDriverAgentRevisionTestDrivers(t *testing.T) []string {
"highgo",
"vastbase",
"opengauss",
"iris",
"mongodb",
"tdengine",
"clickhouse",

View File

@@ -212,6 +212,36 @@ func TestBuiltinActivatePinnedVersionDoesNotRestrictBundleFallback(t *testing.T)
}
}
func TestIRISDriverDefinitionUsesOptionalAgent(t *testing.T) {
definition, ok := resolveDriverDefinition("iris")
if !ok {
t.Fatal("expected iris driver definition")
}
if definition.Name != "InterSystems IRIS" {
t.Fatalf("unexpected iris driver name: %q", definition.Name)
}
if driverGoModulePathMap["iris"] != "github.com/caretdev/go-irisnative" {
t.Fatalf("unexpected iris go module path: %q", driverGoModulePathMap["iris"])
}
if definition.PinnedVersion != "0.2.1" {
t.Fatalf("unexpected iris definition pinned version: %q", definition.PinnedVersion)
}
if definition.DefaultDownloadURL != "builtin://activate/iris" {
t.Fatalf("unexpected iris default download URL: %q", definition.DefaultDownloadURL)
}
if latestDriverVersionMap["iris"] != "0.2.1" {
t.Fatalf("unexpected iris pinned version: %q", latestDriverVersionMap["iris"])
}
tags, err := optionalDriverBuildTags("iris", "")
if err != nil {
t.Fatalf("resolve iris build tags failed: %v", err)
}
if tags != "gonavi_iris_driver" {
t.Fatalf("unexpected iris build tag: %q", tags)
}
}
func TestBuildOptionalDriverInstallPlanMessagePrefersDirectThenBundle(t *testing.T) {
message := buildOptionalDriverInstallPlanMessage("SQL Server", "1.9.6", false, false, false, false, 1, 2)
if !strings.Contains(message, "先尝试 1 个预编译直链") {

View File

@@ -1602,7 +1602,8 @@ func TestJVMApplyChangeFailedAuditFailureMessageIncludesUnderlyingError(t *testi
if !strings.Contains(res.Message, "失败审计写入失败") {
t.Fatalf("expected failed audit failure marker, got %q", res.Message)
}
if !strings.Contains(strings.ToLower(res.Message), "not a directory") {
lowerMessage := strings.ToLower(res.Message)
if !strings.Contains(lowerMessage, "not a directory") && !strings.Contains(lowerMessage, "system cannot find the path specified") {
t.Fatalf("expected underlying audit failure detail in message, got %q", res.Message)
}
}

View File

@@ -18,19 +18,21 @@ type CustomDB struct {
}
func (c *CustomDB) Connect(config connection.ConnectionConfig) error {
if config.Driver == "" || config.DSN == "" {
driver := strings.TrimSpace(config.Driver)
dsn := strings.TrimSpace(config.DSN)
if driver == "" || dsn == "" {
return fmt.Errorf("driver and dsn are required for custom connection")
}
// Verify driver is registered (implicit check by sql.Open)
// We might not need explicit check, sql.Open will fail or Ping will fail if driver not found.
db, err := sql.Open(config.Driver, config.DSN)
db, err := sql.Open(driver, dsn)
if err != nil {
return fmt.Errorf("打开数据库连接失败:%w", err)
return formatCustomDriverOpenError(driver, err)
}
c.conn = db
c.driver = config.Driver
c.driver = driver
c.pingTimeout = getConnectTimeout(config)
if err := c.Ping(); err != nil {
return fmt.Errorf("连接建立后验证失败:%w", err)
@@ -38,6 +40,27 @@ func (c *CustomDB) Connect(config connection.ConnectionConfig) error {
return nil
}
func formatCustomDriverOpenError(driver string, err error) error {
if err == nil {
return nil
}
if strings.Contains(strings.ToLower(err.Error()), "unknown driver") {
if isLikelySystemODBCDriverName(driver) {
return fmt.Errorf("打开数据库连接失败:自定义连接不支持直接填写系统 ODBC/JDBC 驱动名 %q请填写 GoNavi 已注册的 Go database/sql 驱动名。当前版本未注册通用 ODBC 驱动,因此暂不支持通过 %q 连接 InterSystems IRIS%w", driver, driver, err)
}
return fmt.Errorf("打开数据库连接失败:自定义连接驱动 %q 未在 GoNavi 中注册;请填写已注册的 Go database/sql 驱动名,不能填写系统 ODBC/JDBC 驱动名:%w", driver, err)
}
return fmt.Errorf("打开数据库连接失败:%w", err)
}
func isLikelySystemODBCDriverName(driver string) bool {
normalized := strings.ToLower(strings.TrimSpace(driver))
return strings.Contains(normalized, "odbc") ||
strings.Contains(normalized, "jdbc") ||
strings.Contains(normalized, "intersystems") ||
strings.Contains(normalized, "iris")
}
func (c *CustomDB) Close() error {
if c.conn != nil {
return c.conn.Close()

View File

@@ -0,0 +1,54 @@
package db
import (
"strings"
"testing"
"GoNavi-Wails/internal/connection"
)
func TestCustomDBConnectReportsUnsupportedODBCDriverName(t *testing.T) {
db := &CustomDB{}
err := db.Connect(connection.ConnectionConfig{
Driver: "InterSystems IRIS ODBC35",
DSN: "Driver={InterSystems IRIS ODBC35};Server=127.0.0.1;Port=1972;Database=USER;",
})
if err == nil {
t.Fatal("expected unsupported ODBC driver error, got nil")
}
message := err.Error()
for _, want := range []string{
"ODBC/JDBC",
"Go database/sql",
"暂不支持",
"InterSystems IRIS",
} {
if !strings.Contains(message, want) {
t.Fatalf("expected error to contain %q, got %q", want, message)
}
}
}
func TestCustomDBConnectReportsUnregisteredGoDriver(t *testing.T) {
db := &CustomDB{}
err := db.Connect(connection.ConnectionConfig{
Driver: "not-a-registered-go-driver",
DSN: "demo",
})
if err == nil {
t.Fatal("expected unregistered Go driver error, got nil")
}
message := err.Error()
for _, want := range []string{
"未在 GoNavi 中注册",
"Go database/sql",
} {
if !strings.Contains(message, want) {
t.Fatalf("expected error to contain %q, got %q", want, message)
}
}
}

View File

@@ -130,6 +130,8 @@ func normalizeDatabaseType(dbType string) string {
return "kingbase"
case "opengauss", "open_gauss", "open-gauss":
return "opengauss"
case "intersystems", "intersystemsiris", "inter-systems-iris", "inter-systems":
return "iris"
default:
return normalized
}

View File

@@ -16,6 +16,7 @@ func registerOptionalDatabaseFactories() {
registerDatabaseFactory(newOptionalDriverAgentDatabase("highgo"), "highgo")
registerDatabaseFactory(newOptionalDriverAgentDatabase("vastbase"), "vastbase")
registerDatabaseFactory(newOptionalDriverAgentDatabase("opengauss"), "opengauss", "open_gauss", "open-gauss")
registerDatabaseFactory(newOptionalDriverAgentDatabase("iris"), "iris", "intersystems")
registerDatabaseFactory(newOptionalDriverAgentDatabase("mongodb"), "mongodb")
registerDatabaseFactory(newOptionalDriverAgentDatabase("tdengine"), "tdengine")
registerDatabaseFactory(newOptionalDriverAgentDatabase("clickhouse"), "clickhouse")

View File

@@ -16,6 +16,7 @@ func registerOptionalDatabaseFactories() {
registerDatabaseFactory(newOptionalDriverAgentDatabase("highgo"), "highgo")
registerDatabaseFactory(newOptionalDriverAgentDatabase("vastbase"), "vastbase")
registerDatabaseFactory(newOptionalDriverAgentDatabase("opengauss"), "opengauss", "open_gauss", "open-gauss")
registerDatabaseFactory(newOptionalDriverAgentDatabase("iris"), "iris", "intersystems")
registerDatabaseFactory(newOptionalDriverAgentDatabase("mongodb"), "mongodb")
registerDatabaseFactory(newOptionalDriverAgentDatabase("tdengine"), "tdengine")
registerDatabaseFactory(newOptionalDriverAgentDatabase("clickhouse"), "clickhouse")

View File

@@ -17,6 +17,7 @@ func init() {
"highgo": "src-5a29a1d3685eb6b4",
"vastbase": "src-e3cfef65512feb23",
"opengauss": "src-58227ba3bc1ec894",
"iris": "src-1b072c57af08bec4",
"mongodb": "src-57fdd8bfebdcd46e",
"tdengine": "src-939715f94df1ec9c",
"clickhouse": "src-482d62ed565b3e69",

View File

@@ -34,6 +34,7 @@ var optionalGoDrivers = map[string]struct{}{
"highgo": {},
"vastbase": {},
"opengauss": {},
"iris": {},
"mongodb": {},
"tdengine": {},
"clickhouse": {},
@@ -60,6 +61,8 @@ func normalizeRuntimeDriverType(driverType string) string {
return "kingbase"
case "opengauss", "open_gauss", "open-gauss":
return "opengauss"
case "intersystems", "intersystemsiris", "inter-systems-iris", "inter-systems":
return "iris"
default:
return normalized
}
@@ -101,6 +104,8 @@ func driverDisplayName(driverType string) string {
return "Vastbase"
case "opengauss":
return "OpenGauss"
case "iris":
return "InterSystems IRIS"
case "mongodb":
return "MongoDB"
case "tdengine":

View File

@@ -113,7 +113,7 @@ func TestNewCompatibleDriversAreOptionalAgentDrivers(t *testing.T) {
tmpDir := t.TempDir()
SetExternalDriverDownloadDirectory(tmpDir)
for _, driverType := range []string{"oceanbase", "opengauss", "open_gauss", "starrocks"} {
for _, driverType := range []string{"oceanbase", "opengauss", "open_gauss", "starrocks", "iris", "intersystems"} {
if IsBuiltinDriver(driverType) {
t.Fatalf("%s 不应是免安装内置驱动", driverType)
}

960
internal/db/iris_impl.go Normal file
View File

@@ -0,0 +1,960 @@
//go:build gonavi_full_drivers || gonavi_iris_driver
package db
import (
"context"
"database/sql"
"fmt"
"net"
"net/url"
"sort"
"strconv"
"strings"
"time"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/logger"
"GoNavi-Wails/internal/ssh"
"GoNavi-Wails/internal/utils"
_ "github.com/caretdev/go-irisnative"
)
const (
defaultIRISPort = 1972
defaultIRISNamespace = "USER"
)
type IrisDB struct {
conn *sql.DB
pingTimeout time.Duration
namespace string
forwarder *ssh.LocalForwarder
}
type irisTableRef struct {
Schema string
Table string
}
func normalizeIRISNamespace(namespace string) string {
trimmed := strings.Trim(strings.TrimSpace(namespace), "/")
if trimmed == "" {
return defaultIRISNamespace
}
return trimmed
}
func applyIRISURI(config connection.ConnectionConfig) connection.ConnectionConfig {
parsed, ok := parseConnectionURI(config.URI, "iris", "intersystems")
if !ok || parsed == nil {
return config
}
next := config
if host := strings.TrimSpace(parsed.Hostname()); host != "" {
next.Host = host
}
if portText := strings.TrimSpace(parsed.Port()); portText != "" {
if port, err := strconv.Atoi(portText); err == nil && port > 0 {
next.Port = port
}
}
if parsed.User != nil {
next.User = parsed.User.Username()
if password, ok := parsed.User.Password(); ok {
next.Password = password
}
}
if namespace := strings.Trim(strings.TrimSpace(parsed.Path), "/"); namespace != "" {
next.Database = namespace
}
return next
}
func (i *IrisDB) getDSN(config connection.ConnectionConfig) string {
namespace := normalizeIRISNamespace(config.Database)
port := config.Port
if port <= 0 {
port = defaultIRISPort
}
u := &url.URL{
Scheme: "iris",
Host: net.JoinHostPort(config.Host, strconv.Itoa(port)),
Path: "/" + namespace,
}
u.User = url.UserPassword(config.User, config.Password)
q := url.Values{}
mergeConnectionParamsFromConfig(q, config, "iris", "intersystems")
u.RawQuery = q.Encode()
return u.String()
}
func (i *IrisDB) Connect(config connection.ConnectionConfig) error {
runConfig := applyIRISURI(config)
if runConfig.Port <= 0 {
runConfig.Port = defaultIRISPort
}
i.namespace = normalizeIRISNamespace(runConfig.Database)
cleanupOnFailure := true
defer func() {
if !cleanupOnFailure {
return
}
if i.conn != nil {
_ = i.conn.Close()
i.conn = nil
}
if i.forwarder != nil {
_ = i.forwarder.Close()
i.forwarder = nil
}
}()
if runConfig.UseSSH {
logger.Infof("InterSystems IRIS 使用 SSH 连接:地址=%s:%d 用户=%s", runConfig.Host, runConfig.Port, runConfig.User)
forwarder, err := ssh.GetOrCreateLocalForwarder(runConfig.SSH, runConfig.Host, runConfig.Port)
if err != nil {
return fmt.Errorf("创建 SSH 隧道失败:%w", err)
}
i.forwarder = forwarder
host, portStr, err := net.SplitHostPort(forwarder.LocalAddr)
if err != nil {
return fmt.Errorf("解析本地转发地址失败:%w", err)
}
port, err := strconv.Atoi(portStr)
if err != nil {
return fmt.Errorf("解析本地端口失败:%w", err)
}
runConfig.Host = host
runConfig.Port = port
runConfig.UseSSH = false
logger.Infof("InterSystems IRIS 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
}
db, err := sql.Open("iris", i.getDSN(runConfig))
if err != nil {
return fmt.Errorf("打开数据库连接失败:%w", err)
}
i.conn = db
i.pingTimeout = getConnectTimeout(runConfig)
if err := i.Ping(); err != nil {
return fmt.Errorf("连接建立后验证失败:%w", err)
}
cleanupOnFailure = false
return nil
}
func (i *IrisDB) Close() error {
if i.forwarder != nil {
if err := i.forwarder.Close(); err != nil {
logger.Warnf("关闭 InterSystems IRIS SSH 端口转发失败:%v", err)
}
i.forwarder = nil
}
if i.conn != nil {
return i.conn.Close()
}
return nil
}
func (i *IrisDB) Ping() error {
if i.conn == nil {
return fmt.Errorf("连接未打开")
}
timeout := i.pingTimeout
if timeout <= 0 {
timeout = 5 * time.Second
}
ctx, cancel := utils.ContextWithTimeout(timeout)
defer cancel()
return i.conn.PingContext(ctx)
}
func (i *IrisDB) QueryMulti(query string) ([]connection.ResultSetData, error) {
if i.conn == nil {
return nil, fmt.Errorf("连接未打开")
}
rows, err := i.conn.Query(query)
if err != nil {
return nil, err
}
defer rows.Close()
return scanMultiRows(rows)
}
func (i *IrisDB) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) {
if i.conn == nil {
return nil, fmt.Errorf("连接未打开")
}
rows, err := i.conn.QueryContext(ctx, query)
if err != nil {
return nil, err
}
defer rows.Close()
return scanMultiRows(rows)
}
func (i *IrisDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
if i.conn == nil {
return nil, nil, fmt.Errorf("连接未打开")
}
rows, err := i.conn.QueryContext(ctx, query)
if err != nil {
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
}
func (i *IrisDB) Query(query string) ([]map[string]interface{}, []string, error) {
if i.conn == nil {
return nil, nil, fmt.Errorf("连接未打开")
}
rows, err := i.conn.Query(query)
if err != nil {
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
}
func (i *IrisDB) ExecContext(ctx context.Context, query string) (int64, error) {
if i.conn == nil {
return 0, fmt.Errorf("连接未打开")
}
res, err := i.conn.ExecContext(ctx, query)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
func (i *IrisDB) ExecBatchContext(ctx context.Context, query string) (int64, error) {
return i.ExecContext(ctx, query)
}
func (i *IrisDB) Exec(query string) (int64, error) {
if i.conn == nil {
return 0, fmt.Errorf("连接未打开")
}
res, err := i.conn.Exec(query)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
func (i *IrisDB) GetDatabases() ([]string, error) {
namespace := strings.TrimSpace(i.namespace)
if namespace != "" {
return []string{namespace}, nil
}
data, _, err := i.Query(`SELECT DISTINCT TABLE_CATALOG FROM INFORMATION_SCHEMA.TABLES`)
if err != nil {
return nil, err
}
var namespaces []string
seen := map[string]struct{}{}
for _, row := range data {
name := strings.TrimSpace(rowString(row, "TABLE_CATALOG", "table_catalog"))
if name == "" {
continue
}
if _, ok := seen[name]; ok {
continue
}
seen[name] = struct{}{}
namespaces = append(namespaces, name)
}
sort.Strings(namespaces)
return namespaces, nil
}
func (i *IrisDB) GetTables(dbName string) ([]string, error) {
data, _, err := i.Query(`SELECT * FROM INFORMATION_SCHEMA.TABLES`)
if err != nil {
return nil, err
}
var tables []string
seen := map[string]struct{}{}
for _, row := range data {
tableType := strings.ToUpper(strings.TrimSpace(rowString(row, "TABLE_TYPE", "table_type")))
if tableType != "" && tableType != "TABLE" && tableType != "BASE TABLE" {
continue
}
schema := strings.TrimSpace(rowString(row, "TABLE_SCHEMA", "table_schema", "SCHEMA_NAME", "schema_name"))
table := strings.TrimSpace(rowString(row, "TABLE_NAME", "table_name"))
if table == "" || isIRISSystemSchema(schema) {
continue
}
name := table
if schema != "" {
name = schema + "." + table
}
if _, ok := seen[name]; ok {
continue
}
seen[name] = struct{}{}
tables = append(tables, name)
}
sort.Strings(tables)
return tables, nil
}
func (i *IrisDB) GetCreateStatement(dbName, tableName string) (string, error) {
ref, err := parseIRISTableRef(dbName, tableName)
if err != nil {
return "", err
}
columns, err := i.GetColumns(dbName, tableName)
if err != nil {
return "", err
}
if len(columns) == 0 {
return "", fmt.Errorf("未找到表字段:%s", tableName)
}
indexes, idxErr := i.GetIndexes(dbName, tableName)
if idxErr != nil {
indexes = nil
}
return buildIRISCreateTableDDL(ref, columns, indexes), nil
}
func (i *IrisDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
ref, err := parseIRISTableRef(dbName, tableName)
if err != nil {
return nil, err
}
data, _, err := i.Query(buildIRISInfoSchemaWhereQuery("INFORMATION_SCHEMA.COLUMNS", ref))
if err != nil {
return nil, err
}
indexes, _ := i.GetIndexes(dbName, tableName)
keyByColumn := irisColumnKeyMap(indexes)
columns := make([]connection.ColumnDefinition, 0, len(data))
for _, row := range data {
name := strings.TrimSpace(rowString(row, "COLUMN_NAME", "column_name"))
if name == "" {
continue
}
key := keyByColumn[name]
if primary, ok := irisBoolFromRow(row, "PRIMARY_KEY", "primary_key"); ok && primary {
key = "PRI"
} else if key == "" {
if unique, ok := irisBoolFromRow(row, "UNIQUE_COLUMN", "unique_column", "IS_UNIQUE", "is_unique", "UNIQUE", "unique"); ok && unique {
key = "UNI"
}
}
col := connection.ColumnDefinition{
Name: name,
Type: buildIRISColumnType(row),
Nullable: normalizeIRISNullable(rowString(row, "IS_NULLABLE", "is_nullable")),
Key: key,
Extra: "",
Comment: rowString(row, "DESCRIPTION", "description", "COMMENT", "comment"),
}
if rawDefault, ok := rowValue(row, "COLUMN_DEFAULT", "column_default"); ok && rawDefault != nil {
def := strings.TrimSpace(fmt.Sprintf("%v", rawDefault))
if def != "" {
col.Default = &def
}
}
columns = append(columns, col)
}
sort.SliceStable(columns, func(a, b int) bool {
return rowOrdinal(data, columns[a].Name) < rowOrdinal(data, columns[b].Name)
})
return columns, nil
}
func (i *IrisDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
data, _, err := i.Query(`SELECT * FROM INFORMATION_SCHEMA.COLUMNS`)
if err != nil {
return nil, err
}
cols := make([]connection.ColumnDefinitionWithTable, 0, len(data))
for _, row := range data {
schema := strings.TrimSpace(rowString(row, "TABLE_SCHEMA", "table_schema"))
table := strings.TrimSpace(rowString(row, "TABLE_NAME", "table_name"))
name := strings.TrimSpace(rowString(row, "COLUMN_NAME", "column_name"))
if table == "" || name == "" || isIRISSystemSchema(schema) {
continue
}
tableName := table
if schema != "" {
tableName = schema + "." + table
}
cols = append(cols, connection.ColumnDefinitionWithTable{
TableName: tableName,
Name: name,
Type: buildIRISColumnType(row),
})
}
sort.SliceStable(cols, func(a, b int) bool {
if cols[a].TableName == cols[b].TableName {
return cols[a].Name < cols[b].Name
}
return cols[a].TableName < cols[b].TableName
})
return cols, nil
}
func (i *IrisDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
ref, err := parseIRISTableRef(dbName, tableName)
if err != nil {
return nil, err
}
data, _, err := i.Query(buildIRISInfoSchemaWhereQuery("INFORMATION_SCHEMA.INDEXES", ref))
if err != nil {
return nil, err
}
indexes := make([]connection.IndexDefinition, 0, len(data))
for _, row := range data {
name := strings.TrimSpace(rowString(row, "INDEX_NAME", "index_name", "KEY_NAME", "key_name", "CONSTRAINT_NAME", "constraint_name"))
column := strings.TrimSpace(rowString(row, "COLUMN_NAME", "column_name"))
primary, hasPrimaryFlag := irisBoolFromRow(row, "PRIMARY_KEY", "primary_key")
if name == "" && hasPrimaryFlag && primary {
name = "PRIMARY"
}
if name == "" || column == "" {
continue
}
indexType := normalizeIRISIndexType(rowString(row, "INDEX_TYPE", "index_type", "TYPE", "type"))
if hasPrimaryFlag && primary {
indexType = "PRIMARY"
}
nonUnique := parseIRISNonUnique(row)
indexes = append(indexes, connection.IndexDefinition{
Name: name,
ColumnName: column,
NonUnique: nonUnique,
SeqInIndex: parseIRISInt(rowValueAny(row, "ORDINAL_POSITION", "ordinal_position", "SEQ_IN_INDEX", "seq_in_index", "KEY_SEQ", "key_seq")),
IndexType: indexType,
})
}
sort.SliceStable(indexes, func(a, b int) bool {
if indexes[a].Name == indexes[b].Name {
return indexes[a].SeqInIndex < indexes[b].SeqInIndex
}
return indexes[a].Name < indexes[b].Name
})
return indexes, nil
}
func (i *IrisDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
return []connection.ForeignKeyDefinition{}, nil
}
func (i *IrisDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
return []connection.TriggerDefinition{}, nil
}
func (i *IrisDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
if i.conn == nil {
return fmt.Errorf("连接未打开")
}
tx, err := i.conn.Begin()
if err != nil {
return err
}
defer tx.Rollback()
for _, keys := range changes.Deletes {
query, args, ok := buildIRISDeleteSQL(tableName, keys)
if !ok {
continue
}
res, err := tx.Exec(query, args...)
if err != nil {
return fmt.Errorf("删除失败:%v", err)
}
if err := requireSingleRowAffected(res, "删除"); err != nil {
return err
}
}
for _, update := range changes.Updates {
query, args, ok, err := buildIRISUpdateSQL(tableName, update)
if err != nil {
return err
}
if !ok {
continue
}
res, err := tx.Exec(query, args...)
if err != nil {
return fmt.Errorf("更新失败:%v", err)
}
if err := requireSingleRowAffected(res, "更新"); err != nil {
return err
}
}
for _, row := range changes.Inserts {
query, args, ok := buildIRISInsertSQL(tableName, row)
if !ok {
continue
}
res, err := tx.Exec(query, args...)
if err != nil {
return fmt.Errorf("插入失败:%v", err)
}
if affected, err := res.RowsAffected(); err == nil && affected == 0 {
return fmt.Errorf("插入未生效:未影响任何行")
}
}
return tx.Commit()
}
func buildIRISInfoSchemaWhereQuery(table string, ref irisTableRef) string {
conditions := []string{fmt.Sprintf("TABLE_NAME = '%s'", irisSQLLiteral(ref.Table))}
if ref.Schema != "" {
conditions = append(conditions, fmt.Sprintf("TABLE_SCHEMA = '%s'", irisSQLLiteral(ref.Schema)))
}
orderBy := ""
switch strings.ToUpper(strings.TrimSpace(table)) {
case "INFORMATION_SCHEMA.COLUMNS":
orderBy = " ORDER BY ORDINAL_POSITION"
case "INFORMATION_SCHEMA.INDEXES":
orderBy = " ORDER BY INDEX_NAME, ORDINAL_POSITION"
}
return fmt.Sprintf("SELECT * FROM %s WHERE %s%s", table, strings.Join(conditions, " AND "), orderBy)
}
func parseIRISTableRef(defaultSchema, raw string) (irisTableRef, error) {
text := strings.TrimSpace(raw)
if text == "" {
return irisTableRef{}, fmt.Errorf("表名不能为空")
}
if schemaPart, tablePart, ok := splitIRISTablePath(text); ok {
schema := cleanIRISIdentifier(schemaPart)
table := cleanIRISIdentifier(tablePart)
if table == "" {
return irisTableRef{}, fmt.Errorf("表名不能为空")
}
return irisTableRef{Schema: schema, Table: table}, nil
}
return irisTableRef{Schema: cleanIRISIdentifier(defaultSchema), Table: cleanIRISIdentifier(text)}, nil
}
func splitIRISTablePath(raw string) (schemaPart, tablePart string, ok bool) {
inQuote := false
for idx := 0; idx < len(raw); idx++ {
switch raw[idx] {
case '"':
if inQuote && idx+1 < len(raw) && raw[idx+1] == '"' {
idx++
continue
}
inQuote = !inQuote
case '.':
if !inQuote {
return raw[:idx], raw[idx+1:], true
}
}
}
return "", raw, false
}
func cleanIRISIdentifier(raw string) string {
text := strings.TrimSpace(raw)
text = strings.Trim(text, `"`)
return strings.ReplaceAll(text, `""`, `"`)
}
func irisSQLLiteral(raw string) string {
return strings.ReplaceAll(raw, "'", "''")
}
func irisQuoteIdent(name string) string {
text := cleanIRISIdentifier(name)
text = strings.ReplaceAll(text, `"`, `""`)
return `"` + text + `"`
}
func irisQuoteTable(raw string) string {
ref, err := parseIRISTableRef("", raw)
if err != nil {
return irisQuoteIdent(raw)
}
if ref.Schema != "" {
return irisQuoteIdent(ref.Schema) + "." + irisQuoteIdent(ref.Table)
}
return irisQuoteIdent(ref.Table)
}
func isIRISSystemSchema(schema string) bool {
normalized := strings.ToUpper(strings.TrimSpace(schema))
return normalized == "INFORMATION_SCHEMA" ||
strings.HasPrefix(normalized, "%") ||
strings.HasPrefix(normalized, "SYS")
}
func rowValue(row map[string]interface{}, keys ...string) (interface{}, bool) {
for _, key := range keys {
if value, ok := row[key]; ok {
return value, true
}
for existing, value := range row {
if strings.EqualFold(existing, key) {
return value, true
}
}
}
return nil, false
}
func rowValueAny(row map[string]interface{}, keys ...string) interface{} {
value, _ := rowValue(row, keys...)
return value
}
func rowString(row map[string]interface{}, keys ...string) string {
value, ok := rowValue(row, keys...)
if !ok || value == nil {
return ""
}
return fmt.Sprintf("%v", value)
}
func parseIRISInt(value interface{}) int {
switch v := value.(type) {
case int:
return v
case int32:
return int(v)
case int64:
return int(v)
case float64:
return int(v)
case string:
n, _ := strconv.Atoi(strings.TrimSpace(v))
return n
default:
n, _ := strconv.Atoi(strings.TrimSpace(fmt.Sprintf("%v", value)))
return n
}
}
func parseIRISBool(value interface{}) (bool, bool) {
switch v := value.(type) {
case bool:
return v, true
case int:
return v != 0, true
case int64:
return v != 0, true
case float64:
return v != 0, true
case string:
switch strings.ToLower(strings.TrimSpace(v)) {
case "1", "true", "t", "yes", "y":
return true, true
case "0", "false", "f", "no", "n":
return false, true
}
}
return false, false
}
func irisBoolFromRow(row map[string]interface{}, keys ...string) (bool, bool) {
value, ok := rowValue(row, keys...)
if !ok {
return false, false
}
return parseIRISBool(value)
}
func parseIRISNonUnique(row map[string]interface{}) int {
if primary, ok := irisBoolFromRow(row, "PRIMARY_KEY", "primary_key"); ok && primary {
return 0
}
if value, ok := rowValue(row, "NON_UNIQUE", "non_unique"); ok {
if enabled, ok := parseIRISBool(value); ok {
if enabled {
return 1
}
return 0
}
n := parseIRISInt(value)
if n != 0 {
return 1
}
return 0
}
if value, ok := rowValue(row, "IS_UNIQUE", "is_unique", "UNIQUE", "unique"); ok {
if unique, ok := parseIRISBool(value); ok && unique {
return 0
}
}
if unique, ok := irisBoolFromRow(row, "UNIQUE_COLUMN", "unique_column"); ok && unique {
return 0
}
return 1
}
func normalizeIRISIndexType(raw string) string {
text := strings.ToUpper(strings.TrimSpace(raw))
if text == "" {
return "BTREE"
}
return text
}
func normalizeIRISNullable(raw string) string {
switch strings.ToUpper(strings.TrimSpace(raw)) {
case "NO", "N", "FALSE", "0":
return "NO"
default:
return "YES"
}
}
func buildIRISColumnType(row map[string]interface{}) string {
dataType := strings.TrimSpace(rowString(row, "DATA_TYPE", "data_type", "TYPE_NAME", "type_name"))
if dataType == "" {
dataType = "VARCHAR"
}
upper := strings.ToUpper(dataType)
charLength := parseIRISInt(rowValueAny(row, "CHARACTER_MAXIMUM_LENGTH", "character_maximum_length", "CHARACTER_MAX_LENGTH", "character_max_length"))
precision := parseIRISInt(rowValueAny(row, "NUMERIC_PRECISION", "numeric_precision"))
scale := parseIRISInt(rowValueAny(row, "NUMERIC_SCALE", "numeric_scale"))
if charLength > 0 && (strings.Contains(upper, "CHAR") || strings.Contains(upper, "VARCHAR")) && !strings.Contains(dataType, "(") {
return fmt.Sprintf("%s(%d)", dataType, charLength)
}
if precision > 0 && (strings.Contains(upper, "NUMERIC") || strings.Contains(upper, "DECIMAL") || strings.Contains(upper, "NUMBER")) && !strings.Contains(dataType, "(") {
if scale > 0 {
return fmt.Sprintf("%s(%d,%d)", dataType, precision, scale)
}
return fmt.Sprintf("%s(%d)", dataType, precision)
}
return dataType
}
func rowOrdinal(rows []map[string]interface{}, columnName string) int {
for idx, row := range rows {
if strings.EqualFold(rowString(row, "COLUMN_NAME", "column_name"), columnName) {
ordinal := parseIRISInt(rowValueAny(row, "ORDINAL_POSITION", "ordinal_position"))
if ordinal > 0 {
return ordinal
}
return idx + 1
}
}
return len(rows) + 1
}
func irisColumnKeyMap(indexes []connection.IndexDefinition) map[string]string {
result := map[string]string{}
for _, idx := range indexes {
column := strings.TrimSpace(idx.ColumnName)
if column == "" {
continue
}
if isIRISPrimaryIndex(idx) {
result[column] = "PRI"
continue
}
if idx.NonUnique == 0 && result[column] == "" {
result[column] = "UNI"
}
}
return result
}
func isIRISPrimaryIndexName(name string) bool {
normalized := strings.ToUpper(strings.TrimSpace(name))
return normalized == "PRIMARY" || normalized == "PRIMARYKEY" || normalized == "IDKEY"
}
func isIRISPrimaryIndex(idx connection.IndexDefinition) bool {
return isIRISPrimaryIndexName(idx.Name) || strings.EqualFold(strings.TrimSpace(idx.IndexType), "PRIMARY")
}
func buildIRISCreateTableDDL(ref irisTableRef, columns []connection.ColumnDefinition, indexes []connection.IndexDefinition) string {
qualified := irisQuoteIdent(ref.Table)
if strings.TrimSpace(ref.Schema) != "" {
qualified = irisQuoteIdent(ref.Schema) + "." + qualified
}
lines := make([]string, 0, len(columns)+1)
primaryColumns := irisPrimaryColumns(indexes)
if len(primaryColumns) == 0 {
primaryColumns = irisPrimaryColumnsFromColumns(columns)
}
for _, col := range columns {
line := fmt.Sprintf(" %s %s", irisQuoteIdent(col.Name), strings.TrimSpace(col.Type))
if col.Default != nil && strings.TrimSpace(*col.Default) != "" {
line += " DEFAULT " + strings.TrimSpace(*col.Default)
}
if strings.EqualFold(strings.TrimSpace(col.Nullable), "NO") {
line += " NOT NULL"
}
lines = append(lines, line)
}
if len(primaryColumns) > 0 {
lines = append(lines, fmt.Sprintf(" PRIMARY KEY (%s)", irisQuoteIdentList(primaryColumns)))
}
var b strings.Builder
b.WriteString(fmt.Sprintf("CREATE TABLE %s (\n%s\n);", qualified, strings.Join(lines, ",\n")))
for _, stmt := range buildIRISCreateIndexStatements(ref, indexes) {
b.WriteString("\n\n")
b.WriteString(stmt)
}
return b.String()
}
func irisPrimaryColumns(indexes []connection.IndexDefinition) []string {
for _, group := range groupIRISIndexes(indexes) {
if group.Primary {
return group.Columns
}
}
return nil
}
func irisPrimaryColumnsFromColumns(columns []connection.ColumnDefinition) []string {
primaryColumns := make([]string, 0)
for _, column := range columns {
if strings.EqualFold(strings.TrimSpace(column.Key), "PRI") && strings.TrimSpace(column.Name) != "" {
primaryColumns = append(primaryColumns, column.Name)
}
}
return primaryColumns
}
type irisIndexGroup struct {
Name string
Columns []string
NonUnique int
IndexType string
Primary bool
}
func groupIRISIndexes(indexes []connection.IndexDefinition) []irisIndexGroup {
groupsByName := map[string]*irisIndexGroup{}
order := make([]string, 0)
for _, idx := range indexes {
name := strings.TrimSpace(idx.Name)
column := strings.TrimSpace(idx.ColumnName)
if name == "" || column == "" {
continue
}
group, ok := groupsByName[name]
if !ok {
group = &irisIndexGroup{Name: name, NonUnique: idx.NonUnique, IndexType: idx.IndexType}
groupsByName[name] = group
order = append(order, name)
}
group.Columns = append(group.Columns, column)
if idx.NonUnique == 0 {
group.NonUnique = 0
}
if isIRISPrimaryIndex(idx) {
group.Primary = true
}
}
sort.Strings(order)
groups := make([]irisIndexGroup, 0, len(order))
for _, name := range order {
group := groupsByName[name]
groups = append(groups, *group)
}
return groups
}
func buildIRISCreateIndexStatements(ref irisTableRef, indexes []connection.IndexDefinition) []string {
qualified := irisQuoteIdent(ref.Table)
if strings.TrimSpace(ref.Schema) != "" {
qualified = irisQuoteIdent(ref.Schema) + "." + qualified
}
var statements []string
for _, group := range groupIRISIndexes(indexes) {
if len(group.Columns) == 0 || group.Primary {
continue
}
unique := ""
if group.NonUnique == 0 {
unique = "UNIQUE "
}
statements = append(statements, fmt.Sprintf("CREATE %sINDEX %s ON %s (%s);", unique, irisQuoteIdent(group.Name), qualified, irisQuoteIdentList(group.Columns)))
}
return statements
}
func irisQuoteIdentList(columns []string) string {
quoted := make([]string, 0, len(columns))
for _, column := range columns {
quoted = append(quoted, irisQuoteIdent(column))
}
return strings.Join(quoted, ", ")
}
func buildIRISDeleteSQL(tableName string, keys map[string]interface{}) (string, []interface{}, bool) {
wheres, args := irisAssignments(keys, " = ?")
if len(wheres) == 0 {
return "", nil, false
}
return fmt.Sprintf("DELETE FROM %s WHERE %s", irisQuoteTable(tableName), strings.Join(wheres, " AND ")), args, true
}
func buildIRISUpdateSQL(tableName string, update connection.UpdateRow) (string, []interface{}, bool, error) {
sets, args := irisAssignments(update.Values, " = ?")
if len(sets) == 0 {
return "", nil, false, nil
}
wheres, whereArgs := irisAssignments(update.Keys, " = ?")
if len(wheres) == 0 {
return "", nil, false, fmt.Errorf("更新操作需要主键条件")
}
args = append(args, whereArgs...)
return fmt.Sprintf("UPDATE %s SET %s WHERE %s", irisQuoteTable(tableName), strings.Join(sets, ", "), strings.Join(wheres, " AND ")), args, true, nil
}
func buildIRISInsertSQL(tableName string, row map[string]interface{}) (string, []interface{}, bool) {
if len(row) == 0 {
return "", nil, false
}
keys := sortedMapKeys(row)
cols := make([]string, 0, len(keys))
placeholders := make([]string, 0, len(keys))
args := make([]interface{}, 0, len(keys))
for _, key := range keys {
cols = append(cols, irisQuoteIdent(key))
placeholders = append(placeholders, "?")
args = append(args, row[key])
}
return fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", irisQuoteTable(tableName), strings.Join(cols, ", "), strings.Join(placeholders, ", ")), args, true
}
func irisAssignments(values map[string]interface{}, suffix string) ([]string, []interface{}) {
keys := sortedMapKeys(values)
parts := make([]string, 0, len(keys))
args := make([]interface{}, 0, len(keys))
for _, key := range keys {
parts = append(parts, irisQuoteIdent(key)+suffix)
args = append(args, values[key])
}
return parts, args
}
func sortedMapKeys(values map[string]interface{}) []string {
keys := make([]string, 0, len(values))
for key := range values {
if strings.TrimSpace(key) != "" {
keys = append(keys, key)
}
}
sort.Strings(keys)
return keys
}

View File

@@ -0,0 +1,275 @@
//go:build gonavi_full_drivers || gonavi_iris_driver
package db
import (
"database/sql/driver"
"net/url"
"reflect"
"strings"
"testing"
"GoNavi-Wails/internal/connection"
)
func TestIrisDSNUsesNamespaceDefaultPortAndConnectionParams(t *testing.T) {
iris := &IrisDB{}
dsn := iris.getDSN(connection.ConnectionConfig{
Host: "db.example.com",
User: "_SYSTEM",
Password: "p@ss",
ConnectionParams: "timeout=30&ssl=1",
})
parsed, err := url.Parse(dsn)
if err != nil {
t.Fatalf("parse dsn: %v", err)
}
if parsed.Scheme != "iris" {
t.Fatalf("scheme = %q", parsed.Scheme)
}
if parsed.Host != "db.example.com:1972" {
t.Fatalf("host = %q", parsed.Host)
}
if parsed.Path != "/USER" {
t.Fatalf("namespace path = %q", parsed.Path)
}
if parsed.User.Username() != "_SYSTEM" {
t.Fatalf("user = %q", parsed.User.Username())
}
password, _ := parsed.User.Password()
if password != "p@ss" {
t.Fatalf("password = %q", password)
}
if got := parsed.Query().Get("timeout"); got != "30" {
t.Fatalf("timeout param = %q", got)
}
if got := parsed.Query().Get("ssl"); got != "1" {
t.Fatalf("ssl param = %q", got)
}
}
func TestApplyIRISURIExtractsConnectionFields(t *testing.T) {
config := applyIRISURI(connection.ConnectionConfig{
URI: "iris://user:secret@iris.local:1973/APP?timeout=30",
Database: "SHOULD_BE_REPLACED",
})
if config.Host != "iris.local" || config.Port != 1973 || config.User != "user" || config.Password != "secret" {
t.Fatalf("unexpected parsed config: %#v", config)
}
if config.Database != "APP" {
t.Fatalf("database namespace = %q", config.Database)
}
}
func TestIRISTableRefAndIdentifierQuoting(t *testing.T) {
ref, err := parseIRISTableRef("Sample", `"Person.Table"`)
if err != nil {
t.Fatalf("parse table ref: %v", err)
}
if ref.Schema != "Sample" || ref.Table != "Person.Table" {
t.Fatalf("unexpected ref: %#v", ref)
}
ref, err = parseIRISTableRef("", `"Sample"."Person""Archive"`)
if err != nil {
t.Fatalf("parse qualified table ref: %v", err)
}
if ref.Schema != "Sample" || ref.Table != `Person"Archive` {
t.Fatalf("unexpected qualified ref: %#v", ref)
}
if got := irisQuoteTable(`"Sample"."Person""Archive"`); got != `"Sample"."Person""Archive"` {
t.Fatalf("quoted table = %s", got)
}
}
func TestIRISColumnKeyMapPrefersPrimaryThenUnique(t *testing.T) {
keys := irisColumnKeyMap([]connection.IndexDefinition{
{Name: "idx_id", ColumnName: "id", NonUnique: 0},
{Name: "IDKEY", ColumnName: "id", NonUnique: 0},
{Name: "idx_email", ColumnName: "email", NonUnique: 0},
{Name: "idx_name", ColumnName: "name", NonUnique: 1},
})
if keys["id"] != "PRI" {
t.Fatalf("id key = %q", keys["id"])
}
if keys["email"] != "UNI" {
t.Fatalf("email key = %q", keys["email"])
}
if keys["name"] != "" {
t.Fatalf("name key = %q", keys["name"])
}
}
func TestBuildIRISCreateTableDDLIncludesPrimaryAndIndexes(t *testing.T) {
defaultValue := "CURRENT_TIMESTAMP"
ddl := buildIRISCreateTableDDL(
irisTableRef{Schema: "Sample", Table: "Person"},
[]connection.ColumnDefinition{
{Name: "id", Type: "INTEGER", Nullable: "NO"},
{Name: "name", Type: "VARCHAR(80)", Nullable: "NO"},
{Name: "created_at", Type: "TIMESTAMP", Nullable: "YES", Default: &defaultValue},
},
[]connection.IndexDefinition{
{Name: "app_person_pk", ColumnName: "id", NonUnique: 0, SeqInIndex: 1, IndexType: "PRIMARY"},
{Name: "idx_person_name", ColumnName: "name", NonUnique: 0, SeqInIndex: 1},
{Name: "idx_person_created_at", ColumnName: "created_at", NonUnique: 1, SeqInIndex: 1},
},
)
for _, want := range []string{
`CREATE TABLE "Sample"."Person"`,
`"id" INTEGER NOT NULL`,
`"name" VARCHAR(80) NOT NULL`,
`"created_at" TIMESTAMP DEFAULT CURRENT_TIMESTAMP`,
`PRIMARY KEY ("id")`,
`CREATE UNIQUE INDEX "idx_person_name" ON "Sample"."Person" ("name");`,
`CREATE INDEX "idx_person_created_at" ON "Sample"."Person" ("created_at");`,
} {
if !strings.Contains(ddl, want) {
t.Fatalf("ddl missing %q:\n%s", want, ddl)
}
}
if strings.Contains(ddl, `CREATE UNIQUE INDEX "app_person_pk"`) {
t.Fatalf("primary key index should not be emitted as a standalone index:\n%s", ddl)
}
}
func TestBuildIRISCreateTableDDLFallsBackToColumnPrimaryKey(t *testing.T) {
ddl := buildIRISCreateTableDDL(
irisTableRef{Schema: "Sample", Table: "Person"},
[]connection.ColumnDefinition{
{Name: "id", Type: "INTEGER", Nullable: "NO", Key: "PRI"},
{Name: "name", Type: "VARCHAR(80)", Nullable: "YES"},
},
nil,
)
if !strings.Contains(ddl, `PRIMARY KEY ("id")`) {
t.Fatalf("ddl missing primary key from column metadata:\n%s", ddl)
}
}
func TestIrisMetadataMapsColumnsAndIndexes(t *testing.T) {
dbConn, state := openOracleRecordingDB(t)
iris := &IrisDB{conn: dbConn}
columnsQuery := buildIRISInfoSchemaWhereQuery("INFORMATION_SCHEMA.COLUMNS", irisTableRef{Schema: "Sample", Table: "Person"})
indexesQuery := buildIRISInfoSchemaWhereQuery("INFORMATION_SCHEMA.INDEXES", irisTableRef{Schema: "Sample", Table: "Person"})
state.mu.Lock()
state.queryResults[columnsQuery] = oracleRecordingQueryResult{
columns: []string{"TABLE_SCHEMA", "TABLE_NAME", "COLUMN_NAME", "DATA_TYPE", "CHARACTER_MAXIMUM_LENGTH", "IS_NULLABLE", "COLUMN_DEFAULT", "ORDINAL_POSITION", "DESCRIPTION", "PRIMARY_KEY", "UNIQUE_COLUMN"},
rows: [][]driver.Value{
{"Sample", "Person", "id", "INTEGER", nil, "NO", nil, int64(1), "identifier", true, false},
{"Sample", "Person", "name", "VARCHAR", int64(80), "YES", "'anonymous'", int64(2), "display name", false, true},
},
}
state.queryResults[indexesQuery] = oracleRecordingQueryResult{
columns: []string{"INDEX_NAME", "COLUMN_NAME", "NON_UNIQUE", "ORDINAL_POSITION", "INDEX_TYPE", "PRIMARY_KEY"},
rows: [][]driver.Value{
{"app_person_pk", "id", int64(1), int64(1), "bitmap", true},
{"idx_person_name", "name", int64(0), int64(1), "", false},
},
}
state.mu.Unlock()
columns, err := iris.GetColumns("Sample", "Person")
if err != nil {
t.Fatalf("GetColumns returned error: %v", err)
}
if len(columns) != 2 {
t.Fatalf("columns len = %d", len(columns))
}
if columns[0].Name != "id" || columns[0].Key != "PRI" || columns[0].Nullable != "NO" {
t.Fatalf("unexpected id column: %#v", columns[0])
}
if columns[1].Type != "VARCHAR(80)" || columns[1].Key != "UNI" {
t.Fatalf("unexpected name column: %#v", columns[1])
}
indexes, err := iris.GetIndexes("Sample", "Person")
if err != nil {
t.Fatalf("GetIndexes returned error: %v", err)
}
if len(indexes) != 2 || indexes[0].Name != "app_person_pk" || indexes[0].IndexType != "PRIMARY" || indexes[0].NonUnique != 0 {
t.Fatalf("unexpected indexes: %#v", indexes)
}
}
func TestBuildIRISApplyChangesSQL(t *testing.T) {
deleteSQL, deleteArgs, ok := buildIRISDeleteSQL("Sample.Person", map[string]interface{}{"id": 1})
if !ok {
t.Fatal("expected delete SQL")
}
if deleteSQL != `DELETE FROM "Sample"."Person" WHERE "id" = ?` || !reflect.DeepEqual(deleteArgs, []interface{}{1}) {
t.Fatalf("unexpected delete SQL/args: %s %#v", deleteSQL, deleteArgs)
}
updateSQL, updateArgs, ok, err := buildIRISUpdateSQL("Sample.Person", connection.UpdateRow{
Keys: map[string]interface{}{"id": 1},
Values: map[string]interface{}{"name": "Alice", "updated_at": "2026-05-16"},
})
if err != nil || !ok {
t.Fatalf("expected update SQL, ok=%v err=%v", ok, err)
}
if updateSQL != `UPDATE "Sample"."Person" SET "name" = ?, "updated_at" = ? WHERE "id" = ?` {
t.Fatalf("unexpected update SQL: %s", updateSQL)
}
if !reflect.DeepEqual(updateArgs, []interface{}{"Alice", "2026-05-16", 1}) {
t.Fatalf("unexpected update args: %#v", updateArgs)
}
insertSQL, insertArgs, ok := buildIRISInsertSQL("Sample.Person", map[string]interface{}{"name": "Alice", "id": 1})
if !ok {
t.Fatal("expected insert SQL")
}
if insertSQL != `INSERT INTO "Sample"."Person" ("id", "name") VALUES (?, ?)` {
t.Fatalf("unexpected insert SQL: %s", insertSQL)
}
if !reflect.DeepEqual(insertArgs, []interface{}{1, "Alice"}) {
t.Fatalf("unexpected insert args: %#v", insertArgs)
}
}
func TestIrisApplyChangesExecutesInDeleteUpdateInsertOrder(t *testing.T) {
dbConn, state := openOracleRecordingDB(t)
iris := &IrisDB{conn: dbConn}
err := iris.ApplyChanges("Sample.Person", connection.ChangeSet{
Deletes: []map[string]interface{}{
{"id": 3},
},
Updates: []connection.UpdateRow{
{Keys: map[string]interface{}{"id": 2}, Values: map[string]interface{}{"name": "Bob"}},
},
Inserts: []map[string]interface{}{
{"id": 1, "name": "Alice"},
},
})
if err != nil {
t.Fatalf("ApplyChanges returned error: %v", err)
}
got := state.snapshotExecQueries()
want := []string{
`DELETE FROM "Sample"."Person" WHERE "id" = ?`,
`UPDATE "Sample"."Person" SET "name" = ? WHERE "id" = ?`,
`INSERT INTO "Sample"."Person" ("id", "name") VALUES (?, ?)`,
}
if !reflect.DeepEqual(got, want) {
t.Fatalf("unexpected exec queries:\nwant=%#v\ngot=%#v", want, got)
}
}
func TestBuildIRISUpdateSQLRequiresLocatorKeys(t *testing.T) {
_, _, ok, err := buildIRISUpdateSQL("Person", connection.UpdateRow{
Values: map[string]interface{}{"name": "Alice"},
})
if err == nil || ok {
t.Fatalf("expected missing keys to be rejected, ok=%v err=%v", ok, err)
}
}

21
third_party/go-irisnative/LICENSE vendored Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2025 Dmitry Maslennikov
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

11
third_party/go-irisnative/PATCHES.md vendored Normal file
View File

@@ -0,0 +1,11 @@
Local patch against github.com/caretdev/go-irisnative v0.2.1:
- Added `//go:build !windows` to `src/connection/user_posix.go`.
Upstream ships `user_windows.go` with a Windows filename suffix, but
`user_posix.go` has no build constraint, so Windows builds compile both
files and fail with `userCurrent redeclared`.
- Made `Connection.Disconnect` close the underlying TCP connection after
sending the protocol disconnect message, so `database/sql` closes do not
leak sockets.
- Made `Connection.BeginTx` return the `START TRANSACTION` error instead of
marking the connection as in-transaction when the server rejected the begin.

285
third_party/go-irisnative/README.md vendored Normal file
View File

@@ -0,0 +1,285 @@
# go-irisnative
A Golang driver for InterSystems IRIS that implements `database/sql`.
> Project status: **alpha**. API may change. Feedback and PRs welcome.
---
## Installation
```bash
# replace the module path with the final repo path when published
go get github.com/caretdev/go-irisnative
```
Register the driver by importing it for sideeffects:
```go
import (
"database/sql"
_ "github.com/caretdev/go-irisnative" // registers driver as "iris"
)
```
## DSN formats
The driver accepts a URL-style DSN (recommended) or key=value pairs.
**URL style**
```
iris://user:password@host:1972/NAMESPACE?
```
* `host` — IRIS hostname or IP
* `1972` — superserver port (default)
* `Namespace` — IRIS namespace (e.g., `USER`)
---
## Quick start (database/sql)
```go
package main
import (
"context"
"database/sql"
"fmt"
"log"
"time"
_ "github.com/caretdev/go-irisnative"
)
func main() {
dsn := "iris://_SYSTEM:SYS@localhost:1972/USER"
db, err := sql.Open("iris", dsn)
if err != nil { log.Fatal(err) }
defer db.Close()
// Connection pool tuning
db.SetMaxOpenConns(10)
db.SetMaxIdleConns(5)
db.SetConnMaxLifetime(30 * time.Minute)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
_, err = db.ExecContext(ctx, `DROP TABLE IF EXISTS demo_person`)
if err != nil { log.Fatal("drop table:", err) }
// 1) Create a table (id INT PRIMARY KEY, name VARCHAR(80))
_, err = db.ExecContext(ctx, `CREATE TABLE IF NOT EXISTS demo_person (
id INT PRIMARY KEY,
name VARCHAR(80) NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)`)
if err != nil { log.Fatal("create table:", err) }
// 2) Insert with placeholders
res, err := db.ExecContext(ctx, `INSERT INTO demo_person(id, name) VALUES(?, ?)`, 1, "Alice")
if err != nil { log.Fatal("insert:", err) }
if n, _ := res.RowsAffected(); n > 0 { fmt.Println("inserted:", n) }
// 3) Query rows
rows, err := db.QueryContext(ctx, `SELECT id, name, created_at FROM demo_person ORDER BY id`)
if err != nil { log.Fatal("query:", err) }
defer rows.Close()
for rows.Next() {
var (
id int
name string
createdAt time.Time
)
if err := rows.Scan(&id, &name, &createdAt); err != nil { log.Fatal(err) }
fmt.Printf("row: id=%d name=%s created_at=%s\n", id, name, createdAt.Format(time.RFC3339))
}
if err := rows.Err(); err != nil { log.Fatal(err) }
// 4) Prepared statement
stmt, err := db.PrepareContext(ctx, `UPDATE demo_person SET name=? WHERE id=?`)
if err != nil { log.Fatal("prepare:", err) }
defer stmt.Close()
if _, err := stmt.ExecContext(ctx, "Alice Updated", 1); err != nil { log.Fatal("update:", err) }
// 5) Transaction example
tx, err := db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadCommitted})
if err != nil { log.Fatal("begin tx:", err) }
if _, err := tx.ExecContext(ctx, `INSERT INTO demo_person(id, name) VALUES(?, ?)`, 2, "Bob"); err != nil {
tx.Rollback()
log.Fatal("tx insert:", err)
}
if err := tx.Commit(); err != nil { log.Fatal("commit:", err) }
}
```
### Query single value helper
```go
var count int
if err := db.QueryRowContext(ctx, `SELECT COUNT(*) FROM demo_person`).Scan(&count); err != nil {
log.Fatal(err)
}
fmt.Println("count=", count)
```
---
## Using with `sqlx`
`sqlx` adds nice helpers over `database/sql` like struct scanning and named queries.
```bash
go get github.com/jmoiron/sqlx
```
```go
package main
import (
"context"
"fmt"
"log"
"time"
_ "github.com/caretdev/go-irisnative" // driver
"github.com/jmoiron/sqlx"
)
type Person struct {
ID int `db:"id"`
Name string `db:"name"`
CreatedAt time.Time `db:"created_at"`
}
func create(ctx context.Context, db *sqlx.DB) {
drop(ctx, db)
_, err := db.ExecContext(ctx, `CREATE TABLE IF NOT EXISTS demo_person (
id INT PRIMARY KEY,
name VARCHAR(80) NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)`)
if err != nil {
panic(err)
}
}
func drop(ctx context.Context, db *sqlx.DB) {
_, err := db.ExecContext(ctx, `DROP TABLE IF EXISTS demo_person`)
if err != nil {
panic(err)
}
}
func main() {
ctx := context.Background()
dsn := "iris://_SYSTEM:SYS@localhost:1972/USER"
db := sqlx.MustConnect("iris", dsn)
defer db.Close()
create(ctx, db)
defer drop(ctx, db)
// Struct-based insert with NamedExec
p := Person{ID: 3, Name: "Carol"}
_, err := db.NamedExecContext(ctx,
`INSERT INTO demo_person(id, name) VALUES(:id, :name)`, p,
)
if err != nil {
log.Fatal("named insert:", err)
}
// Select into slice of structs
var people []Person
if err := db.SelectContext(ctx, &people, `SELECT id, name, created_at FROM demo_person ORDER BY id`); err != nil {
log.Fatal(err)
}
fmt.Printf("people: %#v\n", people)
// Get a single struct
var one Person
if err := db.GetContext(ctx, &one, `SELECT id, name, created_at FROM demo_person WHERE id=?`, people[0].ID); err != nil {
log.Fatal(err)
}
fmt.Printf("one: %+v\n", one)
// Named query with IN (sqlx.In)
ids := []int{1, 2, 3}
q, args, err := sqlx.In(`SELECT id, name FROM demo_person WHERE id IN (?)`, ids)
if err != nil {
log.Fatal(err)
}
q = db.Rebind(q) // ensure driver-specific bindvars
rows, err := db.QueryxContext(ctx, q, args...)
if err != nil {
log.Fatal(err)
}
defer rows.Close()
for rows.Next() {
var id int
var name string
if err := rows.Scan(&id, &name); err != nil {
log.Fatal(err)
}
fmt.Println(id, name)
}
}
```
---
## Placeholders & rebind
* The driver uses `?` positional placeholders.
* With `sqlx`, **always** call `db.Rebind(q)` after `sqlx.In(...)` to adapt placeholders.
---
## Context, timeouts & cancellations
All examples use `Context`. Set sensible timeouts to avoid runaway queries:
```go
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
```
---
## Error handling tips
* Check `rows.Err()` after iteration.
* Prefer `ExecContext`/`QueryContext` to ensure timeouts are respected.
* Wrap errors with operation context (e.g., `fmt.Errorf("create table: %w", err)`).
---
## Testing locally
1. Start IRIS and ensure SQL is enabled for your namespace (e.g., `USER`).
2. Create a SQL user with privileges to connect and create tables.
3. Verify connectivity using the DSN shown above.
---
## Compatibility
* Go: 1.21+
* InterSystems IRIS: 2025.1+
---
## License
MIT
---
## Contributing
* Run `go vet` and `go test ./...` before submitting PRs.
* Add tests for new behaviors.
* Document any DSN parameters you introduce.

85
third_party/go-irisnative/connector.go vendored Normal file
View File

@@ -0,0 +1,85 @@
package intersystems
import (
"context"
"database/sql/driver"
"errors"
"strings"
)
// Connector represents a fixed configuration for the pq driver with a given
// name. Connector satisfies the database/sql/driver Connector interface and
// can be used to create any number of DB Conn's via the database/sql OpenDB
// function.
//
// See https://golang.org/pkg/database/sql/driver/#Connector.
// See https://golang.org/pkg/database/sql/#OpenDB.
type Connector struct {
opts values
// dialer Dialer
}
// Connect returns a connection to the database using the fixed configuration
// of this Connector. Context is not used.
func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) {
return c.open(ctx)
}
// Driver returns the underlying driver of this Connector.
func (c *Connector) Driver() driver.Driver {
return &Driver{}
}
// NewConnector returns a connector for the pq driver in a fixed configuration
// with the given dsn. The returned connector can be used to create any number
// of equivalent Conn's. The returned connector is intended to be used with
// database/sql.OpenDB.
//
// See https://golang.org/pkg/database/sql/driver/#Connector.
// See https://golang.org/pkg/database/sql/#OpenDB.
func NewConnector(dsn string) (*Connector, error) {
var err error
o := make(values)
// A number of defaults are applied here, in this order:
//
// * Very low precedence defaults applied in every situation
// * Environment variables
// * Explicitly passed connection information
o["host"] = "localhost"
o["port"] = "1972"
if strings.HasPrefix(dsn, "iris://") || strings.HasPrefix(dsn, "IRIS://") {
dsn, err = ParseURL(dsn)
if err != nil {
return nil, err
}
}
if err := parseOpts(dsn, o); err != nil {
return nil, err
}
if enc, ok := o["client_encoding"]; ok && !isUTF8(enc) {
return nil, errors.New("client_encoding must be absent or 'UTF8'")
}
o["client_encoding"] = "UTF8"
return &Connector{opts: o, /*dialer: defaultDialer{}*/}, nil
}
// isUTF8 returns whether name is a fuzzy variation of the string "UTF-8".
func isUTF8(name string) bool {
s := strings.Map(alnumLowerASCII, name)
return s == "utf8" || s == "unicode"
}
func alnumLowerASCII(ch rune) rune {
if 'A' <= ch && ch <= 'Z' {
return ch + ('a' - 'A')
}
if 'a' <= ch && ch <= 'z' || '0' <= ch && ch <= '9' {
return ch
}
return -1 // discard
}

216
third_party/go-irisnative/driver.go vendored Normal file
View File

@@ -0,0 +1,216 @@
package intersystems
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"net"
"unicode"
_ "io"
_ "math"
_ "reflect"
_ "strconv"
_ "strings"
_ "time"
_ "unsafe"
"github.com/caretdev/go-irisnative/src/connection"
)
var (
ErrCouldNotDetectUsername = errors.New("intersystems: Could not detect default username. Please provide one explicitly")
)
var (
_ driver.Driver = Driver{}
)
type values map[string]string
// Driver implements database/sql/driver.Driver.
type Driver struct{}
func (d Driver) Open(name string) (driver.Conn, error) {
return Open(name)
}
func init() {
sql.Register("intersystems", &Driver{})
sql.Register("iris", &Driver{})
}
func Open(dsn string) (_ driver.Conn, err error) {
c, err := NewConnector(dsn)
if err != nil {
return nil, err
}
return c.open(context.Background())
}
type conn struct {
c connection.Connection
tx bool
}
func (c *Connector) open(ctx context.Context) (cn *conn, err error) {
o := make(values)
for k, v := range c.opts {
o[k] = v
}
host := o["host"]
addr := net.JoinHostPort(host, o["port"])
namespace := o["namespace"]
login := o["user"]
password := o["password"]
cn = &conn{}
cn.c, err = connection.Connect(addr, namespace, login, password)
if err != nil {
return nil, err
}
return cn, nil
}
func (cn *conn) Begin() (driver.Tx, error) {
return cn.c.BeginTx(driver.TxOptions{})
}
func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
return cn.c.BeginTx(opts)
}
func (cn *conn) Close() (err error) {
cn.c.Disconnect()
return nil
}
func (cn *conn) Prepare(q string) (st driver.Stmt, err error) {
return cn.c.Prepare(q)
}
func (cn *conn) Commit() error {
if !cn.tx {
panic("transaction already closed")
}
cn.tx = false
cn.c.Commit()
return nil
}
func (cn *conn) Rollback() error {
if !cn.tx {
panic("transaction already closed")
}
cn.tx = false
cn.c.Rollback()
return nil
}
func (cn *conn) Exec(query string, args []driver.NamedValue) (res driver.Result, err error) {
parameters := make([]interface{}, len(args))
for i, a := range args {
parameters[i] = a
}
_, err = cn.c.DirectUpdate(query, parameters...)
if err != nil {
return nil, err
}
return res, nil
}
func (cn *conn) Query(query string, args []driver.NamedValue) (rows driver.Rows, err error) {
parameters := make([]interface{}, len(args))
for i, a := range args {
parameters[i] = a
}
// var rs *connection.ResultSet
_, err = cn.c.Query(query, parameters...)
if err != nil {
return nil, err
}
// rows = &connection.Rows{
// cn: cn.c,
// rs: rs,
// }
return
}
func parseOpts(name string, o values) error {
s := newScanner(name)
for {
var (
keyRunes, valRunes []rune
r rune
ok bool
)
if r, ok = s.SkipSpaces(); !ok {
break
}
// Scan the key
for !unicode.IsSpace(r) && r != '=' {
keyRunes = append(keyRunes, r)
if r, ok = s.Next(); !ok {
break
}
}
// Skip any whitespace if we're not at the = yet
if r != '=' {
r, ok = s.SkipSpaces()
}
// The current character should be =
if r != '=' || !ok {
return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes))
}
// Skip any whitespace after the =
if r, ok = s.SkipSpaces(); !ok {
// If we reach the end here, the last value is just an empty string as per libpq.
o[string(keyRunes)] = ""
break
}
if r != '\'' {
for !unicode.IsSpace(r) {
if r == '\\' {
if r, ok = s.Next(); !ok {
return fmt.Errorf(`missing character after backslash`)
}
}
valRunes = append(valRunes, r)
if r, ok = s.Next(); !ok {
break
}
}
} else {
quote:
for {
if r, ok = s.Next(); !ok {
return fmt.Errorf(`unterminated quoted string literal in connection string`)
}
switch r {
case '\'':
break quote
case '\\':
r, _ = s.Next()
fallthrough
default:
valRunes = append(valRunes, r)
}
}
}
o[string(keyRunes)] = string(valRunes)
}
return nil
}

5
third_party/go-irisnative/go.mod vendored Normal file
View File

@@ -0,0 +1,5 @@
module github.com/caretdev/go-irisnative
go 1.24.3
require github.com/shopspring/decimal v1.4.0

1
third_party/go-irisnative/go.sum vendored Normal file
View File

@@ -0,0 +1 @@
github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=

34
third_party/go-irisnative/scanner.go vendored Normal file
View File

@@ -0,0 +1,34 @@
package intersystems
import "unicode"
type scanner struct {
s []rune
i int
}
// newScanner returns a new scanner initialized with the option string s.
func newScanner(s string) *scanner {
return &scanner{[]rune(s), 0}
}
// Next returns the next rune.
// It returns 0, false if the end of the text has been reached.
func (s *scanner) Next() (rune, bool) {
if s.i >= len(s.s) {
return 0, false
}
r := s.s[s.i]
s.i++
return r, true
}
// SkipSpaces returns the next non-whitespace rune.
// It returns 0, false if the end of the text has been reached.
func (s *scanner) SkipSpaces() (rune, bool) {
r, ok := s.Next()
for unicode.IsSpace(r) && ok {
r, ok = s.Next()
}
return r, ok
}

View File

@@ -0,0 +1,113 @@
package connection
import "github.com/caretdev/go-irisnative/src/iris"
func (c *Connection) ServerVersion() (result string, err error) {
err = c.ClassMethod("%SYSTEM.Version", "GetVersion", &result)
return
}
func (c *Connection) ClassMethod(class, method string, result interface{}, args ...interface{}) (err error) {
msg := NewMessage(CLASSMETHOD_VALUE)
msg.Set(class)
msg.Set(method)
msg.Set(len(args))
for _, arg := range args {
msg.Set(arg)
}
_, err = c.conn.Write(msg.Dump(c.count()))
if err != nil {
return
}
msg, err = ReadMessage(c.conn)
if err != nil {
return
}
msg.Get(result)
return
}
func (c *Connection) ClassMethodVoid(class, method string, args ...interface{}) (err error) {
msg := NewMessage(CLASSMETHOD_VOID)
msg.Set(class)
msg.Set(method)
msg.Set(len(args))
for _, arg := range args {
msg.Set(arg)
}
_, err = c.conn.Write(msg.Dump(c.count()))
if err != nil {
return
}
msg, err = ReadMessage(c.conn)
if err != nil {
return
}
return
}
func (c *Connection) Method(obj iris.Oref, method string, result interface{}, args ...interface{}) (err error) {
msg := NewMessage(METHOD_VALUE)
msg.Set(obj)
msg.Set(method)
msg.Set(len(args))
for _, arg := range args {
msg.Set(arg)
}
_, err = c.conn.Write(msg.Dump(c.count()))
if err != nil {
return
}
msg, err = ReadMessage(c.conn)
if err != nil {
return
}
msg.Get(result)
return
}
func (c *Connection) MethodVoid(obj, method string, args ...interface{}) (err error) {
msg := NewMessage(METHOD_VOID)
msg.Set(obj)
msg.Set(method)
msg.Set(len(args))
for _, arg := range args {
msg.Set(arg)
}
_, err = c.conn.Write(msg.Dump(c.count()))
if err != nil {
return
}
msg, err = ReadMessage(c.conn)
if err != nil {
return
}
return
}
func (c *Connection) PropertyGet(obj iris.Oref, property string, result interface{}) (err error) {
msg := NewMessage(PROPERTY_GET)
msg.Set(obj)
msg.Set(property)
// msg.Set(0)
_, err = c.conn.Write(msg.Dump(c.count()))
if err != nil {
return
}
msg, err = ReadMessage(c.conn)
if err != nil {
return
}
msg.Get(result)
return
}

View File

@@ -0,0 +1,141 @@
package connection
func (c *Connection) GlobalIsDefined(global string, subs ...interface{}) (bool, bool) {
msg := NewMessage(GLOBAL_DATA)
msg.Set(global)
msg.Set(len(subs))
for _, sub := range subs {
msg.Set(sub)
}
msg.Set(0)
_, err := c.conn.Write(msg.Dump(c.count()))
if err != nil {
return false, false
}
msg, err = ReadMessage(c.conn)
if err != nil {
return false, false
}
var result uint8
msg.Get(&result)
return result%10 == 1, result >= 10
}
func (c *Connection) GlobalSet(global string, value interface{}, subs ...interface{}) (err error) {
msg := NewMessage(GLOBAL_SET)
msg.Set(global)
msg.Set(len(subs))
for _, sub := range subs {
msg.Set(sub)
}
msg.Set(value)
_, err = c.conn.Write(msg.Dump(c.count()))
if err != nil {
return
}
_, err = ReadMessage(c.conn)
if err != nil {
return
}
return
}
func (c *Connection) GlobalKill(global string, subs ...interface{}) (err error) {
msg := NewMessage(GLOBAL_KILL)
msg.Set(global)
msg.Set(len(subs))
for _, sub := range subs {
msg.Set(sub)
}
_, err = c.conn.Write(msg.Dump(c.count()))
if err != nil {
return
}
_, err = ReadMessage(c.conn)
if err != nil {
return
}
return
}
func (c *Connection) GlobalGet(global string, result interface{}, subs ...interface{}) (err error) {
msg := NewMessage(GLOBAL_GET)
msg.Set(global)
msg.Set(len(subs))
for _, sub := range subs {
msg.Set(sub)
}
_, err = c.conn.Write(msg.Dump(c.count()))
if err != nil {
return
}
msg, err = ReadMessage(c.conn)
if err != nil {
return
}
msg.Get(result)
return
}
func (c *Connection) GlobalNext(global string, ind *string, subs ...interface{}) (hasNext bool, err error) {
msg := NewMessage(GLOBAL_ORDER)
msg.Set(global)
msg.Set(len(subs) + 1)
for _, sub := range subs {
msg.Set(sub)
}
msg.Set(*ind)
msg.Set(3)
if _, err = c.conn.Write(msg.Dump(c.count())); err != nil {
return
}
if msg, err = ReadMessage(c.conn); err != nil {
return
}
var result string
msg.Get(&result)
*ind = result
hasNext = result != ""
return
}
func (c *Connection) GlobalPrev(global string, ind *string, subs ...interface{}) (hasNext bool, err error) {
msg := NewMessage(GLOBAL_ORDER)
msg.Set(global)
msg.Set(len(subs) + 1)
for _, sub := range subs {
msg.Set(sub)
}
msg.Set(*ind)
msg.Set(7)
if _, err = c.conn.Write(msg.Dump(c.count())); err != nil {
return
}
if msg, err = ReadMessage(c.conn); err != nil {
return
}
var result string
msg.Get(&result)
*ind = result
hasNext = result != ""
return
}

View File

@@ -0,0 +1,148 @@
package connection
import (
"fmt"
"io"
"net"
"github.com/caretdev/go-irisnative/src/list"
)
type Message struct {
header MessageHeader
data []byte
offset uint
}
func NewMessage(messageType MessageType) Message {
return Message{
NewMessageHeader(messageType),
[]byte{},
0,
}
}
func ReadMessage(conn *net.TCPConn) (msg Message, err error) {
buffer := make([]byte, 14)
_, err = conn.Read(buffer)
if err != nil {
return
}
var header [14]byte
copy(header[:], buffer[:14])
var msgHeader = MessageHeader{header}
length := msgHeader.GetLength()
data := make([]byte, length)
var offset int = 0
var size int
for {
size, err = conn.Read(data[offset:])
if err != nil {
if err != io.EOF {
return
}
break
}
offset += size
if offset >= int(length) {
break
}
}
msg = Message{msgHeader, data, 0}
return
}
func (m *Message) AddRaw(value interface{}) {
switch v := value.(type) {
case uint16:
m.data = append(m.data, byte(v&0xff))
m.data = append(m.data, byte(v>>8&0xff))
m.offset += 2
case []byte:
m.data = append(m.data, v...)
m.offset += uint(len(v))
}
}
func (m *Message) GetRaw(value interface{}) error {
switch v := value.(type) {
case *uint16:
*v = uint16(m.data[m.offset]) | (uint16(m.data[m.offset+1]) << 8)
m.offset += 2
case *bool:
*v = (uint16(m.data[m.offset]) | (uint16(m.data[m.offset+1]) << 8)) == 1
m.offset += 2
case *[]byte:
*v = m.data[m.offset:]
m.offset = uint(len(m.data))
default:
return fmt.Errorf("unknown type: %T", v)
}
return nil
}
func (m *Message) Set(value interface{}) error {
listItem := list.NewListItem(value)
m.AddRaw(listItem.Dump())
return nil
}
func (m *Message) SetSQLText(sqlText string) error {
len := len(sqlText)
if len == 0 {
m.Set(sqlText)
return nil
}
const chunksize = 31904
chunks := len / chunksize
if len%chunksize != 0 {
chunks += 1
}
m.Set(chunks)
for i := 0; i < chunks; i++ {
begin := i * chunksize
end := (i + 1) * chunksize
if end > len {
end = len
}
m.Set(sqlText[begin:end])
}
return nil
}
func (m *Message) GetStatus() uint16 {
return m.header.GetStatus()
}
func (m *Message) Get(value interface{}) error {
listItem := list.GetListItem(m.data, &m.offset)
listItem.Get(value)
return nil
}
type AnyType struct {
listItem list.ListItem
}
func (v *AnyType) Int() int {
var value int
v.listItem.Get(&value)
return value
}
func (m *Message) GetAny() AnyType {
listItem := list.GetListItem(m.data, &m.offset)
return AnyType{listItem}
}
func (m *Message) Dump(count uint32) []byte {
m.header.SetCount(count)
m.header.SetLength(uint32(len(m.data)))
return append(m.header.header[:], m.data...)
}

View File

@@ -0,0 +1,84 @@
package connection
type MessageType string
func setUint32(buffer []byte, value uint32) {
buffer[0] = byte(value & 0xff)
buffer[1] = byte(value >> 8 & 0xff)
buffer[2] = byte(value >> 16 & 0xff)
buffer[3] = byte(value >> 24 & 0xff)
}
func getUint32(buffer []byte) uint32 {
return uint32(buffer[0]) |
uint32(buffer[1])<<8 |
uint32(buffer[2])<<16 |
uint32(buffer[3])<<24
}
const (
CONNECT MessageType = "\x43\x4e"
HANDSHAKE MessageType = "\x48\x53"
DISCONNECT MessageType = "\x44\x43"
GLOBAL_GET MessageType = "\x41\xc2"
GLOBAL_SET MessageType = "\x42\xc2"
GLOBAL_KILL MessageType = "\x43\xc2"
GLOBAL_ORDER MessageType = "\x45\xc2"
GLOBAL_DATA MessageType = "\x49\xc2"
CLASSMETHOD_VALUE MessageType = "\x4b\xc2"
CLASSMETHOD_VOID MessageType = "\x4c\xc2"
METHOD_VALUE MessageType = "\x5b\xc2"
METHOD_VOID MessageType = "\x5c\xc2"
PROPERTY_GET MessageType = "\x5d\xc2"
PROPERTY_SET MessageType = "\x5e\xc2"
DIRECT_QUERY MessageType = "DQ"
PREPARED_QUERY MessageType = "PQ"
DIRECT_UPDATE MessageType = "DU"
PREPARED_UPDATE MessageType = "PU"
PREPARE MessageType = "PP"
GET_AUTO_GENERATED_KEYS MessageType = "GG"
COMMIT MessageType = "TC"
ROLLBACK MessageType = "TR"
MULTIPLE_RESULT_SETS_FETCH_DATA MessageType = "MD"
GET_MORE_RESULTS MessageType = "MR"
FETCH_DATA MessageType = "FD"
GET_SERVER_ERROR MessageType = "OE"
)
type MessageHeader struct {
header [14]byte
}
func NewMessageHeader(messageType MessageType) MessageHeader {
header := [14]byte{}
header[12] = messageType[0]
header[13] = messageType[1]
return MessageHeader{header}
}
func (mh *MessageHeader) GetStatus() uint16 {
return uint16(mh.header[12]) | (uint16(mh.header[13]) << 8)
}
func (mh *MessageHeader) SetLength(length uint32) {
setUint32(mh.header[0:], length)
}
func (mh MessageHeader) GetLength() uint32 {
return getUint32(mh.header[0:])
}
func (mh *MessageHeader) SetCount(cnt uint32) {
setUint32(mh.header[4:], cnt)
}
func (mh *MessageHeader) SetStatementId(statementId uint32) {
setUint32(mh.header[8:], statementId)
}

View File

@@ -0,0 +1,251 @@
package connection
import (
"database/sql/driver"
"errors"
"net"
)
const VERSION_PROTOCOL uint16 = 69
type Connection struct {
conn *net.TCPConn
messageCount uint32
statement uint32
unicode bool
locale string
version uint16
info string
featureOptions uint
tx bool
}
var (
ErrCouldNotDetectUsername = errors.New("intersystems: Could not detect default username. Please provide one explicitly")
errBeginTx = errors.New("could not begin transaction")
errMultipleTx = errors.New("multiple transactions")
errReadOnlyTxNotSupported = errors.New("read-only transactions are not supported")
)
func Connect(addr string, namespace, login, password string) (connection Connection, err error) {
tcpAddr, err := net.ResolveTCPAddr("tcp", addr)
if err != nil {
return
}
conn, err := net.DialTCP("tcp", nil, tcpAddr)
if err != nil {
return
}
connection = Connection{
conn: conn,
}
if err = connection.handshake(); err != nil {
return
}
if err = connection.connect(namespace, login, password); err != nil {
return
}
// fmt.Println(connection.version, connection.info)
return
}
func (c *Connection) Disconnect() {
if c.conn == nil {
return
}
msg := NewMessage(DISCONNECT)
_, _ = c.conn.Write(msg.Dump(c.count()))
_ = c.conn.Close()
c.conn = nil
}
func (c *Connection) count() uint32 {
count := c.messageCount
c.messageCount += 1
return count
}
func (c *Connection) statementId() uint32 {
statement := c.statement
c.statement += 1
return statement
}
func (c *Connection) handshake() (err error) {
var message = NewMessage(HANDSHAKE)
message.AddRaw(VERSION_PROTOCOL)
_, err = c.conn.Write(message.Dump(c.count()))
if err != nil {
return
}
msg, err := ReadMessage(c.conn)
if err != nil {
return
}
var version uint16
msg.GetRaw(&version)
c.version = version
var unicode uint16
msg.GetRaw(&unicode)
c.unicode = unicode == 1
var locale string
msg.Get(&locale)
c.locale = locale
return
}
func encode(value string) []byte {
in := []byte(value)
length := len(in)
out := make([]byte, length)
for i := range in {
length--
temp := ((int(in[i])^0xa7)&0xff + length) & 0xff
out[length] = byte(temp<<5 | temp>>3)
}
return out
}
type FeatureOption uint
const (
OptionNone FeatureOption = 0
OptionFastSelect FeatureOption = 1
OptionFastInsert FeatureOption = 2
OptionFastSelectAndInsert FeatureOption = 3
OptionDurableTransactions FeatureOption = 4
OptionNotNullable FeatureOption = 8
OptionRedirectOutput FeatureOption = 32
)
func (c *Connection) IsOptionFastInsert() bool {
return c.featureOptions&uint(OptionFastInsert) == uint(OptionFastInsert)
}
func (c *Connection) IsOptionFastSelect() bool {
return c.featureOptions&uint(OptionFastSelect) == uint(OptionFastSelect)
}
func (c *Connection) connect(namespace, login, password string) (err error) {
msg := NewMessage(CONNECT)
msg.Set(namespace)
msg.Set(encode(login))
msg.Set(encode(password))
var user = "go"
if user, err = systemUser(); err != nil {
user = "go"
}
msg.Set(user) // machine user name
msg.Set("go-machine") // machine name
msg.Set("libirisnative") // application name
msg.Set("") // ?
msg.Set("go") // SharedMemoryFlag?
msg.Set("") // EventClass
msg.Set(1) // AutoCommit ? 1 : 2
msg.Set(0) // IsolationLevel
var featureOptions = OptionNone
featureOptions += OptionFastSelect
// Tricky to make it fully working yet
// featureOptions += OptionFastInsert
featureOptions += OptionDurableTransactions
featureOptions += OptionRedirectOutput
msg.Set(int(featureOptions)) // FeatureOption
_, err = c.conn.Write(msg.Dump(c.count()))
if err != nil {
return
}
msg, err = ReadMessage(c.conn)
if err != nil {
return
}
if status := msg.GetStatus(); status == 417 {
var errorMsg string
msg.Get(&errorMsg)
err = errors.New(errorMsg)
return
}
var info string
msg.Get(&info)
c.info = info
var (
delimited_ids bool
ignored int
isolationLevel int
serverJobNumber string
sqlEmptyString int
serverFeatureOptions uint
)
msg.Get(&delimited_ids)
msg.Get(&ignored)
msg.Get(&isolationLevel)
msg.Get(&serverJobNumber)
msg.Get(&sqlEmptyString)
msg.Get(&serverFeatureOptions)
c.featureOptions = serverFeatureOptions
return
}
func systemUser() (string, error) {
u, err := userCurrent()
if err != nil {
return "", err
}
return u, nil
}
func (c *Connection) Commit() (err error) {
msg := NewMessage(COMMIT)
_, err = c.conn.Write(msg.Dump(c.count()))
if err != nil {
return
}
_, err = ReadMessage(c.conn)
if err != nil {
return
}
return
}
func (c *Connection) Rollback() (err error) {
msg := NewMessage(ROLLBACK)
_, err = c.conn.Write(msg.Dump(c.count()))
if err != nil {
return
}
_, err = ReadMessage(c.conn)
if err != nil {
return
}
return
}
func (c *Connection) BeginTx(opts driver.TxOptions) (driver.Tx, error) {
if c.tx {
return nil, errors.Join(errBeginTx, errMultipleTx)
}
if opts.ReadOnly {
return nil, errors.Join(errBeginTx, errReadOnlyTxNotSupported)
}
if _, err := c.DirectUpdate("START TRANSACTION"); err != nil {
return nil, errors.Join(errBeginTx, err)
}
c.tx = true
return &tx{c}, nil
}

View File

@@ -0,0 +1,101 @@
package connection
import (
"database/sql/driver"
"errors"
"strings"
)
var (
errNoRowsAffected = errors.New("no RowsAffected available after the empty statement")
errNoLastInsertID = errors.New("no LastInsertId available after the empty statement")
)
type Result struct {
cn *Connection
affected int64
}
func (r Result) LastInsertId() (lastId int64, err error) {
// var msg Message
// msg = NewMessage(GET_AUTO_GENERATED_KEYS)
// msg.header.SetStatementId(r.cn.statementId())
// _, err = r.cn.conn.Write(msg.Dump(r.cn.count()))
// if err != nil {
// return
// }
// msg, err = ReadMessage(r.cn.conn)
// if err != nil {
// return
// }
// msg.Get(&lastId)
// return
var rs *ResultSet
rs, err = r.cn.DirectQuery("SELECT LAST_IDENTITY()")
if err != nil {
return
}
row, err := rs.Next()
if err != nil {
return
}
lastId = int64(row[0].(int))
return
}
func (r Result) RowsAffected() (int64, error) {
return r.affected, nil
}
type Rows struct {
cn *Connection
rs *ResultSet
}
type noRows struct{}
var emptyRows noRows
var _ driver.Result = noRows{}
func (noRows) LastInsertId() (int64, error) {
return 0, errNoLastInsertID
}
func (noRows) RowsAffected() (int64, error) {
return 0, errNoRowsAffected
}
func (r *Rows) Close() error {
return nil
}
func (r *Rows) Columns() []string {
if r.rs == nil {
return []string{}
}
columns := r.rs.Columns()
colNames := make([]string, len(columns))
for k, c := range columns {
colname := c.Name()
// tricking IRIS
colname = strings.ReplaceAll(colname, "﹒", ".")
colNames[k] = colname
}
// fmt.Printf("Columns: %#v\n", colNames)
return colNames
}
func (r *Rows) Next(dest []driver.Value) (err error) {
row, err := r.rs.Next()
if err != nil {
return err
}
for i := range dest {
dest[i] = row[i]
}
// fmt.Printf("RowsNext: %#v\n", dest)
return nil
}

View File

@@ -0,0 +1,729 @@
package connection
import (
"database/sql/driver"
"fmt"
"io"
"slices"
"strconv"
"strings"
"time"
"github.com/caretdev/go-irisnative/src/list"
)
const timeLaylout = "2006-01-02 15:04:05.000000000"
const timeLayloutShort = "2006-01-02 15:04:05"
type StatementFeature struct {
featureOption int
msgCount int
maxRowItemCount int
}
type Column struct {
name string
column_type int
precision int
scale int
nullable int
slot_position int
label string
table_name string
schema string
catalog string
is_auto_increment bool
is_case_sensitive bool
is_currency bool
is_read_only bool
is_row_id bool
}
type SQLTYPE int16
const (
GUID SQLTYPE = -11
WLONGVARCHAR SQLTYPE = -10
WVARCHAR SQLTYPE = -9
WCHAR SQLTYPE = -8
BIT SQLTYPE = -7
TINYINT SQLTYPE = -6
BIGINT SQLTYPE = -5
LONGVARBINARY SQLTYPE = -4
VARBINARY SQLTYPE = -3
BINARY SQLTYPE = -2
LONGVARCHAR SQLTYPE = -1
CHAR SQLTYPE = 1
NUMERIC SQLTYPE = 2
DECIMAL SQLTYPE = 3
INTEGER SQLTYPE = 4
SMALLINT SQLTYPE = 5
FLOAT SQLTYPE = 6
REAL SQLTYPE = 7
DOUBLE SQLTYPE = 8
DATE SQLTYPE = 9
TIME SQLTYPE = 10
TIMESTAMP SQLTYPE = 11
VARCHAR SQLTYPE = 12
TYPE_DATE SQLTYPE = 91
TYPE_TIME SQLTYPE = 92
TYPE_TIMESTAMP SQLTYPE = 93
DATE_HOROLOG SQLTYPE = 1091
TIME_HOROLOG SQLTYPE = 1092
TIMESTAMP_POSIX SQLTYPE = 1093
)
func (c Column) Name() string {
return c.name
}
type ResultSet struct {
c *Connection
columns []Column
sf StatementFeature
count int
data []byte
offset uint
sqlCode int16
}
type SQLError struct {
SQLCode int16
Message string
}
func (e *SQLError) Error() string {
return fmt.Sprintf("Error Code: %d, Message: %s", e.SQLCode, e.Message)
}
// func SQLError(code int) error {
// return &SQLError{SQLCode: code}
// }
func (rs ResultSet) Columns() []Column {
return rs.columns
}
func statementFeature(msg *Message) StatementFeature {
featureOption := 0
msgCount := 0
maxRowItemCount := 0
msg.Get(&featureOption)
if featureOption == 2 {
msg.Get(&msgCount)
}
if featureOption == 1 || featureOption == 2 {
msg.Get(&maxRowItemCount)
}
return StatementFeature{
featureOption,
msgCount,
maxRowItemCount,
}
}
type Value interface{}
// type ResultSetRow struct{}
func (rs *ResultSet) fetchMoreData() bool {
msg := NewMessage(FETCH_DATA)
_, err := rs.c.conn.Write(msg.Dump(rs.c.count()))
if err != nil {
panic(err)
}
msg, err = ReadMessage(rs.c.conn)
if err != nil {
panic(err)
}
rs.data = msg.data
rs.offset = 0
return len(msg.data) > 0
}
func fromODBC(coltype SQLTYPE, li list.ListItem) (result interface{}, err error) {
result = nil
if li.IsNull() || li.IsEmpty() {
return
}
switch coltype {
case VARCHAR:
if li.DataLength() == 0 {
return
}
var value string
li.Get(&value)
if value == "\x00" {
value = ""
}
result = value
case INTEGER, TINYINT, SMALLINT:
var value int
li.Get(&value)
result = value
case BIGINT:
var value int64
li.Get(&value)
result = value
case BIT:
var value bool
li.Get(&value)
result = value
case FLOAT:
var value float32
li.Get(&value)
result = value
case DOUBLE:
var value float64
li.Get(&value)
result = value
case TIMESTAMP_POSIX:
if li.DataLength() == 0 {
return
}
if li.Type() == list.LISTITEM_STRING {
var strval string
li.Get(&strval)
result, err = time.Parse(timeLaylout, strval)
if err == nil {
return
}
err = nil
}
var value int64
li.Get(&value)
if value > 0 {
value ^= 0x1000000000000000
} else {
value |= 0x6000000000000000
}
seconds := value / 1000000
nano := value % 1000000 * 1000
result = time.Unix(seconds, nano).In(time.Local)
case VARBINARY:
// var value []uint8
var value string
li.Get(&value)
case TYPE_TIMESTAMP:
var strval string
li.Get(&strval)
result, err = time.Parse(timeLayloutShort, strval)
default:
var value string
li.Get(&value)
fmt.Printf("fromODBC: invalid type: %v - %#v - %#v", coltype, li, value)
result = value
}
return
}
func (rs *ResultSet) Next() ([]Value, error) {
if rs == nil || (rs.sqlCode != 0 && rs.sqlCode != 100) {
return nil, io.EOF
}
if rs.offset >= uint(len(rs.data)) && (rs.sqlCode == 100 || !rs.fetchMoreData()) {
return nil, io.EOF
}
row := make([]Value, rs.count)
data := rs.data
count := rs.count
var offset uint = rs.offset
if rs.sf.featureOption == 1 {
li := list.GetListItem(data, &rs.offset)
li.Get(&data)
offset = 0
count = rs.sf.maxRowItemCount
}
vals := make([]list.ListItem, count)
for i := 0; i < count; i++ {
li := list.GetListItem(data, &offset)
vals[i] = li
}
if rs.sf.featureOption != 1 {
rs.offset = offset
}
var err error
for i, c := range rs.columns {
li := vals[c.slot_position]
row[i], err = fromODBC(SQLTYPE(c.column_type), li)
if err != nil {
return nil, err
}
// fmt.Printf("col: %s: %d; %#v - %#v\n", c.name, c.column_type, row[i], li)
}
// fmt.Printf("row: %#v\n", row)
return row, nil
}
func (c *Connection) getErrorInfo(sqlCode int16) string {
msg := NewMessage(GET_SERVER_ERROR)
msg.Set(sqlCode)
_, err := c.conn.Write(msg.Dump(c.count()))
if err != nil {
panic(err)
}
msg, err = ReadMessage(c.conn)
if err != nil {
panic(err)
}
var sqlMessage string
msg.Get(&sqlMessage)
return sqlMessage
}
func getColumns(msg *Message, statementFeature StatementFeature) []Column {
cnt := 0
msg.Get(&cnt)
columns := make([]Column, cnt)
for i := 0; i < cnt; i++ {
column := Column{}
msg.Get(&column.name)
msg.Get(&column.column_type)
switch column.column_type {
case 9:
column.column_type = 91
case 10:
column.column_type = 92
case 11:
column.column_type = 93
}
msg.Get(&column.precision)
msg.Get(&column.scale)
msg.Get(&column.nullable)
msg.Get(&column.label)
msg.Get(&column.table_name)
msg.Get(&column.schema)
msg.Get(&column.catalog)
additional := ""
msg.Get(&additional)
if statementFeature.featureOption&0x01 == 1 {
msg.Get(&column.slot_position)
column.slot_position -= 1
} else {
column.slot_position = i
}
column.is_auto_increment = additional[0] == 0x01
column.is_case_sensitive = additional[1] == 0x01
column.is_currency = additional[2] == 0x01
column.is_read_only = additional[3] == 0x01
if len(additional) >= 12 {
column.is_row_id = additional[11] == 0x01
}
columns[i] = column
}
return columns
}
func parameterInfo(msg *Message) {
cnt := 0
msg.Get(&cnt)
flag := 0
msg.Get(&flag)
}
func toODBC(value interface{}) interface{} {
var val interface{}
switch v := value.(type) {
case *string:
val = *v
case string:
val = v
if v == "" {
val = "\x00"
}
case nil:
val = ""
case bool:
if v {
val = 1
} else {
val = 0
}
case time.Time:
val = v.UTC().Format(timeLaylout)
case int, int8, int16, int32, int64:
val = v
case float32, float64:
val = v
case []uint8:
val = v
default:
fmt.Printf("unsupported type: %T\n", v)
val = fmt.Sprintf("%v", v)
}
return val
}
func writeParameters(msg *Message, args ...interface{}) {
msg.Set(len(args))
for range args {
msg.Set(99)
msg.Set(4)
}
msg.Set(1) // parameterSets
msg.Set(len(args))
for _, arg := range args {
msg.Set(toODBC(arg))
}
}
func (c *Connection) Query(sqlText string, args ...interface{}) (rs *ResultSet, err error) {
queries := strings.Split(sqlText, ";\n")
if len(queries) == 2 {
sqlText = queries[0]
_, err = c.DirectUpdate(sqlText, args...)
if err != nil {
return
}
sqlText = queries[1]
args = []interface{}{}
}
rs, err = c.DirectQuery(sqlText, args...)
if err != nil {
return
}
return
}
func (c *Connection) DirectQuery(sqlText string, args ...interface{}) (*ResultSet, error) {
sqlText, _, args = FormatQuery(sqlText, args...)
// fmt.Printf("DirectQuery: %s; %#v\n", sqlText, args)
var statementId = c.statementId()
msg := NewMessage(DIRECT_QUERY)
msg.header.SetStatementId(statementId)
msg.SetSQLText(sqlText)
writeParameters(&msg, args...)
msg.Set(10) // Query timeout
msg.Set(200) // Max rows
_, err := c.conn.Write(msg.Dump(c.count()))
if err != nil {
return nil, err
}
msg, err = ReadMessage(c.conn)
if err != nil {
return nil, err
}
sqlCode := int16(msg.GetStatus())
if sqlCode != 0 && sqlCode != 100 {
return nil, &SQLError{SQLCode: sqlCode, Message: c.getErrorInfo(sqlCode)}
}
statementFeature := statementFeature(&msg)
columns := getColumns(&msg, statementFeature)
parameterInfo((&msg))
rs := &ResultSet{
c: c,
sf: statementFeature,
columns: columns,
count: len(columns),
}
msg, err = ReadMessage(c.conn)
rs.sqlCode = int16(msg.GetStatus())
if err != nil {
return nil, err
}
msg.GetRaw(&rs.data)
return rs, nil
}
func (m Message) debug() string {
var sb strings.Builder
for i, b := range m.data {
if i > 0 {
sb.WriteString(", ")
}
sb.WriteString(strconv.Itoa(int(b)))
}
return fmt.Sprintf("$char(%s)", sb.String())
}
func FormatQuery(sqlText string, args ...interface{}) (string, int, []interface{}) {
var count int
for i := range args {
count++
sqlText = strings.Replace(sqlText, "?", fmt.Sprintf(" :%%qpar(%d) ", i+1), 1)
if !strings.Contains(sqlText, "?") {
break
}
}
return sqlText, count, args
}
func (c *Connection) Exec(sqlText string, args ...interface{}) (res *Result, err error) {
queries := strings.Split(sqlText, ";\n")
var onConflict = ""
if len(queries) == 2 {
sqlText = queries[0]
onConflict = strings.Split(queries[1], "-- ")[1]
if strings.Contains(onConflict, "ON CONFLICT UPDATE") {
// fmt.Printf("------\n%s\n%#v\n------\n", sqlText, args)
sqlText = strings.Replace(sqlText, "INSERT INTO", "INSERT OR UPDATE", 1)
onConflict = ""
}
}
res, err = c.DirectUpdate(sqlText, args...)
if err != nil {
if strings.Contains(onConflict, "ON CONFLICT DO NOTHING") {
res = &Result{cn: c, affected: 0}
err = nil
return
}
}
return
}
func (c *Connection) DirectUpdate(sqlText string, args ...interface{}) (*Result, error) {
var batchSize int
sqlText, batchSize, args = FormatQuery(sqlText, args...)
// fmt.Printf("DirectUpdate: %s; %#v\n", sqlText, args)
var batches = 1
if batchSize > 0 {
batches = len(args) / batchSize
}
var addToCache = false
var statementId = c.statementId()
var executeMany = false
var optFastInsert = false
var rowsAffected int64 = 0
var identityColumn = false
var defaults = []interface{}{}
for i := 1; i <= batches; i++ {
if i > 1 && executeMany {
break
}
var msg Message
if !addToCache {
msg = NewMessage(DIRECT_UPDATE)
msg.SetSQLText(sqlText)
msg.Set(batchSize)
for j := 0; j < batchSize; j++ {
msg.Set(99)
msg.Set(1)
}
// msg.Set(len(args))
// for range args {
// msg.Set(99)
// msg.Set(1)
// }
} else {
msg = NewMessage(PREPARED_UPDATE)
}
if addToCache && !executeMany && optFastInsert {
msg.AddRaw([]byte{1, 0, 0, 0})
msg.Set("")
msg.Set(0)
if identityColumn {
msg.Set(2)
msg.Set("")
} else {
msg.Set(1)
}
var batch []interface{} = make([]interface{}, batchSize)
copy(batch, args)
args = slices.Delete(args, 0, batchSize)
var params []byte
var item list.ListItem
for _, arg := range batch {
item = list.NewListItem(toODBC(arg))
params = append(params, item.Dump()...)
}
for _, arg := range defaults {
item = list.NewListItem(toODBC(arg))
params = append(params, item.Dump()...)
}
msg.Set(params)
} else {
msg.Set("")
msg.Set(0)
if executeMany {
msg.Set(batches)
for k := 0; k < batches; k++ {
msg.Set(batchSize)
for j := 0; j < batchSize; j++ {
var idx = (k * batchSize) + j
msg.Set(toODBC(args[idx]))
}
}
} else {
var batch []interface{} = make([]interface{}, batchSize)
copy(batch, args)
args = slices.Delete(args, 0, batchSize)
msg.Set(1)
msg.Set(len(batch))
for _, arg := range batch {
msg.Set(toODBC(arg))
}
}
}
msg.header.SetStatementId(statementId)
_, err := c.conn.Write(msg.Dump(c.count()))
if err != nil {
return nil, err
}
msg, err = ReadMessage(c.conn)
if err != nil {
// fmt.Println("DirectUpdate:Readmessage: ", err)
return nil, err
}
sqlCode := int16(msg.GetStatus())
if sqlCode != 0 && sqlCode != 100 {
return nil, &SQLError{SQLCode: sqlCode, Message: c.getErrorInfo(sqlCode)}
}
if i == 1 {
if c.IsOptionFastInsert() {
stmtFeatureOption, _ := c.checkStatementFeature(&msg)
optFastInsert = stmtFeatureOption&uint(OptionFastInsert) == uint(OptionFastInsert)
}
addToCache, identityColumn, defaults = c.getParameterInfo(&msg, optFastInsert)
}
var batchRows int64
msg.Get(&batchRows)
rowsAffected += batchRows
}
result := &Result{cn: c, affected: rowsAffected}
return result, nil
}
func (c *Connection) checkStatementFeature(msg *Message) (featureOption uint, count uint) {
count = 0
var keyCount int
msg.Get(&featureOption)
if featureOption == uint(OptionFastSelect) || featureOption == uint(OptionFastInsert) {
if featureOption == uint(OptionFastInsert) {
msg.Get(&keyCount)
}
msg.Get(&count)
}
return
}
func (c *Connection) getParameterInfo(msg *Message, optFastInsert bool) (addToCache bool, identityColumn bool, defaults []interface{}) {
var paramscnt int
msg.Get(&paramscnt)
var tablename string
for i := 0; i < paramscnt; i++ {
var (
paramtype int
precision int
scale int
nullable bool
position int
someval1 string
someval2 string
colname string
)
msg.Get(&paramtype)
msg.Get(&precision)
msg.Get(&scale)
msg.GetAny()
if optFastInsert {
msg.Get(&nullable)
msg.Get(&position)
msg.Get(&someval1)
msg.Get(&someval2)
if i == 0 {
msg.Get(&tablename)
}
msg.Get(&colname)
}
}
var flag int
defaults = []interface{}{}
identityColumn = false
msg.Get(&flag)
addToCache = flag&0x1 == 0x1
if optFastInsert {
var paramsDefault []byte
msg.Get(&paramsDefault)
var offset uint = 0
var li list.ListItem
li = list.GetListItem(paramsDefault, &offset)
identityColumn = li.IsEmpty()
for {
if uint(len(paramsDefault)) == offset {
break
}
li = list.GetListItem(paramsDefault, &offset)
if li.IsNull() {
continue
}
var val string
li.Get(&val)
defaults = append(defaults, val)
}
}
return
}
type Stmt struct {
cn *Connection
sql string
closed bool
statementId int32
}
func (c *Connection) Prepare(query string) (*Stmt, error) {
// msg := NewMessage(PREPARE)
// msg.SetSQLText(query)
// msg.Set(0)
// _, err := c.conn.Write(msg.Dump(c.count()))
// if err != nil {
// return nil, err
// }
// msg, err = ReadMessage(c.conn)
// if err != nil {
// return nil, err
// }
// sqlCode := int16(msg.GetStatus())
// if sqlCode != 0 && sqlCode != 100 {
// return nil, &SQLError{SQLCode: sqlCode, Message: c.getErrorInfo(sqlCode)}
// }
st := &Stmt{cn: c, sql: query}
return st, nil
}
func (st *Stmt) Exec(args []driver.Value) (res driver.Result, err error) {
parameters := make([]interface{}, len(args))
for i, a := range args {
parameters[i] = a
}
res, err = st.cn.Exec(st.sql, parameters...)
return
}
func (st *Stmt) Query(args []driver.Value) (rows driver.Rows, err error) {
parameters := make([]interface{}, len(args))
for i, a := range args {
parameters[i] = a
}
var rs *ResultSet
rs, err = st.cn.Query(st.sql, parameters...)
// st.statementId = int32(st.cn.statementId())
if err != nil {
return nil, err
}
rows = &Rows{
cn: st.cn,
rs: rs,
}
return
}
func (st *Stmt) Close() (err error) {
st.closed = true
return nil
}
func (st *Stmt) NumInput() int {
return -1
}

View File

@@ -0,0 +1,29 @@
package connection
type tx struct {
c *Connection
}
func (t *tx) Commit() error {
if t.c == nil || !t.c.tx {
panic("database/sql/driver: misuse of driver: extra Commit")
}
t.c.tx = false
err := t.c.Commit()
t.c = nil
return err
}
func (t *tx) Rollback() error {
if t.c == nil || !t.c.tx {
panic("database/sql/driver: misuse of driver: extra Rollback")
}
t.c.tx = false
err := t.c.Rollback()
t.c = nil
return err
}

View File

@@ -0,0 +1,22 @@
//go:build !windows
package connection
import (
"os"
"os/user"
)
func userCurrent() (string, error) {
u, err := user.Current()
if err == nil {
return u.Username, nil
}
name := os.Getenv("USER")
if name != "" {
return name, nil
}
return "", ErrCouldNotDetectUsername
}

View File

@@ -0,0 +1,19 @@
package connection
import (
"path/filepath"
"syscall"
)
// Perform Windows user name.
func userCurrent() (string, error) {
pw_name := make([]uint16, 128)
pwname_size := uint32(len(pw_name)) - 1
err := syscall.GetUserNameEx(syscall.NameSamCompatible, &pw_name[0], &pwname_size)
if err != nil {
return "", ErrCouldNotDetectUsername
}
s := syscall.UTF16ToString(pw_name)
u := filepath.Base(s)
return u, nil
}

View File

@@ -0,0 +1,4 @@
package iris
type Oref string

View File

@@ -0,0 +1,456 @@
package list
import (
"encoding/binary"
"errors"
"fmt"
"strconv"
"github.com/caretdev/go-irisnative/src/iris"
"github.com/shopspring/decimal"
)
type ListItemType byte
const (
LISTITEM_STRING ListItemType = 0x01
LISTITEM_UNICODE ListItemType = 0x02
LISTITEM_POSINT ListItemType = 0x04
LISTITEM_NEGINT ListItemType = 0x05
LISTITEM_POSFLOAT ListItemType = 0x06
LISTITEM_NEGFLOAT ListItemType = 0x07
LISTITEM_OREF ListItemType = 0x19
)
type ListItem struct {
size uint16
itemType ListItemType
data []byte
isNull bool
byRef bool
}
func (li *ListItem) IsNull() bool {
return li.isNull
}
func (li *ListItem) IsString() bool {
return li.itemType == LISTITEM_STRING || li.itemType == LISTITEM_UNICODE
}
func (li *ListItem) IsEmpty() bool {
return li.itemType == LISTITEM_STRING && len(li.data) == 0
}
func (li *ListItem) Type() ListItemType {
return li.itemType
}
var scale = []float64{
1.0, 10.0, 100.0, 1000.0, 10000.0, 100000.0, 1000000.0, 1.0e7, 1.0e8, 1.0e9,
1.0e10, 1.0e11, 1.0e12, 1.0e13, 1.0e14, 1.0e15, 1.0e16, 1.0e17, 1.0e18, 1.0e19,
1.0e20, 1.0e21, 1.0e22, 9.999999999999999e22, 1.0e24, 1.0e25, 1.0e26, 1.0e27, 1.0e28, 1.0e29,
1.0e30, 1.0e31, 1.0e32, 1.0e33, 1.0e34, 1.0e35, 1.0e36, 1.0e37, 1.0e38, 1.0e39,
1.0e40, 1.0e41, 1.0e42, 1.0e43, 1.0e44, 1.0e45, 1.0e46, 1.0e47, 1.0e48, 1.0e49,
1.0e50, 1.0e51, 1.0e52, 1.0e53, 1.0e54, 1.0e55, 1.0e56, 1.0e57, 1.0e58, 1.0e59,
1.0e60, 1.0e61, 1.0e62, 1.0e63, 1.0e64, 1.0e65, 1.0e66, 1.0e67, 1.0e68, 1.0e69,
1.0e70, 1.0e71, 1.0e72, 1.0e73, 1.0e74, 1.0e75, 1.0e76, 1.0e77, 1.0e78, 1.0e79,
1.0e80, 1.0e81, 1.0e82, 1.0e83, 1.0e84, 1.0e85, 1.0e86, 1.0e87, 1.0e88, 1.0e89,
1.0e90, 1.0e91, 1.0e92, 1.0e93, 1.0e94, 1.0e95, 1.0e96, 1.0e97, 1.0e98, 1.0e99,
1.0e100, 1.0e101, 1.0e102, 1.0e103, 1.0e104, 1.0e105, 1.0e106, 1.0e107, 1.0e108, 1.0e109,
1.0e110, 1.0e111, 1.0e112, 1.0e113, 1.0e114, 1.0e115, 1.0e116, 1.0e117, 1.0e118, 1.0e119,
1.0e120, 1.0e121, 1.0e122, 1.0e123, 1.0e124, 1.0e125, 1.0e126, 1.0e127, 1.0e-128, 1.0e-127,
1.0e-126, 1.0e-125, 1.0e-124, 1.0e-123, 1.0e-122, 1.0e-121, 1.0e-120, 1.0e-119, 1.0e-118, 1.0e-117,
1.0e-116, 1.0e-115, 1.0e-114, 1.0e-113, 1.0e-112, 1.0e-111, 1.0e-110, 1.0e-109, 1.0e-108, 1.0e-107,
1.0e-106, 1.0e-105, 1.0e-104, 1.0e-103, 1.0e-102, 1.0e-101, 1.0e-100, 1.0e-99, 1.0e-98, 1.0e-97,
1.0e-96, 1.0e-95, 1.0e-94, 1.0e-93, 1.0e-92, 1.0e-91, 1.0e-90, 1.0e-89, 1.0e-88, 1.0e-87,
1.0e-86, 1.0e-85, 1.0e-84, 1.0e-83, 1.0e-82, 1.0e-81, 1.0e-80, 1.0e-79, 1.0e-78, 1.0e-77,
1.0e-76, 1.0e-75, 1.0e-74, 1.0e-73, 1.0e-72, 1.0e-71, 1.0e-70, 1.0e-69, 1.0e-68, 1.0e-67,
1.0e-66, 1.0e-65, 1.0e-64, 1.0e-63, 1.0e-62, 1.0e-61, 1.0e-60, 1.0e-59, 1.0e-58, 1.0e-57,
1.0e-56, 1.0e-55, 1.0e-54, 1.0e-53, 1.0e-52, 1.0e-51, 1.0e-50, 1.0e-49, 1.0e-48, 1.0e-47,
1.0e-46, 1.0e-45, 1.0e-44, 1.0e-43, 1.0e-42, 1.0e-41, 1.0e-40, 1.0e-39, 1.0e-38, 1.0e-37,
1.0e-36, 1.0e-35, 1.0e-34, 1.0e-33, 1.0e-32, 1.0e-31, 1.0e-30, 1.0e-29, 1.0e-28, 1.0e-27,
1.0e-26, 1.0e-25, 1.0e-24, 1.0e-23, 1.0e-22, 1.0e-21, 1.0e-20, 1.0e-19, 1.0e-18, 1.0e-17,
1.0e-16, 1.0e-15, 1.0e-14, 1.0e-13, 1.0e-12, 1.0e-11, 1.0e-10, 1.0e-9, 1.0e-8, 1.0e-7,
1.0e-6, 1.0e-5, 1.0e-4, 0.001, 0.01, 0.1}
func (listItem *ListItem) Dump() []byte {
if listItem.isNull {
return []byte{1}
}
var dump = make([]byte, 0)
if listItem.size > 253 {
size := listItem.size + 1
dump = append(dump, 0)
dump = append(dump, byte((size)&0xff))
dump = append(dump, byte((size>>8)&0xff))
} else {
dump = append(dump, byte(listItem.size+2))
}
dump = append(dump, byte(listItem.itemType))
dump = append(dump, listItem.data...)
return dump
}
func GetListItem(buffer []byte, ooffset *uint) ListItem {
var byRef = false
var isNull = false
var size uint16 = 0
var itemType byte = 0
offset := *ooffset
switch buffer[offset] {
case 0:
size = uint16((buffer[offset+1] & 0xff))
size |= ((uint16(buffer[offset+2]) & 0xff) << 8)
size -= 1
offset += 3
itemType = buffer[offset]
offset += 1
case 1:
isNull = true
offset += 1
default:
size = uint16(buffer[offset]) - 2
offset += 1
itemType = buffer[offset]
offset += 1
if itemType >= 32 && itemType < 64 {
itemType = itemType - 32
byRef = true
}
}
var data = []byte{}
if size > 0 {
data = buffer[offset : offset+uint(size)]
}
offset += uint(size)
*ooffset = offset
return ListItem{size, ListItemType(itemType), data, isNull, byRef}
}
func NewListItem(value interface{}) ListItem {
var itemType ListItemType = 0
var size uint16 = 0
var data = make([]byte, 0)
var isNull = false
var byRef = false
switch v := value.(type) {
case *string:
var listItem = NewListItem(*v)
listItem.byRef = true
return listItem
case int, int8, int16, int32, int64:
var ival int64
switch i := v.(type) {
case int:
ival = int64(i)
case int8:
ival = int64(i)
case int16:
ival = int64(i)
case int32:
ival = int64(i)
case int64:
ival = i
}
itemType = 4
var base = 0
var temp = ival
if ival < 0 {
itemType = 5
base = 0xff
temp = ival*-1 - 1
}
for temp > 0 {
data = append(data, byte((temp^int64(base))&0xff))
temp = temp >> 8
}
case uint, uint8, uint16, uint32, uint64:
var uval uint64
switch u := v.(type) {
case uint:
uval = uint64(u)
case uint8:
uval = uint64(u)
case uint16:
uval = uint64(u)
case uint32:
uval = uint64(u)
case uint64:
uval = u
}
itemType = 4
temp := uval
for temp > 0 {
data = append(data, byte(temp&0xff))
temp = temp >> 8
}
case float64, float32:
var d decimal.Decimal
switch f := v.(type) {
case float32:
d = decimal.NewFromFloat32(f)
case float64:
d = decimal.NewFromFloat(f)
}
scaleSize := 256 - d.Exponent()*-1
ival := d.Coefficient().Int64()
itemType = 6
if ival < 0 {
itemType = 7
}
data = append(data, byte(scaleSize))
var base = 0
var temp = ival
if ival < 0 {
base = 0xff
temp = ival*-1 - 1
}
for temp > 0 {
data = append(data, byte((temp^int64(base))&0xff))
temp = temp >> 8
}
case bool:
itemType = 4
if v {
data = []byte{0x1}
} else {
data = []byte{0x0}
}
case string:
itemType = 1
var unicodeBytes []byte
for _, r := range(v) {
if r > 255 {
itemType = 2
var temp = r
// append(unicodeBytes)
for temp > 0 {
unicodeBytes = append(unicodeBytes, byte((temp)&0xff))
temp = temp >> 8
}
} else {
unicodeBytes = append(unicodeBytes, byte((r)&0xff))
unicodeBytes = append(unicodeBytes, byte(0))
}
}
if itemType == 2 {
data = unicodeBytes
} else {
data = []byte(v)
}
case []byte:
itemType = 1
data = v
case nil:
isNull = true
// itemType = 1
// data = []byte("")
case iris.Oref:
itemType = 25
byRef = true
data = []byte(v)
default:
fmt.Printf("unknown: %#v %T\n", v, v)
itemType = 1
data = []byte(fmt.Sprintf("%v", v))
}
size = uint16(len(data))
return ListItem{
size,
itemType,
data,
isNull,
byRef,
}
}
func (li *ListItem) getString() string {
if li.itemType == LISTITEM_UNICODE {
var val string = ""
for i := 0; i < len(li.data); i += 2 {
val += string(rune(getPosInt(li.data[i:i+2])))
}
return val
} else {
return string(li.data)
}
}
func getPosInt(data []byte) int {
temp := make([]byte, 8)
copy(temp, data)
return int(binary.LittleEndian.Uint64(temp[:8]))
}
func getNegInt(data []byte) int {
temp := make([]byte, 8)
copy(temp, data)
for i := range data {
temp[i] ^= 0xff
}
return -int(binary.LittleEndian.Uint64(temp[:8]) + 1)
}
func getPosFloat(data []byte) float64 {
d := scale[int(data[0])]
return float64(getPosInt(data[1:])) * d
}
func getNegFloat(data []byte) float64 {
d := scale[int(data[0])]
return float64(getNegInt(data[1:])) * d
}
func (li *ListItem) asString() (value string, err error) {
if li.isNull {
value = ""
return
}
switch li.itemType {
case 1, 2, 25:
value = li.getString()
case 4:
value = fmt.Sprint(getPosInt(li.data))
case 5:
value = fmt.Sprint(getNegInt(li.data))
case 6:
value = fmt.Sprint(getPosFloat(li.data))
case 7:
value = fmt.Sprint(getNegFloat(li.data))
default:
err = errors.New("not implemented")
}
return
}
func (li *ListItem) asInt() (value int, err error) {
if li.isNull {
value = 0
return
}
switch li.itemType {
case 1, 2:
value, err = strconv.Atoi(li.getString())
case 4:
value = getPosInt(li.data)
case 5:
value = getNegInt(li.data)
case 6:
value = int(getPosFloat(li.data))
case 7:
value = int(getNegFloat(li.data))
default:
err = errors.New("not implemented")
}
return
}
func (li *ListItem) asFloat64() (value float64, err error) {
if li.isNull {
value = 0
return
}
switch li.itemType {
case 1, 2:
var temp int
temp, err = strconv.Atoi(li.getString())
if err != nil {
return
}
value = float64(temp)
case 4:
value = float64(getPosInt(li.data))
case 5:
value = float64(getNegInt(li.data))
case 6:
value = getPosFloat(li.data)
case 7:
value = getNegFloat(li.data)
default:
err = errors.New("not implemented")
}
return
}
type AnyType ListItem
func (v *AnyType) Int() int {
var value int
// ListItem(*v)
return value
}
func (li *ListItem) GetAny() AnyType {
return AnyType(*li)
}
func (li *ListItem) DataLength() int {
return len(li.data)
}
func (li *ListItem) Get(value interface{}) (err error) {
switch v := value.(type) {
case *int:
*v, err = li.asInt()
case *bool:
var temp int
temp, err = li.asInt()
*v = temp != 0
case *int8:
var temp int
temp, err = li.asInt()
*v = int8(temp)
case *int16:
var temp int
temp, err = li.asInt()
*v = int16(temp)
case *int32:
var temp int
temp, err = li.asInt()
*v = int32(temp)
case *int64:
var temp int
temp, err = li.asInt()
*v = int64(temp)
case *uint:
var temp int
temp, err = li.asInt()
*v = uint(temp)
case *uint8:
var temp int
temp, err = li.asInt()
*v = uint8(temp)
case *uint16:
var temp int
temp, err = li.asInt()
*v = uint16(temp)
case *uint32:
var temp int
temp, err = li.asInt()
*v = uint32(temp)
case *uint64:
var temp int
temp, err = li.asInt()
*v = uint64(temp)
case *float64:
*v, err = li.asFloat64()
case *float32:
var temp float64
temp, err = li.asFloat64()
*v = float32(temp)
case *string:
*v, err = li.asString()
case *[]byte:
*v = li.data
case *iris.Oref:
var temp string
temp, err = li.asString()
*v = iris.Oref(temp)
default:
err = errors.New("not implemented")
}
return
}

76
third_party/go-irisnative/url.go vendored Normal file
View File

@@ -0,0 +1,76 @@
package intersystems
import (
"fmt"
"net"
nurl "net/url"
"sort"
"strings"
)
// ParseURL no longer needs to be used by clients of this library since supplying a URL as a
// connection string to sql.Open() is now supported:
//
// sql.Open("intersystems", "iris://_system:SYS@1.2.3.4:1972/USER")
//
// It remains exported here for backwards-compatibility.
//
// ParseURL converts a url to a connection string for driver.Open.
// Example:
//
// "iris://_system:SYS@1.2.3.4:1972/USER"
//
// converts to:
//
// "user=_system password=SYS host=1.2.3.4 port=1972 namespace=USER"
//
// A minimal example:
//
// "iris://"
//
// This will be blank, causing driver.Open to use all of the defaults
func ParseURL(url string) (string, error) {
u, err := nurl.Parse(url)
if err != nil {
return "", err
}
if u.Scheme != "iris" && u.Scheme != "IRIS" {
return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme)
}
var kvs []string
escaper := strings.NewReplacer(`'`, `\'`, `\`, `\\`)
accrue := func(k, v string) {
if v != "" {
kvs = append(kvs, k+"='"+escaper.Replace(v)+"'")
}
}
if u.User != nil {
v := u.User.Username()
accrue("user", v)
v, _ = u.User.Password()
accrue("password", v)
}
if host, port, err := net.SplitHostPort(u.Host); err != nil {
accrue("host", u.Host)
} else {
accrue("host", host)
accrue("port", port)
}
if u.Path != "" {
accrue("namespace", strings.ToUpper(u.Path[1:]))
}
q := u.Query()
for k := range q {
accrue(k, q.Get(k))
}
sort.Strings(kvs) // Makes testing easier (not a performance concern)
return strings.Join(kvs, " "), nil
}

View File

@@ -7,7 +7,7 @@ cd "$SCRIPT_DIR"
SCRIPT_DIR_WINDOWS="$(pwd -W 2>/dev/null || true)"
SCRIPT_DIR_WINDOWS="${SCRIPT_DIR_WINDOWS//\\//}"
DEFAULT_DRIVERS=(mariadb oceanbase doris starrocks sphinx sqlserver sqlite duckdb dameng kingbase highgo vastbase opengauss mongodb tdengine clickhouse)
DEFAULT_DRIVERS=(mariadb oceanbase doris starrocks sphinx sqlserver sqlite duckdb dameng kingbase highgo vastbase opengauss iris mongodb tdengine clickhouse)
TARGET_PLATFORMS=(darwin/amd64 darwin/arm64 windows/amd64 windows/arm64 linux/amd64)
usage() {
@@ -50,7 +50,7 @@ normalize_driver() {
case "$value" in
doris|diros) echo "doris" ;;
open_gauss|open-gauss) echo "opengauss" ;;
mariadb|oceanbase|starrocks|sphinx|sqlserver|sqlite|duckdb|dameng|kingbase|highgo|vastbase|opengauss|mongodb|tdengine|clickhouse)
mariadb|oceanbase|starrocks|sphinx|sqlserver|sqlite|duckdb|dameng|kingbase|highgo|vastbase|opengauss|iris|mongodb|tdengine|clickhouse)
echo "$value"
;;
*)

View File

@@ -5,7 +5,7 @@ set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
cd "$SCRIPT_DIR"
DEFAULT_DRIVERS=(mariadb oceanbase diros starrocks sphinx sqlserver sqlite duckdb dameng kingbase highgo vastbase opengauss mongodb tdengine clickhouse)
DEFAULT_DRIVERS=(mariadb oceanbase diros starrocks sphinx sqlserver sqlite duckdb dameng kingbase highgo vastbase opengauss iris mongodb tdengine clickhouse)
OUTPUT_FILE="internal/db/driver_agent_revisions_gen.go"
usage() {
@@ -27,7 +27,7 @@ normalize_driver() {
doris|diros) echo "diros" ;;
oceanbase) echo "oceanbase" ;;
opengauss|open_gauss|open-gauss) echo "opengauss" ;;
mariadb|diros|starrocks|sphinx|sqlserver|sqlite|duckdb|dameng|kingbase|highgo|vastbase|mongodb|tdengine|clickhouse)
mariadb|diros|starrocks|sphinx|sqlserver|sqlite|duckdb|dameng|kingbase|highgo|vastbase|iris|mongodb|tdengine|clickhouse)
echo "$value"
;;
*)
@@ -125,6 +125,7 @@ highgo:internal/db/highgo_impl.go|\
vastbase:internal/db/vastbase_impl.go|\
opengauss:internal/db/opengauss_impl.go|\
opengauss:internal/db/postgres_impl.go|\
iris:internal/db/iris_impl.go|\
mongodb:internal/db/mongodb_impl.go|\
mongodb:internal/db/mongodb_impl_v1.go|\
tdengine:internal/db/tdengine_impl.go|\