feat(iris): 新增 InterSystems IRIS 数据源支持

- 后端新增 IRIS 连接、查询、DDL、索引元数据和 DataGrid 编辑能力
- 接入 optional driver-agent、构建标签、revision 生成和变更检测流程
- 前端新增 IRIS 连接入口、方言映射、能力配置和图标展示
- 修复 IRIS 主键识别、事务开启错误处理和驱动连接关闭问题
- 补充后端、前端和构建脚本相关回归测试
Refs #408
This commit is contained in:
Syngnat
2026-05-17 10:32:08 +08:00
parent 0cde96844d
commit 992d2dee45
57 changed files with 4391 additions and 16 deletions

21
third_party/go-irisnative/LICENSE vendored Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2025 Dmitry Maslennikov
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

11
third_party/go-irisnative/PATCHES.md vendored Normal file
View File

@@ -0,0 +1,11 @@
Local patch against github.com/caretdev/go-irisnative v0.2.1:
- Added `//go:build !windows` to `src/connection/user_posix.go`.
Upstream ships `user_windows.go` with a Windows filename suffix, but
`user_posix.go` has no build constraint, so Windows builds compile both
files and fail with `userCurrent redeclared`.
- Made `Connection.Disconnect` close the underlying TCP connection after
sending the protocol disconnect message, so `database/sql` closes do not
leak sockets.
- Made `Connection.BeginTx` return the `START TRANSACTION` error instead of
marking the connection as in-transaction when the server rejected the begin.

285
third_party/go-irisnative/README.md vendored Normal file
View File

@@ -0,0 +1,285 @@
# go-irisnative
A Golang driver for InterSystems IRIS that implements `database/sql`.
> Project status: **alpha**. API may change. Feedback and PRs welcome.
---
## Installation
```bash
# replace the module path with the final repo path when published
go get github.com/caretdev/go-irisnative
```
Register the driver by importing it for sideeffects:
```go
import (
"database/sql"
_ "github.com/caretdev/go-irisnative" // registers driver as "iris"
)
```
## DSN formats
The driver accepts a URL-style DSN (recommended) or key=value pairs.
**URL style**
```
iris://user:password@host:1972/NAMESPACE?
```
* `host` — IRIS hostname or IP
* `1972` — superserver port (default)
* `Namespace` — IRIS namespace (e.g., `USER`)
---
## Quick start (database/sql)
```go
package main
import (
"context"
"database/sql"
"fmt"
"log"
"time"
_ "github.com/caretdev/go-irisnative"
)
func main() {
dsn := "iris://_SYSTEM:SYS@localhost:1972/USER"
db, err := sql.Open("iris", dsn)
if err != nil { log.Fatal(err) }
defer db.Close()
// Connection pool tuning
db.SetMaxOpenConns(10)
db.SetMaxIdleConns(5)
db.SetConnMaxLifetime(30 * time.Minute)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
_, err = db.ExecContext(ctx, `DROP TABLE IF EXISTS demo_person`)
if err != nil { log.Fatal("drop table:", err) }
// 1) Create a table (id INT PRIMARY KEY, name VARCHAR(80))
_, err = db.ExecContext(ctx, `CREATE TABLE IF NOT EXISTS demo_person (
id INT PRIMARY KEY,
name VARCHAR(80) NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)`)
if err != nil { log.Fatal("create table:", err) }
// 2) Insert with placeholders
res, err := db.ExecContext(ctx, `INSERT INTO demo_person(id, name) VALUES(?, ?)`, 1, "Alice")
if err != nil { log.Fatal("insert:", err) }
if n, _ := res.RowsAffected(); n > 0 { fmt.Println("inserted:", n) }
// 3) Query rows
rows, err := db.QueryContext(ctx, `SELECT id, name, created_at FROM demo_person ORDER BY id`)
if err != nil { log.Fatal("query:", err) }
defer rows.Close()
for rows.Next() {
var (
id int
name string
createdAt time.Time
)
if err := rows.Scan(&id, &name, &createdAt); err != nil { log.Fatal(err) }
fmt.Printf("row: id=%d name=%s created_at=%s\n", id, name, createdAt.Format(time.RFC3339))
}
if err := rows.Err(); err != nil { log.Fatal(err) }
// 4) Prepared statement
stmt, err := db.PrepareContext(ctx, `UPDATE demo_person SET name=? WHERE id=?`)
if err != nil { log.Fatal("prepare:", err) }
defer stmt.Close()
if _, err := stmt.ExecContext(ctx, "Alice Updated", 1); err != nil { log.Fatal("update:", err) }
// 5) Transaction example
tx, err := db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadCommitted})
if err != nil { log.Fatal("begin tx:", err) }
if _, err := tx.ExecContext(ctx, `INSERT INTO demo_person(id, name) VALUES(?, ?)`, 2, "Bob"); err != nil {
tx.Rollback()
log.Fatal("tx insert:", err)
}
if err := tx.Commit(); err != nil { log.Fatal("commit:", err) }
}
```
### Query single value helper
```go
var count int
if err := db.QueryRowContext(ctx, `SELECT COUNT(*) FROM demo_person`).Scan(&count); err != nil {
log.Fatal(err)
}
fmt.Println("count=", count)
```
---
## Using with `sqlx`
`sqlx` adds nice helpers over `database/sql` like struct scanning and named queries.
```bash
go get github.com/jmoiron/sqlx
```
```go
package main
import (
"context"
"fmt"
"log"
"time"
_ "github.com/caretdev/go-irisnative" // driver
"github.com/jmoiron/sqlx"
)
type Person struct {
ID int `db:"id"`
Name string `db:"name"`
CreatedAt time.Time `db:"created_at"`
}
func create(ctx context.Context, db *sqlx.DB) {
drop(ctx, db)
_, err := db.ExecContext(ctx, `CREATE TABLE IF NOT EXISTS demo_person (
id INT PRIMARY KEY,
name VARCHAR(80) NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)`)
if err != nil {
panic(err)
}
}
func drop(ctx context.Context, db *sqlx.DB) {
_, err := db.ExecContext(ctx, `DROP TABLE IF EXISTS demo_person`)
if err != nil {
panic(err)
}
}
func main() {
ctx := context.Background()
dsn := "iris://_SYSTEM:SYS@localhost:1972/USER"
db := sqlx.MustConnect("iris", dsn)
defer db.Close()
create(ctx, db)
defer drop(ctx, db)
// Struct-based insert with NamedExec
p := Person{ID: 3, Name: "Carol"}
_, err := db.NamedExecContext(ctx,
`INSERT INTO demo_person(id, name) VALUES(:id, :name)`, p,
)
if err != nil {
log.Fatal("named insert:", err)
}
// Select into slice of structs
var people []Person
if err := db.SelectContext(ctx, &people, `SELECT id, name, created_at FROM demo_person ORDER BY id`); err != nil {
log.Fatal(err)
}
fmt.Printf("people: %#v\n", people)
// Get a single struct
var one Person
if err := db.GetContext(ctx, &one, `SELECT id, name, created_at FROM demo_person WHERE id=?`, people[0].ID); err != nil {
log.Fatal(err)
}
fmt.Printf("one: %+v\n", one)
// Named query with IN (sqlx.In)
ids := []int{1, 2, 3}
q, args, err := sqlx.In(`SELECT id, name FROM demo_person WHERE id IN (?)`, ids)
if err != nil {
log.Fatal(err)
}
q = db.Rebind(q) // ensure driver-specific bindvars
rows, err := db.QueryxContext(ctx, q, args...)
if err != nil {
log.Fatal(err)
}
defer rows.Close()
for rows.Next() {
var id int
var name string
if err := rows.Scan(&id, &name); err != nil {
log.Fatal(err)
}
fmt.Println(id, name)
}
}
```
---
## Placeholders & rebind
* The driver uses `?` positional placeholders.
* With `sqlx`, **always** call `db.Rebind(q)` after `sqlx.In(...)` to adapt placeholders.
---
## Context, timeouts & cancellations
All examples use `Context`. Set sensible timeouts to avoid runaway queries:
```go
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
```
---
## Error handling tips
* Check `rows.Err()` after iteration.
* Prefer `ExecContext`/`QueryContext` to ensure timeouts are respected.
* Wrap errors with operation context (e.g., `fmt.Errorf("create table: %w", err)`).
---
## Testing locally
1. Start IRIS and ensure SQL is enabled for your namespace (e.g., `USER`).
2. Create a SQL user with privileges to connect and create tables.
3. Verify connectivity using the DSN shown above.
---
## Compatibility
* Go: 1.21+
* InterSystems IRIS: 2025.1+
---
## License
MIT
---
## Contributing
* Run `go vet` and `go test ./...` before submitting PRs.
* Add tests for new behaviors.
* Document any DSN parameters you introduce.

85
third_party/go-irisnative/connector.go vendored Normal file
View File

