diff --git a/.github/workflows/build-release.yml b/.github/workflows/build-release.yml index fb9211e..24846b4 100644 --- a/.github/workflows/build-release.yml +++ b/.github/workflows/build-release.yml @@ -63,9 +63,9 @@ jobs: README.md ldflags: >- -s -w - -X "github.com/krau/SaveAny-Bot/common.Version=${{ env.VERSION }}" - -X "github.com/krau/SaveAny-Bot/common.BuildTime=${{ format(github.event.repository.updated_at, 'yyyy-MM-dd HH:mm:ss') }}" - -X "github.com/krau/SaveAny-Bot/common.GitCommit=${{ github.sha }}" + -X "github.com/krau/SaveAny-Bot/pkg/consts.Version=${{ env.VERSION }}" + -X "github.com/krau/SaveAny-Bot/pkg/consts.BuildTime=${{ format(github.event.repository.updated_at, 'yyyy-MM-dd HH:mm:ss') }}" + -X "github.com/krau/SaveAny-Bot/pkg/consts.GitCommit=${{ github.sha }}" binary_name: saveany-bot env: VERSION: ${{ env.VERSION }} diff --git a/.gitignore b/.gitignore index 27d3162..6b27064 100644 --- a/.gitignore +++ b/.gitignore @@ -3,7 +3,7 @@ logs/ tmp/ data/ downloads/ -cache/ session.* cache.db -.vscode/ \ No newline at end of file +.vscode/ +temp/ \ No newline at end of file diff --git a/bot/handle_add_task.go b/bot/handle_add_task.go deleted file mode 100644 index d34dfb0..0000000 --- a/bot/handle_add_task.go +++ /dev/null @@ -1,216 +0,0 @@ -package bot - -import ( - "errors" - "fmt" - "path" - "strconv" - "strings" - - "github.com/celestix/gotgproto/dispatcher" - "github.com/celestix/gotgproto/ext" - "github.com/duke-git/lancet/v2/slice" - "github.com/gotd/td/telegram/message/entity" - "github.com/gotd/td/telegram/message/styling" - "github.com/gotd/td/tg" - "github.com/krau/SaveAny-Bot/common" - "github.com/krau/SaveAny-Bot/config" - "github.com/krau/SaveAny-Bot/dao" - "github.com/krau/SaveAny-Bot/queue" - "github.com/krau/SaveAny-Bot/types" - "github.com/krau/SaveAny-Bot/userclient" - "gorm.io/gorm" -) - -func AddToQueue(ctx *ext.Context, update *ext.Update) error { - if !slice.Contain(config.Cfg.GetUsersID(), update.CallbackQuery.UserID) { - ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{ - QueryID: update.CallbackQuery.QueryID, - Alert: true, - Message: "你没有权限", - CacheTime: 5, - }) - return dispatcher.EndGroups - } - args := strings.Split(string(update.CallbackQuery.Data), " ") - addToDir := args[0] == "add_to_dir" // 已经选择了路径 - cbDataId, _ := strconv.Atoi(args[1]) - cbData, err := dao.GetCallbackData(uint(cbDataId)) - if err != nil { - common.Log.Errorf("获取回调数据失败: %s", err) - ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{ - QueryID: update.CallbackQuery.QueryID, - Alert: true, - Message: "获取回调数据失败", - CacheTime: 5, - }) - return dispatcher.EndGroups - } - - data := strings.Split(cbData, " ") - fileChatID, _ := strconv.Atoi(data[0]) - fileMessageID, _ := strconv.Atoi(data[1]) - storageName := data[2] - dirIdInt, _ := strconv.Atoi(data[3]) - dirId := uint(dirIdInt) - - user, err := dao.GetUserByChatID(update.CallbackQuery.UserID) - if err != nil { - common.Log.Errorf("获取用户失败: %s", err) - ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{ - QueryID: update.CallbackQuery.QueryID, - Alert: true, - Message: "获取用户失败", - CacheTime: 5, - }) - return dispatcher.EndGroups - } - - if !addToDir { - dirs, err := dao.GetDirsByUserIDAndStorageName(user.ID, storageName) - if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { - common.Log.Errorf("获取路径失败: %s", err) - ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{ - QueryID: update.CallbackQuery.QueryID, - Alert: true, - Message: "获取路径失败", - CacheTime: 5, - }) - return dispatcher.EndGroups - } - if len(dirs) != 0 { - markup, err := getSelectDirMarkup(fileChatID, fileMessageID, storageName, dirs) - if err != nil { - common.Log.Errorf("获取路径失败: %s", err) - ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{ - QueryID: update.CallbackQuery.QueryID, - Alert: true, - Message: "获取路径失败", - CacheTime: 5, - }) - return dispatcher.EndGroups - } - _, err = ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{ - ID: update.CallbackQuery.GetMsgID(), - Message: "请选择要保存到的路径", - ReplyMarkup: markup, - }) - if err != nil { - common.Log.Errorf("编辑消息失败: %s", err) - } - return dispatcher.EndGroups - } - } - - common.Log.Tracef("Got add to queue: chatID: %d, messageID: %d, storage: %s", fileChatID, fileMessageID, storageName) - record, err := dao.GetReceivedFileByChatAndMessageID(int64(fileChatID), fileMessageID) - if err != nil { - common.Log.Errorf("获取记录失败: %s", err) - ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{ - QueryID: update.CallbackQuery.QueryID, - Alert: true, - Message: "查询记录失败", - CacheTime: 5, - }) - return dispatcher.EndGroups - } - if update.CallbackQuery.MsgID != record.ReplyMessageID { - record.ReplyMessageID = update.CallbackQuery.MsgID - if _, err := dao.SaveReceivedFile(record); err != nil { - common.Log.Errorf("更新记录失败: %s", err) - } - } - - var dir *dao.Dir - if addToDir && dirId != 0 { - dir, err = dao.GetDirByID(dirId) - if err != nil { - common.Log.Errorf("获取路径失败: %s", err) - ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{ - QueryID: update.CallbackQuery.QueryID, - Alert: true, - Message: "获取路径失败", - CacheTime: 5, - }) - return dispatcher.EndGroups - } - } - - var task types.Task - if record.IsTelegraph { - task = types.Task{ - Ctx: ctx, - Status: types.Pending, - IsTelegraph: true, - TelegraphURL: record.TelegraphURL, - StorageName: storageName, - FileChatID: record.ChatID, - FileMessageID: record.MessageID, - ReplyMessageID: record.ReplyMessageID, - ReplyChatID: record.ReplyChatID, - UserID: update.GetUserChat().GetID(), - } - if dir != nil { - task.StoragePath = path.Join(dir.Path, record.FileName) - } - } else { - var file *types.File - var err error - if record.UseUserClient && userclient.UC != nil { - uctx := userclient.UC.CreateContext() - file, err = FileFromMessage(uctx, record.ChatID, record.MessageID, record.FileName) - } else { - file, err = FileFromMessage(ctx, record.ChatID, record.MessageID, record.FileName) - } - if err != nil { - common.Log.Errorf("获取消息中的文件失败: %s", err) - ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{ - QueryID: update.CallbackQuery.QueryID, - Alert: true, - Message: fmt.Sprintf("获取消息中的文件失败: %s", err), - CacheTime: 5, - }) - return dispatcher.EndGroups - } - - task = types.Task{ - Ctx: ctx, - Status: types.Pending, - UseUserClient: record.UseUserClient, - FileDBID: record.ID, - File: file, - StorageName: storageName, - FileChatID: record.ChatID, - ReplyMessageID: record.ReplyMessageID, - FileMessageID: record.MessageID, - ReplyChatID: record.ReplyChatID, - UserID: update.GetUserChat().GetID(), - } - if dir != nil { - task.StoragePath = path.Join(dir.Path, file.FileName) - } - } - - queue.AddTask(&task) - - entityBuilder := entity.Builder{} - var entities []tg.MessageEntityClass - text := fmt.Sprintf("已添加到任务队列\n文件名: %s\n当前排队任务数: %d", record.FileName, queue.Len()) - if err := styling.Perform(&entityBuilder, - styling.Plain("已添加到任务队列\n文件名: "), - styling.Code(record.FileName), - styling.Plain("\n当前排队任务数: "), - styling.Bold(strconv.Itoa(queue.Len())), - ); err != nil { - common.Log.Errorf("Failed to build entity: %s", err) - } else { - text, entities = entityBuilder.Complete() - } - - ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{ - Message: text, - Entities: entities, - ID: record.ReplyMessageID, - }) - return dispatcher.EndGroups -} diff --git a/bot/handle_cancel_task.go b/bot/handle_cancel_task.go deleted file mode 100644 index 512df00..0000000 --- a/bot/handle_cancel_task.go +++ /dev/null @@ -1,27 +0,0 @@ -package bot - -import ( - "strings" - - "github.com/celestix/gotgproto/dispatcher" - "github.com/celestix/gotgproto/ext" - "github.com/gotd/td/tg" - "github.com/krau/SaveAny-Bot/queue" -) - -func cancelTask(ctx *ext.Context, update *ext.Update) error { - key := strings.Split(string(update.CallbackQuery.Data), " ")[1] - ok := queue.CancelTask(key) - if ok { - ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{ - QueryID: update.CallbackQuery.QueryID, - Message: "任务已取消", - }) - return dispatcher.EndGroups - } - ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{ - QueryID: update.CallbackQuery.QueryID, - Message: "任务取消失败", - }) - return dispatcher.EndGroups -} diff --git a/bot/handle_dir.go b/bot/handle_dir.go deleted file mode 100644 index 172d984..0000000 --- a/bot/handle_dir.go +++ /dev/null @@ -1,110 +0,0 @@ -package bot - -import ( - "fmt" - "strconv" - "strings" - - "github.com/celestix/gotgproto/dispatcher" - "github.com/celestix/gotgproto/ext" - "github.com/gotd/td/telegram/message/styling" - "github.com/krau/SaveAny-Bot/common" - "github.com/krau/SaveAny-Bot/dao" - "github.com/krau/SaveAny-Bot/storage" -) - -func sendDirHelp(ctx *ext.Context, update *ext.Update, userChatID int64) error { - dirs, err := dao.GetUserDirsByChatID(userChatID) - if err != nil { - common.Log.Errorf("获取用户路径失败: %s", err) - ctx.Reply(update, ext.ReplyTextString("获取用户路径失败"), nil) - return dispatcher.EndGroups - } - ctx.Reply(update, ext.ReplyTextStyledTextArray( - []styling.StyledTextOption{ - styling.Bold("使用方法: /dir <操作> <参数...>"), - styling.Plain("\n\n可用操作:\n"), - styling.Code("add"), - styling.Plain(" <存储名> <路径> - 添加路径\n"), - styling.Code("del"), - styling.Plain(" <路径ID> - 删除路径\n"), - styling.Plain("\n添加路径示例:\n"), - styling.Code("/dir add local1 path/to/dir"), - styling.Plain("\n\n删除路径示例:\n"), - styling.Code("/dir del 3"), - styling.Plain("\n\n当前已添加的路径:\n"), - styling.Blockquote(func() string { - var sb strings.Builder - for _, dir := range dirs { - sb.WriteString(fmt.Sprintf("%d: ", dir.ID)) - sb.WriteString(dir.StorageName) - sb.WriteString(" - ") - sb.WriteString(dir.Path) - sb.WriteString("\n") - } - return sb.String() - }(), true), - }, - ), nil) - return dispatcher.EndGroups -} - -func dirCmd(ctx *ext.Context, update *ext.Update) error { - args := strings.Split(update.EffectiveMessage.Text, " ") - if len(args) < 2 { - return sendDirHelp(ctx, update, update.GetUserChat().GetID()) - } - user, err := dao.GetUserByChatID(update.GetUserChat().GetID()) - if err != nil { - common.Log.Errorf("获取用户失败: %s", err) - ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil) - return dispatcher.EndGroups - } - switch args[1] { - case "add": - // /dir add local1 path/to/dir - if len(args) < 4 { - return sendDirHelp(ctx, update, update.GetUserChat().GetID()) - } - return addDir(ctx, update, user, args[2], args[3]) - case "del": - // /dir del 3 - if len(args) < 3 { - return sendDirHelp(ctx, update, update.GetUserChat().GetID()) - } - dirID, err := strconv.Atoi(args[2]) - if err != nil { - ctx.Reply(update, ext.ReplyTextString("路径ID无效"), nil) - return dispatcher.EndGroups - } - return delDir(ctx, update, dirID) - default: - ctx.Reply(update, ext.ReplyTextString("未知操作"), nil) - return dispatcher.EndGroups - } -} - -func addDir(ctx *ext.Context, update *ext.Update, user *dao.User, storageName, path string) error { - if _, err := storage.GetStorageByUserIDAndName(user.ChatID, storageName); err != nil { - ctx.Reply(update, ext.ReplyTextString(err.Error()), nil) - return dispatcher.EndGroups - } - - if err := dao.CreateDirForUser(user.ID, storageName, path); err != nil { - common.Log.Errorf("创建路径失败: %s", err) - ctx.Reply(update, ext.ReplyTextString("创建路径失败"), nil) - return dispatcher.EndGroups - } - ctx.Reply(update, ext.ReplyTextString("路径添加成功"), nil) - return dispatcher.EndGroups -} - -func delDir(ctx *ext.Context, update *ext.Update, dirID int) error { - if err := dao.DeleteDirByID(uint(dirID)); err != nil { - common.Log.Errorf("删除路径失败: %s", err) - ctx.Reply(update, ext.ReplyTextString("删除路径失败"), nil) - return dispatcher.EndGroups - } - ctx.Reply(update, ext.ReplyTextString("路径删除成功"), nil) - return dispatcher.EndGroups -} diff --git a/bot/handle_file.go b/bot/handle_file.go deleted file mode 100644 index e56ae55..0000000 --- a/bot/handle_file.go +++ /dev/null @@ -1,86 +0,0 @@ -package bot - -import ( - "fmt" - - "github.com/celestix/gotgproto/dispatcher" - "github.com/celestix/gotgproto/ext" - "github.com/gotd/td/tg" - "github.com/krau/SaveAny-Bot/common" - "github.com/krau/SaveAny-Bot/dao" - "github.com/krau/SaveAny-Bot/types" -) - -func handleFileMessage(ctx *ext.Context, update *ext.Update) error { - common.Log.Trace("Got media: ", update.EffectiveMessage.Media.TypeName()) - supported, err := supportedMediaFilter(update.EffectiveMessage.Message) - if err != nil { - return err - } - if !supported { - return dispatcher.EndGroups - } - - user, err := dao.GetUserByChatID(update.GetUserChat().GetID()) - if err != nil { - common.Log.Errorf("获取用户失败: %s", err) - ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil) - return dispatcher.EndGroups - } - // storages := storage.GetUserStorages(user.ChatID) - // if len(storages) == 0 { - // ctx.Reply(update, ext.ReplyTextString("无可用的存储"), nil) - // return dispatcher.EndGroups - // } - - msg, err := ctx.Reply(update, ext.ReplyTextString("正在获取文件信息..."), nil) - if err != nil { - common.Log.Errorf("回复失败: %s", err) - return dispatcher.EndGroups - } - media := update.EffectiveMessage.Media - file, err := FileFromMedia(media, "") - if err != nil { - common.Log.Errorf("获取文件失败: %s", err) - ctx.Reply(update, ext.ReplyTextString(fmt.Sprintf("获取文件失败: %s", err)), nil) - return dispatcher.EndGroups - } - if file.FileName == "" { - file.FileName = GenFileNameFromMessage(*update.EffectiveMessage.Message, file) - } - - record, err := dao.SaveReceivedFile(&dao.ReceivedFile{ - Processing: false, - FileName: file.FileName, - ChatID: update.EffectiveChat().GetID(), - MessageID: update.EffectiveMessage.ID, - ReplyMessageID: msg.ID, - ReplyChatID: update.GetUserChat().GetID(), - }) - if err != nil { - common.Log.Errorf("添加接收的文件失败: %s", err) - if _, err := ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{ - Message: fmt.Sprintf("添加接收的文件失败: %s", err), - ID: msg.ID, - }); err != nil { - common.Log.Errorf("编辑消息失败: %s", err) - } - return dispatcher.EndGroups - } - - if !user.Silent || user.DefaultStorage == "" { - return ProvideSelectMessage(ctx, update, file.FileName, update.EffectiveChat().GetID(), update.EffectiveMessage.ID, msg.ID) - } - return HandleSilentAddTask(ctx, update, user, &types.Task{ - Ctx: ctx, - Status: types.Pending, - FileDBID: record.ID, - File: file, - StorageName: user.DefaultStorage, - FileChatID: update.EffectiveChat().GetID(), - ReplyMessageID: msg.ID, - ReplyChatID: update.GetUserChat().GetID(), - FileMessageID: update.EffectiveMessage.ID, - UserID: user.ChatID, - }) -} diff --git a/bot/handle_link.go b/bot/handle_link.go deleted file mode 100644 index b777016..0000000 --- a/bot/handle_link.go +++ /dev/null @@ -1,235 +0,0 @@ -package bot - -import ( - "errors" - "fmt" - "net/url" - "regexp" - "strconv" - "strings" - - "github.com/celestix/gotgproto/dispatcher" - "github.com/celestix/gotgproto/ext" - "github.com/gotd/td/tg" - "github.com/krau/SaveAny-Bot/common" - "github.com/krau/SaveAny-Bot/dao" - "github.com/krau/SaveAny-Bot/types" - "github.com/krau/SaveAny-Bot/userclient" -) - -var ( - linkRegexString = `https?://t\.me/(?:c/\d+|[a-zA-Z0-9_]+)/\d+(?:\?[^\s]*)?` - linkRegex = regexp.MustCompile(linkRegexString) -) - -type parseResult struct { - ChatID int64 - MessageID int - Files []*types.File - UserClient bool -} - -func parseLink(ctx *ext.Context, link string) (*parseResult, error) { - u, err := url.Parse(link) - if err != nil { - return nil, fmt.Errorf("无法解析链接: %s", err) - } - strSlice := strings.Split(u.Path, "/") - if len(strSlice) < 3 { - return nil, fmt.Errorf("链接格式错误: %s", link) - } - messageID, err := strconv.Atoi(strSlice[len(strSlice)-1]) - if err != nil { - return nil, fmt.Errorf("无法解析消息 ID: %s", err) - } - var chatID int64 - if len(strSlice) == 3 { - chatUsername := strSlice[1] - peer := ctx.PeerStorage.GetPeerByUsername(chatUsername) - if peer != nil { - chatID = peer.ID - } else { - linkChat, err := ctx.ResolveUsername(chatUsername) - if err != nil { - return nil, fmt.Errorf("解析用户名失败: %s", err) - } - if linkChat == nil { - return nil, fmt.Errorf("找不到该聊天: %s", chatUsername) - } - chatID = linkChat.GetID() - } - } else if len(strSlice) == 4 { - chatIDInt, err := strconv.Atoi(strSlice[2]) - if err != nil { - return nil, fmt.Errorf("无法解析 Chat ID: %s", err) - } - chatID = int64(chatIDInt) - } else { - return nil, errors.New("链接格式不正确,无法解析 Chat ID") - } - if chatID == 0 || messageID == 0 { - return nil, fmt.Errorf("链接中缺少 Chat ID 或 Message ID: %s", link) - } - msg, _, err := tryFetchMessage(ctx, chatID, messageID) - if err != nil { - return nil, fmt.Errorf("获取消息失败: %s", err) - } - mediaGroup, isGroup := msg.GetGroupedID() - if u.Query().Has("single") || !isGroup || (mediaGroup == 0) || userclient.UC == nil { - file, useUserClient, err := tryFetchFileFromMessage(ctx, chatID, messageID, "") - if err != nil { - return nil, fmt.Errorf("获取文件失败: %s", err) - } - if file.FileName == "" { - file.FileName = GenFileNameFromMessage(*msg, file) - } - return &parseResult{ - ChatID: chatID, - MessageID: messageID, - Files: []*types.File{file}, - UserClient: useUserClient, - }, nil - } - groupMessages, isUserClient, err := tryGetMediaGroup(chatID, messageID, mediaGroup) - if err != nil { - return nil, fmt.Errorf("获取媒体组消息失败: %s", err) - } - var files []*types.File - for _, groupMsg := range groupMessages { - file, err := FileFromMedia(groupMsg.Media, "") - if err != nil { - return nil, fmt.Errorf("获取媒体文件失败: %s", err) - } - if file.FileName == "" { - file.FileName = GenFileNameFromMessage(*groupMsg, file) - } - files = append(files, file) - } - return &parseResult{ - ChatID: chatID, - MessageID: messageID, - Files: files, - UserClient: isUserClient, - }, nil -} - -// use passed ctx client to fetch file from message, -// -// if failed try using userclient -func tryFetchFileFromMessage(ctx *ext.Context, chatID int64, messageID int, fileName string) (*types.File, bool, error) { - file, err := FileFromMessage(ctx, chatID, messageID, fileName) - if err == nil { - return file, false, nil - } - if (strings.Contains(err.Error(), "peer not found") || strings.Contains(err.Error(), "unexpected message type")) && userclient.UC != nil { - common.Log.Warnf("无法获取文件 %d:%d, 尝试使用 userbot: %s", chatID, messageID, err) - uctx := userclient.GetCtx() - peer := uctx.PeerStorage.GetInputPeerById(chatID) - if peer == nil { - return nil, true, fmt.Errorf("failed to get peer for chat %d: %w", chatID, err) - } - msg, err := GetSingleHistoryMessage(uctx, uctx.Raw, peer, messageID) - if err != nil { - return nil, true, err - } - file, err = FileFromMedia(msg.Media, fileName) - if err != nil { - return nil, true, fmt.Errorf("failed to get file from userbot message %d:%d: %w", chatID, messageID, err) - } - return file, true, nil - } - return nil, false, err -} - -func tryGetMediaGroup(chatID int64, messageID int, mediaGroupID int64) ([]*tg.Message, bool, error) { - if userclient.UC != nil { - uctx := userclient.GetCtx() - messages, err := GetMediaGroup(uctx, chatID, messageID, mediaGroupID) - if err != nil { - return nil, true, fmt.Errorf("failed to get media group from userbot: %w", err) - } - return messages, true, nil - } - return nil, false, errors.New("userclient is not available, cannot fetch media group") -} - -func tryFetchMessage(ctx *ext.Context, chatID int64, messageID int) (*tg.Message, bool, error) { - msg, err := GetTGMessage(ctx, chatID, messageID) - if err == nil { - return msg, false, nil - } - if userclient.UC != nil && (strings.Contains(err.Error(), "peer not found") || strings.Contains(err.Error(), "unexpected message type")) { - common.Log.Warnf("无法获取消息 %d:%d, 尝试使用 userbot: %s", chatID, messageID, err) - uctx := userclient.GetCtx() - msg, err := GetTGMessage(uctx, chatID, messageID) - if err == nil { - return msg, true, nil - } - return nil, true, fmt.Errorf("获取消息失败: %w", err) - } - return nil, false, fmt.Errorf("获取消息失败: %s", err) -} - -func handleLinkMessage(ctx *ext.Context, update *ext.Update) error { - common.Log.Trace("Got link message") - link := linkRegex.FindString(update.EffectiveMessage.Text) - if link == "" { - return dispatcher.ContinueGroups - } - result, err := parseLink(ctx, link) - if err != nil { - common.Log.Errorf("解析链接失败: %s", err) - ctx.Reply(update, ext.ReplyTextString("解析链接失败"), nil) - return dispatcher.EndGroups - } - - user, err := dao.GetUserByChatID(update.GetUserChat().GetID()) - if err != nil { - common.Log.Errorf("获取用户失败: %s", err) - ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil) - return dispatcher.EndGroups - } - - replied, err := ctx.Reply(update, ext.ReplyTextString("正在获取文件..."), nil) - if err != nil { - common.Log.Errorf("回复失败: %s", err) - return dispatcher.EndGroups - } - - // TODO: handle group files - receivedFile := &dao.ReceivedFile{ - Processing: false, - FileName: result.Files[0].FileName, - ChatID: result.ChatID, - MessageID: result.MessageID, - ReplyMessageID: replied.ID, - ReplyChatID: update.GetUserChat().GetID(), - UseUserClient: result.UserClient, - } - record, err := dao.SaveReceivedFile(receivedFile) - if err != nil { - common.Log.Errorf("保存接收的文件失败: %s", err) - ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{ - Message: "无法保存文件: " + err.Error(), - ID: replied.ID, - }) - return dispatcher.EndGroups - } - file := result.Files[0] - if !user.Silent || user.DefaultStorage == "" { - return ProvideSelectMessage(ctx, update, file.FileName, result.ChatID, result.MessageID, replied.ID) - } - return HandleSilentAddTask(ctx, update, user, &types.Task{ - Ctx: ctx, - Status: types.Pending, - FileDBID: record.ID, - UseUserClient: result.UserClient, - File: file, - StorageName: user.DefaultStorage, - UserID: user.ChatID, - FileChatID: result.ChatID, - FileMessageID: result.MessageID, - ReplyMessageID: replied.ID, - ReplyChatID: update.GetUserChat().GetID(), - }) -} diff --git a/bot/handle_rule.go b/bot/handle_rule.go deleted file mode 100644 index d2b8100..0000000 --- a/bot/handle_rule.go +++ /dev/null @@ -1,141 +0,0 @@ -package bot - -import ( - "fmt" - "strconv" - "strings" - - "github.com/celestix/gotgproto/dispatcher" - "github.com/celestix/gotgproto/ext" - "github.com/duke-git/lancet/v2/slice" - "github.com/gotd/td/telegram/message/styling" - "github.com/krau/SaveAny-Bot/common" - "github.com/krau/SaveAny-Bot/dao" - "github.com/krau/SaveAny-Bot/types" -) - -func sendRuleHelp(ctx *ext.Context, update *ext.Update, userChatID int64) error { - user, err := dao.GetUserByChatID(userChatID) - if err != nil { - common.Log.Errorf("获取用户规则失败: %s", err) - ctx.Reply(update, ext.ReplyTextString("获取用户规则失败"), nil) - return dispatcher.EndGroups - } - ctx.Reply(update, ext.ReplyTextStyledTextArray( - []styling.StyledTextOption{ - styling.Bold("使用方法: /rule <操作> <参数...>"), - styling.Bold(fmt.Sprintf("\n当前已%s规则模式", map[bool]string{true: "启用", false: "禁用"}[user.ApplyRule])), - styling.Plain("\n\n可用操作:\n"), - styling.Code("switch"), - styling.Plain(" - 开关规则模式\n"), - styling.Code("add"), - styling.Plain(" <类型> <数据> <存储名> <路径> - 添加规则\n"), - styling.Code("del"), - styling.Plain(" <规则ID> - 删除规则\n"), - styling.Plain("\n当前已添加的规则:\n"), - styling.Blockquote(func() string { - var sb strings.Builder - for _, rule := range user.Rules { - ruleText := fmt.Sprintf("%s %s %s %s", rule.Type, rule.Data, rule.StorageName, rule.DirPath) - sb.WriteString(fmt.Sprintf("%d: %s\n", rule.ID, ruleText)) - } - return sb.String() - }(), true), - }, - ), nil) - return dispatcher.EndGroups -} - -func ruleCmd(ctx *ext.Context, update *ext.Update) error { - args := strings.Split(update.EffectiveMessage.Text, " ") - if len(args) < 2 { - return sendRuleHelp(ctx, update, update.GetUserChat().GetID()) - } - user, err := dao.GetUserByChatID(update.GetUserChat().GetID()) - if err != nil { - common.Log.Errorf("获取用户失败: %s", err) - ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil) - return dispatcher.EndGroups - } - switch args[1] { - case "switch": - // /rule switch - return switchApplyRule(ctx, update, user) - case "add": - // /rule add - if len(args) < 6 { - return sendRuleHelp(ctx, update, user.ChatID) - } - return addRule(ctx, update, user, args) - case "del": - // /rule del - if len(args) < 3 { - return sendRuleHelp(ctx, update, user.ChatID) - } - ruleID := args[2] - id, err := strconv.Atoi(ruleID) - if err != nil { - ctx.Reply(update, ext.ReplyTextString("无效的规则ID"), nil) - return dispatcher.EndGroups - } - if err := dao.DeleteRule(uint(id)); err != nil { - common.Log.Errorf("删除规则失败: %s", err) - ctx.Reply(update, ext.ReplyTextString("删除规则失败"), nil) - return dispatcher.EndGroups - } - ctx.Reply(update, ext.ReplyTextString("删除规则成功"), nil) - return dispatcher.EndGroups - default: - return sendRuleHelp(ctx, update, user.ChatID) - } -} - -func switchApplyRule(ctx *ext.Context, update *ext.Update, user *dao.User) error { - applyRule := !user.ApplyRule - if err := dao.UpdateUserApplyRule(user.ChatID, applyRule); err != nil { - common.Log.Errorf("更新用户失败: %s", err) - ctx.Reply(update, ext.ReplyTextString("更新用户失败"), nil) - return dispatcher.EndGroups - } - if applyRule { - ctx.Reply(update, ext.ReplyTextString("已启用规则模式"), nil) - } else { - ctx.Reply(update, ext.ReplyTextString("已禁用规则模式"), nil) - } - return dispatcher.EndGroups -} - -func addRule(ctx *ext.Context, update *ext.Update, user *dao.User, args []string) error { - // /rule add - ruleType := args[2] - ruleData := args[3] - storageName := args[4] - dirPath := args[5] - - if !slice.Contain(types.RuleTypes, types.RuleType(ruleType)) { - var ruleTypesStylingArray []styling.StyledTextOption - ruleTypesStylingArray = append(ruleTypesStylingArray, styling.Bold("无效的规则类型, 可用类型:\n")) - for i, ruleType := range types.RuleTypes { - ruleTypesStylingArray = append(ruleTypesStylingArray, styling.Code(string(ruleType))) - if i != len(types.RuleTypes)-1 { - ruleTypesStylingArray = append(ruleTypesStylingArray, styling.Plain(", ")) - } - } - ctx.Reply(update, ext.ReplyTextStyledTextArray(ruleTypesStylingArray), nil) - return dispatcher.EndGroups - } - rule := &dao.Rule{ - Type: ruleType, - Data: ruleData, - StorageName: storageName, - DirPath: dirPath, - UserID: user.ID, - } - if err := dao.CreateRule(rule); err != nil { - common.Log.Errorf("添加规则失败: %s", err) - ctx.Reply(update, ext.ReplyTextString("添加规则失败"), nil) - return dispatcher.EndGroups - } - ctx.Reply(update, ext.ReplyTextString("添加规则成功"), nil) - return dispatcher.EndGroups -} diff --git a/bot/handle_save.go b/bot/handle_save.go deleted file mode 100644 index 7091bb7..0000000 --- a/bot/handle_save.go +++ /dev/null @@ -1,266 +0,0 @@ -package bot - -import ( - "fmt" - "strconv" - "strings" - - "github.com/celestix/gotgproto/dispatcher" - "github.com/celestix/gotgproto/ext" - "github.com/gotd/td/tg" - "github.com/krau/SaveAny-Bot/common" - "github.com/krau/SaveAny-Bot/dao" - "github.com/krau/SaveAny-Bot/queue" - "github.com/krau/SaveAny-Bot/storage" - "github.com/krau/SaveAny-Bot/types" -) - -func sendSaveHelp(ctx *ext.Context, update *ext.Update) error { - helpText := ` -使用方法: - -1. 使用该命令回复要保存的文件, 可选文件名参数. -示例: -/save custom_file_name.mp4 - -2. 设置默认存储后, 发送 /save <频道ID/用户名> <消息ID范围> 来批量保存文件. 遵从存储规则, 若未匹配到任何规则则使用默认存储. -示例: -/save @moreacg 114-514 - ` - ctx.Reply(update, ext.ReplyTextString(helpText), nil) - return dispatcher.EndGroups -} - -func saveCmd(ctx *ext.Context, update *ext.Update) error { - args := strings.Split(update.EffectiveMessage.Text, " ") - if len(args) >= 3 { - return handleBatchSave(ctx, update, args[1:]) - } - - replyToMsgID := func() int { - res, ok := update.EffectiveMessage.GetReplyTo() - if !ok || res == nil { - return 0 - } - replyHeader, ok := res.(*tg.MessageReplyHeader) - if !ok { - return 0 - } - replyToMsgID, ok := replyHeader.GetReplyToMsgID() - if !ok { - return 0 - } - return replyToMsgID - }() - if replyToMsgID == 0 { - return sendSaveHelp(ctx, update) - } - - user, err := dao.GetUserByChatID(update.GetUserChat().GetID()) - if err != nil { - common.Log.Errorf("获取用户失败: %s", err) - ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil) - return dispatcher.EndGroups - } - - // storages := storage.GetUserStorages(user.ChatID) - // if len(storages) == 0 { - // ctx.Reply(update, ext.ReplyTextString("无可用的存储"), nil) - // return dispatcher.EndGroups - // } - - msg, err := GetTGMessage(ctx, update.EffectiveChat().GetID(), replyToMsgID) - if err != nil { - common.Log.Errorf("获取消息失败: %s", err) - ctx.Reply(update, ext.ReplyTextString("无法获取消息"), nil) - return dispatcher.EndGroups - } - - supported, _ := supportedMediaFilter(msg) - if !supported { - ctx.Reply(update, ext.ReplyTextString("不支持的消息类型或消息中没有文件"), nil) - return dispatcher.EndGroups - } - - replied, err := ctx.Reply(update, ext.ReplyTextString("正在获取文件信息..."), nil) - if err != nil { - common.Log.Errorf("回复失败: %s", err) - return dispatcher.EndGroups - } - - cmdText := update.EffectiveMessage.Text - customFileName := strings.TrimSpace(strings.TrimPrefix(cmdText, "/save")) - - file, err := FileFromMessage(ctx, update.EffectiveChat().GetID(), msg.ID, customFileName) - if err != nil { - common.Log.Errorf("获取文件失败: %s", err) - ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{ - Message: fmt.Sprintf("获取文件失败: %s", err), - ID: replied.ID, - }) - return dispatcher.EndGroups - } - - if file.FileName == "" { - file.FileName = GenFileNameFromMessage(*msg, file) - } - receivedFile := &dao.ReceivedFile{ - Processing: false, - FileName: file.FileName, - ChatID: update.EffectiveChat().GetID(), - MessageID: replyToMsgID, - ReplyMessageID: replied.ID, - ReplyChatID: update.GetUserChat().GetID(), - } - - record, err := dao.SaveReceivedFile(receivedFile) - if err != nil { - common.Log.Errorf("保存接收的文件失败: %s", err) - if _, err := ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{ - Message: fmt.Sprintf("保存接收的文件失败: %s", err), - ID: replied.ID, - }); err != nil { - common.Log.Errorf("编辑消息失败: %s", err) - } - return dispatcher.EndGroups - } - if !user.Silent || user.DefaultStorage == "" { - return ProvideSelectMessage(ctx, update, file.FileName, update.EffectiveChat().GetID(), msg.ID, replied.ID) - } - return HandleSilentAddTask(ctx, update, user, &types.Task{ - Ctx: ctx, - Status: types.Pending, - FileDBID: record.ID, - File: file, - StorageName: user.DefaultStorage, - FileChatID: update.EffectiveChat().GetID(), - ReplyMessageID: replied.ID, - ReplyChatID: update.GetUserChat().GetID(), - FileMessageID: msg.ID, - UserID: user.ChatID, - }) -} - -func handleBatchSave(ctx *ext.Context, update *ext.Update, args []string) error { - // args: [0] = @channel, [1] = 114-514 - chatArg := args[0] - var chatID int64 - var err error - msgIdSlice := strings.Split(args[1], "-") - if len(msgIdSlice) != 2 { - ctx.Reply(update, ext.ReplyTextString("无效的消息ID范围"), nil) - return dispatcher.EndGroups - } - minMsgID, minerr := strconv.ParseInt(msgIdSlice[0], 10, 64) - maxMsgID, maxerr := strconv.ParseInt(msgIdSlice[1], 10, 64) - if minerr != nil || maxerr != nil { - ctx.Reply(update, ext.ReplyTextString("无效的消息ID范围"), nil) - return dispatcher.EndGroups - } - if minMsgID > maxMsgID || minMsgID <= 0 || maxMsgID <= 0 { - ctx.Reply(update, ext.ReplyTextString("无效的消息ID范围"), nil) - return dispatcher.EndGroups - } - user, err := dao.GetUserByChatID(update.GetUserChat().GetID()) - if err != nil { - common.Log.Errorf("获取用户失败: %s", err) - ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil) - return dispatcher.EndGroups - } - if user.DefaultStorage == "" { - ctx.Reply(update, ext.ReplyTextString("请先设置默认存储"), nil) - return dispatcher.EndGroups - } - storages := storage.GetUserStorages(user.ChatID) - if len(storages) == 0 { - ctx.Reply(update, ext.ReplyTextString("无可用的存储"), nil) - return dispatcher.EndGroups - } - - if strings.HasPrefix(chatArg, "@") { - chatUsername := strings.TrimPrefix(chatArg, "@") - chat, err := ctx.ResolveUsername(chatUsername) - if err != nil { - common.Log.Errorf("解析频道用户名失败: %s", err) - ctx.Reply(update, ext.ReplyTextString("解析频道用户名失败"), nil) - return dispatcher.EndGroups - } - if chat == nil { - ctx.Reply(update, ext.ReplyTextString("无法找到聊天"), nil) - return dispatcher.EndGroups - } - chatID = chat.GetID() - } else { - chatID, err = strconv.ParseInt(chatArg, 10, 64) - if err != nil { - ctx.Reply(update, ext.ReplyTextString("无效的频道ID或用户名"), nil) - return dispatcher.EndGroups - } - } - if chatID == 0 { - ctx.Reply(update, ext.ReplyTextString("无效的频道ID或用户名"), nil) - return dispatcher.EndGroups - } - - replied, err := ctx.Reply(update, ext.ReplyTextString("正在批量保存..."), nil) - if err != nil { - common.Log.Errorf("回复失败: %s", err) - return dispatcher.EndGroups - } - - total := maxMsgID - minMsgID + 1 - successadd := 0 - failedGetFile := 0 - failedGetMsg := 0 - failedSaveDB := 0 - for i := minMsgID; i <= maxMsgID; i++ { - file, err := FileFromMessage(ctx, chatID, int(i), "") - if err != nil { - common.Log.Errorf("获取文件失败: %s", err) - failedGetFile++ - continue - } - if file.FileName == "" { - message, err := GetTGMessage(ctx, chatID, int(i)) - if err != nil { - common.Log.Errorf("获取消息失败: %s", err) - failedGetMsg++ - continue - } - file.FileName = GenFileNameFromMessage(*message, file) - } - receivedFile := &dao.ReceivedFile{ - Processing: false, - FileName: file.FileName, - ChatID: chatID, - MessageID: int(i), - ReplyChatID: update.GetUserChat().GetID(), - ReplyMessageID: 0, - } - record, err := dao.SaveReceivedFile(receivedFile) - if err != nil { - common.Log.Errorf("保存接收的文件失败: %s", err) - failedSaveDB++ - continue - } - task := &types.Task{ - Ctx: ctx, - Status: types.Pending, - FileDBID: record.ID, - File: file, - StorageName: user.DefaultStorage, - FileChatID: chatID, - FileMessageID: int(i), - UserID: user.ChatID, - ReplyMessageID: 0, - ReplyChatID: update.GetUserChat().GetID(), - } - queue.AddTask(task) - successadd++ - } - ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{ - Message: fmt.Sprintf("批量添加任务完成\n成功添加: %d/%d\n获取文件失败: %d\n获取消息失败: %d\n保存数据库失败: %d", successadd, total, failedGetFile, failedGetMsg, failedSaveDB), - ID: replied.ID, - }) - return dispatcher.EndGroups -} diff --git a/bot/handle_send.go b/bot/handle_send.go deleted file mode 100644 index ffb654c..0000000 --- a/bot/handle_send.go +++ /dev/null @@ -1,95 +0,0 @@ -package bot - -import ( - "fmt" - "strconv" - "strings" - - "github.com/celestix/gotgproto/dispatcher" - "github.com/celestix/gotgproto/ext" - tgtypes "github.com/celestix/gotgproto/types" - "github.com/gotd/td/tg" -) - -func copyMediaToChat(ctx *ext.Context, msg *tg.Message, chatID int64) (*tgtypes.Message, error) { - media, ok := msg.GetMedia() - if !ok { - return nil, fmt.Errorf("获取媒体失败") - } - - req := &tg.MessagesSendMediaRequest{ - InvertMedia: msg.InvertMedia, - Message: msg.Message, - } - - switch m := media.(type) { - case *tg.MessageMediaDocument: - document, ok := m.Document.AsNotEmpty() - if !ok { - return nil, ErrEmptyDocument - } - inputMedia := &tg.InputMediaDocument{ - ID: document.AsInput(), - } - inputMedia.SetFlags() - req.Media = inputMedia - - case *tg.MessageMediaPhoto: - photo, ok := m.Photo.AsNotEmpty() - if !ok { - return nil, ErrEmptyPhoto - } - inputMedia := &tg.InputMediaPhoto{ - ID: photo.AsInput(), - } - inputMedia.SetFlags() - req.Media = inputMedia - - default: - return nil, fmt.Errorf("不支持的媒体类型: %T", media) - } - - req.SetEntities(msg.Entities) - req.SetFlags() - - return ctx.SendMedia(chatID, req) -} - -func sendFileToTelegram(ctx *ext.Context, update *ext.Update) error { - args := strings.Split(string(update.CallbackQuery.Data), " ") - if len(args) < 3 { - ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{ - QueryID: update.CallbackQuery.QueryID, - Alert: true, - Message: "参数错误", - CacheTime: 5, - }) - return dispatcher.EndGroups - } - fileChatID, _ := strconv.Atoi(args[1]) - fileMessageID, _ := strconv.Atoi(args[2]) - fileMessage, err := GetTGMessage(ctx, int64(fileChatID), fileMessageID) - if err != nil { - ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{ - QueryID: update.CallbackQuery.QueryID, - Alert: true, - Message: "无法获取文件消息", - CacheTime: 5, - }) - return dispatcher.EndGroups - } - _, err = copyMediaToChat(ctx, fileMessage, update.EffectiveChat().GetID()) - if err != nil { - ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{ - QueryID: update.CallbackQuery.QueryID, - Alert: true, - Message: fmt.Sprintf("发送文件失败: %s", err), - CacheTime: 5, - }) - } else { - ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{ - QueryID: update.CallbackQuery.QueryID, - }) - } - return dispatcher.EndGroups -} diff --git a/bot/handle_silent.go b/bot/handle_silent.go deleted file mode 100644 index eaeb1ee..0000000 --- a/bot/handle_silent.go +++ /dev/null @@ -1,30 +0,0 @@ -package bot - -import ( - "fmt" - - "github.com/celestix/gotgproto/dispatcher" - "github.com/celestix/gotgproto/ext" - "github.com/krau/SaveAny-Bot/common" - "github.com/krau/SaveAny-Bot/dao" -) - -func silent(ctx *ext.Context, update *ext.Update) error { - user, err := dao.GetUserByChatID(update.GetUserChat().GetID()) - if err != nil { - common.Log.Errorf("获取用户失败: %s", err) - return dispatcher.EndGroups - } - if !user.Silent && user.DefaultStorage == "" { - ctx.Reply(update, ext.ReplyTextString("请先使用 /storage 设置默认存储位置"), nil) - return dispatcher.EndGroups - } - user.Silent = !user.Silent - if err := dao.UpdateUser(user); err != nil { - common.Log.Errorf("更新用户失败: %s", err) - ctx.Reply(update, ext.ReplyTextString("更新用户失败"), nil) - return dispatcher.EndGroups - } - ctx.Reply(update, ext.ReplyTextString(fmt.Sprintf("已%s静默模式", map[bool]string{true: "开启", false: "关闭"}[user.Silent])), nil) - return dispatcher.EndGroups -} diff --git a/bot/handle_storage.go b/bot/handle_storage.go deleted file mode 100644 index 13b314a..0000000 --- a/bot/handle_storage.go +++ /dev/null @@ -1,99 +0,0 @@ -package bot - -import ( - "fmt" - "strconv" - "strings" - - "github.com/celestix/gotgproto/dispatcher" - "github.com/celestix/gotgproto/ext" - "github.com/gotd/td/tg" - "github.com/krau/SaveAny-Bot/common" - "github.com/krau/SaveAny-Bot/dao" - "github.com/krau/SaveAny-Bot/storage" -) - -func storageCmd(ctx *ext.Context, update *ext.Update) error { - userChatID := update.GetUserChat().GetID() - storages := storage.GetUserStorages(userChatID) - if len(storages) == 0 { - ctx.Reply(update, ext.ReplyTextString("无可用的存储"), nil) - return dispatcher.EndGroups - } - markup, err := getSetDefaultStorageMarkup(userChatID, storages) - if err != nil { - common.Log.Errorf("Failed to get markup: %s", err) - ctx.Reply(update, ext.ReplyTextString("获取存储位置失败"), nil) - return dispatcher.EndGroups - } - ctx.Reply(update, ext.ReplyTextString("请选择要设为默认的存储位置"), &ext.ReplyOpts{ - Markup: markup, - }) - return dispatcher.EndGroups -} - -func setDefaultStorage(ctx *ext.Context, update *ext.Update) error { - args := strings.Split(string(update.CallbackQuery.Data), " ") - userID, _ := strconv.Atoi(args[1]) - if userID != int(update.CallbackQuery.GetUserID()) { - ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{ - QueryID: update.CallbackQuery.QueryID, - Alert: true, - Message: "你没有权限", - CacheTime: 5, - }) - return dispatcher.EndGroups - } - cbDataId, _ := strconv.Atoi(args[2]) - storageName, err := dao.GetCallbackData(uint(cbDataId)) - if err != nil { - common.Log.Errorf("获取回调数据失败: %s", err) - ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{ - QueryID: update.CallbackQuery.QueryID, - Alert: true, - Message: "获取回调数据失败", - CacheTime: 5, - }) - return dispatcher.EndGroups - } - - selectedStorage, err := storage.GetStorageByName(storageName) - - if err != nil { - common.Log.Errorf("获取指定存储失败: %s", err) - ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{ - QueryID: update.CallbackQuery.QueryID, - Alert: true, - Message: "获取指定存储失败", - CacheTime: 5, - }) - return dispatcher.EndGroups - } - user, err := dao.GetUserByChatID(int64(userID)) - if err != nil { - common.Log.Errorf("Failed to get user: %s", err) - ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{ - QueryID: update.CallbackQuery.QueryID, - Alert: true, - Message: "获取用户失败", - CacheTime: 5, - }) - return dispatcher.EndGroups - } - user.DefaultStorage = storageName - if err := dao.UpdateUser(user); err != nil { - common.Log.Errorf("Failed to update user: %s", err) - ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{ - QueryID: update.CallbackQuery.QueryID, - Alert: true, - Message: "更新用户失败", - CacheTime: 5, - }) - return dispatcher.EndGroups - } - ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{ - Message: fmt.Sprintf("已将 %s (%s) 设为默认存储位置", selectedStorage.Name(), selectedStorage.Type()), - ID: update.CallbackQuery.GetMsgID(), - }) - return dispatcher.EndGroups -} diff --git a/bot/handle_telegraph.go b/bot/handle_telegraph.go deleted file mode 100644 index 9d1e0a1..0000000 --- a/bot/handle_telegraph.go +++ /dev/null @@ -1,114 +0,0 @@ -package bot - -import ( - "fmt" - "net/http" - "net/url" - "regexp" - "strings" - "time" - - "github.com/celestix/gotgproto/dispatcher" - "github.com/celestix/gotgproto/ext" - "github.com/celestix/telegraph-go/v2" - "github.com/gotd/td/tg" - "github.com/krau/SaveAny-Bot/common" - "github.com/krau/SaveAny-Bot/config" - "github.com/krau/SaveAny-Bot/dao" - "github.com/krau/SaveAny-Bot/storage" - "github.com/krau/SaveAny-Bot/types" -) - -var ( - TelegraphClient *telegraph.TelegraphClient - TelegraphUrlRegexString = `https://telegra.ph/.*` - TelegraphUrlRegex = regexp.MustCompile(TelegraphUrlRegexString) -) - -func InitTelegraphClient() { - var httpClient *http.Client - if config.Cfg.Telegram.Proxy.Enable { - proxyUrl, err := url.Parse(config.Cfg.Telegram.Proxy.URL) - if err != nil { - fmt.Println("Error parsing proxy URL:", err) - return - } - proxy := http.ProxyURL(proxyUrl) - httpClient = &http.Client{ - Transport: &http.Transport{ - Proxy: proxy, - }, - Timeout: 30 * time.Second, - } - } else { - httpClient = &http.Client{ - Timeout: 30 * time.Second, - } - } - TelegraphClient = telegraph.GetTelegraphClient(&telegraph.ClientOpt{HttpClient: httpClient}) -} - -func handleTelegraph(ctx *ext.Context, update *ext.Update) error { - common.Log.Trace("Got telegraph link") - tgphUrl := TelegraphUrlRegex.FindString(update.EffectiveMessage.Text) - if tgphUrl == "" { - return dispatcher.ContinueGroups - } - replied, err := ctx.Reply(update, ext.ReplyTextString("正在获取文件..."), nil) - if err != nil { - common.Log.Errorf("回复失败: %s", err) - return dispatcher.EndGroups - } - user, err := dao.GetUserByChatID(update.GetUserChat().GetID()) - if err != nil { - common.Log.Errorf("获取用户失败: %s", err) - ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil) - return dispatcher.EndGroups - } - storages := storage.GetUserStorages(user.ChatID) - - if len(storages) == 0 { - ctx.Reply(update, ext.ReplyTextString("无可用的存储"), nil) - return dispatcher.EndGroups - } - - tgphPath := strings.Split(tgphUrl, "/")[len(strings.Split(tgphUrl, "/"))-1] - fileName, err := url.PathUnescape(tgphPath) - if err != nil { - common.Log.Errorf("解析 Telegraph 路径失败: %s", err) - fileName = tgphPath - } - - record := &dao.ReceivedFile{ - Processing: false, - FileName: fileName, - ChatID: update.EffectiveChat().GetID(), - MessageID: update.EffectiveMessage.GetID(), - ReplyMessageID: replied.ID, - ReplyChatID: update.EffectiveChat().GetID(), - IsTelegraph: true, - TelegraphURL: tgphUrl, - } - if _, err := dao.SaveReceivedFile(record); err != nil { - common.Log.Errorf("保存接收的文件失败: %s", err) - ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{ - Message: "无法保存文件: " + err.Error(), - ID: replied.ID, - }) - return dispatcher.EndGroups - } - - if !user.Silent || user.DefaultStorage == "" { - return ProvideSelectMessage(ctx, update, fileName, update.EffectiveChat().GetID(), update.EffectiveMessage.GetID(), replied.ID) - } - return HandleSilentAddTask(ctx, update, user, &types.Task{ - Ctx: ctx, - Status: types.Pending, - StorageName: user.DefaultStorage, - UserID: user.ChatID, - ReplyMessageID: replied.ID, - ReplyChatID: update.GetUserChat().GetID(), - IsTelegraph: true, - TelegraphURL: tgphUrl, - }) -} diff --git a/bot/handlers.go b/bot/handlers.go deleted file mode 100644 index d4eef3c..0000000 --- a/bot/handlers.go +++ /dev/null @@ -1,41 +0,0 @@ -package bot - -import ( - "github.com/celestix/gotgproto/dispatcher" - "github.com/celestix/gotgproto/dispatcher/handlers" - "github.com/celestix/gotgproto/dispatcher/handlers/filters" - "github.com/celestix/gotgproto/ext" - "github.com/krau/SaveAny-Bot/common" -) - -func RegisterHandlers(disp dispatcher.Dispatcher) { - disp.AddHandler(handlers.NewMessage(filters.Message.ChatType(filters.ChatTypeChannel), func(ctx *ext.Context, u *ext.Update) error { - return dispatcher.EndGroups - })) - disp.AddHandler(handlers.NewMessage(filters.Message.ChatType(filters.ChatTypeChat), func(ctx *ext.Context, u *ext.Update) error { - return dispatcher.EndGroups - })) - disp.AddHandler(handlers.NewMessage(filters.Message.All, checkPermission)) - disp.AddHandler(handlers.NewCommand("start", start)) - disp.AddHandler(handlers.NewCommand("help", help)) - disp.AddHandler(handlers.NewCommand("silent", silent)) - disp.AddHandler(handlers.NewCommand("storage", storageCmd)) - disp.AddHandler(handlers.NewCommand("save", saveCmd)) - disp.AddHandler(handlers.NewCommand("dir", dirCmd)) - disp.AddHandler(handlers.NewCommand("rule", ruleCmd)) - linkRegexFilter, err := filters.Message.Regex(linkRegexString) - if err != nil { - common.Log.Panicf("创建正则表达式过滤器失败: %s", err) - } - disp.AddHandler(handlers.NewMessage(linkRegexFilter, handleLinkMessage)) - telegraphUrlRegexFilter, err := filters.Message.Regex(TelegraphUrlRegexString) - if err != nil { - common.Log.Panicf("创建 Telegraph URL 正则表达式过滤器失败: %s", err) - } - disp.AddHandler(handlers.NewMessage(telegraphUrlRegexFilter, handleTelegraph)) - disp.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("add"), AddToQueue)) - disp.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("set_default"), setDefaultStorage)) - disp.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("cancel"), cancelTask)) - disp.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("send_here"), sendFileToTelegram)) - disp.AddHandler(handlers.NewMessage(filters.Message.Media, handleFileMessage)) -} diff --git a/bot/middlewares.go b/bot/middlewares.go deleted file mode 100644 index be2b5da..0000000 --- a/bot/middlewares.go +++ /dev/null @@ -1,37 +0,0 @@ -package bot - -import ( - "time" - - "github.com/celestix/gotgproto/dispatcher" - "github.com/celestix/gotgproto/ext" - "github.com/duke-git/lancet/v2/slice" - "github.com/gotd/contrib/middleware/floodwait" - "github.com/gotd/contrib/middleware/ratelimit" - "github.com/gotd/td/telegram" - "github.com/krau/SaveAny-Bot/config" - "golang.org/x/time/rate" -) - -func FloodWaitMiddleware() []telegram.Middleware { - waiter := floodwait.NewSimpleWaiter().WithMaxRetries(uint(config.Cfg.Telegram.FloodRetry)) - ratelimiter := ratelimit.New(rate.Every(time.Millisecond*100), 5) - return []telegram.Middleware{ - waiter, - ratelimiter, - } -} - -const noPermissionText string = ` -您不在白名单中, 无法使用此 Bot. -您可以部署自己的实例: https://github.com/krau/SaveAny-Bot -` - -func checkPermission(ctx *ext.Context, update *ext.Update) error { - userID := update.GetUserChat().GetID() - if !slice.Contain(config.Cfg.GetUsersID(), userID) { - ctx.Reply(update, ext.ReplyTextString(noPermissionText), nil) - return dispatcher.EndGroups - } - return dispatcher.ContinueGroups -} \ No newline at end of file diff --git a/bot/utils.go b/bot/utils.go deleted file mode 100644 index 3ed3649..0000000 --- a/bot/utils.go +++ /dev/null @@ -1,450 +0,0 @@ -package bot - -import ( - "context" - "errors" - "fmt" - "strconv" - "strings" - "time" - - "github.com/celestix/gotgproto/dispatcher" - "github.com/celestix/gotgproto/ext" - "github.com/gabriel-vasile/mimetype" - "github.com/gotd/td/telegram/message/entity" - "github.com/gotd/td/telegram/message/styling" - "github.com/gotd/td/telegram/query" - "github.com/gotd/td/tg" - "github.com/krau/SaveAny-Bot/common" - "github.com/krau/SaveAny-Bot/dao" - "github.com/krau/SaveAny-Bot/queue" - "github.com/krau/SaveAny-Bot/storage" - "github.com/krau/SaveAny-Bot/types" -) - -var ( - ErrEmptyDocument = errors.New("document is empty") - ErrEmptyPhoto = errors.New("photo is empty") - ErrEmptyPhotoSize = errors.New("photo size is empty") - ErrEmptyPhotoSizes = errors.New("photo size slice is empty") - ErrNoStorages = errors.New("no available storage") - ErrEmptyMessage = errors.New("message is empty") -) - -func supportedMediaFilter(m *tg.Message) (bool, error) { - if not := m.Media == nil; not { - return false, dispatcher.EndGroups - } - switch m.Media.(type) { - case *tg.MessageMediaDocument: - return true, nil - case *tg.MessageMediaPhoto: - return true, nil - default: - return false, nil - } -} - -func getSelectStorageMarkup(userChatID int64, fileChatID, fileMessageID int) (*tg.ReplyInlineMarkup, error) { - user, err := dao.GetUserByChatID(userChatID) - if err != nil { - return nil, fmt.Errorf("failed to get user by chat ID: %d, error: %w", userChatID, err) - } - storages := storage.GetUserStorages(user.ChatID) - // if len(storages) == 0 { - // return nil, ErrNoStorages - // } - - buttons := make([]tg.KeyboardButtonClass, 0) - for _, storage := range storages { - cbData := fmt.Sprintf("%d %d %s 0", fileChatID, fileMessageID, storage.Name()) // 0 for empty dir id - cbDataId, err := dao.CreateCallbackData(cbData) - if err != nil { - return nil, fmt.Errorf("failed to create callback data: %w", err) - } - buttons = append(buttons, &tg.KeyboardButtonCallback{ - Text: storage.Name(), - Data: fmt.Appendf(nil, "add %d", cbDataId), - }) - } - markup := &tg.ReplyInlineMarkup{} - for i := 0; i < len(buttons); i += 3 { - row := tg.KeyboardButtonRow{} - row.Buttons = buttons[i:min(i+3, len(buttons))] - markup.Rows = append(markup.Rows, row) - } - markup.Rows = append(markup.Rows, tg.KeyboardButtonRow{ - Buttons: []tg.KeyboardButtonClass{ - &tg.KeyboardButtonCallback{ - Text: "发送到当前聊天", - Data: []byte(fmt.Sprintf("send_here %d %d", fileChatID, fileMessageID)), - }, - }, - }) - return markup, nil -} - -func getSelectDirMarkup(fileChatID, fileMessageID int, storageName string, dirs []dao.Dir) (*tg.ReplyInlineMarkup, error) { - buttons := make([]tg.KeyboardButtonClass, 0) - for _, dir := range dirs { - if dir.ID == 0 || dir.StorageName != storageName { - return nil, fmt.Errorf("unexpected dir: %v", dir) - } - cbDataId, err := dao.CreateCallbackData(fmt.Sprintf("%d %d %s %d", fileChatID, fileMessageID, storageName, dir.ID)) - if err != nil { - return nil, fmt.Errorf("failed to create callback data: %w", err) - } - buttons = append(buttons, &tg.KeyboardButtonCallback{ - Text: dir.Path, - Data: []byte(fmt.Sprintf("add_to_dir %d", cbDataId)), - }) - } - markup := &tg.ReplyInlineMarkup{} - for i := 0; i < len(buttons); i += 3 { - row := tg.KeyboardButtonRow{} - row.Buttons = buttons[i:min(i+3, len(buttons))] - markup.Rows = append(markup.Rows, row) - } - return markup, nil -} - -func getSetDefaultStorageMarkup(userChatID int64, storages []storage.Storage) (*tg.ReplyInlineMarkup, error) { - buttons := make([]tg.KeyboardButtonClass, 0) - for _, storage := range storages { - cbDataId, err := dao.CreateCallbackData(storage.Name()) - if err != nil { - return nil, fmt.Errorf("failed to create callback data: %w", err) - } - buttons = append(buttons, &tg.KeyboardButtonCallback{ - Text: storage.Name(), - Data: []byte(fmt.Sprintf("set_default %d %d", userChatID, cbDataId)), - }) - } - markup := &tg.ReplyInlineMarkup{} - for i := 0; i < len(buttons); i += 3 { - row := tg.KeyboardButtonRow{} - row.Buttons = buttons[i:min(i+3, len(buttons))] - markup.Rows = append(markup.Rows, row) - } - return markup, nil -} - -func FileFromMedia(media tg.MessageMediaClass, customFileName string) (*types.File, error) { - switch media := media.(type) { - case *tg.MessageMediaDocument: - document, ok := media.Document.AsNotEmpty() - if !ok { - return nil, ErrEmptyDocument - } - if customFileName != "" { - return &types.File{ - Location: document.AsInputDocumentFileLocation(), - FileSize: document.Size, - FileName: customFileName, - }, nil - } - fileName := "" - for _, attribute := range document.Attributes { - if name, ok := attribute.(*tg.DocumentAttributeFilename); ok { - fileName = name.GetFileName() - break - } - } - return &types.File{ - Location: document.AsInputDocumentFileLocation(), - FileSize: document.Size, - FileName: fileName, - }, nil - case *tg.MessageMediaPhoto: - photo, ok := media.Photo.AsNotEmpty() - if !ok { - return nil, ErrEmptyPhoto - } - sizes := photo.Sizes - if len(sizes) == 0 { - return nil, ErrEmptyPhotoSizes - } - photoSize := sizes[len(sizes)-1] - size, ok := photoSize.AsNotEmpty() - if !ok { - return nil, ErrEmptyPhotoSize - } - location := new(tg.InputPhotoFileLocation) - location.ID = photo.GetID() - location.AccessHash = photo.GetAccessHash() - location.FileReference = photo.GetFileReference() - location.ThumbSize = size.GetType() - fileName := customFileName - if fileName == "" { - fileName = fmt.Sprintf("photo_%s_%d.jpg", time.Now().Format("2006-01-02_15-04-05"), photo.GetID()) - } - return &types.File{ - Location: location, - FileSize: 0, - FileName: fileName, - }, nil - - } - return nil, fmt.Errorf("unexpected type %T", media) -} - -func FileFromMessage(ctx *ext.Context, chatID int64, messageID int, customFileName string) (*types.File, error) { - key := fmt.Sprintf("file:%d:%d", chatID, messageID) - cachedFile, err := common.CacheGet[*types.File](ctx, key) - if err == nil { - if customFileName != "" { - cachedFile.FileName = customFileName - } - return cachedFile, nil - } - common.Log.Debugf("Getting file: %s", key) - message, err := GetTGMessage(ctx, chatID, messageID) - if err != nil { - return nil, err - } - file, err := FileFromMedia(message.Media, customFileName) - if err != nil { - return nil, err - } - if err := common.CacheSet(ctx, key, file); err != nil { - common.Log.Errorf("Failed to cache file: %s", err) - } - return file, nil -} - -func GetTGMessage(ctx *ext.Context, chatId int64, messageID int) (*tg.Message, error) { - key := fmt.Sprintf("message:%d:%d", chatId, messageID) - cacheMessage, err := common.CacheGet[*tg.Message](ctx, key) - if err == nil { - return cacheMessage, nil - } - common.Log.Debugf("Fetching message: %d:%d", chatId, messageID) - messages, err := ctx.GetMessages(chatId, []tg.InputMessageClass{&tg.InputMessageID{ID: messageID}}) - if err != nil { - return nil, err - } - if len(messages) == 0 { - return nil, ErrEmptyMessage - } - msg := messages[0] - tgMessage, ok := msg.(*tg.Message) - if !ok { - return nil, fmt.Errorf("unexpected message type: %T", msg) - } - if err := common.CacheSet(ctx, key, tgMessage); err != nil { - common.Log.Errorf("Failed to cache message: %s", err) - } - return tgMessage, nil -} - -// Userbot only -// -// https://github.com/iyear/tdl/blob/fbb396da774ba544e527c3ef41c44921ad74ee98/core/util/tutil/tutil.go#L174 -func GetSingleHistoryMessage(ctx context.Context, client *tg.Client, peer tg.InputPeerClass, msg int) (*tg.Message, error) { - it := query.Messages(client).GetHistory(peer).OffsetID(msg + 1).BatchSize(1).Iter() - - if !it.Next(ctx) { - return nil, fmt.Errorf("failed to get message %d from %s: %w", msg, peer, it.Err()) - } - - m, ok := it.Value().Msg.(*tg.Message) - if !ok { - return nil, fmt.Errorf("invalid message %d", msg) - } - - if m.GetID() != msg { - return nil, fmt.Errorf("the message %d/%d may be deleted", GetInputPeerID(peer), msg) - } - return m, nil -} - -// Userbot only -func GetHistoryMessages(ctx context.Context, client *tg.Client, peer tg.InputPeerClass, startID, limit int) ([]*tg.Message, error) { - endID := startID + limit - 1 - msgs, err := client.MessagesGetHistory(ctx, &tg.MessagesGetHistoryRequest{ - Peer: peer, - OffsetID: startID, - Limit: limit, - AddOffset: startID - endID, - }) - if err != nil { - return nil, fmt.Errorf("failed to get history messages: %w", err) - } - var msgClass []tg.MessageClass - switch msgsv := msgs.(type) { - case *tg.MessagesMessages: - msgClass = msgsv.GetMessages() - case *tg.MessagesMessagesSlice: - msgClass = msgsv.GetMessages() - case *tg.MessagesChannelMessages: - msgClass = msgsv.GetMessages() - default: - return nil, fmt.Errorf("unexpected messages type: %T", msgs) - } - - messageBatch := make([]*tg.Message, 0, 100) - - for _, msg := range msgClass { - msgNotEmpty, ok := msg.AsNotEmpty() - if !ok { - continue - } - switch msgNotEmptyV := msgNotEmpty.(type) { - case *tg.Message: - messageBatch = append(messageBatch, msgNotEmptyV) - default: - common.Log.Warnf("Unexpected message type: %T, skipping", msgNotEmptyV) - continue - } - } - if len(messageBatch) == 0 { - return nil, fmt.Errorf("no messages found for peer %s with startID %d and limit %d", peer, startID, limit) - } - return messageBatch, nil -} - -func GetMediaGroup(ctx *ext.Context, chatID int64, messageID int, groupID int64) ([]*tg.Message, error) { - peer := ctx.PeerStorage.GetInputPeerById(chatID) - if peer == nil { - return nil, fmt.Errorf("无法获取聊天 %d 的输入 Peer", chatID) - } - messages, err := GetHistoryMessages(ctx, ctx.Raw, peer, messageID-9, 20) - if err != nil { - return nil, fmt.Errorf("获取消息失败: %s", err) - } - var groupMessages []*tg.Message - for _, msg := range messages { - gID, isGroup := msg.GetGroupedID() - if isGroup && gID == groupID { - groupMessages = append(groupMessages, msg) - } - } - if len(groupMessages) == 0 || (len(groupMessages) == 1 && groupMessages[0].ID == messageID) { - return nil, fmt.Errorf("未找到媒体组 %d 中的消息", groupID) - } - return groupMessages, nil -} - -func GetInputPeerID(peer tg.InputPeerClass) int64 { - switch p := peer.(type) { - case *tg.InputPeerUser: - return p.UserID - case *tg.InputPeerChat: - return p.ChatID - case *tg.InputPeerChannel: - return p.ChannelID - } - - return 0 -} - -func ProvideSelectMessage(ctx *ext.Context, update *ext.Update, fileName string, chatID int64, fileMsgID, toEditMsgID int) error { - entityBuilder := entity.Builder{} - var entities []tg.MessageEntityClass - text := fmt.Sprintf("文件名: %s\n请选择存储位置", fileName) - if err := styling.Perform(&entityBuilder, - styling.Plain("文件名: "), - styling.Code(fileName), - styling.Plain("\n请选择存储位置"), - ); err != nil { - common.Log.Errorf("Failed to build entity: %s", err) - } else { - text, entities = entityBuilder.Complete() - } - markup, err := getSelectStorageMarkup(update.GetUserChat().GetID(), int(chatID), fileMsgID) - if errors.Is(err, ErrNoStorages) { - common.Log.Errorf("Failed to get select storage markup: %s", err) - ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{ - Message: "无可用存储", - ID: toEditMsgID, - }) - return dispatcher.EndGroups - } else if err != nil { - common.Log.Errorf("Failed to get select storage markup: %s", err) - ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{ - Message: "无法获取存储", - ID: toEditMsgID, - }) - return dispatcher.EndGroups - } - _, err = ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{ - Message: text, - Entities: entities, - ReplyMarkup: markup, - ID: toEditMsgID, - }) - if err != nil { - common.Log.Errorf("Failed to reply: %s", err) - } - return dispatcher.EndGroups -} - -func HandleSilentAddTask(ctx *ext.Context, update *ext.Update, user *dao.User, task *types.Task) error { - if user.DefaultStorage == "" { - ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{ - Message: "请先使用 /storage 设置默认存储位置", - ID: task.ReplyMessageID, - }) - return dispatcher.EndGroups - } - queue.AddTask(task) - ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{ - Message: fmt.Sprintf("已添加到队列: %s\n当前排队任务数: %d", task.FileName(), queue.Len()), - ID: task.ReplyMessageID, - }) - return dispatcher.EndGroups -} - -func GenFileNameFromMessage(message tg.Message, file *types.File) string { - if file.FileName != "" { - return file.FileName - } - fileName := genFileNameFromMessageText(message, file) - media, ok := message.GetMedia() - if !ok { - return fileName - } - ext, ok := extraMediaExt(media) - if ok { - return fileName + ext - } - return fileName -} - -func genFileNameFromMessageText(message tg.Message, file *types.File) string { - text := strings.TrimSpace(message.GetMessage()) - if text == "" { - return file.Hash() - } - tags := common.ExtractTagsFromText(text) - if len(tags) > 0 { - return fmt.Sprintf("%s_%s", strings.Join(tags, "_"), strconv.Itoa(message.GetID())) - } - // 删除换行和特殊字符 - text = strings.Map(func(r rune) rune { - if r == '\n' || r == '\r' || r == '\t' || r == ' ' { - return '_' - } - return r - }, text) - runes := []rune(text) - return string(runes[:min(128, len(runes))]) -} - -func extraMediaExt(media tg.MessageMediaClass) (string, bool) { - switch media := media.(type) { - case *tg.MessageMediaDocument: - doc, ok := media.Document.AsNotEmpty() - if !ok { - return "", false - } - ext := mimetype.Lookup(doc.MimeType).Extension() - if ext == "" { - return "", false - } - return ext, true - case *tg.MessageMediaPhoto: - return ".jpg", true - } - return "", false -} diff --git a/bot/bot.go b/client/bot/bot.go similarity index 67% rename from bot/bot.go rename to client/bot/bot.go index 71ade9d..d442f7b 100644 --- a/bot/bot.go +++ b/client/bot/bot.go @@ -6,12 +6,16 @@ import ( "time" "github.com/celestix/gotgproto" + "github.com/celestix/gotgproto/dispatcher" + "github.com/celestix/gotgproto/ext" "github.com/celestix/gotgproto/sessionMaker" - "github.com/glebarez/sqlite" + "github.com/charmbracelet/log" "github.com/gotd/td/telegram/dcs" "github.com/gotd/td/tg" - "github.com/krau/SaveAny-Bot/common" + "github.com/krau/SaveAny-Bot/client/bot/handlers" + "github.com/krau/SaveAny-Bot/client/middleware" "github.com/krau/SaveAny-Bot/config" + "github.com/ncruces/go-sqlite3/gormlite" "golang.org/x/net/proxy" ) @@ -25,11 +29,8 @@ func newProxyDialer(proxyUrl string) (proxy.Dialer, error) { return proxy.FromURL(url, proxy.Direct) } -func Init() { - common.Log.Info("初始化 Telegram 客户端...") - ctx, cancel := context.WithTimeout(context.Background(), time.Duration(config.Cfg.Telegram.Timeout)*time.Second) - defer cancel() - go InitTelegraphClient() +func Init(ctx context.Context) { + log.FromContext(ctx).Info("初始化 Bot...") resultChan := make(chan struct { client *gotgproto.Client err error @@ -55,11 +56,17 @@ func Init() { config.Cfg.Telegram.AppHash, gotgproto.ClientTypeBot(config.Cfg.Telegram.Token), &gotgproto.ClientOpts{ - Session: sessionMaker.SqlSession(sqlite.Open(config.Cfg.DB.Session)), + Session: sessionMaker.SqlSession(gormlite.Open(config.Cfg.DB.Session)), DisableCopyright: true, - Middlewares: FloodWaitMiddleware(), + Middlewares: middleware.NewDefaultMiddlewares(ctx, 5*time.Minute), Resolver: resolver, + Context: ctx, MaxRetries: config.Cfg.Telegram.RpcRetry, + AutoFetchReply: true, + ErrorHandler: func(ctx *ext.Context, u *ext.Update, s string) error { + log.FromContext(ctx).Errorf("Unhandled error: %s", s) + return dispatcher.EndGroups + }, }, ) if err != nil { @@ -69,6 +76,9 @@ func Init() { }{nil, err} return } + client.API().BotsSetBotCommands(ctx, &tg.BotsSetBotCommandsRequest{ + Scope: &tg.BotCommandScopeDefault{}, + }) _, err = client.API().BotsSetBotCommands(ctx, &tg.BotsSetBotCommandsRequest{ Scope: &tg.BotCommandScopeDefault{}, Commands: []tg.BotCommand{ @@ -89,13 +99,13 @@ func Init() { select { case <-ctx.Done(): - common.Log.Panic("初始化客户端失败: 超时") + log.FromContext(ctx).Errorf("已取消 Bot 初始化: %s", ctx.Err()) case result := <-resultChan: if result.err != nil { - common.Log.Panicf("初始化客户端失败: %s", result.err) + log.FromContext(ctx).Fatalf("初始化 Bot 失败: %s", result.err) } Client = result.client - RegisterHandlers(Client.Dispatcher) - common.Log.Info("客户端初始化完成") + handlers.Register(Client.Dispatcher) + log.FromContext(ctx).Info("Bot 初始化完成") } } diff --git a/client/bot/handlers/add_task.go b/client/bot/handlers/add_task.go new file mode 100644 index 0000000..40c348c --- /dev/null +++ b/client/bot/handlers/add_task.go @@ -0,0 +1,80 @@ +package handlers + +import ( + "errors" + "fmt" + "strings" + + "github.com/celestix/gotgproto/dispatcher" + "github.com/celestix/gotgproto/ext" + "github.com/charmbracelet/log" + "github.com/gotd/td/tg" + "github.com/krau/SaveAny-Bot/client/bot/handlers/utils/msgelem" + "github.com/krau/SaveAny-Bot/client/bot/handlers/utils/shortcut" + "github.com/krau/SaveAny-Bot/database" + "github.com/krau/SaveAny-Bot/pkg/enums/tasktype" + "github.com/krau/SaveAny-Bot/pkg/tcbdata" + "github.com/krau/SaveAny-Bot/storage" + "gorm.io/gorm" +) + +func handleAddCallback(ctx *ext.Context, update *ext.Update) error { + dataid := strings.Split(string(update.CallbackQuery.Data), " ")[1] + data, err := shortcut.GetCallbackDataWithAnswer[tcbdata.Add](ctx, update, dataid) + if err != nil { + return err + } + queryID := update.CallbackQuery.GetQueryID() + msgID := update.CallbackQuery.GetMsgID() + userID := update.CallbackQuery.GetUserID() + + selectedStorage, err := storage.GetStorageByUserIDAndName(ctx, userID, data.SelectedStorName) + if err != nil { + log.FromContext(ctx).Errorf("Failed to get storage: %s", err) + ctx.AnswerCallback(msgelem.AlertCallbackAnswer(queryID, "存储获取失败: "+err.Error())) + return dispatcher.EndGroups + } + dirs, err := database.GetDirsByUserChatIDAndStorageName(ctx, userID, data.SelectedStorName) + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return fmt.Errorf("获取用户目录失败: %w", err) + } + + if !data.SettedDir && len(dirs) != 0 { + // ask for directory selection + markup, err := msgelem.BuildSetDirKeyboard(dirs, dataid) + if err != nil { + log.FromContext(ctx).Errorf("Failed to build directory keyboard: %s", err) + ctx.AnswerCallback(msgelem.AlertCallbackAnswer(queryID, "目录键盘构建失败: "+err.Error())) + return dispatcher.EndGroups + } + ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ + ID: update.CallbackQuery.GetMsgID(), + Message: "请选择要存储到的目录", + ReplyMarkup: markup, + }) + return dispatcher.EndGroups + } + + dirPath := "" + if data.DirID != 0 { + dir, err := database.GetDirByID(ctx, data.DirID) + if err != nil { + ctx.AnswerCallback(msgelem.AlertCallbackAnswer(queryID, "获取目录失败: "+err.Error())) + return dispatcher.EndGroups + } + dirPath = dir.Path + } + + switch data.TaskType { + case tasktype.TaskTypeTgfiles: + if data.AsBatch { + return shortcut.CreateAndAddBatchTGFileTaskWithEdit(ctx, userID, selectedStorage, dirPath, data.Files, msgID) + } + return shortcut.CreateAndAddTGFileTaskWithEdit(ctx, userID, selectedStorage, dirPath, data.Files[0], msgID) + case tasktype.TaskTypeTphpics: + return shortcut.CreateAndAddTphTaskWithEdit(ctx, userID, data.TphPageNode, data.TphDirPath, data.TphPics, selectedStorage, msgID) + default: + log.FromContext(ctx).Errorf("Unsupported task type: %s", data.TaskType) + } + return dispatcher.EndGroups +} diff --git a/client/bot/handlers/cancel_task.go b/client/bot/handlers/cancel_task.go new file mode 100644 index 0000000..ecb5fdc --- /dev/null +++ b/client/bot/handlers/cancel_task.go @@ -0,0 +1,28 @@ +package handlers + +import ( + "strings" + + "github.com/celestix/gotgproto/dispatcher" + "github.com/celestix/gotgproto/ext" + "github.com/charmbracelet/log" + "github.com/gotd/td/tg" + "github.com/krau/SaveAny-Bot/client/bot/handlers/utils/msgelem" + "github.com/krau/SaveAny-Bot/core" +) + +func handleCancelCallback(ctx *ext.Context, update *ext.Update) error { + taskid := strings.Split(string(update.CallbackQuery.Data), " ")[1] + if err := core.CancelTask(ctx, taskid); err != nil { + log.FromContext(ctx).Errorf("error cancelling task %s: %v", taskid, err) + ctx.AnswerCallback(msgelem.AlertCallbackAnswer(update.CallbackQuery.GetQueryID(), "取消任务失败: "+err.Error())) + return dispatcher.EndGroups + } + + ctx.EditMessage(update.CallbackQuery.GetUserID(), &tg.MessagesEditMessageRequest{ + ID: update.CallbackQuery.GetMsgID(), + Message: "正在取消任务...", + }) + + return dispatcher.EndGroups +} diff --git a/client/bot/handlers/dir.go b/client/bot/handlers/dir.go new file mode 100644 index 0000000..d30ec40 --- /dev/null +++ b/client/bot/handlers/dir.go @@ -0,0 +1,74 @@ +package handlers + +import ( + "strconv" + "strings" + + "github.com/celestix/gotgproto/dispatcher" + "github.com/celestix/gotgproto/ext" + "github.com/charmbracelet/log" + "github.com/krau/SaveAny-Bot/client/bot/handlers/utils/msgelem" + "github.com/krau/SaveAny-Bot/database" + "github.com/krau/SaveAny-Bot/storage" +) + +func handleDirCmd(ctx *ext.Context, update *ext.Update) error { + logger := log.FromContext(ctx) + args := strings.Split(update.EffectiveMessage.Text, " ") + userChatID := update.GetUserChat().GetID() + dirs, err := database.GetUserDirsByChatID(ctx, userChatID) + if err != nil { + logger.Errorf("获取用户文件夹失败: %s", err) + ctx.Reply(update, ext.ReplyTextString("获取用户文件夹失败"), nil) + return dispatcher.EndGroups + } + if len(args) < 2 { + ctx.Reply(update, ext.ReplyTextStyledTextArray(msgelem.BuildDirHelpStyling(dirs)), nil) + return dispatcher.EndGroups + } + user, err := database.GetUserByChatID(ctx, update.GetUserChat().GetID()) + if err != nil { + logger.Errorf("获取用户失败: %s", err) + ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil) + return dispatcher.EndGroups + } + switch args[1] { + case "add": + // /dir add local1 path/to/dir + if len(args) < 4 { + ctx.Reply(update, ext.ReplyTextStyledTextArray(msgelem.BuildDirHelpStyling(dirs)), nil) + return dispatcher.EndGroups + } + if _, err := storage.GetStorageByUserIDAndName(ctx, user.ChatID, args[2]); err != nil { + ctx.Reply(update, ext.ReplyTextString(err.Error()), nil) + return dispatcher.EndGroups + } + + if err := database.CreateDirForUser(ctx, user.ID, args[2], args[3]); err != nil { + logger.Errorf("创建文件夹失败: %s", err) + ctx.Reply(update, ext.ReplyTextString("创建文件夹失败"), nil) + return dispatcher.EndGroups + } + ctx.Reply(update, ext.ReplyTextString("文件夹添加成功"), nil) + case "del": + // /dir del 3 + if len(args) < 3 { + ctx.Reply(update, ext.ReplyTextStyledTextArray(msgelem.BuildDirHelpStyling(dirs)), nil) + return dispatcher.EndGroups + } + dirID, err := strconv.Atoi(args[2]) + if err != nil { + ctx.Reply(update, ext.ReplyTextString("文件夹ID无效"), nil) + return dispatcher.EndGroups + } + if err := database.DeleteDirByID(ctx, uint(dirID)); err != nil { + logger.Errorf("删除文件夹失败: %s", err) + ctx.Reply(update, ext.ReplyTextString("删除文件夹失败"), nil) + return dispatcher.EndGroups + } + ctx.Reply(update, ext.ReplyTextString("文件夹删除成功"), nil) + default: + ctx.Reply(update, ext.ReplyTextString("未知操作"), nil) + } + return dispatcher.EndGroups +} diff --git a/bot/handle_start.go b/client/bot/handlers/help.go similarity index 54% rename from bot/handle_start.go rename to client/bot/handlers/help.go index 9c542da..5a6ff3d 100644 --- a/bot/handle_start.go +++ b/client/bot/handlers/help.go @@ -1,23 +1,15 @@ -package bot +package handlers import ( "fmt" "github.com/celestix/gotgproto/dispatcher" "github.com/celestix/gotgproto/ext" - "github.com/krau/SaveAny-Bot/common" - "github.com/krau/SaveAny-Bot/dao" + "github.com/krau/SaveAny-Bot/pkg/consts" ) -func start(ctx *ext.Context, update *ext.Update) error { - if err := dao.CreateUser(update.GetUserChat().GetID()); err != nil { - common.Log.Errorf("创建用户失败: %s", err) - return dispatcher.EndGroups - } - return help(ctx, update) -} - -const helpText string = ` +func handleHelpCmd(ctx *ext.Context, update *ext.Update) error { + const helpText string = ` Save Any Bot - 转存你的 Telegram 文件 版本: %s , 提交: %s 命令: @@ -33,8 +25,6 @@ Save Any Bot - 转存你的 Telegram 文件 向 Bot 发送(转发)文件, 或发送一个公开频道的消息链接以保存文件 ` - -func help(ctx *ext.Context, update *ext.Update) error { - ctx.Reply(update, ext.ReplyTextString(fmt.Sprintf(helpText, common.Version, common.GitCommit)), nil) + ctx.Reply(update, ext.ReplyTextString(fmt.Sprintf(helpText, consts.Version, consts.GitCommit)), nil) return dispatcher.EndGroups } diff --git a/client/bot/handlers/link.go b/client/bot/handlers/link.go new file mode 100644 index 0000000..54606f8 --- /dev/null +++ b/client/bot/handlers/link.go @@ -0,0 +1,63 @@ +package handlers + +import ( + "fmt" + + "github.com/celestix/gotgproto/dispatcher" + "github.com/celestix/gotgproto/ext" + "github.com/charmbracelet/log" + "github.com/krau/SaveAny-Bot/client/bot/handlers/utils/msgelem" + "github.com/krau/SaveAny-Bot/client/bot/handlers/utils/shortcut" + "github.com/krau/SaveAny-Bot/pkg/tcbdata" + "github.com/krau/SaveAny-Bot/storage" +) + +func handleMessageLink(ctx *ext.Context, update *ext.Update) error { + replied, files, editReplied, err := shortcut.GetFilesFromUpdateLinkMessageWithReplyEdit(ctx, update) + if err != nil { + return err + } + logger := log.FromContext(ctx) + userId := update.GetUserChat().GetID() + stors := storage.GetUserStorages(ctx, userId) + if len(files) == 1 { + req, err := msgelem.BuildAddOneSelectStorageMessage(ctx, stors, files[0], replied.ID) + if err != nil { + logger.Errorf("构建存储选择消息失败: %s", err) + editReplied("构建存储选择消息失败: "+err.Error(), nil) + return dispatcher.EndGroups + } + ctx.EditMessage(update.EffectiveChat().GetID(), req) + return dispatcher.EndGroups + } + markup, err := msgelem.BuildAddSelectStorageKeyboard(stors, tcbdata.Add{ + Files: files, + }) + if err != nil { + logger.Errorf("构建存储选择键盘失败: %s", err) + editReplied("构建存储选择键盘失败: "+err.Error(), nil) + return dispatcher.EndGroups + } + editReplied(fmt.Sprintf("找到 %d 个文件, 请选择存储位置", len(files)), + markup) + return dispatcher.EndGroups +} + +func handleSilentSaveLink(ctx *ext.Context, update *ext.Update) error { + logger := log.FromContext(ctx) + stor := storage.FromContext(ctx) + if stor == nil { + logger.Warn("Context storage is nil") + ctx.Reply(update, ext.ReplyTextString("未找到存储"), nil) + return dispatcher.EndGroups + } + replied, files, _, err := shortcut.GetFilesFromUpdateLinkMessageWithReplyEdit(ctx, update) + if err != nil { + return err + } + userId := update.GetUserChat().GetID() + if len(files) == 1 { + return shortcut.CreateAndAddTGFileTaskWithEdit(ctx, userId, stor, "", files[0], replied.ID) + } + return shortcut.CreateAndAddBatchTGFileTaskWithEdit(ctx, userId, stor, "", files, replied.ID) +} diff --git a/client/bot/handlers/media.go b/client/bot/handlers/media.go new file mode 100644 index 0000000..4e63261 --- /dev/null +++ b/client/bot/handlers/media.go @@ -0,0 +1,48 @@ +package handlers + +import ( + "github.com/celestix/gotgproto/dispatcher" + "github.com/celestix/gotgproto/ext" + "github.com/charmbracelet/log" + "github.com/krau/SaveAny-Bot/client/bot/handlers/utils/msgelem" + "github.com/krau/SaveAny-Bot/client/bot/handlers/utils/shortcut" + "github.com/krau/SaveAny-Bot/storage" +) + +func handleMediaMessage(ctx *ext.Context, update *ext.Update) error { + logger := log.FromContext(ctx) + message := update.EffectiveMessage.Message + logger.Debugf("Got media: %s", message.Media.TypeName()) + msg, file, err := shortcut.GetFileFromMessageWithReply(ctx, update, message) + if err != nil { + return err + } + userId := update.GetUserChat().GetID() + stors := storage.GetUserStorages(ctx, userId) + req, err := msgelem.BuildAddOneSelectStorageMessage(ctx, stors, file, msg.ID) + if err != nil { + logger.Errorf("构建存储选择消息失败: %s", err) + ctx.Reply(update, ext.ReplyTextString("构建存储选择消息失败: "+err.Error()), nil) + return dispatcher.EndGroups + } + ctx.EditMessage(update.EffectiveChat().GetID(), req) + return dispatcher.EndGroups +} + +func handleSilentSaveMedia(ctx *ext.Context, update *ext.Update) error { + logger := log.FromContext(ctx) + stor := storage.FromContext(ctx) + if stor == nil { + logger.Warn("Context storage is nil") + ctx.Reply(update, ext.ReplyTextString("未找到存储"), nil) + return dispatcher.EndGroups + } + message := update.EffectiveMessage.Message + logger.Debugf("Got media: %s", message.Media.TypeName()) + userID := update.GetUserChat().GetID() + msg, file, err := shortcut.GetFileFromMessageWithReply(ctx, update, message) + if err != nil { + return err + } + return shortcut.CreateAndAddTGFileTaskWithEdit(ctx, userID, stor, "", file, msg.ID) +} diff --git a/client/bot/handlers/middleware.go b/client/bot/handlers/middleware.go new file mode 100644 index 0000000..98945d6 --- /dev/null +++ b/client/bot/handlers/middleware.go @@ -0,0 +1,49 @@ +package handlers + +import ( + "github.com/celestix/gotgproto/dispatcher" + "github.com/celestix/gotgproto/ext" + "github.com/duke-git/lancet/v2/slice" + "github.com/krau/SaveAny-Bot/config" + "github.com/krau/SaveAny-Bot/database" + "github.com/krau/SaveAny-Bot/storage" +) + +func checkPermission(ctx *ext.Context, update *ext.Update) error { + userID := update.GetUserChat().GetID() + if !slice.Contain(config.Cfg.GetUsersID(), userID) { + const noPermissionText string = ` +您不在白名单中, 无法使用此 Bot. +您可以部署自己的实例: https://github.com/krau/SaveAny-Bot +` + ctx.Reply(update, ext.ReplyTextString(noPermissionText), nil) + return dispatcher.EndGroups + } + + return dispatcher.ContinueGroups +} + +func handleSilentMode(next func(*ext.Context, *ext.Update) error, handler func(*ext.Context, *ext.Update) error) func(*ext.Context, *ext.Update) error { + return func(ctx *ext.Context, update *ext.Update) error { + userID := update.GetUserChat().GetID() + user, err := database.GetUserByChatID(ctx, userID) + if err != nil { + ctx.Reply(update, ext.ReplyTextString("获取用户信息失败: "+err.Error()), nil) + return dispatcher.EndGroups + } + if !user.Silent { + return next(ctx, update) + } + if user.DefaultStorage == "" { + ctx.Reply(update, ext.ReplyTextString("您已开启静默模式, 但未设置默认存储端, 请先使用 /storage 设置"), nil) + return next(ctx, update) + } + stor, err := storage.GetStorageByUserIDAndName(ctx, userID, user.DefaultStorage) + if err != nil { + ctx.Reply(update, ext.ReplyTextString("获取默认存储失败: "+err.Error()), nil) + return dispatcher.EndGroups + } + ctx.Context = storage.WithContext(ctx.Context, stor) + return handler(ctx, update) + } +} diff --git a/client/bot/handlers/register.go b/client/bot/handlers/register.go new file mode 100644 index 0000000..40e2863 --- /dev/null +++ b/client/bot/handlers/register.go @@ -0,0 +1,41 @@ +package handlers + +import ( + "github.com/celestix/gotgproto/dispatcher" + "github.com/celestix/gotgproto/dispatcher/handlers" + "github.com/celestix/gotgproto/dispatcher/handlers/filters" + "github.com/celestix/gotgproto/ext" + "github.com/krau/SaveAny-Bot/client/bot/handlers/utils/re" + "github.com/krau/SaveAny-Bot/pkg/tcbdata" +) + +func Register(disp dispatcher.Dispatcher) { + disp.AddHandler(handlers.NewMessage(filters.Message.ChatType(filters.ChatTypeChannel), func(ctx *ext.Context, u *ext.Update) error { + return dispatcher.EndGroups + })) + disp.AddHandler(handlers.NewMessage(filters.Message.ChatType(filters.ChatTypeChat), func(ctx *ext.Context, u *ext.Update) error { + return dispatcher.EndGroups + })) + disp.AddHandler(handlers.NewMessage(filters.Message.All, checkPermission)) + disp.AddHandler(handlers.NewCommand("start", handleHelpCmd)) + disp.AddHandler(handlers.NewCommand("help", handleHelpCmd)) + disp.AddHandler(handlers.NewCommand("silent", handleSilentCmd)) + disp.AddHandler(handlers.NewCommand("storage", handleStorageCmd)) + disp.AddHandler(handlers.NewCommand("dir", handleDirCmd)) + disp.AddHandler(handlers.NewCommand("rule", handleRuleCmd)) + disp.AddHandler(handlers.NewCommand("save", handleSilentMode(handleSaveCmd, handleSilentSaveReplied))) + disp.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix(tcbdata.TypeAdd), handleAddCallback)) + disp.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix(tcbdata.TypeSetDefault), handleSetDefaultCallback)) + disp.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("cancel"), handleCancelCallback)) + linkRegexFilter, err := filters.Message.Regex(re.TgMessageLinkRegexString) + if err != nil { + panic("failed to create regex filter: " + err.Error()) + } + disp.AddHandler(handlers.NewMessage(linkRegexFilter, handleSilentMode(handleMessageLink, handleSilentSaveLink))) + telegraphUrlRegexFilter, err := filters.Message.Regex(re.TelegraphUrlRegexString) + if err != nil { + panic("failed to create Telegraph URL regex filter: " + err.Error()) + } + disp.AddHandler(handlers.NewMessage(telegraphUrlRegexFilter, handleSilentMode(handleTelegraphUrlMessage, handleSilentSaveTelegraph))) + disp.AddHandler(handlers.NewMessage(filters.Message.Media, handleSilentMode(handleMediaMessage, handleSilentSaveMedia))) +} diff --git a/client/bot/handlers/rule.go b/client/bot/handlers/rule.go new file mode 100644 index 0000000..eba856a --- /dev/null +++ b/client/bot/handlers/rule.go @@ -0,0 +1,101 @@ +package handlers + +import ( + "fmt" + "strconv" + "strings" + + "github.com/celestix/gotgproto/dispatcher" + "github.com/celestix/gotgproto/ext" + "github.com/charmbracelet/log" + "github.com/duke-git/lancet/v2/slice" + "github.com/krau/SaveAny-Bot/client/bot/handlers/utils/msgelem" + "github.com/krau/SaveAny-Bot/database" + "github.com/krau/SaveAny-Bot/pkg/enums/rule" +) + +func handleRuleCmd(ctx *ext.Context, update *ext.Update) error { + logger := log.FromContext(ctx) + args := strings.Split(update.EffectiveMessage.Text, " ") + userChatID := update.GetUserChat().GetID() + user, err := database.GetUserByChatID(ctx, userChatID) + if err != nil { + logger.Errorf("获取用户规则失败: %s", err) + ctx.Reply(update, ext.ReplyTextString("获取用户规则失败"), nil) + return dispatcher.EndGroups + } + if len(args) < 2 { + ctx.Reply(update, ext.ReplyTextStyledTextArray(msgelem.BuildRuleHelpStyling(user.ApplyRule, user.Rules)), nil) + return dispatcher.EndGroups + } + switch args[1] { + case "switch": + // /rule switch + applyRule := !user.ApplyRule + if err := database.UpdateUserApplyRule(ctx, user.ChatID, applyRule); err != nil { + logger.Errorf("更新用户失败: %s", err) + ctx.Reply(update, ext.ReplyTextString("更新用户失败"), nil) + return dispatcher.EndGroups + } + ctx.Reply(update, ext.ReplyTextString(fmt.Sprintf("已%s规则模式", map[bool]string{true: "启用", false: "禁用"}[applyRule])), nil) + case "add": + // /rule add + if len(args) < 6 { + ctx.Reply(update, ext.ReplyTextStyledTextArray(msgelem.BuildRuleHelpStyling(user.ApplyRule, user.Rules)), nil) + return dispatcher.EndGroups + } + ruleTypeArg := args[2] + ruleType, err := func() (rule.RuleType, error) { + for _, t := range rule.Values() { + if strings.EqualFold(t.String(), ruleTypeArg) { + return t, nil + } + } + return rule.RuleType(""), fmt.Errorf("无效的规则类型: %s\n可用: %v", ruleTypeArg, slice.Join(rule.Values(), ", ")) + }() + if err != nil { + ctx.Reply(update, ext.ReplyTextString(err.Error()), nil) + return dispatcher.EndGroups + } + + ruleData := args[3] + storageName := args[4] + dirPath := args[5] + + rd := &database.Rule{ + Type: ruleType.String(), + Data: ruleData, + StorageName: storageName, + DirPath: dirPath, + UserID: user.ID, + } + if err := database.CreateRule(ctx, rd); err != nil { + logger.Errorf("创建规则失败: %s", err) + ctx.Reply(update, ext.ReplyTextString("创建规则失败"), nil) + return dispatcher.EndGroups + } + ctx.Reply(update, ext.ReplyTextString("创建规则成功"), nil) + case "del": + // /rule del + if len(args) < 3 { + ctx.Reply(update, ext.ReplyTextString("请提供规则ID"), nil) + return dispatcher.EndGroups + } + ruleID := args[2] + id, err := strconv.Atoi(ruleID) + if err != nil { + ctx.Reply(update, ext.ReplyTextString("无效的规则ID"), nil) + return dispatcher.EndGroups + } + if err := database.DeleteRule(ctx, uint(id)); err != nil { + logger.Errorf("删除规则失败: %s", err) + ctx.Reply(update, ext.ReplyTextString("删除规则失败"), nil) + return dispatcher.EndGroups + } + ctx.Reply(update, ext.ReplyTextString("删除规则成功"), nil) + default: + ctx.Reply(update, ext.ReplyTextStyledTextArray(msgelem.BuildRuleHelpStyling(user.ApplyRule, user.Rules)), nil) + return dispatcher.EndGroups + } + return dispatcher.EndGroups +} diff --git a/client/bot/handlers/save.go b/client/bot/handlers/save.go new file mode 100644 index 0000000..f0dab17 --- /dev/null +++ b/client/bot/handlers/save.go @@ -0,0 +1,168 @@ +package handlers + +import ( + "fmt" + "strings" + + "github.com/celestix/gotgproto/dispatcher" + "github.com/celestix/gotgproto/ext" + "github.com/charmbracelet/log" + "github.com/gotd/td/tg" + "github.com/krau/SaveAny-Bot/client/bot/handlers/utils/mediautil" + "github.com/krau/SaveAny-Bot/client/bot/handlers/utils/msgelem" + "github.com/krau/SaveAny-Bot/client/bot/handlers/utils/shortcut" + "github.com/krau/SaveAny-Bot/common/utils/strutil" + "github.com/krau/SaveAny-Bot/common/utils/tgutil" + "github.com/krau/SaveAny-Bot/pkg/tcbdata" + "github.com/krau/SaveAny-Bot/pkg/tfile" + + "github.com/krau/SaveAny-Bot/storage" +) + +func handleSaveCmd(ctx *ext.Context, update *ext.Update) error { + logger := log.FromContext(ctx) + args := strings.Split(string(update.EffectiveMessage.Text), " ") + if len(args) >= 3 { + return handleBatchSave(ctx, update, args[1], args[2]) + } + replyTo := update.EffectiveMessage.ReplyToMessage + if replyTo == nil || replyTo.Message == nil { + ctx.Reply(update, ext.ReplyTextString(msgelem.SaveHelpText), nil) + return dispatcher.EndGroups + } + genFilename := func() string { + if len(args) > 1 { + return args[1] + } + filename := tgutil.GenFileNameFromMessage(*replyTo.Message) + return filename + }() + option := tfile.WithNameIfEmpty(genFilename) + if len(args) > 1 { + option = tfile.WithName(genFilename) + } + msg, file, err := shortcut.GetFileFromMessageWithReply(ctx, update, replyTo.Message, option) + if err != nil { + return err + } + userId := update.GetUserChat().GetID() + stors := storage.GetUserStorages(ctx, userId) + req, err := msgelem.BuildAddOneSelectStorageMessage(ctx, stors, file, msg.ID) + if err != nil { + logger.Errorf("构建存储选择消息失败: %s", err) + ctx.Reply(update, ext.ReplyTextString("构建存储选择消息失败: "+err.Error()), nil) + return dispatcher.EndGroups + } + ctx.EditMessage(update.EffectiveChat().GetID(), req) + return dispatcher.EndGroups +} + +func handleSilentSaveReplied(ctx *ext.Context, update *ext.Update) error { + args := strings.Split(string(update.EffectiveMessage.Text), " ") + if len(args) >= 3 { + return handleBatchSave(ctx, update, args[1], args[2]) + } + logger := log.FromContext(ctx) + stor := storage.FromContext(ctx) + if stor == nil { + logger.Warn("Context storage is nil") + ctx.Reply(update, ext.ReplyTextString("未找到存储"), nil) + return dispatcher.EndGroups + } + replyTo := update.EffectiveMessage.ReplyToMessage + if replyTo == nil || replyTo.Message == nil { + ctx.Reply(update, ext.ReplyTextString(msgelem.SaveHelpText), nil) + return dispatcher.EndGroups + } + genFilename := func() string { + if len(args) > 1 { + return args[1] + } + filename := tgutil.GenFileNameFromMessage(*replyTo.Message) + return filename + }() + option := tfile.WithNameIfEmpty(genFilename) + if len(args) > 1 { + option = tfile.WithName(genFilename) + } + msg, file, err := shortcut.GetFileFromMessageWithReply(ctx, update, replyTo.Message, option) + if err != nil { + return err + } + return shortcut.CreateAndAddTGFileTaskWithEdit(ctx, update.GetUserChat().GetID(), stor, "", file, msg.GetID()) +} + +func handleBatchSave(ctx *ext.Context, update *ext.Update, chatArg string, msgIdRangeArg string) error { + startID, endID, err := strutil.ParseIntStrRange(msgIdRangeArg, "-") + if err != nil { + ctx.Reply(update, ext.ReplyTextString("无效的消息ID范围: "+err.Error()), nil) + return dispatcher.EndGroups + } + chatID, err := tgutil.ParseChatID(ctx, chatArg) + if err != nil { + ctx.Reply(update, ext.ReplyTextString("无效的ID或用户名: "+err.Error()), nil) + return dispatcher.EndGroups + } + + replied, err := ctx.Reply(update, ext.ReplyTextString("正在获取消息..."), nil) + if err != nil { + log.FromContext(ctx).Errorf("回复失败: %s", err) + return dispatcher.EndGroups + } + + // TODO: generator istead of get all messages + msgs, err := tgutil.GetMessagesRange(ctx, chatID, int(startID), int(endID)) + if err != nil { + ctx.Reply(update, ext.ReplyTextString("获取消息失败: "+err.Error()), nil) + return dispatcher.EndGroups + } + if len(msgs) == 0 { + ctx.Reply(update, ext.ReplyTextString("没有找到指定范围内的消息"), nil) + return dispatcher.EndGroups + } + files := make([]tfile.TGFileMessage, 0, len(msgs)) + for _, msg := range msgs { + media, ok := msg.GetMedia() + if !ok { + continue + } + supported := mediautil.IsSupported(media) + if !supported { + continue + } + file, err := tfile.FromMediaMessage(media, msg, tfile.WithNameIfEmpty(tgutil.GenFileNameFromMessage(*msg))) + if err != nil { + log.FromContext(ctx).Errorf("获取文件失败: %s", err) + continue + } + files = append(files, file) + } + if len(files) == 0 { + ctx.Reply(update, ext.ReplyTextString("没有找到指定范围内的可保存消息"), nil) + return dispatcher.EndGroups + } + stor := storage.FromContext(ctx) + if stor == nil { + // not in silent mode + stors := storage.GetUserStorages(ctx, update.GetUserChat().GetID()) + markup, err := msgelem.BuildAddSelectStorageKeyboard(stors, tcbdata.Add{ + Files: files, + }) + if err != nil { + log.FromContext(ctx).Errorf("构建存储选择键盘失败: %s", err) + ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{ + ID: replied.ID, + Message: "构建存储选择键盘失败: " + err.Error(), + }) + return dispatcher.EndGroups + } + ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{ + ID: replied.ID, + Message: fmt.Sprintf("找到 %d 个文件, 请选择存储位置", len(files)), + ReplyMarkup: markup, + }) + return dispatcher.EndGroups + } + return shortcut.CreateAndAddBatchTGFileTaskWithEdit(ctx, update.GetUserChat().GetID(), stor, "", files, replied.ID) + +} diff --git a/client/bot/handlers/silent.go b/client/bot/handlers/silent.go new file mode 100644 index 0000000..53c46bb --- /dev/null +++ b/client/bot/handlers/silent.go @@ -0,0 +1,104 @@ +package handlers + +import ( + "strings" + + "github.com/celestix/gotgproto/dispatcher" + "github.com/celestix/gotgproto/ext" + "github.com/gotd/td/tg" + "github.com/krau/SaveAny-Bot/client/bot/handlers/utils/msgelem" + "github.com/krau/SaveAny-Bot/common/cache" + "github.com/krau/SaveAny-Bot/database" + "github.com/krau/SaveAny-Bot/pkg/tcbdata" + "github.com/krau/SaveAny-Bot/storage" +) + +func handleSilentCmd(ctx *ext.Context, update *ext.Update) error { + user, err := database.GetUserByChatID(ctx, update.GetUserChat().GetID()) + if err != nil { + ctx.Reply(update, ext.ReplyTextString("获取用户信息失败: "+err.Error()), nil) + return nil + } + if !user.Silent && user.DefaultStorage == "" { + ctx.Reply(update, ext.ReplyTextString("请先使用 /storage 设置默认存储位置"), nil) + return nil + } + user.Silent = !user.Silent + if err := database.UpdateUser(ctx, user); err != nil { + ctx.Reply(update, ext.ReplyTextString("更新用户信息失败: "+err.Error()), nil) + return nil + } + responseText := "已" + map[bool]string{true: "开启", false: "关闭"}[user.Silent] + "静默模式" + ctx.Reply(update, ext.ReplyTextString(responseText), nil) + return dispatcher.EndGroups +} + +func handleSetDefaultCallback(ctx *ext.Context, update *ext.Update) error { + dataid := strings.Split(string(update.CallbackQuery.Data), " ")[1] + data, ok := cache.Get[tcbdata.SetDefaultStorage](dataid) + if !ok { + ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{ + QueryID: update.CallbackQuery.GetQueryID(), + Alert: true, + Message: "数据已过期", + CacheTime: 5, + }) + return dispatcher.EndGroups + } + userID := update.CallbackQuery.GetUserID() + + storageName := data.StorageName + selectedStorage, err := storage.GetStorageByUserIDAndName(ctx, userID, storageName) + if err != nil { + ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{ + QueryID: update.CallbackQuery.GetQueryID(), + Alert: true, + Message: "存储获取失败: " + err.Error(), + CacheTime: 5, + }) + return dispatcher.EndGroups + } + user, err := database.GetUserByChatID(ctx, userID) + if err != nil { + ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{ + QueryID: update.CallbackQuery.GetQueryID(), + Alert: true, + Message: "获取用户信息失败: " + err.Error(), + CacheTime: 5, + }) + return dispatcher.EndGroups + } + user.DefaultStorage = selectedStorage.Name() + if err := database.UpdateUser(ctx, user); err != nil { + ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{ + QueryID: update.CallbackQuery.GetQueryID(), + Alert: true, + Message: "更新用户信息失败: " + err.Error(), + CacheTime: 5, + }) + return dispatcher.EndGroups + } + ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ + ID: update.CallbackQuery.GetMsgID(), + Message: "已将默认存储位置设置为: " + selectedStorage.Name(), + }) + return dispatcher.EndGroups +} + +func handleStorageCmd(ctx *ext.Context, update *ext.Update) error { + userID := update.GetUserChat().GetID() + storages := storage.GetUserStorages(ctx, userID) + if len(storages) == 0 { + ctx.Reply(update, ext.ReplyTextString("无可用的存储"), nil) + return nil + } + markup, err := msgelem.BuildSetDefaultStorageMarkup(ctx, userID, storages) + if err != nil { + ctx.Reply(update, ext.ReplyTextString("获取存储失败: "+err.Error()), nil) + return nil + } + ctx.Reply(update, ext.ReplyTextString("请选择要设为默认的存储位置"), &ext.ReplyOpts{ + Markup: markup, + }) + return dispatcher.EndGroups +} diff --git a/client/bot/handlers/telegraph.go b/client/bot/handlers/telegraph.go new file mode 100644 index 0000000..b239949 --- /dev/null +++ b/client/bot/handlers/telegraph.go @@ -0,0 +1,76 @@ +package handlers + +import ( + "fmt" + + "github.com/celestix/gotgproto/dispatcher" + "github.com/celestix/gotgproto/ext" + "github.com/charmbracelet/log" + "github.com/gotd/td/telegram/message/entity" + "github.com/gotd/td/telegram/message/styling" + "github.com/gotd/td/tg" + "github.com/krau/SaveAny-Bot/client/bot/handlers/utils/msgelem" + "github.com/krau/SaveAny-Bot/client/bot/handlers/utils/shortcut" + "github.com/krau/SaveAny-Bot/pkg/enums/tasktype" + "github.com/krau/SaveAny-Bot/pkg/tcbdata" + "github.com/krau/SaveAny-Bot/storage" +) + +func handleTelegraphUrlMessage(ctx *ext.Context, update *ext.Update) error { + logger := log.FromContext(ctx) + + msg, result, err := shortcut.GetTphPicsFromMessageWithReply(ctx, update) + if err != nil { + return err + } + userID := update.GetUserChat().GetID() + stors := storage.GetUserStorages(ctx, userID) + markup, err := msgelem.BuildAddSelectStorageKeyboard(stors, tcbdata.Add{ + TaskType: tasktype.TaskTypeTphpics, + TphPageNode: result.Page, + TphDirPath: result.TphDir, + TphPics: result.Pics, + }) + if err != nil { + logger.Errorf("构建存储选择键盘失败: %s", err) + ctx.Reply(update, ext.ReplyTextString("构建存储选择键盘失败: "+err.Error()), nil) + return dispatcher.EndGroups + } + + eb := entity.Builder{} + if err := styling.Perform(&eb, + styling.Plain("标题: "), + styling.Code(result.Page.Title), + styling.Plain("\n图片数量: "), + styling.Code(fmt.Sprintf("%d", len(result.Pics))), + styling.Plain("\n请选择存储位置"), + ); err != nil { + log.FromContext(ctx).Errorf("Failed to build entity: %s", err) + return dispatcher.EndGroups + } + text, entities := eb.Complete() + ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ + Message: text, + ID: msg.ID, + ReplyMarkup: markup, + Entities: entities, + }) + return dispatcher.EndGroups +} + +func handleSilentSaveTelegraph(ctx *ext.Context, update *ext.Update) error { + logger := log.FromContext(ctx) + stor := storage.FromContext(ctx) + if stor == nil { + logger.Warn("Context storage is nil") + ctx.Reply(update, ext.ReplyTextString("未找到存储"), nil) + return dispatcher.EndGroups + } + msg, result, err := shortcut.GetTphPicsFromMessageWithReply(ctx, update) + if err != nil { + return err + } + userID := update.GetUserChat().GetID() + return shortcut.CreateAndAddTphTaskWithEdit(ctx, userID, result.Page, result.TphDir, result.Pics, stor, msg.ID) + +} diff --git a/client/bot/handlers/utils/mediautil/media.go b/client/bot/handlers/utils/mediautil/media.go new file mode 100644 index 0000000..1530952 --- /dev/null +++ b/client/bot/handlers/utils/mediautil/media.go @@ -0,0 +1,12 @@ +package mediautil + +import "github.com/gotd/td/tg" + +func IsSupported(media tg.MessageMediaClass) bool { + switch media.(type) { + case *tg.MessageMediaDocument, *tg.MessageMediaPhoto: + return true + default: + return false + } +} diff --git a/client/bot/handlers/utils/msgelem/callback.go b/client/bot/handlers/utils/msgelem/callback.go new file mode 100644 index 0000000..f521f53 --- /dev/null +++ b/client/bot/handlers/utils/msgelem/callback.go @@ -0,0 +1,12 @@ +package msgelem + +import "github.com/gotd/td/tg" + +func AlertCallbackAnswer(queryID int64, text string) *tg.MessagesSetBotCallbackAnswerRequest { + return &tg.MessagesSetBotCallbackAnswerRequest{ + QueryID: queryID, + Alert: true, + Message: text, + CacheTime: 5, + } +} diff --git a/client/bot/handlers/utils/msgelem/dir.go b/client/bot/handlers/utils/msgelem/dir.go new file mode 100644 index 0000000..1482742 --- /dev/null +++ b/client/bot/handlers/utils/msgelem/dir.go @@ -0,0 +1,36 @@ +package msgelem + +import ( + "fmt" + "strings" + + "github.com/gotd/td/telegram/message/styling" + "github.com/krau/SaveAny-Bot/database" +) + +func BuildDirHelpStyling(dirs []database.Dir) []styling.StyledTextOption { + return []styling.StyledTextOption{ + styling.Bold("使用方法: /dir <操作> <参数...>"), + styling.Plain("\n\n可用操作:\n"), + styling.Code("add"), + styling.Plain(" <存储名> <路径> - 添加路径\n"), + styling.Code("del"), + styling.Plain(" <路径ID> - 删除路径\n"), + styling.Plain("\n添加路径示例:\n"), + styling.Code("/dir add local1 path/to/dir"), + styling.Plain("\n\n删除路径示例:\n"), + styling.Code("/dir del 3"), + styling.Plain("\n\n当前已添加的路径:\n"), + styling.Blockquote(func() string { + var sb strings.Builder + for _, dir := range dirs { + sb.WriteString(fmt.Sprintf("%d: ", dir.ID)) + sb.WriteString(dir.StorageName) + sb.WriteString(" - ") + sb.WriteString(dir.Path) + sb.WriteString("\n") + } + return sb.String() + }(), true), + } +} diff --git a/client/bot/handlers/utils/msgelem/rule.go b/client/bot/handlers/utils/msgelem/rule.go new file mode 100644 index 0000000..a708038 --- /dev/null +++ b/client/bot/handlers/utils/msgelem/rule.go @@ -0,0 +1,32 @@ +package msgelem + +import ( + "fmt" + "strings" + + "github.com/gotd/td/telegram/message/styling" + "github.com/krau/SaveAny-Bot/database" +) + +func BuildRuleHelpStyling(enabled bool, rules []database.Rule) []styling.StyledTextOption { + return []styling.StyledTextOption{ + styling.Bold("使用方法: /rule <操作> <参数...>"), + styling.Bold(fmt.Sprintf("\n当前已%s规则模式", map[bool]string{true: "启用", false: "禁用"}[enabled])), + styling.Plain("\n\n可用操作:\n"), + styling.Code("switch"), + styling.Plain(" - 开关规则模式\n"), + styling.Code("add"), + styling.Plain(" <类型> <数据> <存储名> <路径> - 添加规则\n"), + styling.Code("del"), + styling.Plain(" <规则ID> - 删除规则\n"), + styling.Plain("\n当前已添加的规则:\n"), + styling.Blockquote(func() string { + var sb strings.Builder + for _, rule := range rules { + ruleText := fmt.Sprintf("%s %s %s %s", rule.Type, rule.Data, rule.StorageName, rule.DirPath) + sb.WriteString(fmt.Sprintf("%d: %s\n", rule.ID, ruleText)) + } + return sb.String() + }(), true), + } +} diff --git a/client/bot/handlers/utils/msgelem/save.go b/client/bot/handlers/utils/msgelem/save.go new file mode 100644 index 0000000..ed04e3f --- /dev/null +++ b/client/bot/handlers/utils/msgelem/save.go @@ -0,0 +1,15 @@ +package msgelem + +const ( + SaveHelpText = ` + 使用方法: + + 1. 使用该命令回复要保存的文件, 可选文件名参数. + 示例: + /save custom_file_name.mp4 + + 2. 设置默认存储后, 发送 /save <频道ID/用户名> <消息ID范围> 来批量保存文件. 遵从存储规则, 若未匹配到任何规则则使用默认存储. + 示例: + /save @moreacg 114-514 + ` +) diff --git a/client/bot/handlers/utils/msgelem/storage.go b/client/bot/handlers/utils/msgelem/storage.go new file mode 100644 index 0000000..70fda81 --- /dev/null +++ b/client/bot/handlers/utils/msgelem/storage.go @@ -0,0 +1,169 @@ +package msgelem + +import ( + "context" + "fmt" + + "github.com/charmbracelet/log" + "github.com/gotd/td/telegram/message/entity" + "github.com/gotd/td/telegram/message/styling" + "github.com/gotd/td/tg" + "github.com/krau/SaveAny-Bot/common/cache" + "github.com/krau/SaveAny-Bot/database" + "github.com/krau/SaveAny-Bot/pkg/enums/tasktype" + "github.com/krau/SaveAny-Bot/pkg/tcbdata" + "github.com/krau/SaveAny-Bot/pkg/tfile" + "github.com/krau/SaveAny-Bot/storage" + "github.com/rs/xid" +) + +func BuildAddSelectStorageKeyboard(stors []storage.Storage, adddata tcbdata.Add) (*tg.ReplyInlineMarkup, error) { + taskType := adddata.TaskType + if taskType == "" { + if len(adddata.Files) > 0 { + taskType = tasktype.TaskTypeTgfiles + } else if adddata.TphPageNode != nil { + taskType = tasktype.TaskTypeTphpics + } else { + return nil, fmt.Errorf("unknown task type: %s", taskType) + } + } + + buttons := make([]tg.KeyboardButtonClass, 0) + for _, storage := range stors { + data := tcbdata.Add{ + TaskType: taskType, + SelectedStorName: storage.Name(), + + Files: adddata.Files, + AsBatch: len(adddata.Files) > 1, + + TphPageNode: adddata.TphPageNode, + TphPics: adddata.TphPics, + TphDirPath: adddata.TphDirPath, + } + dataid := xid.New().String() + err := cache.Set(dataid, data) + if err != nil { + return nil, err + } + buttons = append(buttons, &tg.KeyboardButtonCallback{ + Text: storage.Name(), + Data: fmt.Appendf(nil, "%s %s", tcbdata.TypeAdd, dataid), + }) + } + markup := &tg.ReplyInlineMarkup{} + for i := 0; i < len(buttons); i += 3 { + row := tg.KeyboardButtonRow{} + row.Buttons = buttons[i:min(i+3, len(buttons))] + markup.Rows = append(markup.Rows, row) + } + return markup, nil +} + +func BuildAddOneSelectStorageMessage(ctx context.Context, stors []storage.Storage, file tfile.TGFileMessage, msgId int) (*tg.MessagesEditMessageRequest, error) { + eb := entity.Builder{} + var entities []tg.MessageEntityClass + text := fmt.Sprintf("文件名: %s\n请选择存储位置", file.Name()) + if err := styling.Perform(&eb, + styling.Plain("文件名: "), + styling.Code(file.Name()), + styling.Plain("\n请选择存储位置"), + ); err != nil { + log.FromContext(ctx).Errorf("Failed to build entity: %s", err) + } else { + text, entities = eb.Complete() + } + markup, err := BuildAddSelectStorageKeyboard(stors, tcbdata.Add{ + TaskType: tasktype.TaskTypeTgfiles, + Files: []tfile.TGFileMessage{file}, + AsBatch: false, + }) + if err != nil { + return nil, fmt.Errorf("failed to build storage keyboard: %w", err) + } + return &tg.MessagesEditMessageRequest{ + Message: text, + Entities: entities, + ReplyMarkup: markup, + ID: msgId, + }, nil +} + +func BuildSetDefaultStorageMarkup(ctx context.Context, userID int64, stors []storage.Storage) (*tg.ReplyInlineMarkup, error) { + buttons := make([]tg.KeyboardButtonClass, 0) + for _, storage := range stors { + data := tcbdata.SetDefaultStorage{ + StorageName: storage.Name(), + } + dataid := xid.New().String() + err := cache.Set(dataid, data) + if err != nil { + return nil, err + } + buttons = append(buttons, &tg.KeyboardButtonCallback{ + Text: storage.Name(), + Data: fmt.Appendf(nil, "%s %s", tcbdata.TypeSetDefault, dataid), + }) + } + markup := &tg.ReplyInlineMarkup{} + for i := 0; i < len(buttons); i += 3 { + row := tg.KeyboardButtonRow{} + row.Buttons = buttons[i:min(i+3, len(buttons))] + markup.Rows = append(markup.Rows, row) + } + return markup, nil +} + +func BuildSetDirKeyboard(dirs []database.Dir, dataid string) (*tg.ReplyInlineMarkup, error) { + data, ok := cache.Get[tcbdata.Add](dataid) + if !ok { + return nil, fmt.Errorf("failed to get data from cache: %s", dataid) + } + if data.DirID != 0 || data.SettedDir { + log.Warnf("Data already has a directory set: %d, %t", data.DirID, data.SettedDir) + return nil, fmt.Errorf("data already has a directory set") + } + buttons := make([]tg.KeyboardButtonClass, 0) + for _, dir := range dirs { + dirDataId := xid.New().String() + dirData := tcbdata.Add{ + Files: data.Files, + SelectedStorName: data.SelectedStorName, + AsBatch: data.AsBatch, + DirID: dir.ID, + SettedDir: true, + } + err := cache.Set(dirDataId, dirData) + if err != nil { + return nil, fmt.Errorf("failed to set directory data in cache: %w", err) + } + buttons = append(buttons, &tg.KeyboardButtonCallback{ + Text: dir.Path, + Data: fmt.Appendf(nil, "%s %s", tcbdata.TypeAdd, dirDataId), + }) + } + dirDefaultDataId := xid.New().String() + dirDefaultData := tcbdata.Add{ + Files: data.Files, + SelectedStorName: data.SelectedStorName, + AsBatch: data.AsBatch, + DirID: 0, + SettedDir: true, + } + err := cache.Set(dirDefaultDataId, dirDefaultData) + if err != nil { + return nil, fmt.Errorf("failed to set default directory data in cache: %w", err) + } + buttons = append(buttons, &tg.KeyboardButtonCallback{ + Text: "默认", + Data: fmt.Appendf(nil, "%s %s", tcbdata.TypeAdd, dirDefaultDataId), + }) + markup := &tg.ReplyInlineMarkup{} + for i := 0; i < len(buttons); i += 3 { + row := tg.KeyboardButtonRow{} + row.Buttons = buttons[i:min(i+3, len(buttons))] + markup.Rows = append(markup.Rows, row) + } + return markup, nil +} diff --git a/client/bot/handlers/utils/msgelem/task.go b/client/bot/handlers/utils/msgelem/task.go new file mode 100644 index 0000000..9026be3 --- /dev/null +++ b/client/bot/handlers/utils/msgelem/task.go @@ -0,0 +1,33 @@ +package msgelem + +import ( + "context" + "fmt" + "strconv" + + "github.com/charmbracelet/log" + "github.com/gotd/td/telegram/message/entity" + "github.com/gotd/td/telegram/message/styling" + "github.com/gotd/td/tg" +) + +func BuildTaskAddedEntities( + ctx context.Context, + filename string, + queueLength int, +) (string, []tg.MessageEntityClass) { + entityBuilder := entity.Builder{} + var entities []tg.MessageEntityClass + text := fmt.Sprintf("已添加到任务队列\n文件名: %s\n当前排队任务数: %d", filename, queueLength) + if err := styling.Perform(&entityBuilder, + styling.Plain("已添加到任务队列\n文件名: "), + styling.Code(filename), + styling.Plain("\n当前排队任务数: "), + styling.Bold(strconv.Itoa(queueLength)), + ); err != nil { + log.FromContext(ctx).Errorf("Failed to build entity: %s", err) + } else { + text, entities = entityBuilder.Complete() + } + return text, entities +} diff --git a/client/bot/handlers/utils/re/regexp.go b/client/bot/handlers/utils/re/regexp.go new file mode 100644 index 0000000..7562247 --- /dev/null +++ b/client/bot/handlers/utils/re/regexp.go @@ -0,0 +1,10 @@ +package re + +import "regexp" + +var ( + TgMessageLinkRegexString = `https?://t\.me/(?:c/\d+|[A-Za-z0-9_]+)/\d+(?:/\d+)?(?:\?[^\s#]*[A-Za-z0-9_])?\b` + TgMessageLinkRegexp = regexp.MustCompile(TgMessageLinkRegexString) + TelegraphUrlRegexString = `https://telegra.ph/.*` + TelegraphUrlRegexp = regexp.MustCompile(TelegraphUrlRegexString) +) diff --git a/client/bot/handlers/utils/ruleutil/rule.go b/client/bot/handlers/utils/ruleutil/rule.go new file mode 100644 index 0000000..3c5cf34 --- /dev/null +++ b/client/bot/handlers/utils/ruleutil/rule.go @@ -0,0 +1,80 @@ +package ruleutil + +import ( + "context" + + "github.com/charmbracelet/log" + "github.com/krau/SaveAny-Bot/database" + "github.com/krau/SaveAny-Bot/pkg/consts" + ruleenum "github.com/krau/SaveAny-Bot/pkg/enums/rule" + "github.com/krau/SaveAny-Bot/pkg/rule" + "github.com/krau/SaveAny-Bot/pkg/tfile" +) + +type ruleInput struct { + File tfile.TGFileMessage +} + +type ruleInputOption func(*ruleInput) + +func NewInput(file tfile.TGFileMessage, opts ...ruleInputOption) *ruleInput { + input := &ruleInput{ + File: file, + } + for _, opt := range opts { + opt(input) + } + return input +} + +type matchedStorName string + +func (m matchedStorName) String() string { + return string(m) +} + +func (m matchedStorName) IsValid() bool { + return m != "" && m != consts.RuleStorNameChosen +} + +func ApplyRule(ctx context.Context, rules []database.Rule, inputs *ruleInput) (matchedStorageName matchedStorName, dirPath string) { + if inputs == nil || len(rules) == 0 { + return "", "" + } + logger := log.FromContext(ctx) + for _, ur := range rules { + switch ur.Type { + case ruleenum.FileNameRegex.String(): + ru, err := rule.NewRuleFileNameRegex(ur.StorageName, ur.DirPath, ur.Data) + if err != nil { + logger.Errorf("Failed to create rule: %s", err) + continue + } + ok, err := ru.Match(inputs.File) + if err != nil { + logger.Errorf("Failed to match rule: %s", err) + continue + } + if ok { + dirPath = ru.StoragePath() + matchedStorageName = matchedStorName(ru.StorageName()) + } + case ruleenum.MessageRegex.String(): + ru, err := rule.NewRuleMessageRegex(ur.StorageName, ur.DirPath, ur.Data) + if err != nil { + logger.Errorf("Failed to create rule: %s", err) + continue + } + ok, err := ru.Match(inputs.File.Message().GetMessage()) + if err != nil { + logger.Errorf("Failed to match rule: %s", err) + continue + } + if ok { + dirPath = ru.StoragePath() + matchedStorageName = matchedStorName(ru.StorageName()) + } + } + } + return +} diff --git a/client/bot/handlers/utils/shortcut/message.go b/client/bot/handlers/utils/shortcut/message.go new file mode 100644 index 0000000..8e4f7e4 --- /dev/null +++ b/client/bot/handlers/utils/shortcut/message.go @@ -0,0 +1,193 @@ +// Some shortcuts for duplicate code in handlers, they should return dispatcher errors +package shortcut + +import ( + "encoding/json" + "net/url" + "strings" + + "github.com/celestix/gotgproto/dispatcher" + "github.com/celestix/gotgproto/ext" + "github.com/celestix/gotgproto/types" + "github.com/charmbracelet/log" + "github.com/gotd/td/tg" + "github.com/krau/SaveAny-Bot/client/bot/handlers/utils/mediautil" + "github.com/krau/SaveAny-Bot/client/bot/handlers/utils/msgelem" + "github.com/krau/SaveAny-Bot/client/bot/handlers/utils/re" + "github.com/krau/SaveAny-Bot/common/cache" + "github.com/krau/SaveAny-Bot/common/utils/tgutil" + "github.com/krau/SaveAny-Bot/common/utils/tphutil" + "github.com/krau/SaveAny-Bot/pkg/telegraph" + "github.com/krau/SaveAny-Bot/pkg/tfile" +) + +// 获取消息中的文件并回复等待消息, 返回等待消息, 获取到的文件 +func GetFileFromMessageWithReply(ctx *ext.Context, update *ext.Update, message *tg.Message, tfileopts ...tfile.TGFileOptions) (replied *types.Message, + file tfile.TGFileMessage, err error, +) { + logger := log.FromContext(ctx) + media := message.Media + supported := mediautil.IsSupported(media) + if !supported { + ctx.Reply(update, ext.ReplyTextString("不支持的消息类型"), nil) + return nil, nil, dispatcher.EndGroups + } + + replied, err = ctx.Reply(update, ext.ReplyTextString("正在获取文件信息..."), nil) + if err != nil { + logger.Errorf("Failed to reply: %s", err) + return nil, nil, dispatcher.EndGroups + } + options := []tfile.TGFileOptions{ + tfile.WithMessage(message), + } + if len(tfileopts) > 0 { + options = append(options, tfileopts...) + } else { + options = append(options, tfile.WithNameIfEmpty(tgutil.GenFileNameFromMessage(*message))) + } + file, err = tfile.FromMediaMessage(media, message, options...) + if err != nil { + logger.Errorf("Failed to get file from media: %s", err) + ctx.Reply(update, ext.ReplyTextString("获取文件失败: "+err.Error()), nil) + return nil, nil, dispatcher.EndGroups + } + return replied, file, nil +} + +type EditMessageFunc func(text string, markup tg.ReplyMarkupClass) + +// 获取链接中的文件并回复等待消息 +func GetFilesFromUpdateLinkMessageWithReplyEdit(ctx *ext.Context, update *ext.Update) (replied *types.Message, files []tfile.TGFileMessage, editReplied EditMessageFunc, err error) { + logger := log.FromContext(ctx) + msgLinks := re.TgMessageLinkRegexp.FindAllString(update.EffectiveMessage.GetMessage(), -1) + if len(msgLinks) == 0 { + logger.Warn("no matched message links but called handleMessageLink") + return nil, nil, nil, dispatcher.EndGroups + } + replied, err = ctx.Reply(update, ext.ReplyTextString("正在获取消息..."), nil) + if err != nil { + logger.Errorf("failed to reply: %s", err) + return nil, nil, nil, dispatcher.EndGroups + } + editReplied = func(text string, markup tg.ReplyMarkupClass) { + if _, err := ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{ + ID: replied.ID, + Message: text, + ReplyMarkup: markup, + }); err != nil { + logger.Errorf("failed to edit message: %s", err) + } + } + + files = make([]tfile.TGFileMessage, 0, len(msgLinks)) + for _, link := range msgLinks { + chatId, msgId, err := tgutil.ParseMessageLink(ctx, link) + if err != nil { + logger.Errorf("failed to parse message link %s: %s", link, err) + continue + } + msg, err := tgutil.GetMessageByID(ctx, chatId, msgId) + if err != nil { + logger.Errorf("failed to get message by ID: %s", err) + continue + } + media, ok := msg.GetMedia() + if !ok { + logger.Debugf("message %d has no media", msg.GetID()) + continue + } + file, err := tfile.FromMediaMessage(media, msg, tfile.WithNameIfEmpty(tgutil.GenFileNameFromMessage(*msg))) + if err != nil { + logger.Errorf("failed to create file from media: %s", err) + continue + } + files = append(files, file) + } + if len(files) == 0 { + editReplied("没有找到可保存的文件", nil) + return nil, nil, nil, dispatcher.EndGroups + } + return replied, files, editReplied, nil +} + +func GetCallbackDataWithAnswer[DataType any](ctx *ext.Context, update *ext.Update, dataid string) (DataType, error) { + data, ok := cache.Get[DataType](dataid) + if !ok { + log.FromContext(ctx).Warnf("Invalid data ID: %s", dataid) + queryID := update.CallbackQuery.GetQueryID() + ctx.AnswerCallback(msgelem.AlertCallbackAnswer(queryID, "数据已过期或无效")) + var zero DataType + return zero, dispatcher.EndGroups + } + return data, nil +} + +type TelegraphResult struct { + Pics []string `json:"pics"` // image urls + TphDir string `json:"tph_dir"` // telegraph path, unescaped + Page *telegraph.Page `json:"page"` // telegraph page node +} + +// return replied message, image urls, telegraph path(unescaped), error +func GetTphPicsFromMessageWithReply(ctx *ext.Context, update *ext.Update) (*types.Message, *TelegraphResult, error) { + logger := log.FromContext(ctx) + tphurl := re.TelegraphUrlRegexp.FindString(update.EffectiveMessage.GetMessage()) // TODO: batch urls + if tphurl == "" { + logger.Warnf("No telegraph url found but called handleTelegraph") + return nil, nil, dispatcher.ContinueGroups + } + pagepath := strings.Split(tphurl, "/")[len(strings.Split(tphurl, "/"))-1] + tphdir, err := url.PathUnescape(pagepath) + if err != nil { + logger.Errorf("Failed to unescape telegraph path: %s", err) + ctx.Reply(update, ext.ReplyTextString("解析 telegraph 路径失败: "+err.Error()), nil) + return nil, nil, dispatcher.EndGroups + } + msg, err := ctx.Reply(update, ext.ReplyTextString("正在获取 telegraph 页面..."), nil) + if err != nil { + logger.Errorf("Failed to reply to update: %s", err) + return nil, nil, dispatcher.EndGroups + } + page, err := tphutil.DefaultClient().GetPage(ctx, pagepath) + if err != nil { + logger.Errorf("Failed to get telegraph page: %s", err) + ctx.Reply(update, ext.ReplyTextString("获取 telegraph 页面失败: "+err.Error()), nil) + return nil, nil, dispatcher.EndGroups + } + imgs := make([]string, 0) + for _, elem := range page.Content { + var node telegraph.NodeElement + data, err := json.Marshal(elem) + if err != nil { + logger.Errorf("Failed to marshal element: %s", err) + continue + } + err = json.Unmarshal(data, &node) + if err != nil { + logger.Errorf("Failed to unmarshal element: %s", err) + continue + } + + if len(node.Children) != 0 { + for _, child := range node.Children { + imgs = append(imgs, tphutil.GetNodeImages(child)...) + } + } + if node.Tag == "img" { + if src, ok := node.Attrs["src"]; ok { + imgs = append(imgs, src) + } + } + } + if len(imgs) == 0 { + logger.Warn("No images found in telegraph page") + ctx.Reply(update, ext.ReplyTextString("在 telegraph 页面中未找到图片"), nil) + return nil, nil, dispatcher.EndGroups + } + return msg, &TelegraphResult{ + Pics: imgs, + TphDir: tphdir, + Page: page, + }, nil +} diff --git a/client/bot/handlers/utils/shortcut/tftask.go b/client/bot/handlers/utils/shortcut/tftask.go new file mode 100644 index 0000000..594e71c --- /dev/null +++ b/client/bot/handlers/utils/shortcut/tftask.go @@ -0,0 +1,152 @@ +package shortcut + +import ( + "fmt" + "path" + + "github.com/celestix/gotgproto/dispatcher" + "github.com/celestix/gotgproto/ext" + "github.com/charmbracelet/log" + "github.com/gotd/td/tg" + "github.com/krau/SaveAny-Bot/client/bot/handlers/utils/msgelem" + "github.com/krau/SaveAny-Bot/client/bot/handlers/utils/ruleutil" + "github.com/krau/SaveAny-Bot/common/utils/tgutil" + "github.com/krau/SaveAny-Bot/core" + "github.com/krau/SaveAny-Bot/core/batchtftask" + "github.com/krau/SaveAny-Bot/core/tftask" + "github.com/krau/SaveAny-Bot/database" + "github.com/krau/SaveAny-Bot/pkg/tfile" + "github.com/krau/SaveAny-Bot/storage" + "github.com/rs/xid" +) + +// 创建一个 tftask.TGFileTask 并添加到任务队列中, 以编辑消息的方式反馈结果 +func CreateAndAddTGFileTaskWithEdit(ctx *ext.Context, userID int64, stor storage.Storage, dirPath string, file tfile.TGFileMessage, trackMsgID int) error { + logger := log.FromContext(ctx) + user, err := database.GetUserByChatID(ctx, userID) + if err != nil { + logger.Errorf("Failed to get user by chat ID: %s", err) + ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ + ID: trackMsgID, + Message: "获取用户失败: " + err.Error(), + }) + return dispatcher.EndGroups + } + if user.ApplyRule && user.Rules != nil { + matchedStorageName, matchedDirPath := ruleutil.ApplyRule(ctx, user.Rules, ruleutil.NewInput(file)) + dirPath = matchedDirPath + if matchedStorageName.IsValid() { + stor, err = storage.GetStorageByUserIDAndName(ctx, user.ChatID, matchedStorageName.String()) + if err != nil { + logger.Errorf("Failed to get storage by user ID and name: %s", err) + ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ + ID: trackMsgID, + Message: "获取存储失败: " + err.Error(), + }) + return dispatcher.EndGroups + } + } + } + + storagePath := stor.JoinStoragePath(path.Join(dirPath, file.Name())) + injectCtx := tgutil.ExtWithContext(ctx.Context, ctx) + taskid := xid.New().String() + task, err := tftask.NewTGFileTask(taskid, injectCtx, file, ctx.Raw, stor, storagePath, + tftask.NewProgressTrack( + trackMsgID, + userID)) + if err != nil { + logger.Errorf("create task failed: %s", err) + ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ + ID: trackMsgID, + Message: "创建任务失败: " + err.Error(), + }) + return dispatcher.EndGroups + } + if err := core.AddTask(injectCtx, task); err != nil { + logger.Errorf("add task failed: %s", err) + ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ + ID: trackMsgID, + Message: "添加任务失败: " + err.Error(), + }) + return dispatcher.EndGroups + } + text, entities := msgelem.BuildTaskAddedEntities(ctx, file.Name(), core.GetLength(injectCtx)) + ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ + ID: trackMsgID, + Message: text, + Entities: entities, + }) + + return dispatcher.EndGroups +} + +// 创建一个 batchtftask.BatchTGFileTask 并添加到任务队列中, 以编辑消息的方式反馈结果 +func CreateAndAddBatchTGFileTaskWithEdit(ctx *ext.Context, userID int64, stor storage.Storage, dirPath string, files []tfile.TGFileMessage, trackMsgID int) error { + logger := log.FromContext(ctx) + user, err := database.GetUserByChatID(ctx, userID) + if err != nil { + logger.Errorf("Failed to get user by chat ID: %s", err) + ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ + ID: trackMsgID, + Message: "获取用户失败: " + err.Error(), + }) + return dispatcher.EndGroups + } + useRule := user.ApplyRule && user.Rules != nil + applyRule := func(file tfile.TGFileMessage) (string, string) { + if !useRule { + return stor.Name(), dirPath + } + storName, dirP := ruleutil.ApplyRule(ctx, user.Rules, ruleutil.NewInput(file)) + if !storName.IsValid() { + return stor.Name(), dirP + } + return storName.String(), dirP + } + + elems := make([]batchtftask.TaskElement, 0, len(files)) + for _, file := range files { + storName, dirPath := applyRule(file) + fileStor := stor + if storName != stor.Name() && storName != "" { + fileStor, err = storage.GetStorageByUserIDAndName(ctx, user.ChatID, storName) + if err != nil { + logger.Errorf("Failed to get storage by user ID and name: %s", err) + ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ + ID: trackMsgID, + Message: "获取存储失败: " + err.Error(), + }) + return dispatcher.EndGroups + } + } + storPath := fileStor.JoinStoragePath(path.Join(dirPath, file.Name())) + elem, err := batchtftask.NewTaskElement(fileStor, storPath, file) + if err != nil { + logger.Errorf("Failed to create task element: %s", err) + ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ + ID: trackMsgID, + Message: "任务创建失败: " + err.Error(), + }) + return dispatcher.EndGroups + } + elems = append(elems, *elem) + } + injectCtx := tgutil.ExtWithContext(ctx.Context, ctx) + taskid := xid.New().String() + task := batchtftask.NewBatchTGFileTask(taskid, injectCtx, elems, ctx.Raw, batchtftask.NewProgressTracker(trackMsgID, userID), true) + if err := core.AddTask(injectCtx, task); err != nil { + logger.Errorf("Failed to add batch task: %s", err) + ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ + ID: trackMsgID, + Message: "批量任务添加失败: " + err.Error(), + }) + return dispatcher.EndGroups + } + ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ + ID: trackMsgID, + Message: fmt.Sprintf("已添加批量任务, 共 %d 个文件", len(files)), + ReplyMarkup: nil, + }) + return dispatcher.EndGroups +} diff --git a/client/bot/handlers/utils/shortcut/tphtask.go b/client/bot/handlers/utils/shortcut/tphtask.go new file mode 100644 index 0000000..4d1c1a2 --- /dev/null +++ b/client/bot/handlers/utils/shortcut/tphtask.go @@ -0,0 +1,50 @@ +package shortcut + +import ( + "github.com/celestix/gotgproto/dispatcher" + "github.com/celestix/gotgproto/ext" + "github.com/charmbracelet/log" + "github.com/gotd/td/tg" + "github.com/krau/SaveAny-Bot/client/bot/handlers/utils/msgelem" + "github.com/krau/SaveAny-Bot/common/utils/tgutil" + "github.com/krau/SaveAny-Bot/common/utils/tphutil" + "github.com/krau/SaveAny-Bot/core" + "github.com/krau/SaveAny-Bot/core/tphtask" + "github.com/krau/SaveAny-Bot/pkg/telegraph" + "github.com/krau/SaveAny-Bot/storage" + "github.com/rs/xid" +) + +func CreateAndAddTphTaskWithEdit(ctx *ext.Context, + userID int64, + tphpage *telegraph.Page, + dirPath string, // unescaped ph path for file storage + pics []string, + stor storage.Storage, + trackMsgID int) error { + injectCtx := tgutil.ExtWithContext(ctx.Context, ctx) + task := tphtask.NewTask(xid.New().String(), + injectCtx, + tphpage.Path, + pics, + stor, + stor.JoinStoragePath(dirPath), + tphutil.DefaultClient(), + tphtask.NewProgress(trackMsgID, userID), + ) + if err := core.AddTask(injectCtx, task); err != nil { + log.FromContext(ctx).Errorf("Failed to add task: %s", err) + ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ + ID: trackMsgID, + Message: "任务添加失败: " + err.Error(), + }) + return dispatcher.EndGroups + } + text, entities := msgelem.BuildTaskAddedEntities(ctx, tphpage.Title, core.GetLength(ctx)) + ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ + ID: trackMsgID, + Message: text, + Entities: entities, + }) + return dispatcher.EndGroups +} diff --git a/userclient/middlewares/middlewares.go b/client/middleware/default.go similarity index 73% rename from userclient/middlewares/middlewares.go rename to client/middleware/default.go index 3b36bef..5910ca4 100644 --- a/userclient/middlewares/middlewares.go +++ b/client/middleware/default.go @@ -1,4 +1,4 @@ -package middlewares +package middleware import ( "context" @@ -7,10 +7,11 @@ import ( "github.com/cenkalti/backoff/v4" "github.com/gotd/contrib/middleware/floodwait" "github.com/gotd/td/telegram" - "github.com/krau/SaveAny-Bot/userclient/middlewares/recovery" - "github.com/krau/SaveAny-Bot/userclient/middlewares/retry" + "github.com/krau/SaveAny-Bot/client/middleware/recovery" + "github.com/krau/SaveAny-Bot/client/middleware/retry" ) +// https://github.com/iyear/tdl/blob/master/core/tclient/tclient.go func NewDefaultMiddlewares(ctx context.Context, timeout time.Duration) []telegram.Middleware { return []telegram.Middleware{ recovery.New(ctx, newBackoff(timeout)), @@ -21,7 +22,6 @@ func NewDefaultMiddlewares(ctx context.Context, timeout time.Duration) []telegra func newBackoff(timeout time.Duration) backoff.BackOff { b := backoff.NewExponentialBackOff() - b.Multiplier = 1.1 b.MaxElapsedTime = timeout b.MaxInterval = 10 * time.Second diff --git a/client/middleware/floodwait.go b/client/middleware/floodwait.go new file mode 100644 index 0000000..4c553c4 --- /dev/null +++ b/client/middleware/floodwait.go @@ -0,0 +1,19 @@ +package middleware + +import ( + "time" + + "github.com/gotd/contrib/middleware/floodwait" + "github.com/gotd/contrib/middleware/ratelimit" + "github.com/gotd/td/telegram" + "golang.org/x/time/rate" +) + +func NewFloodWaitMiddlewares(maxRetries uint) []telegram.Middleware { + waiter := floodwait.NewSimpleWaiter().WithMaxRetries(maxRetries) + ratelimiter := ratelimit.New(rate.Every(time.Millisecond*100), 5) + return []telegram.Middleware{ + waiter, + ratelimiter, + } +} diff --git a/userclient/middlewares/recovery/recovery.go b/client/middleware/recovery/recovery.go similarity index 91% rename from userclient/middlewares/recovery/recovery.go rename to client/middleware/recovery/recovery.go index fbbd562..0b13548 100644 --- a/userclient/middlewares/recovery/recovery.go +++ b/client/middleware/recovery/recovery.go @@ -5,12 +5,12 @@ import ( "time" "github.com/cenkalti/backoff/v4" + "github.com/charmbracelet/log" "github.com/go-faster/errors" "github.com/gotd/td/bin" "github.com/gotd/td/telegram" "github.com/gotd/td/tg" "github.com/gotd/td/tgerr" - "github.com/krau/SaveAny-Bot/common" ) type recovery struct { @@ -39,7 +39,7 @@ func (r *recovery) Handle(next tg.Invoker) telegram.InvokeFunc { return nil }, r.backoff, func(err error, duration time.Duration) { - common.Log.Debug("Wait for connection recovery", "error", err, "duration", duration) + log.FromContext(ctx).Debug("Wait for connection recovery", "error", err, "duration", duration) }) } } diff --git a/userclient/middlewares/retry/retry.go b/client/middleware/retry/retry.go similarity index 90% rename from userclient/middlewares/retry/retry.go rename to client/middleware/retry/retry.go index 325531d..eb66628 100644 --- a/userclient/middlewares/retry/retry.go +++ b/client/middleware/retry/retry.go @@ -4,12 +4,12 @@ import ( "context" "fmt" + "github.com/charmbracelet/log" "github.com/go-faster/errors" "github.com/gotd/td/bin" "github.com/gotd/td/telegram" "github.com/gotd/td/tg" "github.com/gotd/td/tgerr" - "github.com/krau/SaveAny-Bot/common" ) var internalErrors = []string{ @@ -33,7 +33,7 @@ func (r retry) Handle(next tg.Invoker) telegram.InvokeFunc { for retries < r.max { if err := next.Invoke(ctx, input, output); err != nil { if tgerr.Is(err, r.errors...) { - common.Log.Debug("retry middleware", "retries", retries, "error", err) + log.FromContext(ctx).Debug("retry middleware", "retries", retries, "error", err) retries++ continue } diff --git a/userclient/auth.go b/client/user/auth.go similarity index 99% rename from userclient/auth.go rename to client/user/auth.go index d20eb5b..048a5e4 100644 --- a/userclient/auth.go +++ b/client/user/auth.go @@ -1,4 +1,4 @@ -package userclient +package user import ( "strings" diff --git a/userclient/userclient.go b/client/user/userclient.go similarity index 65% rename from userclient/userclient.go rename to client/user/userclient.go index b3b9e25..004e4e5 100644 --- a/userclient/userclient.go +++ b/client/user/userclient.go @@ -1,4 +1,4 @@ -package userclient +package user import ( "context" @@ -7,10 +7,10 @@ import ( "github.com/celestix/gotgproto" "github.com/celestix/gotgproto/ext" "github.com/celestix/gotgproto/sessionMaker" - "github.com/glebarez/sqlite" - "github.com/krau/SaveAny-Bot/common" + "github.com/charmbracelet/log" + "github.com/krau/SaveAny-Bot/client/middleware" "github.com/krau/SaveAny-Bot/config" - "github.com/krau/SaveAny-Bot/userclient/middlewares" + "github.com/ncruces/go-sqlite3/gormlite" ) var UC *gotgproto.Client @@ -26,7 +26,7 @@ func GetCtx() *ext.Context { } func Login(ctx context.Context) (*gotgproto.Client, error) { - common.Log.Debug("Logging in as user client") + log.FromContext(ctx).Debug("Logging in as user client") if UC != nil { return UC, nil } @@ -40,11 +40,11 @@ func Login(ctx context.Context) (*gotgproto.Client, error) { config.Cfg.Telegram.AppHash, gotgproto.ClientTypePhone(""), &gotgproto.ClientOpts{ - Session: sessionMaker.SqlSession(sqlite.Open(config.Cfg.Telegram.Userbot.Session)), - AuthConversator: &termialAuthConversator{}, - // Context: ctx, + Session: sessionMaker.SqlSession(gormlite.Open(config.Cfg.Telegram.Userbot.Session)), + AuthConversator: &termialAuthConversator{}, + Context: ctx, DisableCopyright: true, - Middlewares: middlewares.NewDefaultMiddlewares(ctx, 5*time.Minute), + Middlewares: middleware.NewDefaultMiddlewares(ctx, 5*time.Minute), }, ) if err != nil { @@ -70,10 +70,6 @@ func Login(ctx context.Context) (*gotgproto.Client, error) { return nil, r.err } UC = r.client - // disp := UC.Dispatcher - // disp.AddHandler(handlers.NewAnyUpdate(func(ctx *ext.Context, u *ext.Update) error { - // return dispatcher.EndGroups - // })) return UC, nil } } diff --git a/cmd/geni18n/main.go b/cmd/geni18n/main.go index 91d6e95..1c12130 100644 --- a/cmd/geni18n/main.go +++ b/cmd/geni18n/main.go @@ -1,4 +1,4 @@ -// cmd/gen_i18n/main.go +// cmd/geni18n/main.go package main import ( @@ -14,8 +14,8 @@ import ( ) func main() { - dir := flag.String("dir", "./i18n/locale", "Locales directory path") - out := flag.String("out", "i18n/i18nk/keys.go", "Output file path") + dir := flag.String("dir", "./common/i18n/locale", "Locales directory path") + out := flag.String("out", "common/i18n/i18nk/keys.go", "Output file path") pkg := flag.String("pkg", "i18nk", "Package name for generated file") flag.Parse() diff --git a/cmd/root.go b/cmd/root.go index 79b84eb..4f4f334 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "fmt" "github.com/spf13/cobra" @@ -12,8 +13,8 @@ var rootCmd = &cobra.Command{ Run: Run, } -func Execute() { - if err := rootCmd.Execute(); err != nil { +func Execute(ctx context.Context) { + if err := rootCmd.ExecuteContext(ctx); err != nil { fmt.Println(err) } } diff --git a/cmd/run.go b/cmd/run.go index a2d76e2..5a9a071 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -4,46 +4,78 @@ import ( "context" "fmt" "os" - "os/signal" "path/filepath" - "syscall" + "time" "slices" - "github.com/krau/SaveAny-Bot/bot" - "github.com/krau/SaveAny-Bot/common" + "github.com/charmbracelet/log" + "github.com/krau/SaveAny-Bot/client/bot" + userclient "github.com/krau/SaveAny-Bot/client/user" + "github.com/krau/SaveAny-Bot/common/i18n" + "github.com/krau/SaveAny-Bot/common/i18n/i18nk" + "github.com/krau/SaveAny-Bot/common/utils/fsutil" "github.com/krau/SaveAny-Bot/config" "github.com/krau/SaveAny-Bot/core" - "github.com/krau/SaveAny-Bot/dao" - "github.com/krau/SaveAny-Bot/i18n" - "github.com/krau/SaveAny-Bot/i18n/i18nk" + "github.com/krau/SaveAny-Bot/database" "github.com/krau/SaveAny-Bot/storage" - "github.com/krau/SaveAny-Bot/userclient" "github.com/spf13/cobra" ) -func Run(_ *cobra.Command, _ []string) { - InitAll() - core.Run() +func Run(cmd *cobra.Command, _ []string) { + ctx := cmd.Context() + logger := log.NewWithOptions(os.Stdout, log.Options{ + Level: log.DebugLevel, + ReportTimestamp: true, + TimeFormat: time.TimeOnly, + ReportCaller: true, + }) + ctx = log.WithContext(ctx, logger) - quit := make(chan os.Signal, 1) - signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) - sig := <-quit - common.Log.Info(sig, i18n.T(i18nk.Exiting)) - defer common.Log.Info(i18n.T(i18nk.Bye)) + initAll(ctx) + core.Run(ctx) + + <-ctx.Done() + logger.Info(i18n.T(i18nk.Exiting)) + defer logger.Info(i18n.T(i18nk.Bye)) + cleanCache() +} + +func initAll(ctx context.Context) { + if err := config.Init(ctx); err != nil { + fmt.Println("Failed to load config:", err) + os.Exit(1) + } + logger := log.FromContext(ctx) + i18n.Init(config.Cfg.Lang) + logger.Info(i18n.T(i18nk.Initing)) + if config.Cfg.Telegram.Userbot.Enable { + uc, err := userclient.Login(ctx) + if err != nil { + logger.Fatalf("User client login failed: %s", err) + } + logger.Infof("User client logged in as %s", uc.Self.FirstName) + } + database.Init(ctx) + storage.LoadStorages(ctx) + + bot.Init(ctx) +} + +func cleanCache() { if config.Cfg.NoCleanCache { return } if config.Cfg.Temp.BasePath != "" && !config.Cfg.Stream { if slices.Contains([]string{"/", ".", "\\", ".."}, filepath.Clean(config.Cfg.Temp.BasePath)) { - common.Log.Error(i18n.T(i18nk.InvalidCacheDir, map[string]any{ + log.Error(i18n.T(i18nk.InvalidCacheDir, map[string]any{ "Path": config.Cfg.Temp.BasePath, })) return } currentDir, err := os.Getwd() if err != nil { - common.Log.Error(i18n.T(i18nk.GetWorkdirFailed, map[string]any{ + log.Error(i18n.T(i18nk.GetWorkdirFailed, map[string]any{ "Error": err, })) return @@ -51,42 +83,18 @@ func Run(_ *cobra.Command, _ []string) { cachePath := filepath.Join(currentDir, config.Cfg.Temp.BasePath) cachePath, err = filepath.Abs(cachePath) if err != nil { - common.Log.Error(i18n.T(i18nk.GetCacheAbsPathFailed, map[string]any{ + log.Error(i18n.T(i18nk.GetCacheAbsPathFailed, map[string]any{ "Error": err, })) return } - common.Log.Info(i18n.T(i18nk.CleaningCache, map[string]any{ + log.Info(i18n.T(i18nk.CleaningCache, map[string]any{ "Path": cachePath, })) - if err := common.RemoveAllInDir(cachePath); err != nil { - common.Log.Error(i18n.T(i18nk.CleanCacheFailed, map[string]any{ + if err := fsutil.RemoveAllInDir(cachePath); err != nil { + log.Error(i18n.T(i18nk.CleanCacheFailed, map[string]any{ "Error": err, })) } } } - -func InitAll() { - if err := config.Init(); err != nil { - fmt.Println("Failed to load config:", err) - os.Exit(1) - } - common.InitLogger() - i18n.Init(config.Cfg.Lang) - common.Log.Info(i18n.T(i18nk.Initing)) - dao.Init() - storage.LoadStorages() - common.Init() - if config.Cfg.Telegram.Userbot.Enable { - ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) - defer cancel() - uc, err := userclient.Login(ctx) - if err != nil { - common.Log.Errorf("User client login failed: %s", err) - os.Exit(1) - } - common.Log.Infof("User client logged in as %s", uc.Self.FirstName) - } - bot.Init() -} diff --git a/cmd/version.go b/cmd/version.go index 5dfb7ab..4061093 100644 --- a/cmd/version.go +++ b/cmd/version.go @@ -4,7 +4,7 @@ import ( "fmt" "runtime" - "github.com/krau/SaveAny-Bot/common" + "github.com/krau/SaveAny-Bot/pkg/consts" "github.com/rhysd/go-github-selfupdate/selfupdate" "github.com/blang/semver" @@ -16,7 +16,7 @@ var VersionCmd = &cobra.Command{ Aliases: []string{"v"}, Short: "Print the version number of saveany-bot", Run: func(cmd *cobra.Command, args []string) { - fmt.Printf("saveany-bot version: %s %s/%s\nBuildTime: %s, Commit: %s\n", common.Version, runtime.GOOS, runtime.GOARCH, common.BuildTime, common.GitCommit) + fmt.Printf("saveany-bot version: %s %s/%s\nBuildTime: %s, Commit: %s\n", consts.Version, runtime.GOOS, runtime.GOARCH, consts.BuildTime, consts.GitCommit) }, } @@ -25,14 +25,14 @@ var upgradeCmd = &cobra.Command{ Aliases: []string{"up"}, Short: "Upgrade saveany-bot to the latest version", Run: func(cmd *cobra.Command, args []string) { - v := semver.MustParse(common.Version) + v := semver.MustParse(consts.Version) latest, err := selfupdate.UpdateSelf(v, "krau/SaveAny-Bot") if err != nil { fmt.Println("Binary update failed:", err) return } if latest.Version.Equals(v) { - fmt.Println("Current binary is the latest version", common.Version) + fmt.Println("Current binary is the latest version", consts.Version) } else { fmt.Println("Successfully updated to version", latest.Version) fmt.Println("Release note:\n", latest.ReleaseNotes) diff --git a/common/cache.go b/common/cache.go deleted file mode 100644 index 5a317d2..0000000 --- a/common/cache.go +++ /dev/null @@ -1,38 +0,0 @@ -package common - -import ( - "context" - "time" - - "github.com/eko/gocache/lib/v4/cache" - gocachestore "github.com/eko/gocache/store/go_cache/v4" - gocache "github.com/patrickmn/go-cache" -) - -var Cache *cache.Cache[any] - -func initCache() { - gocacheClient := gocache.New(time.Hour*1, time.Minute*10) - gocacheStore := gocachestore.NewGoCache(gocacheClient) - cacheManager := cache.New[any](gocacheStore) - Cache = cacheManager -} - -func CacheGet[T any](ctx context.Context, key string) (T, error) { - data, err := Cache.Get(ctx, key) - if err != nil { - return *new(T), err - } - if v, ok := data.(T); ok { - return v, nil - } - return *new(T), nil -} - -func CacheSet(ctx context.Context, key string, value any) error { - return Cache.Set(ctx, key, value) -} - -func CacheDelete(ctx context.Context, key string) error { - return Cache.Delete(ctx, key) -} diff --git a/common/cache/ristretto.go b/common/cache/ristretto.go new file mode 100644 index 0000000..978c370 --- /dev/null +++ b/common/cache/ristretto.go @@ -0,0 +1,50 @@ +package cache + +import ( + "fmt" + + "github.com/charmbracelet/log" + "github.com/dgraph-io/ristretto/v2" +) + +var cache *ristretto.Cache[string, any] + + +// TODO: maybe we should use simple ttl cache instead of ristretto... +func init() { + c, err := ristretto.NewCache(&ristretto.Config[string, any]{ + NumCounters: 1e5, + MaxCost: 1e6, // 1000000 / 112 ≈ 8928 + BufferItems: 64, + OnReject: func(item *ristretto.Item[any]) { + log.Warnf("Cache item rejected: key=%d, value=%v", item.Key, item.Value) + }, + }) + if err != nil { + log.Fatalf("failed to create ristretto cache: %v", err) + } + cache = c +} + +func Set(key string, value any) error { + ok := cache.Set(key, value, 0) + if !ok { + return fmt.Errorf("failed to set value in cache") + } + cache.Wait() + return nil +} + +func Get[T any](key string) (T, bool) { + v, ok := cache.Get(key) + if !ok { + var zero T + return zero, false + } + vT, ok := v.(T) + if !ok { + var zero T + return zero, false + } + return vT, true +} diff --git a/common/common.go b/common/common.go deleted file mode 100644 index f15d177..0000000 --- a/common/common.go +++ /dev/null @@ -1,5 +0,0 @@ -package common - -func Init() { - initCache() -} diff --git a/i18n/i18n.go b/common/i18n/i18n.go similarity index 100% rename from i18n/i18n.go rename to common/i18n/i18n.go diff --git a/i18n/i18nk/keys.go b/common/i18n/i18nk/keys.go similarity index 100% rename from i18n/i18nk/keys.go rename to common/i18n/i18nk/keys.go diff --git a/i18n/locale/zh-Hans.toml b/common/i18n/locale/zh-Hans.toml similarity index 100% rename from i18n/locale/zh-Hans.toml rename to common/i18n/locale/zh-Hans.toml diff --git a/common/logger.go b/common/logger.go deleted file mode 100644 index b9b981c..0000000 --- a/common/logger.go +++ /dev/null @@ -1,43 +0,0 @@ -package common - -import ( - "github.com/gookit/slog" - "github.com/gookit/slog/handler" - "github.com/gookit/slog/rotatefile" - "github.com/krau/SaveAny-Bot/config" -) - -var Log *slog.Logger - -func InitLogger() { - if Log != nil { - return - } - Log = slog.New() - logLevel := slog.LevelByName(config.Cfg.Log.Level) - logFilePath := config.Cfg.Log.File - logBackupNum := config.Cfg.Log.BackupCount - var logLevels []slog.Level - for _, level := range slog.AllLevels { - if level <= logLevel { - logLevels = append(logLevels, level) - } - } - tem := "[{{datetime}}] [{{level}}] [{{caller}}] {{message}} {{data}} {{extra}}\n" - consoleH := handler.NewConsoleHandler(logLevels) - consoleH.Formatter().(*slog.TextFormatter).SetTemplate(tem) - Log.AddHandler(consoleH) - if logFilePath != "" && logBackupNum > 0 { - fileH, err := handler.NewTimeRotateFile( - logFilePath, - rotatefile.EveryDay, - handler.WithLogLevels(slog.AllLevels), - handler.WithBackupNum(logBackupNum), - ) - fileH.Formatter().(*slog.TextFormatter).SetTemplate(tem) - if err != nil { - panic(err) - } - Log.AddHandler(fileH) - } -} diff --git a/common/os.go b/common/os.go deleted file mode 100644 index 6003b44..0000000 --- a/common/os.go +++ /dev/null @@ -1,48 +0,0 @@ -package common - -import ( - "os" - "path/filepath" - "time" - - "github.com/krau/SaveAny-Bot/i18n" - "github.com/krau/SaveAny-Bot/i18n/i18nk" -) - -func RmFileAfter(path string, td time.Duration) { - _, err := os.Stat(path) - if err != nil { - Log.Errorf(i18n.T(i18nk.CreateRmTimerFailed, map[string]any{ - "Path": path, - "Error": err, - })) - return - } - Log.Debugf(i18n.T(i18nk.RemoveFileAfter, map[string]any{ - "Duration": td.String(), - "Path": path, - })) - time.AfterFunc(td, func() { - if err := os.Remove(path); err != nil { - Log.Errorf(i18n.T(i18nk.RemoveFileFailed, map[string]any{ - "Path": path, - "Error": err, - })) - } - }) -} - -// 删除目录下的所有内容, 但不删除目录本身 -func RemoveAllInDir(dirPath string) error { - entries, err := os.ReadDir(dirPath) - if err != nil { - return err - } - for _, entry := range entries { - entryPath := filepath.Join(dirPath, entry.Name()) - if err := os.RemoveAll(entryPath); err != nil { - return err - } - } - return nil -} diff --git a/common/tdler/dler.go b/common/tdler/dler.go new file mode 100644 index 0000000..88d51e8 --- /dev/null +++ b/common/tdler/dler.go @@ -0,0 +1,18 @@ +package tdler + +import ( + "github.com/gotd/td/telegram/downloader" + "github.com/krau/SaveAny-Bot/common/utils/dlutil" + "github.com/krau/SaveAny-Bot/config" + "github.com/krau/SaveAny-Bot/pkg/consts/tglimit" + "github.com/krau/SaveAny-Bot/pkg/tfile" +) + +type Client interface { + downloader.Client +} + +func NewDownloader(client Client, file tfile.TGFile) *downloader.Builder { + return downloader.NewDownloader().WithPartSize(tglimit.MaxPartSize). + Download(client, file.Location()).WithThreads(dlutil.BestThreads(file.Size(), config.Cfg.Threads)) +} diff --git a/common/utils.go b/common/utils.go deleted file mode 100644 index 4bc7814..0000000 --- a/common/utils.go +++ /dev/null @@ -1,26 +0,0 @@ -package common - -import ( - "crypto/md5" - "encoding/hex" - "regexp" -) - -func HashString(s string) string { - hash := md5.New() - hash.Write([]byte(s)) - return hex.EncodeToString(hash.Sum(nil)) -} - -var TagRe = regexp.MustCompile(`(?:^|[\p{Zs}\s.,!?(){}[\]<>\"\',。!?():;、])#([\p{L}\d_]+)`) - -func ExtractTagsFromText(text string) []string { - matches := TagRe.FindAllStringSubmatch(text, -1) - tags := make([]string, 0) - for _, match := range matches { - if len(match) > 1 { - tags = append(tags, match[1]) - } - } - return tags -} diff --git a/common/utils/dlutil/dl.go b/common/utils/dlutil/dl.go new file mode 100644 index 0000000..9ad027f --- /dev/null +++ b/common/utils/dlutil/dl.go @@ -0,0 +1,33 @@ +package dlutil + +import "time" + +var threadsLevels = []struct { + threads int + size int64 +}{ + {1, 10 << 20}, + {2, 50 << 20}, + {4, 200 << 20}, + {8, 500 << 20}, +} + +func BestThreads(size int64, max int) int { + for _, thread := range threadsLevels { + if size < thread.size { + return min(thread.threads, max) + } + } + return max +} + +func GetSpeed(downloaded int64, startTime time.Time) float64 { + if startTime.IsZero() { + return 0 + } + elapsed := time.Since(startTime).Seconds() + if elapsed <= 0 { + return 0 + } + return float64(downloaded) / elapsed +} diff --git a/common/utils/fsutil/fs.go b/common/utils/fsutil/fs.go new file mode 100644 index 0000000..9d451c5 --- /dev/null +++ b/common/utils/fsutil/fs.go @@ -0,0 +1,57 @@ +package fsutil + +import ( + "os" + "path/filepath" + + "github.com/gabriel-vasile/mimetype" +) + +// 删除文件夹内的所有文件和子目录, 但不删除文件夹本身 +func RemoveAllInDir(dirPath string) error { + entries, err := os.ReadDir(dirPath) + if err != nil { + return err + } + for _, entry := range entries { + entryPath := filepath.Join(dirPath, entry.Name()) + if err := os.RemoveAll(entryPath); err != nil { + return err + } + } + return nil +} + +func DetectFileExt(fp string) string { + mt, err := mimetype.DetectFile(fp) + if err != nil { + return "" + } + return mt.Extension() +} + +type File struct { + *os.File +} + +func (f *File) Remove() error { + return os.Remove(f.Name()) +} + +func (f *File) CloseAndRemove() error { + if err := f.Close(); err != nil { + return err + } + return f.Remove() +} + +func CreateFile(fp string) (*File, error) { + if err := os.MkdirAll(filepath.Dir(fp), os.ModePerm); err != nil { + return nil, err + } + file, err := os.Create(fp) + if err != nil { + return nil, err + } + return &File{File: file}, nil +} diff --git a/common/utils/ioutil/writer.go b/common/utils/ioutil/writer.go new file mode 100644 index 0000000..f34833a --- /dev/null +++ b/common/utils/ioutil/writer.go @@ -0,0 +1,49 @@ +package ioutil + +import "io" + +type ProgressWriterAt struct { + wrAt io.WriterAt + onWrite func(n int) +} + +func (p *ProgressWriterAt) WriteAt(buf []byte, off int64) (n int, err error) { + n, err = p.wrAt.WriteAt(buf, off) + if n > 0 { + p.onWrite(n) + } + return +} + +func NewProgressWriterAt( + wrAt io.WriterAt, + onWrite func(n int), +) *ProgressWriterAt { + return &ProgressWriterAt{ + wrAt: wrAt, + onWrite: onWrite, + } +} + +type ProgressWriter struct { + wr io.Writer + onWrite func(n int) +} + +func (p *ProgressWriter) Write(buf []byte) (n int, err error) { + n, err = p.wr.Write(buf) + if n > 0 { + p.onWrite(n) + } + return +} + +func NewProgressWriter( + wr io.Writer, + onWrite func(n int), +) *ProgressWriter { + return &ProgressWriter{ + wr: wr, + onWrite: onWrite, + } +} diff --git a/common/utils/strutil/string.go b/common/utils/strutil/string.go new file mode 100644 index 0000000..8eb7f1f --- /dev/null +++ b/common/utils/strutil/string.go @@ -0,0 +1,50 @@ +package strutil + +import ( + "crypto/md5" + "encoding/hex" + "fmt" + "regexp" + "strconv" + "strings" + + "github.com/duke-git/lancet/v2/slice" +) + +func HashString(s string) string { + hash := md5.New() + hash.Write([]byte(s)) + return hex.EncodeToString(hash.Sum(nil)) +} + +var TagRe = regexp.MustCompile(`(?:^|[\p{Zs}\s.,!?(){}[\]<>\"\',。!?():;、])#([\p{L}\d_]+)`) + +func ExtractTagsFromText(text string) []string { + matches := TagRe.FindAllStringSubmatch(text, -1) + tags := make([]string, 0) + for _, match := range matches { + if len(match) > 1 { + tags = append(tags, match[1]) + } + } + return slice.Compact(tags) +} + +func ParseIntStrRange(input string, sep string) (int64, int64, error) { + parts := strings.Split(input, sep) + if len(parts) != 2 { + return 0, 0, fmt.Errorf("invalid range format: %s", input) + } + min, err := strconv.ParseInt(strings.TrimSpace(parts[0]), 10, 64) + if err != nil { + return 0, 0, fmt.Errorf("invalid minimum value: %s", parts[0]) + } + max, err := strconv.ParseInt(strings.TrimSpace(parts[1]), 10, 64) + if err != nil { + return 0, 0, fmt.Errorf("invalid maximum value: %s", parts[1]) + } + if min > max { + min, max = max, min + } + return min, max, nil +} diff --git a/common/utils/tgutil/context.go b/common/utils/tgutil/context.go new file mode 100644 index 0000000..a0ce1f4 --- /dev/null +++ b/common/utils/tgutil/context.go @@ -0,0 +1,22 @@ +package tgutil + +import ( + "context" + + "github.com/celestix/gotgproto/ext" +) + +type contextKey struct{} + +var extKey = contextKey{} + +func ExtFromContext(ctx context.Context) *ext.Context { + if extCtx, ok := ctx.Value(extKey).(*ext.Context); ok { + return extCtx + } + return nil +} + +func ExtWithContext(ctx context.Context, extCtx *ext.Context) context.Context { + return context.WithValue(ctx, extKey, extCtx) +} diff --git a/common/utils/tgutil/message.go b/common/utils/tgutil/message.go new file mode 100644 index 0000000..4b7da85 --- /dev/null +++ b/common/utils/tgutil/message.go @@ -0,0 +1,183 @@ +package tgutil + +import ( + "fmt" + "strconv" + "strings" + + "github.com/celestix/gotgproto/ext" + "github.com/duke-git/lancet/v2/maputil" + "github.com/duke-git/lancet/v2/mathutil" + "github.com/duke-git/lancet/v2/slice" + lcstrutil "github.com/duke-git/lancet/v2/strutil" + "github.com/duke-git/lancet/v2/validator" + "github.com/gabriel-vasile/mimetype" + "github.com/gotd/td/tg" + "github.com/krau/SaveAny-Bot/common/cache" + "github.com/krau/SaveAny-Bot/common/utils/strutil" + "github.com/rs/xid" +) + +func GenFileNameFromMessage(message tg.Message) string { + ext := func(media tg.MessageMediaClass) string { + switch media := media.(type) { + case *tg.MessageMediaDocument: + doc, ok := media.Document.AsNotEmpty() + if !ok { + return "" + } + ext := mimetype.Lookup(doc.MimeType).Extension() + if ext == "" { + return "" + } + return ext + case *tg.MessageMediaPhoto: + return ".jpg" + } + return "" + }(message.Media) + text := strings.TrimSpace(message.GetMessage()) + if text == "" { + return fmt.Sprintf("%d_%s%s", message.GetID(), xid.New().String(), ext) + } + filename := func() string { + tags := strutil.ExtractTagsFromText(text) + if len(tags) > 0 { + tagStrRunes := make([]rune, 0, 64) + for i, tag := range tags { + if i > 0 { + tagStrRunes = append(tagStrRunes, '_') + } + tagStrRunes = append(tagStrRunes, []rune(tag)...) + if len(tagStrRunes) >= 64 { + break + } + } + tagStr := string(tagStrRunes) + return fmt.Sprintf("%s_%s", tagStr, strconv.Itoa(message.GetID())) + } + text = lcstrutil.Substring(strings.Map(func(r rune) rune { + if r < 0x20 || r == 0x7F { + return '_' + } + switch r { + // invalid characters + case '/', '\\', + ':', '*', '?', '"', '<', '>', '|': + return '_' + // empty + case ' ', '\t', '\r', '\n': + return '_' + } + if validator.IsPrintable(string(r)) { + return r + } + return '_' + }, text), 0, 64) + text = strings.Join(strings.FieldsFunc(text, func(r rune) bool { + return r == '_' || r == ' ' + }), "_") + return text + }() + + if filename == "" { + filename = fmt.Sprintf("%d_%s", message.GetID(), xid.New().String()) + } + return filename + ext +} + +func BuildCancelButton(taskID string) tg.KeyboardButtonClass { + return &tg.KeyboardButtonCallback{ + Text: "取消任务", + Data: fmt.Appendf(nil, "cancel %s", taskID), + } +} + +func InputMessageClassSliceFromInt(ids []int) []tg.InputMessageClass { + result := make([]tg.InputMessageClass, 0, len(ids)) + for _, id := range ids { + result = append(result, &tg.InputMessageID{ + ID: id, + }) + } + return result +} + +func GetMessagesRange(ctx *ext.Context, chatID int64, minId, maxId int) ([]*tg.Message, error) { + if minId > maxId { + return nil, fmt.Errorf("minId (%d) cannot be greater than maxId (%d)", minId, maxId) + } + total := maxId - minId + 1 + msgIds := mathutil.Range(minId, total) + toFetchIds := make([]int, 0, total) + cached := make(map[int]*tg.Message, total) + for _, id := range msgIds { + if msg, ok := cache.Get[*tg.Message](fmt.Sprintf("tgmsg:%d:%d:%d", ctx.Self.ID, chatID, id)); ok { + cached[id] = msg + } else { + toFetchIds = append(toFetchIds, id) + } + } + if len(toFetchIds) == 0 { + return maputil.Values(cached), nil + } + + result := make([]*tg.Message, 0, total) + chunks := slice.Chunk(toFetchIds, 100) + for _, chunk := range chunks { + msgs, err := ctx.GetMessages(chatID, InputMessageClassSliceFromInt(chunk)) + if err != nil { + return nil, err + } + if len(msgs) == 0 { + continue + } + for _, msg := range msgs { + if msg == nil { + continue + } + tgMessage, ok := msg.(*tg.Message) + if !ok { + continue + } + if tgMessage.GetID() < minId || tgMessage.GetID() > maxId { + continue + } + result = append(result, tgMessage) + } + } + + for _, msg := range result { + cache.Set(fmt.Sprintf("tgmsg:%d:%d:%d", ctx.Self.ID, chatID, msg.GetID()), msg) + } + for _, msg := range cached { + if msg == nil { + continue + } + result = append(result, msg) + } + return result, nil +} + +func GetMessageByID(ctx *ext.Context, chatID int64, msgID int) (*tg.Message, error) { + key := fmt.Sprintf("tgmsg:%d:%d:%d", ctx.Self.ID, chatID, msgID) + if msg, ok := cache.Get[*tg.Message](key); ok { + return msg, nil + } + msgs, err := ctx.GetMessages(chatID, []tg.InputMessageClass{ + &tg.InputMessageID{ID: msgID}, + }) + if err != nil { + return nil, fmt.Errorf("failed to get message by ID: %w", err) + } + if len(msgs) == 0 { + return nil, fmt.Errorf("message not found: chatID=%d, msgID=%d", chatID, msgID) + } + msg := msgs[0] + tgm, ok := msg.(*tg.Message) + if !ok { + return nil, fmt.Errorf("unexpected message type: %T", msg) + } + cache.Set(key, tgm) + return tgm, nil +} diff --git a/common/utils/tgutil/resolve.go b/common/utils/tgutil/resolve.go new file mode 100644 index 0000000..9fca6eb --- /dev/null +++ b/common/utils/tgutil/resolve.go @@ -0,0 +1,119 @@ +package tgutil + +import ( + "fmt" + "net/url" + "strconv" + "strings" + + "github.com/celestix/gotgproto/ext" + "github.com/duke-git/lancet/v2/validator" + "github.com/gotd/td/tg" +) + +func ParseChatID(ctx *ext.Context, idOrUsername string) (int64, error) { + idOrUsername = strings.TrimPrefix(idOrUsername, "@") + if validator.IsIntStr(idOrUsername) { + chatID, err := strconv.Atoi(idOrUsername) + if err != nil { + return 0, err + } + return int64(chatID), nil + } + username := idOrUsername + peer := ctx.PeerStorage.GetPeerByUsername(username) + if peer != nil && peer.ID != 0 { + return peer.ID, nil + } + chat, err := ctx.ResolveUsername(username) + if err != nil { + return 0, err + } + if chat == nil { + return 0, fmt.Errorf("no chat found for username: %s", idOrUsername) + } + chatID := chat.GetID() + if chatID == 0 { + return 0, fmt.Errorf("chat ID is zero for username: %s", idOrUsername) + } + return chatID, nil +} + +// return: ChatID, MessageID, error +func ParseMessageLink(ctx *ext.Context, link string) (int64, int, error) { + u, err := url.Parse(link) + if err != nil { + return 0, 0, fmt.Errorf("invalid URL: %w", err) + } + paths := strings.Split(strings.TrimPrefix(u.Path, "/"), "/") + + if cmt := u.Query().Get("comment"); cmt != "" { + // 频道评论的消息链接 + // https://t.me/acherkrau/123?comment=2 + chid, err := ParseChatID(ctx, paths[0]) + if err != nil { + return 0, 0, fmt.Errorf("failed to parse chat ID: %w", err) + } + chatfull, err := ctx.GetChat(chid) + if err != nil { + return 0, 0, fmt.Errorf("failed to get chat: %w", err) + } + chfull, ok := chatfull.(*tg.ChannelFull) + if !ok { + return 0, 0, fmt.Errorf("chat is not a channel: %s", chatfull.TypeName()) + } + linkChatId, ok := chfull.GetLinkedChatID() + if !ok { + return 0, 0, fmt.Errorf("channel has no linked chat") + } + msgID, err := strconv.Atoi(cmt) + if err != nil { + return 0, 0, fmt.Errorf("failed to parse comment ID: %w", err) + } + return linkChatId, msgID, nil + } + + switch len(paths) { + case 2: // https://t.me/acherkrau/123 + chatID, err := ParseChatID(ctx, paths[0]) + if err != nil { + return 0, 0, fmt.Errorf("failed to parse chat ID: %w", err) + } + msgID, err := strconv.Atoi(paths[1]) + if err != nil { + return 0, 0, fmt.Errorf("failed to parse message ID: %w", err) + } + return chatID, msgID, nil + case 3: + // https://t.me/c/123456789/123 + // https://t.me/acherkrau/123/456 , 456: message thread ID + chatPart, msgPart := paths[1], paths[2] + if paths[0] != "c" { + chatPart = paths[0] + } + chatID, err := ParseChatID(ctx, chatPart) + if err != nil { + return 0, 0, fmt.Errorf("failed to parse chat ID: %w", err) + } + msgID, err := strconv.Atoi(msgPart) + if err != nil { + return 0, 0, fmt.Errorf("failed to parse message ID: %w", err) + } + return chatID, msgID, nil + case 4: + // https://t.me/c/123456789/111/456 111: topic id + if paths[0] != "c" { + return 0, 0, fmt.Errorf("invalid message link format: %s", link) + } + chatID, err := ParseChatID(ctx, paths[1]) + if err != nil { + return 0, 0, fmt.Errorf("failed to parse chat ID: %w", err) + } + msgID, err := strconv.Atoi(paths[3]) + if err != nil { + return 0, 0, fmt.Errorf("failed to parse message ID: %w", err) + } + return chatID, msgID, nil + } + return 0, 0, fmt.Errorf("invalid message link format: %s", link) +} diff --git a/common/utils/tphutil/tph.go b/common/utils/tphutil/tph.go new file mode 100644 index 0000000..86df916 --- /dev/null +++ b/common/utils/tphutil/tph.go @@ -0,0 +1,51 @@ +package tphutil + +import ( + "encoding/json" + + "github.com/krau/SaveAny-Bot/config" + "github.com/krau/SaveAny-Bot/pkg/telegraph" +) + +var tphClient *telegraph.Client + +func DefaultClient() *telegraph.Client { + if tphClient != nil { + return tphClient + } + if config.Cfg.Telegram.Proxy.Enable && config.Cfg.Telegram.Proxy.URL != "" { + proxyUrl := config.Cfg.Telegram.Proxy.URL + var err error + tphClient, err = telegraph.NewClientWithProxy(proxyUrl) + if err != nil { + tphClient = telegraph.NewClient() + } + } else { + tphClient = telegraph.NewClient() + } + return tphClient +} + +func GetNodeImages(node telegraph.Node) []string { + var srcs []string + + var nodeElement telegraph.NodeElement + data, err := json.Marshal(node) + if err != nil { + return srcs + } + err = json.Unmarshal(data, &nodeElement) + if err != nil { + return srcs + } + + if nodeElement.Tag == "img" { + if src, exists := nodeElement.Attrs["src"]; exists { + srcs = append(srcs, src) + } + } + for _, child := range nodeElement.Children { + srcs = append(srcs, GetNodeImages(child)...) + } + return srcs +} diff --git a/common/version.go b/common/version.go deleted file mode 100644 index 4d889ed..0000000 --- a/common/version.go +++ /dev/null @@ -1,7 +0,0 @@ -package common - -var ( - Version string = "dev" - BuildTime string = "unknown" - GitCommit string = "unknown" -) diff --git a/config/storage/alist.go b/config/storage/alist.go index b946bf5..9449b24 100644 --- a/config/storage/alist.go +++ b/config/storage/alist.go @@ -3,7 +3,7 @@ package storage import ( "fmt" - "github.com/krau/SaveAny-Bot/types" + storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage" ) type AlistStorageConfig struct { @@ -29,8 +29,8 @@ func (a *AlistStorageConfig) Validate() error { return nil } -func (a *AlistStorageConfig) GetType() types.StorageType { - return types.StorageTypeAlist +func (a *AlistStorageConfig) GetType() storenum.StorageType { + return storenum.Alist } func (a *AlistStorageConfig) GetName() string { diff --git a/config/storage/factory.go b/config/storage/factory.go index 6dd8202..5d81dbc 100644 --- a/config/storage/factory.go +++ b/config/storage/factory.go @@ -4,16 +4,17 @@ import ( "fmt" "reflect" - "github.com/krau/SaveAny-Bot/types" + storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage" "github.com/mitchellh/mapstructure" "github.com/spf13/viper" ) -var storageFactories = map[types.StorageType]func(cfg *BaseConfig) (StorageConfig, error){ - types.StorageTypeLocal: createStorageConfig(&LocalStorageConfig{}), - types.StorageTypeAlist: createStorageConfig(&AlistStorageConfig{}), - types.StorageTypeWebdav: createStorageConfig(&WebdavStorageConfig{}), - types.StorageTypeMinio: createStorageConfig(&MinioStorageConfig{}), +var storageFactories = map[storenum.StorageType]func(cfg *BaseConfig) (StorageConfig, error){ + storenum.Local: createStorageConfig(&LocalStorageConfig{}), + storenum.Alist: createStorageConfig(&AlistStorageConfig{}), + storenum.Webdav: createStorageConfig(&WebdavStorageConfig{}), + storenum.Minio: createStorageConfig(&MinioStorageConfig{}), + storenum.Telegram: createStorageConfig(&TelegramStorageConfig{}), } func createStorageConfig(configType StorageConfig) func(cfg *BaseConfig) (StorageConfig, error) { @@ -41,8 +42,12 @@ func LoadStorageConfigs(v *viper.Viper) ([]StorageConfig, error) { if !baseCfg.Enable { continue } + st, err := storenum.ParseStorageType(baseCfg.Type) + if err != nil { + return nil, fmt.Errorf("invalid storage type %s for %s: %w", baseCfg.Type, baseCfg.Name, err) + } - factory, ok := storageFactories[types.StorageType(baseCfg.Type)] + factory, ok := storageFactories[st] if !ok { return nil, fmt.Errorf("unsupported storage type: %s", baseCfg.Type) } diff --git a/config/storage/local.go b/config/storage/local.go index a77b859..46fbf29 100644 --- a/config/storage/local.go +++ b/config/storage/local.go @@ -3,7 +3,7 @@ package storage import ( "fmt" - "github.com/krau/SaveAny-Bot/types" + storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage" ) type LocalStorageConfig struct { @@ -18,8 +18,8 @@ func (l *LocalStorageConfig) Validate() error { return nil } -func (l *LocalStorageConfig) GetType() types.StorageType { - return types.StorageTypeLocal +func (l *LocalStorageConfig) GetType() storenum.StorageType { + return storenum.Local } func (l *LocalStorageConfig) GetName() string { diff --git a/config/storage/minio.go b/config/storage/minio.go index 98807a3..8e9cd20 100644 --- a/config/storage/minio.go +++ b/config/storage/minio.go @@ -3,7 +3,7 @@ package storage import ( "fmt" - "github.com/krau/SaveAny-Bot/types" + storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage" ) type MinioStorageConfig struct { @@ -32,8 +32,8 @@ func (m *MinioStorageConfig) Validate() error { return nil } -func (m *MinioStorageConfig) GetType() types.StorageType { - return types.StorageTypeMinio +func (m *MinioStorageConfig) GetType() storenum.StorageType { + return storenum.Minio } func (m *MinioStorageConfig) GetName() string { diff --git a/config/storage/telegram.go b/config/storage/telegram.go new file mode 100644 index 0000000..8487235 --- /dev/null +++ b/config/storage/telegram.go @@ -0,0 +1,32 @@ +package storage + +import ( + "fmt" + + storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage" +) + +type TelegramStorageConfig struct { + BaseConfig + ChatID int64 `toml:"chat_id" mapstructure:"chat_id" json:"chat_id"` + RateLimit int `toml:"rate_limit" mapstructure:"rate_limit" json:"rate_limit"` + RateBurst int `toml:"rate_burst" mapstructure:"rate_burst" json:"rate_burst"` +} + +func (m *TelegramStorageConfig) Validate() error { + if m.ChatID == 0 { + return fmt.Errorf("chat_id is required for telegram storage") + } + if m.RateLimit < 0 || m.RateBurst < 0 { + return fmt.Errorf("rate_limit and rate_burst must be greater than 0 for telegram storage") + } + return nil +} + +func (m *TelegramStorageConfig) GetType() storenum.StorageType { + return storenum.Telegram +} + +func (m *TelegramStorageConfig) GetName() string { + return m.Name +} diff --git a/config/storage/types.go b/config/storage/types.go index e3579ad..73d3bdc 100644 --- a/config/storage/types.go +++ b/config/storage/types.go @@ -1,10 +1,12 @@ package storage -import "github.com/krau/SaveAny-Bot/types" +import ( + storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage" +) type StorageConfig interface { Validate() error - GetType() types.StorageType + GetType() storenum.StorageType GetName() string } diff --git a/config/storage/webdav.go b/config/storage/webdav.go index f542965..93aaac5 100644 --- a/config/storage/webdav.go +++ b/config/storage/webdav.go @@ -3,7 +3,7 @@ package storage import ( "fmt" - "github.com/krau/SaveAny-Bot/types" + storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage" ) type WebdavStorageConfig struct { @@ -27,8 +27,8 @@ func (w *WebdavStorageConfig) Validate() error { return nil } -func (w *WebdavStorageConfig) GetType() types.StorageType { - return types.StorageTypeWebdav +func (w *WebdavStorageConfig) GetType() storenum.StorageType { + return storenum.Webdav } func (w *WebdavStorageConfig) GetName() string { diff --git a/config/viper.go b/config/viper.go index 4b4134c..f8c065a 100644 --- a/config/viper.go +++ b/config/viper.go @@ -1,15 +1,16 @@ package config import ( + "context" "errors" "fmt" "os" "strings" "github.com/duke-git/lancet/v2/slice" + "github.com/krau/SaveAny-Bot/common/i18n" + "github.com/krau/SaveAny-Bot/common/i18n/i18nk" "github.com/krau/SaveAny-Bot/config/storage" - "github.com/krau/SaveAny-Bot/i18n" - "github.com/krau/SaveAny-Bot/i18n/i18nk" "github.com/spf13/viper" ) @@ -53,7 +54,6 @@ type telegramConfig struct { AppHash string `toml:"app_hash" mapstructure:"app_hash" json:"app_hash"` Timeout int `toml:"timeout" mapstructure:"timeout" json:"timeout"` Proxy proxyConfig `toml:"proxy" mapstructure:"proxy"` - FloodRetry int `toml:"flood_retry" mapstructure:"flood_retry" json:"flood_retry"` RpcRetry int `toml:"rpc_retry" mapstructure:"rpc_retry" json:"rpc_retry"` Userbot userbotConfig `toml:"userbot" mapstructure:"userbot" json:"userbot"` } @@ -79,7 +79,7 @@ func (c Config) GetStorageByName(name string) storage.StorageConfig { return nil } -func Init() error { +func Init(ctx context.Context) error { viper.SetConfigName("config") viper.AddConfigPath(".") viper.AddConfigPath("/etc/saveany/") diff --git a/core/batchtftask/execute.go b/core/batchtftask/execute.go new file mode 100644 index 0000000..3affcb9 --- /dev/null +++ b/core/batchtftask/execute.go @@ -0,0 +1,122 @@ +package batchtftask + +import ( + "context" + "fmt" + "io" + "os" + "path" + + "github.com/charmbracelet/log" + "github.com/duke-git/lancet/v2/retry" + "github.com/krau/SaveAny-Bot/common/tdler" + "github.com/krau/SaveAny-Bot/common/utils/fsutil" + "github.com/krau/SaveAny-Bot/common/utils/ioutil" + "github.com/krau/SaveAny-Bot/config" + "github.com/krau/SaveAny-Bot/pkg/enums/key" + "golang.org/x/sync/errgroup" +) + +func (t *Task) Execute(ctx context.Context) error { + logger := log.FromContext(ctx).WithPrefix(fmt.Sprintf("batch_file[%s]", t.ID)) + logger.Info("Starting batch file task") + t.Progress.OnStart(ctx, t) + workers := config.Cfg.Workers + eg, gctx := errgroup.WithContext(ctx) + eg.SetLimit(workers) + for _, elem := range t.Elems { + elem := elem + eg.Go(func() error { + if t.processing[elem.ID] != nil { + return fmt.Errorf("element with ID %s is already being processed", elem.ID) + } + t.processing[elem.ID] = &elem + defer func() { + delete(t.processing, elem.ID) + }() + return t.processElement(gctx, elem) + }) + } + err := eg.Wait() + if err != nil { + logger.Errorf("Error during batch file processing: %v", err) + } else { + logger.Info("Batch file task completed successfully") + } + t.Progress.OnDone(ctx, t, err) + return err +} + +func (t *Task) processElement(ctx context.Context, elem TaskElement) error { + logger := log.FromContext(ctx).WithPrefix(fmt.Sprintf("file[%s]", elem.File.Name())) + if elem.stream { + pr, pw := io.Pipe() + defer pr.Close() + errg, uploadCtx := errgroup.WithContext(ctx) + errg.Go(func() error { + return elem.Storage.Save(uploadCtx, pr, elem.Path) + }) + wr := ioutil.NewProgressWriter(pw, func(n int) { + t.downloaded.Add(int64(n)) + t.Progress.OnProgress(ctx, t) + }) + errg.Go(func() error { + logger.Info("Starting file download in stream mode") + _, err := tdler.NewDownloader(t.client, elem.File).Stream(uploadCtx, wr) + if closeErr := pw.CloseWithError(err); closeErr != nil { + logger.Errorf("Failed to close pipe writer: %v", closeErr) + } + return err + }) + if err := errg.Wait(); err != nil { + return fmt.Errorf("failed to download file in stream mode: %w", err) + } + logger.Info("File downloaded successfully in stream mode") + return nil + } + logger.Info("Starting file download") + localFile, err := fsutil.CreateFile(elem.localPath) + if err != nil { + return fmt.Errorf("failed to create local file: %w", err) + } + defer func() { + if err := localFile.CloseAndRemove(); err != nil { + logger.Errorf("Failed to close local file: %v", err) + } + }() + wrAt := ioutil.NewProgressWriterAt(localFile, func(n int) { + t.downloaded.Add(int64(n)) + t.Progress.OnProgress(ctx, t) + }) + _, err = tdler.NewDownloader(t.client, elem.File).Parallel(ctx, wrAt) + if err != nil { + return fmt.Errorf("failed to download file: %w", err) + } + logger.Info("File downloaded successfully") + if path.Ext(elem.FileName()) == "" { + ext := fsutil.DetectFileExt(elem.localPath) + if ext != "" { + elem.Path = elem.Path + ext + } + } + var fileStat os.FileInfo + fileStat, err = os.Stat(elem.localPath) + if err != nil { + return fmt.Errorf("failed to get file stat: %w", err) + } + vctx := context.WithValue(ctx, key.ContextKeyContentLength, fileStat.Size()) + err = retry.Retry(func() error { + var file *os.File + file, err = os.Open(elem.localPath) + if err != nil { + return fmt.Errorf("failed to open cache file: %w", err) + } + defer file.Close() + if err = elem.Storage.Save(vctx, file, elem.Path); err != nil { + logger.Errorf("Failed to save file: %s, retrying...", err) + return err + } + return nil + }, retry.Context(vctx), retry.RetryTimes(uint(config.Cfg.Retry))) + return err +} diff --git a/core/batchtftask/progress.go b/core/batchtftask/progress.go new file mode 100644 index 0000000..25ec691 --- /dev/null +++ b/core/batchtftask/progress.go @@ -0,0 +1,176 @@ +package batchtftask + +import ( + "context" + "errors" + "fmt" + "strconv" + "sync/atomic" + "time" + + "github.com/charmbracelet/log" + "github.com/duke-git/lancet/v2/slice" + "github.com/gotd/td/telegram/message/entity" + "github.com/gotd/td/telegram/message/styling" + "github.com/gotd/td/tg" + "github.com/krau/SaveAny-Bot/common/utils/dlutil" + "github.com/krau/SaveAny-Bot/common/utils/tgutil" +) + +type ProgressTracker interface { + OnStart(ctx context.Context, info TaskInfo) + OnProgress(ctx context.Context, info TaskInfo) + OnDone(ctx context.Context, info TaskInfo, err error) +} + +type Progress struct { + MessageID int + ChatID int64 + start time.Time + lastUpdatePercent atomic.Int32 +} + +func (p *Progress) OnStart(ctx context.Context, info TaskInfo) { + p.start = time.Now() + p.lastUpdatePercent.Store(0) + log.FromContext(ctx).Debugf("Batch task progress tracking started for message %d in chat %d", p.MessageID, p.ChatID) + entityBuilder := entity.Builder{} + var entities []tg.MessageEntityClass + if err := styling.Perform(&entityBuilder, + styling.Plain("开始执行批量下载任务\n总大小: "), + styling.Code(fmt.Sprintf("%.2f MB (%d个文件)", float64(info.TotalSize())/(1024*1024), info.Count())), + ); err != nil { + log.FromContext(ctx).Errorf("Failed to build entities: %s", err) + return + } + text, entities := entityBuilder.Complete() + req := &tg.MessagesEditMessageRequest{ + ID: p.MessageID, + } + req.SetMessage(text) + req.SetEntities(entities) + req.SetReplyMarkup(&tg.ReplyInlineMarkup{ + Rows: []tg.KeyboardButtonRow{ + { + Buttons: []tg.KeyboardButtonClass{ + tgutil.BuildCancelButton(info.TaskID()), + }, + }, + }}, + ) + ext := tgutil.ExtFromContext(ctx) + if ext != nil { + ext.EditMessage(p.ChatID, req) + return + } +} + +func (p *Progress) OnProgress(ctx context.Context, info TaskInfo) { + if !shouldUpdateProgress(info.TotalSize(), info.Downloaded(), int(p.lastUpdatePercent.Load())) { + return + } + percent := int((info.Downloaded() * 100) / info.TotalSize()) + if p.lastUpdatePercent.Load() == int32(percent) { + return + } + p.lastUpdatePercent.Store(int32(percent)) + log.FromContext(ctx).Debugf("Progress update: %s, %d/%d", info.TaskID(), info.Downloaded(), info.TotalSize()) + entityBuilder := entity.Builder{} + var entities []tg.MessageEntityClass + if err := styling.Perform(&entityBuilder, + styling.Plain("正在处理批量下载任务\n总大小: "), + styling.Code(fmt.Sprintf("%.2f MB (%d个文件)", float64(info.TotalSize())/(1024*1024), info.Count())), + styling.Plain("\n正在处理:\n"), + func() styling.StyledTextOption { + var lines []string + for _, elem := range info.Processing() { + lines = append(lines, fmt.Sprintf(" - %s (%.2f MB)", elem.FileName(), float64(elem.FileSize())/(1024*1024))) + } + if len(lines) == 0 { + lines = append(lines, " - 无") + } + return styling.Plain(slice.Join(lines, "\n")) + }(), + styling.Plain("\n平均速度: "), + styling.Bold(fmt.Sprintf("%.2f MB/s", dlutil.GetSpeed(info.Downloaded(), p.start)/(1024*1024))), + styling.Plain("\n当前进度: "), + styling.Bold(fmt.Sprintf("%.2f%%", float64(info.Downloaded())/float64(info.TotalSize())*100)), + ); err != nil { + log.FromContext(ctx).Errorf("Failed to build entities: %s", err) + return + } + text, entities := entityBuilder.Complete() + req := &tg.MessagesEditMessageRequest{ + ID: p.MessageID, + } + req.SetMessage(text) + req.SetEntities(entities) + req.SetReplyMarkup(&tg.ReplyInlineMarkup{ + Rows: []tg.KeyboardButtonRow{ + { + Buttons: []tg.KeyboardButtonClass{ + tgutil.BuildCancelButton(info.TaskID()), + }, + }, + }}, + ) + ext := tgutil.ExtFromContext(ctx) + if ext != nil { + ext.EditMessage(p.ChatID, req) + return + } +} + +func (p *Progress) OnDone(ctx context.Context, info TaskInfo, err error) { + if err != nil { + log.FromContext(ctx).Errorf("Batch task %s failed: %s", info.TaskID(), err) + } else { + log.FromContext(ctx).Debugf("Batch task %s completed successfully", info.TaskID()) + } + entityBuilder := entity.Builder{} + var stylingErr error + + if err != nil { + if errors.Is(err, context.Canceled) { + stylingErr = styling.Perform(&entityBuilder, + styling.Plain("任务已取消"), + ) + } else { + stylingErr = styling.Perform(&entityBuilder, + styling.Plain("处理失败, 错误:\n "), + styling.Code(err.Error()), + ) + } + } else { + stylingErr = styling.Perform(&entityBuilder, + styling.Plain("处理完成\n文件数: "), + styling.Code(strconv.Itoa(info.Count())), + styling.Plain("\n总大小: "), + styling.Code(fmt.Sprintf("%.2f MB", float64(info.TotalSize())/(1024*1024))), + ) + } + + if stylingErr != nil { + log.FromContext(ctx).Errorf("Failed to build entities: %s", stylingErr) + return + } + + text, entities := entityBuilder.Complete() + req := &tg.MessagesEditMessageRequest{ + ID: p.MessageID, + } + req.SetMessage(text) + req.SetEntities(entities) + + ext := tgutil.ExtFromContext(ctx) + if ext != nil { + ext.EditMessage(p.ChatID, req) + } +} + +func NewProgressTracker(messageID int, chatID int64) ProgressTracker { + return &Progress{ + MessageID: messageID, + ChatID: chatID, + } +} diff --git a/core/batchtftask/task.go b/core/batchtftask/task.go new file mode 100644 index 0000000..6edfce3 --- /dev/null +++ b/core/batchtftask/task.go @@ -0,0 +1,94 @@ +package batchtftask + +import ( + "context" + "fmt" + "path/filepath" + "sync/atomic" + + "github.com/krau/SaveAny-Bot/common/tdler" + "github.com/krau/SaveAny-Bot/config" + "github.com/krau/SaveAny-Bot/pkg/tfile" + "github.com/krau/SaveAny-Bot/storage" + "github.com/rs/xid" +) + +type TaskElement struct { + ID string + Storage storage.Storage + Path string + File tfile.TGFile + localPath string + stream bool +} + +type Task struct { + ID string + Ctx context.Context + Elems []TaskElement + Progress ProgressTracker + IgnoreErrors bool // if true, errors during processing will be ignored + downloaded atomic.Int64 + client tdler.Client + totalSize int64 + processing map[string]TaskElementInfo + failed map[string]error // errors for each element +} + +func NewTaskElement( + stor storage.Storage, + path string, + file tfile.TGFile, +) (*TaskElement, error) { + id := xid.New().String() + _, ok := stor.(storage.StorageCannotStream) + if !config.Cfg.Stream || ok { + cachePath, err := filepath.Abs(filepath.Join(config.Cfg.Temp.BasePath, fmt.Sprintf("%s_%s", id, file.Name()))) + if err != nil { + return nil, fmt.Errorf("failed to get absolute path for cache: %w", err) + } + return &TaskElement{ + ID: id, + Storage: stor, + Path: path, + File: file, + localPath: cachePath, + }, nil + } + return &TaskElement{ + ID: id, + Storage: stor, + Path: path, + File: file, + stream: true, + }, nil +} + +func NewBatchTGFileTask( + id string, + ctx context.Context, + files []TaskElement, + client tdler.Client, + progress ProgressTracker, + ignoreErrors bool, +) *Task { + task := &Task{ + ID: id, + Ctx: ctx, + client: client, + Elems: files, + Progress: progress, + downloaded: atomic.Int64{}, + totalSize: func() int64 { + var total int64 + for _, elem := range files { + total += elem.File.Size() + } + return total + }(), + processing: make(map[string]TaskElementInfo), + IgnoreErrors: ignoreErrors, + failed: make(map[string]error), + } + return task +} diff --git a/core/batchtftask/taskinfo.go b/core/batchtftask/taskinfo.go new file mode 100644 index 0000000..c702da5 --- /dev/null +++ b/core/batchtftask/taskinfo.go @@ -0,0 +1,56 @@ +package batchtftask + +type TaskElementInfo interface { + FileName() string + FileSize() int64 + StoragePath() string + StorageName() string +} + +func (e *TaskElement) FileName() string { + return e.File.Name() +} + +func (e *TaskElement) FileSize() int64 { + return e.File.Size() +} + +func (e *TaskElement) StoragePath() string { + return e.Path +} + +func (e *TaskElement) StorageName() string { + return e.Storage.Name() +} + +type TaskInfo interface { + TaskID() string + TotalSize() int64 + Downloaded() int64 + Count() int + Processing() []TaskElementInfo +} + +func (t *Task) TaskID() string { + return t.ID +} + +func (t *Task) TotalSize() int64 { + return t.totalSize +} + +func (t *Task) Downloaded() int64 { + return t.downloaded.Load() +} + +func (t *Task) Count() int { + return len(t.Elems) +} + +func (t *Task) Processing() []TaskElementInfo { + processing := make([]TaskElementInfo, 0, len(t.Elems)) + for _, elem := range t.processing { + processing = append(processing, elem) + } + return processing +} diff --git a/core/batchtftask/utils.go b/core/batchtftask/utils.go new file mode 100644 index 0000000..9c38ecf --- /dev/null +++ b/core/batchtftask/utils.go @@ -0,0 +1,32 @@ +package batchtftask + +var progressUpdatesLevels = []struct { + size int64 // 文件大小阈值 + stepPercent int // 每多少 % 更新一次 +}{ + {10 << 20, 100}, + {50 << 20, 20}, + {200 << 20, 10}, + {500 << 20, 5}, +} + +func shouldUpdateProgress(total, downloaded int64, lastUpdatePercent int) bool { + if total <= 0 || downloaded <= 0 { + return false + } + + percent := int((downloaded * 100) / total) + if percent <= lastUpdatePercent { + return false + } + + step := progressUpdatesLevels[len(progressUpdatesLevels)-1].stepPercent + for _, lvl := range progressUpdatesLevels { + if total < lvl.size { + step = lvl.stepPercent + break + } + } + + return percent >= lastUpdatePercent+step +} diff --git a/core/core.go b/core/core.go index c580e1a..588b13e 100644 --- a/core/core.go +++ b/core/core.go @@ -2,92 +2,62 @@ package core import ( "context" - "errors" - "fmt" - "github.com/celestix/gotgproto/ext" - "github.com/gotd/td/telegram/downloader" - "github.com/gotd/td/tg" - "github.com/krau/SaveAny-Bot/common" + "github.com/charmbracelet/log" "github.com/krau/SaveAny-Bot/config" - "github.com/krau/SaveAny-Bot/queue" - "github.com/krau/SaveAny-Bot/types" + "github.com/krau/SaveAny-Bot/pkg/queue" ) -var Downloader *downloader.Downloader +var queueInstance *queue.TaskQueue[Exectable] -func init() { - Downloader = downloader.NewDownloader().WithPartSize(1024 * 1024) +type Exectable interface { + TaskID() string + Execute(ctx context.Context) error } -func worker(queue *queue.TaskQueue, semaphore chan struct{}) { +func worker(ctx context.Context, qe *queue.TaskQueue[Exectable], semaphore chan struct{}) { for { semaphore <- struct{}{} - task := queue.GetTask() - common.Log.Debugf("Got task: %s", task.String()) - - switch task.Status { - case types.Pending: - common.Log.Infof("Processing task: %s", task.String()) - if err := processPendingTask(task); err != nil { - task.Error = err - if errors.Is(err, context.Canceled) { - task.Status = types.Canceled - } else { - common.Log.Errorf("Failed to do task: %s", err) - task.Status = types.Failed - } - } else { - task.Status = types.Succeeded - } - queue.AddTask(task) - case types.Succeeded: - common.Log.Infof("Task succeeded: %s", task.String()) - extCtx, ok := task.Ctx.(*ext.Context) - if !ok { - common.Log.Errorf("Context is not *ext.Context: %T", task.Ctx) - } else if task.ReplyMessageID != 0 { - extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ - Message: fmt.Sprintf("文件保存成功\n [%s]: %s", task.StorageName, task.StoragePath), - ID: task.ReplyMessageID, - }) - } - case types.Failed: - common.Log.Errorf("Task failed: %s", task.String()) - extCtx, ok := task.Ctx.(*ext.Context) - if !ok { - common.Log.Errorf("Context is not *ext.Context: %T", task.Ctx) - } else if task.ReplyMessageID != 0 { - extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ - Message: "文件保存失败\n" + task.Error.Error(), - ID: task.ReplyMessageID, - }) - } - case types.Canceled: - common.Log.Infof("Task canceled: %s", task.String()) - extCtx, ok := task.Ctx.(*ext.Context) - if !ok { - common.Log.Errorf("Context is not *ext.Context: %T", task.Ctx) - } else if task.ReplyMessageID != 0 { - extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ - Message: "任务已取消", - ID: task.ReplyMessageID, - }) - } - default: - common.Log.Errorf("Unknown task status: %s", task.Status) + qtask, err := qe.Get() + if err != nil { + break // queue closed and empty } + log.FromContext(ctx).Infof("Processing task: %s", qtask.ID) + task := qtask.Data + if err := task.Execute(qtask.Context()); err != nil { + log.FromContext(ctx).Errorf("Failed to execute task %s: %v", qtask.ID, err) + } else { + log.FromContext(ctx).Infof("Task %s completed successfully", qtask.ID) + } + qe.Done(qtask.ID) <-semaphore - common.Log.Debugf("Task done: %s; status: %s", task.String(), task.Status) - queue.DoneTask(task) } } -func Run() { - common.Log.Info("Start processing tasks...") +func Run(ctx context.Context) { + log.FromContext(ctx).Info("Start processing tasks...") semaphore := make(chan struct{}, config.Cfg.Workers) - for i := 0; i < config.Cfg.Workers; i++ { - go worker(queue.Queue, semaphore) + if queueInstance == nil { + queueInstance = queue.NewTaskQueue[Exectable]() + } + for range config.Cfg.Workers { + go worker(ctx, queueInstance, semaphore) } } + +func AddTask(ctx context.Context, task Exectable) error { + return queueInstance.Add(queue.NewTask(ctx, task.TaskID(), task)) +} + +func CancelTask(ctx context.Context, id string) error { + err := queueInstance.CancelTask(id) + return err +} + +func GetLength(ctx context.Context) int { + if queueInstance == nil { + return 0 + } + return queueInstance.ActiveLength() +} diff --git a/core/download.go b/core/download.go deleted file mode 100644 index 92f15e1..0000000 --- a/core/download.go +++ /dev/null @@ -1,291 +0,0 @@ -package core - -import ( - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "path" - "path/filepath" - "strings" - "time" - - "github.com/celestix/gotgproto/ext" - "github.com/celestix/telegraph-go/v2" - "github.com/duke-git/lancet/v2/fileutil" - "github.com/gotd/td/telegram/message/entity" - "github.com/gotd/td/telegram/message/styling" - "github.com/gotd/td/tg" - "github.com/krau/SaveAny-Bot/bot" - "github.com/krau/SaveAny-Bot/common" - "github.com/krau/SaveAny-Bot/config" - "github.com/krau/SaveAny-Bot/storage" - "github.com/krau/SaveAny-Bot/types" - "github.com/krau/SaveAny-Bot/userclient" - "golang.org/x/sync/errgroup" -) - -func processPendingTask(task *types.Task) error { - common.Log.Infof("Start processing task: %s", task.String()) - - if task.FileName() == "" { - task.File.FileName = fmt.Sprintf("%d_%d_%s", task.FileChatID, task.FileMessageID, task.File.Hash()) - } - - taskStorage, storagePath, err := getStorageAndPathForTask(task) - if err != nil { - return err - } - if taskStorage == nil { - return fmt.Errorf("not found storage: %s", task.StorageName) - } - task.StoragePath = storagePath - - ctx, ok := task.Ctx.(*ext.Context) - if !ok { - return fmt.Errorf("context is not *ext.Context: %T", task.Ctx) - } - - cancelCtx, cancel := context.WithCancel(ctx) - task.Cancel = cancel - - if task.IsTelegraph { - return processTelegraph(ctx, cancelCtx, task, taskStorage) - } - - if task.File.FileSize == 0 { - return processPhoto(task, taskStorage) - } - api := bot.Client.API() - if task.UseUserClient && userclient.UC != nil { - api = userclient.UC.API() - } - downloadBuilder := Downloader.Download(api, task.File.Location).WithThreads(getTaskThreads(task.File.FileSize)) - - notsupportStreamStorage, notsupportStream := taskStorage.(storage.StorageNotSupportStream) - cancelMarkUp := getCancelTaskMarkup(task) - - if config.Cfg.Stream { - if !notsupportStream { - text, entities := buildProgressMessageEntity(task, 0, task.StartTime, 0) - if task.ReplyMessageID != 0 { - ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ - Message: text, - Entities: entities, - ID: task.ReplyMessageID, - ReplyMarkup: cancelMarkUp, - }) - } - - pr, pw := io.Pipe() - defer pr.Close() - - task.StartTime = time.Now() - progressCallback := buildProgressCallback(ctx, task, getProgressUpdateCount(task.File.FileSize)) - - progressStream := NewProgressStream(pw, task.File.FileSize, progressCallback) - - eg, uploadCtx := errgroup.WithContext(cancelCtx) - - eg.Go(func() error { - return taskStorage.Save(uploadCtx, pr, task.StoragePath) - }) - eg.Go(func() error { - _, err := downloadBuilder.Stream(uploadCtx, progressStream) - if closeErr := pw.CloseWithError(err); closeErr != nil { - common.Log.Errorf("Failed to close pipe writer: %v", closeErr) - } - return err - }) - if err := eg.Wait(); err != nil { - return err - } - - return nil - } - common.Log.Warnf("存储 %s 不支持流式传输: %s", task.StorageName, notsupportStreamStorage.NotSupportStream()) - - if task.ReplyMessageID != 0 { - ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ - Message: fmt.Sprintf("存储 %s 不支持流式传输: %s\n正在使用普通下载...", task.StorageName, notsupportStreamStorage.NotSupportStream()), - ID: task.ReplyMessageID, - ReplyMarkup: cancelMarkUp, - }) - } - } - - cacheDestPath := filepath.Join(config.Cfg.Temp.BasePath, task.FileName()) - cacheDestPath, err = filepath.Abs(cacheDestPath) - if err != nil { - return fmt.Errorf("处理路径失败: %w", err) - } - if err := fileutil.CreateDir(filepath.Dir(cacheDestPath)); err != nil { - return fmt.Errorf("创建目录失败: %w", err) - } - - text, entities := buildProgressMessageEntity(task, 0, task.StartTime, 0) - if task.ReplyMessageID != 0 { - ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ - Message: text, - Entities: entities, - ID: task.ReplyMessageID, - ReplyMarkup: cancelMarkUp, - }) - } - - progressCallback := buildProgressCallback(ctx, task, getProgressUpdateCount(task.File.FileSize)) - dest, err := NewTaskLocalFile(cacheDestPath, task.File.FileSize, progressCallback) - if err != nil { - return fmt.Errorf("创建文件失败: %w", err) - } - defer dest.Close() - task.StartTime = time.Now() - _, err = downloadBuilder.Parallel(cancelCtx, dest) - if err != nil { - return fmt.Errorf("下载文件失败: %w", err) - } - defer cleanCacheFile(cacheDestPath) - - fixTaskFileExt(task, cacheDestPath) - - common.Log.Infof("Downloaded file: %s", cacheDestPath) - if task.ReplyMessageID != 0 { - ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ - Message: fmt.Sprintf("下载完成: %s\n正在转存文件...", task.FileName()), - ID: task.ReplyMessageID, - }) - } - return saveFileWithRetry(cancelCtx, task.StoragePath, taskStorage, cacheDestPath) -} - -func processTelegraph(extCtx *ext.Context, cancelCtx context.Context, task *types.Task, taskStorage storage.Storage) error { - if bot.TelegraphClient == nil { - return fmt.Errorf("telegraph client is not initialized") - } - tgphUrl := task.TelegraphURL - tgphPath := strings.Split(tgphUrl, "/")[len(strings.Split(tgphUrl, "/"))-1] - if tgphUrl == "" || tgphPath == "" { - return fmt.Errorf("invalid telegraph url") - } - entityBuilder := entity.Builder{} - text := fmt.Sprintf("正在下载 Telegraph \n文件夹: %s\n保存路径: %s", - task.FileName(), - fmt.Sprintf("[%s]:%s", task.StorageName, task.StoragePath), - ) - var entities []tg.MessageEntityClass - if err := styling.Perform(&entityBuilder, - styling.Plain("正在下载 Telegraph \n文件夹: "), - styling.Code(task.FileName()), - styling.Plain("\n保存路径: "), - styling.Code(fmt.Sprintf("[%s]:%s", task.StorageName, task.StoragePath)), - ); err != nil { - common.Log.Errorf("Failed to build entities: %s", err) - } - - if task.ReplyMessageID != 0 { - extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ - Message: text, - Entities: entities, - ID: task.ReplyMessageID, - ReplyMarkup: getCancelTaskMarkup(task), - }) - } - - resultCh := make(chan error) - go func() { - page, err := bot.TelegraphClient.GetPage(tgphPath, true) - if err != nil { - resultCh <- fmt.Errorf("获取 telegraph 页面失败: %w", err) - return - } - imgs := make([]string, 0) - for _, element := range page.Content { - var node telegraph.NodeElement - data, err := json.Marshal(element) - if err != nil { - common.Log.Errorf("Failed to marshal element: %s", err) - continue - } - err = json.Unmarshal(data, &node) - if err != nil { - common.Log.Errorf("Failed to unmarshal element: %s", err) - continue - } - - if len(node.Children) != 0 { - for _, child := range node.Children { - imgs = append(imgs, getNodeImages(child)...) - } - } - - if node.Tag == "img" { - if src, ok := node.Attrs["src"]; ok { - imgs = append(imgs, src) - } - } - - } - if len(imgs) == 0 { - resultCh <- fmt.Errorf("没有找到图片") - return - } - hc := bot.TelegraphClient.HttpClient - eg, ectx := errgroup.WithContext(cancelCtx) - eg.SetLimit(config.Cfg.Workers) // TODO: use a new config field for this - for i, img := range imgs { - if strings.HasPrefix(img, "/file/") { - img = "https://telegra.ph" + img - } - eg.Go(func() error { - var lastErr error - for attempt := range config.Cfg.Retry { - if attempt > 0 { - retryDelay := time.Duration(attempt*attempt) * time.Second - select { - case <-ectx.Done(): - return ectx.Err() - case <-time.After(retryDelay): - } - common.Log.Debugf("Retrying to download image %s (attempt %d)", img, attempt+1) - } - req, err := http.NewRequestWithContext(ectx, http.MethodGet, img, nil) - if err != nil { - lastErr = fmt.Errorf("创建请求失败: %w", err) - continue - } - resp, err := hc.Do(req) - if err != nil { - lastErr = fmt.Errorf("发送请求失败: %w", err) - continue - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - lastErr = fmt.Errorf("请求图片失败: %s", resp.Status) - continue - } - targetPath := path.Join(task.StoragePath, fmt.Sprintf("%d%s", i+1, path.Ext(img))) - err = taskStorage.Save(ectx, resp.Body, targetPath) - if err != nil { - lastErr = fmt.Errorf("保存图片失败: %w", err) - continue - } - common.Log.Infof("Saved image: %s", targetPath) - return nil - } - return lastErr - }) - } - if err := eg.Wait(); err != nil { - resultCh <- err - return - } - resultCh <- nil - }() - select { - case err := <-resultCh: - return err - case <-cancelCtx.Done(): - return cancelCtx.Err() - } -} diff --git a/core/download_test.go b/core/download_test.go deleted file mode 100644 index f0c6444..0000000 --- a/core/download_test.go +++ /dev/null @@ -1,80 +0,0 @@ -package core - -import ( - "reflect" - "testing" - - "github.com/celestix/telegraph-go/v2" -) - -func TestGetImgSrcs(t *testing.T) { - complexStructure := telegraph.NodeElement{ - Tag: "div", - Children: []telegraph.Node{ - telegraph.NodeElement{ - Tag: "figure", - Children: []telegraph.Node{ - telegraph.NodeElement{ - Tag: "img", - Attrs: map[string]string{ - "src": "https://example.com/image1.png", - }, - }, - telegraph.NodeElement{ - Tag: "p", - Children: []telegraph.Node{ - "A text node", - }, - }, - telegraph.NodeElement{ - Tag: "figure", - Children: []telegraph.Node{ - telegraph.NodeElement{ - Tag: "img", - Attrs: map[string]string{ - "src": "https://example.com/image2.png", - }, - }, - }, - }, - }, - }, - telegraph.NodeElement{ - Tag: "img", - Attrs: map[string]string{ - "src": "https://example.com/image3.png", - }, - }, - "text node", - telegraph.NodeElement{ - Tag: "div", - Children: []telegraph.Node{ - telegraph.NodeElement{ - Tag: "span", - Children: []telegraph.Node{ - telegraph.NodeElement{ - Tag: "img", - Attrs: map[string]string{ - "src": "https://example.com/image4.png", - }, - }, - }, - }, - }, - }, - }, - } - - expected := []string{ - "https://example.com/image1.png", - "https://example.com/image2.png", - "https://example.com/image3.png", - "https://example.com/image4.png", - } - - got := getNodeImages(complexStructure) - - if !reflect.DeepEqual(expected, got) { - t.Errorf("expected %v,got %v", expected, got) - } -} diff --git a/core/rule.go b/core/rule.go deleted file mode 100644 index b04dcbd..0000000 --- a/core/rule.go +++ /dev/null @@ -1,110 +0,0 @@ -package core - -import ( - "fmt" - "path" - "regexp" - - "github.com/celestix/gotgproto/ext" - "github.com/krau/SaveAny-Bot/bot" - "github.com/krau/SaveAny-Bot/common" - "github.com/krau/SaveAny-Bot/dao" - "github.com/krau/SaveAny-Bot/storage" - "github.com/krau/SaveAny-Bot/types" -) - -func getStorageAndPathForTask(task *types.Task) (storage.Storage, string, error) { - user, err := dao.GetUserByChatID(task.UserID) - if err != nil { - return nil, "", fmt.Errorf("failed to get user by chat ID: %w", err) - } - if task.StoragePath == "" { - task.StoragePath = task.FileName() - } - taskStorage, err := storage.GetStorageByUserIDAndName(task.UserID, task.StorageName) - if err != nil { - return nil, "", err - } - storagePath := taskStorage.JoinStoragePath(*task) - - var ruleTaskStorage storage.Storage - var ruleStoragePath string - if user.ApplyRule && user.Rules != nil { - for _, rule := range user.Rules { - matchStorage, matchStoragePath := applyRule(&rule, *task) - if matchStorage != nil && matchStoragePath != "" { - ruleTaskStorage = matchStorage - ruleStoragePath = matchStoragePath - common.Log.Debugf("Rule matched: %s, %s", ruleTaskStorage.Name(), ruleStoragePath) - return ruleTaskStorage, ruleStoragePath, nil - } - } - } - - if taskStorage.Exists(task.Ctx, storagePath) { - ext := path.Ext(task.FileName()) - name := task.FileName()[:len(task.FileName())-len(ext)] - task.File.FileName = fmt.Sprintf("%s_%d%s", name, task.FileDBID, ext) - task.StoragePath = task.File.FileName - storagePath = taskStorage.JoinStoragePath(*task) - } - - return taskStorage, storagePath, nil -} - -func applyRule(rule *dao.Rule, task types.Task) (storage.Storage, string) { - var DirPath, StorageName string - switch rule.Type { - case string(types.RuleTypeFileNameRegex): - ruleRegex, err := regexp.Compile(rule.Data) - if err != nil { - common.Log.Errorf("failed to compile regex: %s", err) - return nil, "" - } - if !ruleRegex.MatchString(task.FileName()) { - return nil, "" - } - DirPath = rule.DirPath - StorageName = rule.StorageName - case string(types.RuleTypeMessageRegex): - ruleRegex, err := regexp.Compile(rule.Data) - if err != nil { - common.Log.Errorf("failed to compile regex: %s", err) - return nil, "" - } - ctx, ok := task.Ctx.(*ext.Context) - if !ok { - common.Log.Fatalf("context is not *ext.Context: %T", task.Ctx) - return nil, "" - } - msg, err := bot.GetTGMessage(ctx, task.FileChatID, task.FileMessageID) - if err != nil { - common.Log.Errorf("failed to get message: %s", err) - return nil, "" - } - if msg == nil { - return nil, "" - } - if !ruleRegex.MatchString(msg.GetMessage()) { - return nil, "" - } - DirPath = rule.DirPath - StorageName = rule.StorageName - default: - common.Log.Errorf("unknown rule type: %s", rule.Type) - return nil, "" - } - taskStorageName := func() string { - if StorageName == "" || StorageName == "CHOSEN" { - return task.StorageName - } - return StorageName - }() - taskStorage, err := storage.GetStorageByUserIDAndName(task.UserID, taskStorageName) - if err != nil { - common.Log.Errorf("failed to get storage: %s", err) - return nil, "" - } - task.StoragePath = path.Join(DirPath, task.StoragePath) - return taskStorage, taskStorage.JoinStoragePath(task) -} diff --git a/core/tftask/execute.go b/core/tftask/execute.go new file mode 100644 index 0000000..29a74c4 --- /dev/null +++ b/core/tftask/execute.go @@ -0,0 +1,82 @@ +package tftask + +import ( + "context" + "fmt" + "os" + "path" + "time" + + "github.com/charmbracelet/log" + "github.com/krau/SaveAny-Bot/common/tdler" + "github.com/krau/SaveAny-Bot/common/utils/fsutil" + "github.com/krau/SaveAny-Bot/config" + "github.com/krau/SaveAny-Bot/pkg/enums/key" +) + +func (t *TGFileTask) Execute(ctx context.Context) error { + logger := log.FromContext(ctx).WithPrefix(fmt.Sprintf("file[%s]", t.File.Name())) + t.Progress.OnStart(ctx, t) + if t.stream { + return executeStream(ctx, t) + } + + logger.Info("Starting file download") + localFile, err := fsutil.CreateFile(t.localPath) + if err != nil { + return fmt.Errorf("failed to create local file: %w", err) + } + defer func() { + if err := localFile.CloseAndRemove(); err != nil { + logger.Errorf("Failed to close local file: %v", err) + } + }() + wrAt := newWriterAt(ctx, localFile, t.Progress, t) + + defer func() { + t.Progress.OnDone(ctx, t, err) + }() + _, err = tdler.NewDownloader(t.client, t.File).Parallel(ctx, wrAt) + if err != nil { + return fmt.Errorf("failed to download file: %w", err) + } + logger.Infof("File downloaded successfully") + if path.Ext(t.File.Name()) == "" { + ext := fsutil.DetectFileExt(t.localPath) + if ext != "" { + t.Path = t.Path + ext + } + } + var fileStat os.FileInfo + fileStat, err = os.Stat(t.localPath) + if err != nil { + return fmt.Errorf("failed to get file stat: %w", err) + } + vctx := context.WithValue(ctx, key.ContextKeyContentLength, fileStat.Size()) + for i := range config.Cfg.Retry + 1 { + if err = vctx.Err(); err != nil { + return fmt.Errorf("context canceled while saving file: %w", err) + } + var file *os.File + file, err = os.Open(t.localPath) + if err != nil { + return fmt.Errorf("failed to open cache file: %w", err) + } + defer file.Close() + if err = t.Storage.Save(vctx, file, t.Path); err != nil { + if i == config.Cfg.Retry { + return fmt.Errorf("failed to save file: %w", err) + } + logger.Errorf("Failed to save file: %s, retrying...", err) + select { + case <-vctx.Done(): + return fmt.Errorf("context canceled during retry delay: %w", vctx.Err()) + case <-time.After(time.Duration(i*500) * time.Millisecond): + } + continue + } + return nil + } + return fmt.Errorf("failed to save file after retries") + +} diff --git a/core/tftask/progress.go b/core/tftask/progress.go new file mode 100644 index 0000000..3b5cf9b --- /dev/null +++ b/core/tftask/progress.go @@ -0,0 +1,186 @@ +package tftask + +import ( + "context" + "errors" + "fmt" + "sync/atomic" + "time" + + "github.com/charmbracelet/log" + "github.com/gotd/td/telegram/message/entity" + "github.com/gotd/td/telegram/message/styling" + "github.com/gotd/td/tg" + "github.com/krau/SaveAny-Bot/common/utils/dlutil" + "github.com/krau/SaveAny-Bot/common/utils/tgutil" +) + +type ProgressTracker interface { + OnStart(ctx context.Context, info TaskInfo) + OnProgress(ctx context.Context, info TaskInfo, downloaded, total int64) + OnDone(ctx context.Context, info TaskInfo, err error) +} + +type Progress struct { + MessageID int + ChatID int64 + start time.Time + lastUpdatePercent atomic.Int32 +} + +func (p *Progress) OnStart(ctx context.Context, info TaskInfo) { + p.start = time.Now() + p.lastUpdatePercent.Store(0) + log.FromContext(ctx).Debugf("Progress tracking started for message %d in chat %d", p.MessageID, p.ChatID) + entityBuilder := entity.Builder{} + var entities []tg.MessageEntityClass + if err := styling.Perform(&entityBuilder, + styling.Plain("开始下载\n文件名: "), + styling.Code(info.FileName()), + styling.Plain("\n保存路径: "), + styling.Code(fmt.Sprintf("[%s]:%s", info.StorageName(), info.StoragePath())), + styling.Plain("\n文件大小: "), + styling.Code(fmt.Sprintf("%.2f MB", float64(info.FileSize())/(1024*1024))), + ); err != nil { + log.FromContext(ctx).Errorf("Failed to build entities: %s", err) + return + } + text, entities := entityBuilder.Complete() + req := &tg.MessagesEditMessageRequest{ + ID: p.MessageID, + } + req.SetMessage(text) + req.SetEntities(entities) + req.SetReplyMarkup(&tg.ReplyInlineMarkup{ + Rows: []tg.KeyboardButtonRow{ + { + Buttons: []tg.KeyboardButtonClass{ + tgutil.BuildCancelButton(info.TaskID()), + }, + }, + }}, + ) + ext := tgutil.ExtFromContext(ctx) + if ext != nil { + ext.EditMessage(p.ChatID, req) + return + } +} + +func (p *Progress) OnProgress(ctx context.Context, info TaskInfo, downloaded, total int64) { + if !shouldUpdateProgress(total, downloaded, int(p.lastUpdatePercent.Load())) { + return + } + percent := int32((downloaded * 100) / total) + if p.lastUpdatePercent.Load() == percent { + return + } + p.lastUpdatePercent.Store(percent) + log.FromContext(ctx).Debugf("Progress update: %s, %d/%d", info.FileName(), downloaded, total) + entityBuilder := entity.Builder{} + var entities []tg.MessageEntityClass + if err := styling.Perform(&entityBuilder, + styling.Plain("正在处理下载任务\n文件名: "), + styling.Code(info.FileName()), + styling.Plain("\n保存路径: "), + styling.Code(fmt.Sprintf("[%s]:%s", info.StorageName(), info.StoragePath())), + styling.Plain("\n文件大小: "), + styling.Code(fmt.Sprintf("%.2f MB", float64(total)/(1024*1024))), + styling.Plain("\n平均速度: "), + styling.Bold(fmt.Sprintf("%.2f MB/s", dlutil.GetSpeed(downloaded, p.start)/(1024*1024))), + styling.Plain("\n当前进度: "), + styling.Bold(fmt.Sprintf("%.2f%%", float64(downloaded)/float64(total)*100)), + ); err != nil { + log.FromContext(ctx).Errorf("Failed to build entities: %s", err) + return + } + text, entities := entityBuilder.Complete() + req := &tg.MessagesEditMessageRequest{ + ID: p.MessageID, + } + req.SetMessage(text) + req.SetEntities(entities) + req.SetReplyMarkup(&tg.ReplyInlineMarkup{ + Rows: []tg.KeyboardButtonRow{ + { + Buttons: []tg.KeyboardButtonClass{ + tgutil.BuildCancelButton(info.TaskID()), + }, + }, + }}, + ) + ext := tgutil.ExtFromContext(ctx) + if ext != nil { + ext.EditMessage(p.ChatID, req) + return + } + +} + +func (p *Progress) OnDone(ctx context.Context, info TaskInfo, err error) { + if err != nil { + log.FromContext(ctx).Errorf("Progress error for file [%s]: %v", info.FileName(), err) + } else { + log.FromContext(ctx).Debugf("Progress done for file [%s]", info.FileName()) + } + + entityBuilder := entity.Builder{} + var stylingErr error + + if err != nil { + if errors.Is(err, context.Canceled) { + stylingErr = styling.Perform(&entityBuilder, + styling.Plain("任务已取消\n文件名: "), + styling.Code(info.FileName()), + ) + } else { + stylingErr = styling.Perform(&entityBuilder, + styling.Plain("下载失败\n文件名: "), + styling.Code(info.FileName()), + styling.Plain("\n错误: "), + styling.Bold(err.Error()), + ) + } + } else { + stylingErr = styling.Perform(&entityBuilder, + styling.Plain("下载完成\n文件名: "), + styling.Code(info.FileName()), + styling.Plain("\n保存路径: "), + styling.Code(fmt.Sprintf("[%s]:%s", info.StorageName(), info.StoragePath())), + ) + } + + if stylingErr != nil { + log.FromContext(ctx).Errorf("Failed to build entities: %s", stylingErr) + return + } + + text, entities := entityBuilder.Complete() + req := &tg.MessagesEditMessageRequest{ + ID: p.MessageID, + } + req.SetMessage(text) + req.SetEntities(entities) + + ext := tgutil.ExtFromContext(ctx) + if ext != nil { + ext.EditMessage(p.ChatID, req) + } +} + +type ProgressOption func(*Progress) + +func NewProgressTrack( + messageID int, + chatID int64, + opts ...ProgressOption, +) ProgressTracker { + p := &Progress{ + MessageID: messageID, + ChatID: chatID, + } + for _, opt := range opts { + opt(p) + } + return p +} diff --git a/core/tftask/stream.go b/core/tftask/stream.go new file mode 100644 index 0000000..854db27 --- /dev/null +++ b/core/tftask/stream.go @@ -0,0 +1,40 @@ +package tftask + +import ( + "context" + "fmt" + "io" + + "github.com/charmbracelet/log" + "github.com/krau/SaveAny-Bot/common/tdler" + "golang.org/x/sync/errgroup" +) + +func executeStream(ctx context.Context, task *TGFileTask) error { + logger := log.FromContext(ctx).WithPrefix(fmt.Sprintf("file[%s]", task.File.Name())) + + pr, pw := io.Pipe() + defer pr.Close() + errg, uploadCtx := errgroup.WithContext(ctx) + errg.Go(func() error { + return task.Storage.Save(uploadCtx, pr, task.Path) + }) + wr := newWriter(ctx, pw, task.Progress, task) + errg.Go(func() error { + logger.Info("Starting file download in stream mode") + _, err := tdler.NewDownloader(task.client, task.File).Stream(uploadCtx, wr) + if closeErr := pw.CloseWithError(err); closeErr != nil { + logger.Errorf("Failed to close pipe writer: %v", closeErr) + } + return err + }) + var err error + defer func() { + task.Progress.OnDone(ctx, task, err) + }() + if err = errg.Wait(); err != nil { + return err + } + logger.Info("File downloaded successfully in stream mode") + return nil +} diff --git a/core/tftask/taskinfo.go b/core/tftask/taskinfo.go new file mode 100644 index 0000000..abcae29 --- /dev/null +++ b/core/tftask/taskinfo.go @@ -0,0 +1,29 @@ +package tftask + +type TaskInfo interface { + TaskID() string + FileName() string + FileSize() int64 + StoragePath() string + StorageName() string +} + +func (t *TGFileTask) TaskID() string { + return t.ID +} + +func (t *TGFileTask) FileName() string { + return t.File.Name() +} + +func (t *TGFileTask) FileSize() int64 { + return t.File.Size() +} + +func (t *TGFileTask) StoragePath() string { + return t.Path +} + +func (t *TGFileTask) StorageName() string { + return t.Storage.Name() +} diff --git a/core/tftask/tftask.go b/core/tftask/tftask.go new file mode 100644 index 0000000..f83aad6 --- /dev/null +++ b/core/tftask/tftask.go @@ -0,0 +1,64 @@ +package tftask + +import ( + "context" + "fmt" + "path/filepath" + + "github.com/krau/SaveAny-Bot/common/tdler" + "github.com/krau/SaveAny-Bot/config" + "github.com/krau/SaveAny-Bot/pkg/tfile" + "github.com/krau/SaveAny-Bot/storage" +) + +type TGFileTask struct { + ID string + Ctx context.Context + File tfile.TGFile + Storage storage.Storage + Path string + Progress ProgressTracker + client tdler.Client + stream bool // true if the file should be downloaded in stream mode + localPath string +} + +func NewTGFileTask( + id string, + ctx context.Context, + file tfile.TGFile, + client tdler.Client, + stor storage.Storage, + path string, + progress ProgressTracker, +) (*TGFileTask, error) { + _, ok := stor.(storage.StorageCannotStream) + if !config.Cfg.Stream || ok { + cachePath, err := filepath.Abs(filepath.Join(config.Cfg.Temp.BasePath, fmt.Sprintf("%s_%s", id, file.Name()))) + if err != nil { + return nil, fmt.Errorf("failed to get absolute path for cache: %w", err) + } + tftask := &TGFileTask{ + ID: id, + Ctx: ctx, + client: client, + File: file, + Storage: stor, + Path: path, + Progress: progress, + localPath: cachePath, + } + return tftask, nil + } + tfileTask := &TGFileTask{ + ID: id, + Ctx: ctx, + client: client, + File: file, + Storage: stor, + Path: path, + Progress: progress, + stream: true, + } + return tfileTask, nil +} diff --git a/core/tftask/util.go b/core/tftask/util.go new file mode 100644 index 0000000..15cff77 --- /dev/null +++ b/core/tftask/util.go @@ -0,0 +1,32 @@ +package tftask + +var progressUpdatesLevels = []struct { + size int64 // 文件大小阈值 + stepPercent int // 每多少 % 更新一次 +}{ + {10 << 20, 100}, + {50 << 20, 20}, + {200 << 20, 10}, + {500 << 20, 5}, +} + +func shouldUpdateProgress(total, downloaded int64, lastUpdatePercent int) bool { + if total <= 0 || downloaded <= 0 { + return false + } + + percent := int((downloaded * 100) / total) + if percent <= lastUpdatePercent { + return false + } + + step := progressUpdatesLevels[len(progressUpdatesLevels)-1].stepPercent + for _, lvl := range progressUpdatesLevels { + if total < lvl.size { + step = lvl.stepPercent + break + } + } + + return percent >= lastUpdatePercent+step +} diff --git a/core/tftask/writer.go b/core/tftask/writer.go new file mode 100644 index 0000000..d5df4e9 --- /dev/null +++ b/core/tftask/writer.go @@ -0,0 +1,75 @@ +package tftask + +import ( + "context" + "io" + "sync/atomic" +) + +type ProgressWriterAt struct { + ctx context.Context + wrAt io.WriterAt + progress ProgressTracker + downloaded *atomic.Int64 + total int64 + info TaskInfo +} + +func (w *ProgressWriterAt) WriteAt(p []byte, off int64) (int, error) { + at, err := w.wrAt.WriteAt(p, off) + if err != nil { + return 0, err + } + w.progress.OnProgress(w.ctx, w.info, w.downloaded.Add(int64(at)), w.total) + return at, nil +} + +func newWriterAt( + ctx context.Context, + wrAt io.WriterAt, + progress ProgressTracker, + taskInfo TaskInfo, +) *ProgressWriterAt { + return &ProgressWriterAt{ + ctx: ctx, + progress: progress, + downloaded: &atomic.Int64{}, + total: taskInfo.FileSize(), + wrAt: wrAt, + info: taskInfo, + } +} + +type ProgressWriter struct { + ctx context.Context + wrAt io.Writer + progress ProgressTracker + downloaded *atomic.Int64 + total int64 + info TaskInfo +} + +func (w *ProgressWriter) Write(p []byte) (int, error) { + at, err := w.wrAt.Write(p) + if err != nil { + return 0, err + } + w.progress.OnProgress(w.ctx, w.info, w.downloaded.Add(int64(at)), w.total) + return at, nil +} + +func newWriter( + ctx context.Context, + wr io.Writer, + progress ProgressTracker, + taskInfo TaskInfo, +) *ProgressWriter { + return &ProgressWriter{ + ctx: ctx, + progress: progress, + downloaded: &atomic.Int64{}, + total: taskInfo.FileSize(), + wrAt: wr, + info: taskInfo, + } +} diff --git a/core/tphtask/execute.go b/core/tphtask/execute.go new file mode 100644 index 0000000..53edc95 --- /dev/null +++ b/core/tphtask/execute.go @@ -0,0 +1,94 @@ +package tphtask + +import ( + "context" + "fmt" + "io" + "path" + "path/filepath" + + "github.com/charmbracelet/log" + "github.com/duke-git/lancet/v2/retry" + "github.com/krau/SaveAny-Bot/common/utils/fsutil" + "github.com/krau/SaveAny-Bot/config" + "go.uber.org/multierr" + "golang.org/x/sync/errgroup" +) + +func (t *Task) Execute(ctx context.Context) error { + logger := log.FromContext(ctx) + logger.Infof("Starting Telegraph task %s", t.PhPath) + t.progress.OnStart(ctx, t) + eg, gctx := errgroup.WithContext(ctx) + eg.SetLimit(config.Cfg.Workers) + for i, pic := range t.Pics { + pic := pic + i := i + eg.Go(func() error { + err := t.processPic(gctx, pic, i) + if err != nil { + logger.Errorf("Error processing picture %s: %v", pic, err) + return fmt.Errorf("failed to process picture %s: %w", pic, err) + } + t.downloaded.Add(1) + t.progress.OnProgress(gctx, t) + return nil + }) + } + err := eg.Wait() + if err != nil { + logger.Errorf("Error during Telegraph task execution: %v", err) + } else { + logger.Infof("Telegraph task %s completed successfully", t.PhPath) + } + t.progress.OnDone(ctx, t, err) + return err +} + +func (t *Task) processPic(ctx context.Context, picUrl string, index int) error { + retryOpts := []retry.Option{ + retry.Context(ctx), + retry.RetryTimes(uint(config.Cfg.Retry)), + } + var lastErr error + err := retry.Retry(func() error { + var body io.ReadCloser + body, lastErr = t.client.Download(ctx, picUrl) + if lastErr != nil { + lastErr = fmt.Errorf("failed to download picture %s: %w", picUrl, lastErr) + return lastErr + } + defer body.Close() + filename := fmt.Sprintf("%d%s", index+1, path.Ext(picUrl)) + if t.cannotStream { + cacheFile, err := fsutil.CreateFile(filepath.Join(config.Cfg.Temp.BasePath, + fmt.Sprintf("tph_%s_%s", t.TaskID(), filename), + )) + if err != nil { + lastErr = fmt.Errorf("failed to create cache file for picture %s: %w", filename, err) + return lastErr + } + defer func() { + if err := cacheFile.CloseAndRemove(); err != nil { + logger := log.FromContext(ctx) + logger.Errorf("Failed to close and remove cache file for picture %s: %v", filename, err) + } + }() + _, lastErr = io.Copy(cacheFile, body) + if lastErr != nil { + lastErr = fmt.Errorf("failed to copy picture %s to cache file: %w", filename, lastErr) + return lastErr + } + lastErr = t.Stor.Save(ctx, cacheFile, path.Join(t.StorPath, filename)) + } else { + lastErr = t.Stor.Save(ctx, body, path.Join(t.StorPath, filename)) + } + + if lastErr != nil { + lastErr = fmt.Errorf("failed to save picture %s: %w", filename, lastErr) + return lastErr + } + return nil + }, retryOpts...) + return multierr.Combine(err, lastErr) +} diff --git a/core/tphtask/progress.go b/core/tphtask/progress.go new file mode 100644 index 0000000..df0b7fd --- /dev/null +++ b/core/tphtask/progress.go @@ -0,0 +1,150 @@ +package tphtask + +import ( + "context" + "errors" + "fmt" + + "github.com/charmbracelet/log" + "github.com/gotd/td/telegram/message/entity" + "github.com/gotd/td/telegram/message/styling" + "github.com/gotd/td/tg" + "github.com/krau/SaveAny-Bot/common/utils/tgutil" +) + +type ProgressTracker interface { + OnStart(ctx context.Context, info TaskInfo) + OnProgress(ctx context.Context, info TaskInfo) + OnDone(ctx context.Context, info TaskInfo, err error) +} + +type Progress struct { + MessageID int + ChatID int64 +} + +func (p *Progress) OnStart(ctx context.Context, info TaskInfo) { + logger := log.FromContext(ctx) + logger.Debugf("Telegraph task progress tracking started for message %d in chat %d", p.MessageID, p.ChatID) + entityBuilder := entity.Builder{} + var entities []tg.MessageEntityClass + if err := styling.Perform(&entityBuilder, + styling.Plain("开始下载Telegraph\n图片数量: "), + styling.Code(fmt.Sprintf("%d", info.TotalPics())), + ); err != nil { + log.FromContext(ctx).Errorf("Failed to build entities: %s", err) + return + } + text, entities := entityBuilder.Complete() + req := &tg.MessagesEditMessageRequest{ + ID: p.MessageID, + } + req.SetMessage(text) + req.SetEntities(entities) + req.SetReplyMarkup(&tg.ReplyInlineMarkup{ + Rows: []tg.KeyboardButtonRow{ + { + Buttons: []tg.KeyboardButtonClass{ + tgutil.BuildCancelButton(info.TaskID()), + }, + }, + }}, + ) + ext := tgutil.ExtFromContext(ctx) + if ext != nil { + ext.EditMessage(p.ChatID, req) + return + } +} + +func (p *Progress) OnProgress(ctx context.Context, info TaskInfo) { + if !shouldUpdateProgress(info.Downloaded(), int64(info.TotalPics())) { + return + } + log.FromContext(ctx).Debugf("Progress update: %s, %d/%d", info.TaskID(), info.Downloaded(), info.TotalPics()) + entityBuilder := entity.Builder{} + var entities []tg.MessageEntityClass + if err := styling.Perform(&entityBuilder, + styling.Plain("正在下载\n当前进度: "), + styling.Code(fmt.Sprintf("%d/%d", info.Downloaded(), info.TotalPics())), + ); err != nil { + log.FromContext(ctx).Errorf("Failed to build entities: %s", err) + return + } + text, entities := entityBuilder.Complete() + req := &tg.MessagesEditMessageRequest{ + ID: p.MessageID, + } + req.SetMessage(text) + req.SetEntities(entities) + req.SetReplyMarkup(&tg.ReplyInlineMarkup{ + Rows: []tg.KeyboardButtonRow{ + { + Buttons: []tg.KeyboardButtonClass{ + tgutil.BuildCancelButton(info.TaskID()), + }, + }, + }}, + ) + ext := tgutil.ExtFromContext(ctx) + if ext != nil { + ext.EditMessage(p.ChatID, req) + return + } +} + +func (p *Progress) OnDone(ctx context.Context, info TaskInfo, err error) { + logger := log.FromContext(ctx) + if err != nil { + if errors.Is(err, context.Canceled) { + logger.Infof("Telegraph task %s was canceled", info.TaskID()) + ext := tgutil.ExtFromContext(ctx) + if ext != nil { + ext.EditMessage(p.ChatID, &tg.MessagesEditMessageRequest{ + ID: p.MessageID, + Message: fmt.Sprintf("处理已取消: %s", info.TaskID()), + }) + } + } else { + logger.Errorf("Telegraph task %s failed: %s", info.TaskID(), err) + ext := tgutil.ExtFromContext(ctx) + if ext != nil { + ext.EditMessage(p.ChatID, &tg.MessagesEditMessageRequest{ + ID: p.MessageID, + Message: fmt.Sprintf("处理失败: %s", err.Error()), + }) + } + } + return + } + logger.Infof("Telegraph task %s completed successfully", info.TaskID()) + + entityBuilder := entity.Builder{} + if err := styling.Perform(&entityBuilder, + styling.Plain("处理完成\n图片数量: "), + styling.Code(fmt.Sprintf("%d", info.TotalPics())), + styling.Plain("\n保存路径: "), + styling.Code(fmt.Sprintf("[%s]:%s", info.StorageName(), info.StoragePath())), + ); err != nil { + logger.Errorf("Failed to build entities: %s", err) + return + } + text, entities := entityBuilder.Complete() + req := &tg.MessagesEditMessageRequest{ + ID: p.MessageID, + } + req.SetMessage(text) + req.SetEntities(entities) + + ext := tgutil.ExtFromContext(ctx) + if ext != nil { + ext.EditMessage(p.ChatID, req) + } +} + +func NewProgress(messageID int, chatID int64) *Progress { + return &Progress{ + MessageID: messageID, + ChatID: chatID, + } +} diff --git a/core/tphtask/task.go b/core/tphtask/task.go new file mode 100644 index 0000000..cff20d9 --- /dev/null +++ b/core/tphtask/task.go @@ -0,0 +1,51 @@ +package tphtask + +import ( + "context" + "sync/atomic" + + "github.com/krau/SaveAny-Bot/pkg/telegraph" + "github.com/krau/SaveAny-Bot/storage" +) + +type Task struct { + ID string + Ctx context.Context + PhPath string + Pics []string + Stor storage.Storage + StorPath string + client *telegraph.Client + progress ProgressTracker + + cannotStream bool + totalpics int + downloaded atomic.Int64 +} + +func NewTask( + id string, + ctx context.Context, + phPath string, + pics []string, + stor storage.Storage, + storPath string, + client *telegraph.Client, + progress ProgressTracker, +) *Task { + _, cannotStream := stor.(storage.StorageCannotStream) + tphtask := &Task{ + ID: id, + Ctx: ctx, + PhPath: phPath, + Pics: pics, + Stor: stor, + StorPath: storPath, + client: client, + progress: progress, + cannotStream: cannotStream, + totalpics: len(pics), + downloaded: atomic.Int64{}, + } + return tphtask +} diff --git a/core/tphtask/taskinfo.go b/core/tphtask/taskinfo.go new file mode 100644 index 0000000..abc33cd --- /dev/null +++ b/core/tphtask/taskinfo.go @@ -0,0 +1,34 @@ +package tphtask + +type TaskInfo interface { + TaskID() string + Phpath() string + TotalPics() int + Downloaded() int64 + StorageName() string + StoragePath() string +} + +func (t *Task) TaskID() string { + return t.ID +} + +func (t *Task) Phpath() string { + return t.PhPath +} + +func (t *Task) TotalPics() int { + return t.totalpics +} + +func (t *Task) Downloaded() int64 { + return t.downloaded.Load() +} + +func (t *Task) StorageName() string { + return t.Stor.Name() +} + +func (t *Task) StoragePath() string { + return t.StorPath +} diff --git a/core/tphtask/utils.go b/core/tphtask/utils.go new file mode 100644 index 0000000..c2e83a1 --- /dev/null +++ b/core/tphtask/utils.go @@ -0,0 +1,13 @@ +package tphtask + +func shouldUpdateProgress(downloaded int64, total int64) bool { + if total <= 0 || downloaded <= 0 { + return false + } + + step := int64(10) + if downloaded < step { + return downloaded == total + } + return downloaded%step == 0 || downloaded == total +} diff --git a/core/utils.go b/core/utils.go deleted file mode 100644 index 4c83e76..0000000 --- a/core/utils.go +++ /dev/null @@ -1,303 +0,0 @@ -package core - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "os" - "path" - "time" - - "github.com/celestix/gotgproto/ext" - "github.com/celestix/telegraph-go/v2" - "github.com/gabriel-vasile/mimetype" - "github.com/gotd/td/telegram/message/entity" - "github.com/gotd/td/telegram/message/styling" - "github.com/gotd/td/tg" - "github.com/krau/SaveAny-Bot/bot" - "github.com/krau/SaveAny-Bot/common" - "github.com/krau/SaveAny-Bot/config" - "github.com/krau/SaveAny-Bot/storage" - "github.com/krau/SaveAny-Bot/types" - "github.com/krau/SaveAny-Bot/userclient" -) - -func saveFileWithRetry(ctx context.Context, storagePath string, taskStorage storage.Storage, cacheFilePath string) error { - file, err := os.Open(cacheFilePath) - if err != nil { - return fmt.Errorf("failed to open cache file: %w", err) - } - defer file.Close() - fileStat, err := file.Stat() - if err != nil { - return fmt.Errorf("failed to get file stat: %w", err) - } - vctx := context.WithValue(ctx, types.ContextKeyContentLength, fileStat.Size()) - for i := 0; i <= config.Cfg.Retry; i++ { - if err := vctx.Err(); err != nil { - return fmt.Errorf("context canceled while saving file: %w", err) - } - file, err := os.Open(cacheFilePath) - if err != nil { - return fmt.Errorf("failed to open cache file: %w", err) - } - defer file.Close() - if err := taskStorage.Save(vctx, file, storagePath); err != nil { - if i == config.Cfg.Retry { - return fmt.Errorf("failed to save file: %w", err) - } - common.Log.Errorf("Failed to save file: %s, retrying...", err) - select { - case <-vctx.Done(): - return fmt.Errorf("context canceled during retry delay: %w", vctx.Err()) - case <-time.After(time.Duration(i*500) * time.Millisecond): - } - continue - } - return nil - } - return nil -} - -func processPhoto(task *types.Task, taskStorage storage.Storage) error { - api := bot.Client.API() - if task.UseUserClient && userclient.UC != nil { - api = userclient.UC.API() - } - res, err := api.UploadGetFile(task.Ctx, &tg.UploadGetFileRequest{ - Location: task.File.Location, - Offset: 0, - Limit: 1024 * 1024, - }) - if err != nil { - return fmt.Errorf("failed to get file: %w", err) - } - - result, ok := res.(*tg.UploadFile) - if !ok { - return fmt.Errorf("unexpected type %T", res) - } - - common.Log.Infof("Downloaded photo: %s", task.FileName()) - - return taskStorage.Save(task.Ctx, bytes.NewReader(result.Bytes), task.StoragePath) -} - -func cleanCacheFile(destPath string) { - if config.Cfg.Temp.CacheTTL > 0 { - common.RmFileAfter(destPath, time.Duration(config.Cfg.Temp.CacheTTL)*time.Second) - } else { - if err := os.Remove(destPath); err != nil { - common.Log.Errorf("Failed to purge file: %s", err) - } - } -} - -// 获取进度需要更新的次数 -func getProgressUpdateCount(fileSize int64) int { - updateCount := 5 - if fileSize > 1024*1024*1000 { - updateCount = 50 - } else if fileSize > 1024*1024*500 { - updateCount = 20 - } else if fileSize > 1024*1024*200 { - updateCount = 10 - } - return updateCount -} - -func getSpeed(bytesRead int64, startTime time.Time) string { - if startTime.IsZero() { - return "0MB/s" - } - elapsed := time.Since(startTime) - speed := float64(bytesRead) / 1024 / 1024 / elapsed.Seconds() - return fmt.Sprintf("%.2fMB/s", speed) -} - -func buildProgressMessageEntity(task *types.Task, bytesRead int64, startTime time.Time, progress float64) (string, []tg.MessageEntityClass) { - entityBuilder := entity.Builder{} - text := fmt.Sprintf("正在处理下载任务\n文件名: %s\n保存路径: %s\n平均速度: %s\n当前进度: %.2f%%", - task.FileName(), - fmt.Sprintf("[%s]:%s", task.StorageName, task.StoragePath), - getSpeed(bytesRead, startTime), - progress, - ) - var entities []tg.MessageEntityClass - if err := styling.Perform(&entityBuilder, - styling.Plain("正在处理下载任务\n文件名: "), - styling.Code(task.FileName()), - styling.Plain("\n保存路径: "), - styling.Code(fmt.Sprintf("[%s]:%s", task.StorageName, task.StoragePath)), - styling.Plain("\n平均速度: "), - styling.Bold(getSpeed(bytesRead, task.StartTime)), - styling.Plain("\n当前进度: "), - styling.Bold(fmt.Sprintf("%.2f%%", progress)), - ); err != nil { - common.Log.Errorf("Failed to build entities: %s", err) - return text, entities - } - return entityBuilder.Complete() -} - -func buildProgressCallback(ctx *ext.Context, task *types.Task, updateCount int) func(bytesRead, contentLength int64) { - return func(bytesRead, contentLength int64) { - progress := float64(bytesRead) / float64(contentLength) * 100 - common.Log.Tracef("Downloading %s: %.2f%%", task.String(), progress) - progressInt := int(progress) - if task.File.FileSize < 1024*1024*50 || progressInt == 0 || progressInt%int(100/updateCount) != 0 { - return - } - if task.ReplyMessageID == 0 { - return - } - text, entities := buildProgressMessageEntity(task, bytesRead, task.StartTime, progress) - ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ - Message: text, - Entities: entities, - ID: task.ReplyMessageID, - ReplyMarkup: getCancelTaskMarkup(task), - }) - } -} - -func getCancelTaskMarkup(task *types.Task) *tg.ReplyInlineMarkup { - return &tg.ReplyInlineMarkup{ - Rows: []tg.KeyboardButtonRow{{Buttons: []tg.KeyboardButtonClass{&tg.KeyboardButtonCallback{Text: "取消任务", Data: fmt.Appendf(nil, "cancel %s", task.Key())}}}}, - } -} - -func fixTaskFileExt(task *types.Task, localFilePath string) { - if path.Ext(task.FileName()) == "" { - mimeType, err := mimetype.DetectFile(localFilePath) - if err != nil { - common.Log.Errorf("Failed to detect mime type: %s", err) - } else { - task.File.FileName = fmt.Sprintf("%s%s", task.FileName(), mimeType.Extension()) - task.StoragePath = fmt.Sprintf("%s%s", task.StoragePath, mimeType.Extension()) - } - } -} - -func getTaskThreads(fileSize int64) int { - threads := 1 - if fileSize > 1024*1024*100 { - threads = config.Cfg.Threads - } else if fileSize > 1024*1024*50 { - threads = config.Cfg.Threads / 2 - } - return threads -} - -type TaskLocalFile struct { - file *os.File - size int64 - done int64 - progressCallback func(bytesRead, contentLength int64) - callbackTimes int64 - nextCallbackAt int64 - callbackInterval int64 -} - -func (t *TaskLocalFile) Read(p []byte) (n int, err error) { - return t.file.Read(p) -} - -func (t *TaskLocalFile) Close() error { - return t.file.Close() -} -func (t *TaskLocalFile) WriteAt(p []byte, off int64) (int, error) { - n, err := t.file.WriteAt(p, off) - if err != nil { - return n, err - } - t.done += int64(n) - if t.progressCallback != nil && t.done >= t.nextCallbackAt { - t.progressCallback(t.done, t.size) - t.nextCallbackAt += t.callbackInterval - } - return n, nil -} - -func NewTaskLocalFile(filePath string, fileSize int64, progressCallback func(bytesRead, contentLength int64)) (*TaskLocalFile, error) { - file, err := os.Create(filePath) - if err != nil { - return nil, fmt.Errorf("failed to open file: %w", err) - } - var callbackInterval int64 - callbackInterval = fileSize / 100 - if callbackInterval == 0 { - callbackInterval = 1 - } - return &TaskLocalFile{ - file: file, - size: fileSize, - progressCallback: progressCallback, - callbackTimes: 100, - nextCallbackAt: callbackInterval, - callbackInterval: callbackInterval, - }, nil -} - -type ProgressStream struct { - writer io.Writer - size int64 - done int64 - callback func(bytesRead, contentLength int64) - nextAt int64 - interval int64 -} - -func (ps *ProgressStream) Write(p []byte) (n int, err error) { - n, err = ps.writer.Write(p) - if err != nil { - return n, err - } - ps.done += int64(n) - if ps.callback != nil && ps.done >= ps.nextAt { - ps.callback(ps.done, ps.size) - ps.nextAt += ps.interval - } - return n, nil -} - -func NewProgressStream(writer io.Writer, size int64, callback func(bytesRead, contentLength int64)) *ProgressStream { - var interval int64 - interval = size / 100 - if interval == 0 { - interval = 1 - } - return &ProgressStream{ - writer: writer, - size: size, - callback: callback, - nextAt: interval, - interval: interval, - } -} - -func getNodeImages(node telegraph.Node) []string { - var srcs []string - - var nodeElement telegraph.NodeElement - data, err := json.Marshal(node) - if err != nil { - return srcs - } - err = json.Unmarshal(data, &nodeElement) - if err != nil { - return srcs - } - - if nodeElement.Tag == "img" { - if src, exists := nodeElement.Attrs["src"]; exists { - srcs = append(srcs, src) - } - } - for _, child := range nodeElement.Children { - srcs = append(srcs, getNodeImages(child)...) - } - return srcs -} diff --git a/dao/callback_data.go b/dao/callback_data.go deleted file mode 100644 index 7b96e65..0000000 --- a/dao/callback_data.go +++ /dev/null @@ -1,19 +0,0 @@ -package dao - -func CreateCallbackData(data string) (uint, error) { - callbackData := CallbackData{ - Data: data, - } - err := db.Create(&callbackData).Error - return callbackData.ID, err -} - -func GetCallbackData(id uint) (string, error) { - var callbackData CallbackData - err := db.First(&callbackData, id).Error - return callbackData.Data, err -} - -func DeleteCallbackData(id uint) error { - return db.Unscoped().Where("id = ?", id).Delete(&CallbackData{}).Error -} diff --git a/dao/db.go b/dao/db.go deleted file mode 100644 index a0e516e..0000000 --- a/dao/db.go +++ /dev/null @@ -1,116 +0,0 @@ -package dao - -import ( - "errors" - "fmt" - "os" - "path/filepath" - "time" - - "github.com/krau/SaveAny-Bot/common" - "github.com/krau/SaveAny-Bot/config" - _ "github.com/ncruces/go-sqlite3/embed" - "github.com/ncruces/go-sqlite3/gormlite" - "gorm.io/gorm" - glogger "gorm.io/gorm/logger" -) - -var db *gorm.DB - -func Init() { - if err := os.MkdirAll(filepath.Dir(config.Cfg.DB.Path), 0755); err != nil { - common.Log.Panic("Failed to create data directory: ", err) - } - var err error - db, err = gorm.Open(gormlite.Open(config.Cfg.DB.Path), &gorm.Config{ - Logger: glogger.New(common.Log, glogger.Config{ - Colorful: true, - SlowThreshold: time.Second * 5, - LogLevel: glogger.Error, - IgnoreRecordNotFoundError: true, - ParameterizedQueries: true, - }), - PrepareStmt: true, - }) - if err != nil { - common.Log.Panic("Failed to open database: ", err) - } - common.Log.Debug("Database connected") - if err := db.AutoMigrate(&ReceivedFile{}, &User{}, &Dir{}, &CallbackData{}, &Rule{}); err != nil { - common.Log.Panic("迁移数据库失败, 如果您从旧版本升级, 建议手动删除数据库文件后重试: ", err) - } - if err := syncUsers(); err != nil { - common.Log.Panic("Failed to sync users:", err) - } - common.Log.Debug("Database migrated") - if config.Cfg.DB.Expire == 0 { - return - } - if err := cleanExpiredData(db); err != nil { - common.Log.Error("Failed to clean expired data: ", err) - } else { - common.Log.Debug("Cleaned expired data") - } - go cleanJob(db) -} - -func syncUsers() error { - dbUsers, err := GetAllUsers() - if err != nil { - return fmt.Errorf("failed to get users: %w", err) - } - - dbUserMap := make(map[int64]User) - for _, u := range dbUsers { - dbUserMap[u.ChatID] = u - } - - cfgUserMap := make(map[int64]struct{}) - for _, u := range config.Cfg.Users { - cfgUserMap[u.ID] = struct{}{} - } - - for cfgID := range cfgUserMap { - if _, exists := dbUserMap[cfgID]; !exists { - if err := CreateUser(cfgID); err != nil { - return fmt.Errorf("failed to create user %d: %w", cfgID, err) - } - common.Log.Infof("创建用户: %d", cfgID) - } - } - - for dbID, dbUser := range dbUserMap { - if _, exists := cfgUserMap[dbID]; !exists { - if err := DeleteUser(&dbUser); err != nil { - return fmt.Errorf("failed to delete user %d: %w", dbID, err) - } - common.Log.Infof("删除用户: %d", dbID) - } - } - - return nil -} - -func cleanExpiredData(db *gorm.DB) error { - var fileErr error - if err := db.Where("updated_at < ?", time.Now().Add(-time.Duration(config.Cfg.DB.Expire)*time.Second)).Unscoped().Delete(&ReceivedFile{}).Error; err != nil { - fileErr = fmt.Errorf("failed to delete expired files: %w", err) - } - var cbErr error - if err := db.Where("updated_at < ?", time.Now().Add(-time.Duration(config.Cfg.DB.Expire)*time.Second)).Unscoped().Delete(&CallbackData{}).Error; err != nil { - cbErr = fmt.Errorf("failed to delete expired callback data: %w", err) - } - return errors.Join(fileErr, cbErr) -} - -func cleanJob(db *gorm.DB) { - tick := time.NewTicker(time.Duration(config.Cfg.DB.Expire) * time.Second) - defer tick.Stop() - for range tick.C { - if err := cleanExpiredData(db); err != nil { - common.Log.Error("Failed to clean expired data: ", err) - } else { - common.Log.Debug("Cleaned expired data") - } - } -} diff --git a/dao/dir.go b/dao/dir.go deleted file mode 100644 index 6b5fcb2..0000000 --- a/dao/dir.go +++ /dev/null @@ -1,47 +0,0 @@ -package dao - -func CreateDirForUser(userID uint, storageName, path string) error { - dir := Dir{ - UserID: userID, - StorageName: storageName, - Path: path, - } - return db.Create(&dir).Error -} - -func GetDirByID(id uint) (*Dir, error) { - dir := &Dir{} - err := db.First(dir, id).Error - if err != nil { - return nil, err - } - return dir, err -} - -func GetUserDirs(userID uint) ([]Dir, error) { - var dirs []Dir - err := db.Where("user_id = ?", userID).Find(&dirs).Error - return dirs, err -} - -func GetUserDirsByChatID(chatID int64) ([]Dir, error) { - user, err := GetUserByChatID(chatID) - if err != nil { - return nil, err - } - return GetUserDirs(user.ID) -} - -func GetDirsByUserIDAndStorageName(userID uint, storageName string) ([]Dir, error) { - var dirs []Dir - err := db.Where("user_id = ? AND storage_name = ?", userID, storageName).Find(&dirs).Error - return dirs, err -} - -func DeleteDirForUser(userID uint, storageName, path string) error { - return db.Unscoped().Where("user_id = ? AND storage_name = ? AND path = ?", userID, storageName, path).Delete(&Dir{}).Error -} - -func DeleteDirByID(id uint) error { - return db.Unscoped().Delete(&Dir{}, id).Error -} \ No newline at end of file diff --git a/dao/file.go b/dao/file.go deleted file mode 100644 index 01fad70..0000000 --- a/dao/file.go +++ /dev/null @@ -1,36 +0,0 @@ -package dao - -func SaveReceivedFile(receivedFile *ReceivedFile) (*ReceivedFile, error) { - record, err := GetReceivedFileByChatAndMessageID(receivedFile.ChatID, receivedFile.MessageID) - if err == nil { - receivedFile.ID = record.ID - } - db.Save(receivedFile) - return receivedFile, db.Error -} - -func BatchSaveReceivedFiles(receivedFiles []*ReceivedFile) error { - if len(receivedFiles) == 0 { - return nil - } - for _, file := range receivedFiles { - record, err := GetReceivedFileByChatAndMessageID(file.ChatID, file.MessageID) - if err == nil { - file.ID = record.ID - } - } - return db.Save(receivedFiles).Error -} - -func GetReceivedFileByChatAndMessageID(chatID int64, messageID int) (*ReceivedFile, error) { - var receivedFile ReceivedFile - err := db.Where("chat_id = ? AND message_id = ?", chatID, messageID).First(&receivedFile).Error - if err != nil { - return nil, err - } - return &receivedFile, nil -} - -func DeleteReceivedFile(receivedFile *ReceivedFile) error { - return db.Unscoped().Delete(receivedFile).Error -} diff --git a/dao/model.go b/dao/model.go deleted file mode 100644 index 7c14402..0000000 --- a/dao/model.go +++ /dev/null @@ -1,51 +0,0 @@ -package dao - -import ( - "gorm.io/gorm" -) - -type ReceivedFile struct { - gorm.Model - Processing bool - // Which chat the file is from - ChatID int64 `gorm:"uniqueIndex:idx_chat_id_message_id;not null"` - // Which message the file is from - MessageID int `gorm:"uniqueIndex:idx_chat_id_message_id;not null"` - ReplyMessageID int - ReplyChatID int64 - FileName string - IsTelegraph bool - TelegraphURL string - UseUserClient bool // Whether to use userbot client to fetch the file -} - -type User struct { - gorm.Model - ChatID int64 `gorm:"uniqueIndex;not null"` - Silent bool - DefaultStorage string // Default storage name - Dirs []Dir - ApplyRule bool - Rules []Rule -} - -type Dir struct { - gorm.Model - UserID uint - StorageName string - Path string -} - -type CallbackData struct { - gorm.Model - Data string -} - -type Rule struct { - gorm.Model - UserID uint - Type string - Data string - StorageName string - DirPath string -} diff --git a/dao/rule.go b/dao/rule.go deleted file mode 100644 index 452c5b2..0000000 --- a/dao/rule.go +++ /dev/null @@ -1,22 +0,0 @@ -package dao - -func CreateRule(rule *Rule) error { - return db.Create(rule).Error -} - -func DeleteRule(ruleID uint) error { - return db.Unscoped().Delete(&Rule{}, ruleID).Error -} - -func UpdateUserApplyRule(chatID int64, applyRule bool) error { - return db.Model(&User{}).Where("chat_id = ?", chatID).Update("apply_rule", applyRule).Error -} - -func GetRulesByUserChatID(chatID int64) ([]Rule, error) { - var rules []Rule - err := db.Where("user_id = (SELECT id FROM users WHERE chat_id = ?)", chatID).Find(&rules).Error - if err != nil { - return nil, err - } - return rules, nil -} diff --git a/dao/user.go b/dao/user.go deleted file mode 100644 index e3d010e..0000000 --- a/dao/user.go +++ /dev/null @@ -1,33 +0,0 @@ -package dao - -func CreateUser(chatID int64) error { - if _, err := GetUserByChatID(chatID); err == nil { - return nil - } - return db.Create(&User{ChatID: chatID}).Error -} - -func GetAllUsers() ([]User, error) { - var users []User - err := db.Preload("Dirs"). - Preload("Rules"). - Find(&users).Error - return users, err -} - -func GetUserByChatID(chatID int64) (*User, error) { - var user User - err := db. - Preload("Dirs"). - Preload("Rules"). - Where("chat_id = ?", chatID).First(&user).Error - return &user, err -} - -func UpdateUser(user *User) error { - return db.Save(user).Error -} - -func DeleteUser(user *User) error { - return db.Unscoped().Select("Dirs", "Rules").Delete(user).Error -} diff --git a/database/db.go b/database/db.go new file mode 100644 index 0000000..afba26c --- /dev/null +++ b/database/db.go @@ -0,0 +1,86 @@ +package database + +import ( + "context" + "fmt" + "os" + "path/filepath" + "time" + + "github.com/charmbracelet/log" + "github.com/krau/SaveAny-Bot/config" + _ "github.com/ncruces/go-sqlite3/embed" + "github.com/ncruces/go-sqlite3/gormlite" + "gorm.io/gorm" + glogger "gorm.io/gorm/logger" +) + +var db *gorm.DB + +func Init(ctx context.Context) { + logger := log.FromContext(ctx) + if err := os.MkdirAll(filepath.Dir(config.Cfg.DB.Path), 0755); err != nil { + logger.Fatal("Failed to create data directory: ", err) + } + var err error + db, err = gorm.Open(gormlite.Open(config.Cfg.DB.Path), &gorm.Config{ + Logger: glogger.New(logger, glogger.Config{ + Colorful: true, + SlowThreshold: time.Second * 5, + LogLevel: glogger.Error, + IgnoreRecordNotFoundError: true, + ParameterizedQueries: true, + }), + PrepareStmt: true, + }) + if err != nil { + logger.Fatal("Failed to open database: ", err) + } + logger.Debug("Database connected") + if err := db.AutoMigrate(&User{}, &Dir{}, &Rule{}); err != nil { + logger.Fatal("迁移数据库失败, 如果您从旧版本升级, 建议手动删除数据库文件后重试: ", err) + } + if err := syncUsers(ctx); err != nil { + logger.Fatal("Failed to sync users:", err) + } + logger.Debug("Database migrated") + logger.Info("Database initialized") +} + +func syncUsers(ctx context.Context) error { + logger := log.FromContext(ctx) + dbUsers, err := GetAllUsers(ctx) + if err != nil { + return fmt.Errorf("failed to get users: %w", err) + } + + dbUserMap := make(map[int64]User) + for _, u := range dbUsers { + dbUserMap[u.ChatID] = u + } + + cfgUserMap := make(map[int64]struct{}) + for _, u := range config.Cfg.Users { + cfgUserMap[u.ID] = struct{}{} + } + + for cfgID := range cfgUserMap { + if _, exists := dbUserMap[cfgID]; !exists { + if err := CreateUser(ctx, cfgID); err != nil { + return fmt.Errorf("failed to create user %d: %w", cfgID, err) + } + logger.Infof("创建用户: %d", cfgID) + } + } + + for dbID, dbUser := range dbUserMap { + if _, exists := cfgUserMap[dbID]; !exists { + if err := DeleteUser(ctx, &dbUser); err != nil { + return fmt.Errorf("failed to delete user %d: %w", dbID, err) + } + logger.Infof("删除用户: %d", dbID) + } + } + + return nil +} diff --git a/database/dir.go b/database/dir.go new file mode 100644 index 0000000..3ae65dd --- /dev/null +++ b/database/dir.go @@ -0,0 +1,57 @@ +package database + +import "context" + +func CreateDirForUser(ctx context.Context, userID uint, storageName, path string) error { + dir := Dir{ + UserID: userID, + StorageName: storageName, + Path: path, + } + return db.WithContext(ctx).Create(&dir).Error +} + +func GetDirByID(ctx context.Context, id uint) (*Dir, error) { + dir := &Dir{} + err := db.WithContext(ctx).First(dir, id).Error + if err != nil { + return nil, err + } + return dir, err +} + +func GetUserDirs(ctx context.Context, userID uint) ([]Dir, error) { + var dirs []Dir + err := db.WithContext(ctx).Where("user_id = ?", userID).Find(&dirs).Error + return dirs, err +} + +func GetUserDirsByChatID(ctx context.Context, chatID int64) ([]Dir, error) { + user, err := GetUserByChatID(ctx, chatID) + if err != nil { + return nil, err + } + return GetUserDirs(ctx, user.ID) +} + +func GetDirsByUserIDAndStorageName(ctx context.Context, userID uint, storageName string) ([]Dir, error) { + var dirs []Dir + err := db.WithContext(ctx).Where("user_id = ? AND storage_name = ?", userID, storageName).Find(&dirs).Error + return dirs, err +} + +func GetDirsByUserChatIDAndStorageName(ctx context.Context, chatID int64, storageName string) ([]Dir, error) { + user, err := GetUserByChatID(ctx, chatID) + if err != nil { + return nil, err + } + return GetDirsByUserIDAndStorageName(ctx, user.ID, storageName) +} + +func DeleteDirForUser(ctx context.Context, userID uint, storageName, path string) error { + return db.WithContext(ctx).Unscoped().Where("user_id = ? AND storage_name = ? AND path = ?", userID, storageName, path).Delete(&Dir{}).Error +} + +func DeleteDirByID(ctx context.Context, id uint) error { + return db.WithContext(ctx).Unscoped().Delete(&Dir{}, id).Error +} diff --git a/database/model.go b/database/model.go new file mode 100644 index 0000000..dc455ea --- /dev/null +++ b/database/model.go @@ -0,0 +1,31 @@ +package database + +import ( + "gorm.io/gorm" +) + +type User struct { + gorm.Model + ChatID int64 `gorm:"uniqueIndex;not null"` + Silent bool + DefaultStorage string + Dirs []Dir + ApplyRule bool + Rules []Rule +} + +type Dir struct { + gorm.Model + UserID uint + StorageName string + Path string +} + +type Rule struct { + gorm.Model + UserID uint + Type string + Data string + StorageName string + DirPath string +} diff --git a/database/rule.go b/database/rule.go new file mode 100644 index 0000000..937e4ab --- /dev/null +++ b/database/rule.go @@ -0,0 +1,24 @@ +package database + +import "context" + +func CreateRule(ctx context.Context, rule *Rule) error { + return db.WithContext(ctx).Create(rule).Error +} + +func DeleteRule(ctx context.Context, ruleID uint) error { + return db.WithContext(ctx).Unscoped().Delete(&Rule{}, ruleID).Error +} + +func UpdateUserApplyRule(ctx context.Context, chatID int64, applyRule bool) error { + return db.WithContext(ctx).Model(&User{}).Where("chat_id = ?", chatID).Update("apply_rule", applyRule).Error +} + +func GetRulesByUserChatID(ctx context.Context, chatID int64) ([]Rule, error) { + var rules []Rule + err := db.WithContext(ctx).Where("user_id = (SELECT id FROM users WHERE chat_id = ?)", chatID).Find(&rules).Error + if err != nil { + return nil, err + } + return rules, nil +} diff --git a/database/user.go b/database/user.go new file mode 100644 index 0000000..62953bf --- /dev/null +++ b/database/user.go @@ -0,0 +1,40 @@ +package database + +import "context" + +func CreateUser(ctx context.Context, chatID int64) error { + if _, err := GetUserByChatID(ctx, chatID); err == nil { + return nil + } + return db.Create(&User{ChatID: chatID}).Error +} + +func GetAllUsers(ctx context.Context) ([]User, error) { + var users []User + err := db.Preload("Dirs"). + WithContext(ctx). + Preload("Rules"). + Find(&users).Error + return users, err +} + +func GetUserByChatID(ctx context.Context, chatID int64) (*User, error) { + var user User + err := db. + Preload("Dirs"). + WithContext(ctx). + Preload("Rules"). + Where("chat_id = ?", chatID).First(&user).Error + return &user, err +} + +func UpdateUser(ctx context.Context, user *User) error { + if _, err := GetUserByChatID(ctx, user.ChatID); err != nil { + return err + } + return db.WithContext(ctx).Save(user).Error +} + +func DeleteUser(ctx context.Context, user *User) error { + return db.WithContext(ctx).Unscoped().Select("Dirs", "Rules").Delete(user).Error +} diff --git a/go.mod b/go.mod index 4aa60b6..775c807 100644 --- a/go.mod +++ b/go.mod @@ -5,20 +5,17 @@ go 1.23.5 require ( github.com/blang/semver v3.5.1+incompatible github.com/celestix/gotgproto v1.0.0-beta21 - github.com/celestix/telegraph-go/v2 v2.0.4 github.com/cenkalti/backoff/v4 v4.3.0 github.com/charmbracelet/huh v0.7.0 github.com/charmbracelet/log v0.4.2 - github.com/eko/gocache/lib/v4 v4.2.0 - github.com/eko/gocache/store/go_cache/v4 v4.2.2 github.com/fatih/color v1.18.0 github.com/gabriel-vasile/mimetype v1.4.9 github.com/go-faster/errors v0.7.1 - github.com/gookit/slog v0.5.8 github.com/gotd/contrib v0.21.0 github.com/gotd/td v0.125.0 github.com/minio/minio-go/v7 v7.0.92 github.com/rhysd/go-github-selfupdate v1.2.3 + github.com/rs/xid v1.6.0 github.com/spf13/cobra v1.9.1 github.com/spf13/viper v1.20.1 golang.org/x/net v0.41.0 @@ -29,7 +26,6 @@ require ( github.com/AnimeKaizoku/cacher v1.0.2 // indirect github.com/atotto/clipboard v0.1.4 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect - github.com/beorn7/perks v1.0.1 // indirect github.com/catppuccin/go v0.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/charmbracelet/bubbles v0.21.0 // indirect @@ -53,7 +49,6 @@ require ( github.com/go-logfmt/logfmt v0.6.0 // indirect github.com/go-viper/mapstructure/v2 v2.2.1 // indirect github.com/goccy/go-json v0.10.5 // indirect - github.com/golang/mock v1.6.0 // indirect github.com/google/go-github/v30 v30.1.0 // indirect github.com/google/go-querystring v1.1.0 // indirect github.com/google/uuid v1.6.0 // indirect @@ -74,20 +69,14 @@ require ( github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect github.com/muesli/cancelreader v0.2.2 // indirect github.com/muesli/termenv v0.16.0 // indirect - github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/ncruces/go-strftime v0.1.9 // indirect github.com/ncruces/julianday v1.0.0 // indirect github.com/ogen-go/ogen v1.14.0 // indirect github.com/onsi/gomega v1.36.2 // indirect github.com/philhofer/fwd v1.1.3-0.20240916144458-20a13a1f6b7c // indirect github.com/pkg/errors v0.9.1 // indirect - github.com/prometheus/client_golang v1.22.0 // indirect - github.com/prometheus/client_model v0.6.2 // indirect - github.com/prometheus/common v0.64.0 // indirect - github.com/prometheus/procfs v0.16.1 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/rivo/uniseg v0.4.7 // indirect - github.com/rs/xid v1.6.0 // indirect github.com/segmentio/asm v1.2.0 // indirect github.com/tcnksm/go-gitconfig v0.1.2 // indirect github.com/tetratelabs/wazero v1.9.0 // indirect @@ -97,14 +86,12 @@ require ( go.opentelemetry.io/otel/metric v1.36.0 // indirect go.opentelemetry.io/otel/trace v1.36.0 // indirect go.uber.org/atomic v1.11.0 // indirect - go.uber.org/mock v0.5.2 // indirect go.uber.org/zap v1.27.0 // indirect golang.org/x/crypto v0.39.0 // indirect golang.org/x/mod v0.25.0 // indirect golang.org/x/oauth2 v0.30.0 // indirect - golang.org/x/term v0.32.0 // indirect golang.org/x/tools v0.34.0 // indirect - google.golang.org/protobuf v1.36.6 // indirect + gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v2 v2.4.0 // indirect modernc.org/libc v1.65.10 // indirect modernc.org/mathutil v1.7.1 // indirect @@ -114,19 +101,16 @@ require ( ) require ( + github.com/dgraph-io/ristretto/v2 v2.2.0 github.com/duke-git/lancet/v2 v2.3.6 github.com/fsnotify/fsnotify v1.9.0 // indirect - github.com/glebarez/sqlite v1.11.0 - github.com/gookit/color v1.5.4 // indirect - github.com/gookit/goutil v0.6.18 // indirect - github.com/gookit/gsr v0.1.1 // indirect + github.com/glebarez/sqlite v1.11.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/klauspost/compress v1.18.0 // indirect github.com/mitchellh/mapstructure v1.5.0 github.com/ncruces/go-sqlite3 v0.26.1 github.com/ncruces/go-sqlite3/gormlite v0.24.0 github.com/nicksnyder/go-i18n/v2 v2.6.0 - github.com/patrickmn/go-cache v2.1.0+incompatible github.com/pelletier/go-toml/v2 v2.2.4 github.com/sagikazarmark/locafero v0.9.0 // indirect github.com/sourcegraph/conc v0.3.0 // indirect @@ -134,9 +118,8 @@ require ( github.com/spf13/cast v1.9.2 // indirect github.com/spf13/pflag v1.0.6 // indirect github.com/subosito/gotenv v1.6.0 // indirect - github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect - go.uber.org/multierr v1.11.0 // indirect + go.uber.org/multierr v1.11.0 golang.org/x/exp v0.0.0-20250606033433-dcc06ee1d476 // indirect golang.org/x/sync v0.15.0 golang.org/x/sys v0.33.0 // indirect diff --git a/go.sum b/go.sum index 5a46c31..32f47cd 100644 --- a/go.sum +++ b/go.sum @@ -10,16 +10,12 @@ github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiE github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= github.com/aymanbagabas/go-udiff v0.2.0 h1:TK0fH4MteXUDspT88n8CKzvK0X9O2xu9yQjWpi6yML8= github.com/aymanbagabas/go-udiff v0.2.0/go.mod h1:RE4Ex0qsGkTAJoQdQQCA0uG+nAzJO/pI/QwceO5fgrA= -github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= -github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/blang/semver v3.5.1+incompatible h1:cQNTCjp13qL8KC3Nbxr/y2Bqb63oX6wdnnjpJbkM4JQ= github.com/blang/semver v3.5.1+incompatible/go.mod h1:kRBLl5iJ+tD4TcOOxsy/0fnwebNt5EWlYSAyrTnjyyk= github.com/catppuccin/go v0.3.0 h1:d+0/YicIq+hSTo5oPuRi5kOpqkVA5tAsU6dNhvRu+aY= github.com/catppuccin/go v0.3.0/go.mod h1:8IHJuMGaUUjQM82qBrGNBv7LFq6JI3NnQCF6MOlZjpc= github.com/celestix/gotgproto v1.0.0-beta21 h1:VUuAC/Kj5Sdu/WZan3ZUb0GFNAavFxMYxmHAhCBX0J8= github.com/celestix/gotgproto v1.0.0-beta21/go.mod h1:viDkHe9rBegJoEE/jNuFfbBM0XZ3pSx/ugjaNaVnbvU= -github.com/celestix/telegraph-go/v2 v2.0.4 h1:w8HWymJFhMSMPjdGoyTh3/NqE3eXAT1njTvelh0338k= -github.com/celestix/telegraph-go/v2 v2.0.4/go.mod h1:vu2LtqM7MgOAJ2LDF8XK27DWdd1QYLBfZGhalEh086Y= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= @@ -61,16 +57,16 @@ github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s= github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgraph-io/ristretto/v2 v2.2.0 h1:bkY3XzJcXoMuELV8F+vS8kzNgicwQFAaGINAEJdWGOM= +github.com/dgraph-io/ristretto/v2 v2.2.0/go.mod h1:RZrm63UmcBAaYWC1DotLYBmTvgkrs0+XhBd7Npn7/zI= +github.com/dgryski/go-farm v0.0.0-20240924180020-3414d57e47da h1:aIftn67I1fkbMa512G+w+Pxci9hJPB8oMnkcP3iZF38= +github.com/dgryski/go-farm v0.0.0-20240924180020-3414d57e47da/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ= github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/duke-git/lancet/v2 v2.3.6 h1:NKxSSh+dlgp37funvxLCf3xLBeUYa7VW1thYQP6j3Y8= github.com/duke-git/lancet/v2 v2.3.6/go.mod h1:zGa2R4xswg6EG9I6WnyubDbFO/+A/RROxIbXcwryTsc= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= -github.com/eko/gocache/lib/v4 v4.2.0 h1:MNykyi5Xw+5Wu3+PUrvtOCaKSZM1nUSVftbzmeC7Yuw= -github.com/eko/gocache/lib/v4 v4.2.0/go.mod h1:7ViVmbU+CzDHzRpmB4SXKyyzyuJ8A3UW3/cszpcqB4M= -github.com/eko/gocache/store/go_cache/v4 v4.2.2 h1:tAI9nl6TLoJyKG1ujF0CS0n/IgTEMl+NivxtR5R3/hw= -github.com/eko/gocache/store/go_cache/v4 v4.2.2/go.mod h1:T9zkHokzr8K9EiC7RfMbDg6HSwaV6rv3UdcNu13SGcA= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= @@ -109,8 +105,6 @@ github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIx github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= -github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= -github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= @@ -125,14 +119,6 @@ github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17k github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/gookit/color v1.5.4 h1:FZmqs7XOyGgCAxmWyPslpiok1k05wmY3SJTytgvYFs0= -github.com/gookit/color v1.5.4/go.mod h1:pZJOeOS8DM43rXbp4AZo1n9zCU2qjpcRko0b6/QJi9w= -github.com/gookit/goutil v0.6.18 h1:MUVj0G16flubWT8zYVicIuisUiHdgirPAkmnfD2kKgw= -github.com/gookit/goutil v0.6.18/go.mod h1:AY/5sAwKe7Xck+mEbuxj0n/bc3qwrGNe3Oeulln7zBA= -github.com/gookit/gsr v0.1.1 h1:TaHD3M7qa6lcAf9D2J4mGNg+QjgDtD1bw7uctF8RXOM= -github.com/gookit/gsr v0.1.1/go.mod h1:7wv4Y4WCnil8+DlDYHBjidzrEzfHhXEoFjEA0pPPWpI= -github.com/gookit/slog v0.5.8 h1:XZCeHLQvvOZWcSUDZcqxXITsL9+d1ESsKZoASBmK1lI= -github.com/gookit/slog v0.5.8/go.mod h1:s0ViFOY/IgUuT4MDPF0l9x5/npcciy8pL4xwWZadnoc= github.com/gotd/contrib v0.21.0 h1:4Fj05jnyBE84toXZl7mVTvt7f732n5uglvztyG6nTr4= github.com/gotd/contrib v0.21.0/go.mod h1:ENoUh75IhHGxfz/puVJg8BU4ZF89yrL6Q47TyoNqFYo= github.com/gotd/ige v0.2.2 h1:XQ9dJZwBfDnOGSTxKXBGP4gMud3Qku2ekScRjDWWfEk= @@ -156,14 +142,13 @@ github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa02 github.com/klauspost/cpuid/v2 v2.2.10 h1:tBs3QSyvjDyFTq3uoc/9xFpCuOsJQFNPiAhYdw2skhE= github.com/klauspost/cpuid/v2 v2.2.10/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= -github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= @@ -190,8 +175,6 @@ github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELU github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo= github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc= github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk= -github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= -github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/ncruces/go-sqlite3 v0.26.1 h1:lBXmbmucH1Bsj57NUQR6T84UoMN7jnNImhF+ibEITJU= github.com/ncruces/go-sqlite3 v0.26.1/go.mod h1:XFTPtFIo1DmGCh+XVP8KGn9b/o2f+z0WZuT09x2N6eo= github.com/ncruces/go-sqlite3/gormlite v0.24.0 h1:81sHeq3CCdhjoqAB650n5wEdRlLO9VBvosArskcN3+c= @@ -208,8 +191,6 @@ github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+W github.com/onsi/gomega v1.4.2/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= github.com/onsi/gomega v1.36.2 h1:koNYke6TVk6ZmnyHrCXba/T/MoLBXFjeC1PtvYgw0A8= github.com/onsi/gomega v1.36.2/go.mod h1:DdwyADRjrc825LhMEkD76cHR5+pUnjhUN8GlHlRPHzY= -github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= -github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/philhofer/fwd v1.1.3-0.20240916144458-20a13a1f6b7c h1:dAMKvw0MlJT1GshSTtih8C2gDs04w8dReiOGXrGLNoY= @@ -218,14 +199,6 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/prometheus/client_golang v1.22.0 h1:rb93p9lokFEsctTys46VnV1kLCDpVZ0a/Y92Vm0Zc6Q= -github.com/prometheus/client_golang v1.22.0/go.mod h1:R7ljNsLXhuQXYZYtw6GAE9AZg8Y7vEW5scdCXrWRXC0= -github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= -github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= -github.com/prometheus/common v0.64.0 h1:pdZeA+g617P7oGv1CzdTzyeShxAGrTBsolKNOLQPGO4= -github.com/prometheus/common v0.64.0/go.mod h1:0gZns+BLRQ3V6NdaerOhMbwwRbNh9hkGINtQAsP5GS8= -github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg= -github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rhysd/go-github-selfupdate v1.2.3 h1:iaa+J202f+Nc+A8zi75uccC8Wg3omaM7HDeimXA22Ag= @@ -267,11 +240,8 @@ github.com/tinylib/msgp v1.3.0/go.mod h1:ykjzy2wzgrlvpDCRc4LA8UXy6D8bzMSuAF3WD57 github.com/ulikunitz/xz v0.5.9/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= github.com/ulikunitz/xz v0.5.12 h1:37Nm15o69RwBkXM0J6A5OlE67RZTfzUxTj8fB3dfcsc= github.com/ulikunitz/xz v0.5.12/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= -github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= -github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= -github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= go.opentelemetry.io/otel v1.36.0 h1:UumtzIklRBY6cI/lllNZlALOF5nNIzJVb16APdvgTXg= @@ -284,28 +254,22 @@ go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= -go.uber.org/mock v0.5.2 h1:LbtPTcP8A5k9WPXj54PPPbjcI4Y6lhyOZXn+VS7wNko= -go.uber.org/mock v0.5.2/go.mod h1:wLlUxC2vVTPTaE3UD51E0BGOAElKrILxhVSDYQLld5o= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM= golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U= golang.org/x/exp v0.0.0-20250606033433-dcc06ee1d476 h1:bsqhLWFR6G6xiQcb+JoGqdKdRU6WzPWmK8E0jxTjzo4= golang.org/x/exp v0.0.0-20250606033433-dcc06ee1d476/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8= -golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w= golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw= golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -313,46 +277,29 @@ golang.org/x/oauth2 v0.0.0-20181106182150-f42d05182288/go.mod h1:N/0e6XlmueqKjAG golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8= golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.32.0 h1:DR4lr0TjUs3epypdhTOkMmuF5CDFJ/8pOnbzMZPQ7bg= -golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M= golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.3.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= -google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= diff --git a/main.go b/main.go index d378a23..79e9a42 100644 --- a/main.go +++ b/main.go @@ -1,7 +1,15 @@ package main -import "github.com/krau/SaveAny-Bot/cmd" +import ( + "context" + "os" + "os/signal" + + "github.com/krau/SaveAny-Bot/cmd" +) func main() { - cmd.Execute() + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) + defer cancel() + cmd.Execute(ctx) } diff --git a/pkg/consts/specific.go b/pkg/consts/specific.go new file mode 100644 index 0000000..3db7125 --- /dev/null +++ b/pkg/consts/specific.go @@ -0,0 +1,5 @@ +package consts + +const ( + RuleStorNameChosen = "CHOSEN" +) diff --git a/pkg/consts/tglimit/tglimit.go b/pkg/consts/tglimit/tglimit.go new file mode 100644 index 0000000..75c88fc --- /dev/null +++ b/pkg/consts/tglimit/tglimit.go @@ -0,0 +1,6 @@ +package tglimit + +const ( + MaxPartSize = 1024 * 1024 + MaxUploadPartSize = 512 * 1024 +) diff --git a/pkg/consts/version.go b/pkg/consts/version.go new file mode 100644 index 0000000..9dc5fe5 --- /dev/null +++ b/pkg/consts/version.go @@ -0,0 +1,9 @@ +package consts + +// inject version by '-X' flag +// go build -ldflags "-X github.com/krau/SaveAny-Bot/pkg/consts.Version=${{ env.VERSION }}" +var ( + Version string = "dev" + BuildTime string = "unknown" + GitCommit string = "unknown" +) diff --git a/pkg/enums/key/context_key.go b/pkg/enums/key/context_key.go new file mode 100644 index 0000000..4c41074 --- /dev/null +++ b/pkg/enums/key/context_key.go @@ -0,0 +1,5 @@ +package key + +//go:generate go-enum --values --names --flag --nocase +// ENUM(content-length) +type ContextKey string diff --git a/pkg/enums/key/context_key_enum.go b/pkg/enums/key/context_key_enum.go new file mode 100644 index 0000000..942dc66 --- /dev/null +++ b/pkg/enums/key/context_key_enum.go @@ -0,0 +1,82 @@ +// Code generated by go-enum DO NOT EDIT. +// Version: 0.6.1 +// Revision: a6f63bddde05aca4221df9c8e9e6d7d9674b1cb4 +// Build Date: 2025-03-18T23:42:14Z +// Built By: goreleaser + +package key + +import ( + "fmt" + "strings" +) + +const ( + // ContextKeyContentLength is a ContextKey of type content-length. + ContextKeyContentLength ContextKey = "content-length" +) + +var ErrInvalidContextKey = fmt.Errorf("not a valid ContextKey, try [%s]", strings.Join(_ContextKeyNames, ", ")) + +var _ContextKeyNames = []string{ + string(ContextKeyContentLength), +} + +// ContextKeyNames returns a list of possible string values of ContextKey. +func ContextKeyNames() []string { + tmp := make([]string, len(_ContextKeyNames)) + copy(tmp, _ContextKeyNames) + return tmp +} + +// ContextKeyValues returns a list of the values for ContextKey +func ContextKeyValues() []ContextKey { + return []ContextKey{ + ContextKeyContentLength, + } +} + +// String implements the Stringer interface. +func (x ContextKey) String() string { + return string(x) +} + +// IsValid provides a quick way to determine if the typed value is +// part of the allowed enumerated values +func (x ContextKey) IsValid() bool { + _, err := ParseContextKey(string(x)) + return err == nil +} + +var _ContextKeyValue = map[string]ContextKey{ + "content-length": ContextKeyContentLength, +} + +// ParseContextKey attempts to convert a string to a ContextKey. +func ParseContextKey(name string) (ContextKey, error) { + if x, ok := _ContextKeyValue[name]; ok { + return x, nil + } + // Case insensitive parse, do a separate lookup to prevent unnecessary cost of lowercasing a string if we don't need to. + if x, ok := _ContextKeyValue[strings.ToLower(name)]; ok { + return x, nil + } + return ContextKey(""), fmt.Errorf("%s is %w", name, ErrInvalidContextKey) +} + +// Set implements the Golang flag.Value interface func. +func (x *ContextKey) Set(val string) error { + v, err := ParseContextKey(val) + *x = v + return err +} + +// Get implements the Golang flag.Getter interface func. +func (x *ContextKey) Get() interface{} { + return *x +} + +// Type implements the github.com/spf13/pFlag Value interface. +func (x *ContextKey) Type() string { + return "ContextKey" +} diff --git a/pkg/enums/rule/ruletype.go b/pkg/enums/rule/ruletype.go new file mode 100644 index 0000000..62b88ee --- /dev/null +++ b/pkg/enums/rule/ruletype.go @@ -0,0 +1,16 @@ +package rule + +type RuleType string + +const ( + FileNameRegex RuleType = "FILENAME-REGEX" + MessageRegex RuleType = "MESSAGE-REGEX" +) + +func (r RuleType) String() string { + return string(r) +} + +func Values() []RuleType { + return []RuleType{FileNameRegex, MessageRegex} +} diff --git a/pkg/enums/storage/storages.go b/pkg/enums/storage/storages.go new file mode 100644 index 0000000..b901e61 --- /dev/null +++ b/pkg/enums/storage/storages.go @@ -0,0 +1,9 @@ +package storage + +//go:generate go-enum --values --names --noprefix --flag --nocase + +// StorageType +/* ENUM( +local, webdav, alist, minio, telegram +) */ +type StorageType string diff --git a/pkg/enums/storage/storages_enum.go b/pkg/enums/storage/storages_enum.go new file mode 100644 index 0000000..79377a3 --- /dev/null +++ b/pkg/enums/storage/storages_enum.go @@ -0,0 +1,102 @@ +// Code generated by go-enum DO NOT EDIT. +// Version: 0.6.1 +// Revision: a6f63bddde05aca4221df9c8e9e6d7d9674b1cb4 +// Build Date: 2025-03-18T23:42:14Z +// Built By: goreleaser + +package storage + +import ( + "fmt" + "strings" +) + +const ( + // Local is a StorageType of type local. + Local StorageType = "local" + // Webdav is a StorageType of type webdav. + Webdav StorageType = "webdav" + // Alist is a StorageType of type alist. + Alist StorageType = "alist" + // Minio is a StorageType of type minio. + Minio StorageType = "minio" + // Telegram is a StorageType of type telegram. + Telegram StorageType = "telegram" +) + +var ErrInvalidStorageType = fmt.Errorf("not a valid StorageType, try [%s]", strings.Join(_StorageTypeNames, ", ")) + +var _StorageTypeNames = []string{ + string(Local), + string(Webdav), + string(Alist), + string(Minio), + string(Telegram), +} + +// StorageTypeNames returns a list of possible string values of StorageType. +func StorageTypeNames() []string { + tmp := make([]string, len(_StorageTypeNames)) + copy(tmp, _StorageTypeNames) + return tmp +} + +// StorageTypeValues returns a list of the values for StorageType +func StorageTypeValues() []StorageType { + return []StorageType{ + Local, + Webdav, + Alist, + Minio, + Telegram, + } +} + +// String implements the Stringer interface. +func (x StorageType) String() string { + return string(x) +} + +// IsValid provides a quick way to determine if the typed value is +// part of the allowed enumerated values +func (x StorageType) IsValid() bool { + _, err := ParseStorageType(string(x)) + return err == nil +} + +var _StorageTypeValue = map[string]StorageType{ + "local": Local, + "webdav": Webdav, + "alist": Alist, + "minio": Minio, + "telegram": Telegram, +} + +// ParseStorageType attempts to convert a string to a StorageType. +func ParseStorageType(name string) (StorageType, error) { + if x, ok := _StorageTypeValue[name]; ok { + return x, nil + } + // Case insensitive parse, do a separate lookup to prevent unnecessary cost of lowercasing a string if we don't need to. + if x, ok := _StorageTypeValue[strings.ToLower(name)]; ok { + return x, nil + } + return StorageType(""), fmt.Errorf("%s is %w", name, ErrInvalidStorageType) +} + +// Set implements the Golang flag.Value interface func. +func (x *StorageType) Set(val string) error { + v, err := ParseStorageType(val) + *x = v + return err +} + +// Get implements the Golang flag.Getter interface func. +func (x *StorageType) Get() interface{} { + return *x +} + +// Type implements the github.com/spf13/pFlag Value interface. +func (x *StorageType) Type() string { + return "StorageType" +} diff --git a/pkg/enums/tasktype/tasktype.go b/pkg/enums/tasktype/tasktype.go new file mode 100644 index 0000000..26a3239 --- /dev/null +++ b/pkg/enums/tasktype/tasktype.go @@ -0,0 +1,5 @@ +package tasktype + +//go:generate go-enum --values --names --flag --nocase +// ENUM(tgfiles,tphpics) +type TaskType string diff --git a/pkg/enums/tasktype/tasktype_enum.go b/pkg/enums/tasktype/tasktype_enum.go new file mode 100644 index 0000000..f8c117f --- /dev/null +++ b/pkg/enums/tasktype/tasktype_enum.go @@ -0,0 +1,87 @@ +// Code generated by go-enum DO NOT EDIT. +// Version: 0.6.1 +// Revision: a6f63bddde05aca4221df9c8e9e6d7d9674b1cb4 +// Build Date: 2025-03-18T23:42:14Z +// Built By: goreleaser + +package tasktype + +import ( + "fmt" + "strings" +) + +const ( + // TaskTypeTgfiles is a TaskType of type tgfiles. + TaskTypeTgfiles TaskType = "tgfiles" + // TaskTypeTphpics is a TaskType of type tphpics. + TaskTypeTphpics TaskType = "tphpics" +) + +var ErrInvalidTaskType = fmt.Errorf("not a valid TaskType, try [%s]", strings.Join(_TaskTypeNames, ", ")) + +var _TaskTypeNames = []string{ + string(TaskTypeTgfiles), + string(TaskTypeTphpics), +} + +// TaskTypeNames returns a list of possible string values of TaskType. +func TaskTypeNames() []string { + tmp := make([]string, len(_TaskTypeNames)) + copy(tmp, _TaskTypeNames) + return tmp +} + +// TaskTypeValues returns a list of the values for TaskType +func TaskTypeValues() []TaskType { + return []TaskType{ + TaskTypeTgfiles, + TaskTypeTphpics, + } +} + +// String implements the Stringer interface. +func (x TaskType) String() string { + return string(x) +} + +// IsValid provides a quick way to determine if the typed value is +// part of the allowed enumerated values +func (x TaskType) IsValid() bool { + _, err := ParseTaskType(string(x)) + return err == nil +} + +var _TaskTypeValue = map[string]TaskType{ + "tgfiles": TaskTypeTgfiles, + "tphpics": TaskTypeTphpics, +} + +// ParseTaskType attempts to convert a string to a TaskType. +func ParseTaskType(name string) (TaskType, error) { + if x, ok := _TaskTypeValue[name]; ok { + return x, nil + } + // Case insensitive parse, do a separate lookup to prevent unnecessary cost of lowercasing a string if we don't need to. + if x, ok := _TaskTypeValue[strings.ToLower(name)]; ok { + return x, nil + } + return TaskType(""), fmt.Errorf("%s is %w", name, ErrInvalidTaskType) +} + +// Set implements the Golang flag.Value interface func. +func (x *TaskType) Set(val string) error { + v, err := ParseTaskType(val) + *x = v + return err +} + +// Get implements the Golang flag.Getter interface func. +func (x *TaskType) Get() interface{} { + return *x +} + +// Type implements the github.com/spf13/pFlag Value interface. +func (x *TaskType) Type() string { + return "TaskType" +} diff --git a/pkg/queue/queue.go b/pkg/queue/queue.go new file mode 100644 index 0000000..4198791 --- /dev/null +++ b/pkg/queue/queue.go @@ -0,0 +1,241 @@ +package queue + +import ( + "container/list" + "errors" + "fmt" + "sync" +) + +type TaskQueue[T any] struct { + tasks *list.List + taskMap map[string]*Task[T] + runningTaskMap map[string]*Task[T] + mu sync.RWMutex + cond *sync.Cond + closed bool +} + +func NewTaskQueue[T any]() *TaskQueue[T] { + tq := &TaskQueue[T]{ + tasks: list.New(), + taskMap: make(map[string]*Task[T]), + runningTaskMap: make(map[string]*Task[T]), + } + tq.cond = sync.NewCond(&tq.mu) + return tq +} + +func (tq *TaskQueue[T]) Add(task *Task[T]) error { + tq.mu.Lock() + defer tq.mu.Unlock() + + if tq.closed { + return errors.New("queue is closed") + } + + if _, exists := tq.taskMap[task.ID]; exists { + return fmt.Errorf("task with ID %s already exists", task.ID) + } + + if task.IsCancelled() { + return fmt.Errorf("task %s has been cancelled", task.ID) + } + + element := tq.tasks.PushBack(task) + task.element = element + tq.taskMap[task.ID] = task + + tq.cond.Signal() + return nil +} + +func (tq *TaskQueue[T]) Get() (*Task[T], error) { + tq.mu.Lock() + defer tq.mu.Unlock() + + for tq.tasks.Len() == 0 && !tq.closed { + tq.cond.Wait() + } + + if tq.closed && tq.tasks.Len() == 0 { + return nil, fmt.Errorf("queue is closed and empty") + } + + for tq.tasks.Len() > 0 { + element := tq.tasks.Front() + task := element.Value.(*Task[T]) + + tq.tasks.Remove(element) + task.element = nil + + if !task.IsCancelled() { + tq.runningTaskMap[task.ID] = task + return task, nil + } + } + + if !tq.closed { + return tq.Get() + } + + return nil, fmt.Errorf("queue is closed and empty") +} + +func (tq *TaskQueue[T]) Done(taskID string) { + tq.mu.Lock() + defer tq.mu.Unlock() + + delete(tq.taskMap, taskID) + delete(tq.runningTaskMap, taskID) +} + +func (tq *TaskQueue[T]) Peek() (*Task[T], error) { + tq.mu.RLock() + defer tq.mu.RUnlock() + + if tq.tasks.Len() == 0 { + return nil, fmt.Errorf("queue is empty") + } + + for element := tq.tasks.Front(); element != nil; element = element.Next() { + task := element.Value.(*Task[T]) + if !task.IsCancelled() { + return task, nil + } + } + + return nil, fmt.Errorf("queue has no valid tasks") +} + +func (tq *TaskQueue[T]) Length() int { + tq.mu.RLock() + defer tq.mu.RUnlock() + return tq.tasks.Len() +} + +func (tq *TaskQueue[T]) ActiveLength() int { + tq.mu.RLock() + defer tq.mu.RUnlock() + + count := 0 + for element := tq.tasks.Front(); element != nil; element = element.Next() { + task := element.Value.(*Task[T]) + if !task.IsCancelled() { + count++ + } + } + return count +} + +func (tq *TaskQueue[T]) CancelTask(taskID string) error { + tq.mu.RLock() + task, exists := tq.taskMap[taskID] + if !exists { + task, exists = tq.runningTaskMap[taskID] + } + tq.mu.RUnlock() + + if !exists { + return fmt.Errorf("task %s does not exist", taskID) + } + + task.Cancel() + return nil +} + +func (tq *TaskQueue[T]) RemoveTask(taskID string) error { + tq.mu.Lock() + defer tq.mu.Unlock() + + task, exists := tq.taskMap[taskID] + if !exists { + _, exists = tq.runningTaskMap[taskID] + if exists { + delete(tq.runningTaskMap, taskID) + } + return fmt.Errorf("task %s is already running, cannot remove from queue", taskID) + } + + if task.element != nil { + tq.tasks.Remove(task.element) + } + delete(tq.taskMap, taskID) + task.Cancel() + return nil +} + +func (tq *TaskQueue[T]) CancelAll() { + tq.mu.RLock() + tasks := make([]*Task[T], 0, tq.tasks.Len()) + for element := tq.tasks.Front(); element != nil; element = element.Next() { + tasks = append(tasks, element.Value.(*Task[T])) + } + tq.mu.RUnlock() + + for _, task := range tasks { + task.Cancel() + } +} + +func (tq *TaskQueue[T]) GetTask(taskID string) (*Task[T], error) { + tq.mu.RLock() + defer tq.mu.RUnlock() + + task, exists := tq.taskMap[taskID] + if !exists { + return nil, fmt.Errorf("task %s does not exist", taskID) + } + + return task, nil +} + +func (tq *TaskQueue[T]) Close() { + tq.mu.Lock() + defer tq.mu.Unlock() + + tq.closed = true + tq.cond.Broadcast() +} + +func (tq *TaskQueue[T]) IsClosed() bool { + tq.mu.RLock() + defer tq.mu.RUnlock() + return tq.closed +} + +func (tq *TaskQueue[T]) Clear() { + tq.mu.Lock() + defer tq.mu.Unlock() + + for element := tq.tasks.Front(); element != nil; element = element.Next() { + task := element.Value.(*Task[T]) + task.Cancel() + } + + tq.tasks.Init() + tq.taskMap = make(map[string]*Task[T]) +} + +func (tq *TaskQueue[T]) CleanupCancelled() int { + tq.mu.Lock() + defer tq.mu.Unlock() + + removed := 0 + element := tq.tasks.Front() + + for element != nil { + next := element.Next() + task := element.Value.(*Task[T]) + + if task.IsCancelled() { + tq.tasks.Remove(element) + delete(tq.taskMap, task.ID) + removed++ + } + + element = next + } + + return removed +} diff --git a/pkg/queue/queue_test.go b/pkg/queue/queue_test.go new file mode 100644 index 0000000..57cdb9e --- /dev/null +++ b/pkg/queue/queue_test.go @@ -0,0 +1,172 @@ +package queue_test + +import ( + "context" + "fmt" + "sync" + "testing" + + "github.com/krau/SaveAny-Bot/pkg/queue" +) + +// helper to create a simple Task with integer payload +func newTask(id string) *queue.Task[int] { + return queue.NewTask(context.Background(), id, 0) +} + +func TestAddAndLength(t *testing.T) { + q := queue.NewTaskQueue[int]() + if q.Length() != 0 { + t.Fatalf("expected length 0, got %d", q.Length()) + } + t1 := newTask("t1") + if err := q.Add(t1); err != nil { + t.Fatalf("unexpected error on Add: %v", err) + } + if q.Length() != 1 { + t.Fatalf("expected length 1, got %d", q.Length()) + } +} + +func TestDuplicateAdd(t *testing.T) { + q := queue.NewTaskQueue[int]() + t1 := newTask("dup") + if err := q.Add(t1); err != nil { + t.Fatalf("unexpected error on first Add: %v", err) + } + if err := q.Add(t1); err == nil { + t.Fatal("expected error on duplicate Add, got nil") + } +} + +func TestGetAndPeek(t *testing.T) { + q := queue.NewTaskQueue[int]() + t1 := newTask("a") + t2 := newTask("b") + q.Add(t1) + q.Add(t2) + // Peek should return t1 + peeked, err := q.Peek() + if err != nil { + t.Fatalf("unexpected error on Peek: %v", err) + } + if peeked.ID != "a" { + t.Fatalf("expected Peek ID 'a', got '%s'", peeked.ID) + } + // Get should return t1 then t2 + first, err := q.Get() + if err != nil { + t.Fatalf("unexpected error on Get: %v", err) + } + if first.ID != "a" { + t.Fatalf("expected first Get ID 'a', got '%s'", first.ID) + } + second, err := q.Get() + if err != nil { + t.Fatalf("unexpected error on second Get: %v", err) + } + if second.ID != "b" { + t.Fatalf("expected second Get ID 'b', got '%s'", second.ID) + } +} + +func TestCancelAndActiveLength(t *testing.T) { + q := queue.NewTaskQueue[int]() + t1 := newTask("1") + t2 := newTask("2") + q.Add(t1) + q.Add(t2) + // Cancel t1 + if err := q.CancelTask("1"); err != nil { + t.Fatalf("unexpected error on CancelTask: %v", err) + } + // Length counts all entries + if q.Length() != 2 { + t.Fatalf("expected total length 2, got %d", q.Length()) + } + // ActiveLength skips cancelled + if got := q.ActiveLength(); got != 1 { + t.Fatalf("expected active length 1, got %d", got) + } +} + +func TestRemoveTask(t *testing.T) { + q := queue.NewTaskQueue[int]() + t1 := newTask("r1") + q.Add(t1) + if err := q.RemoveTask("r1"); err != nil { + t.Fatalf("unexpected error on RemoveTask: %v", err) + } + if q.Length() != 0 { + t.Fatalf("expected length 0 after remove, got %d", q.Length()) + } +} + +func TestClearAndCleanupCancelled(t *testing.T) { + q := queue.NewTaskQueue[int]() + tasks := []*queue.Task[int]{newTask("c1"), newTask("c2"), newTask("c3")} + for _, tsk := range tasks { + q.Add(tsk) + } + // Cancel one + q.CancelTask("c2") + // Cleanup cancelled + removed := q.CleanupCancelled() + if removed != 1 { + t.Fatalf("expected removed 1, got %d", removed) + } + if q.ActiveLength() != 2 { + t.Fatalf("expected active length 2 after cleanup, got %d", q.ActiveLength()) + } + // Clear all + q.Clear() + if q.Length() != 0 { + t.Fatalf("expected length 0 after clear, got %d", q.Length()) + } +} + +func TestCloseBehavior(t *testing.T) { + q := queue.NewTaskQueue[int]() + done := make(chan struct{}) + // consumer + go func() { + _, err := q.Get() + if err == nil { + t.Errorf("expected error when getting from closed empty queue, got nil") + } + close(done) + }() + // allow goroutine to block + + // close queue + q.Close() + <-done +} + +func TestConcurrencySafety(t *testing.T) { + q := queue.NewTaskQueue[int]() + var wg sync.WaitGroup + n := 1000 + // producers + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < n; i++ { + q.Add(newTask(fmt.Sprintf("p%d", i))) + } + }() + // consumers + wg.Add(1) + go func() { + defer wg.Done() + count := 0 + for count < n { + _, err := q.Get() + if err != nil { + continue + } + count++ + } + }() + wg.Wait() +} diff --git a/pkg/queue/task.go b/pkg/queue/task.go new file mode 100644 index 0000000..6858041 --- /dev/null +++ b/pkg/queue/task.go @@ -0,0 +1,44 @@ +package queue + +import ( + "container/list" + "context" + "time" +) + +type Task[T any] struct { + ID string + Data T + ctx context.Context + cancel context.CancelFunc + created time.Time + element *list.Element +} + +func NewTask[T any](ctx context.Context, id string, data T) *Task[T] { + cancelCtx, cancel := context.WithCancel(ctx) + return &Task[T]{ + ID: id, + Data: data, + ctx: cancelCtx, + cancel: cancel, + created: time.Now(), + } +} + +func (t *Task[T]) IsCancelled() bool { + select { + case <-t.ctx.Done(): + return true + default: + return false + } +} + +func (t *Task[T]) Cancel() { + t.cancel() +} + +func (t *Task[T]) Context() context.Context { + return t.ctx +} diff --git a/pkg/rule/filename_regex.go b/pkg/rule/filename_regex.go new file mode 100644 index 0000000..4b80180 --- /dev/null +++ b/pkg/rule/filename_regex.go @@ -0,0 +1,45 @@ +package rule + +import ( + "regexp" + + ruleenum "github.com/krau/SaveAny-Bot/pkg/enums/rule" + "github.com/krau/SaveAny-Bot/pkg/tfile" +) + +type RuleFileNameRegex struct { + storInfo + regex *regexp.Regexp +} + +var _ RuleClass[tfile.TGFile] = (*RuleFileNameRegex)(nil) + +func (r RuleFileNameRegex) Type() ruleenum.RuleType { + return ruleenum.FileNameRegex +} + +func (r RuleFileNameRegex) Match(input tfile.TGFile) (bool, error) { + return r.regex.MatchString(input.Name()), nil +} + +func (r RuleFileNameRegex) StorageName() string { + return r.storName +} + +func (r RuleFileNameRegex) StoragePath() string { + return r.storPath +} + +func NewRuleFileNameRegex(storName, storPath, regexStr string) (*RuleFileNameRegex, error) { + regex, err := regexp.Compile(regexStr) + if err != nil { + return nil, err + } + return &RuleFileNameRegex{ + storInfo: storInfo{ + storName: storName, + storPath: storPath, + }, + regex: regex, + }, nil +} diff --git a/pkg/rule/message_regex.go b/pkg/rule/message_regex.go new file mode 100644 index 0000000..dade55e --- /dev/null +++ b/pkg/rule/message_regex.go @@ -0,0 +1,43 @@ +package rule + +import ( + "regexp" + + ruleenum "github.com/krau/SaveAny-Bot/pkg/enums/rule" +) + +var _ RuleClass[string] = (*RuleMessageRegex)(nil) + +type RuleMessageRegex struct { + storInfo + regex *regexp.Regexp +} + +func (r RuleMessageRegex) Type() ruleenum.RuleType { + return ruleenum.MessageRegex +} + +func (r RuleMessageRegex) Match(input string) (bool, error) { + return r.regex.MatchString(input), nil +} + +func (r RuleMessageRegex) StorageName() string { + return r.storName +} +func (r RuleMessageRegex) StoragePath() string { + return r.storPath +} + +func NewRuleMessageRegex(storName, storPath, regexStr string) (*RuleMessageRegex, error) { + regex, err := regexp.Compile(regexStr) + if err != nil { + return nil, err + } + return &RuleMessageRegex{ + storInfo: storInfo{ + storName: storName, + storPath: storPath, + }, + regex: regex, + }, nil +} diff --git a/pkg/rule/rule.go b/pkg/rule/rule.go new file mode 100644 index 0000000..064ac85 --- /dev/null +++ b/pkg/rule/rule.go @@ -0,0 +1,17 @@ +package rule + +import ( + ruleenum "github.com/krau/SaveAny-Bot/pkg/enums/rule" +) + +type RuleClass[InputType any] interface { + Type() ruleenum.RuleType + Match(input InputType) (bool, error) + StorageName() string + StoragePath() string +} + +type storInfo struct { + storName string + storPath string +} diff --git a/pkg/tcbdata/data.go b/pkg/tcbdata/data.go new file mode 100644 index 0000000..a256e8c --- /dev/null +++ b/pkg/tcbdata/data.go @@ -0,0 +1,44 @@ +package tcbdata + +import ( + "github.com/krau/SaveAny-Bot/pkg/enums/tasktype" + "github.com/krau/SaveAny-Bot/pkg/telegraph" + "github.com/krau/SaveAny-Bot/pkg/tfile" +) + +const ( + TypeAdd = "add" + TypeSetDefault = "setdefault" +) + +// type TaskDataTGFiles struct { +// Files []tfile.TGFileMessage +// AsBatch bool +// } + +// type TaskDataTelegraph struct { +// Pics []string +// PageNode *telegraph.Page +// } + +// type TaskDataType interface { +// TaskDataTGFiles | TaskDataTelegraph +// } + +type Add struct { + TaskType tasktype.TaskType + SelectedStorName string + DirID uint + SettedDir bool + // tfiles + Files []tfile.TGFileMessage + AsBatch bool + // tphpics + TphPageNode *telegraph.Page + TphPics []string + TphDirPath string // unescaped telegraph.Page.Path +} + +type SetDefaultStorage struct { + StorageName string +} diff --git a/pkg/telegraph/client.go b/pkg/telegraph/client.go new file mode 100644 index 0000000..b2516d6 --- /dev/null +++ b/pkg/telegraph/client.go @@ -0,0 +1,150 @@ +// https://github.com/celestix/telegraph-go + +package telegraph + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" +) + +// Page object represents a page on Telegraph. +type Page struct { + // Path to the page. + Path string `json:"path"` + // URL of the page. + Url string `json:"url"` + // Title of the page. + Title string `json:"title"` + // Description of the page. + Description string `json:"description"` + // Optional. Name of the author, displayed below the title. + AuthorName string `json:"author_name,omitempty"` + // Optional. Profile link, opened when users click on the author's name below the title. Can be any link, not necessarily to a Telegram profile or channel. + AuthorUrl string `json:"author_url,omitempty"` + // Optional. Image URL of the page. + ImageUrl string `json:"image_url,omitempty"` + // Optional. Content of the page. + Content []Node `json:"content,omitempty"` + // Number of page views for the page. + Views int64 `json:"views"` + // Optional. Only returned if access_token passed. True, if the target Telegraph account can edit the page. + CanEdit bool `json:"can_edit,omitempty"` +} + +// Node is abstract object represents a DOM Node. It can be a String which represents a DOM text node or a +// NodeElement object. +type Node any + +// NodeElement represents a DOM element node. +type NodeElement struct { + // Name of the DOM element. Available tags: a, aside, b, blockquote, br, code, em, figcaption, figure, + // h3, h4, hr, i, iframe, img, li, ol, p, pre, s, strong, u, ul, video.Client + Tag string `json:"tag"` + + // Attributes of the DOM element. Key of object represents name of attribute, value represents value + // of attribute. Available attributes: href, src. + Attrs map[string]string `json:"attrs,omitempty"` + + // List of child nodes for the DOM element. + Children []Node `json:"children,omitempty"` +} + +type Client struct { + client *http.Client +} + +type Body struct { + // Ok: if true, request was successful, and result can be found in the Result field. + // If false, error can be explained in Error field. + Ok bool `json:"ok"` + // Error: contains a human-readable description of the error result. + Error string `json:"error"` + // Result: result of requests (if Ok) + Result json.RawMessage `json:"result"` +} + +const ( + ApiUrl = "https://api.telegra.ph/" +) + +func (c *Client) InvokeRequest(ctx context.Context, method string, params url.Values) (json.RawMessage, error) { + r, err := http.NewRequestWithContext(ctx, http.MethodPost, ApiUrl+method, strings.NewReader(params.Encode())) + if err != nil { + return nil, fmt.Errorf("failed to build POST request to %s: %w", method, err) + } + + resp, err := c.client.Do(r) + if err != nil { + return nil, fmt.Errorf("failed to execute POST request to %s: %w", method, err) + } + + defer func() { + _ = resp.Body.Close() + }() + + var b Body + if err = json.NewDecoder(resp.Body).Decode(&b); err != nil { + return nil, fmt.Errorf("failed to parse response from %s: %w", method, err) + } + if !b.Ok { + return nil, fmt.Errorf("failed to %s: %s", method, b.Error) + } + return b.Result, nil +} + +func (c *Client) GetPage(ctx context.Context, phpath string) (*Page, error) { + var ( + u = url.Values{} + a Page + ) + u.Add("path", phpath) + u.Add("return_content", "true") + r, err := c.InvokeRequest(ctx, "getPage", u) + if err != nil { + return nil, err + } + return &a, json.Unmarshal(r, &a) +} + +// Helper to use the client(*http.Client) to download a file from a given URL. +func (c *Client) Download(ctx context.Context, durl string) (io.ReadCloser, error) { + r, err := http.NewRequestWithContext(ctx, http.MethodGet, durl, nil) + if err != nil { + return nil, err + } + resp, err := c.client.Do(r) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to download file from %s: %s", durl, resp.Status) + } + return resp.Body, nil +} + +func NewClient() *Client { + return &Client{ + client: &http.Client{}, + } +} + +func NewClientWithProxy(proxyUrl string) (*Client, error) { + u, err := url.Parse(proxyUrl) + if err != nil { + return nil, err + } + p := http.ProxyURL(u) + httpClient := &http.Client{ + Transport: &http.Transport{ + Proxy: p, + }, + } + return &Client{ + client: httpClient, + }, nil +} diff --git a/pkg/tfile/opts.go b/pkg/tfile/opts.go new file mode 100644 index 0000000..3bc8c97 --- /dev/null +++ b/pkg/tfile/opts.go @@ -0,0 +1,38 @@ +package tfile + +import "github.com/gotd/td/tg" + +type TGFileOptions func(*tgFile) + +func WithMessage(msg *tg.Message) TGFileOptions { + return func(f *tgFile) { + f.message = msg + } +} +func WithName(name string) TGFileOptions { + return func(f *tgFile) { + f.name = name + } +} + +func WithNameIfEmpty(name string) TGFileOptions { + return func(f *tgFile) { + if f.name == "" { + f.name = name + } + } +} + +func WithSize(size int64) TGFileOptions { + return func(f *tgFile) { + f.size = size + } +} + +func WithSizeIfZero(size int64) TGFileOptions { + return func(f *tgFile) { + if f.size == 0 { + f.size = size + } + } +} diff --git a/pkg/tfile/tgfile.go b/pkg/tfile/tgfile.go new file mode 100644 index 0000000..3ba3605 --- /dev/null +++ b/pkg/tfile/tgfile.go @@ -0,0 +1,126 @@ +package tfile + +import ( + "errors" + "fmt" + "time" + + "github.com/gotd/td/tg" +) + +type TGFile interface { + Location() tg.InputFileLocationClass + Size() int64 + Name() string +} + +type TGFileMessage interface { + TGFile + Message() *tg.Message +} + +type tgFile struct { + location tg.InputFileLocationClass + size int64 + name string + message *tg.Message +} + +func (f *tgFile) Location() tg.InputFileLocationClass { + return f.location +} + +func (f *tgFile) Size() int64 { + return f.size +} + +func (f *tgFile) Name() string { + return f.name +} + +func (f *tgFile) Message() *tg.Message { + return f.message +} + +func NewTGFile(location tg.InputFileLocationClass, size int64, name string, + opts ...TGFileOptions, +) TGFile { + f := &tgFile{ + location: location, + size: size, + name: name, + } + for _, opt := range opts { + opt(f) + } + return f +} + +func FromMedia(media tg.MessageMediaClass, opts ...TGFileOptions) (TGFile, error) { + switch m := media.(type) { + case *tg.MessageMediaDocument: + document, ok := m.Document.AsNotEmpty() + if !ok { + return nil, errors.New("document is empty") + } + fileName := "" + for _, attribute := range document.Attributes { + if name, ok := attribute.(*tg.DocumentAttributeFilename); ok { + fileName = name.GetFileName() + break + } + } + file := &tgFile{ + location: document.AsInputDocumentFileLocation(), + size: document.Size, + name: fileName, + } + for _, opt := range opts { + opt(file) + } + return file, nil + case *tg.MessageMediaPhoto: + photo, ok := m.Photo.AsNotEmpty() + if !ok { + return nil, errors.New("photo is empty") + } + sizes := photo.Sizes + if len(sizes) == 0 { + return nil, errors.New("photo sizes are empty") + } + photoSize := sizes[len(sizes)-1] + size, ok := photoSize.AsNotEmpty() + if !ok { + return nil, errors.New("photo size is empty") + } + location := new(tg.InputPhotoFileLocation) + location.ID = photo.GetID() + location.AccessHash = photo.GetAccessHash() + location.FileReference = photo.GetFileReference() + location.ThumbSize = size.GetType() + fileName := fmt.Sprintf("photo_%s_%d.jpg", time.Now().Format("2006-01-02_15-04-05"), photo.GetID()) + file := &tgFile{ + location: location, + size: 0, + name: fileName, + } + for _, opt := range opts { + opt(file) + } + return file, nil + } + return nil, fmt.Errorf("unsupported media type: %T", media) +} + +func FromMediaMessage(media tg.MessageMediaClass, msg *tg.Message, opts ...TGFileOptions) (TGFileMessage, error) { + file, err := FromMedia(media, opts...) + if err != nil { + return nil, err + } + return &tgFile{ + location: file.Location(), + size: file.Size(), + name: file.Name(), + message: msg, + }, nil +} diff --git a/queue/queue.go b/queue/queue.go deleted file mode 100644 index 78401fc..0000000 --- a/queue/queue.go +++ /dev/null @@ -1,110 +0,0 @@ -package queue - -import ( - "container/list" - "sync" - - "github.com/krau/SaveAny-Bot/types" -) - -type TaskQueue struct { - list *list.List - cond *sync.Cond - mutex *sync.Mutex - activeMap map[string]*types.Task -} - -func (q *TaskQueue) AddTask(task *types.Task) { - q.mutex.Lock() - defer q.mutex.Unlock() - q.list.PushBack(task) - q.cond.Signal() - if task.Status != types.Pending { - delete(q.activeMap, task.Key()) - } -} - -func (q *TaskQueue) GetTask() *types.Task { - q.mutex.Lock() - defer q.mutex.Unlock() - for q.list.Len() == 0 { - q.cond.Wait() - } - e := q.list.Front() - task := e.Value.(*types.Task) - q.list.Remove(e) - if task.Status == types.Pending { - q.activeMap[task.Key()] = task - } - return task -} - -func (q *TaskQueue) DoneTask(task *types.Task) { - q.mutex.Lock() - defer q.mutex.Unlock() - delete(q.activeMap, task.Key()) -} - -func (q *TaskQueue) CancelTask(key string) bool { - q.mutex.Lock() - defer q.mutex.Unlock() - if task, ok := q.activeMap[key]; ok { - if task.Cancel != nil { - task.Cancel() - return true - } - } - for e := q.list.Front(); e != nil; e = e.Next() { - task := e.Value.(*types.Task) - if task.Key() == key { - if task.Cancel != nil { - task.Cancel() - } - q.list.Remove(e) - return true - } - } - return false -} - -func (q *TaskQueue) Len() int { - q.mutex.Lock() - defer q.mutex.Unlock() - return q.list.Len() -} - -var Queue *TaskQueue - -func init() { - Queue = NewQueue() -} - -func NewQueue() *TaskQueue { - m := &sync.Mutex{} - return &TaskQueue{ - list: list.New(), - cond: sync.NewCond(m), - mutex: m, - activeMap: make(map[string]*types.Task), - } -} - -func AddTask(task *types.Task) { - Queue.AddTask(task) -} - -func GetTask() *types.Task { - return Queue.GetTask() -} - -func Len() int { - return Queue.Len() -} - -func CancelTask(key string) bool { - return Queue.CancelTask(key) -} - -func DoneTask(task *types.Task) { - Queue.DoneTask(task) -} diff --git a/storage/alist/alist.go b/storage/alist/alist.go index 296ae42..d04f7ac 100644 --- a/storage/alist/alist.go +++ b/storage/alist/alist.go @@ -1,6 +1,7 @@ package alist import ( + "bytes" "context" "encoding/json" "fmt" @@ -8,11 +9,13 @@ import ( "net/http" "net/url" "path" + "strings" "time" - "github.com/krau/SaveAny-Bot/common" + "github.com/charmbracelet/log" config "github.com/krau/SaveAny-Bot/config/storage" - "github.com/krau/SaveAny-Bot/types" + "github.com/krau/SaveAny-Bot/pkg/enums/key" + storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage" ) type Alist struct { @@ -21,9 +24,10 @@ type Alist struct { baseURL string loginInfo *loginRequest config config.AlistStorageConfig + logger *log.Logger } -func (a *Alist) Init(cfg config.StorageConfig) error { +func (a *Alist) Init(ctx context.Context, cfg config.StorageConfig) error { alistConfig, ok := cfg.(*config.AlistStorageConfig) if !ok { return fmt.Errorf("failed to cast alist config") @@ -32,45 +36,46 @@ func (a *Alist) Init(cfg config.StorageConfig) error { return err } a.config = *alistConfig - a.baseURL = alistConfig.URL a.client = getHttpClient() + a.logger = log.FromContext(ctx).WithPrefix(fmt.Sprintf("alist[%s]", alistConfig.Name)) + if alistConfig.Token != "" { a.token = alistConfig.Token ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) defer cancel() req, err := http.NewRequestWithContext(ctx, http.MethodGet, a.baseURL+"/api/me", nil) if err != nil { - common.Log.Fatalf("Failed to create request: %v", err) + a.logger.Fatalf("Failed to create request: %v", err) return err } req.Header.Set("Authorization", a.token) resp, err := a.client.Do(req) if err != nil { - common.Log.Fatalf("Failed to send request: %v", err) + a.logger.Fatalf("Failed to send request: %v", err) return err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - common.Log.Fatalf("Failed to get alist user info: %s", resp.Status) + a.logger.Fatalf("Failed to get alist user info: %s", resp.Status) return err } body, err := io.ReadAll(resp.Body) if err != nil { - common.Log.Fatalf("Failed to read response body: %v", err) + a.logger.Fatalf("Failed to read response body: %v", err) return err } var meResp meResponse if err := json.Unmarshal(body, &meResp); err != nil { - common.Log.Fatalf("Failed to unmarshal me response: %v", err) + a.logger.Fatalf("Failed to unmarshal me response: %v", err) return err } if meResp.Code != http.StatusOK { - common.Log.Fatalf("Failed to get alist user info: %s", meResp.Message) + a.logger.Fatalf("Failed to get alist user info: %s", meResp.Message) return err } - common.Log.Debugf("Logged in Alist as %s", meResp.Data.Username) + a.logger.Debugf("Logged in Alist as %s", meResp.Data.Username) return nil } a.loginInfo = &loginRequest{ @@ -78,18 +83,18 @@ func (a *Alist) Init(cfg config.StorageConfig) error { Password: alistConfig.Password, } - if err := a.getToken(); err != nil { - common.Log.Fatalf("Failed to login to Alist: %v", err) + if err := a.getToken(ctx); err != nil { + a.logger.Fatalf("Failed to login to Alist: %v", err) return err } - common.Log.Debug("Logged in to Alist") + a.logger.Debug("Logged in to Alist") go a.refreshToken(*alistConfig) return nil } -func (a *Alist) Type() types.StorageType { - return types.StorageTypeAlist +func (a *Alist) Type() storenum.StorageType { + return storenum.Alist } func (a *Alist) Name() string { @@ -97,16 +102,23 @@ func (a *Alist) Name() string { } func (a *Alist) Save(ctx context.Context, reader io.Reader, storagePath string) error { - common.Log.Infof("Saving file to %s", storagePath) + a.logger.Infof("Saving file to %s", storagePath) + + ext := path.Ext(storagePath) + base := strings.TrimSuffix(storagePath, ext) + candidate := storagePath + for i := 1; a.Exists(ctx, candidate); i++ { + candidate = fmt.Sprintf("%s_%d%s", base, i, ext) + } req, err := http.NewRequestWithContext(ctx, http.MethodPut, a.baseURL+"/api/fs/put", reader) if err != nil { return fmt.Errorf("failed to create request: %w", err) } req.Header.Set("Authorization", a.token) - req.Header.Set("File-Path", url.PathEscape(storagePath)) + req.Header.Set("File-Path", url.PathEscape(candidate)) req.Header.Set("Content-Type", "application/octet-stream") - if length := ctx.Value(types.ContextKeyContentLength); length != nil { + if length := ctx.Value(key.ContextKeyContentLength); length != nil { length, ok := length.(int64) if ok { req.ContentLength = length @@ -140,15 +152,66 @@ func (a *Alist) Save(ctx context.Context, reader io.Reader, storagePath string) return nil } -func (a *Alist) NotSupportStream() string { - return "Alist does not support chunked transfer encoding" -} - -func (a *Alist) JoinStoragePath(task types.Task) string { - return path.Join(a.config.BasePath, task.StoragePath) +func (a *Alist) JoinStoragePath(p string) string { + return path.Join(a.config.BasePath, p) } func (a *Alist) Exists(ctx context.Context, storagePath string) bool { - // TODO: Implement it. - return false + // POST /api/fs/get + /* + body: + { + "path": "/t", + "password": "", + "page": 1, + "per_page": 0, + "refresh": false + } + */ + body := map[string]any{ + "path": storagePath, + "password": "", + } + bodyBytes, err := json.Marshal(body) + if err != nil { + a.logger.Errorf("Failed to marshal request body: %v", err) + return false + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, a.baseURL+"/api/fs/get", bytes.NewBuffer(bodyBytes)) + if err != nil { + a.logger.Errorf("Failed to create request: %v", err) + return false + } + req.Header.Set("Authorization", a.token) + req.Header.Set("Content-Type", "application/json") + resp, err := a.client.Do(req) + if err != nil { + a.logger.Errorf("Failed to send request: %v", err) + return false + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return false + } + data, err := io.ReadAll(resp.Body) + if err != nil { + a.logger.Errorf("Failed to read response body: %v", err) + return false + } + var fsGetResp fsGetResponse + if err := json.Unmarshal(data, &fsGetResp); err != nil { + a.logger.Errorf("Failed to unmarshal fs get response: %v", err) + return false + } + if fsGetResp.Code != http.StatusOK { + a.logger.Errorf("Failed to get file info from Alist: %d, %s", fsGetResp.Code, fsGetResp.Message) + return false + } + return true + +} + +// Impl StorageCannotStream interface +func (a *Alist) CannotStream() string { + return "Alist does not support chunked transfer encoding" } diff --git a/storage/alist/token.go b/storage/alist/token.go index 12a0bd0..3776308 100644 --- a/storage/alist/token.go +++ b/storage/alist/token.go @@ -2,17 +2,17 @@ package alist import ( "bytes" + "context" "encoding/json" "fmt" "io" "net/http" "time" - "github.com/krau/SaveAny-Bot/common" config "github.com/krau/SaveAny-Bot/config/storage" ) -func (a *Alist) getToken() error { +func (a *Alist) getToken(ctx context.Context) error { loginBody, err := json.Marshal(a.loginInfo) if err != nil { return fmt.Errorf("failed to marshal login request: %w", err) @@ -51,15 +51,15 @@ func (a *Alist) getToken() error { func (a *Alist) refreshToken(cfg config.AlistStorageConfig) { tokenExp := cfg.TokenExp if tokenExp <= 0 { - common.Log.Warn("Invalid token expiration time, using default value") + a.logger.Warn("Invalid token expiration time, using default value") tokenExp = 3600 } for { time.Sleep(time.Duration(tokenExp) * time.Second) - if err := a.getToken(); err != nil { - common.Log.Errorf("Failed to refresh jwt token: %v", err) + if err := a.getToken(context.Background()); err != nil { + a.logger.Errorf("Failed to refresh jwt token: %v", err) continue } - common.Log.Info("Refreshed Alist jwt token") + a.logger.Info("Refreshed Alist jwt token") } } diff --git a/storage/alist/types.go b/storage/alist/types.go index c0bec32..59be4d5 100644 --- a/storage/alist/types.go +++ b/storage/alist/types.go @@ -42,3 +42,8 @@ type putResponse struct { } `json:"task"` } `json:"data"` } + +type fsGetResponse struct { + Code int `json:"code"` + Message string `json:"message"` +} diff --git a/storage/context.go b/storage/context.go new file mode 100644 index 0000000..7ee5ec8 --- /dev/null +++ b/storage/context.go @@ -0,0 +1,22 @@ +package storage + +import "context" + +type contextKey struct{} + +var storageKey = contextKey{} + +func WithContext(ctx context.Context, storage Storage) context.Context { + if storage == nil { + return ctx + } + return context.WithValue(ctx, storageKey, storage) +} + +func FromContext(ctx context.Context) Storage { + storage, ok := ctx.Value(storageKey).(Storage) + if !ok { + return nil + } + return storage +} diff --git a/storage/load.go b/storage/load.go new file mode 100644 index 0000000..d37c8f7 --- /dev/null +++ b/storage/load.go @@ -0,0 +1,80 @@ +package storage + +import ( + "context" + "fmt" + + "github.com/charmbracelet/log" + "github.com/krau/SaveAny-Bot/config" +) + +var UserStorages = make(map[int64][]Storage) + +// GetStorageByName returns storage by name from cache or creates new one +func getStorageByName(ctx context.Context, name string) (Storage, error) { + if name == "" { + return nil, ErrStorageNameEmpty + } + + storage, ok := Storages[name] + if ok { + return storage, nil + } + cfg := config.Cfg.GetStorageByName(name) + if cfg == nil { + return nil, fmt.Errorf("未找到存储 %s", name) + } + + storage, err := NewStorage(ctx, cfg) + if err != nil { + return nil, err + } + Storages[name] = storage + return storage, nil +} + +// 检查 user 是否可用指定的 storage, 若不可用则返回未找到错误 +func GetStorageByUserIDAndName(ctx context.Context, chatID int64, name string) (Storage, error) { + if name == "" { + return nil, ErrStorageNameEmpty + } + + if !config.Cfg.HasStorage(chatID, name) { + return nil, fmt.Errorf("没有找到用户 %d 的存储 %s", chatID, name) + } + + return getStorageByName(ctx, name) +} + +func GetUserStorages(ctx context.Context, chatID int64) []Storage { + if chatID <= 0 { + return nil + } + if storages, ok := UserStorages[chatID]; ok { + return storages + } + var storages []Storage + for _, name := range config.Cfg.GetStorageNamesByUserID(chatID) { + storage, err := getStorageByName(ctx, name) + if err != nil { + continue + } + storages = append(storages, storage) + } + return storages +} + +func LoadStorages(ctx context.Context) { + logger := log.FromContext(ctx) + logger.Info("加载存储...") + for _, storage := range config.Cfg.Storages { + _, err := getStorageByName(ctx, storage.GetName()) + if err != nil { + logger.Errorf("加载存储 %s 失败: %v", storage.GetName(), err) + } + } + logger.Infof("成功加载 %d 个存储", len(Storages)) + for user := range config.Cfg.GetUsersID() { + UserStorages[int64(user)] = GetUserStorages(ctx, int64(user)) + } +} diff --git a/storage/local/local.go b/storage/local/local.go index 4723d9e..34ed4be 100644 --- a/storage/local/local.go +++ b/storage/local/local.go @@ -6,18 +6,20 @@ import ( "io" "os" "path/filepath" + "strings" + "github.com/charmbracelet/log" "github.com/duke-git/lancet/v2/fileutil" - "github.com/krau/SaveAny-Bot/common" config "github.com/krau/SaveAny-Bot/config/storage" - "github.com/krau/SaveAny-Bot/types" + storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage" ) type Local struct { config config.LocalStorageConfig + logger *log.Logger } -func (l *Local) Init(cfg config.StorageConfig) error { +func (l *Local) Init(ctx context.Context, cfg config.StorageConfig) error { localConfig, ok := cfg.(*config.LocalStorageConfig) if !ok { return fmt.Errorf("failed to cast local config") @@ -30,25 +32,33 @@ func (l *Local) Init(cfg config.StorageConfig) error { if err != nil { return fmt.Errorf("failed to create local storage directory: %w", err) } + l.logger = log.FromContext(ctx).WithPrefix(fmt.Sprintf("local[%s]", l.config.Name)) return nil } -func (l *Local) Type() types.StorageType { - return types.StorageTypeLocal +func (l *Local) Type() storenum.StorageType { + return storenum.Local } func (l *Local) Name() string { return l.config.Name } -func (l *Local) JoinStoragePath(task types.Task) string { - return filepath.Join(l.config.BasePath, task.StoragePath) +func (l *Local) JoinStoragePath(path string) string { + return filepath.Join(l.config.BasePath, path) } func (l *Local) Save(ctx context.Context, r io.Reader, storagePath string) error { - common.Log.Infof("Saving file to %s", storagePath) + l.logger.Infof("Saving file to %s", storagePath) - absPath, err := filepath.Abs(storagePath) + ext := filepath.Ext(storagePath) + base := strings.TrimSuffix(storagePath, ext) + candidate := storagePath + for i := 1; l.Exists(ctx, candidate); i++ { + candidate = fmt.Sprintf("%s_%d%s", base, i, ext) + } + + absPath, err := filepath.Abs(candidate) if err != nil { return err } diff --git a/storage/minio/client.go b/storage/minio/client.go index 85c9bcf..f26e9bc 100644 --- a/storage/minio/client.go +++ b/storage/minio/client.go @@ -5,10 +5,11 @@ import ( "fmt" "io" "path" + "strings" - "github.com/krau/SaveAny-Bot/common" + "github.com/charmbracelet/log" config "github.com/krau/SaveAny-Bot/config/storage" - "github.com/krau/SaveAny-Bot/types" + storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage" "github.com/minio/minio-go/v7" "github.com/minio/minio-go/v7/pkg/credentials" ) @@ -16,9 +17,10 @@ import ( type Minio struct { config config.MinioStorageConfig client *minio.Client + logger *log.Logger } -func (m *Minio) Init(cfg config.StorageConfig) error { +func (m *Minio) Init(ctx context.Context, cfg config.StorageConfig) error { minioConfig, ok := cfg.(*config.MinioStorageConfig) if !ok { return fmt.Errorf("failed to cast minio config") @@ -27,6 +29,7 @@ func (m *Minio) Init(cfg config.StorageConfig) error { return err } m.config = *minioConfig + m.logger = log.FromContext(ctx).WithPrefix(fmt.Sprintf("minio[%s]", m.config.Name)) client, err := minio.New(m.config.Endpoint, &minio.Options{ Creds: credentials.NewStaticV4(m.config.AccessKeyID, m.config.SecretAccessKey, ""), @@ -36,7 +39,7 @@ func (m *Minio) Init(cfg config.StorageConfig) error { return fmt.Errorf("failed to create minio client: %w", err) } - exists, err := client.BucketExists(context.Background(), m.config.BucketName) + exists, err := client.BucketExists(ctx, m.config.BucketName) if err != nil { return fmt.Errorf("failed to check bucket existence: %w", err) } @@ -48,22 +51,29 @@ func (m *Minio) Init(cfg config.StorageConfig) error { return nil } -func (m *Minio) Type() types.StorageType { - return types.StorageTypeMinio +func (m *Minio) Type() storenum.StorageType { + return storenum.Minio } func (m *Minio) Name() string { return m.config.Name } -func (m *Minio) JoinStoragePath(task types.Task) string { - return path.Join(m.config.BasePath, task.StoragePath) +func (m *Minio) JoinStoragePath(p string) string { + return path.Join(m.config.BasePath, p) } func (m *Minio) Save(ctx context.Context, r io.Reader, storagePath string) error { - common.Log.Infof("Saving file from reader to %s", storagePath) + m.logger.Infof("Saving file from reader to %s", storagePath) - _, err := m.client.PutObject(ctx, m.config.BucketName, storagePath, r, -1, minio.PutObjectOptions{}) + ext := path.Ext(storagePath) + base := strings.TrimSuffix(storagePath, ext) + candidate := storagePath + for i := 1; m.Exists(ctx, candidate); i++ { + candidate = fmt.Sprintf("%s_%d%s", base, i, ext) + } + + _, err := m.client.PutObject(ctx, m.config.BucketName, candidate, r, -1, minio.PutObjectOptions{}) if err != nil { return fmt.Errorf("failed to upload file to minio: %w", err) } @@ -72,15 +82,7 @@ func (m *Minio) Save(ctx context.Context, r io.Reader, storagePath string) error } func (m *Minio) Exists(ctx context.Context, storagePath string) bool { - common.Log.Debugf("Checking if file exists at %s", storagePath) - // TODO: test it. + m.logger.Debugf("Checking if file exists at %s", storagePath) _, err := m.client.StatObject(ctx, m.config.BucketName, storagePath, minio.StatObjectOptions{}) - if err != nil { - if minio.ToErrorResponse(err).Code == "NoSuchKey" { - return false // File does not exist - } - return false - } - - return true + return err == nil } diff --git a/storage/storage.go b/storage/storage.go index 5ff4c04..a7c7e9e 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -5,121 +5,51 @@ import ( "fmt" "io" - "github.com/krau/SaveAny-Bot/common" - "github.com/krau/SaveAny-Bot/config" - sc "github.com/krau/SaveAny-Bot/config/storage" + storcfg "github.com/krau/SaveAny-Bot/config/storage" + storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage" "github.com/krau/SaveAny-Bot/storage/alist" "github.com/krau/SaveAny-Bot/storage/local" "github.com/krau/SaveAny-Bot/storage/minio" + "github.com/krau/SaveAny-Bot/storage/telegram" "github.com/krau/SaveAny-Bot/storage/webdav" - "github.com/krau/SaveAny-Bot/types" ) type Storage interface { - Init(cfg sc.StorageConfig) error - Type() types.StorageType + Init(ctx context.Context, cfg storcfg.StorageConfig) error + Type() storenum.StorageType Name() string - JoinStoragePath(task types.Task) string + JoinStoragePath(p string) string Save(ctx context.Context, reader io.Reader, storagePath string) error Exists(ctx context.Context, storagePath string) bool } -type StorageNotSupportStream interface { +type StorageCannotStream interface { Storage - NotSupportStream() string + CannotStream() string } var Storages = make(map[string]Storage) -var UserStorages = make(map[int64][]Storage) - -// GetStorageByName returns storage by name from cache or creates new one -func GetStorageByName(name string) (Storage, error) { - if name == "" { - return nil, ErrStorageNameEmpty - } - - storage, ok := Storages[name] - if ok { - return storage, nil - } - cfg := config.Cfg.GetStorageByName(name) - if cfg == nil { - return nil, fmt.Errorf("未找到存储 %s", name) - } - - storage, err := NewStorage(cfg) - if err != nil { - return nil, err - } - Storages[name] = storage - return storage, nil -} - -// 检查 user 是否可用指定的 storage, 若不可用则返回未找到错误 -func GetStorageByUserIDAndName(chatID int64, name string) (Storage, error) { - if name == "" { - return nil, ErrStorageNameEmpty - } - - if !config.Cfg.HasStorage(chatID, name) { - return nil, fmt.Errorf("没有找到用户 %d 的存储 %s", chatID, name) - } - - return GetStorageByName(name) -} - -func GetUserStorages(chatID int64) []Storage { - if chatID <= 0 { - return nil - } - if storages, ok := UserStorages[chatID]; ok { - return storages - } - var storages []Storage - for _, name := range config.Cfg.GetStorageNamesByUserID(chatID) { - storage, err := GetStorageByName(name) - if err != nil { - continue - } - storages = append(storages, storage) - } - return storages -} - type StorageConstructor func() Storage -var storageConstructors = map[string]StorageConstructor{ - string(types.StorageTypeAlist): func() Storage { return new(alist.Alist) }, - string(types.StorageTypeLocal): func() Storage { return new(local.Local) }, - string(types.StorageTypeWebdav): func() Storage { return new(webdav.Webdav) }, - string(types.StorageTypeMinio): func() Storage { return new(minio.Minio) }, +var storageConstructors = map[storenum.StorageType]StorageConstructor{ + storenum.Alist: func() Storage { return new(alist.Alist) }, + storenum.Local: func() Storage { return new(local.Local) }, + storenum.Webdav: func() Storage { return new(webdav.Webdav) }, + storenum.Minio: func() Storage { return new(minio.Minio) }, + storenum.Telegram: func() Storage { return new(telegram.Telegram) }, } -func NewStorage(cfg sc.StorageConfig) (Storage, error) { - constructor, ok := storageConstructors[string(cfg.GetType())] +func NewStorage(ctx context.Context, cfg storcfg.StorageConfig) (Storage, error) { + constructor, ok := storageConstructors[cfg.GetType()] if !ok { return nil, fmt.Errorf("不支持的存储类型: %s", cfg.GetType()) } storage := constructor() - if err := storage.Init(cfg); err != nil { + if err := storage.Init(ctx, cfg); err != nil { return nil, fmt.Errorf("初始化 %s 存储失败: %w", cfg.GetName(), err) } return storage, nil } - -func LoadStorages() { - common.Log.Info("加载存储...") - for _, storage := range config.Cfg.Storages { - _, err := GetStorageByName(storage.GetName()) - if err != nil { - common.Log.Errorf("加载存储 %s 失败: %v", storage.GetName(), err) - } - } - common.Log.Infof("成功加载 %d 个存储", len(Storages)) - for user := range config.Cfg.GetUsersID() { - UserStorages[int64(user)] = GetUserStorages(int64(user)) - } -} diff --git a/storage/telegram/telegram.go b/storage/telegram/telegram.go new file mode 100644 index 0000000..09ee55f --- /dev/null +++ b/storage/telegram/telegram.go @@ -0,0 +1,111 @@ +package telegram + +import ( + "context" + "fmt" + "io" + "path" + "time" + + "github.com/gabriel-vasile/mimetype" + "github.com/gotd/td/telegram/message" + "github.com/gotd/td/telegram/message/styling" + "github.com/gotd/td/telegram/uploader" + "github.com/krau/SaveAny-Bot/common/utils/tgutil" + "github.com/krau/SaveAny-Bot/config" + storconfig "github.com/krau/SaveAny-Bot/config/storage" + "github.com/krau/SaveAny-Bot/pkg/consts/tglimit" + storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage" + "github.com/rs/xid" + "golang.org/x/time/rate" +) + +type Telegram struct { + config storconfig.TelegramStorageConfig + limiter *rate.Limiter +} + +func (t *Telegram) Init(ctx context.Context, cfg storconfig.StorageConfig) error { + telegramConfig, ok := cfg.(*storconfig.TelegramStorageConfig) + if !ok { + return fmt.Errorf("failed to cast telegram config") + } + if err := telegramConfig.Validate(); err != nil { + return err + } + t.config = *telegramConfig + if t.config.RateLimit <= 0 || t.config.RateBurst <= 0 { + t.config.RateLimit = 2 + t.config.RateBurst = 1 + } + t.limiter = rate.NewLimiter(rate.Every(time.Duration(t.config.RateLimit)*time.Second), t.config.RateBurst) + return nil +} + +func (t *Telegram) Type() storenum.StorageType { + return storenum.Telegram +} + +func (t *Telegram) Name() string { + return t.config.Name +} + +func (t *Telegram) JoinStoragePath(p string) string { + return path.Clean(p) +} + +func (t *Telegram) Exists(ctx context.Context, storagePath string) bool { + return false +} + +func (t *Telegram) Save(ctx context.Context, r io.Reader, storagePath string) error { + if err := t.limiter.Wait(ctx); err != nil { + return fmt.Errorf("rate limit failed: %w", err) + } + rs, ok := r.(io.ReadSeeker) + if !ok || rs == nil { + return fmt.Errorf("reader must implement io.ReadSeeker") + } + tctx := tgutil.ExtFromContext(ctx) + if tctx == nil { + return fmt.Errorf("failed to get telegram context") + } + peer := tctx.PeerStorage.GetInputPeerById(t.config.ChatID) + if peer == nil { + return fmt.Errorf("failed to get input peer for chat ID %d", t.config.ChatID) + } + mtype, err := mimetype.DetectReader(rs) + if err != nil { + return fmt.Errorf("failed to detect mimetype: %w", err) + } + filename := path.Base(storagePath) + if filename == "" { + filename = xid.New().String() + mtype.Extension() + } + if _, err := rs.Seek(0, io.SeekStart); err != nil { + return fmt.Errorf("failed to seek reader: %w", err) + } + upler := uploader.NewUploader(tctx.Raw). + WithPartSize(tglimit.MaxUploadPartSize). + WithThreads(config.Cfg.Threads) + + file, err := upler.FromReader(ctx, filename, rs) + if err != nil { + return fmt.Errorf("failed to upload file to telegram: %w", err) + } + + caption := styling.Plain(filename) + docb := message.UploadedDocument(file, caption). + Filename(filename). + ForceFile(true). + MIME(mtype.String()) + + var mediaOpt message.MediaOption = docb + sender := tctx.Sender + _, err = sender.WithUploader(upler).To(peer).Media(ctx, mediaOpt) + return err +} + +func (t *Telegram) CannotStream() string { + return "Telegram storage must use a ReaderSeeker" +} diff --git a/storage/webdav/client.go b/storage/webdav/client.go index 9182bb8..2b3f520 100644 --- a/storage/webdav/client.go +++ b/storage/webdav/client.go @@ -9,7 +9,7 @@ import ( "path" "strings" - "github.com/krau/SaveAny-Bot/types" + "github.com/krau/SaveAny-Bot/pkg/enums/key" ) type Client struct { @@ -54,7 +54,7 @@ func (c *Client) doRequest(ctx context.Context, method WebdavMethod, url string, req.Header.Set("Depth", "1") } if method == WebdavMethodPut && ctx != nil { - if length := ctx.Value(types.ContextKeyContentLength); length != nil { + if length := ctx.Value(key.ContextKeyContentLength); length != nil { if l, ok := length.(int64); ok { req.ContentLength = l } diff --git a/storage/webdav/webdav.go b/storage/webdav/webdav.go index a76819b..e1ba082 100644 --- a/storage/webdav/webdav.go +++ b/storage/webdav/webdav.go @@ -6,19 +6,21 @@ import ( "io" "net/http" "path" + "strings" "time" - "github.com/krau/SaveAny-Bot/common" + "github.com/charmbracelet/log" config "github.com/krau/SaveAny-Bot/config/storage" - "github.com/krau/SaveAny-Bot/types" + storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage" ) type Webdav struct { config config.WebdavStorageConfig client *Client + logger *log.Logger } -func (w *Webdav) Init(cfg config.StorageConfig) error { +func (w *Webdav) Init(ctx context.Context, cfg config.StorageConfig) error { webdavConfig, ok := cfg.(*config.WebdavStorageConfig) if !ok { return fmt.Errorf("failed to cast webdav config") @@ -27,42 +29,51 @@ func (w *Webdav) Init(cfg config.StorageConfig) error { return err } w.config = *webdavConfig + w.logger = log.FromContext(ctx).WithPrefix(fmt.Sprintf("webdav[%s]", w.config.Name)) w.client = NewClient(w.config.URL, w.config.Username, w.config.Password, &http.Client{ Timeout: time.Hour * 12, }) return nil } -func (w *Webdav) Type() types.StorageType { - return types.StorageTypeWebdav +func (w *Webdav) Type() storenum.StorageType { + return storenum.Webdav } func (w *Webdav) Name() string { return w.config.Name } -func (w *Webdav) JoinStoragePath(task types.Task) string { - return path.Join(w.config.BasePath, task.StoragePath) +func (w *Webdav) JoinStoragePath(p string) string { + return path.Join(w.config.BasePath, p) } func (w *Webdav) Save(ctx context.Context, r io.Reader, storagePath string) error { - common.Log.Infof("Saving file to %s", storagePath) - if err := w.client.MkDir(ctx, path.Dir(storagePath)); err != nil { - common.Log.Errorf("Failed to create directory %s: %v", path.Dir(storagePath), err) + w.logger.Infof("Saving file to %s", storagePath) + + ext := path.Ext(storagePath) + base := strings.TrimSuffix(storagePath, ext) + candidate := storagePath + for i := 1; w.Exists(ctx, candidate); i++ { + candidate = fmt.Sprintf("%s_%d%s", base, i, ext) + } + + if err := w.client.MkDir(ctx, path.Dir(candidate)); err != nil { + w.logger.Errorf("Failed to create directory %s: %v", path.Dir(candidate), err) return ErrFailedToCreateDirectory } - if err := w.client.WriteFile(ctx, storagePath, r); err != nil { - common.Log.Errorf("Failed to write file %s: %v", storagePath, err) + if err := w.client.WriteFile(ctx, candidate, r); err != nil { + w.logger.Errorf("Failed to write file %s: %v", candidate, err) return ErrFailedToWriteFile } return nil } func (w *Webdav) Exists(ctx context.Context, storagePath string) bool { - common.Log.Debugf("Checking if file exists at %s", storagePath) + w.logger.Debugf("Checking if file exists at %s", storagePath) exists, err := w.client.Exists(ctx, storagePath) if err != nil { - common.Log.Errorf("Failed to check if file exists at %s: %v", storagePath, err) + w.logger.Errorf("Failed to check if file exists at %s: %v", storagePath, err) return false } return exists diff --git a/types/task.go b/types/task.go deleted file mode 100644 index 25ed4b6..0000000 --- a/types/task.go +++ /dev/null @@ -1,84 +0,0 @@ -package types - -import ( - "context" - "crypto/md5" - "encoding/hex" - "fmt" - "net/url" - "strings" - "time" - - "github.com/gotd/td/tg" -) - -type Task struct { - Ctx context.Context - Cancel context.CancelFunc - Error error - UseUserClient bool - Status TaskStatus - StorageName string - StoragePath string - StartTime time.Time - - FileDBID uint - File *File - FileMessageID int - FileChatID int64 - - IsTelegraph bool - TelegraphURL string - - // to track the reply message - ReplyMessageID int - ReplyChatID int64 - UserID int64 -} - -func (t Task) Key() string { - if t.IsTelegraph { - return hashStr(t.TelegraphURL) - } - return fmt.Sprintf("%d:%d", t.FileChatID, t.FileMessageID) -} - -func (t Task) String() string { - if t.IsTelegraph { - return fmt.Sprintf("[telegraph]:%s", t.TelegraphURL) - } - return fmt.Sprintf("[%d:%d]:%s", t.FileChatID, t.FileMessageID, t.File.FileName) -} - -func (t Task) FileName() string { - if t.IsTelegraph { - tgphPath := strings.Split(t.TelegraphURL, "/")[len(strings.Split(t.TelegraphURL, "/"))-1] - tgphPathUnescaped, err := url.PathUnescape(tgphPath) - if err != nil { - return tgphPath - } - return tgphPathUnescaped - } - return t.File.FileName -} - -type File struct { - Location tg.InputFileLocationClass - FileSize int64 - FileName string -} - -func (f File) Hash() string { - locationBytes := []byte(f.Location.String()) - fileSizeBytes := []byte(fmt.Sprintf("%d", f.FileSize)) - fileNameBytes := []byte(f.FileName) - - structBytes := append(locationBytes, fileSizeBytes...) - structBytes = append(structBytes, fileNameBytes...) - - hash := md5.New() - hash.Write(structBytes) - hashBytes := hash.Sum(nil) - - return hex.EncodeToString(hashBytes) -} diff --git a/types/types.go b/types/types.go deleted file mode 100644 index 4e11bbf..0000000 --- a/types/types.go +++ /dev/null @@ -1,42 +0,0 @@ -package types - -type TaskStatus string - -const ( - Pending TaskStatus = "pending" - Succeeded TaskStatus = "succeeded" - Failed TaskStatus = "failed" - Canceled TaskStatus = "canceled" -) - -type StorageType string - -const ( - StorageTypeLocal StorageType = "local" - StorageTypeWebdav StorageType = "webdav" - StorageTypeAlist StorageType = "alist" - StorageTypeMinio StorageType = "minio" -) - -var StorageTypes = []StorageType{StorageTypeLocal, StorageTypeAlist, StorageTypeWebdav, StorageTypeMinio} -var StorageTypeDisplay = map[StorageType]string{ - StorageTypeLocal: "本地磁盘", - StorageTypeWebdav: "WebDAV", - StorageTypeAlist: "Alist", - StorageTypeMinio: "Minio", -} - -type ContextKey string - -const ( - ContextKeyContentLength ContextKey = "content-length" -) - -type RuleType string - -const ( - RuleTypeFileNameRegex RuleType = "FILENAME-REGEX" - RuleTypeMessageRegex RuleType = "MESSAGE-REGEX" -) - -var RuleTypes = []RuleType{RuleTypeFileNameRegex, RuleTypeMessageRegex} \ No newline at end of file diff --git a/types/utils.go b/types/utils.go deleted file mode 100644 index 85ab6f4..0000000 --- a/types/utils.go +++ /dev/null @@ -1,12 +0,0 @@ -package types - -import ( - "crypto/md5" - "encoding/hex" -) - -func hashStr(s string) string { - hash := md5.New() - hash.Write([]byte(s)) - return hex.EncodeToString(hash.Sum(nil)) -}