Add HTTP API server for file downloads from Telegram links

Co-authored-by: krau <71133316+krau@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot]
2026-01-19 04:42:55 +00:00
parent 7def7f5b28
commit ac10c32215
6 changed files with 593 additions and 1 deletions

371
api/handlers.go Normal file
View File

@@ -0,0 +1,371 @@
package api
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"path"
"sync"
"time"
"github.com/charmbracelet/log"
"github.com/krau/SaveAny-Bot/client/bot"
"github.com/krau/SaveAny-Bot/common/utils/tgutil"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/core"
tftask "github.com/krau/SaveAny-Bot/core/tasks/tfile"
"github.com/krau/SaveAny-Bot/pkg/queue"
"github.com/krau/SaveAny-Bot/pkg/tfile"
"github.com/krau/SaveAny-Bot/storage"
"github.com/rs/xid"
)
// Request/Response types
type CreateTaskRequest struct {
TelegramURL string `json:"telegram_url"`
StorageName string `json:"storage_name,omitempty"`
DirPath string `json:"dir_path,omitempty"`
UserID int64 `json:"user_id"`
}
type CreateTaskResponse struct {
TaskID string `json:"task_id"`
Message string `json:"message"`
}
type TaskStatusResponse struct {
TaskID string `json:"task_id"`
Status string `json:"status"` // queued, running, completed, failed, canceled
Title string `json:"title"`
CreatedAt time.Time `json:"created_at"`
Error string `json:"error,omitempty"`
}
type ListTasksResponse struct {
Queued []TaskInfo `json:"queued"`
Running []TaskInfo `json:"running"`
}
type TaskInfo struct {
ID string `json:"id"`
Title string `json:"title"`
}
type ErrorResponse struct {
Error string `json:"error"`
}
// Task tracking
var (
taskStatuses = make(map[string]*taskStatus)
taskStatusesMu sync.RWMutex
)
type taskStatus struct {
ID string
Status string
Title string
CreatedAt time.Time
Error string
}
func handleHealth(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
}
func handleCreateTask(w http.ResponseWriter, r *http.Request) {
var req CreateTaskRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
respondError(w, "invalid request body", http.StatusBadRequest)
return
}
// Validate request
if req.TelegramURL == "" {
respondError(w, "telegram_url is required", http.StatusBadRequest)
return
}
if req.UserID == 0 {
respondError(w, "user_id is required", http.StatusBadRequest)
return
}
logger := log.FromContext(r.Context()).WithPrefix("api")
// Get storage
var stor storage.Storage
var err error
if req.StorageName != "" {
stor, err = storage.GetStorageByUserIDAndName(r.Context(), req.UserID, req.StorageName)
if err != nil {
logger.Errorf("Failed to get storage: %v", err)
respondError(w, fmt.Sprintf("storage not found: %v", err), http.StatusBadRequest)
return
}
} else {
// Use first available storage for the user
storages := storage.GetUserStorages(r.Context(), req.UserID)
if len(storages) == 0 {
respondError(w, "no storage available for user", http.StatusBadRequest)
return
}
stor = storages[0]
}
// Parse Telegram URL
botCtx := bot.ExtContext()
if botCtx == nil {
respondError(w, "bot not initialized", http.StatusInternalServerError)
return
}
chatID, msgID, err := tgutil.ParseMessageLink(botCtx, req.TelegramURL)
if err != nil {
logger.Errorf("Failed to parse Telegram URL: %v", err)
respondError(w, fmt.Sprintf("invalid telegram URL: %v", err), http.StatusBadRequest)
return
}
// Get message from Telegram
msg, err := tgutil.GetMessageByID(botCtx, chatID, msgID)
if err != nil {
logger.Errorf("Failed to get message: %v", err)
respondError(w, fmt.Sprintf("failed to get telegram message: %v", err), http.StatusInternalServerError)
return
}
// Check if message has media
media, ok := msg.GetMedia()
if !ok {
respondError(w, "message has no media", http.StatusBadRequest)
return
}
// Create TGFile from message media
tgFile, err := tfile.FromMediaMessage(media, botCtx.Raw, msg)
if err != nil {
logger.Errorf("Failed to create TGFile: %v", err)
respondError(w, fmt.Sprintf("failed to create file from message: %v", err), http.StatusBadRequest)
return
}
// Create task
dirPath := req.DirPath
if dirPath == "" {
dirPath = "/"
}
storagePath := stor.JoinStoragePath(path.Join(dirPath, tgFile.Name()))
taskID := xid.New().String()
// Create context with bot extension
injectCtx := tgutil.ExtWithContext(r.Context(), botCtx)
task, err := tftask.NewTGFileTask(taskID, injectCtx, tgFile, stor, storagePath, &apiProgressTracker{
taskID: taskID,
})
if err != nil {
logger.Errorf("Failed to create task: %v", err)
respondError(w, fmt.Sprintf("failed to create task: %v", err), http.StatusInternalServerError)
return
}
// Track task status
trackTask(taskID, task.Title(), "queued")
// Add task to queue
if err := core.AddTask(injectCtx, task); err != nil {
logger.Errorf("Failed to add task: %v", err)
updateTaskStatus(taskID, "failed", err.Error())
respondError(w, fmt.Sprintf("failed to add task: %v", err), http.StatusInternalServerError)
return
}
// Send success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(CreateTaskResponse{
TaskID: taskID,
Message: "task created successfully",
})
}
func handleGetTask(w http.ResponseWriter, r *http.Request) {
taskID := r.PathValue("id")
if taskID == "" {
respondError(w, "task_id is required", http.StatusBadRequest)
return
}
taskStatusesMu.RLock()
status, exists := taskStatuses[taskID]
taskStatusesMu.RUnlock()
if !exists {
respondError(w, "task not found", http.StatusNotFound)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(TaskStatusResponse{
TaskID: status.ID,
Status: status.Status,
Title: status.Title,
CreatedAt: status.CreatedAt,
Error: status.Error,
})
}
func handleListTasks(w http.ResponseWriter, r *http.Request) {
queued := core.GetQueuedTasks(r.Context())
running := core.GetRunningTasks(r.Context())
response := ListTasksResponse{
Queued: convertTaskInfos(queued),
Running: convertTaskInfos(running),
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}
func handleCancelTask(w http.ResponseWriter, r *http.Request) {
taskID := r.PathValue("id")
if taskID == "" {
respondError(w, "task_id is required", http.StatusBadRequest)
return
}
if err := core.CancelTask(r.Context(), taskID); err != nil {
respondError(w, fmt.Sprintf("failed to cancel task: %v", err), http.StatusInternalServerError)
return
}
updateTaskStatus(taskID, "canceled", "")
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{"message": "task canceled"})
}
// Helper functions
func respondError(w http.ResponseWriter, message string, statusCode int) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
json.NewEncoder(w).Encode(ErrorResponse{Error: message})
}
func trackTask(taskID, title, status string) {
taskStatusesMu.Lock()
defer taskStatusesMu.Unlock()
taskStatuses[taskID] = &taskStatus{
ID: taskID,
Status: status,
Title: title,
CreatedAt: time.Now(),
}
}
func updateTaskStatus(taskID, status, errorMsg string) {
taskStatusesMu.Lock()
defer taskStatusesMu.Unlock()
if ts, exists := taskStatuses[taskID]; exists {
ts.Status = status
ts.Error = errorMsg
}
}
func convertTaskInfos(tasks []queue.TaskInfo) []TaskInfo {
result := make([]TaskInfo, len(tasks))
for i, t := range tasks {
result[i] = TaskInfo{
ID: t.ID,
Title: t.Title,
}
}
return result
}
// apiProgressTracker implements tftask.ProgressTracker for API tasks
type apiProgressTracker struct {
taskID string
}
func (a *apiProgressTracker) OnStart(ctx context.Context, info tftask.TaskInfo) {
updateTaskStatus(a.taskID, "running", "")
}
func (a *apiProgressTracker) OnProgress(ctx context.Context, info tftask.TaskInfo, downloaded int64, total int64) {
// No-op for API tasks
}
func (a *apiProgressTracker) OnDone(ctx context.Context, info tftask.TaskInfo, err error) {
if err != nil {
updateTaskStatus(a.taskID, "failed", err.Error())
sendWebhook(a.taskID, "failed", err.Error())
} else {
updateTaskStatus(a.taskID, "completed", "")
sendWebhook(a.taskID, "completed", "")
}
}
// sendWebhook sends a callback to the configured webhook URL
func sendWebhook(taskID, status, errorMsg string) {
cfg := config.C()
if cfg.API.WebhookURL == "" {
return
}
taskStatusesMu.RLock()
ts, exists := taskStatuses[taskID]
taskStatusesMu.RUnlock()
if !exists {
return
}
payload := TaskStatusResponse{
TaskID: ts.ID,
Status: status,
Title: ts.Title,
CreatedAt: ts.CreatedAt,
Error: errorMsg,
}
body, err := json.Marshal(payload)
if err != nil {
log.Errorf("Failed to marshal webhook payload: %v", err)
return
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "POST", cfg.API.WebhookURL, bytes.NewReader(body))
if err != nil {
log.Errorf("Failed to create webhook request: %v", err)
return
}
req.Header.Set("Content-Type", "application/json")
if cfg.API.Token != "" {
req.Header.Set("Authorization", "Bearer "+cfg.API.Token)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
log.Errorf("Failed to send webhook: %v", err)
return
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
body, _ := io.ReadAll(resp.Body)
log.Errorf("Webhook returned error status %d: %s", resp.StatusCode, string(body))
}
}