@@ -0,0 +1,85 @@
package intersystems
import (
"context"
"database/sql/driver"
"errors"
"strings"
)
// Connector represents a fixed configuration for the pq driver with a given
// name. Connector satisfies the database/sql/driver Connector interface and
// can be used to create any number of DB Conn's via the database/sql OpenDB
// function.
//
// See https://golang.org/pkg/database/sql/driver/#Connector.
// See https://golang.org/pkg/database/sql/#OpenDB.
type Connector struct {
opts values
// dialer Dialer
}
// Connect returns a connection to the database using the fixed configuration
// of this Connector. Context is not used.
func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) {
return c.open(ctx)
}
// Driver returns the underlying driver of this Connector.
func (c *Connector) Driver() driver.Driver {
return &Driver{}
}
// NewConnector returns a connector for the pq driver in a fixed configuration
// with the given dsn. The returned connector can be used to create any number
// of equivalent Conn's. The returned connector is intended to be used with
// database/sql.OpenDB.
//
// See https://golang.org/pkg/database/sql/driver/#Connector.
// See https://golang.org/pkg/database/sql/#OpenDB.
func NewConnector(dsn string) (*Connector, error) {
var err error
o := make(values)
// A number of defaults are applied here, in this order:
//
// * Very low precedence defaults applied in every situation
// * Environment variables
// * Explicitly passed connection information
o["host"] = "localhost"
o["port"] = "1972"
if strings.HasPrefix(dsn, "iris://") || strings.HasPrefix(dsn, "IRIS://") {
dsn, err = ParseURL(dsn)
if err != nil {
return nil, err
}
}
if err := parseOpts(dsn, o); err != nil {
return nil, err
}
if enc, ok := o["client_encoding"]; ok && !isUTF8(enc) {
return nil, errors.New("client_encoding must be absent or 'UTF8'")
}
o["client_encoding"] = "UTF8"
return &Connector{opts: o, /*dialer: defaultDialer{}*/}, nil
}
// isUTF8 returns whether name is a fuzzy variation of the string "UTF-8".
func isUTF8(name string) bool {
s := strings.Map(alnumLowerASCII, name)
return s == "utf8" || s == "unicode"
}
func alnumLowerASCII(ch rune) rune {
if 'A' <= ch && ch <= 'Z' {
return ch + ('a' - 'A')
}
if 'a' <= ch && ch <= 'z' || '0' <= ch && ch <= '9' {
return ch
}
return -1 // discard
}

216
third_party/go-irisnative/driver.go vendored Normal file
View File

@@ -0,0 +1,216 @@
package intersystems
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"net"
"unicode"
_ "io"
_ "math"
_ "reflect"
_ "strconv"
_ "strings"
_ "time"
_ "unsafe"
"github.com/caretdev/go-irisnative/src/connection"
)
var (
ErrCouldNotDetectUsername = errors.New("intersystems: Could not detect default username. Please provide one explicitly")
)
var (
_ driver.Driver = Driver{}
)
type values map[string]string
// Driver implements database/sql/driver.Driver.
type Driver struct{}
func (d Driver) Open(name string) (driver.Conn, error) {
return Open(name)
}
func init() {
sql.Register("intersystems", &Driver{})
sql.Register("iris", &Driver{})
}
func Open(dsn string) (_ driver.Conn, err error) {
c, err := NewConnector(dsn)
if err != nil {
return nil, err
}
return c.open(context.Background())
}
type conn struct {
c connection.Connection
tx bool
}
func (c *Connector) open(ctx context.Context) (cn *conn, err error) {
o := make(values)
for k, v := range c.opts {
o[k] = v
}
host := o["host"]
addr := net.JoinHostPort(host, o["port"])
namespace := o["namespace"]
login := o["user"]
password := o["password"]
cn = &conn{}
cn.c, err = connection.Connect(addr, namespace, login, password)
if err != nil {
return nil, err
}
return cn, nil
}
func (cn *conn) Begin() (driver.Tx, error) {
return cn.c.BeginTx(driver.TxOptions{})
}
func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
return cn.c.BeginTx(opts)
}
func (cn *conn) Close() (err error) {
cn.c.Disconnect()
return nil
}
func (cn *conn) Prepare(q string) (st driver.Stmt, err error) {
return cn.c.Prepare(q)
}
func (cn *conn) Commit() error {
if !cn.tx {
panic("transaction already closed")
}
cn.tx = false
cn.c.Commit()
return nil
}
func (cn *conn) Rollback() error {
if !cn.tx {
panic("transaction already closed")
}
cn.tx = false
cn.c.Rollback()
return nil
}
func (cn *conn) Exec(query string, args []driver.NamedValue) (res driver.Result, err error) {
parameters := make([]interface{}, len(args))
for i, a := range args {
parameters[i] = a
}
_, err = cn.c.DirectUpdate(query, parameters...)
if err != nil {
return nil, err
}
return res, nil
}
func (cn *conn) Query(query string, args []driver.NamedValue) (rows driver.Rows, err error) {
parameters := make([]interface{}, len(args))
for i, a := range args {
parameters[i] = a
}
// var rs *connection.ResultSet
_, err = cn.c.Query(query, parameters...)
if err != nil {
return nil, err
}
// rows = &connection.Rows{
// cn: cn.c,
// rs: rs,
// }
return
}
func parseOpts(name string, o values) error {
s := newScanner(name)
for {
var (
keyRunes, valRunes []rune
r rune
ok bool
)
if r, ok = s.SkipSpaces(); !ok {
break
}
// Scan the key
for !unicode.IsSpace(r) && r != '=' {
keyRunes = append(keyRunes, r)
if r, ok = s.Next(); !ok {
break
}
}
// Skip any whitespace if we're not at the = yet
if r != '=' {
r, ok = s.SkipSpaces()
}
// The current character should be =
if r != '=' || !ok {
return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes))
}
// Skip any whitespace after the =
if r, ok = s.SkipSpaces(); !ok {
// If we reach the end here, the last value is just an empty string as per libpq.
o[string(keyRunes)] = ""
break
}
if r != '\'' {
for !unicode.IsSpace(r) {
if r == '\\' {
if r, ok = s.Next(); !ok {
return fmt.Errorf(`missing character after backslash`)
}
}
valRunes = append(valRunes, r)
if r, ok = s.Next(); !ok {
break
}
}
} else {
quote:
for {
if r, ok = s.Next(); !ok {
return fmt.Errorf(`unterminated quoted string literal in connection string`)
}
switch r {
case '\'':
break quote
case '\\':
r, _ = s.Next()
fallthrough
default:
valRunes = append(valRunes, r)
}
}
}
o[string(keyRunes)] = string(valRunes)
}
return nil
}

5
third_party/go-irisnative/go.mod vendored Normal file
View File

@@ -0,0 +1,5 @@
module github.com/caretdev/go-irisnative
go 1.24.3
require github.com/shopspring/decimal v1.4.0

1
third_party/go-irisnative/go.sum vendored Normal file
View File

@@ -0,0 +1 @@
github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=

34
third_party/go-irisnative/scanner.go vendored Normal file
View File

@@ -0,0 +1,34 @@
package intersystems
import "unicode"
type scanner struct {
s []rune
i int
}
// newScanner returns a new scanner initialized with the option string s.
func newScanner(s string) *scanner {
return &scanner{[]rune(s), 0}
}
// Next returns the next rune.
// It returns 0, false if the end of the text has been reached.
func (s *scanner) Next() (rune, bool) {
if s.i >= len(s.s) {
return 0, false
}
r := s.s[s.i]
s.i++
return r, true
}
// SkipSpaces returns the next non-whitespace rune.
// It returns 0, false if the end of the text has been reached.
func (s *scanner) SkipSpaces() (rune, bool) {
r, ok := s.Next()
for unicode.IsSpace(r) && ok {
r, ok = s.Next()
}
return r, ok
}

View File

@@ -0,0 +1,113 @@
package connection
import "github.com/caretdev/go-irisnative/src/iris"
func (c *Connection) ServerVersion() (result string, err error) {
err = c.ClassMethod("%SYSTEM.Version", "GetVersion", &result)
return
}
func (c *Connection) ClassMethod(class, method string, result interface{}, args ...interface{}) (err error) {
msg := NewMessage(CLASSMETHOD_VALUE)
msg.Set(class)
msg.Set(method)
msg.Set(len(args))
for _, arg := range args {
msg.Set(arg)
}
_, err = c.conn.Write(msg.Dump(c.count()))
if err != nil {
return
}
msg, err = ReadMessage(c.conn)
if err != nil {
return
}
msg.Get(result)
return
}
func (c *Connection) ClassMethodVoid(class, method string, args ...interface{}) (err error) {
msg := NewMessage(CLASSMETHOD_VOID)
msg.Set(class)
msg.Set(method)
msg.Set(len(args))
for _, arg := range args {
msg.Set(arg)
}
_, err = c.conn.Write(msg.Dump(c.count()))
if err != nil {
return
}
msg, err = ReadMessage(c.conn)
if err != nil {
return
}
return
}
func (c *Connection) Method(obj iris.Oref, method string, result interface{}, args ...interface{}) (err error) {
msg := NewMessage(METHOD_VALUE)
msg.Set(obj)
msg.Set(method)
msg.Set(len(args))
for _, arg := range args {
msg.Set(arg)
}
_, err = c.conn.Write(msg.Dump(c.count()))
if err != nil {
return
}
msg, err = ReadMessage(c.conn)
if err != nil {
return
}
msg.Get(result)
return
}
func (c *Connection) MethodVoid(obj, method string, args ...interface{}) (err error) {
msg := NewMessage(METHOD_VOID)
msg.Set(obj)
msg.Set(method)
msg.Set(len(args))
for _, arg := range args {
msg.Set(arg)
}
_, err = c.conn.Write(msg.Dump(c.count()))
if err != nil {
return
}
msg, err = ReadMessage(c.conn)
if err != nil {
return
}
return
}
func (c *Connection) PropertyGet(obj iris.Oref, property string, result interface{}) (err error) {
msg := NewMessage(PROPERTY_GET)
msg.Set(obj)
msg.Set(property)
// msg.Set(0)
_, err = c.conn.Write(msg.Dump(c.count()))
if err != nil {
return
}
msg, err = ReadMessage(c.conn)
if err != nil {
return
}
msg.Get(result)
return
}

