Compare commits

...

18 Commits

Author SHA1 Message Date
krau
7d899ae088 ci: Is anyone really using Windows ARM? 2025-03-01 14:01:10 +08:00
krau
7e67bdb7e2 fix: update executable compression condition for Windows ARM64 in build-release workflow 2025-03-01 13:55:42 +08:00
krau
0071780ff4 typo: deploy 2025-03-01 13:44:04 +08:00
krau
0a95431468 feat: add name to build release workflow 2025-03-01 13:39:08 +08:00
krau
34525c5b11 feat: add docs 2025-03-01 13:37:09 +08:00
krau
6ac6d79fb6 feat: update docker-compose.yml to use host network mode for accessing host services 2025-03-01 12:31:20 +08:00
krau
f21a82ad43 chore: clean up README.md by removing unnecessary demo video section 2025-03-01 12:29:43 +08:00
Krau
73f6647f8d Merge pull request #33 from krau/dev-stream
impl webdav stream mode & progress callback for stream mode
2025-03-01 12:24:46 +08:00
krau
6fbb4609f9 feat: show progress for stream mode 2025-03-01 12:22:50 +08:00
krau
802c908384 feat: refactor webdav client and implement custom upload stream handling 2025-03-01 12:06:55 +08:00
Krau
5d403056d0 Merge pull request #32 from krau/dev-stream
feat: add stream upload support and related configurations
2025-02-28 12:17:10 +08:00
krau
8e2dd37155 feat: add stream upload support and related configurations 2025-02-28 11:09:24 +08:00
krau
9c7ed833fd ci: add upx support 2025-02-28 09:45:34 +08:00
Krau
f9d601bd8a Merge pull request #30 from krau/dev
feat: cancel task
2025-02-27 22:34:58 +08:00
krau
152f473131 fix: delete done task 2025-02-27 22:25:10 +08:00
krau
7015081a84 feat: add context cancellation handling in saveFileWithRetry function 2025-02-27 22:07:41 +08:00
krau
be6444cf96 feat: implement task cancellation feature and update task handling 2025-02-27 22:02:16 +08:00
krau
98ba7c50e7 refactor: remove unused StoragePath initialization in AddToQueue function 2025-02-27 21:32:14 +08:00
28 changed files with 692 additions and 77 deletions

View File

@@ -1,3 +1,5 @@
name: Build Release
on:
push:
tags:
@@ -36,6 +38,9 @@ jobs:
matrix:
goos: [linux, darwin, windows]
goarch: [amd64, arm64]
exclude:
- goos: windows
goarch: arm64
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -53,6 +58,7 @@ jobs:
goos: ${{ matrix.goos }}
goarch: ${{ matrix.goarch }}
github_token: ${{ secrets.GITHUB_TOKEN }}
executable_compression: upx
extra_files: |
LICENSE
README.md

22
.github/workflows/docs.yml vendored Normal file
View File

@@ -0,0 +1,22 @@
name: Deploy Docs
on:
push:
branches:
- main
paths:
- "docs/**"
workflow_dispatch:
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- uses: actions/cache@v4
with:
key: ${{ github.ref }}
path: .cache
- run: pip install mkdocs-material
- run: cd docs && mkdocs gh-deploy --force

View File

@@ -1,6 +1,5 @@
<div align="center">
# <img src="docs/logo.jpg" width="45" align="center"> Save Any Bot
**简体中文** | [English](README_EN.md)
@@ -9,15 +8,6 @@
> _就像 PikPak Bot 一样_
</div
Demo Video:
<div align="center">
[SaveAny-Bot 演示视频 The Demo of SaveAny-Bot.webm](https://github.com/user-attachments/assets/a0de2453-a4d1-4a12-81fb-9d84856dce09)
</div>
## 部署

View File

@@ -153,7 +153,6 @@ func AddToQueue(ctx *ext.Context, update *ext.Update) error {
Status: types.Pending,
File: file,
StorageName: storageName,
StoragePath: path.Join(),
FileChatID: record.ChatID,
ReplyMessageID: record.ReplyMessageID,
FileMessageID: record.MessageID,
@@ -164,7 +163,7 @@ func AddToQueue(ctx *ext.Context, update *ext.Update) error {
task.StoragePath = path.Join(dir.Path, file.FileName)
}
queue.AddTask(task)
queue.AddTask(&task)
entityBuilder := entity.Builder{}
var entities []tg.MessageEntityClass

27
bot/handle_cancel_task.go Normal file
View File

@@ -0,0 +1,27 @@
package bot
import (
"strings"
"github.com/celestix/gotgproto/dispatcher"
"github.com/celestix/gotgproto/ext"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/queue"
)
func cancelTask(ctx *ext.Context, update *ext.Update) error {
key := strings.Split(string(update.CallbackQuery.Data), " ")[1]
ok := queue.CancelTask(key)
if ok {
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
QueryID: update.CallbackQuery.QueryID,
Message: "任务已取消",
})
return dispatcher.EndGroups
}
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
QueryID: update.CallbackQuery.QueryID,
Message: "任务取消失败",
})
return dispatcher.EndGroups
}