116
api/middleware.go Normal file
View File

@@ -0,0 +1,116 @@
package api
import (
"net"
"net/http"
"strings"
"time"
"github.com/charmbracelet/log"
"github.com/krau/SaveAny-Bot/config"
)
// authMiddleware validates API token and IP restrictions
func authMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Skip auth for health check
if r.URL.Path == "/health" {
next.ServeHTTP(w, r)
return
}
cfg := config.C()
// Check IP whitelist if configured
if len(cfg.API.TrustedIPs) > 0 {
clientIP := getClientIP(r)
if !isIPAllowed(clientIP, cfg.API.TrustedIPs) {
http.Error(w, `{"error":"forbidden: IP not allowed"}`, http.StatusForbidden)
return
}
}
// Check token if configured
if cfg.API.Token != "" {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
http.Error(w, `{"error":"unauthorized: missing token"}`, http.StatusUnauthorized)
return
}
token := strings.TrimPrefix(authHeader, "Bearer ")
if token != cfg.API.Token {
http.Error(w, `{"error":"unauthorized: invalid token"}`, http.StatusUnauthorized)
return
}
}
next.ServeHTTP(w, r)
})
}
// loggingMiddleware logs HTTP requests
func loggingMiddleware(logger *log.Logger) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// Wrap response writer to capture status code
wrapper := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK}
next.ServeHTTP(wrapper, r)
logger.Infof("%s %s %d %s", r.Method, r.URL.Path, wrapper.statusCode, time.Since(start))
})
}
}
type responseWriter struct {
http.ResponseWriter
statusCode int
}
func (rw *responseWriter) WriteHeader(code int) {
rw.statusCode = code
rw.ResponseWriter.WriteHeader(code)
}
// getClientIP extracts the real client IP from the request
func getClientIP(r *http.Request) string {
// Check X-Forwarded-For header
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
ips := strings.Split(xff, ",")
if len(ips) > 0 {
return strings.TrimSpace(ips[0])
}
}
// Check X-Real-IP header
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return xri
}
// Fall back to RemoteAddr
ip, _, _ := net.SplitHostPort(r.RemoteAddr)
return ip
}
// isIPAllowed checks if the client IP is in the allowed list
func isIPAllowed(clientIP string, allowedIPs []string) bool {
for _, allowedIP := range allowedIPs {
if clientIP == allowedIP || allowedIP == "*" {
return true
}
// Support CIDR notation
if strings.Contains(allowedIP, "/") {
_, ipNet, err := net.ParseCIDR(allowedIP)
if err != nil {
continue
}
if ipNet.Contains(net.ParseIP(clientIP)) {
return true
}
}
}
return false
}