View File

@@ -0,0 +1,141 @@
package connection
func (c *Connection) GlobalIsDefined(global string, subs ...interface{}) (bool, bool) {
msg := NewMessage(GLOBAL_DATA)
msg.Set(global)
msg.Set(len(subs))
for _, sub := range subs {
msg.Set(sub)
}
msg.Set(0)
_, err := c.conn.Write(msg.Dump(c.count()))
if err != nil {
return false, false
}
msg, err = ReadMessage(c.conn)
if err != nil {
return false, false
}
var result uint8
msg.Get(&result)
return result%10 == 1, result >= 10
}
func (c *Connection) GlobalSet(global string, value interface{}, subs ...interface{}) (err error) {
msg := NewMessage(GLOBAL_SET)
msg.Set(global)
msg.Set(len(subs))
for _, sub := range subs {
msg.Set(sub)
}
msg.Set(value)
_, err = c.conn.Write(msg.Dump(c.count()))
if err != nil {
return
}
_, err = ReadMessage(c.conn)
if err != nil {
return
}
return
}
func (c *Connection) GlobalKill(global string, subs ...interface{}) (err error) {
msg := NewMessage(GLOBAL_KILL)
msg.Set(global)
msg.Set(len(subs))
for _, sub := range subs {
msg.Set(sub)
}
_, err = c.conn.Write(msg.Dump(c.count()))
if err != nil {
return
}
_, err = ReadMessage(c.conn)
if err != nil {
return
}
return
}
func (c *Connection) GlobalGet(global string, result interface{}, subs ...interface{}) (err error) {
msg := NewMessage(GLOBAL_GET)
msg.Set(global)
msg.Set(len(subs))
for _, sub := range subs {
msg.Set(sub)
}
_, err = c.conn.Write(msg.Dump(c.count()))
if err != nil {
return
}
msg, err = ReadMessage(c.conn)
if err != nil {
return
}
msg.Get(result)
return
}
func (c *Connection) GlobalNext(global string, ind *string, subs ...interface{}) (hasNext bool, err error) {
msg := NewMessage(GLOBAL_ORDER)
msg.Set(global)
msg.Set(len(subs) + 1)
for _, sub := range subs {
msg.Set(sub)
}
msg.Set(*ind)
msg.Set(3)
if _, err = c.conn.Write(msg.Dump(c.count())); err != nil {
return
}
if msg, err = ReadMessage(c.conn); err != nil {
return
}
var result string
msg.Get(&result)
*ind = result
hasNext = result != ""
return
}
func (c *Connection) GlobalPrev(global string, ind *string, subs ...interface{}) (hasNext bool, err error) {
msg := NewMessage(GLOBAL_ORDER)
msg.Set(global)
msg.Set(len(subs) + 1)
for _, sub := range subs {
msg.Set(sub)
}
msg.Set(*ind)
msg.Set(7)
if _, err = c.conn.Write(msg.Dump(c.count())); err != nil {
return
}
if msg, err = ReadMessage(c.conn); err != nil {
return
}
var result string
msg.Get(&result)
*ind = result
hasNext = result != ""
return
}

View File

@@ -0,0 +1,148 @@
package connection
import (
"fmt"
"io"
"net"
"github.com/caretdev/go-irisnative/src/list"
)
type Message struct {
header MessageHeader
data []byte
offset uint
}
func NewMessage(messageType MessageType) Message {
return Message{
NewMessageHeader(messageType),
[]byte{},
0,
}
}
func ReadMessage(conn *net.TCPConn) (msg Message, err error) {
buffer := make([]byte, 14)
_, err = conn.Read(buffer)
if err != nil {
return
}
var header [14]byte
copy(header[:], buffer[:14])
var msgHeader = MessageHeader{header}
length := msgHeader.GetLength()
data := make([]byte, length)
var offset int = 0
var size int
for {
size, err = conn.Read(data[offset:])
if err != nil {
if err != io.EOF {
return
}
break
}
offset += size
if offset >= int(length) {
break
}
}
msg = Message{msgHeader, data, 0}
return
}
func (m *Message) AddRaw(value interface{}) {
switch v := value.(type) {
case uint16:
m.data = append(m.data, byte(v&0xff))
m.data = append(m.data, byte(v>>8&0xff))
m.offset += 2
case []byte:
m.data = append(m.data, v...)
m.offset += uint(len(v))
}
}
func (m *Message) GetRaw(value interface{}) error {
switch v := value.(type) {
case *uint16:
*v = uint16(m.data[m.offset]) | (uint16(m.data[m.offset+1]) << 8)
m.offset += 2
case *bool:
*v = (uint16(m.data[m.offset]) | (uint16(m.data[m.offset+1]) << 8)) == 1
m.offset += 2
case *[]byte:
*v = m.data[m.offset:]
m.offset = uint(len(m.data))
default:
return fmt.Errorf("unknown type: %T", v)
}
return nil
}
func (m *Message) Set(value interface{}) error {
listItem := list.NewListItem(value)
m.AddRaw(listItem.Dump())
return nil
}
func (m *Message) SetSQLText(sqlText string) error {
len := len(sqlText)
if len == 0 {
m.Set(sqlText)
return nil
}
const chunksize = 31904
chunks := len / chunksize
if len%chunksize != 0 {
chunks += 1
}
m.Set(chunks)
for i := 0; i < chunks; i++ {
begin := i * chunksize
end := (i + 1) * chunksize
if end > len {
end = len
}
m.Set(sqlText[begin:end])
}
return nil
}
func (m *Message) GetStatus() uint16 {
return m.header.GetStatus()
}
func (m *Message) Get(value interface{}) error {
listItem := list.GetListItem(m.data, &m.offset)
listItem.Get(value)
return nil
}
type AnyType struct {
listItem list.ListItem
}
func (v *AnyType) Int() int {
var value int
v.listItem.Get(&value)
return value
}
func (m *Message) GetAny() AnyType {
listItem := list.GetListItem(m.data, &m.offset)
return AnyType{listItem}
}
func (m *Message) Dump(count uint32) []byte {
m.header.SetCount(count)
m.header.SetLength(uint32(len(m.data)))
return append(m.header.header[:], m.data...)
}

View File

@@ -0,0 +1,84 @@
package connection
type MessageType string
func setUint32(buffer []byte, value uint32) {
buffer[0] = byte(value & 0xff)
buffer[1] = byte(value >> 8 & 0xff)
buffer[2] = byte(value >> 16 & 0xff)
buffer[3] = byte(value >> 24 & 0xff)
}
func getUint32(buffer []byte) uint32 {
return uint32(buffer[0]) |
uint32(buffer[1])<<8 |
uint32(buffer[2])<<16 |
uint32(buffer[3])<<24
}
const (
CONNECT MessageType = "\x43\x4e"
HANDSHAKE MessageType = "\x48\x53"
DISCONNECT MessageType = "\x44\x43"
GLOBAL_GET MessageType = "\x41\xc2"
GLOBAL_SET MessageType = "\x42\xc2"
GLOBAL_KILL MessageType = "\x43\xc2"
GLOBAL_ORDER MessageType = "\x45\xc2"
GLOBAL_DATA MessageType = "\x49\xc2"
CLASSMETHOD_VALUE MessageType = "\x4b\xc2"
CLASSMETHOD_VOID MessageType = "\x4c\xc2"
METHOD_VALUE MessageType = "\x5b\xc2"
METHOD_VOID MessageType = "\x5c\xc2"
PROPERTY_GET MessageType = "\x5d\xc2"
PROPERTY_SET MessageType = "\x5e\xc2"
DIRECT_QUERY MessageType = "DQ"
PREPARED_QUERY MessageType = "PQ"
DIRECT_UPDATE MessageType = "DU"
PREPARED_UPDATE MessageType = "PU"
PREPARE MessageType = "PP"
GET_AUTO_GENERATED_KEYS MessageType = "GG"
COMMIT MessageType = "TC"
ROLLBACK MessageType = "TR"
MULTIPLE_RESULT_SETS_FETCH_DATA MessageType = "MD"
GET_MORE_RESULTS MessageType = "MR"
FETCH_DATA MessageType = "FD"
GET_SERVER_ERROR MessageType = "OE"
)
type MessageHeader struct {
header [14]byte
}
func NewMessageHeader(messageType MessageType) MessageHeader {
header := [14]byte{}
header[12] = messageType[0]
header[13] = messageType[1]
return MessageHeader{header}
}
func (mh *MessageHeader) GetStatus() uint16 {
return uint16(mh.header[12]) | (uint16(mh.header[13]) << 8)
}
func (mh *MessageHeader) SetLength(length uint32) {
setUint32(mh.header[0:], length)
}
func (mh MessageHeader) GetLength() uint32 {
return getUint32(mh.header[0:])
}
func (mh *MessageHeader) SetCount(cnt uint32) {
setUint32(mh.header[4:], cnt)
}
func (mh *MessageHeader) SetStatementId(statementId uint32) {
setUint32(mh.header[8:], statementId)
}

