Compare commits

...

76 Commits

Author SHA1 Message Date
krau
50fba3f910 feat: add configurable timeout for Telegram client initialization 2025-04-07 10:23:50 +08:00
krau
87d3f14392 docs: update README to include sponsorship information and improve formatting 2025-04-05 00:01:48 +08:00
krau
30452c8d46 docs: consolidate message link information in help.md 2025-04-04 08:46:48 +08:00
krau
300f7723af fix: enhance webdav client impl 2025-03-31 17:34:24 +08:00
krau
491ba55f1e feat: add support for handling unsupported stream storage in download process 2025-03-26 10:35:40 +08:00
krau
32519b8c08 docs: add note about unsupported storage backends in Stream mode 2025-03-26 10:25:36 +08:00
krau
7ffd9891a0 fix: not pass content length when uploading in non stream mode 2025-03-26 10:22:38 +08:00
krau
347a60f1f7 fix: implement image extraction from Telegraph nodes 2025-03-24 22:04:55 +08:00
krau
da69fe1354 feat: enhance file name generation to include media extensions 2025-03-24 21:36:13 +08:00
krau
746ca026ba docs: remove outdated information about stream mode support in help documentation 2025-03-22 15:48:36 +08:00
krau
a8c64675e5 docs: update help documentation to include supported message links 2025-03-22 15:48:05 +08:00
krau
3918f6eee2 feat: add version and commit information to help text in start command 2025-03-22 15:45:34 +08:00
krau
8d44b43c82 fix: remove caching logic for Telegram messages in GetTGMessage function, close #40 2025-03-22 15:41:20 +08:00
krau
f14c4367f8 feat: cancel download telegraph task 2025-03-22 12:08:19 +08:00
krau
3e3a320672 feat: download telegraph images , close #5 2025-03-22 11:52:43 +08:00
krau
19efab0665 feat: implement GenFileNameFromMessage function for improved file naming 2025-03-22 09:33:50 +08:00
krau
635f00ac71 fix: reorganize cache destination path handling in processPendingTask function 2025-03-21 23:28:14 +08:00
krau
2d2becccf6 refactor: update storage interface to use io.Reader for Save method and remove stream implementations 2025-03-21 23:05:09 +08:00
krau
ed0837a89b refactor: replace logger usage with common.Log for consistent logging 2025-03-21 21:07:53 +08:00
krau
65fee89e14 feat: refactor storage configuration to use dedicated storage package and add new storage types
BREAKING CHANGE: remove deprecated config
2025-03-21 20:52:41 +08:00
krau
8e180006f0 chore: update dependencies to latest versions 2025-03-16 21:55:52 +08:00
krau
721c9666eb refactor: streamline storage configuration loading and remove redundant code 2025-03-11 22:24:52 +08:00
krau
6f35401181 docs: update links in README_EN.md for consistency 2025-03-11 21:46:04 +08:00
Krau
72ae2ce079 Merge pull request #35 from ysicing/main
feat: add Minio storage support
2025-03-11 21:41:43 +08:00
ysicing
495ad3ea5c feat: add Minio storage support
Signed-off-by: ysicing <i@ysicing.me>
2025-03-11 21:29:35 +08:00
krau
3def9df4b4 docs: update alist faq 2025-03-03 10:59:08 +08:00
krau
790a32d297 fix(alist): do not upload file as task to prevent alist cache full file 2025-03-03 10:58:03 +08:00
krau
f7779224ef docs: update example link 2025-03-01 15:54:21 +08:00
krau
7d899ae088 ci: Is anyone really using Windows ARM? 2025-03-01 14:01:10 +08:00
krau
7e67bdb7e2 fix: update executable compression condition for Windows ARM64 in build-release workflow 2025-03-01 13:55:42 +08:00
krau
0071780ff4 typo: deploy 2025-03-01 13:44:04 +08:00
krau
0a95431468 feat: add name to build release workflow 2025-03-01 13:39:08 +08:00
krau
34525c5b11 feat: add docs 2025-03-01 13:37:09 +08:00
krau
6ac6d79fb6 feat: update docker-compose.yml to use host network mode for accessing host services 2025-03-01 12:31:20 +08:00
krau
f21a82ad43 chore: clean up README.md by removing unnecessary demo video section 2025-03-01 12:29:43 +08:00
Krau
73f6647f8d Merge pull request #33 from krau/dev-stream
impl webdav stream mode & progress callback for stream mode
2025-03-01 12:24:46 +08:00
krau
6fbb4609f9 feat: show progress for stream mode 2025-03-01 12:22:50 +08:00
krau
802c908384 feat: refactor webdav client and implement custom upload stream handling 2025-03-01 12:06:55 +08:00
Krau
5d403056d0 Merge pull request #32 from krau/dev-stream
feat: add stream upload support and related configurations
2025-02-28 12:17:10 +08:00
krau
8e2dd37155 feat: add stream upload support and related configurations 2025-02-28 11:09:24 +08:00
krau
9c7ed833fd ci: add upx support 2025-02-28 09:45:34 +08:00
Krau
f9d601bd8a Merge pull request #30 from krau/dev
feat: cancel task
2025-02-27 22:34:58 +08:00
krau
152f473131 fix: delete done task 2025-02-27 22:25:10 +08:00
krau
7015081a84 feat: add context cancellation handling in saveFileWithRetry function 2025-02-27 22:07:41 +08:00
krau
be6444cf96 feat: implement task cancellation feature and update task handling 2025-02-27 22:02:16 +08:00
krau
98ba7c50e7 refactor: remove unused StoragePath initialization in AddToQueue function 2025-02-27 21:32:14 +08:00
krau
0c31d908cc feat: add dir command at init and show dirs in dir command help 2025-02-25 16:23:24 +08:00
krau
9e776b22fb feat: set dir for storages 2025-02-25 16:17:20 +08:00
krau
d6f8603656 docs: update change bot token comment 2025-02-25 15:09:44 +08:00
krau
9c42bee662 refactor: spilt handlers file 2025-02-24 17:50:35 +08:00
krau
b96340dd46 refactor: add err var ErrEmptyMessage 2025-02-24 17:41:36 +08:00
krau
a5ba01e219 typo: config example 2025-02-24 17:34:44 +08:00
krau
d00e907735 typo: config example 2025-02-24 17:34:11 +08:00
Twilight
418f9bd2bc 更详细的 config 配置以及更完善的 README (#25)
* Add files via upload

* Add files via upload

* Add OpenWrt auto-start and shortcut script instructions, Optimize file link reference method.

* add more detailed instructions.

* Add files via upload
2025-02-24 17:32:53 +08:00
krau
28b4585dba chore: update configuration for user storage filtering and add base path for file saving 2025-02-23 18:12:23 +08:00
krau
d2669f0c99 feat: add logging for file save operations in storage modules 2025-02-21 14:04:32 +08:00
krau
c9921926e3 chore: add configurable thread count 2025-02-21 13:53:46 +08:00
krau
d7cd2ede01 feat: add configurable thread count for file processing 2025-02-21 13:51:30 +08:00
krau
ed21b65c98 perf: refactor file download to support multithreading 2025-02-21 13:49:15 +08:00
krau
8975589c43 refactor: file download process and enhance progress tracking 2025-02-21 11:16:45 +08:00
krau
27dca2e343 perf: add UserStorages map and implement GetUserStorages function for user-specific storage retrieval 2025-02-20 22:57:45 +08:00
krau
5c8261c34a refactor: improve error handling in getSelectStorageMarkup for user retrieval 2025-02-20 22:53:08 +08:00
krau
cbc2dc82d8 fix: update EffectiveUser cannot obtain the accurate user, use GetUserChat instead 2025-02-20 22:52:16 +08:00
krau
09a7c5597d fix: add UserID to link message and enforce default storage setting in silent handler 2025-02-19 14:33:03 +08:00
krau
f73f18e90d fix: update user and file deletion to use unscoped delete; add user synchronization logic 2025-02-19 14:19:39 +08:00
krau
ab822c2fe6 fix: update create user to new config 2025-02-19 14:06:33 +08:00
krau
2579044841 fix: update permission check 2025-02-19 13:56:30 +08:00
krau
88a02aae8d chore: update config example and docker compose file 2025-02-19 13:42:12 +08:00
krau
ab374a870b chore: update readme and add english version
Co-authored-by: AHCorn <42889600+AHCorn@users.noreply.github.com>
2025-02-19 13:41:57 +08:00
krau
3a1b8f34ea chore: translate some import log to cn 2025-02-19 12:36:48 +08:00
krau
c4eb824457 feat: set default storage by inline keyboard 2025-02-19 12:23:12 +08:00
krau
692e970772 feat!: (WIP) switched back to using config files config storages because the conversation handling is shit 2025-02-19 11:05:30 +08:00
krau
80696c9661 feat: (WIP) add storage
Co-authored-by: AHCorn <42889600+AHCorn@users.noreply.github.com>
2025-02-18 22:53:07 +08:00
krau
18cd480264 fix: add json tag for config 2025-02-18 19:53:01 +08:00
krau
dfde65c28e feat: (WIP) migrate storage configuration to user-specific models and remove deprecated storage loading 2025-02-18 19:45:06 +08:00
krau
968547b005 feat!: (WIP) decouple storage, users, and configuration files to support multiple users 2025-02-18 17:17:02 +08:00
71 changed files with 3688 additions and 1363 deletions

View File

@@ -1,3 +1,5 @@
name: Build Release
on:
push:
tags:
@@ -36,6 +38,9 @@ jobs:
matrix:
goos: [linux, darwin, windows]
goarch: [amd64, arm64]
exclude:
- goos: windows
goarch: arm64
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -53,6 +58,7 @@ jobs:
goos: ${{ matrix.goos }}
goarch: ${{ matrix.goarch }}
github_token: ${{ secrets.GITHUB_TOKEN }}
executable_compression: upx
extra_files: |
LICENSE
README.md

22
.github/workflows/docs.yml vendored Normal file
View File

@@ -0,0 +1,22 @@
name: Deploy Docs
on:
push:
branches:
- main
paths:
- "docs/**"
workflow_dispatch:
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- uses: actions/cache@v4
with:
key: ${{ github.ref }}
path: .cache
- run: pip install mkdocs-material
- run: cd docs && mkdocs gh-deploy --force

View File

@@ -2,18 +2,12 @@
# <img src="docs/logo.jpg" width="45" align="center"> Save Any Bot
**简体中文** | [English](README_EN.md)
把 Telegram 的文件保存到各类存储端.
> _就像 PikPak Bot 一样_
</div
Demo Video:
<div align="center">
[SaveAny-Bot 演示视频 The Demo of SaveAny-Bot.webm](https://github.com/user-attachments/assets/a0de2453-a4d1-4a12-81fb-9d84856dce09)
</div>
## 部署
@@ -22,7 +16,7 @@ Demo Video:
在 [Release](https://github.com/krau/SaveAny-Bot/releases) 页面下载对应平台的二进制文件.
在解压后目录新建 `config.toml` 文件, 参考 [config.toml.example](https://github.com/krau/SaveAny-Bot/blob/main/config.example.toml) 编辑配置文件.
在解压后目录新建 `config.toml` 文件, 参考 [config.example.toml](./config.example.toml) 编辑配置文件.
运行:
@@ -56,13 +50,31 @@ WantedBy=multi-user.target
systemctl enable --now saveany-bot
```
#### 为 OpenWrt 及衍生系统添加开机自启动服务
创建文件 ` /etc/init.d/saveanybot` ,参考[saveanybot](./docs/saveanybot)自行修改.
`chmod +x /etc/init.d/saveanybot`
完成后,将文件复制到 `/etc/rc.d`并重命名为`S99saveanybot`.
`chmod +x /etc/rc.d/S99saveanybot`
#### 为 OpenWrt 及衍生系统添加快捷指令
创建文件` /usr/bin/sabot` ,参考[sabot](./docs/sabot)自行配置修改,注意此处文件编码仅支持 ANSI 936 .
`chmod +x /usr/bin/sabot`
之后,终端输入`sabot start|stop|restart|status|enable|disable`即可.
### 使用 Docker 部署
#### Docker Compose
下载 [docker-compose.yml](https://github.com/krau/SaveAny-Bot/blob/main/docker-compose.yml) 文件, 并修改其中的配置.
下载 [docker-compose.yml](./docker-compose.yml) 文件, 在同目录下新建 `config.toml` 文件, 参考 [config.example.toml](./config.example.toml) 编辑配置文件.
运行:
启动:
```bash
docker compose up -d
@@ -94,10 +106,18 @@ docker restart saveany-bot
## 使用
向 Bot 发送(转发)文件, 按照提示操作.
向 Bot 发送(转发)文件, 或发送公开频道的消息链接, 按照提示操作.
---
## 赞助
本项目受到 [YxVM](https://yxvm.com/) 与 [NodeSupport](https://github.com/NodeSeekDev/NodeSupport) 的支持.
如果这个项目对你有帮助, 你可以考虑通过以下方式赞助我:
- [爱发电](https://afdian.com/a/acherkrau)
## Thanks
- [gotd](https://github.com/gotd/td)

108
README_EN.md Normal file
View File

@@ -0,0 +1,108 @@
<div align="center">
# <img src="docs/logo.jpg" width="45" align="center"> Save Any Bot
[简体中文](README.md) | **English**
Save Telegram files to various storage endpoints.
> _Just like PikPak Bot_
</div>
## Deployment
### Deploy from Binary
Download the binary file for your platform from the [Release](https://github.com/krau/SaveAny-Bot/releases) page.
Create a `config.toml` file in the extracted directory, refer to [config.example.toml](https://github.com/krau/SaveAny-Bot/blob/main/config.example.toml) for configuration.
Run:
```bash
chmod +x saveany-bot
./saveany-bot
```
#### Add as systemd Service
Create file `/etc/systemd/system/saveany-bot.service` and write the following content:
```
[Unit]
Description=SaveAnyBot
After=systemd-user-sessions.service
[Service]
Type=simple
WorkingDirectory=/yourpath/
ExecStart=/yourpath/saveany-bot
Restart=on-failure
[Install]
WantedBy=multi-user.target
```
Enable auto-start and start the service:
```bash
systemctl enable --now saveany-bot
```
### Deploy with Docker
#### Docker Compose
Download [docker-compose.yml](https://github.com/krau/SaveAny-Bot/blob/main/docker-compose.yml) file and create a `config.toml` file in the same directory, refer to [config.example.toml](https://github.com/krau/SaveAny-Bot/blob/main/config.example.toml) for configuration.
Run:
```bash
docker compose up -d
```
#### Docker
```shell
docker run -d --name saveany-bot \
-v /path/to/config.toml:/app/config.toml \
-v /path/to/downloads:/app/downloads \
ghcr.io/krau/saveany-bot:latest
```
## Update
Use `upgrade` or `up` command to upgrade to the latest version:
```bash
./saveany-bot upgrade
```
If deployed with Docker, use the following commands to update:
```bash
docker pull ghcr.io/krau/saveany-bot:latest
docker restart saveany-bot
```
## Usage
Send (forward) files to the Bot and follow the prompts.
---
## Sponsors
This project is supported by [YxVM](https://yxvm.com/) and [NodeSupport](https://github.com/NodeSeekDev/NodeSupport).
You can consider sponsoring me if this project helps you:
- [Afdian](https://afdian.com/a/acherkrau)
## Thanks
- [gotd](https://github.com/gotd/td)
- [TG-FileStreamBot](https://github.com/EverythingSuckz/TG-FileStreamBot)
- [gotgproto](https://github.com/celestix/gotgproto)
- All the dependencies

View File

@@ -1,21 +0,0 @@
package bootstrap
import (
"github.com/krau/SaveAny-Bot/bot"
"github.com/krau/SaveAny-Bot/common"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/dao"
"github.com/krau/SaveAny-Bot/logger"
"github.com/krau/SaveAny-Bot/storage"
)
func InitAll() {
config.Init()
logger.InitLogger()
logger.L.Info("Running...")
common.Init()
storage.Init()
dao.Init()
bot.Init()
}

View File

@@ -11,8 +11,8 @@ import (
"github.com/glebarez/sqlite"
"github.com/gotd/td/telegram/dcs"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/common"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/logger"
"golang.org/x/net/proxy"
)
@@ -27,9 +27,10 @@ func newProxyDialer(proxyUrl string) (proxy.Dialer, error) {
}
func Init() {
logger.L.Info("Initializing client...")
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
common.Log.Info("初始化 Telegram 客户端...")
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(config.Cfg.Telegram.Timeout)*time.Second)
defer cancel()
go InitTelegraphClient()
resultChan := make(chan struct {
client *gotgproto.Client
err error
@@ -76,7 +77,7 @@ func Init() {
{Command: "silent", Description: "开启/关闭静默模式"},
{Command: "storage", Description: "设置默认存储端"},
{Command: "save", Description: "保存所回复的文件"},
{Command: "path", Description: "更改保存路径配置"},
{Command: "dir", Description: "管理存储文件夹"},
},
})
resultChan <- struct {
@@ -87,15 +88,15 @@ func Init() {
select {
case <-ctx.Done():
logger.L.Fatal("Failed to initialize client: timeout")
common.Log.Fatal("初始化客户端失败: 超时")
os.Exit(1)
case result := <-resultChan:
if result.err != nil {
logger.L.Fatalf("Failed to initialize client: %s", result.err)
common.Log.Fatalf("初始化客户端失败: %s", result.err)
os.Exit(1)
}
Client = result.client
RegisterHandlers(Client.Dispatcher)
logger.L.Info("Client initialized")
common.Log.Info("客户端初始化完成")
}
}

207
bot/handle_add_task.go Normal file
View File

@@ -0,0 +1,207 @@
package bot
import (
"errors"
"fmt"
"path"
"strconv"
"strings"
"github.com/celestix/gotgproto/dispatcher"
"github.com/celestix/gotgproto/ext"
"github.com/duke-git/lancet/v2/slice"
"github.com/gotd/td/telegram/message/entity"
"github.com/gotd/td/telegram/message/styling"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/common"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/dao"
"github.com/krau/SaveAny-Bot/queue"
"github.com/krau/SaveAny-Bot/types"
"gorm.io/gorm"
)
func AddToQueue(ctx *ext.Context, update *ext.Update) error {
// TODO: 回调数据用户独立鉴权 (处理 bot 在群聊中的情况)
if !slice.Contain(config.Cfg.GetUsersID(), update.CallbackQuery.UserID) {
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
QueryID: update.CallbackQuery.QueryID,
Alert: true,
Message: "你没有权限",
CacheTime: 5,
})
return dispatcher.EndGroups
}
args := strings.Split(string(update.CallbackQuery.Data), " ")
addToDir := args[0] == "add_to_dir" // 已经选择了路径
cbDataId, _ := strconv.Atoi(args[1])
cbData, err := dao.GetCallbackData(uint(cbDataId))
if err != nil {
common.Log.Errorf("获取回调数据失败: %s", err)
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
QueryID: update.CallbackQuery.QueryID,
Alert: true,
Message: "获取回调数据失败",
CacheTime: 5,
})
return dispatcher.EndGroups
}
data := strings.Split(cbData, " ")
fileChatID, _ := strconv.Atoi(data[0])
fileMessageID, _ := strconv.Atoi(data[1])
storageName := data[2]
dirIdInt, _ := strconv.Atoi(data[3])
dirId := uint(dirIdInt)
user, err := dao.GetUserByChatID(update.CallbackQuery.UserID)
if err != nil {
common.Log.Errorf("获取用户失败: %s", err)
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
QueryID: update.CallbackQuery.QueryID,
Alert: true,
Message: "获取用户失败",
CacheTime: 5,
})
return dispatcher.EndGroups
}
if !addToDir {
dirs, err := dao.GetDirsByUserIDAndStorageName(user.ID, storageName)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
common.Log.Errorf("获取路径失败: %s", err)
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
QueryID: update.CallbackQuery.QueryID,
Alert: true,
Message: "获取路径失败",
CacheTime: 5,
})
return dispatcher.EndGroups
}
if len(dirs) != 0 {
markup, err := getSelectDirMarkup(fileChatID, fileMessageID, storageName, dirs)
if err != nil {
common.Log.Errorf("获取路径失败: %s", err)
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
QueryID: update.CallbackQuery.QueryID,
Alert: true,
Message: "获取路径失败",
CacheTime: 5,
})
return dispatcher.EndGroups
}
_, err = ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
ID: update.CallbackQuery.GetMsgID(),
Message: "请选择要保存到的路径",
ReplyMarkup: markup,
})
if err != nil {
common.Log.Errorf("编辑消息失败: %s", err)
}
return dispatcher.EndGroups
}
}
common.Log.Tracef("Got add to queue: chatID: %d, messageID: %d, storage: %s", fileChatID, fileMessageID, storageName)
record, err := dao.GetReceivedFileByChatAndMessageID(int64(fileChatID), fileMessageID)
if err != nil {
common.Log.Errorf("获取记录失败: %s", err)
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
QueryID: update.CallbackQuery.QueryID,
Alert: true,
Message: "查询记录失败",
CacheTime: 5,
})
return dispatcher.EndGroups
}
if update.CallbackQuery.MsgID != record.ReplyMessageID {
record.ReplyMessageID = update.CallbackQuery.MsgID
if err := dao.SaveReceivedFile(record); err != nil {
common.Log.Errorf("更新接收的文件失败: %s", err)
}
}
var dir *dao.Dir
if addToDir && dirId != 0 {
dir, err = dao.GetDirByID(dirId)
if err != nil {
common.Log.Errorf("获取路径失败: %s", err)
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
QueryID: update.CallbackQuery.QueryID,
Alert: true,
Message: "获取路径失败",
CacheTime: 5,
})
return dispatcher.EndGroups
}
}
var task types.Task
if record.IsTelegraph {
task = types.Task{
Ctx: ctx,
Status: types.Pending,
IsTelegraph: true,
TelegraphURL: record.TelegraphURL,
StorageName: storageName,
FileChatID: record.ChatID,
FileMessageID: record.MessageID,
ReplyMessageID: record.ReplyMessageID,
ReplyChatID: record.ReplyChatID,
UserID: update.GetUserChat().GetID(),
}
if dir != nil {
task.StoragePath = path.Join(dir.Path, record.FileName)
}
} else {
file, err := FileFromMessage(ctx, record.ChatID, record.MessageID, record.FileName)
if err != nil {
common.Log.Errorf("获取消息中的文件失败: %s", err)
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
QueryID: update.CallbackQuery.QueryID,
Alert: true,
Message: fmt.Sprintf("获取消息中的文件失败: %s", err),
CacheTime: 5,
})
return dispatcher.EndGroups
}
task = types.Task{
Ctx: ctx,
Status: types.Pending,
File: file,
StorageName: storageName,
FileChatID: record.ChatID,
ReplyMessageID: record.ReplyMessageID,
FileMessageID: record.MessageID,
ReplyChatID: record.ReplyChatID,
UserID: update.GetUserChat().GetID(),
}
if dir != nil {
task.StoragePath = path.Join(dir.Path, file.FileName)
}
}
queue.AddTask(&task)
entityBuilder := entity.Builder{}
var entities []tg.MessageEntityClass
text := fmt.Sprintf("已添加到任务队列\n文件名: %s\n当前排队任务数: %d", record.FileName, queue.Len())
if err := styling.Perform(&entityBuilder,
styling.Plain("已添加到任务队列\n文件名: "),
styling.Code(record.FileName),
styling.Plain("\n当前排队任务数: "),
styling.Bold(strconv.Itoa(queue.Len())),
); err != nil {
common.Log.Errorf("Failed to build entity: %s", err)
} else {
text, entities = entityBuilder.Complete()
}
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: text,
Entities: entities,
ID: record.ReplyMessageID,
})
return dispatcher.EndGroups
}

27
bot/handle_cancel_task.go Normal file
View File

@@ -0,0 +1,27 @@
package bot
import (
"strings"
"github.com/celestix/gotgproto/dispatcher"
"github.com/celestix/gotgproto/ext"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/queue"
)
func cancelTask(ctx *ext.Context, update *ext.Update) error {
key := strings.Split(string(update.CallbackQuery.Data), " ")[1]
ok := queue.CancelTask(key)
if ok {
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
QueryID: update.CallbackQuery.QueryID,
Message: "任务已取消",
})
return dispatcher.EndGroups
}
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
QueryID: update.CallbackQuery.QueryID,
Message: "任务取消失败",
})
return dispatcher.EndGroups
}

View File

@@ -0,0 +1,75 @@
package bot
import (
"sync"
)
type ConversationType string
type ConversationState struct {
sync.Mutex
conversationType ConversationType
InConversation bool
data map[ConversationType]map[string]interface{}
}
func (c *ConversationState) Reset() {
c.Lock()
defer c.Unlock()
c.InConversation = false
c.conversationType = ""
c.data = make(map[ConversationType]map[string]interface{})
}
func (c *ConversationState) SetConversationType(t ConversationType) {
c.Lock()
defer c.Unlock()
c.conversationType = t
}
func (c *ConversationState) GetData(key string) interface{} {
if c.data == nil || c.data[c.conversationType] == nil {
return nil
}
return c.data[c.conversationType][key]
}
func (c *ConversationState) SetData(key string, value interface{}) {
c.Lock()
defer c.Unlock()
if c.data == nil {
c.data = make(map[ConversationType]map[string]interface{})
}
if c.data[c.conversationType] == nil {
c.data[c.conversationType] = make(map[string]interface{})
}
c.data[c.conversationType][key] = value
}
// TODO: Implement conversation handling
// var userConversationState = make(map[int64]*ConversationState)
// func handleConversation(ctx *ext.Context, update *ext.Update) error {
// userID := update.EffectiveUser().GetID()
// state, ok := userConversationState[userID]
// if !ok {
// return dispatcher.ContinueGroups
// }
// if update.EffectiveMessage.Text == "/cancel" {
// state.Reset()
// ctx.Reply(update, ext.ReplyTextString("已取消"), nil)
// return dispatcher.EndGroups
// }
// if !state.InConversation {
// return dispatcher.ContinueGroups
// }
// return handleConversationState(ctx, update, state)
// }
// func handleConversationState(ctx *ext.Context, update *ext.Update, state *ConversationState) error {
// switch state.conversationType {
// default:
// common.Log.Errorf("Unknown conversation type: %s", state.conversationType)
// }
// return dispatcher.EndGroups
// }