72
api/server.go Normal file
View File

@@ -0,0 +1,72 @@
package api
import (
"context"
"fmt"
"net/http"
"time"
"github.com/charmbracelet/log"
"github.com/krau/SaveAny-Bot/config"
)
var server *http.Server
// Init initializes and starts the HTTP API server
func Init(ctx context.Context) error {
cfg := config.C()
if !cfg.API.Enable {
return nil
}
logger := log.FromContext(ctx).WithPrefix("api")
mux := http.NewServeMux()
// Register API routes
registerRoutes(mux)
// Wrap with middleware
handler := loggingMiddleware(logger)(authMiddleware(mux))
server = &http.Server{
Addr: fmt.Sprintf(":%d", cfg.API.Port),
Handler: handler,
ReadTimeout: 15 * time.Second,
WriteTimeout: 15 * time.Second,
IdleTimeout: 60 * time.Second,
}
go func() {
logger.Infof("Starting API server on port %d", cfg.API.Port)
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
logger.Errorf("API server error: %v", err)
}
}()
// Graceful shutdown on context cancellation
go func() {
<-ctx.Done()
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := server.Shutdown(shutdownCtx); err != nil {
logger.Errorf("Failed to shutdown API server: %v", err)
} else {
logger.Info("API server stopped")
}
}()
return nil
}
func registerRoutes(mux *http.ServeMux) {
// Health check endpoint (no auth required)
mux.HandleFunc("/health", handleHealth)
// API v1 endpoints
mux.HandleFunc("POST /api/v1/tasks", handleCreateTask)
mux.HandleFunc("GET /api/v1/tasks/{id}", handleGetTask)
mux.HandleFunc("GET /api/v1/tasks", handleListTasks)
mux.HandleFunc("DELETE /api/v1/tasks/{id}", handleCancelTask)
}