View File

@@ -22,5 +22,6 @@ func RegisterHandlers(dispatcher dispatcher.Dispatcher) {
dispatcher.AddHandler(handlers.NewMessage(linkRegexFilter, handleLinkMessage))
dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("add"), AddToQueue))
dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("set_default"), setDefaultStorage))
dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("cancel"), cancelTask))
dispatcher.AddHandler(handlers.NewMessage(filters.Message.Media, handleFileMessage))
}

View File

@@ -264,7 +264,7 @@ func HandleSilentAddTask(ctx *ext.Context, update *ext.Update, user *dao.User, t
})
return dispatcher.EndGroups
}
queue.AddTask(*task)
queue.AddTask(task)
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: fmt.Sprintf("已添加到队列: %s\n当前排队任务数: %d", task.FileName(), queue.Len()),
ID: task.ReplyMessageID,

View File

@@ -2,6 +2,7 @@
workers = 4 # 同时下载文件数
retry = 3 # 下载失败重试次数
threads = 4 # 单个任务下载最大线程数
stream = false # 使用stream模式, 详情请查看文档
[telegram]
# Bot Token

View File

@@ -13,6 +13,7 @@ type Config struct {
Retry int `toml:"retry" mapstructure:"retry"`
NoCleanCache bool `toml:"no_clean_cache" mapstructure:"no_clean_cache" json:"no_clean_cache"`
Threads int `toml:"threads" mapstructure:"threads" json:"threads"`
Stream bool `toml:"stream" mapstructure:"stream" json:"stream"`
Users []userConfig `toml:"users" mapstructure:"users" json:"users"`

View File

@@ -22,13 +22,12 @@ func worker(queue *queue.TaskQueue, semaphore chan struct{}) {
switch task.Status {
case types.Pending:
logger.L.Infof("Processing task: %s", task.String())
if err := processPendingTask(&task); err != nil {
logger.L.Errorf("Failed to do task: %s", err)
if err := processPendingTask(task); err != nil {
task.Error = err
if errors.Is(err, context.Canceled) {
logger.L.Debugf("Task canceled: %s", task.String())
task.Status = types.Canceled
} else {
logger.L.Errorf("Failed to do task: %s", err)
task.Status = types.Failed
}
} else {
@@ -37,23 +36,43 @@ func worker(queue *queue.TaskQueue, semaphore chan struct{}) {
queue.AddTask(task)
case types.Succeeded:
logger.L.Infof("Task succeeded: %s", task.String())
task.Ctx.(*ext.Context).EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: fmt.Sprintf("文件保存成功\n [%s]: %s", task.StorageName, task.StoragePath),
ID: task.ReplyMessageID,
})
extCtx, ok := task.Ctx.(*ext.Context)
if !ok {
logger.L.Errorf("Context is not *ext.Context: %T", task.Ctx)
} else {
extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: fmt.Sprintf("文件保存成功\n [%s]: %s", task.StorageName, task.StoragePath),
ID: task.ReplyMessageID,
})
}
case types.Failed:
logger.L.Errorf("Task failed: %s", task.String())
task.Ctx.(*ext.Context).EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: "文件保存失败\n" + task.Error.Error(),
ID: task.ReplyMessageID,
})
extCtx, ok := task.Ctx.(*ext.Context)
if !ok {
logger.L.Errorf("Context is not *ext.Context: %T", task.Ctx)
} else {
extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: "文件保存失败\n" + task.Error.Error(),
ID: task.ReplyMessageID,
})
}
case types.Canceled:
logger.L.Infof("Task canceled: %s", task.String())
extCtx, ok := task.Ctx.(*ext.Context)
if !ok {
logger.L.Errorf("Context is not *ext.Context: %T", task.Ctx)
} else {
extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: "任务已取消",
ID: task.ReplyMessageID,
})
}
default:
logger.L.Errorf("Unknown task status: %s", task.Status)
}
<-semaphore
logger.L.Debugf("Task done: %s", task.String())
logger.L.Debugf("Task done: %s; status: %s", task.String(), task.Status)
queue.DoneTask(task)
}
}

