Address PR feedback: remove redundant files, format code, add progress tracking

Co-authored-by: krau <71133316+krau@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot]
2026-01-19 05:27:12 +00:00
parent 127901fd24
commit 20a5e317ae
6 changed files with 63 additions and 174 deletions

View File

@@ -1,26 +0,0 @@
# API Module
This module provides a RESTful HTTP API for programmatic file downloads from Telegram.
## Features
- **RESTful API endpoints** for creating, querying, and canceling download tasks
- **Bearer token authentication** for API access control
- **IP whitelist** support for additional security
- **Webhook callbacks** for task completion notifications
- **Task status tracking** (queued, running, completed, failed, canceled)
- **Graceful shutdown** with proper cleanup
## Usage
See the full documentation at:
- English: `/docs/content/en/usage/api.md`
- Chinese: `/docs/content/zh/usage/api.md`
## Architecture
- `server.go` - HTTP server initialization and route registration
- `handlers.go` - API endpoint handlers and business logic
- `middleware.go` - Authentication and logging middleware
The API integrates with the existing task queue system and uses the bot's Telegram client to fetch messages.

View File

@@ -24,10 +24,10 @@ import (
// Request/Response types
type CreateTaskRequest struct {
TelegramURL string `json:"telegram_url"`
StorageName string `json:"storage_name,omitempty"`
DirPath string `json:"dir_path,omitempty"`
UserID int64 `json:"user_id"`
TelegramURL string `json:"telegram_url"`
StorageName string `json:"storage_name,omitempty"`
DirPath string `json:"dir_path,omitempty"`
UserID int64 `json:"user_id"`
}
type CreateTaskResponse struct {
@@ -36,11 +36,14 @@ type CreateTaskResponse struct {
}
type TaskStatusResponse struct {
TaskID string `json:"task_id"`
Status string `json:"status"` // queued, running, completed, failed, canceled
Title string `json:"title"`
CreatedAt time.Time `json:"created_at"`
Error string `json:"error,omitempty"`
TaskID string `json:"task_id"`
Status string `json:"status"` // queued, running, completed, failed, canceled
Title string `json:"title"`
CreatedAt time.Time `json:"created_at"`
Error string `json:"error,omitempty"`
Downloaded int64 `json:"downloaded,omitempty"` // Bytes downloaded
Total int64 `json:"total,omitempty"` // Total bytes
ProgressPct float64 `json:"progress_pct,omitempty"` // Progress percentage (0-100)
}
type ListTasksResponse struct {
@@ -64,11 +67,14 @@ var (
)
type taskStatus struct {
ID string
Status string
Title string
CreatedAt time.Time
Error string
ID string
Status string
Title string
CreatedAt time.Time
Error string
Downloaded int64
Total int64
ProgressPct float64
}
func handleHealth(w http.ResponseWriter, r *http.Request) {
@@ -212,11 +218,14 @@ func handleGetTask(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(TaskStatusResponse{
TaskID: status.ID,
Status: status.Status,
Title: status.Title,
CreatedAt: status.CreatedAt,
Error: status.Error,
TaskID: status.ID,
Status: status.Status,
Title: status.Title,
CreatedAt: status.CreatedAt,
Error: status.Error,
Downloaded: status.Downloaded,
Total: status.Total,
ProgressPct: status.ProgressPct,
})
}
@@ -301,7 +310,15 @@ func (a *apiProgressTracker) OnStart(ctx context.Context, info tftask.TaskInfo)
}
func (a *apiProgressTracker) OnProgress(ctx context.Context, info tftask.TaskInfo, downloaded int64, total int64) {
// No-op for API tasks
taskStatusesMu.Lock()
defer taskStatusesMu.Unlock()
if ts, exists := taskStatuses[a.taskID]; exists {
ts.Downloaded = downloaded
ts.Total = total
if total > 0 {
ts.ProgressPct = float64(downloaded) / float64(total) * 100.0
}
}
}
func (a *apiProgressTracker) OnDone(ctx context.Context, info tftask.TaskInfo, err error) {
@@ -332,11 +349,14 @@ func sendWebhook(taskID, status, errorMsg string) {
logger := log.WithPrefix("webhook")
payload := TaskStatusResponse{
TaskID: ts.ID,
Status: status,
Title: ts.Title,
CreatedAt: ts.CreatedAt,
Error: errorMsg,
TaskID: ts.ID,
Status: status,
Title: ts.Title,
CreatedAt: ts.CreatedAt,
Error: errorMsg,
Downloaded: ts.Downloaded,
Total: ts.Total,
ProgressPct: ts.ProgressPct,
}
body, err := json.Marshal(payload)

View File

@@ -54,12 +54,12 @@ func loggingMiddleware(logger *log.Logger) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// Wrap response writer to capture status code
wrapper := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK}
next.ServeHTTP(wrapper, r)
logger.Infof("%s %s %d %s", r.Method, r.URL.Path, wrapper.statusCode, time.Since(start))
})
}

View File

@@ -61,7 +61,7 @@ func TestIsIPAllowed(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
result := isIPAllowed(tt.clientIP, tt.allowedIPs)
if result != tt.expected {
t.Errorf("isIPAllowed(%q, %v) = %v, want %v",
t.Errorf("isIPAllowed(%q, %v) = %v, want %v",
tt.clientIP, tt.allowedIPs, result, tt.expected)
}
})
@@ -110,11 +110,11 @@ func TestAuthMiddleware_HealthCheck(t *testing.T) {
func TestGetClientIP(t *testing.T) {
tests := []struct {
name string
remoteAddr string
xForwardedFor string
xRealIP string
expectedIP string
name string
remoteAddr string
xForwardedFor string
xRealIP string
expectedIP string
}{
{
name: "RemoteAddr only",

View File

@@ -20,15 +20,15 @@ func Init(ctx context.Context) error {
}
logger := log.FromContext(ctx).WithPrefix("api")
mux := http.NewServeMux()
// Register API routes
registerRoutes(mux)
// Wrap with middleware
handler := loggingMiddleware(logger)(authMiddleware(mux))
server = &http.Server{
Addr: fmt.Sprintf(":%d", cfg.API.Port),
Handler: handler,
@@ -36,34 +36,34 @@ func Init(ctx context.Context) error {
WriteTimeout: 15 * time.Second,
IdleTimeout: 60 * time.Second,
}
go func() {
logger.Infof("Starting API server on port %d", cfg.API.Port)
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
logger.Errorf("API server error: %v", err)
}
}()
// Graceful shutdown on context cancellation
go func() {
<-ctx.Done()
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := server.Shutdown(shutdownCtx); err != nil {
logger.Errorf("Failed to shutdown API server: %v", err)
} else {
logger.Info("API server stopped")
}
}()
return nil
}
func registerRoutes(mux *http.ServeMux) {
// Health check endpoint (no auth required)
mux.HandleFunc("/health", handleHealth)
// API v1 endpoints
mux.HandleFunc("POST /api/v1/tasks", handleCreateTask)
mux.HandleFunc("GET /api/v1/tasks/{id}", handleGetTask)