View File

@@ -0,0 +1,251 @@
package connection
import (
"database/sql/driver"
"errors"
"net"
)
const VERSION_PROTOCOL uint16 = 69
type Connection struct {
conn *net.TCPConn
messageCount uint32
statement uint32
unicode bool
locale string
version uint16
info string
featureOptions uint
tx bool
}
var (
ErrCouldNotDetectUsername = errors.New("intersystems: Could not detect default username. Please provide one explicitly")
errBeginTx = errors.New("could not begin transaction")
errMultipleTx = errors.New("multiple transactions")
errReadOnlyTxNotSupported = errors.New("read-only transactions are not supported")
)
func Connect(addr string, namespace, login, password string) (connection Connection, err error) {
tcpAddr, err := net.ResolveTCPAddr("tcp", addr)
if err != nil {
return
}
conn, err := net.DialTCP("tcp", nil, tcpAddr)
if err != nil {
return
}
connection = Connection{
conn: conn,
}
if err = connection.handshake(); err != nil {
return
}
if err = connection.connect(namespace, login, password); err != nil {
return
}
// fmt.Println(connection.version, connection.info)
return
}
func (c *Connection) Disconnect() {
if c.conn == nil {
return
}
msg := NewMessage(DISCONNECT)
_, _ = c.conn.Write(msg.Dump(c.count()))
_ = c.conn.Close()
c.conn = nil
}
func (c *Connection) count() uint32 {
count := c.messageCount
c.messageCount += 1
return count
}
func (c *Connection) statementId() uint32 {
statement := c.statement
c.statement += 1
return statement
}
func (c *Connection) handshake() (err error) {
var message = NewMessage(HANDSHAKE)
message.AddRaw(VERSION_PROTOCOL)
_, err = c.conn.Write(message.Dump(c.count()))
if err != nil {
return
}
msg, err := ReadMessage(c.conn)
if err != nil {
return
}
var version uint16
msg.GetRaw(&version)
c.version = version
var unicode uint16
msg.GetRaw(&unicode)
c.unicode = unicode == 1
var locale string
msg.Get(&locale)
c.locale = locale
return
}
func encode(value string) []byte {
in := []byte(value)
length := len(in)
out := make([]byte, length)
for i := range in {
length--
temp := ((int(in[i])^0xa7)&0xff + length) & 0xff
out[length] = byte(temp<<5 | temp>>3)
}
return out
}
type FeatureOption uint
const (
OptionNone FeatureOption = 0
OptionFastSelect FeatureOption = 1
OptionFastInsert FeatureOption = 2
OptionFastSelectAndInsert FeatureOption = 3
OptionDurableTransactions FeatureOption = 4
OptionNotNullable FeatureOption = 8
OptionRedirectOutput FeatureOption = 32
)
func (c *Connection) IsOptionFastInsert() bool {
return c.featureOptions&uint(OptionFastInsert) == uint(OptionFastInsert)
}
func (c *Connection) IsOptionFastSelect() bool {
return c.featureOptions&uint(OptionFastSelect) == uint(OptionFastSelect)
}
func (c *Connection) connect(namespace, login, password string) (err error) {
msg := NewMessage(CONNECT)
msg.Set(namespace)
msg.Set(encode(login))
msg.Set(encode(password))
var user = "go"
if user, err = systemUser(); err != nil {
user = "go"
}
msg.Set(user) // machine user name
msg.Set("go-machine") // machine name
msg.Set("libirisnative") // application name
msg.Set("") // ?
msg.Set("go") // SharedMemoryFlag?
msg.Set("") // EventClass
msg.Set(1) // AutoCommit ? 1 : 2
msg.Set(0) // IsolationLevel
var featureOptions = OptionNone
featureOptions += OptionFastSelect
// Tricky to make it fully working yet
// featureOptions += OptionFastInsert
featureOptions += OptionDurableTransactions
featureOptions += OptionRedirectOutput
msg.Set(int(featureOptions)) // FeatureOption
_, err = c.conn.Write(msg.Dump(c.count()))
if err != nil {
return
}
msg, err = ReadMessage(c.conn)
if err != nil {
return
}
if status := msg.GetStatus(); status == 417 {
var errorMsg string
msg.Get(&errorMsg)
err = errors.New(errorMsg)
return
}
var info string
msg.Get(&info)
c.info = info
var (
delimited_ids bool
ignored int
isolationLevel int
serverJobNumber string
sqlEmptyString int
serverFeatureOptions uint
)
msg.Get(&delimited_ids)
msg.Get(&ignored)
msg.Get(&isolationLevel)
msg.Get(&serverJobNumber)
msg.Get(&sqlEmptyString)
msg.Get(&serverFeatureOptions)
c.featureOptions = serverFeatureOptions
return
}
func systemUser() (string, error) {
u, err := userCurrent()
if err != nil {
return "", err
}
return u, nil
}
func (c *Connection) Commit() (err error) {
msg := NewMessage(COMMIT)
_, err = c.conn.Write(msg.Dump(c.count()))
if err != nil {
return
}
_, err = ReadMessage(c.conn)
if err != nil {
return
}
return
}
func (c *Connection) Rollback() (err error) {
msg := NewMessage(ROLLBACK)
_, err = c.conn.Write(msg.Dump(c.count()))
if err != nil {
return
}
_, err = ReadMessage(c.conn)
if err != nil {
return
}
return
}
func (c *Connection) BeginTx(opts driver.TxOptions) (driver.Tx, error) {
if c.tx {
return nil, errors.Join(errBeginTx, errMultipleTx)
}
if opts.ReadOnly {
return nil, errors.Join(errBeginTx, errReadOnlyTxNotSupported)
}
if _, err := c.DirectUpdate("START TRANSACTION"); err != nil {
return nil, errors.Join(errBeginTx, err)
}
c.tx = true
return &tx{c}, nil
}

View File

@@ -0,0 +1,101 @@
package connection
import (
"database/sql/driver"
"errors"
"strings"
)
var (
errNoRowsAffected = errors.New("no RowsAffected available after the empty statement")
errNoLastInsertID = errors.New("no LastInsertId available after the empty statement")
)
type Result struct {
cn *Connection
affected int64
}
func (r Result) LastInsertId() (lastId int64, err error) {
// var msg Message
// msg = NewMessage(GET_AUTO_GENERATED_KEYS)
// msg.header.SetStatementId(r.cn.statementId())
// _, err = r.cn.conn.Write(msg.Dump(r.cn.count()))
// if err != nil {
// return
// }
// msg, err = ReadMessage(r.cn.conn)
// if err != nil {
// return
// }
// msg.Get(&lastId)
// return
var rs *ResultSet
rs, err = r.cn.DirectQuery("SELECT LAST_IDENTITY()")
if err != nil {
return
}
row, err := rs.Next()
if err != nil {
return
}
lastId = int64(row[0].(int))
return
}
func (r Result) RowsAffected() (int64, error) {
return r.affected, nil
}
type Rows struct {
cn *Connection
rs *ResultSet
}
type noRows struct{}
var emptyRows noRows
var _ driver.Result = noRows{}
func (noRows) LastInsertId() (int64, error) {
return 0, errNoLastInsertID
}
func (noRows) RowsAffected() (int64, error) {
return 0, errNoRowsAffected
}
func (r *Rows) Close() error {
return nil
}
func (r *Rows) Columns() []string {
if r.rs == nil {
return []string{}
}
columns := r.rs.Columns()
colNames := make([]string, len(columns))
for k, c := range columns {
colname := c.Name()
// tricking IRIS
colname = strings.ReplaceAll(colname, "﹒", ".")
colNames[k] = colname
}
// fmt.Printf("Columns: %#v\n", colNames)
return colNames
}
func (r *Rows) Next(dest []driver.Value) (err error) {
row, err := r.rs.Next()
if err != nil {
return err
}
for i := range dest {
dest[i] = row[i]
}
// fmt.Printf("RowsNext: %#v\n", dest)
return nil
}

