feat: file name staregy

This commit is contained in:
krau
2025-08-23 17:16:51 +08:00
parent 7300e54c40
commit 68e5a51300
10 changed files with 265 additions and 20 deletions

View File

@@ -0,0 +1,103 @@
package handlers
import (
"fmt"
"strings"
"github.com/celestix/gotgproto/dispatcher"
"github.com/celestix/gotgproto/ext"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/database"
"github.com/krau/SaveAny-Bot/pkg/enums/fnamest"
"github.com/krau/SaveAny-Bot/pkg/tcbdata"
)
func handleConfigCmd(ctx *ext.Context, update *ext.Update) error {
ctx.Reply(update, ext.ReplyTextString("请选择要配置的选项"), &ext.ReplyOpts{
Markup: &tg.ReplyInlineMarkup{
Rows: []tg.KeyboardButtonRow{
{
Buttons: []tg.KeyboardButtonClass{
&tg.KeyboardButtonCallback{
Text: "文件名策略",
Data: fmt.Appendf(nil, "%s %s", tcbdata.TypeConfig, "fnamest"),
},
},
},
},
},
})
return dispatcher.EndGroups
}
func handleConfigCallback(ctx *ext.Context, update *ext.Update) error {
args := strings.Fields(string(update.CallbackQuery.Data))
invaildDataAnswer := func() error {
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
QueryID: update.CallbackQuery.GetQueryID(),
Alert: true,
Message: "无效的回调数据",
CacheTime: 5,
})
return dispatcher.EndGroups
}
if len(args) < 2 {
return invaildDataAnswer()
}
switch args[1] {
case "fnamest":
return handleConfigFnameSTCallback(ctx, update)
default:
return invaildDataAnswer()
}
}
func handleConfigFnameSTCallback(ctx *ext.Context, update *ext.Update) error {
userID := update.CallbackQuery.GetUserID()
user, err := database.GetUserByChatID(ctx, userID)
if err != nil {
return err
}
args := strings.Fields(string(update.CallbackQuery.Data))
if len(args) == 3 {
selected := args[2]
st, err := fnamest.ParseFnameST(selected)
if err != nil {
return err
}
user.FilenameStrategy = st.String()
if err := database.UpdateUser(ctx, user); err != nil {
return err
}
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
ID: update.CallbackQuery.GetMsgID(),
Message: fmt.Sprintf("已将文件名策略设置为: %s", fnamest.FnameSTDisplay[st]),
})
return dispatcher.EndGroups
}
opts := fnamest.FnameSTValues()
buttons := make([]tg.KeyboardButtonClass, 0, len(opts))
for _, opt := range opts {
buttons = append(buttons, &tg.KeyboardButtonCallback{
Text: fnamest.FnameSTDisplay[opt],
Data: fmt.Appendf(nil, "%s %s %s", tcbdata.TypeConfig, "fnamest", opt),
})
}
markup := &tg.ReplyInlineMarkup{Rows: []tg.KeyboardButtonRow{
{Buttons: buttons},
}}
currentStStr := user.FilenameStrategy
if currentStStr == "" {
currentStStr = fnamest.Default.String()
}
currentSt, err := fnamest.ParseFnameST(currentStStr)
if err != nil {
currentSt = fnamest.Default
}
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
ID: update.CallbackQuery.GetMsgID(),
Message: fmt.Sprintf("请选择文件名策略, 当前策略: %s", fnamest.FnameSTDisplay[currentSt]),
ReplyMarkup: markup,
})
return dispatcher.EndGroups
}

View File