88
bot/handle_dir.go Normal file
View File

@@ -0,0 +1,88 @@
package bot
import (
"strings"
"github.com/celestix/gotgproto/dispatcher"
"github.com/celestix/gotgproto/ext"
"github.com/gotd/td/telegram/message/styling"
"github.com/krau/SaveAny-Bot/common"
"github.com/krau/SaveAny-Bot/dao"
"github.com/krau/SaveAny-Bot/storage"
)
func dirCmd(ctx *ext.Context, update *ext.Update) error {
args := strings.Split(strings.TrimPrefix(update.EffectiveMessage.Text, "/dir "), " ")
if len(args) < 3 {
dirs, err := dao.GetUserDirsByChatID(update.GetUserChat().GetID())
if err != nil {
common.Log.Errorf("获取用户路径失败: %s", err)
ctx.Reply(update, ext.ReplyTextString("获取用户路径失败"), nil)
return dispatcher.EndGroups
}
ctx.Reply(update, ext.ReplyTextStyledTextArray(
[]styling.StyledTextOption{
styling.Bold("使用方法: /dir <操作> <存储名> <路径>"),
styling.Plain("\n\n可用操作:\n"),
styling.Code("add"),
styling.Plain(" - 添加路径\n"),
styling.Code("del"),
styling.Plain(" - 删除路径\n"),
styling.Plain("\n示例:\n"),
styling.Code("/dir add local1 path/to/dir"),
styling.Plain("\n\n当前已添加的路径:\n"),
styling.Blockquote(func() string {
var sb strings.Builder
for _, dir := range dirs {
sb.WriteString(dir.StorageName)
sb.WriteString(" - ")
sb.WriteString(dir.Path)
sb.WriteString("\n")
}
return sb.String()
}(), true),
},
), nil)
return dispatcher.EndGroups
}
user, err := dao.GetUserByChatID(update.GetUserChat().GetID())
if err != nil {
common.Log.Errorf("获取用户失败: %s", err)
ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil)
return dispatcher.EndGroups
}
switch args[0] {
case "add":
return addDir(ctx, update, user, args[1], args[2])
case "del":
return delDir(ctx, update, user, args[1], args[2])
default:
ctx.Reply(update, ext.ReplyTextString("未知操作"), nil)
return dispatcher.EndGroups
}
}
func addDir(ctx *ext.Context, update *ext.Update, user *dao.User, storageName, path string) error {
if _, err := storage.GetStorageByUserIDAndName(user.ChatID, storageName); err != nil {
ctx.Reply(update, ext.ReplyTextString(err.Error()), nil)
return dispatcher.EndGroups
}
if err := dao.CreateDirForUser(user.ID, storageName, path); err != nil {
common.Log.Errorf("创建路径失败: %s", err)
ctx.Reply(update, ext.ReplyTextString("创建路径失败"), nil)
return dispatcher.EndGroups
}
ctx.Reply(update, ext.ReplyTextString("路径添加成功"), nil)
return dispatcher.EndGroups
}
func delDir(ctx *ext.Context, update *ext.Update, user *dao.User, storageName, path string) error {
if err := dao.DeleteDirForUser(user.ID, storageName, path); err != nil {
common.Log.Errorf("删除路径失败: %s", err)
ctx.Reply(update, ext.ReplyTextString("删除路径失败"), nil)
return dispatcher.EndGroups
}
ctx.Reply(update, ext.ReplyTextString("路径删除成功"), nil)
return dispatcher.EndGroups
}

85
bot/handle_file.go Normal file
View File

@@ -0,0 +1,85 @@
package bot
import (
"fmt"
"github.com/celestix/gotgproto/dispatcher"
"github.com/celestix/gotgproto/ext"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/common"
"github.com/krau/SaveAny-Bot/dao"
"github.com/krau/SaveAny-Bot/storage"
"github.com/krau/SaveAny-Bot/types"
)
func handleFileMessage(ctx *ext.Context, update *ext.Update) error {
common.Log.Trace("Got media: ", update.EffectiveMessage.Media.TypeName())
supported, err := supportedMediaFilter(update.EffectiveMessage.Message)
if err != nil {
return err
}
if !supported {
return dispatcher.EndGroups
}
user, err := dao.GetUserByChatID(update.GetUserChat().GetID())
if err != nil {
common.Log.Errorf("获取用户失败: %s", err)
ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil)
return dispatcher.EndGroups
}
storages := storage.GetUserStorages(user.ChatID)
if len(storages) == 0 {
ctx.Reply(update, ext.ReplyTextString("无可用的存储"), nil)
return dispatcher.EndGroups
}
msg, err := ctx.Reply(update, ext.ReplyTextString("正在获取文件信息..."), nil)
if err != nil {
common.Log.Errorf("回复失败: %s", err)
return dispatcher.EndGroups
}
media := update.EffectiveMessage.Media
file, err := FileFromMedia(media, "")
if err != nil {
common.Log.Errorf("获取文件失败: %s", err)
ctx.Reply(update, ext.ReplyTextString(fmt.Sprintf("获取文件失败: %s", err)), nil)
return dispatcher.EndGroups
}
if file.FileName == "" {
file.FileName = GenFileNameFromMessage(*update.EffectiveMessage.Message, file)
}
if err := dao.SaveReceivedFile(&dao.ReceivedFile{
Processing: false,
FileName: file.FileName,
ChatID: update.EffectiveChat().GetID(),
MessageID: update.EffectiveMessage.ID,
ReplyMessageID: msg.ID,
ReplyChatID: update.GetUserChat().GetID(),
}); err != nil {
common.Log.Errorf("添加接收的文件失败: %s", err)
if _, err := ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: fmt.Sprintf("添加接收的文件失败: %s", err),
ID: msg.ID,
}); err != nil {
common.Log.Errorf("编辑消息失败: %s", err)
}
return dispatcher.EndGroups
}
if !user.Silent || user.DefaultStorage == "" {
return ProvideSelectMessage(ctx, update, file.FileName, update.EffectiveChat().GetID(), update.EffectiveMessage.ID, msg.ID)
}
return HandleSilentAddTask(ctx, update, user, &types.Task{
Ctx: ctx,
Status: types.Pending,
File: file,
StorageName: user.DefaultStorage,
FileChatID: update.EffectiveChat().GetID(),
ReplyMessageID: msg.ID,
ReplyChatID: update.GetUserChat().GetID(),
FileMessageID: update.EffectiveMessage.ID,
UserID: user.ChatID,
})
}

View File

@@ -1,7 +1,6 @@
package bot
import (
"fmt"
"regexp"
"strconv"
"strings"
@@ -9,8 +8,9 @@ import (
"github.com/celestix/gotgproto/dispatcher"
"github.com/celestix/gotgproto/ext"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/common"
"github.com/krau/SaveAny-Bot/dao"
"github.com/krau/SaveAny-Bot/logger"
"github.com/krau/SaveAny-Bot/storage"
"github.com/krau/SaveAny-Bot/types"
)
@@ -20,7 +20,7 @@ var (
)
func handleLinkMessage(ctx *ext.Context, update *ext.Update) error {
logger.L.Trace("Got link message")
common.Log.Trace("Got link message")
link := linkRegex.FindString(update.EffectiveMessage.Text)
if link == "" {
return dispatcher.ContinueGroups
@@ -31,45 +31,51 @@ func handleLinkMessage(ctx *ext.Context, update *ext.Update) error {
}
messageID, err := strconv.Atoi(strSlice[2])
if err != nil {
logger.L.Errorf("Failed to parse message ID: %s", err)
ctx.Reply(update, ext.ReplyTextString("Failed to parse message ID"), nil)
common.Log.Errorf("解析消息 ID 失败: %s", err)
ctx.Reply(update, ext.ReplyTextString("无法解析消息 ID"), nil)
return dispatcher.EndGroups
}
chatUsername := strSlice[1]
linkChat, err := ctx.ResolveUsername(chatUsername)
if err != nil {
logger.L.Errorf("Failed to resolve chat ID: %s", err)
ctx.Reply(update, ext.ReplyTextString("Failed to resolve chat ID"), nil)
common.Log.Errorf("解析 Chat ID 失败: %s", err)
ctx.Reply(update, ext.ReplyTextString("无法解析 Chat ID"), nil)
return dispatcher.EndGroups
}
if linkChat == nil {
logger.L.Errorf("Cannot find chat: %s", chatUsername)
ctx.Reply(update, ext.ReplyTextString("Cannot find chat"), nil)
common.Log.Errorf("无法找到聊天: %s", chatUsername)
ctx.Reply(update, ext.ReplyTextString("无法找到聊天"), nil)
return dispatcher.EndGroups
}
user, err := dao.GetUserByUserID(update.GetUserChat().GetID())
user, err := dao.GetUserByChatID(update.GetUserChat().GetID())
if err != nil {
logger.L.Errorf("Failed to get user: %s", err)
common.Log.Errorf("获取用户失败: %s", err)
ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil)
return dispatcher.EndGroups
}
storages := storage.GetUserStorages(user.ChatID)
if len(storages) == 0 {
ctx.Reply(update, ext.ReplyTextString("无可用的存储"), nil)
return dispatcher.EndGroups
}
replied, err := ctx.Reply(update, ext.ReplyTextString("正在获取文件..."), nil)
if err != nil {
logger.L.Errorf("Failed to reply: %s", err)
common.Log.Errorf("回复失败: %s", err)
return dispatcher.EndGroups
}
file, err := FileFromMessage(ctx, linkChat.GetID(), messageID, "")
if err != nil {
logger.L.Errorf("Failed to get file from message: %s", err)
common.Log.Errorf("获取文件失败: %s", err)
ctx.Reply(update, ext.ReplyTextString("获取文件失败: "+err.Error()), nil)
return dispatcher.EndGroups
}
if file.FileName == "" {
logger.L.Warnf("Empty file name, use generated name")
file.FileName = fmt.Sprintf("%d_%d_%s", linkChat.GetID(), messageID, file.Hash())
file.FileName = GenFileNameFromMessage(*update.EffectiveMessage.Message, file)
}
receivedFile := &types.ReceivedFile{
receivedFile := &dao.ReceivedFile{
Processing: false,
FileName: file.FileName,
ChatID: linkChat.GetID(),
@@ -78,21 +84,22 @@ func handleLinkMessage(ctx *ext.Context, update *ext.Update) error {
ReplyChatID: update.GetUserChat().GetID(),
}
if err := dao.SaveReceivedFile(receivedFile); err != nil {
logger.L.Errorf("Failed to save received file: %s", err)
common.Log.Errorf("保存接收的文件失败: %s", err)
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: "无法保存文件: " + err.Error(),
ID: replied.ID,
})
return dispatcher.EndGroups
}
if !user.Silent {
return ProvideSelectMessage(ctx, update, file, int(linkChat.GetID()), messageID, replied.ID)
if !user.Silent || user.DefaultStorage == "" {
return ProvideSelectMessage(ctx, update, file.FileName, linkChat.GetID(), messageID, replied.ID)
}
return HandleSilentAddTask(ctx, update, user, &types.Task{
Ctx: ctx,
Status: types.Pending,
File: file,
Storage: types.StorageType(user.DefaultStorage),
StorageName: user.DefaultStorage,
UserID: user.ChatID,
FileChatID: linkChat.GetID(),
FileMessageID: messageID,
ReplyMessageID: replied.ID,

115
bot/handle_save.go Normal file
View File

@@ -0,0 +1,115 @@
package bot
import (
"fmt"
"strings"
"github.com/celestix/gotgproto/dispatcher"
"github.com/celestix/gotgproto/ext"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/common"
"github.com/krau/SaveAny-Bot/dao"
"github.com/krau/SaveAny-Bot/storage"
"github.com/krau/SaveAny-Bot/types"
)
func saveCmd(ctx *ext.Context, update *ext.Update) error {
res, ok := update.EffectiveMessage.GetReplyTo()
if !ok || res == nil {
ctx.Reply(update, ext.ReplyTextString("请回复要保存的文件"), nil)
return dispatcher.EndGroups
}
replyHeader, ok := res.(*tg.MessageReplyHeader)
if !ok {
ctx.Reply(update, ext.ReplyTextString("请回复要保存的文件"), nil)
return dispatcher.EndGroups
}
replyToMsgID, ok := replyHeader.GetReplyToMsgID()
if !ok {
ctx.Reply(update, ext.ReplyTextString("请回复要保存的文件"), nil)
return dispatcher.EndGroups
}
user, err := dao.GetUserByChatID(update.GetUserChat().GetID())
if err != nil {
common.Log.Errorf("获取用户失败: %s", err)
ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil)
return dispatcher.EndGroups
}
storages := storage.GetUserStorages(user.ChatID)
if len(storages) == 0 {
ctx.Reply(update, ext.ReplyTextString("无可用的存储"), nil)
return dispatcher.EndGroups
}
msg, err := GetTGMessage(ctx, update.EffectiveChat().GetID(), replyToMsgID)
if err != nil {
common.Log.Errorf("获取消息失败: %s", err)
ctx.Reply(update, ext.ReplyTextString("无法获取消息"), nil)
return dispatcher.EndGroups
}
supported, _ := supportedMediaFilter(msg)
if !supported {
ctx.Reply(update, ext.ReplyTextString("不支持的消息类型或消息中没有文件"), nil)
return dispatcher.EndGroups
}
replied, err := ctx.Reply(update, ext.ReplyTextString("正在获取文件信息..."), nil)
if err != nil {
common.Log.Errorf("回复失败: %s", err)
return dispatcher.EndGroups
}
cmdText := update.EffectiveMessage.Text
customFileName := strings.TrimSpace(strings.TrimPrefix(cmdText, "/save"))
file, err := FileFromMessage(ctx, update.EffectiveChat().GetID(), msg.ID, customFileName)
if err != nil {
common.Log.Errorf("获取文件失败: %s", err)
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: fmt.Sprintf("获取文件失败: %s", err),
ID: replied.ID,
})
return dispatcher.EndGroups
}
if file.FileName == "" {
file.FileName = GenFileNameFromMessage(*msg, file)
}
receivedFile := &dao.ReceivedFile{
Processing: false,
FileName: file.FileName,
ChatID: update.EffectiveChat().GetID(),
MessageID: replyToMsgID,
ReplyMessageID: replied.ID,
ReplyChatID: update.GetUserChat().GetID(),
}
if err := dao.SaveReceivedFile(receivedFile); err != nil {
common.Log.Errorf("保存接收的文件失败: %s", err)
if _, err := ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: fmt.Sprintf("保存接收的文件失败: %s", err),
ID: replied.ID,
}); err != nil {
common.Log.Errorf("编辑消息失败: %s", err)
}
return dispatcher.EndGroups
}
if !user.Silent || user.DefaultStorage == "" {
return ProvideSelectMessage(ctx, update, file.FileName, update.EffectiveChat().GetID(), msg.ID, replied.ID)
}
return HandleSilentAddTask(ctx, update, user, &types.Task{
Ctx: ctx,
Status: types.Pending,
File: file,
StorageName: user.DefaultStorage,
FileChatID: update.EffectiveChat().GetID(),
ReplyMessageID: replied.ID,
ReplyChatID: update.GetUserChat().GetID(),
FileMessageID: msg.ID,
UserID: user.ChatID,
})
}

30
bot/handle_silent.go Normal file
View File

@@ -0,0 +1,30 @@
package bot
import (
"fmt"
"github.com/celestix/gotgproto/dispatcher"
"github.com/celestix/gotgproto/ext"
"github.com/krau/SaveAny-Bot/common"
"github.com/krau/SaveAny-Bot/dao"
)
func silent(ctx *ext.Context, update *ext.Update) error {
user, err := dao.GetUserByChatID(update.GetUserChat().GetID())
if err != nil {
common.Log.Errorf("获取用户失败: %s", err)
return dispatcher.EndGroups
}
if !user.Silent && user.DefaultStorage == "" {
ctx.Reply(update, ext.ReplyTextString("请先使用 /storage 设置默认存储位置"), nil)
return dispatcher.EndGroups
}
user.Silent = !user.Silent
if err := dao.UpdateUser(user); err != nil {
common.Log.Errorf("更新用户失败: %s", err)
ctx.Reply(update, ext.ReplyTextString("更新用户失败"), nil)
return dispatcher.EndGroups
}
ctx.Reply(update, ext.ReplyTextString(fmt.Sprintf("已%s静默模式", map[bool]string{true: "开启", false: "关闭"}[user.Silent])), nil)
return dispatcher.EndGroups
}

40
bot/handle_start.go Normal file
View File

@@ -0,0 +1,40 @@
package bot
import (
"fmt"
"github.com/celestix/gotgproto/dispatcher"
"github.com/celestix/gotgproto/ext"
"github.com/krau/SaveAny-Bot/common"
"github.com/krau/SaveAny-Bot/dao"
)
func start(ctx *ext.Context, update *ext.Update) error {
if err := dao.CreateUser(update.GetUserChat().GetID()); err != nil {
common.Log.Errorf("创建用户失败: %s", err)
return dispatcher.EndGroups
}
return help(ctx, update)
}
const helpText string = `
Save Any Bot - 转存你的 Telegram 文件
版本: %s , 提交: %s
命令:
/start - 开始使用
/help - 显示帮助
/silent - 开关静默模式
/storage - 设置默认存储位置
/save [自定义文件名] - 保存文件
静默模式: 开启后 Bot 直接保存到收到的文件到默认位置, 不再询问
默认存储位置: 在静默模式下保存到的位置
向 Bot 发送(转发)文件, 或发送一个公开频道的消息链接以保存文件
`
func help(ctx *ext.Context, update *ext.Update) error {
ctx.Reply(update, ext.ReplyTextString(fmt.Sprintf(helpText, common.Version, common.GitCommit[:7])), nil)
return dispatcher.EndGroups
}

99
bot/handle_storage.go Normal file
View File

@@ -0,0 +1,99 @@
package bot
import (
"fmt"
"strconv"
"strings"
"github.com/celestix/gotgproto/dispatcher"
"github.com/celestix/gotgproto/ext"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/common"
"github.com/krau/SaveAny-Bot/dao"
"github.com/krau/SaveAny-Bot/storage"
)
func storageCmd(ctx *ext.Context, update *ext.Update) error {
userChatID := update.GetUserChat().GetID()
storages := storage.GetUserStorages(userChatID)
if len(storages) == 0 {
ctx.Reply(update, ext.ReplyTextString("无可用的存储"), nil)
return dispatcher.EndGroups
}
markup, err := getSetDefaultStorageMarkup(userChatID, storages)
if err != nil {
common.Log.Errorf("Failed to get markup: %s", err)
ctx.Reply(update, ext.ReplyTextString("获取存储位置失败"), nil)
return dispatcher.EndGroups
}
ctx.Reply(update, ext.ReplyTextString("请选择要设为默认的存储位置"), &ext.ReplyOpts{
Markup: markup,
})
return dispatcher.EndGroups
}
func setDefaultStorage(ctx *ext.Context, update *ext.Update) error {
args := strings.Split(string(update.CallbackQuery.Data), " ")
userID, _ := strconv.Atoi(args[1])
if userID != int(update.CallbackQuery.GetUserID()) {
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
QueryID: update.CallbackQuery.QueryID,
Alert: true,
Message: "你没有权限",
CacheTime: 5,
})
return dispatcher.EndGroups
}
cbDataId, _ := strconv.Atoi(args[2])
storageName, err := dao.GetCallbackData(uint(cbDataId))
if err != nil {
common.Log.Errorf("获取回调数据失败: %s", err)
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
QueryID: update.CallbackQuery.QueryID,
Alert: true,
Message: "获取回调数据失败",
CacheTime: 5,
})
return dispatcher.EndGroups
}
selectedStorage, err := storage.GetStorageByName(storageName)
if err != nil {
common.Log.Errorf("获取指定存储失败: %s", err)
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
QueryID: update.CallbackQuery.QueryID,
Alert: true,
Message: "获取指定存储失败",
CacheTime: 5,
})
return dispatcher.EndGroups
}
user, err := dao.GetUserByChatID(int64(userID))
if err != nil {
common.Log.Errorf("Failed to get user: %s", err)
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
QueryID: update.CallbackQuery.QueryID,
Alert: true,
Message: "获取用户失败",
CacheTime: 5,
})
return dispatcher.EndGroups
}
user.DefaultStorage = storageName
if err := dao.UpdateUser(user); err != nil {
common.Log.Errorf("Failed to update user: %s", err)
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
QueryID: update.CallbackQuery.QueryID,
Alert: true,
Message: "更新用户失败",
CacheTime: 5,
})
return dispatcher.EndGroups
}
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: fmt.Sprintf("已将 %s (%s) 设为默认存储位置", selectedStorage.Name(), selectedStorage.Type()),
ID: update.CallbackQuery.GetMsgID(),
})
return dispatcher.EndGroups
}

114
bot/handle_telegraph.go Normal file
View File

@@ -0,0 +1,114 @@
package bot
import (
"fmt"
"net/http"
"net/url"
"regexp"
"strings"
"time"
"github.com/celestix/gotgproto/dispatcher"
"github.com/celestix/gotgproto/ext"
"github.com/celestix/telegraph-go/v2"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/common"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/dao"
"github.com/krau/SaveAny-Bot/storage"
"github.com/krau/SaveAny-Bot/types"
)
var (
TelegraphClient *telegraph.TelegraphClient
TelegraphUrlRegexString = `https://telegra.ph/.*`
TelegraphUrlRegex = regexp.MustCompile(TelegraphUrlRegexString)
)
func InitTelegraphClient() {
var httpClient *http.Client
if config.Cfg.Telegram.Proxy.Enable {
proxyUrl, err := url.Parse(config.Cfg.Telegram.Proxy.URL)
if err != nil {
fmt.Println("Error parsing proxy URL:", err)
return
}
proxy := http.ProxyURL(proxyUrl)
httpClient = &http.Client{
Transport: &http.Transport{
Proxy: proxy,
},
Timeout: 30 * time.Second,
}
} else {
httpClient = &http.Client{
Timeout: 30 * time.Second,
}
}
TelegraphClient = telegraph.GetTelegraphClient(&telegraph.ClientOpt{HttpClient: httpClient})
}
func handleTelegraph(ctx *ext.Context, update *ext.Update) error {
common.Log.Trace("Got telegraph link")
tgphUrl := TelegraphUrlRegex.FindString(update.EffectiveMessage.Text)
if tgphUrl == "" {
return dispatcher.ContinueGroups
}
replied, err := ctx.Reply(update, ext.ReplyTextString("正在获取文件..."), nil)
if err != nil {
common.Log.Errorf("回复失败: %s", err)
return dispatcher.EndGroups
}
user, err := dao.GetUserByChatID(update.GetUserChat().GetID())
if err != nil {
common.Log.Errorf("获取用户失败: %s", err)
ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil)
return dispatcher.EndGroups
}
storages := storage.GetUserStorages(user.ChatID)
if len(storages) == 0 {
ctx.Reply(update, ext.ReplyTextString("无可用的存储"), nil)
return dispatcher.EndGroups
}
tgphPath := strings.Split(tgphUrl, "/")[len(strings.Split(tgphUrl, "/"))-1]
fileName, err := url.PathUnescape(tgphPath)
if err != nil {
common.Log.Errorf("解析 Telegraph 路径失败: %s", err)
fileName = tgphPath
}
record := &dao.ReceivedFile{
Processing: false,
FileName: fileName,
ChatID: update.EffectiveChat().GetID(),
MessageID: update.EffectiveMessage.GetID(),
ReplyMessageID: replied.ID,
ReplyChatID: update.EffectiveChat().GetID(),
IsTelegraph: true,
TelegraphURL: tgphUrl,
}
if err := dao.SaveReceivedFile(record); err != nil {
common.Log.Errorf("保存接收的文件失败: %s", err)
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: "无法保存文件: " + err.Error(),
ID: replied.ID,
})
return dispatcher.EndGroups
}
if !user.Silent || user.DefaultStorage == "" {
return ProvideSelectMessage(ctx, update, fileName, update.EffectiveChat().GetID(), update.EffectiveMessage.GetID(), replied.ID)
}
return HandleSilentAddTask(ctx, update, user, &types.Task{
Ctx: ctx,
Status: types.Pending,
StorageName: user.DefaultStorage,
UserID: user.ChatID,
ReplyMessageID: replied.ID,
ReplyChatID: update.GetUserChat().GetID(),
IsTelegraph: true,
TelegraphURL: tgphUrl,
})
}

View File

