mirror of
https://github.com/isboyjc/GoProxy.git
synced 2026-05-07 04:12:45 +08:00
641 lines
18 KiB
Go
641 lines
18 KiB
Go
package storage
|
||
|
||
import (
|
||
"database/sql"
|
||
"fmt"
|
||
"log"
|
||
"math/rand"
|
||
"time"
|
||
|
||
_ "github.com/mattn/go-sqlite3"
|
||
)
|
||
|
||
type Proxy struct {
|
||
ID int64 `json:"id"`
|
||
Address string `json:"address"`
|
||
Protocol string `json:"protocol"`
|
||
ExitIP string `json:"exit_ip"`
|
||
ExitLocation string `json:"exit_location"`
|
||
Latency int `json:"latency"`
|
||
QualityGrade string `json:"quality_grade"`
|
||
UseCount int `json:"use_count"`
|
||
SuccessCount int `json:"success_count"`
|
||
FailCount int `json:"fail_count"`
|
||
LastUsed time.Time `json:"last_used"`
|
||
LastCheck time.Time `json:"last_check"`
|
||
CreatedAt time.Time `json:"created_at"`
|
||
Status string `json:"status"`
|
||
}
|
||
|
||
// SourceStatus 代理源状态
|
||
type SourceStatus struct {
|
||
ID int64
|
||
URL string
|
||
SuccessCount int
|
||
FailCount int
|
||
ConsecutiveFails int
|
||
LastSuccess time.Time
|
||
LastFail time.Time
|
||
Status string // active/degraded/disabled
|
||
DisabledUntil time.Time
|
||
}
|
||
|
||
type Storage struct {
|
||
db *sql.DB
|
||
}
|
||
|
||
func New(dbPath string) (*Storage, error) {
|
||
db, err := sql.Open("sqlite3", dbPath)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("open db: %w", err)
|
||
}
|
||
|
||
db.SetMaxOpenConns(1) // SQLite 单写
|
||
|
||
s := &Storage{db: db}
|
||
if err := s.initSchema(); err != nil {
|
||
return nil, err
|
||
}
|
||
return s, nil
|
||
}
|
||
|
||
func (s *Storage) initSchema() error {
|
||
// 创建代理表
|
||
_, err := s.db.Exec(`
|
||
CREATE TABLE IF NOT EXISTS proxies (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
address TEXT NOT NULL UNIQUE,
|
||
protocol TEXT NOT NULL,
|
||
exit_ip TEXT NOT NULL DEFAULT '',
|
||
exit_location TEXT NOT NULL DEFAULT '',
|
||
latency INTEGER NOT NULL DEFAULT 0,
|
||
quality_grade TEXT NOT NULL DEFAULT 'C',
|
||
use_count INTEGER NOT NULL DEFAULT 0,
|
||
success_count INTEGER NOT NULL DEFAULT 0,
|
||
fail_count INTEGER NOT NULL DEFAULT 0,
|
||
last_used DATETIME,
|
||
last_check DATETIME,
|
||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||
status TEXT NOT NULL DEFAULT 'active'
|
||
)
|
||
`)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 创建索引
|
||
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_protocol_latency ON proxies(protocol, latency)`)
|
||
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_quality_grade ON proxies(quality_grade, latency)`)
|
||
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_status ON proxies(status)`)
|
||
|
||
// 创建源状态表
|
||
_, err = s.db.Exec(`
|
||
CREATE TABLE IF NOT EXISTS source_status (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
url TEXT NOT NULL UNIQUE,
|
||
success_count INTEGER NOT NULL DEFAULT 0,
|
||
fail_count INTEGER NOT NULL DEFAULT 0,
|
||
consecutive_fails INTEGER NOT NULL DEFAULT 0,
|
||
last_success DATETIME,
|
||
last_fail DATETIME,
|
||
status TEXT NOT NULL DEFAULT 'active',
|
||
disabled_until DATETIME
|
||
)
|
||
`)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 迁移:处理旧的 location 字段(如果存在)
|
||
var hasOldLocation int
|
||
err = s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('proxies') WHERE name='location'`).Scan(&hasOldLocation)
|
||
if err == nil && hasOldLocation > 0 {
|
||
log.Println("[storage] migrating: renaming location to exit_location")
|
||
// 如果有旧的 location 字段,先添加新字段再复制数据
|
||
s.db.Exec(`ALTER TABLE proxies ADD COLUMN exit_location TEXT NOT NULL DEFAULT ''`)
|
||
s.db.Exec(`UPDATE proxies SET exit_location = location WHERE location != ''`)
|
||
}
|
||
|
||
// 迁移:添加 exit_ip 字段
|
||
var hasExitIP int
|
||
err = s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('proxies') WHERE name='exit_ip'`).Scan(&hasExitIP)
|
||
if err == nil && hasExitIP == 0 {
|
||
log.Println("[storage] migrating: adding exit_ip column")
|
||
_, err = s.db.Exec(`ALTER TABLE proxies ADD COLUMN exit_ip TEXT NOT NULL DEFAULT ''`)
|
||
if err != nil {
|
||
return fmt.Errorf("migrate exit_ip column: %w", err)
|
||
}
|
||
}
|
||
|
||
// 迁移:添加 exit_location 字段
|
||
var hasExitLocation int
|
||
err = s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('proxies') WHERE name='exit_location'`).Scan(&hasExitLocation)
|
||
if err == nil && hasExitLocation == 0 {
|
||
log.Println("[storage] migrating: adding exit_location column")
|
||
_, err = s.db.Exec(`ALTER TABLE proxies ADD COLUMN exit_location TEXT NOT NULL DEFAULT ''`)
|
||
if err != nil {
|
||
return fmt.Errorf("migrate exit_location column: %w", err)
|
||
}
|
||
}
|
||
|
||
// 迁移:添加 latency 字段
|
||
var hasLatency int
|
||
err = s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('proxies') WHERE name='latency'`).Scan(&hasLatency)
|
||
if err == nil && hasLatency == 0 {
|
||
log.Println("[storage] migrating: adding latency column")
|
||
s.db.Exec(`ALTER TABLE proxies ADD COLUMN latency INTEGER NOT NULL DEFAULT 0`)
|
||
}
|
||
|
||
// 迁移:添加质量等级字段
|
||
var hasQuality int
|
||
s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('proxies') WHERE name='quality_grade'`).Scan(&hasQuality)
|
||
if hasQuality == 0 {
|
||
log.Println("[storage] migrating: adding quality_grade column")
|
||
s.db.Exec(`ALTER TABLE proxies ADD COLUMN quality_grade TEXT NOT NULL DEFAULT 'C'`)
|
||
}
|
||
|
||
// 迁移:添加使用统计字段
|
||
var hasUseCount int
|
||
s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('proxies') WHERE name='use_count'`).Scan(&hasUseCount)
|
||
if hasUseCount == 0 {
|
||
log.Println("[storage] migrating: adding usage tracking columns")
|
||
s.db.Exec(`ALTER TABLE proxies ADD COLUMN use_count INTEGER NOT NULL DEFAULT 0`)
|
||
s.db.Exec(`ALTER TABLE proxies ADD COLUMN success_count INTEGER NOT NULL DEFAULT 0`)
|
||
s.db.Exec(`ALTER TABLE proxies ADD COLUMN last_used DATETIME`)
|
||
}
|
||
|
||
// 迁移:添加状态字段
|
||
var hasStatus int
|
||
s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('proxies') WHERE name='status'`).Scan(&hasStatus)
|
||
if hasStatus == 0 {
|
||
log.Println("[storage] migrating: adding status column")
|
||
s.db.Exec(`ALTER TABLE proxies ADD COLUMN status TEXT NOT NULL DEFAULT 'active'`)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// AddProxy 新增代理,已存在则忽略
|
||
func (s *Storage) AddProxy(address, protocol string) error {
|
||
_, err := s.db.Exec(
|
||
`INSERT OR IGNORE INTO proxies (address, protocol) VALUES (?, ?)`,
|
||
address, protocol,
|
||
)
|
||
return err
|
||
}
|
||
|
||
// AddProxies 批量新增
|
||
func (s *Storage) AddProxies(proxies []Proxy) error {
|
||
tx, err := s.db.Begin()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
stmt, err := tx.Prepare(`INSERT OR IGNORE INTO proxies (address, protocol) VALUES (?, ?)`)
|
||
if err != nil {
|
||
tx.Rollback()
|
||
return err
|
||
}
|
||
defer stmt.Close()
|
||
|
||
for _, p := range proxies {
|
||
if _, err := stmt.Exec(p.Address, p.Protocol); err != nil {
|
||
log.Printf("insert proxy %s error: %v", p.Address, err)
|
||
}
|
||
}
|
||
return tx.Commit()
|
||
}
|
||
|
||
// GetRandom 随机取一个可用代理(优先选择质量高的)
|
||
func (s *Storage) GetRandom() (*Proxy, error) {
|
||
// 优先从 S/A 级代理中随机选择
|
||
rows, err := s.db.Query(
|
||
`SELECT id, address, protocol, exit_ip, exit_location, latency, quality_grade,
|
||
use_count, success_count, fail_count, last_used, last_check, created_at, status
|
||
FROM proxies
|
||
WHERE status = 'active' AND fail_count < 3
|
||
ORDER BY
|
||
CASE quality_grade
|
||
WHEN 'S' THEN 1
|
||
WHEN 'A' THEN 2
|
||
WHEN 'B' THEN 3
|
||
ELSE 4
|
||
END,
|
||
RANDOM()
|
||
LIMIT 1`,
|
||
)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
if rows.Next() {
|
||
return scanProxy(rows)
|
||
}
|
||
return nil, fmt.Errorf("no available proxy")
|
||
}
|
||
|
||
// scanProxy 扫描代理行数据
|
||
func scanProxy(rows *sql.Rows) (*Proxy, error) {
|
||
p := &Proxy{}
|
||
var lastUsed, lastCheck sql.NullTime
|
||
if err := rows.Scan(&p.ID, &p.Address, &p.Protocol, &p.ExitIP, &p.ExitLocation,
|
||
&p.Latency, &p.QualityGrade, &p.UseCount, &p.SuccessCount, &p.FailCount,
|
||
&lastUsed, &lastCheck, &p.CreatedAt, &p.Status); err != nil {
|
||
return nil, err
|
||
}
|
||
if lastUsed.Valid {
|
||
p.LastUsed = lastUsed.Time
|
||
}
|
||
if lastCheck.Valid {
|
||
p.LastCheck = lastCheck.Time
|
||
}
|
||
return p, nil
|
||
}
|
||
|
||
// GetAll 获取所有可用代理
|
||
func (s *Storage) GetAll() ([]Proxy, error) {
|
||
rows, err := s.db.Query(
|
||
`SELECT id, address, protocol, exit_ip, exit_location, latency, quality_grade,
|
||
use_count, success_count, fail_count, last_used, last_check, created_at, status
|
||
FROM proxies
|
||
WHERE status IN ('active', 'degraded') AND fail_count < 3
|
||
ORDER BY latency ASC`,
|
||
)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
var proxies []Proxy
|
||
for rows.Next() {
|
||
p, err := scanProxy(rows)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
proxies = append(proxies, *p)
|
||
}
|
||
return proxies, nil
|
||
}
|
||
|
||
// GetRandomExclude 排除指定地址随机取一个
|
||
func (s *Storage) GetRandomExclude(excludes []string) (*Proxy, error) {
|
||
proxies, err := s.GetAll()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
excludeMap := make(map[string]bool)
|
||
for _, e := range excludes {
|
||
excludeMap[e] = true
|
||
}
|
||
|
||
var available []Proxy
|
||
for _, p := range proxies {
|
||
if !excludeMap[p.Address] {
|
||
available = append(available, p)
|
||
}
|
||
}
|
||
|
||
if len(available) == 0 {
|
||
// 没有可排除的了,随机取任意一个
|
||
return s.GetRandom()
|
||
}
|
||
|
||
p := available[rand.Intn(len(available))]
|
||
return &p, nil
|
||
}
|
||
|
||
// GetLowestLatencyExclude 排除指定地址后获取延迟最低的代理
|
||
func (s *Storage) GetLowestLatencyExclude(excludes []string) (*Proxy, error) {
|
||
proxies, err := s.GetAll()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
excludeMap := make(map[string]bool)
|
||
for _, e := range excludes {
|
||
excludeMap[e] = true
|
||
}
|
||
|
||
// GetAll() 已经按 latency ASC 排序,找到第一个不在排除列表中的
|
||
for _, p := range proxies {
|
||
if !excludeMap[p.Address] {
|
||
proxy := p
|
||
return &proxy, nil
|
||
}
|
||
}
|
||
|
||
return nil, fmt.Errorf("no available proxy")
|
||
}
|
||
|
||
// Delete 立即删除指定代理
|
||
func (s *Storage) Delete(address string) error {
|
||
_, err := s.db.Exec(`DELETE FROM proxies WHERE address = ?`, address)
|
||
return err
|
||
}
|
||
|
||
// IncrFail 增加失败次数
|
||
func (s *Storage) IncrFail(address string) error {
|
||
_, err := s.db.Exec(
|
||
`UPDATE proxies SET fail_count = fail_count + 1, last_check = CURRENT_TIMESTAMP WHERE address = ?`,
|
||
address,
|
||
)
|
||
return err
|
||
}
|
||
|
||
// ResetFail 重置失败次数(验证通过)
|
||
func (s *Storage) ResetFail(address string) error {
|
||
_, err := s.db.Exec(
|
||
`UPDATE proxies SET fail_count = 0, last_check = CURRENT_TIMESTAMP WHERE address = ?`,
|
||
address,
|
||
)
|
||
return err
|
||
}
|
||
|
||
// UpdateLatency 更新代理的延迟信息(毫秒)
|
||
func (s *Storage) UpdateLatency(address string, latencyMs int) error {
|
||
_, err := s.db.Exec(
|
||
`UPDATE proxies SET latency = ? WHERE address = ?`,
|
||
latencyMs, address,
|
||
)
|
||
return err
|
||
}
|
||
|
||
// UpdateExitInfo 更新代理的出口 IP、位置和质量等级
|
||
func (s *Storage) UpdateExitInfo(address, exitIP, exitLocation string, latencyMs int) error {
|
||
grade := CalculateQualityGrade(latencyMs)
|
||
_, err := s.db.Exec(
|
||
`UPDATE proxies SET exit_ip = ?, exit_location = ?, latency = ?, quality_grade = ? WHERE address = ?`,
|
||
exitIP, exitLocation, latencyMs, grade, address,
|
||
)
|
||
return err
|
||
}
|
||
|
||
// RecordProxyUse 记录代理使用(成功)
|
||
func (s *Storage) RecordProxyUse(address string, success bool) error {
|
||
if success {
|
||
_, err := s.db.Exec(
|
||
`UPDATE proxies SET use_count = use_count + 1, success_count = success_count + 1,
|
||
last_used = CURRENT_TIMESTAMP WHERE address = ?`,
|
||
address,
|
||
)
|
||
return err
|
||
}
|
||
_, err := s.db.Exec(
|
||
`UPDATE proxies SET use_count = use_count + 1, fail_count = fail_count + 1,
|
||
last_used = CURRENT_TIMESTAMP WHERE address = ?`,
|
||
address,
|
||
)
|
||
return err
|
||
}
|
||
|
||
// GetWorstProxies 获取指定协议中延迟最高的N个代理
|
||
func (s *Storage) GetWorstProxies(protocol string, limit int) ([]Proxy, error) {
|
||
rows, err := s.db.Query(
|
||
`SELECT id, address, protocol, exit_ip, exit_location, latency, quality_grade,
|
||
use_count, success_count, fail_count, last_used, last_check, created_at, status
|
||
FROM proxies
|
||
WHERE protocol = ? AND status = 'active'
|
||
AND quality_grade != 'S'
|
||
AND (JULIANDAY('now') - JULIANDAY(created_at)) * 1440 > 60
|
||
ORDER BY latency DESC, fail_count DESC
|
||
LIMIT ?`, protocol, limit,
|
||
)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
var proxies []Proxy
|
||
for rows.Next() {
|
||
p, err := scanProxy(rows)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
proxies = append(proxies, *p)
|
||
}
|
||
return proxies, nil
|
||
}
|
||
|
||
// ReplaceProxy 替换代理(删除旧的,添加新的)
|
||
func (s *Storage) ReplaceProxy(oldAddress string, newProxy Proxy) error {
|
||
tx, err := s.db.Begin()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defer tx.Rollback()
|
||
|
||
// 删除旧代理
|
||
_, err = tx.Exec(`DELETE FROM proxies WHERE address = ?`, oldAddress)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 添加新代理(带完整信息)
|
||
grade := CalculateQualityGrade(newProxy.Latency)
|
||
_, err = tx.Exec(
|
||
`INSERT INTO proxies (address, protocol, exit_ip, exit_location, latency, quality_grade, status)
|
||
VALUES (?, ?, ?, ?, ?, ?, 'active')`,
|
||
newProxy.Address, newProxy.Protocol, newProxy.ExitIP, newProxy.ExitLocation, newProxy.Latency, grade,
|
||
)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
return tx.Commit()
|
||
}
|
||
|
||
// MarkAsReplacementCandidate 标记代理为替换候选
|
||
func (s *Storage) MarkAsReplacementCandidate(addresses []string) error {
|
||
if len(addresses) == 0 {
|
||
return nil
|
||
}
|
||
placeholders := make([]string, len(addresses))
|
||
args := make([]interface{}, len(addresses))
|
||
for i, addr := range addresses {
|
||
placeholders[i] = "?"
|
||
args[i] = addr
|
||
}
|
||
query := fmt.Sprintf(`UPDATE proxies SET status = 'candidate_replace' WHERE address IN (%s)`,
|
||
fmt.Sprintf("%s", placeholders))
|
||
_, err := s.db.Exec(query, args...)
|
||
return err
|
||
}
|
||
|
||
// GetAverageLatency 获取指定协议的平均延迟
|
||
func (s *Storage) GetAverageLatency(protocol string) (int, error) {
|
||
var avg sql.NullFloat64
|
||
err := s.db.QueryRow(
|
||
`SELECT AVG(latency) FROM proxies WHERE protocol = ? AND status = 'active' AND latency > 0`,
|
||
protocol,
|
||
).Scan(&avg)
|
||
if err != nil || !avg.Valid {
|
||
return 0, err
|
||
}
|
||
return int(avg.Float64), nil
|
||
}
|
||
|
||
// GetQualityDistribution 获取质量分布统计
|
||
func (s *Storage) GetQualityDistribution() (map[string]int, error) {
|
||
rows, err := s.db.Query(
|
||
`SELECT quality_grade, COUNT(*) as count
|
||
FROM proxies
|
||
WHERE status = 'active' AND fail_count < 3
|
||
GROUP BY quality_grade`,
|
||
)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
dist := make(map[string]int)
|
||
for rows.Next() {
|
||
var grade string
|
||
var count int
|
||
if err := rows.Scan(&grade, &count); err != nil {
|
||
return nil, err
|
||
}
|
||
dist[grade] = count
|
||
}
|
||
return dist, nil
|
||
}
|
||
|
||
// GetBatchForHealthCheck 获取一批需要健康检查的代理
|
||
func (s *Storage) GetBatchForHealthCheck(batchSize int, skipSGrade bool) ([]Proxy, error) {
|
||
query := `SELECT id, address, protocol, exit_ip, exit_location, latency, quality_grade,
|
||
use_count, success_count, fail_count, last_used, last_check, created_at, status
|
||
FROM proxies
|
||
WHERE status IN ('active', 'degraded') AND fail_count < 3`
|
||
|
||
if skipSGrade {
|
||
query += ` AND quality_grade != 'S'`
|
||
}
|
||
|
||
query += ` ORDER BY
|
||
COALESCE(last_check, '1970-01-01') ASC,
|
||
quality_grade DESC
|
||
LIMIT ?`
|
||
|
||
rows, err := s.db.Query(query, batchSize)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
var proxies []Proxy
|
||
for rows.Next() {
|
||
p, err := scanProxy(rows)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
proxies = append(proxies, *p)
|
||
}
|
||
return proxies, nil
|
||
}
|
||
|
||
// CalculateQualityGrade 根据延迟计算质量等级
|
||
func CalculateQualityGrade(latencyMs int) string {
|
||
switch {
|
||
case latencyMs <= 500:
|
||
return "S" // 超快
|
||
case latencyMs <= 1000:
|
||
return "A" // 良好
|
||
case latencyMs <= 2000:
|
||
return "B" // 可用
|
||
default:
|
||
return "C" // 淘汰候选
|
||
}
|
||
}
|
||
|
||
// DeleteInvalid 删除失败次数超过阈值的代理
|
||
func (s *Storage) DeleteInvalid(maxFailCount int) (int64, error) {
|
||
res, err := s.db.Exec(`DELETE FROM proxies WHERE fail_count >= ?`, maxFailCount)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return res.RowsAffected()
|
||
}
|
||
|
||
// DeleteChinaMainland 删除中国大陆出口的代理(保留香港)
|
||
func (s *Storage) DeleteChinaMainland() (int64, error) {
|
||
// 删除 exit_location 以 "CN " 开头的代理(CN后面有空格表示是城市)
|
||
// 香港是 "HK Hong Kong" 或 "HK" 开头,不会被删除
|
||
res, err := s.db.Exec(`DELETE FROM proxies WHERE exit_location LIKE 'CN %'`)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return res.RowsAffected()
|
||
}
|
||
|
||
// DeleteWithoutExitInfo 删除没有出口信息的代理
|
||
func (s *Storage) DeleteWithoutExitInfo() (int64, error) {
|
||
res, err := s.db.Exec(`DELETE FROM proxies WHERE exit_ip = '' OR exit_location = ''`)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return res.RowsAffected()
|
||
}
|
||
|
||
// Count 返回可用代理数量
|
||
func (s *Storage) Count() (int, error) {
|
||
var count int
|
||
err := s.db.QueryRow(
|
||
`SELECT COUNT(*) FROM proxies WHERE status IN ('active', 'degraded') AND fail_count < 3`,
|
||
).Scan(&count)
|
||
return count, err
|
||
}
|
||
|
||
// CountByProtocol 按协议统计数量
|
||
func (s *Storage) CountByProtocol(protocol string) (int, error) {
|
||
var count int
|
||
err := s.db.QueryRow(
|
||
`SELECT COUNT(*) FROM proxies WHERE status IN ('active', 'degraded') AND fail_count < 3 AND protocol = ?`,
|
||
protocol,
|
||
).Scan(&count)
|
||
return count, err
|
||
}
|
||
|
||
// IncrementFailCount 增加失败次数
|
||
func (s *Storage) IncrementFailCount(address string) error {
|
||
_, err := s.db.Exec(
|
||
`UPDATE proxies SET fail_count = fail_count + 1 WHERE address = ?`,
|
||
address,
|
||
)
|
||
return err
|
||
}
|
||
|
||
// GetByProtocol 按协议获取代理列表
|
||
func (s *Storage) GetByProtocol(protocol string) ([]Proxy, error) {
|
||
rows, err := s.db.Query(
|
||
`SELECT id, address, protocol, exit_ip, exit_location, latency, quality_grade,
|
||
use_count, success_count, fail_count, last_used, last_check, created_at, status
|
||
FROM proxies
|
||
WHERE status IN ('active', 'degraded') AND fail_count < 3 AND protocol = ?
|
||
ORDER BY latency ASC`, protocol,
|
||
)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
var proxies []Proxy
|
||
for rows.Next() {
|
||
p, err := scanProxy(rows)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
proxies = append(proxies, *p)
|
||
}
|
||
return proxies, nil
|
||
}
|
||
|
||
// Close 关闭数据库
|
||
func (s *Storage) Close() error {
|
||
return s.db.Close()
|
||
}
|
||
|
||
// GetDB 获取数据库实例(供其他模块使用)
|
||
func (s *Storage) GetDB() *sql.DB {
|
||
return s.db
|
||
}
|