View File

@@ -0,0 +1,729 @@
package connection
import (
"database/sql/driver"
"fmt"
"io"
"slices"
"strconv"
"strings"
"time"
"github.com/caretdev/go-irisnative/src/list"
)
const timeLaylout = "2006-01-02 15:04:05.000000000"
const timeLayloutShort = "2006-01-02 15:04:05"
type StatementFeature struct {
featureOption int
msgCount int
maxRowItemCount int
}
type Column struct {
name string
column_type int
precision int
scale int
nullable int
slot_position int
label string
table_name string
schema string
catalog string
is_auto_increment bool
is_case_sensitive bool
is_currency bool
is_read_only bool
is_row_id bool
}
type SQLTYPE int16
const (
GUID SQLTYPE = -11
WLONGVARCHAR SQLTYPE = -10
WVARCHAR SQLTYPE = -9
WCHAR SQLTYPE = -8
BIT SQLTYPE = -7
TINYINT SQLTYPE = -6
BIGINT SQLTYPE = -5
LONGVARBINARY SQLTYPE = -4
VARBINARY SQLTYPE = -3
BINARY SQLTYPE = -2
LONGVARCHAR SQLTYPE = -1
CHAR SQLTYPE = 1
NUMERIC SQLTYPE = 2
DECIMAL SQLTYPE = 3
INTEGER SQLTYPE = 4
SMALLINT SQLTYPE = 5
FLOAT SQLTYPE = 6
REAL SQLTYPE = 7
DOUBLE SQLTYPE = 8
DATE SQLTYPE = 9
TIME SQLTYPE = 10
TIMESTAMP SQLTYPE = 11
VARCHAR SQLTYPE = 12
TYPE_DATE SQLTYPE = 91
TYPE_TIME SQLTYPE = 92
TYPE_TIMESTAMP SQLTYPE = 93
DATE_HOROLOG SQLTYPE = 1091
TIME_HOROLOG SQLTYPE = 1092
TIMESTAMP_POSIX SQLTYPE = 1093
)
func (c Column) Name() string {
return c.name
}
type ResultSet struct {
c *Connection
columns []Column
sf StatementFeature
count int
data []byte
offset uint
sqlCode int16
}
type SQLError struct {
SQLCode int16
Message string
}
func (e *SQLError) Error() string {
return fmt.Sprintf("Error Code: %d, Message: %s", e.SQLCode, e.Message)
}
// func SQLError(code int) error {
// return &SQLError{SQLCode: code}
// }
func (rs ResultSet) Columns() []Column {
return rs.columns
}
func statementFeature(msg *Message) StatementFeature {
featureOption := 0
msgCount := 0
maxRowItemCount := 0
msg.Get(&featureOption)
if featureOption == 2 {
msg.Get(&msgCount)
}
if featureOption == 1 || featureOption == 2 {
msg.Get(&maxRowItemCount)
}
return StatementFeature{
featureOption,
msgCount,
maxRowItemCount,
}
}
type Value interface{}
// type ResultSetRow struct{}
func (rs *ResultSet) fetchMoreData() bool {
msg := NewMessage(FETCH_DATA)
_, err := rs.c.conn.Write(msg.Dump(rs.c.count()))
if err != nil {
panic(err)
}
msg, err = ReadMessage(rs.c.conn)
if err != nil {
panic(err)
}
rs.data = msg.data
rs.offset = 0
return len(msg.data) > 0
}
func fromODBC(coltype SQLTYPE, li list.ListItem) (result interface{}, err error) {
result = nil
if li.IsNull() || li.IsEmpty() {
return
}
switch coltype {
case VARCHAR:
if li.DataLength() == 0 {
return
}
var value string
li.Get(&value)
if value == "\x00" {
value = ""
}
result = value
case INTEGER, TINYINT, SMALLINT:
var value int
li.Get(&value)
result = value
case BIGINT:
var value int64
li.Get(&value)
result = value
case BIT:
var value bool
li.Get(&value)
result = value
case FLOAT:
var value float32
li.Get(&value)
result = value
case DOUBLE:
var value float64
li.Get(&value)
result = value
case TIMESTAMP_POSIX:
if li.DataLength() == 0 {
return
}
if li.Type() == list.LISTITEM_STRING {
var strval string
li.Get(&strval)
result, err = time.Parse(timeLaylout, strval)
if err == nil {
return
}
err = nil
}
var value int64
li.Get(&value)
if value > 0 {
value ^= 0x1000000000000000
} else {
value |= 0x6000000000000000
}
seconds := value / 1000000
nano := value % 1000000 * 1000
result = time.Unix(seconds, nano).In(time.Local)
case VARBINARY:
// var value []uint8
var value string
li.Get(&value)
case TYPE_TIMESTAMP:
var strval string
li.Get(&strval)
result, err = time.Parse(timeLayloutShort, strval)
default:
var value string
li.Get(&value)
fmt.Printf("fromODBC: invalid type: %v - %#v - %#v", coltype, li, value)
result = value
}
return
}
func (rs *ResultSet) Next() ([]Value, error) {
if rs == nil || (rs.sqlCode != 0 && rs.sqlCode != 100) {
return nil, io.EOF
}
if rs.offset >= uint(len(rs.data)) && (rs.sqlCode == 100 || !rs.fetchMoreData()) {
return nil, io.EOF
}
row := make([]Value, rs.count)
data := rs.data
count := rs.count
var offset uint = rs.offset
if rs.sf.featureOption == 1 {
li := list.GetListItem(data, &rs.offset)
li.Get(&data)
offset = 0
count = rs.sf.maxRowItemCount
}
vals := make([]list.ListItem, count)
for i := 0; i < count; i++ {
li := list.GetListItem(data, &offset)
vals[i] = li
}
if rs.sf.featureOption != 1 {
rs.offset = offset
}
var err error
for i, c := range rs.columns {
li := vals[c.slot_position]
row[i], err = fromODBC(SQLTYPE(c.column_type), li)
if err != nil {
return nil, err
}
// fmt.Printf("col: %s: %d; %#v - %#v\n", c.name, c.column_type, row[i], li)
}
// fmt.Printf("row: %#v\n", row)
return row, nil
}
func (c *Connection) getErrorInfo(sqlCode int16) string {
msg := NewMessage(GET_SERVER_ERROR)
msg.Set(sqlCode)
_, err := c.conn.Write(msg.Dump(c.count()))
if err != nil {
panic(err)
}
msg, err = ReadMessage(c.conn)
if err != nil {
panic(err)
}
var sqlMessage string
msg.Get(&sqlMessage)
return sqlMessage
}
func getColumns(msg *Message, statementFeature StatementFeature) []Column {
cnt := 0
msg.Get(&cnt)
columns := make([]Column, cnt)
for i := 0; i < cnt; i++ {
column := Column{}
msg.Get(&column.name)
msg.Get(&column.column_type)
switch column.column_type {
case 9:
column.column_type = 91
case 10:
column.column_type = 92
case 11:
column.column_type = 93
}
msg.Get(&column.precision)
msg.Get(&column.scale)
msg.Get(&column.nullable)
msg.Get(&column.label)
msg.Get(&column.table_name)
msg.Get(&column.schema)
msg.Get(&column.catalog)
additional := ""
msg.Get(&additional)
if statementFeature.featureOption&0x01 == 1 {
msg.Get(&column.slot_position)
column.slot_position -= 1
} else {
column.slot_position = i
}
column.is_auto_increment = additional[0] == 0x01
column.is_case_sensitive = additional[1] == 0x01
column.is_currency = additional[2] == 0x01
column.is_read_only = additional[3] == 0x01
if len(additional) >= 12 {
column.is_row_id = additional[11] == 0x01
}
columns[i] = column
}
return columns
}
func parameterInfo(msg *Message) {
cnt := 0
msg.Get(&cnt)
flag := 0
msg.Get(&flag)
}
func toODBC(value interface{}) interface{} {
var val interface{}
switch v := value.(type) {
case *string:
val = *v
case string:
val = v
if v == "" {
val = "\x00"
}
case nil:
val = ""
case bool:
if v {
val = 1
} else {
val = 0
}
case time.Time:
val = v.UTC().Format(timeLaylout)
case int, int8, int16, int32, int64:
val = v
case float32, float64:
val = v
case []uint8:
val = v
default:
fmt.Printf("unsupported type: %T\n", v)
val = fmt.Sprintf("%v", v)
}
return val
}
func writeParameters(msg *Message, args ...interface{}) {
msg.Set(len(args))
for range args {
msg.Set(99)
msg.Set(4)
}
msg.Set(1) // parameterSets
msg.Set(len(args))
for _, arg := range args {
msg.Set(toODBC(arg))
}
}
func (c *Connection) Query(sqlText string, args ...interface{}) (rs *ResultSet, err error) {
queries := strings.Split(sqlText, ";\n")
if len(queries) == 2 {
sqlText = queries[0]
_, err = c.DirectUpdate(sqlText, args...)
if err != nil {
return
}
sqlText = queries[1]
args = []interface{}{}
}
rs, err = c.DirectQuery(sqlText, args...)
if err != nil {
return
}
return
}
func (c *Connection) DirectQuery(sqlText string, args ...interface{}) (*ResultSet, error) {
sqlText, _, args = FormatQuery(sqlText, args...)
// fmt.Printf("DirectQuery: %s; %#v\n", sqlText, args)
var statementId = c.statementId()
msg := NewMessage(DIRECT_QUERY)
msg.header.SetStatementId(statementId)
msg.SetSQLText(sqlText)
writeParameters(&msg, args...)
msg.Set(10) // Query timeout
msg.Set(200) // Max rows
_, err := c.conn.Write(msg.Dump(c.count()))
if err != nil {
return nil, err
}
msg, err = ReadMessage(c.conn)
if err != nil {
return nil, err
}
sqlCode := int16(msg.GetStatus())
if sqlCode != 0 && sqlCode != 100 {
return nil, &SQLError{SQLCode: sqlCode, Message: c.getErrorInfo(sqlCode)}
}
statementFeature := statementFeature(&msg)
columns := getColumns(&msg, statementFeature)
parameterInfo((&msg))
rs := &ResultSet{
c: c,
sf: statementFeature,
columns: columns,
count: len(columns),
}
msg, err = ReadMessage(c.conn)
rs.sqlCode = int16(msg.GetStatus())
if err != nil {
return nil, err
}
msg.GetRaw(&rs.data)
return rs, nil
}
func (m Message) debug() string {
var sb strings.Builder
for i, b := range m.data {
if i > 0 {
sb.WriteString(", ")
}
sb.WriteString(strconv.Itoa(int(b)))
}
return fmt.Sprintf("$char(%s)", sb.String())
}
func FormatQuery(sqlText string, args ...interface{}) (string, int, []interface{}) {
var count int
for i := range args {
count++
sqlText = strings.Replace(sqlText, "?", fmt.Sprintf(" :%%qpar(%d) ", i+1), 1)
if !strings.Contains(sqlText, "?") {
break
}
}
return sqlText, count, args
}
func (c *Connection) Exec(sqlText string, args ...interface{}) (res *Result, err error) {
queries := strings.Split(sqlText, ";\n")
var onConflict = ""
if len(queries) == 2 {
sqlText = queries[0]
onConflict = strings.Split(queries[1], "-- ")[1]
if strings.Contains(onConflict, "ON CONFLICT UPDATE") {
// fmt.Printf("------\n%s\n%#v\n------\n", sqlText, args)
sqlText = strings.Replace(sqlText, "INSERT INTO", "INSERT OR UPDATE", 1)
onConflict = ""
}
}
res, err = c.DirectUpdate(sqlText, args...)
if err != nil {
if strings.Contains(onConflict, "ON CONFLICT DO NOTHING") {
res = &Result{cn: c, affected: 0}
err = nil
return
}
}
return
}
func (c *Connection) DirectUpdate(sqlText string, args ...interface{}) (*Result, error) {
var batchSize int
sqlText, batchSize, args = FormatQuery(sqlText, args...)
// fmt.Printf("DirectUpdate: %s; %#v\n", sqlText, args)
var batches = 1
if batchSize > 0 {
batches = len(args) / batchSize
}
var addToCache = false
var statementId = c.statementId()
var executeMany = false
var optFastInsert = false
var rowsAffected int64 = 0
var identityColumn = false
var defaults = []interface{}{}
for i := 1; i <= batches; i++ {
if i > 1 && executeMany {
break
}
var msg Message
if !addToCache {
msg = NewMessage(DIRECT_UPDATE)
msg.SetSQLText(sqlText)
msg.Set(batchSize)
for j := 0; j < batchSize; j++ {
msg.Set(99)
msg.Set(1)
}
// msg.Set(len(args))
// for range args {
// msg.Set(99)
// msg.Set(1)
// }
} else {
msg = NewMessage(PREPARED_UPDATE)
}
if addToCache && !executeMany && optFastInsert {
msg.AddRaw([]byte{1, 0, 0, 0})
msg.Set("")
msg.Set(0)
if identityColumn {
msg.Set(2)
msg.Set("")
} else {
msg.Set(1)
}
var batch []interface{} = make([]interface{}, batchSize)
copy(batch, args)
args = slices.Delete(args, 0, batchSize)
var params []byte
var item list.ListItem
for _, arg := range batch {
item = list.NewListItem(toODBC(arg))
params = append(params, item.Dump()...)
}
for _, arg := range defaults {
item = list.NewListItem(toODBC(arg))
params = append(params, item.Dump()...)
}
msg.Set(params)
} else {
msg.Set("")
msg.Set(0)
if executeMany {
msg.Set(batches)
for k := 0; k < batches; k++ {
msg.Set(batchSize)
for j := 0; j < batchSize; j++ {
var idx = (k * batchSize) + j
msg.Set(toODBC(args[idx]))
}
}
} else {
var batch []interface{} = make([]interface{}, batchSize)
copy(batch, args)
args = slices.Delete(args, 0, batchSize)
msg.Set(1)
msg.Set(len(batch))
for _, arg := range batch {
msg.Set(toODBC(arg))
}
}
}
msg.header.SetStatementId(statementId)
_, err := c.conn.Write(msg.Dump(c.count()))
if err != nil {
return nil, err
}
msg, err = ReadMessage(c.conn)
if err != nil {
// fmt.Println("DirectUpdate:Readmessage: ", err)
return nil, err
}
sqlCode := int16(msg.GetStatus())
if sqlCode != 0 && sqlCode != 100 {
return nil, &SQLError{SQLCode: sqlCode, Message: c.getErrorInfo(sqlCode)}
}
if i == 1 {
if c.IsOptionFastInsert() {
stmtFeatureOption, _ := c.checkStatementFeature(&msg)
optFastInsert = stmtFeatureOption&uint(OptionFastInsert) == uint(OptionFastInsert)
}
addToCache, identityColumn, defaults = c.getParameterInfo(&msg, optFastInsert)
}
var batchRows int64
msg.Get(&batchRows)
rowsAffected += batchRows
}
result := &Result{cn: c, affected: rowsAffected}
return result, nil
}
func (c *Connection) checkStatementFeature(msg *Message) (featureOption uint, count uint) {
count = 0
var keyCount int
msg.Get(&featureOption)
if featureOption == uint(OptionFastSelect) || featureOption == uint(OptionFastInsert) {
if featureOption == uint(OptionFastInsert) {
msg.Get(&keyCount)
}
msg.Get(&count)
}
return
}
func (c *Connection) getParameterInfo(msg *Message, optFastInsert bool) (addToCache bool, identityColumn bool, defaults []interface{}) {
var paramscnt int
msg.Get(&paramscnt)
var tablename string
for i := 0; i < paramscnt; i++ {
var (
paramtype int
precision int
scale int
nullable bool
position int
someval1 string
someval2 string
colname string
)
msg.Get(&paramtype)
msg.Get(&precision)
msg.Get(&scale)
msg.GetAny()
if optFastInsert {
msg.Get(&nullable)
msg.Get(&position)
msg.Get(&someval1)
msg.Get(&someval2)
if i == 0 {
msg.Get(&tablename)
}
msg.Get(&colname)
}
}
var flag int
defaults = []interface{}{}
identityColumn = false
msg.Get(&flag)
addToCache = flag&0x1 == 0x1
if optFastInsert {
var paramsDefault []byte
msg.Get(&paramsDefault)
var offset uint = 0
var li list.ListItem
li = list.GetListItem(paramsDefault, &offset)
identityColumn = li.IsEmpty()
for {
if uint(len(paramsDefault)) == offset {
break
}
li = list.GetListItem(paramsDefault, &offset)
if li.IsNull() {
continue
}
var val string
li.Get(&val)
defaults = append(defaults, val)
}
}
return
}
type Stmt struct {
cn *Connection
sql string
closed bool
statementId int32
}
func (c *Connection) Prepare(query string) (*Stmt, error) {
// msg := NewMessage(PREPARE)
// msg.SetSQLText(query)
// msg.Set(0)
// _, err := c.conn.Write(msg.Dump(c.count()))
// if err != nil {
// return nil, err
// }
// msg, err = ReadMessage(c.conn)
// if err != nil {
// return nil, err
// }
// sqlCode := int16(msg.GetStatus())
// if sqlCode != 0 && sqlCode != 100 {
// return nil, &SQLError{SQLCode: sqlCode, Message: c.getErrorInfo(sqlCode)}
// }
st := &Stmt{cn: c, sql: query}
return st, nil
}
func (st *Stmt) Exec(args []driver.Value) (res driver.Result, err error) {
parameters := make([]interface{}, len(args))
for i, a := range args {
parameters[i] = a
}
res, err = st.cn.Exec(st.sql, parameters...)
return
}
func (st *Stmt) Query(args []driver.Value) (rows driver.Rows, err error) {
parameters := make([]interface{}, len(args))
for i, a := range args {
parameters[i] = a
}
var rs *ResultSet
rs, err = st.cn.Query(st.sql, parameters...)
// st.statementId = int32(st.cn.statementId())
if err != nil {
return nil, err
}
rows = &Rows{
cn: st.cn,
rs: rs,
}
return
}
func (st *Stmt) Close() (err error) {
st.closed = true
return nil
}
func (st *Stmt) NumInput() int {
return -1
}

