mirror of
https://github.com/krau/SaveAny-Bot.git
synced 2026-05-10 17:52:44 +08:00
Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ab822c2fe6 | ||
|
|
2579044841 | ||
|
|
88a02aae8d | ||
|
|
ab374a870b | ||
|
|
3a1b8f34ea | ||
|
|
c4eb824457 | ||
|
|
692e970772 | ||
|
|
80696c9661 | ||
|
|
18cd480264 | ||
|
|
dfde65c28e | ||
|
|
968547b005 |
@@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
# <img src="docs/logo.jpg" width="45" align="center"> Save Any Bot
|
# <img src="docs/logo.jpg" width="45" align="center"> Save Any Bot
|
||||||
|
|
||||||
|
**简体中文** | [English](README_EN.md)
|
||||||
|
|
||||||
把 Telegram 的文件保存到各类存储端.
|
把 Telegram 的文件保存到各类存储端.
|
||||||
|
|
||||||
> _就像 PikPak Bot 一样_
|
> _就像 PikPak Bot 一样_
|
||||||
@@ -60,9 +62,9 @@ systemctl enable --now saveany-bot
|
|||||||
|
|
||||||
#### Docker Compose
|
#### Docker Compose
|
||||||
|
|
||||||
下载 [docker-compose.yml](https://github.com/krau/SaveAny-Bot/blob/main/docker-compose.yml) 文件, 并修改其中的配置.
|
下载 [docker-compose.yml](https://github.com/krau/SaveAny-Bot/blob/main/docker-compose.yml) 文件, 在同目录下新建 `config.toml` 文件, 参考 [config.toml.example](https://github.com/krau/SaveAny-Bot/blob/main/config.example.toml) 编辑配置文件.
|
||||||
|
|
||||||
运行:
|
启动:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker compose up -d
|
docker compose up -d
|
||||||
@@ -94,7 +96,7 @@ docker restart saveany-bot
|
|||||||
|
|
||||||
## 使用
|
## 使用
|
||||||
|
|
||||||
向 Bot 发送(转发)文件, 按照提示操作.
|
向 Bot 发送(转发)文件, 或发送公开频道的消息链接, 按照提示操作.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|||||||
108
README_EN.md
Normal file
108
README_EN.md
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
<div align="center">
|
||||||
|
|
||||||
|
# <img src="docs/logo.jpg" width="45" align="center"> Save Any Bot
|
||||||
|
|
||||||
|
[简体中文](README.md) | **English**
|
||||||
|
|
||||||
|
Save Telegram files to various storage endpoints.
|
||||||
|
|
||||||
|
> _Just like PikPak Bot_
|
||||||
|
|
||||||
|
</div>
|
||||||
|
|
||||||
|
Demo Video:
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
|
||||||
|
[SaveAny-Bot Demo Video.webm](https://github.com/user-attachments/assets/a0de2453-a4d1-4a12-81fb-9d84856dce09)
|
||||||
|
|
||||||
|
</div>
|
||||||
|
|
||||||
|
## Deployment
|
||||||
|
|
||||||
|
### Deploy from Binary
|
||||||
|
|
||||||
|
Download the binary file for your platform from the [Release](https://github.com/krau/SaveAny-Bot/releases) page.
|
||||||
|
|
||||||
|
Create a `config.toml` file in the extracted directory, refer to [config.toml.example](https://github.com/krau/SaveAny-Bot/blob/main/config.example.toml) for configuration.
|
||||||
|
|
||||||
|
Run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
chmod +x saveany-bot
|
||||||
|
./saveany-bot
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Add as systemd Service
|
||||||
|
|
||||||
|
Create file `/etc/systemd/system/saveany-bot.service` and write the following content:
|
||||||
|
|
||||||
|
```
|
||||||
|
[Unit]
|
||||||
|
Description=SaveAnyBot
|
||||||
|
After=systemd-user-sessions.service
|
||||||
|
|
||||||
|
[Service]
|
||||||
|
Type=simple
|
||||||
|
WorkingDirectory=/yourpath/
|
||||||
|
ExecStart=/yourpath/saveany-bot
|
||||||
|
Restart=on-failure
|
||||||
|
|
||||||
|
[Install]
|
||||||
|
WantedBy=multi-user.target
|
||||||
|
```
|
||||||
|
|
||||||
|
Enable auto-start and start the service:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
systemctl enable --now saveany-bot
|
||||||
|
```
|
||||||
|
|
||||||
|
### Deploy with Docker
|
||||||
|
|
||||||
|
#### Docker Compose
|
||||||
|
|
||||||
|
Download [docker-compose.yml](https://github.com/krau/SaveAny-Bot/blob/main/docker-compose.yml) file and create a `config.toml` file in the same directory, refer to [config.toml.example](https://github.com/krau/SaveAny-Bot/blob/main/config.example.toml) for configuration.
|
||||||
|
|
||||||
|
Run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker compose up -d
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Docker
|
||||||
|
|
||||||
|
```shell
|
||||||
|
docker run -d --name saveany-bot \
|
||||||
|
-v /path/to/config.toml:/app/config.toml \
|
||||||
|
-v /path/to/downloads:/app/downloads \
|
||||||
|
ghcr.io/krau/saveany-bot:latest
|
||||||
|
```
|
||||||
|
|
||||||
|
## Update
|
||||||
|
|
||||||
|
Use `upgrade` or `up` command to upgrade to the latest version:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./saveany-bot upgrade
|
||||||
|
```
|
||||||
|
|
||||||
|
If deployed with Docker, use the following commands to update:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker pull ghcr.io/krau/saveany-bot:latest
|
||||||
|
docker restart saveany-bot
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
Send (forward) files to the Bot and follow the prompts.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Thanks
|
||||||
|
|
||||||
|
- [gotd](https://github.com/gotd/td)
|
||||||
|
- [TG-FileStreamBot](https://github.com/EverythingSuckz/TG-FileStreamBot)
|
||||||
|
- [gotgproto](https://github.com/celestix/gotgproto)
|
||||||
|
- All the dependencies
|
||||||
@@ -1,6 +1,9 @@
|
|||||||
package bootstrap
|
package bootstrap
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
"github.com/krau/SaveAny-Bot/bot"
|
"github.com/krau/SaveAny-Bot/bot"
|
||||||
"github.com/krau/SaveAny-Bot/common"
|
"github.com/krau/SaveAny-Bot/common"
|
||||||
"github.com/krau/SaveAny-Bot/config"
|
"github.com/krau/SaveAny-Bot/config"
|
||||||
@@ -10,12 +13,14 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func InitAll() {
|
func InitAll() {
|
||||||
config.Init()
|
if err := config.Init(); err != nil {
|
||||||
|
fmt.Println("加载配置文件失败: ", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
logger.InitLogger()
|
logger.InitLogger()
|
||||||
logger.L.Info("Running...")
|
logger.L.Info("正在启动 SaveAny-Bot...")
|
||||||
|
storage.LoadStorages()
|
||||||
common.Init()
|
common.Init()
|
||||||
storage.Init()
|
|
||||||
dao.Init()
|
dao.Init()
|
||||||
bot.Init()
|
bot.Init()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ func newProxyDialer(proxyUrl string) (proxy.Dialer, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Init() {
|
func Init() {
|
||||||
logger.L.Info("Initializing client...")
|
logger.L.Info("初始化 Telegram 客户端...")
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
resultChan := make(chan struct {
|
resultChan := make(chan struct {
|
||||||
@@ -76,7 +76,6 @@ func Init() {
|
|||||||
{Command: "silent", Description: "开启/关闭静默模式"},
|
{Command: "silent", Description: "开启/关闭静默模式"},
|
||||||
{Command: "storage", Description: "设置默认存储端"},
|
{Command: "storage", Description: "设置默认存储端"},
|
||||||
{Command: "save", Description: "保存所回复的文件"},
|
{Command: "save", Description: "保存所回复的文件"},
|
||||||
{Command: "path", Description: "更改保存路径配置"},
|
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
resultChan <- struct {
|
resultChan <- struct {
|
||||||
@@ -87,15 +86,15 @@ func Init() {
|
|||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
logger.L.Fatal("Failed to initialize client: timeout")
|
logger.L.Fatal("初始化客户端失败: 超时")
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
case result := <-resultChan:
|
case result := <-resultChan:
|
||||||
if result.err != nil {
|
if result.err != nil {
|
||||||
logger.L.Fatalf("Failed to initialize client: %s", result.err)
|
logger.L.Fatalf("初始化客户端失败: %s", result.err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
Client = result.client
|
Client = result.client
|
||||||
RegisterHandlers(Client.Dispatcher)
|
RegisterHandlers(Client.Dispatcher)
|
||||||
logger.L.Info("Client initialized")
|
logger.L.Info("客户端初始化完成")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"github.com/gotd/td/tg"
|
"github.com/gotd/td/tg"
|
||||||
"github.com/krau/SaveAny-Bot/dao"
|
"github.com/krau/SaveAny-Bot/dao"
|
||||||
"github.com/krau/SaveAny-Bot/logger"
|
"github.com/krau/SaveAny-Bot/logger"
|
||||||
|
"github.com/krau/SaveAny-Bot/storage"
|
||||||
"github.com/krau/SaveAny-Bot/types"
|
"github.com/krau/SaveAny-Bot/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -31,41 +32,49 @@ func handleLinkMessage(ctx *ext.Context, update *ext.Update) error {
|
|||||||
}
|
}
|
||||||
messageID, err := strconv.Atoi(strSlice[2])
|
messageID, err := strconv.Atoi(strSlice[2])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.L.Errorf("Failed to parse message ID: %s", err)
|
logger.L.Errorf("解析消息 ID 失败: %s", err)
|
||||||
ctx.Reply(update, ext.ReplyTextString("Failed to parse message ID"), nil)
|
ctx.Reply(update, ext.ReplyTextString("无法解析消息 ID"), nil)
|
||||||
return dispatcher.EndGroups
|
return dispatcher.EndGroups
|
||||||
}
|
}
|
||||||
chatUsername := strSlice[1]
|
chatUsername := strSlice[1]
|
||||||
linkChat, err := ctx.ResolveUsername(chatUsername)
|
linkChat, err := ctx.ResolveUsername(chatUsername)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.L.Errorf("Failed to resolve chat ID: %s", err)
|
logger.L.Errorf("解析 Chat ID 失败: %s", err)
|
||||||
ctx.Reply(update, ext.ReplyTextString("Failed to resolve chat ID"), nil)
|
ctx.Reply(update, ext.ReplyTextString("无法解析 Chat ID"), nil)
|
||||||
return dispatcher.EndGroups
|
return dispatcher.EndGroups
|
||||||
}
|
}
|
||||||
if linkChat == nil {
|
if linkChat == nil {
|
||||||
logger.L.Errorf("Cannot find chat: %s", chatUsername)
|
logger.L.Errorf("无法找到聊天: %s", chatUsername)
|
||||||
ctx.Reply(update, ext.ReplyTextString("Cannot find chat"), nil)
|
ctx.Reply(update, ext.ReplyTextString("无法找到聊天"), nil)
|
||||||
return dispatcher.EndGroups
|
return dispatcher.EndGroups
|
||||||
}
|
}
|
||||||
user, err := dao.GetUserByUserID(update.GetUserChat().GetID())
|
user, err := dao.GetUserByChatID(update.GetUserChat().GetID())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.L.Errorf("Failed to get user: %s", err)
|
logger.L.Errorf("获取用户失败: %s", err)
|
||||||
|
ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil)
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
storages := storage.GetUserStorages(user.ChatID)
|
||||||
|
|
||||||
|
if len(storages) == 0 {
|
||||||
|
ctx.Reply(update, ext.ReplyTextString("无可用的存储"), nil)
|
||||||
return dispatcher.EndGroups
|
return dispatcher.EndGroups
|
||||||
}
|
}
|
||||||
replied, err := ctx.Reply(update, ext.ReplyTextString("正在获取文件..."), nil)
|
replied, err := ctx.Reply(update, ext.ReplyTextString("正在获取文件..."), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.L.Errorf("Failed to reply: %s", err)
|
logger.L.Errorf("回复失败: %s", err)
|
||||||
return dispatcher.EndGroups
|
return dispatcher.EndGroups
|
||||||
}
|
}
|
||||||
|
|
||||||
file, err := FileFromMessage(ctx, linkChat.GetID(), messageID, "")
|
file, err := FileFromMessage(ctx, linkChat.GetID(), messageID, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.L.Errorf("Failed to get file from message: %s", err)
|
logger.L.Errorf("获取文件失败: %s", err)
|
||||||
ctx.Reply(update, ext.ReplyTextString("获取文件失败: "+err.Error()), nil)
|
ctx.Reply(update, ext.ReplyTextString("获取文件失败: "+err.Error()), nil)
|
||||||
return dispatcher.EndGroups
|
return dispatcher.EndGroups
|
||||||
}
|
}
|
||||||
|
// TODO: Better file name
|
||||||
if file.FileName == "" {
|
if file.FileName == "" {
|
||||||
logger.L.Warnf("Empty file name, use generated name")
|
logger.L.Warnf("文件名为空,使用生成的名称")
|
||||||
file.FileName = fmt.Sprintf("%d_%d_%s", linkChat.GetID(), messageID, file.Hash())
|
file.FileName = fmt.Sprintf("%d_%d_%s", linkChat.GetID(), messageID, file.Hash())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -78,21 +87,21 @@ func handleLinkMessage(ctx *ext.Context, update *ext.Update) error {
|
|||||||
ReplyChatID: update.GetUserChat().GetID(),
|
ReplyChatID: update.GetUserChat().GetID(),
|
||||||
}
|
}
|
||||||
if err := dao.SaveReceivedFile(receivedFile); err != nil {
|
if err := dao.SaveReceivedFile(receivedFile); err != nil {
|
||||||
logger.L.Errorf("Failed to save received file: %s", err)
|
logger.L.Errorf("保存接收的文件失败: %s", err)
|
||||||
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
|
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
|
||||||
Message: "无法保存文件: " + err.Error(),
|
Message: "无法保存文件: " + err.Error(),
|
||||||
ID: replied.ID,
|
ID: replied.ID,
|
||||||
})
|
})
|
||||||
return dispatcher.EndGroups
|
return dispatcher.EndGroups
|
||||||
}
|
}
|
||||||
if !user.Silent {
|
if !user.Silent || user.DefaultStorage == "" {
|
||||||
return ProvideSelectMessage(ctx, update, file, int(linkChat.GetID()), messageID, replied.ID)
|
return ProvideSelectMessage(ctx, update, file, linkChat.GetID(), messageID, replied.ID)
|
||||||
}
|
}
|
||||||
return HandleSilentAddTask(ctx, update, user, &types.Task{
|
return HandleSilentAddTask(ctx, update, user, &types.Task{
|
||||||
Ctx: ctx,
|
Ctx: ctx,
|
||||||
Status: types.Pending,
|
Status: types.Pending,
|
||||||
File: file,
|
File: file,
|
||||||
Storage: types.StorageType(user.DefaultStorage),
|
StorageName: user.DefaultStorage,
|
||||||
FileChatID: linkChat.GetID(),
|
FileChatID: linkChat.GetID(),
|
||||||
FileMessageID: messageID,
|
FileMessageID: messageID,
|
||||||
ReplyMessageID: replied.ID,
|
ReplyMessageID: replied.ID,
|
||||||
|
|||||||
75
bot/handler_conversation.go
Normal file
75
bot/handler_conversation.go
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
package bot
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ConversationType string
|
||||||
|
|
||||||
|
type ConversationState struct {
|
||||||
|
sync.Mutex
|
||||||
|
conversationType ConversationType
|
||||||
|
InConversation bool
|
||||||
|
data map[ConversationType]map[string]interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ConversationState) Reset() {
|
||||||
|
c.Lock()
|
||||||
|
defer c.Unlock()
|
||||||
|
c.InConversation = false
|
||||||
|
c.conversationType = ""
|
||||||
|
c.data = make(map[ConversationType]map[string]interface{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ConversationState) SetConversationType(t ConversationType) {
|
||||||
|
c.Lock()
|
||||||
|
defer c.Unlock()
|
||||||
|
c.conversationType = t
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ConversationState) GetData(key string) interface{} {
|
||||||
|
if c.data == nil || c.data[c.conversationType] == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return c.data[c.conversationType][key]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ConversationState) SetData(key string, value interface{}) {
|
||||||
|
c.Lock()
|
||||||
|
defer c.Unlock()
|
||||||
|
if c.data == nil {
|
||||||
|
c.data = make(map[ConversationType]map[string]interface{})
|
||||||
|
}
|
||||||
|
if c.data[c.conversationType] == nil {
|
||||||
|
c.data[c.conversationType] = make(map[string]interface{})
|
||||||
|
}
|
||||||
|
c.data[c.conversationType][key] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Implement conversation handling
|
||||||
|
// var userConversationState = make(map[int64]*ConversationState)
|
||||||
|
|
||||||
|
// func handleConversation(ctx *ext.Context, update *ext.Update) error {
|
||||||
|
// userID := update.EffectiveUser().GetID()
|
||||||
|
// state, ok := userConversationState[userID]
|
||||||
|
// if !ok {
|
||||||
|
// return dispatcher.ContinueGroups
|
||||||
|
// }
|
||||||
|
// if update.EffectiveMessage.Text == "/cancel" {
|
||||||
|
// state.Reset()
|
||||||
|
// ctx.Reply(update, ext.ReplyTextString("已取消"), nil)
|
||||||
|
// return dispatcher.EndGroups
|
||||||
|
// }
|
||||||
|
// if !state.InConversation {
|
||||||
|
// return dispatcher.ContinueGroups
|
||||||
|
// }
|
||||||
|
// return handleConversationState(ctx, update, state)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func handleConversationState(ctx *ext.Context, update *ext.Update, state *ConversationState) error {
|
||||||
|
// switch state.conversationType {
|
||||||
|
// default:
|
||||||
|
// logger.L.Errorf("Unknown conversation type: %s", state.conversationType)
|
||||||
|
// }
|
||||||
|
// return dispatcher.EndGroups
|
||||||
|
// }
|
||||||
256
bot/handlers.go
256
bot/handlers.go
@@ -6,7 +6,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/duke-git/lancet/v2/slice"
|
"github.com/duke-git/lancet/v2/slice"
|
||||||
"github.com/gookit/goutil/maputil"
|
|
||||||
"github.com/gotd/td/telegram/message/entity"
|
"github.com/gotd/td/telegram/message/entity"
|
||||||
"github.com/gotd/td/telegram/message/styling"
|
"github.com/gotd/td/telegram/message/styling"
|
||||||
"github.com/gotd/td/tg"
|
"github.com/gotd/td/tg"
|
||||||
@@ -28,26 +27,26 @@ func RegisterHandlers(dispatcher dispatcher.Dispatcher) {
|
|||||||
dispatcher.AddHandler(handlers.NewCommand("start", start))
|
dispatcher.AddHandler(handlers.NewCommand("start", start))
|
||||||
dispatcher.AddHandler(handlers.NewCommand("help", help))
|
dispatcher.AddHandler(handlers.NewCommand("help", help))
|
||||||
dispatcher.AddHandler(handlers.NewCommand("silent", silent))
|
dispatcher.AddHandler(handlers.NewCommand("silent", silent))
|
||||||
dispatcher.AddHandler(handlers.NewCommand("storage", setDefaultStorage))
|
dispatcher.AddHandler(handlers.NewCommand("storage", storageCmd))
|
||||||
dispatcher.AddHandler(handlers.NewCommand("save", saveCmd))
|
dispatcher.AddHandler(handlers.NewCommand("save", saveCmd))
|
||||||
dispatcher.AddHandler(handlers.NewCommand("path", setPath))
|
|
||||||
linkRegexFilter, err := filters.Message.Regex(linkRegexString)
|
linkRegexFilter, err := filters.Message.Regex(linkRegexString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.L.Panicf("Failed to create regex filter: %s", err)
|
logger.L.Panicf("创建正则表达式过滤器失败: %s", err)
|
||||||
}
|
}
|
||||||
dispatcher.AddHandler(handlers.NewMessage(linkRegexFilter, handleLinkMessage))
|
dispatcher.AddHandler(handlers.NewMessage(linkRegexFilter, handleLinkMessage))
|
||||||
dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("add"), AddToQueue))
|
dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("add"), AddToQueue))
|
||||||
|
dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("set_default"), setDefaultStorage))
|
||||||
dispatcher.AddHandler(handlers.NewMessage(filters.Message.Media, handleFileMessage))
|
dispatcher.AddHandler(handlers.NewMessage(filters.Message.Media, handleFileMessage))
|
||||||
}
|
}
|
||||||
|
|
||||||
const noPermissionText string = `
|
const noPermissionText string = `
|
||||||
本 Bot 仅限个人使用.
|
您不在白名单中, 无法使用此 Bot.
|
||||||
您可以部署自己的实例: https://github.com/krau/SaveAny-Bot
|
您可以部署自己的实例: https://github.com/krau/SaveAny-Bot
|
||||||
`
|
`
|
||||||
|
|
||||||
func checkPermission(ctx *ext.Context, update *ext.Update) error {
|
func checkPermission(ctx *ext.Context, update *ext.Update) error {
|
||||||
userID := update.GetUserChat().GetID()
|
userID := update.GetUserChat().GetID()
|
||||||
if !slice.Contain(config.Cfg.Telegram.Admins, userID) {
|
if !slice.Contain(config.Cfg.GetUsersID(), userID) {
|
||||||
ctx.Reply(update, ext.ReplyTextString(noPermissionText), nil)
|
ctx.Reply(update, ext.ReplyTextString(noPermissionText), nil)
|
||||||
return dispatcher.EndGroups
|
return dispatcher.EndGroups
|
||||||
}
|
}
|
||||||
@@ -56,7 +55,7 @@ func checkPermission(ctx *ext.Context, update *ext.Update) error {
|
|||||||
|
|
||||||
func start(ctx *ext.Context, update *ext.Update) error {
|
func start(ctx *ext.Context, update *ext.Update) error {
|
||||||
if err := dao.CreateUser(update.GetUserChat().GetID()); err != nil {
|
if err := dao.CreateUser(update.GetUserChat().GetID()); err != nil {
|
||||||
logger.L.Errorf("Failed to create user: %s", err)
|
logger.L.Errorf("创建用户失败: %s", err)
|
||||||
return dispatcher.EndGroups
|
return dispatcher.EndGroups
|
||||||
}
|
}
|
||||||
return help(ctx, update)
|
return help(ctx, update)
|
||||||
@@ -67,10 +66,9 @@ Save Any Bot - 转存你的 Telegram 文件
|
|||||||
命令:
|
命令:
|
||||||
/start - 开始使用
|
/start - 开始使用
|
||||||
/help - 显示帮助
|
/help - 显示帮助
|
||||||
/silent - 静默模式
|
/silent - 开关静默模式
|
||||||
/storage - 设置默认存储位置
|
/storage - 设置默认存储位置
|
||||||
/save [自定义文件名] - 保存文件
|
/save [自定义文件名] - 保存文件
|
||||||
/path <存储类型> <路径> - 更改文件保存路径
|
|
||||||
|
|
||||||
静默模式: 开启后 Bot 直接保存到收到的文件到默认位置, 不再询问
|
静默模式: 开启后 Bot 直接保存到收到的文件到默认位置, 不再询问
|
||||||
|
|
||||||
@@ -85,58 +83,21 @@ func help(ctx *ext.Context, update *ext.Update) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func silent(ctx *ext.Context, update *ext.Update) error {
|
func silent(ctx *ext.Context, update *ext.Update) error {
|
||||||
user, err := dao.GetUserByUserID(update.GetUserChat().GetID())
|
user, err := dao.GetUserByChatID(update.GetUserChat().GetID())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.L.Errorf("Failed to get user: %s", err)
|
logger.L.Errorf("获取用户失败: %s", err)
|
||||||
return dispatcher.EndGroups
|
return dispatcher.EndGroups
|
||||||
}
|
}
|
||||||
user.Silent = !user.Silent
|
user.Silent = !user.Silent
|
||||||
if err := dao.UpdateUser(user); err != nil {
|
if err := dao.UpdateUser(user); err != nil {
|
||||||
logger.L.Errorf("Failed to update user: %s", err)
|
logger.L.Errorf("更新用户失败: %s", err)
|
||||||
|
ctx.Reply(update, ext.ReplyTextString("更新用户失败"), nil)
|
||||||
return dispatcher.EndGroups
|
return dispatcher.EndGroups
|
||||||
}
|
}
|
||||||
ctx.Reply(update, ext.ReplyTextString(fmt.Sprintf("已%s静默模式", map[bool]string{true: "开启", false: "关闭"}[user.Silent])), nil)
|
ctx.Reply(update, ext.ReplyTextString(fmt.Sprintf("已%s静默模式", map[bool]string{true: "开启", false: "关闭"}[user.Silent])), nil)
|
||||||
return dispatcher.EndGroups
|
return dispatcher.EndGroups
|
||||||
}
|
}
|
||||||
|
|
||||||
func setDefaultStorage(ctx *ext.Context, update *ext.Update) error {
|
|
||||||
if len(storage.Storages) == 0 {
|
|
||||||
ctx.Reply(update, ext.ReplyTextString("未配置存储"), nil)
|
|
||||||
return dispatcher.EndGroups
|
|
||||||
}
|
|
||||||
args := strings.Split(update.EffectiveMessage.Text, " ")
|
|
||||||
avaliableStorages := maputil.Keys(storage.Storages)
|
|
||||||
if len(args) < 2 {
|
|
||||||
text := []styling.StyledTextOption{
|
|
||||||
styling.Plain("请提供存储位置名称, 可用项:"),
|
|
||||||
}
|
|
||||||
for _, name := range avaliableStorages {
|
|
||||||
text = append(text, styling.Plain("\n"))
|
|
||||||
text = append(text, styling.Code(name))
|
|
||||||
}
|
|
||||||
text = append(text, styling.Plain("\n示例: /storage local"))
|
|
||||||
ctx.Reply(update, ext.ReplyTextStyledTextArray(text), nil)
|
|
||||||
return dispatcher.EndGroups
|
|
||||||
}
|
|
||||||
storageName := args[1]
|
|
||||||
if !slice.Contain(avaliableStorages, storageName) {
|
|
||||||
ctx.Reply(update, ext.ReplyTextString("存储位置不存在"), nil)
|
|
||||||
return dispatcher.EndGroups
|
|
||||||
}
|
|
||||||
user, err := dao.GetUserByUserID(update.GetUserChat().GetID())
|
|
||||||
if err != nil {
|
|
||||||
logger.L.Errorf("Failed to get user: %s", err)
|
|
||||||
return dispatcher.EndGroups
|
|
||||||
}
|
|
||||||
user.DefaultStorage = storageName
|
|
||||||
if err := dao.UpdateUser(user); err != nil {
|
|
||||||
logger.L.Errorf("Failed to update user: %s", err)
|
|
||||||
return dispatcher.EndGroups
|
|
||||||
}
|
|
||||||
ctx.Reply(update, ext.ReplyTextString(fmt.Sprintf("已设置默认存储位置为 %s", storageName)), nil)
|
|
||||||
return dispatcher.EndGroups
|
|
||||||
}
|
|
||||||
|
|
||||||
func saveCmd(ctx *ext.Context, update *ext.Update) error {
|
func saveCmd(ctx *ext.Context, update *ext.Update) error {
|
||||||
res, ok := update.EffectiveMessage.GetReplyTo()
|
res, ok := update.EffectiveMessage.GetReplyTo()
|
||||||
if !ok || res == nil {
|
if !ok || res == nil {
|
||||||
@@ -154,9 +115,23 @@ func saveCmd(ctx *ext.Context, update *ext.Update) error {
|
|||||||
return dispatcher.EndGroups
|
return dispatcher.EndGroups
|
||||||
}
|
}
|
||||||
|
|
||||||
|
user, err := dao.GetUserByChatID(update.GetUserChat().GetID())
|
||||||
|
if err != nil {
|
||||||
|
logger.L.Errorf("获取用户失败: %s", err)
|
||||||
|
ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil)
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
|
||||||
|
storages := storage.GetUserStorages(user.ChatID)
|
||||||
|
|
||||||
|
if len(storages) == 0 {
|
||||||
|
ctx.Reply(update, ext.ReplyTextString("无可用的存储"), nil)
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
|
||||||
msg, err := GetTGMessage(ctx, update.EffectiveChat().GetID(), replyToMsgID)
|
msg, err := GetTGMessage(ctx, update.EffectiveChat().GetID(), replyToMsgID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.L.Errorf("Failed to get message: %s", err)
|
logger.L.Errorf("获取消息失败: %s", err)
|
||||||
ctx.Reply(update, ext.ReplyTextString("无法获取消息"), nil)
|
ctx.Reply(update, ext.ReplyTextString("无法获取消息"), nil)
|
||||||
return dispatcher.EndGroups
|
return dispatcher.EndGroups
|
||||||
}
|
}
|
||||||
@@ -167,15 +142,9 @@ func saveCmd(ctx *ext.Context, update *ext.Update) error {
|
|||||||
return dispatcher.EndGroups
|
return dispatcher.EndGroups
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := dao.GetUserByUserID(update.GetUserChat().GetID())
|
|
||||||
if err != nil {
|
|
||||||
logger.L.Errorf("Failed to get user: %s", err)
|
|
||||||
return dispatcher.EndGroups
|
|
||||||
}
|
|
||||||
|
|
||||||
replied, err := ctx.Reply(update, ext.ReplyTextString("正在获取文件信息..."), nil)
|
replied, err := ctx.Reply(update, ext.ReplyTextString("正在获取文件信息..."), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.L.Errorf("Failed to reply: %s", err)
|
logger.L.Errorf("回复失败: %s", err)
|
||||||
return dispatcher.EndGroups
|
return dispatcher.EndGroups
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -184,13 +153,15 @@ func saveCmd(ctx *ext.Context, update *ext.Update) error {
|
|||||||
|
|
||||||
file, err := FileFromMessage(ctx, update.EffectiveChat().GetID(), msg.ID, customFileName)
|
file, err := FileFromMessage(ctx, update.EffectiveChat().GetID(), msg.ID, customFileName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.L.Errorf("Failed to get file from message: %s", err)
|
logger.L.Errorf("获取文件失败: %s", err)
|
||||||
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
|
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
|
||||||
Message: fmt.Sprintf("获取文件失败: %s", err),
|
Message: fmt.Sprintf("获取文件失败: %s", err),
|
||||||
ID: replied.ID,
|
ID: replied.ID,
|
||||||
})
|
})
|
||||||
return dispatcher.EndGroups
|
return dispatcher.EndGroups
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: better file name
|
||||||
if file.FileName == "" {
|
if file.FileName == "" {
|
||||||
file.FileName = fmt.Sprintf("%d_%d_%s", update.EffectiveChat().GetID(), replyToMsgID, file.Hash())
|
file.FileName = fmt.Sprintf("%d_%d_%s", update.EffectiveChat().GetID(), replyToMsgID, file.Hash())
|
||||||
}
|
}
|
||||||
@@ -204,72 +175,103 @@ func saveCmd(ctx *ext.Context, update *ext.Update) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := dao.SaveReceivedFile(receivedFile); err != nil {
|
if err := dao.SaveReceivedFile(receivedFile); err != nil {
|
||||||
logger.L.Errorf("Failed to save received file: %s", err)
|
logger.L.Errorf("保存接收的文件失败: %s", err)
|
||||||
if _, err := ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
|
if _, err := ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
|
||||||
Message: fmt.Sprintf("Failed to save received file: %s", err),
|
Message: fmt.Sprintf("保存接收的文件失败: %s", err),
|
||||||
ID: replied.ID,
|
ID: replied.ID,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
logger.L.Errorf("Failed to edit message: %s", err)
|
logger.L.Errorf("编辑消息失败: %s", err)
|
||||||
}
|
}
|
||||||
return dispatcher.EndGroups
|
return dispatcher.EndGroups
|
||||||
}
|
}
|
||||||
if !user.Silent {
|
if !user.Silent || user.DefaultStorage == "" {
|
||||||
return ProvideSelectMessage(ctx, update, file, int(update.EffectiveChat().GetID()), msg.ID, replied.ID)
|
return ProvideSelectMessage(ctx, update, file, update.EffectiveChat().GetID(), msg.ID, replied.ID)
|
||||||
}
|
}
|
||||||
return HandleSilentAddTask(ctx, update, user, &types.Task{
|
return HandleSilentAddTask(ctx, update, user, &types.Task{
|
||||||
Ctx: ctx,
|
Ctx: ctx,
|
||||||
Status: types.Pending,
|
Status: types.Pending,
|
||||||
File: file,
|
File: file,
|
||||||
Storage: types.StorageType(user.DefaultStorage),
|
StorageName: user.DefaultStorage,
|
||||||
FileChatID: update.EffectiveChat().GetID(),
|
FileChatID: update.EffectiveChat().GetID(),
|
||||||
ReplyMessageID: replied.ID,
|
ReplyMessageID: replied.ID,
|
||||||
ReplyChatID: update.GetUserChat().GetID(),
|
ReplyChatID: update.GetUserChat().GetID(),
|
||||||
FileMessageID: msg.ID,
|
FileMessageID: msg.ID,
|
||||||
|
UserID: user.ChatID,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func setPath(ctx *ext.Context, update *ext.Update) error {
|
func storageCmd(ctx *ext.Context, update *ext.Update) error {
|
||||||
if len(storage.Storages) == 0 {
|
user, err := dao.GetUserByChatID(update.GetUserChat().GetID())
|
||||||
ctx.Reply(update, ext.ReplyTextString("未配置存储"), nil)
|
if err != nil {
|
||||||
|
logger.L.Errorf("获取用户失败: %s", err)
|
||||||
|
ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil)
|
||||||
return dispatcher.EndGroups
|
return dispatcher.EndGroups
|
||||||
}
|
}
|
||||||
if update.EffectiveMessage == nil {
|
storages := storage.GetUserStorages(user.ChatID)
|
||||||
logger.L.Error("No effective message")
|
if len(storages) == 0 {
|
||||||
|
ctx.Reply(update, ext.ReplyTextString("无可用的存储"), nil)
|
||||||
return dispatcher.EndGroups
|
return dispatcher.EndGroups
|
||||||
}
|
}
|
||||||
args := strings.Split(update.EffectiveMessage.Text, " ")
|
|
||||||
if len(args) < 3 {
|
ctx.Reply(update, ext.ReplyTextString("请选择要设为默认的存储位置"), &ext.ReplyOpts{
|
||||||
text := []styling.StyledTextOption{
|
Markup: getSetDefaultStorageMarkup(user.ChatID, storages),
|
||||||
styling.Plain("请提供存储位置名称和路径, 可用项:"),
|
})
|
||||||
}
|
|
||||||
for name := range storage.Storages {
|
return dispatcher.EndGroups
|
||||||
text = append(text, styling.Plain("\n"))
|
}
|
||||||
text = append(text, styling.Code(string(name)))
|
|
||||||
}
|
func setDefaultStorage(ctx *ext.Context, update *ext.Update) error {
|
||||||
text = append(text, styling.Plain("\n示例: /path local /path/to/save"))
|
args := strings.Split(string(update.CallbackQuery.Data), " ")
|
||||||
ctx.Reply(update, ext.ReplyTextStyledTextArray(text), nil)
|
userID, _ := strconv.Atoi(args[1])
|
||||||
|
storageNameHash := args[2]
|
||||||
|
if userID != int(update.CallbackQuery.GetUserID()) {
|
||||||
|
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
||||||
|
QueryID: update.CallbackQuery.QueryID,
|
||||||
|
Alert: true,
|
||||||
|
Message: "你没有权限",
|
||||||
|
CacheTime: 5,
|
||||||
|
})
|
||||||
return dispatcher.EndGroups
|
return dispatcher.EndGroups
|
||||||
}
|
}
|
||||||
storageName := args[1]
|
storageName := storageHashName[storageNameHash]
|
||||||
if _, ok := storage.Storages[types.StorageType(storageName)]; !ok {
|
selectedStorage, err := storage.GetStorageByName(storageName)
|
||||||
ctx.Reply(update, ext.ReplyTextString("存储位置不存在"), nil)
|
|
||||||
|
if err != nil {
|
||||||
|
logger.L.Errorf("获取指定存储失败: %s", err)
|
||||||
|
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
||||||
|
QueryID: update.CallbackQuery.QueryID,
|
||||||
|
Alert: true,
|
||||||
|
Message: "获取指定存储失败",
|
||||||
|
CacheTime: 5,
|
||||||
|
})
|
||||||
return dispatcher.EndGroups
|
return dispatcher.EndGroups
|
||||||
}
|
}
|
||||||
path := strings.Join(args[2:], " ")
|
user, err := dao.GetUserByChatID(int64(userID))
|
||||||
switch storageName {
|
if err != nil {
|
||||||
case "local":
|
logger.L.Errorf("Failed to get user: %s", err)
|
||||||
config.Set("storage.local.base_path", path)
|
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
||||||
case "webdav":
|
QueryID: update.CallbackQuery.QueryID,
|
||||||
config.Set("storage.webdav.base_path", path)
|
Alert: true,
|
||||||
case "alist":
|
Message: "获取用户失败",
|
||||||
config.Set("storage.alist.base_path", path)
|
CacheTime: 5,
|
||||||
}
|
})
|
||||||
if err := config.ReloadConfig(); err != nil {
|
|
||||||
logger.L.Errorf("Failed to reload config: %s", err)
|
|
||||||
ctx.Reply(update, ext.ReplyTextString("设置失败: "+err.Error()), nil)
|
|
||||||
return dispatcher.EndGroups
|
return dispatcher.EndGroups
|
||||||
}
|
}
|
||||||
ctx.Reply(update, ext.ReplyTextString("设置成功"), nil)
|
user.DefaultStorage = storageName
|
||||||
|
if err := dao.UpdateUser(user); err != nil {
|
||||||
|
logger.L.Errorf("Failed to update user: %s", err)
|
||||||
|
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
||||||
|
QueryID: update.CallbackQuery.QueryID,
|
||||||
|
Alert: true,
|
||||||
|
Message: "更新用户失败",
|
||||||
|
CacheTime: 5,
|
||||||
|
})
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
|
||||||
|
Message: fmt.Sprintf("已将 %s (%s) 设为默认存储位置", selectedStorage.Name(), selectedStorage.Type()),
|
||||||
|
ID: update.CallbackQuery.GetMsgID(),
|
||||||
|
})
|
||||||
return dispatcher.EndGroups
|
return dispatcher.EndGroups
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -283,21 +285,27 @@ func handleFileMessage(ctx *ext.Context, update *ext.Update) error {
|
|||||||
return dispatcher.EndGroups
|
return dispatcher.EndGroups
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := dao.GetUserByUserID(update.GetUserChat().GetID())
|
user, err := dao.GetUserByChatID(update.GetUserChat().GetID())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.L.Errorf("Failed to get user: %s", err)
|
logger.L.Errorf("获取用户失败: %s", err)
|
||||||
|
ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil)
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
storages := storage.GetUserStorages(user.ChatID)
|
||||||
|
if len(storages) == 0 {
|
||||||
|
ctx.Reply(update, ext.ReplyTextString("无可用的存储"), nil)
|
||||||
return dispatcher.EndGroups
|
return dispatcher.EndGroups
|
||||||
}
|
}
|
||||||
|
|
||||||
msg, err := ctx.Reply(update, ext.ReplyTextString("正在获取文件信息..."), nil)
|
msg, err := ctx.Reply(update, ext.ReplyTextString("正在获取文件信息..."), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.L.Errorf("Failed to reply: %s", err)
|
logger.L.Errorf("回复失败: %s", err)
|
||||||
return dispatcher.EndGroups
|
return dispatcher.EndGroups
|
||||||
}
|
}
|
||||||
media := update.EffectiveMessage.Media
|
media := update.EffectiveMessage.Media
|
||||||
file, err := FileFromMedia(media, "")
|
file, err := FileFromMedia(media, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.L.Errorf("Failed to get file from media: %s", err)
|
logger.L.Errorf("获取文件失败: %s", err)
|
||||||
ctx.Reply(update, ext.ReplyTextString(fmt.Sprintf("获取文件失败: %s", err)), nil)
|
ctx.Reply(update, ext.ReplyTextString(fmt.Sprintf("获取文件失败: %s", err)), nil)
|
||||||
return dispatcher.EndGroups
|
return dispatcher.EndGroups
|
||||||
}
|
}
|
||||||
@@ -313,33 +321,34 @@ func handleFileMessage(ctx *ext.Context, update *ext.Update) error {
|
|||||||
ReplyMessageID: msg.ID,
|
ReplyMessageID: msg.ID,
|
||||||
ReplyChatID: update.GetUserChat().GetID(),
|
ReplyChatID: update.GetUserChat().GetID(),
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
logger.L.Errorf("Failed to add received file: %s", err)
|
logger.L.Errorf("添加接收的文件失败: %s", err)
|
||||||
if _, err := ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
|
if _, err := ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
|
||||||
Message: fmt.Sprintf("Failed to add received file: %s", err),
|
Message: fmt.Sprintf("添加接收的文件失败: %s", err),
|
||||||
ID: msg.ID,
|
ID: msg.ID,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
logger.L.Errorf("Failed to edit message: %s", err)
|
logger.L.Errorf("编辑消息失败: %s", err)
|
||||||
}
|
}
|
||||||
return dispatcher.EndGroups
|
return dispatcher.EndGroups
|
||||||
}
|
}
|
||||||
|
|
||||||
if !user.Silent {
|
if !user.Silent || user.DefaultStorage == "" {
|
||||||
return ProvideSelectMessage(ctx, update, file, int(update.EffectiveChat().GetID()), update.EffectiveMessage.ID, msg.ID)
|
return ProvideSelectMessage(ctx, update, file, update.EffectiveChat().GetID(), update.EffectiveMessage.ID, msg.ID)
|
||||||
}
|
}
|
||||||
return HandleSilentAddTask(ctx, update, user, &types.Task{
|
return HandleSilentAddTask(ctx, update, user, &types.Task{
|
||||||
Ctx: ctx,
|
Ctx: ctx,
|
||||||
Status: types.Pending,
|
Status: types.Pending,
|
||||||
File: file,
|
File: file,
|
||||||
Storage: types.StorageType(user.DefaultStorage),
|
StorageName: user.DefaultStorage,
|
||||||
FileChatID: update.EffectiveChat().GetID(),
|
FileChatID: update.EffectiveChat().GetID(),
|
||||||
ReplyMessageID: msg.ID,
|
ReplyMessageID: msg.ID,
|
||||||
ReplyChatID: update.GetUserChat().GetID(),
|
ReplyChatID: update.GetUserChat().GetID(),
|
||||||
FileMessageID: update.EffectiveMessage.ID,
|
FileMessageID: update.EffectiveMessage.ID,
|
||||||
|
UserID: user.ChatID,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func AddToQueue(ctx *ext.Context, update *ext.Update) error {
|
func AddToQueue(ctx *ext.Context, update *ext.Update) error {
|
||||||
if !slice.Contain(config.Cfg.Telegram.Admins, update.CallbackQuery.UserID) {
|
if !slice.Contain(config.Cfg.GetUsersID(), update.CallbackQuery.UserID) {
|
||||||
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
||||||
QueryID: update.CallbackQuery.QueryID,
|
QueryID: update.CallbackQuery.QueryID,
|
||||||
Alert: true,
|
Alert: true,
|
||||||
@@ -349,13 +358,24 @@ func AddToQueue(ctx *ext.Context, update *ext.Update) error {
|
|||||||
return dispatcher.EndGroups
|
return dispatcher.EndGroups
|
||||||
}
|
}
|
||||||
args := strings.Split(string(update.CallbackQuery.Data), " ")
|
args := strings.Split(string(update.CallbackQuery.Data), " ")
|
||||||
chatID, _ := strconv.Atoi(args[1])
|
fileChatID, _ := strconv.Atoi(args[1])
|
||||||
messageID, _ := strconv.Atoi(args[2])
|
fileMessageID, _ := strconv.Atoi(args[2])
|
||||||
storageName := args[3]
|
storageNameHash := args[3]
|
||||||
logger.L.Tracef("Got add to queue: chatID: %d, messageID: %d, storage: %s", chatID, messageID, storageName)
|
storageName := storageHashName[storageNameHash]
|
||||||
record, err := dao.GetReceivedFileByChatAndMessageID(int64(chatID), messageID)
|
if storageName == "" {
|
||||||
|
logger.L.Errorf("未知存储位置哈希: %d", storageNameHash)
|
||||||
|
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
||||||
|
QueryID: update.CallbackQuery.QueryID,
|
||||||
|
Alert: true,
|
||||||
|
Message: "未知存储位置",
|
||||||
|
CacheTime: 5,
|
||||||
|
})
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
logger.L.Tracef("Got add to queue: chatID: %d, messageID: %d, storage: %s", fileChatID, fileMessageID, storageName)
|
||||||
|
record, err := dao.GetReceivedFileByChatAndMessageID(int64(fileChatID), fileMessageID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.L.Errorf("Failed to get received file: %s", err)
|
logger.L.Errorf("获取记录失败: %s", err)
|
||||||
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
||||||
QueryID: update.CallbackQuery.QueryID,
|
QueryID: update.CallbackQuery.QueryID,
|
||||||
Alert: true,
|
Alert: true,
|
||||||
@@ -367,13 +387,12 @@ func AddToQueue(ctx *ext.Context, update *ext.Update) error {
|
|||||||
if update.CallbackQuery.MsgID != record.ReplyMessageID {
|
if update.CallbackQuery.MsgID != record.ReplyMessageID {
|
||||||
record.ReplyMessageID = update.CallbackQuery.MsgID
|
record.ReplyMessageID = update.CallbackQuery.MsgID
|
||||||
if err := dao.SaveReceivedFile(record); err != nil {
|
if err := dao.SaveReceivedFile(record); err != nil {
|
||||||
logger.L.Errorf("Failed to update received file: %s", err)
|
logger.L.Errorf("更新接收的文件失败: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
file, err := FileFromMessage(ctx, record.ChatID, record.MessageID, record.FileName)
|
file, err := FileFromMessage(ctx, record.ChatID, record.MessageID, record.FileName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.L.Errorf("Failed to get file from message: %s", err)
|
logger.L.Errorf("获取消息中的文件失败: %s", err)
|
||||||
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
||||||
QueryID: update.CallbackQuery.QueryID,
|
QueryID: update.CallbackQuery.QueryID,
|
||||||
Alert: true,
|
Alert: true,
|
||||||
@@ -387,11 +406,12 @@ func AddToQueue(ctx *ext.Context, update *ext.Update) error {
|
|||||||
Ctx: ctx,
|
Ctx: ctx,
|
||||||
Status: types.Pending,
|
Status: types.Pending,
|
||||||
File: file,
|
File: file,
|
||||||
Storage: types.StorageType(storageName),
|
StorageName: storageName,
|
||||||
FileChatID: record.ChatID,
|
FileChatID: record.ChatID,
|
||||||
ReplyMessageID: record.ReplyMessageID,
|
ReplyMessageID: record.ReplyMessageID,
|
||||||
FileMessageID: record.MessageID,
|
FileMessageID: record.MessageID,
|
||||||
ReplyChatID: record.ReplyChatID,
|
ReplyChatID: record.ReplyChatID,
|
||||||
|
UserID: update.EffectiveUser().GetID(),
|
||||||
})
|
})
|
||||||
|
|
||||||
entityBuilder := entity.Builder{}
|
entityBuilder := entity.Builder{}
|
||||||
|
|||||||
101
bot/utils.go
101
bot/utils.go
@@ -11,6 +11,7 @@ import (
|
|||||||
"github.com/gotd/td/telegram/message/styling"
|
"github.com/gotd/td/telegram/message/styling"
|
||||||
"github.com/gotd/td/tg"
|
"github.com/gotd/td/tg"
|
||||||
"github.com/krau/SaveAny-Bot/common"
|
"github.com/krau/SaveAny-Bot/common"
|
||||||
|
"github.com/krau/SaveAny-Bot/dao"
|
||||||
"github.com/krau/SaveAny-Bot/logger"
|
"github.com/krau/SaveAny-Bot/logger"
|
||||||
"github.com/krau/SaveAny-Bot/queue"
|
"github.com/krau/SaveAny-Bot/queue"
|
||||||
"github.com/krau/SaveAny-Bot/storage"
|
"github.com/krau/SaveAny-Bot/storage"
|
||||||
@@ -22,6 +23,7 @@ var (
|
|||||||
ErrEmptyPhoto = errors.New("photo is empty")
|
ErrEmptyPhoto = errors.New("photo is empty")
|
||||||
ErrEmptyPhotoSize = errors.New("photo size is empty")
|
ErrEmptyPhotoSize = errors.New("photo size is empty")
|
||||||
ErrEmptyPhotoSizes = errors.New("photo size slice is empty")
|
ErrEmptyPhotoSizes = errors.New("photo size slice is empty")
|
||||||
|
ErrNoStorages = errors.New("no available storage")
|
||||||
)
|
)
|
||||||
|
|
||||||
func supportedMediaFilter(m *tg.Message) (bool, error) {
|
func supportedMediaFilter(m *tg.Message) (bool, error) {
|
||||||
@@ -38,49 +40,52 @@ func supportedMediaFilter(m *tg.Message) (bool, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var StorageDisplayNames = map[string]string{
|
// for callback data
|
||||||
"all": "全部",
|
var storageHashName = map[string]string{}
|
||||||
"local": "服务器磁盘",
|
|
||||||
"alist": "Alist",
|
|
||||||
"webdav": "WebDAV",
|
|
||||||
}
|
|
||||||
|
|
||||||
func getAddTaskMarkup(chatID, messageID int) *tg.ReplyInlineMarkup {
|
func getSelectStorageMarkup(userChatID int64, fileChatID, fileMessageID int) (*tg.ReplyInlineMarkup, error) {
|
||||||
storageButtons := make([]tg.KeyboardButtonClass, 0)
|
user, err := dao.GetUserByChatID(userChatID)
|
||||||
for _, name := range storage.StorageKeys {
|
if err != nil {
|
||||||
storageButtons = append(storageButtons, &tg.KeyboardButtonCallback{
|
return nil, err
|
||||||
Text: StorageDisplayNames[string(name)],
|
}
|
||||||
Data: []byte(fmt.Sprintf("add %d %d %s", chatID, messageID, name)),
|
storages := storage.GetUserStorages(user.ChatID)
|
||||||
|
|
||||||
|
buttons := make([]tg.KeyboardButtonClass, 0)
|
||||||
|
for _, storage := range storages {
|
||||||
|
nameHash := common.HashString(storage.Name())
|
||||||
|
storageHashName[nameHash] = storage.Name()
|
||||||
|
buttons = append(buttons, &tg.KeyboardButtonCallback{
|
||||||
|
Text: storage.Name(),
|
||||||
|
Data: []byte(fmt.Sprintf("add %d %d %s", fileChatID, fileMessageID, nameHash)),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
markup := &tg.ReplyInlineMarkup{}
|
||||||
|
for i := 0; i < len(buttons); i += 3 {
|
||||||
|
row := tg.KeyboardButtonRow{}
|
||||||
|
row.Buttons = buttons[i:min(i+3, len(buttons))]
|
||||||
|
markup.Rows = append(markup.Rows, row)
|
||||||
|
}
|
||||||
|
return markup, nil
|
||||||
|
}
|
||||||
|
|
||||||
if len(storageButtons) < 1 {
|
func getSetDefaultStorageMarkup(userChatID int64, storages []storage.Storage) *tg.ReplyInlineMarkup {
|
||||||
return nil
|
buttons := make([]tg.KeyboardButtonClass, 0)
|
||||||
|
for _, storage := range storages {
|
||||||
|
nameHash := common.HashString(storage.Name())
|
||||||
|
storageHashName[nameHash] = storage.Name()
|
||||||
|
buttons = append(buttons, &tg.KeyboardButtonCallback{
|
||||||
|
Text: storage.Name(),
|
||||||
|
Data: []byte(fmt.Sprintf("set_default %d %s", userChatID, nameHash)),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
if len(storageButtons) == 1 {
|
markup := &tg.ReplyInlineMarkup{}
|
||||||
return &tg.ReplyInlineMarkup{
|
for i := 0; i < len(buttons); i += 3 {
|
||||||
Rows: []tg.KeyboardButtonRow{
|
row := tg.KeyboardButtonRow{}
|
||||||
{
|
row.Buttons = buttons[i:min(i+3, len(buttons))]
|
||||||
Buttons: storageButtons,
|
markup.Rows = append(markup.Rows, row)
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return &tg.ReplyInlineMarkup{
|
|
||||||
Rows: []tg.KeyboardButtonRow{
|
|
||||||
{
|
|
||||||
Buttons: storageButtons,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Buttons: []tg.KeyboardButtonClass{
|
|
||||||
&tg.KeyboardButtonCallback{
|
|
||||||
Text: "全部",
|
|
||||||
Data: []byte(fmt.Sprintf("add %d %d all", chatID, messageID)),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
return markup
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func FileFromMedia(media tg.MessageMediaClass, customFileName string) (*types.File, error) {
|
func FileFromMedia(media tg.MessageMediaClass, customFileName string) (*types.File, error) {
|
||||||
@@ -181,7 +186,7 @@ func GetTGMessage(ctx *ext.Context, chatId int64, messageID int) (*tg.Message, e
|
|||||||
return tgMessage, nil
|
return tgMessage, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ProvideSelectMessage(ctx *ext.Context, update *ext.Update, file *types.File, chatID int, fileMsgID, toEditMsgID int) error {
|
func ProvideSelectMessage(ctx *ext.Context, update *ext.Update, file *types.File, chatID int64, fileMsgID, toEditMsgID int) error {
|
||||||
entityBuilder := entity.Builder{}
|
entityBuilder := entity.Builder{}
|
||||||
var entities []tg.MessageEntityClass
|
var entities []tg.MessageEntityClass
|
||||||
text := fmt.Sprintf("文件名: %s\n请选择存储位置", file.FileName)
|
text := fmt.Sprintf("文件名: %s\n请选择存储位置", file.FileName)
|
||||||
@@ -194,10 +199,26 @@ func ProvideSelectMessage(ctx *ext.Context, update *ext.Update, file *types.File
|
|||||||
} else {
|
} else {
|
||||||
text, entities = entityBuilder.Complete()
|
text, entities = entityBuilder.Complete()
|
||||||
}
|
}
|
||||||
_, err := ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
|
markup, err := getSelectStorageMarkup(update.EffectiveUser().GetID(), int(chatID), fileMsgID)
|
||||||
|
if errors.Is(err, ErrNoStorages) {
|
||||||
|
logger.L.Errorf("Failed to get select storage markup: %s", err)
|
||||||
|
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
|
||||||
|
Message: "无可用存储",
|
||||||
|
ID: toEditMsgID,
|
||||||
|
})
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
} else if err != nil {
|
||||||
|
logger.L.Errorf("Failed to get select storage markup: %s", err)
|
||||||
|
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
|
||||||
|
Message: "无法获取存储",
|
||||||
|
ID: toEditMsgID,
|
||||||
|
})
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
_, err = ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
|
||||||
Message: text,
|
Message: text,
|
||||||
Entities: entities,
|
Entities: entities,
|
||||||
ReplyMarkup: getAddTaskMarkup(chatID, fileMsgID),
|
ReplyMarkup: markup,
|
||||||
ID: toEditMsgID,
|
ID: toEditMsgID,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
12
common/utils.go
Normal file
12
common/utils.go
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/md5"
|
||||||
|
"encoding/hex"
|
||||||
|
)
|
||||||
|
|
||||||
|
func HashString(s string) string {
|
||||||
|
hash := md5.New()
|
||||||
|
hash.Write([]byte(s))
|
||||||
|
return hex.EncodeToString(hash.Sum(nil))
|
||||||
|
}
|
||||||
@@ -4,11 +4,9 @@ retry = 3 # 下载失败重试次数
|
|||||||
[telegram]
|
[telegram]
|
||||||
# Bot Token
|
# Bot Token
|
||||||
token = ""
|
token = ""
|
||||||
# 允许使用的用户 id 列表
|
|
||||||
admins = [777000]
|
|
||||||
# Telegram API 配置, 若不配置也可运行, 将使用默认的 API ID 和 API HASH
|
# Telegram API 配置, 若不配置也可运行, 将使用默认的 API ID 和 API HASH
|
||||||
# 推荐使用自己的 API ID 和 API HASH (https://my.telegram.org)
|
# 推荐使用自己的 API ID 和 API HASH (https://my.telegram.org)
|
||||||
# app_id = 123456
|
# app_id = 123456
|
||||||
# app_hash = "0123456789abcdef0123456789abcdef"
|
# app_hash = "0123456789abcdef0123456789abcdef"
|
||||||
|
|
||||||
[telegram.proxy]
|
[telegram.proxy]
|
||||||
@@ -17,37 +15,70 @@ enable = false
|
|||||||
url = "socks5://127.0.0.1:7890"
|
url = "socks5://127.0.0.1:7890"
|
||||||
|
|
||||||
|
|
||||||
[storage]
|
# 存储配置列表
|
||||||
[storage.alist] # Alist
|
[[storages]]
|
||||||
|
# 标识名, 需要唯一
|
||||||
|
name = "本机1"
|
||||||
|
# 存储类型, 目前可用: local , alist , webdav
|
||||||
|
type = "local"
|
||||||
enable = true
|
enable = true
|
||||||
base_path = "/telegram" # 保存路径
|
base_path = "./downloads"
|
||||||
username = "admin" # 用户名
|
|
||||||
password = "password" # 密码
|
|
||||||
url = "https://alist.com" # Alist 地址
|
|
||||||
token_exp = 86400 # token 过期时间, 单位: 秒
|
|
||||||
# 可直接使用 token 授权, 此时不能自动刷新登录信息
|
|
||||||
# 配置 token 后, username , password , token_exp 将被忽略
|
|
||||||
token = "jwt_token"
|
|
||||||
|
|
||||||
[storage.local] # 本地磁盘
|
[[storages]]
|
||||||
|
name = "本机2"
|
||||||
|
type = "local"
|
||||||
enable = true
|
enable = true
|
||||||
base_path = "downloads/" # 保存路径
|
base_path = "./downloads/2"
|
||||||
|
|
||||||
[storage.webdav] # WebDav
|
[[storages]]
|
||||||
|
name = "MyAlist"
|
||||||
|
type = "alist"
|
||||||
enable = true
|
enable = true
|
||||||
base_path = "/telegram"
|
base_path = '/'
|
||||||
username = "admin"
|
url = 'https://alist.com'
|
||||||
password = "password"
|
username = 'admin'
|
||||||
url = "https://alist.com/dav"
|
password = 'password'
|
||||||
|
token_exp = 86400
|
||||||
|
# alist 可直接使用 token 登录, 此时 username, password, token_exp 将被忽略
|
||||||
|
# 请自行在 alist 侧配置合理的 token 过期时间
|
||||||
|
# token = ""
|
||||||
|
|
||||||
|
|
||||||
[log]
|
[[storages]]
|
||||||
# 日志等级
|
name = "MyWebdav"
|
||||||
level = "DEBUG"
|
type = "webdav"
|
||||||
|
base_path = '/path/telegram'
|
||||||
|
enable = true
|
||||||
|
url = 'https://example.com/dav'
|
||||||
|
username = 'username'
|
||||||
|
password = 'password'
|
||||||
|
|
||||||
[temp]
|
|
||||||
base_path = "cache/" # 下载文件临时目录, 请不要在此目录下存放任何其他文件
|
|
||||||
cache_ttl = 30 # 临时文件保存时间, 单位: 秒
|
|
||||||
|
|
||||||
[db]
|
# 用户列表
|
||||||
path = "data/data.db" # 数据库文件路径
|
[[users]]
|
||||||
|
# user id
|
||||||
|
id = 123456
|
||||||
|
# 存储名称过滤列表
|
||||||
|
storages = ["MyAlist", "本机1"]
|
||||||
|
# 开启黑名单模式, 过滤列表中的存储将无法使用, 默认为白名单模式
|
||||||
|
blacklist = false
|
||||||
|
|
||||||
|
[[users]]
|
||||||
|
id = 114514
|
||||||
|
# 将列表留空并开启黑名单模式以允许使用所有存储
|
||||||
|
storages = []
|
||||||
|
blacklist = true
|
||||||
|
|
||||||
|
|
||||||
|
# [log]
|
||||||
|
# # 日志等级
|
||||||
|
# level = "DEBUG"
|
||||||
|
|
||||||
|
# [temp]
|
||||||
|
# # 下载文件临时目录, 请不要在此目录下存放任何其他文件
|
||||||
|
# base_path = "cache/"
|
||||||
|
# # 临时文件保存时间, 单位: 秒
|
||||||
|
# cache_ttl = 30
|
||||||
|
|
||||||
|
# [db]
|
||||||
|
# path = "data/data.db" # 数据库文件路径
|
||||||
|
|||||||
95
config/deprecated.go
Normal file
95
config/deprecated.go
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/krau/SaveAny-Bot/types"
|
||||||
|
"gorm.io/datatypes"
|
||||||
|
)
|
||||||
|
|
||||||
|
// for compatibility
|
||||||
|
type deprecatedStorageConfig struct {
|
||||||
|
Alist alistConfig `toml:"alist" mapstructure:"alist"`
|
||||||
|
Local localConfig `toml:"local" mapstructure:"local"`
|
||||||
|
Webdav webdavConfig `toml:"webdav" mapstructure:"webdav"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type alistConfig struct {
|
||||||
|
Enable bool `toml:"enable" mapstructure:"enable" json:"enable"`
|
||||||
|
URL string `toml:"url" mapstructure:"url" json:"url"`
|
||||||
|
Username string `toml:"username" mapstructure:"username" json:"username"`
|
||||||
|
Password string `toml:"password" mapstructure:"password" json:"password"`
|
||||||
|
Token string `toml:"token" mapstructure:"token" json:"token"`
|
||||||
|
BasePath string `toml:"base_path" mapstructure:"base_path" json:"base_path"`
|
||||||
|
TokenExp int64 `toml:"token_exp" mapstructure:"token_exp" json:"token_exp"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *alistConfig) ToJSON() datatypes.JSON {
|
||||||
|
tokenExp := strconv.FormatInt(a.TokenExp, 10)
|
||||||
|
return datatypes.JSON([]byte(`{"url":"` + a.URL + `","username":"` + a.Username + `","password":"` + a.Password + `","token":"` + a.Token + `","base_path":"` + a.BasePath + `","token_exp":` + tokenExp + `}`))
|
||||||
|
}
|
||||||
|
|
||||||
|
type localConfig struct {
|
||||||
|
Enable bool `toml:"enable" mapstructure:"enable" json:"enable"`
|
||||||
|
BasePath string `toml:"base_path" mapstructure:"base_path" json:"base_path"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *localConfig) ToJSON() datatypes.JSON {
|
||||||
|
return datatypes.JSON([]byte(`{"base_path":"` + l.BasePath + `"}`))
|
||||||
|
}
|
||||||
|
|
||||||
|
type webdavConfig struct {
|
||||||
|
Enable bool `toml:"enable" mapstructure:"enable" json:"enable"`
|
||||||
|
URL string `toml:"url" mapstructure:"url" json:"url"`
|
||||||
|
Username string `toml:"username" mapstructure:"username" json:"username"`
|
||||||
|
Password string `toml:"password" mapstructure:"password" json:"password"`
|
||||||
|
BasePath string `toml:"base_path" mapstructure:"base_path" json:"base_path"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *webdavConfig) ToJSON() datatypes.JSON {
|
||||||
|
return datatypes.JSON([]byte(`{"url":"` + w.URL + `","username":"` + w.Username + `","password":"` + w.Password + `","base_path":"` + w.BasePath + `"}`))
|
||||||
|
}
|
||||||
|
|
||||||
|
func transformDeprecatedStorageConfig() {
|
||||||
|
if Cfg.DeprecatedStorage.Alist.Enable {
|
||||||
|
alistStorage := &AlistStorageConfig{
|
||||||
|
NewStorageConfig: NewStorageConfig{
|
||||||
|
Name: "Alist",
|
||||||
|
Enable: true,
|
||||||
|
Type: string(types.StorageTypeAlist),
|
||||||
|
},
|
||||||
|
URL: Cfg.DeprecatedStorage.Alist.URL,
|
||||||
|
Username: Cfg.DeprecatedStorage.Alist.Username,
|
||||||
|
Password: Cfg.DeprecatedStorage.Alist.Password,
|
||||||
|
Token: Cfg.DeprecatedStorage.Alist.Token,
|
||||||
|
BasePath: Cfg.DeprecatedStorage.Alist.BasePath,
|
||||||
|
TokenExp: Cfg.DeprecatedStorage.Alist.TokenExp,
|
||||||
|
}
|
||||||
|
Cfg.Storages = append(Cfg.Storages, alistStorage)
|
||||||
|
}
|
||||||
|
if Cfg.DeprecatedStorage.Local.Enable {
|
||||||
|
localStorage := &LocalStorageConfig{
|
||||||
|
NewStorageConfig: NewStorageConfig{
|
||||||
|
Name: "Local",
|
||||||
|
Enable: true,
|
||||||
|
Type: string(types.StorageTypeLocal),
|
||||||
|
},
|
||||||
|
BasePath: Cfg.DeprecatedStorage.Local.BasePath,
|
||||||
|
}
|
||||||
|
Cfg.Storages = append(Cfg.Storages, localStorage)
|
||||||
|
}
|
||||||
|
if Cfg.DeprecatedStorage.Webdav.Enable {
|
||||||
|
webdavStorage := &WebdavStorageConfig{
|
||||||
|
NewStorageConfig: NewStorageConfig{
|
||||||
|
Name: "Webdav",
|
||||||
|
Enable: true,
|
||||||
|
Type: string(types.StorageTypeWebdav),
|
||||||
|
},
|
||||||
|
URL: Cfg.DeprecatedStorage.Webdav.URL,
|
||||||
|
Username: Cfg.DeprecatedStorage.Webdav.Username,
|
||||||
|
Password: Cfg.DeprecatedStorage.Webdav.Password,
|
||||||
|
BasePath: Cfg.DeprecatedStorage.Webdav.BasePath,
|
||||||
|
}
|
||||||
|
Cfg.Storages = append(Cfg.Storages, webdavStorage)
|
||||||
|
}
|
||||||
|
}
|
||||||
104
config/storage_factory.go
Normal file
104
config/storage_factory.go
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
// storage_config.go
|
||||||
|
|
||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/krau/SaveAny-Bot/types"
|
||||||
|
"github.com/mitchellh/mapstructure"
|
||||||
|
"github.com/spf13/viper"
|
||||||
|
)
|
||||||
|
|
||||||
|
type StorageConfig interface {
|
||||||
|
Validate() error
|
||||||
|
GetType() types.StorageType
|
||||||
|
GetName() string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Base storage config
|
||||||
|
type NewStorageConfig struct {
|
||||||
|
Name string `toml:"name" mapstructure:"name" json:"name"`
|
||||||
|
Type string `toml:"type" mapstructure:"type" json:"type"`
|
||||||
|
Enable bool `toml:"enable" mapstructure:"enable" json:"enable"`
|
||||||
|
RawConfig map[string]interface{} `toml:"-" mapstructure:",remain"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type StorageConfigFactory func(cfg *NewStorageConfig) (StorageConfig, error)
|
||||||
|
|
||||||
|
var storageFactories = make(map[string]StorageConfigFactory)
|
||||||
|
|
||||||
|
func RegisterStorageFactory(storageType string, factory StorageConfigFactory) {
|
||||||
|
storageFactories[storageType] = factory
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
RegisterStorageFactory(string(types.StorageTypeLocal), newLocalStorageConfig)
|
||||||
|
RegisterStorageFactory(string(types.StorageTypeAlist), newAlistStorageConfig)
|
||||||
|
RegisterStorageFactory(string(types.StorageTypeWebdav), newWebdavStorageConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newLocalStorageConfig(cfg *NewStorageConfig) (StorageConfig, error) {
|
||||||
|
var localCfg LocalStorageConfig
|
||||||
|
localCfg.NewStorageConfig = *cfg
|
||||||
|
|
||||||
|
if err := mapstructure.Decode(cfg.RawConfig, &localCfg); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode local storage config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &localCfg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAlistStorageConfig(cfg *NewStorageConfig) (StorageConfig, error) {
|
||||||
|
var alistCfg AlistStorageConfig
|
||||||
|
alistCfg.NewStorageConfig = *cfg
|
||||||
|
|
||||||
|
if err := mapstructure.Decode(cfg.RawConfig, &alistCfg); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode alist storage config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &alistCfg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newWebdavStorageConfig(cfg *NewStorageConfig) (StorageConfig, error) {
|
||||||
|
var webdavCfg WebdavStorageConfig
|
||||||
|
webdavCfg.NewStorageConfig = *cfg
|
||||||
|
|
||||||
|
if err := mapstructure.Decode(cfg.RawConfig, &webdavCfg); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode webdav storage config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &webdavCfg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadStorageConfigs(v *viper.Viper) ([]StorageConfig, error) {
|
||||||
|
var baseConfigs []NewStorageConfig
|
||||||
|
if err := v.UnmarshalKey("storages", &baseConfigs); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal storage configs: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var configs []StorageConfig
|
||||||
|
for _, baseCfg := range baseConfigs {
|
||||||
|
if !baseCfg.Enable {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
factory, ok := storageFactories[baseCfg.Type]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("unsupported storage type: %s", baseCfg.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := factory(&baseCfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create storage config for %s: %w", baseCfg.Name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := cfg.Validate(); err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid storage config for %s: %w", baseCfg.Name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
configs = append(configs, cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
return configs, nil
|
||||||
|
}
|
||||||
106
config/storages.go
Normal file
106
config/storages.go
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/krau/SaveAny-Bot/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (c *Config) GetStoragesByType(storageType types.StorageType) []StorageConfig {
|
||||||
|
var storages []StorageConfig
|
||||||
|
for _, storage := range c.Storages {
|
||||||
|
if storage.GetType() == storageType {
|
||||||
|
storages = append(storages, storage)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return storages
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) GetStorageByName(name string) StorageConfig {
|
||||||
|
for _, storage := range c.Storages {
|
||||||
|
if storage.GetName() == name {
|
||||||
|
return storage
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type LocalStorageConfig struct {
|
||||||
|
NewStorageConfig
|
||||||
|
BasePath string `toml:"base_path" mapstructure:"base_path" json:"base_path"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *LocalStorageConfig) Validate() error {
|
||||||
|
if l.BasePath == "" {
|
||||||
|
return fmt.Errorf("path is required for local storage")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *LocalStorageConfig) GetType() types.StorageType {
|
||||||
|
return types.StorageTypeLocal
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *LocalStorageConfig) GetName() string {
|
||||||
|
return l.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
type AlistStorageConfig struct {
|
||||||
|
NewStorageConfig
|
||||||
|
URL string `toml:"url" mapstructure:"url" json:"url"`
|
||||||
|
Username string `toml:"username" mapstructure:"username" json:"username"`
|
||||||
|
Password string `toml:"password" mapstructure:"password" json:"password"`
|
||||||
|
Token string `toml:"token" mapstructure:"token" json:"token"`
|
||||||
|
BasePath string `toml:"base_path" mapstructure:"base_path" json:"base_path"`
|
||||||
|
TokenExp int64 `toml:"token_exp" mapstructure:"token_exp" json:"token_exp"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AlistStorageConfig) Validate() error {
|
||||||
|
if a.URL == "" {
|
||||||
|
return fmt.Errorf("url is required for alist storage")
|
||||||
|
}
|
||||||
|
if a.Token == "" && (a.Username == "" || a.Password == "") {
|
||||||
|
return fmt.Errorf("username and password or token is required for alist storage")
|
||||||
|
}
|
||||||
|
if a.BasePath == "" {
|
||||||
|
return fmt.Errorf("base_path is required for alist storage")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AlistStorageConfig) GetType() types.StorageType {
|
||||||
|
return types.StorageTypeAlist
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AlistStorageConfig) GetName() string {
|
||||||
|
return a.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
type WebdavStorageConfig struct {
|
||||||
|
NewStorageConfig
|
||||||
|
URL string `toml:"url" mapstructure:"url" json:"url"`
|
||||||
|
Username string `toml:"username" mapstructure:"username" json:"username"`
|
||||||
|
Password string `toml:"password" mapstructure:"password" json:"password"`
|
||||||
|
BasePath string `toml:"base_path" mapstructure:"base_path" json:"base_path"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *WebdavStorageConfig) Validate() error {
|
||||||
|
if w.URL == "" {
|
||||||
|
return fmt.Errorf("url is required for webdav storage")
|
||||||
|
}
|
||||||
|
if w.Username == "" || w.Password == "" {
|
||||||
|
return fmt.Errorf("username and password is required for webdav storage")
|
||||||
|
}
|
||||||
|
if w.BasePath == "" {
|
||||||
|
return fmt.Errorf("base_path is required for webdav storage")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *WebdavStorageConfig) GetType() types.StorageType {
|
||||||
|
return types.StorageTypeWebdav
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *WebdavStorageConfig) GetName() string {
|
||||||
|
return w.Name
|
||||||
|
}
|
||||||
49
config/user.go
Normal file
49
config/user.go
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/duke-git/lancet/v2/slice"
|
||||||
|
)
|
||||||
|
|
||||||
|
type userConfig struct {
|
||||||
|
ID int64 `toml:"id" mapstructure:"id" json:"id"` // telegram user id
|
||||||
|
Storages []string `toml:"storages" mapstructure:"storages" json:"storages"` // storage names
|
||||||
|
Blacklist bool `toml:"blacklist" mapstructure:"blacklist" json:"blacklist"` // 黑名单模式, storage names 中的存储将不会被使用, 默认为白名单模式
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) GetStorageNamesByUserID(userID int64) []string {
|
||||||
|
for _, user := range c.Users {
|
||||||
|
if user.ID == userID {
|
||||||
|
if user.Blacklist {
|
||||||
|
allStorages := make([]string, 0, len(c.Storages))
|
||||||
|
for _, storage := range c.Storages {
|
||||||
|
allStorages = append(allStorages, storage.GetName())
|
||||||
|
}
|
||||||
|
return slice.Compact(slice.Difference(allStorages, user.Storages))
|
||||||
|
} else {
|
||||||
|
return user.Storages
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) GetUsersID() []int64 {
|
||||||
|
var ids []int64
|
||||||
|
for _, user := range c.Users {
|
||||||
|
ids = append(ids, user.ID)
|
||||||
|
}
|
||||||
|
return ids
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) HasStorage(userID int64, storageName string) bool {
|
||||||
|
for _, user := range c.Users {
|
||||||
|
if user.ID == userID {
|
||||||
|
if user.Blacklist {
|
||||||
|
return !slice.Contain(user.Storages, storageName)
|
||||||
|
} else {
|
||||||
|
return slice.Contain(user.Storages, storageName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
127
config/viper.go
127
config/viper.go
@@ -9,26 +9,29 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Workers int `toml:"workers" mapstructure:"workers"`
|
Workers int `toml:"workers" mapstructure:"workers"`
|
||||||
Retry int `toml:"retry" mapstructure:"retry"`
|
Retry int `toml:"retry" mapstructure:"retry"`
|
||||||
NoCleanCache bool `toml:"no_clean_cache" mapstructure:"no_clean_cache"`
|
NoCleanCache bool `toml:"no_clean_cache" mapstructure:"no_clean_cache" json:"no_clean_cache"`
|
||||||
|
Users []userConfig `toml:"users" mapstructure:"users" json:"users"`
|
||||||
|
|
||||||
Temp tempConfig `toml:"temp" mapstructure:"temp"`
|
Temp tempConfig `toml:"temp" mapstructure:"temp"`
|
||||||
Log logConfig `toml:"log" mapstructure:"log"`
|
Log logConfig `toml:"log" mapstructure:"log"`
|
||||||
DB dbConfig `toml:"db" mapstructure:"db"`
|
DB dbConfig `toml:"db" mapstructure:"db"`
|
||||||
Telegram telegramConfig `toml:"telegram" mapstructure:"telegram"`
|
Telegram telegramConfig `toml:"telegram" mapstructure:"telegram"`
|
||||||
Storage storageConfig `toml:"storage" mapstructure:"storage"`
|
Storages []StorageConfig `toml:"-" mapstructure:"-" json:"storages"`
|
||||||
|
// Deprecated
|
||||||
|
DeprecatedStorage deprecatedStorageConfig `toml:"storage" mapstructure:"storage"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type tempConfig struct {
|
type tempConfig struct {
|
||||||
BasePath string `toml:"base_path" mapstructure:"base_path"`
|
BasePath string `toml:"base_path" mapstructure:"base_path" json:"base_path"`
|
||||||
CacheTTL int64 `toml:"cache_ttl" mapstructure:"cache_ttl"`
|
CacheTTL int64 `toml:"cache_ttl" mapstructure:"cache_ttl" json:"cache_ttl"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type logConfig struct {
|
type logConfig struct {
|
||||||
Level string `toml:"level" mapstructure:"level"`
|
Level string `toml:"level" mapstructure:"level"`
|
||||||
File string `toml:"file" mapstructure:"file"`
|
File string `toml:"file" mapstructure:"file"`
|
||||||
BackupCount uint `toml:"backup_count" mapstructure:"backup_count"`
|
BackupCount uint `toml:"backup_count" mapstructure:"backup_count" json:"backup_count"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type dbConfig struct {
|
type dbConfig struct {
|
||||||
@@ -37,10 +40,12 @@ type dbConfig struct {
|
|||||||
|
|
||||||
type telegramConfig struct {
|
type telegramConfig struct {
|
||||||
Token string `toml:"token" mapstructure:"token"`
|
Token string `toml:"token" mapstructure:"token"`
|
||||||
AppID int `toml:"app_id" mapstructure:"app_id"`
|
AppID int `toml:"app_id" mapstructure:"app_id" json:"app_id"`
|
||||||
AppHash string `toml:"app_hash" mapstructure:"app_hash"`
|
AppHash string `toml:"app_hash" mapstructure:"app_hash" json:"app_hash"`
|
||||||
Admins []int64 `toml:"admins" mapstructure:"admins"`
|
|
||||||
Proxy proxyConfig `toml:"proxy" mapstructure:"proxy"`
|
Proxy proxyConfig `toml:"proxy" mapstructure:"proxy"`
|
||||||
|
|
||||||
|
// Deprecated
|
||||||
|
Admins []int64 `toml:"admins" mapstructure:"admins"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type proxyConfig struct {
|
type proxyConfig struct {
|
||||||
@@ -48,38 +53,9 @@ type proxyConfig struct {
|
|||||||
URL string `toml:"url" mapstructure:"url"`
|
URL string `toml:"url" mapstructure:"url"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type storageConfig struct {
|
|
||||||
Alist alistConfig `toml:"alist" mapstructure:"alist"`
|
|
||||||
Local localConfig `toml:"local" mapstructure:"local"`
|
|
||||||
Webdav webdavConfig `toml:"webdav" mapstructure:"webdav"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type alistConfig struct {
|
|
||||||
Enable bool `toml:"enable" mapstructure:"enable"`
|
|
||||||
URL string `toml:"url" mapstructure:"url"`
|
|
||||||
Username string `toml:"username" mapstructure:"username"`
|
|
||||||
Password string `toml:"password" mapstructure:"password"`
|
|
||||||
Token string `toml:"token" mapstructure:"token"`
|
|
||||||
BasePath string `toml:"base_path" mapstructure:"base_path"`
|
|
||||||
TokenExp int64 `toml:"token_exp" mapstructure:"token_exp"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type localConfig struct {
|
|
||||||
Enable bool `toml:"enable" mapstructure:"enable"`
|
|
||||||
BasePath string `toml:"base_path" mapstructure:"base_path"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type webdavConfig struct {
|
|
||||||
Enable bool `toml:"enable" mapstructure:"enable"`
|
|
||||||
URL string `toml:"url" mapstructure:"url"`
|
|
||||||
Username string `toml:"username" mapstructure:"username"`
|
|
||||||
Password string `toml:"password" mapstructure:"password"`
|
|
||||||
BasePath string `toml:"base_path" mapstructure:"base_path"`
|
|
||||||
}
|
|
||||||
|
|
||||||
var Cfg *Config
|
var Cfg *Config
|
||||||
|
|
||||||
func Init() {
|
func Init() error {
|
||||||
viper.SetConfigName("config")
|
viper.SetConfigName("config")
|
||||||
viper.AddConfigPath(".")
|
viper.AddConfigPath(".")
|
||||||
viper.AddConfigPath("/etc/saveany/")
|
viper.AddConfigPath("/etc/saveany/")
|
||||||
@@ -104,10 +80,11 @@ func Init() {
|
|||||||
|
|
||||||
viper.SetDefault("db.path", "data/saveany.db")
|
viper.SetDefault("db.path", "data/saveany.db")
|
||||||
|
|
||||||
viper.SetDefault("storage.alist.base_path", "/")
|
if err := viper.SafeWriteConfigAs("config.toml"); err != nil {
|
||||||
viper.SetDefault("storage.alist.token_exp", 3600)
|
if _, ok := err.(viper.ConfigFileAlreadyExistsError); !ok {
|
||||||
|
return fmt.Errorf("error saving default config: %w", err)
|
||||||
viper.SafeWriteConfigAs("config.toml")
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err := viper.ReadInConfig(); err != nil {
|
if err := viper.ReadInConfig(); err != nil {
|
||||||
fmt.Println("Error reading config file, ", err)
|
fmt.Println("Error reading config file, ", err)
|
||||||
@@ -115,14 +92,62 @@ func Init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Cfg = &Config{}
|
Cfg = &Config{}
|
||||||
|
|
||||||
if err := viper.Unmarshal(Cfg); err != nil {
|
if err := viper.Unmarshal(Cfg); err != nil {
|
||||||
fmt.Println("Error unmarshalling config file, ", err)
|
fmt.Println("Error unmarshalling config file, ", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
if Cfg.Workers < 1 || Cfg.Retry < 1 {
|
|
||||||
fmt.Println("Invalid workers or retry value")
|
if Cfg.Telegram.Admins != nil {
|
||||||
os.Exit(1)
|
fmt.Println("警告: 你正在使用旧版 Telegram 管理员配置, 该配置下的用户将可用所有存储.\ntelegram.admins 未来版本将会被废弃, 请参考新的配置文件模板, 使用 users 配置替代.")
|
||||||
|
for _, admin := range Cfg.Telegram.Admins {
|
||||||
|
found := false
|
||||||
|
for _, user := range Cfg.Users {
|
||||||
|
if user.ID == admin {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if found {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
Cfg.Users = append(Cfg.Users, userConfig{
|
||||||
|
ID: admin,
|
||||||
|
Storages: []string{},
|
||||||
|
Blacklist: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
storagesConfig, err := LoadStorageConfigs(viper.GetViper())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error loading storage configs: %w", err)
|
||||||
|
}
|
||||||
|
Cfg.Storages = storagesConfig
|
||||||
|
|
||||||
|
if Cfg.DeprecatedStorage != (deprecatedStorageConfig{}) {
|
||||||
|
fmt.Println("\n警告: 你正在使用旧版存储配置, 未来版本将会被废弃.\n请参考新的配置文件模板.")
|
||||||
|
transformDeprecatedStorageConfig()
|
||||||
|
}
|
||||||
|
|
||||||
|
storageNames := make(map[string]struct{})
|
||||||
|
for _, storage := range Cfg.Storages {
|
||||||
|
if _, ok := storageNames[storage.GetName()]; ok {
|
||||||
|
return fmt.Errorf("重复的存储名: %s", storage.GetName())
|
||||||
|
}
|
||||||
|
storageNames[storage.GetName()] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("已加载 %d 个存储:\n", len(Cfg.Storages))
|
||||||
|
for _, storage := range Cfg.Storages {
|
||||||
|
fmt.Printf(" - %s (%s)\n", storage.GetName(), storage.GetType())
|
||||||
|
}
|
||||||
|
|
||||||
|
if Cfg.Workers < 1 || Cfg.Retry < 1 {
|
||||||
|
return fmt.Errorf("workers 和 retry 必须大于 0, 当前值: workers=%d, retry=%d", Cfg.Workers, Cfg.Retry)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func Set(key string, value any) {
|
func Set(key string, value any) {
|
||||||
|
|||||||
29
core/core.go
29
core/core.go
@@ -19,6 +19,7 @@ import (
|
|||||||
"github.com/krau/SaveAny-Bot/config"
|
"github.com/krau/SaveAny-Bot/config"
|
||||||
"github.com/krau/SaveAny-Bot/logger"
|
"github.com/krau/SaveAny-Bot/logger"
|
||||||
"github.com/krau/SaveAny-Bot/queue"
|
"github.com/krau/SaveAny-Bot/queue"
|
||||||
|
"github.com/krau/SaveAny-Bot/storage"
|
||||||
"github.com/krau/SaveAny-Bot/types"
|
"github.com/krau/SaveAny-Bot/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -30,26 +31,24 @@ func processPendingTask(task *types.Task) error {
|
|||||||
cacheDestPath := filepath.Join(config.Cfg.Temp.BasePath, task.FileName())
|
cacheDestPath := filepath.Join(config.Cfg.Temp.BasePath, task.FileName())
|
||||||
cacheDestPath, err := filepath.Abs(cacheDestPath)
|
cacheDestPath, err := filepath.Abs(cacheDestPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get absolute path: %w", err)
|
return fmt.Errorf("处理路径失败: %w", err)
|
||||||
}
|
}
|
||||||
if err := fileutil.CreateDir(filepath.Dir(cacheDestPath)); err != nil {
|
if err := fileutil.CreateDir(filepath.Dir(cacheDestPath)); err != nil {
|
||||||
return fmt.Errorf("failed to create directory: %w", err)
|
return fmt.Errorf("创建目录失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if task.StoragePath == "" {
|
if task.StoragePath == "" {
|
||||||
task.StoragePath = task.File.FileName
|
task.StoragePath = task.File.FileName
|
||||||
}
|
}
|
||||||
switch task.Storage {
|
|
||||||
case types.Local:
|
taskStorage, err := storage.GetStorageByUserIDAndName(task.UserID, task.StorageName)
|
||||||
task.StoragePath = filepath.Join(config.Cfg.Storage.Local.BasePath, task.StoragePath)
|
if err != nil {
|
||||||
case types.Webdav:
|
return err
|
||||||
task.StoragePath = path.Join(config.Cfg.Storage.Webdav.BasePath, task.StoragePath)
|
|
||||||
case types.Alist:
|
|
||||||
task.StoragePath = path.Join(config.Cfg.Storage.Alist.BasePath, task.StoragePath)
|
|
||||||
}
|
}
|
||||||
|
task.StoragePath = taskStorage.JoinStoragePath(*task)
|
||||||
|
|
||||||
if task.File.FileSize == 0 {
|
if task.File.FileSize == 0 {
|
||||||
return processPhoto(task, cacheDestPath)
|
return processPhoto(task, taskStorage, cacheDestPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := task.Ctx.(*ext.Context)
|
ctx := task.Ctx.(*ext.Context)
|
||||||
@@ -81,18 +80,18 @@ func processPendingTask(task *types.Task) error {
|
|||||||
0, task.File.FileSize-1, task.File.FileSize,
|
0, task.File.FileSize-1, task.File.FileSize,
|
||||||
progressCallback, task.File.FileSize/100)
|
progressCallback, task.File.FileSize/100)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create reader: %w", err)
|
return fmt.Errorf("创建下载失败: %w", err)
|
||||||
}
|
}
|
||||||
defer readCloser.Close()
|
defer readCloser.Close()
|
||||||
|
|
||||||
dest, err := os.Create(cacheDestPath)
|
dest, err := os.Create(cacheDestPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create file: %w", err)
|
return fmt.Errorf("创建文件失败: %w", err)
|
||||||
}
|
}
|
||||||
defer dest.Close()
|
defer dest.Close()
|
||||||
task.StartTime = time.Now()
|
task.StartTime = time.Now()
|
||||||
if _, err := io.CopyN(dest, readCloser, task.File.FileSize); err != nil {
|
if _, err := io.CopyN(dest, readCloser, task.File.FileSize); err != nil {
|
||||||
return fmt.Errorf("failed to download file: %w", err)
|
return fmt.Errorf("下载文件失败: %w", err)
|
||||||
}
|
}
|
||||||
defer cleanCacheFile(cacheDestPath)
|
defer cleanCacheFile(cacheDestPath)
|
||||||
if path.Ext(task.FileName()) == "" {
|
if path.Ext(task.FileName()) == "" {
|
||||||
@@ -111,7 +110,7 @@ func processPendingTask(task *types.Task) error {
|
|||||||
ID: task.ReplyMessageID,
|
ID: task.ReplyMessageID,
|
||||||
})
|
})
|
||||||
|
|
||||||
return saveFileWithRetry(task, cacheDestPath)
|
return saveFileWithRetry(task, taskStorage, cacheDestPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
func worker(queue *queue.TaskQueue, semaphore chan struct{}) {
|
func worker(queue *queue.TaskQueue, semaphore chan struct{}) {
|
||||||
@@ -139,7 +138,7 @@ func worker(queue *queue.TaskQueue, semaphore chan struct{}) {
|
|||||||
case types.Succeeded:
|
case types.Succeeded:
|
||||||
logger.L.Infof("Task succeeded: %s", task.String())
|
logger.L.Infof("Task succeeded: %s", task.String())
|
||||||
task.Ctx.(*ext.Context).EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
task.Ctx.(*ext.Context).EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
||||||
Message: fmt.Sprintf("文件保存成功\n [%s]: %s", task.Storage, task.StoragePath),
|
Message: fmt.Sprintf("文件保存成功\n [%s]: %s", task.StorageName, task.StoragePath),
|
||||||
ID: task.ReplyMessageID,
|
ID: task.ReplyMessageID,
|
||||||
})
|
})
|
||||||
case types.Failed:
|
case types.Failed:
|
||||||
|
|||||||
@@ -16,9 +16,9 @@ import (
|
|||||||
"github.com/krau/SaveAny-Bot/types"
|
"github.com/krau/SaveAny-Bot/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
func saveFileWithRetry(task *types.Task, localFilePath string) error {
|
func saveFileWithRetry(task *types.Task, taskStorage storage.Storage, localFilePath string) error {
|
||||||
for i := 0; i <= config.Cfg.Retry; i++ {
|
for i := 0; i <= config.Cfg.Retry; i++ {
|
||||||
if err := storage.Save(task.Storage, task.Ctx, localFilePath, task.StoragePath); err != nil {
|
if err := taskStorage.Save(task.Ctx, localFilePath, task.StoragePath); err != nil {
|
||||||
if i == config.Cfg.Retry {
|
if i == config.Cfg.Retry {
|
||||||
return fmt.Errorf("failed to save file: %w", err)
|
return fmt.Errorf("failed to save file: %w", err)
|
||||||
}
|
}
|
||||||
@@ -30,7 +30,7 @@ func saveFileWithRetry(task *types.Task, localFilePath string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func processPhoto(task *types.Task, cachePath string) error {
|
func processPhoto(task *types.Task, taskStorage storage.Storage, cachePath string) error {
|
||||||
res, err := bot.Client.API().UploadGetFile(task.Ctx, &tg.UploadGetFileRequest{
|
res, err := bot.Client.API().UploadGetFile(task.Ctx, &tg.UploadGetFileRequest{
|
||||||
Location: task.File.Location,
|
Location: task.File.Location,
|
||||||
Offset: 0,
|
Offset: 0,
|
||||||
@@ -53,7 +53,7 @@ func processPhoto(task *types.Task, cachePath string) error {
|
|||||||
|
|
||||||
logger.L.Infof("Downloaded file: %s", cachePath)
|
logger.L.Infof("Downloaded file: %s", cachePath)
|
||||||
|
|
||||||
return saveFileWithRetry(task, cachePath)
|
return saveFileWithRetry(task, taskStorage, cachePath)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getProgressBar(progress float64, totalCount int) string {
|
func getProgressBar(progress float64, totalCount int) string {
|
||||||
@@ -104,7 +104,7 @@ func buildProgressMessageEntity(task *types.Task, barTotalCount int, bytesRead i
|
|||||||
entityBuilder := entity.Builder{}
|
entityBuilder := entity.Builder{}
|
||||||
text := fmt.Sprintf("正在处理下载任务\n文件名: %s\n保存路径: %s\n平均速度: %s\n当前进度: [%s] %.2f%%",
|
text := fmt.Sprintf("正在处理下载任务\n文件名: %s\n保存路径: %s\n平均速度: %s\n当前进度: [%s] %.2f%%",
|
||||||
task.FileName(),
|
task.FileName(),
|
||||||
fmt.Sprintf("[%s]:%s", task.Storage, task.StoragePath),
|
fmt.Sprintf("[%s]:%s", task.StorageName, task.StoragePath),
|
||||||
getSpeed(bytesRead, startTime),
|
getSpeed(bytesRead, startTime),
|
||||||
getProgressBar(progress, barTotalCount),
|
getProgressBar(progress, barTotalCount),
|
||||||
progress,
|
progress,
|
||||||
@@ -114,7 +114,7 @@ func buildProgressMessageEntity(task *types.Task, barTotalCount int, bytesRead i
|
|||||||
styling.Plain("正在处理下载任务\n文件名: "),
|
styling.Plain("正在处理下载任务\n文件名: "),
|
||||||
styling.Code(task.FileName()),
|
styling.Code(task.FileName()),
|
||||||
styling.Plain("\n保存路径: "),
|
styling.Plain("\n保存路径: "),
|
||||||
styling.Code(fmt.Sprintf("[%s]:%s", task.Storage, task.StoragePath)),
|
styling.Code(fmt.Sprintf("[%s]:%s", task.StorageName, task.StoragePath)),
|
||||||
styling.Plain("\n平均速度: "),
|
styling.Plain("\n平均速度: "),
|
||||||
styling.Bold(getSpeed(bytesRead, task.StartTime)),
|
styling.Bold(getSpeed(bytesRead, task.StartTime)),
|
||||||
styling.Plain("\n当前进度:\n "),
|
styling.Plain("\n当前进度:\n "),
|
||||||
|
|||||||
12
dao/db.go
12
dao/db.go
@@ -16,7 +16,7 @@ import (
|
|||||||
var db *gorm.DB
|
var db *gorm.DB
|
||||||
|
|
||||||
func Init() {
|
func Init() {
|
||||||
if err := os.MkdirAll(filepath.Dir(config.Cfg.DB.Path), 755); err != nil {
|
if err := os.MkdirAll(filepath.Dir(config.Cfg.DB.Path), 0755); err != nil {
|
||||||
logger.L.Fatal("Failed to create data directory: ", err)
|
logger.L.Fatal("Failed to create data directory: ", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
@@ -36,9 +36,13 @@ func Init() {
|
|||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
logger.L.Debug("Database connected")
|
logger.L.Debug("Database connected")
|
||||||
db.AutoMigrate(&types.ReceivedFile{}, &types.User{})
|
if err := db.AutoMigrate(&types.ReceivedFile{}, &types.User{}); err != nil {
|
||||||
|
logger.L.Fatal("迁移数据库失败, 如果您从旧版本升级, 建议手动删除数据库文件后重试: ", err)
|
||||||
|
}
|
||||||
|
|
||||||
for _, admin := range config.Cfg.Telegram.Admins {
|
for _, admin := range config.Cfg.GetUsersID() {
|
||||||
CreateUser(int64(admin))
|
if err := CreateUser(int64(admin)); err != nil {
|
||||||
|
logger.L.Fatal("Failed to create admin user: ", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
20
dao/user.go
20
dao/user.go
@@ -4,19 +4,29 @@ import (
|
|||||||
"github.com/krau/SaveAny-Bot/types"
|
"github.com/krau/SaveAny-Bot/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
func CreateUser(userID int64) error {
|
func CreateUser(chatID int64) error {
|
||||||
if _, err := GetUserByUserID(userID); err == nil {
|
if _, err := GetUserByChatID(chatID); err == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return db.Create(&types.User{UserID: userID}).Error
|
return db.Create(&types.User{ChatID: chatID}).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUserByUserID(userID int64) (*types.User, error) {
|
func GetAllUsers() ([]types.User, error) {
|
||||||
|
var users []types.User
|
||||||
|
err := db.Find(&users).Error
|
||||||
|
return users, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetUserByChatID(chatID int64) (*types.User, error) {
|
||||||
var user types.User
|
var user types.User
|
||||||
err := db.Where("user_id = ?", userID).First(&user).Error
|
err := db.Where("chat_id = ?", chatID).First(&user).Error
|
||||||
return &user, err
|
return &user, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateUser(user *types.User) error {
|
func UpdateUser(user *types.User) error {
|
||||||
return db.Save(user).Error
|
return db.Save(user).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func DeleteUser(user *types.User) error {
|
||||||
|
return db.Delete(user).Error
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,30 +3,8 @@ services:
|
|||||||
image: ghcr.io/krau/saveany-bot:latest
|
image: ghcr.io/krau/saveany-bot:latest
|
||||||
container_name: saveany-bot
|
container_name: saveany-bot
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
environment:
|
|
||||||
- SAVEANY_TELEGRAM_TOKEN=bot_token
|
|
||||||
- SAVEANY_TELEGRAM_ADMINS=admin_id1,admin_id2
|
|
||||||
# 推荐使用自己的 API ID 和 API HASH (https://my.telegram.org)
|
|
||||||
# 若不配置也可运行, 将使用默认的 API ID 和 API HASH
|
|
||||||
# - SAVEANY_TELEGRAM_APP_ID=app_id
|
|
||||||
# - SAVEANY_TELEGRAM_APP_HASH=app_hash
|
|
||||||
|
|
||||||
# 本地存储
|
|
||||||
- SAVEANY_STORAGE_LOCAL_ENABLE=true
|
|
||||||
- SAVEANY_STORAGE_LOCAL_BASE_PATH=/app/downloads
|
|
||||||
# Alist
|
|
||||||
- SAVEANY_STORAGE_ALIST_ENABLE=true
|
|
||||||
- SAVEANY_STORAGE_ALIST_BASE_PATH=/saveany
|
|
||||||
- SAVEANY_STORAGE_ALIST_URL=https://example.com
|
|
||||||
- SAVEANY_STORAGE_ALIST_USERNAME=username
|
|
||||||
- SAVEANY_STORAGE_ALIST_PASSWORD=password
|
|
||||||
# webdav
|
|
||||||
- SAVEANY_STORAGE_WEBDAV_ENABLE=true
|
|
||||||
- SAVEANY_STORAGE_WEBDAV_BASE_PATH=/saveany
|
|
||||||
- SAVEANY_STORAGE_WEBDAV_URL=https://example.com
|
|
||||||
- SAVEANY_STORAGE_WEBDAV_USERNAME=username
|
|
||||||
- SAVEANY_STORAGE_WEBDAV_PASSWORD=password
|
|
||||||
volumes:
|
volumes:
|
||||||
- ./data:/app/data
|
- ./data:/app/data
|
||||||
|
- ./config.toml:/app/config.toml
|
||||||
- ./downloads:/app/downloads
|
- ./downloads:/app/downloads
|
||||||
- ./cache:/app/cache
|
- ./cache:/app/cache
|
||||||
4
go.mod
4
go.mod
@@ -17,6 +17,7 @@ require (
|
|||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
filippo.io/edwards25519 v1.1.0 // indirect
|
||||||
github.com/AnimeKaizoku/cacher v1.0.2 // indirect
|
github.com/AnimeKaizoku/cacher v1.0.2 // indirect
|
||||||
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
|
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
|
||||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||||
@@ -31,6 +32,7 @@ require (
|
|||||||
github.com/go-faster/jx v1.1.0 // indirect
|
github.com/go-faster/jx v1.1.0 // indirect
|
||||||
github.com/go-faster/xor v1.0.0 // indirect
|
github.com/go-faster/xor v1.0.0 // indirect
|
||||||
github.com/go-faster/yaml v0.4.6 // indirect
|
github.com/go-faster/yaml v0.4.6 // indirect
|
||||||
|
github.com/go-sql-driver/mysql v1.8.1 // indirect
|
||||||
github.com/google/go-github/v30 v30.1.0 // indirect
|
github.com/google/go-github/v30 v30.1.0 // indirect
|
||||||
github.com/google/go-querystring v1.1.0 // indirect
|
github.com/google/go-querystring v1.1.0 // indirect
|
||||||
github.com/google/pprof v0.0.0-20250128161936-077ca0a936bf // indirect
|
github.com/google/pprof v0.0.0-20250128161936-077ca0a936bf // indirect
|
||||||
@@ -60,6 +62,7 @@ require (
|
|||||||
golang.org/x/oauth2 v0.26.0 // indirect
|
golang.org/x/oauth2 v0.26.0 // indirect
|
||||||
golang.org/x/tools v0.30.0 // indirect
|
golang.org/x/tools v0.30.0 // indirect
|
||||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||||
|
gorm.io/driver/mysql v1.5.6 // indirect
|
||||||
modernc.org/libc v1.61.13 // indirect
|
modernc.org/libc v1.61.13 // indirect
|
||||||
modernc.org/mathutil v1.7.1 // indirect
|
modernc.org/mathutil v1.7.1 // indirect
|
||||||
modernc.org/memory v1.8.2 // indirect
|
modernc.org/memory v1.8.2 // indirect
|
||||||
@@ -97,5 +100,6 @@ require (
|
|||||||
golang.org/x/text v0.22.0 // indirect
|
golang.org/x/text v0.22.0 // indirect
|
||||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
|
gorm.io/datatypes v1.2.5
|
||||||
gorm.io/gorm v1.25.12
|
gorm.io/gorm v1.25.12
|
||||||
)
|
)
|
||||||
|
|||||||
10
go.sum
10
go.sum
@@ -1,3 +1,5 @@
|
|||||||
|
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
|
||||||
|
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
||||||
github.com/AnimeKaizoku/cacher v1.0.2 h1:7Bf5qRylWb7q2Evib0OXlhG37/t7BP2HK/7IyPvSmGQ=
|
github.com/AnimeKaizoku/cacher v1.0.2 h1:7Bf5qRylWb7q2Evib0OXlhG37/t7BP2HK/7IyPvSmGQ=
|
||||||
github.com/AnimeKaizoku/cacher v1.0.2/go.mod h1:jw0de/b0K6W7Y3T9rHCMGVKUf6oG7hENNcssxYcZTCc=
|
github.com/AnimeKaizoku/cacher v1.0.2/go.mod h1:jw0de/b0K6W7Y3T9rHCMGVKUf6oG7hENNcssxYcZTCc=
|
||||||
github.com/blang/semver v3.5.1+incompatible h1:cQNTCjp13qL8KC3Nbxr/y2Bqb63oX6wdnnjpJbkM4JQ=
|
github.com/blang/semver v3.5.1+incompatible h1:cQNTCjp13qL8KC3Nbxr/y2Bqb63oX6wdnnjpJbkM4JQ=
|
||||||
@@ -55,6 +57,9 @@ github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
|
|||||||
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||||
|
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
|
||||||
|
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
|
||||||
|
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
|
||||||
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||||
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||||
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
@@ -265,6 +270,11 @@ gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
|||||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
gorm.io/datatypes v1.2.5 h1:9UogU3jkydFVW1bIVVeoYsTpLRgwDVW3rHfJG6/Ek9I=
|
||||||
|
gorm.io/datatypes v1.2.5/go.mod h1:I5FUdlKpLb5PMqeMQhm30CQ6jXP8Rj89xkTeCSAaAD4=
|
||||||
|
gorm.io/driver/mysql v1.5.6 h1:Ld4mkIickM+EliaQZQx3uOJDJHtrd70MxAUqWqlx3Y8=
|
||||||
|
gorm.io/driver/mysql v1.5.6/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM=
|
||||||
|
gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
|
||||||
gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
|
gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
|
||||||
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
|
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
|
||||||
modernc.org/cc/v4 v4.24.4 h1:TFkx1s6dCkQpd6dKurBNmpo+G8Zl4Sq/ztJ+2+DEsh0=
|
modernc.org/cc/v4 v4.24.4 h1:TFkx1s6dCkQpd6dKurBNmpo+G8Zl4Sq/ztJ+2+DEsh0=
|
||||||
|
|||||||
@@ -1,19 +1,19 @@
|
|||||||
package alist
|
package alist
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
|
"path"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/krau/SaveAny-Bot/config"
|
"github.com/krau/SaveAny-Bot/config"
|
||||||
"github.com/krau/SaveAny-Bot/logger"
|
"github.com/krau/SaveAny-Bot/logger"
|
||||||
|
"github.com/krau/SaveAny-Bot/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Alist struct {
|
type Alist struct {
|
||||||
@@ -21,154 +21,80 @@ type Alist struct {
|
|||||||
token string
|
token string
|
||||||
baseURL string
|
baseURL string
|
||||||
loginInfo *loginRequest
|
loginInfo *loginRequest
|
||||||
|
config config.AlistStorageConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
func (a *Alist) Init(cfg config.StorageConfig) error {
|
||||||
ErrAlistLoginFailed = errors.New("failed to login to Alist")
|
alistConfig, ok := cfg.(*config.AlistStorageConfig)
|
||||||
)
|
if !ok {
|
||||||
|
return fmt.Errorf("failed to cast alist config")
|
||||||
type loginRequest struct {
|
|
||||||
Username string `json:"username"`
|
|
||||||
Password string `json:"password"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type loginResponse struct {
|
|
||||||
Code int `json:"code"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
Data struct {
|
|
||||||
Token string `json:"token"`
|
|
||||||
} `json:"data"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type meResponse struct {
|
|
||||||
Code int `json:"code"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
Data struct {
|
|
||||||
ID int `json:"id"`
|
|
||||||
Username string `json:"username"`
|
|
||||||
} `json:"data"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type putResponse struct {
|
|
||||||
Code int `json:"code"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
Data struct {
|
|
||||||
Task struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Name string `json:"name"`
|
|
||||||
State int `json:"state"`
|
|
||||||
Status string `json:"status"`
|
|
||||||
Progress int `json:"progress"`
|
|
||||||
Error string `json:"error"`
|
|
||||||
} `json:"task"`
|
|
||||||
} `json:"data"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Alist) getToken() error {
|
|
||||||
loginBody, err := json.Marshal(a.loginInfo)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to marshal login request: %w", err)
|
|
||||||
}
|
}
|
||||||
|
if err := alistConfig.Validate(); err != nil {
|
||||||
req, err := http.NewRequest(http.MethodPost, a.baseURL+"/api/auth/login", bytes.NewBuffer(loginBody))
|
return err
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create login request: %w", err)
|
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/json")
|
a.config = *alistConfig
|
||||||
|
|
||||||
resp, err := a.client.Do(req)
|
a.baseURL = alistConfig.URL
|
||||||
if err != nil {
|
a.client = getHttpClient()
|
||||||
return fmt.Errorf("failed to send login request: %w", err)
|
if alistConfig.Token != "" {
|
||||||
}
|
a.token = alistConfig.Token
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to read login response: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var loginResp loginResponse
|
|
||||||
if err := json.Unmarshal(body, &loginResp); err != nil {
|
|
||||||
return fmt.Errorf("failed to unmarshal login response: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if loginResp.Code != http.StatusOK {
|
|
||||||
return fmt.Errorf("%w: %s", ErrAlistLoginFailed, loginResp.Message)
|
|
||||||
}
|
|
||||||
|
|
||||||
a.token = loginResp.Data.Token
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Alist) refreshToken() {
|
|
||||||
for {
|
|
||||||
time.Sleep(time.Duration(config.Cfg.Storage.Alist.TokenExp) * time.Second)
|
|
||||||
if err := a.getToken(); err != nil {
|
|
||||||
logger.L.Errorf("Failed to refresh jwt token: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
logger.L.Info("Refreshed Alist jwt token")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Alist) Init() {
|
|
||||||
a.baseURL = config.Cfg.Storage.Alist.URL
|
|
||||||
a.client = &http.Client{
|
|
||||||
Timeout: 12 * time.Hour,
|
|
||||||
Transport: &http.Transport{
|
|
||||||
TLSHandshakeTimeout: 10 * time.Second,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
if config.Cfg.Storage.Alist.Token != "" {
|
|
||||||
a.token = config.Cfg.Storage.Alist.Token
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
|
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, a.baseURL+"/api/me", nil)
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, a.baseURL+"/api/me", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.L.Fatalf("Failed to create request: %v", err)
|
logger.L.Fatalf("Failed to create request: %v", err)
|
||||||
os.Exit(1)
|
return err
|
||||||
}
|
}
|
||||||
req.Header.Set("Authorization", a.token)
|
req.Header.Set("Authorization", a.token)
|
||||||
|
|
||||||
resp, err := a.client.Do(req)
|
resp, err := a.client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.L.Fatalf("Failed to send request: %v", err)
|
logger.L.Fatalf("Failed to send request: %v", err)
|
||||||
os.Exit(1)
|
return err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
logger.L.Fatalf("Failed to get alist user info: %s", resp.Status)
|
logger.L.Fatalf("Failed to get alist user info: %s", resp.Status)
|
||||||
os.Exit(1)
|
return err
|
||||||
}
|
}
|
||||||
body, err := io.ReadAll(resp.Body)
|
body, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.L.Fatalf("Failed to read response body: %v", err)
|
logger.L.Fatalf("Failed to read response body: %v", err)
|
||||||
os.Exit(1)
|
return err
|
||||||
}
|
}
|
||||||
var meResp meResponse
|
var meResp meResponse
|
||||||
if err := json.Unmarshal(body, &meResp); err != nil {
|
if err := json.Unmarshal(body, &meResp); err != nil {
|
||||||
logger.L.Fatalf("Failed to unmarshal me response: %v", err)
|
logger.L.Fatalf("Failed to unmarshal me response: %v", err)
|
||||||
os.Exit(1)
|
return err
|
||||||
}
|
}
|
||||||
if meResp.Code != http.StatusOK {
|
if meResp.Code != http.StatusOK {
|
||||||
logger.L.Fatalf("Failed to get alist user info: %s", meResp.Message)
|
logger.L.Fatalf("Failed to get alist user info: %s", meResp.Message)
|
||||||
os.Exit(1)
|
return err
|
||||||
}
|
}
|
||||||
logger.L.Debugf("Logged in Alist as %s", meResp.Data.Username)
|
logger.L.Debugf("Logged in Alist as %s", meResp.Data.Username)
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
a.loginInfo = &loginRequest{
|
a.loginInfo = &loginRequest{
|
||||||
Username: config.Cfg.Storage.Alist.Username,
|
Username: alistConfig.Username,
|
||||||
Password: config.Cfg.Storage.Alist.Password,
|
Password: alistConfig.Password,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := a.getToken(); err != nil {
|
if err := a.getToken(); err != nil {
|
||||||
logger.L.Fatalf("Failed to login to Alist: %v", err)
|
logger.L.Fatalf("Failed to login to Alist: %v", err)
|
||||||
os.Exit(1)
|
return err
|
||||||
}
|
}
|
||||||
logger.L.Debug("Logged in to Alist")
|
logger.L.Debug("Logged in to Alist")
|
||||||
|
|
||||||
go a.refreshToken()
|
go a.refreshToken(*alistConfig)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Alist) Type() types.StorageType {
|
||||||
|
return types.StorageTypeAlist
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Alist) Name() string {
|
||||||
|
return a.config.Name
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Alist) Save(ctx context.Context, filePath, storagePath string) error {
|
func (a *Alist) Save(ctx context.Context, filePath, storagePath string) error {
|
||||||
@@ -219,3 +145,7 @@ func (a *Alist) Save(ctx context.Context, filePath, storagePath string) error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Alist) JoinStoragePath(task types.Task) string {
|
||||||
|
return path.Join(a.config.BasePath, task.StoragePath)
|
||||||
|
}
|
||||||
|
|||||||
65
storage/alist/token.go
Normal file
65
storage/alist/token.go
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
package alist
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/krau/SaveAny-Bot/config"
|
||||||
|
"github.com/krau/SaveAny-Bot/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (a *Alist) getToken() error {
|
||||||
|
loginBody, err := json.Marshal(a.loginInfo)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal login request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequest(http.MethodPost, a.baseURL+"/api/auth/login", bytes.NewBuffer(loginBody))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create login request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := a.client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to send login request: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read login response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var loginResp loginResponse
|
||||||
|
if err := json.Unmarshal(body, &loginResp); err != nil {
|
||||||
|
return fmt.Errorf("failed to unmarshal login response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if loginResp.Code != http.StatusOK {
|
||||||
|
return fmt.Errorf("%w: %s", ErrAlistLoginFailed, loginResp.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
a.token = loginResp.Data.Token
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Alist) refreshToken(cfg config.AlistStorageConfig) {
|
||||||
|
tokenExp := cfg.TokenExp
|
||||||
|
if tokenExp <= 0 {
|
||||||
|
logger.L.Warn("Invalid token expiration time, using default value")
|
||||||
|
tokenExp = 3600
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
time.Sleep(time.Duration(tokenExp) * time.Second)
|
||||||
|
if err := a.getToken(); err != nil {
|
||||||
|
logger.L.Errorf("Failed to refresh jwt token: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
logger.L.Info("Refreshed Alist jwt token")
|
||||||
|
}
|
||||||
|
}
|
||||||
44
storage/alist/types.go
Normal file
44
storage/alist/types.go
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
package alist
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrAlistLoginFailed = errors.New("failed to login to Alist")
|
||||||
|
)
|
||||||
|
|
||||||
|
type loginRequest struct {
|
||||||
|
Username string `json:"username"`
|
||||||
|
Password string `json:"password"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type loginResponse struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Data struct {
|
||||||
|
Token string `json:"token"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type meResponse struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Data struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type putResponse struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Data struct {
|
||||||
|
Task struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
State int `json:"state"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
Progress int `json:"progress"`
|
||||||
|
Error string `json:"error"`
|
||||||
|
} `json:"task"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
23
storage/alist/utils.go
Normal file
23
storage/alist/utils.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package alist
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
httpClient *http.Client
|
||||||
|
)
|
||||||
|
|
||||||
|
func getHttpClient() *http.Client {
|
||||||
|
if httpClient != nil {
|
||||||
|
return httpClient
|
||||||
|
}
|
||||||
|
httpClient = &http.Client{
|
||||||
|
Timeout: 12 * time.Hour,
|
||||||
|
Transport: &http.Transport{
|
||||||
|
TLSHandshakeTimeout: 10 * time.Second,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return httpClient
|
||||||
|
}
|
||||||
9
storage/errs.go
Normal file
9
storage/errs.go
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
package storage
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrStorageNameEmpty = errors.New("storage name is empty")
|
||||||
|
)
|
||||||
@@ -2,22 +2,41 @@ package local
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
||||||
"github.com/duke-git/lancet/v2/fileutil"
|
"github.com/duke-git/lancet/v2/fileutil"
|
||||||
"github.com/krau/SaveAny-Bot/config"
|
"github.com/krau/SaveAny-Bot/config"
|
||||||
"github.com/krau/SaveAny-Bot/logger"
|
"github.com/krau/SaveAny-Bot/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Local struct{}
|
type Local struct {
|
||||||
|
config config.LocalStorageConfig
|
||||||
|
}
|
||||||
|
|
||||||
func (l *Local) Init() {
|
func (l *Local) Init(cfg config.StorageConfig) error {
|
||||||
err := os.MkdirAll(config.Cfg.Storage.Local.BasePath, os.ModePerm)
|
localConfig, ok := cfg.(*config.LocalStorageConfig)
|
||||||
if err != nil {
|
if !ok {
|
||||||
logger.L.Fatalf("Failed to create local storage directory: %s", err)
|
return fmt.Errorf("failed to cast local config")
|
||||||
os.Exit(1)
|
|
||||||
}
|
}
|
||||||
|
if err := localConfig.Validate(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
l.config = *localConfig
|
||||||
|
err := os.MkdirAll(localConfig.BasePath, os.ModePerm)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create local storage directory: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Local) Type() types.StorageType {
|
||||||
|
return types.StorageTypeLocal
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Local) Name() string {
|
||||||
|
return l.config.Name
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Local) Save(ctx context.Context, filePath, storagePath string) error {
|
func (l *Local) Save(ctx context.Context, filePath, storagePath string) error {
|
||||||
@@ -30,3 +49,7 @@ func (l *Local) Save(ctx context.Context, filePath, storagePath string) error {
|
|||||||
}
|
}
|
||||||
return fileutil.CopyFile(filePath, storagePath)
|
return fileutil.CopyFile(filePath, storagePath)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (l *Local) JoinStoragePath(task types.Task) string {
|
||||||
|
return filepath.Join(l.config.BasePath, task.StoragePath)
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,12 +2,8 @@ package storage
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"fmt"
|
||||||
"path"
|
|
||||||
"path/filepath"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/duke-git/lancet/v2/slice"
|
|
||||||
"github.com/krau/SaveAny-Bot/config"
|
"github.com/krau/SaveAny-Bot/config"
|
||||||
"github.com/krau/SaveAny-Bot/logger"
|
"github.com/krau/SaveAny-Bot/logger"
|
||||||
"github.com/krau/SaveAny-Bot/storage/alist"
|
"github.com/krau/SaveAny-Bot/storage/alist"
|
||||||
@@ -17,68 +13,92 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Storage interface {
|
type Storage interface {
|
||||||
Init()
|
Init(cfg config.StorageConfig) error
|
||||||
|
Type() types.StorageType
|
||||||
|
Name() string
|
||||||
|
JoinStoragePath(task types.Task) string
|
||||||
Save(cttx context.Context, localFilePath, storagePath string) error
|
Save(cttx context.Context, localFilePath, storagePath string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
var Storages = make(map[types.StorageType]Storage)
|
var Storages = make(map[string]Storage)
|
||||||
var StorageKeys = make([]types.StorageType, 0)
|
|
||||||
|
|
||||||
func Init() {
|
// GetStorageByName returns storage by name from cache or creates new one
|
||||||
logger.L.Debug("Initializing storage...")
|
func GetStorageByName(name string) (Storage, error) {
|
||||||
if config.Cfg.Storage.Alist.Enable {
|
if name == "" {
|
||||||
Storages[types.Alist] = new(alist.Alist)
|
return nil, ErrStorageNameEmpty
|
||||||
Storages[types.Alist].Init()
|
|
||||||
}
|
|
||||||
if config.Cfg.Storage.Local.Enable {
|
|
||||||
Storages[types.Local] = new(local.Local)
|
|
||||||
Storages[types.Local].Init()
|
|
||||||
}
|
|
||||||
if config.Cfg.Storage.Webdav.Enable {
|
|
||||||
Storages[types.Webdav] = new(webdav.Webdav)
|
|
||||||
Storages[types.Webdav].Init()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for k := range Storages {
|
storage, ok := Storages[name]
|
||||||
StorageKeys = append(StorageKeys, k)
|
if ok {
|
||||||
|
return storage, nil
|
||||||
|
}
|
||||||
|
cfg := config.Cfg.GetStorageByName(name)
|
||||||
|
if cfg == nil {
|
||||||
|
return nil, fmt.Errorf("未找到存储 %s", name)
|
||||||
}
|
}
|
||||||
|
|
||||||
slice.Sort(StorageKeys)
|
storage, err := NewStorage(cfg)
|
||||||
|
if err != nil {
|
||||||
logger.L.Debug("Storage initialized")
|
return nil, err
|
||||||
|
}
|
||||||
|
Storages[name] = storage
|
||||||
|
return storage, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func Save(storageType types.StorageType, ctx context.Context, filePath, storagePath string) error {
|
// 检查 user 是否可用指定的 storage, 若不可用则返回未找到错误
|
||||||
logger.L.Debugf("Saving file %s to storage: [%s] %s", filePath, storageType, storagePath)
|
func GetStorageByUserIDAndName(chatID int64, name string) (Storage, error) {
|
||||||
if ctx == nil {
|
if name == "" {
|
||||||
ctx = context.Background()
|
return nil, ErrStorageNameEmpty
|
||||||
}
|
}
|
||||||
if storageType != types.StorageAll {
|
|
||||||
return Storages[storageType].Save(ctx, filePath, storagePath)
|
if !config.Cfg.HasStorage(chatID, name) {
|
||||||
|
return nil, fmt.Errorf("没有找到用户 %d 的存储 %s", chatID, name)
|
||||||
}
|
}
|
||||||
errs := make([]error, 0)
|
|
||||||
var wg sync.WaitGroup
|
return GetStorageByName(name)
|
||||||
for _, storage := range Storages {
|
}
|
||||||
wg.Add(1)
|
|
||||||
go func(storage Storage) {
|
func GetUserStorages(chatID int64) []Storage {
|
||||||
defer wg.Done()
|
var storages []Storage
|
||||||
storageDestPath := storagePath
|
for _, name := range config.Cfg.GetStorageNamesByUserID(chatID) {
|
||||||
switch storage.(type) {
|
storage, err := GetStorageByName(name)
|
||||||
case *local.Local:
|
if err != nil {
|
||||||
storageDestPath = filepath.Join(config.Cfg.Storage.Local.BasePath, storagePath)
|
continue
|
||||||
case *webdav.Webdav:
|
}
|
||||||
storageDestPath = path.Join(config.Cfg.Storage.Webdav.BasePath, storagePath)
|
storages = append(storages, storage)
|
||||||
case *alist.Alist:
|
}
|
||||||
storageDestPath = path.Join(config.Cfg.Storage.Alist.BasePath, storagePath)
|
return storages
|
||||||
}
|
}
|
||||||
if err := storage.Save(ctx, filePath, storageDestPath); err != nil {
|
|
||||||
errs = append(errs, err)
|
type StorageConstructor func() Storage
|
||||||
}
|
|
||||||
}(storage)
|
var storageConstructors = map[string]StorageConstructor{
|
||||||
}
|
string(types.StorageTypeAlist): func() Storage { return new(alist.Alist) },
|
||||||
wg.Wait()
|
string(types.StorageTypeLocal): func() Storage { return new(local.Local) },
|
||||||
if len(errs) > 0 {
|
string(types.StorageTypeWebdav): func() Storage { return new(webdav.Webdav) },
|
||||||
return errors.Join(errs...)
|
}
|
||||||
}
|
|
||||||
return nil
|
func NewStorage(cfg config.StorageConfig) (Storage, error) {
|
||||||
|
constructor, ok := storageConstructors[string(cfg.GetType())]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("不支持的存储类型: %s", cfg.GetType())
|
||||||
|
}
|
||||||
|
|
||||||
|
storage := constructor()
|
||||||
|
if err := storage.Init(cfg); err != nil {
|
||||||
|
return nil, fmt.Errorf("初始化 %s 存储失败: %w", cfg.GetName(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return storage, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadStorages() {
|
||||||
|
logger.L.Info("加载存储...")
|
||||||
|
for _, storage := range config.Cfg.Storages {
|
||||||
|
_, err := GetStorageByName(storage.GetName())
|
||||||
|
if err != nil {
|
||||||
|
logger.L.Errorf("加载存储 %s 失败: %v", storage.GetName(), err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
logger.L.Infof("成功加载 %d 个存储", len(Storages))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,33 +2,50 @@ package webdav
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/krau/SaveAny-Bot/config"
|
"github.com/krau/SaveAny-Bot/config"
|
||||||
"github.com/krau/SaveAny-Bot/logger"
|
"github.com/krau/SaveAny-Bot/logger"
|
||||||
|
"github.com/krau/SaveAny-Bot/types"
|
||||||
"github.com/studio-b12/gowebdav"
|
"github.com/studio-b12/gowebdav"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Webdav struct{}
|
type Webdav struct {
|
||||||
|
config config.WebdavStorageConfig
|
||||||
|
client *gowebdav.Client
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
func (w *Webdav) Init(cfg config.StorageConfig) error {
|
||||||
Client *gowebdav.Client
|
webdavConfig, ok := cfg.(*config.WebdavStorageConfig)
|
||||||
)
|
if !ok {
|
||||||
|
return fmt.Errorf("failed to cast webdav config")
|
||||||
func (w *Webdav) Init() {
|
|
||||||
webdavConfig := config.Cfg.Storage.Webdav
|
|
||||||
Client = gowebdav.NewClient(webdavConfig.URL, webdavConfig.Username, webdavConfig.Password)
|
|
||||||
if err := Client.Connect(); err != nil {
|
|
||||||
logger.L.Fatalf("Failed to connect to webdav server: %v", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
}
|
||||||
Client.SetTimeout(24 * time.Hour)
|
if err := webdavConfig.Validate(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
w.config = *webdavConfig
|
||||||
|
client := gowebdav.NewClient(webdavConfig.URL, webdavConfig.Username, webdavConfig.Password)
|
||||||
|
if err := client.Connect(); err != nil {
|
||||||
|
return fmt.Errorf("failed to connect to webdav server: %w", err)
|
||||||
|
}
|
||||||
|
client.SetTimeout(12 * time.Hour)
|
||||||
|
w.client = client
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Webdav) Type() types.StorageType {
|
||||||
|
return types.StorageTypeWebdav
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Webdav) Name() string {
|
||||||
|
return w.config.Name
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *Webdav) Save(ctx context.Context, filePath, storagePath string) error {
|
func (w *Webdav) Save(ctx context.Context, filePath, storagePath string) error {
|
||||||
if err := Client.MkdirAll(path.Dir(storagePath), os.ModePerm); err != nil {
|
if err := w.client.MkdirAll(path.Dir(storagePath), os.ModePerm); err != nil {
|
||||||
logger.L.Errorf("Failed to create directory %s: %v", path.Dir(storagePath), err)
|
logger.L.Errorf("Failed to create directory %s: %v", path.Dir(storagePath), err)
|
||||||
return ErrFailedToCreateDirectory
|
return ErrFailedToCreateDirectory
|
||||||
}
|
}
|
||||||
@@ -39,9 +56,13 @@ func (w *Webdav) Save(ctx context.Context, filePath, storagePath string) error {
|
|||||||
}
|
}
|
||||||
defer file.Close()
|
defer file.Close()
|
||||||
|
|
||||||
if err := Client.WriteStream(storagePath, file, os.ModePerm); err != nil {
|
if err := w.client.WriteStream(storagePath, file, os.ModePerm); err != nil {
|
||||||
logger.L.Errorf("Failed to write file %s: %v", storagePath, err)
|
logger.L.Errorf("Failed to write file %s: %v", storagePath, err)
|
||||||
return ErrFailedToWriteFile
|
return ErrFailedToWriteFile
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *Webdav) JoinStoragePath(task types.Task) string {
|
||||||
|
return path.Join(w.config.BasePath, task.StoragePath)
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,9 +6,11 @@ import (
|
|||||||
|
|
||||||
type ReceivedFile struct {
|
type ReceivedFile struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
Processing bool
|
Processing bool
|
||||||
ChatID int64 `gorm:"uniqueIndex:idx_chat_id_message_id;not null"`
|
// Which chat the file is from
|
||||||
MessageID int `gorm:"uniqueIndex:idx_chat_id_message_id;not null"`
|
ChatID int64 `gorm:"uniqueIndex:idx_chat_id_message_id;not null"`
|
||||||
|
// Which message the file is from
|
||||||
|
MessageID int `gorm:"uniqueIndex:idx_chat_id_message_id;not null"`
|
||||||
ReplyMessageID int
|
ReplyMessageID int
|
||||||
ReplyChatID int64
|
ReplyChatID int64
|
||||||
FileName string
|
FileName string
|
||||||
@@ -16,7 +18,7 @@ type ReceivedFile struct {
|
|||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
UserID int64 `gorm:"uniqueIndex"`
|
ChatID int64 `gorm:"uniqueIndex;not null"`
|
||||||
Silent bool
|
Silent bool
|
||||||
DefaultStorage string
|
DefaultStorage string // Default storage name
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,27 +22,34 @@ var (
|
|||||||
type StorageType string
|
type StorageType string
|
||||||
|
|
||||||
var (
|
var (
|
||||||
StorageAll StorageType = "all"
|
StorageTypeLocal StorageType = "local"
|
||||||
Local StorageType = "local"
|
StorageTypeWebdav StorageType = "webdav"
|
||||||
Webdav StorageType = "webdav"
|
StorageTypeAlist StorageType = "alist"
|
||||||
Alist StorageType = "alist"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var StorageTypes = []StorageType{Local, Alist, Webdav, StorageAll}
|
var StorageTypes = []StorageType{StorageTypeLocal, StorageTypeAlist, StorageTypeWebdav}
|
||||||
|
var StorageTypeDisplay = map[StorageType]string{
|
||||||
|
StorageTypeLocal: "本地磁盘",
|
||||||
|
StorageTypeWebdav: "WebDAV",
|
||||||
|
StorageTypeAlist: "Alist",
|
||||||
|
}
|
||||||
|
|
||||||
type Task struct {
|
type Task struct {
|
||||||
Ctx context.Context
|
Ctx context.Context
|
||||||
Error error
|
Error error
|
||||||
Status TaskStatus
|
Status TaskStatus
|
||||||
File *File
|
File *File
|
||||||
Storage StorageType
|
StorageName string
|
||||||
StoragePath string
|
StoragePath string
|
||||||
StartTime time.Time
|
StartTime time.Time
|
||||||
|
|
||||||
FileMessageID int
|
FileMessageID int
|
||||||
FileChatID int64
|
FileChatID int64
|
||||||
|
// to track the reply message
|
||||||
ReplyMessageID int
|
ReplyMessageID int
|
||||||
ReplyChatID int64
|
ReplyChatID int64
|
||||||
|
// to track the user
|
||||||
|
UserID int64
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Task) String() string {
|
func (t Task) String() string {
|
||||||
|
|||||||
Reference in New Issue
Block a user