Compare commits

..

3 Commits

5 changed files with 172 additions and 30 deletions

View File

@@ -10,6 +10,7 @@ import (
"github.com/celestix/gotgproto/sessionMaker"
"github.com/glebarez/sqlite"
"github.com/gotd/td/telegram/dcs"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/logger"
"golang.org/x/net/proxy"
@@ -60,6 +61,24 @@ func Init() {
Resolver: resolver,
},
)
if err != nil {
resultChan <- struct {
client *gotgproto.Client
err error
}{nil, err}
return
}
_, err = client.API().BotsSetBotCommands(ctx, &tg.BotsSetBotCommandsRequest{
Scope: &tg.BotCommandScopeDefault{},
Commands: []tg.BotCommand{
{Command: "start", Description: "开始使用"},
{Command: "help", Description: "显示帮助"},
{Command: "silent", Description: "开启/关闭静默模式"},
{Command: "storage", Description: "设置默认存储端"},
{Command: "save", Description: "保存所回复的文件"},
{Command: "path", Description: "更改保存路径配置"},
},
})
resultChan <- struct {
client *gotgproto.Client
err error

View File

@@ -8,6 +8,7 @@ import (
"github.com/duke-git/lancet/v2/slice"
"github.com/gookit/goutil/maputil"
"github.com/gotd/td/telegram/message/entity"
"github.com/gotd/td/telegram/message/styling"
"github.com/gotd/td/tg"
@@ -30,6 +31,7 @@ func RegisterHandlers(dispatcher dispatcher.Dispatcher) {
dispatcher.AddHandler(handlers.NewCommand("silent", silent))
dispatcher.AddHandler(handlers.NewCommand("storage", setDefaultStorage))
dispatcher.AddHandler(handlers.NewCommand("save", saveCmd))
dispatcher.AddHandler(handlers.NewCommand("path", setPath))
dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("add"), AddToQueue))
dispatcher.AddHandler(handlers.NewMessage(filters.Message.Media, handleFileMessage))
}
@@ -57,13 +59,14 @@ func start(ctx *ext.Context, update *ext.Update) error {
}
const helpText string = `
SaveAny Bot - 转存你的 Telegram 文件
Save Any Bot - 转存你的 Telegram 文件
命令:
/start - 开始使用
/help - 显示帮助
/silent - 静默模式
/storage - 设置默认存储位置
/save [自定义文件名] - 保存文件
/path <存储类型> <路径> - 更改文件保存路径
静默模式: 开启后 Bot 直接保存到收到的文件到默认位置, 不再询问
`
@@ -84,12 +87,7 @@ func silent(ctx *ext.Context, update *ext.Update) error {
logger.L.Errorf("Failed to update user: %s", err)
return dispatcher.EndGroups
}
ctx.Reply(update, ext.ReplyTextString(fmt.Sprintf("已%s静默模式", func() string {
if user.Silent {
return "开启"
}
return "关闭"
}())), nil)
ctx.Reply(update, ext.ReplyTextString(fmt.Sprintf("已%s静默模式", map[bool]string{true: "开启", false: "关闭"}[user.Silent])), nil)
return dispatcher.EndGroups
}
@@ -149,6 +147,11 @@ func saveCmd(ctx *ext.Context, update *ext.Update) error {
}
msg, err := GetTGMessage(ctx, Client, replyToMsgID)
if err != nil {
logger.L.Errorf("Failed to get message: %s", err)
ctx.Reply(update, ext.ReplyTextString("无法获取消息"), nil)
return dispatcher.EndGroups
}
supported, _ := supportedMediaFilter(msg)
if !supported {
@@ -209,9 +212,21 @@ func saveCmd(ctx *ext.Context, update *ext.Update) error {
}
if !user.Silent {
text := "请选择存储位置"
entityBuilder := entity.Builder{}
var entities []tg.MessageEntityClass
text := fmt.Sprintf("文件名: %s\n请选择存储位置", file.FileName)
if err := styling.Perform(&entityBuilder,
styling.Plain("文件名: "),
styling.Code(file.FileName),
styling.Plain("\n请选择存储位置"),
); err != nil {
logger.L.Errorf("Failed to build entity: %s", err)
} else {
text, entities = entityBuilder.Complete()
}
_, err = ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: text,
Entities: entities,
ReplyMarkup: getAddTaskMarkup(msg.ID),
ID: replied.ID,
})
@@ -244,6 +259,51 @@ func saveCmd(ctx *ext.Context, update *ext.Update) error {
return dispatcher.EndGroups
}
func setPath(ctx *ext.Context, update *ext.Update) error {
if len(storage.Storages) == 0 {
ctx.Reply(update, ext.ReplyTextString("未配置存储"), nil)
return dispatcher.EndGroups
}
if update.EffectiveMessage == nil {
logger.L.Error("No effective message")
return dispatcher.EndGroups
}
args := strings.Split(update.EffectiveMessage.Text, " ")
if len(args) < 3 {
text := []styling.StyledTextOption{
styling.Plain("请提供存储位置名称和路径, 可用项:"),
}
for name := range storage.Storages {
text = append(text, styling.Plain("\n"))
text = append(text, styling.Code(string(name)))
}
text = append(text, styling.Plain("\n示例: /path local /path/to/save"))
ctx.Reply(update, ext.ReplyTextStyledTextArray(text), nil)
return dispatcher.EndGroups
}
storageName := args[1]
if _, ok := storage.Storages[types.StorageType(storageName)]; !ok {
ctx.Reply(update, ext.ReplyTextString("存储位置不存在"), nil)
return dispatcher.EndGroups
}
path := strings.Join(args[2:], " ")
switch storageName {
case "local":
config.Set("storage.local.base_path", path)
case "webdav":
config.Set("storage.webdav.base_path", path)
case "alist":
config.Set("storage.alist.base_path", path)
}
if err := config.ReloadConfig(); err != nil {
logger.L.Errorf("Failed to reload config: %s", err)
ctx.Reply(update, ext.ReplyTextString("设置失败: "+err.Error()), nil)
return dispatcher.EndGroups
}
ctx.Reply(update, ext.ReplyTextString("设置成功"), nil)
return dispatcher.EndGroups
}
func handleFileMessage(ctx *ext.Context, update *ext.Update) error {
logger.L.Trace("Got media: ", update.EffectiveMessage.Media.TypeName())
supported, err := supportedMediaFilter(update.EffectiveMessage.Message)
@@ -299,9 +359,21 @@ func handleFileMessage(ctx *ext.Context, update *ext.Update) error {
}
if !user.Silent {
text := "请选择存储位置"
entityBuilder := entity.Builder{}
var entities []tg.MessageEntityClass
text := fmt.Sprintf("文件名: %s\n请选择存储位置", file.FileName)
if err := styling.Perform(&entityBuilder,
styling.Plain("文件名: "),
styling.Code(file.FileName),
styling.Plain("\n请选择存储位置"),
); err != nil {
logger.L.Errorf("Failed to build entity: %s", err)
} else {
text, entities = entityBuilder.Complete()
}
_, err = ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: text,
Entities: entities,
ReplyMarkup: getAddTaskMarkup(update.EffectiveMessage.ID),
ID: msg.ID,
})
@@ -388,9 +460,25 @@ func AddToQueue(ctx *ext.Context, update *ext.Update) error {
ReplyMessageID: record.ReplyMessageID,
MessageID: record.MessageID,
})
entityBuilder := entity.Builder{}
var entities []tg.MessageEntityClass
text := fmt.Sprintf("已添加到任务队列\n文件名: %s\n当前排队任务数: %d", record.FileName, queue.Len())
if err := styling.Perform(&entityBuilder,
styling.Plain("已添加到任务队列\n文件名: "),
styling.Code(record.FileName),
styling.Plain("\n当前排队任务数: "),
styling.Bold(strconv.Itoa(queue.Len())),
); err != nil {
logger.L.Errorf("Failed to build entity: %s", err)
} else {
text, entities = entityBuilder.Complete()
}
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: fmt.Sprintf("已添加到队列: %s\n当前排队任务数: %d", record.FileName, queue.Len()),
ID: record.ReplyMessageID,
Message: text,
Entities: entities,
ID: record.ReplyMessageID,
})
return dispatcher.EndGroups
}