View File

@@ -0,0 +1,29 @@
package connection
type tx struct {
c *Connection
}
func (t *tx) Commit() error {
if t.c == nil || !t.c.tx {
panic("database/sql/driver: misuse of driver: extra Commit")
}
t.c.tx = false
err := t.c.Commit()
t.c = nil
return err
}
func (t *tx) Rollback() error {
if t.c == nil || !t.c.tx {
panic("database/sql/driver: misuse of driver: extra Rollback")
}
t.c.tx = false
err := t.c.Rollback()
t.c = nil
return err
}

View File

@@ -0,0 +1,22 @@
//go:build !windows
package connection
import (
"os"
"os/user"
)
func userCurrent() (string, error) {
u, err := user.Current()
if err == nil {
return u.Username, nil
}
name := os.Getenv("USER")
if name != "" {
return name, nil
}
return "", ErrCouldNotDetectUsername
}

View File

@@ -0,0 +1,19 @@
package connection
import (
"path/filepath"
"syscall"
)
// Perform Windows user name.
func userCurrent() (string, error) {
pw_name := make([]uint16, 128)
pwname_size := uint32(len(pw_name)) - 1
err := syscall.GetUserNameEx(syscall.NameSamCompatible, &pw_name[0], &pwname_size)
if err != nil {
return "", ErrCouldNotDetectUsername
}
s := syscall.UTF16ToString(pw_name)
u := filepath.Base(s)
return u, nil
}

