Files
GoProxy/fetcher/source_manager.go
isboyjc f55209d8d3 feat: init
2026-03-29 03:31:59 +08:00

133 lines
3.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package fetcher
import (
"database/sql"
"log"
"sync"
"time"
)
// SourceManager 代理源管理器(断路器)
type SourceManager struct {
db *sql.DB
mu sync.RWMutex
}
func NewSourceManager(db *sql.DB) *SourceManager {
return &SourceManager{db: db}
}
// CanUseSource 判断源是否可用
func (sm *SourceManager) CanUseSource(url string) bool {
sm.mu.RLock()
defer sm.mu.RUnlock()
var status string
var disabledUntil sql.NullTime
err := sm.db.QueryRow(
`SELECT status, disabled_until FROM source_status WHERE url = ?`,
url,
).Scan(&status, &disabledUntil)
// 源不存在,默认可用
if err != nil {
return true
}
// 检查是否被禁用且还在冷却期
if status == "disabled" && disabledUntil.Valid {
if time.Now().Before(disabledUntil.Time) {
return false
}
// 冷却期结束,重置状态
sm.db.Exec(`UPDATE source_status SET status = 'active', consecutive_fails = 0 WHERE url = ?`, url)
return true
}
return status != "disabled"
}
// RecordSuccess 记录源抓取成功
func (sm *SourceManager) RecordSuccess(url string) {
sm.mu.Lock()
defer sm.mu.Unlock()
sm.db.Exec(`
INSERT INTO source_status (url, success_count, consecutive_fails, last_success, status)
VALUES (?, 1, 0, CURRENT_TIMESTAMP, 'active')
ON CONFLICT(url) DO UPDATE SET
success_count = success_count + 1,
consecutive_fails = 0,
last_success = CURRENT_TIMESTAMP,
status = 'active'
`, url)
}
// RecordFail 记录源抓取失败
func (sm *SourceManager) RecordFail(url string, failThreshold, disableThreshold, cooldownMinutes int) {
sm.mu.Lock()
defer sm.mu.Unlock()
// 增加失败计数
sm.db.Exec(`
INSERT INTO source_status (url, fail_count, consecutive_fails, last_fail)
VALUES (?, 1, 1, CURRENT_TIMESTAMP)
ON CONFLICT(url) DO UPDATE SET
fail_count = fail_count + 1,
consecutive_fails = consecutive_fails + 1,
last_fail = CURRENT_TIMESTAMP
`, url)
// 检查是否需要降级或禁用
var consecutiveFails int
sm.db.QueryRow(`SELECT consecutive_fails FROM source_status WHERE url = ?`, url).Scan(&consecutiveFails)
if consecutiveFails >= disableThreshold {
// 禁用源
disabledUntil := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute)
sm.db.Exec(
`UPDATE source_status SET status = 'disabled', disabled_until = ? WHERE url = ?`,
disabledUntil, url,
)
log.Printf("[source] ⛔ 禁用源(连续失败%d次: %s (冷却%d分钟)", consecutiveFails, url, cooldownMinutes)
} else if consecutiveFails >= failThreshold {
// 降级源
sm.db.Exec(`UPDATE source_status SET status = 'degraded' WHERE url = ?`, url)
log.Printf("[source] ⚠️ 降级源(连续失败%d次: %s", consecutiveFails, url)
}
}
// GetSourceStats 获取所有源的统计信息
func (sm *SourceManager) GetSourceStats() ([]map[string]interface{}, error) {
rows, err := sm.db.Query(`
SELECT url, success_count, fail_count, consecutive_fails,
last_success, last_fail, status
FROM source_status
ORDER BY success_count DESC
`)
if err != nil {
return nil, err
}
defer rows.Close()
var stats []map[string]interface{}
for rows.Next() {
var url, status string
var successCount, failCount, consecutiveFails int
var lastSuccess, lastFail sql.NullTime
rows.Scan(&url, &successCount, &failCount, &consecutiveFails, &lastSuccess, &lastFail, &status)
stats = append(stats, map[string]interface{}{
"url": url,
"success_count": successCount,
"fail_count": failCount,
"consecutive_fails": consecutiveFails,
"last_success": lastSuccess,
"last_fail": lastFail,
"status": status,
})
}
return stats, nil
}