feat: batch save files

This commit is contained in:
krau
2025-04-12 16:27:23 +08:00
parent 725acd0199
commit c8c348a182
5 changed files with 236 additions and 59 deletions

View File

@@ -29,24 +29,40 @@ func handleLinkMessage(ctx *ext.Context, update *ext.Update) error {
if len(strSlice) < 3 {
return dispatcher.ContinueGroups
}
messageID, err := strconv.Atoi(strSlice[2])
messageID, err := strconv.Atoi(strSlice[len(strSlice)-1])
if err != nil {
common.Log.Errorf("解析消息 ID 失败: %s", err)
ctx.Reply(update, ext.ReplyTextString("无法解析消息 ID"), nil)
return dispatcher.EndGroups
}
chatUsername := strSlice[1]
linkChat, err := ctx.ResolveUsername(chatUsername)
if err != nil {
common.Log.Errorf("解析 Chat ID 失败: %s", err)
ctx.Reply(update, ext.ReplyTextString("无法解析 Chat ID"), nil)
return dispatcher.EndGroups
}
if linkChat == nil {
common.Log.Errorf("无法找到聊天: %s", chatUsername)
ctx.Reply(update, ext.ReplyTextString("无法找到聊天"), nil)
var linkChatID int64
if len(strSlice) == 3 {
chatUsername := strSlice[1]
linkChat, err := ctx.ResolveUsername(chatUsername)
if err != nil {
common.Log.Errorf("解析用户名失败: %s", err)
ctx.Reply(update, ext.ReplyTextString("解析用户名失败"), nil)
return dispatcher.EndGroups
}
if linkChat == nil {
common.Log.Errorf("无法找到聊天: %s", chatUsername)
ctx.Reply(update, ext.ReplyTextString("无法找到聊天"), nil)
return dispatcher.EndGroups
}
linkChatID = linkChat.GetID()
} else if len(strSlice) == 4 {
chatID, err := strconv.Atoi(strSlice[2])
if err != nil {
common.Log.Errorf("解析 Chat ID 失败: %s", err)
ctx.Reply(update, ext.ReplyTextString("解析 Chat ID 失败"), nil)
return dispatcher.EndGroups
}
linkChatID = int64(chatID)
} else {
ctx.Reply(update, ext.ReplyTextString("无法解析链接"), nil)
return dispatcher.EndGroups
}
user, err := dao.GetUserByChatID(update.GetUserChat().GetID())
if err != nil {
common.Log.Errorf("获取用户失败: %s", err)
@@ -65,7 +81,7 @@ func handleLinkMessage(ctx *ext.Context, update *ext.Update) error {
return dispatcher.EndGroups
}
file, err := FileFromMessage(ctx, linkChat.GetID(), messageID, "")
file, err := FileFromMessage(ctx, linkChatID, messageID, "")
if err != nil {
common.Log.Errorf("获取文件失败: %s", err)
ctx.Reply(update, ext.ReplyTextString("获取文件失败: "+err.Error()), nil)
@@ -78,7 +94,7 @@ func handleLinkMessage(ctx *ext.Context, update *ext.Update) error {
receivedFile := &dao.ReceivedFile{
Processing: false,
FileName: file.FileName,
ChatID: linkChat.GetID(),
ChatID: linkChatID,
MessageID: messageID,
ReplyMessageID: replied.ID,
ReplyChatID: update.GetUserChat().GetID(),
@@ -92,7 +108,7 @@ func handleLinkMessage(ctx *ext.Context, update *ext.Update) error {
return dispatcher.EndGroups
}
if !user.Silent || user.DefaultStorage == "" {
return ProvideSelectMessage(ctx, update, file.FileName, linkChat.GetID(), messageID, replied.ID)
return ProvideSelectMessage(ctx, update, file.FileName, linkChatID, messageID, replied.ID)
}
return HandleSilentAddTask(ctx, update, user, &types.Task{
Ctx: ctx,
@@ -100,7 +116,7 @@ func handleLinkMessage(ctx *ext.Context, update *ext.Update) error {
File: file,
StorageName: user.DefaultStorage,
UserID: user.ChatID,
FileChatID: linkChat.GetID(),
FileChatID: linkChatID,
FileMessageID: messageID,
ReplyMessageID: replied.ID,
ReplyChatID: update.GetUserChat().GetID(),

View File

@@ -2,6 +2,7 @@ package bot
import (
"fmt"
"strconv"
"strings"
"github.com/celestix/gotgproto/dispatcher"
@@ -9,25 +10,50 @@ import (
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/common"
"github.com/krau/SaveAny-Bot/dao"
"github.com/krau/SaveAny-Bot/queue"
"github.com/krau/SaveAny-Bot/storage"
"github.com/krau/SaveAny-Bot/types"
)
func sendSaveHelp(ctx *ext.Context, update *ext.Update) error {
helpText := `
使用方法:
1. 使用该命令回复要保存的文件, 可选文件名参数.
示例:
/save custom_file_name.mp4
2. 设置默认存储后, 发送 /save <频道ID/用户名> <消息ID范围> 来批量保存文件. 遵从存储规则, 若未匹配到任何规则则使用默认存储.
示例:
/save @moreacg 114-514
`
ctx.Reply(update, ext.ReplyTextString(helpText), nil)
return dispatcher.EndGroups
}
func saveCmd(ctx *ext.Context, update *ext.Update) error {
res, ok := update.EffectiveMessage.GetReplyTo()
if !ok || res == nil {
ctx.Reply(update, ext.ReplyTextString("请回复要保存的文件"), nil)
return dispatcher.EndGroups
args := strings.Split(update.EffectiveMessage.Text, " ")
if len(args) >= 3 {
return handleBatchSave(ctx, update, args[1:])
}
replyHeader, ok := res.(*tg.MessageReplyHeader)
if !ok {
ctx.Reply(update, ext.ReplyTextString("请回复要保存的文件"), nil)
return dispatcher.EndGroups
}
replyToMsgID, ok := replyHeader.GetReplyToMsgID()
if !ok {
ctx.Reply(update, ext.ReplyTextString("请回复要保存的文件"), nil)
return dispatcher.EndGroups
replyToMsgID := func() int {
res, ok := update.EffectiveMessage.GetReplyTo()
if !ok || res == nil {
return 0
}
replyHeader, ok := res.(*tg.MessageReplyHeader)
if !ok {
return 0
}
replyToMsgID, ok := replyHeader.GetReplyToMsgID()
if !ok {
return 0
}
return replyToMsgID
}()
if replyToMsgID == 0 {
return sendSaveHelp(ctx, update)
}
user, err := dao.GetUserByChatID(update.GetUserChat().GetID())
@@ -113,3 +139,125 @@ func saveCmd(ctx *ext.Context, update *ext.Update) error {
UserID: user.ChatID,
})
}
func handleBatchSave(ctx *ext.Context, update *ext.Update, args []string) error {
// args: [0] = @channel, [1] = 114-514
chatArg := args[0]
var chatID int64
var err error
msgIdSlice := strings.Split(args[1], "-")
if len(msgIdSlice) != 2 {
ctx.Reply(update, ext.ReplyTextString("无效的消息ID范围"), nil)
return dispatcher.EndGroups
}
minMsgID, minerr := strconv.ParseInt(msgIdSlice[0], 10, 64)
maxMsgID, maxerr := strconv.ParseInt(msgIdSlice[1], 10, 64)
if minerr != nil || maxerr != nil {
ctx.Reply(update, ext.ReplyTextString("无效的消息ID范围"), nil)
return dispatcher.EndGroups
}
if minMsgID > maxMsgID || minMsgID <= 0 || maxMsgID <= 0 {
ctx.Reply(update, ext.ReplyTextString("无效的消息ID范围"), nil)
return dispatcher.EndGroups
}
user, err := dao.GetUserByChatID(update.GetUserChat().GetID())
if err != nil {
common.Log.Errorf("获取用户失败: %s", err)
ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil)
return dispatcher.EndGroups
}
if user.DefaultStorage == "" {
ctx.Reply(update, ext.ReplyTextString("请先设置默认存储"), nil)
return dispatcher.EndGroups
}
storages := storage.GetUserStorages(user.ChatID)
if len(storages) == 0 {
ctx.Reply(update, ext.ReplyTextString("无可用的存储"), nil)
return dispatcher.EndGroups
}
if strings.HasPrefix(chatArg, "@") {
chatUsername := strings.TrimPrefix(chatArg, "@")
chat, err := ctx.ResolveUsername(chatUsername)
if err != nil {
common.Log.Errorf("解析频道用户名失败: %s", err)
ctx.Reply(update, ext.ReplyTextString("解析频道用户名失败"), nil)
return dispatcher.EndGroups
}
if chat == nil {
ctx.Reply(update, ext.ReplyTextString("无法找到聊天"), nil)
return dispatcher.EndGroups
}
chatID = chat.GetID()
} else {
chatID, err = strconv.ParseInt(chatArg, 10, 64)
if err != nil {
ctx.Reply(update, ext.ReplyTextString("无效的频道ID或用户名"), nil)
return dispatcher.EndGroups
}
}
if chatID == 0 {
ctx.Reply(update, ext.ReplyTextString("无效的频道ID或用户名"), nil)
return dispatcher.EndGroups
}
replied, err := ctx.Reply(update, ext.ReplyTextString("正在批量保存..."), nil)
if err != nil {
common.Log.Errorf("回复失败: %s", err)
return dispatcher.EndGroups
}
total := maxMsgID - minMsgID + 1
successadd := 0
failedGetFile := 0
failedGetMsg := 0
failedSaveDB := 0
for i := minMsgID; i <= maxMsgID; i++ {
file, err := FileFromMessage(ctx, chatID, int(i), "")
if err != nil {
common.Log.Errorf("获取文件失败: %s", err)
failedGetFile++
continue
}
if file.FileName == "" {
message, err := GetTGMessage(ctx, chatID, int(i))
if err != nil {
common.Log.Errorf("获取消息失败: %s", err)
failedGetMsg++
continue
}
file.FileName = GenFileNameFromMessage(*message, file)
}
receivedFile := &dao.ReceivedFile{
Processing: false,
FileName: file.FileName,
ChatID: chatID,
MessageID: int(i),
ReplyChatID: update.GetUserChat().GetID(),
ReplyMessageID: 0,
}
if err := dao.SaveReceivedFile(receivedFile); err != nil {
common.Log.Errorf("保存接收的文件失败: %s", err)
failedSaveDB++
continue
}
task := &types.Task{
Ctx: ctx,
Status: types.Pending,
File: file,
StorageName: user.DefaultStorage,
FileChatID: chatID,
FileMessageID: int(i),
UserID: user.ChatID,
ReplyMessageID: 0,
ReplyChatID: update.GetUserChat().GetID(),
}
queue.AddTask(task)
successadd++
}
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: fmt.Sprintf("批量保存完成\n成功添加: %d/%d\n获取文件失败: %d\n获取消息失败: %d\n保存数据库失败: %d", successadd, total, failedGetFile, failedGetMsg, failedSaveDB),
ID: replied.ID,
})
return dispatcher.EndGroups
}

View File

@@ -46,7 +46,7 @@ func worker(queue *queue.TaskQueue, semaphore chan struct{}) {
extCtx, ok := task.Ctx.(*ext.Context)
if !ok {
common.Log.Errorf("Context is not *ext.Context: %T", task.Ctx)
} else {
} else if task.ReplyMessageID != 0 {
extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: fmt.Sprintf("文件保存成功\n [%s]: %s", task.StorageName, task.StoragePath),
ID: task.ReplyMessageID,
@@ -57,7 +57,7 @@ func worker(queue *queue.TaskQueue, semaphore chan struct{}) {
extCtx, ok := task.Ctx.(*ext.Context)
if !ok {
common.Log.Errorf("Context is not *ext.Context: %T", task.Ctx)
} else {
} else if task.ReplyMessageID != 0 {
extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: "文件保存失败\n" + task.Error.Error(),
ID: task.ReplyMessageID,
@@ -68,7 +68,7 @@ func worker(queue *queue.TaskQueue, semaphore chan struct{}) {
extCtx, ok := task.Ctx.(*ext.Context)
if !ok {
common.Log.Errorf("Context is not *ext.Context: %T", task.Ctx)
} else {
} else if task.ReplyMessageID != 0 {
extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: "任务已取消",
ID: task.ReplyMessageID,

View File

@@ -63,12 +63,14 @@ func processPendingTask(task *types.Task) error {
if config.Cfg.Stream {
if !notsupportStream {
text, entities := buildProgressMessageEntity(task, 0, task.StartTime, 0)
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: text,
Entities: entities,
ID: task.ReplyMessageID,
ReplyMarkup: cancelMarkUp,
})
if task.ReplyMessageID != 0 {
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: text,
Entities: entities,
ID: task.ReplyMessageID,
ReplyMarkup: cancelMarkUp,
})
}
pr, pw := io.Pipe()
defer pr.Close()
@@ -97,11 +99,14 @@ func processPendingTask(task *types.Task) error {
return nil
}
common.Log.Warnf("存储 %s 不支持流式传输: %s", task.StorageName, notsupportStreamStorage.NotSupportStream())
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: fmt.Sprintf("存储 %s 不支持流式传输: %s\n正在使用普通下载...", task.StorageName, notsupportStreamStorage.NotSupportStream()),
ID: task.ReplyMessageID,
ReplyMarkup: cancelMarkUp,
})
if task.ReplyMessageID != 0 {
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: fmt.Sprintf("存储 %s 不支持流式传输: %s\n正在使用普通下载...", task.StorageName, notsupportStreamStorage.NotSupportStream()),
ID: task.ReplyMessageID,
ReplyMarkup: cancelMarkUp,
})
}
}
cacheDestPath := filepath.Join(config.Cfg.Temp.BasePath, task.FileName())
@@ -114,12 +119,14 @@ func processPendingTask(task *types.Task) error {
}
text, entities := buildProgressMessageEntity(task, 0, task.StartTime, 0)
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: text,
Entities: entities,
ID: task.ReplyMessageID,
ReplyMarkup: cancelMarkUp,
})
if task.ReplyMessageID != 0 {
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: text,
Entities: entities,
ID: task.ReplyMessageID,
ReplyMarkup: cancelMarkUp,
})
}
progressCallback := buildProgressCallback(ctx, task, getProgressUpdateCount(task.File.FileSize))
dest, err := NewTaskLocalFile(cacheDestPath, task.File.FileSize, progressCallback)
@@ -137,11 +144,12 @@ func processPendingTask(task *types.Task) error {
fixTaskFileExt(task, cacheDestPath)
common.Log.Infof("Downloaded file: %s", cacheDestPath)
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: fmt.Sprintf("下载完成: %s\n正在转存文件...", task.FileName()),
ID: task.ReplyMessageID,
})
if task.ReplyMessageID != 0 {
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: fmt.Sprintf("下载完成: %s\n正在转存文件...", task.FileName()),
ID: task.ReplyMessageID,
})
}
return saveFileWithRetry(cancelCtx, task.StoragePath, taskStorage, cacheDestPath)
}
@@ -169,12 +177,14 @@ func processTelegraph(extCtx *ext.Context, cancelCtx context.Context, task *type
common.Log.Errorf("Failed to build entities: %s", err)
}
extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: text,
Entities: entities,
ID: task.ReplyMessageID,
ReplyMarkup: getCancelTaskMarkup(task),
})
if task.ReplyMessageID != 0 {
extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: text,
Entities: entities,
ID: task.ReplyMessageID,
ReplyMarkup: getCancelTaskMarkup(task),
})
}
resultCh := make(chan error)
go func() {

View File

@@ -145,6 +145,9 @@ func buildProgressCallback(ctx *ext.Context, task *types.Task, updateCount int)
if task.File.FileSize < 1024*1024*50 || progressInt == 0 || progressInt%int(100/updateCount) != 0 {
return
}
if task.ReplyMessageID == 0 {
return
}
text, entities := buildProgressMessageEntity(task, bytesRead, task.StartTime, progress)
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: text,