feat: batch save files
This commit is contained in:
@@ -29,24 +29,40 @@ func handleLinkMessage(ctx *ext.Context, update *ext.Update) error {
|
||||
if len(strSlice) < 3 {
|
||||
return dispatcher.ContinueGroups
|
||||
}
|
||||
messageID, err := strconv.Atoi(strSlice[2])
|
||||
messageID, err := strconv.Atoi(strSlice[len(strSlice)-1])
|
||||
if err != nil {
|
||||
common.Log.Errorf("解析消息 ID 失败: %s", err)
|
||||
ctx.Reply(update, ext.ReplyTextString("无法解析消息 ID"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
chatUsername := strSlice[1]
|
||||
linkChat, err := ctx.ResolveUsername(chatUsername)
|
||||
if err != nil {
|
||||
common.Log.Errorf("解析 Chat ID 失败: %s", err)
|
||||
ctx.Reply(update, ext.ReplyTextString("无法解析 Chat ID"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
if linkChat == nil {
|
||||
common.Log.Errorf("无法找到聊天: %s", chatUsername)
|
||||
ctx.Reply(update, ext.ReplyTextString("无法找到聊天"), nil)
|
||||
var linkChatID int64
|
||||
if len(strSlice) == 3 {
|
||||
chatUsername := strSlice[1]
|
||||
linkChat, err := ctx.ResolveUsername(chatUsername)
|
||||
if err != nil {
|
||||
common.Log.Errorf("解析用户名失败: %s", err)
|
||||
ctx.Reply(update, ext.ReplyTextString("解析用户名失败"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
if linkChat == nil {
|
||||
common.Log.Errorf("无法找到聊天: %s", chatUsername)
|
||||
ctx.Reply(update, ext.ReplyTextString("无法找到聊天"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
linkChatID = linkChat.GetID()
|
||||
} else if len(strSlice) == 4 {
|
||||
chatID, err := strconv.Atoi(strSlice[2])
|
||||
if err != nil {
|
||||
common.Log.Errorf("解析 Chat ID 失败: %s", err)
|
||||
ctx.Reply(update, ext.ReplyTextString("解析 Chat ID 失败"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
linkChatID = int64(chatID)
|
||||
} else {
|
||||
ctx.Reply(update, ext.ReplyTextString("无法解析链接"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
|
||||
user, err := dao.GetUserByChatID(update.GetUserChat().GetID())
|
||||
if err != nil {
|
||||
common.Log.Errorf("获取用户失败: %s", err)
|
||||
@@ -65,7 +81,7 @@ func handleLinkMessage(ctx *ext.Context, update *ext.Update) error {
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
|
||||
file, err := FileFromMessage(ctx, linkChat.GetID(), messageID, "")
|
||||
file, err := FileFromMessage(ctx, linkChatID, messageID, "")
|
||||
if err != nil {
|
||||
common.Log.Errorf("获取文件失败: %s", err)
|
||||
ctx.Reply(update, ext.ReplyTextString("获取文件失败: "+err.Error()), nil)
|
||||
@@ -78,7 +94,7 @@ func handleLinkMessage(ctx *ext.Context, update *ext.Update) error {
|
||||
receivedFile := &dao.ReceivedFile{
|
||||
Processing: false,
|
||||
FileName: file.FileName,
|
||||
ChatID: linkChat.GetID(),
|
||||
ChatID: linkChatID,
|
||||
MessageID: messageID,
|
||||
ReplyMessageID: replied.ID,
|
||||
ReplyChatID: update.GetUserChat().GetID(),
|
||||
@@ -92,7 +108,7 @@ func handleLinkMessage(ctx *ext.Context, update *ext.Update) error {
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
if !user.Silent || user.DefaultStorage == "" {
|
||||
return ProvideSelectMessage(ctx, update, file.FileName, linkChat.GetID(), messageID, replied.ID)
|
||||
return ProvideSelectMessage(ctx, update, file.FileName, linkChatID, messageID, replied.ID)
|
||||
}
|
||||
return HandleSilentAddTask(ctx, update, user, &types.Task{
|
||||
Ctx: ctx,
|
||||
@@ -100,7 +116,7 @@ func handleLinkMessage(ctx *ext.Context, update *ext.Update) error {
|
||||
File: file,
|
||||
StorageName: user.DefaultStorage,
|
||||
UserID: user.ChatID,
|
||||
FileChatID: linkChat.GetID(),
|
||||
FileChatID: linkChatID,
|
||||
FileMessageID: messageID,
|
||||
ReplyMessageID: replied.ID,
|
||||
ReplyChatID: update.GetUserChat().GetID(),
|
||||
|
||||
@@ -2,6 +2,7 @@ package bot
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/celestix/gotgproto/dispatcher"
|
||||
@@ -9,25 +10,50 @@ import (
|
||||
"github.com/gotd/td/tg"
|
||||
"github.com/krau/SaveAny-Bot/common"
|
||||
"github.com/krau/SaveAny-Bot/dao"
|
||||
"github.com/krau/SaveAny-Bot/queue"
|
||||
"github.com/krau/SaveAny-Bot/storage"
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
)
|
||||
|
||||
func sendSaveHelp(ctx *ext.Context, update *ext.Update) error {
|
||||
helpText := `
|
||||
使用方法:
|
||||
|
||||
1. 使用该命令回复要保存的文件, 可选文件名参数.
|
||||
示例:
|
||||
/save custom_file_name.mp4
|
||||
|
||||
2. 设置默认存储后, 发送 /save <频道ID/用户名> <消息ID范围> 来批量保存文件. 遵从存储规则, 若未匹配到任何规则则使用默认存储.
|
||||
示例:
|
||||
/save @moreacg 114-514
|
||||
`
|
||||
ctx.Reply(update, ext.ReplyTextString(helpText), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
|
||||
func saveCmd(ctx *ext.Context, update *ext.Update) error {
|
||||
res, ok := update.EffectiveMessage.GetReplyTo()
|
||||
if !ok || res == nil {
|
||||
ctx.Reply(update, ext.ReplyTextString("请回复要保存的文件"), nil)
|
||||
return dispatcher.EndGroups
|
||||
args := strings.Split(update.EffectiveMessage.Text, " ")
|
||||
if len(args) >= 3 {
|
||||
return handleBatchSave(ctx, update, args[1:])
|
||||
}
|
||||
replyHeader, ok := res.(*tg.MessageReplyHeader)
|
||||
if !ok {
|
||||
ctx.Reply(update, ext.ReplyTextString("请回复要保存的文件"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
replyToMsgID, ok := replyHeader.GetReplyToMsgID()
|
||||
if !ok {
|
||||
ctx.Reply(update, ext.ReplyTextString("请回复要保存的文件"), nil)
|
||||
return dispatcher.EndGroups
|
||||
|
||||
replyToMsgID := func() int {
|
||||
res, ok := update.EffectiveMessage.GetReplyTo()
|
||||
if !ok || res == nil {
|
||||
return 0
|
||||
}
|
||||
replyHeader, ok := res.(*tg.MessageReplyHeader)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
replyToMsgID, ok := replyHeader.GetReplyToMsgID()
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
return replyToMsgID
|
||||
}()
|
||||
if replyToMsgID == 0 {
|
||||
return sendSaveHelp(ctx, update)
|
||||
}
|
||||
|
||||
user, err := dao.GetUserByChatID(update.GetUserChat().GetID())
|
||||
@@ -113,3 +139,125 @@ func saveCmd(ctx *ext.Context, update *ext.Update) error {
|
||||
UserID: user.ChatID,
|
||||
})
|
||||
}
|
||||
|
||||
func handleBatchSave(ctx *ext.Context, update *ext.Update, args []string) error {
|
||||
// args: [0] = @channel, [1] = 114-514
|
||||
chatArg := args[0]
|
||||
var chatID int64
|
||||
var err error
|
||||
msgIdSlice := strings.Split(args[1], "-")
|
||||
if len(msgIdSlice) != 2 {
|
||||
ctx.Reply(update, ext.ReplyTextString("无效的消息ID范围"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
minMsgID, minerr := strconv.ParseInt(msgIdSlice[0], 10, 64)
|
||||
maxMsgID, maxerr := strconv.ParseInt(msgIdSlice[1], 10, 64)
|
||||
if minerr != nil || maxerr != nil {
|
||||
ctx.Reply(update, ext.ReplyTextString("无效的消息ID范围"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
if minMsgID > maxMsgID || minMsgID <= 0 || maxMsgID <= 0 {
|
||||
ctx.Reply(update, ext.ReplyTextString("无效的消息ID范围"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
user, err := dao.GetUserByChatID(update.GetUserChat().GetID())
|
||||
if err != nil {
|
||||
common.Log.Errorf("获取用户失败: %s", err)
|
||||
ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
if user.DefaultStorage == "" {
|
||||
ctx.Reply(update, ext.ReplyTextString("请先设置默认存储"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
storages := storage.GetUserStorages(user.ChatID)
|
||||
if len(storages) == 0 {
|
||||
ctx.Reply(update, ext.ReplyTextString("无可用的存储"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
|
||||
if strings.HasPrefix(chatArg, "@") {
|
||||
chatUsername := strings.TrimPrefix(chatArg, "@")
|
||||
chat, err := ctx.ResolveUsername(chatUsername)
|
||||
if err != nil {
|
||||
common.Log.Errorf("解析频道用户名失败: %s", err)
|
||||
ctx.Reply(update, ext.ReplyTextString("解析频道用户名失败"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
if chat == nil {
|
||||
ctx.Reply(update, ext.ReplyTextString("无法找到聊天"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
chatID = chat.GetID()
|
||||
} else {
|
||||
chatID, err = strconv.ParseInt(chatArg, 10, 64)
|
||||
if err != nil {
|
||||
ctx.Reply(update, ext.ReplyTextString("无效的频道ID或用户名"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
}
|
||||
if chatID == 0 {
|
||||
ctx.Reply(update, ext.ReplyTextString("无效的频道ID或用户名"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
|
||||
replied, err := ctx.Reply(update, ext.ReplyTextString("正在批量保存..."), nil)
|
||||
if err != nil {
|
||||
common.Log.Errorf("回复失败: %s", err)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
|
||||
total := maxMsgID - minMsgID + 1
|
||||
successadd := 0
|
||||
failedGetFile := 0
|
||||
failedGetMsg := 0
|
||||
failedSaveDB := 0
|
||||
for i := minMsgID; i <= maxMsgID; i++ {
|
||||
file, err := FileFromMessage(ctx, chatID, int(i), "")
|
||||
if err != nil {
|
||||
common.Log.Errorf("获取文件失败: %s", err)
|
||||
failedGetFile++
|
||||
continue
|
||||
}
|
||||
if file.FileName == "" {
|
||||
message, err := GetTGMessage(ctx, chatID, int(i))
|
||||
if err != nil {
|
||||
common.Log.Errorf("获取消息失败: %s", err)
|
||||
failedGetMsg++
|
||||
continue
|
||||
}
|
||||
file.FileName = GenFileNameFromMessage(*message, file)
|
||||
}
|
||||
receivedFile := &dao.ReceivedFile{
|
||||
Processing: false,
|
||||
FileName: file.FileName,
|
||||
ChatID: chatID,
|
||||
MessageID: int(i),
|
||||
ReplyChatID: update.GetUserChat().GetID(),
|
||||
ReplyMessageID: 0,
|
||||
}
|
||||
if err := dao.SaveReceivedFile(receivedFile); err != nil {
|
||||
common.Log.Errorf("保存接收的文件失败: %s", err)
|
||||
failedSaveDB++
|
||||
continue
|
||||
}
|
||||
task := &types.Task{
|
||||
Ctx: ctx,
|
||||
Status: types.Pending,
|
||||
File: file,
|
||||
StorageName: user.DefaultStorage,
|
||||
FileChatID: chatID,
|
||||
FileMessageID: int(i),
|
||||
UserID: user.ChatID,
|
||||
ReplyMessageID: 0,
|
||||
ReplyChatID: update.GetUserChat().GetID(),
|
||||
}
|
||||
queue.AddTask(task)
|
||||
successadd++
|
||||
}
|
||||
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
|
||||
Message: fmt.Sprintf("批量保存完成\n成功添加: %d/%d\n获取文件失败: %d\n获取消息失败: %d\n保存数据库失败: %d", successadd, total, failedGetFile, failedGetMsg, failedSaveDB),
|
||||
ID: replied.ID,
|
||||
})
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
|
||||
@@ -46,7 +46,7 @@ func worker(queue *queue.TaskQueue, semaphore chan struct{}) {
|
||||
extCtx, ok := task.Ctx.(*ext.Context)
|
||||
if !ok {
|
||||
common.Log.Errorf("Context is not *ext.Context: %T", task.Ctx)
|
||||
} else {
|
||||
} else if task.ReplyMessageID != 0 {
|
||||
extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
||||
Message: fmt.Sprintf("文件保存成功\n [%s]: %s", task.StorageName, task.StoragePath),
|
||||
ID: task.ReplyMessageID,
|
||||
@@ -57,7 +57,7 @@ func worker(queue *queue.TaskQueue, semaphore chan struct{}) {
|
||||
extCtx, ok := task.Ctx.(*ext.Context)
|
||||
if !ok {
|
||||
common.Log.Errorf("Context is not *ext.Context: %T", task.Ctx)
|
||||
} else {
|
||||
} else if task.ReplyMessageID != 0 {
|
||||
extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
||||
Message: "文件保存失败\n" + task.Error.Error(),
|
||||
ID: task.ReplyMessageID,
|
||||
@@ -68,7 +68,7 @@ func worker(queue *queue.TaskQueue, semaphore chan struct{}) {
|
||||
extCtx, ok := task.Ctx.(*ext.Context)
|
||||
if !ok {
|
||||
common.Log.Errorf("Context is not *ext.Context: %T", task.Ctx)
|
||||
} else {
|
||||
} else if task.ReplyMessageID != 0 {
|
||||
extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
||||
Message: "任务已取消",
|
||||
ID: task.ReplyMessageID,
|
||||
|
||||
@@ -63,12 +63,14 @@ func processPendingTask(task *types.Task) error {
|
||||
if config.Cfg.Stream {
|
||||
if !notsupportStream {
|
||||
text, entities := buildProgressMessageEntity(task, 0, task.StartTime, 0)
|
||||
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
||||
Message: text,
|
||||
Entities: entities,
|
||||
ID: task.ReplyMessageID,
|
||||
ReplyMarkup: cancelMarkUp,
|
||||
})
|
||||
if task.ReplyMessageID != 0 {
|
||||
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
||||
Message: text,
|
||||
Entities: entities,
|
||||
ID: task.ReplyMessageID,
|
||||
ReplyMarkup: cancelMarkUp,
|
||||
})
|
||||
}
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
defer pr.Close()
|
||||
@@ -97,11 +99,14 @@ func processPendingTask(task *types.Task) error {
|
||||
return nil
|
||||
}
|
||||
common.Log.Warnf("存储 %s 不支持流式传输: %s", task.StorageName, notsupportStreamStorage.NotSupportStream())
|
||||
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
||||
Message: fmt.Sprintf("存储 %s 不支持流式传输: %s\n正在使用普通下载...", task.StorageName, notsupportStreamStorage.NotSupportStream()),
|
||||
ID: task.ReplyMessageID,
|
||||
ReplyMarkup: cancelMarkUp,
|
||||
})
|
||||
|
||||
if task.ReplyMessageID != 0 {
|
||||
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
||||
Message: fmt.Sprintf("存储 %s 不支持流式传输: %s\n正在使用普通下载...", task.StorageName, notsupportStreamStorage.NotSupportStream()),
|
||||
ID: task.ReplyMessageID,
|
||||
ReplyMarkup: cancelMarkUp,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
cacheDestPath := filepath.Join(config.Cfg.Temp.BasePath, task.FileName())
|
||||
@@ -114,12 +119,14 @@ func processPendingTask(task *types.Task) error {
|
||||
}
|
||||
|
||||
text, entities := buildProgressMessageEntity(task, 0, task.StartTime, 0)
|
||||
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
||||
Message: text,
|
||||
Entities: entities,
|
||||
ID: task.ReplyMessageID,
|
||||
ReplyMarkup: cancelMarkUp,
|
||||
})
|
||||
if task.ReplyMessageID != 0 {
|
||||
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
||||
Message: text,
|
||||
Entities: entities,
|
||||
ID: task.ReplyMessageID,
|
||||
ReplyMarkup: cancelMarkUp,
|
||||
})
|
||||
}
|
||||
|
||||
progressCallback := buildProgressCallback(ctx, task, getProgressUpdateCount(task.File.FileSize))
|
||||
dest, err := NewTaskLocalFile(cacheDestPath, task.File.FileSize, progressCallback)
|
||||
@@ -137,11 +144,12 @@ func processPendingTask(task *types.Task) error {
|
||||
fixTaskFileExt(task, cacheDestPath)
|
||||
|
||||
common.Log.Infof("Downloaded file: %s", cacheDestPath)
|
||||
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
||||
Message: fmt.Sprintf("下载完成: %s\n正在转存文件...", task.FileName()),
|
||||
ID: task.ReplyMessageID,
|
||||
})
|
||||
|
||||
if task.ReplyMessageID != 0 {
|
||||
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
||||
Message: fmt.Sprintf("下载完成: %s\n正在转存文件...", task.FileName()),
|
||||
ID: task.ReplyMessageID,
|
||||
})
|
||||
}
|
||||
return saveFileWithRetry(cancelCtx, task.StoragePath, taskStorage, cacheDestPath)
|
||||
}
|
||||
|
||||
@@ -169,12 +177,14 @@ func processTelegraph(extCtx *ext.Context, cancelCtx context.Context, task *type
|
||||
common.Log.Errorf("Failed to build entities: %s", err)
|
||||
}
|
||||
|
||||
extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
||||
Message: text,
|
||||
Entities: entities,
|
||||
ID: task.ReplyMessageID,
|
||||
ReplyMarkup: getCancelTaskMarkup(task),
|
||||
})
|
||||
if task.ReplyMessageID != 0 {
|
||||
extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
||||
Message: text,
|
||||
Entities: entities,
|
||||
ID: task.ReplyMessageID,
|
||||
ReplyMarkup: getCancelTaskMarkup(task),
|
||||
})
|
||||
}
|
||||
|
||||
resultCh := make(chan error)
|
||||
go func() {
|
||||
|
||||
@@ -145,6 +145,9 @@ func buildProgressCallback(ctx *ext.Context, task *types.Task, updateCount int)
|
||||
if task.File.FileSize < 1024*1024*50 || progressInt == 0 || progressInt%int(100/updateCount) != 0 {
|
||||
return
|
||||
}
|
||||
if task.ReplyMessageID == 0 {
|
||||
return
|
||||
}
|
||||
text, entities := buildProgressMessageEntity(task, bytesRead, task.StartTime, progress)
|
||||
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
||||
Message: text,
|
||||
|
||||
Reference in New Issue
Block a user