View File

@@ -1,6 +1,7 @@
package core
import (
"context"
"fmt"
"path/filepath"
"time"
@@ -48,26 +49,62 @@ func processPendingTask(task *types.Task) error {
return fmt.Errorf("context is not *ext.Context: %T", task.Ctx)
}
cancelCtx, cancel := context.WithCancel(ctx)
task.Cancel = cancel
downloadBuider := Downloader.Download(bot.Client.API(), task.File.Location).WithThreads(getTaskThreads(task.File.FileSize))
taskStreamStorage, isStreamStorage := taskStorage.(storage.StreamStorage)
if config.Cfg.Stream {
if !isStreamStorage {
logger.L.Warnf("存储 %s 不支持流式上传", taskStorage.Name())
} else {
text, entities := buildProgressMessageEntity(task, 0, task.StartTime, 0)
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: text,
Entities: entities,
ID: task.ReplyMessageID,
ReplyMarkup: getCancelTaskMarkup(task),
})
uploadStream, err := taskStreamStorage.NewUploadStream(cancelCtx, task.StoragePath)
if err != nil {
return fmt.Errorf("创建上传流失败: %w", err)
}
defer uploadStream.Close()
task.StartTime = time.Now()
progressCallback := buildProgressCallback(ctx, task, getProgressUpdateCount(task.File.FileSize))
progressStream := NewProgressStream(uploadStream, task.File.FileSize, progressCallback)
_, err = downloadBuider.Stream(cancelCtx, progressStream)
if err != nil {
return fmt.Errorf("下载文件失败: %w", err)
}
logger.L.Infof("Uploaded file: %s", task.StoragePath)
return nil
}
}
text, entities := buildProgressMessageEntity(task, 0, task.StartTime, 0)
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: text,
Entities: entities,
ID: task.ReplyMessageID,
Message: text,
Entities: entities,
ID: task.ReplyMessageID,
ReplyMarkup: getCancelTaskMarkup(task),
})
progressCallback := buildProgressCallback(ctx, task, getProgressUpdateCount(task.File.FileSize))
progressCallback := buildProgressCallback(ctx, task, getProgressUpdateCount(task.File.FileSize))
dest, err := NewTaskLocalFile(cacheDestPath, task.File.FileSize, progressCallback)
if err != nil {
return fmt.Errorf("创建文件失败: %w", err)
}
defer dest.Close()
task.StartTime = time.Now()
downloadBuider := Downloader.Download(bot.Client.API(), task.File.Location).WithThreads(getTaskThreads(task.File.FileSize))
_, err = downloadBuider.Parallel(ctx, dest)
_, err = downloadBuider.Parallel(cancelCtx, dest)
if err != nil {
return fmt.Errorf("下载文件失败: %w", err)
}
defer cleanCacheFile(cacheDestPath)
fixTaskFileExt(task, cacheDestPath)
@@ -78,5 +115,5 @@ func processPendingTask(task *types.Task) error {
ID: task.ReplyMessageID,
})
return saveFileWithRetry(task, taskStorage, cacheDestPath)
return saveFileWithRetry(cancelCtx, task, taskStorage, cacheDestPath)
}

View File

