mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-06-09 16:09:41 +08:00
✨ feat(iris): 新增 InterSystems IRIS 数据源支持
- 后端新增 IRIS 连接、查询、DDL、索引元数据和 DataGrid 编辑能力 - 接入 optional driver-agent、构建标签、revision 生成和变更检测流程 - 前端新增 IRIS 连接入口、方言映射、能力配置和图标展示 - 修复 IRIS 主键识别、事务开启错误处理和驱动连接关闭问题 - 补充后端、前端和构建脚本相关回归测试 Refs #408
This commit is contained in:
113
third_party/go-irisnative/src/connection/classes.go
vendored
Normal file
113
third_party/go-irisnative/src/connection/classes.go
vendored
Normal file
@@ -0,0 +1,113 @@
|
||||
package connection
|
||||
|
||||
import "github.com/caretdev/go-irisnative/src/iris"
|
||||
|
||||
func (c *Connection) ServerVersion() (result string, err error) {
|
||||
err = c.ClassMethod("%SYSTEM.Version", "GetVersion", &result)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Connection) ClassMethod(class, method string, result interface{}, args ...interface{}) (err error) {
|
||||
msg := NewMessage(CLASSMETHOD_VALUE)
|
||||
msg.Set(class)
|
||||
msg.Set(method)
|
||||
msg.Set(len(args))
|
||||
for _, arg := range args {
|
||||
msg.Set(arg)
|
||||
}
|
||||
|
||||
_, err = c.conn.Write(msg.Dump(c.count()))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
msg, err = ReadMessage(c.conn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
msg.Get(result)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Connection) ClassMethodVoid(class, method string, args ...interface{}) (err error) {
|
||||
msg := NewMessage(CLASSMETHOD_VOID)
|
||||
msg.Set(class)
|
||||
msg.Set(method)
|
||||
msg.Set(len(args))
|
||||
for _, arg := range args {
|
||||
msg.Set(arg)
|
||||
}
|
||||
|
||||
_, err = c.conn.Write(msg.Dump(c.count()))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
msg, err = ReadMessage(c.conn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Connection) Method(obj iris.Oref, method string, result interface{}, args ...interface{}) (err error) {
|
||||
msg := NewMessage(METHOD_VALUE)
|
||||
msg.Set(obj)
|
||||
msg.Set(method)
|
||||
msg.Set(len(args))
|
||||
for _, arg := range args {
|
||||
msg.Set(arg)
|
||||
}
|
||||
|
||||
_, err = c.conn.Write(msg.Dump(c.count()))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
msg, err = ReadMessage(c.conn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
msg.Get(result)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Connection) MethodVoid(obj, method string, args ...interface{}) (err error) {
|
||||
msg := NewMessage(METHOD_VOID)
|
||||
msg.Set(obj)
|
||||
msg.Set(method)
|
||||
msg.Set(len(args))
|
||||
for _, arg := range args {
|
||||
msg.Set(arg)
|
||||
}
|
||||
|
||||
_, err = c.conn.Write(msg.Dump(c.count()))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
msg, err = ReadMessage(c.conn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
func (c *Connection) PropertyGet(obj iris.Oref, property string, result interface{}) (err error) {
|
||||
msg := NewMessage(PROPERTY_GET)
|
||||
msg.Set(obj)
|
||||
msg.Set(property)
|
||||
// msg.Set(0)
|
||||
|
||||
_, err = c.conn.Write(msg.Dump(c.count()))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
msg, err = ReadMessage(c.conn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
msg.Get(result)
|
||||
|
||||
return
|
||||
}
|
||||
141
third_party/go-irisnative/src/connection/globals.go
vendored
Normal file
141
third_party/go-irisnative/src/connection/globals.go
vendored
Normal file
@@ -0,0 +1,141 @@
|
||||
package connection
|
||||
|
||||
func (c *Connection) GlobalIsDefined(global string, subs ...interface{}) (bool, bool) {
|
||||
msg := NewMessage(GLOBAL_DATA)
|
||||
msg.Set(global)
|
||||
msg.Set(len(subs))
|
||||
for _, sub := range subs {
|
||||
msg.Set(sub)
|
||||
}
|
||||
msg.Set(0)
|
||||
_, err := c.conn.Write(msg.Dump(c.count()))
|
||||
if err != nil {
|
||||
return false, false
|
||||
}
|
||||
|
||||
msg, err = ReadMessage(c.conn)
|
||||
if err != nil {
|
||||
return false, false
|
||||
}
|
||||
|
||||
var result uint8
|
||||
msg.Get(&result)
|
||||
return result%10 == 1, result >= 10
|
||||
}
|
||||
|
||||
func (c *Connection) GlobalSet(global string, value interface{}, subs ...interface{}) (err error) {
|
||||
msg := NewMessage(GLOBAL_SET)
|
||||
msg.Set(global)
|
||||
msg.Set(len(subs))
|
||||
for _, sub := range subs {
|
||||
msg.Set(sub)
|
||||
}
|
||||
msg.Set(value)
|
||||
|
||||
_, err = c.conn.Write(msg.Dump(c.count()))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = ReadMessage(c.conn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Connection) GlobalKill(global string, subs ...interface{}) (err error) {
|
||||
msg := NewMessage(GLOBAL_KILL)
|
||||
msg.Set(global)
|
||||
msg.Set(len(subs))
|
||||
for _, sub := range subs {
|
||||
msg.Set(sub)
|
||||
}
|
||||
|
||||
_, err = c.conn.Write(msg.Dump(c.count()))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = ReadMessage(c.conn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Connection) GlobalGet(global string, result interface{}, subs ...interface{}) (err error) {
|
||||
msg := NewMessage(GLOBAL_GET)
|
||||
msg.Set(global)
|
||||
msg.Set(len(subs))
|
||||
for _, sub := range subs {
|
||||
msg.Set(sub)
|
||||
}
|
||||
|
||||
_, err = c.conn.Write(msg.Dump(c.count()))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
msg, err = ReadMessage(c.conn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
msg.Get(result)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Connection) GlobalNext(global string, ind *string, subs ...interface{}) (hasNext bool, err error) {
|
||||
msg := NewMessage(GLOBAL_ORDER)
|
||||
msg.Set(global)
|
||||
msg.Set(len(subs) + 1)
|
||||
for _, sub := range subs {
|
||||
msg.Set(sub)
|
||||
}
|
||||
msg.Set(*ind)
|
||||
msg.Set(3)
|
||||
|
||||
if _, err = c.conn.Write(msg.Dump(c.count())); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if msg, err = ReadMessage(c.conn); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var result string
|
||||
msg.Get(&result)
|
||||
*ind = result
|
||||
hasNext = result != ""
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Connection) GlobalPrev(global string, ind *string, subs ...interface{}) (hasNext bool, err error) {
|
||||
msg := NewMessage(GLOBAL_ORDER)
|
||||
msg.Set(global)
|
||||
msg.Set(len(subs) + 1)
|
||||
for _, sub := range subs {
|
||||
msg.Set(sub)
|
||||
}
|
||||
msg.Set(*ind)
|
||||
msg.Set(7)
|
||||
|
||||
if _, err = c.conn.Write(msg.Dump(c.count())); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if msg, err = ReadMessage(c.conn); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var result string
|
||||
msg.Get(&result)
|
||||
*ind = result
|
||||
hasNext = result != ""
|
||||
|
||||
return
|
||||
}
|
||||
148
third_party/go-irisnative/src/connection/message.go
vendored
Normal file
148
third_party/go-irisnative/src/connection/message.go
vendored
Normal file
@@ -0,0 +1,148 @@
|
||||
package connection
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/caretdev/go-irisnative/src/list"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
header MessageHeader
|
||||
data []byte
|
||||
offset uint
|
||||
}
|
||||
|
||||
func NewMessage(messageType MessageType) Message {
|
||||
return Message{
|
||||
NewMessageHeader(messageType),
|
||||
[]byte{},
|
||||
0,
|
||||
}
|
||||
}
|
||||
|
||||
func ReadMessage(conn *net.TCPConn) (msg Message, err error) {
|
||||
buffer := make([]byte, 14)
|
||||
|
||||
_, err = conn.Read(buffer)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var header [14]byte
|
||||
copy(header[:], buffer[:14])
|
||||
var msgHeader = MessageHeader{header}
|
||||
|
||||
length := msgHeader.GetLength()
|
||||
data := make([]byte, length)
|
||||
var offset int = 0
|
||||
var size int
|
||||
for {
|
||||
size, err = conn.Read(data[offset:])
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
return
|
||||
}
|
||||
break
|
||||
}
|
||||
offset += size
|
||||
if offset >= int(length) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
msg = Message{msgHeader, data, 0}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (m *Message) AddRaw(value interface{}) {
|
||||
switch v := value.(type) {
|
||||
case uint16:
|
||||
m.data = append(m.data, byte(v&0xff))
|
||||
m.data = append(m.data, byte(v>>8&0xff))
|
||||
m.offset += 2
|
||||
case []byte:
|
||||
m.data = append(m.data, v...)
|
||||
m.offset += uint(len(v))
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Message) GetRaw(value interface{}) error {
|
||||
switch v := value.(type) {
|
||||
case *uint16:
|
||||
*v = uint16(m.data[m.offset]) | (uint16(m.data[m.offset+1]) << 8)
|
||||
m.offset += 2
|
||||
case *bool:
|
||||
*v = (uint16(m.data[m.offset]) | (uint16(m.data[m.offset+1]) << 8)) == 1
|
||||
m.offset += 2
|
||||
case *[]byte:
|
||||
*v = m.data[m.offset:]
|
||||
m.offset = uint(len(m.data))
|
||||
default:
|
||||
return fmt.Errorf("unknown type: %T", v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Message) Set(value interface{}) error {
|
||||
listItem := list.NewListItem(value)
|
||||
m.AddRaw(listItem.Dump())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Message) SetSQLText(sqlText string) error {
|
||||
len := len(sqlText)
|
||||
if len == 0 {
|
||||
m.Set(sqlText)
|
||||
return nil
|
||||
}
|
||||
const chunksize = 31904
|
||||
chunks := len / chunksize
|
||||
if len%chunksize != 0 {
|
||||
chunks += 1
|
||||
}
|
||||
m.Set(chunks)
|
||||
for i := 0; i < chunks; i++ {
|
||||
begin := i * chunksize
|
||||
end := (i + 1) * chunksize
|
||||
if end > len {
|
||||
end = len
|
||||
}
|
||||
m.Set(sqlText[begin:end])
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Message) GetStatus() uint16 {
|
||||
return m.header.GetStatus()
|
||||
}
|
||||
|
||||
func (m *Message) Get(value interface{}) error {
|
||||
listItem := list.GetListItem(m.data, &m.offset)
|
||||
listItem.Get(value)
|
||||
return nil
|
||||
}
|
||||
|
||||
type AnyType struct {
|
||||
listItem list.ListItem
|
||||
}
|
||||
|
||||
func (v *AnyType) Int() int {
|
||||
var value int
|
||||
v.listItem.Get(&value)
|
||||
return value
|
||||
}
|
||||
|
||||
func (m *Message) GetAny() AnyType {
|
||||
listItem := list.GetListItem(m.data, &m.offset)
|
||||
return AnyType{listItem}
|
||||
}
|
||||
|
||||
func (m *Message) Dump(count uint32) []byte {
|
||||
m.header.SetCount(count)
|
||||
m.header.SetLength(uint32(len(m.data)))
|
||||
|
||||
return append(m.header.header[:], m.data...)
|
||||
}
|
||||
84
third_party/go-irisnative/src/connection/message_header.go
vendored
Normal file
84
third_party/go-irisnative/src/connection/message_header.go
vendored
Normal file
@@ -0,0 +1,84 @@
|
||||
package connection
|
||||
|
||||
type MessageType string
|
||||
|
||||
func setUint32(buffer []byte, value uint32) {
|
||||
buffer[0] = byte(value & 0xff)
|
||||
buffer[1] = byte(value >> 8 & 0xff)
|
||||
buffer[2] = byte(value >> 16 & 0xff)
|
||||
buffer[3] = byte(value >> 24 & 0xff)
|
||||
}
|
||||
|
||||
func getUint32(buffer []byte) uint32 {
|
||||
return uint32(buffer[0]) |
|
||||
uint32(buffer[1])<<8 |
|
||||
uint32(buffer[2])<<16 |
|
||||
uint32(buffer[3])<<24
|
||||
}
|
||||
|
||||
const (
|
||||
CONNECT MessageType = "\x43\x4e"
|
||||
HANDSHAKE MessageType = "\x48\x53"
|
||||
DISCONNECT MessageType = "\x44\x43"
|
||||
|
||||
GLOBAL_GET MessageType = "\x41\xc2"
|
||||
GLOBAL_SET MessageType = "\x42\xc2"
|
||||
GLOBAL_KILL MessageType = "\x43\xc2"
|
||||
GLOBAL_ORDER MessageType = "\x45\xc2"
|
||||
GLOBAL_DATA MessageType = "\x49\xc2"
|
||||
|
||||
CLASSMETHOD_VALUE MessageType = "\x4b\xc2"
|
||||
CLASSMETHOD_VOID MessageType = "\x4c\xc2"
|
||||
|
||||
METHOD_VALUE MessageType = "\x5b\xc2"
|
||||
METHOD_VOID MessageType = "\x5c\xc2"
|
||||
|
||||
PROPERTY_GET MessageType = "\x5d\xc2"
|
||||
PROPERTY_SET MessageType = "\x5e\xc2"
|
||||
|
||||
DIRECT_QUERY MessageType = "DQ"
|
||||
PREPARED_QUERY MessageType = "PQ"
|
||||
DIRECT_UPDATE MessageType = "DU"
|
||||
PREPARED_UPDATE MessageType = "PU"
|
||||
PREPARE MessageType = "PP"
|
||||
GET_AUTO_GENERATED_KEYS MessageType = "GG"
|
||||
|
||||
COMMIT MessageType = "TC"
|
||||
ROLLBACK MessageType = "TR"
|
||||
|
||||
MULTIPLE_RESULT_SETS_FETCH_DATA MessageType = "MD"
|
||||
GET_MORE_RESULTS MessageType = "MR"
|
||||
FETCH_DATA MessageType = "FD"
|
||||
GET_SERVER_ERROR MessageType = "OE"
|
||||
)
|
||||
|
||||
type MessageHeader struct {
|
||||
header [14]byte
|
||||
}
|
||||
|
||||
func NewMessageHeader(messageType MessageType) MessageHeader {
|
||||
header := [14]byte{}
|
||||
header[12] = messageType[0]
|
||||
header[13] = messageType[1]
|
||||
return MessageHeader{header}
|
||||
}
|
||||
|
||||
func (mh *MessageHeader) GetStatus() uint16 {
|
||||
return uint16(mh.header[12]) | (uint16(mh.header[13]) << 8)
|
||||
}
|
||||
|
||||
func (mh *MessageHeader) SetLength(length uint32) {
|
||||
setUint32(mh.header[0:], length)
|
||||
}
|
||||
|
||||
func (mh MessageHeader) GetLength() uint32 {
|
||||
return getUint32(mh.header[0:])
|
||||
}
|
||||
|
||||
func (mh *MessageHeader) SetCount(cnt uint32) {
|
||||
setUint32(mh.header[4:], cnt)
|
||||
}
|
||||
|
||||
func (mh *MessageHeader) SetStatementId(statementId uint32) {
|
||||
setUint32(mh.header[8:], statementId)
|
||||
}
|
||||
251
third_party/go-irisnative/src/connection/mod.go
vendored
Normal file
251
third_party/go-irisnative/src/connection/mod.go
vendored
Normal file
@@ -0,0 +1,251 @@
|
||||
package connection
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"net"
|
||||
)
|
||||
|
||||
const VERSION_PROTOCOL uint16 = 69
|
||||
|
||||
type Connection struct {
|
||||
conn *net.TCPConn
|
||||
messageCount uint32
|
||||
statement uint32
|
||||
unicode bool
|
||||
locale string
|
||||
version uint16
|
||||
info string
|
||||
featureOptions uint
|
||||
tx bool
|
||||
}
|
||||
|
||||
var (
|
||||
ErrCouldNotDetectUsername = errors.New("intersystems: Could not detect default username. Please provide one explicitly")
|
||||
errBeginTx = errors.New("could not begin transaction")
|
||||
errMultipleTx = errors.New("multiple transactions")
|
||||
errReadOnlyTxNotSupported = errors.New("read-only transactions are not supported")
|
||||
)
|
||||
|
||||
func Connect(addr string, namespace, login, password string) (connection Connection, err error) {
|
||||
|
||||
tcpAddr, err := net.ResolveTCPAddr("tcp", addr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := net.DialTCP("tcp", nil, tcpAddr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
connection = Connection{
|
||||
conn: conn,
|
||||
}
|
||||
|
||||
if err = connection.handshake(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err = connection.connect(namespace, login, password); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// fmt.Println(connection.version, connection.info)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Connection) Disconnect() {
|
||||
if c.conn == nil {
|
||||
return
|
||||
}
|
||||
msg := NewMessage(DISCONNECT)
|
||||
_, _ = c.conn.Write(msg.Dump(c.count()))
|
||||
_ = c.conn.Close()
|
||||
c.conn = nil
|
||||
}
|
||||
|
||||
func (c *Connection) count() uint32 {
|
||||
count := c.messageCount
|
||||
c.messageCount += 1
|
||||
return count
|
||||
}
|
||||
|
||||
func (c *Connection) statementId() uint32 {
|
||||
statement := c.statement
|
||||
c.statement += 1
|
||||
return statement
|
||||
}
|
||||
|
||||
func (c *Connection) handshake() (err error) {
|
||||
var message = NewMessage(HANDSHAKE)
|
||||
message.AddRaw(VERSION_PROTOCOL)
|
||||
|
||||
_, err = c.conn.Write(message.Dump(c.count()))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
msg, err := ReadMessage(c.conn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var version uint16
|
||||
msg.GetRaw(&version)
|
||||
c.version = version
|
||||
|
||||
var unicode uint16
|
||||
msg.GetRaw(&unicode)
|
||||
c.unicode = unicode == 1
|
||||
|
||||
var locale string
|
||||
msg.Get(&locale)
|
||||
c.locale = locale
|
||||
return
|
||||
}
|
||||
|
||||
func encode(value string) []byte {
|
||||
in := []byte(value)
|
||||
length := len(in)
|
||||
out := make([]byte, length)
|
||||
for i := range in {
|
||||
length--
|
||||
temp := ((int(in[i])^0xa7)&0xff + length) & 0xff
|
||||
out[length] = byte(temp<<5 | temp>>3)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
type FeatureOption uint
|
||||
|
||||
const (
|
||||
OptionNone FeatureOption = 0
|
||||
OptionFastSelect FeatureOption = 1
|
||||
OptionFastInsert FeatureOption = 2
|
||||
OptionFastSelectAndInsert FeatureOption = 3
|
||||
OptionDurableTransactions FeatureOption = 4
|
||||
OptionNotNullable FeatureOption = 8
|
||||
OptionRedirectOutput FeatureOption = 32
|
||||
)
|
||||
|
||||
func (c *Connection) IsOptionFastInsert() bool {
|
||||
return c.featureOptions&uint(OptionFastInsert) == uint(OptionFastInsert)
|
||||
}
|
||||
|
||||
func (c *Connection) IsOptionFastSelect() bool {
|
||||
return c.featureOptions&uint(OptionFastSelect) == uint(OptionFastSelect)
|
||||
}
|
||||
|
||||
func (c *Connection) connect(namespace, login, password string) (err error) {
|
||||
msg := NewMessage(CONNECT)
|
||||
msg.Set(namespace)
|
||||
msg.Set(encode(login))
|
||||
msg.Set(encode(password))
|
||||
var user = "go"
|
||||
if user, err = systemUser(); err != nil {
|
||||
user = "go"
|
||||
}
|
||||
msg.Set(user) // machine user name
|
||||
msg.Set("go-machine") // machine name
|
||||
msg.Set("libirisnative") // application name
|
||||
msg.Set("") // ?
|
||||
msg.Set("go") // SharedMemoryFlag?
|
||||
msg.Set("") // EventClass
|
||||
msg.Set(1) // AutoCommit ? 1 : 2
|
||||
msg.Set(0) // IsolationLevel
|
||||
var featureOptions = OptionNone
|
||||
featureOptions += OptionFastSelect
|
||||
// Tricky to make it fully working yet
|
||||
// featureOptions += OptionFastInsert
|
||||
featureOptions += OptionDurableTransactions
|
||||
featureOptions += OptionRedirectOutput
|
||||
msg.Set(int(featureOptions)) // FeatureOption
|
||||
|
||||
_, err = c.conn.Write(msg.Dump(c.count()))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
msg, err = ReadMessage(c.conn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if status := msg.GetStatus(); status == 417 {
|
||||
var errorMsg string
|
||||
msg.Get(&errorMsg)
|
||||
err = errors.New(errorMsg)
|
||||
return
|
||||
}
|
||||
|
||||
var info string
|
||||
msg.Get(&info)
|
||||
c.info = info
|
||||
var (
|
||||
delimited_ids bool
|
||||
ignored int
|
||||
isolationLevel int
|
||||
serverJobNumber string
|
||||
sqlEmptyString int
|
||||
serverFeatureOptions uint
|
||||
)
|
||||
msg.Get(&delimited_ids)
|
||||
msg.Get(&ignored)
|
||||
msg.Get(&isolationLevel)
|
||||
msg.Get(&serverJobNumber)
|
||||
msg.Get(&sqlEmptyString)
|
||||
msg.Get(&serverFeatureOptions)
|
||||
c.featureOptions = serverFeatureOptions
|
||||
return
|
||||
}
|
||||
|
||||
func systemUser() (string, error) {
|
||||
u, err := userCurrent()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func (c *Connection) Commit() (err error) {
|
||||
msg := NewMessage(COMMIT)
|
||||
_, err = c.conn.Write(msg.Dump(c.count()))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = ReadMessage(c.conn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Connection) Rollback() (err error) {
|
||||
msg := NewMessage(ROLLBACK)
|
||||
_, err = c.conn.Write(msg.Dump(c.count()))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = ReadMessage(c.conn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Connection) BeginTx(opts driver.TxOptions) (driver.Tx, error) {
|
||||
if c.tx {
|
||||
return nil, errors.Join(errBeginTx, errMultipleTx)
|
||||
}
|
||||
|
||||
if opts.ReadOnly {
|
||||
return nil, errors.Join(errBeginTx, errReadOnlyTxNotSupported)
|
||||
}
|
||||
|
||||
if _, err := c.DirectUpdate("START TRANSACTION"); err != nil {
|
||||
return nil, errors.Join(errBeginTx, err)
|
||||
}
|
||||
c.tx = true
|
||||
return &tx{c}, nil
|
||||
}
|
||||
101
third_party/go-irisnative/src/connection/rows.go
vendored
Normal file
101
third_party/go-irisnative/src/connection/rows.go
vendored
Normal file
@@ -0,0 +1,101 @@
|
||||
package connection
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
errNoRowsAffected = errors.New("no RowsAffected available after the empty statement")
|
||||
errNoLastInsertID = errors.New("no LastInsertId available after the empty statement")
|
||||
)
|
||||
|
||||
type Result struct {
|
||||
cn *Connection
|
||||
affected int64
|
||||
}
|
||||
|
||||
func (r Result) LastInsertId() (lastId int64, err error) {
|
||||
// var msg Message
|
||||
// msg = NewMessage(GET_AUTO_GENERATED_KEYS)
|
||||
// msg.header.SetStatementId(r.cn.statementId())
|
||||
// _, err = r.cn.conn.Write(msg.Dump(r.cn.count()))
|
||||
// if err != nil {
|
||||
// return
|
||||
// }
|
||||
// msg, err = ReadMessage(r.cn.conn)
|
||||
// if err != nil {
|
||||
// return
|
||||
// }
|
||||
// msg.Get(&lastId)
|
||||
// return
|
||||
var rs *ResultSet
|
||||
rs, err = r.cn.DirectQuery("SELECT LAST_IDENTITY()")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
row, err := rs.Next()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
lastId = int64(row[0].(int))
|
||||
return
|
||||
}
|
||||
|
||||
func (r Result) RowsAffected() (int64, error) {
|
||||
return r.affected, nil
|
||||
}
|
||||
|
||||
type Rows struct {
|
||||
cn *Connection
|
||||
rs *ResultSet
|
||||
}
|
||||
|
||||
type noRows struct{}
|
||||
|
||||
var emptyRows noRows
|
||||
|
||||
var _ driver.Result = noRows{}
|
||||
|
||||
func (noRows) LastInsertId() (int64, error) {
|
||||
return 0, errNoLastInsertID
|
||||
}
|
||||
|
||||
func (noRows) RowsAffected() (int64, error) {
|
||||
return 0, errNoRowsAffected
|
||||
}
|
||||
|
||||
|
||||
func (r *Rows) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Rows) Columns() []string {
|
||||
if r.rs == nil {
|
||||
return []string{}
|
||||
}
|
||||
columns := r.rs.Columns()
|
||||
colNames := make([]string, len(columns))
|
||||
for k, c := range columns {
|
||||
colname := c.Name()
|
||||
// tricking IRIS
|
||||
colname = strings.ReplaceAll(colname, "﹒", ".")
|
||||
colNames[k] = colname
|
||||
}
|
||||
// fmt.Printf("Columns: %#v\n", colNames)
|
||||
return colNames
|
||||
}
|
||||
|
||||
func (r *Rows) Next(dest []driver.Value) (err error) {
|
||||
row, err := r.rs.Next()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for i := range dest {
|
||||
dest[i] = row[i]
|
||||
}
|
||||
// fmt.Printf("RowsNext: %#v\n", dest)
|
||||
return nil
|
||||
}
|
||||
|
||||
729
third_party/go-irisnative/src/connection/sql.go
vendored
Normal file
729
third_party/go-irisnative/src/connection/sql.go
vendored
Normal file
@@ -0,0 +1,729 @@
|
||||
package connection
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"io"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/caretdev/go-irisnative/src/list"
|
||||
)
|
||||
|
||||
const timeLaylout = "2006-01-02 15:04:05.000000000"
|
||||
const timeLayloutShort = "2006-01-02 15:04:05"
|
||||
|
||||
type StatementFeature struct {
|
||||
featureOption int
|
||||
msgCount int
|
||||
maxRowItemCount int
|
||||
}
|
||||
|
||||
type Column struct {
|
||||
name string
|
||||
column_type int
|
||||
precision int
|
||||
scale int
|
||||
nullable int
|
||||
slot_position int
|
||||
label string
|
||||
table_name string
|
||||
schema string
|
||||
catalog string
|
||||
is_auto_increment bool
|
||||
is_case_sensitive bool
|
||||
is_currency bool
|
||||
is_read_only bool
|
||||
is_row_id bool
|
||||
}
|
||||
|
||||
type SQLTYPE int16
|
||||
|
||||
const (
|
||||
GUID SQLTYPE = -11
|
||||
WLONGVARCHAR SQLTYPE = -10
|
||||
WVARCHAR SQLTYPE = -9
|
||||
WCHAR SQLTYPE = -8
|
||||
BIT SQLTYPE = -7
|
||||
TINYINT SQLTYPE = -6
|
||||
BIGINT SQLTYPE = -5
|
||||
LONGVARBINARY SQLTYPE = -4
|
||||
VARBINARY SQLTYPE = -3
|
||||
BINARY SQLTYPE = -2
|
||||
LONGVARCHAR SQLTYPE = -1
|
||||
CHAR SQLTYPE = 1
|
||||
NUMERIC SQLTYPE = 2
|
||||
DECIMAL SQLTYPE = 3
|
||||
INTEGER SQLTYPE = 4
|
||||
SMALLINT SQLTYPE = 5
|
||||
FLOAT SQLTYPE = 6
|
||||
REAL SQLTYPE = 7
|
||||
DOUBLE SQLTYPE = 8
|
||||
DATE SQLTYPE = 9
|
||||
TIME SQLTYPE = 10
|
||||
TIMESTAMP SQLTYPE = 11
|
||||
VARCHAR SQLTYPE = 12
|
||||
TYPE_DATE SQLTYPE = 91
|
||||
TYPE_TIME SQLTYPE = 92
|
||||
TYPE_TIMESTAMP SQLTYPE = 93
|
||||
DATE_HOROLOG SQLTYPE = 1091
|
||||
TIME_HOROLOG SQLTYPE = 1092
|
||||
TIMESTAMP_POSIX SQLTYPE = 1093
|
||||
)
|
||||
|
||||
func (c Column) Name() string {
|
||||
return c.name
|
||||
}
|
||||
|
||||
type ResultSet struct {
|
||||
c *Connection
|
||||
columns []Column
|
||||
sf StatementFeature
|
||||
count int
|
||||
data []byte
|
||||
offset uint
|
||||
sqlCode int16
|
||||
}
|
||||
|
||||
type SQLError struct {
|
||||
SQLCode int16
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *SQLError) Error() string {
|
||||
return fmt.Sprintf("Error Code: %d, Message: %s", e.SQLCode, e.Message)
|
||||
}
|
||||
|
||||
// func SQLError(code int) error {
|
||||
// return &SQLError{SQLCode: code}
|
||||
// }
|
||||
|
||||
func (rs ResultSet) Columns() []Column {
|
||||
return rs.columns
|
||||
}
|
||||
|
||||
func statementFeature(msg *Message) StatementFeature {
|
||||
featureOption := 0
|
||||
msgCount := 0
|
||||
maxRowItemCount := 0
|
||||
msg.Get(&featureOption)
|
||||
if featureOption == 2 {
|
||||
msg.Get(&msgCount)
|
||||
}
|
||||
if featureOption == 1 || featureOption == 2 {
|
||||
msg.Get(&maxRowItemCount)
|
||||
}
|
||||
return StatementFeature{
|
||||
featureOption,
|
||||
msgCount,
|
||||
maxRowItemCount,
|
||||
}
|
||||
}
|
||||
|
||||
type Value interface{}
|
||||
|
||||
// type ResultSetRow struct{}
|
||||
|
||||
func (rs *ResultSet) fetchMoreData() bool {
|
||||
msg := NewMessage(FETCH_DATA)
|
||||
_, err := rs.c.conn.Write(msg.Dump(rs.c.count()))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
msg, err = ReadMessage(rs.c.conn)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
rs.data = msg.data
|
||||
rs.offset = 0
|
||||
return len(msg.data) > 0
|
||||
}
|
||||
|
||||
func fromODBC(coltype SQLTYPE, li list.ListItem) (result interface{}, err error) {
|
||||
result = nil
|
||||
if li.IsNull() || li.IsEmpty() {
|
||||
return
|
||||
}
|
||||
switch coltype {
|
||||
case VARCHAR:
|
||||
if li.DataLength() == 0 {
|
||||
return
|
||||
}
|
||||
var value string
|
||||
li.Get(&value)
|
||||
if value == "\x00" {
|
||||
value = ""
|
||||
}
|
||||
result = value
|
||||
case INTEGER, TINYINT, SMALLINT:
|
||||
var value int
|
||||
li.Get(&value)
|
||||
result = value
|
||||
case BIGINT:
|
||||
var value int64
|
||||
li.Get(&value)
|
||||
result = value
|
||||
case BIT:
|
||||
var value bool
|
||||
li.Get(&value)
|
||||
result = value
|
||||
case FLOAT:
|
||||
var value float32
|
||||
li.Get(&value)
|
||||
result = value
|
||||
case DOUBLE:
|
||||
var value float64
|
||||
li.Get(&value)
|
||||
result = value
|
||||
case TIMESTAMP_POSIX:
|
||||
if li.DataLength() == 0 {
|
||||
return
|
||||
}
|
||||
if li.Type() == list.LISTITEM_STRING {
|
||||
var strval string
|
||||
li.Get(&strval)
|
||||
result, err = time.Parse(timeLaylout, strval)
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
err = nil
|
||||
}
|
||||
var value int64
|
||||
li.Get(&value)
|
||||
if value > 0 {
|
||||
value ^= 0x1000000000000000
|
||||
} else {
|
||||
value |= 0x6000000000000000
|
||||
}
|
||||
seconds := value / 1000000
|
||||
nano := value % 1000000 * 1000
|
||||
result = time.Unix(seconds, nano).In(time.Local)
|
||||
case VARBINARY:
|
||||
// var value []uint8
|
||||
var value string
|
||||
li.Get(&value)
|
||||
case TYPE_TIMESTAMP:
|
||||
var strval string
|
||||
li.Get(&strval)
|
||||
result, err = time.Parse(timeLayloutShort, strval)
|
||||
default:
|
||||
var value string
|
||||
li.Get(&value)
|
||||
fmt.Printf("fromODBC: invalid type: %v - %#v - %#v", coltype, li, value)
|
||||
result = value
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (rs *ResultSet) Next() ([]Value, error) {
|
||||
if rs == nil || (rs.sqlCode != 0 && rs.sqlCode != 100) {
|
||||
return nil, io.EOF
|
||||
}
|
||||
if rs.offset >= uint(len(rs.data)) && (rs.sqlCode == 100 || !rs.fetchMoreData()) {
|
||||
return nil, io.EOF
|
||||
}
|
||||
row := make([]Value, rs.count)
|
||||
data := rs.data
|
||||
count := rs.count
|
||||
var offset uint = rs.offset
|
||||
if rs.sf.featureOption == 1 {
|
||||
li := list.GetListItem(data, &rs.offset)
|
||||
li.Get(&data)
|
||||
offset = 0
|
||||
count = rs.sf.maxRowItemCount
|
||||
}
|
||||
vals := make([]list.ListItem, count)
|
||||
for i := 0; i < count; i++ {
|
||||
li := list.GetListItem(data, &offset)
|
||||
vals[i] = li
|
||||
}
|
||||
if rs.sf.featureOption != 1 {
|
||||
rs.offset = offset
|
||||
}
|
||||
var err error
|
||||
for i, c := range rs.columns {
|
||||
li := vals[c.slot_position]
|
||||
row[i], err = fromODBC(SQLTYPE(c.column_type), li)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// fmt.Printf("col: %s: %d; %#v - %#v\n", c.name, c.column_type, row[i], li)
|
||||
}
|
||||
// fmt.Printf("row: %#v\n", row)
|
||||
return row, nil
|
||||
}
|
||||
|
||||
func (c *Connection) getErrorInfo(sqlCode int16) string {
|
||||
msg := NewMessage(GET_SERVER_ERROR)
|
||||
msg.Set(sqlCode)
|
||||
_, err := c.conn.Write(msg.Dump(c.count()))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
msg, err = ReadMessage(c.conn)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
var sqlMessage string
|
||||
msg.Get(&sqlMessage)
|
||||
return sqlMessage
|
||||
}
|
||||
|
||||
func getColumns(msg *Message, statementFeature StatementFeature) []Column {
|
||||
cnt := 0
|
||||
msg.Get(&cnt)
|
||||
columns := make([]Column, cnt)
|
||||
for i := 0; i < cnt; i++ {
|
||||
column := Column{}
|
||||
msg.Get(&column.name)
|
||||
msg.Get(&column.column_type)
|
||||
switch column.column_type {
|
||||
case 9:
|
||||
column.column_type = 91
|
||||
case 10:
|
||||
column.column_type = 92
|
||||
case 11:
|
||||
column.column_type = 93
|
||||
}
|
||||
msg.Get(&column.precision)
|
||||
msg.Get(&column.scale)
|
||||
msg.Get(&column.nullable)
|
||||
msg.Get(&column.label)
|
||||
msg.Get(&column.table_name)
|
||||
msg.Get(&column.schema)
|
||||
msg.Get(&column.catalog)
|
||||
additional := ""
|
||||
msg.Get(&additional)
|
||||
if statementFeature.featureOption&0x01 == 1 {
|
||||
msg.Get(&column.slot_position)
|
||||
column.slot_position -= 1
|
||||
} else {
|
||||
column.slot_position = i
|
||||
}
|
||||
column.is_auto_increment = additional[0] == 0x01
|
||||
column.is_case_sensitive = additional[1] == 0x01
|
||||
column.is_currency = additional[2] == 0x01
|
||||
column.is_read_only = additional[3] == 0x01
|
||||
if len(additional) >= 12 {
|
||||
column.is_row_id = additional[11] == 0x01
|
||||
}
|
||||
columns[i] = column
|
||||
}
|
||||
return columns
|
||||
}
|
||||
|
||||
func parameterInfo(msg *Message) {
|
||||
cnt := 0
|
||||
msg.Get(&cnt)
|
||||
flag := 0
|
||||
msg.Get(&flag)
|
||||
}
|
||||
|
||||
func toODBC(value interface{}) interface{} {
|
||||
var val interface{}
|
||||
switch v := value.(type) {
|
||||
case *string:
|
||||
val = *v
|
||||
case string:
|
||||
val = v
|
||||
if v == "" {
|
||||
val = "\x00"
|
||||
}
|
||||
case nil:
|
||||
val = ""
|
||||
case bool:
|
||||
if v {
|
||||
val = 1
|
||||
} else {
|
||||
val = 0
|
||||
}
|
||||
case time.Time:
|
||||
val = v.UTC().Format(timeLaylout)
|
||||
case int, int8, int16, int32, int64:
|
||||
val = v
|
||||
case float32, float64:
|
||||
val = v
|
||||
case []uint8:
|
||||
val = v
|
||||
default:
|
||||
fmt.Printf("unsupported type: %T\n", v)
|
||||
val = fmt.Sprintf("%v", v)
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
func writeParameters(msg *Message, args ...interface{}) {
|
||||
msg.Set(len(args))
|
||||
for range args {
|
||||
msg.Set(99)
|
||||
msg.Set(4)
|
||||
}
|
||||
|
||||
msg.Set(1) // parameterSets
|
||||
msg.Set(len(args))
|
||||
for _, arg := range args {
|
||||
msg.Set(toODBC(arg))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connection) Query(sqlText string, args ...interface{}) (rs *ResultSet, err error) {
|
||||
queries := strings.Split(sqlText, ";\n")
|
||||
if len(queries) == 2 {
|
||||
sqlText = queries[0]
|
||||
_, err = c.DirectUpdate(sqlText, args...)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
sqlText = queries[1]
|
||||
args = []interface{}{}
|
||||
}
|
||||
rs, err = c.DirectQuery(sqlText, args...)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Connection) DirectQuery(sqlText string, args ...interface{}) (*ResultSet, error) {
|
||||
sqlText, _, args = FormatQuery(sqlText, args...)
|
||||
// fmt.Printf("DirectQuery: %s; %#v\n", sqlText, args)
|
||||
|
||||
var statementId = c.statementId()
|
||||
msg := NewMessage(DIRECT_QUERY)
|
||||
msg.header.SetStatementId(statementId)
|
||||
msg.SetSQLText(sqlText)
|
||||
writeParameters(&msg, args...)
|
||||
msg.Set(10) // Query timeout
|
||||
msg.Set(200) // Max rows
|
||||
|
||||
_, err := c.conn.Write(msg.Dump(c.count()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
msg, err = ReadMessage(c.conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sqlCode := int16(msg.GetStatus())
|
||||
if sqlCode != 0 && sqlCode != 100 {
|
||||
return nil, &SQLError{SQLCode: sqlCode, Message: c.getErrorInfo(sqlCode)}
|
||||
}
|
||||
statementFeature := statementFeature(&msg)
|
||||
columns := getColumns(&msg, statementFeature)
|
||||
parameterInfo((&msg))
|
||||
rs := &ResultSet{
|
||||
c: c,
|
||||
sf: statementFeature,
|
||||
columns: columns,
|
||||
count: len(columns),
|
||||
}
|
||||
|
||||
msg, err = ReadMessage(c.conn)
|
||||
rs.sqlCode = int16(msg.GetStatus())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msg.GetRaw(&rs.data)
|
||||
|
||||
return rs, nil
|
||||
}
|
||||
|
||||
func (m Message) debug() string {
|
||||
var sb strings.Builder
|
||||
for i, b := range m.data {
|
||||
if i > 0 {
|
||||
sb.WriteString(", ")
|
||||
}
|
||||
sb.WriteString(strconv.Itoa(int(b)))
|
||||
}
|
||||
return fmt.Sprintf("$char(%s)", sb.String())
|
||||
}
|
||||
|
||||
func FormatQuery(sqlText string, args ...interface{}) (string, int, []interface{}) {
|
||||
var count int
|
||||
for i := range args {
|
||||
count++
|
||||
sqlText = strings.Replace(sqlText, "?", fmt.Sprintf(" :%%qpar(%d) ", i+1), 1)
|
||||
if !strings.Contains(sqlText, "?") {
|
||||
break
|
||||
}
|
||||
}
|
||||
return sqlText, count, args
|
||||
}
|
||||
|
||||
func (c *Connection) Exec(sqlText string, args ...interface{}) (res *Result, err error) {
|
||||
queries := strings.Split(sqlText, ";\n")
|
||||
var onConflict = ""
|
||||
if len(queries) == 2 {
|
||||
sqlText = queries[0]
|
||||
onConflict = strings.Split(queries[1], "-- ")[1]
|
||||
if strings.Contains(onConflict, "ON CONFLICT UPDATE") {
|
||||
// fmt.Printf("------\n%s\n%#v\n------\n", sqlText, args)
|
||||
sqlText = strings.Replace(sqlText, "INSERT INTO", "INSERT OR UPDATE", 1)
|
||||
onConflict = ""
|
||||
}
|
||||
}
|
||||
res, err = c.DirectUpdate(sqlText, args...)
|
||||
if err != nil {
|
||||
if strings.Contains(onConflict, "ON CONFLICT DO NOTHING") {
|
||||
res = &Result{cn: c, affected: 0}
|
||||
err = nil
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Connection) DirectUpdate(sqlText string, args ...interface{}) (*Result, error) {
|
||||
var batchSize int
|
||||
sqlText, batchSize, args = FormatQuery(sqlText, args...)
|
||||
// fmt.Printf("DirectUpdate: %s; %#v\n", sqlText, args)
|
||||
var batches = 1
|
||||
if batchSize > 0 {
|
||||
batches = len(args) / batchSize
|
||||
}
|
||||
var addToCache = false
|
||||
var statementId = c.statementId()
|
||||
var executeMany = false
|
||||
var optFastInsert = false
|
||||
var rowsAffected int64 = 0
|
||||
var identityColumn = false
|
||||
var defaults = []interface{}{}
|
||||
for i := 1; i <= batches; i++ {
|
||||
if i > 1 && executeMany {
|
||||
break
|
||||
}
|
||||
var msg Message
|
||||
if !addToCache {
|
||||
msg = NewMessage(DIRECT_UPDATE)
|
||||
msg.SetSQLText(sqlText)
|
||||
msg.Set(batchSize)
|
||||
for j := 0; j < batchSize; j++ {
|
||||
msg.Set(99)
|
||||
msg.Set(1)
|
||||
}
|
||||
// msg.Set(len(args))
|
||||
// for range args {
|
||||
// msg.Set(99)
|
||||
// msg.Set(1)
|
||||
// }
|
||||
} else {
|
||||
msg = NewMessage(PREPARED_UPDATE)
|
||||
}
|
||||
if addToCache && !executeMany && optFastInsert {
|
||||
msg.AddRaw([]byte{1, 0, 0, 0})
|
||||
msg.Set("")
|
||||
msg.Set(0)
|
||||
if identityColumn {
|
||||
msg.Set(2)
|
||||
msg.Set("")
|
||||
} else {
|
||||
msg.Set(1)
|
||||
}
|
||||
var batch []interface{} = make([]interface{}, batchSize)
|
||||
copy(batch, args)
|
||||
args = slices.Delete(args, 0, batchSize)
|
||||
var params []byte
|
||||
var item list.ListItem
|
||||
for _, arg := range batch {
|
||||
item = list.NewListItem(toODBC(arg))
|
||||
params = append(params, item.Dump()...)
|
||||
}
|
||||
for _, arg := range defaults {
|
||||
item = list.NewListItem(toODBC(arg))
|
||||
params = append(params, item.Dump()...)
|
||||
}
|
||||
msg.Set(params)
|
||||
} else {
|
||||
msg.Set("")
|
||||
msg.Set(0)
|
||||
if executeMany {
|
||||
msg.Set(batches)
|
||||
for k := 0; k < batches; k++ {
|
||||
msg.Set(batchSize)
|
||||
for j := 0; j < batchSize; j++ {
|
||||
var idx = (k * batchSize) + j
|
||||
msg.Set(toODBC(args[idx]))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
var batch []interface{} = make([]interface{}, batchSize)
|
||||
copy(batch, args)
|
||||
args = slices.Delete(args, 0, batchSize)
|
||||
msg.Set(1)
|
||||
msg.Set(len(batch))
|
||||
for _, arg := range batch {
|
||||
msg.Set(toODBC(arg))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
msg.header.SetStatementId(statementId)
|
||||
_, err := c.conn.Write(msg.Dump(c.count()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
msg, err = ReadMessage(c.conn)
|
||||
if err != nil {
|
||||
// fmt.Println("DirectUpdate:Readmessage: ", err)
|
||||
return nil, err
|
||||
}
|
||||
sqlCode := int16(msg.GetStatus())
|
||||
if sqlCode != 0 && sqlCode != 100 {
|
||||
return nil, &SQLError{SQLCode: sqlCode, Message: c.getErrorInfo(sqlCode)}
|
||||
}
|
||||
if i == 1 {
|
||||
if c.IsOptionFastInsert() {
|
||||
stmtFeatureOption, _ := c.checkStatementFeature(&msg)
|
||||
optFastInsert = stmtFeatureOption&uint(OptionFastInsert) == uint(OptionFastInsert)
|
||||
}
|
||||
addToCache, identityColumn, defaults = c.getParameterInfo(&msg, optFastInsert)
|
||||
}
|
||||
var batchRows int64
|
||||
msg.Get(&batchRows)
|
||||
rowsAffected += batchRows
|
||||
}
|
||||
result := &Result{cn: c, affected: rowsAffected}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (c *Connection) checkStatementFeature(msg *Message) (featureOption uint, count uint) {
|
||||
count = 0
|
||||
var keyCount int
|
||||
msg.Get(&featureOption)
|
||||
if featureOption == uint(OptionFastSelect) || featureOption == uint(OptionFastInsert) {
|
||||
if featureOption == uint(OptionFastInsert) {
|
||||
msg.Get(&keyCount)
|
||||
}
|
||||
msg.Get(&count)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Connection) getParameterInfo(msg *Message, optFastInsert bool) (addToCache bool, identityColumn bool, defaults []interface{}) {
|
||||
var paramscnt int
|
||||
msg.Get(¶mscnt)
|
||||
var tablename string
|
||||
for i := 0; i < paramscnt; i++ {
|
||||
var (
|
||||
paramtype int
|
||||
precision int
|
||||
scale int
|
||||
nullable bool
|
||||
position int
|
||||
someval1 string
|
||||
someval2 string
|
||||
colname string
|
||||
)
|
||||
msg.Get(¶mtype)
|
||||
msg.Get(&precision)
|
||||
msg.Get(&scale)
|
||||
msg.GetAny()
|
||||
if optFastInsert {
|
||||
msg.Get(&nullable)
|
||||
msg.Get(&position)
|
||||
msg.Get(&someval1)
|
||||
msg.Get(&someval2)
|
||||
if i == 0 {
|
||||
msg.Get(&tablename)
|
||||
}
|
||||
msg.Get(&colname)
|
||||
}
|
||||
}
|
||||
var flag int
|
||||
defaults = []interface{}{}
|
||||
identityColumn = false
|
||||
msg.Get(&flag)
|
||||
addToCache = flag&0x1 == 0x1
|
||||
if optFastInsert {
|
||||
var paramsDefault []byte
|
||||
msg.Get(¶msDefault)
|
||||
var offset uint = 0
|
||||
var li list.ListItem
|
||||
li = list.GetListItem(paramsDefault, &offset)
|
||||
identityColumn = li.IsEmpty()
|
||||
for {
|
||||
if uint(len(paramsDefault)) == offset {
|
||||
break
|
||||
}
|
||||
li = list.GetListItem(paramsDefault, &offset)
|
||||
if li.IsNull() {
|
||||
continue
|
||||
}
|
||||
var val string
|
||||
li.Get(&val)
|
||||
defaults = append(defaults, val)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type Stmt struct {
|
||||
cn *Connection
|
||||
sql string
|
||||
closed bool
|
||||
statementId int32
|
||||
}
|
||||
|
||||
func (c *Connection) Prepare(query string) (*Stmt, error) {
|
||||
// msg := NewMessage(PREPARE)
|
||||
// msg.SetSQLText(query)
|
||||
// msg.Set(0)
|
||||
|
||||
// _, err := c.conn.Write(msg.Dump(c.count()))
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// msg, err = ReadMessage(c.conn)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// sqlCode := int16(msg.GetStatus())
|
||||
// if sqlCode != 0 && sqlCode != 100 {
|
||||
// return nil, &SQLError{SQLCode: sqlCode, Message: c.getErrorInfo(sqlCode)}
|
||||
// }
|
||||
|
||||
st := &Stmt{cn: c, sql: query}
|
||||
return st, nil
|
||||
}
|
||||
|
||||
func (st *Stmt) Exec(args []driver.Value) (res driver.Result, err error) {
|
||||
parameters := make([]interface{}, len(args))
|
||||
for i, a := range args {
|
||||
parameters[i] = a
|
||||
}
|
||||
res, err = st.cn.Exec(st.sql, parameters...)
|
||||
return
|
||||
}
|
||||
|
||||
func (st *Stmt) Query(args []driver.Value) (rows driver.Rows, err error) {
|
||||
parameters := make([]interface{}, len(args))
|
||||
for i, a := range args {
|
||||
parameters[i] = a
|
||||
}
|
||||
var rs *ResultSet
|
||||
rs, err = st.cn.Query(st.sql, parameters...)
|
||||
// st.statementId = int32(st.cn.statementId())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rows = &Rows{
|
||||
cn: st.cn,
|
||||
rs: rs,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (st *Stmt) Close() (err error) {
|
||||
st.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (st *Stmt) NumInput() int {
|
||||
return -1
|
||||
}
|
||||
29
third_party/go-irisnative/src/connection/transaction.go
vendored
Normal file
29
third_party/go-irisnative/src/connection/transaction.go
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
package connection
|
||||
|
||||
type tx struct {
|
||||
c *Connection
|
||||
}
|
||||
|
||||
func (t *tx) Commit() error {
|
||||
if t.c == nil || !t.c.tx {
|
||||
panic("database/sql/driver: misuse of driver: extra Commit")
|
||||
}
|
||||
|
||||
t.c.tx = false
|
||||
err := t.c.Commit()
|
||||
t.c = nil
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *tx) Rollback() error {
|
||||
if t.c == nil || !t.c.tx {
|
||||
panic("database/sql/driver: misuse of driver: extra Rollback")
|
||||
}
|
||||
|
||||
t.c.tx = false
|
||||
err := t.c.Rollback()
|
||||
t.c = nil
|
||||
|
||||
return err
|
||||
}
|
||||
22
third_party/go-irisnative/src/connection/user_posix.go
vendored
Normal file
22
third_party/go-irisnative/src/connection/user_posix.go
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
//go:build !windows
|
||||
|
||||
package connection
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/user"
|
||||
)
|
||||
|
||||
func userCurrent() (string, error) {
|
||||
u, err := user.Current()
|
||||
if err == nil {
|
||||
return u.Username, nil
|
||||
}
|
||||
|
||||
name := os.Getenv("USER")
|
||||
if name != "" {
|
||||
return name, nil
|
||||
}
|
||||
|
||||
return "", ErrCouldNotDetectUsername
|
||||
}
|
||||
19
third_party/go-irisnative/src/connection/user_windows.go
vendored
Normal file
19
third_party/go-irisnative/src/connection/user_windows.go
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
package connection
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// Perform Windows user name.
|
||||
func userCurrent() (string, error) {
|
||||
pw_name := make([]uint16, 128)
|
||||
pwname_size := uint32(len(pw_name)) - 1
|
||||
err := syscall.GetUserNameEx(syscall.NameSamCompatible, &pw_name[0], &pwname_size)
|
||||
if err != nil {
|
||||
return "", ErrCouldNotDetectUsername
|
||||
}
|
||||
s := syscall.UTF16ToString(pw_name)
|
||||
u := filepath.Base(s)
|
||||
return u, nil
|
||||
}
|
||||
Reference in New Issue
Block a user