Files
MyGoNavi/internal/db/mysql_impl.go
2026-04-11 21:53:52 +08:00

1030 lines
25 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package db
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"math"
"net/url"
"strconv"
"strings"
"time"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/logger"
"GoNavi-Wails/internal/ssh"
"GoNavi-Wails/internal/utils"
_ "github.com/go-sql-driver/mysql"
)
type MySQLDB struct {
conn *sql.DB
pingTimeout time.Duration
}
const defaultMySQLPort = 3306
func parseHostPortWithDefault(raw string, defaultPort int) (string, int, bool) {
text := strings.TrimSpace(raw)
if text == "" {
return "", 0, false
}
if strings.HasPrefix(text, "[") {
end := strings.Index(text, "]")
if end < 0 {
return text, defaultPort, true
}
host := text[1:end]
portText := strings.TrimSpace(text[end+1:])
if strings.HasPrefix(portText, ":") {
if p, err := strconv.Atoi(strings.TrimSpace(strings.TrimPrefix(portText, ":"))); err == nil && p > 0 {
return host, p, true
}
}
return host, defaultPort, true
}
lastColon := strings.LastIndex(text, ":")
if lastColon > 0 && strings.Count(text, ":") == 1 {
host := strings.TrimSpace(text[:lastColon])
portText := strings.TrimSpace(text[lastColon+1:])
if host != "" {
if p, err := strconv.Atoi(portText); err == nil && p > 0 {
return host, p, true
}
return host, defaultPort, true
}
}
return text, defaultPort, true
}
func normalizeMySQLAddress(host string, port int) string {
h := strings.TrimSpace(host)
if h == "" {
h = "localhost"
}
p := port
if p <= 0 {
p = defaultMySQLPort
}
return fmt.Sprintf("%s:%d", h, p)
}
var mysqlDatabaseQueries = []string{
"SHOW DATABASES",
"SELECT DATABASE() AS `Database`",
}
func collectMySQLDatabaseNames(queryFn func(string) ([]map[string]interface{}, []string, error)) ([]string, error) {
if queryFn == nil {
return nil, fmt.Errorf("查询函数为空")
}
names := make([]string, 0, 8)
seen := make(map[string]struct{}, 8)
var lastErr error
appendNames := func(rows []map[string]interface{}) {
for _, row := range rows {
for _, key := range []string{"Database", "database"} {
val, ok := row[key]
if !ok || val == nil {
continue
}
name := strings.TrimSpace(fmt.Sprintf("%v", val))
if name == "" || strings.EqualFold(name, "<nil>") {
continue
}
if _, exists := seen[name]; exists {
continue
}
seen[name] = struct{}{}
names = append(names, name)
break
}
}
}
for _, sqlText := range mysqlDatabaseQueries {
rows, _, err := queryFn(sqlText)
if err != nil {
lastErr = err
continue
}
appendNames(rows)
if len(names) > 0 {
return names, nil
}
}
if len(names) > 0 {
return names, nil
}
if lastErr != nil {
return nil, lastErr
}
return nil, fmt.Errorf("未获取到可用数据库")
}
func applyMySQLURI(config connection.ConnectionConfig) connection.ConnectionConfig {
uriText := strings.TrimSpace(config.URI)
if uriText == "" {
return config
}
lowerURI := strings.ToLower(uriText)
if !strings.HasPrefix(lowerURI, "mysql://") {
return config
}
parsed, err := url.Parse(uriText)
if err != nil {
return config
}
if parsed.User != nil {
if config.User == "" {
config.User = parsed.User.Username()
}
if pass, ok := parsed.User.Password(); ok && config.Password == "" {
config.Password = pass
}
}
if dbName := strings.TrimPrefix(parsed.Path, "/"); dbName != "" && config.Database == "" {
config.Database = dbName
}
defaultPort := config.Port
if defaultPort <= 0 {
defaultPort = defaultMySQLPort
}
hostsFromURI := make([]string, 0, 4)
hostText := strings.TrimSpace(parsed.Host)
if hostText != "" {
for _, entry := range strings.Split(hostText, ",") {
host, port, ok := parseHostPortWithDefault(entry, defaultPort)
if !ok {
continue
}
hostsFromURI = append(hostsFromURI, normalizeMySQLAddress(host, port))
}
}
if len(config.Hosts) == 0 && len(hostsFromURI) > 0 {
config.Hosts = hostsFromURI
}
if strings.TrimSpace(config.Host) == "" && len(hostsFromURI) > 0 {
host, port, ok := parseHostPortWithDefault(hostsFromURI[0], defaultPort)
if ok {
config.Host = host
config.Port = port
}
}
if config.Topology == "" {
topology := strings.TrimSpace(parsed.Query().Get("topology"))
if topology != "" {
config.Topology = strings.ToLower(topology)
}
}
return config
}
func collectMySQLAddresses(config connection.ConnectionConfig) []string {
defaultPort := config.Port
if defaultPort <= 0 {
defaultPort = defaultMySQLPort
}
candidates := make([]string, 0, len(config.Hosts)+1)
if len(config.Hosts) > 0 {
candidates = append(candidates, config.Hosts...)
} else {
candidates = append(candidates, normalizeMySQLAddress(config.Host, defaultPort))
}
result := make([]string, 0, len(candidates))
seen := make(map[string]struct{}, len(candidates))
for _, entry := range candidates {
host, port, ok := parseHostPortWithDefault(entry, defaultPort)
if !ok {
continue
}
normalized := normalizeMySQLAddress(host, port)
if _, exists := seen[normalized]; exists {
continue
}
seen[normalized] = struct{}{}
result = append(result, normalized)
}
return result
}
func (m *MySQLDB) getDSN(config connection.ConnectionConfig) (string, error) {
database := config.Database
protocol := "tcp"
address := normalizeMySQLAddress(config.Host, config.Port)
if config.UseSSH {
netName, err := ssh.RegisterSSHNetwork(config.SSH)
if err != nil {
return "", fmt.Errorf("创建 SSH 隧道失败:%w", err)
}
protocol = netName
}
timeout := getConnectTimeoutSeconds(config)
tlsMode := resolveMySQLTLSMode(config)
return fmt.Sprintf(
"%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds&tls=%s&multiStatements=true",
config.User, config.Password, protocol, address, database, timeout, url.QueryEscape(tlsMode),
), nil
}
func resolveMySQLCredential(config connection.ConnectionConfig, addressIndex int) (string, string) {
primaryUser := strings.TrimSpace(config.User)
primaryPassword := config.Password
replicaUser := strings.TrimSpace(config.MySQLReplicaUser)
replicaPassword := config.MySQLReplicaPassword
if addressIndex > 0 && replicaUser != "" {
return replicaUser, replicaPassword
}
if primaryUser == "" && replicaUser != "" {
return replicaUser, replicaPassword
}
return config.User, primaryPassword
}
func (m *MySQLDB) Connect(config connection.ConnectionConfig) error {
runConfig := applyMySQLURI(config)
addresses := collectMySQLAddresses(runConfig)
if len(addresses) == 0 {
return fmt.Errorf("连接建立后验证失败:未找到可用的 MySQL 地址")
}
var errorDetails []string
for index, address := range addresses {
candidateConfig := runConfig
host, port, ok := parseHostPortWithDefault(address, defaultMySQLPort)
if !ok {
continue
}
candidateConfig.Host = host
candidateConfig.Port = port
candidateConfig.User, candidateConfig.Password = resolveMySQLCredential(runConfig, index)
dsn, err := m.getDSN(candidateConfig)
if err != nil {
errorDetails = append(errorDetails, fmt.Sprintf("%s 生成连接串失败: %v", address, err))
continue
}
db, err := sql.Open("mysql", dsn)
if err != nil {
errorDetails = append(errorDetails, fmt.Sprintf("%s 打开失败: %v", address, err))
continue
}
timeout := getConnectTimeout(candidateConfig)
ctx, cancel := utils.ContextWithTimeout(timeout)
pingErr := db.PingContext(ctx)
cancel()
if pingErr != nil {
_ = db.Close()
errorDetails = append(errorDetails, fmt.Sprintf("%s 验证失败: %v", address, pingErr))
continue
}
m.conn = db
m.pingTimeout = timeout
return nil
}
if len(errorDetails) == 0 {
return fmt.Errorf("连接建立后验证失败:未找到可用的 MySQL 地址")
}
return fmt.Errorf("连接建立后验证失败:%s", strings.Join(errorDetails, ""))
}
func (m *MySQLDB) Close() error {
if m.conn != nil {
return m.conn.Close()
}
return nil
}
func (m *MySQLDB) Ping() error {
if m.conn == nil {
return fmt.Errorf("连接未打开")
}
timeout := m.pingTimeout
if timeout <= 0 {
timeout = 5 * time.Second
}
ctx, cancel := utils.ContextWithTimeout(timeout)
defer cancel()
return m.conn.PingContext(ctx)
}
func (m *MySQLDB) QueryMulti(query string) ([]connection.ResultSetData, error) {
if m.conn == nil {
return nil, fmt.Errorf("连接未打开")
}
rows, err := m.conn.Query(query)
if err != nil {
return nil, err
}
defer rows.Close()
return scanMultiRows(rows)
}
func (m *MySQLDB) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) {
if m.conn == nil {
return nil, fmt.Errorf("连接未打开")
}
rows, err := m.conn.QueryContext(ctx, query)
if err != nil {
return nil, err
}
defer rows.Close()
return scanMultiRows(rows)
}
func (m *MySQLDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
if m.conn == nil {
return nil, nil, fmt.Errorf("连接未打开")
}
rows, err := m.conn.QueryContext(ctx, query)
if err != nil {
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
}
func (m *MySQLDB) Query(query string) ([]map[string]interface{}, []string, error) {
if m.conn == nil {
return nil, nil, fmt.Errorf("连接未打开")
}
rows, err := m.conn.Query(query)
if err != nil {
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
}
func (m *MySQLDB) ExecBatchContext(ctx context.Context, query string) (int64, error) {
if m.conn == nil {
return 0, fmt.Errorf("连接未打开")
}
res, err := m.conn.ExecContext(ctx, query)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
func (m *MySQLDB) ExecContext(ctx context.Context, query string) (int64, error) {
if m.conn == nil {
return 0, fmt.Errorf("连接未打开")
}
res, err := m.conn.ExecContext(ctx, query)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
func (m *MySQLDB) Exec(query string) (int64, error) {
if m.conn == nil {
return 0, fmt.Errorf("连接未打开")
}
res, err := m.conn.Exec(query)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
func (m *MySQLDB) GetDatabases() ([]string, error) {
return collectMySQLDatabaseNames(m.Query)
}
func (m *MySQLDB) GetTables(dbName string) ([]string, error) {
query := "SHOW TABLES"
if dbName != "" {
query = fmt.Sprintf("SHOW TABLES FROM `%s`", dbName)
}
data, _, err := m.Query(query)
if err != nil {
return nil, err
}
var tables []string
for _, row := range data {
for _, v := range row {
tables = append(tables, fmt.Sprintf("%v", v))
break
}
}
return tables, nil
}
func (m *MySQLDB) GetCreateStatement(dbName, tableName string) (string, error) {
query := fmt.Sprintf("SHOW CREATE TABLE `%s`.`%s`", dbName, tableName)
if dbName == "" {
query = fmt.Sprintf("SHOW CREATE TABLE `%s`", tableName)
}
data, _, err := m.Query(query)
if err != nil {
return "", err
}
if len(data) > 0 {
if val, ok := data[0]["Create Table"]; ok {
return fmt.Sprintf("%v", val), nil
}
}
return "", fmt.Errorf("未找到建表语句")
}
func (m *MySQLDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
query := fmt.Sprintf("SHOW FULL COLUMNS FROM `%s`.`%s`", dbName, tableName)
if dbName == "" {
query = fmt.Sprintf("SHOW FULL COLUMNS FROM `%s`", tableName)
}
data, _, err := m.Query(query)
if err != nil {
return nil, err
}
var columns []connection.ColumnDefinition
for _, row := range data {
col := connection.ColumnDefinition{
Name: fmt.Sprintf("%v", row["Field"]),
Type: fmt.Sprintf("%v", row["Type"]),
Nullable: fmt.Sprintf("%v", row["Null"]),
Key: fmt.Sprintf("%v", row["Key"]),
Extra: fmt.Sprintf("%v", row["Extra"]),
Comment: fmt.Sprintf("%v", row["Comment"]),
}
if row["Default"] != nil {
d := fmt.Sprintf("%v", row["Default"])
col.Default = &d
}
columns = append(columns, col)
}
return columns, nil
}
func (m *MySQLDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
query := fmt.Sprintf("SHOW INDEX FROM `%s`.`%s`", dbName, tableName)
if dbName == "" {
query = fmt.Sprintf("SHOW INDEX FROM `%s`", tableName)
}
data, _, err := m.Query(query)
if err != nil {
return nil, err
}
var indexes []connection.IndexDefinition
for _, row := range data {
nonUnique := 0
if val, ok := row["Non_unique"]; ok {
if f, ok := val.(float64); ok {
nonUnique = int(f)
} else if i, ok := val.(int64); ok {
nonUnique = int(i)
}
}
seq := 0
if val, ok := row["Seq_in_index"]; ok {
if f, ok := val.(float64); ok {
seq = int(f)
} else if i, ok := val.(int64); ok {
seq = int(i)
}
}
subPart := 0
if val, ok := row["Sub_part"]; ok && val != nil {
if f, ok := val.(float64); ok {
subPart = int(f)
} else if i, ok := val.(int64); ok {
subPart = int(i)
}
}
idx := connection.IndexDefinition{
Name: fmt.Sprintf("%v", row["Key_name"]),
ColumnName: fmt.Sprintf("%v", row["Column_name"]),
NonUnique: nonUnique,
SeqInIndex: seq,
IndexType: fmt.Sprintf("%v", row["Index_type"]),
SubPart: subPart,
}
indexes = append(indexes, idx)
}
return indexes, nil
}
func (m *MySQLDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
query := fmt.Sprintf(`SELECT CONSTRAINT_NAME, COLUMN_NAME, REFERENCED_TABLE_NAME, REFERENCED_COLUMN_NAME
FROM information_schema.KEY_COLUMN_USAGE
WHERE TABLE_SCHEMA = '%s' AND TABLE_NAME = '%s' AND REFERENCED_TABLE_NAME IS NOT NULL`, dbName, tableName)
data, _, err := m.Query(query)
if err != nil {
return nil, err
}
var fks []connection.ForeignKeyDefinition
for _, row := range data {
fk := connection.ForeignKeyDefinition{
Name: fmt.Sprintf("%v", row["CONSTRAINT_NAME"]),
ColumnName: fmt.Sprintf("%v", row["COLUMN_NAME"]),
RefTableName: fmt.Sprintf("%v", row["REFERENCED_TABLE_NAME"]),
RefColumnName: fmt.Sprintf("%v", row["REFERENCED_COLUMN_NAME"]),
ConstraintName: fmt.Sprintf("%v", row["CONSTRAINT_NAME"]),
}
fks = append(fks, fk)
}
return fks, nil
}
func (m *MySQLDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
query := fmt.Sprintf("SHOW TRIGGERS FROM `%s` WHERE `Table` = '%s'", dbName, tableName)
data, _, err := m.Query(query)
if err != nil {
return nil, err
}
var triggers []connection.TriggerDefinition
for _, row := range data {
trig := connection.TriggerDefinition{
Name: fmt.Sprintf("%v", row["Trigger"]),
Timing: fmt.Sprintf("%v", row["Timing"]),
Event: fmt.Sprintf("%v", row["Event"]),
Statement: fmt.Sprintf("%v", row["Statement"]),
}
triggers = append(triggers, trig)
}
return triggers, nil
}
func (m *MySQLDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
if m.conn == nil {
return fmt.Errorf("连接未打开")
}
columnTypeMap := m.loadColumnTypeMap(tableName)
tx, err := m.conn.Begin()
if err != nil {
return err
}
defer tx.Rollback()
// 1. Deletes
for _, pk := range changes.Deletes {
var wheres []string
var args []interface{}
for k, v := range pk {
wheres = append(wheres, fmt.Sprintf("`%s` = ?", k))
args = append(args, normalizeMySQLValueForWrite(k, v, columnTypeMap))
}
if len(wheres) == 0 {
continue
}
query := fmt.Sprintf("DELETE FROM `%s` WHERE %s", tableName, strings.Join(wheres, " AND "))
res, err := tx.Exec(query, args...)
if err != nil {
return fmt.Errorf("删除失败:%v", err)
}
if affected, err := res.RowsAffected(); err == nil && affected == 0 {
return fmt.Errorf("删除未生效:未匹配到任何行")
}
}
// 2. Updates
for _, update := range changes.Updates {
var sets []string
var args []interface{}
for k, v := range update.Values {
sets = append(sets, fmt.Sprintf("`%s` = ?", k))
args = append(args, normalizeMySQLValueForWrite(k, v, columnTypeMap))
}
if len(sets) == 0 {
continue
}
var wheres []string
for k, v := range update.Keys {
wheres = append(wheres, fmt.Sprintf("`%s` = ?", k))
args = append(args, normalizeMySQLValueForWrite(k, v, columnTypeMap))
}
if len(wheres) == 0 {
return fmt.Errorf("更新操作需要主键条件")
}
query := fmt.Sprintf("UPDATE `%s` SET %s WHERE %s", tableName, strings.Join(sets, ", "), strings.Join(wheres, " AND "))
res, err := tx.Exec(query, args...)
if err != nil {
return fmt.Errorf("更新失败:%v", err)
}
if affected, err := res.RowsAffected(); err == nil && affected == 0 {
return fmt.Errorf("更新未生效:未匹配到任何行")
}
}
// 3. Inserts
for _, row := range changes.Inserts {
var cols []string
var placeholders []string
var args []interface{}
for k, v := range row {
normalizedValue, omit := normalizeMySQLValueForInsert(k, v, columnTypeMap)
if omit {
continue
}
cols = append(cols, fmt.Sprintf("`%s`", k))
placeholders = append(placeholders, "?")
args = append(args, normalizedValue)
}
if len(cols) == 0 {
query := fmt.Sprintf("INSERT INTO `%s` () VALUES ()", tableName)
res, err := tx.Exec(query)
if err != nil {
return fmt.Errorf("插入失败:%v", err)
}
if affected, err := res.RowsAffected(); err == nil && affected == 0 {
return fmt.Errorf("插入未生效:未影响任何行")
}
continue
}
query := fmt.Sprintf("INSERT INTO `%s` (%s) VALUES (%s)", tableName, strings.Join(cols, ", "), strings.Join(placeholders, ", "))
res, err := tx.Exec(query, args...)
if err != nil {
return fmt.Errorf("插入失败:%v", err)
}
if affected, err := res.RowsAffected(); err == nil && affected == 0 {
return fmt.Errorf("插入未生效:未影响任何行")
}
}
return tx.Commit()
}
func normalizeMySQLComplexValue(value interface{}) interface{} {
switch v := value.(type) {
case map[string]interface{}, []interface{}:
if data, err := json.Marshal(v); err == nil {
return string(data)
}
return fmt.Sprintf("%v", value)
default:
return value
}
}
func normalizeMySQLDateTimeValue(value interface{}) interface{} {
text, ok := value.(string)
if !ok {
return value
}
raw := strings.TrimSpace(text)
if raw == "" {
return value
}
cleaned := strings.ReplaceAll(raw, "+ ", "+")
cleaned = strings.ReplaceAll(cleaned, "- ", "-")
if len(cleaned) >= 19 && cleaned[10] == 'T' {
if strings.HasSuffix(cleaned, "Z") || hasTimezoneOffset(cleaned) {
if t, err := time.Parse(time.RFC3339Nano, cleaned); err == nil {
return formatMySQLDateTime(t)
}
if t, err := time.Parse(time.RFC3339, cleaned); err == nil {
return formatMySQLDateTime(t)
}
}
return strings.Replace(cleaned, "T", " ", 1)
}
if strings.Contains(cleaned, " ") && (strings.HasSuffix(cleaned, "Z") || hasTimezoneOffset(cleaned)) {
candidate := strings.Replace(cleaned, " ", "T", 1)
if t, err := time.Parse(time.RFC3339Nano, candidate); err == nil {
return formatMySQLDateTime(t)
}
if t, err := time.Parse(time.RFC3339, candidate); err == nil {
return formatMySQLDateTime(t)
}
}
return value
}
func (m *MySQLDB) loadColumnTypeMap(tableName string) map[string]string {
result := map[string]string{}
table := strings.TrimSpace(tableName)
if table == "" {
return result
}
columns, err := m.GetColumns("", table)
if err != nil {
logger.Warnf("加载列元数据失败(不影响提交):表=%s err=%v", table, err)
return result
}
for _, col := range columns {
name := strings.ToLower(strings.TrimSpace(col.Name))
if name == "" {
continue
}
result[name] = strings.TrimSpace(col.Type)
}
return result
}
func normalizeMySQLValueForInsert(columnName string, value interface{}, columnTypeMap map[string]string) (interface{}, bool) {
columnType := strings.ToLower(strings.TrimSpace(columnTypeMap[strings.ToLower(strings.TrimSpace(columnName))]))
if isMySQLBitColumnType(columnType) {
return normalizeMySQLBitValue(value), false
}
if !isMySQLTemporalColumnType(columnType) {
return normalizeMySQLComplexValue(value), false
}
text, ok := value.(string)
if ok && strings.TrimSpace(text) == "" {
// INSERT 空时间字段不写入,交给 DB 默认值处理(如 CURRENT_TIMESTAMP
return nil, true
}
return normalizeMySQLDateTimeValue(value), false
}
func normalizeMySQLValueForWrite(columnName string, value interface{}, columnTypeMap map[string]string) interface{} {
columnType := strings.ToLower(strings.TrimSpace(columnTypeMap[strings.ToLower(strings.TrimSpace(columnName))]))
if isMySQLBitColumnType(columnType) {
return normalizeMySQLBitValue(value)
}
if !isMySQLTemporalColumnType(columnType) {
return value
}
text, ok := value.(string)
if ok && strings.TrimSpace(text) == "" {
return nil
}
return normalizeMySQLDateTimeValue(value)
}
func isMySQLTemporalColumnType(columnType string) bool {
raw := strings.ToLower(strings.TrimSpace(columnType))
if raw == "" {
return false
}
if strings.Contains(raw, "datetime") || strings.Contains(raw, "timestamp") {
return true
}
base := raw
if idx := strings.IndexAny(base, "( "); idx >= 0 {
base = base[:idx]
}
return base == "date" || base == "time" || base == "year"
}
func isMySQLBitColumnType(columnType string) bool {
raw := strings.ToLower(strings.TrimSpace(columnType))
if raw == "" {
return false
}
base := raw
if idx := strings.IndexAny(base, "( "); idx >= 0 {
base = base[:idx]
}
return base == "bit"
}
func normalizeMySQLBitValue(value interface{}) interface{} {
switch v := value.(type) {
case nil:
return nil
case []byte:
return v
case bool:
if v {
return []byte{1}
}
return []byte{0}
case string:
if bitValue, ok := parseMySQLBitString(v); ok {
return bitValue
}
return value
case int:
if v >= 0 {
if bitValue, ok := mysqlBitBytesFromUint64(uint64(v)); ok {
return bitValue
}
}
case int8:
if v >= 0 {
if bitValue, ok := mysqlBitBytesFromUint64(uint64(v)); ok {
return bitValue
}
}
case int16:
if v >= 0 {
if bitValue, ok := mysqlBitBytesFromUint64(uint64(v)); ok {
return bitValue
}
}
case int32:
if v >= 0 {
if bitValue, ok := mysqlBitBytesFromUint64(uint64(v)); ok {
return bitValue
}
}
case int64:
if v >= 0 {
if bitValue, ok := mysqlBitBytesFromUint64(uint64(v)); ok {
return bitValue
}
}
case uint:
if bitValue, ok := mysqlBitBytesFromUint64(uint64(v)); ok {
return bitValue
}
case uint8:
if bitValue, ok := mysqlBitBytesFromUint64(uint64(v)); ok {
return bitValue
}
case uint16:
if bitValue, ok := mysqlBitBytesFromUint64(uint64(v)); ok {
return bitValue
}
case uint32:
if bitValue, ok := mysqlBitBytesFromUint64(uint64(v)); ok {
return bitValue
}
case uint64:
if bitValue, ok := mysqlBitBytesFromUint64(v); ok {
return bitValue
}
case float32:
if v >= 0 && math.Trunc(float64(v)) == float64(v) {
if bitValue, ok := mysqlBitBytesFromUint64(uint64(v)); ok {
return bitValue
}
}
case float64:
if v >= 0 && math.Trunc(v) == v {
if bitValue, ok := mysqlBitBytesFromUint64(uint64(v)); ok {
return bitValue
}
}
}
return value
}
func parseMySQLBitString(text string) ([]byte, bool) {
raw := strings.TrimSpace(text)
if raw == "" {
return nil, false
}
switch strings.ToLower(raw) {
case "true":
return []byte{1}, true
case "false":
return []byte{0}, true
}
if len(raw) > 3 && (raw[0] == 'b' || raw[0] == 'B') && raw[1] == '\'' && raw[len(raw)-1] == '\'' {
value, err := strconv.ParseUint(raw[2:len(raw)-1], 2, 64)
if err == nil {
return mysqlBitBytesFromUint64OrZero(value), true
}
return nil, false
}
if len(raw) > 2 && (strings.HasPrefix(raw, "0b") || strings.HasPrefix(raw, "0B")) {
value, err := strconv.ParseUint(raw[2:], 2, 64)
if err == nil {
return mysqlBitBytesFromUint64OrZero(value), true
}
return nil, false
}
value, err := strconv.ParseUint(raw, 10, 64)
if err != nil {
return nil, false
}
return mysqlBitBytesFromUint64OrZero(value), true
}
func mysqlBitBytesFromUint64(value uint64) ([]byte, bool) {
return mysqlBitBytesFromUint64OrZero(value), true
}
func mysqlBitBytesFromUint64OrZero(value uint64) []byte {
if value == 0 {
return []byte{0}
}
var buf [8]byte
index := len(buf)
for value > 0 {
index--
buf[index] = byte(value)
value >>= 8
}
return append([]byte(nil), buf[index:]...)
}
func hasTimezoneOffset(text string) bool {
pos := strings.LastIndexAny(text, "+-")
if pos < 0 || pos < 10 || pos+1 >= len(text) {
return false
}
offset := text[pos+1:]
if len(offset) == 5 && offset[2] == ':' {
return isAllDigits(offset[:2]) && isAllDigits(offset[3:])
}
if len(offset) == 4 {
return isAllDigits(offset)
}
return false
}
func isAllDigits(text string) bool {
if text == "" {
return false
}
for _, r := range text {
if r < '0' || r > '9' {
return false
}
}
return true
}
func formatMySQLDateTime(t time.Time) string {
base := t.Format("2006-01-02 15:04:05")
nanos := t.Nanosecond()
if nanos == 0 {
return base
}
micro := nanos / 1000
return fmt.Sprintf("%s.%06d", base, micro)
}
func (m *MySQLDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
query := fmt.Sprintf("SELECT TABLE_NAME, COLUMN_NAME, COLUMN_TYPE FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = '%s'", dbName)
if dbName == "" {
return nil, fmt.Errorf("获取全部列信息需要指定数据库名称")
}
data, _, err := m.Query(query)
if err != nil {
return nil, err
}
var cols []connection.ColumnDefinitionWithTable
for _, row := range data {
col := connection.ColumnDefinitionWithTable{
TableName: fmt.Sprintf("%v", row["TABLE_NAME"]),
Name: fmt.Sprintf("%v", row["COLUMN_NAME"]),
Type: fmt.Sprintf("%v", row["COLUMN_TYPE"]),
}
cols = append(cols, col)
}
return cols, nil
}