package api import ( "bytes" "context" "encoding/json" "fmt" "io" "math" "net/http" "net/url" "path" "strings" "sync" "sync/atomic" "time" "github.com/celestix/gotgproto/ext" "github.com/charmbracelet/log" "github.com/gotd/td/tg" "github.com/krau/SaveAny-Bot/client/bot" "github.com/krau/SaveAny-Bot/client/bot/handlers/utils/ruleutil" "github.com/krau/SaveAny-Bot/common/utils/tgutil" "github.com/krau/SaveAny-Bot/config" "github.com/krau/SaveAny-Bot/core" tftask "github.com/krau/SaveAny-Bot/core/tasks/tfile" "github.com/krau/SaveAny-Bot/database" "github.com/krau/SaveAny-Bot/pkg/queue" "github.com/krau/SaveAny-Bot/pkg/tfile" "github.com/krau/SaveAny-Bot/storage" "github.com/rs/xid" ) // Request/Response types type CreateTaskRequest struct { TelegramURL string `json:"telegram_url"` StorageName string `json:"storage_name,omitempty"` DirPath string `json:"dir_path,omitempty"` UserID int64 `json:"user_id"` } type CreateTaskResponse struct { TaskID string `json:"task_id"` Message string `json:"message"` } type TaskStatusResponse struct { TaskID string `json:"task_id"` Status string `json:"status"` // queued, running, completed, failed, canceled Title string `json:"title"` CreatedAt time.Time `json:"created_at"` Error string `json:"error,omitempty"` Downloaded int64 `json:"downloaded,omitempty"` // Bytes downloaded Total int64 `json:"total,omitempty"` // Total bytes ProgressPct float64 `json:"progress_pct,omitempty"` // Progress percentage (0-100) } type ListTasksResponse struct { Queued []TaskInfo `json:"queued"` Running []TaskInfo `json:"running"` } type TaskInfo struct { ID string `json:"id"` Title string `json:"title"` } type ErrorResponse struct { Error string `json:"error"` } // Task tracking var ( taskStatuses = make(map[string]*taskStatus) taskStatusesMu sync.RWMutex ) type taskStatus struct { ID string Status string Title string CreatedAt time.Time Error string Downloaded atomic.Int64 // Use atomic for lock-free updates Total atomic.Int64 // Use atomic for lock-free updates ProgressPct uint64 // Store as uint64 bits of float64 for atomic access } func handleHealth(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) } func handleCreateTask(w http.ResponseWriter, r *http.Request) { var req CreateTaskRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { respondError(w, "invalid request body", http.StatusBadRequest) return } // Validate request if req.TelegramURL == "" { respondError(w, "telegram_url is required", http.StatusBadRequest) return } if req.UserID <= 0 { respondError(w, "user_id is required and must be positive", http.StatusBadRequest) return } logger := log.FromContext(r.Context()).WithPrefix("api") // Get user from database userDB, err := database.GetUserByChatID(r.Context(), req.UserID) if err != nil { logger.Errorf("Failed to get user: %v", err) respondError(w, "user not found", http.StatusBadRequest) return } // Get storage var stor storage.Storage if req.StorageName != "" { stor, err = storage.GetStorageByUserIDAndName(r.Context(), req.UserID, req.StorageName) if err != nil { logger.Errorf("Failed to get storage: %v", err) respondError(w, "storage not found", http.StatusBadRequest) return } } else { // Use first available storage for the user storages := storage.GetUserStorages(r.Context(), req.UserID) if len(storages) == 0 { respondError(w, "no storage available for user", http.StatusBadRequest) return } stor = storages[0] } // Parse Telegram URL botCtx := bot.ExtContext() if botCtx == nil { respondError(w, "bot not initialized", http.StatusInternalServerError) return } linkUrl, err := url.Parse(req.TelegramURL) if err != nil { logger.Errorf("Failed to parse URL: %v", err) respondError(w, "invalid telegram URL format", http.StatusBadRequest) return } chatID, msgID, err := tgutil.ParseMessageLink(botCtx, req.TelegramURL) if err != nil { logger.Errorf("Failed to parse Telegram URL: %v", err) respondError(w, "invalid telegram URL format", http.StatusBadRequest) return } // Get message from Telegram msg, err := tgutil.GetMessageByID(botCtx, chatID, msgID) if err != nil { logger.Errorf("Failed to get message: %v", err) respondError(w, "failed to retrieve message", http.StatusBadRequest) return } // Check if message has media media, ok := msg.GetMedia() if !ok { respondError(w, "message has no media", http.StatusBadRequest) return } // Collect files - handle both single and grouped messages files := make([]tfile.TGFileMessage, 0) // Check for grouped messages (media group) groupID, isGroup := msg.GetGroupedID() if isGroup && groupID != 0 && !linkUrl.Query().Has("single") { // Handle media group gmsgs, err := tgutil.GetGroupedMessages(botCtx, chatID, msg) if err != nil { logger.Errorf("Failed to get grouped messages: %v", err) // Fall back to single message file, err := createTGFileWithMedia(botCtx, msg, media, userDB) if err != nil { logger.Errorf("Failed to create TGFile: %v", err) respondError(w, "invalid message format", http.StatusBadRequest) return } files = append(files, file) } else { // Process all messages in the group for _, gmsg := range gmsgs { if gmsg.Media == nil { continue } gMedia, ok := gmsg.GetMedia() if !ok { continue } file, err := createTGFileWithMedia(botCtx, gmsg, gMedia, userDB) if err != nil { logger.Warnf("Failed to create TGFile for grouped message: %v", err) continue } files = append(files, file) } } } else { // Single message file, err := createTGFileWithMedia(botCtx, msg, media, userDB) if err != nil { logger.Errorf("Failed to create TGFile: %v", err) respondError(w, "invalid message format", http.StatusBadRequest) return } files = append(files, file) } if len(files) == 0 { respondError(w, "no savable files found", http.StatusBadRequest) return } // Create tasks for all files with proper album handling taskIDs := make([]string, 0, len(files)) baseDirPath := req.DirPath if baseDirPath == "" { baseDirPath = "/" } // Create context with bot extension injectCtx := tgutil.ExtWithContext(r.Context(), botCtx) // Apply storage rules if enabled for the user useRule := userDB.ApplyRule && userDB.Rules != nil // Helper to apply rules to a file applyRule := func(file tfile.TGFileMessage) (storage.Storage, ruleutil.MatchedDirPath) { fileStor := stor dirPath := ruleutil.MatchedDirPath(baseDirPath) if useRule { matched, matchedStorName, matchedDirPath := ruleutil.ApplyRule(injectCtx, userDB.Rules, ruleutil.NewInput(file)) if matched { // Rule matched, apply overrides if matchedDirPath != "" { dirPath = matchedDirPath } if matchedStorName.Usable() { var err error fileStor, err = storage.GetStorageByUserIDAndName(injectCtx, userDB.ChatID, matchedStorName.String()) if err != nil { logger.Errorf("Failed to get storage from rule: %v", err) // Fall back to original storage fileStor = stor } } } } return fileStor, dirPath } // Separate files into regular and album files type albumFile struct { file tfile.TGFileMessage storage storage.Storage dirPath ruleutil.MatchedDirPath } albumFiles := make(map[int64][]albumFile) for _, tgFile := range files { fileStor, dirPath := applyRule(tgFile) // Check if this needs album handling (NEW-FOR-ALBUM rule) if dirPath.NeedNewForAlbum() { groupId, isGroup := tgFile.Message().GetGroupedID() if !isGroup || groupId == 0 { logger.Warnf("File %s has NEW-FOR-ALBUM rule but is not in a group, treating as regular file", tgFile.Name()) // Treat as regular file with base dir path storagePath := fileStor.JoinStoragePath(path.Join(baseDirPath, tgFile.Name())) taskID := xid.New().String() task, err := tftask.NewTGFileTask(taskID, injectCtx, tgFile, fileStor, storagePath, &apiProgressTracker{ taskID: taskID, }) if err != nil { logger.Errorf("Failed to create task: %v", err) respondError(w, "failed to create task", http.StatusInternalServerError) return } trackTask(taskID, task.Title(), "queued") if err := core.AddTask(injectCtx, task); err != nil { logger.Errorf("Failed to add task: %v", err) updateTaskStatus(taskID, "failed", err.Error()) respondError(w, "failed to add task to queue", http.StatusInternalServerError) return } taskIDs = append(taskIDs, taskID) continue } // Group by album ID if _, ok := albumFiles[groupId]; !ok { albumFiles[groupId] = make([]albumFile, 0) } albumFiles[groupId] = append(albumFiles[groupId], albumFile{ file: tgFile, storage: fileStor, dirPath: dirPath, }) } else { // Regular file - create task immediately storagePath := fileStor.JoinStoragePath(path.Join(dirPath.String(), tgFile.Name())) taskID := xid.New().String() task, err := tftask.NewTGFileTask(taskID, injectCtx, tgFile, fileStor, storagePath, &apiProgressTracker{ taskID: taskID, }) if err != nil { logger.Errorf("Failed to create task: %v", err) respondError(w, "failed to create task", http.StatusInternalServerError) return } trackTask(taskID, task.Title(), "queued") if err := core.AddTask(injectCtx, task); err != nil { logger.Errorf("Failed to add task: %v", err) updateTaskStatus(taskID, "failed", err.Error()) respondError(w, "failed to add task to queue", http.StatusInternalServerError) return } taskIDs = append(taskIDs, taskID) } } // Handle album files - group them into subdirectories for _, afiles := range albumFiles { if len(afiles) <= 1 { // Single file in album, treat as regular for _, af := range afiles { storagePath := af.storage.JoinStoragePath(path.Join(baseDirPath, af.file.Name())) taskID := xid.New().String() task, err := tftask.NewTGFileTask(taskID, injectCtx, af.file, af.storage, storagePath, &apiProgressTracker{ taskID: taskID, }) if err != nil { logger.Errorf("Failed to create task: %v", err) respondError(w, "failed to create task", http.StatusInternalServerError) return } trackTask(taskID, task.Title(), "queued") if err := core.AddTask(injectCtx, task); err != nil { logger.Errorf("Failed to add task: %v", err) updateTaskStatus(taskID, "failed", err.Error()) respondError(w, "failed to add task to queue", http.StatusInternalServerError) return } taskIDs = append(taskIDs, taskID) } continue } // Multiple files in album - create subdirectory named after first file // Remove extension from first file's name to use as directory name albumDir := strings.TrimSuffix(path.Base(afiles[0].file.Name()), path.Ext(afiles[0].file.Name())) albumStor := afiles[0].storage for _, af := range afiles { // All files go into the album subdirectory storagePath := albumStor.JoinStoragePath(path.Join(baseDirPath, albumDir, af.file.Name())) taskID := xid.New().String() task, err := tftask.NewTGFileTask(taskID, injectCtx, af.file, albumStor, storagePath, &apiProgressTracker{ taskID: taskID, }) if err != nil { logger.Errorf("Failed to create task for album file: %v", err) respondError(w, "failed to create task", http.StatusInternalServerError) return } trackTask(taskID, task.Title(), "queued") if err := core.AddTask(injectCtx, task); err != nil { logger.Errorf("Failed to add task: %v", err) updateTaskStatus(taskID, "failed", err.Error()) respondError(w, "failed to add task to queue", http.StatusInternalServerError) return } taskIDs = append(taskIDs, taskID) } } // Send success response w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusCreated) // Return first task ID for single file, or all task IDs for media group if len(taskIDs) == 1 { json.NewEncoder(w).Encode(CreateTaskResponse{ TaskID: taskIDs[0], Message: "task created successfully", }) } else { json.NewEncoder(w).Encode(map[string]interface{}{ "task_ids": taskIDs, "message": fmt.Sprintf("%d tasks created successfully", len(taskIDs)), }) } } // createTGFileWithMedia creates a TGFile with proper filename handling using user's strategy func createTGFileWithMedia(botCtx *ext.Context, msg *tg.Message, media tg.MessageMediaClass, userDB *database.User) (tfile.TGFileMessage, error) { // Use the same filename generation logic as bot handlers opts := []tfile.TGFileOption{tfile.WithNameIfEmpty(tgutil.GenFileNameFromMessage(*msg))} return tfile.FromMediaMessage(media, botCtx.Raw, msg, opts...) } func handleGetTask(w http.ResponseWriter, r *http.Request) { taskID := r.PathValue("id") if taskID == "" { respondError(w, "task_id is required", http.StatusBadRequest) return } taskStatusesMu.RLock() status, exists := taskStatuses[taskID] taskStatusesMu.RUnlock() if !exists { respondError(w, "task not found", http.StatusNotFound) return } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(TaskStatusResponse{ TaskID: status.ID, Status: status.Status, Title: status.Title, CreatedAt: status.CreatedAt, Error: status.Error, Downloaded: status.Downloaded.Load(), Total: status.Total.Load(), ProgressPct: math.Float64frombits(atomic.LoadUint64((*uint64)(&status.ProgressPct))), }) } func handleListTasks(w http.ResponseWriter, r *http.Request) { queued := core.GetQueuedTasks(r.Context()) running := core.GetRunningTasks(r.Context()) response := ListTasksResponse{ Queued: convertTaskInfos(queued), Running: convertTaskInfos(running), } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(response) } func handleCancelTask(w http.ResponseWriter, r *http.Request) { taskID := r.PathValue("id") if taskID == "" { respondError(w, "task_id is required", http.StatusBadRequest) return } if err := core.CancelTask(r.Context(), taskID); err != nil { log.FromContext(r.Context()).Errorf("Failed to cancel task %s: %v", taskID, err) respondError(w, "failed to cancel task", http.StatusInternalServerError) return } updateTaskStatus(taskID, "canceled", "") w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]string{"message": "task canceled"}) } // Helper functions func respondError(w http.ResponseWriter, message string, statusCode int) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(statusCode) json.NewEncoder(w).Encode(ErrorResponse{Error: message}) } func trackTask(taskID, title, status string) { taskStatusesMu.Lock() defer taskStatusesMu.Unlock() taskStatuses[taskID] = &taskStatus{ ID: taskID, Status: status, Title: title, CreatedAt: time.Now(), } } func updateTaskStatus(taskID, status, errorMsg string) { taskStatusesMu.Lock() defer taskStatusesMu.Unlock() if ts, exists := taskStatuses[taskID]; exists { ts.Status = status ts.Error = errorMsg } } func convertTaskInfos(tasks []queue.TaskInfo) []TaskInfo { result := make([]TaskInfo, len(tasks)) for i, t := range tasks { result[i] = TaskInfo{ ID: t.ID, Title: t.Title, } } return result } // apiProgressTracker implements tftask.ProgressTracker for API tasks type apiProgressTracker struct { taskID string } func (a *apiProgressTracker) OnStart(ctx context.Context, info tftask.TaskInfo) { updateTaskStatus(a.taskID, "running", "") } func (a *apiProgressTracker) OnProgress(ctx context.Context, info tftask.TaskInfo, downloaded int64, total int64) { // Use atomic operations to avoid mutex locks for better performance // OnProgress is called very frequently during downloads taskStatusesMu.RLock() ts, exists := taskStatuses[a.taskID] taskStatusesMu.RUnlock() if exists { ts.Downloaded.Store(downloaded) ts.Total.Store(total) if total > 0 { progressPct := float64(downloaded) / float64(total) * 100.0 atomic.StoreUint64((*uint64)(&ts.ProgressPct), math.Float64bits(progressPct)) } } } func (a *apiProgressTracker) OnDone(ctx context.Context, info tftask.TaskInfo, err error) { if err != nil { updateTaskStatus(a.taskID, "failed", err.Error()) sendWebhook(a.taskID, "failed", err.Error()) } else { updateTaskStatus(a.taskID, "completed", "") sendWebhook(a.taskID, "completed", "") } } // sendWebhook sends a callback to the configured webhook URL func sendWebhook(taskID, status, errorMsg string) { cfg := config.C() if cfg.API.WebhookURL == "" { return } taskStatusesMu.RLock() ts, exists := taskStatuses[taskID] taskStatusesMu.RUnlock() if !exists { return } logger := log.WithPrefix("webhook") payload := TaskStatusResponse{ TaskID: ts.ID, Status: status, Title: ts.Title, CreatedAt: ts.CreatedAt, Error: errorMsg, Downloaded: ts.Downloaded.Load(), Total: ts.Total.Load(), ProgressPct: math.Float64frombits(atomic.LoadUint64((*uint64)(&ts.ProgressPct))), } body, err := json.Marshal(payload) if err != nil { logger.Errorf("Failed to marshal webhook payload: %v", err) return } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() req, err := http.NewRequestWithContext(ctx, "POST", cfg.API.WebhookURL, bytes.NewReader(body)) if err != nil { logger.Errorf("Failed to create webhook request: %v", err) return } req.Header.Set("Content-Type", "application/json") if cfg.API.Token != "" { req.Header.Set("Authorization", "Bearer "+cfg.API.Token) } resp, err := http.DefaultClient.Do(req) if err != nil { logger.Errorf("Failed to send webhook: %v", err) return } defer resp.Body.Close() if resp.StatusCode >= 400 { body, err := io.ReadAll(resp.Body) if err != nil { logger.Errorf("Webhook returned error status %d, failed to read response body: %v", resp.StatusCode, err) } else { logger.Errorf("Webhook returned error status %d: %s", resp.StatusCode, string(body)) } } }