View File

@@ -0,0 +1,4 @@
package iris
type Oref string

View File

@@ -0,0 +1,456 @@
package list
import (
"encoding/binary"
"errors"
"fmt"
"strconv"
"github.com/caretdev/go-irisnative/src/iris"
"github.com/shopspring/decimal"
)
type ListItemType byte
const (
LISTITEM_STRING ListItemType = 0x01
LISTITEM_UNICODE ListItemType = 0x02
LISTITEM_POSINT ListItemType = 0x04
LISTITEM_NEGINT ListItemType = 0x05
LISTITEM_POSFLOAT ListItemType = 0x06
LISTITEM_NEGFLOAT ListItemType = 0x07
LISTITEM_OREF ListItemType = 0x19
)
type ListItem struct {
size uint16
itemType ListItemType
data []byte
isNull bool
byRef bool
}
func (li *ListItem) IsNull() bool {
return li.isNull
}
func (li *ListItem) IsString() bool {
return li.itemType == LISTITEM_STRING || li.itemType == LISTITEM_UNICODE
}
func (li *ListItem) IsEmpty() bool {
return li.itemType == LISTITEM_STRING && len(li.data) == 0
}
func (li *ListItem) Type() ListItemType {
return li.itemType
}
var scale = []float64{
1.0, 10.0, 100.0, 1000.0, 10000.0, 100000.0, 1000000.0, 1.0e7, 1.0e8, 1.0e9,
1.0e10, 1.0e11, 1.0e12, 1.0e13, 1.0e14, 1.0e15, 1.0e16, 1.0e17, 1.0e18, 1.0e19,
1.0e20, 1.0e21, 1.0e22, 9.999999999999999e22, 1.0e24, 1.0e25, 1.0e26, 1.0e27, 1.0e28, 1.0e29,
1.0e30, 1.0e31, 1.0e32, 1.0e33, 1.0e34, 1.0e35, 1.0e36, 1.0e37, 1.0e38, 1.0e39,
1.0e40, 1.0e41, 1.0e42, 1.0e43, 1.0e44, 1.0e45, 1.0e46, 1.0e47, 1.0e48, 1.0e49,
1.0e50, 1.0e51, 1.0e52, 1.0e53, 1.0e54, 1.0e55, 1.0e56, 1.0e57, 1.0e58, 1.0e59,
1.0e60, 1.0e61, 1.0e62, 1.0e63, 1.0e64, 1.0e65, 1.0e66, 1.0e67, 1.0e68, 1.0e69,
1.0e70, 1.0e71, 1.0e72, 1.0e73, 1.0e74, 1.0e75, 1.0e76, 1.0e77, 1.0e78, 1.0e79,
1.0e80, 1.0e81, 1.0e82, 1.0e83, 1.0e84, 1.0e85, 1.0e86, 1.0e87, 1.0e88, 1.0e89,
1.0e90, 1.0e91, 1.0e92, 1.0e93, 1.0e94, 1.0e95, 1.0e96, 1.0e97, 1.0e98, 1.0e99,
1.0e100, 1.0e101, 1.0e102, 1.0e103, 1.0e104, 1.0e105, 1.0e106, 1.0e107, 1.0e108, 1.0e109,
1.0e110, 1.0e111, 1.0e112, 1.0e113, 1.0e114, 1.0e115, 1.0e116, 1.0e117, 1.0e118, 1.0e119,
1.0e120, 1.0e121, 1.0e122, 1.0e123, 1.0e124, 1.0e125, 1.0e126, 1.0e127, 1.0e-128, 1.0e-127,
1.0e-126, 1.0e-125, 1.0e-124, 1.0e-123, 1.0e-122, 1.0e-121, 1.0e-120, 1.0e-119, 1.0e-118, 1.0e-117,
1.0e-116, 1.0e-115, 1.0e-114, 1.0e-113, 1.0e-112, 1.0e-111, 1.0e-110, 1.0e-109, 1.0e-108, 1.0e-107,
1.0e-106, 1.0e-105, 1.0e-104, 1.0e-103, 1.0e-102, 1.0e-101, 1.0e-100, 1.0e-99, 1.0e-98, 1.0e-97,
1.0e-96, 1.0e-95, 1.0e-94, 1.0e-93, 1.0e-92, 1.0e-91, 1.0e-90, 1.0e-89, 1.0e-88, 1.0e-87,
1.0e-86, 1.0e-85, 1.0e-84, 1.0e-83, 1.0e-82, 1.0e-81, 1.0e-80, 1.0e-79, 1.0e-78, 1.0e-77,
1.0e-76, 1.0e-75, 1.0e-74, 1.0e-73, 1.0e-72, 1.0e-71, 1.0e-70, 1.0e-69, 1.0e-68, 1.0e-67,
1.0e-66, 1.0e-65, 1.0e-64, 1.0e-63, 1.0e-62, 1.0e-61, 1.0e-60, 1.0e-59, 1.0e-58, 1.0e-57,
1.0e-56, 1.0e-55, 1.0e-54, 1.0e-53, 1.0e-52, 1.0e-51, 1.0e-50, 1.0e-49, 1.0e-48, 1.0e-47,
1.0e-46, 1.0e-45, 1.0e-44, 1.0e-43, 1.0e-42, 1.0e-41, 1.0e-40, 1.0e-39, 1.0e-38, 1.0e-37,
1.0e-36, 1.0e-35, 1.0e-34, 1.0e-33, 1.0e-32, 1.0e-31, 1.0e-30, 1.0e-29, 1.0e-28, 1.0e-27,
1.0e-26, 1.0e-25, 1.0e-24, 1.0e-23, 1.0e-22, 1.0e-21, 1.0e-20, 1.0e-19, 1.0e-18, 1.0e-17,
1.0e-16, 1.0e-15, 1.0e-14, 1.0e-13, 1.0e-12, 1.0e-11, 1.0e-10, 1.0e-9, 1.0e-8, 1.0e-7,
1.0e-6, 1.0e-5, 1.0e-4, 0.001, 0.01, 0.1}
func (listItem *ListItem) Dump() []byte {
if listItem.isNull {
return []byte{1}
}
var dump = make([]byte, 0)
if listItem.size > 253 {
size := listItem.size + 1
dump = append(dump, 0)
dump = append(dump, byte((size)&0xff))
dump = append(dump, byte((size>>8)&0xff))
} else {
dump = append(dump, byte(listItem.size+2))
}
dump = append(dump, byte(listItem.itemType))
dump = append(dump, listItem.data...)
return dump
}
func GetListItem(buffer []byte, ooffset *uint) ListItem {
var byRef = false
var isNull = false
var size uint16 = 0
var itemType byte = 0
offset := *ooffset
switch buffer[offset] {
case 0:
size = uint16((buffer[offset+1] & 0xff))
size |= ((uint16(buffer[offset+2]) & 0xff) << 8)
size -= 1
offset += 3
itemType = buffer[offset]
offset += 1
case 1:
isNull = true
offset += 1
default:
size = uint16(buffer[offset]) - 2
offset += 1
itemType = buffer[offset]
offset += 1
if itemType >= 32 && itemType < 64 {
itemType = itemType - 32
byRef = true
}
}
var data = []byte{}
if size > 0 {
data = buffer[offset : offset+uint(size)]
}
offset += uint(size)
*ooffset = offset
return ListItem{size, ListItemType(itemType), data, isNull, byRef}
}
func NewListItem(value interface{}) ListItem {
var itemType ListItemType = 0
var size uint16 = 0
var data = make([]byte, 0)
var isNull = false
var byRef = false
switch v := value.(type) {
case *string:
var listItem = NewListItem(*v)
listItem.byRef = true
return listItem
case int, int8, int16, int32, int64:
var ival int64
switch i := v.(type) {
case int:
ival = int64(i)
case int8:
ival = int64(i)
case int16:
ival = int64(i)
case int32:
ival = int64(i)
case int64:
ival = i
}
itemType = 4
var base = 0
var temp = ival
if ival < 0 {
itemType = 5
base = 0xff
temp = ival*-1 - 1
}
for temp > 0 {
data = append(data, byte((temp^int64(base))&0xff))
temp = temp >> 8
}
case uint, uint8, uint16, uint32, uint64:
var uval uint64
switch u := v.(type) {
case uint:
uval = uint64(u)
case uint8:
uval = uint64(u)
case uint16:
uval = uint64(u)
case uint32:
uval = uint64(u)
case uint64:
uval = u
}
itemType = 4
temp := uval
for temp > 0 {
data = append(data, byte(temp&0xff))
temp = temp >> 8
}
case float64, float32:
var d decimal.Decimal
switch f := v.(type) {
case float32:
d = decimal.NewFromFloat32(f)
case float64:
d = decimal.NewFromFloat(f)
}
scaleSize := 256 - d.Exponent()*-1
ival := d.Coefficient().Int64()
itemType = 6
if ival < 0 {
itemType = 7
}
data = append(data, byte(scaleSize))
var base = 0
var temp = ival
if ival < 0 {
base = 0xff
temp = ival*-1 - 1
}
for temp > 0 {
data = append(data, byte((temp^int64(base))&0xff))
temp = temp >> 8
}
case bool:
itemType = 4
if v {
data = []byte{0x1}
} else {
data = []byte{0x0}
}
case string:
itemType = 1
var unicodeBytes []byte
for _, r := range(v) {
if r > 255 {
itemType = 2
var temp = r
// append(unicodeBytes)
for temp > 0 {
unicodeBytes = append(unicodeBytes, byte((temp)&0xff))
temp = temp >> 8
}
} else {
unicodeBytes = append(unicodeBytes, byte((r)&0xff))
unicodeBytes = append(unicodeBytes, byte(0))
}
}
if itemType == 2 {
data = unicodeBytes
} else {
data = []byte(v)
}
case []byte:
itemType = 1
data = v
case nil:
isNull = true
// itemType = 1
// data = []byte("")
case iris.Oref:
itemType = 25
byRef = true
data = []byte(v)
default:
fmt.Printf("unknown: %#v %T\n", v, v)
itemType = 1
data = []byte(fmt.Sprintf("%v", v))
}
size = uint16(len(data))
return ListItem{
size,
itemType,
data,
isNull,
byRef,
}
}
func (li *ListItem) getString() string {
if li.itemType == LISTITEM_UNICODE {
var val string = ""
for i := 0; i < len(li.data); i += 2 {
val += string(rune(getPosInt(li.data[i:i+2])))
}
return val
} else {
return string(li.data)
}
}
func getPosInt(data []byte) int {
temp := make([]byte, 8)
copy(temp, data)
return int(binary.LittleEndian.Uint64(temp[:8]))
}
func getNegInt(data []byte) int {
temp := make([]byte, 8)
copy(temp, data)
for i := range data {
temp[i] ^= 0xff
}
return -int(binary.LittleEndian.Uint64(temp[:8]) + 1)
}
func getPosFloat(data []byte) float64 {
d := scale[int(data[0])]
return float64(getPosInt(data[1:])) * d
}
func getNegFloat(data []byte) float64 {
d := scale[int(data[0])]
return float64(getNegInt(data[1:])) * d
}
func (li *ListItem) asString() (value string, err error) {
if li.isNull {
value = ""
return
}
switch li.itemType {
case 1, 2, 25:
value = li.getString()
case 4:
value = fmt.Sprint(getPosInt(li.data))
case 5:
value = fmt.Sprint(getNegInt(li.data))
case 6:
value = fmt.Sprint(getPosFloat(li.data))
case 7:
value = fmt.Sprint(getNegFloat(li.data))
default:
err = errors.New("not implemented")
}
return
}
func (li *ListItem) asInt() (value int, err error) {
if li.isNull {
value = 0
return
}
switch li.itemType {
case 1, 2:
value, err = strconv.Atoi(li.getString())
case 4:
value = getPosInt(li.data)
case 5:
value = getNegInt(li.data)
case 6:
value = int(getPosFloat(li.data))
case 7:
value = int(getNegFloat(li.data))
default:
err = errors.New("not implemented")
}
return
}
func (li *ListItem) asFloat64() (value float64, err error) {
if li.isNull {
value = 0
return
}
switch li.itemType {
case 1, 2:
var temp int
temp, err = strconv.Atoi(li.getString())
if err != nil {
return
}
value = float64(temp)
case 4:
value = float64(getPosInt(li.data))
case 5:
value = float64(getNegInt(li.data))
case 6:
value = getPosFloat(li.data)
case 7:
value = getNegFloat(li.data)
default:
err = errors.New("not implemented")
}
return
}
type AnyType ListItem
func (v *AnyType) Int() int {
var value int
// ListItem(*v)
return value
}
func (li *ListItem) GetAny() AnyType {
return AnyType(*li)
}
func (li *ListItem) DataLength() int {
return len(li.data)
}
func (li *ListItem) Get(value interface{}) (err error) {
switch v := value.(type) {
case *int:
*v, err = li.asInt()
case *bool:
var temp int
temp, err = li.asInt()
*v = temp != 0
case *int8:
var temp int
temp, err = li.asInt()
*v = int8(temp)
case *int16:
var temp int
temp, err = li.asInt()
*v = int16(temp)
case *int32:
var temp int
temp, err = li.asInt()
*v = int32(temp)
case *int64:
var temp int
temp, err = li.asInt()
*v = int64(temp)
case *uint:
var temp int
temp, err = li.asInt()
*v = uint(temp)
case *uint8:
var temp int
temp, err = li.asInt()
*v = uint8(temp)
case *uint16:
var temp int
temp, err = li.asInt()
*v = uint16(temp)
case *uint32:
var temp int
temp, err = li.asInt()
*v = uint32(temp)
case *uint64:
var temp int
temp, err = li.asInt()
*v = uint64(temp)
case *float64:
*v, err = li.asFloat64()
case *float32:
var temp float64
temp, err = li.asFloat64()
*v = float32(temp)
case *string:
*v, err = li.asString()
case *[]byte:
*v = li.data
case *iris.Oref:
var temp string
temp, err = li.asString()
*v = iris.Oref(temp)
default:
err = errors.New("not implemented")
}
return
}

