From 6896bdc852d179d674e2e0dfd7c152b771ca53f3 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 19 Jan 2026 06:14:48 +0000 Subject: [PATCH] Fix performance issues, add media group support, and improve filename handling Co-authored-by: krau <71133316+krau@users.noreply.github.com> --- api/handlers.go | 181 ++++++++++++++++++++++++++++++++++++------------ 1 file changed, 136 insertions(+), 45 deletions(-) diff --git a/api/handlers.go b/api/handlers.go index f1daef6..0e7bdfb 100644 --- a/api/handlers.go +++ b/api/handlers.go @@ -4,18 +4,25 @@ import ( "bytes" "context" "encoding/json" + "fmt" "io" + "math" "net/http" + "net/url" "path" "sync" + "sync/atomic" "time" + "github.com/celestix/gotgproto/ext" "github.com/charmbracelet/log" + "github.com/gotd/td/tg" "github.com/krau/SaveAny-Bot/client/bot" "github.com/krau/SaveAny-Bot/common/utils/tgutil" "github.com/krau/SaveAny-Bot/config" "github.com/krau/SaveAny-Bot/core" tftask "github.com/krau/SaveAny-Bot/core/tasks/tfile" + "github.com/krau/SaveAny-Bot/database" "github.com/krau/SaveAny-Bot/pkg/queue" "github.com/krau/SaveAny-Bot/pkg/tfile" "github.com/krau/SaveAny-Bot/storage" @@ -72,9 +79,9 @@ type taskStatus struct { Title string CreatedAt time.Time Error string - Downloaded int64 - Total int64 - ProgressPct float64 + Downloaded atomic.Int64 // Use atomic for lock-free updates + Total atomic.Int64 // Use atomic for lock-free updates + ProgressPct uint64 // Store as uint64 bits of float64 for atomic access } func handleHealth(w http.ResponseWriter, r *http.Request) { @@ -102,9 +109,16 @@ func handleCreateTask(w http.ResponseWriter, r *http.Request) { logger := log.FromContext(r.Context()).WithPrefix("api") + // Get user from database + userDB, err := database.GetUserByChatID(r.Context(), req.UserID) + if err != nil { + logger.Errorf("Failed to get user: %v", err) + respondError(w, "user not found", http.StatusBadRequest) + return + } + // Get storage var stor storage.Storage - var err error if req.StorageName != "" { stor, err = storage.GetStorageByUserIDAndName(r.Context(), req.UserID, req.StorageName) if err != nil { @@ -129,6 +143,13 @@ func handleCreateTask(w http.ResponseWriter, r *http.Request) { return } + linkUrl, err := url.Parse(req.TelegramURL) + if err != nil { + logger.Errorf("Failed to parse URL: %v", err) + respondError(w, "invalid telegram URL format", http.StatusBadRequest) + return + } + chatID, msgID, err := tgutil.ParseMessageLink(botCtx, req.TelegramURL) if err != nil { logger.Errorf("Failed to parse Telegram URL: %v", err) @@ -151,53 +172,118 @@ func handleCreateTask(w http.ResponseWriter, r *http.Request) { return } - // Create TGFile from message media - tgFile, err := tfile.FromMediaMessage(media, botCtx.Raw, msg) - if err != nil { - logger.Errorf("Failed to create TGFile: %v", err) - respondError(w, "invalid message format", http.StatusBadRequest) + // Collect files - handle both single and grouped messages + files := make([]tfile.TGFileMessage, 0) + + // Check for grouped messages (media group) + groupID, isGroup := msg.GetGroupedID() + if isGroup && groupID != 0 && !linkUrl.Query().Has("single") { + // Handle media group + gmsgs, err := tgutil.GetGroupedMessages(botCtx, chatID, msg) + if err != nil { + logger.Errorf("Failed to get grouped messages: %v", err) + // Fall back to single message + file, err := createTGFileWithMedia(botCtx, msg, media, userDB) + if err != nil { + logger.Errorf("Failed to create TGFile: %v", err) + respondError(w, "invalid message format", http.StatusBadRequest) + return + } + files = append(files, file) + } else { + // Process all messages in the group + for _, gmsg := range gmsgs { + if gmsg.Media == nil { + continue + } + gMedia, ok := gmsg.GetMedia() + if !ok { + continue + } + file, err := createTGFileWithMedia(botCtx, gmsg, gMedia, userDB) + if err != nil { + logger.Warnf("Failed to create TGFile for grouped message: %v", err) + continue + } + files = append(files, file) + } + } + } else { + // Single message + file, err := createTGFileWithMedia(botCtx, msg, media, userDB) + if err != nil { + logger.Errorf("Failed to create TGFile: %v", err) + respondError(w, "invalid message format", http.StatusBadRequest) + return + } + files = append(files, file) + } + + if len(files) == 0 { + respondError(w, "no savable files found", http.StatusBadRequest) return } - // Create task + // Create tasks for all files + taskIDs := make([]string, 0, len(files)) dirPath := req.DirPath if dirPath == "" { dirPath = "/" } - storagePath := stor.JoinStoragePath(path.Join(dirPath, tgFile.Name())) - taskID := xid.New().String() - // Create context with bot extension injectCtx := tgutil.ExtWithContext(r.Context(), botCtx) - task, err := tftask.NewTGFileTask(taskID, injectCtx, tgFile, stor, storagePath, &apiProgressTracker{ - taskID: taskID, - }) - if err != nil { - logger.Errorf("Failed to create task: %v", err) - respondError(w, "failed to create task", http.StatusInternalServerError) - return - } + for _, tgFile := range files { + storagePath := stor.JoinStoragePath(path.Join(dirPath, tgFile.Name())) + taskID := xid.New().String() - // Track task status - trackTask(taskID, task.Title(), "queued") + task, err := tftask.NewTGFileTask(taskID, injectCtx, tgFile, stor, storagePath, &apiProgressTracker{ + taskID: taskID, + }) + if err != nil { + logger.Errorf("Failed to create task: %v", err) + respondError(w, "failed to create task", http.StatusInternalServerError) + return + } - // Add task to queue - if err := core.AddTask(injectCtx, task); err != nil { - logger.Errorf("Failed to add task: %v", err) - updateTaskStatus(taskID, "failed", err.Error()) - respondError(w, "failed to add task to queue", http.StatusInternalServerError) - return + // Track task status + trackTask(taskID, task.Title(), "queued") + + // Add task to queue + if err := core.AddTask(injectCtx, task); err != nil { + logger.Errorf("Failed to add task: %v", err) + updateTaskStatus(taskID, "failed", err.Error()) + respondError(w, "failed to add task to queue", http.StatusInternalServerError) + return + } + + taskIDs = append(taskIDs, taskID) } // Send success response w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusCreated) - json.NewEncoder(w).Encode(CreateTaskResponse{ - TaskID: taskID, - Message: "task created successfully", - }) + + // Return first task ID for single file, or all task IDs for media group + if len(taskIDs) == 1 { + json.NewEncoder(w).Encode(CreateTaskResponse{ + TaskID: taskIDs[0], + Message: "task created successfully", + }) + } else { + json.NewEncoder(w).Encode(map[string]interface{}{ + "task_ids": taskIDs, + "message": fmt.Sprintf("%d tasks created successfully", len(taskIDs)), + }) + } +} + +// createTGFileWithMedia creates a TGFile with proper filename handling using user's strategy +func createTGFileWithMedia(botCtx *ext.Context, msg *tg.Message, media tg.MessageMediaClass, userDB *database.User) (tfile.TGFileMessage, error) { + // Use the same filename generation logic as bot handlers + opts := []tfile.TGFileOption{tfile.WithNameIfEmpty(tgutil.GenFileNameFromMessage(*msg))} + return tfile.FromMediaMessage(media, botCtx.Raw, msg, opts...) } func handleGetTask(w http.ResponseWriter, r *http.Request) { @@ -223,9 +309,9 @@ func handleGetTask(w http.ResponseWriter, r *http.Request) { Title: status.Title, CreatedAt: status.CreatedAt, Error: status.Error, - Downloaded: status.Downloaded, - Total: status.Total, - ProgressPct: status.ProgressPct, + Downloaded: status.Downloaded.Load(), + Total: status.Total.Load(), + ProgressPct: math.Float64frombits(atomic.LoadUint64((*uint64)(&status.ProgressPct))), }) } @@ -310,13 +396,18 @@ func (a *apiProgressTracker) OnStart(ctx context.Context, info tftask.TaskInfo) } func (a *apiProgressTracker) OnProgress(ctx context.Context, info tftask.TaskInfo, downloaded int64, total int64) { - taskStatusesMu.Lock() - defer taskStatusesMu.Unlock() - if ts, exists := taskStatuses[a.taskID]; exists { - ts.Downloaded = downloaded - ts.Total = total + // Use atomic operations to avoid mutex locks for better performance + // OnProgress is called very frequently during downloads + taskStatusesMu.RLock() + ts, exists := taskStatuses[a.taskID] + taskStatusesMu.RUnlock() + + if exists { + ts.Downloaded.Store(downloaded) + ts.Total.Store(total) if total > 0 { - ts.ProgressPct = float64(downloaded) / float64(total) * 100.0 + progressPct := float64(downloaded) / float64(total) * 100.0 + atomic.StoreUint64((*uint64)(&ts.ProgressPct), math.Float64bits(progressPct)) } } } @@ -354,9 +445,9 @@ func sendWebhook(taskID, status, errorMsg string) { Title: ts.Title, CreatedAt: ts.CreatedAt, Error: errorMsg, - Downloaded: ts.Downloaded, - Total: ts.Total, - ProgressPct: ts.ProgressPct, + Downloaded: ts.Downloaded.Load(), + Total: ts.Total.Load(), + ProgressPct: math.Float64frombits(atomic.LoadUint64((*uint64)(&ts.ProgressPct))), } body, err := json.Marshal(payload)