@@ -1,7 +1,9 @@
package core
import (
"context"
"fmt"
"io"
"os"
"path"
"time"
@@ -19,13 +21,21 @@ import (
"github.com/krau/SaveAny-Bot/types"
)
func saveFileWithRetry(task *types.Task, taskStorage storage.Storage, localFilePath string) error {
func saveFileWithRetry(ctx context.Context, task *types.Task, taskStorage storage.Storage, localFilePath string) error {
for i := 0; i <= config.Cfg.Retry; i++ {
if err := taskStorage.Save(task.Ctx, localFilePath, task.StoragePath); err != nil {
if err := ctx.Err(); err != nil {
return fmt.Errorf("context canceled while saving file: %w", err)
}
if err := taskStorage.Save(ctx, localFilePath, task.StoragePath); err != nil {
if i == config.Cfg.Retry {
return fmt.Errorf("failed to save file: %w", err)
}
logger.L.Errorf("Failed to save file: %s, retrying...", err)
select {
case <-ctx.Done():
return fmt.Errorf("context canceled during retry delay: %w", ctx.Err())
case <-time.After(time.Duration(i*500) * time.Millisecond):
}
continue
}
return nil
@@ -56,22 +66,9 @@ func processPhoto(task *types.Task, taskStorage storage.Storage, cachePath strin
logger.L.Infof("Downloaded file: %s", cachePath)
return saveFileWithRetry(task, taskStorage, cachePath)
return saveFileWithRetry(task.Ctx, task, taskStorage, cachePath)
}
// func getProgressBar(progress float64, updateCount int) string {
// bar := ""
// barSize := 100 / updateCount
// for i := 0; i < updateCount; i++ {
// if progress >= float64(barSize*(i+1)) {
// bar += "█"
// } else {
// bar += "░"
// }
// }
// return bar
// }
func cleanCacheFile(destPath string) {
if config.Cfg.Temp.CacheTTL > 0 {
common.RmFileAfter(destPath, time.Duration(config.Cfg.Temp.CacheTTL)*time.Second)
@@ -139,13 +136,20 @@ func buildProgressCallback(ctx *ext.Context, task *types.Task, updateCount int)
}
text, entities := buildProgressMessageEntity(task, bytesRead, task.StartTime, progress)
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: text,
Entities: entities,
ID: task.ReplyMessageID,
Message: text,
Entities: entities,
ID: task.ReplyMessageID,
ReplyMarkup: getCancelTaskMarkup(task),
})
}
}
func getCancelTaskMarkup(task *types.Task) *tg.ReplyInlineMarkup {
return &tg.ReplyInlineMarkup{
Rows: []tg.KeyboardButtonRow{{Buttons: []tg.KeyboardButtonClass{&tg.KeyboardButtonCallback{Text: "取消任务", Data: fmt.Appendf(nil, "cancel %s", task.Key())}}}},
}
}
func fixTaskFileExt(task *types.Task, localFilePath string) {
if path.Ext(task.FileName()) == "" {
mimeType, err := mimetype.DetectFile(localFilePath)
@@ -217,3 +221,40 @@ func NewTaskLocalFile(filePath string, fileSize int64, progressCallback func(byt
callbackInterval: callbackInterval,
}, nil
}
type ProgressStream struct {
writer io.Writer
size int64
done int64
callback func(bytesRead, contentLength int64)
nextAt int64
interval int64
}
func (ps *ProgressStream) Write(p []byte) (n int, err error) {
n, err = ps.writer.Write(p)
if err != nil {
return n, err
}
ps.done += int64(n)
if ps.callback != nil && ps.done >= ps.nextAt {
ps.callback(ps.done, ps.size)
ps.nextAt += ps.interval
}
return n, nil
}
func NewProgressStream(writer io.Writer, size int64, callback func(bytesRead, contentLength int64)) *ProgressStream {
var interval int64
interval = size / 100
if interval == 0 {
interval = 1
}
return &ProgressStream{
writer: writer,
size: size,
callback: callback,
nextAt: interval,
interval: interval,
}
}

View File

@@ -7,4 +7,7 @@ services:
- ./data:/app/data
- ./config.toml:/app/config.toml
- ./downloads:/app/downloads
- ./cache:/app/cache
- ./cache:/app/cache
# 使用 host 模式以便访问宿主机服务 (如代理)
# 如果你对 Docker 网络模式熟悉, 可以自行修改
network_mode: host

94
docs/docs/deploy.md Normal file
View File

@@ -0,0 +1,94 @@
# 部署指南
## 从二进制文件部署
在 [Release](https://github.com/krau/SaveAny-Bot/releases) 页面下载对应平台的二进制文件.
在解压后目录新建 `config.toml` 文件, 参考 [config.example.toml](./config.example.toml) 编辑配置文件.
运行:
```bash
chmod +x saveany-bot
./saveany-bot
```
### 添加为 systemd 服务
创建文件 `/etc/systemd/system/saveany-bot.service` 并写入以下内容:
```
[Unit]
Description=SaveAnyBot
After=systemd-user-sessions.service
[Service]
Type=simple
WorkingDirectory=/yourpath/
ExecStart=/yourpath/saveany-bot
Restart=on-failure
[Install]
WantedBy=multi-user.target
```
设为开机启动并启动服务:
```bash
systemctl enable --now saveany-bot
```
### 为OpenWrt及衍生系统添加开机自启动服务
创建文件 ` /etc/init.d/saveanybot` ,参考[saveanybot](./docs/saveanybot)自行修改.
`chmod +x /etc/init.d/saveanybot`
完成后,将文件复制到 `/etc/rc.d`并重命名为`S99saveanybot`.
`chmod +x /etc/rc.d/S99saveanybot`
### 为OpenWrt及衍生系统添加快捷指令
创建文件` /usr/bin/sabot` ,参考[sabot](./docs/sabot)自行配置修改,注意此处文件编码仅支持 ANSI 936 .
`chmod +x /usr/bin/sabot`
之后,终端输入`sabot start|stop|restart|status|enable|disable`即可.
## 使用 Docker 部署
### Docker Compose
下载 [docker-compose.yml](./docker-compose.yml) 文件, 在同目录下新建 `config.toml` 文件, 参考 [config.example.toml](./config.example.toml) 编辑配置文件.
启动:
```bash
docker compose up -d
```
### Docker
```shell
docker run -d --name saveany-bot \
-v /path/to/config.toml:/app/config.toml \
-v /path/to/downloads:/app/downloads \
ghcr.io/krau/saveany-bot:latest
```
## 更新
使用 `upgrade``up` 升级到最新版
```bash
./saveany-bot upgrade
```
如果是 Docker 部署, 使用以下命令更新:
```bash
docker pull ghcr.io/krau/saveany-bot:latest
docker restart saveany-bot
```

20
docs/docs/faq.md Normal file
View File

@@ -0,0 +1,20 @@
# 常见问题
## 上传 alist 失败也会显示成功
这是 alist 的上传实现导致的问题, 上传到 alist 的文件实际上会被 alist 暂存在本地, 在客户端上传结束后 alist 就返回成功, 然后 alist 会在后台将文件上传到对应的存储.
目前 bot 是根据 alist 的返回判断是否成功, 无法获知 alist 的后台上传任务是否成功.
在 alist 管理页面适当调整上传分片大小, 为 alist 使用更稳定的网络环境部署, 都可以减少这种情况的发生.
## Bot 提示下载成功但是 alist 未显示
检查 alist 后台 > 任务 > 上传 中对应的上传任务的状态, 如果任务状态为成功但目录中不显示, 是由于 alist 缓存了目录结构, 参考文档可以调整缓存时间
https://alist.nn.ci/zh/guide/drivers/common.html#缓存过期
## docker部署配置了代理后仍无法连接 telegram (初始化客户端超时)
docker 不能直接访问宿主机网络, 如果你不熟悉其用法, 请将容器设为 host 模式:

35
docs/docs/help.md Normal file
View File

@@ -0,0 +1,35 @@
# 使用帮助
## 保存文件
Bot 接受两种消息: 文件和链接.
目前, 链接仅支持公开频道 (具有用户名) 的链接, 例如: `https://t.me/acherkrau/1097`.
**即使频道禁止了转发和保存, Bot 依然可以下载其文件.**
## 静默模式 (silent)
使用 `/silent` 命令可以开关静默模式.
默认情况下不开启静默模式, Bot 会询问你每个文件的保存位置.
开启静默模式后, Bot 会直接保存文件到默认位置, 无需确认.
在开启静默模式之前, 需要使用 `/storage` 命令设置默认保存位置.
## Stream 模式
在配置文件中将 `stream` 设置为 `true` 可以开启 Stream 模式.
未开启时, Bot 处理任务分为两步: 下载和上传. Bot 会将文件暂存到本地, 然后上传到对应存储位置, 最后删除本地文件.
开启后, Bot 将直接将文件流式传输到存储端, 不需要下载到本地.
该功能对于硬盘空间有限的部署环境十分有用, 然而相较于普通模式也具有一些弊端:
- 无法使用多线程从 telegram 下载文件, 速度较慢.
- 网络不稳定时, 任务失败率高.
- 无法在中间层对文件进行处理, 例如自动文件类型识别.
虽然目前 Bot 适配的所有存储端 (Alist, 本地磁盘, Webdav) 都支持 Stream 模式, 但今后可能会有不支持的存储端, 此时即使开启 Stream 模式, Bot 也会自动切换到普通模式.

7
docs/docs/index.md Normal file
View File

@@ -0,0 +1,7 @@
# SaveAnyBot 文档
SaveAnyBot 是一个可以保存 Telegram 上的文件到云存储的机器人, 就像 PikPak Bot 一样.
不同的是, SaveAnyBot 提供更灵活的存储端选择, 并实现一些更强大的功能.
本项目以 AGPL-3.0 协议开源, 请遵守协议使用.

33
docs/mkdocs.yml Normal file
View File

@@ -0,0 +1,33 @@
site_name: SaveAnyBot 官方文档
site_author: Krau
site_description: SaveAnyBot 是一个可以保存 Telegram 上的文件到多种云存储的机器人, 本文档将帮助你了解如何部署和使用它.
repo_name: krau/saveany-bot
repo_url: https://github.com/krau/saveany-bot
copyright: CC BY-NC-SA 4.0
theme:
name: material
language: zh
highlightjs: true
palette:
- media: "(prefers-color-scheme)"
toggle:
icon: material/brightness-auto
name: 切换主题
- media: "(prefers-color-scheme: light)"
scheme: default
primary: indigo
toggle:
icon: material/brightness-7
name: 暗色模式
- media: "(prefers-color-scheme: dark)"
scheme: slate
primary: blue grey
toggle:
icon: material/brightness-4
name: 亮色模式
nav:
- index.md
- deploy.md
- help.md
- faq.md

1
go.mod
View File

@@ -12,7 +12,6 @@ require (
github.com/rhysd/go-github-selfupdate v1.2.3
github.com/spf13/cobra v1.8.1
github.com/spf13/viper v1.19.0
github.com/studio-b12/gowebdav v0.10.0
golang.org/x/net v0.35.0
golang.org/x/time v0.10.0
)

2
go.sum
View File

@@ -172,8 +172,6 @@ github.com/spf13/viper v1.19.0 h1:RWq5SEjt8o25SROyN3z2OrDB9l7RPd3lwTWU8EcEdcI=
github.com/spf13/viper v1.19.0/go.mod h1:GQUN9bilAbhU/jgc1bKs99f/suXKeUMct8Adx5+Ntkg=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/studio-b12/gowebdav v0.10.0 h1:Yewz8FFiadcGEu4hxS/AAJQlHelndqln1bns3hcJIYc=
github.com/studio-b12/gowebdav v0.10.0/go.mod h1:bHA7t77X/QFExdeAnDzK6vKM34kEZAcE1OX4MfiwjkE=
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
github.com/tcnksm/go-gitconfig v0.1.2 h1:iiDhRitByXAEyjgBqsKi9QU4o2TNtv9kPP3RgPgXBPw=

View File

@@ -8,30 +8,65 @@ import (
)
type TaskQueue struct {
list *list.List
cond *sync.Cond
mutex *sync.Mutex
list *list.List
cond *sync.Cond
mutex *sync.Mutex
activeMap map[string]*types.Task
}
func (q *TaskQueue) AddTask(task types.Task) {
func (q *TaskQueue) AddTask(task *types.Task) {
q.mutex.Lock()
defer q.mutex.Unlock()
q.list.PushBack(task)
q.cond.Signal()
if task.Status != types.Pending {
delete(q.activeMap, task.Key())
}
}
func (q *TaskQueue) GetTask() types.Task {
func (q *TaskQueue) GetTask() *types.Task {
q.mutex.Lock()
defer q.mutex.Unlock()
for q.list.Len() == 0 {
q.cond.Wait()
}
e := q.list.Front()
task := e.Value.(types.Task)
task := e.Value.(*types.Task)
q.list.Remove(e)
if task.Status == types.Pending {
q.activeMap[task.Key()] = task
}
return task
}
func (q *TaskQueue) DoneTask(task *types.Task) {
q.mutex.Lock()
defer q.mutex.Unlock()
delete(q.activeMap, task.Key())
}
func (q *TaskQueue) CancelTask(key string) bool {
q.mutex.Lock()
defer q.mutex.Unlock()
if task, ok := q.activeMap[key]; ok {
if task.Cancel != nil {
task.Cancel()
return true
}
}
for e := q.list.Front(); e != nil; e = e.Next() {
task := e.Value.(*types.Task)
if task.Key() == key {
if task.Cancel != nil {
task.Cancel()
}
q.list.Remove(e)
return true
}
}
return false
}
func (q *TaskQueue) Len() int {
q.mutex.Lock()
defer q.mutex.Unlock()
@@ -47,20 +82,29 @@ func init() {
func NewQueue() *TaskQueue {
m := &sync.Mutex{}
return &TaskQueue{
list: list.New(),
cond: sync.NewCond(m),
mutex: m,
list: list.New(),
cond: sync.NewCond(m),
mutex: m,
activeMap: make(map[string]*types.Task),
}
}
func AddTask(task types.Task) {
func AddTask(task *types.Task) {
Queue.AddTask(task)
}
func GetTask() types.Task {
func GetTask() *types.Task {
return Queue.GetTask()
}
func Len() int {
return Queue.Len()
}
func CancelTask(key string) bool {
return Queue.CancelTask(key)
}
func DoneTask(task *types.Task) {
Queue.DoneTask(task)
}

View File

@@ -9,6 +9,7 @@ import (
"net/url"
"os"
"path"
"sync"
"time"
"github.com/krau/SaveAny-Bot/config"
@@ -150,3 +151,88 @@ func (a *Alist) Save(ctx context.Context, filePath, storagePath string) error {
func (a *Alist) JoinStoragePath(task types.Task) string {
return path.Join(a.config.BasePath, task.StoragePath)
}
type uploadStream struct {
ctx context.Context
client *http.Client
token string
storagePath string
baseURL string
pr *io.PipeReader
pw *io.PipeWriter
errChan chan error
once sync.Once
}
func (us *uploadStream) Write(p []byte) (int, error) {
return us.pw.Write(p)
}
func (us *uploadStream) Close() error {
var uploadErr error
us.once.Do(func() {
if err := us.pw.Close(); err != nil {
uploadErr = fmt.Errorf("failed to close pipe writer: %w", err)
return
}
if err := <-us.errChan; err != nil {
uploadErr = err
}
})
return uploadErr
}
func (a *Alist) NewUploadStream(ctx context.Context, storagePath string) (io.WriteCloser, error) {
if a.token == "" {
if err := a.getToken(); err != nil {
return nil, fmt.Errorf("not logged in to Alist: %w", err)
}
}
pr, pw := io.Pipe()
// 创建上传流对象
us := &uploadStream{
ctx: ctx,
client: a.client,
token: a.token,
storagePath: storagePath,
baseURL: a.baseURL,
pr: pr,
pw: pw,
errChan: make(chan error, 1),
}
go func() {
defer close(us.errChan)
req, err := http.NewRequestWithContext(ctx, http.MethodPut, a.baseURL+"/api/fs/put", pr)
if err != nil {
us.errChan <- fmt.Errorf("failed to create request: %w", err)
return
}
req.Header.Set("Authorization", a.token)
req.Header.Set("File-Path", url.PathEscape(storagePath))
req.Header.Set("As-Task", "true")
req.Header.Set("Content-Type", "application/octet-stream")
resp, err := a.client.Do(req)
if err != nil {
us.errChan <- fmt.Errorf("failed to send request: %w", err)
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
us.errChan <- fmt.Errorf("failed to upload file, status code: %d, response: %s", resp.StatusCode, string(body))
return
}
us.errChan <- nil
}()
return us, nil
}

View File

@@ -3,6 +3,7 @@ package local
import (
"context"
"fmt"
"io"
"os"
"path/filepath"
@@ -55,3 +56,18 @@ func (l *Local) Save(ctx context.Context, filePath, storagePath string) error {
func (l *Local) JoinStoragePath(task types.Task) string {
return filepath.Join(l.config.BasePath, task.StoragePath)
}
func (l *Local) NewUploadStream(ctx context.Context, path string) (io.WriteCloser, error) {
absPath, err := filepath.Abs(path)
if err != nil {
return nil, err
}
if err := fileutil.CreateDir(filepath.Dir(absPath)); err != nil {
return nil, err
}
file, err := os.Create(absPath)
if err != nil {
return nil, err
}
return file, nil
}

View File

@@ -3,6 +3,7 @@ package storage
import (
"context"
"fmt"
"io"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/logger"
@@ -20,6 +21,11 @@ type Storage interface {
Save(cttx context.Context, localFilePath, storagePath string) error
}
type StreamStorage interface {
Storage
NewUploadStream(ctx context.Context, path string) (io.WriteCloser, error)
}
var Storages = make(map[string]Storage)
var UserStorages = make(map[int64][]Storage)

70
storage/webdav/client.go Normal file
View File

@@ -0,0 +1,70 @@
package webdav
import (
"context"
"fmt"
"io"
"net/http"
"strings"
)
type Client struct {
BaseURL string
Username string
Password string
httpClient *http.Client
}
func NewClient(baseURL, username, password string, httpClient *http.Client) *Client {
if !strings.HasSuffix(baseURL, "/") {
baseURL += "/"
}
if httpClient == nil {
httpClient = http.DefaultClient
}
return &Client{
BaseURL: baseURL,
Username: username,
Password: password,
httpClient: httpClient,
}
}
func (c *Client) doRequest(ctx context.Context, method, url string, body io.Reader) (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, method, url, body)
if err != nil {
return nil, err
}
if c.Username != "" && c.Password != "" {
req.SetBasicAuth(c.Username, c.Password)
}
return c.httpClient.Do(req)
}
func (c *Client) MkDir(ctx context.Context, dirPath string) error {
url := c.BaseURL + dirPath
resp, err := c.doRequest(ctx, "MKCOL", url, nil)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
return nil
}
return fmt.Errorf("MKCOL: %s", resp.Status)
}
func (c *Client) WriteFile(ctx context.Context, remotePath string, content io.Reader) error {
url := c.BaseURL + remotePath
resp, err := c.doRequest(ctx, "PUT", url, content)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
return nil
}
return fmt.Errorf("PUT: %s", resp.Status)
}

58
storage/webdav/stream.go Normal file
View File

@@ -0,0 +1,58 @@
package webdav
import (
"context"
"fmt"
"io"
"path"
"github.com/krau/SaveAny-Bot/logger"
)
type WebdavWriter struct {
pipeWriter *io.PipeWriter
done chan error
path string
}
func (w *WebdavWriter) Write(p []byte) (n int, err error) {
return w.pipeWriter.Write(p)
}
func (w *WebdavWriter) Close() error {
if err := w.pipeWriter.Close(); err != nil {
return err
}
if err := <-w.done; err != nil {
return fmt.Errorf("upload failed: %w", err)
}
return nil
}
func (w *Webdav) NewUploadStream(ctx context.Context, storagePath string) (io.WriteCloser, error) {
if err := w.client.MkDir(ctx, path.Dir(storagePath)); err != nil {
logger.L.Errorf("Failed to create directory %s: %v", path.Dir(storagePath), err)
return nil, ErrFailedToCreateDirectory
}
pipeReader, pipeWriter := io.Pipe()
done := make(chan error, 1)
go func() {
defer func() {
if err := recover(); err != nil {
done <- fmt.Errorf("panic during upload: %v", err)
}
}()
err := w.client.WriteFile(ctx, storagePath, pipeReader)
pipeReader.Close()
done <- err
}()
return &WebdavWriter{
pipeWriter: pipeWriter,
done: done,
path: storagePath,
}, nil
}

View File

@@ -3,6 +3,7 @@ package webdav
import (
"context"
"fmt"
"net/http"
"os"
"path"
"time"
@@ -10,12 +11,11 @@ import (
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/logger"
"github.com/krau/SaveAny-Bot/types"
"github.com/studio-b12/gowebdav"
)
type Webdav struct {
config config.WebdavStorageConfig
client *gowebdav.Client
client *Client
}
func (w *Webdav) Init(cfg config.StorageConfig) error {
@@ -27,12 +27,9 @@ func (w *Webdav) Init(cfg config.StorageConfig) error {
return err
}
w.config = *webdavConfig
client := gowebdav.NewClient(webdavConfig.URL, webdavConfig.Username, webdavConfig.Password)
if err := client.Connect(); err != nil {
return fmt.Errorf("failed to connect to webdav server: %w", err)
}
client.SetTimeout(12 * time.Hour)
w.client = client
w.client = NewClient(w.config.URL, w.config.Username, w.config.Password, &http.Client{
Timeout: time.Hour * 12,
})
return nil
}
@@ -46,7 +43,7 @@ func (w *Webdav) Name() string {
func (w *Webdav) Save(ctx context.Context, filePath, storagePath string) error {
logger.L.Infof("Saving file %s to %s", filePath, storagePath)
if err := w.client.MkdirAll(path.Dir(storagePath), os.ModePerm); err != nil {
if err := w.client.MkDir(ctx, path.Dir(storagePath)); err != nil {
logger.L.Errorf("Failed to create directory %s: %v", path.Dir(storagePath), err)
return ErrFailedToCreateDirectory
}
@@ -57,7 +54,7 @@ func (w *Webdav) Save(ctx context.Context, filePath, storagePath string) error {
}
defer file.Close()
if err := w.client.WriteStream(storagePath, file, os.ModePerm); err != nil {
if err := w.client.WriteFile(ctx, storagePath, file); err != nil {
logger.L.Errorf("Failed to write file %s: %v", storagePath, err)
return ErrFailedToWriteFile
}

View File

@@ -36,6 +36,7 @@ var StorageTypeDisplay = map[StorageType]string{
type Task struct {
Ctx context.Context
Cancel context.CancelFunc
Error error
Status TaskStatus
File *File
@@ -52,6 +53,10 @@ type Task struct {
UserID int64
}
func (t Task) Key() string {
return fmt.Sprintf("%d:%d", t.FileChatID, t.FileMessageID)
}
func (t Task) String() string {
return fmt.Sprintf("[%d:%d]:%s", t.FileChatID, t.FileMessageID, t.File.FileName)
}