package storage import ( "database/sql" "fmt" "log" "math/rand" "time" _ "github.com/mattn/go-sqlite3" ) type Proxy struct { ID int64 `json:"id"` Address string `json:"address"` Protocol string `json:"protocol"` ExitIP string `json:"exit_ip"` ExitLocation string `json:"exit_location"` Latency int `json:"latency"` QualityGrade string `json:"quality_grade"` UseCount int `json:"use_count"` SuccessCount int `json:"success_count"` FailCount int `json:"fail_count"` LastUsed time.Time `json:"last_used"` LastCheck time.Time `json:"last_check"` CreatedAt time.Time `json:"created_at"` Status string `json:"status"` } // SourceStatus 代理源状态 type SourceStatus struct { ID int64 URL string SuccessCount int FailCount int ConsecutiveFails int LastSuccess time.Time LastFail time.Time Status string // active/degraded/disabled DisabledUntil time.Time } type Storage struct { db *sql.DB } func New(dbPath string) (*Storage, error) { db, err := sql.Open("sqlite3", dbPath) if err != nil { return nil, fmt.Errorf("open db: %w", err) } db.SetMaxOpenConns(1) // SQLite 单写 s := &Storage{db: db} if err := s.initSchema(); err != nil { return nil, err } return s, nil } func (s *Storage) initSchema() error { // 创建代理表 _, err := s.db.Exec(` CREATE TABLE IF NOT EXISTS proxies ( id INTEGER PRIMARY KEY AUTOINCREMENT, address TEXT NOT NULL UNIQUE, protocol TEXT NOT NULL, exit_ip TEXT NOT NULL DEFAULT '', exit_location TEXT NOT NULL DEFAULT '', latency INTEGER NOT NULL DEFAULT 0, quality_grade TEXT NOT NULL DEFAULT 'C', use_count INTEGER NOT NULL DEFAULT 0, success_count INTEGER NOT NULL DEFAULT 0, fail_count INTEGER NOT NULL DEFAULT 0, last_used DATETIME, last_check DATETIME, created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, status TEXT NOT NULL DEFAULT 'active' ) `) if err != nil { return err } // 创建索引 s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_protocol_latency ON proxies(protocol, latency)`) s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_quality_grade ON proxies(quality_grade, latency)`) s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_status ON proxies(status)`) // 创建源状态表 _, err = s.db.Exec(` CREATE TABLE IF NOT EXISTS source_status ( id INTEGER PRIMARY KEY AUTOINCREMENT, url TEXT NOT NULL UNIQUE, success_count INTEGER NOT NULL DEFAULT 0, fail_count INTEGER NOT NULL DEFAULT 0, consecutive_fails INTEGER NOT NULL DEFAULT 0, last_success DATETIME, last_fail DATETIME, status TEXT NOT NULL DEFAULT 'active', disabled_until DATETIME ) `) if err != nil { return err } // 迁移:处理旧的 location 字段(如果存在) var hasOldLocation int err = s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('proxies') WHERE name='location'`).Scan(&hasOldLocation) if err == nil && hasOldLocation > 0 { log.Println("[storage] migrating: renaming location to exit_location") // 如果有旧的 location 字段,先添加新字段再复制数据 s.db.Exec(`ALTER TABLE proxies ADD COLUMN exit_location TEXT NOT NULL DEFAULT ''`) s.db.Exec(`UPDATE proxies SET exit_location = location WHERE location != ''`) } // 迁移:添加 exit_ip 字段 var hasExitIP int err = s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('proxies') WHERE name='exit_ip'`).Scan(&hasExitIP) if err == nil && hasExitIP == 0 { log.Println("[storage] migrating: adding exit_ip column") _, err = s.db.Exec(`ALTER TABLE proxies ADD COLUMN exit_ip TEXT NOT NULL DEFAULT ''`) if err != nil { return fmt.Errorf("migrate exit_ip column: %w", err) } } // 迁移:添加 exit_location 字段 var hasExitLocation int err = s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('proxies') WHERE name='exit_location'`).Scan(&hasExitLocation) if err == nil && hasExitLocation == 0 { log.Println("[storage] migrating: adding exit_location column") _, err = s.db.Exec(`ALTER TABLE proxies ADD COLUMN exit_location TEXT NOT NULL DEFAULT ''`) if err != nil { return fmt.Errorf("migrate exit_location column: %w", err) } } // 迁移:添加 latency 字段 var hasLatency int err = s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('proxies') WHERE name='latency'`).Scan(&hasLatency) if err == nil && hasLatency == 0 { log.Println("[storage] migrating: adding latency column") s.db.Exec(`ALTER TABLE proxies ADD COLUMN latency INTEGER NOT NULL DEFAULT 0`) } // 迁移:添加质量等级字段 var hasQuality int s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('proxies') WHERE name='quality_grade'`).Scan(&hasQuality) if hasQuality == 0 { log.Println("[storage] migrating: adding quality_grade column") s.db.Exec(`ALTER TABLE proxies ADD COLUMN quality_grade TEXT NOT NULL DEFAULT 'C'`) } // 迁移:添加使用统计字段 var hasUseCount int s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('proxies') WHERE name='use_count'`).Scan(&hasUseCount) if hasUseCount == 0 { log.Println("[storage] migrating: adding usage tracking columns") s.db.Exec(`ALTER TABLE proxies ADD COLUMN use_count INTEGER NOT NULL DEFAULT 0`) s.db.Exec(`ALTER TABLE proxies ADD COLUMN success_count INTEGER NOT NULL DEFAULT 0`) s.db.Exec(`ALTER TABLE proxies ADD COLUMN last_used DATETIME`) } // 迁移:添加状态字段 var hasStatus int s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('proxies') WHERE name='status'`).Scan(&hasStatus) if hasStatus == 0 { log.Println("[storage] migrating: adding status column") s.db.Exec(`ALTER TABLE proxies ADD COLUMN status TEXT NOT NULL DEFAULT 'active'`) } return nil } // AddProxy 新增代理,已存在则忽略 func (s *Storage) AddProxy(address, protocol string) error { _, err := s.db.Exec( `INSERT OR IGNORE INTO proxies (address, protocol) VALUES (?, ?)`, address, protocol, ) return err } // AddProxies 批量新增 func (s *Storage) AddProxies(proxies []Proxy) error { tx, err := s.db.Begin() if err != nil { return err } stmt, err := tx.Prepare(`INSERT OR IGNORE INTO proxies (address, protocol) VALUES (?, ?)`) if err != nil { tx.Rollback() return err } defer stmt.Close() for _, p := range proxies { if _, err := stmt.Exec(p.Address, p.Protocol); err != nil { log.Printf("insert proxy %s error: %v", p.Address, err) } } return tx.Commit() } // GetRandom 随机取一个可用代理(优先选择质量高的) func (s *Storage) GetRandom() (*Proxy, error) { // 优先从 S/A 级代理中随机选择 rows, err := s.db.Query( `SELECT id, address, protocol, exit_ip, exit_location, latency, quality_grade, use_count, success_count, fail_count, last_used, last_check, created_at, status FROM proxies WHERE status = 'active' AND fail_count < 3 ORDER BY CASE quality_grade WHEN 'S' THEN 1 WHEN 'A' THEN 2 WHEN 'B' THEN 3 ELSE 4 END, RANDOM() LIMIT 1`, ) if err != nil { return nil, err } defer rows.Close() if rows.Next() { return scanProxy(rows) } return nil, fmt.Errorf("no available proxy") } // scanProxy 扫描代理行数据 func scanProxy(rows *sql.Rows) (*Proxy, error) { p := &Proxy{} var lastUsed, lastCheck sql.NullTime if err := rows.Scan(&p.ID, &p.Address, &p.Protocol, &p.ExitIP, &p.ExitLocation, &p.Latency, &p.QualityGrade, &p.UseCount, &p.SuccessCount, &p.FailCount, &lastUsed, &lastCheck, &p.CreatedAt, &p.Status); err != nil { return nil, err } if lastUsed.Valid { p.LastUsed = lastUsed.Time } if lastCheck.Valid { p.LastCheck = lastCheck.Time } return p, nil } // GetAll 获取所有可用代理 func (s *Storage) GetAll() ([]Proxy, error) { rows, err := s.db.Query( `SELECT id, address, protocol, exit_ip, exit_location, latency, quality_grade, use_count, success_count, fail_count, last_used, last_check, created_at, status FROM proxies WHERE status IN ('active', 'degraded') AND fail_count < 3 ORDER BY latency ASC`, ) if err != nil { return nil, err } defer rows.Close() var proxies []Proxy for rows.Next() { p, err := scanProxy(rows) if err != nil { return nil, err } proxies = append(proxies, *p) } return proxies, nil } // GetRandomExclude 排除指定地址随机取一个 func (s *Storage) GetRandomExclude(excludes []string) (*Proxy, error) { proxies, err := s.GetAll() if err != nil { return nil, err } excludeMap := make(map[string]bool) for _, e := range excludes { excludeMap[e] = true } var available []Proxy for _, p := range proxies { if !excludeMap[p.Address] { available = append(available, p) } } if len(available) == 0 { // 没有可排除的了,随机取任意一个 return s.GetRandom() } p := available[rand.Intn(len(available))] return &p, nil } // GetLowestLatencyExclude 排除指定地址后获取延迟最低的代理 func (s *Storage) GetLowestLatencyExclude(excludes []string) (*Proxy, error) { proxies, err := s.GetAll() if err != nil { return nil, err } excludeMap := make(map[string]bool) for _, e := range excludes { excludeMap[e] = true } // GetAll() 已经按 latency ASC 排序,找到第一个不在排除列表中的 for _, p := range proxies { if !excludeMap[p.Address] { proxy := p return &proxy, nil } } return nil, fmt.Errorf("no available proxy") } // Delete 立即删除指定代理 func (s *Storage) Delete(address string) error { _, err := s.db.Exec(`DELETE FROM proxies WHERE address = ?`, address) return err } // IncrFail 增加失败次数 func (s *Storage) IncrFail(address string) error { _, err := s.db.Exec( `UPDATE proxies SET fail_count = fail_count + 1, last_check = CURRENT_TIMESTAMP WHERE address = ?`, address, ) return err } // ResetFail 重置失败次数(验证通过) func (s *Storage) ResetFail(address string) error { _, err := s.db.Exec( `UPDATE proxies SET fail_count = 0, last_check = CURRENT_TIMESTAMP WHERE address = ?`, address, ) return err } // UpdateLatency 更新代理的延迟信息(毫秒) func (s *Storage) UpdateLatency(address string, latencyMs int) error { _, err := s.db.Exec( `UPDATE proxies SET latency = ? WHERE address = ?`, latencyMs, address, ) return err } // UpdateExitInfo 更新代理的出口 IP、位置和质量等级 func (s *Storage) UpdateExitInfo(address, exitIP, exitLocation string, latencyMs int) error { grade := CalculateQualityGrade(latencyMs) _, err := s.db.Exec( `UPDATE proxies SET exit_ip = ?, exit_location = ?, latency = ?, quality_grade = ? WHERE address = ?`, exitIP, exitLocation, latencyMs, grade, address, ) return err } // RecordProxyUse 记录代理使用(成功) func (s *Storage) RecordProxyUse(address string, success bool) error { if success { _, err := s.db.Exec( `UPDATE proxies SET use_count = use_count + 1, success_count = success_count + 1, last_used = CURRENT_TIMESTAMP WHERE address = ?`, address, ) return err } _, err := s.db.Exec( `UPDATE proxies SET use_count = use_count + 1, fail_count = fail_count + 1, last_used = CURRENT_TIMESTAMP WHERE address = ?`, address, ) return err } // GetWorstProxies 获取指定协议中延迟最高的N个代理 func (s *Storage) GetWorstProxies(protocol string, limit int) ([]Proxy, error) { rows, err := s.db.Query( `SELECT id, address, protocol, exit_ip, exit_location, latency, quality_grade, use_count, success_count, fail_count, last_used, last_check, created_at, status FROM proxies WHERE protocol = ? AND status = 'active' AND quality_grade != 'S' AND (JULIANDAY('now') - JULIANDAY(created_at)) * 1440 > 60 ORDER BY latency DESC, fail_count DESC LIMIT ?`, protocol, limit, ) if err != nil { return nil, err } defer rows.Close() var proxies []Proxy for rows.Next() { p, err := scanProxy(rows) if err != nil { return nil, err } proxies = append(proxies, *p) } return proxies, nil } // ReplaceProxy 替换代理(删除旧的,添加新的) func (s *Storage) ReplaceProxy(oldAddress string, newProxy Proxy) error { tx, err := s.db.Begin() if err != nil { return err } defer tx.Rollback() // 删除旧代理 _, err = tx.Exec(`DELETE FROM proxies WHERE address = ?`, oldAddress) if err != nil { return err } // 添加新代理(带完整信息) grade := CalculateQualityGrade(newProxy.Latency) _, err = tx.Exec( `INSERT INTO proxies (address, protocol, exit_ip, exit_location, latency, quality_grade, status) VALUES (?, ?, ?, ?, ?, ?, 'active')`, newProxy.Address, newProxy.Protocol, newProxy.ExitIP, newProxy.ExitLocation, newProxy.Latency, grade, ) if err != nil { return err } return tx.Commit() } // MarkAsReplacementCandidate 标记代理为替换候选 func (s *Storage) MarkAsReplacementCandidate(addresses []string) error { if len(addresses) == 0 { return nil } placeholders := make([]string, len(addresses)) args := make([]interface{}, len(addresses)) for i, addr := range addresses { placeholders[i] = "?" args[i] = addr } query := fmt.Sprintf(`UPDATE proxies SET status = 'candidate_replace' WHERE address IN (%s)`, fmt.Sprintf("%s", placeholders)) _, err := s.db.Exec(query, args...) return err } // GetAverageLatency 获取指定协议的平均延迟 func (s *Storage) GetAverageLatency(protocol string) (int, error) { var avg sql.NullFloat64 err := s.db.QueryRow( `SELECT AVG(latency) FROM proxies WHERE protocol = ? AND status = 'active' AND latency > 0`, protocol, ).Scan(&avg) if err != nil || !avg.Valid { return 0, err } return int(avg.Float64), nil } // GetQualityDistribution 获取质量分布统计 func (s *Storage) GetQualityDistribution() (map[string]int, error) { rows, err := s.db.Query( `SELECT quality_grade, COUNT(*) as count FROM proxies WHERE status = 'active' AND fail_count < 3 GROUP BY quality_grade`, ) if err != nil { return nil, err } defer rows.Close() dist := make(map[string]int) for rows.Next() { var grade string var count int if err := rows.Scan(&grade, &count); err != nil { return nil, err } dist[grade] = count } return dist, nil } // GetBatchForHealthCheck 获取一批需要健康检查的代理 func (s *Storage) GetBatchForHealthCheck(batchSize int, skipSGrade bool) ([]Proxy, error) { query := `SELECT id, address, protocol, exit_ip, exit_location, latency, quality_grade, use_count, success_count, fail_count, last_used, last_check, created_at, status FROM proxies WHERE status IN ('active', 'degraded') AND fail_count < 3` if skipSGrade { query += ` AND quality_grade != 'S'` } query += ` ORDER BY COALESCE(last_check, '1970-01-01') ASC, quality_grade DESC LIMIT ?` rows, err := s.db.Query(query, batchSize) if err != nil { return nil, err } defer rows.Close() var proxies []Proxy for rows.Next() { p, err := scanProxy(rows) if err != nil { return nil, err } proxies = append(proxies, *p) } return proxies, nil } // CalculateQualityGrade 根据延迟计算质量等级 func CalculateQualityGrade(latencyMs int) string { switch { case latencyMs <= 500: return "S" // 超快 case latencyMs <= 1000: return "A" // 良好 case latencyMs <= 2000: return "B" // 可用 default: return "C" // 淘汰候选 } } // DeleteInvalid 删除失败次数超过阈值的代理 func (s *Storage) DeleteInvalid(maxFailCount int) (int64, error) { res, err := s.db.Exec(`DELETE FROM proxies WHERE fail_count >= ?`, maxFailCount) if err != nil { return 0, err } return res.RowsAffected() } // DeleteBlockedCountries 删除指定国家代码出口的代理 func (s *Storage) DeleteBlockedCountries(countryCodes []string) (int64, error) { if len(countryCodes) == 0 { return 0, nil } var totalDeleted int64 for _, code := range countryCodes { // exit_location 格式:如 "CN Beijing" 或 "HK Hong Kong" // 使用 LIKE 'CODE %' 来匹配国家代码(后面有空格表示有城市信息) res, err := s.db.Exec(`DELETE FROM proxies WHERE exit_location LIKE ?`, code+" %") if err != nil { return totalDeleted, err } affected, _ := res.RowsAffected() totalDeleted += affected } return totalDeleted, nil } // DeleteWithoutExitInfo 删除没有出口信息的代理 func (s *Storage) DeleteWithoutExitInfo() (int64, error) { res, err := s.db.Exec(`DELETE FROM proxies WHERE exit_ip = '' OR exit_location = ''`) if err != nil { return 0, err } return res.RowsAffected() } // Count 返回可用代理数量 func (s *Storage) Count() (int, error) { var count int err := s.db.QueryRow( `SELECT COUNT(*) FROM proxies WHERE status IN ('active', 'degraded') AND fail_count < 3`, ).Scan(&count) return count, err } // CountByProtocol 按协议统计数量 func (s *Storage) CountByProtocol(protocol string) (int, error) { var count int err := s.db.QueryRow( `SELECT COUNT(*) FROM proxies WHERE status IN ('active', 'degraded') AND fail_count < 3 AND protocol = ?`, protocol, ).Scan(&count) return count, err } // IncrementFailCount 增加失败次数 func (s *Storage) IncrementFailCount(address string) error { _, err := s.db.Exec( `UPDATE proxies SET fail_count = fail_count + 1 WHERE address = ?`, address, ) return err } // GetByProtocol 按协议获取代理列表 func (s *Storage) GetByProtocol(protocol string) ([]Proxy, error) { rows, err := s.db.Query( `SELECT id, address, protocol, exit_ip, exit_location, latency, quality_grade, use_count, success_count, fail_count, last_used, last_check, created_at, status FROM proxies WHERE status IN ('active', 'degraded') AND fail_count < 3 AND protocol = ? ORDER BY latency ASC`, protocol, ) if err != nil { return nil, err } defer rows.Close() var proxies []Proxy for rows.Next() { p, err := scanProxy(rows) if err != nil { return nil, err } proxies = append(proxies, *p) } return proxies, nil } // Close 关闭数据库 func (s *Storage) Close() error { return s.db.Close() } // GetDB 获取数据库实例(供其他模块使用) func (s *Storage) GetDB() *sql.DB { return s.db }