diff --git a/build-driver-agents.sh b/build-driver-agents.sh index 65a6c52..854348d 100755 --- a/build-driver-agents.sh +++ b/build-driver-agents.sh @@ -5,7 +5,7 @@ set -euo pipefail SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" cd "$SCRIPT_DIR" -DEFAULT_DRIVERS=(mariadb oceanbase doris starrocks sphinx sqlserver sqlite duckdb dameng kingbase highgo vastbase opengauss mongodb tdengine clickhouse) +DEFAULT_DRIVERS=(mariadb oceanbase doris starrocks sphinx sqlserver sqlite duckdb dameng kingbase highgo vastbase opengauss iris mongodb tdengine clickhouse) DEFAULT_PLATFORMS=(darwin/amd64 darwin/arm64 windows/amd64 windows/arm64 linux/amd64 linux/arm64) DUCKDB_WINDOWS_LIBRARY_VERSION="v1.4.4" DUCKDB_WINDOWS_LIBRARY_URL="https://github.com/duckdb/duckdb/releases/download/${DUCKDB_WINDOWS_LIBRARY_VERSION}/libduckdb-windows-amd64.zip" @@ -42,7 +42,7 @@ normalize_driver() { case "$name" in doris|diros) echo "doris" ;; open_gauss|open-gauss) echo "opengauss" ;; - mariadb|oceanbase|starrocks|sphinx|sqlserver|sqlite|duckdb|dameng|kingbase|highgo|vastbase|opengauss|mongodb|tdengine|clickhouse) + mariadb|oceanbase|starrocks|sphinx|sqlserver|sqlite|duckdb|dameng|kingbase|highgo|vastbase|opengauss|iris|mongodb|tdengine|clickhouse) echo "$name" ;; *) diff --git a/cmd/optional-driver-agent/provider_iris.go b/cmd/optional-driver-agent/provider_iris.go new file mode 100644 index 0000000..05c47ee --- /dev/null +++ b/cmd/optional-driver-agent/provider_iris.go @@ -0,0 +1,12 @@ +//go:build gonavi_iris_driver + +package main + +import "GoNavi-Wails/internal/db" + +func init() { + agentDriverType = "iris" + agentDatabaseFactory = func() db.Database { + return &db.IrisDB{} + } +} diff --git a/frontend/src/components/ConnectionModal.tsx b/frontend/src/components/ConnectionModal.tsx index 0f892c9..0188ef1 100644 --- a/frontend/src/components/ConnectionModal.tsx +++ b/frontend/src/components/ConnectionModal.tsx @@ -224,6 +224,8 @@ const getDefaultPortByType = (type: string) => { return 54321; case "sqlserver": return 1433; + case "iris": + return 1972; case "mongodb": return 27017; case "highgo": @@ -247,6 +249,7 @@ const singleHostUriSchemesByType: Record = { clickhouse: ["clickhouse"], oracle: ["oracle"], sqlserver: ["sqlserver"], + iris: ["iris", "intersystems"], redis: ["redis"], tdengine: ["tdengine"], dameng: ["dameng", "dm"], @@ -368,6 +371,7 @@ const supportsConnectionParamsForType = (type: string) => type === "opengauss" || type === "oracle" || type === "sqlserver" || + type === "iris" || type === "clickhouse" || type === "mongodb" || type === "dameng" || @@ -390,6 +394,13 @@ const normalizeDriverType = (value: string): string => { .toLowerCase(); if (normalized === "postgresql") return "postgres"; if (normalized === "doris") return "diros"; + if ( + normalized === "intersystems" || + normalized === "intersystemsiris" || + normalized === "inter-systems-iris" || + normalized === "inter-systems" + ) + return "iris"; if ( normalized === "open_gauss" || normalized === "open-gauss" || @@ -1980,6 +1991,9 @@ const ConnectionModal: React.FC<{ if (dbType === "oracle") { return "oracle://user:pass@127.0.0.1:1521/ORCLPDB1"; } + if (dbType === "iris") { + return "iris://user:pass@127.0.0.1:1972/USER"; + } if (dbType === "opengauss") { return "opengauss://user:pass@127.0.0.1:5432/db_name"; } @@ -2006,6 +2020,8 @@ const ConnectionModal: React.FC<{ return "PREFETCH_ROWS=5000&TRACE FILE=/tmp/go-ora.trc"; case "sqlserver": return "app name=GoNavi&packet size=32767"; + case "iris": + return "timeout=30"; case "clickhouse": return "max_execution_time=60&compress=lz4"; case "mongodb": @@ -3869,6 +3885,11 @@ const ConnectionModal: React.FC<{ name: "SQL Server", icon: getDbIcon("sqlserver", undefined, 36), }, + { + key: "iris", + name: "InterSystems IRIS", + icon: getDbIcon("iris", undefined, 36), + }, { key: "sqlite", name: "SQLite", diff --git a/frontend/src/components/DatabaseIcons.test.tsx b/frontend/src/components/DatabaseIcons.test.tsx new file mode 100644 index 0000000..e9b3a40 --- /dev/null +++ b/frontend/src/components/DatabaseIcons.test.tsx @@ -0,0 +1,10 @@ +import { describe, expect, it } from 'vitest'; + +import { DB_ICON_TYPES, getDbIconLabel } from './DatabaseIcons'; + +describe('DatabaseIcons', () => { + it('includes InterSystems IRIS in the selectable database icons', () => { + expect(DB_ICON_TYPES).toContain('iris'); + expect(getDbIconLabel('iris')).toBe('InterSystems IRIS'); + }); +}); diff --git a/frontend/src/components/DatabaseIcons.tsx b/frontend/src/components/DatabaseIcons.tsx index d50bca9..0def9b2 100644 --- a/frontend/src/components/DatabaseIcons.tsx +++ b/frontend/src/components/DatabaseIcons.tsx @@ -27,6 +27,7 @@ const DB_DEFAULT_COLORS: Record = { vastbase: '#0066CC', opengauss: '#2446A8', highgo: '#00A86B', + iris: '#1F6FEB', tdengine: '#2962FF', diros: '#0050B3', starrocks: '#00A6A6', @@ -146,6 +147,9 @@ const OpenGaussIcon: React.FC = ({ size = 16, color }) => ( const HighGoIcon: React.FC = ({ size = 16, color }) => ( ); +const IrisIcon: React.FC = ({ size = 16, color }) => ( + +); const TDengineIcon: React.FC = ({ size = 16, color }) => ( ); @@ -195,6 +199,7 @@ const DB_ICON_MAP: Record> = { vastbase: VastBaseIcon, opengauss: OpenGaussIcon, highgo: HighGoIcon, + iris: IrisIcon, tdengine: TDengineIcon, custom: CustomIcon, }; @@ -203,7 +208,7 @@ const DB_ICON_MAP: Record> = { export const DB_ICON_TYPES: string[] = [ 'mysql', 'mariadb', 'oceanbase', 'postgres', 'redis', 'mongodb', 'jvm', 'oracle', 'sqlserver', 'sqlite', 'duckdb', 'clickhouse', 'starrocks', - 'kingbase', 'dameng', 'vastbase', 'opengauss', 'highgo', 'tdengine', 'custom', + 'kingbase', 'dameng', 'vastbase', 'opengauss', 'highgo', 'iris', 'tdengine', 'custom', ]; /** 该类型是否有品牌 SVG 文件 */ @@ -225,7 +230,7 @@ export const getDbIconLabel = (type: string): string => { sqlserver: 'SQL Server', clickhouse: 'ClickHouse', sqlite: 'SQLite', starrocks: 'StarRocks', duckdb: 'DuckDB', kingbase: '金仓', dameng: '达梦', - vastbase: 'VastBase', opengauss: 'OpenGauss', highgo: '瀚高', tdengine: 'TDengine', + vastbase: 'VastBase', opengauss: 'OpenGauss', highgo: '瀚高', iris: 'InterSystems IRIS', tdengine: 'TDengine', custom: '自定义', }; return labels[type?.toLowerCase()] || type; diff --git a/frontend/src/components/Sidebar.tsx b/frontend/src/components/Sidebar.tsx index 61db2b7..b6d2453 100644 --- a/frontend/src/components/Sidebar.tsx +++ b/frontend/src/components/Sidebar.tsx @@ -131,6 +131,12 @@ const normalizeDriverType = (value: string): string => { normalized === 'open-gauss' || normalized === 'opengauss' ) return 'opengauss'; + if ( + normalized === 'intersystems' || + normalized === 'intersystemsiris' || + normalized === 'inter-systems' || + normalized === 'inter-systems-iris' + ) return 'iris'; return normalized; }; @@ -648,6 +654,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }> 'open_gauss', 'open-gauss', 'sqlserver', + 'iris', 'oracle', 'dameng', ]); @@ -661,6 +668,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }> 'open_gauss', 'open-gauss', 'sqlserver', + 'iris', 'oracle', 'dm', ]); diff --git a/frontend/src/components/tableDataDangerActions.ts b/frontend/src/components/tableDataDangerActions.ts index 5a1030d..7169f78 100644 --- a/frontend/src/components/tableDataDangerActions.ts +++ b/frontend/src/components/tableDataDangerActions.ts @@ -38,6 +38,10 @@ const resolveCustomDriverDialect = (driver: string): string => { return 'highgo'; case 'vastbase': return 'vastbase'; + case 'iris': + case 'intersystems': + case 'intersystemsiris': + return 'iris'; default: break; } diff --git a/frontend/src/utils/connectionModalPresentation.test.ts b/frontend/src/utils/connectionModalPresentation.test.ts index 9ef3195..f46b559 100644 --- a/frontend/src/utils/connectionModalPresentation.test.ts +++ b/frontend/src/utils/connectionModalPresentation.test.ts @@ -84,6 +84,7 @@ describe('connectionModalPresentation', () => { 'highgo', 'vastbase', 'opengauss', + 'iris', 'mongodb', 'redis', 'tdengine', @@ -139,6 +140,13 @@ describe('connectionModalPresentation', () => { 'customDriver', 'customDsn', ]); + expect(resolveConnectionConfigLayout('iris').sections).toEqual([ + 'identity', + 'uri', + 'target', + 'credentials', + 'databaseScope', + ]); }); it('uses localized labels for layout kinds shown in the modal', () => { diff --git a/frontend/src/utils/dataSourceCapabilities.test.ts b/frontend/src/utils/dataSourceCapabilities.test.ts index e02fe21..adea747 100644 --- a/frontend/src/utils/dataSourceCapabilities.test.ts +++ b/frontend/src/utils/dataSourceCapabilities.test.ts @@ -40,6 +40,20 @@ describe('dataSourceCapabilities', () => { }); }); + it('keeps InterSystems IRIS as an editable SQL datasource capability', () => { + expect(getDataSourceCapabilities({ type: 'iris' })).toMatchObject({ + type: 'iris', + supportsQueryEditor: true, + supportsSqlQueryExport: true, + supportsCopyInsert: true, + forceReadOnlyQueryResult: false, + }); + expect(getDataSourceCapabilities({ type: 'custom', driver: 'intersystemsiris' })).toMatchObject({ + type: 'iris', + supportsQueryEditor: true, + }); + }); + it('treats OceanBase Oracle protocol as Oracle capabilities', () => { expect(getDataSourceCapabilities({ type: 'oceanbase', diff --git a/frontend/src/utils/dataSourceCapabilities.ts b/frontend/src/utils/dataSourceCapabilities.ts index b55c59e..5de93ba 100644 --- a/frontend/src/utils/dataSourceCapabilities.ts +++ b/frontend/src/utils/dataSourceCapabilities.ts @@ -18,6 +18,11 @@ const normalizeDataSourceToken = (raw: string): string => { return 'opengauss'; case 'dm': return 'dameng'; + case 'intersystems': + case 'intersystemsiris': + case 'inter-systems': + case 'inter-systems-iris': + return 'iris'; default: return normalized; } @@ -52,6 +57,7 @@ const SQL_QUERY_EXPORT_TYPES = new Set([ 'vastbase', 'opengauss', 'sqlserver', + 'iris', 'sqlite', 'duckdb', 'oracle', @@ -73,6 +79,7 @@ const COPY_INSERT_TYPES = new Set([ 'vastbase', 'opengauss', 'sqlserver', + 'iris', 'sqlite', 'duckdb', 'oracle', diff --git a/frontend/src/utils/driverImportGuidance.test.ts b/frontend/src/utils/driverImportGuidance.test.ts index d382324..1d0041f 100644 --- a/frontend/src/utils/driverImportGuidance.test.ts +++ b/frontend/src/utils/driverImportGuidance.test.ts @@ -19,6 +19,8 @@ describe('driver import guidance', () => { expect(CUSTOM_CONNECTION_DRIVER_HELP).toContain('pgx'); expect(CUSTOM_CONNECTION_DRIVER_HELP).toContain('open_gauss'); expect(CUSTOM_CONNECTION_DRIVER_HELP).toContain('oceanbase'); + expect(CUSTOM_CONNECTION_DRIVER_HELP).toContain('Go database/sql'); + expect(CUSTOM_CONNECTION_DRIVER_HELP).toContain('ODBC/JDBC'); expect(CUSTOM_CONNECTION_DRIVER_HELP).toContain('JDBC Jar'); }); }); diff --git a/frontend/src/utils/driverImportGuidance.ts b/frontend/src/utils/driverImportGuidance.ts index 9a2ab78..ebd83af 100644 --- a/frontend/src/utils/driverImportGuidance.ts +++ b/frontend/src/utils/driverImportGuidance.ts @@ -7,4 +7,4 @@ export const DRIVER_LOCAL_IMPORT_SINGLE_FILE_HELP = '行内“导入驱动包”仅用于单个驱动文件/总包(如 `mariadb-driver-agent`、`mariadb-driver-agent.exe`、`GoNavi-DriverAgents.zip`),不支持直接导入 JDBC Jar;批量导入请使用上方“导入驱动目录”。'; export const CUSTOM_CONNECTION_DRIVER_HELP = - '已支持: mysql, starrocks, oceanbase, postgres, opengauss, sqlite, oracle, dm, kingbase;别名支持 postgresql/pgx、open_gauss/open-gauss、dm8、kingbase8/kingbasees/kingbasev8。当前不支持通过 JDBC Jar 扩展驱动。'; + '已支持: mysql, starrocks, oceanbase, postgres, opengauss, sqlite, oracle, dm, kingbase;别名支持 postgresql/pgx、open_gauss/open-gauss、dm8、kingbase8/kingbasees/kingbasev8。请填写 GoNavi 已注册的 Go database/sql 驱动名,不能直接填写系统 ODBC/JDBC 驱动名或导入 JDBC Jar。'; diff --git a/frontend/src/utils/sqlDialect.test.ts b/frontend/src/utils/sqlDialect.test.ts index 9dcfecb..3e05595 100644 --- a/frontend/src/utils/sqlDialect.test.ts +++ b/frontend/src/utils/sqlDialect.test.ts @@ -19,6 +19,8 @@ describe('sqlDialect', () => { expect(resolveSqlDialect('doris')).toBe('diros'); expect(resolveSqlDialect('StarRocks')).toBe('starrocks'); expect(resolveSqlDialect('dameng')).toBe('dameng'); + expect(resolveSqlDialect('InterSystems IRIS')).toBe('iris'); + expect(resolveSqlDialect('custom', 'intersystemsiris')).toBe('iris'); expect(resolveSqlDialect('custom', 'kingbase8')).toBe('kingbase'); expect(resolveSqlDialect('custom', 'dm8')).toBe('dameng'); expect(resolveSqlDialect('custom', 'mariadb')).toBe('mariadb'); @@ -43,6 +45,7 @@ describe('sqlDialect', () => { expect(values(resolveColumnTypeOptions('starrocks'))).toContain('PERCENTILE'); expect(values(resolveColumnTypeOptions('sphinx'))).toContain('text'); expect(values(resolveColumnTypeOptions('clickhouse'))).toContain('DateTime64(3)'); + expect(values(resolveColumnTypeOptions('iris'))).toContain('varchar(255)'); expect(values(resolveColumnTypeOptions('tdengine'))).toContain('TIMESTAMP'); expect(values(resolveColumnTypeOptions('duckdb'))).toContain('STRUCT'); }); diff --git a/frontend/src/utils/sqlDialect.ts b/frontend/src/utils/sqlDialect.ts index 4f3b862..8158167 100644 --- a/frontend/src/utils/sqlDialect.ts +++ b/frontend/src/utils/sqlDialect.ts @@ -22,6 +22,7 @@ export type SqlDialect = | 'oracle' | 'dameng' | 'sqlserver' + | 'iris' | 'sqlite' | 'duckdb' | 'clickhouse' @@ -68,6 +69,12 @@ export const resolveSqlDialect = ( case 'sql_server': case 'sql-server': return 'sqlserver'; + case 'intersystems': + case 'intersystemsiris': + case 'inter-systems': + case 'inter-systems-iris': + case 'iris': + return 'iris'; case 'doris': case 'diros': return 'diros'; @@ -122,6 +129,7 @@ export const resolveSqlDialect = ( if (source.includes('clickhouse')) return 'clickhouse'; if (source.includes('tdengine')) return 'tdengine'; if (source.includes('sqlserver') || source.includes('mssql')) return 'sqlserver'; + if (source.includes('iris') || source.includes('intersystems')) return 'iris'; return source; }; @@ -479,6 +487,7 @@ export const resolveColumnTypeOptions = (dbType: string): ColumnTypeOption[] => if (dialect === 'oracle') return ORACLE_TYPES; if (dialect === 'dameng') return DAMENG_TYPES; if (dialect === 'sqlserver') return SQLSERVER_TYPES; + if (dialect === 'iris') return COMMON_TYPES; if (dialect === 'sqlite') return SQLITE_TYPES; if (dialect === 'duckdb') return DUCKDB_TYPES; if (dialect === 'clickhouse') return CLICKHOUSE_TYPES; diff --git a/go.mod b/go.mod index 999561c..370036b 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( gitea.com/kingbase/gokb v0.0.0-20201021123113-29bd62a876c3 gitee.com/chunanyong/dm v1.8.22 github.com/ClickHouse/clickhouse-go/v2 v2.43.0 + github.com/caretdev/go-irisnative v0.2.1 github.com/duckdb/duckdb-go/v2 v2.5.5 github.com/go-sql-driver/mysql v1.9.3 github.com/google/uuid v1.6.0 @@ -122,3 +123,5 @@ require ( ) replace github.com/highgo/pq-sm3 => ./third_party/highgo-pq + +replace github.com/caretdev/go-irisnative => ./third_party/go-irisnative diff --git a/internal/app/db_context.go b/internal/app/db_context.go index 6fe1907..3303e34 100644 --- a/internal/app/db_context.go +++ b/internal/app/db_context.go @@ -20,7 +20,7 @@ func normalizeRunConfig(config connection.ConnectionConfig, dbName string) conne if !isOceanBaseOracleProtocol(config) { runConfig.Database = name } - case "mysql", "mariadb", "diros", "starrocks", "sphinx", "postgres", "kingbase", "highgo", "vastbase", "opengauss", "sqlserver", "mongodb", "tdengine", "clickhouse": + case "mysql", "mariadb", "diros", "starrocks", "sphinx", "postgres", "kingbase", "highgo", "vastbase", "opengauss", "sqlserver", "iris", "intersystems", "intersystemsiris", "inter-systems", "inter-systems-iris", "mongodb", "tdengine", "clickhouse": // 这些类型的 dbName 表示"数据库",需要写入连接配置以选择目标库。 runConfig.Database = name case "dameng": @@ -68,6 +68,16 @@ func normalizeSchemaAndTable(config connection.ConnectionConfig, dbName string, } } + if dbType == "iris" { + schema, table := db.SplitSQLQualifiedName(rawTable) + if schema != "" && table != "" { + return schema, table + } + if table != "" { + return "", table + } + } + if parts := strings.SplitN(rawTable, ".", 2); len(parts) == 2 { schema := strings.TrimSpace(parts[0]) table := strings.TrimSpace(parts[1]) diff --git a/internal/app/db_context_test.go b/internal/app/db_context_test.go index 36a8f20..ee02ea7 100644 --- a/internal/app/db_context_test.go +++ b/internal/app/db_context_test.go @@ -90,6 +90,43 @@ func TestNormalizeRunConfig_StarRocksUsesDatabaseFromTree(t *testing.T) { } } +func TestNormalizeRunConfig_IRISUsesNamespaceFromTree(t *testing.T) { + t.Parallel() + + runConfig := normalizeRunConfig(connection.ConnectionConfig{ + Type: "iris", + Database: "USER", + }, "APP") + + if runConfig.Database != "APP" { + t.Fatalf("expected IRIS namespace from tree, got %q", runConfig.Database) + } +} + +func TestNormalizeSchemaAndTable_IRISDoesNotTreatNamespaceAsSchema(t *testing.T) { + t.Parallel() + + schema, table := normalizeSchemaAndTable(connection.ConnectionConfig{ + Type: "iris", + }, "USER", "Person") + + if schema != "" || table != "Person" { + t.Fatalf("expected IRIS pure table to omit schema, got %q.%q", schema, table) + } +} + +func TestNormalizeSchemaAndTable_IRISSplitsQualifiedTable(t *testing.T) { + t.Parallel() + + schema, table := normalizeSchemaAndTable(connection.ConnectionConfig{ + Type: "iris", + }, "USER", `"Sample.Schema"."Person.Table"`) + + if schema != "Sample.Schema" || table != "Person.Table" { + t.Fatalf("expected IRIS qualified table split, got %q.%q", schema, table) + } +} + func TestNormalizeSchemaAndTable_OceanBaseOracleUsesSchemaFromDatabaseTree(t *testing.T) { t.Parallel() diff --git a/internal/app/db_proxy.go b/internal/app/db_proxy.go index c214420..a5e847b 100644 --- a/internal/app/db_proxy.go +++ b/internal/app/db_proxy.go @@ -231,6 +231,8 @@ func defaultPortByType(driverType string) int { return 9000 case "highgo": return 5866 + case "iris": + return 1972 default: return 0 } diff --git a/internal/app/methods_db.go b/internal/app/methods_db.go index 7033b5a..4662b8c 100644 --- a/internal/app/methods_db.go +++ b/internal/app/methods_db.go @@ -163,6 +163,9 @@ func resolveDDLDBType(config connection.ConnectionConfig) string { if dbType == "kingbase8" || dbType == "kingbasees" || dbType == "kingbasev8" { return "kingbase" } + if dbType == "intersystems" || dbType == "intersystemsiris" || dbType == "inter-systems" || dbType == "inter-systems-iris" { + return "iris" + } if dbType == "oceanbase" && isOceanBaseOracleProtocol(config) { return "oracle" } @@ -194,6 +197,8 @@ func resolveDDLDBType(config connection.ConnectionConfig) string { return "highgo" case "vastbase": return "vastbase" + case "iris", "intersystems", "intersystemsiris", "inter-systems", "inter-systems-iris": + return "iris" case "oceanbase": return "oceanbase" } @@ -209,6 +214,8 @@ func resolveDDLDBType(config connection.ConnectionConfig) string { return "highgo" case strings.Contains(driver, "vastbase"): return "vastbase" + case strings.Contains(driver, "iris"), strings.Contains(driver, "intersystems"): + return "iris" case strings.Contains(driver, "sqlite"): return "sqlite" case strings.Contains(driver, "sphinx"): @@ -253,6 +260,16 @@ func normalizeSchemaAndTableByType(dbType string, dbName string, tableName strin } } + if dbType == "iris" { + schema, table := db.SplitSQLQualifiedName(rawTable) + if schema != "" && table != "" { + return schema, table + } + if table != "" { + return "", table + } + } + if parts := strings.SplitN(rawTable, ".", 2); len(parts) == 2 { schema := strings.TrimSpace(parts[0]) table := strings.TrimSpace(parts[1]) diff --git a/internal/app/methods_db_create_statement_test.go b/internal/app/methods_db_create_statement_test.go index 323655e..662561f 100644 --- a/internal/app/methods_db_create_statement_test.go +++ b/internal/app/methods_db_create_statement_test.go @@ -72,6 +72,7 @@ func TestResolveDDLDBType_CustomDriverAlias(t *testing.T) { {name: "kingbase contains alias", driver: "kingbasees", want: "kingbase"}, {name: "dm alias", driver: "dm8", want: "dameng"}, {name: "sqlite alias", driver: "sqlite3", want: "sqlite"}, + {name: "iris alias", driver: "InterSystems IRIS", want: "iris"}, } for _, tc := range testCases { @@ -106,6 +107,14 @@ func TestResolveDDLDBType_KingbaseTypeAlias(t *testing.T) { } } +func TestResolveDDLDBType_IRISTypeAlias(t *testing.T) { + t.Parallel() + + if got := resolveDDLDBType(connection.ConnectionConfig{Type: "InterSystemsIRIS"}); got != "iris" { + t.Fatalf("expected InterSystemsIRIS type alias to resolve to iris, got %q", got) + } +} + func TestNormalizeSchemaAndTableByType_PGLikeQuotedQualifiedName(t *testing.T) { t.Parallel() diff --git a/internal/app/methods_driver.go b/internal/app/methods_driver.go index d3df111..6f8e26e 100644 --- a/internal/app/methods_driver.go +++ b/internal/app/methods_driver.go @@ -346,6 +346,7 @@ const builtinDriverManifestJSON = `{ "highgo": { "engine": "go", "version": "0.0.0-local", "checksumPolicy": "off", "downloadUrl": "builtin://activate/highgo" }, "vastbase": { "engine": "go", "version": "1.11.1", "checksumPolicy": "off", "downloadUrl": "builtin://activate/vastbase" }, "opengauss": { "engine": "go", "version": "1.11.1", "checksumPolicy": "off", "downloadUrl": "builtin://activate/opengauss" }, + "iris": { "engine": "go", "version": "0.2.1", "checksumPolicy": "off", "downloadUrl": "builtin://activate/iris" }, "mongodb": { "engine": "go", "version": "2.5.0", "checksumPolicy": "off", "downloadUrl": "builtin://activate/mongodb" }, "tdengine": { "engine": "go", "version": "3.7.8", "checksumPolicy": "off", "downloadUrl": "builtin://activate/tdengine" }, "clickhouse": { "engine": "go", "version": "2.43.1", "checksumPolicy": "off", "downloadUrl": "builtin://activate/clickhouse" } @@ -402,6 +403,7 @@ var latestDriverVersionMap = map[string]string{ "highgo": "0.0.0-local", "vastbase": "1.11.2", "opengauss": "1.11.1", + "iris": "0.2.1", "mongodb": "2.5.0", "tdengine": "3.7.8", "clickhouse": "2.43.1", @@ -424,6 +426,7 @@ var driverGoModulePathMap = map[string]string{ "highgo": "github.com/highgo/pq-sm3", "vastbase": "github.com/lib/pq", "opengauss": "github.com/lib/pq", + "iris": "github.com/caretdev/go-irisnative", "mongodb": "go.mongodb.org/mongo-driver/v2", "tdengine": "github.com/taosdata/driver-go/v3", "clickhouse": "github.com/ClickHouse/clickhouse-go/v2", @@ -1404,6 +1407,8 @@ func normalizeDriverType(driverType string) string { return "postgres" case "opengauss", "open_gauss", "open-gauss": return "opengauss" + case "intersystems", "intersystemsiris", "inter-systems-iris", "inter-systems": + return "iris" default: return normalized } @@ -1485,6 +1490,7 @@ func allDriverDefinitionsWithPackages(packages map[string]pinnedDriverPackage) [ buildOptionalGoDriverDefinition("highgo", "HighGo", packages), buildOptionalGoDriverDefinition("vastbase", "Vastbase", packages), buildOptionalGoDriverDefinition("opengauss", "OpenGauss", packages), + buildOptionalGoDriverDefinition("iris", "InterSystems IRIS", packages), buildOptionalGoDriverDefinition("mongodb", "MongoDB", packages), buildOptionalGoDriverDefinition("tdengine", "TDengine", packages), buildOptionalGoDriverDefinition("clickhouse", "ClickHouse", packages), @@ -3804,6 +3810,8 @@ func optionalDriverBuildTag(driverType string, selectedVersion string) (string, return "gonavi_vastbase_driver", nil case "opengauss": return "gonavi_opengauss_driver", nil + case "iris": + return "gonavi_iris_driver", nil case "mongodb": if resolveMongoDriverMajorFromVersion(selectedVersion) == 1 { return "gonavi_mongodb_driver_v1", nil diff --git a/internal/app/methods_driver_agent_revision_test.go b/internal/app/methods_driver_agent_revision_test.go index e1d1b27..0a09c9c 100644 --- a/internal/app/methods_driver_agent_revision_test.go +++ b/internal/app/methods_driver_agent_revision_test.go @@ -138,6 +138,7 @@ func optionalDriverAgentRevisionTestDrivers(t *testing.T) []string { "highgo", "vastbase", "opengauss", + "iris", "mongodb", "tdengine", "clickhouse", diff --git a/internal/app/methods_driver_version_test.go b/internal/app/methods_driver_version_test.go index 8bc07a9..f2875d0 100644 --- a/internal/app/methods_driver_version_test.go +++ b/internal/app/methods_driver_version_test.go @@ -212,6 +212,36 @@ func TestBuiltinActivatePinnedVersionDoesNotRestrictBundleFallback(t *testing.T) } } +func TestIRISDriverDefinitionUsesOptionalAgent(t *testing.T) { + definition, ok := resolveDriverDefinition("iris") + if !ok { + t.Fatal("expected iris driver definition") + } + if definition.Name != "InterSystems IRIS" { + t.Fatalf("unexpected iris driver name: %q", definition.Name) + } + if driverGoModulePathMap["iris"] != "github.com/caretdev/go-irisnative" { + t.Fatalf("unexpected iris go module path: %q", driverGoModulePathMap["iris"]) + } + if definition.PinnedVersion != "0.2.1" { + t.Fatalf("unexpected iris definition pinned version: %q", definition.PinnedVersion) + } + if definition.DefaultDownloadURL != "builtin://activate/iris" { + t.Fatalf("unexpected iris default download URL: %q", definition.DefaultDownloadURL) + } + if latestDriverVersionMap["iris"] != "0.2.1" { + t.Fatalf("unexpected iris pinned version: %q", latestDriverVersionMap["iris"]) + } + + tags, err := optionalDriverBuildTags("iris", "") + if err != nil { + t.Fatalf("resolve iris build tags failed: %v", err) + } + if tags != "gonavi_iris_driver" { + t.Fatalf("unexpected iris build tag: %q", tags) + } +} + func TestBuildOptionalDriverInstallPlanMessagePrefersDirectThenBundle(t *testing.T) { message := buildOptionalDriverInstallPlanMessage("SQL Server", "1.9.6", false, false, false, false, 1, 2) if !strings.Contains(message, "先尝试 1 个预编译直链") { diff --git a/internal/app/methods_jvm_test.go b/internal/app/methods_jvm_test.go index b1c58d5..a35cb27 100644 --- a/internal/app/methods_jvm_test.go +++ b/internal/app/methods_jvm_test.go @@ -1602,7 +1602,8 @@ func TestJVMApplyChangeFailedAuditFailureMessageIncludesUnderlyingError(t *testi if !strings.Contains(res.Message, "失败审计写入失败") { t.Fatalf("expected failed audit failure marker, got %q", res.Message) } - if !strings.Contains(strings.ToLower(res.Message), "not a directory") { + lowerMessage := strings.ToLower(res.Message) + if !strings.Contains(lowerMessage, "not a directory") && !strings.Contains(lowerMessage, "system cannot find the path specified") { t.Fatalf("expected underlying audit failure detail in message, got %q", res.Message) } } diff --git a/internal/db/custom_impl.go b/internal/db/custom_impl.go index 006d3da..dc48398 100644 --- a/internal/db/custom_impl.go +++ b/internal/db/custom_impl.go @@ -18,19 +18,21 @@ type CustomDB struct { } func (c *CustomDB) Connect(config connection.ConnectionConfig) error { - if config.Driver == "" || config.DSN == "" { + driver := strings.TrimSpace(config.Driver) + dsn := strings.TrimSpace(config.DSN) + if driver == "" || dsn == "" { return fmt.Errorf("driver and dsn are required for custom connection") } // Verify driver is registered (implicit check by sql.Open) // We might not need explicit check, sql.Open will fail or Ping will fail if driver not found. - db, err := sql.Open(config.Driver, config.DSN) + db, err := sql.Open(driver, dsn) if err != nil { - return fmt.Errorf("打开数据库连接失败:%w", err) + return formatCustomDriverOpenError(driver, err) } c.conn = db - c.driver = config.Driver + c.driver = driver c.pingTimeout = getConnectTimeout(config) if err := c.Ping(); err != nil { return fmt.Errorf("连接建立后验证失败:%w", err) @@ -38,6 +40,27 @@ func (c *CustomDB) Connect(config connection.ConnectionConfig) error { return nil } +func formatCustomDriverOpenError(driver string, err error) error { + if err == nil { + return nil + } + if strings.Contains(strings.ToLower(err.Error()), "unknown driver") { + if isLikelySystemODBCDriverName(driver) { + return fmt.Errorf("打开数据库连接失败:自定义连接不支持直接填写系统 ODBC/JDBC 驱动名 %q;请填写 GoNavi 已注册的 Go database/sql 驱动名。当前版本未注册通用 ODBC 驱动,因此暂不支持通过 %q 连接 InterSystems IRIS:%w", driver, driver, err) + } + return fmt.Errorf("打开数据库连接失败:自定义连接驱动 %q 未在 GoNavi 中注册;请填写已注册的 Go database/sql 驱动名,不能填写系统 ODBC/JDBC 驱动名:%w", driver, err) + } + return fmt.Errorf("打开数据库连接失败:%w", err) +} + +func isLikelySystemODBCDriverName(driver string) bool { + normalized := strings.ToLower(strings.TrimSpace(driver)) + return strings.Contains(normalized, "odbc") || + strings.Contains(normalized, "jdbc") || + strings.Contains(normalized, "intersystems") || + strings.Contains(normalized, "iris") +} + func (c *CustomDB) Close() error { if c.conn != nil { return c.conn.Close() diff --git a/internal/db/custom_impl_test.go b/internal/db/custom_impl_test.go new file mode 100644 index 0000000..3634869 --- /dev/null +++ b/internal/db/custom_impl_test.go @@ -0,0 +1,54 @@ +package db + +import ( + "strings" + "testing" + + "GoNavi-Wails/internal/connection" +) + +func TestCustomDBConnectReportsUnsupportedODBCDriverName(t *testing.T) { + db := &CustomDB{} + + err := db.Connect(connection.ConnectionConfig{ + Driver: "InterSystems IRIS ODBC35", + DSN: "Driver={InterSystems IRIS ODBC35};Server=127.0.0.1;Port=1972;Database=USER;", + }) + if err == nil { + t.Fatal("expected unsupported ODBC driver error, got nil") + } + + message := err.Error() + for _, want := range []string{ + "ODBC/JDBC", + "Go database/sql", + "暂不支持", + "InterSystems IRIS", + } { + if !strings.Contains(message, want) { + t.Fatalf("expected error to contain %q, got %q", want, message) + } + } +} + +func TestCustomDBConnectReportsUnregisteredGoDriver(t *testing.T) { + db := &CustomDB{} + + err := db.Connect(connection.ConnectionConfig{ + Driver: "not-a-registered-go-driver", + DSN: "demo", + }) + if err == nil { + t.Fatal("expected unregistered Go driver error, got nil") + } + + message := err.Error() + for _, want := range []string{ + "未在 GoNavi 中注册", + "Go database/sql", + } { + if !strings.Contains(message, want) { + t.Fatalf("expected error to contain %q, got %q", want, message) + } + } +} diff --git a/internal/db/database.go b/internal/db/database.go index 07b2a92..e4494b7 100644 --- a/internal/db/database.go +++ b/internal/db/database.go @@ -130,6 +130,8 @@ func normalizeDatabaseType(dbType string) string { return "kingbase" case "opengauss", "open_gauss", "open-gauss": return "opengauss" + case "intersystems", "intersystemsiris", "inter-systems-iris", "inter-systems": + return "iris" default: return normalized } diff --git a/internal/db/database_optional_factories_full.go b/internal/db/database_optional_factories_full.go index 01e5425..3147992 100644 --- a/internal/db/database_optional_factories_full.go +++ b/internal/db/database_optional_factories_full.go @@ -16,6 +16,7 @@ func registerOptionalDatabaseFactories() { registerDatabaseFactory(newOptionalDriverAgentDatabase("highgo"), "highgo") registerDatabaseFactory(newOptionalDriverAgentDatabase("vastbase"), "vastbase") registerDatabaseFactory(newOptionalDriverAgentDatabase("opengauss"), "opengauss", "open_gauss", "open-gauss") + registerDatabaseFactory(newOptionalDriverAgentDatabase("iris"), "iris", "intersystems") registerDatabaseFactory(newOptionalDriverAgentDatabase("mongodb"), "mongodb") registerDatabaseFactory(newOptionalDriverAgentDatabase("tdengine"), "tdengine") registerDatabaseFactory(newOptionalDriverAgentDatabase("clickhouse"), "clickhouse") diff --git a/internal/db/database_optional_factories_lite.go b/internal/db/database_optional_factories_lite.go index 4f6db32..d726c0f 100644 --- a/internal/db/database_optional_factories_lite.go +++ b/internal/db/database_optional_factories_lite.go @@ -16,6 +16,7 @@ func registerOptionalDatabaseFactories() { registerDatabaseFactory(newOptionalDriverAgentDatabase("highgo"), "highgo") registerDatabaseFactory(newOptionalDriverAgentDatabase("vastbase"), "vastbase") registerDatabaseFactory(newOptionalDriverAgentDatabase("opengauss"), "opengauss", "open_gauss", "open-gauss") + registerDatabaseFactory(newOptionalDriverAgentDatabase("iris"), "iris", "intersystems") registerDatabaseFactory(newOptionalDriverAgentDatabase("mongodb"), "mongodb") registerDatabaseFactory(newOptionalDriverAgentDatabase("tdengine"), "tdengine") registerDatabaseFactory(newOptionalDriverAgentDatabase("clickhouse"), "clickhouse") diff --git a/internal/db/driver_agent_revisions_gen.go b/internal/db/driver_agent_revisions_gen.go index 3c89e7a..65360a8 100644 --- a/internal/db/driver_agent_revisions_gen.go +++ b/internal/db/driver_agent_revisions_gen.go @@ -17,6 +17,7 @@ func init() { "highgo": "src-5a29a1d3685eb6b4", "vastbase": "src-e3cfef65512feb23", "opengauss": "src-58227ba3bc1ec894", + "iris": "src-1b072c57af08bec4", "mongodb": "src-57fdd8bfebdcd46e", "tdengine": "src-939715f94df1ec9c", "clickhouse": "src-482d62ed565b3e69", diff --git a/internal/db/driver_support.go b/internal/db/driver_support.go index 647f4e2..05f942e 100644 --- a/internal/db/driver_support.go +++ b/internal/db/driver_support.go @@ -34,6 +34,7 @@ var optionalGoDrivers = map[string]struct{}{ "highgo": {}, "vastbase": {}, "opengauss": {}, + "iris": {}, "mongodb": {}, "tdengine": {}, "clickhouse": {}, @@ -60,6 +61,8 @@ func normalizeRuntimeDriverType(driverType string) string { return "kingbase" case "opengauss", "open_gauss", "open-gauss": return "opengauss" + case "intersystems", "intersystemsiris", "inter-systems-iris", "inter-systems": + return "iris" default: return normalized } @@ -101,6 +104,8 @@ func driverDisplayName(driverType string) string { return "Vastbase" case "opengauss": return "OpenGauss" + case "iris": + return "InterSystems IRIS" case "mongodb": return "MongoDB" case "tdengine": diff --git a/internal/db/driver_support_test.go b/internal/db/driver_support_test.go index d8ebc14..bcdedb0 100644 --- a/internal/db/driver_support_test.go +++ b/internal/db/driver_support_test.go @@ -113,7 +113,7 @@ func TestNewCompatibleDriversAreOptionalAgentDrivers(t *testing.T) { tmpDir := t.TempDir() SetExternalDriverDownloadDirectory(tmpDir) - for _, driverType := range []string{"oceanbase", "opengauss", "open_gauss", "starrocks"} { + for _, driverType := range []string{"oceanbase", "opengauss", "open_gauss", "starrocks", "iris", "intersystems"} { if IsBuiltinDriver(driverType) { t.Fatalf("%s 不应是免安装内置驱动", driverType) } diff --git a/internal/db/iris_impl.go b/internal/db/iris_impl.go new file mode 100644 index 0000000..74746f8 --- /dev/null +++ b/internal/db/iris_impl.go @@ -0,0 +1,960 @@ +//go:build gonavi_full_drivers || gonavi_iris_driver + +package db + +import ( + "context" + "database/sql" + "fmt" + "net" + "net/url" + "sort" + "strconv" + "strings" + "time" + + "GoNavi-Wails/internal/connection" + "GoNavi-Wails/internal/logger" + "GoNavi-Wails/internal/ssh" + "GoNavi-Wails/internal/utils" + + _ "github.com/caretdev/go-irisnative" +) + +const ( + defaultIRISPort = 1972 + defaultIRISNamespace = "USER" +) + +type IrisDB struct { + conn *sql.DB + pingTimeout time.Duration + namespace string + forwarder *ssh.LocalForwarder +} + +type irisTableRef struct { + Schema string + Table string +} + +func normalizeIRISNamespace(namespace string) string { + trimmed := strings.Trim(strings.TrimSpace(namespace), "/") + if trimmed == "" { + return defaultIRISNamespace + } + return trimmed +} + +func applyIRISURI(config connection.ConnectionConfig) connection.ConnectionConfig { + parsed, ok := parseConnectionURI(config.URI, "iris", "intersystems") + if !ok || parsed == nil { + return config + } + next := config + if host := strings.TrimSpace(parsed.Hostname()); host != "" { + next.Host = host + } + if portText := strings.TrimSpace(parsed.Port()); portText != "" { + if port, err := strconv.Atoi(portText); err == nil && port > 0 { + next.Port = port + } + } + if parsed.User != nil { + next.User = parsed.User.Username() + if password, ok := parsed.User.Password(); ok { + next.Password = password + } + } + if namespace := strings.Trim(strings.TrimSpace(parsed.Path), "/"); namespace != "" { + next.Database = namespace + } + return next +} + +func (i *IrisDB) getDSN(config connection.ConnectionConfig) string { + namespace := normalizeIRISNamespace(config.Database) + port := config.Port + if port <= 0 { + port = defaultIRISPort + } + + u := &url.URL{ + Scheme: "iris", + Host: net.JoinHostPort(config.Host, strconv.Itoa(port)), + Path: "/" + namespace, + } + u.User = url.UserPassword(config.User, config.Password) + + q := url.Values{} + mergeConnectionParamsFromConfig(q, config, "iris", "intersystems") + u.RawQuery = q.Encode() + return u.String() +} + +func (i *IrisDB) Connect(config connection.ConnectionConfig) error { + runConfig := applyIRISURI(config) + if runConfig.Port <= 0 { + runConfig.Port = defaultIRISPort + } + i.namespace = normalizeIRISNamespace(runConfig.Database) + + cleanupOnFailure := true + defer func() { + if !cleanupOnFailure { + return + } + if i.conn != nil { + _ = i.conn.Close() + i.conn = nil + } + if i.forwarder != nil { + _ = i.forwarder.Close() + i.forwarder = nil + } + }() + + if runConfig.UseSSH { + logger.Infof("InterSystems IRIS 使用 SSH 连接:地址=%s:%d 用户=%s", runConfig.Host, runConfig.Port, runConfig.User) + forwarder, err := ssh.GetOrCreateLocalForwarder(runConfig.SSH, runConfig.Host, runConfig.Port) + if err != nil { + return fmt.Errorf("创建 SSH 隧道失败:%w", err) + } + i.forwarder = forwarder + + host, portStr, err := net.SplitHostPort(forwarder.LocalAddr) + if err != nil { + return fmt.Errorf("解析本地转发地址失败:%w", err) + } + port, err := strconv.Atoi(portStr) + if err != nil { + return fmt.Errorf("解析本地端口失败:%w", err) + } + + runConfig.Host = host + runConfig.Port = port + runConfig.UseSSH = false + logger.Infof("InterSystems IRIS 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port) + } + + db, err := sql.Open("iris", i.getDSN(runConfig)) + if err != nil { + return fmt.Errorf("打开数据库连接失败:%w", err) + } + i.conn = db + i.pingTimeout = getConnectTimeout(runConfig) + if err := i.Ping(); err != nil { + return fmt.Errorf("连接建立后验证失败:%w", err) + } + cleanupOnFailure = false + return nil +} + +func (i *IrisDB) Close() error { + if i.forwarder != nil { + if err := i.forwarder.Close(); err != nil { + logger.Warnf("关闭 InterSystems IRIS SSH 端口转发失败:%v", err) + } + i.forwarder = nil + } + if i.conn != nil { + return i.conn.Close() + } + return nil +} + +func (i *IrisDB) Ping() error { + if i.conn == nil { + return fmt.Errorf("连接未打开") + } + timeout := i.pingTimeout + if timeout <= 0 { + timeout = 5 * time.Second + } + ctx, cancel := utils.ContextWithTimeout(timeout) + defer cancel() + return i.conn.PingContext(ctx) +} + +func (i *IrisDB) QueryMulti(query string) ([]connection.ResultSetData, error) { + if i.conn == nil { + return nil, fmt.Errorf("连接未打开") + } + rows, err := i.conn.Query(query) + if err != nil { + return nil, err + } + defer rows.Close() + return scanMultiRows(rows) +} + +func (i *IrisDB) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) { + if i.conn == nil { + return nil, fmt.Errorf("连接未打开") + } + rows, err := i.conn.QueryContext(ctx, query) + if err != nil { + return nil, err + } + defer rows.Close() + return scanMultiRows(rows) +} + +func (i *IrisDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) { + if i.conn == nil { + return nil, nil, fmt.Errorf("连接未打开") + } + rows, err := i.conn.QueryContext(ctx, query) + if err != nil { + return nil, nil, err + } + defer rows.Close() + return scanRows(rows) +} + +func (i *IrisDB) Query(query string) ([]map[string]interface{}, []string, error) { + if i.conn == nil { + return nil, nil, fmt.Errorf("连接未打开") + } + rows, err := i.conn.Query(query) + if err != nil { + return nil, nil, err + } + defer rows.Close() + return scanRows(rows) +} + +func (i *IrisDB) ExecContext(ctx context.Context, query string) (int64, error) { + if i.conn == nil { + return 0, fmt.Errorf("连接未打开") + } + res, err := i.conn.ExecContext(ctx, query) + if err != nil { + return 0, err + } + return res.RowsAffected() +} + +func (i *IrisDB) ExecBatchContext(ctx context.Context, query string) (int64, error) { + return i.ExecContext(ctx, query) +} + +func (i *IrisDB) Exec(query string) (int64, error) { + if i.conn == nil { + return 0, fmt.Errorf("连接未打开") + } + res, err := i.conn.Exec(query) + if err != nil { + return 0, err + } + return res.RowsAffected() +} + +func (i *IrisDB) GetDatabases() ([]string, error) { + namespace := strings.TrimSpace(i.namespace) + if namespace != "" { + return []string{namespace}, nil + } + data, _, err := i.Query(`SELECT DISTINCT TABLE_CATALOG FROM INFORMATION_SCHEMA.TABLES`) + if err != nil { + return nil, err + } + var namespaces []string + seen := map[string]struct{}{} + for _, row := range data { + name := strings.TrimSpace(rowString(row, "TABLE_CATALOG", "table_catalog")) + if name == "" { + continue + } + if _, ok := seen[name]; ok { + continue + } + seen[name] = struct{}{} + namespaces = append(namespaces, name) + } + sort.Strings(namespaces) + return namespaces, nil +} + +func (i *IrisDB) GetTables(dbName string) ([]string, error) { + data, _, err := i.Query(`SELECT * FROM INFORMATION_SCHEMA.TABLES`) + if err != nil { + return nil, err + } + var tables []string + seen := map[string]struct{}{} + for _, row := range data { + tableType := strings.ToUpper(strings.TrimSpace(rowString(row, "TABLE_TYPE", "table_type"))) + if tableType != "" && tableType != "TABLE" && tableType != "BASE TABLE" { + continue + } + schema := strings.TrimSpace(rowString(row, "TABLE_SCHEMA", "table_schema", "SCHEMA_NAME", "schema_name")) + table := strings.TrimSpace(rowString(row, "TABLE_NAME", "table_name")) + if table == "" || isIRISSystemSchema(schema) { + continue + } + name := table + if schema != "" { + name = schema + "." + table + } + if _, ok := seen[name]; ok { + continue + } + seen[name] = struct{}{} + tables = append(tables, name) + } + sort.Strings(tables) + return tables, nil +} + +func (i *IrisDB) GetCreateStatement(dbName, tableName string) (string, error) { + ref, err := parseIRISTableRef(dbName, tableName) + if err != nil { + return "", err + } + columns, err := i.GetColumns(dbName, tableName) + if err != nil { + return "", err + } + if len(columns) == 0 { + return "", fmt.Errorf("未找到表字段:%s", tableName) + } + indexes, idxErr := i.GetIndexes(dbName, tableName) + if idxErr != nil { + indexes = nil + } + return buildIRISCreateTableDDL(ref, columns, indexes), nil +} + +func (i *IrisDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) { + ref, err := parseIRISTableRef(dbName, tableName) + if err != nil { + return nil, err + } + data, _, err := i.Query(buildIRISInfoSchemaWhereQuery("INFORMATION_SCHEMA.COLUMNS", ref)) + if err != nil { + return nil, err + } + indexes, _ := i.GetIndexes(dbName, tableName) + keyByColumn := irisColumnKeyMap(indexes) + + columns := make([]connection.ColumnDefinition, 0, len(data)) + for _, row := range data { + name := strings.TrimSpace(rowString(row, "COLUMN_NAME", "column_name")) + if name == "" { + continue + } + key := keyByColumn[name] + if primary, ok := irisBoolFromRow(row, "PRIMARY_KEY", "primary_key"); ok && primary { + key = "PRI" + } else if key == "" { + if unique, ok := irisBoolFromRow(row, "UNIQUE_COLUMN", "unique_column", "IS_UNIQUE", "is_unique", "UNIQUE", "unique"); ok && unique { + key = "UNI" + } + } + col := connection.ColumnDefinition{ + Name: name, + Type: buildIRISColumnType(row), + Nullable: normalizeIRISNullable(rowString(row, "IS_NULLABLE", "is_nullable")), + Key: key, + Extra: "", + Comment: rowString(row, "DESCRIPTION", "description", "COMMENT", "comment"), + } + if rawDefault, ok := rowValue(row, "COLUMN_DEFAULT", "column_default"); ok && rawDefault != nil { + def := strings.TrimSpace(fmt.Sprintf("%v", rawDefault)) + if def != "" { + col.Default = &def + } + } + columns = append(columns, col) + } + sort.SliceStable(columns, func(a, b int) bool { + return rowOrdinal(data, columns[a].Name) < rowOrdinal(data, columns[b].Name) + }) + return columns, nil +} + +func (i *IrisDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) { + data, _, err := i.Query(`SELECT * FROM INFORMATION_SCHEMA.COLUMNS`) + if err != nil { + return nil, err + } + cols := make([]connection.ColumnDefinitionWithTable, 0, len(data)) + for _, row := range data { + schema := strings.TrimSpace(rowString(row, "TABLE_SCHEMA", "table_schema")) + table := strings.TrimSpace(rowString(row, "TABLE_NAME", "table_name")) + name := strings.TrimSpace(rowString(row, "COLUMN_NAME", "column_name")) + if table == "" || name == "" || isIRISSystemSchema(schema) { + continue + } + tableName := table + if schema != "" { + tableName = schema + "." + table + } + cols = append(cols, connection.ColumnDefinitionWithTable{ + TableName: tableName, + Name: name, + Type: buildIRISColumnType(row), + }) + } + sort.SliceStable(cols, func(a, b int) bool { + if cols[a].TableName == cols[b].TableName { + return cols[a].Name < cols[b].Name + } + return cols[a].TableName < cols[b].TableName + }) + return cols, nil +} + +func (i *IrisDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) { + ref, err := parseIRISTableRef(dbName, tableName) + if err != nil { + return nil, err + } + data, _, err := i.Query(buildIRISInfoSchemaWhereQuery("INFORMATION_SCHEMA.INDEXES", ref)) + if err != nil { + return nil, err + } + indexes := make([]connection.IndexDefinition, 0, len(data)) + for _, row := range data { + name := strings.TrimSpace(rowString(row, "INDEX_NAME", "index_name", "KEY_NAME", "key_name", "CONSTRAINT_NAME", "constraint_name")) + column := strings.TrimSpace(rowString(row, "COLUMN_NAME", "column_name")) + primary, hasPrimaryFlag := irisBoolFromRow(row, "PRIMARY_KEY", "primary_key") + if name == "" && hasPrimaryFlag && primary { + name = "PRIMARY" + } + if name == "" || column == "" { + continue + } + indexType := normalizeIRISIndexType(rowString(row, "INDEX_TYPE", "index_type", "TYPE", "type")) + if hasPrimaryFlag && primary { + indexType = "PRIMARY" + } + nonUnique := parseIRISNonUnique(row) + indexes = append(indexes, connection.IndexDefinition{ + Name: name, + ColumnName: column, + NonUnique: nonUnique, + SeqInIndex: parseIRISInt(rowValueAny(row, "ORDINAL_POSITION", "ordinal_position", "SEQ_IN_INDEX", "seq_in_index", "KEY_SEQ", "key_seq")), + IndexType: indexType, + }) + } + sort.SliceStable(indexes, func(a, b int) bool { + if indexes[a].Name == indexes[b].Name { + return indexes[a].SeqInIndex < indexes[b].SeqInIndex + } + return indexes[a].Name < indexes[b].Name + }) + return indexes, nil +} + +func (i *IrisDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) { + return []connection.ForeignKeyDefinition{}, nil +} + +func (i *IrisDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) { + return []connection.TriggerDefinition{}, nil +} + +func (i *IrisDB) ApplyChanges(tableName string, changes connection.ChangeSet) error { + if i.conn == nil { + return fmt.Errorf("连接未打开") + } + tx, err := i.conn.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + for _, keys := range changes.Deletes { + query, args, ok := buildIRISDeleteSQL(tableName, keys) + if !ok { + continue + } + res, err := tx.Exec(query, args...) + if err != nil { + return fmt.Errorf("删除失败:%v", err) + } + if err := requireSingleRowAffected(res, "删除"); err != nil { + return err + } + } + + for _, update := range changes.Updates { + query, args, ok, err := buildIRISUpdateSQL(tableName, update) + if err != nil { + return err + } + if !ok { + continue + } + res, err := tx.Exec(query, args...) + if err != nil { + return fmt.Errorf("更新失败:%v", err) + } + if err := requireSingleRowAffected(res, "更新"); err != nil { + return err + } + } + + for _, row := range changes.Inserts { + query, args, ok := buildIRISInsertSQL(tableName, row) + if !ok { + continue + } + res, err := tx.Exec(query, args...) + if err != nil { + return fmt.Errorf("插入失败:%v", err) + } + if affected, err := res.RowsAffected(); err == nil && affected == 0 { + return fmt.Errorf("插入未生效:未影响任何行") + } + } + + return tx.Commit() +} + +func buildIRISInfoSchemaWhereQuery(table string, ref irisTableRef) string { + conditions := []string{fmt.Sprintf("TABLE_NAME = '%s'", irisSQLLiteral(ref.Table))} + if ref.Schema != "" { + conditions = append(conditions, fmt.Sprintf("TABLE_SCHEMA = '%s'", irisSQLLiteral(ref.Schema))) + } + orderBy := "" + switch strings.ToUpper(strings.TrimSpace(table)) { + case "INFORMATION_SCHEMA.COLUMNS": + orderBy = " ORDER BY ORDINAL_POSITION" + case "INFORMATION_SCHEMA.INDEXES": + orderBy = " ORDER BY INDEX_NAME, ORDINAL_POSITION" + } + return fmt.Sprintf("SELECT * FROM %s WHERE %s%s", table, strings.Join(conditions, " AND "), orderBy) +} + +func parseIRISTableRef(defaultSchema, raw string) (irisTableRef, error) { + text := strings.TrimSpace(raw) + if text == "" { + return irisTableRef{}, fmt.Errorf("表名不能为空") + } + if schemaPart, tablePart, ok := splitIRISTablePath(text); ok { + schema := cleanIRISIdentifier(schemaPart) + table := cleanIRISIdentifier(tablePart) + if table == "" { + return irisTableRef{}, fmt.Errorf("表名不能为空") + } + return irisTableRef{Schema: schema, Table: table}, nil + } + return irisTableRef{Schema: cleanIRISIdentifier(defaultSchema), Table: cleanIRISIdentifier(text)}, nil +} + +func splitIRISTablePath(raw string) (schemaPart, tablePart string, ok bool) { + inQuote := false + for idx := 0; idx < len(raw); idx++ { + switch raw[idx] { + case '"': + if inQuote && idx+1 < len(raw) && raw[idx+1] == '"' { + idx++ + continue + } + inQuote = !inQuote + case '.': + if !inQuote { + return raw[:idx], raw[idx+1:], true + } + } + } + return "", raw, false +} + +func cleanIRISIdentifier(raw string) string { + text := strings.TrimSpace(raw) + text = strings.Trim(text, `"`) + return strings.ReplaceAll(text, `""`, `"`) +} + +func irisSQLLiteral(raw string) string { + return strings.ReplaceAll(raw, "'", "''") +} + +func irisQuoteIdent(name string) string { + text := cleanIRISIdentifier(name) + text = strings.ReplaceAll(text, `"`, `""`) + return `"` + text + `"` +} + +func irisQuoteTable(raw string) string { + ref, err := parseIRISTableRef("", raw) + if err != nil { + return irisQuoteIdent(raw) + } + if ref.Schema != "" { + return irisQuoteIdent(ref.Schema) + "." + irisQuoteIdent(ref.Table) + } + return irisQuoteIdent(ref.Table) +} + +func isIRISSystemSchema(schema string) bool { + normalized := strings.ToUpper(strings.TrimSpace(schema)) + return normalized == "INFORMATION_SCHEMA" || + strings.HasPrefix(normalized, "%") || + strings.HasPrefix(normalized, "SYS") +} + +func rowValue(row map[string]interface{}, keys ...string) (interface{}, bool) { + for _, key := range keys { + if value, ok := row[key]; ok { + return value, true + } + for existing, value := range row { + if strings.EqualFold(existing, key) { + return value, true + } + } + } + return nil, false +} + +func rowValueAny(row map[string]interface{}, keys ...string) interface{} { + value, _ := rowValue(row, keys...) + return value +} + +func rowString(row map[string]interface{}, keys ...string) string { + value, ok := rowValue(row, keys...) + if !ok || value == nil { + return "" + } + return fmt.Sprintf("%v", value) +} + +func parseIRISInt(value interface{}) int { + switch v := value.(type) { + case int: + return v + case int32: + return int(v) + case int64: + return int(v) + case float64: + return int(v) + case string: + n, _ := strconv.Atoi(strings.TrimSpace(v)) + return n + default: + n, _ := strconv.Atoi(strings.TrimSpace(fmt.Sprintf("%v", value))) + return n + } +} + +func parseIRISBool(value interface{}) (bool, bool) { + switch v := value.(type) { + case bool: + return v, true + case int: + return v != 0, true + case int64: + return v != 0, true + case float64: + return v != 0, true + case string: + switch strings.ToLower(strings.TrimSpace(v)) { + case "1", "true", "t", "yes", "y": + return true, true + case "0", "false", "f", "no", "n": + return false, true + } + } + return false, false +} + +func irisBoolFromRow(row map[string]interface{}, keys ...string) (bool, bool) { + value, ok := rowValue(row, keys...) + if !ok { + return false, false + } + return parseIRISBool(value) +} + +func parseIRISNonUnique(row map[string]interface{}) int { + if primary, ok := irisBoolFromRow(row, "PRIMARY_KEY", "primary_key"); ok && primary { + return 0 + } + if value, ok := rowValue(row, "NON_UNIQUE", "non_unique"); ok { + if enabled, ok := parseIRISBool(value); ok { + if enabled { + return 1 + } + return 0 + } + n := parseIRISInt(value) + if n != 0 { + return 1 + } + return 0 + } + if value, ok := rowValue(row, "IS_UNIQUE", "is_unique", "UNIQUE", "unique"); ok { + if unique, ok := parseIRISBool(value); ok && unique { + return 0 + } + } + if unique, ok := irisBoolFromRow(row, "UNIQUE_COLUMN", "unique_column"); ok && unique { + return 0 + } + return 1 +} + +func normalizeIRISIndexType(raw string) string { + text := strings.ToUpper(strings.TrimSpace(raw)) + if text == "" { + return "BTREE" + } + return text +} + +func normalizeIRISNullable(raw string) string { + switch strings.ToUpper(strings.TrimSpace(raw)) { + case "NO", "N", "FALSE", "0": + return "NO" + default: + return "YES" + } +} + +func buildIRISColumnType(row map[string]interface{}) string { + dataType := strings.TrimSpace(rowString(row, "DATA_TYPE", "data_type", "TYPE_NAME", "type_name")) + if dataType == "" { + dataType = "VARCHAR" + } + upper := strings.ToUpper(dataType) + charLength := parseIRISInt(rowValueAny(row, "CHARACTER_MAXIMUM_LENGTH", "character_maximum_length", "CHARACTER_MAX_LENGTH", "character_max_length")) + precision := parseIRISInt(rowValueAny(row, "NUMERIC_PRECISION", "numeric_precision")) + scale := parseIRISInt(rowValueAny(row, "NUMERIC_SCALE", "numeric_scale")) + if charLength > 0 && (strings.Contains(upper, "CHAR") || strings.Contains(upper, "VARCHAR")) && !strings.Contains(dataType, "(") { + return fmt.Sprintf("%s(%d)", dataType, charLength) + } + if precision > 0 && (strings.Contains(upper, "NUMERIC") || strings.Contains(upper, "DECIMAL") || strings.Contains(upper, "NUMBER")) && !strings.Contains(dataType, "(") { + if scale > 0 { + return fmt.Sprintf("%s(%d,%d)", dataType, precision, scale) + } + return fmt.Sprintf("%s(%d)", dataType, precision) + } + return dataType +} + +func rowOrdinal(rows []map[string]interface{}, columnName string) int { + for idx, row := range rows { + if strings.EqualFold(rowString(row, "COLUMN_NAME", "column_name"), columnName) { + ordinal := parseIRISInt(rowValueAny(row, "ORDINAL_POSITION", "ordinal_position")) + if ordinal > 0 { + return ordinal + } + return idx + 1 + } + } + return len(rows) + 1 +} + +func irisColumnKeyMap(indexes []connection.IndexDefinition) map[string]string { + result := map[string]string{} + for _, idx := range indexes { + column := strings.TrimSpace(idx.ColumnName) + if column == "" { + continue + } + if isIRISPrimaryIndex(idx) { + result[column] = "PRI" + continue + } + if idx.NonUnique == 0 && result[column] == "" { + result[column] = "UNI" + } + } + return result +} + +func isIRISPrimaryIndexName(name string) bool { + normalized := strings.ToUpper(strings.TrimSpace(name)) + return normalized == "PRIMARY" || normalized == "PRIMARYKEY" || normalized == "IDKEY" +} + +func isIRISPrimaryIndex(idx connection.IndexDefinition) bool { + return isIRISPrimaryIndexName(idx.Name) || strings.EqualFold(strings.TrimSpace(idx.IndexType), "PRIMARY") +} + +func buildIRISCreateTableDDL(ref irisTableRef, columns []connection.ColumnDefinition, indexes []connection.IndexDefinition) string { + qualified := irisQuoteIdent(ref.Table) + if strings.TrimSpace(ref.Schema) != "" { + qualified = irisQuoteIdent(ref.Schema) + "." + qualified + } + + lines := make([]string, 0, len(columns)+1) + primaryColumns := irisPrimaryColumns(indexes) + if len(primaryColumns) == 0 { + primaryColumns = irisPrimaryColumnsFromColumns(columns) + } + for _, col := range columns { + line := fmt.Sprintf(" %s %s", irisQuoteIdent(col.Name), strings.TrimSpace(col.Type)) + if col.Default != nil && strings.TrimSpace(*col.Default) != "" { + line += " DEFAULT " + strings.TrimSpace(*col.Default) + } + if strings.EqualFold(strings.TrimSpace(col.Nullable), "NO") { + line += " NOT NULL" + } + lines = append(lines, line) + } + if len(primaryColumns) > 0 { + lines = append(lines, fmt.Sprintf(" PRIMARY KEY (%s)", irisQuoteIdentList(primaryColumns))) + } + + var b strings.Builder + b.WriteString(fmt.Sprintf("CREATE TABLE %s (\n%s\n);", qualified, strings.Join(lines, ",\n"))) + + for _, stmt := range buildIRISCreateIndexStatements(ref, indexes) { + b.WriteString("\n\n") + b.WriteString(stmt) + } + return b.String() +} + +func irisPrimaryColumns(indexes []connection.IndexDefinition) []string { + for _, group := range groupIRISIndexes(indexes) { + if group.Primary { + return group.Columns + } + } + return nil +} + +func irisPrimaryColumnsFromColumns(columns []connection.ColumnDefinition) []string { + primaryColumns := make([]string, 0) + for _, column := range columns { + if strings.EqualFold(strings.TrimSpace(column.Key), "PRI") && strings.TrimSpace(column.Name) != "" { + primaryColumns = append(primaryColumns, column.Name) + } + } + return primaryColumns +} + +type irisIndexGroup struct { + Name string + Columns []string + NonUnique int + IndexType string + Primary bool +} + +func groupIRISIndexes(indexes []connection.IndexDefinition) []irisIndexGroup { + groupsByName := map[string]*irisIndexGroup{} + order := make([]string, 0) + for _, idx := range indexes { + name := strings.TrimSpace(idx.Name) + column := strings.TrimSpace(idx.ColumnName) + if name == "" || column == "" { + continue + } + group, ok := groupsByName[name] + if !ok { + group = &irisIndexGroup{Name: name, NonUnique: idx.NonUnique, IndexType: idx.IndexType} + groupsByName[name] = group + order = append(order, name) + } + group.Columns = append(group.Columns, column) + if idx.NonUnique == 0 { + group.NonUnique = 0 + } + if isIRISPrimaryIndex(idx) { + group.Primary = true + } + } + sort.Strings(order) + groups := make([]irisIndexGroup, 0, len(order)) + for _, name := range order { + group := groupsByName[name] + groups = append(groups, *group) + } + return groups +} + +func buildIRISCreateIndexStatements(ref irisTableRef, indexes []connection.IndexDefinition) []string { + qualified := irisQuoteIdent(ref.Table) + if strings.TrimSpace(ref.Schema) != "" { + qualified = irisQuoteIdent(ref.Schema) + "." + qualified + } + var statements []string + for _, group := range groupIRISIndexes(indexes) { + if len(group.Columns) == 0 || group.Primary { + continue + } + unique := "" + if group.NonUnique == 0 { + unique = "UNIQUE " + } + statements = append(statements, fmt.Sprintf("CREATE %sINDEX %s ON %s (%s);", unique, irisQuoteIdent(group.Name), qualified, irisQuoteIdentList(group.Columns))) + } + return statements +} + +func irisQuoteIdentList(columns []string) string { + quoted := make([]string, 0, len(columns)) + for _, column := range columns { + quoted = append(quoted, irisQuoteIdent(column)) + } + return strings.Join(quoted, ", ") +} + +func buildIRISDeleteSQL(tableName string, keys map[string]interface{}) (string, []interface{}, bool) { + wheres, args := irisAssignments(keys, " = ?") + if len(wheres) == 0 { + return "", nil, false + } + return fmt.Sprintf("DELETE FROM %s WHERE %s", irisQuoteTable(tableName), strings.Join(wheres, " AND ")), args, true +} + +func buildIRISUpdateSQL(tableName string, update connection.UpdateRow) (string, []interface{}, bool, error) { + sets, args := irisAssignments(update.Values, " = ?") + if len(sets) == 0 { + return "", nil, false, nil + } + wheres, whereArgs := irisAssignments(update.Keys, " = ?") + if len(wheres) == 0 { + return "", nil, false, fmt.Errorf("更新操作需要主键条件") + } + args = append(args, whereArgs...) + return fmt.Sprintf("UPDATE %s SET %s WHERE %s", irisQuoteTable(tableName), strings.Join(sets, ", "), strings.Join(wheres, " AND ")), args, true, nil +} + +func buildIRISInsertSQL(tableName string, row map[string]interface{}) (string, []interface{}, bool) { + if len(row) == 0 { + return "", nil, false + } + keys := sortedMapKeys(row) + cols := make([]string, 0, len(keys)) + placeholders := make([]string, 0, len(keys)) + args := make([]interface{}, 0, len(keys)) + for _, key := range keys { + cols = append(cols, irisQuoteIdent(key)) + placeholders = append(placeholders, "?") + args = append(args, row[key]) + } + return fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", irisQuoteTable(tableName), strings.Join(cols, ", "), strings.Join(placeholders, ", ")), args, true +} + +func irisAssignments(values map[string]interface{}, suffix string) ([]string, []interface{}) { + keys := sortedMapKeys(values) + parts := make([]string, 0, len(keys)) + args := make([]interface{}, 0, len(keys)) + for _, key := range keys { + parts = append(parts, irisQuoteIdent(key)+suffix) + args = append(args, values[key]) + } + return parts, args +} + +func sortedMapKeys(values map[string]interface{}) []string { + keys := make([]string, 0, len(values)) + for key := range values { + if strings.TrimSpace(key) != "" { + keys = append(keys, key) + } + } + sort.Strings(keys) + return keys +} diff --git a/internal/db/iris_impl_test.go b/internal/db/iris_impl_test.go new file mode 100644 index 0000000..1e7e9d1 --- /dev/null +++ b/internal/db/iris_impl_test.go @@ -0,0 +1,275 @@ +//go:build gonavi_full_drivers || gonavi_iris_driver + +package db + +import ( + "database/sql/driver" + "net/url" + "reflect" + "strings" + "testing" + + "GoNavi-Wails/internal/connection" +) + +func TestIrisDSNUsesNamespaceDefaultPortAndConnectionParams(t *testing.T) { + iris := &IrisDB{} + + dsn := iris.getDSN(connection.ConnectionConfig{ + Host: "db.example.com", + User: "_SYSTEM", + Password: "p@ss", + ConnectionParams: "timeout=30&ssl=1", + }) + + parsed, err := url.Parse(dsn) + if err != nil { + t.Fatalf("parse dsn: %v", err) + } + if parsed.Scheme != "iris" { + t.Fatalf("scheme = %q", parsed.Scheme) + } + if parsed.Host != "db.example.com:1972" { + t.Fatalf("host = %q", parsed.Host) + } + if parsed.Path != "/USER" { + t.Fatalf("namespace path = %q", parsed.Path) + } + if parsed.User.Username() != "_SYSTEM" { + t.Fatalf("user = %q", parsed.User.Username()) + } + password, _ := parsed.User.Password() + if password != "p@ss" { + t.Fatalf("password = %q", password) + } + if got := parsed.Query().Get("timeout"); got != "30" { + t.Fatalf("timeout param = %q", got) + } + if got := parsed.Query().Get("ssl"); got != "1" { + t.Fatalf("ssl param = %q", got) + } +} + +func TestApplyIRISURIExtractsConnectionFields(t *testing.T) { + config := applyIRISURI(connection.ConnectionConfig{ + URI: "iris://user:secret@iris.local:1973/APP?timeout=30", + Database: "SHOULD_BE_REPLACED", + }) + + if config.Host != "iris.local" || config.Port != 1973 || config.User != "user" || config.Password != "secret" { + t.Fatalf("unexpected parsed config: %#v", config) + } + if config.Database != "APP" { + t.Fatalf("database namespace = %q", config.Database) + } +} + +func TestIRISTableRefAndIdentifierQuoting(t *testing.T) { + ref, err := parseIRISTableRef("Sample", `"Person.Table"`) + if err != nil { + t.Fatalf("parse table ref: %v", err) + } + if ref.Schema != "Sample" || ref.Table != "Person.Table" { + t.Fatalf("unexpected ref: %#v", ref) + } + + ref, err = parseIRISTableRef("", `"Sample"."Person""Archive"`) + if err != nil { + t.Fatalf("parse qualified table ref: %v", err) + } + if ref.Schema != "Sample" || ref.Table != `Person"Archive` { + t.Fatalf("unexpected qualified ref: %#v", ref) + } + if got := irisQuoteTable(`"Sample"."Person""Archive"`); got != `"Sample"."Person""Archive"` { + t.Fatalf("quoted table = %s", got) + } +} + +func TestIRISColumnKeyMapPrefersPrimaryThenUnique(t *testing.T) { + keys := irisColumnKeyMap([]connection.IndexDefinition{ + {Name: "idx_id", ColumnName: "id", NonUnique: 0}, + {Name: "IDKEY", ColumnName: "id", NonUnique: 0}, + {Name: "idx_email", ColumnName: "email", NonUnique: 0}, + {Name: "idx_name", ColumnName: "name", NonUnique: 1}, + }) + + if keys["id"] != "PRI" { + t.Fatalf("id key = %q", keys["id"]) + } + if keys["email"] != "UNI" { + t.Fatalf("email key = %q", keys["email"]) + } + if keys["name"] != "" { + t.Fatalf("name key = %q", keys["name"]) + } +} + +func TestBuildIRISCreateTableDDLIncludesPrimaryAndIndexes(t *testing.T) { + defaultValue := "CURRENT_TIMESTAMP" + ddl := buildIRISCreateTableDDL( + irisTableRef{Schema: "Sample", Table: "Person"}, + []connection.ColumnDefinition{ + {Name: "id", Type: "INTEGER", Nullable: "NO"}, + {Name: "name", Type: "VARCHAR(80)", Nullable: "NO"}, + {Name: "created_at", Type: "TIMESTAMP", Nullable: "YES", Default: &defaultValue}, + }, + []connection.IndexDefinition{ + {Name: "app_person_pk", ColumnName: "id", NonUnique: 0, SeqInIndex: 1, IndexType: "PRIMARY"}, + {Name: "idx_person_name", ColumnName: "name", NonUnique: 0, SeqInIndex: 1}, + {Name: "idx_person_created_at", ColumnName: "created_at", NonUnique: 1, SeqInIndex: 1}, + }, + ) + + for _, want := range []string{ + `CREATE TABLE "Sample"."Person"`, + `"id" INTEGER NOT NULL`, + `"name" VARCHAR(80) NOT NULL`, + `"created_at" TIMESTAMP DEFAULT CURRENT_TIMESTAMP`, + `PRIMARY KEY ("id")`, + `CREATE UNIQUE INDEX "idx_person_name" ON "Sample"."Person" ("name");`, + `CREATE INDEX "idx_person_created_at" ON "Sample"."Person" ("created_at");`, + } { + if !strings.Contains(ddl, want) { + t.Fatalf("ddl missing %q:\n%s", want, ddl) + } + } + if strings.Contains(ddl, `CREATE UNIQUE INDEX "app_person_pk"`) { + t.Fatalf("primary key index should not be emitted as a standalone index:\n%s", ddl) + } +} + +func TestBuildIRISCreateTableDDLFallsBackToColumnPrimaryKey(t *testing.T) { + ddl := buildIRISCreateTableDDL( + irisTableRef{Schema: "Sample", Table: "Person"}, + []connection.ColumnDefinition{ + {Name: "id", Type: "INTEGER", Nullable: "NO", Key: "PRI"}, + {Name: "name", Type: "VARCHAR(80)", Nullable: "YES"}, + }, + nil, + ) + + if !strings.Contains(ddl, `PRIMARY KEY ("id")`) { + t.Fatalf("ddl missing primary key from column metadata:\n%s", ddl) + } +} + +func TestIrisMetadataMapsColumnsAndIndexes(t *testing.T) { + dbConn, state := openOracleRecordingDB(t) + iris := &IrisDB{conn: dbConn} + + columnsQuery := buildIRISInfoSchemaWhereQuery("INFORMATION_SCHEMA.COLUMNS", irisTableRef{Schema: "Sample", Table: "Person"}) + indexesQuery := buildIRISInfoSchemaWhereQuery("INFORMATION_SCHEMA.INDEXES", irisTableRef{Schema: "Sample", Table: "Person"}) + + state.mu.Lock() + state.queryResults[columnsQuery] = oracleRecordingQueryResult{ + columns: []string{"TABLE_SCHEMA", "TABLE_NAME", "COLUMN_NAME", "DATA_TYPE", "CHARACTER_MAXIMUM_LENGTH", "IS_NULLABLE", "COLUMN_DEFAULT", "ORDINAL_POSITION", "DESCRIPTION", "PRIMARY_KEY", "UNIQUE_COLUMN"}, + rows: [][]driver.Value{ + {"Sample", "Person", "id", "INTEGER", nil, "NO", nil, int64(1), "identifier", true, false}, + {"Sample", "Person", "name", "VARCHAR", int64(80), "YES", "'anonymous'", int64(2), "display name", false, true}, + }, + } + state.queryResults[indexesQuery] = oracleRecordingQueryResult{ + columns: []string{"INDEX_NAME", "COLUMN_NAME", "NON_UNIQUE", "ORDINAL_POSITION", "INDEX_TYPE", "PRIMARY_KEY"}, + rows: [][]driver.Value{ + {"app_person_pk", "id", int64(1), int64(1), "bitmap", true}, + {"idx_person_name", "name", int64(0), int64(1), "", false}, + }, + } + state.mu.Unlock() + + columns, err := iris.GetColumns("Sample", "Person") + if err != nil { + t.Fatalf("GetColumns returned error: %v", err) + } + if len(columns) != 2 { + t.Fatalf("columns len = %d", len(columns)) + } + if columns[0].Name != "id" || columns[0].Key != "PRI" || columns[0].Nullable != "NO" { + t.Fatalf("unexpected id column: %#v", columns[0]) + } + if columns[1].Type != "VARCHAR(80)" || columns[1].Key != "UNI" { + t.Fatalf("unexpected name column: %#v", columns[1]) + } + + indexes, err := iris.GetIndexes("Sample", "Person") + if err != nil { + t.Fatalf("GetIndexes returned error: %v", err) + } + if len(indexes) != 2 || indexes[0].Name != "app_person_pk" || indexes[0].IndexType != "PRIMARY" || indexes[0].NonUnique != 0 { + t.Fatalf("unexpected indexes: %#v", indexes) + } +} + +func TestBuildIRISApplyChangesSQL(t *testing.T) { + deleteSQL, deleteArgs, ok := buildIRISDeleteSQL("Sample.Person", map[string]interface{}{"id": 1}) + if !ok { + t.Fatal("expected delete SQL") + } + if deleteSQL != `DELETE FROM "Sample"."Person" WHERE "id" = ?` || !reflect.DeepEqual(deleteArgs, []interface{}{1}) { + t.Fatalf("unexpected delete SQL/args: %s %#v", deleteSQL, deleteArgs) + } + + updateSQL, updateArgs, ok, err := buildIRISUpdateSQL("Sample.Person", connection.UpdateRow{ + Keys: map[string]interface{}{"id": 1}, + Values: map[string]interface{}{"name": "Alice", "updated_at": "2026-05-16"}, + }) + if err != nil || !ok { + t.Fatalf("expected update SQL, ok=%v err=%v", ok, err) + } + if updateSQL != `UPDATE "Sample"."Person" SET "name" = ?, "updated_at" = ? WHERE "id" = ?` { + t.Fatalf("unexpected update SQL: %s", updateSQL) + } + if !reflect.DeepEqual(updateArgs, []interface{}{"Alice", "2026-05-16", 1}) { + t.Fatalf("unexpected update args: %#v", updateArgs) + } + + insertSQL, insertArgs, ok := buildIRISInsertSQL("Sample.Person", map[string]interface{}{"name": "Alice", "id": 1}) + if !ok { + t.Fatal("expected insert SQL") + } + if insertSQL != `INSERT INTO "Sample"."Person" ("id", "name") VALUES (?, ?)` { + t.Fatalf("unexpected insert SQL: %s", insertSQL) + } + if !reflect.DeepEqual(insertArgs, []interface{}{1, "Alice"}) { + t.Fatalf("unexpected insert args: %#v", insertArgs) + } +} + +func TestIrisApplyChangesExecutesInDeleteUpdateInsertOrder(t *testing.T) { + dbConn, state := openOracleRecordingDB(t) + iris := &IrisDB{conn: dbConn} + + err := iris.ApplyChanges("Sample.Person", connection.ChangeSet{ + Deletes: []map[string]interface{}{ + {"id": 3}, + }, + Updates: []connection.UpdateRow{ + {Keys: map[string]interface{}{"id": 2}, Values: map[string]interface{}{"name": "Bob"}}, + }, + Inserts: []map[string]interface{}{ + {"id": 1, "name": "Alice"}, + }, + }) + if err != nil { + t.Fatalf("ApplyChanges returned error: %v", err) + } + + got := state.snapshotExecQueries() + want := []string{ + `DELETE FROM "Sample"."Person" WHERE "id" = ?`, + `UPDATE "Sample"."Person" SET "name" = ? WHERE "id" = ?`, + `INSERT INTO "Sample"."Person" ("id", "name") VALUES (?, ?)`, + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("unexpected exec queries:\nwant=%#v\ngot=%#v", want, got) + } +} + +func TestBuildIRISUpdateSQLRequiresLocatorKeys(t *testing.T) { + _, _, ok, err := buildIRISUpdateSQL("Person", connection.UpdateRow{ + Values: map[string]interface{}{"name": "Alice"}, + }) + if err == nil || ok { + t.Fatalf("expected missing keys to be rejected, ok=%v err=%v", ok, err) + } +} diff --git a/third_party/go-irisnative/LICENSE b/third_party/go-irisnative/LICENSE new file mode 100644 index 0000000..fa4395a --- /dev/null +++ b/third_party/go-irisnative/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 Dmitry Maslennikov + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/third_party/go-irisnative/PATCHES.md b/third_party/go-irisnative/PATCHES.md new file mode 100644 index 0000000..5e10215 --- /dev/null +++ b/third_party/go-irisnative/PATCHES.md @@ -0,0 +1,11 @@ +Local patch against github.com/caretdev/go-irisnative v0.2.1: + +- Added `//go:build !windows` to `src/connection/user_posix.go`. + Upstream ships `user_windows.go` with a Windows filename suffix, but + `user_posix.go` has no build constraint, so Windows builds compile both + files and fail with `userCurrent redeclared`. +- Made `Connection.Disconnect` close the underlying TCP connection after + sending the protocol disconnect message, so `database/sql` closes do not + leak sockets. +- Made `Connection.BeginTx` return the `START TRANSACTION` error instead of + marking the connection as in-transaction when the server rejected the begin. diff --git a/third_party/go-irisnative/README.md b/third_party/go-irisnative/README.md new file mode 100644 index 0000000..27b8fae --- /dev/null +++ b/third_party/go-irisnative/README.md @@ -0,0 +1,285 @@ +# go-irisnative + +A Golang driver for InterSystems IRIS that implements `database/sql`. + +> Project status: **alpha**. API may change. Feedback and PRs welcome. + +--- + +## Installation + +```bash +# replace the module path with the final repo path when published +go get github.com/caretdev/go-irisnative +``` + +Register the driver by importing it for side‑effects: + +```go +import ( + "database/sql" + _ "github.com/caretdev/go-irisnative" // registers driver as "iris" +) +``` + +## DSN formats + +The driver accepts a URL-style DSN (recommended) or key=value pairs. + +**URL style** + +``` +iris://user:password@host:1972/NAMESPACE? +``` + +* `host` — IRIS hostname or IP +* `1972` — superserver port (default) +* `Namespace` — IRIS namespace (e.g., `USER`) + +--- + +## Quick start (database/sql) + +```go +package main + +import ( + "context" + "database/sql" + "fmt" + "log" + "time" + + _ "github.com/caretdev/go-irisnative" +) + +func main() { + dsn := "iris://_SYSTEM:SYS@localhost:1972/USER" + db, err := sql.Open("iris", dsn) + if err != nil { log.Fatal(err) } + defer db.Close() + + // Connection pool tuning + db.SetMaxOpenConns(10) + db.SetMaxIdleConns(5) + db.SetConnMaxLifetime(30 * time.Minute) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + _, err = db.ExecContext(ctx, `DROP TABLE IF EXISTS demo_person`) + if err != nil { log.Fatal("drop table:", err) } + + // 1) Create a table (id INT PRIMARY KEY, name VARCHAR(80)) + _, err = db.ExecContext(ctx, `CREATE TABLE IF NOT EXISTS demo_person ( + id INT PRIMARY KEY, + name VARCHAR(80) NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + )`) + if err != nil { log.Fatal("create table:", err) } + + // 2) Insert with placeholders + res, err := db.ExecContext(ctx, `INSERT INTO demo_person(id, name) VALUES(?, ?)`, 1, "Alice") + if err != nil { log.Fatal("insert:", err) } + if n, _ := res.RowsAffected(); n > 0 { fmt.Println("inserted:", n) } + + // 3) Query rows + rows, err := db.QueryContext(ctx, `SELECT id, name, created_at FROM demo_person ORDER BY id`) + if err != nil { log.Fatal("query:", err) } + defer rows.Close() + + for rows.Next() { + var ( + id int + name string + createdAt time.Time + ) + if err := rows.Scan(&id, &name, &createdAt); err != nil { log.Fatal(err) } + fmt.Printf("row: id=%d name=%s created_at=%s\n", id, name, createdAt.Format(time.RFC3339)) + } + if err := rows.Err(); err != nil { log.Fatal(err) } + + // 4) Prepared statement + stmt, err := db.PrepareContext(ctx, `UPDATE demo_person SET name=? WHERE id=?`) + if err != nil { log.Fatal("prepare:", err) } + defer stmt.Close() + if _, err := stmt.ExecContext(ctx, "Alice Updated", 1); err != nil { log.Fatal("update:", err) } + + // 5) Transaction example + tx, err := db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadCommitted}) + if err != nil { log.Fatal("begin tx:", err) } + if _, err := tx.ExecContext(ctx, `INSERT INTO demo_person(id, name) VALUES(?, ?)`, 2, "Bob"); err != nil { + tx.Rollback() + log.Fatal("tx insert:", err) + } + if err := tx.Commit(); err != nil { log.Fatal("commit:", err) } +} +``` + +### Query single value helper + +```go +var count int +if err := db.QueryRowContext(ctx, `SELECT COUNT(*) FROM demo_person`).Scan(&count); err != nil { + log.Fatal(err) +} +fmt.Println("count=", count) +``` + +--- + +## Using with `sqlx` + +`sqlx` adds nice helpers over `database/sql` like struct scanning and named queries. + +```bash +go get github.com/jmoiron/sqlx +``` + +```go +package main + +import ( + "context" + "fmt" + "log" + "time" + + _ "github.com/caretdev/go-irisnative" // driver + "github.com/jmoiron/sqlx" +) + +type Person struct { + ID int `db:"id"` + Name string `db:"name"` + CreatedAt time.Time `db:"created_at"` +} + +func create(ctx context.Context, db *sqlx.DB) { + drop(ctx, db) + _, err := db.ExecContext(ctx, `CREATE TABLE IF NOT EXISTS demo_person ( + id INT PRIMARY KEY, + name VARCHAR(80) NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + )`) + if err != nil { + panic(err) + } +} + +func drop(ctx context.Context, db *sqlx.DB) { + _, err := db.ExecContext(ctx, `DROP TABLE IF EXISTS demo_person`) + if err != nil { + panic(err) + } +} + +func main() { + ctx := context.Background() + dsn := "iris://_SYSTEM:SYS@localhost:1972/USER" + db := sqlx.MustConnect("iris", dsn) + defer db.Close() + + create(ctx, db) + defer drop(ctx, db) + + // Struct-based insert with NamedExec + p := Person{ID: 3, Name: "Carol"} + _, err := db.NamedExecContext(ctx, + `INSERT INTO demo_person(id, name) VALUES(:id, :name)`, p, + ) + if err != nil { + log.Fatal("named insert:", err) + } + + // Select into slice of structs + var people []Person + if err := db.SelectContext(ctx, &people, `SELECT id, name, created_at FROM demo_person ORDER BY id`); err != nil { + log.Fatal(err) + } + fmt.Printf("people: %#v\n", people) + + // Get a single struct + var one Person + if err := db.GetContext(ctx, &one, `SELECT id, name, created_at FROM demo_person WHERE id=?`, people[0].ID); err != nil { + log.Fatal(err) + } + fmt.Printf("one: %+v\n", one) + + // Named query with IN (sqlx.In) + ids := []int{1, 2, 3} + q, args, err := sqlx.In(`SELECT id, name FROM demo_person WHERE id IN (?)`, ids) + if err != nil { + log.Fatal(err) + } + q = db.Rebind(q) // ensure driver-specific bindvars + rows, err := db.QueryxContext(ctx, q, args...) + if err != nil { + log.Fatal(err) + } + defer rows.Close() + for rows.Next() { + var id int + var name string + if err := rows.Scan(&id, &name); err != nil { + log.Fatal(err) + } + fmt.Println(id, name) + } +} +``` + +--- + +## Placeholders & rebind + +* The driver uses `?` positional placeholders. +* With `sqlx`, **always** call `db.Rebind(q)` after `sqlx.In(...)` to adapt placeholders. + +--- + +## Context, timeouts & cancellations + +All examples use `Context`. Set sensible timeouts to avoid runaway queries: + +```go +ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) +defer cancel() +``` + +--- + +## Error handling tips + +* Check `rows.Err()` after iteration. +* Prefer `ExecContext`/`QueryContext` to ensure timeouts are respected. +* Wrap errors with operation context (e.g., `fmt.Errorf("create table: %w", err)`). + +--- + +## Testing locally + +1. Start IRIS and ensure SQL is enabled for your namespace (e.g., `USER`). +2. Create a SQL user with privileges to connect and create tables. +3. Verify connectivity using the DSN shown above. + +--- + +## Compatibility + +* Go: 1.21+ +* InterSystems IRIS: 2025.1+ + +--- + +## License + +MIT + +--- + +## Contributing + +* Run `go vet` and `go test ./...` before submitting PRs. +* Add tests for new behaviors. +* Document any DSN parameters you introduce. diff --git a/third_party/go-irisnative/connector.go b/third_party/go-irisnative/connector.go new file mode 100644 index 0000000..f599ab6 --- /dev/null +++ b/third_party/go-irisnative/connector.go @@ -0,0 +1,85 @@ +package intersystems + +import ( + "context" + "database/sql/driver" + "errors" + "strings" +) + +// Connector represents a fixed configuration for the pq driver with a given +// name. Connector satisfies the database/sql/driver Connector interface and +// can be used to create any number of DB Conn's via the database/sql OpenDB +// function. +// +// See https://golang.org/pkg/database/sql/driver/#Connector. +// See https://golang.org/pkg/database/sql/#OpenDB. +type Connector struct { + opts values + // dialer Dialer +} + +// Connect returns a connection to the database using the fixed configuration +// of this Connector. Context is not used. +func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) { + return c.open(ctx) +} + +// Driver returns the underlying driver of this Connector. +func (c *Connector) Driver() driver.Driver { + return &Driver{} +} + +// NewConnector returns a connector for the pq driver in a fixed configuration +// with the given dsn. The returned connector can be used to create any number +// of equivalent Conn's. The returned connector is intended to be used with +// database/sql.OpenDB. +// +// See https://golang.org/pkg/database/sql/driver/#Connector. +// See https://golang.org/pkg/database/sql/#OpenDB. +func NewConnector(dsn string) (*Connector, error) { + var err error + o := make(values) + + // A number of defaults are applied here, in this order: + // + // * Very low precedence defaults applied in every situation + // * Environment variables + // * Explicitly passed connection information + o["host"] = "localhost" + o["port"] = "1972" + + if strings.HasPrefix(dsn, "iris://") || strings.HasPrefix(dsn, "IRIS://") { + dsn, err = ParseURL(dsn) + if err != nil { + return nil, err + } + } + + if err := parseOpts(dsn, o); err != nil { + return nil, err + } + + if enc, ok := o["client_encoding"]; ok && !isUTF8(enc) { + return nil, errors.New("client_encoding must be absent or 'UTF8'") + } + o["client_encoding"] = "UTF8" + + return &Connector{opts: o, /*dialer: defaultDialer{}*/}, nil +} + +// isUTF8 returns whether name is a fuzzy variation of the string "UTF-8". +func isUTF8(name string) bool { + s := strings.Map(alnumLowerASCII, name) + return s == "utf8" || s == "unicode" +} + +func alnumLowerASCII(ch rune) rune { + if 'A' <= ch && ch <= 'Z' { + return ch + ('a' - 'A') + } + if 'a' <= ch && ch <= 'z' || '0' <= ch && ch <= '9' { + return ch + } + return -1 // discard +} diff --git a/third_party/go-irisnative/driver.go b/third_party/go-irisnative/driver.go new file mode 100644 index 0000000..4501d6c --- /dev/null +++ b/third_party/go-irisnative/driver.go @@ -0,0 +1,216 @@ +package intersystems + +import ( + "context" + "database/sql" + "database/sql/driver" + "errors" + "fmt" + "net" + "unicode" + + _ "io" + _ "math" + _ "reflect" + _ "strconv" + _ "strings" + _ "time" + _ "unsafe" + + "github.com/caretdev/go-irisnative/src/connection" +) + +var ( + ErrCouldNotDetectUsername = errors.New("intersystems: Could not detect default username. Please provide one explicitly") +) + +var ( + _ driver.Driver = Driver{} +) + +type values map[string]string + +// Driver implements database/sql/driver.Driver. +type Driver struct{} + +func (d Driver) Open(name string) (driver.Conn, error) { + return Open(name) +} + +func init() { + sql.Register("intersystems", &Driver{}) + sql.Register("iris", &Driver{}) +} + +func Open(dsn string) (_ driver.Conn, err error) { + c, err := NewConnector(dsn) + if err != nil { + return nil, err + } + return c.open(context.Background()) +} + +type conn struct { + c connection.Connection + tx bool +} + +func (c *Connector) open(ctx context.Context) (cn *conn, err error) { + o := make(values) + for k, v := range c.opts { + o[k] = v + } + host := o["host"] + addr := net.JoinHostPort(host, o["port"]) + namespace := o["namespace"] + login := o["user"] + password := o["password"] + + cn = &conn{} + + cn.c, err = connection.Connect(addr, namespace, login, password) + if err != nil { + return nil, err + } + return cn, nil +} + +func (cn *conn) Begin() (driver.Tx, error) { + return cn.c.BeginTx(driver.TxOptions{}) +} + +func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + return cn.c.BeginTx(opts) +} + +func (cn *conn) Close() (err error) { + cn.c.Disconnect() + return nil +} + +func (cn *conn) Prepare(q string) (st driver.Stmt, err error) { + return cn.c.Prepare(q) +} + +func (cn *conn) Commit() error { + if !cn.tx { + panic("transaction already closed") + } + cn.tx = false + cn.c.Commit() + return nil +} + +func (cn *conn) Rollback() error { + if !cn.tx { + panic("transaction already closed") + } + cn.tx = false + cn.c.Rollback() + return nil +} + +func (cn *conn) Exec(query string, args []driver.NamedValue) (res driver.Result, err error) { + parameters := make([]interface{}, len(args)) + for i, a := range args { + parameters[i] = a + } + _, err = cn.c.DirectUpdate(query, parameters...) + if err != nil { + return nil, err + } + return res, nil +} + +func (cn *conn) Query(query string, args []driver.NamedValue) (rows driver.Rows, err error) { + parameters := make([]interface{}, len(args)) + for i, a := range args { + parameters[i] = a + } + // var rs *connection.ResultSet + _, err = cn.c.Query(query, parameters...) + if err != nil { + return nil, err + } + // rows = &connection.Rows{ + // cn: cn.c, + // rs: rs, + // } + return +} + +func parseOpts(name string, o values) error { + s := newScanner(name) + + for { + var ( + keyRunes, valRunes []rune + r rune + ok bool + ) + + if r, ok = s.SkipSpaces(); !ok { + break + } + + // Scan the key + for !unicode.IsSpace(r) && r != '=' { + keyRunes = append(keyRunes, r) + if r, ok = s.Next(); !ok { + break + } + } + + // Skip any whitespace if we're not at the = yet + if r != '=' { + r, ok = s.SkipSpaces() + } + + // The current character should be = + if r != '=' || !ok { + return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes)) + } + + // Skip any whitespace after the = + if r, ok = s.SkipSpaces(); !ok { + // If we reach the end here, the last value is just an empty string as per libpq. + o[string(keyRunes)] = "" + break + } + + if r != '\'' { + for !unicode.IsSpace(r) { + if r == '\\' { + if r, ok = s.Next(); !ok { + return fmt.Errorf(`missing character after backslash`) + } + } + valRunes = append(valRunes, r) + + if r, ok = s.Next(); !ok { + break + } + } + } else { + quote: + for { + if r, ok = s.Next(); !ok { + return fmt.Errorf(`unterminated quoted string literal in connection string`) + } + switch r { + case '\'': + break quote + case '\\': + r, _ = s.Next() + fallthrough + default: + valRunes = append(valRunes, r) + } + } + } + + o[string(keyRunes)] = string(valRunes) + } + + return nil +} diff --git a/third_party/go-irisnative/go.mod b/third_party/go-irisnative/go.mod new file mode 100644 index 0000000..59781e7 --- /dev/null +++ b/third_party/go-irisnative/go.mod @@ -0,0 +1,5 @@ +module github.com/caretdev/go-irisnative + +go 1.24.3 + +require github.com/shopspring/decimal v1.4.0 diff --git a/third_party/go-irisnative/go.sum b/third_party/go-irisnative/go.sum new file mode 100644 index 0000000..2def83f --- /dev/null +++ b/third_party/go-irisnative/go.sum @@ -0,0 +1 @@ +github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME= diff --git a/third_party/go-irisnative/scanner.go b/third_party/go-irisnative/scanner.go new file mode 100644 index 0000000..4bf6e60 --- /dev/null +++ b/third_party/go-irisnative/scanner.go @@ -0,0 +1,34 @@ +package intersystems + +import "unicode" + +type scanner struct { + s []rune + i int +} + +// newScanner returns a new scanner initialized with the option string s. +func newScanner(s string) *scanner { + return &scanner{[]rune(s), 0} +} + +// Next returns the next rune. +// It returns 0, false if the end of the text has been reached. +func (s *scanner) Next() (rune, bool) { + if s.i >= len(s.s) { + return 0, false + } + r := s.s[s.i] + s.i++ + return r, true +} + +// SkipSpaces returns the next non-whitespace rune. +// It returns 0, false if the end of the text has been reached. +func (s *scanner) SkipSpaces() (rune, bool) { + r, ok := s.Next() + for unicode.IsSpace(r) && ok { + r, ok = s.Next() + } + return r, ok +} diff --git a/third_party/go-irisnative/src/connection/classes.go b/third_party/go-irisnative/src/connection/classes.go new file mode 100644 index 0000000..cff3237 --- /dev/null +++ b/third_party/go-irisnative/src/connection/classes.go @@ -0,0 +1,113 @@ +package connection + +import "github.com/caretdev/go-irisnative/src/iris" + +func (c *Connection) ServerVersion() (result string, err error) { + err = c.ClassMethod("%SYSTEM.Version", "GetVersion", &result) + return +} + +func (c *Connection) ClassMethod(class, method string, result interface{}, args ...interface{}) (err error) { + msg := NewMessage(CLASSMETHOD_VALUE) + msg.Set(class) + msg.Set(method) + msg.Set(len(args)) + for _, arg := range args { + msg.Set(arg) + } + + _, err = c.conn.Write(msg.Dump(c.count())) + if err != nil { + return + } + msg, err = ReadMessage(c.conn) + if err != nil { + return + } + + msg.Get(result) + + return +} + +func (c *Connection) ClassMethodVoid(class, method string, args ...interface{}) (err error) { + msg := NewMessage(CLASSMETHOD_VOID) + msg.Set(class) + msg.Set(method) + msg.Set(len(args)) + for _, arg := range args { + msg.Set(arg) + } + + _, err = c.conn.Write(msg.Dump(c.count())) + if err != nil { + return + } + msg, err = ReadMessage(c.conn) + if err != nil { + return + } + return +} + +func (c *Connection) Method(obj iris.Oref, method string, result interface{}, args ...interface{}) (err error) { + msg := NewMessage(METHOD_VALUE) + msg.Set(obj) + msg.Set(method) + msg.Set(len(args)) + for _, arg := range args { + msg.Set(arg) + } + + _, err = c.conn.Write(msg.Dump(c.count())) + if err != nil { + return + } + msg, err = ReadMessage(c.conn) + if err != nil { + return + } + + msg.Get(result) + + return +} + +func (c *Connection) MethodVoid(obj, method string, args ...interface{}) (err error) { + msg := NewMessage(METHOD_VOID) + msg.Set(obj) + msg.Set(method) + msg.Set(len(args)) + for _, arg := range args { + msg.Set(arg) + } + + _, err = c.conn.Write(msg.Dump(c.count())) + if err != nil { + return + } + msg, err = ReadMessage(c.conn) + if err != nil { + return + } + return +} +func (c *Connection) PropertyGet(obj iris.Oref, property string, result interface{}) (err error) { + msg := NewMessage(PROPERTY_GET) + msg.Set(obj) + msg.Set(property) + // msg.Set(0) + + _, err = c.conn.Write(msg.Dump(c.count())) + if err != nil { + return + } + msg, err = ReadMessage(c.conn) + if err != nil { + return + } + + msg.Get(result) + + return +} diff --git a/third_party/go-irisnative/src/connection/globals.go b/third_party/go-irisnative/src/connection/globals.go new file mode 100644 index 0000000..32bcf51 --- /dev/null +++ b/third_party/go-irisnative/src/connection/globals.go @@ -0,0 +1,141 @@ +package connection + +func (c *Connection) GlobalIsDefined(global string, subs ...interface{}) (bool, bool) { + msg := NewMessage(GLOBAL_DATA) + msg.Set(global) + msg.Set(len(subs)) + for _, sub := range subs { + msg.Set(sub) + } + msg.Set(0) + _, err := c.conn.Write(msg.Dump(c.count())) + if err != nil { + return false, false + } + + msg, err = ReadMessage(c.conn) + if err != nil { + return false, false + } + + var result uint8 + msg.Get(&result) + return result%10 == 1, result >= 10 +} + +func (c *Connection) GlobalSet(global string, value interface{}, subs ...interface{}) (err error) { + msg := NewMessage(GLOBAL_SET) + msg.Set(global) + msg.Set(len(subs)) + for _, sub := range subs { + msg.Set(sub) + } + msg.Set(value) + + _, err = c.conn.Write(msg.Dump(c.count())) + if err != nil { + return + } + + _, err = ReadMessage(c.conn) + if err != nil { + return + } + + return +} + +func (c *Connection) GlobalKill(global string, subs ...interface{}) (err error) { + msg := NewMessage(GLOBAL_KILL) + msg.Set(global) + msg.Set(len(subs)) + for _, sub := range subs { + msg.Set(sub) + } + + _, err = c.conn.Write(msg.Dump(c.count())) + if err != nil { + return + } + + _, err = ReadMessage(c.conn) + if err != nil { + return + } + + return +} + +func (c *Connection) GlobalGet(global string, result interface{}, subs ...interface{}) (err error) { + msg := NewMessage(GLOBAL_GET) + msg.Set(global) + msg.Set(len(subs)) + for _, sub := range subs { + msg.Set(sub) + } + + _, err = c.conn.Write(msg.Dump(c.count())) + if err != nil { + return + } + + msg, err = ReadMessage(c.conn) + if err != nil { + return + } + + msg.Get(result) + return +} + +func (c *Connection) GlobalNext(global string, ind *string, subs ...interface{}) (hasNext bool, err error) { + msg := NewMessage(GLOBAL_ORDER) + msg.Set(global) + msg.Set(len(subs) + 1) + for _, sub := range subs { + msg.Set(sub) + } + msg.Set(*ind) + msg.Set(3) + + if _, err = c.conn.Write(msg.Dump(c.count())); err != nil { + return + } + + if msg, err = ReadMessage(c.conn); err != nil { + return + } + + var result string + msg.Get(&result) + *ind = result + hasNext = result != "" + + return +} + +func (c *Connection) GlobalPrev(global string, ind *string, subs ...interface{}) (hasNext bool, err error) { + msg := NewMessage(GLOBAL_ORDER) + msg.Set(global) + msg.Set(len(subs) + 1) + for _, sub := range subs { + msg.Set(sub) + } + msg.Set(*ind) + msg.Set(7) + + if _, err = c.conn.Write(msg.Dump(c.count())); err != nil { + return + } + + if msg, err = ReadMessage(c.conn); err != nil { + return + } + + var result string + msg.Get(&result) + *ind = result + hasNext = result != "" + + return +} diff --git a/third_party/go-irisnative/src/connection/message.go b/third_party/go-irisnative/src/connection/message.go new file mode 100644 index 0000000..43fc9b2 --- /dev/null +++ b/third_party/go-irisnative/src/connection/message.go @@ -0,0 +1,148 @@ +package connection + +import ( + "fmt" + "io" + "net" + + "github.com/caretdev/go-irisnative/src/list" +) + +type Message struct { + header MessageHeader + data []byte + offset uint +} + +func NewMessage(messageType MessageType) Message { + return Message{ + NewMessageHeader(messageType), + []byte{}, + 0, + } +} + +func ReadMessage(conn *net.TCPConn) (msg Message, err error) { + buffer := make([]byte, 14) + + _, err = conn.Read(buffer) + if err != nil { + return + } + + var header [14]byte + copy(header[:], buffer[:14]) + var msgHeader = MessageHeader{header} + + length := msgHeader.GetLength() + data := make([]byte, length) + var offset int = 0 + var size int + for { + size, err = conn.Read(data[offset:]) + if err != nil { + if err != io.EOF { + return + } + break + } + offset += size + if offset >= int(length) { + break + } + } + + msg = Message{msgHeader, data, 0} + + return +} + +func (m *Message) AddRaw(value interface{}) { + switch v := value.(type) { + case uint16: + m.data = append(m.data, byte(v&0xff)) + m.data = append(m.data, byte(v>>8&0xff)) + m.offset += 2 + case []byte: + m.data = append(m.data, v...) + m.offset += uint(len(v)) + } +} + +func (m *Message) GetRaw(value interface{}) error { + switch v := value.(type) { + case *uint16: + *v = uint16(m.data[m.offset]) | (uint16(m.data[m.offset+1]) << 8) + m.offset += 2 + case *bool: + *v = (uint16(m.data[m.offset]) | (uint16(m.data[m.offset+1]) << 8)) == 1 + m.offset += 2 + case *[]byte: + *v = m.data[m.offset:] + m.offset = uint(len(m.data)) + default: + return fmt.Errorf("unknown type: %T", v) + } + return nil +} + +func (m *Message) Set(value interface{}) error { + listItem := list.NewListItem(value) + m.AddRaw(listItem.Dump()) + return nil +} + +func (m *Message) SetSQLText(sqlText string) error { + len := len(sqlText) + if len == 0 { + m.Set(sqlText) + return nil + } + const chunksize = 31904 + chunks := len / chunksize + if len%chunksize != 0 { + chunks += 1 + } + m.Set(chunks) + for i := 0; i < chunks; i++ { + begin := i * chunksize + end := (i + 1) * chunksize + if end > len { + end = len + } + m.Set(sqlText[begin:end]) + } + return nil +} + +func (m *Message) GetStatus() uint16 { + return m.header.GetStatus() +} + +func (m *Message) Get(value interface{}) error { + listItem := list.GetListItem(m.data, &m.offset) + listItem.Get(value) + return nil +} + +type AnyType struct { + listItem list.ListItem +} + +func (v *AnyType) Int() int { + var value int + v.listItem.Get(&value) + return value +} + +func (m *Message) GetAny() AnyType { + listItem := list.GetListItem(m.data, &m.offset) + return AnyType{listItem} +} + +func (m *Message) Dump(count uint32) []byte { + m.header.SetCount(count) + m.header.SetLength(uint32(len(m.data))) + + return append(m.header.header[:], m.data...) +} diff --git a/third_party/go-irisnative/src/connection/message_header.go b/third_party/go-irisnative/src/connection/message_header.go new file mode 100644 index 0000000..0f79a6a --- /dev/null +++ b/third_party/go-irisnative/src/connection/message_header.go @@ -0,0 +1,84 @@ +package connection + +type MessageType string + +func setUint32(buffer []byte, value uint32) { + buffer[0] = byte(value & 0xff) + buffer[1] = byte(value >> 8 & 0xff) + buffer[2] = byte(value >> 16 & 0xff) + buffer[3] = byte(value >> 24 & 0xff) +} + +func getUint32(buffer []byte) uint32 { + return uint32(buffer[0]) | + uint32(buffer[1])<<8 | + uint32(buffer[2])<<16 | + uint32(buffer[3])<<24 +} + +const ( + CONNECT MessageType = "\x43\x4e" + HANDSHAKE MessageType = "\x48\x53" + DISCONNECT MessageType = "\x44\x43" + + GLOBAL_GET MessageType = "\x41\xc2" + GLOBAL_SET MessageType = "\x42\xc2" + GLOBAL_KILL MessageType = "\x43\xc2" + GLOBAL_ORDER MessageType = "\x45\xc2" + GLOBAL_DATA MessageType = "\x49\xc2" + + CLASSMETHOD_VALUE MessageType = "\x4b\xc2" + CLASSMETHOD_VOID MessageType = "\x4c\xc2" + + METHOD_VALUE MessageType = "\x5b\xc2" + METHOD_VOID MessageType = "\x5c\xc2" + + PROPERTY_GET MessageType = "\x5d\xc2" + PROPERTY_SET MessageType = "\x5e\xc2" + + DIRECT_QUERY MessageType = "DQ" + PREPARED_QUERY MessageType = "PQ" + DIRECT_UPDATE MessageType = "DU" + PREPARED_UPDATE MessageType = "PU" + PREPARE MessageType = "PP" + GET_AUTO_GENERATED_KEYS MessageType = "GG" + + COMMIT MessageType = "TC" + ROLLBACK MessageType = "TR" + + MULTIPLE_RESULT_SETS_FETCH_DATA MessageType = "MD" + GET_MORE_RESULTS MessageType = "MR" + FETCH_DATA MessageType = "FD" + GET_SERVER_ERROR MessageType = "OE" +) + +type MessageHeader struct { + header [14]byte +} + +func NewMessageHeader(messageType MessageType) MessageHeader { + header := [14]byte{} + header[12] = messageType[0] + header[13] = messageType[1] + return MessageHeader{header} +} + +func (mh *MessageHeader) GetStatus() uint16 { + return uint16(mh.header[12]) | (uint16(mh.header[13]) << 8) +} + +func (mh *MessageHeader) SetLength(length uint32) { + setUint32(mh.header[0:], length) +} + +func (mh MessageHeader) GetLength() uint32 { + return getUint32(mh.header[0:]) +} + +func (mh *MessageHeader) SetCount(cnt uint32) { + setUint32(mh.header[4:], cnt) +} + +func (mh *MessageHeader) SetStatementId(statementId uint32) { + setUint32(mh.header[8:], statementId) +} diff --git a/third_party/go-irisnative/src/connection/mod.go b/third_party/go-irisnative/src/connection/mod.go new file mode 100644 index 0000000..6d04546 --- /dev/null +++ b/third_party/go-irisnative/src/connection/mod.go @@ -0,0 +1,251 @@ +package connection + +import ( + "database/sql/driver" + "errors" + "net" +) + +const VERSION_PROTOCOL uint16 = 69 + +type Connection struct { + conn *net.TCPConn + messageCount uint32 + statement uint32 + unicode bool + locale string + version uint16 + info string + featureOptions uint + tx bool +} + +var ( + ErrCouldNotDetectUsername = errors.New("intersystems: Could not detect default username. Please provide one explicitly") + errBeginTx = errors.New("could not begin transaction") + errMultipleTx = errors.New("multiple transactions") + errReadOnlyTxNotSupported = errors.New("read-only transactions are not supported") +) + +func Connect(addr string, namespace, login, password string) (connection Connection, err error) { + + tcpAddr, err := net.ResolveTCPAddr("tcp", addr) + if err != nil { + return + } + + conn, err := net.DialTCP("tcp", nil, tcpAddr) + if err != nil { + return + } + + connection = Connection{ + conn: conn, + } + + if err = connection.handshake(); err != nil { + return + } + + if err = connection.connect(namespace, login, password); err != nil { + return + } + + // fmt.Println(connection.version, connection.info) + + return +} + +func (c *Connection) Disconnect() { + if c.conn == nil { + return + } + msg := NewMessage(DISCONNECT) + _, _ = c.conn.Write(msg.Dump(c.count())) + _ = c.conn.Close() + c.conn = nil +} + +func (c *Connection) count() uint32 { + count := c.messageCount + c.messageCount += 1 + return count +} + +func (c *Connection) statementId() uint32 { + statement := c.statement + c.statement += 1 + return statement +} + +func (c *Connection) handshake() (err error) { + var message = NewMessage(HANDSHAKE) + message.AddRaw(VERSION_PROTOCOL) + + _, err = c.conn.Write(message.Dump(c.count())) + if err != nil { + return + } + + msg, err := ReadMessage(c.conn) + if err != nil { + return + } + + var version uint16 + msg.GetRaw(&version) + c.version = version + + var unicode uint16 + msg.GetRaw(&unicode) + c.unicode = unicode == 1 + + var locale string + msg.Get(&locale) + c.locale = locale + return +} + +func encode(value string) []byte { + in := []byte(value) + length := len(in) + out := make([]byte, length) + for i := range in { + length-- + temp := ((int(in[i])^0xa7)&0xff + length) & 0xff + out[length] = byte(temp<<5 | temp>>3) + } + return out +} + +type FeatureOption uint + +const ( + OptionNone FeatureOption = 0 + OptionFastSelect FeatureOption = 1 + OptionFastInsert FeatureOption = 2 + OptionFastSelectAndInsert FeatureOption = 3 + OptionDurableTransactions FeatureOption = 4 + OptionNotNullable FeatureOption = 8 + OptionRedirectOutput FeatureOption = 32 +) + +func (c *Connection) IsOptionFastInsert() bool { + return c.featureOptions&uint(OptionFastInsert) == uint(OptionFastInsert) +} + +func (c *Connection) IsOptionFastSelect() bool { + return c.featureOptions&uint(OptionFastSelect) == uint(OptionFastSelect) +} + +func (c *Connection) connect(namespace, login, password string) (err error) { + msg := NewMessage(CONNECT) + msg.Set(namespace) + msg.Set(encode(login)) + msg.Set(encode(password)) + var user = "go" + if user, err = systemUser(); err != nil { + user = "go" + } + msg.Set(user) // machine user name + msg.Set("go-machine") // machine name + msg.Set("libirisnative") // application name + msg.Set("") // ? + msg.Set("go") // SharedMemoryFlag? + msg.Set("") // EventClass + msg.Set(1) // AutoCommit ? 1 : 2 + msg.Set(0) // IsolationLevel + var featureOptions = OptionNone + featureOptions += OptionFastSelect + // Tricky to make it fully working yet + // featureOptions += OptionFastInsert + featureOptions += OptionDurableTransactions + featureOptions += OptionRedirectOutput + msg.Set(int(featureOptions)) // FeatureOption + + _, err = c.conn.Write(msg.Dump(c.count())) + if err != nil { + return + } + + msg, err = ReadMessage(c.conn) + if err != nil { + return + } + if status := msg.GetStatus(); status == 417 { + var errorMsg string + msg.Get(&errorMsg) + err = errors.New(errorMsg) + return + } + + var info string + msg.Get(&info) + c.info = info + var ( + delimited_ids bool + ignored int + isolationLevel int + serverJobNumber string + sqlEmptyString int + serverFeatureOptions uint + ) + msg.Get(&delimited_ids) + msg.Get(&ignored) + msg.Get(&isolationLevel) + msg.Get(&serverJobNumber) + msg.Get(&sqlEmptyString) + msg.Get(&serverFeatureOptions) + c.featureOptions = serverFeatureOptions + return +} + +func systemUser() (string, error) { + u, err := userCurrent() + if err != nil { + return "", err + } + return u, nil +} + +func (c *Connection) Commit() (err error) { + msg := NewMessage(COMMIT) + _, err = c.conn.Write(msg.Dump(c.count())) + if err != nil { + return + } + _, err = ReadMessage(c.conn) + if err != nil { + return + } + return +} + +func (c *Connection) Rollback() (err error) { + msg := NewMessage(ROLLBACK) + _, err = c.conn.Write(msg.Dump(c.count())) + if err != nil { + return + } + _, err = ReadMessage(c.conn) + if err != nil { + return + } + return +} + +func (c *Connection) BeginTx(opts driver.TxOptions) (driver.Tx, error) { + if c.tx { + return nil, errors.Join(errBeginTx, errMultipleTx) + } + + if opts.ReadOnly { + return nil, errors.Join(errBeginTx, errReadOnlyTxNotSupported) + } + + if _, err := c.DirectUpdate("START TRANSACTION"); err != nil { + return nil, errors.Join(errBeginTx, err) + } + c.tx = true + return &tx{c}, nil +} diff --git a/third_party/go-irisnative/src/connection/rows.go b/third_party/go-irisnative/src/connection/rows.go new file mode 100644 index 0000000..5adf797 --- /dev/null +++ b/third_party/go-irisnative/src/connection/rows.go @@ -0,0 +1,101 @@ +package connection + +import ( + "database/sql/driver" + "errors" + "strings" +) + +var ( + errNoRowsAffected = errors.New("no RowsAffected available after the empty statement") + errNoLastInsertID = errors.New("no LastInsertId available after the empty statement") +) + +type Result struct { + cn *Connection + affected int64 +} + +func (r Result) LastInsertId() (lastId int64, err error) { + // var msg Message + // msg = NewMessage(GET_AUTO_GENERATED_KEYS) + // msg.header.SetStatementId(r.cn.statementId()) + // _, err = r.cn.conn.Write(msg.Dump(r.cn.count())) + // if err != nil { + // return + // } + // msg, err = ReadMessage(r.cn.conn) + // if err != nil { + // return + // } + // msg.Get(&lastId) + // return + var rs *ResultSet + rs, err = r.cn.DirectQuery("SELECT LAST_IDENTITY()") + if err != nil { + return + } + row, err := rs.Next() + if err != nil { + return + } + lastId = int64(row[0].(int)) + return +} + +func (r Result) RowsAffected() (int64, error) { + return r.affected, nil +} + +type Rows struct { + cn *Connection + rs *ResultSet +} + +type noRows struct{} + +var emptyRows noRows + +var _ driver.Result = noRows{} + +func (noRows) LastInsertId() (int64, error) { + return 0, errNoLastInsertID +} + +func (noRows) RowsAffected() (int64, error) { + return 0, errNoRowsAffected +} + + +func (r *Rows) Close() error { + return nil +} + +func (r *Rows) Columns() []string { + if r.rs == nil { + return []string{} + } + columns := r.rs.Columns() + colNames := make([]string, len(columns)) + for k, c := range columns { + colname := c.Name() + // tricking IRIS + colname = strings.ReplaceAll(colname, "﹒", ".") + colNames[k] = colname + } + // fmt.Printf("Columns: %#v\n", colNames) + return colNames +} + +func (r *Rows) Next(dest []driver.Value) (err error) { + row, err := r.rs.Next() + if err != nil { + return err + } + for i := range dest { + dest[i] = row[i] + } + // fmt.Printf("RowsNext: %#v\n", dest) + return nil +} + diff --git a/third_party/go-irisnative/src/connection/sql.go b/third_party/go-irisnative/src/connection/sql.go new file mode 100644 index 0000000..701897e --- /dev/null +++ b/third_party/go-irisnative/src/connection/sql.go @@ -0,0 +1,729 @@ +package connection + +import ( + "database/sql/driver" + "fmt" + "io" + "slices" + "strconv" + "strings" + "time" + + "github.com/caretdev/go-irisnative/src/list" +) + +const timeLaylout = "2006-01-02 15:04:05.000000000" +const timeLayloutShort = "2006-01-02 15:04:05" + +type StatementFeature struct { + featureOption int + msgCount int + maxRowItemCount int +} + +type Column struct { + name string + column_type int + precision int + scale int + nullable int + slot_position int + label string + table_name string + schema string + catalog string + is_auto_increment bool + is_case_sensitive bool + is_currency bool + is_read_only bool + is_row_id bool +} + +type SQLTYPE int16 + +const ( + GUID SQLTYPE = -11 + WLONGVARCHAR SQLTYPE = -10 + WVARCHAR SQLTYPE = -9 + WCHAR SQLTYPE = -8 + BIT SQLTYPE = -7 + TINYINT SQLTYPE = -6 + BIGINT SQLTYPE = -5 + LONGVARBINARY SQLTYPE = -4 + VARBINARY SQLTYPE = -3 + BINARY SQLTYPE = -2 + LONGVARCHAR SQLTYPE = -1 + CHAR SQLTYPE = 1 + NUMERIC SQLTYPE = 2 + DECIMAL SQLTYPE = 3 + INTEGER SQLTYPE = 4 + SMALLINT SQLTYPE = 5 + FLOAT SQLTYPE = 6 + REAL SQLTYPE = 7 + DOUBLE SQLTYPE = 8 + DATE SQLTYPE = 9 + TIME SQLTYPE = 10 + TIMESTAMP SQLTYPE = 11 + VARCHAR SQLTYPE = 12 + TYPE_DATE SQLTYPE = 91 + TYPE_TIME SQLTYPE = 92 + TYPE_TIMESTAMP SQLTYPE = 93 + DATE_HOROLOG SQLTYPE = 1091 + TIME_HOROLOG SQLTYPE = 1092 + TIMESTAMP_POSIX SQLTYPE = 1093 +) + +func (c Column) Name() string { + return c.name +} + +type ResultSet struct { + c *Connection + columns []Column + sf StatementFeature + count int + data []byte + offset uint + sqlCode int16 +} + +type SQLError struct { + SQLCode int16 + Message string +} + +func (e *SQLError) Error() string { + return fmt.Sprintf("Error Code: %d, Message: %s", e.SQLCode, e.Message) +} + +// func SQLError(code int) error { +// return &SQLError{SQLCode: code} +// } + +func (rs ResultSet) Columns() []Column { + return rs.columns +} + +func statementFeature(msg *Message) StatementFeature { + featureOption := 0 + msgCount := 0 + maxRowItemCount := 0 + msg.Get(&featureOption) + if featureOption == 2 { + msg.Get(&msgCount) + } + if featureOption == 1 || featureOption == 2 { + msg.Get(&maxRowItemCount) + } + return StatementFeature{ + featureOption, + msgCount, + maxRowItemCount, + } +} + +type Value interface{} + +// type ResultSetRow struct{} + +func (rs *ResultSet) fetchMoreData() bool { + msg := NewMessage(FETCH_DATA) + _, err := rs.c.conn.Write(msg.Dump(rs.c.count())) + if err != nil { + panic(err) + } + msg, err = ReadMessage(rs.c.conn) + if err != nil { + panic(err) + } + + rs.data = msg.data + rs.offset = 0 + return len(msg.data) > 0 +} + +func fromODBC(coltype SQLTYPE, li list.ListItem) (result interface{}, err error) { + result = nil + if li.IsNull() || li.IsEmpty() { + return + } + switch coltype { + case VARCHAR: + if li.DataLength() == 0 { + return + } + var value string + li.Get(&value) + if value == "\x00" { + value = "" + } + result = value + case INTEGER, TINYINT, SMALLINT: + var value int + li.Get(&value) + result = value + case BIGINT: + var value int64 + li.Get(&value) + result = value + case BIT: + var value bool + li.Get(&value) + result = value + case FLOAT: + var value float32 + li.Get(&value) + result = value + case DOUBLE: + var value float64 + li.Get(&value) + result = value + case TIMESTAMP_POSIX: + if li.DataLength() == 0 { + return + } + if li.Type() == list.LISTITEM_STRING { + var strval string + li.Get(&strval) + result, err = time.Parse(timeLaylout, strval) + if err == nil { + return + } + err = nil + } + var value int64 + li.Get(&value) + if value > 0 { + value ^= 0x1000000000000000 + } else { + value |= 0x6000000000000000 + } + seconds := value / 1000000 + nano := value % 1000000 * 1000 + result = time.Unix(seconds, nano).In(time.Local) + case VARBINARY: + // var value []uint8 + var value string + li.Get(&value) + case TYPE_TIMESTAMP: + var strval string + li.Get(&strval) + result, err = time.Parse(timeLayloutShort, strval) + default: + var value string + li.Get(&value) + fmt.Printf("fromODBC: invalid type: %v - %#v - %#v", coltype, li, value) + result = value + } + return +} + +func (rs *ResultSet) Next() ([]Value, error) { + if rs == nil || (rs.sqlCode != 0 && rs.sqlCode != 100) { + return nil, io.EOF + } + if rs.offset >= uint(len(rs.data)) && (rs.sqlCode == 100 || !rs.fetchMoreData()) { + return nil, io.EOF + } + row := make([]Value, rs.count) + data := rs.data + count := rs.count + var offset uint = rs.offset + if rs.sf.featureOption == 1 { + li := list.GetListItem(data, &rs.offset) + li.Get(&data) + offset = 0 + count = rs.sf.maxRowItemCount + } + vals := make([]list.ListItem, count) + for i := 0; i < count; i++ { + li := list.GetListItem(data, &offset) + vals[i] = li + } + if rs.sf.featureOption != 1 { + rs.offset = offset + } + var err error + for i, c := range rs.columns { + li := vals[c.slot_position] + row[i], err = fromODBC(SQLTYPE(c.column_type), li) + if err != nil { + return nil, err + } + // fmt.Printf("col: %s: %d; %#v - %#v\n", c.name, c.column_type, row[i], li) + } + // fmt.Printf("row: %#v\n", row) + return row, nil +} + +func (c *Connection) getErrorInfo(sqlCode int16) string { + msg := NewMessage(GET_SERVER_ERROR) + msg.Set(sqlCode) + _, err := c.conn.Write(msg.Dump(c.count())) + if err != nil { + panic(err) + } + msg, err = ReadMessage(c.conn) + if err != nil { + panic(err) + } + var sqlMessage string + msg.Get(&sqlMessage) + return sqlMessage +} + +func getColumns(msg *Message, statementFeature StatementFeature) []Column { + cnt := 0 + msg.Get(&cnt) + columns := make([]Column, cnt) + for i := 0; i < cnt; i++ { + column := Column{} + msg.Get(&column.name) + msg.Get(&column.column_type) + switch column.column_type { + case 9: + column.column_type = 91 + case 10: + column.column_type = 92 + case 11: + column.column_type = 93 + } + msg.Get(&column.precision) + msg.Get(&column.scale) + msg.Get(&column.nullable) + msg.Get(&column.label) + msg.Get(&column.table_name) + msg.Get(&column.schema) + msg.Get(&column.catalog) + additional := "" + msg.Get(&additional) + if statementFeature.featureOption&0x01 == 1 { + msg.Get(&column.slot_position) + column.slot_position -= 1 + } else { + column.slot_position = i + } + column.is_auto_increment = additional[0] == 0x01 + column.is_case_sensitive = additional[1] == 0x01 + column.is_currency = additional[2] == 0x01 + column.is_read_only = additional[3] == 0x01 + if len(additional) >= 12 { + column.is_row_id = additional[11] == 0x01 + } + columns[i] = column + } + return columns +} + +func parameterInfo(msg *Message) { + cnt := 0 + msg.Get(&cnt) + flag := 0 + msg.Get(&flag) +} + +func toODBC(value interface{}) interface{} { + var val interface{} + switch v := value.(type) { + case *string: + val = *v + case string: + val = v + if v == "" { + val = "\x00" + } + case nil: + val = "" + case bool: + if v { + val = 1 + } else { + val = 0 + } + case time.Time: + val = v.UTC().Format(timeLaylout) + case int, int8, int16, int32, int64: + val = v + case float32, float64: + val = v + case []uint8: + val = v + default: + fmt.Printf("unsupported type: %T\n", v) + val = fmt.Sprintf("%v", v) + } + return val +} + +func writeParameters(msg *Message, args ...interface{}) { + msg.Set(len(args)) + for range args { + msg.Set(99) + msg.Set(4) + } + + msg.Set(1) // parameterSets + msg.Set(len(args)) + for _, arg := range args { + msg.Set(toODBC(arg)) + } +} + +func (c *Connection) Query(sqlText string, args ...interface{}) (rs *ResultSet, err error) { + queries := strings.Split(sqlText, ";\n") + if len(queries) == 2 { + sqlText = queries[0] + _, err = c.DirectUpdate(sqlText, args...) + if err != nil { + return + } + + sqlText = queries[1] + args = []interface{}{} + } + rs, err = c.DirectQuery(sqlText, args...) + if err != nil { + return + } + return +} + +func (c *Connection) DirectQuery(sqlText string, args ...interface{}) (*ResultSet, error) { + sqlText, _, args = FormatQuery(sqlText, args...) + // fmt.Printf("DirectQuery: %s; %#v\n", sqlText, args) + + var statementId = c.statementId() + msg := NewMessage(DIRECT_QUERY) + msg.header.SetStatementId(statementId) + msg.SetSQLText(sqlText) + writeParameters(&msg, args...) + msg.Set(10) // Query timeout + msg.Set(200) // Max rows + + _, err := c.conn.Write(msg.Dump(c.count())) + if err != nil { + return nil, err + } + msg, err = ReadMessage(c.conn) + if err != nil { + return nil, err + } + sqlCode := int16(msg.GetStatus()) + if sqlCode != 0 && sqlCode != 100 { + return nil, &SQLError{SQLCode: sqlCode, Message: c.getErrorInfo(sqlCode)} + } + statementFeature := statementFeature(&msg) + columns := getColumns(&msg, statementFeature) + parameterInfo((&msg)) + rs := &ResultSet{ + c: c, + sf: statementFeature, + columns: columns, + count: len(columns), + } + + msg, err = ReadMessage(c.conn) + rs.sqlCode = int16(msg.GetStatus()) + if err != nil { + return nil, err + } + + msg.GetRaw(&rs.data) + + return rs, nil +} + +func (m Message) debug() string { + var sb strings.Builder + for i, b := range m.data { + if i > 0 { + sb.WriteString(", ") + } + sb.WriteString(strconv.Itoa(int(b))) + } + return fmt.Sprintf("$char(%s)", sb.String()) +} + +func FormatQuery(sqlText string, args ...interface{}) (string, int, []interface{}) { + var count int + for i := range args { + count++ + sqlText = strings.Replace(sqlText, "?", fmt.Sprintf(" :%%qpar(%d) ", i+1), 1) + if !strings.Contains(sqlText, "?") { + break + } + } + return sqlText, count, args +} + +func (c *Connection) Exec(sqlText string, args ...interface{}) (res *Result, err error) { + queries := strings.Split(sqlText, ";\n") + var onConflict = "" + if len(queries) == 2 { + sqlText = queries[0] + onConflict = strings.Split(queries[1], "-- ")[1] + if strings.Contains(onConflict, "ON CONFLICT UPDATE") { + // fmt.Printf("------\n%s\n%#v\n------\n", sqlText, args) + sqlText = strings.Replace(sqlText, "INSERT INTO", "INSERT OR UPDATE", 1) + onConflict = "" + } + } + res, err = c.DirectUpdate(sqlText, args...) + if err != nil { + if strings.Contains(onConflict, "ON CONFLICT DO NOTHING") { + res = &Result{cn: c, affected: 0} + err = nil + return + } + } + return +} + +func (c *Connection) DirectUpdate(sqlText string, args ...interface{}) (*Result, error) { + var batchSize int + sqlText, batchSize, args = FormatQuery(sqlText, args...) + // fmt.Printf("DirectUpdate: %s; %#v\n", sqlText, args) + var batches = 1 + if batchSize > 0 { + batches = len(args) / batchSize + } + var addToCache = false + var statementId = c.statementId() + var executeMany = false + var optFastInsert = false + var rowsAffected int64 = 0 + var identityColumn = false + var defaults = []interface{}{} + for i := 1; i <= batches; i++ { + if i > 1 && executeMany { + break + } + var msg Message + if !addToCache { + msg = NewMessage(DIRECT_UPDATE) + msg.SetSQLText(sqlText) + msg.Set(batchSize) + for j := 0; j < batchSize; j++ { + msg.Set(99) + msg.Set(1) + } + // msg.Set(len(args)) + // for range args { + // msg.Set(99) + // msg.Set(1) + // } + } else { + msg = NewMessage(PREPARED_UPDATE) + } + if addToCache && !executeMany && optFastInsert { + msg.AddRaw([]byte{1, 0, 0, 0}) + msg.Set("") + msg.Set(0) + if identityColumn { + msg.Set(2) + msg.Set("") + } else { + msg.Set(1) + } + var batch []interface{} = make([]interface{}, batchSize) + copy(batch, args) + args = slices.Delete(args, 0, batchSize) + var params []byte + var item list.ListItem + for _, arg := range batch { + item = list.NewListItem(toODBC(arg)) + params = append(params, item.Dump()...) + } + for _, arg := range defaults { + item = list.NewListItem(toODBC(arg)) + params = append(params, item.Dump()...) + } + msg.Set(params) + } else { + msg.Set("") + msg.Set(0) + if executeMany { + msg.Set(batches) + for k := 0; k < batches; k++ { + msg.Set(batchSize) + for j := 0; j < batchSize; j++ { + var idx = (k * batchSize) + j + msg.Set(toODBC(args[idx])) + } + } + } else { + var batch []interface{} = make([]interface{}, batchSize) + copy(batch, args) + args = slices.Delete(args, 0, batchSize) + msg.Set(1) + msg.Set(len(batch)) + for _, arg := range batch { + msg.Set(toODBC(arg)) + } + } + } + + msg.header.SetStatementId(statementId) + _, err := c.conn.Write(msg.Dump(c.count())) + if err != nil { + return nil, err + } + msg, err = ReadMessage(c.conn) + if err != nil { + // fmt.Println("DirectUpdate:Readmessage: ", err) + return nil, err + } + sqlCode := int16(msg.GetStatus()) + if sqlCode != 0 && sqlCode != 100 { + return nil, &SQLError{SQLCode: sqlCode, Message: c.getErrorInfo(sqlCode)} + } + if i == 1 { + if c.IsOptionFastInsert() { + stmtFeatureOption, _ := c.checkStatementFeature(&msg) + optFastInsert = stmtFeatureOption&uint(OptionFastInsert) == uint(OptionFastInsert) + } + addToCache, identityColumn, defaults = c.getParameterInfo(&msg, optFastInsert) + } + var batchRows int64 + msg.Get(&batchRows) + rowsAffected += batchRows + } + result := &Result{cn: c, affected: rowsAffected} + return result, nil +} + +func (c *Connection) checkStatementFeature(msg *Message) (featureOption uint, count uint) { + count = 0 + var keyCount int + msg.Get(&featureOption) + if featureOption == uint(OptionFastSelect) || featureOption == uint(OptionFastInsert) { + if featureOption == uint(OptionFastInsert) { + msg.Get(&keyCount) + } + msg.Get(&count) + } + return +} + +func (c *Connection) getParameterInfo(msg *Message, optFastInsert bool) (addToCache bool, identityColumn bool, defaults []interface{}) { + var paramscnt int + msg.Get(¶mscnt) + var tablename string + for i := 0; i < paramscnt; i++ { + var ( + paramtype int + precision int + scale int + nullable bool + position int + someval1 string + someval2 string + colname string + ) + msg.Get(¶mtype) + msg.Get(&precision) + msg.Get(&scale) + msg.GetAny() + if optFastInsert { + msg.Get(&nullable) + msg.Get(&position) + msg.Get(&someval1) + msg.Get(&someval2) + if i == 0 { + msg.Get(&tablename) + } + msg.Get(&colname) + } + } + var flag int + defaults = []interface{}{} + identityColumn = false + msg.Get(&flag) + addToCache = flag&0x1 == 0x1 + if optFastInsert { + var paramsDefault []byte + msg.Get(¶msDefault) + var offset uint = 0 + var li list.ListItem + li = list.GetListItem(paramsDefault, &offset) + identityColumn = li.IsEmpty() + for { + if uint(len(paramsDefault)) == offset { + break + } + li = list.GetListItem(paramsDefault, &offset) + if li.IsNull() { + continue + } + var val string + li.Get(&val) + defaults = append(defaults, val) + } + } + return +} + +type Stmt struct { + cn *Connection + sql string + closed bool + statementId int32 +} + +func (c *Connection) Prepare(query string) (*Stmt, error) { + // msg := NewMessage(PREPARE) + // msg.SetSQLText(query) + // msg.Set(0) + + // _, err := c.conn.Write(msg.Dump(c.count())) + // if err != nil { + // return nil, err + // } + // msg, err = ReadMessage(c.conn) + // if err != nil { + // return nil, err + // } + // sqlCode := int16(msg.GetStatus()) + // if sqlCode != 0 && sqlCode != 100 { + // return nil, &SQLError{SQLCode: sqlCode, Message: c.getErrorInfo(sqlCode)} + // } + + st := &Stmt{cn: c, sql: query} + return st, nil +} + +func (st *Stmt) Exec(args []driver.Value) (res driver.Result, err error) { + parameters := make([]interface{}, len(args)) + for i, a := range args { + parameters[i] = a + } + res, err = st.cn.Exec(st.sql, parameters...) + return +} + +func (st *Stmt) Query(args []driver.Value) (rows driver.Rows, err error) { + parameters := make([]interface{}, len(args)) + for i, a := range args { + parameters[i] = a + } + var rs *ResultSet + rs, err = st.cn.Query(st.sql, parameters...) + // st.statementId = int32(st.cn.statementId()) + if err != nil { + return nil, err + } + rows = &Rows{ + cn: st.cn, + rs: rs, + } + return +} + +func (st *Stmt) Close() (err error) { + st.closed = true + return nil +} + +func (st *Stmt) NumInput() int { + return -1 +} diff --git a/third_party/go-irisnative/src/connection/transaction.go b/third_party/go-irisnative/src/connection/transaction.go new file mode 100644 index 0000000..9b2d6a6 --- /dev/null +++ b/third_party/go-irisnative/src/connection/transaction.go @@ -0,0 +1,29 @@ +package connection + +type tx struct { + c *Connection +} + +func (t *tx) Commit() error { + if t.c == nil || !t.c.tx { + panic("database/sql/driver: misuse of driver: extra Commit") + } + + t.c.tx = false + err := t.c.Commit() + t.c = nil + + return err +} + +func (t *tx) Rollback() error { + if t.c == nil || !t.c.tx { + panic("database/sql/driver: misuse of driver: extra Rollback") + } + + t.c.tx = false + err := t.c.Rollback() + t.c = nil + + return err +} diff --git a/third_party/go-irisnative/src/connection/user_posix.go b/third_party/go-irisnative/src/connection/user_posix.go new file mode 100644 index 0000000..32cf023 --- /dev/null +++ b/third_party/go-irisnative/src/connection/user_posix.go @@ -0,0 +1,22 @@ +//go:build !windows + +package connection + +import ( + "os" + "os/user" +) + +func userCurrent() (string, error) { + u, err := user.Current() + if err == nil { + return u.Username, nil + } + + name := os.Getenv("USER") + if name != "" { + return name, nil + } + + return "", ErrCouldNotDetectUsername +} diff --git a/third_party/go-irisnative/src/connection/user_windows.go b/third_party/go-irisnative/src/connection/user_windows.go new file mode 100644 index 0000000..3f3fde7 --- /dev/null +++ b/third_party/go-irisnative/src/connection/user_windows.go @@ -0,0 +1,19 @@ +package connection + +import ( + "path/filepath" + "syscall" +) + +// Perform Windows user name. +func userCurrent() (string, error) { + pw_name := make([]uint16, 128) + pwname_size := uint32(len(pw_name)) - 1 + err := syscall.GetUserNameEx(syscall.NameSamCompatible, &pw_name[0], &pwname_size) + if err != nil { + return "", ErrCouldNotDetectUsername + } + s := syscall.UTF16ToString(pw_name) + u := filepath.Base(s) + return u, nil +} diff --git a/third_party/go-irisnative/src/iris/oref.go b/third_party/go-irisnative/src/iris/oref.go new file mode 100644 index 0000000..0c2ea63 --- /dev/null +++ b/third_party/go-irisnative/src/iris/oref.go @@ -0,0 +1,4 @@ +package iris + +type Oref string + diff --git a/third_party/go-irisnative/src/list/listitem.go b/third_party/go-irisnative/src/list/listitem.go new file mode 100644 index 0000000..21b024e --- /dev/null +++ b/third_party/go-irisnative/src/list/listitem.go @@ -0,0 +1,456 @@ +package list + +import ( + "encoding/binary" + "errors" + "fmt" + "strconv" + + "github.com/caretdev/go-irisnative/src/iris" + "github.com/shopspring/decimal" +) + +type ListItemType byte + +const ( + LISTITEM_STRING ListItemType = 0x01 + LISTITEM_UNICODE ListItemType = 0x02 + LISTITEM_POSINT ListItemType = 0x04 + LISTITEM_NEGINT ListItemType = 0x05 + LISTITEM_POSFLOAT ListItemType = 0x06 + LISTITEM_NEGFLOAT ListItemType = 0x07 + LISTITEM_OREF ListItemType = 0x19 +) + +type ListItem struct { + size uint16 + itemType ListItemType + data []byte + isNull bool + byRef bool +} + +func (li *ListItem) IsNull() bool { + return li.isNull +} + +func (li *ListItem) IsString() bool { + return li.itemType == LISTITEM_STRING || li.itemType == LISTITEM_UNICODE +} + +func (li *ListItem) IsEmpty() bool { + return li.itemType == LISTITEM_STRING && len(li.data) == 0 +} + +func (li *ListItem) Type() ListItemType { + return li.itemType +} + +var scale = []float64{ + 1.0, 10.0, 100.0, 1000.0, 10000.0, 100000.0, 1000000.0, 1.0e7, 1.0e8, 1.0e9, + 1.0e10, 1.0e11, 1.0e12, 1.0e13, 1.0e14, 1.0e15, 1.0e16, 1.0e17, 1.0e18, 1.0e19, + 1.0e20, 1.0e21, 1.0e22, 9.999999999999999e22, 1.0e24, 1.0e25, 1.0e26, 1.0e27, 1.0e28, 1.0e29, + 1.0e30, 1.0e31, 1.0e32, 1.0e33, 1.0e34, 1.0e35, 1.0e36, 1.0e37, 1.0e38, 1.0e39, + 1.0e40, 1.0e41, 1.0e42, 1.0e43, 1.0e44, 1.0e45, 1.0e46, 1.0e47, 1.0e48, 1.0e49, + 1.0e50, 1.0e51, 1.0e52, 1.0e53, 1.0e54, 1.0e55, 1.0e56, 1.0e57, 1.0e58, 1.0e59, + 1.0e60, 1.0e61, 1.0e62, 1.0e63, 1.0e64, 1.0e65, 1.0e66, 1.0e67, 1.0e68, 1.0e69, + 1.0e70, 1.0e71, 1.0e72, 1.0e73, 1.0e74, 1.0e75, 1.0e76, 1.0e77, 1.0e78, 1.0e79, + 1.0e80, 1.0e81, 1.0e82, 1.0e83, 1.0e84, 1.0e85, 1.0e86, 1.0e87, 1.0e88, 1.0e89, + 1.0e90, 1.0e91, 1.0e92, 1.0e93, 1.0e94, 1.0e95, 1.0e96, 1.0e97, 1.0e98, 1.0e99, + 1.0e100, 1.0e101, 1.0e102, 1.0e103, 1.0e104, 1.0e105, 1.0e106, 1.0e107, 1.0e108, 1.0e109, + 1.0e110, 1.0e111, 1.0e112, 1.0e113, 1.0e114, 1.0e115, 1.0e116, 1.0e117, 1.0e118, 1.0e119, + 1.0e120, 1.0e121, 1.0e122, 1.0e123, 1.0e124, 1.0e125, 1.0e126, 1.0e127, 1.0e-128, 1.0e-127, + 1.0e-126, 1.0e-125, 1.0e-124, 1.0e-123, 1.0e-122, 1.0e-121, 1.0e-120, 1.0e-119, 1.0e-118, 1.0e-117, + 1.0e-116, 1.0e-115, 1.0e-114, 1.0e-113, 1.0e-112, 1.0e-111, 1.0e-110, 1.0e-109, 1.0e-108, 1.0e-107, + 1.0e-106, 1.0e-105, 1.0e-104, 1.0e-103, 1.0e-102, 1.0e-101, 1.0e-100, 1.0e-99, 1.0e-98, 1.0e-97, + 1.0e-96, 1.0e-95, 1.0e-94, 1.0e-93, 1.0e-92, 1.0e-91, 1.0e-90, 1.0e-89, 1.0e-88, 1.0e-87, + 1.0e-86, 1.0e-85, 1.0e-84, 1.0e-83, 1.0e-82, 1.0e-81, 1.0e-80, 1.0e-79, 1.0e-78, 1.0e-77, + 1.0e-76, 1.0e-75, 1.0e-74, 1.0e-73, 1.0e-72, 1.0e-71, 1.0e-70, 1.0e-69, 1.0e-68, 1.0e-67, + 1.0e-66, 1.0e-65, 1.0e-64, 1.0e-63, 1.0e-62, 1.0e-61, 1.0e-60, 1.0e-59, 1.0e-58, 1.0e-57, + 1.0e-56, 1.0e-55, 1.0e-54, 1.0e-53, 1.0e-52, 1.0e-51, 1.0e-50, 1.0e-49, 1.0e-48, 1.0e-47, + 1.0e-46, 1.0e-45, 1.0e-44, 1.0e-43, 1.0e-42, 1.0e-41, 1.0e-40, 1.0e-39, 1.0e-38, 1.0e-37, + 1.0e-36, 1.0e-35, 1.0e-34, 1.0e-33, 1.0e-32, 1.0e-31, 1.0e-30, 1.0e-29, 1.0e-28, 1.0e-27, + 1.0e-26, 1.0e-25, 1.0e-24, 1.0e-23, 1.0e-22, 1.0e-21, 1.0e-20, 1.0e-19, 1.0e-18, 1.0e-17, + 1.0e-16, 1.0e-15, 1.0e-14, 1.0e-13, 1.0e-12, 1.0e-11, 1.0e-10, 1.0e-9, 1.0e-8, 1.0e-7, + 1.0e-6, 1.0e-5, 1.0e-4, 0.001, 0.01, 0.1} + +func (listItem *ListItem) Dump() []byte { + if listItem.isNull { + return []byte{1} + } + var dump = make([]byte, 0) + if listItem.size > 253 { + size := listItem.size + 1 + dump = append(dump, 0) + dump = append(dump, byte((size)&0xff)) + dump = append(dump, byte((size>>8)&0xff)) + } else { + dump = append(dump, byte(listItem.size+2)) + } + dump = append(dump, byte(listItem.itemType)) + dump = append(dump, listItem.data...) + return dump +} + +func GetListItem(buffer []byte, ooffset *uint) ListItem { + var byRef = false + var isNull = false + var size uint16 = 0 + var itemType byte = 0 + offset := *ooffset + + switch buffer[offset] { + case 0: + size = uint16((buffer[offset+1] & 0xff)) + size |= ((uint16(buffer[offset+2]) & 0xff) << 8) + size -= 1 + offset += 3 + itemType = buffer[offset] + offset += 1 + case 1: + isNull = true + offset += 1 + default: + size = uint16(buffer[offset]) - 2 + offset += 1 + itemType = buffer[offset] + offset += 1 + if itemType >= 32 && itemType < 64 { + itemType = itemType - 32 + byRef = true + } + } + var data = []byte{} + if size > 0 { + data = buffer[offset : offset+uint(size)] + } + offset += uint(size) + *ooffset = offset + return ListItem{size, ListItemType(itemType), data, isNull, byRef} +} + +func NewListItem(value interface{}) ListItem { + var itemType ListItemType = 0 + var size uint16 = 0 + var data = make([]byte, 0) + var isNull = false + var byRef = false + + switch v := value.(type) { + case *string: + var listItem = NewListItem(*v) + listItem.byRef = true + return listItem + case int, int8, int16, int32, int64: + var ival int64 + switch i := v.(type) { + case int: + ival = int64(i) + case int8: + ival = int64(i) + case int16: + ival = int64(i) + case int32: + ival = int64(i) + case int64: + ival = i + } + itemType = 4 + var base = 0 + var temp = ival + if ival < 0 { + itemType = 5 + base = 0xff + temp = ival*-1 - 1 + } + for temp > 0 { + data = append(data, byte((temp^int64(base))&0xff)) + temp = temp >> 8 + } + case uint, uint8, uint16, uint32, uint64: + var uval uint64 + switch u := v.(type) { + case uint: + uval = uint64(u) + case uint8: + uval = uint64(u) + case uint16: + uval = uint64(u) + case uint32: + uval = uint64(u) + case uint64: + uval = u + } + itemType = 4 + temp := uval + for temp > 0 { + data = append(data, byte(temp&0xff)) + temp = temp >> 8 + } + case float64, float32: + var d decimal.Decimal + switch f := v.(type) { + case float32: + d = decimal.NewFromFloat32(f) + case float64: + d = decimal.NewFromFloat(f) + } + scaleSize := 256 - d.Exponent()*-1 + ival := d.Coefficient().Int64() + itemType = 6 + if ival < 0 { + itemType = 7 + } + data = append(data, byte(scaleSize)) + var base = 0 + var temp = ival + if ival < 0 { + base = 0xff + temp = ival*-1 - 1 + } + for temp > 0 { + data = append(data, byte((temp^int64(base))&0xff)) + temp = temp >> 8 + } + case bool: + itemType = 4 + if v { + data = []byte{0x1} + } else { + data = []byte{0x0} + } + case string: + itemType = 1 + var unicodeBytes []byte + for _, r := range(v) { + if r > 255 { + itemType = 2 + var temp = r + // append(unicodeBytes) + for temp > 0 { + unicodeBytes = append(unicodeBytes, byte((temp)&0xff)) + temp = temp >> 8 + } + } else { + unicodeBytes = append(unicodeBytes, byte((r)&0xff)) + unicodeBytes = append(unicodeBytes, byte(0)) + } + } + if itemType == 2 { + data = unicodeBytes + } else { + data = []byte(v) + } + case []byte: + itemType = 1 + data = v + case nil: + isNull = true + // itemType = 1 + // data = []byte("") + case iris.Oref: + itemType = 25 + byRef = true + data = []byte(v) + default: + fmt.Printf("unknown: %#v %T\n", v, v) + itemType = 1 + data = []byte(fmt.Sprintf("%v", v)) + } + size = uint16(len(data)) + return ListItem{ + size, + itemType, + data, + isNull, + byRef, + } +} + +func (li *ListItem) getString() string { + if li.itemType == LISTITEM_UNICODE { + var val string = "" + for i := 0; i < len(li.data); i += 2 { + val += string(rune(getPosInt(li.data[i:i+2]))) + } + return val + } else { + return string(li.data) + } +} + +func getPosInt(data []byte) int { + temp := make([]byte, 8) + copy(temp, data) + return int(binary.LittleEndian.Uint64(temp[:8])) +} + +func getNegInt(data []byte) int { + temp := make([]byte, 8) + copy(temp, data) + for i := range data { + temp[i] ^= 0xff + } + return -int(binary.LittleEndian.Uint64(temp[:8]) + 1) +} + +func getPosFloat(data []byte) float64 { + d := scale[int(data[0])] + return float64(getPosInt(data[1:])) * d +} + +func getNegFloat(data []byte) float64 { + d := scale[int(data[0])] + return float64(getNegInt(data[1:])) * d +} + +func (li *ListItem) asString() (value string, err error) { + if li.isNull { + value = "" + return + } + switch li.itemType { + case 1, 2, 25: + value = li.getString() + case 4: + value = fmt.Sprint(getPosInt(li.data)) + case 5: + value = fmt.Sprint(getNegInt(li.data)) + case 6: + value = fmt.Sprint(getPosFloat(li.data)) + case 7: + value = fmt.Sprint(getNegFloat(li.data)) + default: + err = errors.New("not implemented") + } + return +} + +func (li *ListItem) asInt() (value int, err error) { + if li.isNull { + value = 0 + return + } + switch li.itemType { + case 1, 2: + value, err = strconv.Atoi(li.getString()) + case 4: + value = getPosInt(li.data) + case 5: + value = getNegInt(li.data) + case 6: + value = int(getPosFloat(li.data)) + case 7: + value = int(getNegFloat(li.data)) + default: + err = errors.New("not implemented") + } + return +} + +func (li *ListItem) asFloat64() (value float64, err error) { + if li.isNull { + value = 0 + return + } + switch li.itemType { + case 1, 2: + var temp int + temp, err = strconv.Atoi(li.getString()) + if err != nil { + return + } + value = float64(temp) + case 4: + value = float64(getPosInt(li.data)) + case 5: + value = float64(getNegInt(li.data)) + case 6: + value = getPosFloat(li.data) + case 7: + value = getNegFloat(li.data) + default: + err = errors.New("not implemented") + } + return +} + +type AnyType ListItem + +func (v *AnyType) Int() int { + var value int + // ListItem(*v) + return value +} + +func (li *ListItem) GetAny() AnyType { + return AnyType(*li) +} + +func (li *ListItem) DataLength() int { + return len(li.data) +} + +func (li *ListItem) Get(value interface{}) (err error) { + switch v := value.(type) { + case *int: + *v, err = li.asInt() + case *bool: + var temp int + temp, err = li.asInt() + *v = temp != 0 + case *int8: + var temp int + temp, err = li.asInt() + *v = int8(temp) + case *int16: + var temp int + temp, err = li.asInt() + *v = int16(temp) + case *int32: + var temp int + temp, err = li.asInt() + *v = int32(temp) + case *int64: + var temp int + temp, err = li.asInt() + *v = int64(temp) + case *uint: + var temp int + temp, err = li.asInt() + *v = uint(temp) + case *uint8: + var temp int + temp, err = li.asInt() + *v = uint8(temp) + case *uint16: + var temp int + temp, err = li.asInt() + *v = uint16(temp) + case *uint32: + var temp int + temp, err = li.asInt() + *v = uint32(temp) + case *uint64: + var temp int + temp, err = li.asInt() + *v = uint64(temp) + case *float64: + *v, err = li.asFloat64() + case *float32: + var temp float64 + temp, err = li.asFloat64() + *v = float32(temp) + case *string: + *v, err = li.asString() + case *[]byte: + *v = li.data + case *iris.Oref: + var temp string + temp, err = li.asString() + *v = iris.Oref(temp) + default: + err = errors.New("not implemented") + } + return +} diff --git a/third_party/go-irisnative/url.go b/third_party/go-irisnative/url.go new file mode 100644 index 0000000..c8ad49e --- /dev/null +++ b/third_party/go-irisnative/url.go @@ -0,0 +1,76 @@ +package intersystems + +import ( + "fmt" + "net" + nurl "net/url" + "sort" + "strings" +) + +// ParseURL no longer needs to be used by clients of this library since supplying a URL as a +// connection string to sql.Open() is now supported: +// +// sql.Open("intersystems", "iris://_system:SYS@1.2.3.4:1972/USER") +// +// It remains exported here for backwards-compatibility. +// +// ParseURL converts a url to a connection string for driver.Open. +// Example: +// +// "iris://_system:SYS@1.2.3.4:1972/USER" +// +// converts to: +// +// "user=_system password=SYS host=1.2.3.4 port=1972 namespace=USER" +// +// A minimal example: +// +// "iris://" +// +// This will be blank, causing driver.Open to use all of the defaults +func ParseURL(url string) (string, error) { + u, err := nurl.Parse(url) + if err != nil { + return "", err + } + + if u.Scheme != "iris" && u.Scheme != "IRIS" { + return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme) + } + + var kvs []string + escaper := strings.NewReplacer(`'`, `\'`, `\`, `\\`) + accrue := func(k, v string) { + if v != "" { + kvs = append(kvs, k+"='"+escaper.Replace(v)+"'") + } + } + + if u.User != nil { + v := u.User.Username() + accrue("user", v) + + v, _ = u.User.Password() + accrue("password", v) + } + + if host, port, err := net.SplitHostPort(u.Host); err != nil { + accrue("host", u.Host) + } else { + accrue("host", host) + accrue("port", port) + } + + if u.Path != "" { + accrue("namespace", strings.ToUpper(u.Path[1:])) + } + + q := u.Query() + for k := range q { + accrue(k, q.Get(k)) + } + + sort.Strings(kvs) // Makes testing easier (not a performance concern) + return strings.Join(kvs, " "), nil +} diff --git a/tools/detect-changed-driver-agents.sh b/tools/detect-changed-driver-agents.sh index b220c64..ece4bc6 100644 --- a/tools/detect-changed-driver-agents.sh +++ b/tools/detect-changed-driver-agents.sh @@ -7,7 +7,7 @@ cd "$SCRIPT_DIR" SCRIPT_DIR_WINDOWS="$(pwd -W 2>/dev/null || true)" SCRIPT_DIR_WINDOWS="${SCRIPT_DIR_WINDOWS//\\//}" -DEFAULT_DRIVERS=(mariadb oceanbase doris starrocks sphinx sqlserver sqlite duckdb dameng kingbase highgo vastbase opengauss mongodb tdengine clickhouse) +DEFAULT_DRIVERS=(mariadb oceanbase doris starrocks sphinx sqlserver sqlite duckdb dameng kingbase highgo vastbase opengauss iris mongodb tdengine clickhouse) TARGET_PLATFORMS=(darwin/amd64 darwin/arm64 windows/amd64 windows/arm64 linux/amd64) usage() { @@ -50,7 +50,7 @@ normalize_driver() { case "$value" in doris|diros) echo "doris" ;; open_gauss|open-gauss) echo "opengauss" ;; - mariadb|oceanbase|starrocks|sphinx|sqlserver|sqlite|duckdb|dameng|kingbase|highgo|vastbase|opengauss|mongodb|tdengine|clickhouse) + mariadb|oceanbase|starrocks|sphinx|sqlserver|sqlite|duckdb|dameng|kingbase|highgo|vastbase|opengauss|iris|mongodb|tdengine|clickhouse) echo "$value" ;; *) diff --git a/tools/generate-driver-agent-revisions.sh b/tools/generate-driver-agent-revisions.sh index fb48c9c..fda338b 100755 --- a/tools/generate-driver-agent-revisions.sh +++ b/tools/generate-driver-agent-revisions.sh @@ -5,7 +5,7 @@ set -euo pipefail SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" cd "$SCRIPT_DIR" -DEFAULT_DRIVERS=(mariadb oceanbase diros starrocks sphinx sqlserver sqlite duckdb dameng kingbase highgo vastbase opengauss mongodb tdengine clickhouse) +DEFAULT_DRIVERS=(mariadb oceanbase diros starrocks sphinx sqlserver sqlite duckdb dameng kingbase highgo vastbase opengauss iris mongodb tdengine clickhouse) OUTPUT_FILE="internal/db/driver_agent_revisions_gen.go" usage() { @@ -27,7 +27,7 @@ normalize_driver() { doris|diros) echo "diros" ;; oceanbase) echo "oceanbase" ;; opengauss|open_gauss|open-gauss) echo "opengauss" ;; - mariadb|diros|starrocks|sphinx|sqlserver|sqlite|duckdb|dameng|kingbase|highgo|vastbase|mongodb|tdengine|clickhouse) + mariadb|diros|starrocks|sphinx|sqlserver|sqlite|duckdb|dameng|kingbase|highgo|vastbase|iris|mongodb|tdengine|clickhouse) echo "$value" ;; *) @@ -125,6 +125,7 @@ highgo:internal/db/highgo_impl.go|\ vastbase:internal/db/vastbase_impl.go|\ opengauss:internal/db/opengauss_impl.go|\ opengauss:internal/db/postgres_impl.go|\ +iris:internal/db/iris_impl.go|\ mongodb:internal/db/mongodb_impl.go|\ mongodb:internal/db/mongodb_impl_v1.go|\ tdengine:internal/db/tdengine_impl.go|\