package db import ( "context" "database/sql" "fmt" "strings" "time" "GoNavi-Wails/internal/connection" "GoNavi-Wails/internal/utils" _ "modernc.org/sqlite" ) type SQLiteDB struct { conn *sql.DB pingTimeout time.Duration } func (s *SQLiteDB) Connect(config connection.ConnectionConfig) error { dsn := config.Host db, err := sql.Open("sqlite", dsn) if err != nil { return fmt.Errorf("打开数据库连接失败:%w", err) } s.conn = db s.pingTimeout = getConnectTimeout(config) // Force verification if err := s.Ping(); err != nil { return fmt.Errorf("连接建立后验证失败:%w", err) } return nil } func (s *SQLiteDB) Close() error { if s.conn != nil { return s.conn.Close() } return nil } func (s *SQLiteDB) Ping() error { if s.conn == nil { return fmt.Errorf("connection not open") } timeout := s.pingTimeout if timeout <= 0 { timeout = 5 * time.Second } ctx, cancel := utils.ContextWithTimeout(timeout) defer cancel() return s.conn.PingContext(ctx) } func (s *SQLiteDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) { if s.conn == nil { return nil, nil, fmt.Errorf("connection not open") } rows, err := s.conn.QueryContext(ctx, query) if err != nil { return nil, nil, err } defer rows.Close() return scanRows(rows) } func (s *SQLiteDB) Query(query string) ([]map[string]interface{}, []string, error) { if s.conn == nil { return nil, nil, fmt.Errorf("connection not open") } rows, err := s.conn.Query(query) if err != nil { return nil, nil, err } defer rows.Close() columns, err := rows.Columns() if err != nil { return nil, nil, err } var resultData []map[string]interface{} for rows.Next() { values := make([]interface{}, len(columns)) valuePtrs := make([]interface{}, len(columns)) for i := range columns { valuePtrs[i] = &values[i] } if err := rows.Scan(valuePtrs...); err != nil { continue } entry := make(map[string]interface{}) for i, col := range columns { entry[col] = normalizeQueryValue(values[i]) } resultData = append(resultData, entry) } return resultData, columns, nil } func (s *SQLiteDB) ExecContext(ctx context.Context, query string) (int64, error) { if s.conn == nil { return 0, fmt.Errorf("connection not open") } res, err := s.conn.ExecContext(ctx, query) if err != nil { return 0, err } return res.RowsAffected() } func (s *SQLiteDB) Exec(query string) (int64, error) { if s.conn == nil { return 0, fmt.Errorf("connection not open") } res, err := s.conn.Exec(query) if err != nil { return 0, err } return res.RowsAffected() } func (s *SQLiteDB) GetDatabases() ([]string, error) { return []string{"main"}, nil } func (s *SQLiteDB) GetTables(dbName string) ([]string, error) { query := "SELECT name FROM sqlite_master WHERE type='table'" data, _, err := s.Query(query) if err != nil { return nil, err } var tables []string for _, row := range data { if val, ok := row["name"]; ok { tables = append(tables, fmt.Sprintf("%v", val)) } } return tables, nil } func (s *SQLiteDB) GetCreateStatement(dbName, tableName string) (string, error) { query := fmt.Sprintf("SELECT sql FROM sqlite_master WHERE type='table' AND name='%s'", tableName) data, _, err := s.Query(query) if err != nil { return "", err } if len(data) > 0 { if val, ok := data[0]["sql"]; ok { return fmt.Sprintf("%v", val), nil } } return "", fmt.Errorf("create statement not found") } func (s *SQLiteDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) { table := strings.TrimSpace(tableName) if table == "" { return nil, fmt.Errorf("table name required") } esc := func(v string) string { return strings.ReplaceAll(v, "'", "''") } // cid, name, type, notnull, dflt_value, pk data, _, err := s.Query(fmt.Sprintf("PRAGMA table_info('%s')", esc(table))) if err != nil { return nil, err } parseInt := func(v interface{}) int { switch val := v.(type) { case int: return val case int64: return int(val) case float64: return int(val) case string: var n int _, _ = fmt.Sscanf(strings.TrimSpace(val), "%d", &n) return n default: var n int _, _ = fmt.Sscanf(strings.TrimSpace(fmt.Sprintf("%v", v)), "%d", &n) return n } } getStr := func(row map[string]interface{}, key string) string { if v, ok := row[key]; ok && v != nil { return fmt.Sprintf("%v", v) } if v, ok := row[strings.ToUpper(key)]; ok && v != nil { return fmt.Sprintf("%v", v) } return "" } var columns []connection.ColumnDefinition for _, row := range data { notnull := 0 if v, ok := row["notnull"]; ok && v != nil { notnull = parseInt(v) } else if v, ok := row["NOTNULL"]; ok && v != nil { notnull = parseInt(v) } pk := 0 if v, ok := row["pk"]; ok && v != nil { pk = parseInt(v) } else if v, ok := row["PK"]; ok && v != nil { pk = parseInt(v) } nullable := "YES" if notnull == 1 { nullable = "NO" } key := "" if pk == 1 { key = "PRI" } col := connection.ColumnDefinition{ Name: getStr(row, "name"), Type: getStr(row, "type"), Nullable: nullable, Key: key, Extra: "", Comment: "", } if v, ok := row["dflt_value"]; ok && v != nil { def := fmt.Sprintf("%v", v) col.Default = &def } else if v, ok := row["DFLT_VALUE"]; ok && v != nil { def := fmt.Sprintf("%v", v) col.Default = &def } columns = append(columns, col) } return columns, nil } func (s *SQLiteDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) { table := strings.TrimSpace(tableName) if table == "" { return nil, fmt.Errorf("table name required") } esc := func(v string) string { return strings.ReplaceAll(v, "'", "''") } parseInt := func(v interface{}) int { switch val := v.(type) { case int: return val case int64: return int(val) case float64: return int(val) case string: var n int _, _ = fmt.Sscanf(strings.TrimSpace(val), "%d", &n) return n default: var n int _, _ = fmt.Sscanf(strings.TrimSpace(fmt.Sprintf("%v", v)), "%d", &n) return n } } data, _, err := s.Query(fmt.Sprintf("PRAGMA index_list('%s')", esc(table))) if err != nil { return nil, err } var indexes []connection.IndexDefinition for _, row := range data { indexName := "" if v, ok := row["name"]; ok && v != nil { indexName = fmt.Sprintf("%v", v) } else if v, ok := row["NAME"]; ok && v != nil { indexName = fmt.Sprintf("%v", v) } if strings.TrimSpace(indexName) == "" { continue } unique := 0 if v, ok := row["unique"]; ok && v != nil { unique = parseInt(v) } else if v, ok := row["UNIQUE"]; ok && v != nil { unique = parseInt(v) } nonUnique := 1 if unique == 1 { nonUnique = 0 } cols, _, err := s.Query(fmt.Sprintf("PRAGMA index_info('%s')", esc(indexName))) if err != nil { // skip broken index continue } for _, c := range cols { colName := "" if v, ok := c["name"]; ok && v != nil { colName = fmt.Sprintf("%v", v) } else if v, ok := c["NAME"]; ok && v != nil { colName = fmt.Sprintf("%v", v) } if strings.TrimSpace(colName) == "" { continue } seq := 0 if v, ok := c["seqno"]; ok && v != nil { seq = parseInt(v) + 1 } else if v, ok := c["SEQNO"]; ok && v != nil { seq = parseInt(v) + 1 } indexes = append(indexes, connection.IndexDefinition{ Name: indexName, ColumnName: colName, NonUnique: nonUnique, SeqInIndex: seq, IndexType: "BTREE", }) } } return indexes, nil } func (s *SQLiteDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) { table := strings.TrimSpace(tableName) if table == "" { return nil, fmt.Errorf("table name required") } esc := func(v string) string { return strings.ReplaceAll(v, "'", "''") } data, _, err := s.Query(fmt.Sprintf("PRAGMA foreign_key_list('%s')", esc(table))) if err != nil { return nil, err } parseInt := func(v interface{}) int { switch val := v.(type) { case int: return val case int64: return int(val) case float64: return int(val) case string: var n int _, _ = fmt.Sscanf(strings.TrimSpace(val), "%d", &n) return n default: var n int _, _ = fmt.Sscanf(strings.TrimSpace(fmt.Sprintf("%v", v)), "%d", &n) return n } } var fks []connection.ForeignKeyDefinition for _, row := range data { id := 0 if v, ok := row["id"]; ok && v != nil { id = parseInt(v) } else if v, ok := row["ID"]; ok && v != nil { id = parseInt(v) } refTable := "" if v, ok := row["table"]; ok && v != nil { refTable = fmt.Sprintf("%v", v) } else if v, ok := row["TABLE"]; ok && v != nil { refTable = fmt.Sprintf("%v", v) } fromCol := "" if v, ok := row["from"]; ok && v != nil { fromCol = fmt.Sprintf("%v", v) } else if v, ok := row["FROM"]; ok && v != nil { fromCol = fmt.Sprintf("%v", v) } toCol := "" if v, ok := row["to"]; ok && v != nil { toCol = fmt.Sprintf("%v", v) } else if v, ok := row["TO"]; ok && v != nil { toCol = fmt.Sprintf("%v", v) } name := fmt.Sprintf("fk_%s_%d", table, id) fks = append(fks, connection.ForeignKeyDefinition{ Name: name, ColumnName: fromCol, RefTableName: refTable, RefColumnName: toCol, ConstraintName: name, }) } return fks, nil } func (s *SQLiteDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) { table := strings.TrimSpace(tableName) if table == "" { return nil, fmt.Errorf("table name required") } esc := func(v string) string { return strings.ReplaceAll(v, "'", "''") } data, _, err := s.Query(fmt.Sprintf("SELECT name AS trigger_name, sql AS statement FROM sqlite_master WHERE type='trigger' AND tbl_name='%s' ORDER BY name", esc(table))) if err != nil { return nil, err } var triggers []connection.TriggerDefinition for _, row := range data { name := fmt.Sprintf("%v", row["trigger_name"]) stmt := "" if v, ok := row["statement"]; ok && v != nil { stmt = fmt.Sprintf("%v", v) } upper := strings.ToUpper(stmt) timing := "" switch { case strings.Contains(upper, " BEFORE "): timing = "BEFORE" case strings.Contains(upper, " AFTER "): timing = "AFTER" case strings.Contains(upper, " INSTEAD OF "): timing = "INSTEAD OF" } event := "" switch { case strings.Contains(upper, " INSERT "): event = "INSERT" case strings.Contains(upper, " UPDATE "): event = "UPDATE" case strings.Contains(upper, " DELETE "): event = "DELETE" } triggers = append(triggers, connection.TriggerDefinition{ Name: name, Timing: timing, Event: event, Statement: stmt, }) } return triggers, nil } func (s *SQLiteDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) { tables, err := s.GetTables(dbName) if err != nil { return nil, err } var cols []connection.ColumnDefinitionWithTable for _, table := range tables { // Skip internal tables if strings.HasPrefix(strings.ToLower(table), "sqlite_") { continue } columns, err := s.GetColumns("", table) if err != nil { continue } for _, col := range columns { cols = append(cols, connection.ColumnDefinitionWithTable{ TableName: table, Name: col.Name, Type: col.Type, }) } } return cols, nil }