feat: save file by cmd

This commit is contained in:
krau
2024-11-09 11:34:28 +08:00
parent 454d69c9d4
commit e3cd659eb3
2 changed files with 105 additions and 5 deletions

View File

@@ -130,13 +130,114 @@ func setDefaultStorage(ctx *ext.Context, update *ext.Update) error {
}
func saveCmd(ctx *ext.Context, update *ext.Update) error {
// TODO: Implement save command
res, ok := update.EffectiveMessage.GetReplyTo()
if !ok || res == nil {
ctx.Reply(update, "请回复要保存的文件", nil)
return dispatcher.EndGroups
}
replyHeader, ok := res.(*tg.MessageReplyHeader)
if !ok {
ctx.Reply(update, "请回复要保存的文件", nil)
return dispatcher.EndGroups
}
replyToMsgID, ok := replyHeader.GetReplyToMsgID()
if !ok {
ctx.Reply(update, "请回复要保存的文件", nil)
return dispatcher.EndGroups
}
msg, err := GetTGMessage(ctx, Client, replyToMsgID)
supported, _ := supportedMediaFilter(msg)
if !supported {
ctx.Reply(update, "不支持的消息类型或消息中没有文件", nil)
return dispatcher.EndGroups
}
user, err := dao.GetUserByUserID(update.GetUserChat().GetID())
if err != nil {
logger.L.Errorf("Failed to get user: %s", err)
return dispatcher.EndGroups
}
replied, err := ctx.Reply(update, "正在获取文件信息...", nil)
if err != nil {
logger.L.Errorf("Failed to reply: %s", err)
return dispatcher.EndGroups
}
file, err := FileFromMessage(ctx, Client, update.EffectiveChat().GetID(), msg.ID)
if err != nil {
logger.L.Errorf("Failed to get file from message: %s", err)
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: "无法获取文件",
ID: replied.ID,
})
return dispatcher.EndGroups
}
if file.FileName == "" {
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: "无法获取文件名",
ID: replied.ID,
})
return dispatcher.EndGroups
}
if err := dao.AddReceivedFile(&types.ReceivedFile{
Processing: false,
FileName: file.FileName,
ChatID: update.EffectiveChat().GetID(),
MessageID: replyToMsgID,
ReplyMessageID: replied.ID,
}); err != nil {
logger.L.Errorf("Failed to add received file: %s", err)
if _, err := ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: "无法保存文件",
ID: replied.ID,
}); err != nil {
logger.L.Errorf("Failed to edit message: %s", err)
}
return dispatcher.EndGroups
}
if !user.Silent {
text := "请选择存储位置"
_, err = ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: text,
ReplyMarkup: getAddTaskMarkup(msg.ID),
ID: replied.ID,
})
if err != nil {
logger.L.Errorf("Failed to reply: %s", err)
}
return dispatcher.EndGroups
}
if user.DefaultStorage == "" {
ctx.Reply(update, "请先使用 /storage 设置默认存储位置", nil)
return dispatcher.EndGroups
}
queue.AddTask(types.Task{
Ctx: ctx,
Status: types.Pending,
File: file,
Storage: types.StorageType(user.DefaultStorage),
ChatID: update.EffectiveChat().GetID(),
ReplyMessageID: replied.ID,
MessageID: msg.ID,
})
_, err = ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: fmt.Sprintf("已添加到队列: %s\n当前排队任务数: %d", file.FileName, queue.Len()),
ID: replied.ID,
})
if err != nil {
logger.L.Errorf("Failed to edit message: %s", err)
}
return dispatcher.EndGroups
}
func handleFileMessage(ctx *ext.Context, update *ext.Update) error {
logger.L.Trace("Got media: ", update.EffectiveMessage.Media.TypeName())
supported, err := supportedMediaFilter(update.EffectiveMessage)
supported, err := supportedMediaFilter(update.EffectiveMessage.Message)
if err != nil {
return err
}
@@ -226,7 +327,7 @@ func handleFileMessage(ctx *ext.Context, update *ext.Update) error {
func AddToQueue(ctx *ext.Context, update *ext.Update) error {
args := strings.Split(string(update.CallbackQuery.Data), " ")
messageID, _ := strconv.Atoi(args[1])
logger.L.Trace("Got add to queue: chatID: %d, messageID: %d, storage: %s", update.EffectiveChat().GetID(), messageID, args[2])
logger.L.Tracef("Got add to queue: chatID: %d, messageID: %d, storage: %s", update.EffectiveChat().GetID(), messageID, args[2])
record, err := dao.GetReceivedFileByChatAndMessageID(update.EffectiveChat().GetID(), messageID)
if err != nil {
logger.L.Errorf("Failed to get received file: %s", err)

View File

@@ -6,7 +6,6 @@ import (
"github.com/celestix/gotgproto"
"github.com/celestix/gotgproto/dispatcher"
tgTypes "github.com/celestix/gotgproto/types"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/common"
"github.com/krau/SaveAny-Bot/logger"
@@ -14,7 +13,7 @@ import (
"github.com/krau/SaveAny-Bot/types"
)
func supportedMediaFilter(m *tgTypes.Message) (bool, error) {
func supportedMediaFilter(m *tg.Message) (bool, error) {
if not := m.Media == nil; not {
return false, dispatcher.EndGroups
}