View File

@@ -10,6 +10,7 @@ import (
"slices"
"github.com/charmbracelet/log"
"github.com/krau/SaveAny-Bot/api"
"github.com/krau/SaveAny-Bot/client/bot"
userclient "github.com/krau/SaveAny-Bot/client/user"
"github.com/krau/SaveAny-Bot/common/cache"
@@ -76,7 +77,11 @@ func initAll(ctx context.Context, cmd *cobra.Command) (<-chan struct{}, error) {
logger.Fatal("User login failed", "error", err)
}
}
return bot.Init(ctx), nil
exitChan := bot.Init(ctx)
if err := api.Init(ctx); err != nil {
return nil, fmt.Errorf("failed to init API server: %w", err)
}
return exitChan, nil
}
func cleanCache() {

View File

@@ -29,6 +29,19 @@ secret = ""
# 转存完成后删除 Aria2 下载的本地文件
remove_after_transfer = true
# HTTP API 配置
[api]
# 启用 HTTP API 服务
enable = false
# API 服务监听端口
port = 8080
# API 访问令牌 (留空则不验证)
token = ""
# 任务完成回调 Webhook URL (留空则不回调)
webhook_url = ""
# 可信任的 IP 地址列表 (留空则不限制), 支持单个 IP 或 CIDR 格式
# trusted_ips = ["127.0.0.1", "192.168.1.0/24"]
# 存储列表
[[storages]]
# 标识名, 需要唯一

View File

@@ -33,6 +33,7 @@ type Config struct {
Storages []storage.StorageConfig `toml:"-" mapstructure:"-" json:"storages"`
Parser parserConfig `toml:"parser" mapstructure:"parser" json:"parser"`
Hook hookConfig `toml:"hook" mapstructure:"hook" json:"hook"`
API apiConfig `toml:"api" mapstructure:"api" json:"api"`
}
type aria2Config struct {
@@ -42,6 +43,14 @@ type aria2Config struct {
KeepFile bool `toml:"keep_file" mapstructure:"keep_file" json:"keep_file"`
}
type apiConfig struct {
Enable bool `toml:"enable" mapstructure:"enable" json:"enable"`
Port int `toml:"port" mapstructure:"port" json:"port"`
Token string `toml:"token" mapstructure:"token" json:"token"`
WebhookURL string `toml:"webhook_url" mapstructure:"webhook_url" json:"webhook_url"`
TrustedIPs []string `toml:"trusted_ips" mapstructure:"trusted_ips" json:"trusted_ips"`
}
var cfg = &Config{}
func C() Config {
@@ -115,6 +124,12 @@ func Init(ctx context.Context, configFile ...string) error {
// 数据库
"db.path": "data/saveany.db",
"db.session": "data/session.db",
// API
"api.enable": false,
"api.port": 8080,
"api.token": "",
"api.webhook_url": "",
}
for key, value := range defaultConfigs {