feat: batch save files

This commit is contained in:
krau
2025-04-12 16:27:23 +08:00
parent 725acd0199
commit c8c348a182
5 changed files with 236 additions and 59 deletions

View File

@@ -29,24 +29,40 @@ func handleLinkMessage(ctx *ext.Context, update *ext.Update) error {
if len(strSlice) < 3 { if len(strSlice) < 3 {
return dispatcher.ContinueGroups return dispatcher.ContinueGroups
} }
messageID, err := strconv.Atoi(strSlice[2]) messageID, err := strconv.Atoi(strSlice[len(strSlice)-1])
if err != nil { if err != nil {
common.Log.Errorf("解析消息 ID 失败: %s", err) common.Log.Errorf("解析消息 ID 失败: %s", err)
ctx.Reply(update, ext.ReplyTextString("无法解析消息 ID"), nil) ctx.Reply(update, ext.ReplyTextString("无法解析消息 ID"), nil)
return dispatcher.EndGroups return dispatcher.EndGroups
} }
chatUsername := strSlice[1] var linkChatID int64
linkChat, err := ctx.ResolveUsername(chatUsername) if len(strSlice) == 3 {
if err != nil { chatUsername := strSlice[1]
common.Log.Errorf("解析 Chat ID 失败: %s", err) linkChat, err := ctx.ResolveUsername(chatUsername)
ctx.Reply(update, ext.ReplyTextString("无法解析 Chat ID"), nil) if err != nil {
return dispatcher.EndGroups common.Log.Errorf("解析用户名失败: %s", err)
} ctx.Reply(update, ext.ReplyTextString("解析用户名失败"), nil)
if linkChat == nil { return dispatcher.EndGroups
common.Log.Errorf("无法找到聊天: %s", chatUsername) }
ctx.Reply(update, ext.ReplyTextString("无法找到聊天"), nil) if linkChat == nil {
common.Log.Errorf("无法找到聊天: %s", chatUsername)
ctx.Reply(update, ext.ReplyTextString("无法找到聊天"), nil)
return dispatcher.EndGroups
}
linkChatID = linkChat.GetID()
} else if len(strSlice) == 4 {
chatID, err := strconv.Atoi(strSlice[2])
if err != nil {
common.Log.Errorf("解析 Chat ID 失败: %s", err)
ctx.Reply(update, ext.ReplyTextString("解析 Chat ID 失败"), nil)
return dispatcher.EndGroups
}
linkChatID = int64(chatID)
} else {
ctx.Reply(update, ext.ReplyTextString("无法解析链接"), nil)
return dispatcher.EndGroups return dispatcher.EndGroups
} }
user, err := dao.GetUserByChatID(update.GetUserChat().GetID()) user, err := dao.GetUserByChatID(update.GetUserChat().GetID())
if err != nil { if err != nil {
common.Log.Errorf("获取用户失败: %s", err) common.Log.Errorf("获取用户失败: %s", err)
@@ -65,7 +81,7 @@ func handleLinkMessage(ctx *ext.Context, update *ext.Update) error {
return dispatcher.EndGroups return dispatcher.EndGroups
} }
file, err := FileFromMessage(ctx, linkChat.GetID(), messageID, "") file, err := FileFromMessage(ctx, linkChatID, messageID, "")
if err != nil { if err != nil {
common.Log.Errorf("获取文件失败: %s", err) common.Log.Errorf("获取文件失败: %s", err)
ctx.Reply(update, ext.ReplyTextString("获取文件失败: "+err.Error()), nil) ctx.Reply(update, ext.ReplyTextString("获取文件失败: "+err.Error()), nil)
@@ -78,7 +94,7 @@ func handleLinkMessage(ctx *ext.Context, update *ext.Update) error {
receivedFile := &dao.ReceivedFile{ receivedFile := &dao.ReceivedFile{
Processing: false, Processing: false,
FileName: file.FileName, FileName: file.FileName,
ChatID: linkChat.GetID(), ChatID: linkChatID,
MessageID: messageID, MessageID: messageID,
ReplyMessageID: replied.ID, ReplyMessageID: replied.ID,
ReplyChatID: update.GetUserChat().GetID(), ReplyChatID: update.GetUserChat().GetID(),
@@ -92,7 +108,7 @@ func handleLinkMessage(ctx *ext.Context, update *ext.Update) error {
return dispatcher.EndGroups return dispatcher.EndGroups
} }
if !user.Silent || user.DefaultStorage == "" { if !user.Silent || user.DefaultStorage == "" {
return ProvideSelectMessage(ctx, update, file.FileName, linkChat.GetID(), messageID, replied.ID) return ProvideSelectMessage(ctx, update, file.FileName, linkChatID, messageID, replied.ID)
} }
return HandleSilentAddTask(ctx, update, user, &types.Task{ return HandleSilentAddTask(ctx, update, user, &types.Task{
Ctx: ctx, Ctx: ctx,
@@ -100,7 +116,7 @@ func handleLinkMessage(ctx *ext.Context, update *ext.Update) error {
File: file, File: file,
StorageName: user.DefaultStorage, StorageName: user.DefaultStorage,
UserID: user.ChatID, UserID: user.ChatID,
FileChatID: linkChat.GetID(), FileChatID: linkChatID,
FileMessageID: messageID, FileMessageID: messageID,
ReplyMessageID: replied.ID, ReplyMessageID: replied.ID,
ReplyChatID: update.GetUserChat().GetID(), ReplyChatID: update.GetUserChat().GetID(),

View File

@@ -2,6 +2,7 @@ package bot
import ( import (
"fmt" "fmt"
"strconv"
"strings" "strings"
"github.com/celestix/gotgproto/dispatcher" "github.com/celestix/gotgproto/dispatcher"
@@ -9,25 +10,50 @@ import (
"github.com/gotd/td/tg" "github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/common" "github.com/krau/SaveAny-Bot/common"
"github.com/krau/SaveAny-Bot/dao" "github.com/krau/SaveAny-Bot/dao"
"github.com/krau/SaveAny-Bot/queue"
"github.com/krau/SaveAny-Bot/storage" "github.com/krau/SaveAny-Bot/storage"
"github.com/krau/SaveAny-Bot/types" "github.com/krau/SaveAny-Bot/types"
) )
func sendSaveHelp(ctx *ext.Context, update *ext.Update) error {
helpText := `
使用方法:
1. 使用该命令回复要保存的文件, 可选文件名参数.
示例:
/save custom_file_name.mp4
2. 设置默认存储后, 发送 /save <频道ID/用户名> <消息ID范围> 来批量保存文件. 遵从存储规则, 若未匹配到任何规则则使用默认存储.
示例:
/save @moreacg 114-514
`
ctx.Reply(update, ext.ReplyTextString(helpText), nil)
return dispatcher.EndGroups
}
func saveCmd(ctx *ext.Context, update *ext.Update) error { func saveCmd(ctx *ext.Context, update *ext.Update) error {
res, ok := update.EffectiveMessage.GetReplyTo() args := strings.Split(update.EffectiveMessage.Text, " ")
if !ok || res == nil { if len(args) >= 3 {
ctx.Reply(update, ext.ReplyTextString("请回复要保存的文件"), nil) return handleBatchSave(ctx, update, args[1:])
return dispatcher.EndGroups
} }
replyHeader, ok := res.(*tg.MessageReplyHeader)
if !ok { replyToMsgID := func() int {
ctx.Reply(update, ext.ReplyTextString("请回复要保存的文件"), nil) res, ok := update.EffectiveMessage.GetReplyTo()
return dispatcher.EndGroups if !ok || res == nil {
} return 0
replyToMsgID, ok := replyHeader.GetReplyToMsgID() }
if !ok { replyHeader, ok := res.(*tg.MessageReplyHeader)
ctx.Reply(update, ext.ReplyTextString("请回复要保存的文件"), nil) if !ok {
return dispatcher.EndGroups return 0
}
replyToMsgID, ok := replyHeader.GetReplyToMsgID()
if !ok {
return 0
}
return replyToMsgID
}()
if replyToMsgID == 0 {
return sendSaveHelp(ctx, update)
} }
user, err := dao.GetUserByChatID(update.GetUserChat().GetID()) user, err := dao.GetUserByChatID(update.GetUserChat().GetID())
@@ -113,3 +139,125 @@ func saveCmd(ctx *ext.Context, update *ext.Update) error {
UserID: user.ChatID, UserID: user.ChatID,
}) })
} }
func handleBatchSave(ctx *ext.Context, update *ext.Update, args []string) error {
// args: [0] = @channel, [1] = 114-514
chatArg := args[0]
var chatID int64
var err error
msgIdSlice := strings.Split(args[1], "-")
if len(msgIdSlice) != 2 {
ctx.Reply(update, ext.ReplyTextString("无效的消息ID范围"), nil)
return dispatcher.EndGroups
}
minMsgID, minerr := strconv.ParseInt(msgIdSlice[0], 10, 64)
maxMsgID, maxerr := strconv.ParseInt(msgIdSlice[1], 10, 64)
if minerr != nil || maxerr != nil {
ctx.Reply(update, ext.ReplyTextString("无效的消息ID范围"), nil)
return dispatcher.EndGroups
}
if minMsgID > maxMsgID || minMsgID <= 0 || maxMsgID <= 0 {
ctx.Reply(update, ext.ReplyTextString("无效的消息ID范围"), nil)
return dispatcher.EndGroups
}
user, err := dao.GetUserByChatID(update.GetUserChat().GetID())
if err != nil {
common.Log.Errorf("获取用户失败: %s", err)
ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil)
return dispatcher.EndGroups
}
if user.DefaultStorage == "" {
ctx.Reply(update, ext.ReplyTextString("请先设置默认存储"), nil)
return dispatcher.EndGroups
}
storages := storage.GetUserStorages(user.ChatID)
if len(storages) == 0 {
ctx.Reply(update, ext.ReplyTextString("无可用的存储"), nil)
return dispatcher.EndGroups
}
if strings.HasPrefix(chatArg, "@") {
chatUsername := strings.TrimPrefix(chatArg, "@")
chat, err := ctx.ResolveUsername(chatUsername)
if err != nil {
common.Log.Errorf("解析频道用户名失败: %s", err)
ctx.Reply(update, ext.ReplyTextString("解析频道用户名失败"), nil)
return dispatcher.EndGroups
}
if chat == nil {
ctx.Reply(update, ext.ReplyTextString("无法找到聊天"), nil)
return dispatcher.EndGroups
}
chatID = chat.GetID()
} else {
chatID, err = strconv.ParseInt(chatArg, 10, 64)
if err != nil {
ctx.Reply(update, ext.ReplyTextString("无效的频道ID或用户名"), nil)
return dispatcher.EndGroups
}
}
if chatID == 0 {
ctx.Reply(update, ext.ReplyTextString("无效的频道ID或用户名"), nil)
return dispatcher.EndGroups
}
replied, err := ctx.Reply(update, ext.ReplyTextString("正在批量保存..."), nil)
if err != nil {
common.Log.Errorf("回复失败: %s", err)
return dispatcher.EndGroups
}
total := maxMsgID - minMsgID + 1
successadd := 0
failedGetFile := 0
failedGetMsg := 0
failedSaveDB := 0
for i := minMsgID; i <= maxMsgID; i++ {
file, err := FileFromMessage(ctx, chatID, int(i), "")
if err != nil {
common.Log.Errorf("获取文件失败: %s", err)
failedGetFile++
continue
}
if file.FileName == "" {
message, err := GetTGMessage(ctx, chatID, int(i))
if err != nil {
common.Log.Errorf("获取消息失败: %s", err)
failedGetMsg++
continue
}
file.FileName = GenFileNameFromMessage(*message, file)
}
receivedFile := &dao.ReceivedFile{
Processing: false,
FileName: file.FileName,
ChatID: chatID,
MessageID: int(i),
ReplyChatID: update.GetUserChat().GetID(),
ReplyMessageID: 0,
}
if err := dao.SaveReceivedFile(receivedFile); err != nil {
common.Log.Errorf("保存接收的文件失败: %s", err)
failedSaveDB++
continue
}
task := &types.Task{
Ctx: ctx,
Status: types.Pending,
File: file,
StorageName: user.DefaultStorage,
FileChatID: chatID,
FileMessageID: int(i),
UserID: user.ChatID,
ReplyMessageID: 0,
ReplyChatID: update.GetUserChat().GetID(),
}
queue.AddTask(task)
successadd++
}
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: fmt.Sprintf("批量保存完成\n成功添加: %d/%d\n获取文件失败: %d\n获取消息失败: %d\n保存数据库失败: %d", successadd, total, failedGetFile, failedGetMsg, failedSaveDB),
ID: replied.ID,
})
return dispatcher.EndGroups
}

View File

@@ -46,7 +46,7 @@ func worker(queue *queue.TaskQueue, semaphore chan struct{}) {
extCtx, ok := task.Ctx.(*ext.Context) extCtx, ok := task.Ctx.(*ext.Context)
if !ok { if !ok {
common.Log.Errorf("Context is not *ext.Context: %T", task.Ctx) common.Log.Errorf("Context is not *ext.Context: %T", task.Ctx)
} else { } else if task.ReplyMessageID != 0 {
extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: fmt.Sprintf("文件保存成功\n [%s]: %s", task.StorageName, task.StoragePath), Message: fmt.Sprintf("文件保存成功\n [%s]: %s", task.StorageName, task.StoragePath),
ID: task.ReplyMessageID, ID: task.ReplyMessageID,
@@ -57,7 +57,7 @@ func worker(queue *queue.TaskQueue, semaphore chan struct{}) {
extCtx, ok := task.Ctx.(*ext.Context) extCtx, ok := task.Ctx.(*ext.Context)
if !ok { if !ok {
common.Log.Errorf("Context is not *ext.Context: %T", task.Ctx) common.Log.Errorf("Context is not *ext.Context: %T", task.Ctx)
} else { } else if task.ReplyMessageID != 0 {
extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: "文件保存失败\n" + task.Error.Error(), Message: "文件保存失败\n" + task.Error.Error(),
ID: task.ReplyMessageID, ID: task.ReplyMessageID,
@@ -68,7 +68,7 @@ func worker(queue *queue.TaskQueue, semaphore chan struct{}) {
extCtx, ok := task.Ctx.(*ext.Context) extCtx, ok := task.Ctx.(*ext.Context)
if !ok { if !ok {
common.Log.Errorf("Context is not *ext.Context: %T", task.Ctx) common.Log.Errorf("Context is not *ext.Context: %T", task.Ctx)
} else { } else if task.ReplyMessageID != 0 {
extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: "任务已取消", Message: "任务已取消",
ID: task.ReplyMessageID, ID: task.ReplyMessageID,

View File

@@ -63,12 +63,14 @@ func processPendingTask(task *types.Task) error {
if config.Cfg.Stream { if config.Cfg.Stream {
if !notsupportStream { if !notsupportStream {
text, entities := buildProgressMessageEntity(task, 0, task.StartTime, 0) text, entities := buildProgressMessageEntity(task, 0, task.StartTime, 0)
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ if task.ReplyMessageID != 0 {
Message: text, ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Entities: entities, Message: text,
ID: task.ReplyMessageID, Entities: entities,
ReplyMarkup: cancelMarkUp, ID: task.ReplyMessageID,
}) ReplyMarkup: cancelMarkUp,
})
}
pr, pw := io.Pipe() pr, pw := io.Pipe()
defer pr.Close() defer pr.Close()
@@ -97,11 +99,14 @@ func processPendingTask(task *types.Task) error {
return nil return nil
} }
common.Log.Warnf("存储 %s 不支持流式传输: %s", task.StorageName, notsupportStreamStorage.NotSupportStream()) common.Log.Warnf("存储 %s 不支持流式传输: %s", task.StorageName, notsupportStreamStorage.NotSupportStream())
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: fmt.Sprintf("存储 %s 不支持流式传输: %s\n正在使用普通下载...", task.StorageName, notsupportStreamStorage.NotSupportStream()), if task.ReplyMessageID != 0 {
ID: task.ReplyMessageID, ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
ReplyMarkup: cancelMarkUp, Message: fmt.Sprintf("存储 %s 不支持流式传输: %s\n正在使用普通下载...", task.StorageName, notsupportStreamStorage.NotSupportStream()),
}) ID: task.ReplyMessageID,
ReplyMarkup: cancelMarkUp,
})
}
} }
cacheDestPath := filepath.Join(config.Cfg.Temp.BasePath, task.FileName()) cacheDestPath := filepath.Join(config.Cfg.Temp.BasePath, task.FileName())
@@ -114,12 +119,14 @@ func processPendingTask(task *types.Task) error {
} }
text, entities := buildProgressMessageEntity(task, 0, task.StartTime, 0) text, entities := buildProgressMessageEntity(task, 0, task.StartTime, 0)
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ if task.ReplyMessageID != 0 {
Message: text, ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Entities: entities, Message: text,
ID: task.ReplyMessageID, Entities: entities,
ReplyMarkup: cancelMarkUp, ID: task.ReplyMessageID,
}) ReplyMarkup: cancelMarkUp,
})
}
progressCallback := buildProgressCallback(ctx, task, getProgressUpdateCount(task.File.FileSize)) progressCallback := buildProgressCallback(ctx, task, getProgressUpdateCount(task.File.FileSize))
dest, err := NewTaskLocalFile(cacheDestPath, task.File.FileSize, progressCallback) dest, err := NewTaskLocalFile(cacheDestPath, task.File.FileSize, progressCallback)
@@ -137,11 +144,12 @@ func processPendingTask(task *types.Task) error {
fixTaskFileExt(task, cacheDestPath) fixTaskFileExt(task, cacheDestPath)
common.Log.Infof("Downloaded file: %s", cacheDestPath) common.Log.Infof("Downloaded file: %s", cacheDestPath)
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ if task.ReplyMessageID != 0 {
Message: fmt.Sprintf("下载完成: %s\n正在转存文件...", task.FileName()), ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
ID: task.ReplyMessageID, Message: fmt.Sprintf("下载完成: %s\n正在转存文件...", task.FileName()),
}) ID: task.ReplyMessageID,
})
}
return saveFileWithRetry(cancelCtx, task.StoragePath, taskStorage, cacheDestPath) return saveFileWithRetry(cancelCtx, task.StoragePath, taskStorage, cacheDestPath)
} }
@@ -169,12 +177,14 @@ func processTelegraph(extCtx *ext.Context, cancelCtx context.Context, task *type
common.Log.Errorf("Failed to build entities: %s", err) common.Log.Errorf("Failed to build entities: %s", err)
} }
extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ if task.ReplyMessageID != 0 {
Message: text, extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Entities: entities, Message: text,
ID: task.ReplyMessageID, Entities: entities,
ReplyMarkup: getCancelTaskMarkup(task), ID: task.ReplyMessageID,
}) ReplyMarkup: getCancelTaskMarkup(task),
})
}
resultCh := make(chan error) resultCh := make(chan error)
go func() { go func() {

View File

@@ -145,6 +145,9 @@ func buildProgressCallback(ctx *ext.Context, task *types.Task, updateCount int)
if task.File.FileSize < 1024*1024*50 || progressInt == 0 || progressInt%int(100/updateCount) != 0 { if task.File.FileSize < 1024*1024*50 || progressInt == 0 || progressInt%int(100/updateCount) != 0 {
return return
} }
if task.ReplyMessageID == 0 {
return
}
text, entities := buildProgressMessageEntity(task, bytesRead, task.StartTime, progress) text, entities := buildProgressMessageEntity(task, bytesRead, task.StartTime, progress)
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: text, Message: text,