View File

@@ -121,3 +121,20 @@ func Init() {
os.Exit(1)
}
}
func Set(key string, value any) {
viper.Set(key, value)
}
func ReloadConfig() error {
if err := viper.WriteConfig(); err != nil {
return err
}
if err := viper.ReadInConfig(); err != nil {
return err
}
if error := viper.Unmarshal(Cfg); error != nil {
return error
}
return nil
}

View File

@@ -57,30 +57,20 @@ func processPendingTask(task *types.Task) error {
if task.File.FileSize < 1024*1024*50 || int(progress)%(100/barTotalCount) != 0 {
return
}
text := fmt.Sprintf("正在处理下载任务\n文件名: %s\n保存路径: %s\n平均速度: %s\n当前进度: [%s] %.2f%%",
task.FileName(),
fmt.Sprintf("[%s]:%s", task.Storage, task.StoragePath),
getSpeed(bytesRead, task.StartTime),
getProgressBar(progress, barTotalCount),
progress,
)
text, entities := buildProgressMessageEntity(task, barTotalCount, bytesRead, task.StartTime, progress)
ctx.EditMessage(task.ChatID, &tg.MessagesEditMessageRequest{
Message: text,
ID: task.ReplyMessageID,
Message: text,
Entities: entities,
ID: task.ReplyMessageID,
})
}
text, entities := buildProgressMessageEntity(task, barTotalCount, 0, task.StartTime, 0)
ctx.EditMessage(task.ChatID, &tg.MessagesEditMessageRequest{
Message: fmt.Sprintf("正在处理下载任务\n文件名: %s\n保存路径: %s\n平均速度: %s\n当前进度: [%s] %.2f%%",
task.FileName(),
fmt.Sprintf("[%s]:%s", task.Storage, task.StoragePath),
"0B/s",
getProgressBar(0, barTotalCount),
0.0,
),
ID: task.ReplyMessageID,
Message: text,
Entities: entities,
ID: task.ReplyMessageID,
})
readCloser, err := NewTelegramReader(task.Ctx, bot.Client, &task.File.Location,
0, task.File.FileSize-1, task.File.FileSize,
progressCallback, task.File.FileSize/100)
@@ -135,7 +125,7 @@ func worker(queue *queue.TaskQueue, semaphore chan struct{}) {
case types.Succeeded:
logger.L.Infof("Task succeeded: %s", task.String())
task.Ctx.(*ext.Context).EditMessage(task.ChatID, &tg.MessagesEditMessageRequest{
Message: "保存成功\n" + task.FileName(),
Message: fmt.Sprintf("文件保存成功\n [%s]: %s", task.Storage, task.StoragePath),
ID: task.ReplyMessageID,
})
case types.Failed:

View File

@@ -5,6 +5,8 @@ import (
"os"
"time"
"github.com/gotd/td/telegram/message/entity"
"github.com/gotd/td/telegram/message/styling"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/bot"
"github.com/krau/SaveAny-Bot/common"
@@ -97,3 +99,29 @@ func getSpeed(bytesRead int64, startTime time.Time) string {
speed := float64(bytesRead) / 1024 / 1024 / elapsed.Seconds()
return fmt.Sprintf("%.2fMB/s", speed)
}
func buildProgressMessageEntity(task *types.Task, barTotalCount int, bytesRead int64, startTime time.Time, progress float64) (string, []tg.MessageEntityClass) {
entityBuilder := entity.Builder{}
text := fmt.Sprintf("正在处理下载任务\n文件名: %s\n保存路径: %s\n平均速度: %s\n当前进度: [%s] %.2f%%",
task.FileName(),
fmt.Sprintf("[%s]:%s", task.Storage, task.StoragePath),
getSpeed(bytesRead, startTime),
getProgressBar(progress, barTotalCount),
progress,
)
var entities []tg.MessageEntityClass
if err := styling.Perform(&entityBuilder,
styling.Plain("正在处理下载任务\n文件名: "),
styling.Code(task.FileName()),
styling.Plain("\n保存路径: "),
styling.Code(fmt.Sprintf("[%s]:%s", task.Storage, task.StoragePath)),
styling.Plain("\n平均速度: "),
styling.Bold(getSpeed(bytesRead, task.StartTime)),
styling.Plain("\n当前进度:\n "),
styling.Code(fmt.Sprintf("[%s] %.2f%%", getProgressBar(progress, barTotalCount), progress)),
); err != nil {
logger.L.Errorf("Failed to build entities: %s", err)
return text, entities
}
return entityBuilder.Complete()
}