Fix performance issues, add media group support, and improve filename handling

Co-authored-by: krau <71133316+krau@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot]
2026-01-19 06:14:48 +00:00
parent 3a6402a71b
commit 6896bdc852

View File

@@ -4,18 +4,25 @@ import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"math"
"net/http"
"net/url"
"path"
"sync"
"sync/atomic"
"time"
"github.com/celestix/gotgproto/ext"
"github.com/charmbracelet/log"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/client/bot"
"github.com/krau/SaveAny-Bot/common/utils/tgutil"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/core"
tftask "github.com/krau/SaveAny-Bot/core/tasks/tfile"
"github.com/krau/SaveAny-Bot/database"
"github.com/krau/SaveAny-Bot/pkg/queue"
"github.com/krau/SaveAny-Bot/pkg/tfile"
"github.com/krau/SaveAny-Bot/storage"
@@ -72,9 +79,9 @@ type taskStatus struct {
Title string
CreatedAt time.Time
Error string
Downloaded int64
Total int64
ProgressPct float64
Downloaded atomic.Int64 // Use atomic for lock-free updates
Total atomic.Int64 // Use atomic for lock-free updates
ProgressPct uint64 // Store as uint64 bits of float64 for atomic access
}
func handleHealth(w http.ResponseWriter, r *http.Request) {
@@ -102,9 +109,16 @@ func handleCreateTask(w http.ResponseWriter, r *http.Request) {
logger := log.FromContext(r.Context()).WithPrefix("api")
// Get user from database
userDB, err := database.GetUserByChatID(r.Context(), req.UserID)
if err != nil {
logger.Errorf("Failed to get user: %v", err)
respondError(w, "user not found", http.StatusBadRequest)
return
}
// Get storage
var stor storage.Storage
var err error
if req.StorageName != "" {
stor, err = storage.GetStorageByUserIDAndName(r.Context(), req.UserID, req.StorageName)
if err != nil {
@@ -129,6 +143,13 @@ func handleCreateTask(w http.ResponseWriter, r *http.Request) {
return
}
linkUrl, err := url.Parse(req.TelegramURL)
if err != nil {
logger.Errorf("Failed to parse URL: %v", err)
respondError(w, "invalid telegram URL format", http.StatusBadRequest)
return
}
chatID, msgID, err := tgutil.ParseMessageLink(botCtx, req.TelegramURL)
if err != nil {
logger.Errorf("Failed to parse Telegram URL: %v", err)
@@ -151,53 +172,118 @@ func handleCreateTask(w http.ResponseWriter, r *http.Request) {
return
}
// Create TGFile from message media
tgFile, err := tfile.FromMediaMessage(media, botCtx.Raw, msg)
if err != nil {
logger.Errorf("Failed to create TGFile: %v", err)
respondError(w, "invalid message format", http.StatusBadRequest)
// Collect files - handle both single and grouped messages
files := make([]tfile.TGFileMessage, 0)
// Check for grouped messages (media group)
groupID, isGroup := msg.GetGroupedID()
if isGroup && groupID != 0 && !linkUrl.Query().Has("single") {
// Handle media group
gmsgs, err := tgutil.GetGroupedMessages(botCtx, chatID, msg)
if err != nil {
logger.Errorf("Failed to get grouped messages: %v", err)
// Fall back to single message
file, err := createTGFileWithMedia(botCtx, msg, media, userDB)
if err != nil {
logger.Errorf("Failed to create TGFile: %v", err)
respondError(w, "invalid message format", http.StatusBadRequest)
return
}
files = append(files, file)
} else {
// Process all messages in the group
for _, gmsg := range gmsgs {
if gmsg.Media == nil {
continue
}
gMedia, ok := gmsg.GetMedia()
if !ok {
continue
}
file, err := createTGFileWithMedia(botCtx, gmsg, gMedia, userDB)
if err != nil {
logger.Warnf("Failed to create TGFile for grouped message: %v", err)
continue
}
files = append(files, file)
}
}
} else {
// Single message
file, err := createTGFileWithMedia(botCtx, msg, media, userDB)
if err != nil {
logger.Errorf("Failed to create TGFile: %v", err)
respondError(w, "invalid message format", http.StatusBadRequest)
return
}
files = append(files, file)
}
if len(files) == 0 {
respondError(w, "no savable files found", http.StatusBadRequest)
return
}
// Create task
// Create tasks for all files
taskIDs := make([]string, 0, len(files))
dirPath := req.DirPath
if dirPath == "" {
dirPath = "/"
}
storagePath := stor.JoinStoragePath(path.Join(dirPath, tgFile.Name()))
taskID := xid.New().String()
// Create context with bot extension
injectCtx := tgutil.ExtWithContext(r.Context(), botCtx)
task, err := tftask.NewTGFileTask(taskID, injectCtx, tgFile, stor, storagePath, &apiProgressTracker{
taskID: taskID,
})
if err != nil {
logger.Errorf("Failed to create task: %v", err)
respondError(w, "failed to create task", http.StatusInternalServerError)
return
}
for _, tgFile := range files {
storagePath := stor.JoinStoragePath(path.Join(dirPath, tgFile.Name()))
taskID := xid.New().String()
// Track task status
trackTask(taskID, task.Title(), "queued")
task, err := tftask.NewTGFileTask(taskID, injectCtx, tgFile, stor, storagePath, &apiProgressTracker{
taskID: taskID,
})
if err != nil {
logger.Errorf("Failed to create task: %v", err)
respondError(w, "failed to create task", http.StatusInternalServerError)
return
}
// Add task to queue
if err := core.AddTask(injectCtx, task); err != nil {
logger.Errorf("Failed to add task: %v", err)
updateTaskStatus(taskID, "failed", err.Error())
respondError(w, "failed to add task to queue", http.StatusInternalServerError)
return
// Track task status
trackTask(taskID, task.Title(), "queued")
// Add task to queue
if err := core.AddTask(injectCtx, task); err != nil {
logger.Errorf("Failed to add task: %v", err)
updateTaskStatus(taskID, "failed", err.Error())
respondError(w, "failed to add task to queue", http.StatusInternalServerError)
return
}
taskIDs = append(taskIDs, taskID)
}
// Send success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(CreateTaskResponse{
TaskID: taskID,
Message: "task created successfully",
})
// Return first task ID for single file, or all task IDs for media group
if len(taskIDs) == 1 {
json.NewEncoder(w).Encode(CreateTaskResponse{
TaskID: taskIDs[0],
Message: "task created successfully",
})
} else {
json.NewEncoder(w).Encode(map[string]interface{}{
"task_ids": taskIDs,
"message": fmt.Sprintf("%d tasks created successfully", len(taskIDs)),
})
}
}
// createTGFileWithMedia creates a TGFile with proper filename handling using user's strategy
func createTGFileWithMedia(botCtx *ext.Context, msg *tg.Message, media tg.MessageMediaClass, userDB *database.User) (tfile.TGFileMessage, error) {
// Use the same filename generation logic as bot handlers
opts := []tfile.TGFileOption{tfile.WithNameIfEmpty(tgutil.GenFileNameFromMessage(*msg))}
return tfile.FromMediaMessage(media, botCtx.Raw, msg, opts...)
}
func handleGetTask(w http.ResponseWriter, r *http.Request) {
@@ -223,9 +309,9 @@ func handleGetTask(w http.ResponseWriter, r *http.Request) {
Title: status.Title,
CreatedAt: status.CreatedAt,
Error: status.Error,
Downloaded: status.Downloaded,
Total: status.Total,
ProgressPct: status.ProgressPct,
Downloaded: status.Downloaded.Load(),
Total: status.Total.Load(),
ProgressPct: math.Float64frombits(atomic.LoadUint64((*uint64)(&status.ProgressPct))),
})
}
@@ -310,13 +396,18 @@ func (a *apiProgressTracker) OnStart(ctx context.Context, info tftask.TaskInfo)
}
func (a *apiProgressTracker) OnProgress(ctx context.Context, info tftask.TaskInfo, downloaded int64, total int64) {
taskStatusesMu.Lock()
defer taskStatusesMu.Unlock()
if ts, exists := taskStatuses[a.taskID]; exists {
ts.Downloaded = downloaded
ts.Total = total
// Use atomic operations to avoid mutex locks for better performance
// OnProgress is called very frequently during downloads
taskStatusesMu.RLock()
ts, exists := taskStatuses[a.taskID]
taskStatusesMu.RUnlock()
if exists {
ts.Downloaded.Store(downloaded)
ts.Total.Store(total)
if total > 0 {
ts.ProgressPct = float64(downloaded) / float64(total) * 100.0
progressPct := float64(downloaded) / float64(total) * 100.0
atomic.StoreUint64((*uint64)(&ts.ProgressPct), math.Float64bits(progressPct))
}
}
}
@@ -354,9 +445,9 @@ func sendWebhook(taskID, status, errorMsg string) {
Title: ts.Title,
CreatedAt: ts.CreatedAt,
Error: errorMsg,
Downloaded: ts.Downloaded,
Total: ts.Total,
ProgressPct: ts.ProgressPct,
Downloaded: ts.Downloaded.Load(),
Total: ts.Total.Load(),
ProgressPct: math.Float64frombits(atomic.LoadUint64((*uint64)(&ts.ProgressPct))),
}
body, err := json.Marshal(payload)