Compare commits
27 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
300f7723af | ||
|
|
491ba55f1e | ||
|
|
32519b8c08 | ||
|
|
7ffd9891a0 | ||
|
|
347a60f1f7 | ||
|
|
da69fe1354 | ||
|
|
746ca026ba | ||
|
|
a8c64675e5 | ||
|
|
3918f6eee2 | ||
|
|
8d44b43c82 | ||
|
|
f14c4367f8 | ||
|
|
3e3a320672 | ||
|
|
19efab0665 | ||
|
|
635f00ac71 | ||
|
|
2d2becccf6 | ||
|
|
ed0837a89b | ||
|
|
65fee89e14 | ||
|
|
8e180006f0 | ||
|
|
721c9666eb | ||
|
|
6f35401181 | ||
|
|
72ae2ce079 | ||
|
|
495ad3ea5c | ||
|
|
3def9df4b4 | ||
|
|
790a32d297 | ||
|
|
f7779224ef | ||
|
|
7d899ae088 | ||
|
|
7e67bdb7e2 |
3
.github/workflows/build-release.yml
vendored
3
.github/workflows/build-release.yml
vendored
@@ -38,6 +38,9 @@ jobs:
|
||||
matrix:
|
||||
goos: [linux, darwin, windows]
|
||||
goarch: [amd64, arm64]
|
||||
exclude:
|
||||
- goos: windows
|
||||
goarch: arm64
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
12
README_EN.md
12
README_EN.md
@@ -10,21 +10,13 @@ Save Telegram files to various storage endpoints.
|
||||
|
||||
</div>
|
||||
|
||||
Demo Video:
|
||||
|
||||
<div align="center">
|
||||
|
||||
[SaveAny-Bot Demo Video.webm](https://github.com/user-attachments/assets/a0de2453-a4d1-4a12-81fb-9d84856dce09)
|
||||
|
||||
</div>
|
||||
|
||||
## Deployment
|
||||
|
||||
### Deploy from Binary
|
||||
|
||||
Download the binary file for your platform from the [Release](https://github.com/krau/SaveAny-Bot/releases) page.
|
||||
|
||||
Create a `config.toml` file in the extracted directory, refer to [config.toml.example](https://github.com/krau/SaveAny-Bot/blob/main/config.example.toml) for configuration.
|
||||
Create a `config.toml` file in the extracted directory, refer to [config.example.toml](https://github.com/krau/SaveAny-Bot/blob/main/config.example.toml) for configuration.
|
||||
|
||||
Run:
|
||||
|
||||
@@ -62,7 +54,7 @@ systemctl enable --now saveany-bot
|
||||
|
||||
#### Docker Compose
|
||||
|
||||
Download [docker-compose.yml](https://github.com/krau/SaveAny-Bot/blob/main/docker-compose.yml) file and create a `config.toml` file in the same directory, refer to [config.toml.example](https://github.com/krau/SaveAny-Bot/blob/main/config.example.toml) for configuration.
|
||||
Download [docker-compose.yml](https://github.com/krau/SaveAny-Bot/blob/main/docker-compose.yml) file and create a `config.toml` file in the same directory, refer to [config.example.toml](https://github.com/krau/SaveAny-Bot/blob/main/config.example.toml) for configuration.
|
||||
|
||||
Run:
|
||||
|
||||
|
||||
11
bot/bot.go
11
bot/bot.go
@@ -11,8 +11,8 @@ import (
|
||||
"github.com/glebarez/sqlite"
|
||||
"github.com/gotd/td/telegram/dcs"
|
||||
"github.com/gotd/td/tg"
|
||||
"github.com/krau/SaveAny-Bot/common"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/logger"
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
@@ -27,7 +27,8 @@ func newProxyDialer(proxyUrl string) (proxy.Dialer, error) {
|
||||
}
|
||||
|
||||
func Init() {
|
||||
logger.L.Info("初始化 Telegram 客户端...")
|
||||
InitTelegraphClient()
|
||||
common.Log.Info("初始化 Telegram 客户端...")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
resultChan := make(chan struct {
|
||||
@@ -87,15 +88,15 @@ func Init() {
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.L.Fatal("初始化客户端失败: 超时")
|
||||
common.Log.Fatal("初始化客户端失败: 超时")
|
||||
os.Exit(1)
|
||||
case result := <-resultChan:
|
||||
if result.err != nil {
|
||||
logger.L.Fatalf("初始化客户端失败: %s", result.err)
|
||||
common.Log.Fatalf("初始化客户端失败: %s", result.err)
|
||||
os.Exit(1)
|
||||
}
|
||||
Client = result.client
|
||||
RegisterHandlers(Client.Dispatcher)
|
||||
logger.L.Info("客户端初始化完成")
|
||||
common.Log.Info("客户端初始化完成")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,9 +13,9 @@ import (
|
||||
"github.com/gotd/td/telegram/message/entity"
|
||||
"github.com/gotd/td/telegram/message/styling"
|
||||
"github.com/gotd/td/tg"
|
||||
"github.com/krau/SaveAny-Bot/common"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/dao"
|
||||
"github.com/krau/SaveAny-Bot/logger"
|
||||
"github.com/krau/SaveAny-Bot/queue"
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
"gorm.io/gorm"
|
||||
@@ -33,11 +33,11 @@ func AddToQueue(ctx *ext.Context, update *ext.Update) error {
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
args := strings.Split(string(update.CallbackQuery.Data), " ")
|
||||
addToDir := args[0] == "add_to_dir"
|
||||
addToDir := args[0] == "add_to_dir" // 已经选择了路径
|
||||
cbDataId, _ := strconv.Atoi(args[1])
|
||||
cbData, err := dao.GetCallbackData(uint(cbDataId))
|
||||
if err != nil {
|
||||
logger.L.Errorf("获取回调数据失败: %s", err)
|
||||
common.Log.Errorf("获取回调数据失败: %s", err)
|
||||
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
||||
QueryID: update.CallbackQuery.QueryID,
|
||||
Alert: true,
|
||||
@@ -56,7 +56,7 @@ func AddToQueue(ctx *ext.Context, update *ext.Update) error {
|
||||
|
||||
user, err := dao.GetUserByChatID(update.CallbackQuery.UserID)
|
||||
if err != nil {
|
||||
logger.L.Errorf("获取用户失败: %s", err)
|
||||
common.Log.Errorf("获取用户失败: %s", err)
|
||||
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
||||
QueryID: update.CallbackQuery.QueryID,
|
||||
Alert: true,
|
||||
@@ -69,7 +69,7 @@ func AddToQueue(ctx *ext.Context, update *ext.Update) error {
|
||||
if !addToDir {
|
||||
dirs, err := dao.GetDirsByUserIDAndStorageName(user.ID, storageName)
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
logger.L.Errorf("获取路径失败: %s", err)
|
||||
common.Log.Errorf("获取路径失败: %s", err)
|
||||
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
||||
QueryID: update.CallbackQuery.QueryID,
|
||||
Alert: true,
|
||||
@@ -81,7 +81,7 @@ func AddToQueue(ctx *ext.Context, update *ext.Update) error {
|
||||
if len(dirs) != 0 {
|
||||
markup, err := getSelectDirMarkup(fileChatID, fileMessageID, storageName, dirs)
|
||||
if err != nil {
|
||||
logger.L.Errorf("获取路径失败: %s", err)
|
||||
common.Log.Errorf("获取路径失败: %s", err)
|
||||
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
||||
QueryID: update.CallbackQuery.QueryID,
|
||||
Alert: true,
|
||||
@@ -96,16 +96,16 @@ func AddToQueue(ctx *ext.Context, update *ext.Update) error {
|
||||
ReplyMarkup: markup,
|
||||
})
|
||||
if err != nil {
|
||||
logger.L.Errorf("编辑消息失败: %s", err)
|
||||
common.Log.Errorf("编辑消息失败: %s", err)
|
||||
}
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
}
|
||||
|
||||
logger.L.Tracef("Got add to queue: chatID: %d, messageID: %d, storage: %s", fileChatID, fileMessageID, storageName)
|
||||
common.Log.Tracef("Got add to queue: chatID: %d, messageID: %d, storage: %s", fileChatID, fileMessageID, storageName)
|
||||
record, err := dao.GetReceivedFileByChatAndMessageID(int64(fileChatID), fileMessageID)
|
||||
if err != nil {
|
||||
logger.L.Errorf("获取记录失败: %s", err)
|
||||
common.Log.Errorf("获取记录失败: %s", err)
|
||||
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
||||
QueryID: update.CallbackQuery.QueryID,
|
||||
Alert: true,
|
||||
@@ -117,7 +117,7 @@ func AddToQueue(ctx *ext.Context, update *ext.Update) error {
|
||||
if update.CallbackQuery.MsgID != record.ReplyMessageID {
|
||||
record.ReplyMessageID = update.CallbackQuery.MsgID
|
||||
if err := dao.SaveReceivedFile(record); err != nil {
|
||||
logger.L.Errorf("更新接收的文件失败: %s", err)
|
||||
common.Log.Errorf("更新接收的文件失败: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -125,7 +125,7 @@ func AddToQueue(ctx *ext.Context, update *ext.Update) error {
|
||||
if addToDir && dirId != 0 {
|
||||
dir, err = dao.GetDirByID(dirId)
|
||||
if err != nil {
|
||||
logger.L.Errorf("获取路径失败: %s", err)
|
||||
common.Log.Errorf("获取路径失败: %s", err)
|
||||
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
||||
QueryID: update.CallbackQuery.QueryID,
|
||||
Alert: true,
|
||||
@@ -136,31 +136,50 @@ func AddToQueue(ctx *ext.Context, update *ext.Update) error {
|
||||
}
|
||||
}
|
||||
|
||||
file, err := FileFromMessage(ctx, record.ChatID, record.MessageID, record.FileName)
|
||||
if err != nil {
|
||||
logger.L.Errorf("获取消息中的文件失败: %s", err)
|
||||
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
||||
QueryID: update.CallbackQuery.QueryID,
|
||||
Alert: true,
|
||||
Message: fmt.Sprintf("获取消息中的文件失败: %s", err),
|
||||
CacheTime: 5,
|
||||
})
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
var task types.Task
|
||||
if record.IsTelegraph {
|
||||
task = types.Task{
|
||||
Ctx: ctx,
|
||||
Status: types.Pending,
|
||||
IsTelegraph: true,
|
||||
TelegraphURL: record.TelegraphURL,
|
||||
StorageName: storageName,
|
||||
FileChatID: record.ChatID,
|
||||
FileMessageID: record.MessageID,
|
||||
ReplyMessageID: record.ReplyMessageID,
|
||||
ReplyChatID: record.ReplyChatID,
|
||||
UserID: update.GetUserChat().GetID(),
|
||||
}
|
||||
if dir != nil {
|
||||
task.StoragePath = path.Join(dir.Path, record.FileName)
|
||||
}
|
||||
} else {
|
||||
file, err := FileFromMessage(ctx, record.ChatID, record.MessageID, record.FileName)
|
||||
if err != nil {
|
||||
common.Log.Errorf("获取消息中的文件失败: %s", err)
|
||||
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
||||
QueryID: update.CallbackQuery.QueryID,
|
||||
Alert: true,
|
||||
Message: fmt.Sprintf("获取消息中的文件失败: %s", err),
|
||||
CacheTime: 5,
|
||||
})
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
|
||||
task := types.Task{
|
||||
Ctx: ctx,
|
||||
Status: types.Pending,
|
||||
File: file,
|
||||
StorageName: storageName,
|
||||
FileChatID: record.ChatID,
|
||||
ReplyMessageID: record.ReplyMessageID,
|
||||
FileMessageID: record.MessageID,
|
||||
ReplyChatID: record.ReplyChatID,
|
||||
UserID: update.GetUserChat().GetID(),
|
||||
}
|
||||
if dir != nil {
|
||||
task.StoragePath = path.Join(dir.Path, file.FileName)
|
||||
task = types.Task{
|
||||
Ctx: ctx,
|
||||
Status: types.Pending,
|
||||
File: file,
|
||||
StorageName: storageName,
|
||||
FileChatID: record.ChatID,
|
||||
ReplyMessageID: record.ReplyMessageID,
|
||||
FileMessageID: record.MessageID,
|
||||
ReplyChatID: record.ReplyChatID,
|
||||
UserID: update.GetUserChat().GetID(),
|
||||
}
|
||||
if dir != nil {
|
||||
task.StoragePath = path.Join(dir.Path, file.FileName)
|
||||
}
|
||||
}
|
||||
|
||||
queue.AddTask(&task)
|
||||
@@ -174,7 +193,7 @@ func AddToQueue(ctx *ext.Context, update *ext.Update) error {
|
||||
styling.Plain("\n当前排队任务数: "),
|
||||
styling.Bold(strconv.Itoa(queue.Len())),
|
||||
); err != nil {
|
||||
logger.L.Errorf("Failed to build entity: %s", err)
|
||||
common.Log.Errorf("Failed to build entity: %s", err)
|
||||
} else {
|
||||
text, entities = entityBuilder.Complete()
|
||||
}
|
||||
|
||||
@@ -69,7 +69,7 @@ func (c *ConversationState) SetData(key string, value interface{}) {
|
||||
// func handleConversationState(ctx *ext.Context, update *ext.Update, state *ConversationState) error {
|
||||
// switch state.conversationType {
|
||||
// default:
|
||||
// logger.L.Errorf("Unknown conversation type: %s", state.conversationType)
|
||||
// common.Log.Errorf("Unknown conversation type: %s", state.conversationType)
|
||||
// }
|
||||
// return dispatcher.EndGroups
|
||||
// }
|
||||
|
||||
@@ -6,8 +6,8 @@ import (
|
||||
"github.com/celestix/gotgproto/dispatcher"
|
||||
"github.com/celestix/gotgproto/ext"
|
||||
"github.com/gotd/td/telegram/message/styling"
|
||||
"github.com/krau/SaveAny-Bot/common"
|
||||
"github.com/krau/SaveAny-Bot/dao"
|
||||
"github.com/krau/SaveAny-Bot/logger"
|
||||
"github.com/krau/SaveAny-Bot/storage"
|
||||
)
|
||||
|
||||
@@ -16,7 +16,7 @@ func dirCmd(ctx *ext.Context, update *ext.Update) error {
|
||||
if len(args) < 3 {
|
||||
dirs, err := dao.GetUserDirsByChatID(update.GetUserChat().GetID())
|
||||
if err != nil {
|
||||
logger.L.Errorf("获取用户路径失败: %s", err)
|
||||
common.Log.Errorf("获取用户路径失败: %s", err)
|
||||
ctx.Reply(update, ext.ReplyTextString("获取用户路径失败"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
@@ -47,7 +47,7 @@ func dirCmd(ctx *ext.Context, update *ext.Update) error {
|
||||
}
|
||||
user, err := dao.GetUserByChatID(update.GetUserChat().GetID())
|
||||
if err != nil {
|
||||
logger.L.Errorf("获取用户失败: %s", err)
|
||||
common.Log.Errorf("获取用户失败: %s", err)
|
||||
ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
@@ -69,7 +69,7 @@ func addDir(ctx *ext.Context, update *ext.Update, user *dao.User, storageName, p
|
||||
}
|
||||
|
||||
if err := dao.CreateDirForUser(user.ID, storageName, path); err != nil {
|
||||
logger.L.Errorf("创建路径失败: %s", err)
|
||||
common.Log.Errorf("创建路径失败: %s", err)
|
||||
ctx.Reply(update, ext.ReplyTextString("创建路径失败"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
@@ -79,7 +79,7 @@ func addDir(ctx *ext.Context, update *ext.Update, user *dao.User, storageName, p
|
||||
|
||||
func delDir(ctx *ext.Context, update *ext.Update, user *dao.User, storageName, path string) error {
|
||||
if err := dao.DeleteDirForUser(user.ID, storageName, path); err != nil {
|
||||
logger.L.Errorf("删除路径失败: %s", err)
|
||||
common.Log.Errorf("删除路径失败: %s", err)
|
||||
ctx.Reply(update, ext.ReplyTextString("删除路径失败"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
|
||||
@@ -6,14 +6,14 @@ import (
|
||||
"github.com/celestix/gotgproto/dispatcher"
|
||||
"github.com/celestix/gotgproto/ext"
|
||||
"github.com/gotd/td/tg"
|
||||
"github.com/krau/SaveAny-Bot/common"
|
||||
"github.com/krau/SaveAny-Bot/dao"
|
||||
"github.com/krau/SaveAny-Bot/logger"
|
||||
"github.com/krau/SaveAny-Bot/storage"
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
)
|
||||
|
||||
func handleFileMessage(ctx *ext.Context, update *ext.Update) error {
|
||||
logger.L.Trace("Got media: ", update.EffectiveMessage.Media.TypeName())
|
||||
common.Log.Trace("Got media: ", update.EffectiveMessage.Media.TypeName())
|
||||
supported, err := supportedMediaFilter(update.EffectiveMessage.Message)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -24,7 +24,7 @@ func handleFileMessage(ctx *ext.Context, update *ext.Update) error {
|
||||
|
||||
user, err := dao.GetUserByChatID(update.GetUserChat().GetID())
|
||||
if err != nil {
|
||||
logger.L.Errorf("获取用户失败: %s", err)
|
||||
common.Log.Errorf("获取用户失败: %s", err)
|
||||
ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
@@ -36,18 +36,18 @@ func handleFileMessage(ctx *ext.Context, update *ext.Update) error {
|
||||
|
||||
msg, err := ctx.Reply(update, ext.ReplyTextString("正在获取文件信息..."), nil)
|
||||
if err != nil {
|
||||
logger.L.Errorf("回复失败: %s", err)
|
||||
common.Log.Errorf("回复失败: %s", err)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
media := update.EffectiveMessage.Media
|
||||
file, err := FileFromMedia(media, "")
|
||||
if err != nil {
|
||||
logger.L.Errorf("获取文件失败: %s", err)
|
||||
common.Log.Errorf("获取文件失败: %s", err)
|
||||
ctx.Reply(update, ext.ReplyTextString(fmt.Sprintf("获取文件失败: %s", err)), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
if file.FileName == "" {
|
||||
file.FileName = fmt.Sprintf("%d_%d_%s", update.EffectiveChat().GetID(), update.EffectiveMessage.ID, file.Hash())
|
||||
file.FileName = GenFileNameFromMessage(*update.EffectiveMessage.Message, file)
|
||||
}
|
||||
|
||||
if err := dao.SaveReceivedFile(&dao.ReceivedFile{
|
||||
@@ -58,18 +58,18 @@ func handleFileMessage(ctx *ext.Context, update *ext.Update) error {
|
||||
ReplyMessageID: msg.ID,
|
||||
ReplyChatID: update.GetUserChat().GetID(),
|
||||
}); err != nil {
|
||||
logger.L.Errorf("添加接收的文件失败: %s", err)
|
||||
common.Log.Errorf("添加接收的文件失败: %s", err)
|
||||
if _, err := ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
|
||||
Message: fmt.Sprintf("添加接收的文件失败: %s", err),
|
||||
ID: msg.ID,
|
||||
}); err != nil {
|
||||
logger.L.Errorf("编辑消息失败: %s", err)
|
||||
common.Log.Errorf("编辑消息失败: %s", err)
|
||||
}
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
|
||||
if !user.Silent || user.DefaultStorage == "" {
|
||||
return ProvideSelectMessage(ctx, update, file, update.EffectiveChat().GetID(), update.EffectiveMessage.ID, msg.ID)
|
||||
return ProvideSelectMessage(ctx, update, file.FileName, update.EffectiveChat().GetID(), update.EffectiveMessage.ID, msg.ID)
|
||||
}
|
||||
return HandleSilentAddTask(ctx, update, user, &types.Task{
|
||||
Ctx: ctx,
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package bot
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -9,8 +8,8 @@ import (
|
||||
"github.com/celestix/gotgproto/dispatcher"
|
||||
"github.com/celestix/gotgproto/ext"
|
||||
"github.com/gotd/td/tg"
|
||||
"github.com/krau/SaveAny-Bot/common"
|
||||
"github.com/krau/SaveAny-Bot/dao"
|
||||
"github.com/krau/SaveAny-Bot/logger"
|
||||
"github.com/krau/SaveAny-Bot/storage"
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
)
|
||||
@@ -21,7 +20,7 @@ var (
|
||||
)
|
||||
|
||||
func handleLinkMessage(ctx *ext.Context, update *ext.Update) error {
|
||||
logger.L.Trace("Got link message")
|
||||
common.Log.Trace("Got link message")
|
||||
link := linkRegex.FindString(update.EffectiveMessage.Text)
|
||||
if link == "" {
|
||||
return dispatcher.ContinueGroups
|
||||
@@ -32,25 +31,25 @@ func handleLinkMessage(ctx *ext.Context, update *ext.Update) error {
|
||||
}
|
||||
messageID, err := strconv.Atoi(strSlice[2])
|
||||
if err != nil {
|
||||
logger.L.Errorf("解析消息 ID 失败: %s", err)
|
||||
common.Log.Errorf("解析消息 ID 失败: %s", err)
|
||||
ctx.Reply(update, ext.ReplyTextString("无法解析消息 ID"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
chatUsername := strSlice[1]
|
||||
linkChat, err := ctx.ResolveUsername(chatUsername)
|
||||
if err != nil {
|
||||
logger.L.Errorf("解析 Chat ID 失败: %s", err)
|
||||
common.Log.Errorf("解析 Chat ID 失败: %s", err)
|
||||
ctx.Reply(update, ext.ReplyTextString("无法解析 Chat ID"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
if linkChat == nil {
|
||||
logger.L.Errorf("无法找到聊天: %s", chatUsername)
|
||||
common.Log.Errorf("无法找到聊天: %s", chatUsername)
|
||||
ctx.Reply(update, ext.ReplyTextString("无法找到聊天"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
user, err := dao.GetUserByChatID(update.GetUserChat().GetID())
|
||||
if err != nil {
|
||||
logger.L.Errorf("获取用户失败: %s", err)
|
||||
common.Log.Errorf("获取用户失败: %s", err)
|
||||
ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
@@ -62,20 +61,18 @@ func handleLinkMessage(ctx *ext.Context, update *ext.Update) error {
|
||||
}
|
||||
replied, err := ctx.Reply(update, ext.ReplyTextString("正在获取文件..."), nil)
|
||||
if err != nil {
|
||||
logger.L.Errorf("回复失败: %s", err)
|
||||
common.Log.Errorf("回复失败: %s", err)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
|
||||
file, err := FileFromMessage(ctx, linkChat.GetID(), messageID, "")
|
||||
if err != nil {
|
||||
logger.L.Errorf("获取文件失败: %s", err)
|
||||
common.Log.Errorf("获取文件失败: %s", err)
|
||||
ctx.Reply(update, ext.ReplyTextString("获取文件失败: "+err.Error()), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
// TODO: Better file name
|
||||
if file.FileName == "" {
|
||||
logger.L.Warnf("文件名为空,使用生成的名称")
|
||||
file.FileName = fmt.Sprintf("%d_%d_%s", linkChat.GetID(), messageID, file.Hash())
|
||||
file.FileName = GenFileNameFromMessage(*update.EffectiveMessage.Message, file)
|
||||
}
|
||||
|
||||
receivedFile := &dao.ReceivedFile{
|
||||
@@ -87,7 +84,7 @@ func handleLinkMessage(ctx *ext.Context, update *ext.Update) error {
|
||||
ReplyChatID: update.GetUserChat().GetID(),
|
||||
}
|
||||
if err := dao.SaveReceivedFile(receivedFile); err != nil {
|
||||
logger.L.Errorf("保存接收的文件失败: %s", err)
|
||||
common.Log.Errorf("保存接收的文件失败: %s", err)
|
||||
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
|
||||
Message: "无法保存文件: " + err.Error(),
|
||||
ID: replied.ID,
|
||||
@@ -95,7 +92,7 @@ func handleLinkMessage(ctx *ext.Context, update *ext.Update) error {
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
if !user.Silent || user.DefaultStorage == "" {
|
||||
return ProvideSelectMessage(ctx, update, file, linkChat.GetID(), messageID, replied.ID)
|
||||
return ProvideSelectMessage(ctx, update, file.FileName, linkChat.GetID(), messageID, replied.ID)
|
||||
}
|
||||
return HandleSilentAddTask(ctx, update, user, &types.Task{
|
||||
Ctx: ctx,
|
||||
|
||||
@@ -7,8 +7,8 @@ import (
|
||||
"github.com/celestix/gotgproto/dispatcher"
|
||||
"github.com/celestix/gotgproto/ext"
|
||||
"github.com/gotd/td/tg"
|
||||
"github.com/krau/SaveAny-Bot/common"
|
||||
"github.com/krau/SaveAny-Bot/dao"
|
||||
"github.com/krau/SaveAny-Bot/logger"
|
||||
"github.com/krau/SaveAny-Bot/storage"
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
)
|
||||
@@ -32,7 +32,7 @@ func saveCmd(ctx *ext.Context, update *ext.Update) error {
|
||||
|
||||
user, err := dao.GetUserByChatID(update.GetUserChat().GetID())
|
||||
if err != nil {
|
||||
logger.L.Errorf("获取用户失败: %s", err)
|
||||
common.Log.Errorf("获取用户失败: %s", err)
|
||||
ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
@@ -46,7 +46,7 @@ func saveCmd(ctx *ext.Context, update *ext.Update) error {
|
||||
|
||||
msg, err := GetTGMessage(ctx, update.EffectiveChat().GetID(), replyToMsgID)
|
||||
if err != nil {
|
||||
logger.L.Errorf("获取消息失败: %s", err)
|
||||
common.Log.Errorf("获取消息失败: %s", err)
|
||||
ctx.Reply(update, ext.ReplyTextString("无法获取消息"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
@@ -59,7 +59,7 @@ func saveCmd(ctx *ext.Context, update *ext.Update) error {
|
||||
|
||||
replied, err := ctx.Reply(update, ext.ReplyTextString("正在获取文件信息..."), nil)
|
||||
if err != nil {
|
||||
logger.L.Errorf("回复失败: %s", err)
|
||||
common.Log.Errorf("回复失败: %s", err)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
|
||||
@@ -68,7 +68,7 @@ func saveCmd(ctx *ext.Context, update *ext.Update) error {
|
||||
|
||||
file, err := FileFromMessage(ctx, update.EffectiveChat().GetID(), msg.ID, customFileName)
|
||||
if err != nil {
|
||||
logger.L.Errorf("获取文件失败: %s", err)
|
||||
common.Log.Errorf("获取文件失败: %s", err)
|
||||
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
|
||||
Message: fmt.Sprintf("获取文件失败: %s", err),
|
||||
ID: replied.ID,
|
||||
@@ -76,9 +76,8 @@ func saveCmd(ctx *ext.Context, update *ext.Update) error {
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
|
||||
// TODO: better file name
|
||||
if file.FileName == "" {
|
||||
file.FileName = fmt.Sprintf("%d_%d_%s", update.EffectiveChat().GetID(), replyToMsgID, file.Hash())
|
||||
file.FileName = GenFileNameFromMessage(*msg, file)
|
||||
}
|
||||
receivedFile := &dao.ReceivedFile{
|
||||
Processing: false,
|
||||
@@ -90,17 +89,17 @@ func saveCmd(ctx *ext.Context, update *ext.Update) error {
|
||||
}
|
||||
|
||||
if err := dao.SaveReceivedFile(receivedFile); err != nil {
|
||||
logger.L.Errorf("保存接收的文件失败: %s", err)
|
||||
common.Log.Errorf("保存接收的文件失败: %s", err)
|
||||
if _, err := ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
|
||||
Message: fmt.Sprintf("保存接收的文件失败: %s", err),
|
||||
ID: replied.ID,
|
||||
}); err != nil {
|
||||
logger.L.Errorf("编辑消息失败: %s", err)
|
||||
common.Log.Errorf("编辑消息失败: %s", err)
|
||||
}
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
if !user.Silent || user.DefaultStorage == "" {
|
||||
return ProvideSelectMessage(ctx, update, file, update.EffectiveChat().GetID(), msg.ID, replied.ID)
|
||||
return ProvideSelectMessage(ctx, update, file.FileName, update.EffectiveChat().GetID(), msg.ID, replied.ID)
|
||||
}
|
||||
return HandleSilentAddTask(ctx, update, user, &types.Task{
|
||||
Ctx: ctx,
|
||||
|
||||
@@ -5,14 +5,14 @@ import (
|
||||
|
||||
"github.com/celestix/gotgproto/dispatcher"
|
||||
"github.com/celestix/gotgproto/ext"
|
||||
"github.com/krau/SaveAny-Bot/common"
|
||||
"github.com/krau/SaveAny-Bot/dao"
|
||||
"github.com/krau/SaveAny-Bot/logger"
|
||||
)
|
||||
|
||||
func silent(ctx *ext.Context, update *ext.Update) error {
|
||||
user, err := dao.GetUserByChatID(update.GetUserChat().GetID())
|
||||
if err != nil {
|
||||
logger.L.Errorf("获取用户失败: %s", err)
|
||||
common.Log.Errorf("获取用户失败: %s", err)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
if !user.Silent && user.DefaultStorage == "" {
|
||||
@@ -21,7 +21,7 @@ func silent(ctx *ext.Context, update *ext.Update) error {
|
||||
}
|
||||
user.Silent = !user.Silent
|
||||
if err := dao.UpdateUser(user); err != nil {
|
||||
logger.L.Errorf("更新用户失败: %s", err)
|
||||
common.Log.Errorf("更新用户失败: %s", err)
|
||||
ctx.Reply(update, ext.ReplyTextString("更新用户失败"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
package bot
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/celestix/gotgproto/dispatcher"
|
||||
"github.com/celestix/gotgproto/ext"
|
||||
"github.com/krau/SaveAny-Bot/common"
|
||||
"github.com/krau/SaveAny-Bot/dao"
|
||||
"github.com/krau/SaveAny-Bot/logger"
|
||||
)
|
||||
|
||||
func start(ctx *ext.Context, update *ext.Update) error {
|
||||
if err := dao.CreateUser(update.GetUserChat().GetID()); err != nil {
|
||||
logger.L.Errorf("创建用户失败: %s", err)
|
||||
common.Log.Errorf("创建用户失败: %s", err)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
return help(ctx, update)
|
||||
@@ -17,6 +19,7 @@ func start(ctx *ext.Context, update *ext.Update) error {
|
||||
|
||||
const helpText string = `
|
||||
Save Any Bot - 转存你的 Telegram 文件
|
||||
版本: %s , 提交: %s
|
||||
命令:
|
||||
/start - 开始使用
|
||||
/help - 显示帮助
|
||||
@@ -32,6 +35,6 @@ Save Any Bot - 转存你的 Telegram 文件
|
||||
`
|
||||
|
||||
func help(ctx *ext.Context, update *ext.Update) error {
|
||||
ctx.Reply(update, ext.ReplyTextString(helpText), nil)
|
||||
ctx.Reply(update, ext.ReplyTextString(fmt.Sprintf(helpText, common.Version, common.GitCommit[:7])), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
|
||||
@@ -8,8 +8,8 @@ import (
|
||||
"github.com/celestix/gotgproto/dispatcher"
|
||||
"github.com/celestix/gotgproto/ext"
|
||||
"github.com/gotd/td/tg"
|
||||
"github.com/krau/SaveAny-Bot/common"
|
||||
"github.com/krau/SaveAny-Bot/dao"
|
||||
"github.com/krau/SaveAny-Bot/logger"
|
||||
"github.com/krau/SaveAny-Bot/storage"
|
||||
)
|
||||
|
||||
@@ -22,7 +22,7 @@ func storageCmd(ctx *ext.Context, update *ext.Update) error {
|
||||
}
|
||||
markup, err := getSetDefaultStorageMarkup(userChatID, storages)
|
||||
if err != nil {
|
||||
logger.L.Errorf("Failed to get markup: %s", err)
|
||||
common.Log.Errorf("Failed to get markup: %s", err)
|
||||
ctx.Reply(update, ext.ReplyTextString("获取存储位置失败"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
@@ -47,7 +47,7 @@ func setDefaultStorage(ctx *ext.Context, update *ext.Update) error {
|
||||
cbDataId, _ := strconv.Atoi(args[2])
|
||||
storageName, err := dao.GetCallbackData(uint(cbDataId))
|
||||
if err != nil {
|
||||
logger.L.Errorf("获取回调数据失败: %s", err)
|
||||
common.Log.Errorf("获取回调数据失败: %s", err)
|
||||
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
||||
QueryID: update.CallbackQuery.QueryID,
|
||||
Alert: true,
|
||||
@@ -60,7 +60,7 @@ func setDefaultStorage(ctx *ext.Context, update *ext.Update) error {
|
||||
selectedStorage, err := storage.GetStorageByName(storageName)
|
||||
|
||||
if err != nil {
|
||||
logger.L.Errorf("获取指定存储失败: %s", err)
|
||||
common.Log.Errorf("获取指定存储失败: %s", err)
|
||||
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
||||
QueryID: update.CallbackQuery.QueryID,
|
||||
Alert: true,
|
||||
@@ -71,7 +71,7 @@ func setDefaultStorage(ctx *ext.Context, update *ext.Update) error {
|
||||
}
|
||||
user, err := dao.GetUserByChatID(int64(userID))
|
||||
if err != nil {
|
||||
logger.L.Errorf("Failed to get user: %s", err)
|
||||
common.Log.Errorf("Failed to get user: %s", err)
|
||||
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
||||
QueryID: update.CallbackQuery.QueryID,
|
||||
Alert: true,
|
||||
@@ -82,7 +82,7 @@ func setDefaultStorage(ctx *ext.Context, update *ext.Update) error {
|
||||
}
|
||||
user.DefaultStorage = storageName
|
||||
if err := dao.UpdateUser(user); err != nil {
|
||||
logger.L.Errorf("Failed to update user: %s", err)
|
||||
common.Log.Errorf("Failed to update user: %s", err)
|
||||
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
||||
QueryID: update.CallbackQuery.QueryID,
|
||||
Alert: true,
|
||||
|
||||
114
bot/handle_telegraph.go
Normal file
114
bot/handle_telegraph.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package bot
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/celestix/gotgproto/dispatcher"
|
||||
"github.com/celestix/gotgproto/ext"
|
||||
"github.com/celestix/telegraph-go/v2"
|
||||
"github.com/gotd/td/tg"
|
||||
"github.com/krau/SaveAny-Bot/common"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/dao"
|
||||
"github.com/krau/SaveAny-Bot/storage"
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
)
|
||||
|
||||
var (
|
||||
TelegraphClient *telegraph.TelegraphClient
|
||||
TelegraphUrlRegexString = `https://telegra.ph/.*`
|
||||
TelegraphUrlRegex = regexp.MustCompile(TelegraphUrlRegexString)
|
||||
)
|
||||
|
||||
func InitTelegraphClient() {
|
||||
var httpClient *http.Client
|
||||
if config.Cfg.Telegram.Proxy.Enable {
|
||||
proxyUrl, err := url.Parse(config.Cfg.Telegram.Proxy.URL)
|
||||
if err != nil {
|
||||
fmt.Println("Error parsing proxy URL:", err)
|
||||
return
|
||||
}
|
||||
proxy := http.ProxyURL(proxyUrl)
|
||||
httpClient = &http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: proxy,
|
||||
},
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
} else {
|
||||
httpClient = &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
}
|
||||
TelegraphClient = telegraph.GetTelegraphClient(&telegraph.ClientOpt{HttpClient: httpClient})
|
||||
}
|
||||
|
||||
func handleTelegraph(ctx *ext.Context, update *ext.Update) error {
|
||||
common.Log.Trace("Got telegraph link")
|
||||
tgphUrl := TelegraphUrlRegex.FindString(update.EffectiveMessage.Text)
|
||||
if tgphUrl == "" {
|
||||
return dispatcher.ContinueGroups
|
||||
}
|
||||
replied, err := ctx.Reply(update, ext.ReplyTextString("正在获取文件..."), nil)
|
||||
if err != nil {
|
||||
common.Log.Errorf("回复失败: %s", err)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
user, err := dao.GetUserByChatID(update.GetUserChat().GetID())
|
||||
if err != nil {
|
||||
common.Log.Errorf("获取用户失败: %s", err)
|
||||
ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
storages := storage.GetUserStorages(user.ChatID)
|
||||
|
||||
if len(storages) == 0 {
|
||||
ctx.Reply(update, ext.ReplyTextString("无可用的存储"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
|
||||
tgphPath := strings.Split(tgphUrl, "/")[len(strings.Split(tgphUrl, "/"))-1]
|
||||
fileName, err := url.PathUnescape(tgphPath)
|
||||
if err != nil {
|
||||
common.Log.Errorf("解析 Telegraph 路径失败: %s", err)
|
||||
fileName = tgphPath
|
||||
}
|
||||
|
||||
record := &dao.ReceivedFile{
|
||||
Processing: false,
|
||||
FileName: fileName,
|
||||
ChatID: update.EffectiveChat().GetID(),
|
||||
MessageID: update.EffectiveMessage.GetID(),
|
||||
ReplyMessageID: replied.ID,
|
||||
ReplyChatID: update.EffectiveChat().GetID(),
|
||||
IsTelegraph: true,
|
||||
TelegraphURL: tgphUrl,
|
||||
}
|
||||
if err := dao.SaveReceivedFile(record); err != nil {
|
||||
common.Log.Errorf("保存接收的文件失败: %s", err)
|
||||
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
|
||||
Message: "无法保存文件: " + err.Error(),
|
||||
ID: replied.ID,
|
||||
})
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
|
||||
if !user.Silent || user.DefaultStorage == "" {
|
||||
return ProvideSelectMessage(ctx, update, fileName, update.EffectiveChat().GetID(), update.EffectiveMessage.GetID(), replied.ID)
|
||||
}
|
||||
return HandleSilentAddTask(ctx, update, user, &types.Task{
|
||||
Ctx: ctx,
|
||||
Status: types.Pending,
|
||||
StorageName: user.DefaultStorage,
|
||||
UserID: user.ChatID,
|
||||
ReplyMessageID: replied.ID,
|
||||
ReplyChatID: update.GetUserChat().GetID(),
|
||||
IsTelegraph: true,
|
||||
TelegraphURL: tgphUrl,
|
||||
})
|
||||
}
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"github.com/celestix/gotgproto/dispatcher"
|
||||
"github.com/celestix/gotgproto/dispatcher/handlers"
|
||||
"github.com/celestix/gotgproto/dispatcher/handlers/filters"
|
||||
"github.com/krau/SaveAny-Bot/logger"
|
||||
"github.com/krau/SaveAny-Bot/common"
|
||||
)
|
||||
|
||||
func RegisterHandlers(dispatcher dispatcher.Dispatcher) {
|
||||
@@ -17,9 +17,14 @@ func RegisterHandlers(dispatcher dispatcher.Dispatcher) {
|
||||
dispatcher.AddHandler(handlers.NewCommand("dir", dirCmd))
|
||||
linkRegexFilter, err := filters.Message.Regex(linkRegexString)
|
||||
if err != nil {
|
||||
logger.L.Panicf("创建正则表达式过滤器失败: %s", err)
|
||||
common.Log.Panicf("创建正则表达式过滤器失败: %s", err)
|
||||
}
|
||||
dispatcher.AddHandler(handlers.NewMessage(linkRegexFilter, handleLinkMessage))
|
||||
telegraphUrlRegexFilter, err := filters.Message.Regex(TelegraphUrlRegexString)
|
||||
if err != nil {
|
||||
common.Log.Panicf("创建 Telegraph URL 正则表达式过滤器失败: %s", err)
|
||||
}
|
||||
dispatcher.AddHandler(handlers.NewMessage(telegraphUrlRegexFilter, handleTelegraph))
|
||||
dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("add"), AddToQueue))
|
||||
dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("set_default"), setDefaultStorage))
|
||||
dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("cancel"), cancelTask))
|
||||
|
||||
71
bot/utils.go
71
bot/utils.go
@@ -3,16 +3,18 @@ package bot
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/celestix/gotgproto/dispatcher"
|
||||
"github.com/celestix/gotgproto/ext"
|
||||
"github.com/gabriel-vasile/mimetype"
|
||||
"github.com/gotd/td/telegram/message/entity"
|
||||
"github.com/gotd/td/telegram/message/styling"
|
||||
"github.com/gotd/td/tg"
|
||||
"github.com/krau/SaveAny-Bot/common"
|
||||
"github.com/krau/SaveAny-Bot/dao"
|
||||
"github.com/krau/SaveAny-Bot/logger"
|
||||
"github.com/krau/SaveAny-Bot/queue"
|
||||
"github.com/krau/SaveAny-Bot/storage"
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
@@ -178,7 +180,7 @@ func FileFromMedia(media tg.MessageMediaClass, customFileName string) (*types.Fi
|
||||
|
||||
func FileFromMessage(ctx *ext.Context, chatID int64, messageID int, customFileName string) (*types.File, error) {
|
||||
key := fmt.Sprintf("file:%d:%d", chatID, messageID)
|
||||
logger.L.Debugf("Getting file: %s", key)
|
||||
common.Log.Debugf("Getting file: %s", key)
|
||||
var cachedFile types.File
|
||||
err := common.Cache.Get(key, &cachedFile)
|
||||
if err == nil {
|
||||
@@ -193,13 +195,13 @@ func FileFromMessage(ctx *ext.Context, chatID int64, messageID int, customFileNa
|
||||
return nil, err
|
||||
}
|
||||
if err := common.Cache.Set(key, file, 3600); err != nil {
|
||||
logger.L.Errorf("Failed to cache file: %s", err)
|
||||
common.Log.Errorf("Failed to cache file: %s", err)
|
||||
}
|
||||
return file, nil
|
||||
}
|
||||
|
||||
func GetTGMessage(ctx *ext.Context, chatId int64, messageID int) (*tg.Message, error) {
|
||||
logger.L.Debugf("Fetching message: %d", messageID)
|
||||
common.Log.Debugf("Fetching message: %d", messageID)
|
||||
messages, err := ctx.GetMessages(chatId, []tg.InputMessageClass{&tg.InputMessageID{ID: messageID}})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -215,29 +217,29 @@ func GetTGMessage(ctx *ext.Context, chatId int64, messageID int) (*tg.Message, e
|
||||
return tgMessage, nil
|
||||
}
|
||||
|
||||
func ProvideSelectMessage(ctx *ext.Context, update *ext.Update, file *types.File, chatID int64, fileMsgID, toEditMsgID int) error {
|
||||
func ProvideSelectMessage(ctx *ext.Context, update *ext.Update, fileName string, chatID int64, fileMsgID, toEditMsgID int) error {
|
||||
entityBuilder := entity.Builder{}
|
||||
var entities []tg.MessageEntityClass
|
||||
text := fmt.Sprintf("文件名: %s\n请选择存储位置", file.FileName)
|
||||
text := fmt.Sprintf("文件名: %s\n请选择存储位置", fileName)
|
||||
if err := styling.Perform(&entityBuilder,
|
||||
styling.Plain("文件名: "),
|
||||
styling.Code(file.FileName),
|
||||
styling.Code(fileName),
|
||||
styling.Plain("\n请选择存储位置"),
|
||||
); err != nil {
|
||||
logger.L.Errorf("Failed to build entity: %s", err)
|
||||
common.Log.Errorf("Failed to build entity: %s", err)
|
||||
} else {
|
||||
text, entities = entityBuilder.Complete()
|
||||
}
|
||||
markup, err := getSelectStorageMarkup(update.GetUserChat().GetID(), int(chatID), fileMsgID)
|
||||
if errors.Is(err, ErrNoStorages) {
|
||||
logger.L.Errorf("Failed to get select storage markup: %s", err)
|
||||
common.Log.Errorf("Failed to get select storage markup: %s", err)
|
||||
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
|
||||
Message: "无可用存储",
|
||||
ID: toEditMsgID,
|
||||
})
|
||||
return dispatcher.EndGroups
|
||||
} else if err != nil {
|
||||
logger.L.Errorf("Failed to get select storage markup: %s", err)
|
||||
common.Log.Errorf("Failed to get select storage markup: %s", err)
|
||||
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
|
||||
Message: "无法获取存储",
|
||||
ID: toEditMsgID,
|
||||
@@ -251,7 +253,7 @@ func ProvideSelectMessage(ctx *ext.Context, update *ext.Update, file *types.File
|
||||
ID: toEditMsgID,
|
||||
})
|
||||
if err != nil {
|
||||
logger.L.Errorf("Failed to reply: %s", err)
|
||||
common.Log.Errorf("Failed to reply: %s", err)
|
||||
}
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
@@ -271,3 +273,50 @@ func HandleSilentAddTask(ctx *ext.Context, update *ext.Update, user *dao.User, t
|
||||
})
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
|
||||
func GenFileNameFromMessage(message tg.Message, file *types.File) string {
|
||||
if file.FileName != "" {
|
||||
return file.FileName
|
||||
}
|
||||
fileName := genFileNameFromMessageText(message, file)
|
||||
media, ok := message.GetMedia()
|
||||
if !ok {
|
||||
return fileName
|
||||
}
|
||||
ext, ok := extraMediaExt(media)
|
||||
if ok {
|
||||
return fileName + ext
|
||||
}
|
||||
return fileName
|
||||
}
|
||||
|
||||
func genFileNameFromMessageText(message tg.Message, file *types.File) string {
|
||||
text := strings.TrimSpace(message.GetMessage())
|
||||
if text == "" {
|
||||
return file.Hash()
|
||||
}
|
||||
tags := common.ExtractTagsFromText(text)
|
||||
if len(tags) > 0 {
|
||||
return fmt.Sprintf("%s_%s", strings.Join(tags, "_"), strconv.Itoa(message.GetID()))
|
||||
}
|
||||
runes := []rune(text)
|
||||
return string(runes[:min(128, len(runes))])
|
||||
}
|
||||
|
||||
func extraMediaExt(media tg.MessageMediaClass) (string, bool) {
|
||||
switch media := media.(type) {
|
||||
case *tg.MessageMediaDocument:
|
||||
doc, ok := media.Document.AsNotEmpty()
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
ext := mimetype.Lookup(doc.MimeType).Extension()
|
||||
if ext == "" {
|
||||
return "", false
|
||||
}
|
||||
return ext, true
|
||||
case *tg.MessageMediaPhoto:
|
||||
return ".jpg", true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
29
cmd/run.go
29
cmd/run.go
@@ -7,12 +7,13 @@ import (
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
|
||||
"slices"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/bot"
|
||||
"github.com/krau/SaveAny-Bot/common"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/core"
|
||||
"github.com/krau/SaveAny-Bot/dao"
|
||||
"github.com/krau/SaveAny-Bot/logger"
|
||||
"github.com/krau/SaveAny-Bot/storage"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
@@ -24,32 +25,30 @@ func Run(_ *cobra.Command, _ []string) {
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
sig := <-quit
|
||||
logger.L.Info(sig, ", exitting...")
|
||||
defer logger.L.Info("Bye!")
|
||||
common.Log.Info(sig, ", exitting...")
|
||||
defer common.Log.Info("Bye!")
|
||||
if config.Cfg.NoCleanCache {
|
||||
return
|
||||
}
|
||||
if config.Cfg.Temp.BasePath != "" {
|
||||
for _, path := range []string{"/", ".", "\\", ".."} {
|
||||
if filepath.Clean(config.Cfg.Temp.BasePath) == path {
|
||||
logger.L.Error("Invalid cache dir: ", config.Cfg.Temp.BasePath)
|
||||
return
|
||||
}
|
||||
if config.Cfg.Temp.BasePath != "" && !config.Cfg.Stream {
|
||||
if slices.Contains([]string{"/", ".", "\\", ".."}, filepath.Clean(config.Cfg.Temp.BasePath)) {
|
||||
common.Log.Error("无效的缓存文件夹: ", config.Cfg.Temp.BasePath)
|
||||
return
|
||||
}
|
||||
currentDir, err := os.Getwd()
|
||||
if err != nil {
|
||||
logger.L.Error("Failed to get current dir: ", err)
|
||||
common.Log.Error("获取工作目录失败: ", err)
|
||||
return
|
||||
}
|
||||
cachePath := filepath.Join(currentDir, config.Cfg.Temp.BasePath)
|
||||
cachePath, err = filepath.Abs(cachePath)
|
||||
if err != nil {
|
||||
logger.L.Error("Failed to get absolute path: ", err)
|
||||
common.Log.Error("获取缓存绝对路径失败: ", err)
|
||||
return
|
||||
}
|
||||
logger.L.Info("Cleaning cache dir: ", cachePath)
|
||||
common.Log.Info("正在清理缓存文件夹: ", cachePath)
|
||||
if err := os.RemoveAll(cachePath); err != nil {
|
||||
logger.L.Error("Failed to clean cache dir: ", err)
|
||||
common.Log.Error("清理缓存失败: ", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -59,8 +58,8 @@ func InitAll() {
|
||||
fmt.Println("加载配置文件失败: ", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
logger.InitLogger()
|
||||
logger.L.Info("正在启动 SaveAny-Bot...")
|
||||
common.InitLogger()
|
||||
common.Log.Info("正在启动 SaveAny-Bot...")
|
||||
dao.Init()
|
||||
storage.LoadStorages()
|
||||
common.Init()
|
||||
|
||||
@@ -24,7 +24,7 @@ func initCache() {
|
||||
Cache = &CommonCache{cache: freecache.NewCache(10 * 1024 * 1024)}
|
||||
}
|
||||
|
||||
func (c *CommonCache) Get(key string, value *types.File) error {
|
||||
func (c *CommonCache) Get(key string, value any) error {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
data, err := Cache.cache.Get([]byte(key))
|
||||
@@ -39,7 +39,7 @@ func (c *CommonCache) Get(key string, value *types.File) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonCache) Set(key string, value *types.File, expireSeconds int) error {
|
||||
func (c *CommonCache) Set(key string, value any, expireSeconds int) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
var buf bytes.Buffer
|
||||
|
||||
@@ -1,21 +1,20 @@
|
||||
package logger
|
||||
package common
|
||||
|
||||
import (
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
|
||||
"github.com/gookit/slog"
|
||||
"github.com/gookit/slog/handler"
|
||||
"github.com/gookit/slog/rotatefile"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
)
|
||||
|
||||
var L *slog.Logger
|
||||
var Log *slog.Logger
|
||||
|
||||
func InitLogger() {
|
||||
if L != nil {
|
||||
if Log != nil {
|
||||
return
|
||||
}
|
||||
slog.DefaultChannelName = "SaveAnyBot"
|
||||
L = slog.New()
|
||||
Log = slog.New()
|
||||
logLevel := slog.LevelByName(config.Cfg.Log.Level)
|
||||
logFilePath := config.Cfg.Log.File
|
||||
logBackupNum := config.Cfg.Log.BackupCount
|
||||
@@ -36,5 +35,5 @@ func InitLogger() {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
L.AddHandlers(consoleH, fileH)
|
||||
Log.AddHandlers(consoleH, fileH)
|
||||
}
|
||||
@@ -5,8 +5,6 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/logger"
|
||||
)
|
||||
|
||||
// 创建文件, 自动创建目录
|
||||
@@ -31,10 +29,10 @@ func PurgeFile(path string) error {
|
||||
func RmFileAfter(path string, td time.Duration) {
|
||||
_, err := os.Stat(path)
|
||||
if err != nil {
|
||||
logger.L.Errorf("Failed to create timer for %s: %s", path, err)
|
||||
Log.Errorf("Failed to create timer for %s: %s", path, err)
|
||||
return
|
||||
}
|
||||
logger.L.Debugf("Remove file after %s: %s", td, path)
|
||||
Log.Debugf("Remove file after %s: %s", td, path)
|
||||
time.AfterFunc(td, func() {
|
||||
PurgeFile(path)
|
||||
})
|
||||
|
||||
@@ -3,6 +3,7 @@ package common
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
func HashString(s string) string {
|
||||
@@ -10,3 +11,16 @@ func HashString(s string) string {
|
||||
hash.Write([]byte(s))
|
||||
return hex.EncodeToString(hash.Sum(nil))
|
||||
}
|
||||
|
||||
var TagRe = regexp.MustCompile(`(?:^|[\p{Zs}\s.,!?(){}[\]<>\"\',。!?():;、])#([\p{L}\d_]+)`)
|
||||
|
||||
func ExtractTagsFromText(text string) []string {
|
||||
matches := TagRe.FindAllStringSubmatch(text, -1)
|
||||
tags := make([]string, 0)
|
||||
for _, match := range matches {
|
||||
if len(match) > 1 {
|
||||
tags = append(tags, match[1])
|
||||
}
|
||||
}
|
||||
return tags
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@ url = "socks5://127.0.0.1:7890"
|
||||
[[storages]]
|
||||
# 标识名, 需要唯一
|
||||
name = "本机1"
|
||||
# 存储类型, 目前可用: local , alist , webdav
|
||||
# 存储类型, 目前可用: local, alist, webdav, minio
|
||||
type = "local"
|
||||
# 启用存储
|
||||
enable = true
|
||||
@@ -59,6 +59,16 @@ url = 'https://example.com/dav'
|
||||
username = 'username'
|
||||
password = 'password'
|
||||
|
||||
[[storages]]
|
||||
name = "MyMinio"
|
||||
type = "minio"
|
||||
enable = true
|
||||
endpoint = 'play.min.io'
|
||||
use_ssl = true
|
||||
access_key_id = 'Q3AM3UQ867SPQQA43P2F'
|
||||
secret_access_key = 'zuf+tfteSlswRu7BJ86wekitnifILbZam1KYY3TG'
|
||||
bucket_name = 'saveanybot'
|
||||
base_path = '/path/telegram'
|
||||
|
||||
# 用户列表
|
||||
[[users]]
|
||||
@@ -91,4 +101,4 @@ storages = ["本机1"]
|
||||
# cache_ttl = 30
|
||||
|
||||
# [db]
|
||||
# path = "data/data.db" # 数据库文件路径
|
||||
# path = "data/data.db" # 数据库文件路径
|
||||
|
||||
@@ -1,95 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
"gorm.io/datatypes"
|
||||
)
|
||||
|
||||
// for compatibility
|
||||
type deprecatedStorageConfig struct {
|
||||
Alist alistConfig `toml:"alist" mapstructure:"alist"`
|
||||
Local localConfig `toml:"local" mapstructure:"local"`
|
||||
Webdav webdavConfig `toml:"webdav" mapstructure:"webdav"`
|
||||
}
|
||||
|
||||
type alistConfig struct {
|
||||
Enable bool `toml:"enable" mapstructure:"enable" json:"enable"`
|
||||
URL string `toml:"url" mapstructure:"url" json:"url"`
|
||||
Username string `toml:"username" mapstructure:"username" json:"username"`
|
||||
Password string `toml:"password" mapstructure:"password" json:"password"`
|
||||
Token string `toml:"token" mapstructure:"token" json:"token"`
|
||||
BasePath string `toml:"base_path" mapstructure:"base_path" json:"base_path"`
|
||||
TokenExp int64 `toml:"token_exp" mapstructure:"token_exp" json:"token_exp"`
|
||||
}
|
||||
|
||||
func (a *alistConfig) ToJSON() datatypes.JSON {
|
||||
tokenExp := strconv.FormatInt(a.TokenExp, 10)
|
||||
return datatypes.JSON([]byte(`{"url":"` + a.URL + `","username":"` + a.Username + `","password":"` + a.Password + `","token":"` + a.Token + `","base_path":"` + a.BasePath + `","token_exp":` + tokenExp + `}`))
|
||||
}
|
||||
|
||||
type localConfig struct {
|
||||
Enable bool `toml:"enable" mapstructure:"enable" json:"enable"`
|
||||
BasePath string `toml:"base_path" mapstructure:"base_path" json:"base_path"`
|
||||
}
|
||||
|
||||
func (l *localConfig) ToJSON() datatypes.JSON {
|
||||
return datatypes.JSON([]byte(`{"base_path":"` + l.BasePath + `"}`))
|
||||
}
|
||||
|
||||
type webdavConfig struct {
|
||||
Enable bool `toml:"enable" mapstructure:"enable" json:"enable"`
|
||||
URL string `toml:"url" mapstructure:"url" json:"url"`
|
||||
Username string `toml:"username" mapstructure:"username" json:"username"`
|
||||
Password string `toml:"password" mapstructure:"password" json:"password"`
|
||||
BasePath string `toml:"base_path" mapstructure:"base_path" json:"base_path"`
|
||||
}
|
||||
|
||||
func (w *webdavConfig) ToJSON() datatypes.JSON {
|
||||
return datatypes.JSON([]byte(`{"url":"` + w.URL + `","username":"` + w.Username + `","password":"` + w.Password + `","base_path":"` + w.BasePath + `"}`))
|
||||
}
|
||||
|
||||
func transformDeprecatedStorageConfig() {
|
||||
if Cfg.DeprecatedStorage.Alist.Enable {
|
||||
alistStorage := &AlistStorageConfig{
|
||||
NewStorageConfig: NewStorageConfig{
|
||||
Name: "Alist",
|
||||
Enable: true,
|
||||
Type: string(types.StorageTypeAlist),
|
||||
},
|
||||
URL: Cfg.DeprecatedStorage.Alist.URL,
|
||||
Username: Cfg.DeprecatedStorage.Alist.Username,
|
||||
Password: Cfg.DeprecatedStorage.Alist.Password,
|
||||
Token: Cfg.DeprecatedStorage.Alist.Token,
|
||||
BasePath: Cfg.DeprecatedStorage.Alist.BasePath,
|
||||
TokenExp: Cfg.DeprecatedStorage.Alist.TokenExp,
|
||||
}
|
||||
Cfg.Storages = append(Cfg.Storages, alistStorage)
|
||||
}
|
||||
if Cfg.DeprecatedStorage.Local.Enable {
|
||||
localStorage := &LocalStorageConfig{
|
||||
NewStorageConfig: NewStorageConfig{
|
||||
Name: "Local",
|
||||
Enable: true,
|
||||
Type: string(types.StorageTypeLocal),
|
||||
},
|
||||
BasePath: Cfg.DeprecatedStorage.Local.BasePath,
|
||||
}
|
||||
Cfg.Storages = append(Cfg.Storages, localStorage)
|
||||
}
|
||||
if Cfg.DeprecatedStorage.Webdav.Enable {
|
||||
webdavStorage := &WebdavStorageConfig{
|
||||
NewStorageConfig: NewStorageConfig{
|
||||
Name: "Webdav",
|
||||
Enable: true,
|
||||
Type: string(types.StorageTypeWebdav),
|
||||
},
|
||||
URL: Cfg.DeprecatedStorage.Webdav.URL,
|
||||
Username: Cfg.DeprecatedStorage.Webdav.Username,
|
||||
Password: Cfg.DeprecatedStorage.Webdav.Password,
|
||||
BasePath: Cfg.DeprecatedStorage.Webdav.BasePath,
|
||||
}
|
||||
Cfg.Storages = append(Cfg.Storages, webdavStorage)
|
||||
}
|
||||
}
|
||||
38
config/storage/alist.go
Normal file
38
config/storage/alist.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
)
|
||||
|
||||
type AlistStorageConfig struct {
|
||||
BaseConfig
|
||||
URL string `toml:"url" mapstructure:"url" json:"url"`
|
||||
Username string `toml:"username" mapstructure:"username" json:"username"`
|
||||
Password string `toml:"password" mapstructure:"password" json:"password"`
|
||||
Token string `toml:"token" mapstructure:"token" json:"token"`
|
||||
BasePath string `toml:"base_path" mapstructure:"base_path" json:"base_path"`
|
||||
TokenExp int64 `toml:"token_exp" mapstructure:"token_exp" json:"token_exp"`
|
||||
}
|
||||
|
||||
func (a *AlistStorageConfig) Validate() error {
|
||||
if a.URL == "" {
|
||||
return fmt.Errorf("url is required for alist storage")
|
||||
}
|
||||
if a.Token == "" && (a.Username == "" || a.Password == "") {
|
||||
return fmt.Errorf("username and password or token is required for alist storage")
|
||||
}
|
||||
if a.BasePath == "" {
|
||||
return fmt.Errorf("base_path is required for alist storage")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *AlistStorageConfig) GetType() types.StorageType {
|
||||
return types.StorageTypeAlist
|
||||
}
|
||||
|
||||
func (a *AlistStorageConfig) GetName() string {
|
||||
return a.Name
|
||||
}
|
||||
63
config/storage/factory.go
Normal file
63
config/storage/factory.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
var storageFactories = map[types.StorageType]func(cfg *BaseConfig) (StorageConfig, error){
|
||||
types.StorageTypeLocal: createStorageConfig(&LocalStorageConfig{}),
|
||||
types.StorageTypeAlist: createStorageConfig(&AlistStorageConfig{}),
|
||||
types.StorageTypeWebdav: createStorageConfig(&WebdavStorageConfig{}),
|
||||
types.StorageTypeMinio: createStorageConfig(&MinioStorageConfig{}),
|
||||
}
|
||||
|
||||
func createStorageConfig(configType StorageConfig) func(cfg *BaseConfig) (StorageConfig, error) {
|
||||
return func(cfg *BaseConfig) (StorageConfig, error) {
|
||||
configValue := reflect.New(reflect.TypeOf(configType).Elem()).Interface().(StorageConfig)
|
||||
|
||||
reflect.ValueOf(configValue).Elem().FieldByName("BaseConfig").Set(reflect.ValueOf(*cfg))
|
||||
|
||||
if err := mapstructure.Decode(cfg.RawConfig, configValue); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode %s storage config: %w", cfg.Type, err)
|
||||
}
|
||||
|
||||
return configValue, nil
|
||||
}
|
||||
}
|
||||
|
||||
func LoadStorageConfigs(v *viper.Viper) ([]StorageConfig, error) {
|
||||
var baseConfigs []BaseConfig
|
||||
if err := v.UnmarshalKey("storages", &baseConfigs); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal storage configs: %w", err)
|
||||
}
|
||||
|
||||
var configs []StorageConfig
|
||||
for _, baseCfg := range baseConfigs {
|
||||
if !baseCfg.Enable {
|
||||
continue
|
||||
}
|
||||
|
||||
factory, ok := storageFactories[types.StorageType(baseCfg.Type)]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unsupported storage type: %s", baseCfg.Type)
|
||||
}
|
||||
|
||||
cfg, err := factory(&baseCfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create storage config for %s: %w", baseCfg.Name, err)
|
||||
}
|
||||
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid storage config for %s: %w", baseCfg.Name, err)
|
||||
}
|
||||
|
||||
configs = append(configs, cfg)
|
||||
}
|
||||
|
||||
return configs, nil
|
||||
}
|
||||
27
config/storage/local.go
Normal file
27
config/storage/local.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
)
|
||||
|
||||
type LocalStorageConfig struct {
|
||||
BaseConfig
|
||||
BasePath string `toml:"base_path" mapstructure:"base_path" json:"base_path"`
|
||||
}
|
||||
|
||||
func (l *LocalStorageConfig) Validate() error {
|
||||
if l.BasePath == "" {
|
||||
return fmt.Errorf("path is required for local storage")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *LocalStorageConfig) GetType() types.StorageType {
|
||||
return types.StorageTypeLocal
|
||||
}
|
||||
|
||||
func (l *LocalStorageConfig) GetName() string {
|
||||
return l.Name
|
||||
}
|
||||
41
config/storage/minio.go
Normal file
41
config/storage/minio.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
)
|
||||
|
||||
type MinioStorageConfig struct {
|
||||
BaseConfig
|
||||
Endpoint string `toml:"endpoint" mapstructure:"endpoint" json:"endpoint"`
|
||||
AccessKeyID string `toml:"access_key_id" mapstructure:"access_key_id" json:"access_key_id"`
|
||||
SecretAccessKey string `toml:"secret_access_key" mapstructure:"secret_access_key" json:"secret_access_key"`
|
||||
BucketName string `toml:"bucket_name" mapstructure:"bucket_name" json:"bucket_name"`
|
||||
UseSSL bool `toml:"use_ssl" mapstructure:"use_ssl" json:"use_ssl"`
|
||||
BasePath string `toml:"base_path" mapstructure:"base_path" json:"base_path"`
|
||||
}
|
||||
|
||||
func (m *MinioStorageConfig) Validate() error {
|
||||
if m.Endpoint == "" {
|
||||
return fmt.Errorf("endpoint is required for minio storage")
|
||||
}
|
||||
if m.AccessKeyID == "" || m.SecretAccessKey == "" {
|
||||
return fmt.Errorf("access_key_id and secret_access_key are required for minio storage")
|
||||
}
|
||||
if m.BucketName == "" {
|
||||
return fmt.Errorf("bucket_name is required for minio storage")
|
||||
}
|
||||
if m.BasePath == "" {
|
||||
return fmt.Errorf("base_path is required for minio storage")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MinioStorageConfig) GetType() types.StorageType {
|
||||
return types.StorageTypeMinio
|
||||
}
|
||||
|
||||
func (m *MinioStorageConfig) GetName() string {
|
||||
return m.Name
|
||||
}
|
||||
16
config/storage/types.go
Normal file
16
config/storage/types.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package storage
|
||||
|
||||
import "github.com/krau/SaveAny-Bot/types"
|
||||
|
||||
type StorageConfig interface {
|
||||
Validate() error
|
||||
GetType() types.StorageType
|
||||
GetName() string
|
||||
}
|
||||
|
||||
type BaseConfig struct {
|
||||
Name string `toml:"name" mapstructure:"name" json:"name"`
|
||||
Type string `toml:"type" mapstructure:"type" json:"type"`
|
||||
Enable bool `toml:"enable" mapstructure:"enable" json:"enable"`
|
||||
RawConfig map[string]any `toml:"-" mapstructure:",remain"`
|
||||
}
|
||||
36
config/storage/webdav.go
Normal file
36
config/storage/webdav.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
)
|
||||
|
||||
type WebdavStorageConfig struct {
|
||||
BaseConfig
|
||||
URL string `toml:"url" mapstructure:"url" json:"url"`
|
||||
Username string `toml:"username" mapstructure:"username" json:"username"`
|
||||
Password string `toml:"password" mapstructure:"password" json:"password"`
|
||||
BasePath string `toml:"base_path" mapstructure:"base_path" json:"base_path"`
|
||||
}
|
||||
|
||||
func (w *WebdavStorageConfig) Validate() error {
|
||||
if w.URL == "" {
|
||||
return fmt.Errorf("url is required for webdav storage")
|
||||
}
|
||||
if w.Username == "" || w.Password == "" {
|
||||
return fmt.Errorf("username and password is required for webdav storage")
|
||||
}
|
||||
if w.BasePath == "" {
|
||||
return fmt.Errorf("base_path is required for webdav storage")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *WebdavStorageConfig) GetType() types.StorageType {
|
||||
return types.StorageTypeWebdav
|
||||
}
|
||||
|
||||
func (w *WebdavStorageConfig) GetName() string {
|
||||
return w.Name
|
||||
}
|
||||
@@ -1,104 +0,0 @@
|
||||
// storage_config.go
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
type StorageConfig interface {
|
||||
Validate() error
|
||||
GetType() types.StorageType
|
||||
GetName() string
|
||||
}
|
||||
|
||||
// Base storage config
|
||||
type NewStorageConfig struct {
|
||||
Name string `toml:"name" mapstructure:"name" json:"name"`
|
||||
Type string `toml:"type" mapstructure:"type" json:"type"`
|
||||
Enable bool `toml:"enable" mapstructure:"enable" json:"enable"`
|
||||
RawConfig map[string]interface{} `toml:"-" mapstructure:",remain"`
|
||||
}
|
||||
|
||||
type StorageConfigFactory func(cfg *NewStorageConfig) (StorageConfig, error)
|
||||
|
||||
var storageFactories = make(map[string]StorageConfigFactory)
|
||||
|
||||
func RegisterStorageFactory(storageType string, factory StorageConfigFactory) {
|
||||
storageFactories[storageType] = factory
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterStorageFactory(string(types.StorageTypeLocal), newLocalStorageConfig)
|
||||
RegisterStorageFactory(string(types.StorageTypeAlist), newAlistStorageConfig)
|
||||
RegisterStorageFactory(string(types.StorageTypeWebdav), newWebdavStorageConfig)
|
||||
}
|
||||
|
||||
func newLocalStorageConfig(cfg *NewStorageConfig) (StorageConfig, error) {
|
||||
var localCfg LocalStorageConfig
|
||||
localCfg.NewStorageConfig = *cfg
|
||||
|
||||
if err := mapstructure.Decode(cfg.RawConfig, &localCfg); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode local storage config: %w", err)
|
||||
}
|
||||
|
||||
return &localCfg, nil
|
||||
}
|
||||
|
||||
func newAlistStorageConfig(cfg *NewStorageConfig) (StorageConfig, error) {
|
||||
var alistCfg AlistStorageConfig
|
||||
alistCfg.NewStorageConfig = *cfg
|
||||
|
||||
if err := mapstructure.Decode(cfg.RawConfig, &alistCfg); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode alist storage config: %w", err)
|
||||
}
|
||||
|
||||
return &alistCfg, nil
|
||||
}
|
||||
|
||||
func newWebdavStorageConfig(cfg *NewStorageConfig) (StorageConfig, error) {
|
||||
var webdavCfg WebdavStorageConfig
|
||||
webdavCfg.NewStorageConfig = *cfg
|
||||
|
||||
if err := mapstructure.Decode(cfg.RawConfig, &webdavCfg); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode webdav storage config: %w", err)
|
||||
}
|
||||
|
||||
return &webdavCfg, nil
|
||||
}
|
||||
|
||||
func LoadStorageConfigs(v *viper.Viper) ([]StorageConfig, error) {
|
||||
var baseConfigs []NewStorageConfig
|
||||
if err := v.UnmarshalKey("storages", &baseConfigs); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal storage configs: %w", err)
|
||||
}
|
||||
|
||||
var configs []StorageConfig
|
||||
for _, baseCfg := range baseConfigs {
|
||||
if !baseCfg.Enable {
|
||||
continue
|
||||
}
|
||||
|
||||
factory, ok := storageFactories[baseCfg.Type]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unsupported storage type: %s", baseCfg.Type)
|
||||
}
|
||||
|
||||
cfg, err := factory(&baseCfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create storage config for %s: %w", baseCfg.Name, err)
|
||||
}
|
||||
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid storage config for %s: %w", baseCfg.Name, err)
|
||||
}
|
||||
|
||||
configs = append(configs, cfg)
|
||||
}
|
||||
|
||||
return configs, nil
|
||||
}
|
||||
@@ -1,106 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
)
|
||||
|
||||
func (c *Config) GetStoragesByType(storageType types.StorageType) []StorageConfig {
|
||||
var storages []StorageConfig
|
||||
for _, storage := range c.Storages {
|
||||
if storage.GetType() == storageType {
|
||||
storages = append(storages, storage)
|
||||
}
|
||||
}
|
||||
return storages
|
||||
}
|
||||
|
||||
func (c *Config) GetStorageByName(name string) StorageConfig {
|
||||
for _, storage := range c.Storages {
|
||||
if storage.GetName() == name {
|
||||
return storage
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type LocalStorageConfig struct {
|
||||
NewStorageConfig
|
||||
BasePath string `toml:"base_path" mapstructure:"base_path" json:"base_path"`
|
||||
}
|
||||
|
||||
func (l *LocalStorageConfig) Validate() error {
|
||||
if l.BasePath == "" {
|
||||
return fmt.Errorf("path is required for local storage")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *LocalStorageConfig) GetType() types.StorageType {
|
||||
return types.StorageTypeLocal
|
||||
}
|
||||
|
||||
func (l *LocalStorageConfig) GetName() string {
|
||||
return l.Name
|
||||
}
|
||||
|
||||
type AlistStorageConfig struct {
|
||||
NewStorageConfig
|
||||
URL string `toml:"url" mapstructure:"url" json:"url"`
|
||||
Username string `toml:"username" mapstructure:"username" json:"username"`
|
||||
Password string `toml:"password" mapstructure:"password" json:"password"`
|
||||
Token string `toml:"token" mapstructure:"token" json:"token"`
|
||||
BasePath string `toml:"base_path" mapstructure:"base_path" json:"base_path"`
|
||||
TokenExp int64 `toml:"token_exp" mapstructure:"token_exp" json:"token_exp"`
|
||||
}
|
||||
|
||||
func (a *AlistStorageConfig) Validate() error {
|
||||
if a.URL == "" {
|
||||
return fmt.Errorf("url is required for alist storage")
|
||||
}
|
||||
if a.Token == "" && (a.Username == "" || a.Password == "") {
|
||||
return fmt.Errorf("username and password or token is required for alist storage")
|
||||
}
|
||||
if a.BasePath == "" {
|
||||
return fmt.Errorf("base_path is required for alist storage")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *AlistStorageConfig) GetType() types.StorageType {
|
||||
return types.StorageTypeAlist
|
||||
}
|
||||
|
||||
func (a *AlistStorageConfig) GetName() string {
|
||||
return a.Name
|
||||
}
|
||||
|
||||
type WebdavStorageConfig struct {
|
||||
NewStorageConfig
|
||||
URL string `toml:"url" mapstructure:"url" json:"url"`
|
||||
Username string `toml:"username" mapstructure:"username" json:"username"`
|
||||
Password string `toml:"password" mapstructure:"password" json:"password"`
|
||||
BasePath string `toml:"base_path" mapstructure:"base_path" json:"base_path"`
|
||||
}
|
||||
|
||||
func (w *WebdavStorageConfig) Validate() error {
|
||||
if w.URL == "" {
|
||||
return fmt.Errorf("url is required for webdav storage")
|
||||
}
|
||||
if w.Username == "" || w.Password == "" {
|
||||
return fmt.Errorf("username and password is required for webdav storage")
|
||||
}
|
||||
if w.BasePath == "" {
|
||||
return fmt.Errorf("base_path is required for webdav storage")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *WebdavStorageConfig) GetType() types.StorageType {
|
||||
return types.StorageTypeWebdav
|
||||
}
|
||||
|
||||
func (w *WebdavStorageConfig) GetName() string {
|
||||
return w.Name
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/config/storage"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
@@ -17,13 +18,11 @@ type Config struct {
|
||||
|
||||
Users []userConfig `toml:"users" mapstructure:"users" json:"users"`
|
||||
|
||||
Temp tempConfig `toml:"temp" mapstructure:"temp"`
|
||||
Log logConfig `toml:"log" mapstructure:"log"`
|
||||
DB dbConfig `toml:"db" mapstructure:"db"`
|
||||
Telegram telegramConfig `toml:"telegram" mapstructure:"telegram"`
|
||||
Storages []StorageConfig `toml:"-" mapstructure:"-" json:"storages"`
|
||||
// Deprecated
|
||||
DeprecatedStorage deprecatedStorageConfig `toml:"storage" mapstructure:"storage"`
|
||||
Temp tempConfig `toml:"temp" mapstructure:"temp"`
|
||||
Log logConfig `toml:"log" mapstructure:"log"`
|
||||
DB dbConfig `toml:"db" mapstructure:"db"`
|
||||
Telegram telegramConfig `toml:"telegram" mapstructure:"telegram"`
|
||||
Storages []storage.StorageConfig `toml:"-" mapstructure:"-" json:"storages"`
|
||||
}
|
||||
|
||||
type tempConfig struct {
|
||||
@@ -58,6 +57,15 @@ type proxyConfig struct {
|
||||
|
||||
var Cfg *Config
|
||||
|
||||
func (c Config) GetStorageByName(name string) storage.StorageConfig {
|
||||
for _, storage := range c.Storages {
|
||||
if storage.GetName() == name {
|
||||
return storage
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func Init() error {
|
||||
viper.SetConfigName("config")
|
||||
viper.AddConfigPath(".")
|
||||
@@ -102,38 +110,12 @@ func Init() error {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if Cfg.Telegram.Admins != nil {
|
||||
fmt.Println("警告: 你正在使用旧版 Telegram 管理员配置, 该配置下的用户将可用所有存储.\ntelegram.admins 未来版本将会被废弃, 请参考新的配置文件模板, 使用 users 配置替代.")
|
||||
for _, admin := range Cfg.Telegram.Admins {
|
||||
found := false
|
||||
for _, user := range Cfg.Users {
|
||||
if user.ID == admin {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if found {
|
||||
continue
|
||||
}
|
||||
Cfg.Users = append(Cfg.Users, userConfig{
|
||||
ID: admin,
|
||||
Storages: []string{},
|
||||
Blacklist: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
storagesConfig, err := LoadStorageConfigs(viper.GetViper())
|
||||
storagesConfig, err := storage.LoadStorageConfigs(viper.GetViper())
|
||||
if err != nil {
|
||||
return fmt.Errorf("error loading storage configs: %w", err)
|
||||
}
|
||||
Cfg.Storages = storagesConfig
|
||||
|
||||
if Cfg.DeprecatedStorage != (deprecatedStorageConfig{}) {
|
||||
fmt.Println("\n警告: 你正在使用旧版存储配置, 未来版本将会被废弃.\n请参考新的配置文件模板.")
|
||||
transformDeprecatedStorageConfig()
|
||||
}
|
||||
|
||||
storageNames := make(map[string]struct{})
|
||||
for _, storage := range Cfg.Storages {
|
||||
if _, ok := storageNames[storage.GetName()]; ok {
|
||||
|
||||
33
core/core.go
33
core/core.go
@@ -6,28 +6,35 @@ import (
|
||||
"fmt"
|
||||
|
||||
"github.com/celestix/gotgproto/ext"
|
||||
"github.com/gotd/td/telegram/downloader"
|
||||
"github.com/gotd/td/tg"
|
||||
"github.com/krau/SaveAny-Bot/common"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/logger"
|
||||
"github.com/krau/SaveAny-Bot/queue"
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
)
|
||||
|
||||
var Downloader *downloader.Downloader
|
||||
|
||||
func init() {
|
||||
Downloader = downloader.NewDownloader().WithPartSize(1024 * 1024)
|
||||
}
|
||||
|
||||
func worker(queue *queue.TaskQueue, semaphore chan struct{}) {
|
||||
for {
|
||||
semaphore <- struct{}{}
|
||||
task := queue.GetTask()
|
||||
logger.L.Debugf("Got task: %s", task.String())
|
||||
common.Log.Debugf("Got task: %s", task.String())
|
||||
|
||||
switch task.Status {
|
||||
case types.Pending:
|
||||
logger.L.Infof("Processing task: %s", task.String())
|
||||
common.Log.Infof("Processing task: %s", task.String())
|
||||
if err := processPendingTask(task); err != nil {
|
||||
task.Error = err
|
||||
if errors.Is(err, context.Canceled) {
|
||||
task.Status = types.Canceled
|
||||
} else {
|
||||
logger.L.Errorf("Failed to do task: %s", err)
|
||||
common.Log.Errorf("Failed to do task: %s", err)
|
||||
task.Status = types.Failed
|
||||
}
|
||||
} else {
|
||||
@@ -35,10 +42,10 @@ func worker(queue *queue.TaskQueue, semaphore chan struct{}) {
|
||||
}
|
||||
queue.AddTask(task)
|
||||
case types.Succeeded:
|
||||
logger.L.Infof("Task succeeded: %s", task.String())
|
||||
common.Log.Infof("Task succeeded: %s", task.String())
|
||||
extCtx, ok := task.Ctx.(*ext.Context)
|
||||
if !ok {
|
||||
logger.L.Errorf("Context is not *ext.Context: %T", task.Ctx)
|
||||
common.Log.Errorf("Context is not *ext.Context: %T", task.Ctx)
|
||||
} else {
|
||||
extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
||||
Message: fmt.Sprintf("文件保存成功\n [%s]: %s", task.StorageName, task.StoragePath),
|
||||
@@ -46,10 +53,10 @@ func worker(queue *queue.TaskQueue, semaphore chan struct{}) {
|
||||
})
|
||||
}
|
||||
case types.Failed:
|
||||
logger.L.Errorf("Task failed: %s", task.String())
|
||||
common.Log.Errorf("Task failed: %s", task.String())
|
||||
extCtx, ok := task.Ctx.(*ext.Context)
|
||||
if !ok {
|
||||
logger.L.Errorf("Context is not *ext.Context: %T", task.Ctx)
|
||||
common.Log.Errorf("Context is not *ext.Context: %T", task.Ctx)
|
||||
} else {
|
||||
extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
||||
Message: "文件保存失败\n" + task.Error.Error(),
|
||||
@@ -57,10 +64,10 @@ func worker(queue *queue.TaskQueue, semaphore chan struct{}) {
|
||||
})
|
||||
}
|
||||
case types.Canceled:
|
||||
logger.L.Infof("Task canceled: %s", task.String())
|
||||
common.Log.Infof("Task canceled: %s", task.String())
|
||||
extCtx, ok := task.Ctx.(*ext.Context)
|
||||
if !ok {
|
||||
logger.L.Errorf("Context is not *ext.Context: %T", task.Ctx)
|
||||
common.Log.Errorf("Context is not *ext.Context: %T", task.Ctx)
|
||||
} else {
|
||||
extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
||||
Message: "任务已取消",
|
||||
@@ -68,16 +75,16 @@ func worker(queue *queue.TaskQueue, semaphore chan struct{}) {
|
||||
})
|
||||
}
|
||||
default:
|
||||
logger.L.Errorf("Unknown task status: %s", task.Status)
|
||||
common.Log.Errorf("Unknown task status: %s", task.Status)
|
||||
}
|
||||
<-semaphore
|
||||
logger.L.Debugf("Task done: %s; status: %s", task.String(), task.Status)
|
||||
common.Log.Debugf("Task done: %s; status: %s", task.String(), task.Status)
|
||||
queue.DoneTask(task)
|
||||
}
|
||||
}
|
||||
|
||||
func Run() {
|
||||
logger.L.Info("Start processing tasks...")
|
||||
common.Log.Info("Start processing tasks...")
|
||||
semaphore := make(chan struct{}, config.Cfg.Workers)
|
||||
for i := 0; i < config.Cfg.Workers; i++ {
|
||||
go worker(queue.Queue, semaphore)
|
||||
|
||||
227
core/download.go
227
core/download.go
@@ -2,36 +2,37 @@ package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/celestix/gotgproto/ext"
|
||||
"github.com/celestix/telegraph-go/v2"
|
||||
"github.com/duke-git/lancet/v2/fileutil"
|
||||
"github.com/gotd/td/telegram/message/entity"
|
||||
"github.com/gotd/td/telegram/message/styling"
|
||||
"github.com/gotd/td/tg"
|
||||
"github.com/krau/SaveAny-Bot/bot"
|
||||
"github.com/krau/SaveAny-Bot/common"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/logger"
|
||||
"github.com/krau/SaveAny-Bot/storage"
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
func processPendingTask(task *types.Task) error {
|
||||
logger.L.Debugf("Start processing task: %s", task.String())
|
||||
common.Log.Debugf("Start processing task: %s", task.String())
|
||||
if task.FileName() == "" {
|
||||
task.File.FileName = fmt.Sprintf("%d_%d_%s", task.FileChatID, task.FileMessageID, task.File.Hash())
|
||||
}
|
||||
cacheDestPath := filepath.Join(config.Cfg.Temp.BasePath, task.FileName())
|
||||
cacheDestPath, err := filepath.Abs(cacheDestPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("处理路径失败: %w", err)
|
||||
}
|
||||
if err := fileutil.CreateDir(filepath.Dir(cacheDestPath)); err != nil {
|
||||
return fmt.Errorf("创建目录失败: %w", err)
|
||||
}
|
||||
|
||||
if task.StoragePath == "" {
|
||||
task.StoragePath = task.File.FileName
|
||||
task.StoragePath = task.FileName()
|
||||
}
|
||||
|
||||
taskStorage, err := storage.GetStorageByUserIDAndName(task.UserID, task.StorageName)
|
||||
@@ -40,10 +41,6 @@ func processPendingTask(task *types.Task) error {
|
||||
}
|
||||
task.StoragePath = taskStorage.JoinStoragePath(*task)
|
||||
|
||||
if task.File.FileSize == 0 {
|
||||
return processPhoto(task, taskStorage, cacheDestPath)
|
||||
}
|
||||
|
||||
ctx, ok := task.Ctx.(*ext.Context)
|
||||
if !ok {
|
||||
return fmt.Errorf("context is not *ext.Context: %T", task.Ctx)
|
||||
@@ -52,38 +49,69 @@ func processPendingTask(task *types.Task) error {
|
||||
cancelCtx, cancel := context.WithCancel(ctx)
|
||||
task.Cancel = cancel
|
||||
|
||||
downloadBuider := Downloader.Download(bot.Client.API(), task.File.Location).WithThreads(getTaskThreads(task.File.FileSize))
|
||||
if task.IsTelegraph {
|
||||
return processTelegraph(ctx, cancelCtx, task, taskStorage)
|
||||
}
|
||||
|
||||
taskStreamStorage, isStreamStorage := taskStorage.(storage.StreamStorage)
|
||||
if task.File.FileSize == 0 {
|
||||
return processPhoto(task, taskStorage)
|
||||
}
|
||||
|
||||
downloadBuilder := Downloader.Download(bot.Client.API(), task.File.Location).WithThreads(getTaskThreads(task.File.FileSize))
|
||||
|
||||
notsupportStreamStorage, notsupportStream := taskStorage.(storage.StorageNotSupportStream)
|
||||
cancelMarkUp := getCancelTaskMarkup(task)
|
||||
if config.Cfg.Stream {
|
||||
if !isStreamStorage {
|
||||
logger.L.Warnf("存储 %s 不支持流式上传", taskStorage.Name())
|
||||
} else {
|
||||
if !notsupportStream {
|
||||
text, entities := buildProgressMessageEntity(task, 0, task.StartTime, 0)
|
||||
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
||||
Message: text,
|
||||
Entities: entities,
|
||||
ID: task.ReplyMessageID,
|
||||
ReplyMarkup: getCancelTaskMarkup(task),
|
||||
ReplyMarkup: cancelMarkUp,
|
||||
})
|
||||
uploadStream, err := taskStreamStorage.NewUploadStream(cancelCtx, task.StoragePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建上传流失败: %w", err)
|
||||
}
|
||||
defer uploadStream.Close()
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
defer pr.Close()
|
||||
|
||||
task.StartTime = time.Now()
|
||||
progressCallback := buildProgressCallback(ctx, task, getProgressUpdateCount(task.File.FileSize))
|
||||
|
||||
progressStream := NewProgressStream(uploadStream, task.File.FileSize, progressCallback)
|
||||
progressStream := NewProgressStream(pw, task.File.FileSize, progressCallback)
|
||||
|
||||
_, err = downloadBuider.Stream(cancelCtx, progressStream)
|
||||
if err != nil {
|
||||
return fmt.Errorf("下载文件失败: %w", err)
|
||||
eg, uploadCtx := errgroup.WithContext(cancelCtx)
|
||||
|
||||
eg.Go(func() error {
|
||||
return taskStorage.Save(uploadCtx, pr, task.StoragePath)
|
||||
})
|
||||
eg.Go(func() error {
|
||||
_, err := downloadBuilder.Stream(uploadCtx, progressStream)
|
||||
if closeErr := pw.CloseWithError(err); closeErr != nil {
|
||||
common.Log.Errorf("Failed to close pipe writer: %v", closeErr)
|
||||
}
|
||||
return err
|
||||
})
|
||||
if err := eg.Wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
logger.L.Infof("Uploaded file: %s", task.StoragePath)
|
||||
|
||||
return nil
|
||||
}
|
||||
common.Log.Warnf("存储 %s 不支持流式传输: %s", task.StorageName, notsupportStreamStorage.NotSupportStream())
|
||||
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
||||
Message: fmt.Sprintf("存储 %s 不支持流式传输: %s\n正在使用普通下载...", task.StorageName, notsupportStreamStorage.NotSupportStream()),
|
||||
ID: task.ReplyMessageID,
|
||||
ReplyMarkup: cancelMarkUp,
|
||||
})
|
||||
}
|
||||
|
||||
cacheDestPath := filepath.Join(config.Cfg.Temp.BasePath, task.FileName())
|
||||
cacheDestPath, err = filepath.Abs(cacheDestPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("处理路径失败: %w", err)
|
||||
}
|
||||
if err := fileutil.CreateDir(filepath.Dir(cacheDestPath)); err != nil {
|
||||
return fmt.Errorf("创建目录失败: %w", err)
|
||||
}
|
||||
|
||||
text, entities := buildProgressMessageEntity(task, 0, task.StartTime, 0)
|
||||
@@ -91,7 +119,7 @@ func processPendingTask(task *types.Task) error {
|
||||
Message: text,
|
||||
Entities: entities,
|
||||
ID: task.ReplyMessageID,
|
||||
ReplyMarkup: getCancelTaskMarkup(task),
|
||||
ReplyMarkup: cancelMarkUp,
|
||||
})
|
||||
|
||||
progressCallback := buildProgressCallback(ctx, task, getProgressUpdateCount(task.File.FileSize))
|
||||
@@ -101,7 +129,7 @@ func processPendingTask(task *types.Task) error {
|
||||
}
|
||||
defer dest.Close()
|
||||
task.StartTime = time.Now()
|
||||
_, err = downloadBuider.Parallel(cancelCtx, dest)
|
||||
_, err = downloadBuilder.Parallel(cancelCtx, dest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("下载文件失败: %w", err)
|
||||
}
|
||||
@@ -109,11 +137,140 @@ func processPendingTask(task *types.Task) error {
|
||||
|
||||
fixTaskFileExt(task, cacheDestPath)
|
||||
|
||||
logger.L.Infof("Downloaded file: %s", cacheDestPath)
|
||||
common.Log.Infof("Downloaded file: %s", cacheDestPath)
|
||||
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
||||
Message: fmt.Sprintf("下载完成: %s\n正在转存文件...", task.FileName()),
|
||||
ID: task.ReplyMessageID,
|
||||
})
|
||||
|
||||
return saveFileWithRetry(cancelCtx, task, taskStorage, cacheDestPath)
|
||||
return saveFileWithRetry(cancelCtx, task.StoragePath, taskStorage, cacheDestPath)
|
||||
}
|
||||
|
||||
func processTelegraph(extCtx *ext.Context, cancelCtx context.Context, task *types.Task, taskStorage storage.Storage) error {
|
||||
if bot.TelegraphClient == nil {
|
||||
return fmt.Errorf("telegraph client is not initialized")
|
||||
}
|
||||
tgphUrl := task.TelegraphURL
|
||||
tgphPath := strings.Split(tgphUrl, "/")[len(strings.Split(tgphUrl, "/"))-1]
|
||||
if tgphUrl == "" || tgphPath == "" {
|
||||
return fmt.Errorf("invalid telegraph url")
|
||||
}
|
||||
entityBuilder := entity.Builder{}
|
||||
text := fmt.Sprintf("正在下载 Telegraph \n文件夹: %s\n保存路径: %s",
|
||||
task.FileName(),
|
||||
fmt.Sprintf("[%s]:%s", task.StorageName, task.StoragePath),
|
||||
)
|
||||
var entities []tg.MessageEntityClass
|
||||
if err := styling.Perform(&entityBuilder,
|
||||
styling.Plain("正在下载 Telegraph \n文件夹: "),
|
||||
styling.Code(task.FileName()),
|
||||
styling.Plain("\n保存路径: "),
|
||||
styling.Code(fmt.Sprintf("[%s]:%s", task.StorageName, task.StoragePath)),
|
||||
); err != nil {
|
||||
common.Log.Errorf("Failed to build entities: %s", err)
|
||||
}
|
||||
|
||||
extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
||||
Message: text,
|
||||
Entities: entities,
|
||||
ID: task.ReplyMessageID,
|
||||
ReplyMarkup: getCancelTaskMarkup(task),
|
||||
})
|
||||
|
||||
resultCh := make(chan error)
|
||||
go func() {
|
||||
page, err := bot.TelegraphClient.GetPage(tgphPath, true)
|
||||
if err != nil {
|
||||
resultCh <- fmt.Errorf("获取 telegraph 页面失败: %w", err)
|
||||
return
|
||||
}
|
||||
imgs := make([]string, 0)
|
||||
for _, element := range page.Content {
|
||||
var node telegraph.NodeElement
|
||||
data, err := json.Marshal(element)
|
||||
if err != nil {
|
||||
common.Log.Errorf("Failed to marshal element: %s", err)
|
||||
continue
|
||||
}
|
||||
err = json.Unmarshal(data, &node)
|
||||
if err != nil {
|
||||
common.Log.Errorf("Failed to unmarshal element: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if len(node.Children) != 0 {
|
||||
for _, child := range node.Children {
|
||||
imgs = append(imgs, getNodeImages(child)...)
|
||||
}
|
||||
}
|
||||
|
||||
if node.Tag == "img" {
|
||||
if src, ok := node.Attrs["src"]; ok {
|
||||
imgs = append(imgs, src)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
if len(imgs) == 0 {
|
||||
resultCh <- fmt.Errorf("没有找到图片")
|
||||
return
|
||||
}
|
||||
hc := bot.TelegraphClient.HttpClient
|
||||
eg, ectx := errgroup.WithContext(cancelCtx)
|
||||
eg.SetLimit(config.Cfg.Workers) // TODO: use a new config field for this
|
||||
for i, img := range imgs {
|
||||
if strings.HasPrefix(img, "/file/") {
|
||||
img = "https://telegra.ph" + img
|
||||
}
|
||||
eg.Go(func() error {
|
||||
var lastErr error
|
||||
for attempt := range config.Cfg.Retry {
|
||||
if attempt > 0 {
|
||||
retryDelay := time.Duration(attempt*attempt) * time.Second
|
||||
select {
|
||||
case <-ectx.Done():
|
||||
return ectx.Err()
|
||||
case <-time.After(retryDelay):
|
||||
}
|
||||
common.Log.Debugf("Retrying to download image %s (attempt %d)", img, attempt+1)
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ectx, http.MethodGet, img, nil)
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("创建请求失败: %w", err)
|
||||
continue
|
||||
}
|
||||
resp, err := hc.Do(req)
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("发送请求失败: %w", err)
|
||||
continue
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
lastErr = fmt.Errorf("请求图片失败: %s", resp.Status)
|
||||
continue
|
||||
}
|
||||
targetPath := path.Join(task.StoragePath, fmt.Sprintf("%d%s", i+1, path.Ext(img)))
|
||||
err = taskStorage.Save(ectx, resp.Body, targetPath)
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("保存图片失败: %w", err)
|
||||
continue
|
||||
}
|
||||
common.Log.Infof("Saved image: %s", targetPath)
|
||||
return nil
|
||||
}
|
||||
return lastErr
|
||||
})
|
||||
}
|
||||
if err := eg.Wait(); err != nil {
|
||||
resultCh <- err
|
||||
return
|
||||
}
|
||||
resultCh <- nil
|
||||
}()
|
||||
select {
|
||||
case err := <-resultCh:
|
||||
return err
|
||||
case <-cancelCtx.Done():
|
||||
return cancelCtx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
80
core/download_test.go
Normal file
80
core/download_test.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/celestix/telegraph-go/v2"
|
||||
)
|
||||
|
||||
func TestGetImgSrcs(t *testing.T) {
|
||||
complexStructure := telegraph.NodeElement{
|
||||
Tag: "div",
|
||||
Children: []telegraph.Node{
|
||||
telegraph.NodeElement{
|
||||
Tag: "figure",
|
||||
Children: []telegraph.Node{
|
||||
telegraph.NodeElement{
|
||||
Tag: "img",
|
||||
Attrs: map[string]string{
|
||||
"src": "https://example.com/image1.png",
|
||||
},
|
||||
},
|
||||
telegraph.NodeElement{
|
||||
Tag: "p",
|
||||
Children: []telegraph.Node{
|
||||
"A text node",
|
||||
},
|
||||
},
|
||||
telegraph.NodeElement{
|
||||
Tag: "figure",
|
||||
Children: []telegraph.Node{
|
||||
telegraph.NodeElement{
|
||||
Tag: "img",
|
||||
Attrs: map[string]string{
|
||||
"src": "https://example.com/image2.png",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
telegraph.NodeElement{
|
||||
Tag: "img",
|
||||
Attrs: map[string]string{
|
||||
"src": "https://example.com/image3.png",
|
||||
},
|
||||
},
|
||||
"text node",
|
||||
telegraph.NodeElement{
|
||||
Tag: "div",
|
||||
Children: []telegraph.Node{
|
||||
telegraph.NodeElement{
|
||||
Tag: "span",
|
||||
Children: []telegraph.Node{
|
||||
telegraph.NodeElement{
|
||||
Tag: "img",
|
||||
Attrs: map[string]string{
|
||||
"src": "https://example.com/image4.png",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
expected := []string{
|
||||
"https://example.com/image1.png",
|
||||
"https://example.com/image2.png",
|
||||
"https://example.com/image3.png",
|
||||
"https://example.com/image4.png",
|
||||
}
|
||||
|
||||
got := getNodeImages(complexStructure)
|
||||
|
||||
if !reflect.DeepEqual(expected, got) {
|
||||
t.Errorf("expected %v,got %v", expected, got)
|
||||
}
|
||||
}
|
||||
@@ -1,9 +0,0 @@
|
||||
package core
|
||||
|
||||
import "github.com/gotd/td/telegram/downloader"
|
||||
|
||||
var Downloader *downloader.Downloader
|
||||
|
||||
func init() {
|
||||
Downloader = downloader.NewDownloader().WithPartSize(1024 * 1024)
|
||||
}
|
||||
@@ -1,7 +1,9 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
@@ -9,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/celestix/gotgproto/ext"
|
||||
"github.com/celestix/telegraph-go/v2"
|
||||
"github.com/gabriel-vasile/mimetype"
|
||||
"github.com/gotd/td/telegram/message/entity"
|
||||
"github.com/gotd/td/telegram/message/styling"
|
||||
@@ -16,24 +19,38 @@ import (
|
||||
"github.com/krau/SaveAny-Bot/bot"
|
||||
"github.com/krau/SaveAny-Bot/common"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/logger"
|
||||
"github.com/krau/SaveAny-Bot/storage"
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
)
|
||||
|
||||
func saveFileWithRetry(ctx context.Context, task *types.Task, taskStorage storage.Storage, localFilePath string) error {
|
||||
func saveFileWithRetry(ctx context.Context, storagePath string, taskStorage storage.Storage, cacheFilePath string) error {
|
||||
file, err := os.Open(cacheFilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open cache file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
fileStat, err := file.Stat()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get file stat: %w", err)
|
||||
}
|
||||
vctx := context.WithValue(ctx, types.ContextKeyContentLength, fileStat.Size())
|
||||
for i := 0; i <= config.Cfg.Retry; i++ {
|
||||
if err := ctx.Err(); err != nil {
|
||||
if err := vctx.Err(); err != nil {
|
||||
return fmt.Errorf("context canceled while saving file: %w", err)
|
||||
}
|
||||
if err := taskStorage.Save(ctx, localFilePath, task.StoragePath); err != nil {
|
||||
file, err := os.Open(cacheFilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open cache file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
if err := taskStorage.Save(vctx, file, storagePath); err != nil {
|
||||
if i == config.Cfg.Retry {
|
||||
return fmt.Errorf("failed to save file: %w", err)
|
||||
}
|
||||
logger.L.Errorf("Failed to save file: %s, retrying...", err)
|
||||
common.Log.Errorf("Failed to save file: %s, retrying...", err)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("context canceled during retry delay: %w", ctx.Err())
|
||||
case <-vctx.Done():
|
||||
return fmt.Errorf("context canceled during retry delay: %w", vctx.Err())
|
||||
case <-time.After(time.Duration(i*500) * time.Millisecond):
|
||||
}
|
||||
continue
|
||||
@@ -43,7 +60,7 @@ func saveFileWithRetry(ctx context.Context, task *types.Task, taskStorage storag
|
||||
return nil
|
||||
}
|
||||
|
||||
func processPhoto(task *types.Task, taskStorage storage.Storage, cachePath string) error {
|
||||
func processPhoto(task *types.Task, taskStorage storage.Storage) error {
|
||||
res, err := bot.Client.API().UploadGetFile(task.Ctx, &tg.UploadGetFileRequest{
|
||||
Location: task.File.Location,
|
||||
Offset: 0,
|
||||
@@ -58,15 +75,9 @@ func processPhoto(task *types.Task, taskStorage storage.Storage, cachePath strin
|
||||
return fmt.Errorf("unexpected type %T", res)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(cachePath, result.Bytes, os.ModePerm); err != nil {
|
||||
return fmt.Errorf("failed to write file: %w", err)
|
||||
}
|
||||
common.Log.Infof("Downloaded photo: %s", task.FileName())
|
||||
|
||||
defer cleanCacheFile(cachePath)
|
||||
|
||||
logger.L.Infof("Downloaded file: %s", cachePath)
|
||||
|
||||
return saveFileWithRetry(task.Ctx, task, taskStorage, cachePath)
|
||||
return taskStorage.Save(task.Ctx, bytes.NewReader(result.Bytes), task.StoragePath)
|
||||
}
|
||||
|
||||
func cleanCacheFile(destPath string) {
|
||||
@@ -74,7 +85,7 @@ func cleanCacheFile(destPath string) {
|
||||
common.RmFileAfter(destPath, time.Duration(config.Cfg.Temp.CacheTTL)*time.Second)
|
||||
} else {
|
||||
if err := os.Remove(destPath); err != nil {
|
||||
logger.L.Errorf("Failed to purge file: %s", err)
|
||||
common.Log.Errorf("Failed to purge file: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -120,7 +131,7 @@ func buildProgressMessageEntity(task *types.Task, bytesRead int64, startTime tim
|
||||
styling.Plain("\n当前进度: "),
|
||||
styling.Bold(fmt.Sprintf("%.2f%%", progress)),
|
||||
); err != nil {
|
||||
logger.L.Errorf("Failed to build entities: %s", err)
|
||||
common.Log.Errorf("Failed to build entities: %s", err)
|
||||
return text, entities
|
||||
}
|
||||
return entityBuilder.Complete()
|
||||
@@ -129,7 +140,7 @@ func buildProgressMessageEntity(task *types.Task, bytesRead int64, startTime tim
|
||||
func buildProgressCallback(ctx *ext.Context, task *types.Task, updateCount int) func(bytesRead, contentLength int64) {
|
||||
return func(bytesRead, contentLength int64) {
|
||||
progress := float64(bytesRead) / float64(contentLength) * 100
|
||||
logger.L.Tracef("Downloading %s: %.2f%%", task.String(), progress)
|
||||
common.Log.Tracef("Downloading %s: %.2f%%", task.String(), progress)
|
||||
progressInt := int(progress)
|
||||
if task.File.FileSize < 1024*1024*50 || progressInt == 0 || progressInt%int(100/updateCount) != 0 {
|
||||
return
|
||||
@@ -154,7 +165,7 @@ func fixTaskFileExt(task *types.Task, localFilePath string) {
|
||||
if path.Ext(task.FileName()) == "" {
|
||||
mimeType, err := mimetype.DetectFile(localFilePath)
|
||||
if err != nil {
|
||||
logger.L.Errorf("Failed to detect mime type: %s", err)
|
||||
common.Log.Errorf("Failed to detect mime type: %s", err)
|
||||
} else {
|
||||
task.File.FileName = fmt.Sprintf("%s%s", task.FileName(), mimeType.Extension())
|
||||
task.StoragePath = fmt.Sprintf("%s%s", task.StoragePath, mimeType.Extension())
|
||||
@@ -258,3 +269,27 @@ func NewProgressStream(writer io.Writer, size int64, callback func(bytesRead, co
|
||||
interval: interval,
|
||||
}
|
||||
}
|
||||
|
||||
func getNodeImages(node telegraph.Node) []string {
|
||||
var srcs []string
|
||||
|
||||
var nodeElement telegraph.NodeElement
|
||||
data, err := json.Marshal(node)
|
||||
if err != nil {
|
||||
return srcs
|
||||
}
|
||||
err = json.Unmarshal(data, &nodeElement)
|
||||
if err != nil {
|
||||
return srcs
|
||||
}
|
||||
|
||||
if nodeElement.Tag == "img" {
|
||||
if src, exists := nodeElement.Attrs["src"]; exists {
|
||||
srcs = append(srcs, src)
|
||||
}
|
||||
}
|
||||
for _, child := range nodeElement.Children {
|
||||
srcs = append(srcs, getNodeImages(child)...)
|
||||
}
|
||||
return srcs
|
||||
}
|
||||
|
||||
18
dao/db.go
18
dao/db.go
@@ -7,8 +7,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/glebarez/sqlite"
|
||||
"github.com/krau/SaveAny-Bot/common"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/logger"
|
||||
"gorm.io/gorm"
|
||||
glogger "gorm.io/gorm/logger"
|
||||
)
|
||||
@@ -17,12 +17,12 @@ var db *gorm.DB
|
||||
|
||||
func Init() {
|
||||
if err := os.MkdirAll(filepath.Dir(config.Cfg.DB.Path), 0755); err != nil {
|
||||
logger.L.Fatal("Failed to create data directory: ", err)
|
||||
common.Log.Fatal("Failed to create data directory: ", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
var err error
|
||||
db, err = gorm.Open(sqlite.Open(config.Cfg.DB.Path), &gorm.Config{
|
||||
Logger: glogger.New(logger.L, glogger.Config{
|
||||
Logger: glogger.New(common.Log, glogger.Config{
|
||||
Colorful: true,
|
||||
SlowThreshold: time.Second * 5,
|
||||
LogLevel: glogger.Error,
|
||||
@@ -32,16 +32,16 @@ func Init() {
|
||||
PrepareStmt: true,
|
||||
})
|
||||
if err != nil {
|
||||
logger.L.Fatal("Failed to open database: ", err)
|
||||
common.Log.Fatal("Failed to open database: ", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
logger.L.Debug("Database connected")
|
||||
common.Log.Debug("Database connected")
|
||||
if err := db.AutoMigrate(&ReceivedFile{}, &User{}, &Dir{}, &CallbackData{}); err != nil {
|
||||
logger.L.Fatal("迁移数据库失败, 如果您从旧版本升级, 建议手动删除数据库文件后重试: ", err)
|
||||
common.Log.Fatal("迁移数据库失败, 如果您从旧版本升级, 建议手动删除数据库文件后重试: ", err)
|
||||
}
|
||||
|
||||
if err := syncUsers(); err != nil {
|
||||
logger.L.Fatal("Failed to sync users:", err)
|
||||
common.Log.Fatal("Failed to sync users:", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -66,7 +66,7 @@ func syncUsers() error {
|
||||
if err := CreateUser(cfgID); err != nil {
|
||||
return fmt.Errorf("failed to create user %d: %w", cfgID, err)
|
||||
}
|
||||
logger.L.Infof("创建用户: %d", cfgID)
|
||||
common.Log.Infof("创建用户: %d", cfgID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -75,7 +75,7 @@ func syncUsers() error {
|
||||
if err := DeleteUser(&dbUser); err != nil {
|
||||
return fmt.Errorf("failed to delete user %d: %w", dbID, err)
|
||||
}
|
||||
logger.L.Infof("删除用户: %d", dbID)
|
||||
common.Log.Infof("删除用户: %d", dbID)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -14,6 +14,8 @@ type ReceivedFile struct {
|
||||
ReplyMessageID int
|
||||
ReplyChatID int64
|
||||
FileName string
|
||||
IsTelegraph bool
|
||||
TelegraphURL string
|
||||
}
|
||||
|
||||
type User struct {
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
在 [Release](https://github.com/krau/SaveAny-Bot/releases) 页面下载对应平台的二进制文件.
|
||||
|
||||
在解压后目录新建 `config.toml` 文件, 参考 [config.example.toml](./config.example.toml) 编辑配置文件.
|
||||
在解压后目录新建 `config.toml` 文件, 参考 [config.example.toml](https://github.com/krau/SaveAny-Bot/blob/main/config.example.toml) 编辑配置文件.
|
||||
|
||||
运行:
|
||||
|
||||
@@ -40,7 +40,7 @@ systemctl enable --now saveany-bot
|
||||
|
||||
### 为OpenWrt及衍生系统添加开机自启动服务
|
||||
|
||||
创建文件 ` /etc/init.d/saveanybot` ,参考[saveanybot](./docs/saveanybot)自行修改.
|
||||
创建文件 ` /etc/init.d/saveanybot` ,参考[saveanybot](https://github.com/krau/SaveAny-Bot/blob/main/docs/saveanybot)自行修改.
|
||||
|
||||
`chmod +x /etc/init.d/saveanybot`
|
||||
|
||||
@@ -50,7 +50,7 @@ systemctl enable --now saveany-bot
|
||||
|
||||
### 为OpenWrt及衍生系统添加快捷指令
|
||||
|
||||
创建文件` /usr/bin/sabot` ,参考[sabot](./docs/sabot)自行配置修改,注意此处文件编码仅支持 ANSI 936 .
|
||||
创建文件` /usr/bin/sabot` ,参考[sabot](https://github.com/krau/SaveAny-Bot/blob/main/docs/sabot)自行配置修改,注意此处文件编码仅支持 ANSI 936 .
|
||||
|
||||
`chmod +x /usr/bin/sabot`
|
||||
|
||||
@@ -61,7 +61,7 @@ systemctl enable --now saveany-bot
|
||||
|
||||
### Docker Compose
|
||||
|
||||
下载 [docker-compose.yml](./docker-compose.yml) 文件, 在同目录下新建 `config.toml` 文件, 参考 [config.example.toml](./config.example.toml) 编辑配置文件.
|
||||
下载 [docker-compose.yml](https://github.com/krau/SaveAny-Bot/blob/main/docker-compose.yml) 文件, 在同目录下新建 `config.toml` 文件, 参考 [config.example.toml](https://github.com/krau/SaveAny-Bot/blob/main/config.example.toml) 编辑配置文件.
|
||||
|
||||
启动:
|
||||
|
||||
|
||||
@@ -2,15 +2,11 @@
|
||||
|
||||
## 上传 alist 失败也会显示成功
|
||||
|
||||
这是 alist 的上传实现导致的问题, 上传到 alist 的文件实际上会被 alist 暂存在本地, 在客户端上传结束后 alist 就返回成功, 然后 alist 会在后台将文件上传到对应的存储.
|
||||
|
||||
目前 bot 是根据 alist 的返回判断是否成功, 无法获知 alist 的后台上传任务是否成功.
|
||||
|
||||
在 alist 管理页面适当调整上传分片大小, 为 alist 使用更稳定的网络环境部署, 都可以减少这种情况的发生.
|
||||
|
||||
## Bot 提示下载成功但是 alist 未显示
|
||||
|
||||
检查 alist 后台 > 任务 > 上传 中对应的上传任务的状态, 如果任务状态为成功但目录中不显示, 是由于 alist 缓存了目录结构, 参考文档可以调整缓存时间
|
||||
alist 缓存了目录结构, 参考文档可以调整缓存时间
|
||||
|
||||
https://alist.nn.ci/zh/guide/drivers/common.html#缓存过期
|
||||
|
||||
|
||||
@@ -4,10 +4,14 @@
|
||||
|
||||
Bot 接受两种消息: 文件和链接.
|
||||
|
||||
目前, 链接仅支持公开频道 (具有用户名) 的链接, 例如: `https://t.me/acherkrau/1097`.
|
||||
支持以下链接:
|
||||
|
||||
1. 公开频道 (具有用户名) 的消息链接, 例如: `https://t.me/acherkrau/1097`.
|
||||
|
||||
**即使频道禁止了转发和保存, Bot 依然可以下载其文件.**
|
||||
|
||||
2. Telegra.ph 的文章链接, Bot 将下载其中的所有图片
|
||||
|
||||
## 静默模式 (silent)
|
||||
|
||||
使用 `/silent` 命令可以开关静默模式.
|
||||
@@ -32,4 +36,6 @@ Bot 接受两种消息: 文件和链接.
|
||||
- 网络不稳定时, 任务失败率高.
|
||||
- 无法在中间层对文件进行处理, 例如自动文件类型识别.
|
||||
|
||||
虽然目前 Bot 适配的所有存储端 (Alist, 本地磁盘, Webdav) 都支持 Stream 模式, 但今后可能会有不支持的存储端, 此时即使开启 Stream 模式, Bot 也会自动切换到普通模式.
|
||||
**不支持** Stream 模式的存储端:
|
||||
|
||||
- alist
|
||||
|
||||
21
go.mod
21
go.mod
@@ -5,19 +5,20 @@ go 1.23.5
|
||||
require (
|
||||
github.com/blang/semver v3.5.1+incompatible
|
||||
github.com/celestix/gotgproto v1.0.0-beta20.2
|
||||
github.com/celestix/telegraph-go/v2 v2.0.4
|
||||
github.com/gabriel-vasile/mimetype v1.4.8
|
||||
github.com/gookit/slog v0.5.7
|
||||
github.com/gotd/contrib v0.21.0
|
||||
github.com/gotd/td v0.120.0
|
||||
github.com/minio/minio-go/v7 v7.0.81
|
||||
github.com/rhysd/go-github-selfupdate v1.2.3
|
||||
github.com/spf13/cobra v1.8.1
|
||||
github.com/spf13/viper v1.19.0
|
||||
golang.org/x/net v0.35.0
|
||||
golang.org/x/net v0.37.0
|
||||
golang.org/x/time v0.10.0
|
||||
)
|
||||
|
||||
require (
|
||||
filippo.io/edwards25519 v1.1.0 // indirect
|
||||
github.com/AnimeKaizoku/cacher v1.0.2 // indirect
|
||||
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
@@ -31,7 +32,8 @@ require (
|
||||
github.com/go-faster/jx v1.1.0 // indirect
|
||||
github.com/go-faster/xor v1.0.0 // indirect
|
||||
github.com/go-faster/yaml v0.4.6 // indirect
|
||||
github.com/go-sql-driver/mysql v1.8.1 // indirect
|
||||
github.com/go-ini/ini v1.67.0 // indirect
|
||||
github.com/goccy/go-json v0.10.3 // indirect
|
||||
github.com/google/go-github/v30 v30.1.0 // indirect
|
||||
github.com/google/go-querystring v1.1.0 // indirect
|
||||
github.com/google/pprof v0.0.0-20250128161936-077ca0a936bf // indirect
|
||||
@@ -41,13 +43,16 @@ require (
|
||||
github.com/inconshreveable/go-update v0.0.0-20160112193335-8152e7eb6ccf // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.8 // indirect
|
||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/minio/md5-simd v1.1.2 // indirect
|
||||
github.com/ncruces/go-strftime v0.1.9 // indirect
|
||||
github.com/ogen-go/ogen v1.10.0 // indirect
|
||||
github.com/onsi/gomega v1.36.2 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
github.com/rs/xid v1.6.0 // indirect
|
||||
github.com/segmentio/asm v1.2.0 // indirect
|
||||
github.com/tcnksm/go-gitconfig v0.1.2 // indirect
|
||||
github.com/ulikunitz/xz v0.5.12 // indirect
|
||||
@@ -56,12 +61,11 @@ require (
|
||||
go.opentelemetry.io/otel/trace v1.34.0 // indirect
|
||||
go.uber.org/atomic v1.11.0 // indirect
|
||||
go.uber.org/zap v1.27.0 // indirect
|
||||
golang.org/x/crypto v0.33.0 // indirect
|
||||
golang.org/x/crypto v0.36.0 // indirect
|
||||
golang.org/x/mod v0.23.0 // indirect
|
||||
golang.org/x/oauth2 v0.26.0 // indirect
|
||||
golang.org/x/tools v0.30.0 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
gorm.io/driver/mysql v1.5.6 // indirect
|
||||
modernc.org/libc v1.61.13 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
modernc.org/memory v1.8.2 // indirect
|
||||
@@ -94,11 +98,10 @@ require (
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20250210185358-939b2ce775ac // indirect
|
||||
golang.org/x/sync v0.11.0 // indirect
|
||||
golang.org/x/sys v0.30.0 // indirect
|
||||
golang.org/x/text v0.22.0 // indirect
|
||||
golang.org/x/sync v0.12.0
|
||||
golang.org/x/sys v0.31.0 // indirect
|
||||
golang.org/x/text v0.23.0 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
gorm.io/datatypes v1.2.5
|
||||
gorm.io/gorm v1.25.12
|
||||
)
|
||||
|
||||
72
go.sum
72
go.sum
@@ -1,11 +1,11 @@
|
||||
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
|
||||
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
||||
github.com/AnimeKaizoku/cacher v1.0.2 h1:7Bf5qRylWb7q2Evib0OXlhG37/t7BP2HK/7IyPvSmGQ=
|
||||
github.com/AnimeKaizoku/cacher v1.0.2/go.mod h1:jw0de/b0K6W7Y3T9rHCMGVKUf6oG7hENNcssxYcZTCc=
|
||||
github.com/blang/semver v3.5.1+incompatible h1:cQNTCjp13qL8KC3Nbxr/y2Bqb63oX6wdnnjpJbkM4JQ=
|
||||
github.com/blang/semver v3.5.1+incompatible/go.mod h1:kRBLl5iJ+tD4TcOOxsy/0fnwebNt5EWlYSAyrTnjyyk=
|
||||
github.com/celestix/gotgproto v1.0.0-beta20.2 h1:+WcsKdsyj4xy+TAV+4Sw6zp1xiQrIr4dMnM31+k8NYM=
|
||||
github.com/celestix/gotgproto v1.0.0-beta20.2/go.mod h1:j42ZhBMUke6QyBLvCgx8tA+TL9L3+pq/Q46B+b5+3aU=
|
||||
github.com/celestix/telegraph-go/v2 v2.0.4 h1:w8HWymJFhMSMPjdGoyTh3/NqE3eXAT1njTvelh0338k=
|
||||
github.com/celestix/telegraph-go/v2 v2.0.4/go.mod h1:vu2LtqM7MgOAJ2LDF8XK27DWdd1QYLBfZGhalEh086Y=
|
||||
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
|
||||
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
|
||||
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
@@ -49,17 +49,14 @@ github.com/go-faster/xor v1.0.0 h1:2o8vTOgErSGHP3/7XwA5ib1FTtUsNtwCoLLBjl31X38=
|
||||
github.com/go-faster/xor v1.0.0/go.mod h1:x5CaDY9UKErKzqfRfFZdfu+OSTfoZny3w5Ak7UxcipQ=
|
||||
github.com/go-faster/yaml v0.4.6 h1:lOK/EhI04gCpPgPhgt0bChS6bvw7G3WwI8xxVe0sw9I=
|
||||
github.com/go-faster/yaml v0.4.6/go.mod h1:390dRIvV4zbnO7qC9FGo6YYutc+wyyUSHBgbXL52eXk=
|
||||
github.com/go-ini/ini v1.67.0 h1:z6ZrTEZqSWOTyH2FlglNbNgARyHG8oLW9gMELqKr06A=
|
||||
github.com/go-ini/ini v1.67.0/go.mod h1:ByCAeIL28uOIIG0E3PJtZPDL8WnHpFKFOtgjp+3Ies8=
|
||||
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
|
||||
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
|
||||
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
|
||||
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA=
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0=
|
||||
github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A=
|
||||
github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI=
|
||||
github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA=
|
||||
github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
@@ -97,20 +94,15 @@ github.com/inconshreveable/go-update v0.0.0-20160112193335-8152e7eb6ccf h1:WfD7V
|
||||
github.com/inconshreveable/go-update v0.0.0-20160112193335-8152e7eb6ccf/go.mod h1:hyb9oH7vZsitZCiBt0ZvifOrB+qc8PS5IiilCIb87rg=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA=
|
||||
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw=
|
||||
github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A=
|
||||
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
|
||||
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc=
|
||||
github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0=
|
||||
github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
github.com/klauspost/cpuid/v2 v2.2.8 h1:+StwCXwm9PdpiEkPyzBXIy+M9KUb4ODm0Zarf1kS5BM=
|
||||
github.com/klauspost/cpuid/v2 v2.2.8/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
@@ -124,10 +116,10 @@ github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHP
|
||||
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
|
||||
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||
github.com/microsoft/go-mssqldb v1.7.2 h1:CHkFJiObW7ItKTJfHo1QX7QBBD1iV+mn1eOyRP3b/PA=
|
||||
github.com/microsoft/go-mssqldb v1.7.2/go.mod h1:kOvZKUdrhhFQmxLZqbwUV0rHkNkZpthMITIb2Ko1IoA=
|
||||
github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34=
|
||||
github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM=
|
||||
github.com/minio/minio-go/v7 v7.0.81 h1:SzhMN0TQ6T/xSBu6Nvw3M5M8voM+Ht8RH3hE8S7zxaA=
|
||||
github.com/minio/minio-go/v7 v7.0.81/go.mod h1:84gmIilaX4zcvAWWzJ5Z1WI5axN+hAbM5w25xf8xvC0=
|
||||
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
|
||||
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
|
||||
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
|
||||
@@ -150,6 +142,8 @@ github.com/rhysd/go-github-selfupdate v1.2.3 h1:iaa+J202f+Nc+A8zi75uccC8Wg3omaM7
|
||||
github.com/rhysd/go-github-selfupdate v1.2.3/go.mod h1:mp/N8zj6jFfBQy/XMYoWsmfzxazpPAODuqarmPDe2Rg=
|
||||
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
|
||||
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
|
||||
github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU=
|
||||
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/sagikazarmark/locafero v0.7.0 h1:5MqpDsTGNDhY8sGp0Aowyf0qKsPrhewaLSsFaodPcyo=
|
||||
github.com/sagikazarmark/locafero v0.7.0/go.mod h1:2za3Cg5rMaTMoG/2Ulr9AwtFaIppKXTRYnozin4aB5k=
|
||||
@@ -201,8 +195,8 @@ go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
|
||||
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
|
||||
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
|
||||
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
|
||||
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
||||
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
||||
golang.org/x/exp v0.0.0-20250210185358-939b2ce775ac h1:l5+whBCLH3iH2ZNHYLbAe58bo7yrN4mVcnkHDYz5vvs=
|
||||
golang.org/x/exp v0.0.0-20250210185358-939b2ce775ac/go.mod h1:hH+7mtFmImwwcMvScyxUhjuVHR3HGaDPMn9rMSUUbxo=
|
||||
golang.org/x/mod v0.23.0 h1:Zb7khfcRGKk+kqfxFaP5tZqCnDZMjC5VtUBs87Hr6QM=
|
||||
@@ -211,29 +205,30 @@ golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73r
|
||||
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
|
||||
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
|
||||
golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
|
||||
golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20181106182150-f42d05182288/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.26.0 h1:afQXWNNaeC4nvZ0Ed9XvCCzXM6UHJG7iCg0W4fPqSBE=
|
||||
golang.org/x/oauth2 v0.26.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
|
||||
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
|
||||
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
|
||||
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
|
||||
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
|
||||
golang.org/x/term v0.29.0 h1:L6pJp37ocefwRRtYPKSWOWzOtWSxVajvz2ldH/xi3iU=
|
||||
golang.org/x/term v0.29.0/go.mod h1:6bl4lRlvVuDgSf3179VpIxBF0o10JUpXWOnI7nErv7s=
|
||||
golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y=
|
||||
golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
|
||||
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
|
||||
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
|
||||
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
|
||||
golang.org/x/time v0.10.0 h1:3usCWA8tQn0L8+hFJQNgzpWbd89begxN66o1Ojdn5L4=
|
||||
golang.org/x/time v0.10.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
@@ -255,17 +250,6 @@ gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gorm.io/datatypes v1.2.5 h1:9UogU3jkydFVW1bIVVeoYsTpLRgwDVW3rHfJG6/Ek9I=
|
||||
gorm.io/datatypes v1.2.5/go.mod h1:I5FUdlKpLb5PMqeMQhm30CQ6jXP8Rj89xkTeCSAaAD4=
|
||||
gorm.io/driver/mysql v1.5.6 h1:Ld4mkIickM+EliaQZQx3uOJDJHtrd70MxAUqWqlx3Y8=
|
||||
gorm.io/driver/mysql v1.5.6/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM=
|
||||
gorm.io/driver/postgres v1.5.0 h1:u2FXTy14l45qc3UeCJ7QaAXZmZfDDv0YrthvmRq1l0U=
|
||||
gorm.io/driver/postgres v1.5.0/go.mod h1:FUZXzO+5Uqg5zzwzv4KK49R8lvGIyscBOqYrtI1Ce9A=
|
||||
gorm.io/driver/sqlite v1.5.5 h1:7MDMtUZhV065SilG62E0MquljeArQZNfJnjd9i9gx3E=
|
||||
gorm.io/driver/sqlite v1.5.5/go.mod h1:6NgQ7sQWAIFsPrJJl1lSNSu2TABh0ZZ/zm5fosATavE=
|
||||
gorm.io/driver/sqlserver v1.5.4 h1:xA+Y1KDNspv79q43bPyjDMUgHoYHLhXYmdFcYPobg8g=
|
||||
gorm.io/driver/sqlserver v1.5.4/go.mod h1:+frZ/qYmuna11zHPlh5oc2O6ZA/lS88Keb0XSH1Zh/g=
|
||||
gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
|
||||
gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
|
||||
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
|
||||
modernc.org/cc/v4 v4.24.4 h1:TFkx1s6dCkQpd6dKurBNmpo+G8Zl4Sq/ztJ+2+DEsh0=
|
||||
|
||||
@@ -7,13 +7,11 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/logger"
|
||||
"github.com/krau/SaveAny-Bot/common"
|
||||
config "github.com/krau/SaveAny-Bot/config/storage"
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
)
|
||||
|
||||
@@ -43,36 +41,36 @@ func (a *Alist) Init(cfg config.StorageConfig) error {
|
||||
defer cancel()
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, a.baseURL+"/api/me", nil)
|
||||
if err != nil {
|
||||
logger.L.Fatalf("Failed to create request: %v", err)
|
||||
common.Log.Fatalf("Failed to create request: %v", err)
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Authorization", a.token)
|
||||
|
||||
resp, err := a.client.Do(req)
|
||||
if err != nil {
|
||||
logger.L.Fatalf("Failed to send request: %v", err)
|
||||
common.Log.Fatalf("Failed to send request: %v", err)
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
logger.L.Fatalf("Failed to get alist user info: %s", resp.Status)
|
||||
common.Log.Fatalf("Failed to get alist user info: %s", resp.Status)
|
||||
return err
|
||||
}
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
logger.L.Fatalf("Failed to read response body: %v", err)
|
||||
common.Log.Fatalf("Failed to read response body: %v", err)
|
||||
return err
|
||||
}
|
||||
var meResp meResponse
|
||||
if err := json.Unmarshal(body, &meResp); err != nil {
|
||||
logger.L.Fatalf("Failed to unmarshal me response: %v", err)
|
||||
common.Log.Fatalf("Failed to unmarshal me response: %v", err)
|
||||
return err
|
||||
}
|
||||
if meResp.Code != http.StatusOK {
|
||||
logger.L.Fatalf("Failed to get alist user info: %s", meResp.Message)
|
||||
common.Log.Fatalf("Failed to get alist user info: %s", meResp.Message)
|
||||
return err
|
||||
}
|
||||
logger.L.Debugf("Logged in Alist as %s", meResp.Data.Username)
|
||||
common.Log.Debugf("Logged in Alist as %s", meResp.Data.Username)
|
||||
return nil
|
||||
}
|
||||
a.loginInfo = &loginRequest{
|
||||
@@ -81,10 +79,10 @@ func (a *Alist) Init(cfg config.StorageConfig) error {
|
||||
}
|
||||
|
||||
if err := a.getToken(); err != nil {
|
||||
logger.L.Fatalf("Failed to login to Alist: %v", err)
|
||||
common.Log.Fatalf("Failed to login to Alist: %v", err)
|
||||
return err
|
||||
}
|
||||
logger.L.Debug("Logged in to Alist")
|
||||
common.Log.Debug("Logged in to Alist")
|
||||
|
||||
go a.refreshToken(*alistConfig)
|
||||
return nil
|
||||
@@ -98,28 +96,22 @@ func (a *Alist) Name() string {
|
||||
return a.config.Name
|
||||
}
|
||||
|
||||
func (a *Alist) Save(ctx context.Context, filePath, storagePath string) error {
|
||||
logger.L.Infof("Saving file %s to %s", filePath, storagePath)
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
func (a *Alist) Save(ctx context.Context, reader io.Reader, storagePath string) error {
|
||||
common.Log.Infof("Saving file to %s", storagePath)
|
||||
|
||||
filestat, err := file.Stat()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get file stats: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPut, a.baseURL+"/api/fs/put", file)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPut, a.baseURL+"/api/fs/put", reader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", a.token)
|
||||
req.Header.Set("File-Path", url.PathEscape(storagePath))
|
||||
req.Header.Set("As-Task", "true")
|
||||
req.Header.Set("Content-Type", "application/octet-stream")
|
||||
req.ContentLength = filestat.Size()
|
||||
if length := ctx.Value(types.ContextKeyContentLength); length != nil {
|
||||
length, ok := length.(int64)
|
||||
if ok {
|
||||
req.ContentLength = length
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := a.client.Do(req)
|
||||
if err != nil {
|
||||
@@ -148,91 +140,10 @@ func (a *Alist) Save(ctx context.Context, filePath, storagePath string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Alist) NotSupportStream() string {
|
||||
return "Alist does not support chunked transfer encoding"
|
||||
}
|
||||
|
||||
func (a *Alist) JoinStoragePath(task types.Task) string {
|
||||
return path.Join(a.config.BasePath, task.StoragePath)
|
||||
}
|
||||
|
||||
type uploadStream struct {
|
||||
ctx context.Context
|
||||
client *http.Client
|
||||
token string
|
||||
storagePath string
|
||||
baseURL string
|
||||
pr *io.PipeReader
|
||||
pw *io.PipeWriter
|
||||
errChan chan error
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func (us *uploadStream) Write(p []byte) (int, error) {
|
||||
return us.pw.Write(p)
|
||||
}
|
||||
|
||||
func (us *uploadStream) Close() error {
|
||||
var uploadErr error
|
||||
us.once.Do(func() {
|
||||
if err := us.pw.Close(); err != nil {
|
||||
uploadErr = fmt.Errorf("failed to close pipe writer: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := <-us.errChan; err != nil {
|
||||
uploadErr = err
|
||||
}
|
||||
})
|
||||
return uploadErr
|
||||
}
|
||||
|
||||
func (a *Alist) NewUploadStream(ctx context.Context, storagePath string) (io.WriteCloser, error) {
|
||||
if a.token == "" {
|
||||
if err := a.getToken(); err != nil {
|
||||
return nil, fmt.Errorf("not logged in to Alist: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
|
||||
// 创建上传流对象
|
||||
us := &uploadStream{
|
||||
ctx: ctx,
|
||||
client: a.client,
|
||||
token: a.token,
|
||||
storagePath: storagePath,
|
||||
baseURL: a.baseURL,
|
||||
pr: pr,
|
||||
pw: pw,
|
||||
errChan: make(chan error, 1),
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer close(us.errChan)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPut, a.baseURL+"/api/fs/put", pr)
|
||||
if err != nil {
|
||||
us.errChan <- fmt.Errorf("failed to create request: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", a.token)
|
||||
req.Header.Set("File-Path", url.PathEscape(storagePath))
|
||||
req.Header.Set("As-Task", "true")
|
||||
req.Header.Set("Content-Type", "application/octet-stream")
|
||||
|
||||
resp, err := a.client.Do(req)
|
||||
if err != nil {
|
||||
us.errChan <- fmt.Errorf("failed to send request: %w", err)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
us.errChan <- fmt.Errorf("failed to upload file, status code: %d, response: %s", resp.StatusCode, string(body))
|
||||
return
|
||||
}
|
||||
|
||||
us.errChan <- nil
|
||||
}()
|
||||
|
||||
return us, nil
|
||||
}
|
||||
|
||||
@@ -8,8 +8,8 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/logger"
|
||||
"github.com/krau/SaveAny-Bot/common"
|
||||
config "github.com/krau/SaveAny-Bot/config/storage"
|
||||
)
|
||||
|
||||
func (a *Alist) getToken() error {
|
||||
@@ -51,15 +51,15 @@ func (a *Alist) getToken() error {
|
||||
func (a *Alist) refreshToken(cfg config.AlistStorageConfig) {
|
||||
tokenExp := cfg.TokenExp
|
||||
if tokenExp <= 0 {
|
||||
logger.L.Warn("Invalid token expiration time, using default value")
|
||||
common.Log.Warn("Invalid token expiration time, using default value")
|
||||
tokenExp = 3600
|
||||
}
|
||||
for {
|
||||
time.Sleep(time.Duration(tokenExp) * time.Second)
|
||||
if err := a.getToken(); err != nil {
|
||||
logger.L.Errorf("Failed to refresh jwt token: %v", err)
|
||||
common.Log.Errorf("Failed to refresh jwt token: %v", err)
|
||||
continue
|
||||
}
|
||||
logger.L.Info("Refreshed Alist jwt token")
|
||||
common.Log.Info("Refreshed Alist jwt token")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,8 +8,8 @@ import (
|
||||
"path/filepath"
|
||||
|
||||
"github.com/duke-git/lancet/v2/fileutil"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/logger"
|
||||
"github.com/krau/SaveAny-Bot/common"
|
||||
config "github.com/krau/SaveAny-Bot/config/storage"
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
)
|
||||
|
||||
@@ -41,8 +41,13 @@ func (l *Local) Name() string {
|
||||
return l.config.Name
|
||||
}
|
||||
|
||||
func (l *Local) Save(ctx context.Context, filePath, storagePath string) error {
|
||||
logger.L.Infof("Saving file %s to %s", filePath, storagePath)
|
||||
func (l *Local) JoinStoragePath(task types.Task) string {
|
||||
return filepath.Join(l.config.BasePath, task.StoragePath)
|
||||
}
|
||||
|
||||
func (l *Local) Save(ctx context.Context, r io.Reader, storagePath string) error {
|
||||
common.Log.Infof("Saving file to %s", storagePath)
|
||||
|
||||
absPath, err := filepath.Abs(storagePath)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -50,24 +55,11 @@ func (l *Local) Save(ctx context.Context, filePath, storagePath string) error {
|
||||
if err := fileutil.CreateDir(filepath.Dir(absPath)); err != nil {
|
||||
return err
|
||||
}
|
||||
return fileutil.CopyFile(filePath, storagePath)
|
||||
}
|
||||
|
||||
func (l *Local) JoinStoragePath(task types.Task) string {
|
||||
return filepath.Join(l.config.BasePath, task.StoragePath)
|
||||
}
|
||||
|
||||
func (l *Local) NewUploadStream(ctx context.Context, path string) (io.WriteCloser, error) {
|
||||
absPath, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := fileutil.CreateDir(filepath.Dir(absPath)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
file, err := os.Create(absPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
return file, nil
|
||||
defer file.Close()
|
||||
_, err = io.Copy(file, r)
|
||||
return err
|
||||
}
|
||||
|
||||
72
storage/minio/client.go
Normal file
72
storage/minio/client.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package minio
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"path"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/common"
|
||||
config "github.com/krau/SaveAny-Bot/config/storage"
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
"github.com/minio/minio-go/v7"
|
||||
"github.com/minio/minio-go/v7/pkg/credentials"
|
||||
)
|
||||
|
||||
type Minio struct {
|
||||
config config.MinioStorageConfig
|
||||
client *minio.Client
|
||||
}
|
||||
|
||||
func (m *Minio) Init(cfg config.StorageConfig) error {
|
||||
minioConfig, ok := cfg.(*config.MinioStorageConfig)
|
||||
if !ok {
|
||||
return fmt.Errorf("failed to cast minio config")
|
||||
}
|
||||
if err := minioConfig.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
m.config = *minioConfig
|
||||
|
||||
client, err := minio.New(m.config.Endpoint, &minio.Options{
|
||||
Creds: credentials.NewStaticV4(m.config.AccessKeyID, m.config.SecretAccessKey, ""),
|
||||
Secure: m.config.UseSSL,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create minio client: %w", err)
|
||||
}
|
||||
|
||||
exists, err := client.BucketExists(context.Background(), m.config.BucketName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check bucket existence: %w", err)
|
||||
}
|
||||
if !exists {
|
||||
return fmt.Errorf("bucket %s does not exist", m.config.BucketName)
|
||||
}
|
||||
|
||||
m.client = client
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Minio) Type() types.StorageType {
|
||||
return types.StorageTypeMinio
|
||||
}
|
||||
|
||||
func (m *Minio) Name() string {
|
||||
return m.config.Name
|
||||
}
|
||||
|
||||
func (m *Minio) JoinStoragePath(task types.Task) string {
|
||||
return path.Join(m.config.BasePath, task.StoragePath)
|
||||
}
|
||||
|
||||
func (m *Minio) Save(ctx context.Context, r io.Reader, storagePath string) error {
|
||||
common.Log.Infof("Saving file from reader to %s", storagePath)
|
||||
|
||||
_, err := m.client.PutObject(ctx, m.config.BucketName, storagePath, r, -1, minio.PutObjectOptions{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to upload file to minio: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -5,25 +5,27 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/common"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/logger"
|
||||
sc "github.com/krau/SaveAny-Bot/config/storage"
|
||||
"github.com/krau/SaveAny-Bot/storage/alist"
|
||||
"github.com/krau/SaveAny-Bot/storage/local"
|
||||
"github.com/krau/SaveAny-Bot/storage/minio"
|
||||
"github.com/krau/SaveAny-Bot/storage/webdav"
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
)
|
||||
|
||||
type Storage interface {
|
||||
Init(cfg config.StorageConfig) error
|
||||
Init(cfg sc.StorageConfig) error
|
||||
Type() types.StorageType
|
||||
Name() string
|
||||
JoinStoragePath(task types.Task) string
|
||||
Save(cttx context.Context, localFilePath, storagePath string) error
|
||||
Save(ctx context.Context, reader io.Reader, storagePath string) error
|
||||
}
|
||||
|
||||
type StreamStorage interface {
|
||||
type StorageNotSupportStream interface {
|
||||
Storage
|
||||
NewUploadStream(ctx context.Context, path string) (io.WriteCloser, error)
|
||||
NotSupportStream() string
|
||||
}
|
||||
|
||||
var Storages = make(map[string]Storage)
|
||||
@@ -90,9 +92,10 @@ var storageConstructors = map[string]StorageConstructor{
|
||||
string(types.StorageTypeAlist): func() Storage { return new(alist.Alist) },
|
||||
string(types.StorageTypeLocal): func() Storage { return new(local.Local) },
|
||||
string(types.StorageTypeWebdav): func() Storage { return new(webdav.Webdav) },
|
||||
string(types.StorageTypeMinio): func() Storage { return new(minio.Minio) },
|
||||
}
|
||||
|
||||
func NewStorage(cfg config.StorageConfig) (Storage, error) {
|
||||
func NewStorage(cfg sc.StorageConfig) (Storage, error) {
|
||||
constructor, ok := storageConstructors[string(cfg.GetType())]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("不支持的存储类型: %s", cfg.GetType())
|
||||
@@ -107,14 +110,14 @@ func NewStorage(cfg config.StorageConfig) (Storage, error) {
|
||||
}
|
||||
|
||||
func LoadStorages() {
|
||||
logger.L.Info("加载存储...")
|
||||
common.Log.Info("加载存储...")
|
||||
for _, storage := range config.Cfg.Storages {
|
||||
_, err := GetStorageByName(storage.GetName())
|
||||
if err != nil {
|
||||
logger.L.Errorf("加载存储 %s 失败: %v", storage.GetName(), err)
|
||||
common.Log.Errorf("加载存储 %s 失败: %v", storage.GetName(), err)
|
||||
}
|
||||
}
|
||||
logger.L.Infof("成功加载 %d 个存储", len(Storages))
|
||||
common.Log.Infof("成功加载 %d 个存储", len(Storages))
|
||||
for user := range config.Cfg.GetUsersID() {
|
||||
UserStorages[int64(user)] = GetUserStorages(int64(user))
|
||||
}
|
||||
|
||||
130
storage/webdav/client._test.go
Normal file
130
storage/webdav/client._test.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package webdav
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/net/webdav"
|
||||
)
|
||||
|
||||
func setupWebDAVServer(t *testing.T) (*httptest.Server, string) {
|
||||
t.Helper()
|
||||
tempDir, err := os.MkdirTemp("", "webdav_test")
|
||||
if err != nil {
|
||||
t.Fatalf("mk temp dir failed: %v", err)
|
||||
}
|
||||
|
||||
handler := &webdav.Handler{
|
||||
Prefix: "/",
|
||||
FileSystem: webdav.Dir(tempDir),
|
||||
LockSystem: webdav.NewMemLS(),
|
||||
}
|
||||
|
||||
server := httptest.NewServer(handler)
|
||||
return server, tempDir
|
||||
}
|
||||
|
||||
func TestMkDirAndExists(t *testing.T) {
|
||||
server, tempDir := setupWebDAVServer(t)
|
||||
defer os.RemoveAll(tempDir)
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "", "", nil)
|
||||
ctx := context.Background()
|
||||
|
||||
testpaths := []string{"testdir", "testdir/subdir", "testdir/子目录", "/testdir/测试路径/测试路径2"}
|
||||
for _, p := range testpaths {
|
||||
exists, err := client.Exists(ctx, p)
|
||||
if err != nil {
|
||||
t.Fatalf("Call Exists Err: %v", err)
|
||||
}
|
||||
if exists {
|
||||
t.Fatalf("Dir should not exist")
|
||||
}
|
||||
|
||||
if err := client.MkDir(ctx, p); err != nil {
|
||||
t.Fatalf("Call MkDir Err: %v", err)
|
||||
}
|
||||
|
||||
exists, err = client.Exists(ctx, p)
|
||||
if err != nil {
|
||||
t.Fatalf("Call Exists Err: %v", err)
|
||||
}
|
||||
if !exists {
|
||||
t.Fatalf("Dir should exist")
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestWriteFile(t *testing.T) {
|
||||
server, tempDir := setupWebDAVServer(t)
|
||||
defer os.RemoveAll(tempDir)
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "", "", nil)
|
||||
ctx := context.Background()
|
||||
|
||||
testCases := []struct {
|
||||
remotePath string
|
||||
content string
|
||||
}{
|
||||
{
|
||||
remotePath: "hello.txt",
|
||||
content: "Hello webdav",
|
||||
},
|
||||
{
|
||||
remotePath: "nested/dir/test.txt",
|
||||
content: "Nested file",
|
||||
},
|
||||
{
|
||||
remotePath: "empty.txt",
|
||||
content: "",
|
||||
},
|
||||
{
|
||||
remotePath: "unicode.txt",
|
||||
content: "测试",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.remotePath, func(t *testing.T) {
|
||||
dir := path.Dir(tc.remotePath)
|
||||
if dir != "." {
|
||||
if err := client.MkDir(ctx, dir); err != nil {
|
||||
t.Fatalf("创建目录 %s 失败: %v", dir, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := client.WriteFile(ctx, tc.remotePath, strings.NewReader(tc.content)); err != nil {
|
||||
t.Fatalf("写入文件 %s 失败: %v", tc.remotePath, err)
|
||||
}
|
||||
|
||||
localPath := filepath.Join(tempDir, tc.remotePath)
|
||||
data, err := os.ReadFile(localPath)
|
||||
if err != nil {
|
||||
t.Fatalf("读取文件 %s 失败: %v", localPath, err)
|
||||
}
|
||||
if string(data) != tc.content {
|
||||
t.Fatalf("文件内容不匹配: got %s, want %s", string(data), tc.content)
|
||||
}
|
||||
|
||||
appended := tc.content + " Overwritten."
|
||||
if err := client.WriteFile(ctx, tc.remotePath, strings.NewReader(appended)); err != nil {
|
||||
t.Fatalf("覆盖写入文件 %s 失败: %v", tc.remotePath, err)
|
||||
}
|
||||
data, err = os.ReadFile(localPath)
|
||||
if err != nil {
|
||||
t.Fatalf("读取覆盖后的文件 %s 失败: %v", localPath, err)
|
||||
}
|
||||
if string(data) != appended {
|
||||
t.Fatalf("文件覆盖后的内容不匹配: got %s, want %s", string(data), appended)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
@@ -38,21 +40,63 @@ func (c *Client) doRequest(ctx context.Context, method, url string, body io.Read
|
||||
if c.Username != "" && c.Password != "" {
|
||||
req.SetBasicAuth(c.Username, c.Password)
|
||||
}
|
||||
if length := ctx.Value(types.ContextKeyContentLength); length != nil {
|
||||
if l, ok := length.(int64); ok {
|
||||
req.ContentLength = l
|
||||
}
|
||||
}
|
||||
return c.httpClient.Do(req)
|
||||
}
|
||||
|
||||
func (c *Client) MkDir(ctx context.Context, dirPath string) error {
|
||||
url := c.BaseURL + dirPath
|
||||
resp, err := c.doRequest(ctx, "MKCOL", url, nil)
|
||||
func (c *Client) Exists(ctx context.Context, remotePath string) (bool, error) {
|
||||
url := c.BaseURL + remotePath
|
||||
resp, err := c.doRequest(ctx, "PROPFIND", url, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
return false, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||
return true, nil
|
||||
}
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
return false, nil
|
||||
}
|
||||
return false, fmt.Errorf("PROPFIND: %s", resp.Status)
|
||||
}
|
||||
|
||||
func (c *Client) MkDir(ctx context.Context, dirPath string) error {
|
||||
dirPath = strings.Trim(dirPath, "/")
|
||||
if dirPath == "" {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("MKCOL: %s", resp.Status)
|
||||
parts := strings.Split(dirPath, "/")
|
||||
currentPath := ""
|
||||
for i, part := range parts {
|
||||
if i > 0 {
|
||||
currentPath += "/"
|
||||
}
|
||||
currentPath += part
|
||||
|
||||
exists, err := c.Exists(ctx, currentPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if exists {
|
||||
continue
|
||||
}
|
||||
url := c.BaseURL + currentPath
|
||||
resp, err := c.doRequest(ctx, "MKCOL", url, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return fmt.Errorf("MKCOL %s: %s", currentPath, resp.Status)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) WriteFile(ctx context.Context, remotePath string, content io.Reader) error {
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
package webdav
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"path"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/logger"
|
||||
)
|
||||
|
||||
type WebdavWriter struct {
|
||||
pipeWriter *io.PipeWriter
|
||||
done chan error
|
||||
path string
|
||||
}
|
||||
|
||||
func (w *WebdavWriter) Write(p []byte) (n int, err error) {
|
||||
return w.pipeWriter.Write(p)
|
||||
}
|
||||
|
||||
func (w *WebdavWriter) Close() error {
|
||||
if err := w.pipeWriter.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := <-w.done; err != nil {
|
||||
return fmt.Errorf("upload failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *Webdav) NewUploadStream(ctx context.Context, storagePath string) (io.WriteCloser, error) {
|
||||
if err := w.client.MkDir(ctx, path.Dir(storagePath)); err != nil {
|
||||
logger.L.Errorf("Failed to create directory %s: %v", path.Dir(storagePath), err)
|
||||
return nil, ErrFailedToCreateDirectory
|
||||
}
|
||||
pipeReader, pipeWriter := io.Pipe()
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
done <- fmt.Errorf("panic during upload: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
err := w.client.WriteFile(ctx, storagePath, pipeReader)
|
||||
|
||||
pipeReader.Close()
|
||||
done <- err
|
||||
}()
|
||||
|
||||
return &WebdavWriter{
|
||||
pipeWriter: pipeWriter,
|
||||
done: done,
|
||||
path: storagePath,
|
||||
}, nil
|
||||
}
|
||||
@@ -3,13 +3,13 @@ package webdav
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"time"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/logger"
|
||||
"github.com/krau/SaveAny-Bot/common"
|
||||
config "github.com/krau/SaveAny-Bot/config/storage"
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
)
|
||||
|
||||
@@ -41,26 +41,19 @@ func (w *Webdav) Name() string {
|
||||
return w.config.Name
|
||||
}
|
||||
|
||||
func (w *Webdav) Save(ctx context.Context, filePath, storagePath string) error {
|
||||
logger.L.Infof("Saving file %s to %s", filePath, storagePath)
|
||||
func (w *Webdav) JoinStoragePath(task types.Task) string {
|
||||
return path.Join(w.config.BasePath, task.StoragePath)
|
||||
}
|
||||
|
||||
func (w *Webdav) Save(ctx context.Context, r io.Reader, storagePath string) error {
|
||||
common.Log.Infof("Saving file to %s", storagePath)
|
||||
if err := w.client.MkDir(ctx, path.Dir(storagePath)); err != nil {
|
||||
logger.L.Errorf("Failed to create directory %s: %v", path.Dir(storagePath), err)
|
||||
common.Log.Errorf("Failed to create directory %s: %v", path.Dir(storagePath), err)
|
||||
return ErrFailedToCreateDirectory
|
||||
}
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
logger.L.Errorf("Failed to open file %s: %v", filePath, err)
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
if err := w.client.WriteFile(ctx, storagePath, file); err != nil {
|
||||
logger.L.Errorf("Failed to write file %s: %v", storagePath, err)
|
||||
if err := w.client.WriteFile(ctx, storagePath, r); err != nil {
|
||||
common.Log.Errorf("Failed to write file %s: %v", storagePath, err)
|
||||
return ErrFailedToWriteFile
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *Webdav) JoinStoragePath(task types.Task) string {
|
||||
return path.Join(w.config.BasePath, task.StoragePath)
|
||||
}
|
||||
|
||||
82
types/task.go
Normal file
82
types/task.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gotd/td/tg"
|
||||
)
|
||||
|
||||
type Task struct {
|
||||
Ctx context.Context
|
||||
Cancel context.CancelFunc
|
||||
Error error
|
||||
Status TaskStatus
|
||||
StorageName string
|
||||
StoragePath string
|
||||
StartTime time.Time
|
||||
|
||||
File *File
|
||||
FileMessageID int
|
||||
FileChatID int64
|
||||
|
||||
IsTelegraph bool
|
||||
TelegraphURL string
|
||||
|
||||
// to track the reply message
|
||||
ReplyMessageID int
|
||||
ReplyChatID int64
|
||||
UserID int64
|
||||
}
|
||||
|
||||
func (t Task) Key() string {
|
||||
if t.IsTelegraph {
|
||||
return hashStr(t.TelegraphURL)
|
||||
}
|
||||
return fmt.Sprintf("%d:%d", t.FileChatID, t.FileMessageID)
|
||||
}
|
||||
|
||||
func (t Task) String() string {
|
||||
if t.IsTelegraph {
|
||||
return fmt.Sprintf("[telegraph]:%s", t.TelegraphURL)
|
||||
}
|
||||
return fmt.Sprintf("[%d:%d]:%s", t.FileChatID, t.FileMessageID, t.File.FileName)
|
||||
}
|
||||
|
||||
func (t Task) FileName() string {
|
||||
if t.IsTelegraph {
|
||||
tgphPath := strings.Split(t.TelegraphURL, "/")[len(strings.Split(t.TelegraphURL, "/"))-1]
|
||||
tgphPathUnescaped, err := url.PathUnescape(tgphPath)
|
||||
if err != nil {
|
||||
return tgphPath
|
||||
}
|
||||
return tgphPathUnescaped
|
||||
}
|
||||
return t.File.FileName
|
||||
}
|
||||
|
||||
type File struct {
|
||||
Location tg.InputFileLocationClass
|
||||
FileSize int64
|
||||
FileName string
|
||||
}
|
||||
|
||||
func (f File) Hash() string {
|
||||
locationBytes := []byte(f.Location.String())
|
||||
fileSizeBytes := []byte(fmt.Sprintf("%d", f.FileSize))
|
||||
fileNameBytes := []byte(f.FileName)
|
||||
|
||||
structBytes := append(locationBytes, fileSizeBytes...)
|
||||
structBytes = append(structBytes, fileNameBytes...)
|
||||
|
||||
hash := md5.New()
|
||||
hash.Write(structBytes)
|
||||
hashBytes := hash.Sum(nil)
|
||||
|
||||
return hex.EncodeToString(hashBytes)
|
||||
}
|
||||
@@ -1,18 +1,8 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/gotd/td/tg"
|
||||
)
|
||||
|
||||
type TaskStatus string
|
||||
|
||||
var (
|
||||
const (
|
||||
Pending TaskStatus = "pending"
|
||||
Succeeded TaskStatus = "succeeded"
|
||||
Failed TaskStatus = "failed"
|
||||
@@ -21,67 +11,23 @@ var (
|
||||
|
||||
type StorageType string
|
||||
|
||||
var (
|
||||
const (
|
||||
StorageTypeLocal StorageType = "local"
|
||||
StorageTypeWebdav StorageType = "webdav"
|
||||
StorageTypeAlist StorageType = "alist"
|
||||
StorageTypeMinio StorageType = "minio"
|
||||
)
|
||||
|
||||
var StorageTypes = []StorageType{StorageTypeLocal, StorageTypeAlist, StorageTypeWebdav}
|
||||
var StorageTypes = []StorageType{StorageTypeLocal, StorageTypeAlist, StorageTypeWebdav, StorageTypeMinio}
|
||||
var StorageTypeDisplay = map[StorageType]string{
|
||||
StorageTypeLocal: "本地磁盘",
|
||||
StorageTypeWebdav: "WebDAV",
|
||||
StorageTypeAlist: "Alist",
|
||||
StorageTypeMinio: "Minio",
|
||||
}
|
||||
|
||||
type Task struct {
|
||||
Ctx context.Context
|
||||
Cancel context.CancelFunc
|
||||
Error error
|
||||
Status TaskStatus
|
||||
File *File
|
||||
StorageName string
|
||||
StoragePath string
|
||||
StartTime time.Time
|
||||
type ContextKey string
|
||||
|
||||
FileMessageID int
|
||||
FileChatID int64
|
||||
// to track the reply message
|
||||
ReplyMessageID int
|
||||
ReplyChatID int64
|
||||
// to track the user
|
||||
UserID int64
|
||||
}
|
||||
|
||||
func (t Task) Key() string {
|
||||
return fmt.Sprintf("%d:%d", t.FileChatID, t.FileMessageID)
|
||||
}
|
||||
|
||||
func (t Task) String() string {
|
||||
return fmt.Sprintf("[%d:%d]:%s", t.FileChatID, t.FileMessageID, t.File.FileName)
|
||||
}
|
||||
|
||||
func (t Task) FileName() string {
|
||||
return t.File.FileName
|
||||
}
|
||||
|
||||
type File struct {
|
||||
Location tg.InputFileLocationClass
|
||||
FileSize int64
|
||||
FileName string
|
||||
}
|
||||
|
||||
func (f File) Hash() string {
|
||||
locationBytes := []byte(f.Location.String())
|
||||
fileSizeBytes := []byte(fmt.Sprintf("%d", f.FileSize))
|
||||
fileNameBytes := []byte(f.FileName)
|
||||
|
||||
structBytes := append(locationBytes, fileSizeBytes...)
|
||||
structBytes = append(structBytes, fileNameBytes...)
|
||||
|
||||
hash := md5.New()
|
||||
hash.Write(structBytes)
|
||||
hashBytes := hash.Sum(nil)
|
||||
|
||||
return hex.EncodeToString(hashBytes)
|
||||
}
|
||||
const (
|
||||
ContextKeyContentLength ContextKey = "content-length"
|
||||
)
|
||||
|
||||
12
types/utils.go
Normal file
12
types/utils.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
)
|
||||
|
||||
func hashStr(s string) string {
|
||||
hash := md5.New()
|
||||
hash.Write([]byte(s))
|
||||
return hex.EncodeToString(hash.Sum(nil))
|
||||
}
|
||||
Reference in New Issue
Block a user