76
third_party/go-irisnative/url.go vendored Normal file
View File

@@ -0,0 +1,76 @@
package intersystems
import (
"fmt"
"net"
nurl "net/url"
"sort"
"strings"
)
// ParseURL no longer needs to be used by clients of this library since supplying a URL as a
// connection string to sql.Open() is now supported:
//
// sql.Open("intersystems", "iris://_system:SYS@1.2.3.4:1972/USER")
//
// It remains exported here for backwards-compatibility.
//
// ParseURL converts a url to a connection string for driver.Open.
// Example:
//
// "iris://_system:SYS@1.2.3.4:1972/USER"
//
// converts to:
//
// "user=_system password=SYS host=1.2.3.4 port=1972 namespace=USER"
//
// A minimal example:
//
// "iris://"
//
// This will be blank, causing driver.Open to use all of the defaults
func ParseURL(url string) (string, error) {
u, err := nurl.Parse(url)
if err != nil {
return "", err
}
if u.Scheme != "iris" && u.Scheme != "IRIS" {
return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme)
}
var kvs []string
escaper := strings.NewReplacer(`'`, `\'`, `\`, `\\`)
accrue := func(k, v string) {
if v != "" {
kvs = append(kvs, k+"='"+escaper.Replace(v)+"'")
}
}
if u.User != nil {
v := u.User.Username()
accrue("user", v)
v, _ = u.User.Password()
accrue("password", v)
}
if host, port, err := net.SplitHostPort(u.Host); err != nil {
accrue("host", u.Host)
} else {
accrue("host", host)
accrue("port", port)
}
if u.Path != "" {
accrue("namespace", strings.ToUpper(u.Path[1:]))
}
q := u.Query()
for k := range q {
accrue(k, q.Get(k))
}
sort.Strings(kvs) // Makes testing easier (not a performance concern)
return strings.Join(kvs, " "), nil
}