@@ -1,26 +1,10 @@
package bot
import (
"fmt"
"strconv"
"strings"
"github.com/duke-git/lancet/v2/slice"
"github.com/gookit/goutil/maputil"
"github.com/gotd/td/telegram/message/entity"
"github.com/gotd/td/telegram/message/styling"
"github.com/gotd/td/tg"
"github.com/celestix/gotgproto/dispatcher"
"github.com/celestix/gotgproto/dispatcher/handlers"
"github.com/celestix/gotgproto/dispatcher/handlers/filters"
"github.com/celestix/gotgproto/ext"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/dao"
"github.com/krau/SaveAny-Bot/logger"
"github.com/krau/SaveAny-Bot/queue"
"github.com/krau/SaveAny-Bot/storage"
"github.com/krau/SaveAny-Bot/types"
"github.com/krau/SaveAny-Bot/common"
)
func RegisterHandlers(dispatcher dispatcher.Dispatcher) {
@@ -28,390 +12,21 @@ func RegisterHandlers(dispatcher dispatcher.Dispatcher) {
dispatcher.AddHandler(handlers.NewCommand("start", start))
dispatcher.AddHandler(handlers.NewCommand("help", help))
dispatcher.AddHandler(handlers.NewCommand("silent", silent))
dispatcher.AddHandler(handlers.NewCommand("storage", setDefaultStorage))
dispatcher.AddHandler(handlers.NewCommand("storage", storageCmd))
dispatcher.AddHandler(handlers.NewCommand("save", saveCmd))
dispatcher.AddHandler(handlers.NewCommand("path", setPath))
dispatcher.AddHandler(handlers.NewCommand("dir", dirCmd))
linkRegexFilter, err := filters.Message.Regex(linkRegexString)
if err != nil {
logger.L.Panicf("Failed to create regex filter: %s", err)
common.Log.Panicf("创建正则表达式过滤器失败: %s", err)
}
dispatcher.AddHandler(handlers.NewMessage(linkRegexFilter, handleLinkMessage))
telegraphUrlRegexFilter, err := filters.Message.Regex(TelegraphUrlRegexString)
if err != nil {
common.Log.Panicf("创建 Telegraph URL 正则表达式过滤器失败: %s", err)
}
dispatcher.AddHandler(handlers.NewMessage(telegraphUrlRegexFilter, handleTelegraph))
dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("add"), AddToQueue))
dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("set_default"), setDefaultStorage))
dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("cancel"), cancelTask))
dispatcher.AddHandler(handlers.NewMessage(filters.Message.Media, handleFileMessage))
}
const noPermissionText string = `
本 Bot 仅限个人使用.
您可以部署自己的实例: https://github.com/krau/SaveAny-Bot
`
func checkPermission(ctx *ext.Context, update *ext.Update) error {
userID := update.GetUserChat().GetID()
if !slice.Contain(config.Cfg.Telegram.Admins, userID) {
ctx.Reply(update, ext.ReplyTextString(noPermissionText), nil)
return dispatcher.EndGroups
}
return dispatcher.ContinueGroups
}
func start(ctx *ext.Context, update *ext.Update) error {
if err := dao.CreateUser(update.GetUserChat().GetID()); err != nil {
logger.L.Errorf("Failed to create user: %s", err)
return dispatcher.EndGroups
}
return help(ctx, update)
}
const helpText string = `
Save Any Bot - 转存你的 Telegram 文件
命令:
/start - 开始使用
/help - 显示帮助
/silent - 静默模式
/storage - 设置默认存储位置
/save [自定义文件名] - 保存文件
/path <存储类型> <路径> - 更改文件保存路径
静默模式: 开启后 Bot 直接保存到收到的文件到默认位置, 不再询问
默认存储位置: 在静默模式下保存到的位置
向 Bot 发送(转发)文件, 或发送一个公开频道的消息链接以保存文件
`
func help(ctx *ext.Context, update *ext.Update) error {
ctx.Reply(update, ext.ReplyTextString(helpText), nil)
return dispatcher.EndGroups
}
func silent(ctx *ext.Context, update *ext.Update) error {
user, err := dao.GetUserByUserID(update.GetUserChat().GetID())
if err != nil {
logger.L.Errorf("Failed to get user: %s", err)
return dispatcher.EndGroups
}
user.Silent = !user.Silent
if err := dao.UpdateUser(user); err != nil {
logger.L.Errorf("Failed to update user: %s", err)
return dispatcher.EndGroups
}
ctx.Reply(update, ext.ReplyTextString(fmt.Sprintf("已%s静默模式", map[bool]string{true: "开启", false: "关闭"}[user.Silent])), nil)
return dispatcher.EndGroups
}
func setDefaultStorage(ctx *ext.Context, update *ext.Update) error {
if len(storage.Storages) == 0 {
ctx.Reply(update, ext.ReplyTextString("未配置存储"), nil)
return dispatcher.EndGroups
}
args := strings.Split(update.EffectiveMessage.Text, " ")
avaliableStorages := maputil.Keys(storage.Storages)
if len(args) < 2 {
text := []styling.StyledTextOption{
styling.Plain("请提供存储位置名称, 可用项:"),
}
for _, name := range avaliableStorages {
text = append(text, styling.Plain("\n"))
text = append(text, styling.Code(name))
}
text = append(text, styling.Plain("\n示例: /storage local"))
ctx.Reply(update, ext.ReplyTextStyledTextArray(text), nil)
return dispatcher.EndGroups
}
storageName := args[1]
if !slice.Contain(avaliableStorages, storageName) {
ctx.Reply(update, ext.ReplyTextString("存储位置不存在"), nil)
return dispatcher.EndGroups
}
user, err := dao.GetUserByUserID(update.GetUserChat().GetID())
if err != nil {
logger.L.Errorf("Failed to get user: %s", err)
return dispatcher.EndGroups
}
user.DefaultStorage = storageName
if err := dao.UpdateUser(user); err != nil {
logger.L.Errorf("Failed to update user: %s", err)
return dispatcher.EndGroups
}
ctx.Reply(update, ext.ReplyTextString(fmt.Sprintf("已设置默认存储位置为 %s", storageName)), nil)
return dispatcher.EndGroups
}
func saveCmd(ctx *ext.Context, update *ext.Update) error {
res, ok := update.EffectiveMessage.GetReplyTo()
if !ok || res == nil {
ctx.Reply(update, ext.ReplyTextString("请回复要保存的文件"), nil)
return dispatcher.EndGroups
}
replyHeader, ok := res.(*tg.MessageReplyHeader)
if !ok {
ctx.Reply(update, ext.ReplyTextString("请回复要保存的文件"), nil)
return dispatcher.EndGroups
}
replyToMsgID, ok := replyHeader.GetReplyToMsgID()
if !ok {
ctx.Reply(update, ext.ReplyTextString("请回复要保存的文件"), nil)
return dispatcher.EndGroups
}
msg, err := GetTGMessage(ctx, update.EffectiveChat().GetID(), replyToMsgID)
if err != nil {
logger.L.Errorf("Failed to get message: %s", err)
ctx.Reply(update, ext.ReplyTextString("无法获取消息"), nil)
return dispatcher.EndGroups
}
supported, _ := supportedMediaFilter(msg)
if !supported {
ctx.Reply(update, ext.ReplyTextString("不支持的消息类型或消息中没有文件"), nil)
return dispatcher.EndGroups
}
user, err := dao.GetUserByUserID(update.GetUserChat().GetID())
if err != nil {
logger.L.Errorf("Failed to get user: %s", err)
return dispatcher.EndGroups
}
replied, err := ctx.Reply(update, ext.ReplyTextString("正在获取文件信息..."), nil)
if err != nil {
logger.L.Errorf("Failed to reply: %s", err)
return dispatcher.EndGroups
}
cmdText := update.EffectiveMessage.Text
customFileName := strings.TrimSpace(strings.TrimPrefix(cmdText, "/save"))
file, err := FileFromMessage(ctx, update.EffectiveChat().GetID(), msg.ID, customFileName)
if err != nil {
logger.L.Errorf("Failed to get file from message: %s", err)
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: fmt.Sprintf("获取文件失败: %s", err),
ID: replied.ID,
})
return dispatcher.EndGroups
}
if file.FileName == "" {
file.FileName = fmt.Sprintf("%d_%d_%s", update.EffectiveChat().GetID(), replyToMsgID, file.Hash())
}
receivedFile := &types.ReceivedFile{
Processing: false,
FileName: file.FileName,
ChatID: update.EffectiveChat().GetID(),
MessageID: replyToMsgID,
ReplyMessageID: replied.ID,
ReplyChatID: update.GetUserChat().GetID(),
}
if err := dao.SaveReceivedFile(receivedFile); err != nil {
logger.L.Errorf("Failed to save received file: %s", err)
if _, err := ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: fmt.Sprintf("Failed to save received file: %s", err),
ID: replied.ID,
}); err != nil {
logger.L.Errorf("Failed to edit message: %s", err)
}
return dispatcher.EndGroups
}
if !user.Silent {
return ProvideSelectMessage(ctx, update, file, int(update.EffectiveChat().GetID()), msg.ID, replied.ID)
}
return HandleSilentAddTask(ctx, update, user, &types.Task{
Ctx: ctx,
Status: types.Pending,
File: file,
Storage: types.StorageType(user.DefaultStorage),
FileChatID: update.EffectiveChat().GetID(),
ReplyMessageID: replied.ID,
ReplyChatID: update.GetUserChat().GetID(),
FileMessageID: msg.ID,
})
}
func setPath(ctx *ext.Context, update *ext.Update) error {
if len(storage.Storages) == 0 {
ctx.Reply(update, ext.ReplyTextString("未配置存储"), nil)
return dispatcher.EndGroups
}
if update.EffectiveMessage == nil {
logger.L.Error("No effective message")
return dispatcher.EndGroups
}
args := strings.Split(update.EffectiveMessage.Text, " ")
if len(args) < 3 {
text := []styling.StyledTextOption{
styling.Plain("请提供存储位置名称和路径, 可用项:"),
}
for name := range storage.Storages {
text = append(text, styling.Plain("\n"))
text = append(text, styling.Code(string(name)))
}
text = append(text, styling.Plain("\n示例: /path local /path/to/save"))
ctx.Reply(update, ext.ReplyTextStyledTextArray(text), nil)
return dispatcher.EndGroups
}
storageName := args[1]
if _, ok := storage.Storages[types.StorageType(storageName)]; !ok {
ctx.Reply(update, ext.ReplyTextString("存储位置不存在"), nil)
return dispatcher.EndGroups
}
path := strings.Join(args[2:], " ")
switch storageName {
case "local":
config.Set("storage.local.base_path", path)
case "webdav":
config.Set("storage.webdav.base_path", path)
case "alist":
config.Set("storage.alist.base_path", path)
}
if err := config.ReloadConfig(); err != nil {
logger.L.Errorf("Failed to reload config: %s", err)
ctx.Reply(update, ext.ReplyTextString("设置失败: "+err.Error()), nil)
return dispatcher.EndGroups
}
ctx.Reply(update, ext.ReplyTextString("设置成功"), nil)
return dispatcher.EndGroups
}
func handleFileMessage(ctx *ext.Context, update *ext.Update) error {
logger.L.Trace("Got media: ", update.EffectiveMessage.Media.TypeName())
supported, err := supportedMediaFilter(update.EffectiveMessage.Message)
if err != nil {
return err
}
if !supported {
return dispatcher.EndGroups
}
user, err := dao.GetUserByUserID(update.GetUserChat().GetID())
if err != nil {
logger.L.Errorf("Failed to get user: %s", err)
return dispatcher.EndGroups
}
msg, err := ctx.Reply(update, ext.ReplyTextString("正在获取文件信息..."), nil)
if err != nil {
logger.L.Errorf("Failed to reply: %s", err)
return dispatcher.EndGroups
}
media := update.EffectiveMessage.Media
file, err := FileFromMedia(media, "")
if err != nil {
logger.L.Errorf("Failed to get file from media: %s", err)
ctx.Reply(update, ext.ReplyTextString(fmt.Sprintf("获取文件失败: %s", err)), nil)
return dispatcher.EndGroups
}
if file.FileName == "" {
file.FileName = fmt.Sprintf("%d_%d_%s", update.EffectiveChat().GetID(), update.EffectiveMessage.ID, file.Hash())
}
if err := dao.SaveReceivedFile(&types.ReceivedFile{
Processing: false,
FileName: file.FileName,
ChatID: update.EffectiveChat().GetID(),
MessageID: update.EffectiveMessage.ID,
ReplyMessageID: msg.ID,
ReplyChatID: update.GetUserChat().GetID(),
}); err != nil {
logger.L.Errorf("Failed to add received file: %s", err)
if _, err := ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: fmt.Sprintf("Failed to add received file: %s", err),
ID: msg.ID,
}); err != nil {
logger.L.Errorf("Failed to edit message: %s", err)
}
return dispatcher.EndGroups
}
if !user.Silent {
return ProvideSelectMessage(ctx, update, file, int(update.EffectiveChat().GetID()), update.EffectiveMessage.ID, msg.ID)
}
return HandleSilentAddTask(ctx, update, user, &types.Task{
Ctx: ctx,
Status: types.Pending,
File: file,
Storage: types.StorageType(user.DefaultStorage),
FileChatID: update.EffectiveChat().GetID(),
ReplyMessageID: msg.ID,
ReplyChatID: update.GetUserChat().GetID(),
FileMessageID: update.EffectiveMessage.ID,
})
}
func AddToQueue(ctx *ext.Context, update *ext.Update) error {
if !slice.Contain(config.Cfg.Telegram.Admins, update.CallbackQuery.UserID) {
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
QueryID: update.CallbackQuery.QueryID,
Alert: true,
Message: "你没有权限",
CacheTime: 5,
})
return dispatcher.EndGroups
}
args := strings.Split(string(update.CallbackQuery.Data), " ")
chatID, _ := strconv.Atoi(args[1])
messageID, _ := strconv.Atoi(args[2])
storageName := args[3]
logger.L.Tracef("Got add to queue: chatID: %d, messageID: %d, storage: %s", chatID, messageID, storageName)
record, err := dao.GetReceivedFileByChatAndMessageID(int64(chatID), messageID)
if err != nil {
logger.L.Errorf("Failed to get received file: %s", err)
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
QueryID: update.CallbackQuery.QueryID,
Alert: true,
Message: "查询记录失败",
CacheTime: 5,
})
return dispatcher.EndGroups
}
if update.CallbackQuery.MsgID != record.ReplyMessageID {
record.ReplyMessageID = update.CallbackQuery.MsgID
if err := dao.SaveReceivedFile(record); err != nil {
logger.L.Errorf("Failed to update received file: %s", err)
}
}
file, err := FileFromMessage(ctx, record.ChatID, record.MessageID, record.FileName)
if err != nil {
logger.L.Errorf("Failed to get file from message: %s", err)
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
QueryID: update.CallbackQuery.QueryID,
Alert: true,
Message: fmt.Sprintf("获取消息中的文件失败: %s", err),
CacheTime: 5,
})
return dispatcher.EndGroups
}
queue.AddTask(types.Task{
Ctx: ctx,
Status: types.Pending,
File: file,
Storage: types.StorageType(storageName),
FileChatID: record.ChatID,
ReplyMessageID: record.ReplyMessageID,
FileMessageID: record.MessageID,
ReplyChatID: record.ReplyChatID,
})
entityBuilder := entity.Builder{}
var entities []tg.MessageEntityClass
text := fmt.Sprintf("已添加到任务队列\n文件名: %s\n当前排队任务数: %d", record.FileName, queue.Len())
if err := styling.Perform(&entityBuilder,
styling.Plain("已添加到任务队列\n文件名: "),
styling.Code(record.FileName),
styling.Plain("\n当前排队任务数: "),
styling.Bold(strconv.Itoa(queue.Len())),
); err != nil {
logger.L.Errorf("Failed to build entity: %s", err)
} else {
text, entities = entityBuilder.Complete()
}
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: text,
Entities: entities,
ID: record.ReplyMessageID,
})
return dispatcher.EndGroups
}

View File

@@ -3,9 +3,13 @@ package bot
import (
"time"
"github.com/celestix/gotgproto/dispatcher"
"github.com/celestix/gotgproto/ext"
"github.com/duke-git/lancet/v2/slice"
"github.com/gotd/contrib/middleware/floodwait"
"github.com/gotd/contrib/middleware/ratelimit"
"github.com/gotd/td/telegram"
"github.com/krau/SaveAny-Bot/config"
"golang.org/x/time/rate"
)
@@ -17,3 +21,17 @@ func FloodWaitMiddleware() []telegram.Middleware {
ratelimiter,
}
}
const noPermissionText string = `
您不在白名单中, 无法使用此 Bot.
您可以部署自己的实例: https://github.com/krau/SaveAny-Bot
`
func checkPermission(ctx *ext.Context, update *ext.Update) error {
userID := update.GetUserChat().GetID()
if !slice.Contain(config.Cfg.GetUsersID(), userID) {
ctx.Reply(update, ext.ReplyTextString(noPermissionText), nil)
return dispatcher.EndGroups
}
return dispatcher.ContinueGroups
}

View File

