mirror of
https://github.com/isboyjc/GoProxy.git
synced 2026-05-10 17:43:11 +08:00
- Added support for importing Clash/V2ray subscriptions, including automatic format detection and integration with sing-box for protocol conversion. - Introduced five proxy usage modes in the configuration, allowing flexible selection between mixed, custom-only, and free-only modes. - Enhanced `.env.example` and `docker-compose.yml` to include new environment variables for custom proxy settings. - Updated `CHANGELOG.md` to document new features and improvements related to subscription management. - Improved WebUI for managing subscriptions and displaying proxy statistics. - Implemented a background process for refreshing subscriptions and probing disabled proxies for reactivation.
1196 lines
35 KiB
Go
1196 lines
35 KiB
Go
package storage
|
||
|
||
import (
|
||
"database/sql"
|
||
"fmt"
|
||
"log"
|
||
"math/rand"
|
||
"strings"
|
||
"time"
|
||
|
||
_ "github.com/mattn/go-sqlite3"
|
||
)
|
||
|
||
type Proxy struct {
|
||
ID int64 `json:"id"`
|
||
Address string `json:"address"`
|
||
Protocol string `json:"protocol"`
|
||
ExitIP string `json:"exit_ip"`
|
||
ExitLocation string `json:"exit_location"`
|
||
Latency int `json:"latency"`
|
||
QualityGrade string `json:"quality_grade"`
|
||
UseCount int `json:"use_count"`
|
||
SuccessCount int `json:"success_count"`
|
||
FailCount int `json:"fail_count"`
|
||
LastUsed time.Time `json:"last_used"`
|
||
LastCheck time.Time `json:"last_check"`
|
||
CreatedAt time.Time `json:"created_at"`
|
||
Status string `json:"status"`
|
||
Source string `json:"source"` // "free" 或 "custom"
|
||
SubscriptionID int64 `json:"subscription_id"` // 所属订阅ID(0=免费代理)
|
||
}
|
||
|
||
// Subscription 订阅信息
|
||
type Subscription struct {
|
||
ID int64 `json:"id"`
|
||
Name string `json:"name"`
|
||
URL string `json:"url"`
|
||
FilePath string `json:"file_path"`
|
||
Format string `json:"format"` // clash / plain / base64 / auto
|
||
RefreshMin int `json:"refresh_min"`
|
||
LastFetch time.Time `json:"last_fetch"`
|
||
LastSuccess time.Time `json:"last_success"` // 最后一次有可用节点的时间
|
||
Status string `json:"status"` // active / paused
|
||
ProxyCount int `json:"proxy_count"`
|
||
CreatedAt time.Time `json:"created_at"`
|
||
Contributed bool `json:"contributed"` // 是否为访客贡献
|
||
}
|
||
|
||
// SourceStatus 代理源状态
|
||
type SourceStatus struct {
|
||
ID int64
|
||
URL string
|
||
SuccessCount int
|
||
FailCount int
|
||
ConsecutiveFails int
|
||
LastSuccess time.Time
|
||
LastFail time.Time
|
||
Status string // active/degraded/disabled
|
||
DisabledUntil time.Time
|
||
}
|
||
|
||
type Storage struct {
|
||
db *sql.DB
|
||
}
|
||
|
||
func New(dbPath string) (*Storage, error) {
|
||
db, err := sql.Open("sqlite3", dbPath)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("open db: %w", err)
|
||
}
|
||
|
||
db.SetMaxOpenConns(1) // SQLite 单写
|
||
|
||
s := &Storage{db: db}
|
||
if err := s.initSchema(); err != nil {
|
||
return nil, err
|
||
}
|
||
return s, nil
|
||
}
|
||
|
||
func (s *Storage) initSchema() error {
|
||
// 创建代理表
|
||
_, err := s.db.Exec(`
|
||
CREATE TABLE IF NOT EXISTS proxies (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
address TEXT NOT NULL UNIQUE,
|
||
protocol TEXT NOT NULL,
|
||
exit_ip TEXT NOT NULL DEFAULT '',
|
||
exit_location TEXT NOT NULL DEFAULT '',
|
||
latency INTEGER NOT NULL DEFAULT 0,
|
||
quality_grade TEXT NOT NULL DEFAULT 'C',
|
||
use_count INTEGER NOT NULL DEFAULT 0,
|
||
success_count INTEGER NOT NULL DEFAULT 0,
|
||
fail_count INTEGER NOT NULL DEFAULT 0,
|
||
last_used DATETIME,
|
||
last_check DATETIME,
|
||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||
status TEXT NOT NULL DEFAULT 'active'
|
||
)
|
||
`)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 创建索引
|
||
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_protocol_latency ON proxies(protocol, latency)`)
|
||
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_quality_grade ON proxies(quality_grade, latency)`)
|
||
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_status ON proxies(status)`)
|
||
|
||
// 创建源状态表
|
||
_, err = s.db.Exec(`
|
||
CREATE TABLE IF NOT EXISTS source_status (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
url TEXT NOT NULL UNIQUE,
|
||
success_count INTEGER NOT NULL DEFAULT 0,
|
||
fail_count INTEGER NOT NULL DEFAULT 0,
|
||
consecutive_fails INTEGER NOT NULL DEFAULT 0,
|
||
last_success DATETIME,
|
||
last_fail DATETIME,
|
||
status TEXT NOT NULL DEFAULT 'active',
|
||
disabled_until DATETIME
|
||
)
|
||
`)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 迁移:处理旧的 location 字段(如果存在)
|
||
var hasOldLocation int
|
||
err = s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('proxies') WHERE name='location'`).Scan(&hasOldLocation)
|
||
if err == nil && hasOldLocation > 0 {
|
||
log.Println("[storage] migrating: renaming location to exit_location")
|
||
// 如果有旧的 location 字段,先添加新字段再复制数据
|
||
s.db.Exec(`ALTER TABLE proxies ADD COLUMN exit_location TEXT NOT NULL DEFAULT ''`)
|
||
s.db.Exec(`UPDATE proxies SET exit_location = location WHERE location != ''`)
|
||
}
|
||
|
||
// 迁移:添加 exit_ip 字段
|
||
var hasExitIP int
|
||
err = s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('proxies') WHERE name='exit_ip'`).Scan(&hasExitIP)
|
||
if err == nil && hasExitIP == 0 {
|
||
log.Println("[storage] migrating: adding exit_ip column")
|
||
_, err = s.db.Exec(`ALTER TABLE proxies ADD COLUMN exit_ip TEXT NOT NULL DEFAULT ''`)
|
||
if err != nil {
|
||
return fmt.Errorf("migrate exit_ip column: %w", err)
|
||
}
|
||
}
|
||
|
||
// 迁移:添加 exit_location 字段
|
||
var hasExitLocation int
|
||
err = s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('proxies') WHERE name='exit_location'`).Scan(&hasExitLocation)
|
||
if err == nil && hasExitLocation == 0 {
|
||
log.Println("[storage] migrating: adding exit_location column")
|
||
_, err = s.db.Exec(`ALTER TABLE proxies ADD COLUMN exit_location TEXT NOT NULL DEFAULT ''`)
|
||
if err != nil {
|
||
return fmt.Errorf("migrate exit_location column: %w", err)
|
||
}
|
||
}
|
||
|
||
// 迁移:添加 latency 字段
|
||
var hasLatency int
|
||
err = s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('proxies') WHERE name='latency'`).Scan(&hasLatency)
|
||
if err == nil && hasLatency == 0 {
|
||
log.Println("[storage] migrating: adding latency column")
|
||
s.db.Exec(`ALTER TABLE proxies ADD COLUMN latency INTEGER NOT NULL DEFAULT 0`)
|
||
}
|
||
|
||
// 迁移:添加质量等级字段
|
||
var hasQuality int
|
||
s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('proxies') WHERE name='quality_grade'`).Scan(&hasQuality)
|
||
if hasQuality == 0 {
|
||
log.Println("[storage] migrating: adding quality_grade column")
|
||
s.db.Exec(`ALTER TABLE proxies ADD COLUMN quality_grade TEXT NOT NULL DEFAULT 'C'`)
|
||
}
|
||
|
||
// 迁移:添加使用统计字段
|
||
var hasUseCount int
|
||
s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('proxies') WHERE name='use_count'`).Scan(&hasUseCount)
|
||
if hasUseCount == 0 {
|
||
log.Println("[storage] migrating: adding usage tracking columns")
|
||
s.db.Exec(`ALTER TABLE proxies ADD COLUMN use_count INTEGER NOT NULL DEFAULT 0`)
|
||
s.db.Exec(`ALTER TABLE proxies ADD COLUMN success_count INTEGER NOT NULL DEFAULT 0`)
|
||
s.db.Exec(`ALTER TABLE proxies ADD COLUMN last_used DATETIME`)
|
||
}
|
||
|
||
// 迁移:添加状态字段
|
||
var hasStatus int
|
||
s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('proxies') WHERE name='status'`).Scan(&hasStatus)
|
||
if hasStatus == 0 {
|
||
log.Println("[storage] migrating: adding status column")
|
||
s.db.Exec(`ALTER TABLE proxies ADD COLUMN status TEXT NOT NULL DEFAULT 'active'`)
|
||
}
|
||
|
||
// 迁移:添加 source 字段(区分免费代理和订阅代理)
|
||
var hasSource int
|
||
s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('proxies') WHERE name='source'`).Scan(&hasSource)
|
||
if hasSource == 0 {
|
||
log.Println("[storage] migrating: adding source column")
|
||
s.db.Exec(`ALTER TABLE proxies ADD COLUMN source TEXT NOT NULL DEFAULT 'free'`)
|
||
}
|
||
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_source ON proxies(source, status)`)
|
||
|
||
// 迁移:添加 subscription_id 字段
|
||
var hasSubID int
|
||
s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('proxies') WHERE name='subscription_id'`).Scan(&hasSubID)
|
||
if hasSubID == 0 {
|
||
log.Println("[storage] migrating: adding subscription_id column")
|
||
s.db.Exec(`ALTER TABLE proxies ADD COLUMN subscription_id INTEGER NOT NULL DEFAULT 0`)
|
||
}
|
||
|
||
// 创建订阅表
|
||
_, err = s.db.Exec(`
|
||
CREATE TABLE IF NOT EXISTS subscriptions (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
name TEXT NOT NULL DEFAULT '',
|
||
url TEXT NOT NULL DEFAULT '',
|
||
file_path TEXT NOT NULL DEFAULT '',
|
||
format TEXT NOT NULL DEFAULT 'clash',
|
||
refresh_min INTEGER NOT NULL DEFAULT 60,
|
||
last_fetch DATETIME,
|
||
status TEXT NOT NULL DEFAULT 'active',
|
||
proxy_count INTEGER NOT NULL DEFAULT 0,
|
||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||
)
|
||
`)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 迁移:订阅表添加 contributed 和 last_success 字段
|
||
var hasContributed int
|
||
s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('subscriptions') WHERE name='contributed'`).Scan(&hasContributed)
|
||
if hasContributed == 0 {
|
||
s.db.Exec(`ALTER TABLE subscriptions ADD COLUMN contributed INTEGER NOT NULL DEFAULT 0`)
|
||
}
|
||
var hasLastSuccess int
|
||
s.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('subscriptions') WHERE name='last_success'`).Scan(&hasLastSuccess)
|
||
if hasLastSuccess == 0 {
|
||
s.db.Exec(`ALTER TABLE subscriptions ADD COLUMN last_success DATETIME`)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// AddProxy 新增免费代理,已存在则忽略
|
||
func (s *Storage) AddProxy(address, protocol string) error {
|
||
result, err := s.db.Exec(
|
||
`INSERT OR IGNORE INTO proxies (address, protocol, source) VALUES (?, ?, 'free')`,
|
||
address, protocol,
|
||
)
|
||
if err != nil {
|
||
log.Printf("[storage] AddProxy %s error: %v", address, err)
|
||
return err
|
||
}
|
||
|
||
// 检查是否真的插入了
|
||
affected, _ := result.RowsAffected()
|
||
if affected == 0 {
|
||
log.Printf("[storage] AddProxy %s ignored (already exists or constraint)", address)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// AddProxies 批量新增
|
||
func (s *Storage) AddProxies(proxies []Proxy) error {
|
||
tx, err := s.db.Begin()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
stmt, err := tx.Prepare(`INSERT OR IGNORE INTO proxies (address, protocol) VALUES (?, ?)`)
|
||
if err != nil {
|
||
tx.Rollback()
|
||
return err
|
||
}
|
||
defer stmt.Close()
|
||
|
||
for _, p := range proxies {
|
||
if _, err := stmt.Exec(p.Address, p.Protocol); err != nil {
|
||
log.Printf("insert proxy %s error: %v", p.Address, err)
|
||
}
|
||
}
|
||
return tx.Commit()
|
||
}
|
||
|
||
// GetRandom 随机取一个可用代理(优先选择质量高的)
|
||
func (s *Storage) GetRandom() (*Proxy, error) {
|
||
rows, err := s.db.Query(
|
||
`SELECT `+proxyColumns+`
|
||
FROM proxies
|
||
WHERE status = 'active' AND fail_count < 3
|
||
ORDER BY
|
||
CASE quality_grade
|
||
WHEN 'S' THEN 1
|
||
WHEN 'A' THEN 2
|
||
WHEN 'B' THEN 3
|
||
ELSE 4
|
||
END,
|
||
RANDOM()
|
||
LIMIT 1`,
|
||
)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
if rows.Next() {
|
||
return scanProxy(rows)
|
||
}
|
||
return nil, fmt.Errorf("no available proxy")
|
||
}
|
||
|
||
// proxyColumns 代理表查询的标准列列表
|
||
const proxyColumns = `id, address, protocol, exit_ip, exit_location, latency, quality_grade,
|
||
use_count, success_count, fail_count, last_used, last_check, created_at, status, source, subscription_id`
|
||
|
||
// scanProxy 扫描代理行数据
|
||
func scanProxy(rows *sql.Rows) (*Proxy, error) {
|
||
p := &Proxy{}
|
||
var lastUsed, lastCheck sql.NullTime
|
||
var source sql.NullString
|
||
var subID sql.NullInt64
|
||
if err := rows.Scan(&p.ID, &p.Address, &p.Protocol, &p.ExitIP, &p.ExitLocation,
|
||
&p.Latency, &p.QualityGrade, &p.UseCount, &p.SuccessCount, &p.FailCount,
|
||
&lastUsed, &lastCheck, &p.CreatedAt, &p.Status, &source, &subID); err != nil {
|
||
return nil, err
|
||
}
|
||
if lastUsed.Valid {
|
||
p.LastUsed = lastUsed.Time
|
||
}
|
||
if lastCheck.Valid {
|
||
p.LastCheck = lastCheck.Time
|
||
}
|
||
if source.Valid {
|
||
p.Source = source.String
|
||
} else {
|
||
p.Source = "free"
|
||
}
|
||
if subID.Valid {
|
||
p.SubscriptionID = subID.Int64
|
||
}
|
||
return p, nil
|
||
}
|
||
|
||
// GetAll 获取所有可用代理
|
||
func (s *Storage) GetAll() ([]Proxy, error) {
|
||
return s.GetAllFiltered("")
|
||
}
|
||
|
||
// GetAllFiltered 获取可用代理(可按来源过滤)
|
||
// sourceFilter: "" = 全部, "free" = 仅免费, "custom" = 仅订阅
|
||
func (s *Storage) GetAllFiltered(sourceFilter string) ([]Proxy, error) {
|
||
query := `SELECT ` + proxyColumns + `
|
||
FROM proxies
|
||
WHERE status IN ('active', 'degraded') AND fail_count < 3`
|
||
var args []interface{}
|
||
if sourceFilter != "" {
|
||
query += ` AND source = ?`
|
||
args = append(args, sourceFilter)
|
||
}
|
||
query += ` ORDER BY latency ASC`
|
||
|
||
rows, err := s.db.Query(query, args...)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
var proxies []Proxy
|
||
for rows.Next() {
|
||
p, err := scanProxy(rows)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
proxies = append(proxies, *p)
|
||
}
|
||
return proxies, nil
|
||
}
|
||
|
||
// GetRandomExclude 排除指定地址随机取一个
|
||
func (s *Storage) GetRandomExclude(excludes []string) (*Proxy, error) {
|
||
return s.GetRandomExcludeFiltered(excludes, "")
|
||
}
|
||
|
||
// GetRandomExcludeFiltered 排除指定地址随机取一个(可按来源过滤)
|
||
func (s *Storage) GetRandomExcludeFiltered(excludes []string, sourceFilter string) (*Proxy, error) {
|
||
proxies, err := s.GetAllFiltered(sourceFilter)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
excludeMap := make(map[string]bool)
|
||
for _, e := range excludes {
|
||
excludeMap[e] = true
|
||
}
|
||
|
||
var available []Proxy
|
||
for _, p := range proxies {
|
||
if !excludeMap[p.Address] {
|
||
available = append(available, p)
|
||
}
|
||
}
|
||
|
||
if len(available) == 0 {
|
||
if sourceFilter != "" {
|
||
return nil, fmt.Errorf("no available %s proxy", sourceFilter)
|
||
}
|
||
return s.GetRandom()
|
||
}
|
||
|
||
p := available[rand.Intn(len(available))]
|
||
return &p, nil
|
||
}
|
||
|
||
// GetLowestLatencyExclude 排除指定地址后获取延迟最低的代理
|
||
func (s *Storage) GetLowestLatencyExclude(excludes []string) (*Proxy, error) {
|
||
return s.GetLowestLatencyExcludeFiltered(excludes, "")
|
||
}
|
||
|
||
// GetLowestLatencyExcludeFiltered 排除指定地址后获取延迟最低的代理(可按来源过滤)
|
||
func (s *Storage) GetLowestLatencyExcludeFiltered(excludes []string, sourceFilter string) (*Proxy, error) {
|
||
proxies, err := s.GetAllFiltered(sourceFilter)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
excludeMap := make(map[string]bool)
|
||
for _, e := range excludes {
|
||
excludeMap[e] = true
|
||
}
|
||
|
||
for _, p := range proxies {
|
||
if !excludeMap[p.Address] {
|
||
proxy := p
|
||
return &proxy, nil
|
||
}
|
||
}
|
||
|
||
return nil, fmt.Errorf("no available proxy")
|
||
}
|
||
|
||
// GetRandomByProtocolExclude 按协议获取随机代理(排除已尝试的)
|
||
func (s *Storage) GetRandomByProtocolExclude(protocol string, excludes []string) (*Proxy, error) {
|
||
return s.GetRandomByProtocolExcludeFiltered(protocol, excludes, "")
|
||
}
|
||
|
||
// GetRandomByProtocolExcludeFiltered 按协议获取随机代理(可按来源过滤)
|
||
func (s *Storage) GetRandomByProtocolExcludeFiltered(protocol string, excludes []string, sourceFilter string) (*Proxy, error) {
|
||
proxies, err := s.GetAllFiltered(sourceFilter)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
excludeMap := make(map[string]bool)
|
||
for _, e := range excludes {
|
||
excludeMap[e] = true
|
||
}
|
||
|
||
var available []Proxy
|
||
for _, p := range proxies {
|
||
if p.Protocol == protocol && !excludeMap[p.Address] {
|
||
available = append(available, p)
|
||
}
|
||
}
|
||
|
||
if len(available) == 0 {
|
||
return nil, fmt.Errorf("no %s proxy available", protocol)
|
||
}
|
||
|
||
proxy := available[time.Now().UnixNano()%int64(len(available))]
|
||
return &proxy, nil
|
||
}
|
||
|
||
// GetLowestLatencyByProtocolExclude 按协议获取最低延迟代理(排除已尝试的)
|
||
func (s *Storage) GetLowestLatencyByProtocolExclude(protocol string, excludes []string) (*Proxy, error) {
|
||
return s.GetLowestLatencyByProtocolExcludeFiltered(protocol, excludes, "")
|
||
}
|
||
|
||
// GetLowestLatencyByProtocolExcludeFiltered 按协议获取最低延迟代理(可按来源过滤)
|
||
func (s *Storage) GetLowestLatencyByProtocolExcludeFiltered(protocol string, excludes []string, sourceFilter string) (*Proxy, error) {
|
||
proxies, err := s.GetAllFiltered(sourceFilter)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
excludeMap := make(map[string]bool)
|
||
for _, e := range excludes {
|
||
excludeMap[e] = true
|
||
}
|
||
|
||
for _, p := range proxies {
|
||
if p.Protocol == protocol && !excludeMap[p.Address] {
|
||
proxy := p
|
||
return &proxy, nil
|
||
}
|
||
}
|
||
|
||
return nil, fmt.Errorf("no %s proxy available", protocol)
|
||
}
|
||
|
||
// Delete 立即删除指定代理
|
||
func (s *Storage) Delete(address string) error {
|
||
_, err := s.db.Exec(`DELETE FROM proxies WHERE address = ?`, address)
|
||
return err
|
||
}
|
||
|
||
// IncrFail 增加失败次数
|
||
func (s *Storage) IncrFail(address string) error {
|
||
_, err := s.db.Exec(
|
||
`UPDATE proxies SET fail_count = fail_count + 1, last_check = CURRENT_TIMESTAMP WHERE address = ?`,
|
||
address,
|
||
)
|
||
return err
|
||
}
|
||
|
||
// ResetFail 重置失败次数(验证通过)
|
||
func (s *Storage) ResetFail(address string) error {
|
||
_, err := s.db.Exec(
|
||
`UPDATE proxies SET fail_count = 0, last_check = CURRENT_TIMESTAMP WHERE address = ?`,
|
||
address,
|
||
)
|
||
return err
|
||
}
|
||
|
||
// UpdateLatency 更新代理的延迟信息(毫秒)
|
||
func (s *Storage) UpdateLatency(address string, latencyMs int) error {
|
||
_, err := s.db.Exec(
|
||
`UPDATE proxies SET latency = ? WHERE address = ?`,
|
||
latencyMs, address,
|
||
)
|
||
return err
|
||
}
|
||
|
||
// UpdateExitInfo 更新代理的出口 IP、位置和质量等级
|
||
func (s *Storage) UpdateExitInfo(address, exitIP, exitLocation string, latencyMs int) error {
|
||
grade := CalculateQualityGrade(latencyMs)
|
||
_, err := s.db.Exec(
|
||
`UPDATE proxies SET exit_ip = ?, exit_location = ?, latency = ?, quality_grade = ? WHERE address = ?`,
|
||
exitIP, exitLocation, latencyMs, grade, address,
|
||
)
|
||
return err
|
||
}
|
||
|
||
// RecordProxyUse 记录代理使用(成功)
|
||
func (s *Storage) RecordProxyUse(address string, success bool) error {
|
||
if success {
|
||
_, err := s.db.Exec(
|
||
`UPDATE proxies SET use_count = use_count + 1, success_count = success_count + 1,
|
||
last_used = CURRENT_TIMESTAMP WHERE address = ?`,
|
||
address,
|
||
)
|
||
return err
|
||
}
|
||
_, err := s.db.Exec(
|
||
`UPDATE proxies SET use_count = use_count + 1, fail_count = fail_count + 1,
|
||
last_used = CURRENT_TIMESTAMP WHERE address = ?`,
|
||
address,
|
||
)
|
||
return err
|
||
}
|
||
|
||
// GetWorstProxies 获取指定协议中延迟最高的N个代理(仅免费代理)
|
||
func (s *Storage) GetWorstProxies(protocol string, limit int) ([]Proxy, error) {
|
||
rows, err := s.db.Query(
|
||
`SELECT `+proxyColumns+`
|
||
FROM proxies
|
||
WHERE protocol = ? AND status = 'active' AND source = 'free'
|
||
AND quality_grade != 'S'
|
||
AND (JULIANDAY('now') - JULIANDAY(created_at)) * 1440 > 60
|
||
ORDER BY latency DESC, fail_count DESC
|
||
LIMIT ?`, protocol, limit,
|
||
)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
var proxies []Proxy
|
||
for rows.Next() {
|
||
p, err := scanProxy(rows)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
proxies = append(proxies, *p)
|
||
}
|
||
return proxies, nil
|
||
}
|
||
|
||
// ReplaceProxy 替换代理(删除旧的,添加新的)
|
||
func (s *Storage) ReplaceProxy(oldAddress string, newProxy Proxy) error {
|
||
tx, err := s.db.Begin()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defer tx.Rollback()
|
||
|
||
// 删除旧代理
|
||
_, err = tx.Exec(`DELETE FROM proxies WHERE address = ?`, oldAddress)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 添加新代理(带完整信息)
|
||
grade := CalculateQualityGrade(newProxy.Latency)
|
||
source := newProxy.Source
|
||
if source == "" {
|
||
source = "free"
|
||
}
|
||
_, err = tx.Exec(
|
||
`INSERT INTO proxies (address, protocol, exit_ip, exit_location, latency, quality_grade, status, source)
|
||
VALUES (?, ?, ?, ?, ?, ?, 'active', ?)`,
|
||
newProxy.Address, newProxy.Protocol, newProxy.ExitIP, newProxy.ExitLocation, newProxy.Latency, grade, source,
|
||
)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
return tx.Commit()
|
||
}
|
||
|
||
// MarkAsReplacementCandidate 标记代理为替换候选
|
||
func (s *Storage) MarkAsReplacementCandidate(addresses []string) error {
|
||
if len(addresses) == 0 {
|
||
return nil
|
||
}
|
||
placeholders := make([]string, len(addresses))
|
||
args := make([]interface{}, len(addresses))
|
||
for i, addr := range addresses {
|
||
placeholders[i] = "?"
|
||
args[i] = addr
|
||
}
|
||
query := fmt.Sprintf(`UPDATE proxies SET status = 'candidate_replace' WHERE address IN (%s)`,
|
||
fmt.Sprintf("%s", placeholders))
|
||
_, err := s.db.Exec(query, args...)
|
||
return err
|
||
}
|
||
|
||
// GetAverageLatency 获取指定协议的平均延迟
|
||
func (s *Storage) GetAverageLatency(protocol string) (int, error) {
|
||
var avg sql.NullFloat64
|
||
err := s.db.QueryRow(
|
||
`SELECT AVG(latency) FROM proxies WHERE protocol = ? AND status = 'active' AND latency > 0`,
|
||
protocol,
|
||
).Scan(&avg)
|
||
if err != nil || !avg.Valid {
|
||
return 0, err
|
||
}
|
||
return int(avg.Float64), nil
|
||
}
|
||
|
||
// GetQualityDistribution 获取质量分布统计
|
||
func (s *Storage) GetQualityDistribution() (map[string]int, error) {
|
||
rows, err := s.db.Query(
|
||
`SELECT quality_grade, COUNT(*) as count
|
||
FROM proxies
|
||
WHERE status = 'active' AND fail_count < 3
|
||
GROUP BY quality_grade`,
|
||
)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
dist := make(map[string]int)
|
||
for rows.Next() {
|
||
var grade string
|
||
var count int
|
||
if err := rows.Scan(&grade, &count); err != nil {
|
||
return nil, err
|
||
}
|
||
dist[grade] = count
|
||
}
|
||
return dist, nil
|
||
}
|
||
|
||
// GetBatchForHealthCheck 获取一批需要健康检查的代理
|
||
func (s *Storage) GetBatchForHealthCheck(batchSize int, skipSGrade bool) ([]Proxy, error) {
|
||
query := `SELECT ` + proxyColumns + `
|
||
FROM proxies
|
||
WHERE status IN ('active', 'degraded') AND fail_count < 3`
|
||
|
||
if skipSGrade {
|
||
query += ` AND quality_grade != 'S'`
|
||
}
|
||
|
||
query += ` ORDER BY
|
||
COALESCE(last_check, '1970-01-01') ASC,
|
||
quality_grade DESC
|
||
LIMIT ?`
|
||
|
||
rows, err := s.db.Query(query, batchSize)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
var proxies []Proxy
|
||
for rows.Next() {
|
||
p, err := scanProxy(rows)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
proxies = append(proxies, *p)
|
||
}
|
||
return proxies, nil
|
||
}
|
||
|
||
// CalculateQualityGrade 根据延迟计算质量等级
|
||
func CalculateQualityGrade(latencyMs int) string {
|
||
switch {
|
||
case latencyMs <= 500:
|
||
return "S" // 超快
|
||
case latencyMs <= 1000:
|
||
return "A" // 良好
|
||
case latencyMs <= 2000:
|
||
return "B" // 可用
|
||
default:
|
||
return "C" // 淘汰候选
|
||
}
|
||
}
|
||
|
||
// DeleteInvalid 删除失败次数超过阈值的代理(仅免费代理)
|
||
func (s *Storage) DeleteInvalid(maxFailCount int) (int64, error) {
|
||
res, err := s.db.Exec(`DELETE FROM proxies WHERE fail_count >= ? AND source = 'free'`, maxFailCount)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return res.RowsAffected()
|
||
}
|
||
|
||
// DeleteBlockedCountries 删除指定国家代码出口的代理
|
||
func (s *Storage) DeleteBlockedCountries(countryCodes []string) (int64, error) {
|
||
if len(countryCodes) == 0 {
|
||
return 0, nil
|
||
}
|
||
|
||
var totalDeleted int64
|
||
for _, code := range countryCodes {
|
||
// exit_location 格式:如 "CN Beijing" 或 "CN"(仅国家代码)
|
||
// 同时匹配 "CODE" 和 "CODE ..." 两种情况(仅删除免费代理)
|
||
res, err := s.db.Exec(`DELETE FROM proxies WHERE source = 'free' AND (exit_location = ? OR exit_location LIKE ?)`, code, code+" %")
|
||
if err != nil {
|
||
return totalDeleted, err
|
||
}
|
||
affected, _ := res.RowsAffected()
|
||
totalDeleted += affected
|
||
}
|
||
return totalDeleted, nil
|
||
}
|
||
|
||
// DeleteNotAllowedCountries 删除不在白名单中的代理
|
||
func (s *Storage) DeleteNotAllowedCountries(allowedCodes []string) (int64, error) {
|
||
if len(allowedCodes) == 0 {
|
||
return 0, nil
|
||
}
|
||
|
||
// 构建 WHERE 条件:exit_location 不以任何白名单国家代码开头
|
||
// 即:NOT (exit_location = 'US' OR exit_location LIKE 'US %' OR ...)
|
||
conditions := make([]string, 0, len(allowedCodes)*2)
|
||
args := make([]interface{}, 0, len(allowedCodes)*2)
|
||
for _, code := range allowedCodes {
|
||
conditions = append(conditions, "exit_location = ?", "exit_location LIKE ?")
|
||
args = append(args, code, code+" %")
|
||
}
|
||
|
||
query := `DELETE FROM proxies WHERE source = 'free' AND exit_location != '' AND NOT (` + strings.Join(conditions, " OR ") + `)`
|
||
res, err := s.db.Exec(query, args...)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return res.RowsAffected()
|
||
}
|
||
|
||
// DeleteWithoutExitInfo 删除没有出口信息的代理(仅免费代理)
|
||
func (s *Storage) DeleteWithoutExitInfo() (int64, error) {
|
||
res, err := s.db.Exec(`DELETE FROM proxies WHERE source = 'free' AND (exit_ip = '' OR exit_location = '')`)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return res.RowsAffected()
|
||
}
|
||
|
||
// DisableBlockedCountries 禁用订阅代理中属于被屏蔽国家的(不删除)
|
||
func (s *Storage) DisableBlockedCountries(countryCodes []string) (int64, error) {
|
||
if len(countryCodes) == 0 {
|
||
return 0, nil
|
||
}
|
||
var total int64
|
||
for _, code := range countryCodes {
|
||
res, err := s.db.Exec(
|
||
`UPDATE proxies SET status = 'disabled' WHERE source = 'custom' AND status = 'active' AND (exit_location = ? OR exit_location LIKE ?)`,
|
||
code, code+" %",
|
||
)
|
||
if err != nil {
|
||
return total, err
|
||
}
|
||
affected, _ := res.RowsAffected()
|
||
total += affected
|
||
}
|
||
return total, nil
|
||
}
|
||
|
||
// DisableNotAllowedCountries 禁用订阅代理中不在白名单的(不删除)
|
||
func (s *Storage) DisableNotAllowedCountries(allowedCodes []string) (int64, error) {
|
||
if len(allowedCodes) == 0 {
|
||
return 0, nil
|
||
}
|
||
conditions := make([]string, 0, len(allowedCodes)*2)
|
||
args := make([]interface{}, 0, len(allowedCodes)*2)
|
||
for _, code := range allowedCodes {
|
||
conditions = append(conditions, "exit_location = ?", "exit_location LIKE ?")
|
||
args = append(args, code, code+" %")
|
||
}
|
||
query := `UPDATE proxies SET status = 'disabled' WHERE source = 'custom' AND status = 'active' AND exit_location != '' AND NOT (` + strings.Join(conditions, " OR ") + `)`
|
||
res, err := s.db.Exec(query, args...)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return res.RowsAffected()
|
||
}
|
||
|
||
// Count 返回可用代理数量(仅免费代理,用于 slot 计算)
|
||
func (s *Storage) Count() (int, error) {
|
||
var count int
|
||
err := s.db.QueryRow(
|
||
`SELECT COUNT(*) FROM proxies WHERE status IN ('active', 'degraded') AND fail_count < 3 AND source = 'free'`,
|
||
).Scan(&count)
|
||
return count, err
|
||
}
|
||
|
||
// CountAll 返回所有可用代理数量(免费+订阅)
|
||
func (s *Storage) CountAll() (int, error) {
|
||
var count int
|
||
err := s.db.QueryRow(
|
||
`SELECT COUNT(*) FROM proxies WHERE status IN ('active', 'degraded') AND fail_count < 3`,
|
||
).Scan(&count)
|
||
return count, err
|
||
}
|
||
|
||
// CountByProtocol 按协议统计数量(仅免费代理,用于 slot 计算)
|
||
func (s *Storage) CountByProtocol(protocol string) (int, error) {
|
||
var count int
|
||
err := s.db.QueryRow(
|
||
`SELECT COUNT(*) FROM proxies WHERE status IN ('active', 'degraded') AND fail_count < 3 AND source = 'free' AND protocol = ?`,
|
||
protocol,
|
||
).Scan(&count)
|
||
return count, err
|
||
}
|
||
|
||
// IncrementFailCount 增加失败次数
|
||
func (s *Storage) IncrementFailCount(address string) error {
|
||
_, err := s.db.Exec(
|
||
`UPDATE proxies SET fail_count = fail_count + 1 WHERE address = ?`,
|
||
address,
|
||
)
|
||
return err
|
||
}
|
||
|
||
// GetByProtocol 按协议获取代理列表
|
||
func (s *Storage) GetByProtocol(protocol string) ([]Proxy, error) {
|
||
rows, err := s.db.Query(
|
||
`SELECT `+proxyColumns+`
|
||
FROM proxies
|
||
WHERE status IN ('active', 'degraded') AND fail_count < 3 AND protocol = ?
|
||
ORDER BY latency ASC`, protocol,
|
||
)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
var proxies []Proxy
|
||
for rows.Next() {
|
||
p, err := scanProxy(rows)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
proxies = append(proxies, *p)
|
||
}
|
||
return proxies, nil
|
||
}
|
||
|
||
// ========== 订阅代理相关方法 ==========
|
||
|
||
// AddProxyWithSource 新增代理并指定来源和订阅ID
|
||
func (s *Storage) AddProxyWithSource(address, protocol, source string, subscriptionID ...int64) error {
|
||
subID := int64(0)
|
||
if len(subscriptionID) > 0 {
|
||
subID = subscriptionID[0]
|
||
}
|
||
result, err := s.db.Exec(
|
||
`INSERT OR IGNORE INTO proxies (address, protocol, source, subscription_id) VALUES (?, ?, ?, ?)`,
|
||
address, protocol, source, subID,
|
||
)
|
||
if err != nil {
|
||
log.Printf("[storage] AddProxyWithSource %s error: %v", address, err)
|
||
return err
|
||
}
|
||
affected, _ := result.RowsAffected()
|
||
if affected == 0 {
|
||
// 已存在,更新 source 和 subscription_id
|
||
_, err = s.db.Exec(`UPDATE proxies SET source = ?, subscription_id = ? WHERE address = ?`, source, subID, address)
|
||
}
|
||
return err
|
||
}
|
||
|
||
// DeleteBySubscriptionID 删除指定订阅的所有代理
|
||
func (s *Storage) DeleteBySubscriptionID(subscriptionID int64) (int64, error) {
|
||
res, err := s.db.Exec(`DELETE FROM proxies WHERE subscription_id = ?`, subscriptionID)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return res.RowsAffected()
|
||
}
|
||
|
||
// DisableProxy 禁用代理(软删除,用于订阅代理)
|
||
func (s *Storage) DisableProxy(address string) error {
|
||
_, err := s.db.Exec(
|
||
`UPDATE proxies SET status = 'disabled' WHERE address = ?`,
|
||
address,
|
||
)
|
||
return err
|
||
}
|
||
|
||
// EnableProxy 启用代理(从禁用状态恢复)
|
||
func (s *Storage) EnableProxy(address string) error {
|
||
_, err := s.db.Exec(
|
||
`UPDATE proxies SET status = 'active', fail_count = 0 WHERE address = ?`,
|
||
address,
|
||
)
|
||
return err
|
||
}
|
||
|
||
// GetDisabledCustomProxies 获取所有被禁用的订阅代理
|
||
func (s *Storage) GetDisabledCustomProxies() ([]Proxy, error) {
|
||
rows, err := s.db.Query(
|
||
`SELECT `+proxyColumns+`
|
||
FROM proxies
|
||
WHERE source = 'custom' AND status = 'disabled'`,
|
||
)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
var proxies []Proxy
|
||
for rows.Next() {
|
||
p, err := scanProxy(rows)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
proxies = append(proxies, *p)
|
||
}
|
||
return proxies, nil
|
||
}
|
||
|
||
// CountBySource 按来源统计可用代理数量
|
||
func (s *Storage) CountBySource(source string) (int, error) {
|
||
var count int
|
||
err := s.db.QueryRow(
|
||
`SELECT COUNT(*) FROM proxies WHERE source = ? AND status IN ('active', 'degraded') AND fail_count < 3`,
|
||
source,
|
||
).Scan(&count)
|
||
return count, err
|
||
}
|
||
|
||
// DeleteBySource 删除指定来源的所有代理
|
||
func (s *Storage) DeleteBySource(source string) (int64, error) {
|
||
res, err := s.db.Exec(`DELETE FROM proxies WHERE source = ?`, source)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return res.RowsAffected()
|
||
}
|
||
|
||
// DeleteCustomProxiesNotIn 删除不在给定地址列表中的订阅代理
|
||
func (s *Storage) DeleteCustomProxiesNotIn(addresses []string) (int64, error) {
|
||
if len(addresses) == 0 {
|
||
return s.DeleteBySource("custom")
|
||
}
|
||
placeholders := make([]string, len(addresses))
|
||
args := make([]interface{}, len(addresses))
|
||
for i, addr := range addresses {
|
||
placeholders[i] = "?"
|
||
args[i] = addr
|
||
}
|
||
query := `DELETE FROM proxies WHERE source = 'custom' AND address NOT IN (` + strings.Join(placeholders, ",") + `)`
|
||
res, err := s.db.Exec(query, args...)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return res.RowsAffected()
|
||
}
|
||
|
||
// ========== 订阅 CRUD ==========
|
||
|
||
// AddSubscription 添加订阅(自动去重:相同 URL 或 file_path 不重复添加)
|
||
func (s *Storage) AddSubscription(name, url, filePath, format string, refreshMin int) (int64, error) {
|
||
// 去重检查
|
||
if url != "" {
|
||
var existID int64
|
||
err := s.db.QueryRow(`SELECT id FROM subscriptions WHERE url = ? AND url != ''`, url).Scan(&existID)
|
||
if err == nil {
|
||
return 0, fmt.Errorf("该订阅 URL 已存在")
|
||
}
|
||
}
|
||
if filePath != "" {
|
||
var existID int64
|
||
err := s.db.QueryRow(`SELECT id FROM subscriptions WHERE file_path = ? AND file_path != ''`, filePath).Scan(&existID)
|
||
if err == nil {
|
||
return 0, fmt.Errorf("该订阅文件已存在")
|
||
}
|
||
}
|
||
|
||
res, err := s.db.Exec(
|
||
`INSERT INTO subscriptions (name, url, file_path, format, refresh_min) VALUES (?, ?, ?, ?, ?)`,
|
||
name, url, filePath, format, refreshMin,
|
||
)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return res.LastInsertId()
|
||
}
|
||
|
||
// CountBySubscriptionID 统计指定订阅的可用/禁用代理数
|
||
func (s *Storage) CountBySubscriptionID(subID int64) (active int, disabled int) {
|
||
s.db.QueryRow(
|
||
`SELECT COUNT(*) FROM proxies WHERE subscription_id = ? AND status IN ('active', 'degraded') AND fail_count < 3`,
|
||
subID,
|
||
).Scan(&active)
|
||
s.db.QueryRow(
|
||
`SELECT COUNT(*) FROM proxies WHERE subscription_id = ? AND status = 'disabled'`,
|
||
subID,
|
||
).Scan(&disabled)
|
||
return
|
||
}
|
||
|
||
// AddContributedSubscription 添加访客贡献的订阅
|
||
func (s *Storage) AddContributedSubscription(name, url string, refreshMin int) (int64, error) {
|
||
if url == "" {
|
||
return 0, fmt.Errorf("URL 不能为空")
|
||
}
|
||
// 去重
|
||
var existID int64
|
||
err := s.db.QueryRow(`SELECT id FROM subscriptions WHERE url = ? AND url != ''`, url).Scan(&existID)
|
||
if err == nil {
|
||
return 0, fmt.Errorf("该订阅 URL 已存在")
|
||
}
|
||
|
||
res, err := s.db.Exec(
|
||
`INSERT INTO subscriptions (name, url, format, refresh_min, contributed) VALUES (?, ?, 'auto', ?, 1)`,
|
||
name, url, refreshMin,
|
||
)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return res.LastInsertId()
|
||
}
|
||
|
||
// UpdateSubscription 更新订阅
|
||
func (s *Storage) UpdateSubscription(id int64, name, url, filePath, format string, refreshMin int) error {
|
||
_, err := s.db.Exec(
|
||
`UPDATE subscriptions SET name = ?, url = ?, file_path = ?, format = ?, refresh_min = ? WHERE id = ?`,
|
||
name, url, filePath, format, refreshMin, id,
|
||
)
|
||
return err
|
||
}
|
||
|
||
// DeleteSubscription 删除订阅
|
||
func (s *Storage) DeleteSubscription(id int64) error {
|
||
_, err := s.db.Exec(`DELETE FROM subscriptions WHERE id = ?`, id)
|
||
return err
|
||
}
|
||
|
||
// GetSubscriptions 获取所有订阅
|
||
func (s *Storage) GetSubscriptions() ([]Subscription, error) {
|
||
rows, err := s.db.Query(
|
||
`SELECT ` + subColumns + `
|
||
FROM subscriptions ORDER BY created_at DESC`,
|
||
)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
var subs []Subscription
|
||
for rows.Next() {
|
||
sub, err := scanSubscription(rows)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
subs = append(subs, *sub)
|
||
}
|
||
return subs, nil
|
||
}
|
||
|
||
// GetSubscription 获取单个订阅
|
||
func (s *Storage) GetSubscription(id int64) (*Subscription, error) {
|
||
rows, err := s.db.Query(
|
||
`SELECT ` + subColumns + `
|
||
FROM subscriptions WHERE id = ?`, id,
|
||
)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
if rows.Next() {
|
||
return scanSubscription(rows)
|
||
}
|
||
return nil, fmt.Errorf("subscription %d not found", id)
|
||
}
|
||
|
||
// UpdateSubscriptionFetch 更新订阅的最后拉取时间和代理数
|
||
func (s *Storage) UpdateSubscriptionFetch(id int64, proxyCount int) error {
|
||
_, err := s.db.Exec(
|
||
`UPDATE subscriptions SET last_fetch = CURRENT_TIMESTAMP, proxy_count = ? WHERE id = ?`,
|
||
proxyCount, id,
|
||
)
|
||
return err
|
||
}
|
||
|
||
// UpdateSubscriptionSuccess 记录订阅最后一次有可用节点的时间
|
||
func (s *Storage) UpdateSubscriptionSuccess(id int64) error {
|
||
_, err := s.db.Exec(
|
||
`UPDATE subscriptions SET last_success = CURRENT_TIMESTAMP WHERE id = ?`, id,
|
||
)
|
||
return err
|
||
}
|
||
|
||
// GetStaleSubscriptions 获取连续 N 天无可用节点的订阅
|
||
func (s *Storage) GetStaleSubscriptions(staleDays int) ([]Subscription, error) {
|
||
rows, err := s.db.Query(
|
||
`SELECT `+subColumns+`
|
||
FROM subscriptions
|
||
WHERE status = 'active'
|
||
AND (last_success IS NULL OR JULIANDAY('now') - JULIANDAY(last_success) > ?)
|
||
AND JULIANDAY('now') - JULIANDAY(created_at) > ?`,
|
||
staleDays, staleDays,
|
||
)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
var subs []Subscription
|
||
for rows.Next() {
|
||
sub, err := scanSubscription(rows)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
subs = append(subs, *sub)
|
||
}
|
||
return subs, nil
|
||
}
|
||
|
||
// ToggleSubscription 切换订阅状态
|
||
func (s *Storage) ToggleSubscription(id int64) error {
|
||
_, err := s.db.Exec(
|
||
`UPDATE subscriptions SET status = CASE WHEN status = 'active' THEN 'paused' ELSE 'active' END WHERE id = ?`,
|
||
id,
|
||
)
|
||
return err
|
||
}
|
||
|
||
// scanSubscription 扫描订阅行数据
|
||
// subColumns 订阅表查询列
|
||
const subColumns = `id, name, url, file_path, format, refresh_min, last_fetch, last_success, status, proxy_count, created_at, contributed`
|
||
|
||
func scanSubscription(rows *sql.Rows) (*Subscription, error) {
|
||
sub := &Subscription{}
|
||
var lastFetch, lastSuccess sql.NullTime
|
||
var contributed int
|
||
if err := rows.Scan(&sub.ID, &sub.Name, &sub.URL, &sub.FilePath, &sub.Format,
|
||
&sub.RefreshMin, &lastFetch, &lastSuccess, &sub.Status, &sub.ProxyCount, &sub.CreatedAt, &contributed); err != nil {
|
||
return nil, err
|
||
}
|
||
if lastFetch.Valid {
|
||
sub.LastFetch = lastFetch.Time
|
||
}
|
||
if lastSuccess.Valid {
|
||
sub.LastSuccess = lastSuccess.Time
|
||
}
|
||
sub.Contributed = contributed == 1
|
||
return sub, nil
|
||
}
|
||
|
||
// Close 关闭数据库
|
||
func (s *Storage) Close() error {
|
||
return s.db.Close()
|
||
}
|
||
|
||
// GetDB 获取数据库实例(供其他模块使用)
|
||
func (s *Storage) GetDB() *sql.DB {
|
||
return s.db
|
||
}
|