mirror of
https://github.com/krau/SaveAny-Bot.git
synced 2026-06-30 03:31:24 +08:00
Address PR feedback: remove redundant files, format code, add progress tracking
Co-authored-by: krau <71133316+krau@users.noreply.github.com>
This commit is contained in:
@@ -1,26 +0,0 @@
|
||||
# API Module
|
||||
|
||||
This module provides a RESTful HTTP API for programmatic file downloads from Telegram.
|
||||
|
||||
## Features
|
||||
|
||||
- **RESTful API endpoints** for creating, querying, and canceling download tasks
|
||||
- **Bearer token authentication** for API access control
|
||||
- **IP whitelist** support for additional security
|
||||
- **Webhook callbacks** for task completion notifications
|
||||
- **Task status tracking** (queued, running, completed, failed, canceled)
|
||||
- **Graceful shutdown** with proper cleanup
|
||||
|
||||
## Usage
|
||||
|
||||
See the full documentation at:
|
||||
- English: `/docs/content/en/usage/api.md`
|
||||
- Chinese: `/docs/content/zh/usage/api.md`
|
||||
|
||||
## Architecture
|
||||
|
||||
- `server.go` - HTTP server initialization and route registration
|
||||
- `handlers.go` - API endpoint handlers and business logic
|
||||
- `middleware.go` - Authentication and logging middleware
|
||||
|
||||
The API integrates with the existing task queue system and uses the bot's Telegram client to fetch messages.
|
||||
@@ -24,10 +24,10 @@ import (
|
||||
|
||||
// Request/Response types
|
||||
type CreateTaskRequest struct {
|
||||
TelegramURL string `json:"telegram_url"`
|
||||
StorageName string `json:"storage_name,omitempty"`
|
||||
DirPath string `json:"dir_path,omitempty"`
|
||||
UserID int64 `json:"user_id"`
|
||||
TelegramURL string `json:"telegram_url"`
|
||||
StorageName string `json:"storage_name,omitempty"`
|
||||
DirPath string `json:"dir_path,omitempty"`
|
||||
UserID int64 `json:"user_id"`
|
||||
}
|
||||
|
||||
type CreateTaskResponse struct {
|
||||
@@ -36,11 +36,14 @@ type CreateTaskResponse struct {
|
||||
}
|
||||
|
||||
type TaskStatusResponse struct {
|
||||
TaskID string `json:"task_id"`
|
||||
Status string `json:"status"` // queued, running, completed, failed, canceled
|
||||
Title string `json:"title"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Error string `json:"error,omitempty"`
|
||||
TaskID string `json:"task_id"`
|
||||
Status string `json:"status"` // queued, running, completed, failed, canceled
|
||||
Title string `json:"title"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Downloaded int64 `json:"downloaded,omitempty"` // Bytes downloaded
|
||||
Total int64 `json:"total,omitempty"` // Total bytes
|
||||
ProgressPct float64 `json:"progress_pct,omitempty"` // Progress percentage (0-100)
|
||||
}
|
||||
|
||||
type ListTasksResponse struct {
|
||||
@@ -64,11 +67,14 @@ var (
|
||||
)
|
||||
|
||||
type taskStatus struct {
|
||||
ID string
|
||||
Status string
|
||||
Title string
|
||||
CreatedAt time.Time
|
||||
Error string
|
||||
ID string
|
||||
Status string
|
||||
Title string
|
||||
CreatedAt time.Time
|
||||
Error string
|
||||
Downloaded int64
|
||||
Total int64
|
||||
ProgressPct float64
|
||||
}
|
||||
|
||||
func handleHealth(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -212,11 +218,14 @@ func handleGetTask(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(TaskStatusResponse{
|
||||
TaskID: status.ID,
|
||||
Status: status.Status,
|
||||
Title: status.Title,
|
||||
CreatedAt: status.CreatedAt,
|
||||
Error: status.Error,
|
||||
TaskID: status.ID,
|
||||
Status: status.Status,
|
||||
Title: status.Title,
|
||||
CreatedAt: status.CreatedAt,
|
||||
Error: status.Error,
|
||||
Downloaded: status.Downloaded,
|
||||
Total: status.Total,
|
||||
ProgressPct: status.ProgressPct,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -301,7 +310,15 @@ func (a *apiProgressTracker) OnStart(ctx context.Context, info tftask.TaskInfo)
|
||||
}
|
||||
|
||||
func (a *apiProgressTracker) OnProgress(ctx context.Context, info tftask.TaskInfo, downloaded int64, total int64) {
|
||||
// No-op for API tasks
|
||||
taskStatusesMu.Lock()
|
||||
defer taskStatusesMu.Unlock()
|
||||
if ts, exists := taskStatuses[a.taskID]; exists {
|
||||
ts.Downloaded = downloaded
|
||||
ts.Total = total
|
||||
if total > 0 {
|
||||
ts.ProgressPct = float64(downloaded) / float64(total) * 100.0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *apiProgressTracker) OnDone(ctx context.Context, info tftask.TaskInfo, err error) {
|
||||
@@ -332,11 +349,14 @@ func sendWebhook(taskID, status, errorMsg string) {
|
||||
logger := log.WithPrefix("webhook")
|
||||
|
||||
payload := TaskStatusResponse{
|
||||
TaskID: ts.ID,
|
||||
Status: status,
|
||||
Title: ts.Title,
|
||||
CreatedAt: ts.CreatedAt,
|
||||
Error: errorMsg,
|
||||
TaskID: ts.ID,
|
||||
Status: status,
|
||||
Title: ts.Title,
|
||||
CreatedAt: ts.CreatedAt,
|
||||
Error: errorMsg,
|
||||
Downloaded: ts.Downloaded,
|
||||
Total: ts.Total,
|
||||
ProgressPct: ts.ProgressPct,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
|
||||
@@ -54,12 +54,12 @@ func loggingMiddleware(logger *log.Logger) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
|
||||
|
||||
// Wrap response writer to capture status code
|
||||
wrapper := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK}
|
||||
|
||||
|
||||
next.ServeHTTP(wrapper, r)
|
||||
|
||||
|
||||
logger.Infof("%s %s %d %s", r.Method, r.URL.Path, wrapper.statusCode, time.Since(start))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -61,7 +61,7 @@ func TestIsIPAllowed(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isIPAllowed(tt.clientIP, tt.allowedIPs)
|
||||
if result != tt.expected {
|
||||
t.Errorf("isIPAllowed(%q, %v) = %v, want %v",
|
||||
t.Errorf("isIPAllowed(%q, %v) = %v, want %v",
|
||||
tt.clientIP, tt.allowedIPs, result, tt.expected)
|
||||
}
|
||||
})
|
||||
@@ -110,11 +110,11 @@ func TestAuthMiddleware_HealthCheck(t *testing.T) {
|
||||
|
||||
func TestGetClientIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
remoteAddr string
|
||||
xForwardedFor string
|
||||
xRealIP string
|
||||
expectedIP string
|
||||
name string
|
||||
remoteAddr string
|
||||
xForwardedFor string
|
||||
xRealIP string
|
||||
expectedIP string
|
||||
}{
|
||||
{
|
||||
name: "RemoteAddr only",
|
||||
|
||||
@@ -20,15 +20,15 @@ func Init(ctx context.Context) error {
|
||||
}
|
||||
|
||||
logger := log.FromContext(ctx).WithPrefix("api")
|
||||
|
||||
|
||||
mux := http.NewServeMux()
|
||||
|
||||
|
||||
// Register API routes
|
||||
registerRoutes(mux)
|
||||
|
||||
|
||||
// Wrap with middleware
|
||||
handler := loggingMiddleware(logger)(authMiddleware(mux))
|
||||
|
||||
|
||||
server = &http.Server{
|
||||
Addr: fmt.Sprintf(":%d", cfg.API.Port),
|
||||
Handler: handler,
|
||||
@@ -36,34 +36,34 @@ func Init(ctx context.Context) error {
|
||||
WriteTimeout: 15 * time.Second,
|
||||
IdleTimeout: 60 * time.Second,
|
||||
}
|
||||
|
||||
|
||||
go func() {
|
||||
logger.Infof("Starting API server on port %d", cfg.API.Port)
|
||||
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
logger.Errorf("API server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
|
||||
// Graceful shutdown on context cancellation
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
|
||||
if err := server.Shutdown(shutdownCtx); err != nil {
|
||||
logger.Errorf("Failed to shutdown API server: %v", err)
|
||||
} else {
|
||||
logger.Info("API server stopped")
|
||||
}
|
||||
}()
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func registerRoutes(mux *http.ServeMux) {
|
||||
// Health check endpoint (no auth required)
|
||||
mux.HandleFunc("/health", handleHealth)
|
||||
|
||||
|
||||
// API v1 endpoints
|
||||
mux.HandleFunc("POST /api/v1/tasks", handleCreateTask)
|
||||
mux.HandleFunc("GET /api/v1/tasks/{id}", handleGetTask)
|
||||
|
||||
Reference in New Issue
Block a user