feat: add custom file name support for saved files and improve error messages

This commit is contained in:
krau
2025-02-12 12:02:28 +08:00
parent 1701d1ab86
commit 930e838b2e
5 changed files with 60 additions and 27 deletions

View File

@@ -1,6 +1,7 @@
package bot
import (
"errors"
"fmt"
"strconv"
"strings"
@@ -146,6 +147,7 @@ func saveCmd(ctx *ext.Context, update *ext.Update) error {
ctx.Reply(update, ext.ReplyTextString("请回复要保存的文件"), nil)
return dispatcher.EndGroups
}
msg, err := GetTGMessage(ctx, Client, replyToMsgID)
supported, _ := supportedMediaFilter(msg)
@@ -165,7 +167,11 @@ func saveCmd(ctx *ext.Context, update *ext.Update) error {
logger.L.Errorf("Failed to reply: %s", err)
return dispatcher.EndGroups
}
file, err := FileFromMessage(ctx, Client, update.EffectiveChat().GetID(), msg.ID)
cmdText := update.EffectiveMessage.Text
customFileName := strings.TrimSpace(strings.TrimPrefix(cmdText, "/save"))
file, err := FileFromMessage(ctx, Client, update.EffectiveChat().GetID(), msg.ID, customFileName)
if err != nil {
logger.L.Errorf("Failed to get file from message: %s", err)
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
@@ -183,16 +189,18 @@ func saveCmd(ctx *ext.Context, update *ext.Update) error {
return dispatcher.EndGroups
}
if err := dao.AddReceivedFile(&types.ReceivedFile{
receivedFile := &types.ReceivedFile{
Processing: false,
FileName: file.FileName,
ChatID: update.EffectiveChat().GetID(),
MessageID: replyToMsgID,
ReplyMessageID: replied.ID,
}); err != nil {
logger.L.Errorf("Failed to add received file: %s", err)
}
if err := dao.SaveReceivedFile(receivedFile); err != nil {
logger.L.Errorf("Failed to save received file: %s", err)
if _, err := ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: "无法保存文件",
Message: fmt.Sprintf("Failed to save received file: %s", err),
ID: replied.ID,
}); err != nil {
logger.L.Errorf("Failed to edit message: %s", err)
@@ -258,10 +266,14 @@ func handleFileMessage(ctx *ext.Context, update *ext.Update) error {
return dispatcher.EndGroups
}
media := update.EffectiveMessage.Media
file, err := FileFromMedia(media)
file, err := FileFromMedia(media, "")
if err != nil {
logger.L.Errorf("Failed to get file from media: %s", err)
ctx.Reply(update, ext.ReplyTextString("无法获取文件"), nil)
if errors.Is(err, ErrEmptyFileName) {
ctx.Reply(update, ext.ReplyTextString("无法获取文件名, 请使用 /save <自定义文件名> 回复此文件"), nil)
} else {
ctx.Reply(update, ext.ReplyTextString(fmt.Sprintf("获取文件失败: %s", err)), nil)
}
return dispatcher.EndGroups
}
if file.FileName == "" {
@@ -269,7 +281,7 @@ func handleFileMessage(ctx *ext.Context, update *ext.Update) error {
return dispatcher.EndGroups
}
if err := dao.AddReceivedFile(&types.ReceivedFile{
if err := dao.SaveReceivedFile(&types.ReceivedFile{
Processing: false,
FileName: file.FileName,
ChatID: update.EffectiveChat().GetID(),
@@ -278,12 +290,11 @@ func handleFileMessage(ctx *ext.Context, update *ext.Update) error {
}); err != nil {
logger.L.Errorf("Failed to add received file: %s", err)
if _, err := ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: "无法保存文件",
Message: fmt.Sprintf("Failed to add received file: %s", err),
ID: msg.ID,
}); err != nil {
logger.L.Errorf("Failed to edit message: %s", err)
}
return dispatcher.EndGroups
}
@@ -351,18 +362,18 @@ func AddToQueue(ctx *ext.Context, update *ext.Update) error {
}
if update.CallbackQuery.MsgID != record.ReplyMessageID {
record.ReplyMessageID = update.CallbackQuery.MsgID
if err := dao.UpdateReceivedFile(record); err != nil {
if err := dao.SaveReceivedFile(record); err != nil {
logger.L.Errorf("Failed to update received file: %s", err)
}
}
file, err := FileFromMessage(ctx, Client, record.ChatID, record.MessageID)
file, err := FileFromMessage(ctx, Client, record.ChatID, record.MessageID, record.FileName)
if err != nil {
logger.L.Errorf("Failed to get file from message: %s", err)
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
QueryID: update.CallbackQuery.QueryID,
Alert: true,
Message: "获取消息文件失败",
Message: fmt.Sprintf("获取消息中的文件失败: %s", err),
CacheTime: 5,
})
return dispatcher.EndGroups

View File

@@ -82,14 +82,21 @@ func getAddTaskMarkup(messageID int) *tg.ReplyInlineMarkup {
}
}
func FileFromMedia(media tg.MessageMediaClass) (*types.File, error) {
func FileFromMedia(media tg.MessageMediaClass, customFileName string) (*types.File, error) {
switch media := media.(type) {
case *tg.MessageMediaDocument:
document, ok := media.Document.AsNotEmpty()
if !ok {
return nil, ErrEmptyDocument
}
var fileName string
if customFileName != "" {
return &types.File{
Location: document.AsInputDocumentFileLocation(),
FileSize: document.Size,
FileName: customFileName,
}, nil
}
fileName := ""
for _, attribute := range document.Attributes {
if name, ok := attribute.(*tg.DocumentAttributeFilename); ok {
fileName = name.GetFileName()
@@ -123,17 +130,21 @@ func FileFromMedia(media tg.MessageMediaClass) (*types.File, error) {
location.AccessHash = photo.GetAccessHash()
location.FileReference = photo.GetFileReference()
location.ThumbSize = size.GetType()
fileName := customFileName
if fileName == "" {
fileName = fmt.Sprintf("photo_%s_%d.jpg", time.Now().Format("2006-01-02_15-04-05"), photo.GetID())
}
return &types.File{
Location: location,
FileSize: 0,
FileName: fmt.Sprintf("photo_%s_%d.jpg", time.Now().Format("2006-01-02_15-04-05"), photo.GetID()),
FileName: fileName,
}, nil
}
return nil, fmt.Errorf("unexpected type %T", media)
}
func FileFromMessage(ctx context.Context, client *gotgproto.Client, chatID int64, messageID int) (*types.File, error) {
func FileFromMessage(ctx context.Context, client *gotgproto.Client, chatID int64, messageID int, customFileName string) (*types.File, error) {
key := fmt.Sprintf("file:%d:%d", chatID, messageID)
logger.L.Debugf("Getting file: %s", key)
var cachedFile types.File
@@ -146,7 +157,7 @@ func FileFromMessage(ctx context.Context, client *gotgproto.Client, chatID int64
if err != nil {
return nil, err
}
file, err := FileFromMedia(message.Media)
file, err := FileFromMedia(message.Media, customFileName)
if err != nil {
return nil, err
}

View File

@@ -3,12 +3,14 @@ package dao
import (
"os"
"path/filepath"
"time"
"github.com/glebarez/sqlite"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/logger"
"github.com/krau/SaveAny-Bot/types"
"gorm.io/gorm"
glogger "gorm.io/gorm/logger"
)
var db *gorm.DB
@@ -19,7 +21,16 @@ func Init() {
os.Exit(1)
}
var err error
db, err = gorm.Open(sqlite.Open(config.Cfg.DB.Path), &gorm.Config{})
db, err = gorm.Open(sqlite.Open(config.Cfg.DB.Path), &gorm.Config{
Logger: glogger.New(logger.L, glogger.Config{
Colorful: true,
SlowThreshold: time.Second * 5,
LogLevel: glogger.Error,
IgnoreRecordNotFoundError: true,
ParameterizedQueries: true,
}),
PrepareStmt: true,
})
if err != nil {
logger.L.Fatal("Failed to open database: ", err)
os.Exit(1)

View File

@@ -2,8 +2,12 @@ package dao
import "github.com/krau/SaveAny-Bot/types"
func AddReceivedFile(receivedFile *types.ReceivedFile) error {
return db.Create(receivedFile).Error
func SaveReceivedFile(receivedFile *types.ReceivedFile) error {
record, err := GetReceivedFileByChatAndMessageID(receivedFile.ChatID, receivedFile.MessageID)
if err == nil {
receivedFile.ID = record.ID
}
return db.Save(receivedFile).Error
}
func GetReceivedFileByChatAndMessageID(chatID int64, messageID int) (*types.ReceivedFile, error) {
@@ -15,10 +19,6 @@ func GetReceivedFileByChatAndMessageID(chatID int64, messageID int) (*types.Rece
return &receivedFile, nil
}
func UpdateReceivedFile(receivedFile *types.ReceivedFile) error {
return db.Save(receivedFile).Error
}
func DeleteReceivedFile(receivedFile *types.ReceivedFile) error {
return db.Delete(receivedFile).Error
}

View File

@@ -7,8 +7,8 @@ import (
type ReceivedFile struct {
gorm.Model
Processing bool
ChatID int64
MessageID int
ChatID int64 `gorm:"uniqueIndex:idx_chat_id_message_id;not null"`
MessageID int `gorm:"uniqueIndex:idx_chat_id_message_id;not null"`
ReplyMessageID int
FileName string
}