@@ -3,15 +3,18 @@ package bot
import (
"errors"
"fmt"
"strconv"
"strings"
"time"
"github.com/celestix/gotgproto/dispatcher"
"github.com/celestix/gotgproto/ext"
"github.com/gabriel-vasile/mimetype"
"github.com/gotd/td/telegram/message/entity"
"github.com/gotd/td/telegram/message/styling"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/common"
"github.com/krau/SaveAny-Bot/logger"
"github.com/krau/SaveAny-Bot/dao"
"github.com/krau/SaveAny-Bot/queue"
"github.com/krau/SaveAny-Bot/storage"
"github.com/krau/SaveAny-Bot/types"
@@ -22,6 +25,8 @@ var (
ErrEmptyPhoto = errors.New("photo is empty")
ErrEmptyPhotoSize = errors.New("photo size is empty")
ErrEmptyPhotoSizes = errors.New("photo size slice is empty")
ErrNoStorages = errors.New("no available storage")
ErrEmptyMessage = errors.New("message is empty")
)
func supportedMediaFilter(m *tg.Message) (bool, error) {
@@ -38,49 +43,80 @@ func supportedMediaFilter(m *tg.Message) (bool, error) {
}
}
var StorageDisplayNames = map[string]string{
"all": "全部",
"local": "服务器磁盘",
"alist": "Alist",
"webdav": "WebDAV",
}
func getSelectStorageMarkup(userChatID int64, fileChatID, fileMessageID int) (*tg.ReplyInlineMarkup, error) {
user, err := dao.GetUserByChatID(userChatID)
if err != nil {
return nil, fmt.Errorf("failed to get user by chat ID: %d, error: %w", userChatID, err)
}
storages := storage.GetUserStorages(user.ChatID)
if len(storages) == 0 {
return nil, ErrNoStorages
}
func getAddTaskMarkup(chatID, messageID int) *tg.ReplyInlineMarkup {
storageButtons := make([]tg.KeyboardButtonClass, 0)
for _, name := range storage.StorageKeys {
storageButtons = append(storageButtons, &tg.KeyboardButtonCallback{
Text: StorageDisplayNames[string(name)],
Data: []byte(fmt.Sprintf("add %d %d %s", chatID, messageID, name)),
buttons := make([]tg.KeyboardButtonClass, 0)
for _, storage := range storages {
cbData := fmt.Sprintf("%d %d %s 0", fileChatID, fileMessageID, storage.Name()) // 0 for empty dir id
cbDataId, err := dao.CreateCallbackData(cbData)
if err != nil {
return nil, fmt.Errorf("failed to create callback data: %w", err)
}
buttons = append(buttons, &tg.KeyboardButtonCallback{
Text: storage.Name(),
Data: []byte(fmt.Sprintf("add %d", cbDataId)),
})
}
markup := &tg.ReplyInlineMarkup{}
for i := 0; i < len(buttons); i += 3 {
row := tg.KeyboardButtonRow{}
row.Buttons = buttons[i:min(i+3, len(buttons))]
markup.Rows = append(markup.Rows, row)
}
return markup, nil
}
if len(storageButtons) < 1 {
return nil
}
if len(storageButtons) == 1 {
return &tg.ReplyInlineMarkup{
Rows: []tg.KeyboardButtonRow{
{
Buttons: storageButtons,
},
},
func getSelectDirMarkup(fileChatID, fileMessageID int, storageName string, dirs []dao.Dir) (*tg.ReplyInlineMarkup, error) {
buttons := make([]tg.KeyboardButtonClass, 0)
for _, dir := range dirs {
if dir.ID == 0 || dir.StorageName != storageName {
return nil, fmt.Errorf("unexpected dir: %v", dir)
}
cbDataId, err := dao.CreateCallbackData(fmt.Sprintf("%d %d %s %d", fileChatID, fileMessageID, storageName, dir.ID))
if err != nil {
return nil, fmt.Errorf("failed to create callback data: %w", err)
}
buttons = append(buttons, &tg.KeyboardButtonCallback{
Text: dir.Path,
Data: []byte(fmt.Sprintf("add_to_dir %d", cbDataId)),
})
}
return &tg.ReplyInlineMarkup{
Rows: []tg.KeyboardButtonRow{
{
Buttons: storageButtons,
},
{
Buttons: []tg.KeyboardButtonClass{
&tg.KeyboardButtonCallback{
Text: "全部",
Data: []byte(fmt.Sprintf("add %d %d all", chatID, messageID)),
},
},
},
},
markup := &tg.ReplyInlineMarkup{}
for i := 0; i < len(buttons); i += 3 {
row := tg.KeyboardButtonRow{}
row.Buttons = buttons[i:min(i+3, len(buttons))]
markup.Rows = append(markup.Rows, row)
}
return markup, nil
}
func getSetDefaultStorageMarkup(userChatID int64, storages []storage.Storage) (*tg.ReplyInlineMarkup, error) {
buttons := make([]tg.KeyboardButtonClass, 0)
for _, storage := range storages {
cbDataId, err := dao.CreateCallbackData(storage.Name())
if err != nil {
return nil, fmt.Errorf("failed to create callback data: %w", err)
}
buttons = append(buttons, &tg.KeyboardButtonCallback{
Text: storage.Name(),
Data: []byte(fmt.Sprintf("set_default %d %d", userChatID, cbDataId)),
})
}
markup := &tg.ReplyInlineMarkup{}
for i := 0; i < len(buttons); i += 3 {
row := tg.KeyboardButtonRow{}
row.Buttons = buttons[i:min(i+3, len(buttons))]
markup.Rows = append(markup.Rows, row)
}
return markup, nil
}
func FileFromMedia(media tg.MessageMediaClass, customFileName string) (*types.File, error) {
@@ -144,7 +180,7 @@ func FileFromMedia(media tg.MessageMediaClass, customFileName string) (*types.Fi
func FileFromMessage(ctx *ext.Context, chatID int64, messageID int, customFileName string) (*types.File, error) {
key := fmt.Sprintf("file:%d:%d", chatID, messageID)
logger.L.Debugf("Getting file: %s", key)
common.Log.Debugf("Getting file: %s", key)
var cachedFile types.File
err := common.Cache.Get(key, &cachedFile)
if err == nil {
@@ -159,19 +195,19 @@ func FileFromMessage(ctx *ext.Context, chatID int64, messageID int, customFileNa
return nil, err
}
if err := common.Cache.Set(key, file, 3600); err != nil {
logger.L.Errorf("Failed to cache file: %s", err)
common.Log.Errorf("Failed to cache file: %s", err)
}
return file, nil
}
func GetTGMessage(ctx *ext.Context, chatId int64, messageID int) (*tg.Message, error) {
logger.L.Debugf("Fetching message: %d", messageID)
common.Log.Debugf("Fetching message: %d", messageID)
messages, err := ctx.GetMessages(chatId, []tg.InputMessageClass{&tg.InputMessageID{ID: messageID}})
if err != nil {
return nil, err
}
if len(messages) == 0 {
return nil, errors.New("no messages found")
return nil, ErrEmptyMessage
}
msg := messages[0]
tgMessage, ok := msg.(*tg.Message)
@@ -181,32 +217,48 @@ func GetTGMessage(ctx *ext.Context, chatId int64, messageID int) (*tg.Message, e
return tgMessage, nil
}
func ProvideSelectMessage(ctx *ext.Context, update *ext.Update, file *types.File, chatID int, fileMsgID, toEditMsgID int) error {
func ProvideSelectMessage(ctx *ext.Context, update *ext.Update, fileName string, chatID int64, fileMsgID, toEditMsgID int) error {
entityBuilder := entity.Builder{}
var entities []tg.MessageEntityClass
text := fmt.Sprintf("文件名: %s\n请选择存储位置", file.FileName)
text := fmt.Sprintf("文件名: %s\n请选择存储位置", fileName)
if err := styling.Perform(&entityBuilder,
styling.Plain("文件名: "),
styling.Code(file.FileName),
styling.Code(fileName),
styling.Plain("\n请选择存储位置"),
); err != nil {
logger.L.Errorf("Failed to build entity: %s", err)
common.Log.Errorf("Failed to build entity: %s", err)
} else {
text, entities = entityBuilder.Complete()
}
_, err := ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
markup, err := getSelectStorageMarkup(update.GetUserChat().GetID(), int(chatID), fileMsgID)
if errors.Is(err, ErrNoStorages) {
common.Log.Errorf("Failed to get select storage markup: %s", err)
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: "无可用存储",
ID: toEditMsgID,
})
return dispatcher.EndGroups
} else if err != nil {
common.Log.Errorf("Failed to get select storage markup: %s", err)
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: "无法获取存储",
ID: toEditMsgID,
})
return dispatcher.EndGroups
}
_, err = ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: text,
Entities: entities,
ReplyMarkup: getAddTaskMarkup(chatID, fileMsgID),
ReplyMarkup: markup,
ID: toEditMsgID,
})
if err != nil {
logger.L.Errorf("Failed to reply: %s", err)
common.Log.Errorf("Failed to reply: %s", err)
}
return dispatcher.EndGroups
}
func HandleSilentAddTask(ctx *ext.Context, update *ext.Update, user *types.User, task *types.Task) error {
func HandleSilentAddTask(ctx *ext.Context, update *ext.Update, user *dao.User, task *types.Task) error {
if user.DefaultStorage == "" {
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: "请先使用 /storage 设置默认存储位置",
@@ -214,10 +266,57 @@ func HandleSilentAddTask(ctx *ext.Context, update *ext.Update, user *types.User,
})
return dispatcher.EndGroups
}
queue.AddTask(*task)
queue.AddTask(task)
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: fmt.Sprintf("已添加到队列: %s\n当前排队任务数: %d", task.FileName(), queue.Len()),
ID: task.ReplyMessageID,
})
return dispatcher.EndGroups
}
func GenFileNameFromMessage(message tg.Message, file *types.File) string {
if file.FileName != "" {
return file.FileName
}
fileName := genFileNameFromMessageText(message, file)
media, ok := message.GetMedia()
if !ok {
return fileName
}
ext, ok := extraMediaExt(media)
if ok {
return fileName + ext
}
return fileName
}
func genFileNameFromMessageText(message tg.Message, file *types.File) string {
text := strings.TrimSpace(message.GetMessage())
if text == "" {
return file.Hash()
}
tags := common.ExtractTagsFromText(text)
if len(tags) > 0 {
return fmt.Sprintf("%s_%s", strings.Join(tags, "_"), strconv.Itoa(message.GetID()))
}
runes := []rune(text)
return string(runes[:min(128, len(runes))])
}
func extraMediaExt(media tg.MessageMediaClass) (string, bool) {
switch media := media.(type) {
case *tg.MessageMediaDocument:
doc, ok := media.Document.AsNotEmpty()
if !ok {
return "", false
}
ext := mimetype.Lookup(doc.MimeType).Extension()
if ext == "" {
return "", false
}
return ext, true
case *tg.MessageMediaPhoto:
return ".jpg", true
}
return "", false
}

View File

@@ -1,51 +1,67 @@
package cmd
import (
"fmt"
"os"
"os/signal"
"path/filepath"
"syscall"
"github.com/krau/SaveAny-Bot/bootstrap"
"slices"
"github.com/krau/SaveAny-Bot/bot"
"github.com/krau/SaveAny-Bot/common"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/core"
"github.com/krau/SaveAny-Bot/logger"
"github.com/krau/SaveAny-Bot/dao"
"github.com/krau/SaveAny-Bot/storage"
"github.com/spf13/cobra"
)
func Run(_ *cobra.Command, _ []string) {
bootstrap.InitAll()
InitAll()
core.Run()
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
sig := <-quit
logger.L.Info(sig, ", exitting...")
defer logger.L.Info("Bye!")
common.Log.Info(sig, ", exitting...")
defer common.Log.Info("Bye!")
if config.Cfg.NoCleanCache {
return
}
if config.Cfg.Temp.BasePath != "" {
for _, path := range []string{"/", ".", "\\", ".."} {
if filepath.Clean(config.Cfg.Temp.BasePath) == path {
logger.L.Error("Invalid cache dir: ", config.Cfg.Temp.BasePath)
return
}
if config.Cfg.Temp.BasePath != "" && !config.Cfg.Stream {
if slices.Contains([]string{"/", ".", "\\", ".."}, filepath.Clean(config.Cfg.Temp.BasePath)) {
common.Log.Error("无效的缓存文件夹: ", config.Cfg.Temp.BasePath)
return
}
currentDir, err := os.Getwd()
if err != nil {
logger.L.Error("Failed to get current dir: ", err)
common.Log.Error("获取工作目录失败: ", err)
return
}
cachePath := filepath.Join(currentDir, config.Cfg.Temp.BasePath)
cachePath, err = filepath.Abs(cachePath)
if err != nil {
logger.L.Error("Failed to get absolute path: ", err)
common.Log.Error("获取缓存绝对路径失败: ", err)
return
}
logger.L.Info("Cleaning cache dir: ", cachePath)
common.Log.Info("正在清理缓存文件夹: ", cachePath)
if err := os.RemoveAll(cachePath); err != nil {
logger.L.Error("Failed to clean cache dir: ", err)
common.Log.Error("清理缓存失败: ", err)
}
}
}
func InitAll() {
if err := config.Init(); err != nil {
fmt.Println("加载配置文件失败: ", err)
os.Exit(1)
}
common.InitLogger()
common.Log.Info("正在启动 SaveAny-Bot...")
dao.Init()
storage.LoadStorages()
common.Init()
bot.Init()
}

View File

@@ -24,7 +24,7 @@ func initCache() {
Cache = &CommonCache{cache: freecache.NewCache(10 * 1024 * 1024)}
}
func (c *CommonCache) Get(key string, value *types.File) error {
func (c *CommonCache) Get(key string, value any) error {
c.mu.RLock()
defer c.mu.RUnlock()
data, err := Cache.cache.Get([]byte(key))
@@ -39,7 +39,7 @@ func (c *CommonCache) Get(key string, value *types.File) error {
return nil
}
func (c *CommonCache) Set(key string, value *types.File, expireSeconds int) error {
func (c *CommonCache) Set(key string, value any, expireSeconds int) error {
c.mu.Lock()
defer c.mu.Unlock()
var buf bytes.Buffer

View File

@@ -1,21 +1,20 @@
package logger
package common
import (
"github.com/krau/SaveAny-Bot/config"
"github.com/gookit/slog"
"github.com/gookit/slog/handler"
"github.com/gookit/slog/rotatefile"
"github.com/krau/SaveAny-Bot/config"
)
var L *slog.Logger
var Log *slog.Logger
func InitLogger() {
if L != nil {
if Log != nil {
return
}
slog.DefaultChannelName = "SaveAnyBot"
L = slog.New()
Log = slog.New()
logLevel := slog.LevelByName(config.Cfg.Log.Level)
logFilePath := config.Cfg.Log.File
logBackupNum := config.Cfg.Log.BackupCount
@@ -36,5 +35,5 @@ func InitLogger() {
if err != nil {
panic(err)
}
L.AddHandlers(consoleH, fileH)
Log.AddHandlers(consoleH, fileH)
}

View File

@@ -5,8 +5,6 @@ import (
"os"
"path/filepath"
"time"
"github.com/krau/SaveAny-Bot/logger"
)
// 创建文件, 自动创建目录
@@ -31,10 +29,10 @@ func PurgeFile(path string) error {
func RmFileAfter(path string, td time.Duration) {
_, err := os.Stat(path)
if err != nil {
logger.L.Errorf("Failed to create timer for %s: %s", path, err)
Log.Errorf("Failed to create timer for %s: %s", path, err)
return
}
logger.L.Debugf("Remove file after %s: %s", td, path)
Log.Debugf("Remove file after %s: %s", td, path)
time.AfterFunc(td, func() {
PurgeFile(path)
})

26
common/utils.go Normal file
View File

@@ -0,0 +1,26 @@
package common
import (
"crypto/md5"
"encoding/hex"
"regexp"
)
func HashString(s string) string {
hash := md5.New()
hash.Write([]byte(s))
return hex.EncodeToString(hash.Sum(nil))
}
var TagRe = regexp.MustCompile(`(?:^|[\p{Zs}\s.,!?(){}[\]<>\"\',。!?():;、])#([\p{L}\d_]+)`)
func ExtractTagsFromText(text string) []string {
matches := TagRe.FindAllStringSubmatch(text, -1)
tags := make([]string, 0)
for _, match := range matches {
if len(match) > 1 {
tags = append(tags, match[1])
}
}
return tags
}

View File

@@ -1,53 +1,100 @@
#创建文件时,若需要保留中文注释,请务必确保本文件编码为 UTF-8 ,否则会无法读取。
workers = 4 # 同时下载文件数
retry = 3 # 下载失败重试次数
threads = 4 # 单个任务下载最大线程数
stream = false # 使用stream模式, 详情请查看文档
[telegram]
# Bot Token
# 更换 Bot Token 后请删除数据库文件和 session.db
token = ""
# 允许使用的用户 id 列表
admins = [777000]
# Telegram API 配置, 若不配置也可运行, 将使用默认的 API ID 和 API HASH
# 推荐使用自己的 API ID 和 API HASH (https://my.telegram.org)
# app_id = 123456
# app_id = 123456
# app_hash = "0123456789abcdef0123456789abcdef"
# 初始化超时时间, 单位: 秒
timeout = 60
[telegram.proxy]
# 启用代理连接 telegram, 只支持 socks5
enable = false
url = "socks5://127.0.0.1:7890"
[storage]
[storage.alist] # Alist
# 存储配置列表
[[storages]]
# 标识名, 需要唯一
name = "本机1"
# 存储类型, 目前可用: local, alist, webdav, minio
type = "local"
# 启用存储
enable = true
base_path = "/telegram" # 保存路径
username = "admin" # 用户名
password = "password" # 密码
url = "https://alist.com" # Alist 地址
token_exp = 86400 # token 过期时间, 单位: 秒
# 可直接使用 token 授权, 此时不能自动刷新登录信息
# 配置 token 后, username , password , token_exp 将被忽略
token = "jwt_token"
# 文件保存路径
base_path = "./downloads"
[storage.local] # 本地磁盘
[[storages]]
name = "MyAlist"
type = "alist"
enable = false #记得启用
base_path = '/'
url = 'https://alist.com'
username = 'admin'
password = 'password'
token_exp = 86400 # 86400--1天 604800--7天 1296000--15天 2592000--30天 15552000--180天
# alist 可直接使用 token 登录, 此时 username, password, token_exp 将被忽略
# 请自行在 alist 侧配置合理的 token 过期时间
# token = ""
[[storages]]
name = "MyWebdav"
type = "webdav"
enable = false
base_path = '/path/telegram'
url = 'https://example.com/dav'
username = 'username'
password = 'password'
[[storages]]
name = "MyMinio"
type = "minio"
enable = true
base_path = "downloads/" # 保存路径
endpoint = 'play.min.io'
use_ssl = true
access_key_id = 'Q3AM3UQ867SPQQA43P2F'
secret_access_key = 'zuf+tfteSlswRu7BJ86wekitnifILbZam1KYY3TG'
bucket_name = 'saveanybot'
base_path = '/path/telegram'
[storage.webdav] # WebDav
enable = true
base_path = "/telegram"
username = "admin"
password = "password"
url = "https://alist.com/dav"
# 用户列表
[[users]]
# telegram user id
id = 114514
# 开启黑名单,开启后下方留空以使用所有存储,反之则为白名单,白名单请在下方输入允许的存储名
blacklist = true
# 将列表留空并开启黑名单模式以允许使用所有存储此处示例为黑名单模式用户114514 可使用所有存储
storages = []
[log]
# 日志等级
level = "DEBUG"
[[users]]
id = 123456
blacklist = false #开启白名单模式此时用户123456 仅可使用下方列表中的存储
# 此时该用户只能使用名为 本机1 的存储
storages = ["本机1"]
[temp]
base_path = "cache/" # 下载文件临时目录, 请不要在此目录下存放任何其他文件
cache_ttl = 30 # 临时文件保存时间, 单位: 秒
[db]
path = "data/data.db" # 数据库文件路径
# 其他配置
# [log]
# # 日志等级
# level = "DEBUG"
# [temp]
# # 下载文件临时目录, 请不要在此目录下存放任何其他文件
# base_path = "cache/"
# # 临时文件保存时间, 单位: 秒
# cache_ttl = 30
# [db]
# path = "data/data.db" # 数据库文件路径

38
config/storage/alist.go Normal file
View File

@@ -0,0 +1,38 @@
package storage
import (
"fmt"
"github.com/krau/SaveAny-Bot/types"
)
type AlistStorageConfig struct {
BaseConfig
URL string `toml:"url" mapstructure:"url" json:"url"`
Username string `toml:"username" mapstructure:"username" json:"username"`
Password string `toml:"password" mapstructure:"password" json:"password"`
Token string `toml:"token" mapstructure:"token" json:"token"`
BasePath string `toml:"base_path" mapstructure:"base_path" json:"base_path"`
TokenExp int64 `toml:"token_exp" mapstructure:"token_exp" json:"token_exp"`
}
func (a *AlistStorageConfig) Validate() error {
if a.URL == "" {
return fmt.Errorf("url is required for alist storage")
}
if a.Token == "" && (a.Username == "" || a.Password == "") {
return fmt.Errorf("username and password or token is required for alist storage")
}
if a.BasePath == "" {
return fmt.Errorf("base_path is required for alist storage")
}
return nil
}
func (a *AlistStorageConfig) GetType() types.StorageType {
return types.StorageTypeAlist
}
func (a *AlistStorageConfig) GetName() string {
return a.Name
}

63
config/storage/factory.go Normal file
View File

@@ -0,0 +1,63 @@
package storage
import (
"fmt"
"reflect"
"github.com/krau/SaveAny-Bot/types"
"github.com/mitchellh/mapstructure"
"github.com/spf13/viper"
)
var storageFactories = map[types.StorageType]func(cfg *BaseConfig) (StorageConfig, error){
types.StorageTypeLocal: createStorageConfig(&LocalStorageConfig{}),
types.StorageTypeAlist: createStorageConfig(&AlistStorageConfig{}),
types.StorageTypeWebdav: createStorageConfig(&WebdavStorageConfig{}),
types.StorageTypeMinio: createStorageConfig(&MinioStorageConfig{}),
}
func createStorageConfig(configType StorageConfig) func(cfg *BaseConfig) (StorageConfig, error) {
return func(cfg *BaseConfig) (StorageConfig, error) {
configValue := reflect.New(reflect.TypeOf(configType).Elem()).Interface().(StorageConfig)
reflect.ValueOf(configValue).Elem().FieldByName("BaseConfig").Set(reflect.ValueOf(*cfg))
if err := mapstructure.Decode(cfg.RawConfig, configValue); err != nil {
return nil, fmt.Errorf("failed to decode %s storage config: %w", cfg.Type, err)
}
return configValue, nil
}
}
func LoadStorageConfigs(v *viper.Viper) ([]StorageConfig, error) {
var baseConfigs []BaseConfig
if err := v.UnmarshalKey("storages", &baseConfigs); err != nil {
return nil, fmt.Errorf("failed to unmarshal storage configs: %w", err)
}
var configs []StorageConfig
for _, baseCfg := range baseConfigs {
if !baseCfg.Enable {
continue
}
factory, ok := storageFactories[types.StorageType(baseCfg.Type)]
if !ok {
return nil, fmt.Errorf("unsupported storage type: %s", baseCfg.Type)
}
cfg, err := factory(&baseCfg)
if err != nil {
return nil, fmt.Errorf("failed to create storage config for %s: %w", baseCfg.Name, err)
}
if err := cfg.Validate(); err != nil {
return nil, fmt.Errorf("invalid storage config for %s: %w", baseCfg.Name, err)
}
configs = append(configs, cfg)
}
return configs, nil
}

27
config/storage/local.go Normal file
View File

@@ -0,0 +1,27 @@
package storage
import (
"fmt"
"github.com/krau/SaveAny-Bot/types"
)
type LocalStorageConfig struct {
BaseConfig
BasePath string `toml:"base_path" mapstructure:"base_path" json:"base_path"`
}
func (l *LocalStorageConfig) Validate() error {
if l.BasePath == "" {
return fmt.Errorf("path is required for local storage")
}
return nil
}
func (l *LocalStorageConfig) GetType() types.StorageType {
return types.StorageTypeLocal
}
func (l *LocalStorageConfig) GetName() string {
return l.Name
}

41
config/storage/minio.go Normal file
View File

@@ -0,0 +1,41 @@
package storage
import (
"fmt"
"github.com/krau/SaveAny-Bot/types"
)
type MinioStorageConfig struct {
BaseConfig
Endpoint string `toml:"endpoint" mapstructure:"endpoint" json:"endpoint"`
AccessKeyID string `toml:"access_key_id" mapstructure:"access_key_id" json:"access_key_id"`
SecretAccessKey string `toml:"secret_access_key" mapstructure:"secret_access_key" json:"secret_access_key"`
BucketName string `toml:"bucket_name" mapstructure:"bucket_name" json:"bucket_name"`
UseSSL bool `toml:"use_ssl" mapstructure:"use_ssl" json:"use_ssl"`
BasePath string `toml:"base_path" mapstructure:"base_path" json:"base_path"`
}
func (m *MinioStorageConfig) Validate() error {
if m.Endpoint == "" {
return fmt.Errorf("endpoint is required for minio storage")
}
if m.AccessKeyID == "" || m.SecretAccessKey == "" {
return fmt.Errorf("access_key_id and secret_access_key are required for minio storage")
}
if m.BucketName == "" {
return fmt.Errorf("bucket_name is required for minio storage")
}
if m.BasePath == "" {
return fmt.Errorf("base_path is required for minio storage")
}
return nil
}
func (m *MinioStorageConfig) GetType() types.StorageType {
return types.StorageTypeMinio
}
func (m *MinioStorageConfig) GetName() string {
return m.Name
}

16
config/storage/types.go Normal file
View File

@@ -0,0 +1,16 @@
package storage
import "github.com/krau/SaveAny-Bot/types"
type StorageConfig interface {
Validate() error
GetType() types.StorageType
GetName() string
}
type BaseConfig struct {
Name string `toml:"name" mapstructure:"name" json:"name"`
Type string `toml:"type" mapstructure:"type" json:"type"`
Enable bool `toml:"enable" mapstructure:"enable" json:"enable"`
RawConfig map[string]any `toml:"-" mapstructure:",remain"`
}

36
config/storage/webdav.go Normal file
View File

@@ -0,0 +1,36 @@
package storage
import (
"fmt"
"github.com/krau/SaveAny-Bot/types"
)
type WebdavStorageConfig struct {
BaseConfig
URL string `toml:"url" mapstructure:"url" json:"url"`
Username string `toml:"username" mapstructure:"username" json:"username"`
Password string `toml:"password" mapstructure:"password" json:"password"`
BasePath string `toml:"base_path" mapstructure:"base_path" json:"base_path"`
}
func (w *WebdavStorageConfig) Validate() error {
if w.URL == "" {
return fmt.Errorf("url is required for webdav storage")
}
if w.Username == "" || w.Password == "" {
return fmt.Errorf("username and password is required for webdav storage")
}
if w.BasePath == "" {
return fmt.Errorf("base_path is required for webdav storage")
}
return nil
}
func (w *WebdavStorageConfig) GetType() types.StorageType {
return types.StorageTypeWebdav
}
func (w *WebdavStorageConfig) GetName() string {
return w.Name
}

49
config/user.go Normal file
View File

@@ -0,0 +1,49 @@
package config
import (
"github.com/duke-git/lancet/v2/slice"
)
type userConfig struct {
ID int64 `toml:"id" mapstructure:"id" json:"id"` // telegram user id
Storages []string `toml:"storages" mapstructure:"storages" json:"storages"` // storage names
Blacklist bool `toml:"blacklist" mapstructure:"blacklist" json:"blacklist"` // 黑名单模式, storage names 中的存储将不会被使用, 默认为白名单模式
}
func (c *Config) GetStorageNamesByUserID(userID int64) []string {
for _, user := range c.Users {
if user.ID == userID {
if user.Blacklist {
allStorages := make([]string, 0, len(c.Storages))
for _, storage := range c.Storages {
allStorages = append(allStorages, storage.GetName())
}
return slice.Compact(slice.Difference(allStorages, user.Storages))
} else {
return user.Storages
}
}
}
return nil
}
func (c *Config) GetUsersID() []int64 {
var ids []int64
for _, user := range c.Users {
ids = append(ids, user.ID)
}
return ids
}
func (c *Config) HasStorage(userID int64, storageName string) bool {
for _, user := range c.Users {
if user.ID == userID {
if user.Blacklist {
return !slice.Contain(user.Storages, storageName)
} else {
return slice.Contain(user.Storages, storageName)
}
}
}
return false
}

View File

@@ -5,30 +5,35 @@ import (
"os"
"strings"
"github.com/krau/SaveAny-Bot/config/storage"
"github.com/spf13/viper"
)
type Config struct {
Workers int `toml:"workers" mapstructure:"workers"`
Retry int `toml:"retry" mapstructure:"retry"`
NoCleanCache bool `toml:"no_clean_cache" mapstructure:"no_clean_cache"`
NoCleanCache bool `toml:"no_clean_cache" mapstructure:"no_clean_cache" json:"no_clean_cache"`
Threads int `toml:"threads" mapstructure:"threads" json:"threads"`
Stream bool `toml:"stream" mapstructure:"stream" json:"stream"`
Temp tempConfig `toml:"temp" mapstructure:"temp"`
Log logConfig `toml:"log" mapstructure:"log"`
DB dbConfig `toml:"db" mapstructure:"db"`
Telegram telegramConfig `toml:"telegram" mapstructure:"telegram"`
Storage storageConfig `toml:"storage" mapstructure:"storage"`
Users []userConfig `toml:"users" mapstructure:"users" json:"users"`
Temp tempConfig `toml:"temp" mapstructure:"temp"`
Log logConfig `toml:"log" mapstructure:"log"`
DB dbConfig `toml:"db" mapstructure:"db"`
Telegram telegramConfig `toml:"telegram" mapstructure:"telegram"`
Storages []storage.StorageConfig `toml:"-" mapstructure:"-" json:"storages"`
}
type tempConfig struct {
BasePath string `toml:"base_path" mapstructure:"base_path"`
CacheTTL int64 `toml:"cache_ttl" mapstructure:"cache_ttl"`
BasePath string `toml:"base_path" mapstructure:"base_path" json:"base_path"`
CacheTTL int64 `toml:"cache_ttl" mapstructure:"cache_ttl" json:"cache_ttl"`
}
type logConfig struct {
Level string `toml:"level" mapstructure:"level"`
File string `toml:"file" mapstructure:"file"`
BackupCount uint `toml:"backup_count" mapstructure:"backup_count"`
BackupCount uint `toml:"backup_count" mapstructure:"backup_count" json:"backup_count"`
}
type dbConfig struct {
@@ -37,10 +42,13 @@ type dbConfig struct {
type telegramConfig struct {
Token string `toml:"token" mapstructure:"token"`
AppID int `toml:"app_id" mapstructure:"app_id"`
AppHash string `toml:"app_hash" mapstructure:"app_hash"`
Admins []int64 `toml:"admins" mapstructure:"admins"`
AppID int `toml:"app_id" mapstructure:"app_id" json:"app_id"`
AppHash string `toml:"app_hash" mapstructure:"app_hash" json:"app_hash"`
Timeout int `toml:"timeout" mapstructure:"timeout" json:"timeout"`
Proxy proxyConfig `toml:"proxy" mapstructure:"proxy"`
// Deprecated
Admins []int64 `toml:"admins" mapstructure:"admins"`
}
type proxyConfig struct {
@@ -48,38 +56,18 @@ type proxyConfig struct {
URL string `toml:"url" mapstructure:"url"`
}
type storageConfig struct {
Alist alistConfig `toml:"alist" mapstructure:"alist"`
Local localConfig `toml:"local" mapstructure:"local"`
Webdav webdavConfig `toml:"webdav" mapstructure:"webdav"`
}
type alistConfig struct {
Enable bool `toml:"enable" mapstructure:"enable"`
URL string `toml:"url" mapstructure:"url"`
Username string `toml:"username" mapstructure:"username"`
Password string `toml:"password" mapstructure:"password"`
Token string `toml:"token" mapstructure:"token"`
BasePath string `toml:"base_path" mapstructure:"base_path"`
TokenExp int64 `toml:"token_exp" mapstructure:"token_exp"`
}
type localConfig struct {
Enable bool `toml:"enable" mapstructure:"enable"`
BasePath string `toml:"base_path" mapstructure:"base_path"`
}
type webdavConfig struct {
Enable bool `toml:"enable" mapstructure:"enable"`
URL string `toml:"url" mapstructure:"url"`
Username string `toml:"username" mapstructure:"username"`
Password string `toml:"password" mapstructure:"password"`
BasePath string `toml:"base_path" mapstructure:"base_path"`
}
var Cfg *Config
func Init() {
func (c Config) GetStorageByName(name string) storage.StorageConfig {
for _, storage := range c.Storages {
if storage.GetName() == name {
return storage
}
}
return nil
}
func Init() error {
viper.SetConfigName("config")
viper.AddConfigPath(".")
viper.AddConfigPath("/etc/saveany/")
@@ -91,9 +79,11 @@ func Init() {
viper.SetDefault("workers", 3)
viper.SetDefault("retry", 3)
viper.SetDefault("threads", 4)
viper.SetDefault("telegram.app_id", 1025907)
viper.SetDefault("telegram.app_hash", "452b0359b988148995f22ff0f4229750")
viper.SetDefault("telegram.timeout", 60)
viper.SetDefault("temp.base_path", "cache/")
viper.SetDefault("temp.cache_ttl", 3600)
@@ -104,10 +94,11 @@ func Init() {
viper.SetDefault("db.path", "data/saveany.db")
viper.SetDefault("storage.alist.base_path", "/")
viper.SetDefault("storage.alist.token_exp", 3600)
viper.SafeWriteConfigAs("config.toml")
if err := viper.SafeWriteConfigAs("config.toml"); err != nil {
if _, ok := err.(viper.ConfigFileAlreadyExistsError); !ok {
return fmt.Errorf("error saving default config: %w", err)
}
}
if err := viper.ReadInConfig(); err != nil {
fmt.Println("Error reading config file, ", err)
@@ -115,14 +106,36 @@ func Init() {
}
Cfg = &Config{}
if err := viper.Unmarshal(Cfg); err != nil {
fmt.Println("Error unmarshalling config file, ", err)
os.Exit(1)
}
if Cfg.Workers < 1 || Cfg.Retry < 1 {
fmt.Println("Invalid workers or retry value")
os.Exit(1)
storagesConfig, err := storage.LoadStorageConfigs(viper.GetViper())
if err != nil {
return fmt.Errorf("error loading storage configs: %w", err)
}
Cfg.Storages = storagesConfig
storageNames := make(map[string]struct{})
for _, storage := range Cfg.Storages {
if _, ok := storageNames[storage.GetName()]; ok {
return fmt.Errorf("重复的存储名: %s", storage.GetName())
}
storageNames[storage.GetName()] = struct{}{}
}
fmt.Printf("已加载 %d 个存储:\n", len(Cfg.Storages))
for _, storage := range Cfg.Storages {
fmt.Printf(" - %s (%s)\n", storage.GetName(), storage.GetType())
}
if Cfg.Workers < 1 || Cfg.Retry < 1 {
return fmt.Errorf("workers 和 retry 必须大于 0, 当前值: workers=%d, retry=%d", Cfg.Workers, Cfg.Retry)
}
return nil
}
func Set(key string, value any) {

View File

@@ -4,132 +4,37 @@ import (
"context"
"errors"
"fmt"
"io"
"os"
"path"
"path/filepath"
"time"
"github.com/gabriel-vasile/mimetype"
"github.com/celestix/gotgproto/ext"
"github.com/duke-git/lancet/v2/fileutil"
"github.com/gotd/td/telegram/downloader"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/bot"
"github.com/krau/SaveAny-Bot/common"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/logger"
"github.com/krau/SaveAny-Bot/queue"
"github.com/krau/SaveAny-Bot/types"
)
func processPendingTask(task *types.Task) error {
logger.L.Debugf("Start processing task: %s", task.String())
if task.FileName() == "" {
task.File.FileName = fmt.Sprintf("%d_%d_%s", task.FileChatID, task.FileMessageID, task.File.Hash())
}
cacheDestPath := filepath.Join(config.Cfg.Temp.BasePath, task.FileName())
cacheDestPath, err := filepath.Abs(cacheDestPath)
if err != nil {
return fmt.Errorf("failed to get absolute path: %w", err)
}
if err := fileutil.CreateDir(filepath.Dir(cacheDestPath)); err != nil {
return fmt.Errorf("failed to create directory: %w", err)
}
var Downloader *downloader.Downloader
if task.StoragePath == "" {
task.StoragePath = task.File.FileName
}
switch task.Storage {
case types.Local:
task.StoragePath = filepath.Join(config.Cfg.Storage.Local.BasePath, task.StoragePath)
case types.Webdav:
task.StoragePath = path.Join(config.Cfg.Storage.Webdav.BasePath, task.StoragePath)
case types.Alist:
task.StoragePath = path.Join(config.Cfg.Storage.Alist.BasePath, task.StoragePath)
}
if task.File.FileSize == 0 {
return processPhoto(task, cacheDestPath)
}
ctx := task.Ctx.(*ext.Context)
barTotalCount := calculateBarTotalCount(task.File.FileSize)
progressCallback := func(bytesRead, contentLength int64) {
progress := float64(bytesRead) / float64(contentLength) * 100
logger.L.Tracef("Downloading %s: %.2f%%", task.String(), progress)
if task.File.FileSize < 1024*1024*50 || int(progress)%(100/barTotalCount) != 0 {
return
}
text, entities := buildProgressMessageEntity(task, barTotalCount, bytesRead, task.StartTime, progress)
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: text,
Entities: entities,
ID: task.ReplyMessageID,
})
}
text, entities := buildProgressMessageEntity(task, barTotalCount, 0, task.StartTime, 0)
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: text,
Entities: entities,
ID: task.ReplyMessageID,
})
readCloser, err := NewTelegramReader(task.Ctx, bot.Client, &task.File.Location,
0, task.File.FileSize-1, task.File.FileSize,
progressCallback, task.File.FileSize/100)
if err != nil {
return fmt.Errorf("failed to create reader: %w", err)
}
defer readCloser.Close()
dest, err := os.Create(cacheDestPath)
if err != nil {
return fmt.Errorf("failed to create file: %w", err)
}
defer dest.Close()
task.StartTime = time.Now()
if _, err := io.CopyN(dest, readCloser, task.File.FileSize); err != nil {
return fmt.Errorf("failed to download file: %w", err)
}
defer cleanCacheFile(cacheDestPath)
if path.Ext(task.FileName()) == "" {
mimeType, err := mimetype.DetectFile(cacheDestPath)
if err != nil {
logger.L.Errorf("Failed to detect mime type: %s", err)
} else {
task.File.FileName = fmt.Sprintf("%s%s", task.FileName(), mimeType.Extension())
task.StoragePath = fmt.Sprintf("%s%s", task.StoragePath, mimeType.Extension())
}
}
logger.L.Infof("Downloaded file: %s", cacheDestPath)
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: fmt.Sprintf("下载完成: %s\n正在转存文件...", task.FileName()),
ID: task.ReplyMessageID,
})
return saveFileWithRetry(task, cacheDestPath)
func init() {
Downloader = downloader.NewDownloader().WithPartSize(1024 * 1024)
}
func worker(queue *queue.TaskQueue, semaphore chan struct{}) {
for {
semaphore <- struct{}{}
task := queue.GetTask()
logger.L.Debugf("Got task: %s", task.String())
common.Log.Debugf("Got task: %s", task.String())
switch task.Status {
case types.Pending:
logger.L.Infof("Processing task: %s", task.String())
if err := processPendingTask(&task); err != nil {
logger.L.Errorf("Failed to do task: %s", err)
common.Log.Infof("Processing task: %s", task.String())
if err := processPendingTask(task); err != nil {
task.Error = err
if errors.Is(err, context.Canceled) {
logger.L.Debugf("Task canceled: %s", task.String())
task.Status = types.Canceled
} else {
common.Log.Errorf("Failed to do task: %s", err)
task.Status = types.Failed
}
} else {
@@ -137,29 +42,49 @@ func worker(queue *queue.TaskQueue, semaphore chan struct{}) {
}
queue.AddTask(task)
case types.Succeeded:
logger.L.Infof("Task succeeded: %s", task.String())
task.Ctx.(*ext.Context).EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: fmt.Sprintf("文件保存成功\n [%s]: %s", task.Storage, task.StoragePath),
ID: task.ReplyMessageID,
})
common.Log.Infof("Task succeeded: %s", task.String())
extCtx, ok := task.Ctx.(*ext.Context)
if !ok {
common.Log.Errorf("Context is not *ext.Context: %T", task.Ctx)
} else {
extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: fmt.Sprintf("文件保存成功\n [%s]: %s", task.StorageName, task.StoragePath),
ID: task.ReplyMessageID,
})
}
case types.Failed:
logger.L.Errorf("Task failed: %s", task.String())
task.Ctx.(*ext.Context).EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: "文件保存失败\n" + task.Error.Error(),
ID: task.ReplyMessageID,
})
common.Log.Errorf("Task failed: %s", task.String())
extCtx, ok := task.Ctx.(*ext.Context)
if !ok {
common.Log.Errorf("Context is not *ext.Context: %T", task.Ctx)
} else {
extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: "文件保存失败\n" + task.Error.Error(),
ID: task.ReplyMessageID,
})
}
case types.Canceled:
logger.L.Infof("Task canceled: %s", task.String())
common.Log.Infof("Task canceled: %s", task.String())
extCtx, ok := task.Ctx.(*ext.Context)
if !ok {
common.Log.Errorf("Context is not *ext.Context: %T", task.Ctx)
} else {
extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: "任务已取消",
ID: task.ReplyMessageID,
})
}
default:
logger.L.Errorf("Unknown task status: %s", task.Status)
common.Log.Errorf("Unknown task status: %s", task.Status)
}
<-semaphore
logger.L.Debugf("Task done: %s", task.String())
common.Log.Debugf("Task done: %s; status: %s", task.String(), task.Status)
queue.DoneTask(task)
}
}
func Run() {
logger.L.Info("Start processing tasks...")
common.Log.Info("Start processing tasks...")
semaphore := make(chan struct{}, config.Cfg.Workers)
for i := 0; i < config.Cfg.Workers; i++ {
go worker(queue.Queue, semaphore)

276
core/download.go Normal file
View File

@@ -0,0 +1,276 @@
package core
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"path"
"path/filepath"
"strings"
"time"
"github.com/celestix/gotgproto/ext"
"github.com/celestix/telegraph-go/v2"
"github.com/duke-git/lancet/v2/fileutil"
"github.com/gotd/td/telegram/message/entity"
"github.com/gotd/td/telegram/message/styling"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/bot"
"github.com/krau/SaveAny-Bot/common"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/storage"
"github.com/krau/SaveAny-Bot/types"
"golang.org/x/sync/errgroup"
)
func processPendingTask(task *types.Task) error {
common.Log.Debugf("Start processing task: %s", task.String())
if task.FileName() == "" {
task.File.FileName = fmt.Sprintf("%d_%d_%s", task.FileChatID, task.FileMessageID, task.File.Hash())
}
if task.StoragePath == "" {
task.StoragePath = task.FileName()
}
taskStorage, err := storage.GetStorageByUserIDAndName(task.UserID, task.StorageName)
if err != nil {
return err
}
task.StoragePath = taskStorage.JoinStoragePath(*task)
ctx, ok := task.Ctx.(*ext.Context)
if !ok {
return fmt.Errorf("context is not *ext.Context: %T", task.Ctx)
}
cancelCtx, cancel := context.WithCancel(ctx)
task.Cancel = cancel
if task.IsTelegraph {
return processTelegraph(ctx, cancelCtx, task, taskStorage)
}
if task.File.FileSize == 0 {
return processPhoto(task, taskStorage)
}
downloadBuilder := Downloader.Download(bot.Client.API(), task.File.Location).WithThreads(getTaskThreads(task.File.FileSize))
notsupportStreamStorage, notsupportStream := taskStorage.(storage.StorageNotSupportStream)
cancelMarkUp := getCancelTaskMarkup(task)
if config.Cfg.Stream {
if !notsupportStream {
text, entities := buildProgressMessageEntity(task, 0, task.StartTime, 0)
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: text,
Entities: entities,
ID: task.ReplyMessageID,
ReplyMarkup: cancelMarkUp,
})
pr, pw := io.Pipe()
defer pr.Close()
task.StartTime = time.Now()
progressCallback := buildProgressCallback(ctx, task, getProgressUpdateCount(task.File.FileSize))
progressStream := NewProgressStream(pw, task.File.FileSize, progressCallback)
eg, uploadCtx := errgroup.WithContext(cancelCtx)
eg.Go(func() error {
return taskStorage.Save(uploadCtx, pr, task.StoragePath)
})
eg.Go(func() error {
_, err := downloadBuilder.Stream(uploadCtx, progressStream)
if closeErr := pw.CloseWithError(err); closeErr != nil {
common.Log.Errorf("Failed to close pipe writer: %v", closeErr)
}
return err
})
if err := eg.Wait(); err != nil {
return err
}
return nil
}
common.Log.Warnf("存储 %s 不支持流式传输: %s", task.StorageName, notsupportStreamStorage.NotSupportStream())
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: fmt.Sprintf("存储 %s 不支持流式传输: %s\n正在使用普通下载...", task.StorageName, notsupportStreamStorage.NotSupportStream()),
ID: task.ReplyMessageID,
ReplyMarkup: cancelMarkUp,
})
}
cacheDestPath := filepath.Join(config.Cfg.Temp.BasePath, task.FileName())
cacheDestPath, err = filepath.Abs(cacheDestPath)
if err != nil {
return fmt.Errorf("处理路径失败: %w", err)
}
if err := fileutil.CreateDir(filepath.Dir(cacheDestPath)); err != nil {
return fmt.Errorf("创建目录失败: %w", err)
}
text, entities := buildProgressMessageEntity(task, 0, task.StartTime, 0)
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: text,
Entities: entities,
ID: task.ReplyMessageID,
ReplyMarkup: cancelMarkUp,
})
progressCallback := buildProgressCallback(ctx, task, getProgressUpdateCount(task.File.FileSize))
dest, err := NewTaskLocalFile(cacheDestPath, task.File.FileSize, progressCallback)
if err != nil {
return fmt.Errorf("创建文件失败: %w", err)
}
defer dest.Close()
task.StartTime = time.Now()
_, err = downloadBuilder.Parallel(cancelCtx, dest)
if err != nil {
return fmt.Errorf("下载文件失败: %w", err)
}
defer cleanCacheFile(cacheDestPath)
fixTaskFileExt(task, cacheDestPath)
common.Log.Infof("Downloaded file: %s", cacheDestPath)
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: fmt.Sprintf("下载完成: %s\n正在转存文件...", task.FileName()),
ID: task.ReplyMessageID,
})
return saveFileWithRetry(cancelCtx, task.StoragePath, taskStorage, cacheDestPath)
}
func processTelegraph(extCtx *ext.Context, cancelCtx context.Context, task *types.Task, taskStorage storage.Storage) error {
if bot.TelegraphClient == nil {
return fmt.Errorf("telegraph client is not initialized")
}
tgphUrl := task.TelegraphURL
tgphPath := strings.Split(tgphUrl, "/")[len(strings.Split(tgphUrl, "/"))-1]
if tgphUrl == "" || tgphPath == "" {
return fmt.Errorf("invalid telegraph url")
}
entityBuilder := entity.Builder{}
text := fmt.Sprintf("正在下载 Telegraph \n文件夹: %s\n保存路径: %s",
task.FileName(),
fmt.Sprintf("[%s]:%s", task.StorageName, task.StoragePath),
)
var entities []tg.MessageEntityClass
if err := styling.Perform(&entityBuilder,
styling.Plain("正在下载 Telegraph \n文件夹: "),
styling.Code(task.FileName()),
styling.Plain("\n保存路径: "),
styling.Code(fmt.Sprintf("[%s]:%s", task.StorageName, task.StoragePath)),
); err != nil {
common.Log.Errorf("Failed to build entities: %s", err)
}
extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: text,
Entities: entities,
ID: task.ReplyMessageID,
ReplyMarkup: getCancelTaskMarkup(task),
})
resultCh := make(chan error)
go func() {
page, err := bot.TelegraphClient.GetPage(tgphPath, true)
if err != nil {
resultCh <- fmt.Errorf("获取 telegraph 页面失败: %w", err)
return
}
imgs := make([]string, 0)
for _, element := range page.Content {
var node telegraph.NodeElement
data, err := json.Marshal(element)
if err != nil {
common.Log.Errorf("Failed to marshal element: %s", err)
continue
}
err = json.Unmarshal(data, &node)
if err != nil {
common.Log.Errorf("Failed to unmarshal element: %s", err)
continue
}
if len(node.Children) != 0 {
for _, child := range node.Children {
imgs = append(imgs, getNodeImages(child)...)
}
}
if node.Tag == "img" {
if src, ok := node.Attrs["src"]; ok {
imgs = append(imgs, src)
}
}
}
if len(imgs) == 0 {
resultCh <- fmt.Errorf("没有找到图片")
return
}
hc := bot.TelegraphClient.HttpClient
eg, ectx := errgroup.WithContext(cancelCtx)
eg.SetLimit(config.Cfg.Workers) // TODO: use a new config field for this
for i, img := range imgs {
if strings.HasPrefix(img, "/file/") {
img = "https://telegra.ph" + img
}
eg.Go(func() error {
var lastErr error
for attempt := range config.Cfg.Retry {
if attempt > 0 {
retryDelay := time.Duration(attempt*attempt) * time.Second
select {
case <-ectx.Done():
return ectx.Err()
case <-time.After(retryDelay):
}
common.Log.Debugf("Retrying to download image %s (attempt %d)", img, attempt+1)
}
req, err := http.NewRequestWithContext(ectx, http.MethodGet, img, nil)
if err != nil {
lastErr = fmt.Errorf("创建请求失败: %w", err)
continue
}
resp, err := hc.Do(req)
if err != nil {
lastErr = fmt.Errorf("发送请求失败: %w", err)
continue
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
lastErr = fmt.Errorf("请求图片失败: %s", resp.Status)
continue
}
targetPath := path.Join(task.StoragePath, fmt.Sprintf("%d%s", i+1, path.Ext(img)))
err = taskStorage.Save(ectx, resp.Body, targetPath)
if err != nil {
lastErr = fmt.Errorf("保存图片失败: %w", err)
continue
}
common.Log.Infof("Saved image: %s", targetPath)
return nil
}
return lastErr
})
}
if err := eg.Wait(); err != nil {
resultCh <- err
return
}
resultCh <- nil
}()
select {
case err := <-resultCh:
return err
case <-cancelCtx.Done():
return cancelCtx.Err()
}
}

80
core/download_test.go Normal file
View File

@@ -0,0 +1,80 @@
package core
import (
"reflect"
"testing"
"github.com/celestix/telegraph-go/v2"
)
func TestGetImgSrcs(t *testing.T) {
complexStructure := telegraph.NodeElement{
Tag: "div",
Children: []telegraph.Node{
telegraph.NodeElement{
Tag: "figure",
Children: []telegraph.Node{
telegraph.NodeElement{
Tag: "img",
Attrs: map[string]string{
"src": "https://example.com/image1.png",
},
},
telegraph.NodeElement{
Tag: "p",
Children: []telegraph.Node{
"A text node",
},
},
telegraph.NodeElement{
Tag: "figure",
Children: []telegraph.Node{
telegraph.NodeElement{
Tag: "img",
Attrs: map[string]string{
"src": "https://example.com/image2.png",
},
},
},
},
},
},
telegraph.NodeElement{
Tag: "img",
Attrs: map[string]string{
"src": "https://example.com/image3.png",
},
},
"text node",
telegraph.NodeElement{
Tag: "div",
Children: []telegraph.Node{
telegraph.NodeElement{
Tag: "span",
Children: []telegraph.Node{
telegraph.NodeElement{
Tag: "img",
Attrs: map[string]string{
"src": "https://example.com/image4.png",
},
},
},
},
},
},
},
}
expected := []string{
"https://example.com/image1.png",
"https://example.com/image2.png",
"https://example.com/image3.png",
"https://example.com/image4.png",
}
got := getNodeImages(complexStructure)
if !reflect.DeepEqual(expected, got) {
t.Errorf("expected %vgot %v", expected, got)
}
}

View File

@@ -1,154 +0,0 @@
package core
import (
"context"
"fmt"
"io"
"strings"
"github.com/celestix/gotgproto"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/config"
)
type telegramReader struct {
client *gotgproto.Client
location *tg.InputFileLocationClass
bytesread int64
chunkSize int64
i int64
contentLength int64
start int64
end int64
next func() ([]byte, error)
progressCallback func(bytesRead, contentLength int64)
callbackInterval int64
lastProgress int64
buffer []byte
ctx context.Context
}
func (*telegramReader) Close() error {
return nil
}
func (r *telegramReader) Read(p []byte) (n int, err error) {
if r.bytesread == r.contentLength {
return 0, io.EOF
}
if r.i >= int64(len(r.buffer)) {
r.buffer, err = r.next()
if err != nil {
return 0, err
}
if len(r.buffer) == 0 {
r.next = r.partStream()
r.buffer, err = r.next()
if err != nil {
return 0, err
}
}
r.i = 0
}
n = copy(p, r.buffer[r.i:])
r.i += int64(n)
r.bytesread += int64(n)
if r.progressCallback != nil && (r.bytesread-r.lastProgress >= r.callbackInterval || r.bytesread == r.contentLength) {
r.progressCallback(r.bytesread, r.contentLength)
r.lastProgress = r.bytesread
}
return n, nil
}
func NewTelegramReader(
ctx context.Context,
client *gotgproto.Client,
location *tg.InputFileLocationClass,
start int64,
end int64,
contentLength int64,
progressCallback func(bytesRead, contentLength int64),
callbackInterval int64,
) (io.ReadCloser, error) {
r := &telegramReader{
ctx: ctx,
location: location,
client: client,
start: start,
end: end,
chunkSize: int64(1024 * 1024),
contentLength: contentLength,
progressCallback: progressCallback,
callbackInterval: callbackInterval,
}
r.next = r.partStream()
return r, nil
}
func (r *telegramReader) chunk(offset int64, limit int64) ([]byte, error) {
var lastError error
for i := 0; i < config.Cfg.Retry; i++ {
req := &tg.UploadGetFileRequest{
Offset: offset,
Limit: int(limit),
Location: *r.location,
}
res, err := r.client.API().UploadGetFile(r.ctx, req)
if err != nil {
if strings.Contains(err.Error(), tg.ErrTimeout) {
lastError = err
continue
}
return nil, err
}
switch result := res.(type) {
case *tg.UploadFile:
return result.Bytes, nil
default:
return nil, fmt.Errorf("unexpected type %T", r)
}
}
return nil, lastError
}
func (r *telegramReader) partStream() func() ([]byte, error) {
start := r.start
end := r.end
offset := start - (start % r.chunkSize)
firstPartCut := start - offset
lastPartCut := (end % r.chunkSize) + 1
partCount := int((end - offset + r.chunkSize) / r.chunkSize)
currentPart := 1
readData := func() ([]byte, error) {
if currentPart > partCount {
return make([]byte, 0), nil
}
res, err := r.chunk(offset, r.chunkSize)
if err != nil {
return nil, err
}
if len(res) == 0 {
return res, nil
} else if partCount == 1 {
res = res[firstPartCut:lastPartCut]
} else if currentPart == 1 {
res = res[firstPartCut:]
} else if currentPart == partCount {
res = res[:lastPartCut]
}
currentPart++
offset += r.chunkSize
return res, nil
}
return readData
}

View File

@@ -1,28 +1,58 @@
package core
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"os"
"path"
"time"
"github.com/celestix/gotgproto/ext"
"github.com/celestix/telegraph-go/v2"
"github.com/gabriel-vasile/mimetype"
"github.com/gotd/td/telegram/message/entity"
"github.com/gotd/td/telegram/message/styling"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/bot"
"github.com/krau/SaveAny-Bot/common"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/logger"
"github.com/krau/SaveAny-Bot/storage"
"github.com/krau/SaveAny-Bot/types"
)
func saveFileWithRetry(task *types.Task, localFilePath string) error {
func saveFileWithRetry(ctx context.Context, storagePath string, taskStorage storage.Storage, cacheFilePath string) error {
file, err := os.Open(cacheFilePath)
if err != nil {
return fmt.Errorf("failed to open cache file: %w", err)
}
defer file.Close()
fileStat, err := file.Stat()
if err != nil {
return fmt.Errorf("failed to get file stat: %w", err)
}
vctx := context.WithValue(ctx, types.ContextKeyContentLength, fileStat.Size())
for i := 0; i <= config.Cfg.Retry; i++ {
if err := storage.Save(task.Storage, task.Ctx, localFilePath, task.StoragePath); err != nil {
if err := vctx.Err(); err != nil {
return fmt.Errorf("context canceled while saving file: %w", err)
}
file, err := os.Open(cacheFilePath)
if err != nil {
return fmt.Errorf("failed to open cache file: %w", err)
}
defer file.Close()
if err := taskStorage.Save(vctx, file, storagePath); err != nil {
if i == config.Cfg.Retry {
return fmt.Errorf("failed to save file: %w", err)
}
logger.L.Errorf("Failed to save file: %s, retrying...", err)
common.Log.Errorf("Failed to save file: %s, retrying...", err)
select {
case <-vctx.Done():
return fmt.Errorf("context canceled during retry delay: %w", vctx.Err())
case <-time.After(time.Duration(i*500) * time.Millisecond):
}
continue
}
return nil
@@ -30,7 +60,7 @@ func saveFileWithRetry(task *types.Task, localFilePath string) error {
return nil
}
func processPhoto(task *types.Task, cachePath string) error {
func processPhoto(task *types.Task, taskStorage storage.Storage) error {
res, err := bot.Client.API().UploadGetFile(task.Ctx, &tg.UploadGetFileRequest{
Location: task.File.Location,
Offset: 0,
@@ -45,28 +75,9 @@ func processPhoto(task *types.Task, cachePath string) error {
return fmt.Errorf("unexpected type %T", res)
}
if err := os.WriteFile(cachePath, result.Bytes, os.ModePerm); err != nil {
return fmt.Errorf("failed to write file: %w", err)
}
common.Log.Infof("Downloaded photo: %s", task.FileName())
defer cleanCacheFile(cachePath)
logger.L.Infof("Downloaded file: %s", cachePath)
return saveFileWithRetry(task, cachePath)
}
func getProgressBar(progress float64, totalCount int) string {
bar := ""
barSize := 100 / totalCount
for i := 0; i < totalCount; i++ {
if int(progress)/barSize > i {
bar += "█"
} else {
bar += "░"
}
}
return bar
return taskStorage.Save(task.Ctx, bytes.NewReader(result.Bytes), task.StoragePath)
}
func cleanCacheFile(destPath string) {
@@ -74,21 +85,22 @@ func cleanCacheFile(destPath string) {
common.RmFileAfter(destPath, time.Duration(config.Cfg.Temp.CacheTTL)*time.Second)
} else {
if err := os.Remove(destPath); err != nil {
logger.L.Errorf("Failed to purge file: %s", err)
common.Log.Errorf("Failed to purge file: %s", err)
}
}
}
func calculateBarTotalCount(fileSize int64) int {
barTotalCount := 5
// 获取进度需要更新的次数
func getProgressUpdateCount(fileSize int64) int {
updateCount := 5
if fileSize > 1024*1024*1000 {
barTotalCount = 40
updateCount = 50
} else if fileSize > 1024*1024*500 {
barTotalCount = 20
updateCount = 20
} else if fileSize > 1024*1024*200 {
barTotalCount = 10
updateCount = 10
}
return barTotalCount
return updateCount
}
func getSpeed(bytesRead int64, startTime time.Time) string {
@@ -100,13 +112,12 @@ func getSpeed(bytesRead int64, startTime time.Time) string {
return fmt.Sprintf("%.2fMB/s", speed)
}
func buildProgressMessageEntity(task *types.Task, barTotalCount int, bytesRead int64, startTime time.Time, progress float64) (string, []tg.MessageEntityClass) {
func buildProgressMessageEntity(task *types.Task, bytesRead int64, startTime time.Time, progress float64) (string, []tg.MessageEntityClass) {
entityBuilder := entity.Builder{}
text := fmt.Sprintf("正在处理下载任务\n文件名: %s\n保存路径: %s\n平均速度: %s\n当前进度: [%s] %.2f%%",
text := fmt.Sprintf("正在处理下载任务\n文件名: %s\n保存路径: %s\n平均速度: %s\n当前进度: %.2f%%",
task.FileName(),
fmt.Sprintf("[%s]:%s", task.Storage, task.StoragePath),
fmt.Sprintf("[%s]:%s", task.StorageName, task.StoragePath),
getSpeed(bytesRead, startTime),
getProgressBar(progress, barTotalCount),
progress,
)
var entities []tg.MessageEntityClass
@@ -114,14 +125,171 @@ func buildProgressMessageEntity(task *types.Task, barTotalCount int, bytesRead i
styling.Plain("正在处理下载任务\n文件名: "),
styling.Code(task.FileName()),
styling.Plain("\n保存路径: "),
styling.Code(fmt.Sprintf("[%s]:%s", task.Storage, task.StoragePath)),
styling.Code(fmt.Sprintf("[%s]:%s", task.StorageName, task.StoragePath)),
styling.Plain("\n平均速度: "),
styling.Bold(getSpeed(bytesRead, task.StartTime)),
styling.Plain("\n当前进度:\n "),
styling.Code(fmt.Sprintf("[%s] %.2f%%", getProgressBar(progress, barTotalCount), progress)),
styling.Plain("\n当前进度: "),
styling.Bold(fmt.Sprintf("%.2f%%", progress)),
); err != nil {
logger.L.Errorf("Failed to build entities: %s", err)
common.Log.Errorf("Failed to build entities: %s", err)
return text, entities
}
return entityBuilder.Complete()
}
func buildProgressCallback(ctx *ext.Context, task *types.Task, updateCount int) func(bytesRead, contentLength int64) {
return func(bytesRead, contentLength int64) {
progress := float64(bytesRead) / float64(contentLength) * 100
common.Log.Tracef("Downloading %s: %.2f%%", task.String(), progress)
progressInt := int(progress)
if task.File.FileSize < 1024*1024*50 || progressInt == 0 || progressInt%int(100/updateCount) != 0 {
return
}
text, entities := buildProgressMessageEntity(task, bytesRead, task.StartTime, progress)
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: text,
Entities: entities,
ID: task.ReplyMessageID,
ReplyMarkup: getCancelTaskMarkup(task),
})
}
}
func getCancelTaskMarkup(task *types.Task) *tg.ReplyInlineMarkup {
return &tg.ReplyInlineMarkup{
Rows: []tg.KeyboardButtonRow{{Buttons: []tg.KeyboardButtonClass{&tg.KeyboardButtonCallback{Text: "取消任务", Data: fmt.Appendf(nil, "cancel %s", task.Key())}}}},
}
}
func fixTaskFileExt(task *types.Task, localFilePath string) {
if path.Ext(task.FileName()) == "" {
mimeType, err := mimetype.DetectFile(localFilePath)
if err != nil {
common.Log.Errorf("Failed to detect mime type: %s", err)
} else {
task.File.FileName = fmt.Sprintf("%s%s", task.FileName(), mimeType.Extension())
task.StoragePath = fmt.Sprintf("%s%s", task.StoragePath, mimeType.Extension())
}
}
}
func getTaskThreads(fileSize int64) int {
threads := 1
if fileSize > 1024*1024*100 {
threads = config.Cfg.Threads
} else if fileSize > 1024*1024*50 {
threads = config.Cfg.Threads / 2
}
return threads
}
type TaskLocalFile struct {
file *os.File
size int64
done int64
progressCallback func(bytesRead, contentLength int64)
callbackTimes int64
nextCallbackAt int64
callbackInterval int64
}
func (t *TaskLocalFile) Read(p []byte) (n int, err error) {
return t.file.Read(p)
}
func (t *TaskLocalFile) Close() error {
return t.file.Close()
}
func (t *TaskLocalFile) WriteAt(p []byte, off int64) (int, error) {
n, err := t.file.WriteAt(p, off)
if err != nil {
return n, err
}
t.done += int64(n)
if t.progressCallback != nil && t.done >= t.nextCallbackAt {
t.progressCallback(t.done, t.size)
t.nextCallbackAt += t.callbackInterval
}
return n, nil
}
func NewTaskLocalFile(filePath string, fileSize int64, progressCallback func(bytesRead, contentLength int64)) (*TaskLocalFile, error) {
file, err := os.Create(filePath)
if err != nil {
return nil, fmt.Errorf("failed to open file: %w", err)
}
var callbackInterval int64
callbackInterval = fileSize / 100
if callbackInterval == 0 {
callbackInterval = 1
}
return &TaskLocalFile{
file: file,
size: fileSize,
progressCallback: progressCallback,
callbackTimes: 100,
nextCallbackAt: callbackInterval,
callbackInterval: callbackInterval,
}, nil
}
type ProgressStream struct {
writer io.Writer
size int64
done int64
callback func(bytesRead, contentLength int64)
nextAt int64
interval int64
}
func (ps *ProgressStream) Write(p []byte) (n int, err error) {
n, err = ps.writer.Write(p)
if err != nil {
return n, err
}
ps.done += int64(n)
if ps.callback != nil && ps.done >= ps.nextAt {
ps.callback(ps.done, ps.size)
ps.nextAt += ps.interval
}
return n, nil
}
func NewProgressStream(writer io.Writer, size int64, callback func(bytesRead, contentLength int64)) *ProgressStream {
var interval int64
interval = size / 100
if interval == 0 {
interval = 1
}
return &ProgressStream{
writer: writer,
size: size,
callback: callback,
nextAt: interval,
interval: interval,
}
}
func getNodeImages(node telegraph.Node) []string {
var srcs []string
var nodeElement telegraph.NodeElement
data, err := json.Marshal(node)
if err != nil {
return srcs
}
err = json.Unmarshal(data, &nodeElement)
if err != nil {
return srcs
}
if nodeElement.Tag == "img" {
if src, exists := nodeElement.Attrs["src"]; exists {
srcs = append(srcs, src)
}
}
for _, child := range nodeElement.Children {
srcs = append(srcs, getNodeImages(child)...)
}
return srcs
}

19
dao/callback_data.go Normal file
View File

@@ -0,0 +1,19 @@
package dao
func CreateCallbackData(data string) (uint, error) {
callbackData := CallbackData{
Data: data,
}
err := db.Create(&callbackData).Error
return callbackData.ID, err
}
func GetCallbackData(id uint) (string, error) {
var callbackData CallbackData
err := db.First(&callbackData, id).Error
return callbackData.Data, err
}
func DeleteCallbackData(id uint) error {
return db.Unscoped().Where("id = ?", id).Delete(&CallbackData{}).Error
}

View File

@@ -1,14 +1,14 @@
package dao
import (
"fmt"
"os"
"path/filepath"
"time"
"github.com/glebarez/sqlite"
"github.com/krau/SaveAny-Bot/common"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/logger"
"github.com/krau/SaveAny-Bot/types"
"gorm.io/gorm"
glogger "gorm.io/gorm/logger"
)
@@ -16,13 +16,13 @@ import (
var db *gorm.DB
func Init() {
if err := os.MkdirAll(filepath.Dir(config.Cfg.DB.Path), 755); err != nil {
logger.L.Fatal("Failed to create data directory: ", err)
if err := os.MkdirAll(filepath.Dir(config.Cfg.DB.Path), 0755); err != nil {
common.Log.Fatal("Failed to create data directory: ", err)
os.Exit(1)
}
var err error
db, err = gorm.Open(sqlite.Open(config.Cfg.DB.Path), &gorm.Config{
Logger: glogger.New(logger.L, glogger.Config{
Logger: glogger.New(common.Log, glogger.Config{
Colorful: true,
SlowThreshold: time.Second * 5,
LogLevel: glogger.Error,
@@ -32,13 +32,52 @@ func Init() {
PrepareStmt: true,
})
if err != nil {
logger.L.Fatal("Failed to open database: ", err)
common.Log.Fatal("Failed to open database: ", err)
os.Exit(1)
}
logger.L.Debug("Database connected")
db.AutoMigrate(&types.ReceivedFile{}, &types.User{})
common.Log.Debug("Database connected")
if err := db.AutoMigrate(&ReceivedFile{}, &User{}, &Dir{}, &CallbackData{}); err != nil {
common.Log.Fatal("迁移数据库失败, 如果您从旧版本升级, 建议手动删除数据库文件后重试: ", err)
}
for _, admin := range config.Cfg.Telegram.Admins {
CreateUser(int64(admin))
if err := syncUsers(); err != nil {
common.Log.Fatal("Failed to sync users:", err)
}
}
func syncUsers() error {
dbUsers, err := GetAllUsers()
if err != nil {
return fmt.Errorf("failed to get users: %w", err)
}
dbUserMap := make(map[int64]User)
for _, u := range dbUsers {
dbUserMap[u.ChatID] = u
}
cfgUserMap := make(map[int64]struct{})
for _, u := range config.Cfg.Users {
cfgUserMap[u.ID] = struct{}{}
}
for cfgID := range cfgUserMap {
if _, exists := dbUserMap[cfgID]; !exists {
if err := CreateUser(cfgID); err != nil {
return fmt.Errorf("failed to create user %d: %w", cfgID, err)
}
common.Log.Infof("创建用户: %d", cfgID)
}
}
for dbID, dbUser := range dbUserMap {
if _, exists := cfgUserMap[dbID]; !exists {
if err := DeleteUser(&dbUser); err != nil {
return fmt.Errorf("failed to delete user %d: %w", dbID, err)
}
common.Log.Infof("删除用户: %d", dbID)
}
}
return nil
}

43
dao/dir.go Normal file
View File

@@ -0,0 +1,43 @@
package dao
func CreateDirForUser(userID uint, storageName, path string) error {
dir := Dir{
UserID: userID,
StorageName: storageName,
Path: path,
}
return db.Create(&dir).Error
}
func GetDirByID(id uint) (*Dir, error) {
dir := &Dir{}
err := db.First(dir, id).Error
if err != nil {
return nil, err
}
return dir, err
}
func GetUserDirs(userID uint) ([]Dir, error) {
var dirs []Dir
err := db.Where("user_id = ?", userID).Find(&dirs).Error
return dirs, err
}
func GetUserDirsByChatID(chatID int64) ([]Dir, error) {
user, err := GetUserByChatID(chatID)
if err != nil {
return nil, err
}
return GetUserDirs(user.ID)
}
func GetDirsByUserIDAndStorageName(userID uint, storageName string) ([]Dir, error) {
var dirs []Dir
err := db.Where("user_id = ? AND storage_name = ?", userID, storageName).Find(&dirs).Error
return dirs, err
}
func DeleteDirForUser(userID uint, storageName, path string) error {
return db.Unscoped().Where("user_id = ? AND storage_name = ? AND path = ?", userID, storageName, path).Delete(&Dir{}).Error
}

View File

@@ -1,8 +1,6 @@
package dao
import "github.com/krau/SaveAny-Bot/types"
func SaveReceivedFile(receivedFile *types.ReceivedFile) error {
func SaveReceivedFile(receivedFile *ReceivedFile) error {
record, err := GetReceivedFileByChatAndMessageID(receivedFile.ChatID, receivedFile.MessageID)
if err == nil {
receivedFile.ID = record.ID
@@ -10,8 +8,8 @@ func SaveReceivedFile(receivedFile *types.ReceivedFile) error {
return db.Save(receivedFile).Error
}
func GetReceivedFileByChatAndMessageID(chatID int64, messageID int) (*types.ReceivedFile, error) {
var receivedFile types.ReceivedFile
func GetReceivedFileByChatAndMessageID(chatID int64, messageID int) (*ReceivedFile, error) {
var receivedFile ReceivedFile
err := db.Where("chat_id = ? AND message_id = ?", chatID, messageID).First(&receivedFile).Error
if err != nil {
return nil, err
@@ -19,6 +17,6 @@ func GetReceivedFileByChatAndMessageID(chatID int64, messageID int) (*types.Rece
return &receivedFile, nil
}
func DeleteReceivedFile(receivedFile *types.ReceivedFile) error {
return db.Delete(receivedFile).Error
func DeleteReceivedFile(receivedFile *ReceivedFile) error {
return db.Unscoped().Delete(receivedFile).Error
}

39
dao/model.go Normal file
View File

@@ -0,0 +1,39 @@
package dao
import (
"gorm.io/gorm"
)
type ReceivedFile struct {
gorm.Model
Processing bool
// Which chat the file is from
ChatID int64 `gorm:"uniqueIndex:idx_chat_id_message_id;not null"`
// Which message the file is from
MessageID int `gorm:"uniqueIndex:idx_chat_id_message_id;not null"`
ReplyMessageID int
ReplyChatID int64
FileName string
IsTelegraph bool
TelegraphURL string
}
type User struct {
gorm.Model
ChatID int64 `gorm:"uniqueIndex;not null"`
Silent bool
DefaultStorage string // Default storage name
Dirs []Dir
}
type Dir struct {
gorm.Model
UserID uint
StorageName string
Path string
}
type CallbackData struct {
gorm.Model
Data string
}

View File

@@ -1,22 +1,30 @@
package dao
import (
"github.com/krau/SaveAny-Bot/types"
)
func CreateUser(userID int64) error {
if _, err := GetUserByUserID(userID); err == nil {
func CreateUser(chatID int64) error {
if _, err := GetUserByChatID(chatID); err == nil {
return nil
}
return db.Create(&types.User{UserID: userID}).Error
return db.Create(&User{ChatID: chatID}).Error
}
func GetUserByUserID(userID int64) (*types.User, error) {
var user types.User
err := db.Where("user_id = ?", userID).First(&user).Error
func GetAllUsers() ([]User, error) {
var users []User
err := db.Preload("Dirs").Find(&users).Error
return users, err
}
func GetUserByChatID(chatID int64) (*User, error) {
var user User
err := db.
Preload("Dirs").
Where("chat_id = ?", chatID).First(&user).Error
return &user, err
}
func UpdateUser(user *types.User) error {
func UpdateUser(user *User) error {
return db.Save(user).Error
}
func DeleteUser(user *User) error {
return db.Unscoped().Select("Dirs").Delete(user).Error
}

View File

@@ -3,30 +3,11 @@ services:
image: ghcr.io/krau/saveany-bot:latest
container_name: saveany-bot
restart: unless-stopped
environment:
- SAVEANY_TELEGRAM_TOKEN=bot_token
- SAVEANY_TELEGRAM_ADMINS=admin_id1,admin_id2
# 推荐使用自己的 API ID 和 API HASH (https://my.telegram.org)
# 若不配置也可运行, 将使用默认的 API ID 和 API HASH
# - SAVEANY_TELEGRAM_APP_ID=app_id
# - SAVEANY_TELEGRAM_APP_HASH=app_hash
# 本地存储
- SAVEANY_STORAGE_LOCAL_ENABLE=true
- SAVEANY_STORAGE_LOCAL_BASE_PATH=/app/downloads
# Alist
- SAVEANY_STORAGE_ALIST_ENABLE=true
- SAVEANY_STORAGE_ALIST_BASE_PATH=/saveany
- SAVEANY_STORAGE_ALIST_URL=https://example.com
- SAVEANY_STORAGE_ALIST_USERNAME=username
- SAVEANY_STORAGE_ALIST_PASSWORD=password
# webdav
- SAVEANY_STORAGE_WEBDAV_ENABLE=true
- SAVEANY_STORAGE_WEBDAV_BASE_PATH=/saveany
- SAVEANY_STORAGE_WEBDAV_URL=https://example.com
- SAVEANY_STORAGE_WEBDAV_USERNAME=username
- SAVEANY_STORAGE_WEBDAV_PASSWORD=password
volumes:
- ./data:/app/data
- ./config.toml:/app/config.toml
- ./downloads:/app/downloads
- ./cache:/app/cache
- ./cache:/app/cache
# 使用 host 模式以便访问宿主机服务 (如代理)
# 如果你对 Docker 网络模式熟悉, 可以自行修改
network_mode: host

94
docs/docs/deploy.md Normal file
View File

@@ -0,0 +1,94 @@
# 部署指南
## 从二进制文件部署
在 [Release](https://github.com/krau/SaveAny-Bot/releases) 页面下载对应平台的二进制文件.
在解压后目录新建 `config.toml` 文件, 参考 [config.example.toml](https://github.com/krau/SaveAny-Bot/blob/main/config.example.toml) 编辑配置文件.
运行:
```bash
chmod +x saveany-bot
./saveany-bot
```
### 添加为 systemd 服务
创建文件 `/etc/systemd/system/saveany-bot.service` 并写入以下内容:
```
[Unit]
Description=SaveAnyBot
After=systemd-user-sessions.service
[Service]
Type=simple
WorkingDirectory=/yourpath/
ExecStart=/yourpath/saveany-bot
Restart=on-failure
[Install]
WantedBy=multi-user.target
```
设为开机启动并启动服务:
```bash
systemctl enable --now saveany-bot
```
### 为OpenWrt及衍生系统添加开机自启动服务
创建文件 ` /etc/init.d/saveanybot` ,参考[saveanybot](https://github.com/krau/SaveAny-Bot/blob/main/docs/saveanybot)自行修改.
`chmod +x /etc/init.d/saveanybot`
完成后,将文件复制到 `/etc/rc.d`并重命名为`S99saveanybot`.
`chmod +x /etc/rc.d/S99saveanybot`
### 为OpenWrt及衍生系统添加快捷指令
创建文件` /usr/bin/sabot` ,参考[sabot](https://github.com/krau/SaveAny-Bot/blob/main/docs/sabot)自行配置修改,注意此处文件编码仅支持 ANSI 936 .
`chmod +x /usr/bin/sabot`
之后,终端输入`sabot start|stop|restart|status|enable|disable`即可.
## 使用 Docker 部署
### Docker Compose
下载 [docker-compose.yml](https://github.com/krau/SaveAny-Bot/blob/main/docker-compose.yml) 文件, 在同目录下新建 `config.toml` 文件, 参考 [config.example.toml](https://github.com/krau/SaveAny-Bot/blob/main/config.example.toml) 编辑配置文件.
启动:
```bash
docker compose up -d
```
### Docker
```shell
docker run -d --name saveany-bot \
-v /path/to/config.toml:/app/config.toml \
-v /path/to/downloads:/app/downloads \
ghcr.io/krau/saveany-bot:latest
```
## 更新
使用 `upgrade``up` 升级到最新版
```bash
./saveany-bot upgrade
```
如果是 Docker 部署, 使用以下命令更新:
```bash
docker pull ghcr.io/krau/saveany-bot:latest
docker restart saveany-bot
```

16
docs/docs/faq.md Normal file
View File

@@ -0,0 +1,16 @@
# 常见问题
## 上传 alist 失败也会显示成功
在 alist 管理页面适当调整上传分片大小, 为 alist 使用更稳定的网络环境部署, 都可以减少这种情况的发生.
## Bot 提示下载成功但是 alist 未显示
alist 缓存了目录结构, 参考文档可以调整缓存时间
https://alist.nn.ci/zh/guide/drivers/common.html#缓存过期
## docker部署配置了代理后仍无法连接 telegram (初始化客户端超时)
docker 不能直接访问宿主机网络, 如果你不熟悉其用法, 请将容器设为 host 模式:

38
docs/docs/help.md Normal file
View File

@@ -0,0 +1,38 @@
# 使用帮助
## 保存文件
Bot 接受两种消息: 文件和链接.
支持以下链接:
1. 公开频道 (具有用户名) 的消息链接, 例如: `https://t.me/acherkrau/1097`. **即使频道禁止了转发和保存, Bot 依然可以下载其文件.**
2. Telegra.ph 的文章链接, Bot 将下载其中的所有图片
## 静默模式 (silent)
使用 `/silent` 命令可以开关静默模式.
默认情况下不开启静默模式, Bot 会询问你每个文件的保存位置.
开启静默模式后, Bot 会直接保存文件到默认位置, 无需确认.
在开启静默模式之前, 需要使用 `/storage` 命令设置默认保存位置.
## Stream 模式
在配置文件中将 `stream` 设置为 `true` 可以开启 Stream 模式.
未开启时, Bot 处理任务分为两步: 下载和上传. Bot 会将文件暂存到本地, 然后上传到对应存储位置, 最后删除本地文件.
开启后, Bot 将直接将文件流式传输到存储端, 不需要下载到本地.
该功能对于硬盘空间有限的部署环境十分有用, 然而相较于普通模式也具有一些弊端:
- 无法使用多线程从 telegram 下载文件, 速度较慢.
- 网络不稳定时, 任务失败率高.
- 无法在中间层对文件进行处理, 例如自动文件类型识别.
**不支持** Stream 模式的存储端:
- alist

7
docs/docs/index.md Normal file
View File

@@ -0,0 +1,7 @@
# SaveAnyBot 文档
SaveAnyBot 是一个可以保存 Telegram 上的文件到云存储的机器人, 就像 PikPak Bot 一样.
不同的是, SaveAnyBot 提供更灵活的存储端选择, 并实现一些更强大的功能.
本项目以 AGPL-3.0 协议开源, 请遵守协议使用.

33
docs/mkdocs.yml Normal file
View File

@@ -0,0 +1,33 @@
site_name: SaveAnyBot 官方文档
site_author: Krau
site_description: SaveAnyBot 是一个可以保存 Telegram 上的文件到多种云存储的机器人, 本文档将帮助你了解如何部署和使用它.
repo_name: krau/saveany-bot
repo_url: https://github.com/krau/saveany-bot
copyright: CC BY-NC-SA 4.0
theme:
name: material
language: zh
highlightjs: true
palette:
- media: "(prefers-color-scheme)"
toggle:
icon: material/brightness-auto
name: 切换主题
- media: "(prefers-color-scheme: light)"
scheme: default
primary: indigo
toggle:
icon: material/brightness-7
name: 暗色模式
- media: "(prefers-color-scheme: dark)"
scheme: slate
primary: blue grey
toggle:
icon: material/brightness-4
name: 亮色模式
nav:
- index.md
- deploy.md
- help.md
- faq.md

28
docs/sabot Normal file
View File

@@ -0,0 +1,28 @@
#!/bin/sh
case "$1" in
start)
/etc/init.d/saveanybot start
;;
stop)
/etc/init.d/saveanybot stop
;;
restart)
/etc/init.d/saveanybot restart
;;
status)
/etc/init.d/saveanybot status
;;
enable)
/etc/init.d/saveanybot enable
echo "Enable SaveAnyBot auto-start."
;;
disable)
/etc/init.d/saveanybot disable
echo "Disable SaveAnyBot auto-start."
;;
*)
echo "Usage: $0 {start|stop|restart|status|enable|disable}"
exit 1
;;
esac

34
docs/saveanybot Normal file
View File

@@ -0,0 +1,34 @@
#!/bin/sh /etc/rc.common
# This is the OpenWRT init.d script for SaveAnyBot
START=99 # 设置启动顺序,数字越大越后启动
STOP=10 # 设置停止顺序,数字越小越先停止
# 脚本描述
description="SaveAnyBot"
# 设置工作目录和执行文件路径
WORKING_DIR="/mnt/mmc1-1/SaveAnyBot"
EXEC_PATH="$WORKING_DIR/saveany-bot"
# 启动函数
start() {
echo "Starting SaveAnyBot..."
# 切换到工作目录并执行程序
cd $WORKING_DIR
$EXEC_PATH &
}
# 停止函数
stop() {
echo "Stopping SaveAnyBot..."
# 查找并杀死进程
killall saveany-bot
}
# 重启函数
reload() {
stop
start
}

24
go.mod
View File

@@ -5,14 +5,16 @@ go 1.23.5
require (
github.com/blang/semver v3.5.1+incompatible
github.com/celestix/gotgproto v1.0.0-beta20.2
github.com/celestix/telegraph-go/v2 v2.0.4
github.com/gabriel-vasile/mimetype v1.4.8
github.com/gookit/slog v0.5.7
github.com/gotd/contrib v0.21.0
github.com/gotd/td v0.120.0
github.com/minio/minio-go/v7 v7.0.81
github.com/rhysd/go-github-selfupdate v1.2.3
github.com/spf13/cobra v1.8.1
github.com/spf13/viper v1.19.0
github.com/studio-b12/gowebdav v0.10.0
golang.org/x/net v0.35.0
golang.org/x/net v0.37.0
golang.org/x/time v0.10.0
)
@@ -24,13 +26,14 @@ require (
github.com/dlclark/regexp2 v1.11.5 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/fatih/color v1.18.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.8 // indirect
github.com/ghodss/yaml v1.0.0 // indirect
github.com/glebarez/go-sqlite v1.22.0 // indirect
github.com/go-faster/errors v0.7.1 // indirect
github.com/go-faster/jx v1.1.0 // indirect
github.com/go-faster/xor v1.0.0 // indirect
github.com/go-faster/yaml v0.4.6 // indirect
github.com/go-ini/ini v1.67.0 // indirect
github.com/goccy/go-json v0.10.3 // indirect
github.com/google/go-github/v30 v30.1.0 // indirect
github.com/google/go-querystring v1.1.0 // indirect
github.com/google/pprof v0.0.0-20250128161936-077ca0a936bf // indirect
@@ -40,13 +43,16 @@ require (
github.com/inconshreveable/go-update v0.0.0-20160112193335-8152e7eb6ccf // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/klauspost/cpuid/v2 v2.2.8 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/minio/md5-simd v1.1.2 // indirect
github.com/ncruces/go-strftime v0.1.9 // indirect
github.com/ogen-go/ogen v1.10.0 // indirect
github.com/onsi/gomega v1.36.2 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/rs/xid v1.6.0 // indirect
github.com/segmentio/asm v1.2.0 // indirect
github.com/tcnksm/go-gitconfig v0.1.2 // indirect
github.com/ulikunitz/xz v0.5.12 // indirect
@@ -55,7 +61,7 @@ require (
go.opentelemetry.io/otel/trace v1.34.0 // indirect
go.uber.org/atomic v1.11.0 // indirect
go.uber.org/zap v1.27.0 // indirect
golang.org/x/crypto v0.33.0 // indirect
golang.org/x/crypto v0.36.0 // indirect
golang.org/x/mod v0.23.0 // indirect
golang.org/x/oauth2 v0.26.0 // indirect
golang.org/x/tools v0.30.0 // indirect
@@ -73,13 +79,13 @@ require (
github.com/fsnotify/fsnotify v1.8.0 // indirect
github.com/glebarez/sqlite v1.11.0
github.com/gookit/color v1.5.4 // indirect
github.com/gookit/goutil v0.6.18
github.com/gookit/goutil v0.6.18 // indirect
github.com/gookit/gsr v0.1.0 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/klauspost/compress v1.17.11 // indirect
github.com/magiconair/properties v1.8.9 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/mitchellh/mapstructure v1.5.0
github.com/pelletier/go-toml/v2 v2.2.3 // indirect
github.com/sagikazarmark/locafero v0.7.0 // indirect
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
@@ -92,9 +98,9 @@ require (
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/exp v0.0.0-20250210185358-939b2ce775ac // indirect
golang.org/x/sync v0.11.0 // indirect
golang.org/x/sys v0.30.0 // indirect
golang.org/x/text v0.22.0 // indirect
golang.org/x/sync v0.12.0
golang.org/x/sys v0.31.0 // indirect
golang.org/x/text v0.23.0 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
gorm.io/gorm v1.25.12

81
go.sum
View File

@@ -2,10 +2,10 @@ github.com/AnimeKaizoku/cacher v1.0.2 h1:7Bf5qRylWb7q2Evib0OXlhG37/t7BP2HK/7IyPv
github.com/AnimeKaizoku/cacher v1.0.2/go.mod h1:jw0de/b0K6W7Y3T9rHCMGVKUf6oG7hENNcssxYcZTCc=
github.com/blang/semver v3.5.1+incompatible h1:cQNTCjp13qL8KC3Nbxr/y2Bqb63oX6wdnnjpJbkM4JQ=
github.com/blang/semver v3.5.1+incompatible/go.mod h1:kRBLl5iJ+tD4TcOOxsy/0fnwebNt5EWlYSAyrTnjyyk=
github.com/celestix/gotgproto v1.0.0-beta20.1 h1:F7H08CuSiHP0YlZqATBi2wJvg7dxXFvFbpauWFd0IbI=
github.com/celestix/gotgproto v1.0.0-beta20.1/go.mod h1:j42ZhBMUke6QyBLvCgx8tA+TL9L3+pq/Q46B+b5+3aU=
github.com/celestix/gotgproto v1.0.0-beta20.2 h1:+WcsKdsyj4xy+TAV+4Sw6zp1xiQrIr4dMnM31+k8NYM=
github.com/celestix/gotgproto v1.0.0-beta20.2/go.mod h1:j42ZhBMUke6QyBLvCgx8tA+TL9L3+pq/Q46B+b5+3aU=
github.com/celestix/telegraph-go/v2 v2.0.4 h1:w8HWymJFhMSMPjdGoyTh3/NqE3eXAT1njTvelh0338k=
github.com/celestix/telegraph-go/v2 v2.0.4/go.mod h1:vu2LtqM7MgOAJ2LDF8XK27DWdd1QYLBfZGhalEh086Y=
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
@@ -19,8 +19,6 @@ github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46t
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo=
github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ=
github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/duke-git/lancet/v2 v2.3.4 h1:8XGI7P9w+/GqmEBEXYaH/XuNiM0f4/90Ioti0IvYJls=
@@ -51,10 +49,14 @@ github.com/go-faster/xor v1.0.0 h1:2o8vTOgErSGHP3/7XwA5ib1FTtUsNtwCoLLBjl31X38=
github.com/go-faster/xor v1.0.0/go.mod h1:x5CaDY9UKErKzqfRfFZdfu+OSTfoZny3w5Ak7UxcipQ=
github.com/go-faster/yaml v0.4.6 h1:lOK/EhI04gCpPgPhgt0bChS6bvw7G3WwI8xxVe0sw9I=
github.com/go-faster/yaml v0.4.6/go.mod h1:390dRIvV4zbnO7qC9FGo6YYutc+wyyUSHBgbXL52eXk=
github.com/go-ini/ini v1.67.0 h1:z6ZrTEZqSWOTyH2FlglNbNgARyHG8oLW9gMELqKr06A=
github.com/go-ini/ini v1.67.0/go.mod h1:ByCAeIL28uOIIG0E3PJtZPDL8WnHpFKFOtgjp+3Ies8=
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA=
github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
@@ -83,8 +85,6 @@ github.com/gotd/ige v0.2.2 h1:XQ9dJZwBfDnOGSTxKXBGP4gMud3Qku2ekScRjDWWfEk=
github.com/gotd/ige v0.2.2/go.mod h1:tuCRb+Y5Y3eNTo3ypIfNpQ4MFjrnONiL2jN2AKZXmb0=
github.com/gotd/neo v0.1.5 h1:oj0iQfMbGClP8xI59x7fE/uHoTJD7NZH9oV1WNuPukQ=
github.com/gotd/neo v0.1.5/go.mod h1:9A2a4bn9zL6FADufBdt7tZt+WMhvZoc5gWXihOPoiBQ=
github.com/gotd/td v0.118.0 h1:iPGkaOAd3QO72TcvzNJGKGpLDzYOW8GIz+Va2upxBbY=
github.com/gotd/td v0.118.0/go.mod h1:FUNVeJB9Id2Vqps9yF+8kmBNNyCGO6VXDyO8Ah7bVSw=
github.com/gotd/td v0.120.0 h1:XeiafJM82/9SaB+ZMjMm/dnUx5+avINwVZOEsnV0zMo=
github.com/gotd/td v0.120.0/go.mod h1:BCc2jFj1l5zP9Trk4J7nxeqW0KBGl6K95eXMgszkbOI=
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
@@ -100,6 +100,9 @@ github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc=
github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0=
github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.8 h1:+StwCXwm9PdpiEkPyzBXIy+M9KUb4ODm0Zarf1kS5BM=
github.com/klauspost/cpuid/v2 v2.2.8/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
@@ -113,12 +116,14 @@ github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHP
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34=
github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM=
github.com/minio/minio-go/v7 v7.0.81 h1:SzhMN0TQ6T/xSBu6Nvw3M5M8voM+Ht8RH3hE8S7zxaA=
github.com/minio/minio-go/v7 v7.0.81/go.mod h1:84gmIilaX4zcvAWWzJ5Z1WI5axN+hAbM5w25xf8xvC0=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/ogen-go/ogen v1.9.0 h1:n+lDQpiSFYC9G4hTvuNVWnqmIP0LR8ws0faDn9jX3hU=
github.com/ogen-go/ogen v1.9.0/go.mod h1:vkHpuRyzjdfuRCy81EShi4t9sIgZDcNPGmiDKipRloc=
github.com/ogen-go/ogen v1.10.0 h1:x3ukRtq/pdn/k8+pYBtqWceVASiSmgK9M5lrH89Q+04=
github.com/ogen-go/ogen v1.10.0/go.mod h1:WExXrswerPzGWD0NpzBFsz+5eQIbP7HAtZUmpV8dqqI=
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
@@ -137,6 +142,8 @@ github.com/rhysd/go-github-selfupdate v1.2.3 h1:iaa+J202f+Nc+A8zi75uccC8Wg3omaM7
github.com/rhysd/go-github-selfupdate v1.2.3/go.mod h1:mp/N8zj6jFfBQy/XMYoWsmfzxazpPAODuqarmPDe2Rg=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU=
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sagikazarmark/locafero v0.7.0 h1:5MqpDsTGNDhY8sGp0Aowyf0qKsPrhewaLSsFaodPcyo=
github.com/sagikazarmark/locafero v0.7.0/go.mod h1:2za3Cg5rMaTMoG/2Ulr9AwtFaIppKXTRYnozin4aB5k=
@@ -159,8 +166,6 @@ github.com/spf13/viper v1.19.0 h1:RWq5SEjt8o25SROyN3z2OrDB9l7RPd3lwTWU8EcEdcI=
github.com/spf13/viper v1.19.0/go.mod h1:GQUN9bilAbhU/jgc1bKs99f/suXKeUMct8Adx5+Ntkg=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/studio-b12/gowebdav v0.10.0 h1:Yewz8FFiadcGEu4hxS/AAJQlHelndqln1bns3hcJIYc=
github.com/studio-b12/gowebdav v0.10.0/go.mod h1:bHA7t77X/QFExdeAnDzK6vKM34kEZAcE1OX4MfiwjkE=
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
github.com/tcnksm/go-gitconfig v0.1.2 h1:iiDhRitByXAEyjgBqsKi9QU4o2TNtv9kPP3RgPgXBPw=
@@ -190,63 +195,43 @@ go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc=
golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc=
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
golang.org/x/exp v0.0.0-20250128182459-e0ece0dbea4c h1:KL/ZBHXgKGVmuZBZ01Lt57yE5ws8ZPSkkihmEyq7FXc=
golang.org/x/exp v0.0.0-20250128182459-e0ece0dbea4c/go.mod h1:tujkw807nyEEAamNbDrEGzRav+ilXA7PCRAd6xsmwiU=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/exp v0.0.0-20250210185358-939b2ce775ac h1:l5+whBCLH3iH2ZNHYLbAe58bo7yrN4mVcnkHDYz5vvs=
golang.org/x/exp v0.0.0-20250210185358-939b2ce775ac/go.mod h1:hH+7mtFmImwwcMvScyxUhjuVHR3HGaDPMn9rMSUUbxo=
golang.org/x/mod v0.22.0 h1:D4nJWe9zXqHOmWqj4VMOJhvzj7bEZg4wEYa759z1pH4=
golang.org/x/mod v0.22.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
golang.org/x/mod v0.23.0 h1:Zb7khfcRGKk+kqfxFaP5tZqCnDZMjC5VtUBs87Hr6QM=
golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0=
golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k=
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20181106182150-f42d05182288/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.25.0 h1:CY4y7XT9v0cRI9oupztF8AgiIu99L/ksR/Xp/6jrZ70=
golang.org/x/oauth2 v0.25.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
golang.org/x/oauth2 v0.26.0 h1:afQXWNNaeC4nvZ0Ed9XvCCzXM6UHJG7iCg0W4fPqSBE=
golang.org/x/oauth2 v0.26.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.28.0 h1:/Ts8HFuMR2E6IP/jlo7QVLZHggjKQbhu/7H0LJFr3Gg=
golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek=
golang.org/x/term v0.29.0 h1:L6pJp37ocefwRRtYPKSWOWzOtWSxVajvz2ldH/xi3iU=
golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y=
golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY=
golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
golang.org/x/time v0.10.0 h1:3usCWA8tQn0L8+hFJQNgzpWbd89begxN66o1Ojdn5L4=
golang.org/x/time v0.10.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.29.0 h1:Xx0h3TtM9rzQpQuR4dKLrdglAmCEN5Oi+P74JdhdzXE=
golang.org/x/tools v0.29.0/go.mod h1:KMQVMRsVxU6nHCFXrBPhDB8XncLNLM0lIy/F14RP588=
golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY=
golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@@ -269,16 +254,12 @@ gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
modernc.org/cc/v4 v4.24.4 h1:TFkx1s6dCkQpd6dKurBNmpo+G8Zl4Sq/ztJ+2+DEsh0=
modernc.org/cc/v4 v4.24.4/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
modernc.org/ccgo/v4 v4.23.15 h1:wFDan71KnYqeHz4eF63vmGE6Q6Pc0PUGDpP0PRMYjDc=
modernc.org/ccgo/v4 v4.23.15/go.mod h1:nJX30dks/IWuBOnVa7VRii9Me4/9TZ1SC9GNtmARTy0=
modernc.org/ccgo/v4 v4.23.16 h1:Z2N+kk38b7SfySC1ZkpGLN2vthNJP1+ZzGZIlH7uBxo=
modernc.org/ccgo/v4 v4.23.16/go.mod h1:nNma8goMTY7aQZQNTyN9AIoJfxav4nvTnvKThAeMDdo=
modernc.org/fileutil v1.3.0 h1:gQ5SIzK3H9kdfai/5x41oQiKValumqNTDXMvKo62HvE=
modernc.org/fileutil v1.3.0/go.mod h1:XatxS8fZi3pS8/hKG2GH/ArUogfxjpEKs3Ku3aK4JyQ=
modernc.org/gc/v2 v2.6.2 h1:YBXi5Kqp6aCK3fIxwKQ3/fErvawVKwjOLItxj1brGds=
modernc.org/gc/v2 v2.6.2/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
modernc.org/gc/v2 v2.6.3 h1:aJVhcqAte49LF+mGveZ5KPlsp4tdGdAOT4sipJXADjw=
modernc.org/libc v1.61.11 h1:6sZG8uB6EMMG7iTLPTndi8jyTdgAQNIeLGjCFICACZw=
modernc.org/libc v1.61.11/go.mod h1:HHX+srFdn839oaJRd0W8hBM3eg+mieyZCAjWwB08/nM=
modernc.org/gc/v2 v2.6.3/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
modernc.org/libc v1.61.13 h1:3LRd6ZO1ezsFiX1y+bHd1ipyEHIJKvuprv0sLTBwLW8=
modernc.org/libc v1.61.13/go.mod h1:8F/uJWL/3nNil0Lgt1Dpz+GgkApWh04N3el3hxJcA6E=
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
@@ -289,8 +270,6 @@ modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
modernc.org/sqlite v1.34.5 h1:Bb6SR13/fjp15jt70CL4f18JIN7p7dnMExd+UFnF15g=
modernc.org/sqlite v1.34.5/go.mod h1:YLuNmX9NKs8wRNK2ko1LW1NGYcc9FkBO69JOt1AR9JE=
modernc.org/sqlite v1.35.0 h1:yQps4fegMnZFdphtzlfQTCNBWtS0CZv48pRpW3RFHRw=
modernc.org/sqlite v1.35.0/go.mod h1:9cr2sicr7jIaWTBKQmAxQLfBv9LL0su4ZTEV+utt3ic=
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=

View File

@@ -8,30 +8,65 @@ import (
)
type TaskQueue struct {
list *list.List
cond *sync.Cond
mutex *sync.Mutex
list *list.List
cond *sync.Cond
mutex *sync.Mutex
activeMap map[string]*types.Task
}
func (q *TaskQueue) AddTask(task types.Task) {
func (q *TaskQueue) AddTask(task *types.Task) {
q.mutex.Lock()
defer q.mutex.Unlock()
q.list.PushBack(task)
q.cond.Signal()
if task.Status != types.Pending {
delete(q.activeMap, task.Key())
}
}
func (q *TaskQueue) GetTask() types.Task {
func (q *TaskQueue) GetTask() *types.Task {
q.mutex.Lock()
defer q.mutex.Unlock()
for q.list.Len() == 0 {
q.cond.Wait()
}
e := q.list.Front()
task := e.Value.(types.Task)
task := e.Value.(*types.Task)
q.list.Remove(e)
if task.Status == types.Pending {
q.activeMap[task.Key()] = task
}
return task
}
func (q *TaskQueue) DoneTask(task *types.Task) {
q.mutex.Lock()
defer q.mutex.Unlock()
delete(q.activeMap, task.Key())
}
func (q *TaskQueue) CancelTask(key string) bool {
q.mutex.Lock()
defer q.mutex.Unlock()
if task, ok := q.activeMap[key]; ok {
if task.Cancel != nil {
task.Cancel()
return true
}
}
for e := q.list.Front(); e != nil; e = e.Next() {
task := e.Value.(*types.Task)
if task.Key() == key {
if task.Cancel != nil {
task.Cancel()
}
q.list.Remove(e)
return true
}
}
return false
}
func (q *TaskQueue) Len() int {
q.mutex.Lock()
defer q.mutex.Unlock()
@@ -47,20 +82,29 @@ func init() {
func NewQueue() *TaskQueue {
m := &sync.Mutex{}
return &TaskQueue{
list: list.New(),
cond: sync.NewCond(m),
mutex: m,
list: list.New(),
cond: sync.NewCond(m),
mutex: m,
activeMap: make(map[string]*types.Task),
}
}
func AddTask(task types.Task) {
func AddTask(task *types.Task) {
Queue.AddTask(task)
}
func GetTask() types.Task {
func GetTask() *types.Task {
return Queue.GetTask()
}
func Len() int {
return Queue.Len()
}
func CancelTask(key string) bool {
return Queue.CancelTask(key)
}
func DoneTask(task *types.Task) {
Queue.DoneTask(task)
}

View File

@@ -1,19 +1,18 @@
package alist
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path"
"time"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/logger"
"github.com/krau/SaveAny-Bot/common"
config "github.com/krau/SaveAny-Bot/config/storage"
"github.com/krau/SaveAny-Bot/types"
)
type Alist struct {
@@ -21,177 +20,98 @@ type Alist struct {
token string
baseURL string
loginInfo *loginRequest
config config.AlistStorageConfig
}
var (
ErrAlistLoginFailed = errors.New("failed to login to Alist")
)
type loginRequest struct {
Username string `json:"username"`
Password string `json:"password"`
}
type loginResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Data struct {
Token string `json:"token"`
} `json:"data"`
}
type meResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Data struct {
ID int `json:"id"`
Username string `json:"username"`
} `json:"data"`
}
type putResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Data struct {
Task struct {
ID string `json:"id"`
Name string `json:"name"`
State int `json:"state"`
Status string `json:"status"`
Progress int `json:"progress"`
Error string `json:"error"`
} `json:"task"`
} `json:"data"`
}
func (a *Alist) getToken() error {
loginBody, err := json.Marshal(a.loginInfo)
if err != nil {
return fmt.Errorf("failed to marshal login request: %w", err)
func (a *Alist) Init(cfg config.StorageConfig) error {
alistConfig, ok := cfg.(*config.AlistStorageConfig)
if !ok {
return fmt.Errorf("failed to cast alist config")
}
req, err := http.NewRequest(http.MethodPost, a.baseURL+"/api/auth/login", bytes.NewBuffer(loginBody))
if err != nil {
return fmt.Errorf("failed to create login request: %w", err)
if err := alistConfig.Validate(); err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
a.config = *alistConfig
resp, err := a.client.Do(req)
if err != nil {
return fmt.Errorf("failed to send login request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read login response: %w", err)
}
var loginResp loginResponse
if err := json.Unmarshal(body, &loginResp); err != nil {
return fmt.Errorf("failed to unmarshal login response: %w", err)
}
if loginResp.Code != http.StatusOK {
return fmt.Errorf("%w: %s", ErrAlistLoginFailed, loginResp.Message)
}
a.token = loginResp.Data.Token
return nil
}
func (a *Alist) refreshToken() {
for {
time.Sleep(time.Duration(config.Cfg.Storage.Alist.TokenExp) * time.Second)
if err := a.getToken(); err != nil {
logger.L.Errorf("Failed to refresh jwt token: %v", err)
continue
}
logger.L.Info("Refreshed Alist jwt token")
}
}
func (a *Alist) Init() {
a.baseURL = config.Cfg.Storage.Alist.URL
a.client = &http.Client{
Timeout: 12 * time.Hour,
Transport: &http.Transport{
TLSHandshakeTimeout: 10 * time.Second,
},
}
if config.Cfg.Storage.Alist.Token != "" {
a.token = config.Cfg.Storage.Alist.Token
a.baseURL = alistConfig.URL
a.client = getHttpClient()
if alistConfig.Token != "" {
a.token = alistConfig.Token
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, a.baseURL+"/api/me", nil)
if err != nil {
logger.L.Fatalf("Failed to create request: %v", err)
os.Exit(1)
common.Log.Fatalf("Failed to create request: %v", err)
return err
}
req.Header.Set("Authorization", a.token)
resp, err := a.client.Do(req)
if err != nil {
logger.L.Fatalf("Failed to send request: %v", err)
os.Exit(1)
common.Log.Fatalf("Failed to send request: %v", err)
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
logger.L.Fatalf("Failed to get alist user info: %s", resp.Status)
os.Exit(1)
common.Log.Fatalf("Failed to get alist user info: %s", resp.Status)
return err
}
body, err := io.ReadAll(resp.Body)
if err != nil {
logger.L.Fatalf("Failed to read response body: %v", err)
os.Exit(1)
common.Log.Fatalf("Failed to read response body: %v", err)
return err
}
var meResp meResponse
if err := json.Unmarshal(body, &meResp); err != nil {
logger.L.Fatalf("Failed to unmarshal me response: %v", err)
os.Exit(1)
common.Log.Fatalf("Failed to unmarshal me response: %v", err)
return err
}
if meResp.Code != http.StatusOK {
logger.L.Fatalf("Failed to get alist user info: %s", meResp.Message)
os.Exit(1)
common.Log.Fatalf("Failed to get alist user info: %s", meResp.Message)
return err
}
logger.L.Debugf("Logged in Alist as %s", meResp.Data.Username)
return
common.Log.Debugf("Logged in Alist as %s", meResp.Data.Username)
return nil
}
a.loginInfo = &loginRequest{
Username: config.Cfg.Storage.Alist.Username,
Password: config.Cfg.Storage.Alist.Password,
Username: alistConfig.Username,
Password: alistConfig.Password,
}
if err := a.getToken(); err != nil {
logger.L.Fatalf("Failed to login to Alist: %v", err)
os.Exit(1)
common.Log.Fatalf("Failed to login to Alist: %v", err)
return err
}
logger.L.Debug("Logged in to Alist")
common.Log.Debug("Logged in to Alist")
go a.refreshToken()
go a.refreshToken(*alistConfig)
return nil
}
func (a *Alist) Save(ctx context.Context, filePath, storagePath string) error {
file, err := os.Open(filePath)
if err != nil {
return fmt.Errorf("failed to open file: %w", err)
}
defer file.Close()
func (a *Alist) Type() types.StorageType {
return types.StorageTypeAlist
}
filestat, err := file.Stat()
if err != nil {
return fmt.Errorf("failed to get file stats: %w", err)
}
func (a *Alist) Name() string {
return a.config.Name
}
req, err := http.NewRequestWithContext(ctx, http.MethodPut, a.baseURL+"/api/fs/put", file)
func (a *Alist) Save(ctx context.Context, reader io.Reader, storagePath string) error {
common.Log.Infof("Saving file to %s", storagePath)
req, err := http.NewRequestWithContext(ctx, http.MethodPut, a.baseURL+"/api/fs/put", reader)
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Authorization", a.token)
req.Header.Set("File-Path", url.PathEscape(storagePath))
req.Header.Set("As-Task", "true")
req.Header.Set("Content-Type", "application/octet-stream")
req.ContentLength = filestat.Size()
if length := ctx.Value(types.ContextKeyContentLength); length != nil {
length, ok := length.(int64)
if ok {
req.ContentLength = length
}
}
resp, err := a.client.Do(req)
if err != nil {
@@ -219,3 +139,11 @@ func (a *Alist) Save(ctx context.Context, filePath, storagePath string) error {
return nil
}
func (a *Alist) NotSupportStream() string {
return "Alist does not support chunked transfer encoding"
}
func (a *Alist) JoinStoragePath(task types.Task) string {
return path.Join(a.config.BasePath, task.StoragePath)
}

65
storage/alist/token.go Normal file
View File

@@ -0,0 +1,65 @@
package alist
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
"github.com/krau/SaveAny-Bot/common"
config "github.com/krau/SaveAny-Bot/config/storage"
)
func (a *Alist) getToken() error {
loginBody, err := json.Marshal(a.loginInfo)
if err != nil {
return fmt.Errorf("failed to marshal login request: %w", err)
}
req, err := http.NewRequest(http.MethodPost, a.baseURL+"/api/auth/login", bytes.NewBuffer(loginBody))
if err != nil {
return fmt.Errorf("failed to create login request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := a.client.Do(req)
if err != nil {
return fmt.Errorf("failed to send login request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read login response: %w", err)
}
var loginResp loginResponse
if err := json.Unmarshal(body, &loginResp); err != nil {
return fmt.Errorf("failed to unmarshal login response: %w", err)
}
if loginResp.Code != http.StatusOK {
return fmt.Errorf("%w: %s", ErrAlistLoginFailed, loginResp.Message)
}
a.token = loginResp.Data.Token
return nil
}
func (a *Alist) refreshToken(cfg config.AlistStorageConfig) {
tokenExp := cfg.TokenExp
if tokenExp <= 0 {
common.Log.Warn("Invalid token expiration time, using default value")
tokenExp = 3600
}
for {
time.Sleep(time.Duration(tokenExp) * time.Second)
if err := a.getToken(); err != nil {
common.Log.Errorf("Failed to refresh jwt token: %v", err)
continue
}
common.Log.Info("Refreshed Alist jwt token")
}
}

44
storage/alist/types.go Normal file
View File

@@ -0,0 +1,44 @@
package alist
import "errors"
var (
ErrAlistLoginFailed = errors.New("failed to login to Alist")
)
type loginRequest struct {
Username string `json:"username"`
Password string `json:"password"`
}
type loginResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Data struct {
Token string `json:"token"`
} `json:"data"`
}
type meResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Data struct {
ID int `json:"id"`
Username string `json:"username"`
} `json:"data"`
}
type putResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Data struct {
Task struct {
ID string `json:"id"`
Name string `json:"name"`
State int `json:"state"`
Status string `json:"status"`
Progress int `json:"progress"`
Error string `json:"error"`
} `json:"task"`
} `json:"data"`
}

23
storage/alist/utils.go Normal file
View File

@@ -0,0 +1,23 @@
package alist
import (
"net/http"
"time"
)
var (
httpClient *http.Client
)
func getHttpClient() *http.Client {
if httpClient != nil {
return httpClient
}
httpClient = &http.Client{
Timeout: 12 * time.Hour,
Transport: &http.Transport{
TLSHandshakeTimeout: 10 * time.Second,
},
}
return httpClient
}

9
storage/errs.go Normal file
View File

@@ -0,0 +1,9 @@
package storage
import (
"errors"
)
var (
ErrStorageNameEmpty = errors.New("storage name is empty")
)

View File

@@ -2,25 +2,52 @@ package local
import (
"context"
"fmt"
"io"
"os"
"path/filepath"
"github.com/duke-git/lancet/v2/fileutil"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/logger"
"github.com/krau/SaveAny-Bot/common"
config "github.com/krau/SaveAny-Bot/config/storage"
"github.com/krau/SaveAny-Bot/types"
)
type Local struct{}
func (l *Local) Init() {
err := os.MkdirAll(config.Cfg.Storage.Local.BasePath, os.ModePerm)
if err != nil {
logger.L.Fatalf("Failed to create local storage directory: %s", err)
os.Exit(1)
}
type Local struct {
config config.LocalStorageConfig
}
func (l *Local) Save(ctx context.Context, filePath, storagePath string) error {
func (l *Local) Init(cfg config.StorageConfig) error {
localConfig, ok := cfg.(*config.LocalStorageConfig)
if !ok {
return fmt.Errorf("failed to cast local config")
}
if err := localConfig.Validate(); err != nil {
return err
}
l.config = *localConfig
err := os.MkdirAll(localConfig.BasePath, os.ModePerm)
if err != nil {
return fmt.Errorf("failed to create local storage directory: %w", err)
}
return nil
}
func (l *Local) Type() types.StorageType {
return types.StorageTypeLocal
}
func (l *Local) Name() string {
return l.config.Name
}
func (l *Local) JoinStoragePath(task types.Task) string {
return filepath.Join(l.config.BasePath, task.StoragePath)
}
func (l *Local) Save(ctx context.Context, r io.Reader, storagePath string) error {
common.Log.Infof("Saving file to %s", storagePath)
absPath, err := filepath.Abs(storagePath)
if err != nil {
return err
@@ -28,5 +55,11 @@ func (l *Local) Save(ctx context.Context, filePath, storagePath string) error {
if err := fileutil.CreateDir(filepath.Dir(absPath)); err != nil {
return err
}
return fileutil.CopyFile(filePath, storagePath)
file, err := os.Create(absPath)
if err != nil {
return err
}
defer file.Close()
_, err = io.Copy(file, r)
return err
}

72
storage/minio/client.go Normal file
View File

@@ -0,0 +1,72 @@
package minio
import (
"context"
"fmt"
"io"
"path"
"github.com/krau/SaveAny-Bot/common"
config "github.com/krau/SaveAny-Bot/config/storage"
"github.com/krau/SaveAny-Bot/types"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
)
type Minio struct {
config config.MinioStorageConfig
client *minio.Client
}
func (m *Minio) Init(cfg config.StorageConfig) error {
minioConfig, ok := cfg.(*config.MinioStorageConfig)
if !ok {
return fmt.Errorf("failed to cast minio config")
}
if err := minioConfig.Validate(); err != nil {
return err
}
m.config = *minioConfig
client, err := minio.New(m.config.Endpoint, &minio.Options{
Creds: credentials.NewStaticV4(m.config.AccessKeyID, m.config.SecretAccessKey, ""),
Secure: m.config.UseSSL,
})
if err != nil {
return fmt.Errorf("failed to create minio client: %w", err)
}
exists, err := client.BucketExists(context.Background(), m.config.BucketName)
if err != nil {
return fmt.Errorf("failed to check bucket existence: %w", err)
}
if !exists {
return fmt.Errorf("bucket %s does not exist", m.config.BucketName)
}
m.client = client
return nil
}
func (m *Minio) Type() types.StorageType {
return types.StorageTypeMinio
}
func (m *Minio) Name() string {
return m.config.Name
}
func (m *Minio) JoinStoragePath(task types.Task) string {
return path.Join(m.config.BasePath, task.StoragePath)
}
func (m *Minio) Save(ctx context.Context, r io.Reader, storagePath string) error {
common.Log.Infof("Saving file from reader to %s", storagePath)
_, err := m.client.PutObject(ctx, m.config.BucketName, storagePath, r, -1, minio.PutObjectOptions{})
if err != nil {
return fmt.Errorf("failed to upload file to minio: %w", err)
}
return nil
}

View File

@@ -2,83 +2,123 @@ package storage
import (
"context"
"errors"
"path"
"path/filepath"
"sync"
"fmt"
"io"
"github.com/duke-git/lancet/v2/slice"
"github.com/krau/SaveAny-Bot/common"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/logger"
sc "github.com/krau/SaveAny-Bot/config/storage"
"github.com/krau/SaveAny-Bot/storage/alist"
"github.com/krau/SaveAny-Bot/storage/local"
"github.com/krau/SaveAny-Bot/storage/minio"
"github.com/krau/SaveAny-Bot/storage/webdav"
"github.com/krau/SaveAny-Bot/types"
)
type Storage interface {
Init()
Save(cttx context.Context, localFilePath, storagePath string) error
Init(cfg sc.StorageConfig) error
Type() types.StorageType
Name() string
JoinStoragePath(task types.Task) string
Save(ctx context.Context, reader io.Reader, storagePath string) error
}
var Storages = make(map[types.StorageType]Storage)
var StorageKeys = make([]types.StorageType, 0)
func Init() {
logger.L.Debug("Initializing storage...")
if config.Cfg.Storage.Alist.Enable {
Storages[types.Alist] = new(alist.Alist)
Storages[types.Alist].Init()
}
if config.Cfg.Storage.Local.Enable {
Storages[types.Local] = new(local.Local)
Storages[types.Local].Init()
}
if config.Cfg.Storage.Webdav.Enable {
Storages[types.Webdav] = new(webdav.Webdav)
Storages[types.Webdav].Init()
}
for k := range Storages {
StorageKeys = append(StorageKeys, k)
}
slice.Sort(StorageKeys)
logger.L.Debug("Storage initialized")
type StorageNotSupportStream interface {
Storage
NotSupportStream() string
}
func Save(storageType types.StorageType, ctx context.Context, filePath, storagePath string) error {
logger.L.Debugf("Saving file %s to storage: [%s] %s", filePath, storageType, storagePath)
if ctx == nil {
ctx = context.Background()
var Storages = make(map[string]Storage)
var UserStorages = make(map[int64][]Storage)
// GetStorageByName returns storage by name from cache or creates new one
func GetStorageByName(name string) (Storage, error) {
if name == "" {
return nil, ErrStorageNameEmpty
}
if storageType != types.StorageAll {
return Storages[storageType].Save(ctx, filePath, storagePath)
storage, ok := Storages[name]
if ok {
return storage, nil
}
errs := make([]error, 0)
var wg sync.WaitGroup
for _, storage := range Storages {
wg.Add(1)
go func(storage Storage) {
defer wg.Done()
storageDestPath := storagePath
switch storage.(type) {
case *local.Local:
storageDestPath = filepath.Join(config.Cfg.Storage.Local.BasePath, storagePath)
case *webdav.Webdav:
storageDestPath = path.Join(config.Cfg.Storage.Webdav.BasePath, storagePath)
case *alist.Alist:
storageDestPath = path.Join(config.Cfg.Storage.Alist.BasePath, storagePath)
}
if err := storage.Save(ctx, filePath, storageDestPath); err != nil {
errs = append(errs, err)
}
}(storage)
cfg := config.Cfg.GetStorageByName(name)
if cfg == nil {
return nil, fmt.Errorf("未找到存储 %s", name)
}
wg.Wait()
if len(errs) > 0 {
return errors.Join(errs...)
storage, err := NewStorage(cfg)
if err != nil {
return nil, err
}
Storages[name] = storage
return storage, nil
}
// 检查 user 是否可用指定的 storage, 若不可用则返回未找到错误
func GetStorageByUserIDAndName(chatID int64, name string) (Storage, error) {
if name == "" {
return nil, ErrStorageNameEmpty
}
if !config.Cfg.HasStorage(chatID, name) {
return nil, fmt.Errorf("没有找到用户 %d 的存储 %s", chatID, name)
}
return GetStorageByName(name)
}
func GetUserStorages(chatID int64) []Storage {
if chatID <= 0 {
return nil
}
if storages, ok := UserStorages[chatID]; ok {
return storages
}
var storages []Storage
for _, name := range config.Cfg.GetStorageNamesByUserID(chatID) {
storage, err := GetStorageByName(name)
if err != nil {
continue
}
storages = append(storages, storage)
}
return storages
}
type StorageConstructor func() Storage
var storageConstructors = map[string]StorageConstructor{
string(types.StorageTypeAlist): func() Storage { return new(alist.Alist) },
string(types.StorageTypeLocal): func() Storage { return new(local.Local) },
string(types.StorageTypeWebdav): func() Storage { return new(webdav.Webdav) },
string(types.StorageTypeMinio): func() Storage { return new(minio.Minio) },
}
func NewStorage(cfg sc.StorageConfig) (Storage, error) {
constructor, ok := storageConstructors[string(cfg.GetType())]
if !ok {
return nil, fmt.Errorf("不支持的存储类型: %s", cfg.GetType())
}
storage := constructor()
if err := storage.Init(cfg); err != nil {
return nil, fmt.Errorf("初始化 %s 存储失败: %w", cfg.GetName(), err)
}
return storage, nil
}
func LoadStorages() {
common.Log.Info("加载存储...")
for _, storage := range config.Cfg.Storages {
_, err := GetStorageByName(storage.GetName())
if err != nil {
common.Log.Errorf("加载存储 %s 失败: %v", storage.GetName(), err)
}
}
common.Log.Infof("成功加载 %d 个存储", len(Storages))
for user := range config.Cfg.GetUsersID() {
UserStorages[int64(user)] = GetUserStorages(int64(user))
}
return nil
}

View File

@@ -0,0 +1,130 @@
package webdav
import (
"context"
"net/http/httptest"
"os"
"path"
"path/filepath"
"strings"
"testing"
"golang.org/x/net/webdav"
)
func setupWebDAVServer(t *testing.T) (*httptest.Server, string) {
t.Helper()
tempDir, err := os.MkdirTemp("", "webdav_test")
if err != nil {
t.Fatalf("mk temp dir failed: %v", err)
}
handler := &webdav.Handler{
Prefix: "/",
FileSystem: webdav.Dir(tempDir),
LockSystem: webdav.NewMemLS(),
}
server := httptest.NewServer(handler)
return server, tempDir
}
func TestMkDirAndExists(t *testing.T) {
server, tempDir := setupWebDAVServer(t)
defer os.RemoveAll(tempDir)
defer server.Close()
client := NewClient(server.URL, "", "", nil)
ctx := context.Background()
testpaths := []string{"testdir", "testdir/subdir", "testdir/子目录", "/testdir/测试路径/测试路径2"}
for _, p := range testpaths {
exists, err := client.Exists(ctx, p)
if err != nil {
t.Fatalf("Call Exists Err: %v", err)
}
if exists {
t.Fatalf("Dir should not exist")
}
if err := client.MkDir(ctx, p); err != nil {
t.Fatalf("Call MkDir Err: %v", err)
}
exists, err = client.Exists(ctx, p)
if err != nil {
t.Fatalf("Call Exists Err: %v", err)
}
if !exists {
t.Fatalf("Dir should exist")
}
}
}
func TestWriteFile(t *testing.T) {
server, tempDir := setupWebDAVServer(t)
defer os.RemoveAll(tempDir)
defer server.Close()
client := NewClient(server.URL, "", "", nil)
ctx := context.Background()
testCases := []struct {
remotePath string
content string
}{
{
remotePath: "hello.txt",
content: "Hello webdav",
},
{
remotePath: "nested/dir/test.txt",
content: "Nested file",
},
{
remotePath: "empty.txt",
content: "",
},
{
remotePath: "unicode.txt",
content: "测试",
},
}
for _, tc := range testCases {
t.Run(tc.remotePath, func(t *testing.T) {
dir := path.Dir(tc.remotePath)
if dir != "." {
if err := client.MkDir(ctx, dir); err != nil {
t.Fatalf("创建目录 %s 失败: %v", dir, err)
}
}
if err := client.WriteFile(ctx, tc.remotePath, strings.NewReader(tc.content)); err != nil {
t.Fatalf("写入文件 %s 失败: %v", tc.remotePath, err)
}
localPath := filepath.Join(tempDir, tc.remotePath)
data, err := os.ReadFile(localPath)
if err != nil {
t.Fatalf("读取文件 %s 失败: %v", localPath, err)
}
if string(data) != tc.content {
t.Fatalf("文件内容不匹配: got %s, want %s", string(data), tc.content)
}
appended := tc.content + " Overwritten."
if err := client.WriteFile(ctx, tc.remotePath, strings.NewReader(appended)); err != nil {
t.Fatalf("覆盖写入文件 %s 失败: %v", tc.remotePath, err)
}
data, err = os.ReadFile(localPath)
if err != nil {
t.Fatalf("读取覆盖后的文件 %s 失败: %v", localPath, err)
}
if string(data) != appended {
t.Fatalf("文件覆盖后的内容不匹配: got %s, want %s", string(data), appended)
}
})
}
}

114
storage/webdav/client.go Normal file
View File

@@ -0,0 +1,114 @@
package webdav
import (
"context"
"fmt"
"io"
"net/http"
"strings"
"github.com/krau/SaveAny-Bot/types"
)
type Client struct {
BaseURL string
Username string
Password string
httpClient *http.Client
}
func NewClient(baseURL, username, password string, httpClient *http.Client) *Client {
if !strings.HasSuffix(baseURL, "/") {
baseURL += "/"
}
if httpClient == nil {
httpClient = http.DefaultClient
}
return &Client{
BaseURL: baseURL,
Username: username,
Password: password,
httpClient: httpClient,
}
}
func (c *Client) doRequest(ctx context.Context, method, url string, body io.Reader) (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, method, url, body)
if err != nil {
return nil, err
}
if c.Username != "" && c.Password != "" {
req.SetBasicAuth(c.Username, c.Password)
}
if length := ctx.Value(types.ContextKeyContentLength); length != nil {
if l, ok := length.(int64); ok {
req.ContentLength = l
}
}
return c.httpClient.Do(req)
}
func (c *Client) Exists(ctx context.Context, remotePath string) (bool, error) {
url := c.BaseURL + remotePath
resp, err := c.doRequest(ctx, "PROPFIND", url, nil)
if err != nil {
return false, err
}
defer resp.Body.Close()
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
return true, nil
}
if resp.StatusCode == http.StatusNotFound {
return false, nil
}
return false, fmt.Errorf("PROPFIND: %s", resp.Status)
}
func (c *Client) MkDir(ctx context.Context, dirPath string) error {
dirPath = strings.Trim(dirPath, "/")
if dirPath == "" {
return nil
}
parts := strings.Split(dirPath, "/")
currentPath := ""
for i, part := range parts {
if i > 0 {
currentPath += "/"
}
currentPath += part
exists, err := c.Exists(ctx, currentPath)
if err != nil {
return err
}
if exists {
continue
}
url := c.BaseURL + currentPath
resp, err := c.doRequest(ctx, "MKCOL", url, nil)
if err != nil {
return err
}
resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return fmt.Errorf("MKCOL %s: %s", currentPath, resp.Status)
}
}
return nil
}
func (c *Client) WriteFile(ctx context.Context, remotePath string, content io.Reader) error {
url := c.BaseURL + remotePath
resp, err := c.doRequest(ctx, "PUT", url, content)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
return nil
}
return fmt.Errorf("PUT: %s", resp.Status)
}

View File

@@ -2,45 +2,57 @@ package webdav
import (
"context"
"os"
"fmt"
"io"
"net/http"
"path"
"time"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/logger"
"github.com/studio-b12/gowebdav"
"github.com/krau/SaveAny-Bot/common"
config "github.com/krau/SaveAny-Bot/config/storage"
"github.com/krau/SaveAny-Bot/types"
)
type Webdav struct{}
var (
Client *gowebdav.Client
)
func (w *Webdav) Init() {
webdavConfig := config.Cfg.Storage.Webdav
Client = gowebdav.NewClient(webdavConfig.URL, webdavConfig.Username, webdavConfig.Password)
if err := Client.Connect(); err != nil {
logger.L.Fatalf("Failed to connect to webdav server: %v", err)
os.Exit(1)
}
Client.SetTimeout(24 * time.Hour)
type Webdav struct {
config config.WebdavStorageConfig
client *Client
}
func (w *Webdav) Save(ctx context.Context, filePath, storagePath string) error {
if err := Client.MkdirAll(path.Dir(storagePath), os.ModePerm); err != nil {
logger.L.Errorf("Failed to create directory %s: %v", path.Dir(storagePath), err)
return ErrFailedToCreateDirectory
func (w *Webdav) Init(cfg config.StorageConfig) error {
webdavConfig, ok := cfg.(*config.WebdavStorageConfig)
if !ok {
return fmt.Errorf("failed to cast webdav config")
}
file, err := os.Open(filePath)
if err != nil {
logger.L.Errorf("Failed to open file %s: %v", filePath, err)
if err := webdavConfig.Validate(); err != nil {
return err
}
defer file.Close()
w.config = *webdavConfig
w.client = NewClient(w.config.URL, w.config.Username, w.config.Password, &http.Client{
Timeout: time.Hour * 12,
})
return nil
}
if err := Client.WriteStream(storagePath, file, os.ModePerm); err != nil {
logger.L.Errorf("Failed to write file %s: %v", storagePath, err)
func (w *Webdav) Type() types.StorageType {
return types.StorageTypeWebdav
}
func (w *Webdav) Name() string {
return w.config.Name
}
func (w *Webdav) JoinStoragePath(task types.Task) string {
return path.Join(w.config.BasePath, task.StoragePath)
}
func (w *Webdav) Save(ctx context.Context, r io.Reader, storagePath string) error {
common.Log.Infof("Saving file to %s", storagePath)
if err := w.client.MkDir(ctx, path.Dir(storagePath)); err != nil {
common.Log.Errorf("Failed to create directory %s: %v", path.Dir(storagePath), err)
return ErrFailedToCreateDirectory
}
if err := w.client.WriteFile(ctx, storagePath, r); err != nil {
common.Log.Errorf("Failed to write file %s: %v", storagePath, err)
return ErrFailedToWriteFile
}
return nil

View File

@@ -1,22 +0,0 @@
package types
import (
"gorm.io/gorm"
)
type ReceivedFile struct {
gorm.Model
Processing bool
ChatID int64 `gorm:"uniqueIndex:idx_chat_id_message_id;not null"`
MessageID int `gorm:"uniqueIndex:idx_chat_id_message_id;not null"`
ReplyMessageID int
ReplyChatID int64
FileName string
}
type User struct {
gorm.Model
UserID int64 `gorm:"uniqueIndex"`
Silent bool
DefaultStorage string
}

82
types/task.go Normal file
View File

@@ -0,0 +1,82 @@
package types
import (
"context"
"crypto/md5"
"encoding/hex"
"fmt"
"net/url"
"strings"
"time"
"github.com/gotd/td/tg"
)
type Task struct {
Ctx context.Context
Cancel context.CancelFunc
Error error
Status TaskStatus
StorageName string
StoragePath string
StartTime time.Time
File *File
FileMessageID int
FileChatID int64
IsTelegraph bool
TelegraphURL string
// to track the reply message
ReplyMessageID int
ReplyChatID int64
UserID int64
}
func (t Task) Key() string {
if t.IsTelegraph {
return hashStr(t.TelegraphURL)
}
return fmt.Sprintf("%d:%d", t.FileChatID, t.FileMessageID)
}
func (t Task) String() string {
if t.IsTelegraph {
return fmt.Sprintf("[telegraph]:%s", t.TelegraphURL)
}
return fmt.Sprintf("[%d:%d]:%s", t.FileChatID, t.FileMessageID, t.File.FileName)
}
func (t Task) FileName() string {
if t.IsTelegraph {
tgphPath := strings.Split(t.TelegraphURL, "/")[len(strings.Split(t.TelegraphURL, "/"))-1]
tgphPathUnescaped, err := url.PathUnescape(tgphPath)
if err != nil {
return tgphPath
}
return tgphPathUnescaped
}
return t.File.FileName
}
type File struct {
Location tg.InputFileLocationClass
FileSize int64
FileName string
}
func (f File) Hash() string {
locationBytes := []byte(f.Location.String())
fileSizeBytes := []byte(fmt.Sprintf("%d", f.FileSize))
fileNameBytes := []byte(f.FileName)
structBytes := append(locationBytes, fileSizeBytes...)
structBytes = append(structBytes, fileNameBytes...)
hash := md5.New()
hash.Write(structBytes)
hashBytes := hash.Sum(nil)
return hex.EncodeToString(hashBytes)
}

View File

@@ -1,18 +1,8 @@
package types
import (
"context"
"crypto/md5"
"encoding/hex"
"fmt"
"time"
"github.com/gotd/td/tg"
)
type TaskStatus string
var (
const (
Pending TaskStatus = "pending"
Succeeded TaskStatus = "succeeded"
Failed TaskStatus = "failed"
@@ -21,55 +11,23 @@ var (
type StorageType string
var (
StorageAll StorageType = "all"
Local StorageType = "local"
Webdav StorageType = "webdav"
Alist StorageType = "alist"
const (
StorageTypeLocal StorageType = "local"
StorageTypeWebdav StorageType = "webdav"
StorageTypeAlist StorageType = "alist"
StorageTypeMinio StorageType = "minio"
)
var StorageTypes = []StorageType{Local, Alist, Webdav, StorageAll}
type Task struct {
Ctx context.Context
Error error
Status TaskStatus
File *File
Storage StorageType
StoragePath string
StartTime time.Time
FileMessageID int
FileChatID int64
ReplyMessageID int
ReplyChatID int64
var StorageTypes = []StorageType{StorageTypeLocal, StorageTypeAlist, StorageTypeWebdav, StorageTypeMinio}
var StorageTypeDisplay = map[StorageType]string{
StorageTypeLocal: "本地磁盘",
StorageTypeWebdav: "WebDAV",
StorageTypeAlist: "Alist",
StorageTypeMinio: "Minio",
}
func (t Task) String() string {
return fmt.Sprintf("[%d:%d]:%s", t.FileChatID, t.FileMessageID, t.File.FileName)
}
type ContextKey string
func (t Task) FileName() string {
return t.File.FileName
}
type File struct {
Location tg.InputFileLocationClass
FileSize int64
FileName string
}
func (f File) Hash() string {
locationBytes := []byte(f.Location.String())
fileSizeBytes := []byte(fmt.Sprintf("%d", f.FileSize))
fileNameBytes := []byte(f.FileName)
structBytes := append(locationBytes, fileSizeBytes...)
structBytes = append(structBytes, fileNameBytes...)
hash := md5.New()
hash.Write(structBytes)
hashBytes := hash.Sum(nil)
return hex.EncodeToString(hashBytes)
}
const (
ContextKeyContentLength ContextKey = "content-length"
)

12
types/utils.go Normal file
View File

@@ -0,0 +1,12 @@
package types
import (
"crypto/md5"
"encoding/hex"
)
func hashStr(s string) string {
hash := md5.New()
hash.Write([]byte(s))
return hex.EncodeToString(hash.Sum(nil))
}