Files
GoProxy/storage/storage.go
isboyjc a06be637e7 feat: implement geo-filtering with whitelist and blacklist support
- Added support for geo-filtering in the proxy pool, allowing configuration of allowed and blocked countries via environment variables.
- Updated `.env.example` and `docker-compose.yml` to include `ALLOWED_COUNTRIES` for whitelist functionality.
- Enhanced `CLAUDE.md`, `GEO_FILTER.md`, and `README.md` to document the new geo-filtering features and usage instructions.
- Modified proxy validation logic to prioritize whitelist over blacklist during admission checks.
- Improved WebUI to allow dynamic configuration of geo-filter settings.
2026-04-01 21:45:09 +08:00

734 lines
20 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package storage
import (
"database/sql"
"fmt"
"log"
"math/rand"
"strings"
"time"
_ "github.com/mattn/go-sqlite3"
)
type Proxy struct {
ID int64 `json:"id"`
Address string `json:"address"`
Protocol string `json:"protocol"`
ExitIP string `json:"exit_ip"`
ExitLocation string `json:"exit_location"`
Latency int `json:"latency"`
QualityGrade string `json:"quality_grade"`
UseCount int `json:"use_count"`
SuccessCount int `json:"success_count"`
FailCount int `json:"fail_count"`
LastUsed time.Time `json:"last_used"`
LastCheck time.Time `json:"last_check"`
CreatedAt time.Time `json:"created_at"`
Status string `json:"status"`
}
// SourceStatus 代理源状态
type SourceStatus struct {
ID int64
URL string
SuccessCount int
FailCount int
ConsecutiveFails int
LastSuccess time.Time
LastFail time.Time
Status string // active/degraded/disabled
DisabledUntil time.Time
}
type Storage struct {
db *sql.DB
}
func New(dbPath string) (*Storage, error) {
db, err := sql.Open("sqlite3", dbPath)
if err != nil {
return nil, fmt.Errorf("open db: %w", err)
}
db.SetMaxOpenConns(1) // SQLite 单写
s := &Storage{db: db}
if err := s.initSchema(); err != nil {
return nil, err
}
return s, nil
}
func (s *Storage) initSchema() error {
// 创建代理表
_, err := s.db.Exec(`
CREATE TABLE IF NOT EXISTS proxies (
id INTEGER PRIMARY KEY AUTOINCREMENT,
address TEXT NOT NULL UNIQUE,
protocol TEXT NOT NULL,
exit_ip TEXT NOT NULL DEFAULT '',
exit_location TEXT NOT NULL DEFAULT '',
latency INTEGER NOT NULL DEFAULT 0,
quality_grade TEXT NOT NULL DEFAULT 'C',
use_count INTEGER NOT NULL DEFAULT 0,
success_count INTEGER NOT NULL DEFAULT 0,
fail_count INTEGER NOT NULL DEFAULT 0,
last_used DATETIME,
last_check DATETIME,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
status TEXT NOT NULL DEFAULT 'active'
)
`)
if err != nil {
return err
}
// 创建索引
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_protocol_latency ON proxies(protocol, latency)`)
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_quality_grade ON proxies(quality_grade, latency)`)
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_status ON proxies(status)`)
// 创建源状态表
_, err = s.db.Exec(`
CREATE TABLE IF NOT EXISTS source_status (
id INTEGER PRIMARY KEY AUTOINCREMENT,
url TEXT NOT NULL UNIQUE,
success_count INTEGER NOT NULL DEFAULT 0,
fail_count INTEGER NOT NULL DEFAULT 0,
consecutive_fails INTEGER NOT NULL DEFAULT 0,
last_success DATETIME,
last_fail DATETIME,
status TEXT NOT NULL DEFAULT 'active',
disabled_until DATETIME
)
`)
if err != nil {
return err
}
// 迁移:处理旧的 location 字段(如果存在)
var hasOldLocation int
err = s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('proxies') WHERE name='location'`).Scan(&hasOldLocation)
if err == nil && hasOldLocation > 0 {
log.Println("[storage] migrating: renaming location to exit_location")
// 如果有旧的 location 字段,先添加新字段再复制数据
s.db.Exec(`ALTER TABLE proxies ADD COLUMN exit_location TEXT NOT NULL DEFAULT ''`)
s.db.Exec(`UPDATE proxies SET exit_location = location WHERE location != ''`)
}
// 迁移:添加 exit_ip 字段
var hasExitIP int
err = s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('proxies') WHERE name='exit_ip'`).Scan(&hasExitIP)
if err == nil && hasExitIP == 0 {
log.Println("[storage] migrating: adding exit_ip column")
_, err = s.db.Exec(`ALTER TABLE proxies ADD COLUMN exit_ip TEXT NOT NULL DEFAULT ''`)
if err != nil {
return fmt.Errorf("migrate exit_ip column: %w", err)
}
}
// 迁移:添加 exit_location 字段
var hasExitLocation int
err = s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('proxies') WHERE name='exit_location'`).Scan(&hasExitLocation)
if err == nil && hasExitLocation == 0 {
log.Println("[storage] migrating: adding exit_location column")
_, err = s.db.Exec(`ALTER TABLE proxies ADD COLUMN exit_location TEXT NOT NULL DEFAULT ''`)
if err != nil {
return fmt.Errorf("migrate exit_location column: %w", err)
}
}
// 迁移:添加 latency 字段
var hasLatency int
err = s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('proxies') WHERE name='latency'`).Scan(&hasLatency)
if err == nil && hasLatency == 0 {
log.Println("[storage] migrating: adding latency column")
s.db.Exec(`ALTER TABLE proxies ADD COLUMN latency INTEGER NOT NULL DEFAULT 0`)
}
// 迁移:添加质量等级字段
var hasQuality int
s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('proxies') WHERE name='quality_grade'`).Scan(&hasQuality)
if hasQuality == 0 {
log.Println("[storage] migrating: adding quality_grade column")
s.db.Exec(`ALTER TABLE proxies ADD COLUMN quality_grade TEXT NOT NULL DEFAULT 'C'`)
}
// 迁移:添加使用统计字段
var hasUseCount int
s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('proxies') WHERE name='use_count'`).Scan(&hasUseCount)
if hasUseCount == 0 {
log.Println("[storage] migrating: adding usage tracking columns")
s.db.Exec(`ALTER TABLE proxies ADD COLUMN use_count INTEGER NOT NULL DEFAULT 0`)
s.db.Exec(`ALTER TABLE proxies ADD COLUMN success_count INTEGER NOT NULL DEFAULT 0`)
s.db.Exec(`ALTER TABLE proxies ADD COLUMN last_used DATETIME`)
}
// 迁移:添加状态字段
var hasStatus int
s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('proxies') WHERE name='status'`).Scan(&hasStatus)
if hasStatus == 0 {
log.Println("[storage] migrating: adding status column")
s.db.Exec(`ALTER TABLE proxies ADD COLUMN status TEXT NOT NULL DEFAULT 'active'`)
}
return nil
}
// AddProxy 新增代理,已存在则忽略
func (s *Storage) AddProxy(address, protocol string) error {
result, err := s.db.Exec(
`INSERT OR IGNORE INTO proxies (address, protocol) VALUES (?, ?)`,
address, protocol,
)
if err != nil {
log.Printf("[storage] AddProxy %s error: %v", address, err)
return err
}
// 检查是否真的插入了
affected, _ := result.RowsAffected()
if affected == 0 {
log.Printf("[storage] AddProxy %s ignored (already exists or constraint)", address)
}
return nil
}
// AddProxies 批量新增
func (s *Storage) AddProxies(proxies []Proxy) error {
tx, err := s.db.Begin()
if err != nil {
return err
}
stmt, err := tx.Prepare(`INSERT OR IGNORE INTO proxies (address, protocol) VALUES (?, ?)`)
if err != nil {
tx.Rollback()
return err
}
defer stmt.Close()
for _, p := range proxies {
if _, err := stmt.Exec(p.Address, p.Protocol); err != nil {
log.Printf("insert proxy %s error: %v", p.Address, err)
}
}
return tx.Commit()
}
// GetRandom 随机取一个可用代理(优先选择质量高的)
func (s *Storage) GetRandom() (*Proxy, error) {
// 优先从 S/A 级代理中随机选择
rows, err := s.db.Query(
`SELECT id, address, protocol, exit_ip, exit_location, latency, quality_grade,
use_count, success_count, fail_count, last_used, last_check, created_at, status
FROM proxies
WHERE status = 'active' AND fail_count < 3
ORDER BY
CASE quality_grade
WHEN 'S' THEN 1
WHEN 'A' THEN 2
WHEN 'B' THEN 3
ELSE 4
END,
RANDOM()
LIMIT 1`,
)
if err != nil {
return nil, err
}
defer rows.Close()
if rows.Next() {
return scanProxy(rows)
}
return nil, fmt.Errorf("no available proxy")
}
// scanProxy 扫描代理行数据
func scanProxy(rows *sql.Rows) (*Proxy, error) {
p := &Proxy{}
var lastUsed, lastCheck sql.NullTime
if err := rows.Scan(&p.ID, &p.Address, &p.Protocol, &p.ExitIP, &p.ExitLocation,
&p.Latency, &p.QualityGrade, &p.UseCount, &p.SuccessCount, &p.FailCount,
&lastUsed, &lastCheck, &p.CreatedAt, &p.Status); err != nil {
return nil, err
}
if lastUsed.Valid {
p.LastUsed = lastUsed.Time
}
if lastCheck.Valid {
p.LastCheck = lastCheck.Time
}
return p, nil
}
// GetAll 获取所有可用代理
func (s *Storage) GetAll() ([]Proxy, error) {
rows, err := s.db.Query(
`SELECT id, address, protocol, exit_ip, exit_location, latency, quality_grade,
use_count, success_count, fail_count, last_used, last_check, created_at, status
FROM proxies
WHERE status IN ('active', 'degraded') AND fail_count < 3
ORDER BY latency ASC`,
)
if err != nil {
return nil, err
}
defer rows.Close()
var proxies []Proxy
for rows.Next() {
p, err := scanProxy(rows)
if err != nil {
return nil, err
}
proxies = append(proxies, *p)
}
return proxies, nil
}
// GetRandomExclude 排除指定地址随机取一个
func (s *Storage) GetRandomExclude(excludes []string) (*Proxy, error) {
proxies, err := s.GetAll()
if err != nil {
return nil, err
}
excludeMap := make(map[string]bool)
for _, e := range excludes {
excludeMap[e] = true
}
var available []Proxy
for _, p := range proxies {
if !excludeMap[p.Address] {
available = append(available, p)
}
}
if len(available) == 0 {
// 没有可排除的了,随机取任意一个
return s.GetRandom()
}
p := available[rand.Intn(len(available))]
return &p, nil
}
// GetLowestLatencyExclude 排除指定地址后获取延迟最低的代理
func (s *Storage) GetLowestLatencyExclude(excludes []string) (*Proxy, error) {
proxies, err := s.GetAll()
if err != nil {
return nil, err
}
excludeMap := make(map[string]bool)
for _, e := range excludes {
excludeMap[e] = true
}
// GetAll() 已经按 latency ASC 排序,找到第一个不在排除列表中的
for _, p := range proxies {
if !excludeMap[p.Address] {
proxy := p
return &proxy, nil
}
}
return nil, fmt.Errorf("no available proxy")
}
// GetRandomByProtocolExclude 按协议获取随机代理(排除已尝试的)
func (s *Storage) GetRandomByProtocolExclude(protocol string, excludes []string) (*Proxy, error) {
proxies, err := s.GetAll()
if err != nil {
return nil, err
}
excludeMap := make(map[string]bool)
for _, e := range excludes {
excludeMap[e] = true
}
var available []Proxy
for _, p := range proxies {
if p.Protocol == protocol && !excludeMap[p.Address] {
available = append(available, p)
}
}
if len(available) == 0 {
return nil, fmt.Errorf("no %s proxy available", protocol)
}
proxy := available[time.Now().UnixNano()%int64(len(available))]
return &proxy, nil
}
// GetLowestLatencyByProtocolExclude 按协议获取最低延迟代理(排除已尝试的)
func (s *Storage) GetLowestLatencyByProtocolExclude(protocol string, excludes []string) (*Proxy, error) {
proxies, err := s.GetAll()
if err != nil {
return nil, err
}
excludeMap := make(map[string]bool)
for _, e := range excludes {
excludeMap[e] = true
}
// GetAll() 已经按 latency ASC 排序,找到第一个匹配协议且不在排除列表中的
for _, p := range proxies {
if p.Protocol == protocol && !excludeMap[p.Address] {
proxy := p
return &proxy, nil
}
}
return nil, fmt.Errorf("no %s proxy available", protocol)
}
// Delete 立即删除指定代理
func (s *Storage) Delete(address string) error {
_, err := s.db.Exec(`DELETE FROM proxies WHERE address = ?`, address)
return err
}
// IncrFail 增加失败次数
func (s *Storage) IncrFail(address string) error {
_, err := s.db.Exec(
`UPDATE proxies SET fail_count = fail_count + 1, last_check = CURRENT_TIMESTAMP WHERE address = ?`,
address,
)
return err
}
// ResetFail 重置失败次数(验证通过)
func (s *Storage) ResetFail(address string) error {
_, err := s.db.Exec(
`UPDATE proxies SET fail_count = 0, last_check = CURRENT_TIMESTAMP WHERE address = ?`,
address,
)
return err
}
// UpdateLatency 更新代理的延迟信息(毫秒)
func (s *Storage) UpdateLatency(address string, latencyMs int) error {
_, err := s.db.Exec(
`UPDATE proxies SET latency = ? WHERE address = ?`,
latencyMs, address,
)
return err
}
// UpdateExitInfo 更新代理的出口 IP、位置和质量等级
func (s *Storage) UpdateExitInfo(address, exitIP, exitLocation string, latencyMs int) error {
grade := CalculateQualityGrade(latencyMs)
_, err := s.db.Exec(
`UPDATE proxies SET exit_ip = ?, exit_location = ?, latency = ?, quality_grade = ? WHERE address = ?`,
exitIP, exitLocation, latencyMs, grade, address,
)
return err
}
// RecordProxyUse 记录代理使用(成功)
func (s *Storage) RecordProxyUse(address string, success bool) error {
if success {
_, err := s.db.Exec(
`UPDATE proxies SET use_count = use_count + 1, success_count = success_count + 1,
last_used = CURRENT_TIMESTAMP WHERE address = ?`,
address,
)
return err
}
_, err := s.db.Exec(
`UPDATE proxies SET use_count = use_count + 1, fail_count = fail_count + 1,
last_used = CURRENT_TIMESTAMP WHERE address = ?`,
address,
)
return err
}
// GetWorstProxies 获取指定协议中延迟最高的N个代理
func (s *Storage) GetWorstProxies(protocol string, limit int) ([]Proxy, error) {
rows, err := s.db.Query(
`SELECT id, address, protocol, exit_ip, exit_location, latency, quality_grade,
use_count, success_count, fail_count, last_used, last_check, created_at, status
FROM proxies
WHERE protocol = ? AND status = 'active'
AND quality_grade != 'S'
AND (JULIANDAY('now') - JULIANDAY(created_at)) * 1440 > 60
ORDER BY latency DESC, fail_count DESC
LIMIT ?`, protocol, limit,
)
if err != nil {
return nil, err
}
defer rows.Close()
var proxies []Proxy
for rows.Next() {
p, err := scanProxy(rows)
if err != nil {
return nil, err
}
proxies = append(proxies, *p)
}
return proxies, nil
}
// ReplaceProxy 替换代理(删除旧的,添加新的)
func (s *Storage) ReplaceProxy(oldAddress string, newProxy Proxy) error {
tx, err := s.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
// 删除旧代理
_, err = tx.Exec(`DELETE FROM proxies WHERE address = ?`, oldAddress)
if err != nil {
return err
}
// 添加新代理(带完整信息)
grade := CalculateQualityGrade(newProxy.Latency)
_, err = tx.Exec(
`INSERT INTO proxies (address, protocol, exit_ip, exit_location, latency, quality_grade, status)
VALUES (?, ?, ?, ?, ?, ?, 'active')`,
newProxy.Address, newProxy.Protocol, newProxy.ExitIP, newProxy.ExitLocation, newProxy.Latency, grade,
)
if err != nil {
return err
}
return tx.Commit()
}
// MarkAsReplacementCandidate 标记代理为替换候选
func (s *Storage) MarkAsReplacementCandidate(addresses []string) error {
if len(addresses) == 0 {
return nil
}
placeholders := make([]string, len(addresses))
args := make([]interface{}, len(addresses))
for i, addr := range addresses {
placeholders[i] = "?"
args[i] = addr
}
query := fmt.Sprintf(`UPDATE proxies SET status = 'candidate_replace' WHERE address IN (%s)`,
fmt.Sprintf("%s", placeholders))
_, err := s.db.Exec(query, args...)
return err
}
// GetAverageLatency 获取指定协议的平均延迟
func (s *Storage) GetAverageLatency(protocol string) (int, error) {
var avg sql.NullFloat64
err := s.db.QueryRow(
`SELECT AVG(latency) FROM proxies WHERE protocol = ? AND status = 'active' AND latency > 0`,
protocol,
).Scan(&avg)
if err != nil || !avg.Valid {
return 0, err
}
return int(avg.Float64), nil
}
// GetQualityDistribution 获取质量分布统计
func (s *Storage) GetQualityDistribution() (map[string]int, error) {
rows, err := s.db.Query(
`SELECT quality_grade, COUNT(*) as count
FROM proxies
WHERE status = 'active' AND fail_count < 3
GROUP BY quality_grade`,
)
if err != nil {
return nil, err
}
defer rows.Close()
dist := make(map[string]int)
for rows.Next() {
var grade string
var count int
if err := rows.Scan(&grade, &count); err != nil {
return nil, err
}
dist[grade] = count
}
return dist, nil
}
// GetBatchForHealthCheck 获取一批需要健康检查的代理
func (s *Storage) GetBatchForHealthCheck(batchSize int, skipSGrade bool) ([]Proxy, error) {
query := `SELECT id, address, protocol, exit_ip, exit_location, latency, quality_grade,
use_count, success_count, fail_count, last_used, last_check, created_at, status
FROM proxies
WHERE status IN ('active', 'degraded') AND fail_count < 3`
if skipSGrade {
query += ` AND quality_grade != 'S'`
}
query += ` ORDER BY
COALESCE(last_check, '1970-01-01') ASC,
quality_grade DESC
LIMIT ?`
rows, err := s.db.Query(query, batchSize)
if err != nil {
return nil, err
}
defer rows.Close()
var proxies []Proxy
for rows.Next() {
p, err := scanProxy(rows)
if err != nil {
return nil, err
}
proxies = append(proxies, *p)
}
return proxies, nil
}
// CalculateQualityGrade 根据延迟计算质量等级
func CalculateQualityGrade(latencyMs int) string {
switch {
case latencyMs <= 500:
return "S" // 超快
case latencyMs <= 1000:
return "A" // 良好
case latencyMs <= 2000:
return "B" // 可用
default:
return "C" // 淘汰候选
}
}
// DeleteInvalid 删除失败次数超过阈值的代理
func (s *Storage) DeleteInvalid(maxFailCount int) (int64, error) {
res, err := s.db.Exec(`DELETE FROM proxies WHERE fail_count >= ?`, maxFailCount)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
// DeleteBlockedCountries 删除指定国家代码出口的代理
func (s *Storage) DeleteBlockedCountries(countryCodes []string) (int64, error) {
if len(countryCodes) == 0 {
return 0, nil
}
var totalDeleted int64
for _, code := range countryCodes {
// exit_location 格式:如 "CN Beijing" 或 "CN"(仅国家代码)
// 同时匹配 "CODE" 和 "CODE ..." 两种情况
res, err := s.db.Exec(`DELETE FROM proxies WHERE exit_location = ? OR exit_location LIKE ?`, code, code+" %")
if err != nil {
return totalDeleted, err
}
affected, _ := res.RowsAffected()
totalDeleted += affected
}
return totalDeleted, nil
}
// DeleteNotAllowedCountries 删除不在白名单中的代理
func (s *Storage) DeleteNotAllowedCountries(allowedCodes []string) (int64, error) {
if len(allowedCodes) == 0 {
return 0, nil
}
// 构建 WHERE 条件exit_location 不以任何白名单国家代码开头
// 即NOT (exit_location = 'US' OR exit_location LIKE 'US %' OR ...)
conditions := make([]string, 0, len(allowedCodes)*2)
args := make([]interface{}, 0, len(allowedCodes)*2)
for _, code := range allowedCodes {
conditions = append(conditions, "exit_location = ?", "exit_location LIKE ?")
args = append(args, code, code+" %")
}
query := `DELETE FROM proxies WHERE exit_location != '' AND NOT (` + strings.Join(conditions, " OR ") + `)`
res, err := s.db.Exec(query, args...)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
// DeleteWithoutExitInfo 删除没有出口信息的代理
func (s *Storage) DeleteWithoutExitInfo() (int64, error) {
res, err := s.db.Exec(`DELETE FROM proxies WHERE exit_ip = '' OR exit_location = ''`)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
// Count 返回可用代理数量
func (s *Storage) Count() (int, error) {
var count int
err := s.db.QueryRow(
`SELECT COUNT(*) FROM proxies WHERE status IN ('active', 'degraded') AND fail_count < 3`,
).Scan(&count)
return count, err
}
// CountByProtocol 按协议统计数量
func (s *Storage) CountByProtocol(protocol string) (int, error) {
var count int
err := s.db.QueryRow(
`SELECT COUNT(*) FROM proxies WHERE status IN ('active', 'degraded') AND fail_count < 3 AND protocol = ?`,
protocol,
).Scan(&count)
return count, err
}
// IncrementFailCount 增加失败次数
func (s *Storage) IncrementFailCount(address string) error {
_, err := s.db.Exec(
`UPDATE proxies SET fail_count = fail_count + 1 WHERE address = ?`,
address,
)
return err
}
// GetByProtocol 按协议获取代理列表
func (s *Storage) GetByProtocol(protocol string) ([]Proxy, error) {
rows, err := s.db.Query(
`SELECT id, address, protocol, exit_ip, exit_location, latency, quality_grade,
use_count, success_count, fail_count, last_used, last_check, created_at, status
FROM proxies
WHERE status IN ('active', 'degraded') AND fail_count < 3 AND protocol = ?
ORDER BY latency ASC`, protocol,
)
if err != nil {
return nil, err
}
defer rows.Close()
var proxies []Proxy
for rows.Next() {
p, err := scanProxy(rows)
if err != nil {
return nil, err
}
proxies = append(proxies, *p)
}
return proxies, nil
}
// Close 关闭数据库
func (s *Storage) Close() error {
return s.db.Close()
}
// GetDB 获取数据库实例(供其他模块使用)
func (s *Storage) GetDB() *sql.DB {
return s.db
}