//go:build gonavi_full_drivers || gonavi_sqlite_driver package db import ( "context" "database/sql" "fmt" "os" "path/filepath" "strconv" "strings" "time" "GoNavi-Wails/internal/connection" "GoNavi-Wails/internal/utils" _ "modernc.org/sqlite" ) type SQLiteDB struct { conn *sql.DB pingTimeout time.Duration } func (s *SQLiteDB) Connect(config connection.ConnectionConfig) error { dsn, err := resolveSQLiteDSN(config) if err != nil { return err } if err := ensureSQLiteParentDir(dsn); err != nil { return err } db, err := sql.Open("sqlite", dsn) if err != nil { return fmt.Errorf("打开数据库连接失败:%w", err) } s.conn = db s.pingTimeout = getConnectTimeout(config) // Force verification if err := s.Ping(); err != nil { _ = db.Close() s.conn = nil return fmt.Errorf("连接建立后验证失败:%w", err) } return nil } func resolveSQLiteDSN(config connection.ConnectionConfig) (string, error) { dsn := strings.TrimSpace(config.Host) if dsn == "" { dsn = strings.TrimSpace(config.Database) } dsn = normalizeSQLitePath(dsn) if dsn == "" { return "", fmt.Errorf("SQLite 需要本地数据库文件路径(例如 /path/to/demo.sqlite)") } if strings.EqualFold(dsn, ":memory:") { return dsn, nil } if looksLikeHostPort(dsn) { return "", fmt.Errorf("SQLite 需要本地数据库文件路径,当前输入看起来是主机地址:%s", dsn) } return dsn, nil } func normalizeSQLitePath(raw string) string { text := strings.TrimSpace(raw) if strings.HasPrefix(text, "/") && len(text) > 3 && isWindowsDrivePath(text[1:]) { text = text[1:] } if isWindowsDrivePath(text) { text = trimLegacyPortSuffix(text) } return text } func isWindowsDrivePath(path string) bool { if len(path) < 3 { return false } drive := path[0] if !((drive >= 'a' && drive <= 'z') || (drive >= 'A' && drive <= 'Z')) { return false } if path[1] != ':' { return false } sep := path[2] return sep == '\\' || sep == '/' } func trimLegacyPortSuffix(path string) string { normalized := path for { idx := strings.LastIndex(normalized, ":") if idx <= 1 || idx+1 >= len(normalized) { return normalized } suffix := normalized[idx+1:] validDigits := true for _, ch := range suffix { if ch < '0' || ch > '9' { validDigits = false break } } if !validDigits { return normalized } normalized = normalized[:idx] } } func looksLikeHostPort(raw string) bool { text := strings.TrimSpace(raw) if text == "" { return false } if strings.ContainsAny(text, `/\`) { return false } if strings.HasPrefix(strings.ToLower(text), "file:") { return false } if strings.HasPrefix(text, "[") { closing := strings.LastIndex(text, "]") if closing <= 0 || closing+1 >= len(text) { return false } portText := strings.TrimSpace(strings.TrimPrefix(text[closing+1:], ":")) return isValidPortText(portText) } if strings.Count(text, ":") != 1 { return false } split := strings.LastIndex(text, ":") if split <= 0 || split+1 >= len(text) { return false } return isValidPortText(strings.TrimSpace(text[split+1:])) } func isValidPortText(text string) bool { port, err := strconv.Atoi(text) return err == nil && port > 0 && port <= 65535 } func ensureSQLiteParentDir(dsn string) error { text := strings.TrimSpace(dsn) if text == "" || strings.EqualFold(text, ":memory:") { return nil } // file: URI 由驱动处理,避免在这里误判路径格式。 if strings.HasPrefix(strings.ToLower(text), "file:") { return nil } path := text if idx := strings.Index(path, "?"); idx >= 0 { path = path[:idx] } path = strings.TrimSpace(path) if path == "" { return nil } dir := filepath.Dir(path) if dir == "." || dir == "" { return nil } if err := os.MkdirAll(dir, 0o755); err != nil { return fmt.Errorf("创建 SQLite 数据文件目录失败:%w", err) } return nil } func (s *SQLiteDB) Close() error { if s.conn != nil { return s.conn.Close() } return nil } func (s *SQLiteDB) Ping() error { if s.conn == nil { return fmt.Errorf("连接未打开") } timeout := s.pingTimeout if timeout <= 0 { timeout = 5 * time.Second } ctx, cancel := utils.ContextWithTimeout(timeout) defer cancel() return s.conn.PingContext(ctx) } func (s *SQLiteDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) { if s.conn == nil { return nil, nil, fmt.Errorf("连接未打开") } rows, err := s.conn.QueryContext(ctx, query) if err != nil { return nil, nil, err } defer rows.Close() return scanRows(rows) } func (s *SQLiteDB) Query(query string) ([]map[string]interface{}, []string, error) { if s.conn == nil { return nil, nil, fmt.Errorf("连接未打开") } rows, err := s.conn.Query(query) if err != nil { return nil, nil, err } defer rows.Close() return scanRows(rows) } func (s *SQLiteDB) ExecContext(ctx context.Context, query string) (int64, error) { if s.conn == nil { return 0, fmt.Errorf("连接未打开") } res, err := s.conn.ExecContext(ctx, query) if err != nil { return 0, err } return res.RowsAffected() } func (s *SQLiteDB) Exec(query string) (int64, error) { if s.conn == nil { return 0, fmt.Errorf("连接未打开") } res, err := s.conn.Exec(query) if err != nil { return 0, err } return res.RowsAffected() } func (s *SQLiteDB) GetDatabases() ([]string, error) { return []string{"main"}, nil } func (s *SQLiteDB) GetTables(dbName string) ([]string, error) { query := "SELECT name FROM sqlite_master WHERE type='table'" data, _, err := s.Query(query) if err != nil { return nil, err } var tables []string for _, row := range data { if val, ok := row["name"]; ok { tables = append(tables, fmt.Sprintf("%v", val)) } } return tables, nil } func (s *SQLiteDB) GetCreateStatement(dbName, tableName string) (string, error) { query := fmt.Sprintf("SELECT sql FROM sqlite_master WHERE type='table' AND name='%s'", tableName) data, _, err := s.Query(query) if err != nil { return "", err } if len(data) > 0 { if val, ok := data[0]["sql"]; ok { return fmt.Sprintf("%v", val), nil } } return "", fmt.Errorf("未找到建表语句") } func (s *SQLiteDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) { table := strings.TrimSpace(tableName) if table == "" { return nil, fmt.Errorf("表名不能为空") } esc := func(v string) string { return strings.ReplaceAll(v, "'", "''") } // cid, name, type, notnull, dflt_value, pk data, _, err := s.Query(fmt.Sprintf("PRAGMA table_info('%s')", esc(table))) if err != nil { return nil, err } parseInt := func(v interface{}) int { switch val := v.(type) { case int: return val case int64: return int(val) case float64: return int(val) case string: var n int _, _ = fmt.Sscanf(strings.TrimSpace(val), "%d", &n) return n default: var n int _, _ = fmt.Sscanf(strings.TrimSpace(fmt.Sprintf("%v", v)), "%d", &n) return n } } getStr := func(row map[string]interface{}, key string) string { if v, ok := row[key]; ok && v != nil { return fmt.Sprintf("%v", v) } if v, ok := row[strings.ToUpper(key)]; ok && v != nil { return fmt.Sprintf("%v", v) } return "" } var columns []connection.ColumnDefinition for _, row := range data { notnull := 0 if v, ok := row["notnull"]; ok && v != nil { notnull = parseInt(v) } else if v, ok := row["NOTNULL"]; ok && v != nil { notnull = parseInt(v) } pk := 0 if v, ok := row["pk"]; ok && v != nil { pk = parseInt(v) } else if v, ok := row["PK"]; ok && v != nil { pk = parseInt(v) } nullable := "YES" if notnull == 1 { nullable = "NO" } key := "" if pk == 1 { key = "PRI" } col := connection.ColumnDefinition{ Name: getStr(row, "name"), Type: getStr(row, "type"), Nullable: nullable, Key: key, Extra: "", Comment: "", } if v, ok := row["dflt_value"]; ok && v != nil { def := fmt.Sprintf("%v", v) col.Default = &def } else if v, ok := row["DFLT_VALUE"]; ok && v != nil { def := fmt.Sprintf("%v", v) col.Default = &def } columns = append(columns, col) } return columns, nil } func (s *SQLiteDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) { table := strings.TrimSpace(tableName) if table == "" { return nil, fmt.Errorf("表名不能为空") } esc := func(v string) string { return strings.ReplaceAll(v, "'", "''") } parseInt := func(v interface{}) int { switch val := v.(type) { case int: return val case int64: return int(val) case float64: return int(val) case string: var n int _, _ = fmt.Sscanf(strings.TrimSpace(val), "%d", &n) return n default: var n int _, _ = fmt.Sscanf(strings.TrimSpace(fmt.Sprintf("%v", v)), "%d", &n) return n } } data, _, err := s.Query(fmt.Sprintf("PRAGMA index_list('%s')", esc(table))) if err != nil { return nil, err } var indexes []connection.IndexDefinition for _, row := range data { indexName := "" if v, ok := row["name"]; ok && v != nil { indexName = fmt.Sprintf("%v", v) } else if v, ok := row["NAME"]; ok && v != nil { indexName = fmt.Sprintf("%v", v) } if strings.TrimSpace(indexName) == "" { continue } unique := 0 if v, ok := row["unique"]; ok && v != nil { unique = parseInt(v) } else if v, ok := row["UNIQUE"]; ok && v != nil { unique = parseInt(v) } nonUnique := 1 if unique == 1 { nonUnique = 0 } cols, _, err := s.Query(fmt.Sprintf("PRAGMA index_info('%s')", esc(indexName))) if err != nil { // skip broken index continue } for _, c := range cols { colName := "" if v, ok := c["name"]; ok && v != nil { colName = fmt.Sprintf("%v", v) } else if v, ok := c["NAME"]; ok && v != nil { colName = fmt.Sprintf("%v", v) } if strings.TrimSpace(colName) == "" { continue } seq := 0 if v, ok := c["seqno"]; ok && v != nil { seq = parseInt(v) + 1 } else if v, ok := c["SEQNO"]; ok && v != nil { seq = parseInt(v) + 1 } indexes = append(indexes, connection.IndexDefinition{ Name: indexName, ColumnName: colName, NonUnique: nonUnique, SeqInIndex: seq, IndexType: "BTREE", }) } } return indexes, nil } func (s *SQLiteDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) { table := strings.TrimSpace(tableName) if table == "" { return nil, fmt.Errorf("表名不能为空") } esc := func(v string) string { return strings.ReplaceAll(v, "'", "''") } data, _, err := s.Query(fmt.Sprintf("PRAGMA foreign_key_list('%s')", esc(table))) if err != nil { return nil, err } parseInt := func(v interface{}) int { switch val := v.(type) { case int: return val case int64: return int(val) case float64: return int(val) case string: var n int _, _ = fmt.Sscanf(strings.TrimSpace(val), "%d", &n) return n default: var n int _, _ = fmt.Sscanf(strings.TrimSpace(fmt.Sprintf("%v", v)), "%d", &n) return n } } var fks []connection.ForeignKeyDefinition for _, row := range data { id := 0 if v, ok := row["id"]; ok && v != nil { id = parseInt(v) } else if v, ok := row["ID"]; ok && v != nil { id = parseInt(v) } refTable := "" if v, ok := row["table"]; ok && v != nil { refTable = fmt.Sprintf("%v", v) } else if v, ok := row["TABLE"]; ok && v != nil { refTable = fmt.Sprintf("%v", v) } fromCol := "" if v, ok := row["from"]; ok && v != nil { fromCol = fmt.Sprintf("%v", v) } else if v, ok := row["FROM"]; ok && v != nil { fromCol = fmt.Sprintf("%v", v) } toCol := "" if v, ok := row["to"]; ok && v != nil { toCol = fmt.Sprintf("%v", v) } else if v, ok := row["TO"]; ok && v != nil { toCol = fmt.Sprintf("%v", v) } name := fmt.Sprintf("fk_%s_%d", table, id) fks = append(fks, connection.ForeignKeyDefinition{ Name: name, ColumnName: fromCol, RefTableName: refTable, RefColumnName: toCol, ConstraintName: name, }) } return fks, nil } func (s *SQLiteDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) { table := strings.TrimSpace(tableName) if table == "" { return nil, fmt.Errorf("表名不能为空") } esc := func(v string) string { return strings.ReplaceAll(v, "'", "''") } data, _, err := s.Query(fmt.Sprintf("SELECT name AS trigger_name, sql AS statement FROM sqlite_master WHERE type='trigger' AND tbl_name='%s' ORDER BY name", esc(table))) if err != nil { return nil, err } var triggers []connection.TriggerDefinition for _, row := range data { name := fmt.Sprintf("%v", row["trigger_name"]) stmt := "" if v, ok := row["statement"]; ok && v != nil { stmt = fmt.Sprintf("%v", v) } upper := strings.ToUpper(stmt) timing := "" switch { case strings.Contains(upper, " BEFORE "): timing = "BEFORE" case strings.Contains(upper, " AFTER "): timing = "AFTER" case strings.Contains(upper, " INSTEAD OF "): timing = "INSTEAD OF" } event := "" switch { case strings.Contains(upper, " INSERT "): event = "INSERT" case strings.Contains(upper, " UPDATE "): event = "UPDATE" case strings.Contains(upper, " DELETE "): event = "DELETE" } triggers = append(triggers, connection.TriggerDefinition{ Name: name, Timing: timing, Event: event, Statement: stmt, }) } return triggers, nil } func (s *SQLiteDB) ApplyChanges(tableName string, changes connection.ChangeSet) error { if s.conn == nil { return fmt.Errorf("连接未打开") } tx, err := s.conn.Begin() if err != nil { return err } defer tx.Rollback() quoteIdent := func(name string) string { n := strings.TrimSpace(name) n = strings.Trim(n, "\"") n = strings.ReplaceAll(n, "\"", "\"\"") if n == "" { return "\"\"" } return `"` + n + `"` } schema := "" table := strings.TrimSpace(tableName) if parts := strings.SplitN(table, ".", 2); len(parts) == 2 { schema = strings.TrimSpace(parts[0]) table = strings.TrimSpace(parts[1]) } qualifiedTable := "" if schema != "" { qualifiedTable = fmt.Sprintf("%s.%s", quoteIdent(schema), quoteIdent(table)) } else { qualifiedTable = quoteIdent(table) } // 1. Deletes for _, pk := range changes.Deletes { var wheres []string var args []interface{} for k, v := range pk { wheres = append(wheres, fmt.Sprintf("%s = ?", quoteIdent(k))) args = append(args, v) } if len(wheres) == 0 { continue } query := fmt.Sprintf("DELETE FROM %s WHERE %s", qualifiedTable, strings.Join(wheres, " AND ")) if _, err := tx.Exec(query, args...); err != nil { return fmt.Errorf("删除失败:%v", err) } } // 2. Updates for _, update := range changes.Updates { var sets []string var args []interface{} for k, v := range update.Values { sets = append(sets, fmt.Sprintf("%s = ?", quoteIdent(k))) args = append(args, v) } if len(sets) == 0 { continue } var wheres []string for k, v := range update.Keys { wheres = append(wheres, fmt.Sprintf("%s = ?", quoteIdent(k))) args = append(args, v) } if len(wheres) == 0 { return fmt.Errorf("更新操作需要主键条件") } query := fmt.Sprintf("UPDATE %s SET %s WHERE %s", qualifiedTable, strings.Join(sets, ", "), strings.Join(wheres, " AND ")) if _, err := tx.Exec(query, args...); err != nil { return fmt.Errorf("更新失败:%v", err) } } // 3. Inserts for _, row := range changes.Inserts { var cols []string var placeholders []string var args []interface{} for k, v := range row { cols = append(cols, quoteIdent(k)) placeholders = append(placeholders, "?") args = append(args, v) } if len(cols) == 0 { continue } query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(cols, ", "), strings.Join(placeholders, ", ")) if _, err := tx.Exec(query, args...); err != nil { return fmt.Errorf("插入失败:%v", err) } } return tx.Commit() } func (s *SQLiteDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) { tables, err := s.GetTables(dbName) if err != nil { return nil, err } var cols []connection.ColumnDefinitionWithTable for _, table := range tables { // Skip internal tables if strings.HasPrefix(strings.ToLower(table), "sqlite_") { continue } columns, err := s.GetColumns("", table) if err != nil { continue } for _, col := range columns { cols = append(cols, connection.ColumnDefinitionWithTable{ TableName: table, Name: col.Name, Type: col.Type, }) } } return cols, nil }