@@ -13,6 +13,8 @@ import (
"github.com/krau/SaveAny-Bot/client/bot/handlers/utils/msgelem"
"github.com/krau/SaveAny-Bot/client/bot/handlers/utils/shortcut"
"github.com/krau/SaveAny-Bot/common/utils/tgutil"
"github.com/krau/SaveAny-Bot/database"
"github.com/krau/SaveAny-Bot/pkg/enums/fnamest"
"github.com/krau/SaveAny-Bot/pkg/tcbdata"
"github.com/krau/SaveAny-Bot/pkg/tfile"
"github.com/krau/SaveAny-Bot/storage"
@@ -26,12 +28,22 @@ func handleMediaMessage(ctx *ext.Context, update *ext.Update) error {
return handleGroupMediaMessage(ctx, update, message, groupID)
}
logger.Debugf("Got media: %s", message.Media.TypeName())
msg, file, err := shortcut.GetFileFromMessageWithReply(ctx, update, message)
userId := update.GetUserChat().GetID()
userDB, err := database.GetUserByChatID(ctx, userId)
if err != nil {
return err
}
userId := update.GetUserChat().GetID()
tfOpts := make([]tfile.TGFileOption, 0)
switch userDB.FilenameStrategy {
case fnamest.Message.String():
tfOpts = append(tfOpts, tfile.WithName(tgutil.GenFileNameFromMessage(*message)))
default:
}
msg, file, err := shortcut.GetFileFromMessageWithReply(ctx, update, message, tfOpts...)
if err != nil {
return err
}
stors := storage.GetUserStorages(ctx, userId)
req, err := msgelem.BuildAddOneSelectStorageMessage(ctx, stors, file, msg.ID)
if err != nil {
@@ -58,7 +70,17 @@ func handleSilentSaveMedia(ctx *ext.Context, update *ext.Update) error {
}
logger.Debugf("Got media: %s", message.Media.TypeName())
userID := update.GetUserChat().GetID()
msg, file, err := shortcut.GetFileFromMessageWithReply(ctx, update, message)
userDB, err := database.GetUserByChatID(ctx, userID)
if err != nil {
return err
}
tfOpts := make([]tfile.TGFileOption, 0)
switch userDB.FilenameStrategy {
case fnamest.Message.String():
tfOpts = append(tfOpts, tfile.WithName(tgutil.GenFileNameFromMessage(*message)))
default:
}
msg, file, err := shortcut.GetFileFromMessageWithReply(ctx, update, message, tfOpts...)
if err != nil {
return err
}

View File

@@ -40,9 +40,11 @@ func Register(disp dispatcher.Dispatcher) {
disp.AddHandler(handlers.NewCommand("watch", handleWatchCmd))
disp.AddHandler(handlers.NewCommand("unwatch", handleUnwatchCmd))
disp.AddHandler(handlers.NewCommand("save", handleSilentMode(handleSaveCmd, handleSilentSaveReplied)))
disp.AddHandler(handlers.NewCommand("config", handleConfigCmd))
disp.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix(tcbdata.TypeAdd), handleAddCallback))
disp.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix(tcbdata.TypeSetDefault), handleSetDefaultCallback))
disp.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("cancel"), handleCancelCallback))
disp.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix(tcbdata.TypeCancel), handleCancelCallback))
disp.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix(tcbdata.TypeConfig), handleConfigCallback))
linkRegexFilter, err := filters.Message.Regex(re.TgMessageLinkRegexString)
if err != nil {
panic("failed to create regex filter: " + err.Error())

View File

@@ -20,12 +20,14 @@ import (
"github.com/krau/SaveAny-Bot/common/utils/tgutil"
"github.com/krau/SaveAny-Bot/common/utils/tphutil"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/database"
"github.com/krau/SaveAny-Bot/pkg/enums/fnamest"
"github.com/krau/SaveAny-Bot/pkg/telegraph"
"github.com/krau/SaveAny-Bot/pkg/tfile"
)
// 获取消息中的文件并回复等待消息, 返回等待消息, 获取到的文件
func GetFileFromMessageWithReply(ctx *ext.Context, update *ext.Update, message *tg.Message, tfileopts ...tfile.TGFileOptions) (replied *types.Message,
func GetFileFromMessageWithReply(ctx *ext.Context, update *ext.Update, message *tg.Message, tfileopts ...tfile.TGFileOption) (replied *types.Message,
file tfile.TGFileMessage, err error,
) {
logger := log.FromContext(ctx)
@@ -40,7 +42,7 @@ func GetFileFromMessageWithReply(ctx *ext.Context, update *ext.Update, message *
logger.Errorf("Failed to reply: %s", err)
return nil, nil, dispatcher.EndGroups
}
options := []tfile.TGFileOptions{
options := []tfile.TGFileOption{
tfile.WithMessage(message),
}
if len(tfileopts) > 0 {
@@ -81,7 +83,12 @@ func GetFilesFromUpdateLinkMessageWithReplyEdit(ctx *ext.Context, update *ext.Up
logger.Errorf("failed to edit message: %s", err)
}
}
user, err := database.GetUserByChatID(ctx, update.GetUserChat().GetID())
if err != nil {
logger.Errorf("failed to get user from db: %s", err)
editReplied("获取用户信息失败: "+err.Error(), nil)
return nil, nil, nil, dispatcher.EndGroups
}
files = make([]tfile.TGFileMessage, 0, len(msgLinks))
addFile := func(client downloader.Client, msg *tg.Message) {
if msg == nil || msg.Media == nil {
@@ -93,7 +100,14 @@ func GetFilesFromUpdateLinkMessageWithReplyEdit(ctx *ext.Context, update *ext.Up
logger.Debugf("message %d has no media", msg.GetID())
return
}
file, err := tfile.FromMediaMessage(media, client, msg, tfile.WithNameIfEmpty(tgutil.GenFileNameFromMessage(*msg)))
var opt tfile.TGFileOption
switch user.FilenameStrategy {
case fnamest.Message.String():
opt = tfile.WithName(tgutil.GenFileNameFromMessage(*msg))
default:
opt = tfile.WithNameIfEmpty(tgutil.GenFileNameFromMessage(*msg))
}
file, err := tfile.FromMediaMessage(media, client, msg, opt)
if err != nil {
logger.Errorf("failed to create file from media: %s", err)
return

View File

@@ -14,7 +14,7 @@ import (
func handleWatchCmd(ctx *ext.Context, update *ext.Update) error {
logger := log.FromContext(ctx)
args := strings.Split(string(update.EffectiveMessage.Text), " ")
args := strings.Split(update.EffectiveMessage.Text, " ")
if len(args) < 2 {
ctx.Reply(update, ext.ReplyTextString(msgelem.WatchHelpText), nil)
return dispatcher.EndGroups
@@ -82,7 +82,7 @@ func handleWatchCmd(ctx *ext.Context, update *ext.Update) error {
func handleUnwatchCmd(ctx *ext.Context, update *ext.Update) error {
logger := log.FromContext(ctx)
args := strings.Split(string(update.EffectiveMessage.Text), " ")
args := strings.Split(update.EffectiveMessage.Text, " ")
if len(args) < 2 {
ctx.Reply(update, ext.ReplyTextString("请提供要取消监听的聊天ID或用户名"), nil)
return dispatcher.EndGroups

View File

@@ -0,0 +1,14 @@
package fnamest
//go:generate go-enum --values --names --noprefix --flag --nocase
// FnameST
/* ENUM(
default, message
) */
type FnameST string
var FnameSTDisplay = map[FnameST]string{
Default: "默认",
Message: "优先从消息生成",
}

View File

@@ -0,0 +1,87 @@
// Code generated by go-enum DO NOT EDIT.
// Version: 0.6.1
// Revision: a6f63bddde05aca4221df9c8e9e6d7d9674b1cb4
// Build Date: 2025-03-18T23:42:14Z
// Built By: goreleaser
package fnamest
import (
"fmt"
"strings"
)
const (
// Default is a FnameST of type default.
Default FnameST = "default"
// Message is a FnameST of type message.
Message FnameST = "message"
)
var ErrInvalidFnameST = fmt.Errorf("not a valid FnameST, try [%s]", strings.Join(_FnameSTNames, ", "))
var _FnameSTNames = []string{
string(Default),
string(Message),
}
// FnameSTNames returns a list of possible string values of FnameST.
func FnameSTNames() []string {
tmp := make([]string, len(_FnameSTNames))
copy(tmp, _FnameSTNames)
return tmp
}
// FnameSTValues returns a list of the values for FnameST
func FnameSTValues() []FnameST {
return []FnameST{
Default,
Message,
}
}
// String implements the Stringer interface.
func (x FnameST) String() string {
return string(x)
}
// IsValid provides a quick way to determine if the typed value is
// part of the allowed enumerated values
func (x FnameST) IsValid() bool {
_, err := ParseFnameST(string(x))
return err == nil
}
var _FnameSTValue = map[string]FnameST{
"default": Default,
"message": Message,
}
// ParseFnameST attempts to convert a string to a FnameST.
func ParseFnameST(name string) (FnameST, error) {
if x, ok := _FnameSTValue[name]; ok {
return x, nil
}
// Case insensitive parse, do a separate lookup to prevent unnecessary cost of lowercasing a string if we don't need to.
if x, ok := _FnameSTValue[strings.ToLower(name)]; ok {
return x, nil
}
return FnameST(""), fmt.Errorf("%s is %w", name, ErrInvalidFnameST)
}
// Set implements the Golang flag.Value interface func.
func (x *FnameST) Set(val string) error {
v, err := ParseFnameST(val)
*x = v
return err
}
// Get implements the Golang flag.Getter interface func.
func (x *FnameST) Get() interface{} {
return *x
}
// Type implements the github.com/spf13/pFlag Value interface.
func (x *FnameST) Type() string {
return "FnameST"
}

View File

@@ -10,6 +10,8 @@ import (
const (
TypeAdd = "add"
TypeSetDefault = "setdefault"
TypeConfig = "config"
TypeCancel = "cancel"
)
// type TaskDataTGFiles struct {

View File

@@ -2,20 +2,21 @@ package tfile
import "github.com/gotd/td/tg"
type TGFileOptions func(*tgFile)
type TGFileOption func(*tgFile)
func WithMessage(msg *tg.Message) TGFileOptions {
func WithMessage(msg *tg.Message) TGFileOption {
return func(f *tgFile) {
f.message = msg
}
}
func WithName(name string) TGFileOptions {
func WithName(name string) TGFileOption {
return func(f *tgFile) {
f.name = name
}
}
func WithNameIfEmpty(name string) TGFileOptions {
func WithNameIfEmpty(name string) TGFileOption {
return func(f *tgFile) {
if f.name == "" {
f.name = name
@@ -23,13 +24,13 @@ func WithNameIfEmpty(name string) TGFileOptions {
}
}
func WithSize(size int64) TGFileOptions {
func WithSize(size int64) TGFileOption {
return func(f *tgFile) {
f.size = size
}
}
func WithSizeIfZero(size int64) TGFileOptions {
func WithSizeIfZero(size int64) TGFileOption {
return func(f *tgFile) {
if f.size == 0 {
f.size = size

View File

@@ -54,7 +54,7 @@ func NewTGFile(
dler downloader.Client,
size int64,
name string,
opts ...TGFileOptions,
opts ...TGFileOption,
) TGFile {
f := &tgFile{
location: location,
@@ -68,7 +68,7 @@ func NewTGFile(
return f
}
func FromMedia(media tg.MessageMediaClass, client downloader.Client, opts ...TGFileOptions) (TGFile, error) {
func FromMedia(media tg.MessageMediaClass, client downloader.Client, opts ...TGFileOption) (TGFile, error) {
switch m := media.(type) {
case *tg.MessageMediaDocument:
document, ok := m.Document.AsNotEmpty()
@@ -125,7 +125,7 @@ func FromMedia(media tg.MessageMediaClass, client downloader.Client, opts ...TGF
return nil, fmt.Errorf("unsupported media type: %T", media)
}
func FromMediaMessage(media tg.MessageMediaClass, client downloader.Client, msg *tg.Message, opts ...TGFileOptions) (TGFileMessage, error) {
func FromMediaMessage(media tg.MessageMediaClass, client downloader.Client, msg *tg.Message, opts ...TGFileOption) (TGFileMessage, error) {
file, err := FromMedia(media, client, opts...)
if err != nil {
return nil, err