mirror of
https://github.com/krau/SaveAny-Bot.git
synced 2026-05-10 17:52:44 +08:00
Compare commits
33 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0071780ff4 | ||
|
|
0a95431468 | ||
|
|
34525c5b11 | ||
|
|
6ac6d79fb6 | ||
|
|
f21a82ad43 | ||
|
|
73f6647f8d | ||
|
|
6fbb4609f9 | ||
|
|
802c908384 | ||
|
|
5d403056d0 | ||
|
|
8e2dd37155 | ||
|
|
9c7ed833fd | ||
|
|
f9d601bd8a | ||
|
|
152f473131 | ||
|
|
7015081a84 | ||
|
|
be6444cf96 | ||
|
|
98ba7c50e7 | ||
|
|
0c31d908cc | ||
|
|
9e776b22fb | ||
|
|
d6f8603656 | ||
|
|
9c42bee662 | ||
|
|
b96340dd46 | ||
|
|
a5ba01e219 | ||
|
|
d00e907735 | ||
|
|
418f9bd2bc | ||
|
|
28b4585dba | ||
|
|
d2669f0c99 | ||
|
|
c9921926e3 | ||
|
|
d7cd2ede01 | ||
|
|
ed21b65c98 | ||
|
|
8975589c43 | ||
|
|
27dca2e343 | ||
|
|
5c8261c34a | ||
|
|
cbc2dc82d8 |
3
.github/workflows/build-release.yml
vendored
3
.github/workflows/build-release.yml
vendored
@@ -1,3 +1,5 @@
|
|||||||
|
name: Build Release
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
tags:
|
tags:
|
||||||
@@ -53,6 +55,7 @@ jobs:
|
|||||||
goos: ${{ matrix.goos }}
|
goos: ${{ matrix.goos }}
|
||||||
goarch: ${{ matrix.goarch }}
|
goarch: ${{ matrix.goarch }}
|
||||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
executable_compression: upx
|
||||||
extra_files: |
|
extra_files: |
|
||||||
LICENSE
|
LICENSE
|
||||||
README.md
|
README.md
|
||||||
|
|||||||
22
.github/workflows/docs.yml
vendored
Normal file
22
.github/workflows/docs.yml
vendored
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
name: Deploy Docs
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
paths:
|
||||||
|
- "docs/**"
|
||||||
|
workflow_dispatch:
|
||||||
|
jobs:
|
||||||
|
deploy:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.11"
|
||||||
|
- uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
key: ${{ github.ref }}
|
||||||
|
path: .cache
|
||||||
|
- run: pip install mkdocs-material
|
||||||
|
- run: cd docs && mkdocs gh-deploy --force
|
||||||
31
README.md
31
README.md
@@ -8,14 +8,6 @@
|
|||||||
|
|
||||||
> _就像 PikPak Bot 一样_
|
> _就像 PikPak Bot 一样_
|
||||||
|
|
||||||
</div
|
|
||||||
|
|
||||||
Demo Video:
|
|
||||||
|
|
||||||
<div align="center">
|
|
||||||
|
|
||||||
[SaveAny-Bot 演示视频 | The Demo of SaveAny-Bot.webm](https://github.com/user-attachments/assets/a0de2453-a4d1-4a12-81fb-9d84856dce09)
|
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
## 部署
|
## 部署
|
||||||
@@ -24,7 +16,7 @@ Demo Video:
|
|||||||
|
|
||||||
在 [Release](https://github.com/krau/SaveAny-Bot/releases) 页面下载对应平台的二进制文件.
|
在 [Release](https://github.com/krau/SaveAny-Bot/releases) 页面下载对应平台的二进制文件.
|
||||||
|
|
||||||
在解压后目录新建 `config.toml` 文件, 参考 [config.toml.example](https://github.com/krau/SaveAny-Bot/blob/main/config.example.toml) 编辑配置文件.
|
在解压后目录新建 `config.toml` 文件, 参考 [config.example.toml](./config.example.toml) 编辑配置文件.
|
||||||
|
|
||||||
运行:
|
运行:
|
||||||
|
|
||||||
@@ -58,11 +50,30 @@ WantedBy=multi-user.target
|
|||||||
systemctl enable --now saveany-bot
|
systemctl enable --now saveany-bot
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### 为OpenWrt及衍生系统添加开机自启动服务
|
||||||
|
|
||||||
|
创建文件 ` /etc/init.d/saveanybot` ,参考[saveanybot](./docs/saveanybot)自行修改.
|
||||||
|
|
||||||
|
`chmod +x /etc/init.d/saveanybot`
|
||||||
|
|
||||||
|
完成后,将文件复制到 `/etc/rc.d`并重命名为`S99saveanybot`.
|
||||||
|
|
||||||
|
`chmod +x /etc/rc.d/S99saveanybot`
|
||||||
|
|
||||||
|
#### 为OpenWrt及衍生系统添加快捷指令
|
||||||
|
|
||||||
|
创建文件` /usr/bin/sabot` ,参考[sabot](./docs/sabot)自行配置修改,注意此处文件编码仅支持 ANSI 936 .
|
||||||
|
|
||||||
|
`chmod +x /usr/bin/sabot`
|
||||||
|
|
||||||
|
之后,终端输入`sabot start|stop|restart|status|enable|disable`即可.
|
||||||
|
|
||||||
|
|
||||||
### 使用 Docker 部署
|
### 使用 Docker 部署
|
||||||
|
|
||||||
#### Docker Compose
|
#### Docker Compose
|
||||||
|
|
||||||
下载 [docker-compose.yml](https://github.com/krau/SaveAny-Bot/blob/main/docker-compose.yml) 文件, 在同目录下新建 `config.toml` 文件, 参考 [config.toml.example](https://github.com/krau/SaveAny-Bot/blob/main/config.example.toml) 编辑配置文件.
|
下载 [docker-compose.yml](./docker-compose.yml) 文件, 在同目录下新建 `config.toml` 文件, 参考 [config.example.toml](./config.example.toml) 编辑配置文件.
|
||||||
|
|
||||||
启动:
|
启动:
|
||||||
|
|
||||||
|
|||||||
@@ -1,26 +0,0 @@
|
|||||||
package bootstrap
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
|
|
||||||
"github.com/krau/SaveAny-Bot/bot"
|
|
||||||
"github.com/krau/SaveAny-Bot/common"
|
|
||||||
"github.com/krau/SaveAny-Bot/config"
|
|
||||||
"github.com/krau/SaveAny-Bot/dao"
|
|
||||||
"github.com/krau/SaveAny-Bot/logger"
|
|
||||||
"github.com/krau/SaveAny-Bot/storage"
|
|
||||||
)
|
|
||||||
|
|
||||||
func InitAll() {
|
|
||||||
if err := config.Init(); err != nil {
|
|
||||||
fmt.Println("加载配置文件失败: ", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
logger.InitLogger()
|
|
||||||
logger.L.Info("正在启动 SaveAny-Bot...")
|
|
||||||
storage.LoadStorages()
|
|
||||||
common.Init()
|
|
||||||
dao.Init()
|
|
||||||
bot.Init()
|
|
||||||
}
|
|
||||||
@@ -76,6 +76,7 @@ func Init() {
|
|||||||
{Command: "silent", Description: "开启/关闭静默模式"},
|
{Command: "silent", Description: "开启/关闭静默模式"},
|
||||||
{Command: "storage", Description: "设置默认存储端"},
|
{Command: "storage", Description: "设置默认存储端"},
|
||||||
{Command: "save", Description: "保存所回复的文件"},
|
{Command: "save", Description: "保存所回复的文件"},
|
||||||
|
{Command: "dir", Description: "管理存储文件夹"},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
resultChan <- struct {
|
resultChan <- struct {
|
||||||
|
|||||||
188
bot/handle_add_task.go
Normal file
188
bot/handle_add_task.go
Normal file
@@ -0,0 +1,188 @@
|
|||||||
|
package bot
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"path"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/celestix/gotgproto/dispatcher"
|
||||||
|
"github.com/celestix/gotgproto/ext"
|
||||||
|
"github.com/duke-git/lancet/v2/slice"
|
||||||
|
"github.com/gotd/td/telegram/message/entity"
|
||||||
|
"github.com/gotd/td/telegram/message/styling"
|
||||||
|
"github.com/gotd/td/tg"
|
||||||
|
"github.com/krau/SaveAny-Bot/config"
|
||||||
|
"github.com/krau/SaveAny-Bot/dao"
|
||||||
|
"github.com/krau/SaveAny-Bot/logger"
|
||||||
|
"github.com/krau/SaveAny-Bot/queue"
|
||||||
|
"github.com/krau/SaveAny-Bot/types"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
func AddToQueue(ctx *ext.Context, update *ext.Update) error {
|
||||||
|
// TODO: 回调数据用户独立鉴权 (处理 bot 在群聊中的情况)
|
||||||
|
if !slice.Contain(config.Cfg.GetUsersID(), update.CallbackQuery.UserID) {
|
||||||
|
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
||||||
|
QueryID: update.CallbackQuery.QueryID,
|
||||||
|
Alert: true,
|
||||||
|
Message: "你没有权限",
|
||||||
|
CacheTime: 5,
|
||||||
|
})
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
args := strings.Split(string(update.CallbackQuery.Data), " ")
|
||||||
|
addToDir := args[0] == "add_to_dir"
|
||||||
|
cbDataId, _ := strconv.Atoi(args[1])
|
||||||
|
cbData, err := dao.GetCallbackData(uint(cbDataId))
|
||||||
|
if err != nil {
|
||||||
|
logger.L.Errorf("获取回调数据失败: %s", err)
|
||||||
|
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
||||||
|
QueryID: update.CallbackQuery.QueryID,
|
||||||
|
Alert: true,
|
||||||
|
Message: "获取回调数据失败",
|
||||||
|
CacheTime: 5,
|
||||||
|
})
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
|
||||||
|
data := strings.Split(cbData, " ")
|
||||||
|
fileChatID, _ := strconv.Atoi(data[0])
|
||||||
|
fileMessageID, _ := strconv.Atoi(data[1])
|
||||||
|
storageName := data[2]
|
||||||
|
dirIdInt, _ := strconv.Atoi(data[3])
|
||||||
|
dirId := uint(dirIdInt)
|
||||||
|
|
||||||
|
user, err := dao.GetUserByChatID(update.CallbackQuery.UserID)
|
||||||
|
if err != nil {
|
||||||
|
logger.L.Errorf("获取用户失败: %s", err)
|
||||||
|
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
||||||
|
QueryID: update.CallbackQuery.QueryID,
|
||||||
|
Alert: true,
|
||||||
|
Message: "获取用户失败",
|
||||||
|
CacheTime: 5,
|
||||||
|
})
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
|
||||||
|
if !addToDir {
|
||||||
|
dirs, err := dao.GetDirsByUserIDAndStorageName(user.ID, storageName)
|
||||||
|
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
logger.L.Errorf("获取路径失败: %s", err)
|
||||||
|
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
||||||
|
QueryID: update.CallbackQuery.QueryID,
|
||||||
|
Alert: true,
|
||||||
|
Message: "获取路径失败",
|
||||||
|
CacheTime: 5,
|
||||||
|
})
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
if len(dirs) != 0 {
|
||||||
|
markup, err := getSelectDirMarkup(fileChatID, fileMessageID, storageName, dirs)
|
||||||
|
if err != nil {
|
||||||
|
logger.L.Errorf("获取路径失败: %s", err)
|
||||||
|
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
||||||
|
QueryID: update.CallbackQuery.QueryID,
|
||||||
|
Alert: true,
|
||||||
|
Message: "获取路径失败",
|
||||||
|
CacheTime: 5,
|
||||||
|
})
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
_, err = ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
|
||||||
|
ID: update.CallbackQuery.GetMsgID(),
|
||||||
|
Message: "请选择要保存到的路径",
|
||||||
|
ReplyMarkup: markup,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
logger.L.Errorf("编辑消息失败: %s", err)
|
||||||
|
}
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.L.Tracef("Got add to queue: chatID: %d, messageID: %d, storage: %s", fileChatID, fileMessageID, storageName)
|
||||||
|
record, err := dao.GetReceivedFileByChatAndMessageID(int64(fileChatID), fileMessageID)
|
||||||
|
if err != nil {
|
||||||
|
logger.L.Errorf("获取记录失败: %s", err)
|
||||||
|
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
||||||
|
QueryID: update.CallbackQuery.QueryID,
|
||||||
|
Alert: true,
|
||||||
|
Message: "查询记录失败",
|
||||||
|
CacheTime: 5,
|
||||||
|
})
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
if update.CallbackQuery.MsgID != record.ReplyMessageID {
|
||||||
|
record.ReplyMessageID = update.CallbackQuery.MsgID
|
||||||
|
if err := dao.SaveReceivedFile(record); err != nil {
|
||||||
|
logger.L.Errorf("更新接收的文件失败: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var dir *dao.Dir
|
||||||
|
if addToDir && dirId != 0 {
|
||||||
|
dir, err = dao.GetDirByID(dirId)
|
||||||
|
if err != nil {
|
||||||
|
logger.L.Errorf("获取路径失败: %s", err)
|
||||||
|
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
||||||
|
QueryID: update.CallbackQuery.QueryID,
|
||||||
|
Alert: true,
|
||||||
|
Message: "获取路径失败",
|
||||||
|
CacheTime: 5,
|
||||||
|
})
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
file, err := FileFromMessage(ctx, record.ChatID, record.MessageID, record.FileName)
|
||||||
|
if err != nil {
|
||||||
|
logger.L.Errorf("获取消息中的文件失败: %s", err)
|
||||||
|
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
||||||
|
QueryID: update.CallbackQuery.QueryID,
|
||||||
|
Alert: true,
|
||||||
|
Message: fmt.Sprintf("获取消息中的文件失败: %s", err),
|
||||||
|
CacheTime: 5,
|
||||||
|
})
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
|
||||||
|
task := types.Task{
|
||||||
|
Ctx: ctx,
|
||||||
|
Status: types.Pending,
|
||||||
|
File: file,
|
||||||
|
StorageName: storageName,
|
||||||
|
FileChatID: record.ChatID,
|
||||||
|
ReplyMessageID: record.ReplyMessageID,
|
||||||
|
FileMessageID: record.MessageID,
|
||||||
|
ReplyChatID: record.ReplyChatID,
|
||||||
|
UserID: update.GetUserChat().GetID(),
|
||||||
|
}
|
||||||
|
if dir != nil {
|
||||||
|
task.StoragePath = path.Join(dir.Path, file.FileName)
|
||||||
|
}
|
||||||
|
|
||||||
|
queue.AddTask(&task)
|
||||||
|
|
||||||
|
entityBuilder := entity.Builder{}
|
||||||
|
var entities []tg.MessageEntityClass
|
||||||
|
text := fmt.Sprintf("已添加到任务队列\n文件名: %s\n当前排队任务数: %d", record.FileName, queue.Len())
|
||||||
|
if err := styling.Perform(&entityBuilder,
|
||||||
|
styling.Plain("已添加到任务队列\n文件名: "),
|
||||||
|
styling.Code(record.FileName),
|
||||||
|
styling.Plain("\n当前排队任务数: "),
|
||||||
|
styling.Bold(strconv.Itoa(queue.Len())),
|
||||||
|
); err != nil {
|
||||||
|
logger.L.Errorf("Failed to build entity: %s", err)
|
||||||
|
} else {
|
||||||
|
text, entities = entityBuilder.Complete()
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
|
||||||
|
Message: text,
|
||||||
|
Entities: entities,
|
||||||
|
ID: record.ReplyMessageID,
|
||||||
|
})
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
27
bot/handle_cancel_task.go
Normal file
27
bot/handle_cancel_task.go
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
package bot
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/celestix/gotgproto/dispatcher"
|
||||||
|
"github.com/celestix/gotgproto/ext"
|
||||||
|
"github.com/gotd/td/tg"
|
||||||
|
"github.com/krau/SaveAny-Bot/queue"
|
||||||
|
)
|
||||||
|
|
||||||
|
func cancelTask(ctx *ext.Context, update *ext.Update) error {
|
||||||
|
key := strings.Split(string(update.CallbackQuery.Data), " ")[1]
|
||||||
|
ok := queue.CancelTask(key)
|
||||||
|
if ok {
|
||||||
|
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
||||||
|
QueryID: update.CallbackQuery.QueryID,
|
||||||
|
Message: "任务已取消",
|
||||||
|
})
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
||||||
|
QueryID: update.CallbackQuery.QueryID,
|
||||||
|
Message: "任务取消失败",
|
||||||
|
})
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
88
bot/handle_dir.go
Normal file
88
bot/handle_dir.go
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
package bot
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/celestix/gotgproto/dispatcher"
|
||||||
|
"github.com/celestix/gotgproto/ext"
|
||||||
|
"github.com/gotd/td/telegram/message/styling"
|
||||||
|
"github.com/krau/SaveAny-Bot/dao"
|
||||||
|
"github.com/krau/SaveAny-Bot/logger"
|
||||||
|
"github.com/krau/SaveAny-Bot/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
func dirCmd(ctx *ext.Context, update *ext.Update) error {
|
||||||
|
args := strings.Split(strings.TrimPrefix(update.EffectiveMessage.Text, "/dir "), " ")
|
||||||
|
if len(args) < 3 {
|
||||||
|
dirs, err := dao.GetUserDirsByChatID(update.GetUserChat().GetID())
|
||||||
|
if err != nil {
|
||||||
|
logger.L.Errorf("获取用户路径失败: %s", err)
|
||||||
|
ctx.Reply(update, ext.ReplyTextString("获取用户路径失败"), nil)
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
ctx.Reply(update, ext.ReplyTextStyledTextArray(
|
||||||
|
[]styling.StyledTextOption{
|
||||||
|
styling.Bold("使用方法: /dir <操作> <存储名> <路径>"),
|
||||||
|
styling.Plain("\n\n可用操作:\n"),
|
||||||
|
styling.Code("add"),
|
||||||
|
styling.Plain(" - 添加路径\n"),
|
||||||
|
styling.Code("del"),
|
||||||
|
styling.Plain(" - 删除路径\n"),
|
||||||
|
styling.Plain("\n示例:\n"),
|
||||||
|
styling.Code("/dir add local1 path/to/dir"),
|
||||||
|
styling.Plain("\n\n当前已添加的路径:\n"),
|
||||||
|
styling.Blockquote(func() string {
|
||||||
|
var sb strings.Builder
|
||||||
|
for _, dir := range dirs {
|
||||||
|
sb.WriteString(dir.StorageName)
|
||||||
|
sb.WriteString(" - ")
|
||||||
|
sb.WriteString(dir.Path)
|
||||||
|
sb.WriteString("\n")
|
||||||
|
}
|
||||||
|
return sb.String()
|
||||||
|
}(), true),
|
||||||
|
},
|
||||||
|
), nil)
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
user, err := dao.GetUserByChatID(update.GetUserChat().GetID())
|
||||||
|
if err != nil {
|
||||||
|
logger.L.Errorf("获取用户失败: %s", err)
|
||||||
|
ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil)
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
switch args[0] {
|
||||||
|
case "add":
|
||||||
|
return addDir(ctx, update, user, args[1], args[2])
|
||||||
|
case "del":
|
||||||
|
return delDir(ctx, update, user, args[1], args[2])
|
||||||
|
default:
|
||||||
|
ctx.Reply(update, ext.ReplyTextString("未知操作"), nil)
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func addDir(ctx *ext.Context, update *ext.Update, user *dao.User, storageName, path string) error {
|
||||||
|
if _, err := storage.GetStorageByUserIDAndName(user.ChatID, storageName); err != nil {
|
||||||
|
ctx.Reply(update, ext.ReplyTextString(err.Error()), nil)
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := dao.CreateDirForUser(user.ID, storageName, path); err != nil {
|
||||||
|
logger.L.Errorf("创建路径失败: %s", err)
|
||||||
|
ctx.Reply(update, ext.ReplyTextString("创建路径失败"), nil)
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
ctx.Reply(update, ext.ReplyTextString("路径添加成功"), nil)
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
|
||||||
|
func delDir(ctx *ext.Context, update *ext.Update, user *dao.User, storageName, path string) error {
|
||||||
|
if err := dao.DeleteDirForUser(user.ID, storageName, path); err != nil {
|
||||||
|
logger.L.Errorf("删除路径失败: %s", err)
|
||||||
|
ctx.Reply(update, ext.ReplyTextString("删除路径失败"), nil)
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
ctx.Reply(update, ext.ReplyTextString("路径删除成功"), nil)
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
85
bot/handle_file.go
Normal file
85
bot/handle_file.go
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
package bot
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/celestix/gotgproto/dispatcher"
|
||||||
|
"github.com/celestix/gotgproto/ext"
|
||||||
|
"github.com/gotd/td/tg"
|
||||||
|
"github.com/krau/SaveAny-Bot/dao"
|
||||||
|
"github.com/krau/SaveAny-Bot/logger"
|
||||||
|
"github.com/krau/SaveAny-Bot/storage"
|
||||||
|
"github.com/krau/SaveAny-Bot/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
func handleFileMessage(ctx *ext.Context, update *ext.Update) error {
|
||||||
|
logger.L.Trace("Got media: ", update.EffectiveMessage.Media.TypeName())
|
||||||
|
supported, err := supportedMediaFilter(update.EffectiveMessage.Message)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !supported {
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := dao.GetUserByChatID(update.GetUserChat().GetID())
|
||||||
|
if err != nil {
|
||||||
|
logger.L.Errorf("获取用户失败: %s", err)
|
||||||
|
ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil)
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
storages := storage.GetUserStorages(user.ChatID)
|
||||||
|
if len(storages) == 0 {
|
||||||
|
ctx.Reply(update, ext.ReplyTextString("无可用的存储"), nil)
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
|
||||||
|
msg, err := ctx.Reply(update, ext.ReplyTextString("正在获取文件信息..."), nil)
|
||||||
|
if err != nil {
|
||||||
|
logger.L.Errorf("回复失败: %s", err)
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
media := update.EffectiveMessage.Media
|
||||||
|
file, err := FileFromMedia(media, "")
|
||||||
|
if err != nil {
|
||||||
|
logger.L.Errorf("获取文件失败: %s", err)
|
||||||
|
ctx.Reply(update, ext.ReplyTextString(fmt.Sprintf("获取文件失败: %s", err)), nil)
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
if file.FileName == "" {
|
||||||
|
file.FileName = fmt.Sprintf("%d_%d_%s", update.EffectiveChat().GetID(), update.EffectiveMessage.ID, file.Hash())
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := dao.SaveReceivedFile(&dao.ReceivedFile{
|
||||||
|
Processing: false,
|
||||||
|
FileName: file.FileName,
|
||||||
|
ChatID: update.EffectiveChat().GetID(),
|
||||||
|
MessageID: update.EffectiveMessage.ID,
|
||||||
|
ReplyMessageID: msg.ID,
|
||||||
|
ReplyChatID: update.GetUserChat().GetID(),
|
||||||
|
}); err != nil {
|
||||||
|
logger.L.Errorf("添加接收的文件失败: %s", err)
|
||||||
|
if _, err := ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
|
||||||
|
Message: fmt.Sprintf("添加接收的文件失败: %s", err),
|
||||||
|
ID: msg.ID,
|
||||||
|
}); err != nil {
|
||||||
|
logger.L.Errorf("编辑消息失败: %s", err)
|
||||||
|
}
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
|
||||||
|
if !user.Silent || user.DefaultStorage == "" {
|
||||||
|
return ProvideSelectMessage(ctx, update, file, update.EffectiveChat().GetID(), update.EffectiveMessage.ID, msg.ID)
|
||||||
|
}
|
||||||
|
return HandleSilentAddTask(ctx, update, user, &types.Task{
|
||||||
|
Ctx: ctx,
|
||||||
|
Status: types.Pending,
|
||||||
|
File: file,
|
||||||
|
StorageName: user.DefaultStorage,
|
||||||
|
FileChatID: update.EffectiveChat().GetID(),
|
||||||
|
ReplyMessageID: msg.ID,
|
||||||
|
ReplyChatID: update.GetUserChat().GetID(),
|
||||||
|
FileMessageID: update.EffectiveMessage.ID,
|
||||||
|
UserID: user.ChatID,
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -78,7 +78,7 @@ func handleLinkMessage(ctx *ext.Context, update *ext.Update) error {
|
|||||||
file.FileName = fmt.Sprintf("%d_%d_%s", linkChat.GetID(), messageID, file.Hash())
|
file.FileName = fmt.Sprintf("%d_%d_%s", linkChat.GetID(), messageID, file.Hash())
|
||||||
}
|
}
|
||||||
|
|
||||||
receivedFile := &types.ReceivedFile{
|
receivedFile := &dao.ReceivedFile{
|
||||||
Processing: false,
|
Processing: false,
|
||||||
FileName: file.FileName,
|
FileName: file.FileName,
|
||||||
ChatID: linkChat.GetID(),
|
ChatID: linkChat.GetID(),
|
||||||
|
|||||||
116
bot/handle_save.go
Normal file
116
bot/handle_save.go
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
package bot
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/celestix/gotgproto/dispatcher"
|
||||||
|
"github.com/celestix/gotgproto/ext"
|
||||||
|
"github.com/gotd/td/tg"
|
||||||
|
"github.com/krau/SaveAny-Bot/dao"
|
||||||
|
"github.com/krau/SaveAny-Bot/logger"
|
||||||
|
"github.com/krau/SaveAny-Bot/storage"
|
||||||
|
"github.com/krau/SaveAny-Bot/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
func saveCmd(ctx *ext.Context, update *ext.Update) error {
|
||||||
|
res, ok := update.EffectiveMessage.GetReplyTo()
|
||||||
|
if !ok || res == nil {
|
||||||
|
ctx.Reply(update, ext.ReplyTextString("请回复要保存的文件"), nil)
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
replyHeader, ok := res.(*tg.MessageReplyHeader)
|
||||||
|
if !ok {
|
||||||
|
ctx.Reply(update, ext.ReplyTextString("请回复要保存的文件"), nil)
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
replyToMsgID, ok := replyHeader.GetReplyToMsgID()
|
||||||
|
if !ok {
|
||||||
|
ctx.Reply(update, ext.ReplyTextString("请回复要保存的文件"), nil)
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := dao.GetUserByChatID(update.GetUserChat().GetID())
|
||||||
|
if err != nil {
|
||||||
|
logger.L.Errorf("获取用户失败: %s", err)
|
||||||
|
ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil)
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
|
||||||
|
storages := storage.GetUserStorages(user.ChatID)
|
||||||
|
|
||||||
|
if len(storages) == 0 {
|
||||||
|
ctx.Reply(update, ext.ReplyTextString("无可用的存储"), nil)
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
|
||||||
|
msg, err := GetTGMessage(ctx, update.EffectiveChat().GetID(), replyToMsgID)
|
||||||
|
if err != nil {
|
||||||
|
logger.L.Errorf("获取消息失败: %s", err)
|
||||||
|
ctx.Reply(update, ext.ReplyTextString("无法获取消息"), nil)
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
|
||||||
|
supported, _ := supportedMediaFilter(msg)
|
||||||
|
if !supported {
|
||||||
|
ctx.Reply(update, ext.ReplyTextString("不支持的消息类型或消息中没有文件"), nil)
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
|
||||||
|
replied, err := ctx.Reply(update, ext.ReplyTextString("正在获取文件信息..."), nil)
|
||||||
|
if err != nil {
|
||||||
|
logger.L.Errorf("回复失败: %s", err)
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
|
||||||
|
cmdText := update.EffectiveMessage.Text
|
||||||
|
customFileName := strings.TrimSpace(strings.TrimPrefix(cmdText, "/save"))
|
||||||
|
|
||||||
|
file, err := FileFromMessage(ctx, update.EffectiveChat().GetID(), msg.ID, customFileName)
|
||||||
|
if err != nil {
|
||||||
|
logger.L.Errorf("获取文件失败: %s", err)
|
||||||
|
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
|
||||||
|
Message: fmt.Sprintf("获取文件失败: %s", err),
|
||||||
|
ID: replied.ID,
|
||||||
|
})
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: better file name
|
||||||
|
if file.FileName == "" {
|
||||||
|
file.FileName = fmt.Sprintf("%d_%d_%s", update.EffectiveChat().GetID(), replyToMsgID, file.Hash())
|
||||||
|
}
|
||||||
|
receivedFile := &dao.ReceivedFile{
|
||||||
|
Processing: false,
|
||||||
|
FileName: file.FileName,
|
||||||
|
ChatID: update.EffectiveChat().GetID(),
|
||||||
|
MessageID: replyToMsgID,
|
||||||
|
ReplyMessageID: replied.ID,
|
||||||
|
ReplyChatID: update.GetUserChat().GetID(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := dao.SaveReceivedFile(receivedFile); err != nil {
|
||||||
|
logger.L.Errorf("保存接收的文件失败: %s", err)
|
||||||
|
if _, err := ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
|
||||||
|
Message: fmt.Sprintf("保存接收的文件失败: %s", err),
|
||||||
|
ID: replied.ID,
|
||||||
|
}); err != nil {
|
||||||
|
logger.L.Errorf("编辑消息失败: %s", err)
|
||||||
|
}
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
if !user.Silent || user.DefaultStorage == "" {
|
||||||
|
return ProvideSelectMessage(ctx, update, file, update.EffectiveChat().GetID(), msg.ID, replied.ID)
|
||||||
|
}
|
||||||
|
return HandleSilentAddTask(ctx, update, user, &types.Task{
|
||||||
|
Ctx: ctx,
|
||||||
|
Status: types.Pending,
|
||||||
|
File: file,
|
||||||
|
StorageName: user.DefaultStorage,
|
||||||
|
FileChatID: update.EffectiveChat().GetID(),
|
||||||
|
ReplyMessageID: replied.ID,
|
||||||
|
ReplyChatID: update.GetUserChat().GetID(),
|
||||||
|
FileMessageID: msg.ID,
|
||||||
|
UserID: user.ChatID,
|
||||||
|
})
|
||||||
|
}
|
||||||
30
bot/handle_silent.go
Normal file
30
bot/handle_silent.go
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
package bot
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/celestix/gotgproto/dispatcher"
|
||||||
|
"github.com/celestix/gotgproto/ext"
|
||||||
|
"github.com/krau/SaveAny-Bot/dao"
|
||||||
|
"github.com/krau/SaveAny-Bot/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
func silent(ctx *ext.Context, update *ext.Update) error {
|
||||||
|
user, err := dao.GetUserByChatID(update.GetUserChat().GetID())
|
||||||
|
if err != nil {
|
||||||
|
logger.L.Errorf("获取用户失败: %s", err)
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
if !user.Silent && user.DefaultStorage == "" {
|
||||||
|
ctx.Reply(update, ext.ReplyTextString("请先使用 /storage 设置默认存储位置"), nil)
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
user.Silent = !user.Silent
|
||||||
|
if err := dao.UpdateUser(user); err != nil {
|
||||||
|
logger.L.Errorf("更新用户失败: %s", err)
|
||||||
|
ctx.Reply(update, ext.ReplyTextString("更新用户失败"), nil)
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
ctx.Reply(update, ext.ReplyTextString(fmt.Sprintf("已%s静默模式", map[bool]string{true: "开启", false: "关闭"}[user.Silent])), nil)
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
37
bot/handle_start.go
Normal file
37
bot/handle_start.go
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
package bot
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/celestix/gotgproto/dispatcher"
|
||||||
|
"github.com/celestix/gotgproto/ext"
|
||||||
|
"github.com/krau/SaveAny-Bot/dao"
|
||||||
|
"github.com/krau/SaveAny-Bot/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
func start(ctx *ext.Context, update *ext.Update) error {
|
||||||
|
if err := dao.CreateUser(update.GetUserChat().GetID()); err != nil {
|
||||||
|
logger.L.Errorf("创建用户失败: %s", err)
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
return help(ctx, update)
|
||||||
|
}
|
||||||
|
|
||||||
|
const helpText string = `
|
||||||
|
Save Any Bot - 转存你的 Telegram 文件
|
||||||
|
命令:
|
||||||
|
/start - 开始使用
|
||||||
|
/help - 显示帮助
|
||||||
|
/silent - 开关静默模式
|
||||||
|
/storage - 设置默认存储位置
|
||||||
|
/save [自定义文件名] - 保存文件
|
||||||
|
|
||||||
|
静默模式: 开启后 Bot 直接保存到收到的文件到默认位置, 不再询问
|
||||||
|
|
||||||
|
默认存储位置: 在静默模式下保存到的位置
|
||||||
|
|
||||||
|
向 Bot 发送(转发)文件, 或发送一个公开频道的消息链接以保存文件
|
||||||
|
`
|
||||||
|
|
||||||
|
func help(ctx *ext.Context, update *ext.Update) error {
|
||||||
|
ctx.Reply(update, ext.ReplyTextString(helpText), nil)
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
99
bot/handle_storage.go
Normal file
99
bot/handle_storage.go
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
package bot
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/celestix/gotgproto/dispatcher"
|
||||||
|
"github.com/celestix/gotgproto/ext"
|
||||||
|
"github.com/gotd/td/tg"
|
||||||
|
"github.com/krau/SaveAny-Bot/dao"
|
||||||
|
"github.com/krau/SaveAny-Bot/logger"
|
||||||
|
"github.com/krau/SaveAny-Bot/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
func storageCmd(ctx *ext.Context, update *ext.Update) error {
|
||||||
|
userChatID := update.GetUserChat().GetID()
|
||||||
|
storages := storage.GetUserStorages(userChatID)
|
||||||
|
if len(storages) == 0 {
|
||||||
|
ctx.Reply(update, ext.ReplyTextString("无可用的存储"), nil)
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
markup, err := getSetDefaultStorageMarkup(userChatID, storages)
|
||||||
|
if err != nil {
|
||||||
|
logger.L.Errorf("Failed to get markup: %s", err)
|
||||||
|
ctx.Reply(update, ext.ReplyTextString("获取存储位置失败"), nil)
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
ctx.Reply(update, ext.ReplyTextString("请选择要设为默认的存储位置"), &ext.ReplyOpts{
|
||||||
|
Markup: markup,
|
||||||
|
})
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
|
||||||
|
func setDefaultStorage(ctx *ext.Context, update *ext.Update) error {
|
||||||
|
args := strings.Split(string(update.CallbackQuery.Data), " ")
|
||||||
|
userID, _ := strconv.Atoi(args[1])
|
||||||
|
if userID != int(update.CallbackQuery.GetUserID()) {
|
||||||
|
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
||||||
|
QueryID: update.CallbackQuery.QueryID,
|
||||||
|
Alert: true,
|
||||||
|
Message: "你没有权限",
|
||||||
|
CacheTime: 5,
|
||||||
|
})
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
cbDataId, _ := strconv.Atoi(args[2])
|
||||||
|
storageName, err := dao.GetCallbackData(uint(cbDataId))
|
||||||
|
if err != nil {
|
||||||
|
logger.L.Errorf("获取回调数据失败: %s", err)
|
||||||
|
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
||||||
|
QueryID: update.CallbackQuery.QueryID,
|
||||||
|
Alert: true,
|
||||||
|
Message: "获取回调数据失败",
|
||||||
|
CacheTime: 5,
|
||||||
|
})
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
|
||||||
|
selectedStorage, err := storage.GetStorageByName(storageName)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logger.L.Errorf("获取指定存储失败: %s", err)
|
||||||
|
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
||||||
|
QueryID: update.CallbackQuery.QueryID,
|
||||||
|
Alert: true,
|
||||||
|
Message: "获取指定存储失败",
|
||||||
|
CacheTime: 5,
|
||||||
|
})
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
user, err := dao.GetUserByChatID(int64(userID))
|
||||||
|
if err != nil {
|
||||||
|
logger.L.Errorf("Failed to get user: %s", err)
|
||||||
|
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
||||||
|
QueryID: update.CallbackQuery.QueryID,
|
||||||
|
Alert: true,
|
||||||
|
Message: "获取用户失败",
|
||||||
|
CacheTime: 5,
|
||||||
|
})
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
user.DefaultStorage = storageName
|
||||||
|
if err := dao.UpdateUser(user); err != nil {
|
||||||
|
logger.L.Errorf("Failed to update user: %s", err)
|
||||||
|
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
||||||
|
QueryID: update.CallbackQuery.QueryID,
|
||||||
|
Alert: true,
|
||||||
|
Message: "更新用户失败",
|
||||||
|
CacheTime: 5,
|
||||||
|
})
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
|
||||||
|
Message: fmt.Sprintf("已将 %s (%s) 设为默认存储位置", selectedStorage.Name(), selectedStorage.Type()),
|
||||||
|
ID: update.CallbackQuery.GetMsgID(),
|
||||||
|
})
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
418
bot/handlers.go
418
bot/handlers.go
@@ -1,25 +1,10 @@
|
|||||||
package bot
|
package bot
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/duke-git/lancet/v2/slice"
|
|
||||||
"github.com/gotd/td/telegram/message/entity"
|
|
||||||
"github.com/gotd/td/telegram/message/styling"
|
|
||||||
"github.com/gotd/td/tg"
|
|
||||||
|
|
||||||
"github.com/celestix/gotgproto/dispatcher"
|
"github.com/celestix/gotgproto/dispatcher"
|
||||||
"github.com/celestix/gotgproto/dispatcher/handlers"
|
"github.com/celestix/gotgproto/dispatcher/handlers"
|
||||||
"github.com/celestix/gotgproto/dispatcher/handlers/filters"
|
"github.com/celestix/gotgproto/dispatcher/handlers/filters"
|
||||||
"github.com/celestix/gotgproto/ext"
|
|
||||||
"github.com/krau/SaveAny-Bot/config"
|
|
||||||
"github.com/krau/SaveAny-Bot/dao"
|
|
||||||
"github.com/krau/SaveAny-Bot/logger"
|
"github.com/krau/SaveAny-Bot/logger"
|
||||||
"github.com/krau/SaveAny-Bot/queue"
|
|
||||||
"github.com/krau/SaveAny-Bot/storage"
|
|
||||||
"github.com/krau/SaveAny-Bot/types"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func RegisterHandlers(dispatcher dispatcher.Dispatcher) {
|
func RegisterHandlers(dispatcher dispatcher.Dispatcher) {
|
||||||
@@ -29,6 +14,7 @@ func RegisterHandlers(dispatcher dispatcher.Dispatcher) {
|
|||||||
dispatcher.AddHandler(handlers.NewCommand("silent", silent))
|
dispatcher.AddHandler(handlers.NewCommand("silent", silent))
|
||||||
dispatcher.AddHandler(handlers.NewCommand("storage", storageCmd))
|
dispatcher.AddHandler(handlers.NewCommand("storage", storageCmd))
|
||||||
dispatcher.AddHandler(handlers.NewCommand("save", saveCmd))
|
dispatcher.AddHandler(handlers.NewCommand("save", saveCmd))
|
||||||
|
dispatcher.AddHandler(handlers.NewCommand("dir", dirCmd))
|
||||||
linkRegexFilter, err := filters.Message.Regex(linkRegexString)
|
linkRegexFilter, err := filters.Message.Regex(linkRegexString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.L.Panicf("创建正则表达式过滤器失败: %s", err)
|
logger.L.Panicf("创建正则表达式过滤器失败: %s", err)
|
||||||
@@ -36,406 +22,6 @@ func RegisterHandlers(dispatcher dispatcher.Dispatcher) {
|
|||||||
dispatcher.AddHandler(handlers.NewMessage(linkRegexFilter, handleLinkMessage))
|
dispatcher.AddHandler(handlers.NewMessage(linkRegexFilter, handleLinkMessage))
|
||||||
dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("add"), AddToQueue))
|
dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("add"), AddToQueue))
|
||||||
dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("set_default"), setDefaultStorage))
|
dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("set_default"), setDefaultStorage))
|
||||||
|
dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("cancel"), cancelTask))
|
||||||
dispatcher.AddHandler(handlers.NewMessage(filters.Message.Media, handleFileMessage))
|
dispatcher.AddHandler(handlers.NewMessage(filters.Message.Media, handleFileMessage))
|
||||||
}
|
}
|
||||||
|
|
||||||
const noPermissionText string = `
|
|
||||||
您不在白名单中, 无法使用此 Bot.
|
|
||||||
您可以部署自己的实例: https://github.com/krau/SaveAny-Bot
|
|
||||||
`
|
|
||||||
|
|
||||||
func checkPermission(ctx *ext.Context, update *ext.Update) error {
|
|
||||||
userID := update.GetUserChat().GetID()
|
|
||||||
if !slice.Contain(config.Cfg.GetUsersID(), userID) {
|
|
||||||
ctx.Reply(update, ext.ReplyTextString(noPermissionText), nil)
|
|
||||||
return dispatcher.EndGroups
|
|
||||||
}
|
|
||||||
return dispatcher.ContinueGroups
|
|
||||||
}
|
|
||||||
|
|
||||||
func start(ctx *ext.Context, update *ext.Update) error {
|
|
||||||
if err := dao.CreateUser(update.GetUserChat().GetID()); err != nil {
|
|
||||||
logger.L.Errorf("创建用户失败: %s", err)
|
|
||||||
return dispatcher.EndGroups
|
|
||||||
}
|
|
||||||
return help(ctx, update)
|
|
||||||
}
|
|
||||||
|
|
||||||
const helpText string = `
|
|
||||||
Save Any Bot - 转存你的 Telegram 文件
|
|
||||||
命令:
|
|
||||||
/start - 开始使用
|
|
||||||
/help - 显示帮助
|
|
||||||
/silent - 开关静默模式
|
|
||||||
/storage - 设置默认存储位置
|
|
||||||
/save [自定义文件名] - 保存文件
|
|
||||||
|
|
||||||
静默模式: 开启后 Bot 直接保存到收到的文件到默认位置, 不再询问
|
|
||||||
|
|
||||||
默认存储位置: 在静默模式下保存到的位置
|
|
||||||
|
|
||||||
向 Bot 发送(转发)文件, 或发送一个公开频道的消息链接以保存文件
|
|
||||||
`
|
|
||||||
|
|
||||||
func help(ctx *ext.Context, update *ext.Update) error {
|
|
||||||
ctx.Reply(update, ext.ReplyTextString(helpText), nil)
|
|
||||||
return dispatcher.EndGroups
|
|
||||||
}
|
|
||||||
|
|
||||||
func silent(ctx *ext.Context, update *ext.Update) error {
|
|
||||||
user, err := dao.GetUserByChatID(update.GetUserChat().GetID())
|
|
||||||
if err != nil {
|
|
||||||
logger.L.Errorf("获取用户失败: %s", err)
|
|
||||||
return dispatcher.EndGroups
|
|
||||||
}
|
|
||||||
if !user.Silent && user.DefaultStorage == "" {
|
|
||||||
ctx.Reply(update, ext.ReplyTextString("请先使用 /storage 设置默认存储位置"), nil)
|
|
||||||
return dispatcher.EndGroups
|
|
||||||
}
|
|
||||||
user.Silent = !user.Silent
|
|
||||||
if err := dao.UpdateUser(user); err != nil {
|
|
||||||
logger.L.Errorf("更新用户失败: %s", err)
|
|
||||||
ctx.Reply(update, ext.ReplyTextString("更新用户失败"), nil)
|
|
||||||
return dispatcher.EndGroups
|
|
||||||
}
|
|
||||||
ctx.Reply(update, ext.ReplyTextString(fmt.Sprintf("已%s静默模式", map[bool]string{true: "开启", false: "关闭"}[user.Silent])), nil)
|
|
||||||
return dispatcher.EndGroups
|
|
||||||
}
|
|
||||||
|
|
||||||
func saveCmd(ctx *ext.Context, update *ext.Update) error {
|
|
||||||
res, ok := update.EffectiveMessage.GetReplyTo()
|
|
||||||
if !ok || res == nil {
|
|
||||||
ctx.Reply(update, ext.ReplyTextString("请回复要保存的文件"), nil)
|
|
||||||
return dispatcher.EndGroups
|
|
||||||
}
|
|
||||||
replyHeader, ok := res.(*tg.MessageReplyHeader)
|
|
||||||
if !ok {
|
|
||||||
ctx.Reply(update, ext.ReplyTextString("请回复要保存的文件"), nil)
|
|
||||||
return dispatcher.EndGroups
|
|
||||||
}
|
|
||||||
replyToMsgID, ok := replyHeader.GetReplyToMsgID()
|
|
||||||
if !ok {
|
|
||||||
ctx.Reply(update, ext.ReplyTextString("请回复要保存的文件"), nil)
|
|
||||||
return dispatcher.EndGroups
|
|
||||||
}
|
|
||||||
|
|
||||||
user, err := dao.GetUserByChatID(update.GetUserChat().GetID())
|
|
||||||
if err != nil {
|
|
||||||
logger.L.Errorf("获取用户失败: %s", err)
|
|
||||||
ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil)
|
|
||||||
return dispatcher.EndGroups
|
|
||||||
}
|
|
||||||
|
|
||||||
storages := storage.GetUserStorages(user.ChatID)
|
|
||||||
|
|
||||||
if len(storages) == 0 {
|
|
||||||
ctx.Reply(update, ext.ReplyTextString("无可用的存储"), nil)
|
|
||||||
return dispatcher.EndGroups
|
|
||||||
}
|
|
||||||
|
|
||||||
msg, err := GetTGMessage(ctx, update.EffectiveChat().GetID(), replyToMsgID)
|
|
||||||
if err != nil {
|
|
||||||
logger.L.Errorf("获取消息失败: %s", err)
|
|
||||||
ctx.Reply(update, ext.ReplyTextString("无法获取消息"), nil)
|
|
||||||
return dispatcher.EndGroups
|
|
||||||
}
|
|
||||||
|
|
||||||
supported, _ := supportedMediaFilter(msg)
|
|
||||||
if !supported {
|
|
||||||
ctx.Reply(update, ext.ReplyTextString("不支持的消息类型或消息中没有文件"), nil)
|
|
||||||
return dispatcher.EndGroups
|
|
||||||
}
|
|
||||||
|
|
||||||
replied, err := ctx.Reply(update, ext.ReplyTextString("正在获取文件信息..."), nil)
|
|
||||||
if err != nil {
|
|
||||||
logger.L.Errorf("回复失败: %s", err)
|
|
||||||
return dispatcher.EndGroups
|
|
||||||
}
|
|
||||||
|
|
||||||
cmdText := update.EffectiveMessage.Text
|
|
||||||
customFileName := strings.TrimSpace(strings.TrimPrefix(cmdText, "/save"))
|
|
||||||
|
|
||||||
file, err := FileFromMessage(ctx, update.EffectiveChat().GetID(), msg.ID, customFileName)
|
|
||||||
if err != nil {
|
|
||||||
logger.L.Errorf("获取文件失败: %s", err)
|
|
||||||
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
|
|
||||||
Message: fmt.Sprintf("获取文件失败: %s", err),
|
|
||||||
ID: replied.ID,
|
|
||||||
})
|
|
||||||
return dispatcher.EndGroups
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: better file name
|
|
||||||
if file.FileName == "" {
|
|
||||||
file.FileName = fmt.Sprintf("%d_%d_%s", update.EffectiveChat().GetID(), replyToMsgID, file.Hash())
|
|
||||||
}
|
|
||||||
receivedFile := &types.ReceivedFile{
|
|
||||||
Processing: false,
|
|
||||||
FileName: file.FileName,
|
|
||||||
ChatID: update.EffectiveChat().GetID(),
|
|
||||||
MessageID: replyToMsgID,
|
|
||||||
ReplyMessageID: replied.ID,
|
|
||||||
ReplyChatID: update.GetUserChat().GetID(),
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := dao.SaveReceivedFile(receivedFile); err != nil {
|
|
||||||
logger.L.Errorf("保存接收的文件失败: %s", err)
|
|
||||||
if _, err := ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
|
|
||||||
Message: fmt.Sprintf("保存接收的文件失败: %s", err),
|
|
||||||
ID: replied.ID,
|
|
||||||
}); err != nil {
|
|
||||||
logger.L.Errorf("编辑消息失败: %s", err)
|
|
||||||
}
|
|
||||||
return dispatcher.EndGroups
|
|
||||||
}
|
|
||||||
if !user.Silent || user.DefaultStorage == "" {
|
|
||||||
return ProvideSelectMessage(ctx, update, file, update.EffectiveChat().GetID(), msg.ID, replied.ID)
|
|
||||||
}
|
|
||||||
return HandleSilentAddTask(ctx, update, user, &types.Task{
|
|
||||||
Ctx: ctx,
|
|
||||||
Status: types.Pending,
|
|
||||||
File: file,
|
|
||||||
StorageName: user.DefaultStorage,
|
|
||||||
FileChatID: update.EffectiveChat().GetID(),
|
|
||||||
ReplyMessageID: replied.ID,
|
|
||||||
ReplyChatID: update.GetUserChat().GetID(),
|
|
||||||
FileMessageID: msg.ID,
|
|
||||||
UserID: user.ChatID,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func storageCmd(ctx *ext.Context, update *ext.Update) error {
|
|
||||||
user, err := dao.GetUserByChatID(update.GetUserChat().GetID())
|
|
||||||
if err != nil {
|
|
||||||
logger.L.Errorf("获取用户失败: %s", err)
|
|
||||||
ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil)
|
|
||||||
return dispatcher.EndGroups
|
|
||||||
}
|
|
||||||
storages := storage.GetUserStorages(user.ChatID)
|
|
||||||
if len(storages) == 0 {
|
|
||||||
ctx.Reply(update, ext.ReplyTextString("无可用的存储"), nil)
|
|
||||||
return dispatcher.EndGroups
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx.Reply(update, ext.ReplyTextString("请选择要设为默认的存储位置"), &ext.ReplyOpts{
|
|
||||||
Markup: getSetDefaultStorageMarkup(user.ChatID, storages),
|
|
||||||
})
|
|
||||||
|
|
||||||
return dispatcher.EndGroups
|
|
||||||
}
|
|
||||||
|
|
||||||
func setDefaultStorage(ctx *ext.Context, update *ext.Update) error {
|
|
||||||
args := strings.Split(string(update.CallbackQuery.Data), " ")
|
|
||||||
userID, _ := strconv.Atoi(args[1])
|
|
||||||
storageNameHash := args[2]
|
|
||||||
if userID != int(update.CallbackQuery.GetUserID()) {
|
|
||||||
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
|
||||||
QueryID: update.CallbackQuery.QueryID,
|
|
||||||
Alert: true,
|
|
||||||
Message: "你没有权限",
|
|
||||||
CacheTime: 5,
|
|
||||||
})
|
|
||||||
return dispatcher.EndGroups
|
|
||||||
}
|
|
||||||
storageName := storageHashName[storageNameHash]
|
|
||||||
selectedStorage, err := storage.GetStorageByName(storageName)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
logger.L.Errorf("获取指定存储失败: %s", err)
|
|
||||||
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
|
||||||
QueryID: update.CallbackQuery.QueryID,
|
|
||||||
Alert: true,
|
|
||||||
Message: "获取指定存储失败",
|
|
||||||
CacheTime: 5,
|
|
||||||
})
|
|
||||||
return dispatcher.EndGroups
|
|
||||||
}
|
|
||||||
user, err := dao.GetUserByChatID(int64(userID))
|
|
||||||
if err != nil {
|
|
||||||
logger.L.Errorf("Failed to get user: %s", err)
|
|
||||||
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
|
||||||
QueryID: update.CallbackQuery.QueryID,
|
|
||||||
Alert: true,
|
|
||||||
Message: "获取用户失败",
|
|
||||||
CacheTime: 5,
|
|
||||||
})
|
|
||||||
return dispatcher.EndGroups
|
|
||||||
}
|
|
||||||
user.DefaultStorage = storageName
|
|
||||||
if err := dao.UpdateUser(user); err != nil {
|
|
||||||
logger.L.Errorf("Failed to update user: %s", err)
|
|
||||||
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
|
||||||
QueryID: update.CallbackQuery.QueryID,
|
|
||||||
Alert: true,
|
|
||||||
Message: "更新用户失败",
|
|
||||||
CacheTime: 5,
|
|
||||||
})
|
|
||||||
return dispatcher.EndGroups
|
|
||||||
}
|
|
||||||
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
|
|
||||||
Message: fmt.Sprintf("已将 %s (%s) 设为默认存储位置", selectedStorage.Name(), selectedStorage.Type()),
|
|
||||||
ID: update.CallbackQuery.GetMsgID(),
|
|
||||||
})
|
|
||||||
return dispatcher.EndGroups
|
|
||||||
}
|
|
||||||
|
|
||||||
func handleFileMessage(ctx *ext.Context, update *ext.Update) error {
|
|
||||||
logger.L.Trace("Got media: ", update.EffectiveMessage.Media.TypeName())
|
|
||||||
supported, err := supportedMediaFilter(update.EffectiveMessage.Message)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if !supported {
|
|
||||||
return dispatcher.EndGroups
|
|
||||||
}
|
|
||||||
|
|
||||||
user, err := dao.GetUserByChatID(update.GetUserChat().GetID())
|
|
||||||
if err != nil {
|
|
||||||
logger.L.Errorf("获取用户失败: %s", err)
|
|
||||||
ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil)
|
|
||||||
return dispatcher.EndGroups
|
|
||||||
}
|
|
||||||
storages := storage.GetUserStorages(user.ChatID)
|
|
||||||
if len(storages) == 0 {
|
|
||||||
ctx.Reply(update, ext.ReplyTextString("无可用的存储"), nil)
|
|
||||||
return dispatcher.EndGroups
|
|
||||||
}
|
|
||||||
|
|
||||||
msg, err := ctx.Reply(update, ext.ReplyTextString("正在获取文件信息..."), nil)
|
|
||||||
if err != nil {
|
|
||||||
logger.L.Errorf("回复失败: %s", err)
|
|
||||||
return dispatcher.EndGroups
|
|
||||||
}
|
|
||||||
media := update.EffectiveMessage.Media
|
|
||||||
file, err := FileFromMedia(media, "")
|
|
||||||
if err != nil {
|
|
||||||
logger.L.Errorf("获取文件失败: %s", err)
|
|
||||||
ctx.Reply(update, ext.ReplyTextString(fmt.Sprintf("获取文件失败: %s", err)), nil)
|
|
||||||
return dispatcher.EndGroups
|
|
||||||
}
|
|
||||||
if file.FileName == "" {
|
|
||||||
file.FileName = fmt.Sprintf("%d_%d_%s", update.EffectiveChat().GetID(), update.EffectiveMessage.ID, file.Hash())
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := dao.SaveReceivedFile(&types.ReceivedFile{
|
|
||||||
Processing: false,
|
|
||||||
FileName: file.FileName,
|
|
||||||
ChatID: update.EffectiveChat().GetID(),
|
|
||||||
MessageID: update.EffectiveMessage.ID,
|
|
||||||
ReplyMessageID: msg.ID,
|
|
||||||
ReplyChatID: update.GetUserChat().GetID(),
|
|
||||||
}); err != nil {
|
|
||||||
logger.L.Errorf("添加接收的文件失败: %s", err)
|
|
||||||
if _, err := ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
|
|
||||||
Message: fmt.Sprintf("添加接收的文件失败: %s", err),
|
|
||||||
ID: msg.ID,
|
|
||||||
}); err != nil {
|
|
||||||
logger.L.Errorf("编辑消息失败: %s", err)
|
|
||||||
}
|
|
||||||
return dispatcher.EndGroups
|
|
||||||
}
|
|
||||||
|
|
||||||
if !user.Silent || user.DefaultStorage == "" {
|
|
||||||
return ProvideSelectMessage(ctx, update, file, update.EffectiveChat().GetID(), update.EffectiveMessage.ID, msg.ID)
|
|
||||||
}
|
|
||||||
return HandleSilentAddTask(ctx, update, user, &types.Task{
|
|
||||||
Ctx: ctx,
|
|
||||||
Status: types.Pending,
|
|
||||||
File: file,
|
|
||||||
StorageName: user.DefaultStorage,
|
|
||||||
FileChatID: update.EffectiveChat().GetID(),
|
|
||||||
ReplyMessageID: msg.ID,
|
|
||||||
ReplyChatID: update.GetUserChat().GetID(),
|
|
||||||
FileMessageID: update.EffectiveMessage.ID,
|
|
||||||
UserID: user.ChatID,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func AddToQueue(ctx *ext.Context, update *ext.Update) error {
|
|
||||||
if !slice.Contain(config.Cfg.GetUsersID(), update.CallbackQuery.UserID) {
|
|
||||||
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
|
||||||
QueryID: update.CallbackQuery.QueryID,
|
|
||||||
Alert: true,
|
|
||||||
Message: "你没有权限",
|
|
||||||
CacheTime: 5,
|
|
||||||
})
|
|
||||||
return dispatcher.EndGroups
|
|
||||||
}
|
|
||||||
args := strings.Split(string(update.CallbackQuery.Data), " ")
|
|
||||||
fileChatID, _ := strconv.Atoi(args[1])
|
|
||||||
fileMessageID, _ := strconv.Atoi(args[2])
|
|
||||||
storageNameHash := args[3]
|
|
||||||
storageName := storageHashName[storageNameHash]
|
|
||||||
if storageName == "" {
|
|
||||||
logger.L.Errorf("未知存储位置哈希: %d", storageNameHash)
|
|
||||||
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
|
||||||
QueryID: update.CallbackQuery.QueryID,
|
|
||||||
Alert: true,
|
|
||||||
Message: "未知存储位置",
|
|
||||||
CacheTime: 5,
|
|
||||||
})
|
|
||||||
return dispatcher.EndGroups
|
|
||||||
}
|
|
||||||
logger.L.Tracef("Got add to queue: chatID: %d, messageID: %d, storage: %s", fileChatID, fileMessageID, storageName)
|
|
||||||
record, err := dao.GetReceivedFileByChatAndMessageID(int64(fileChatID), fileMessageID)
|
|
||||||
if err != nil {
|
|
||||||
logger.L.Errorf("获取记录失败: %s", err)
|
|
||||||
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
|
||||||
QueryID: update.CallbackQuery.QueryID,
|
|
||||||
Alert: true,
|
|
||||||
Message: "查询记录失败",
|
|
||||||
CacheTime: 5,
|
|
||||||
})
|
|
||||||
return dispatcher.EndGroups
|
|
||||||
}
|
|
||||||
if update.CallbackQuery.MsgID != record.ReplyMessageID {
|
|
||||||
record.ReplyMessageID = update.CallbackQuery.MsgID
|
|
||||||
if err := dao.SaveReceivedFile(record); err != nil {
|
|
||||||
logger.L.Errorf("更新接收的文件失败: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
file, err := FileFromMessage(ctx, record.ChatID, record.MessageID, record.FileName)
|
|
||||||
if err != nil {
|
|
||||||
logger.L.Errorf("获取消息中的文件失败: %s", err)
|
|
||||||
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
|
||||||
QueryID: update.CallbackQuery.QueryID,
|
|
||||||
Alert: true,
|
|
||||||
Message: fmt.Sprintf("获取消息中的文件失败: %s", err),
|
|
||||||
CacheTime: 5,
|
|
||||||
})
|
|
||||||
return dispatcher.EndGroups
|
|
||||||
}
|
|
||||||
|
|
||||||
queue.AddTask(types.Task{
|
|
||||||
Ctx: ctx,
|
|
||||||
Status: types.Pending,
|
|
||||||
File: file,
|
|
||||||
StorageName: storageName,
|
|
||||||
FileChatID: record.ChatID,
|
|
||||||
ReplyMessageID: record.ReplyMessageID,
|
|
||||||
FileMessageID: record.MessageID,
|
|
||||||
ReplyChatID: record.ReplyChatID,
|
|
||||||
UserID: update.EffectiveUser().GetID(),
|
|
||||||
})
|
|
||||||
|
|
||||||
entityBuilder := entity.Builder{}
|
|
||||||
var entities []tg.MessageEntityClass
|
|
||||||
text := fmt.Sprintf("已添加到任务队列\n文件名: %s\n当前排队任务数: %d", record.FileName, queue.Len())
|
|
||||||
if err := styling.Perform(&entityBuilder,
|
|
||||||
styling.Plain("已添加到任务队列\n文件名: "),
|
|
||||||
styling.Code(record.FileName),
|
|
||||||
styling.Plain("\n当前排队任务数: "),
|
|
||||||
styling.Bold(strconv.Itoa(queue.Len())),
|
|
||||||
); err != nil {
|
|
||||||
logger.L.Errorf("Failed to build entity: %s", err)
|
|
||||||
} else {
|
|
||||||
text, entities = entityBuilder.Complete()
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
|
|
||||||
Message: text,
|
|
||||||
Entities: entities,
|
|
||||||
ID: record.ReplyMessageID,
|
|
||||||
})
|
|
||||||
return dispatcher.EndGroups
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -3,9 +3,13 @@ package bot
|
|||||||
import (
|
import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/celestix/gotgproto/dispatcher"
|
||||||
|
"github.com/celestix/gotgproto/ext"
|
||||||
|
"github.com/duke-git/lancet/v2/slice"
|
||||||
"github.com/gotd/contrib/middleware/floodwait"
|
"github.com/gotd/contrib/middleware/floodwait"
|
||||||
"github.com/gotd/contrib/middleware/ratelimit"
|
"github.com/gotd/contrib/middleware/ratelimit"
|
||||||
"github.com/gotd/td/telegram"
|
"github.com/gotd/td/telegram"
|
||||||
|
"github.com/krau/SaveAny-Bot/config"
|
||||||
"golang.org/x/time/rate"
|
"golang.org/x/time/rate"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -17,3 +21,17 @@ func FloodWaitMiddleware() []telegram.Middleware {
|
|||||||
ratelimiter,
|
ratelimiter,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const noPermissionText string = `
|
||||||
|
您不在白名单中, 无法使用此 Bot.
|
||||||
|
您可以部署自己的实例: https://github.com/krau/SaveAny-Bot
|
||||||
|
`
|
||||||
|
|
||||||
|
func checkPermission(ctx *ext.Context, update *ext.Update) error {
|
||||||
|
userID := update.GetUserChat().GetID()
|
||||||
|
if !slice.Contain(config.Cfg.GetUsersID(), userID) {
|
||||||
|
ctx.Reply(update, ext.ReplyTextString(noPermissionText), nil)
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
return dispatcher.ContinueGroups
|
||||||
|
}
|
||||||
|
|||||||
65
bot/utils.go
65
bot/utils.go
@@ -24,6 +24,7 @@ var (
|
|||||||
ErrEmptyPhotoSize = errors.New("photo size is empty")
|
ErrEmptyPhotoSize = errors.New("photo size is empty")
|
||||||
ErrEmptyPhotoSizes = errors.New("photo size slice is empty")
|
ErrEmptyPhotoSizes = errors.New("photo size slice is empty")
|
||||||
ErrNoStorages = errors.New("no available storage")
|
ErrNoStorages = errors.New("no available storage")
|
||||||
|
ErrEmptyMessage = errors.New("message is empty")
|
||||||
)
|
)
|
||||||
|
|
||||||
func supportedMediaFilter(m *tg.Message) (bool, error) {
|
func supportedMediaFilter(m *tg.Message) (bool, error) {
|
||||||
@@ -40,23 +41,26 @@ func supportedMediaFilter(m *tg.Message) (bool, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// for callback data
|
|
||||||
var storageHashName = map[string]string{}
|
|
||||||
|
|
||||||
func getSelectStorageMarkup(userChatID int64, fileChatID, fileMessageID int) (*tg.ReplyInlineMarkup, error) {
|
func getSelectStorageMarkup(userChatID int64, fileChatID, fileMessageID int) (*tg.ReplyInlineMarkup, error) {
|
||||||
user, err := dao.GetUserByChatID(userChatID)
|
user, err := dao.GetUserByChatID(userChatID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("failed to get user by chat ID: %d, error: %w", userChatID, err)
|
||||||
}
|
}
|
||||||
storages := storage.GetUserStorages(user.ChatID)
|
storages := storage.GetUserStorages(user.ChatID)
|
||||||
|
if len(storages) == 0 {
|
||||||
|
return nil, ErrNoStorages
|
||||||
|
}
|
||||||
|
|
||||||
buttons := make([]tg.KeyboardButtonClass, 0)
|
buttons := make([]tg.KeyboardButtonClass, 0)
|
||||||
for _, storage := range storages {
|
for _, storage := range storages {
|
||||||
nameHash := common.HashString(storage.Name())
|
cbData := fmt.Sprintf("%d %d %s 0", fileChatID, fileMessageID, storage.Name()) // 0 for empty dir id
|
||||||
storageHashName[nameHash] = storage.Name()
|
cbDataId, err := dao.CreateCallbackData(cbData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create callback data: %w", err)
|
||||||
|
}
|
||||||
buttons = append(buttons, &tg.KeyboardButtonCallback{
|
buttons = append(buttons, &tg.KeyboardButtonCallback{
|
||||||
Text: storage.Name(),
|
Text: storage.Name(),
|
||||||
Data: []byte(fmt.Sprintf("add %d %d %s", fileChatID, fileMessageID, nameHash)),
|
Data: []byte(fmt.Sprintf("add %d", cbDataId)),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
markup := &tg.ReplyInlineMarkup{}
|
markup := &tg.ReplyInlineMarkup{}
|
||||||
@@ -68,14 +72,19 @@ func getSelectStorageMarkup(userChatID int64, fileChatID, fileMessageID int) (*t
|
|||||||
return markup, nil
|
return markup, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getSetDefaultStorageMarkup(userChatID int64, storages []storage.Storage) *tg.ReplyInlineMarkup {
|
func getSelectDirMarkup(fileChatID, fileMessageID int, storageName string, dirs []dao.Dir) (*tg.ReplyInlineMarkup, error) {
|
||||||
buttons := make([]tg.KeyboardButtonClass, 0)
|
buttons := make([]tg.KeyboardButtonClass, 0)
|
||||||
for _, storage := range storages {
|
for _, dir := range dirs {
|
||||||
nameHash := common.HashString(storage.Name())
|
if dir.ID == 0 || dir.StorageName != storageName {
|
||||||
storageHashName[nameHash] = storage.Name()
|
return nil, fmt.Errorf("unexpected dir: %v", dir)
|
||||||
|
}
|
||||||
|
cbDataId, err := dao.CreateCallbackData(fmt.Sprintf("%d %d %s %d", fileChatID, fileMessageID, storageName, dir.ID))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create callback data: %w", err)
|
||||||
|
}
|
||||||
buttons = append(buttons, &tg.KeyboardButtonCallback{
|
buttons = append(buttons, &tg.KeyboardButtonCallback{
|
||||||
Text: storage.Name(),
|
Text: dir.Path,
|
||||||
Data: []byte(fmt.Sprintf("set_default %d %s", userChatID, nameHash)),
|
Data: []byte(fmt.Sprintf("add_to_dir %d", cbDataId)),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
markup := &tg.ReplyInlineMarkup{}
|
markup := &tg.ReplyInlineMarkup{}
|
||||||
@@ -84,8 +93,28 @@ func getSetDefaultStorageMarkup(userChatID int64, storages []storage.Storage) *t
|
|||||||
row.Buttons = buttons[i:min(i+3, len(buttons))]
|
row.Buttons = buttons[i:min(i+3, len(buttons))]
|
||||||
markup.Rows = append(markup.Rows, row)
|
markup.Rows = append(markup.Rows, row)
|
||||||
}
|
}
|
||||||
return markup
|
return markup, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getSetDefaultStorageMarkup(userChatID int64, storages []storage.Storage) (*tg.ReplyInlineMarkup, error) {
|
||||||
|
buttons := make([]tg.KeyboardButtonClass, 0)
|
||||||
|
for _, storage := range storages {
|
||||||
|
cbDataId, err := dao.CreateCallbackData(storage.Name())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create callback data: %w", err)
|
||||||
|
}
|
||||||
|
buttons = append(buttons, &tg.KeyboardButtonCallback{
|
||||||
|
Text: storage.Name(),
|
||||||
|
Data: []byte(fmt.Sprintf("set_default %d %d", userChatID, cbDataId)),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
markup := &tg.ReplyInlineMarkup{}
|
||||||
|
for i := 0; i < len(buttons); i += 3 {
|
||||||
|
row := tg.KeyboardButtonRow{}
|
||||||
|
row.Buttons = buttons[i:min(i+3, len(buttons))]
|
||||||
|
markup.Rows = append(markup.Rows, row)
|
||||||
|
}
|
||||||
|
return markup, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func FileFromMedia(media tg.MessageMediaClass, customFileName string) (*types.File, error) {
|
func FileFromMedia(media tg.MessageMediaClass, customFileName string) (*types.File, error) {
|
||||||
@@ -176,7 +205,7 @@ func GetTGMessage(ctx *ext.Context, chatId int64, messageID int) (*tg.Message, e
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if len(messages) == 0 {
|
if len(messages) == 0 {
|
||||||
return nil, errors.New("no messages found")
|
return nil, ErrEmptyMessage
|
||||||
}
|
}
|
||||||
msg := messages[0]
|
msg := messages[0]
|
||||||
tgMessage, ok := msg.(*tg.Message)
|
tgMessage, ok := msg.(*tg.Message)
|
||||||
@@ -199,7 +228,7 @@ func ProvideSelectMessage(ctx *ext.Context, update *ext.Update, file *types.File
|
|||||||
} else {
|
} else {
|
||||||
text, entities = entityBuilder.Complete()
|
text, entities = entityBuilder.Complete()
|
||||||
}
|
}
|
||||||
markup, err := getSelectStorageMarkup(update.EffectiveUser().GetID(), int(chatID), fileMsgID)
|
markup, err := getSelectStorageMarkup(update.GetUserChat().GetID(), int(chatID), fileMsgID)
|
||||||
if errors.Is(err, ErrNoStorages) {
|
if errors.Is(err, ErrNoStorages) {
|
||||||
logger.L.Errorf("Failed to get select storage markup: %s", err)
|
logger.L.Errorf("Failed to get select storage markup: %s", err)
|
||||||
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
|
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
|
||||||
@@ -227,7 +256,7 @@ func ProvideSelectMessage(ctx *ext.Context, update *ext.Update, file *types.File
|
|||||||
return dispatcher.EndGroups
|
return dispatcher.EndGroups
|
||||||
}
|
}
|
||||||
|
|
||||||
func HandleSilentAddTask(ctx *ext.Context, update *ext.Update, user *types.User, task *types.Task) error {
|
func HandleSilentAddTask(ctx *ext.Context, update *ext.Update, user *dao.User, task *types.Task) error {
|
||||||
if user.DefaultStorage == "" {
|
if user.DefaultStorage == "" {
|
||||||
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
|
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
|
||||||
Message: "请先使用 /storage 设置默认存储位置",
|
Message: "请先使用 /storage 设置默认存储位置",
|
||||||
@@ -235,7 +264,7 @@ func HandleSilentAddTask(ctx *ext.Context, update *ext.Update, user *types.User,
|
|||||||
})
|
})
|
||||||
return dispatcher.EndGroups
|
return dispatcher.EndGroups
|
||||||
}
|
}
|
||||||
queue.AddTask(*task)
|
queue.AddTask(task)
|
||||||
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
|
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
|
||||||
Message: fmt.Sprintf("已添加到队列: %s\n当前排队任务数: %d", task.FileName(), queue.Len()),
|
Message: fmt.Sprintf("已添加到队列: %s\n当前排队任务数: %d", task.FileName(), queue.Len()),
|
||||||
ID: task.ReplyMessageID,
|
ID: task.ReplyMessageID,
|
||||||
|
|||||||
21
cmd/run.go
21
cmd/run.go
@@ -1,20 +1,24 @@
|
|||||||
package cmd
|
package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
"github.com/krau/SaveAny-Bot/bootstrap"
|
"github.com/krau/SaveAny-Bot/bot"
|
||||||
|
"github.com/krau/SaveAny-Bot/common"
|
||||||
"github.com/krau/SaveAny-Bot/config"
|
"github.com/krau/SaveAny-Bot/config"
|
||||||
"github.com/krau/SaveAny-Bot/core"
|
"github.com/krau/SaveAny-Bot/core"
|
||||||
|
"github.com/krau/SaveAny-Bot/dao"
|
||||||
"github.com/krau/SaveAny-Bot/logger"
|
"github.com/krau/SaveAny-Bot/logger"
|
||||||
|
"github.com/krau/SaveAny-Bot/storage"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Run(_ *cobra.Command, _ []string) {
|
func Run(_ *cobra.Command, _ []string) {
|
||||||
bootstrap.InitAll()
|
InitAll()
|
||||||
core.Run()
|
core.Run()
|
||||||
|
|
||||||
quit := make(chan os.Signal, 1)
|
quit := make(chan os.Signal, 1)
|
||||||
@@ -49,3 +53,16 @@ func Run(_ *cobra.Command, _ []string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func InitAll() {
|
||||||
|
if err := config.Init(); err != nil {
|
||||||
|
fmt.Println("加载配置文件失败: ", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
logger.InitLogger()
|
||||||
|
logger.L.Info("正在启动 SaveAny-Bot...")
|
||||||
|
dao.Init()
|
||||||
|
storage.LoadStorages()
|
||||||
|
common.Init()
|
||||||
|
bot.Init()
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,8 +1,12 @@
|
|||||||
|
#创建文件时,若需要保留中文注释,请务必确保本文件编码为 UTF-8 ,否则会无法读取。
|
||||||
workers = 4 # 同时下载文件数
|
workers = 4 # 同时下载文件数
|
||||||
retry = 3 # 下载失败重试次数
|
retry = 3 # 下载失败重试次数
|
||||||
|
threads = 4 # 单个任务下载最大线程数
|
||||||
|
stream = false # 使用stream模式, 详情请查看文档
|
||||||
|
|
||||||
[telegram]
|
[telegram]
|
||||||
# Bot Token
|
# Bot Token
|
||||||
|
# 更换 Bot Token 后请删除数据库文件和 session.db
|
||||||
token = ""
|
token = ""
|
||||||
# Telegram API 配置, 若不配置也可运行, 将使用默认的 API ID 和 API HASH
|
# Telegram API 配置, 若不配置也可运行, 将使用默认的 API ID 和 API HASH
|
||||||
# 推荐使用自己的 API ID 和 API HASH (https://my.telegram.org)
|
# 推荐使用自己的 API ID 和 API HASH (https://my.telegram.org)
|
||||||
@@ -21,7 +25,9 @@ url = "socks5://127.0.0.1:7890"
|
|||||||
name = "本机1"
|
name = "本机1"
|
||||||
# 存储类型, 目前可用: local , alist , webdav
|
# 存储类型, 目前可用: local , alist , webdav
|
||||||
type = "local"
|
type = "local"
|
||||||
|
# 启用存储
|
||||||
enable = true
|
enable = true
|
||||||
|
# 文件保存根路径
|
||||||
base_path = "./downloads"
|
base_path = "./downloads"
|
||||||
|
|
||||||
[[storages]]
|
[[storages]]
|
||||||
@@ -33,12 +39,12 @@ base_path = "./downloads/2"
|
|||||||
[[storages]]
|
[[storages]]
|
||||||
name = "MyAlist"
|
name = "MyAlist"
|
||||||
type = "alist"
|
type = "alist"
|
||||||
enable = true
|
enable = false #记得启用
|
||||||
base_path = '/'
|
base_path = '/'
|
||||||
url = 'https://alist.com'
|
url = 'https://alist.com'
|
||||||
username = 'admin'
|
username = 'admin'
|
||||||
password = 'password'
|
password = 'password'
|
||||||
token_exp = 86400
|
token_exp = 86400 # 86400--1天 604800--7天 1296000--15天 2592000--30天 15552000--180天
|
||||||
# alist 可直接使用 token 登录, 此时 username, password, token_exp 将被忽略
|
# alist 可直接使用 token 登录, 此时 username, password, token_exp 将被忽略
|
||||||
# 请自行在 alist 侧配置合理的 token 过期时间
|
# 请自行在 alist 侧配置合理的 token 过期时间
|
||||||
# token = ""
|
# token = ""
|
||||||
@@ -47,8 +53,8 @@ token_exp = 86400
|
|||||||
[[storages]]
|
[[storages]]
|
||||||
name = "MyWebdav"
|
name = "MyWebdav"
|
||||||
type = "webdav"
|
type = "webdav"
|
||||||
|
enable = false
|
||||||
base_path = '/path/telegram'
|
base_path = '/path/telegram'
|
||||||
enable = true
|
|
||||||
url = 'https://example.com/dav'
|
url = 'https://example.com/dav'
|
||||||
username = 'username'
|
username = 'username'
|
||||||
password = 'password'
|
password = 'password'
|
||||||
@@ -56,20 +62,24 @@ password = 'password'
|
|||||||
|
|
||||||
# 用户列表
|
# 用户列表
|
||||||
[[users]]
|
[[users]]
|
||||||
# user id
|
# telegram user id
|
||||||
id = 123456
|
id = 114514
|
||||||
# 存储名称过滤列表
|
# 开启黑名单,开启后下方留空以使用所有存储,反之则为白名单,白名单请在下方输入允许的存储名
|
||||||
storages = ["MyAlist", "本机1"]
|
blacklist = true
|
||||||
# 开启黑名单模式, 过滤列表中的存储将无法使用, 默认为白名单模式
|
# 将列表留空并开启黑名单模式以允许使用所有存储,此处示例为黑名单模式,用户114514 可使用所有存储
|
||||||
blacklist = false
|
storages = []
|
||||||
|
|
||||||
|
|
||||||
[[users]]
|
[[users]]
|
||||||
id = 114514
|
id = 123456
|
||||||
# 将列表留空并开启黑名单模式以允许使用所有存储
|
blacklist = false #开启白名单模式,此时,用户123456 仅可使用下方列表中的存储
|
||||||
storages = []
|
# 此时该用户只能使用名为 本机1 的存储
|
||||||
blacklist = true
|
storages = ["本机1"]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 其他配置
|
||||||
|
|
||||||
# [log]
|
# [log]
|
||||||
# # 日志等级
|
# # 日志等级
|
||||||
# level = "DEBUG"
|
# level = "DEBUG"
|
||||||
@@ -81,4 +91,4 @@ blacklist = true
|
|||||||
# cache_ttl = 30
|
# cache_ttl = 30
|
||||||
|
|
||||||
# [db]
|
# [db]
|
||||||
# path = "data/data.db" # 数据库文件路径
|
# path = "data/data.db" # 数据库文件路径
|
||||||
@@ -9,10 +9,13 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Workers int `toml:"workers" mapstructure:"workers"`
|
Workers int `toml:"workers" mapstructure:"workers"`
|
||||||
Retry int `toml:"retry" mapstructure:"retry"`
|
Retry int `toml:"retry" mapstructure:"retry"`
|
||||||
NoCleanCache bool `toml:"no_clean_cache" mapstructure:"no_clean_cache" json:"no_clean_cache"`
|
NoCleanCache bool `toml:"no_clean_cache" mapstructure:"no_clean_cache" json:"no_clean_cache"`
|
||||||
Users []userConfig `toml:"users" mapstructure:"users" json:"users"`
|
Threads int `toml:"threads" mapstructure:"threads" json:"threads"`
|
||||||
|
Stream bool `toml:"stream" mapstructure:"stream" json:"stream"`
|
||||||
|
|
||||||
|
Users []userConfig `toml:"users" mapstructure:"users" json:"users"`
|
||||||
|
|
||||||
Temp tempConfig `toml:"temp" mapstructure:"temp"`
|
Temp tempConfig `toml:"temp" mapstructure:"temp"`
|
||||||
Log logConfig `toml:"log" mapstructure:"log"`
|
Log logConfig `toml:"log" mapstructure:"log"`
|
||||||
@@ -67,6 +70,7 @@ func Init() error {
|
|||||||
|
|
||||||
viper.SetDefault("workers", 3)
|
viper.SetDefault("workers", 3)
|
||||||
viper.SetDefault("retry", 3)
|
viper.SetDefault("retry", 3)
|
||||||
|
viper.SetDefault("threads", 4)
|
||||||
|
|
||||||
viper.SetDefault("telegram.app_id", 1025907)
|
viper.SetDefault("telegram.app_id", 1025907)
|
||||||
viper.SetDefault("telegram.app_hash", "452b0359b988148995f22ff0f4229750")
|
viper.SetDefault("telegram.app_hash", "452b0359b988148995f22ff0f4229750")
|
||||||
|
|||||||
143
core/core.go
143
core/core.go
@@ -4,115 +4,15 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"os"
|
|
||||||
"path"
|
|
||||||
"path/filepath"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gabriel-vasile/mimetype"
|
|
||||||
|
|
||||||
"github.com/celestix/gotgproto/ext"
|
"github.com/celestix/gotgproto/ext"
|
||||||
"github.com/duke-git/lancet/v2/fileutil"
|
|
||||||
"github.com/gotd/td/tg"
|
"github.com/gotd/td/tg"
|
||||||
"github.com/krau/SaveAny-Bot/bot"
|
|
||||||
"github.com/krau/SaveAny-Bot/config"
|
"github.com/krau/SaveAny-Bot/config"
|
||||||
"github.com/krau/SaveAny-Bot/logger"
|
"github.com/krau/SaveAny-Bot/logger"
|
||||||
"github.com/krau/SaveAny-Bot/queue"
|
"github.com/krau/SaveAny-Bot/queue"
|
||||||
"github.com/krau/SaveAny-Bot/storage"
|
|
||||||
"github.com/krau/SaveAny-Bot/types"
|
"github.com/krau/SaveAny-Bot/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
func processPendingTask(task *types.Task) error {
|
|
||||||
logger.L.Debugf("Start processing task: %s", task.String())
|
|
||||||
if task.FileName() == "" {
|
|
||||||
task.File.FileName = fmt.Sprintf("%d_%d_%s", task.FileChatID, task.FileMessageID, task.File.Hash())
|
|
||||||
}
|
|
||||||
cacheDestPath := filepath.Join(config.Cfg.Temp.BasePath, task.FileName())
|
|
||||||
cacheDestPath, err := filepath.Abs(cacheDestPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("处理路径失败: %w", err)
|
|
||||||
}
|
|
||||||
if err := fileutil.CreateDir(filepath.Dir(cacheDestPath)); err != nil {
|
|
||||||
return fmt.Errorf("创建目录失败: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if task.StoragePath == "" {
|
|
||||||
task.StoragePath = task.File.FileName
|
|
||||||
}
|
|
||||||
|
|
||||||
taskStorage, err := storage.GetStorageByUserIDAndName(task.UserID, task.StorageName)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
task.StoragePath = taskStorage.JoinStoragePath(*task)
|
|
||||||
|
|
||||||
if task.File.FileSize == 0 {
|
|
||||||
return processPhoto(task, taskStorage, cacheDestPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := task.Ctx.(*ext.Context)
|
|
||||||
|
|
||||||
barTotalCount := calculateBarTotalCount(task.File.FileSize)
|
|
||||||
|
|
||||||
progressCallback := func(bytesRead, contentLength int64) {
|
|
||||||
progress := float64(bytesRead) / float64(contentLength) * 100
|
|
||||||
logger.L.Tracef("Downloading %s: %.2f%%", task.String(), progress)
|
|
||||||
if task.File.FileSize < 1024*1024*50 || int(progress)%(100/barTotalCount) != 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
text, entities := buildProgressMessageEntity(task, barTotalCount, bytesRead, task.StartTime, progress)
|
|
||||||
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
|
||||||
Message: text,
|
|
||||||
Entities: entities,
|
|
||||||
ID: task.ReplyMessageID,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
text, entities := buildProgressMessageEntity(task, barTotalCount, 0, task.StartTime, 0)
|
|
||||||
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
|
||||||
Message: text,
|
|
||||||
Entities: entities,
|
|
||||||
ID: task.ReplyMessageID,
|
|
||||||
})
|
|
||||||
|
|
||||||
readCloser, err := NewTelegramReader(task.Ctx, bot.Client, &task.File.Location,
|
|
||||||
0, task.File.FileSize-1, task.File.FileSize,
|
|
||||||
progressCallback, task.File.FileSize/100)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("创建下载失败: %w", err)
|
|
||||||
}
|
|
||||||
defer readCloser.Close()
|
|
||||||
|
|
||||||
dest, err := os.Create(cacheDestPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("创建文件失败: %w", err)
|
|
||||||
}
|
|
||||||
defer dest.Close()
|
|
||||||
task.StartTime = time.Now()
|
|
||||||
if _, err := io.CopyN(dest, readCloser, task.File.FileSize); err != nil {
|
|
||||||
return fmt.Errorf("下载文件失败: %w", err)
|
|
||||||
}
|
|
||||||
defer cleanCacheFile(cacheDestPath)
|
|
||||||
if path.Ext(task.FileName()) == "" {
|
|
||||||
mimeType, err := mimetype.DetectFile(cacheDestPath)
|
|
||||||
if err != nil {
|
|
||||||
logger.L.Errorf("Failed to detect mime type: %s", err)
|
|
||||||
} else {
|
|
||||||
task.File.FileName = fmt.Sprintf("%s%s", task.FileName(), mimeType.Extension())
|
|
||||||
task.StoragePath = fmt.Sprintf("%s%s", task.StoragePath, mimeType.Extension())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.L.Infof("Downloaded file: %s", cacheDestPath)
|
|
||||||
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
|
||||||
Message: fmt.Sprintf("下载完成: %s\n正在转存文件...", task.FileName()),
|
|
||||||
ID: task.ReplyMessageID,
|
|
||||||
})
|
|
||||||
|
|
||||||
return saveFileWithRetry(task, taskStorage, cacheDestPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
func worker(queue *queue.TaskQueue, semaphore chan struct{}) {
|
func worker(queue *queue.TaskQueue, semaphore chan struct{}) {
|
||||||
for {
|
for {
|
||||||
semaphore <- struct{}{}
|
semaphore <- struct{}{}
|
||||||
@@ -122,13 +22,12 @@ func worker(queue *queue.TaskQueue, semaphore chan struct{}) {
|
|||||||
switch task.Status {
|
switch task.Status {
|
||||||
case types.Pending:
|
case types.Pending:
|
||||||
logger.L.Infof("Processing task: %s", task.String())
|
logger.L.Infof("Processing task: %s", task.String())
|
||||||
if err := processPendingTask(&task); err != nil {
|
if err := processPendingTask(task); err != nil {
|
||||||
logger.L.Errorf("Failed to do task: %s", err)
|
|
||||||
task.Error = err
|
task.Error = err
|
||||||
if errors.Is(err, context.Canceled) {
|
if errors.Is(err, context.Canceled) {
|
||||||
logger.L.Debugf("Task canceled: %s", task.String())
|
|
||||||
task.Status = types.Canceled
|
task.Status = types.Canceled
|
||||||
} else {
|
} else {
|
||||||
|
logger.L.Errorf("Failed to do task: %s", err)
|
||||||
task.Status = types.Failed
|
task.Status = types.Failed
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -137,23 +36,43 @@ func worker(queue *queue.TaskQueue, semaphore chan struct{}) {
|
|||||||
queue.AddTask(task)
|
queue.AddTask(task)
|
||||||
case types.Succeeded:
|
case types.Succeeded:
|
||||||
logger.L.Infof("Task succeeded: %s", task.String())
|
logger.L.Infof("Task succeeded: %s", task.String())
|
||||||
task.Ctx.(*ext.Context).EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
extCtx, ok := task.Ctx.(*ext.Context)
|
||||||
Message: fmt.Sprintf("文件保存成功\n [%s]: %s", task.StorageName, task.StoragePath),
|
if !ok {
|
||||||
ID: task.ReplyMessageID,
|
logger.L.Errorf("Context is not *ext.Context: %T", task.Ctx)
|
||||||
})
|
} else {
|
||||||
|
extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
||||||
|
Message: fmt.Sprintf("文件保存成功\n [%s]: %s", task.StorageName, task.StoragePath),
|
||||||
|
ID: task.ReplyMessageID,
|
||||||
|
})
|
||||||
|
}
|
||||||
case types.Failed:
|
case types.Failed:
|
||||||
logger.L.Errorf("Task failed: %s", task.String())
|
logger.L.Errorf("Task failed: %s", task.String())
|
||||||
task.Ctx.(*ext.Context).EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
extCtx, ok := task.Ctx.(*ext.Context)
|
||||||
Message: "文件保存失败\n" + task.Error.Error(),
|
if !ok {
|
||||||
ID: task.ReplyMessageID,
|
logger.L.Errorf("Context is not *ext.Context: %T", task.Ctx)
|
||||||
})
|
} else {
|
||||||
|
extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
||||||
|
Message: "文件保存失败\n" + task.Error.Error(),
|
||||||
|
ID: task.ReplyMessageID,
|
||||||
|
})
|
||||||
|
}
|
||||||
case types.Canceled:
|
case types.Canceled:
|
||||||
logger.L.Infof("Task canceled: %s", task.String())
|
logger.L.Infof("Task canceled: %s", task.String())
|
||||||
|
extCtx, ok := task.Ctx.(*ext.Context)
|
||||||
|
if !ok {
|
||||||
|
logger.L.Errorf("Context is not *ext.Context: %T", task.Ctx)
|
||||||
|
} else {
|
||||||
|
extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
||||||
|
Message: "任务已取消",
|
||||||
|
ID: task.ReplyMessageID,
|
||||||
|
})
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
logger.L.Errorf("Unknown task status: %s", task.Status)
|
logger.L.Errorf("Unknown task status: %s", task.Status)
|
||||||
}
|
}
|
||||||
<-semaphore
|
<-semaphore
|
||||||
logger.L.Debugf("Task done: %s", task.String())
|
logger.L.Debugf("Task done: %s; status: %s", task.String(), task.Status)
|
||||||
|
queue.DoneTask(task)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
119
core/download.go
Normal file
119
core/download.go
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
package core
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"path/filepath"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/celestix/gotgproto/ext"
|
||||||
|
"github.com/duke-git/lancet/v2/fileutil"
|
||||||
|
"github.com/gotd/td/tg"
|
||||||
|
"github.com/krau/SaveAny-Bot/bot"
|
||||||
|
"github.com/krau/SaveAny-Bot/config"
|
||||||
|
"github.com/krau/SaveAny-Bot/logger"
|
||||||
|
"github.com/krau/SaveAny-Bot/storage"
|
||||||
|
"github.com/krau/SaveAny-Bot/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
func processPendingTask(task *types.Task) error {
|
||||||
|
logger.L.Debugf("Start processing task: %s", task.String())
|
||||||
|
if task.FileName() == "" {
|
||||||
|
task.File.FileName = fmt.Sprintf("%d_%d_%s", task.FileChatID, task.FileMessageID, task.File.Hash())
|
||||||
|
}
|
||||||
|
cacheDestPath := filepath.Join(config.Cfg.Temp.BasePath, task.FileName())
|
||||||
|
cacheDestPath, err := filepath.Abs(cacheDestPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("处理路径失败: %w", err)
|
||||||
|
}
|
||||||
|
if err := fileutil.CreateDir(filepath.Dir(cacheDestPath)); err != nil {
|
||||||
|
return fmt.Errorf("创建目录失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if task.StoragePath == "" {
|
||||||
|
task.StoragePath = task.File.FileName
|
||||||
|
}
|
||||||
|
|
||||||
|
taskStorage, err := storage.GetStorageByUserIDAndName(task.UserID, task.StorageName)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
task.StoragePath = taskStorage.JoinStoragePath(*task)
|
||||||
|
|
||||||
|
if task.File.FileSize == 0 {
|
||||||
|
return processPhoto(task, taskStorage, cacheDestPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, ok := task.Ctx.(*ext.Context)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("context is not *ext.Context: %T", task.Ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
cancelCtx, cancel := context.WithCancel(ctx)
|
||||||
|
task.Cancel = cancel
|
||||||
|
|
||||||
|
downloadBuider := Downloader.Download(bot.Client.API(), task.File.Location).WithThreads(getTaskThreads(task.File.FileSize))
|
||||||
|
|
||||||
|
taskStreamStorage, isStreamStorage := taskStorage.(storage.StreamStorage)
|
||||||
|
if config.Cfg.Stream {
|
||||||
|
if !isStreamStorage {
|
||||||
|
logger.L.Warnf("存储 %s 不支持流式上传", taskStorage.Name())
|
||||||
|
} else {
|
||||||
|
text, entities := buildProgressMessageEntity(task, 0, task.StartTime, 0)
|
||||||
|
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
||||||
|
Message: text,
|
||||||
|
Entities: entities,
|
||||||
|
ID: task.ReplyMessageID,
|
||||||
|
ReplyMarkup: getCancelTaskMarkup(task),
|
||||||
|
})
|
||||||
|
uploadStream, err := taskStreamStorage.NewUploadStream(cancelCtx, task.StoragePath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("创建上传流失败: %w", err)
|
||||||
|
}
|
||||||
|
defer uploadStream.Close()
|
||||||
|
|
||||||
|
task.StartTime = time.Now()
|
||||||
|
progressCallback := buildProgressCallback(ctx, task, getProgressUpdateCount(task.File.FileSize))
|
||||||
|
|
||||||
|
progressStream := NewProgressStream(uploadStream, task.File.FileSize, progressCallback)
|
||||||
|
|
||||||
|
_, err = downloadBuider.Stream(cancelCtx, progressStream)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("下载文件失败: %w", err)
|
||||||
|
}
|
||||||
|
logger.L.Infof("Uploaded file: %s", task.StoragePath)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
text, entities := buildProgressMessageEntity(task, 0, task.StartTime, 0)
|
||||||
|
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
||||||
|
Message: text,
|
||||||
|
Entities: entities,
|
||||||
|
ID: task.ReplyMessageID,
|
||||||
|
ReplyMarkup: getCancelTaskMarkup(task),
|
||||||
|
})
|
||||||
|
|
||||||
|
progressCallback := buildProgressCallback(ctx, task, getProgressUpdateCount(task.File.FileSize))
|
||||||
|
dest, err := NewTaskLocalFile(cacheDestPath, task.File.FileSize, progressCallback)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("创建文件失败: %w", err)
|
||||||
|
}
|
||||||
|
defer dest.Close()
|
||||||
|
task.StartTime = time.Now()
|
||||||
|
_, err = downloadBuider.Parallel(cancelCtx, dest)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("下载文件失败: %w", err)
|
||||||
|
}
|
||||||
|
defer cleanCacheFile(cacheDestPath)
|
||||||
|
|
||||||
|
fixTaskFileExt(task, cacheDestPath)
|
||||||
|
|
||||||
|
logger.L.Infof("Downloaded file: %s", cacheDestPath)
|
||||||
|
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
||||||
|
Message: fmt.Sprintf("下载完成: %s\n正在转存文件...", task.FileName()),
|
||||||
|
ID: task.ReplyMessageID,
|
||||||
|
})
|
||||||
|
|
||||||
|
return saveFileWithRetry(cancelCtx, task, taskStorage, cacheDestPath)
|
||||||
|
}
|
||||||
9
core/downloader.go
Normal file
9
core/downloader.go
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
package core
|
||||||
|
|
||||||
|
import "github.com/gotd/td/telegram/downloader"
|
||||||
|
|
||||||
|
var Downloader *downloader.Downloader
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
Downloader = downloader.NewDownloader().WithPartSize(1024 * 1024)
|
||||||
|
}
|
||||||
154
core/reader.go
154
core/reader.go
@@ -1,154 +0,0 @@
|
|||||||
package core
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/celestix/gotgproto"
|
|
||||||
"github.com/gotd/td/tg"
|
|
||||||
"github.com/krau/SaveAny-Bot/config"
|
|
||||||
)
|
|
||||||
|
|
||||||
type telegramReader struct {
|
|
||||||
client *gotgproto.Client
|
|
||||||
location *tg.InputFileLocationClass
|
|
||||||
bytesread int64
|
|
||||||
chunkSize int64
|
|
||||||
i int64
|
|
||||||
contentLength int64
|
|
||||||
start int64
|
|
||||||
end int64
|
|
||||||
next func() ([]byte, error)
|
|
||||||
progressCallback func(bytesRead, contentLength int64)
|
|
||||||
callbackInterval int64
|
|
||||||
lastProgress int64
|
|
||||||
buffer []byte
|
|
||||||
ctx context.Context
|
|
||||||
}
|
|
||||||
|
|
||||||
func (*telegramReader) Close() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *telegramReader) Read(p []byte) (n int, err error) {
|
|
||||||
if r.bytesread == r.contentLength {
|
|
||||||
return 0, io.EOF
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.i >= int64(len(r.buffer)) {
|
|
||||||
r.buffer, err = r.next()
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if len(r.buffer) == 0 {
|
|
||||||
r.next = r.partStream()
|
|
||||||
r.buffer, err = r.next()
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
r.i = 0
|
|
||||||
}
|
|
||||||
n = copy(p, r.buffer[r.i:])
|
|
||||||
r.i += int64(n)
|
|
||||||
r.bytesread += int64(n)
|
|
||||||
|
|
||||||
if r.progressCallback != nil && (r.bytesread-r.lastProgress >= r.callbackInterval || r.bytesread == r.contentLength) {
|
|
||||||
r.progressCallback(r.bytesread, r.contentLength)
|
|
||||||
r.lastProgress = r.bytesread
|
|
||||||
}
|
|
||||||
|
|
||||||
return n, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewTelegramReader(
|
|
||||||
ctx context.Context,
|
|
||||||
client *gotgproto.Client,
|
|
||||||
location *tg.InputFileLocationClass,
|
|
||||||
start int64,
|
|
||||||
end int64,
|
|
||||||
contentLength int64,
|
|
||||||
progressCallback func(bytesRead, contentLength int64),
|
|
||||||
callbackInterval int64,
|
|
||||||
) (io.ReadCloser, error) {
|
|
||||||
|
|
||||||
r := &telegramReader{
|
|
||||||
ctx: ctx,
|
|
||||||
location: location,
|
|
||||||
client: client,
|
|
||||||
start: start,
|
|
||||||
end: end,
|
|
||||||
chunkSize: int64(1024 * 1024),
|
|
||||||
contentLength: contentLength,
|
|
||||||
progressCallback: progressCallback,
|
|
||||||
callbackInterval: callbackInterval,
|
|
||||||
}
|
|
||||||
|
|
||||||
r.next = r.partStream()
|
|
||||||
return r, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *telegramReader) chunk(offset int64, limit int64) ([]byte, error) {
|
|
||||||
var lastError error
|
|
||||||
for i := 0; i < config.Cfg.Retry; i++ {
|
|
||||||
req := &tg.UploadGetFileRequest{
|
|
||||||
Offset: offset,
|
|
||||||
Limit: int(limit),
|
|
||||||
Location: *r.location,
|
|
||||||
}
|
|
||||||
res, err := r.client.API().UploadGetFile(r.ctx, req)
|
|
||||||
if err != nil {
|
|
||||||
if strings.Contains(err.Error(), tg.ErrTimeout) {
|
|
||||||
lastError = err
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
switch result := res.(type) {
|
|
||||||
case *tg.UploadFile:
|
|
||||||
return result.Bytes, nil
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("unexpected type %T", r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil, lastError
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *telegramReader) partStream() func() ([]byte, error) {
|
|
||||||
|
|
||||||
start := r.start
|
|
||||||
end := r.end
|
|
||||||
offset := start - (start % r.chunkSize)
|
|
||||||
|
|
||||||
firstPartCut := start - offset
|
|
||||||
lastPartCut := (end % r.chunkSize) + 1
|
|
||||||
partCount := int((end - offset + r.chunkSize) / r.chunkSize)
|
|
||||||
currentPart := 1
|
|
||||||
|
|
||||||
readData := func() ([]byte, error) {
|
|
||||||
if currentPart > partCount {
|
|
||||||
return make([]byte, 0), nil
|
|
||||||
}
|
|
||||||
res, err := r.chunk(offset, r.chunkSize)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if len(res) == 0 {
|
|
||||||
return res, nil
|
|
||||||
} else if partCount == 1 {
|
|
||||||
res = res[firstPartCut:lastPartCut]
|
|
||||||
} else if currentPart == 1 {
|
|
||||||
res = res[firstPartCut:]
|
|
||||||
} else if currentPart == partCount {
|
|
||||||
res = res[:lastPartCut]
|
|
||||||
}
|
|
||||||
|
|
||||||
currentPart++
|
|
||||||
offset += r.chunkSize
|
|
||||||
return res, nil
|
|
||||||
}
|
|
||||||
return readData
|
|
||||||
}
|
|
||||||
187
core/utils.go
187
core/utils.go
@@ -1,10 +1,15 @@
|
|||||||
package core
|
package core
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"os"
|
"os"
|
||||||
|
"path"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/celestix/gotgproto/ext"
|
||||||
|
"github.com/gabriel-vasile/mimetype"
|
||||||
"github.com/gotd/td/telegram/message/entity"
|
"github.com/gotd/td/telegram/message/entity"
|
||||||
"github.com/gotd/td/telegram/message/styling"
|
"github.com/gotd/td/telegram/message/styling"
|
||||||
"github.com/gotd/td/tg"
|
"github.com/gotd/td/tg"
|
||||||
@@ -16,13 +21,21 @@ import (
|
|||||||
"github.com/krau/SaveAny-Bot/types"
|
"github.com/krau/SaveAny-Bot/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
func saveFileWithRetry(task *types.Task, taskStorage storage.Storage, localFilePath string) error {
|
func saveFileWithRetry(ctx context.Context, task *types.Task, taskStorage storage.Storage, localFilePath string) error {
|
||||||
for i := 0; i <= config.Cfg.Retry; i++ {
|
for i := 0; i <= config.Cfg.Retry; i++ {
|
||||||
if err := taskStorage.Save(task.Ctx, localFilePath, task.StoragePath); err != nil {
|
if err := ctx.Err(); err != nil {
|
||||||
|
return fmt.Errorf("context canceled while saving file: %w", err)
|
||||||
|
}
|
||||||
|
if err := taskStorage.Save(ctx, localFilePath, task.StoragePath); err != nil {
|
||||||
if i == config.Cfg.Retry {
|
if i == config.Cfg.Retry {
|
||||||
return fmt.Errorf("failed to save file: %w", err)
|
return fmt.Errorf("failed to save file: %w", err)
|
||||||
}
|
}
|
||||||
logger.L.Errorf("Failed to save file: %s, retrying...", err)
|
logger.L.Errorf("Failed to save file: %s, retrying...", err)
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return fmt.Errorf("context canceled during retry delay: %w", ctx.Err())
|
||||||
|
case <-time.After(time.Duration(i*500) * time.Millisecond):
|
||||||
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -53,20 +66,7 @@ func processPhoto(task *types.Task, taskStorage storage.Storage, cachePath strin
|
|||||||
|
|
||||||
logger.L.Infof("Downloaded file: %s", cachePath)
|
logger.L.Infof("Downloaded file: %s", cachePath)
|
||||||
|
|
||||||
return saveFileWithRetry(task, taskStorage, cachePath)
|
return saveFileWithRetry(task.Ctx, task, taskStorage, cachePath)
|
||||||
}
|
|
||||||
|
|
||||||
func getProgressBar(progress float64, totalCount int) string {
|
|
||||||
bar := ""
|
|
||||||
barSize := 100 / totalCount
|
|
||||||
for i := 0; i < totalCount; i++ {
|
|
||||||
if int(progress)/barSize > i {
|
|
||||||
bar += "█"
|
|
||||||
} else {
|
|
||||||
bar += "░"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return bar
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func cleanCacheFile(destPath string) {
|
func cleanCacheFile(destPath string) {
|
||||||
@@ -79,16 +79,17 @@ func cleanCacheFile(destPath string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func calculateBarTotalCount(fileSize int64) int {
|
// 获取进度需要更新的次数
|
||||||
barTotalCount := 5
|
func getProgressUpdateCount(fileSize int64) int {
|
||||||
|
updateCount := 5
|
||||||
if fileSize > 1024*1024*1000 {
|
if fileSize > 1024*1024*1000 {
|
||||||
barTotalCount = 40
|
updateCount = 50
|
||||||
} else if fileSize > 1024*1024*500 {
|
} else if fileSize > 1024*1024*500 {
|
||||||
barTotalCount = 20
|
updateCount = 20
|
||||||
} else if fileSize > 1024*1024*200 {
|
} else if fileSize > 1024*1024*200 {
|
||||||
barTotalCount = 10
|
updateCount = 10
|
||||||
}
|
}
|
||||||
return barTotalCount
|
return updateCount
|
||||||
}
|
}
|
||||||
|
|
||||||
func getSpeed(bytesRead int64, startTime time.Time) string {
|
func getSpeed(bytesRead int64, startTime time.Time) string {
|
||||||
@@ -100,13 +101,12 @@ func getSpeed(bytesRead int64, startTime time.Time) string {
|
|||||||
return fmt.Sprintf("%.2fMB/s", speed)
|
return fmt.Sprintf("%.2fMB/s", speed)
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildProgressMessageEntity(task *types.Task, barTotalCount int, bytesRead int64, startTime time.Time, progress float64) (string, []tg.MessageEntityClass) {
|
func buildProgressMessageEntity(task *types.Task, bytesRead int64, startTime time.Time, progress float64) (string, []tg.MessageEntityClass) {
|
||||||
entityBuilder := entity.Builder{}
|
entityBuilder := entity.Builder{}
|
||||||
text := fmt.Sprintf("正在处理下载任务\n文件名: %s\n保存路径: %s\n平均速度: %s\n当前进度: [%s] %.2f%%",
|
text := fmt.Sprintf("正在处理下载任务\n文件名: %s\n保存路径: %s\n平均速度: %s\n当前进度: %.2f%%",
|
||||||
task.FileName(),
|
task.FileName(),
|
||||||
fmt.Sprintf("[%s]:%s", task.StorageName, task.StoragePath),
|
fmt.Sprintf("[%s]:%s", task.StorageName, task.StoragePath),
|
||||||
getSpeed(bytesRead, startTime),
|
getSpeed(bytesRead, startTime),
|
||||||
getProgressBar(progress, barTotalCount),
|
|
||||||
progress,
|
progress,
|
||||||
)
|
)
|
||||||
var entities []tg.MessageEntityClass
|
var entities []tg.MessageEntityClass
|
||||||
@@ -117,11 +117,144 @@ func buildProgressMessageEntity(task *types.Task, barTotalCount int, bytesRead i
|
|||||||
styling.Code(fmt.Sprintf("[%s]:%s", task.StorageName, task.StoragePath)),
|
styling.Code(fmt.Sprintf("[%s]:%s", task.StorageName, task.StoragePath)),
|
||||||
styling.Plain("\n平均速度: "),
|
styling.Plain("\n平均速度: "),
|
||||||
styling.Bold(getSpeed(bytesRead, task.StartTime)),
|
styling.Bold(getSpeed(bytesRead, task.StartTime)),
|
||||||
styling.Plain("\n当前进度:\n "),
|
styling.Plain("\n当前进度: "),
|
||||||
styling.Code(fmt.Sprintf("[%s] %.2f%%", getProgressBar(progress, barTotalCount), progress)),
|
styling.Bold(fmt.Sprintf("%.2f%%", progress)),
|
||||||
); err != nil {
|
); err != nil {
|
||||||
logger.L.Errorf("Failed to build entities: %s", err)
|
logger.L.Errorf("Failed to build entities: %s", err)
|
||||||
return text, entities
|
return text, entities
|
||||||
}
|
}
|
||||||
return entityBuilder.Complete()
|
return entityBuilder.Complete()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func buildProgressCallback(ctx *ext.Context, task *types.Task, updateCount int) func(bytesRead, contentLength int64) {
|
||||||
|
return func(bytesRead, contentLength int64) {
|
||||||
|
progress := float64(bytesRead) / float64(contentLength) * 100
|
||||||
|
logger.L.Tracef("Downloading %s: %.2f%%", task.String(), progress)
|
||||||
|
progressInt := int(progress)
|
||||||
|
if task.File.FileSize < 1024*1024*50 || progressInt == 0 || progressInt%int(100/updateCount) != 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
text, entities := buildProgressMessageEntity(task, bytesRead, task.StartTime, progress)
|
||||||
|
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
||||||
|
Message: text,
|
||||||
|
Entities: entities,
|
||||||
|
ID: task.ReplyMessageID,
|
||||||
|
ReplyMarkup: getCancelTaskMarkup(task),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getCancelTaskMarkup(task *types.Task) *tg.ReplyInlineMarkup {
|
||||||
|
return &tg.ReplyInlineMarkup{
|
||||||
|
Rows: []tg.KeyboardButtonRow{{Buttons: []tg.KeyboardButtonClass{&tg.KeyboardButtonCallback{Text: "取消任务", Data: fmt.Appendf(nil, "cancel %s", task.Key())}}}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func fixTaskFileExt(task *types.Task, localFilePath string) {
|
||||||
|
if path.Ext(task.FileName()) == "" {
|
||||||
|
mimeType, err := mimetype.DetectFile(localFilePath)
|
||||||
|
if err != nil {
|
||||||
|
logger.L.Errorf("Failed to detect mime type: %s", err)
|
||||||
|
} else {
|
||||||
|
task.File.FileName = fmt.Sprintf("%s%s", task.FileName(), mimeType.Extension())
|
||||||
|
task.StoragePath = fmt.Sprintf("%s%s", task.StoragePath, mimeType.Extension())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getTaskThreads(fileSize int64) int {
|
||||||
|
threads := 1
|
||||||
|
if fileSize > 1024*1024*100 {
|
||||||
|
threads = config.Cfg.Threads
|
||||||
|
} else if fileSize > 1024*1024*50 {
|
||||||
|
threads = config.Cfg.Threads / 2
|
||||||
|
}
|
||||||
|
return threads
|
||||||
|
}
|
||||||
|
|
||||||
|
type TaskLocalFile struct {
|
||||||
|
file *os.File
|
||||||
|
size int64
|
||||||
|
done int64
|
||||||
|
progressCallback func(bytesRead, contentLength int64)
|
||||||
|
callbackTimes int64
|
||||||
|
nextCallbackAt int64
|
||||||
|
callbackInterval int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TaskLocalFile) Read(p []byte) (n int, err error) {
|
||||||
|
return t.file.Read(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TaskLocalFile) Close() error {
|
||||||
|
return t.file.Close()
|
||||||
|
}
|
||||||
|
func (t *TaskLocalFile) WriteAt(p []byte, off int64) (int, error) {
|
||||||
|
n, err := t.file.WriteAt(p, off)
|
||||||
|
if err != nil {
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
t.done += int64(n)
|
||||||
|
if t.progressCallback != nil && t.done >= t.nextCallbackAt {
|
||||||
|
t.progressCallback(t.done, t.size)
|
||||||
|
t.nextCallbackAt += t.callbackInterval
|
||||||
|
}
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTaskLocalFile(filePath string, fileSize int64, progressCallback func(bytesRead, contentLength int64)) (*TaskLocalFile, error) {
|
||||||
|
file, err := os.Create(filePath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to open file: %w", err)
|
||||||
|
}
|
||||||
|
var callbackInterval int64
|
||||||
|
callbackInterval = fileSize / 100
|
||||||
|
if callbackInterval == 0 {
|
||||||
|
callbackInterval = 1
|
||||||
|
}
|
||||||
|
return &TaskLocalFile{
|
||||||
|
file: file,
|
||||||
|
size: fileSize,
|
||||||
|
progressCallback: progressCallback,
|
||||||
|
callbackTimes: 100,
|
||||||
|
nextCallbackAt: callbackInterval,
|
||||||
|
callbackInterval: callbackInterval,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type ProgressStream struct {
|
||||||
|
writer io.Writer
|
||||||
|
size int64
|
||||||
|
done int64
|
||||||
|
callback func(bytesRead, contentLength int64)
|
||||||
|
nextAt int64
|
||||||
|
interval int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ps *ProgressStream) Write(p []byte) (n int, err error) {
|
||||||
|
n, err = ps.writer.Write(p)
|
||||||
|
if err != nil {
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
ps.done += int64(n)
|
||||||
|
if ps.callback != nil && ps.done >= ps.nextAt {
|
||||||
|
ps.callback(ps.done, ps.size)
|
||||||
|
ps.nextAt += ps.interval
|
||||||
|
}
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewProgressStream(writer io.Writer, size int64, callback func(bytesRead, contentLength int64)) *ProgressStream {
|
||||||
|
var interval int64
|
||||||
|
interval = size / 100
|
||||||
|
if interval == 0 {
|
||||||
|
interval = 1
|
||||||
|
}
|
||||||
|
return &ProgressStream{
|
||||||
|
writer: writer,
|
||||||
|
size: size,
|
||||||
|
callback: callback,
|
||||||
|
nextAt: interval,
|
||||||
|
interval: interval,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
19
dao/callback_data.go
Normal file
19
dao/callback_data.go
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
package dao
|
||||||
|
|
||||||
|
func CreateCallbackData(data string) (uint, error) {
|
||||||
|
callbackData := CallbackData{
|
||||||
|
Data: data,
|
||||||
|
}
|
||||||
|
err := db.Create(&callbackData).Error
|
||||||
|
return callbackData.ID, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetCallbackData(id uint) (string, error) {
|
||||||
|
var callbackData CallbackData
|
||||||
|
err := db.First(&callbackData, id).Error
|
||||||
|
return callbackData.Data, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func DeleteCallbackData(id uint) error {
|
||||||
|
return db.Unscoped().Where("id = ?", id).Delete(&CallbackData{}).Error
|
||||||
|
}
|
||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"github.com/glebarez/sqlite"
|
"github.com/glebarez/sqlite"
|
||||||
"github.com/krau/SaveAny-Bot/config"
|
"github.com/krau/SaveAny-Bot/config"
|
||||||
"github.com/krau/SaveAny-Bot/logger"
|
"github.com/krau/SaveAny-Bot/logger"
|
||||||
"github.com/krau/SaveAny-Bot/types"
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
glogger "gorm.io/gorm/logger"
|
glogger "gorm.io/gorm/logger"
|
||||||
)
|
)
|
||||||
@@ -37,7 +36,7 @@ func Init() {
|
|||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
logger.L.Debug("Database connected")
|
logger.L.Debug("Database connected")
|
||||||
if err := db.AutoMigrate(&types.ReceivedFile{}, &types.User{}); err != nil {
|
if err := db.AutoMigrate(&ReceivedFile{}, &User{}, &Dir{}, &CallbackData{}); err != nil {
|
||||||
logger.L.Fatal("迁移数据库失败, 如果您从旧版本升级, 建议手动删除数据库文件后重试: ", err)
|
logger.L.Fatal("迁移数据库失败, 如果您从旧版本升级, 建议手动删除数据库文件后重试: ", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -52,7 +51,7 @@ func syncUsers() error {
|
|||||||
return fmt.Errorf("failed to get users: %w", err)
|
return fmt.Errorf("failed to get users: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
dbUserMap := make(map[int64]types.User)
|
dbUserMap := make(map[int64]User)
|
||||||
for _, u := range dbUsers {
|
for _, u := range dbUsers {
|
||||||
dbUserMap[u.ChatID] = u
|
dbUserMap[u.ChatID] = u
|
||||||
}
|
}
|
||||||
|
|||||||
43
dao/dir.go
Normal file
43
dao/dir.go
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
package dao
|
||||||
|
|
||||||
|
func CreateDirForUser(userID uint, storageName, path string) error {
|
||||||
|
dir := Dir{
|
||||||
|
UserID: userID,
|
||||||
|
StorageName: storageName,
|
||||||
|
Path: path,
|
||||||
|
}
|
||||||
|
return db.Create(&dir).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetDirByID(id uint) (*Dir, error) {
|
||||||
|
dir := &Dir{}
|
||||||
|
err := db.First(dir, id).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return dir, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetUserDirs(userID uint) ([]Dir, error) {
|
||||||
|
var dirs []Dir
|
||||||
|
err := db.Where("user_id = ?", userID).Find(&dirs).Error
|
||||||
|
return dirs, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetUserDirsByChatID(chatID int64) ([]Dir, error) {
|
||||||
|
user, err := GetUserByChatID(chatID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return GetUserDirs(user.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetDirsByUserIDAndStorageName(userID uint, storageName string) ([]Dir, error) {
|
||||||
|
var dirs []Dir
|
||||||
|
err := db.Where("user_id = ? AND storage_name = ?", userID, storageName).Find(&dirs).Error
|
||||||
|
return dirs, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func DeleteDirForUser(userID uint, storageName, path string) error {
|
||||||
|
return db.Unscoped().Where("user_id = ? AND storage_name = ? AND path = ?", userID, storageName, path).Delete(&Dir{}).Error
|
||||||
|
}
|
||||||
10
dao/file.go
10
dao/file.go
@@ -1,8 +1,6 @@
|
|||||||
package dao
|
package dao
|
||||||
|
|
||||||
import "github.com/krau/SaveAny-Bot/types"
|
func SaveReceivedFile(receivedFile *ReceivedFile) error {
|
||||||
|
|
||||||
func SaveReceivedFile(receivedFile *types.ReceivedFile) error {
|
|
||||||
record, err := GetReceivedFileByChatAndMessageID(receivedFile.ChatID, receivedFile.MessageID)
|
record, err := GetReceivedFileByChatAndMessageID(receivedFile.ChatID, receivedFile.MessageID)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
receivedFile.ID = record.ID
|
receivedFile.ID = record.ID
|
||||||
@@ -10,8 +8,8 @@ func SaveReceivedFile(receivedFile *types.ReceivedFile) error {
|
|||||||
return db.Save(receivedFile).Error
|
return db.Save(receivedFile).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetReceivedFileByChatAndMessageID(chatID int64, messageID int) (*types.ReceivedFile, error) {
|
func GetReceivedFileByChatAndMessageID(chatID int64, messageID int) (*ReceivedFile, error) {
|
||||||
var receivedFile types.ReceivedFile
|
var receivedFile ReceivedFile
|
||||||
err := db.Where("chat_id = ? AND message_id = ?", chatID, messageID).First(&receivedFile).Error
|
err := db.Where("chat_id = ? AND message_id = ?", chatID, messageID).First(&receivedFile).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -19,6 +17,6 @@ func GetReceivedFileByChatAndMessageID(chatID int64, messageID int) (*types.Rece
|
|||||||
return &receivedFile, nil
|
return &receivedFile, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteReceivedFile(receivedFile *types.ReceivedFile) error {
|
func DeleteReceivedFile(receivedFile *ReceivedFile) error {
|
||||||
return db.Unscoped().Delete(receivedFile).Error
|
return db.Unscoped().Delete(receivedFile).Error
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package types
|
package dao
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@@ -21,4 +21,17 @@ type User struct {
|
|||||||
ChatID int64 `gorm:"uniqueIndex;not null"`
|
ChatID int64 `gorm:"uniqueIndex;not null"`
|
||||||
Silent bool
|
Silent bool
|
||||||
DefaultStorage string // Default storage name
|
DefaultStorage string // Default storage name
|
||||||
|
Dirs []Dir
|
||||||
|
}
|
||||||
|
|
||||||
|
type Dir struct {
|
||||||
|
gorm.Model
|
||||||
|
UserID uint
|
||||||
|
StorageName string
|
||||||
|
Path string
|
||||||
|
}
|
||||||
|
|
||||||
|
type CallbackData struct {
|
||||||
|
gorm.Model
|
||||||
|
Data string
|
||||||
}
|
}
|
||||||
26
dao/user.go
26
dao/user.go
@@ -1,32 +1,30 @@
|
|||||||
package dao
|
package dao
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/krau/SaveAny-Bot/types"
|
|
||||||
)
|
|
||||||
|
|
||||||
func CreateUser(chatID int64) error {
|
func CreateUser(chatID int64) error {
|
||||||
if _, err := GetUserByChatID(chatID); err == nil {
|
if _, err := GetUserByChatID(chatID); err == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return db.Create(&types.User{ChatID: chatID}).Error
|
return db.Create(&User{ChatID: chatID}).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAllUsers() ([]types.User, error) {
|
func GetAllUsers() ([]User, error) {
|
||||||
var users []types.User
|
var users []User
|
||||||
err := db.Find(&users).Error
|
err := db.Preload("Dirs").Find(&users).Error
|
||||||
return users, err
|
return users, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUserByChatID(chatID int64) (*types.User, error) {
|
func GetUserByChatID(chatID int64) (*User, error) {
|
||||||
var user types.User
|
var user User
|
||||||
err := db.Where("chat_id = ?", chatID).First(&user).Error
|
err := db.
|
||||||
|
Preload("Dirs").
|
||||||
|
Where("chat_id = ?", chatID).First(&user).Error
|
||||||
return &user, err
|
return &user, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateUser(user *types.User) error {
|
func UpdateUser(user *User) error {
|
||||||
return db.Save(user).Error
|
return db.Save(user).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteUser(user *types.User) error {
|
func DeleteUser(user *User) error {
|
||||||
return db.Unscoped().Delete(user).Error
|
return db.Unscoped().Select("Dirs").Delete(user).Error
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,4 +7,7 @@ services:
|
|||||||
- ./data:/app/data
|
- ./data:/app/data
|
||||||
- ./config.toml:/app/config.toml
|
- ./config.toml:/app/config.toml
|
||||||
- ./downloads:/app/downloads
|
- ./downloads:/app/downloads
|
||||||
- ./cache:/app/cache
|
- ./cache:/app/cache
|
||||||
|
# 使用 host 模式以便访问宿主机服务 (如代理)
|
||||||
|
# 如果你对 Docker 网络模式熟悉, 可以自行修改
|
||||||
|
network_mode: host
|
||||||
94
docs/docs/deploy.md
Normal file
94
docs/docs/deploy.md
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
# 部署指南
|
||||||
|
|
||||||
|
## 从二进制文件部署
|
||||||
|
|
||||||
|
在 [Release](https://github.com/krau/SaveAny-Bot/releases) 页面下载对应平台的二进制文件.
|
||||||
|
|
||||||
|
在解压后目录新建 `config.toml` 文件, 参考 [config.example.toml](./config.example.toml) 编辑配置文件.
|
||||||
|
|
||||||
|
运行:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
chmod +x saveany-bot
|
||||||
|
./saveany-bot
|
||||||
|
```
|
||||||
|
|
||||||
|
### 添加为 systemd 服务
|
||||||
|
|
||||||
|
创建文件 `/etc/systemd/system/saveany-bot.service` 并写入以下内容:
|
||||||
|
|
||||||
|
```
|
||||||
|
[Unit]
|
||||||
|
Description=SaveAnyBot
|
||||||
|
After=systemd-user-sessions.service
|
||||||
|
|
||||||
|
[Service]
|
||||||
|
Type=simple
|
||||||
|
WorkingDirectory=/yourpath/
|
||||||
|
ExecStart=/yourpath/saveany-bot
|
||||||
|
Restart=on-failure
|
||||||
|
|
||||||
|
[Install]
|
||||||
|
WantedBy=multi-user.target
|
||||||
|
```
|
||||||
|
|
||||||
|
设为开机启动并启动服务:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
systemctl enable --now saveany-bot
|
||||||
|
```
|
||||||
|
|
||||||
|
### 为OpenWrt及衍生系统添加开机自启动服务
|
||||||
|
|
||||||
|
创建文件 ` /etc/init.d/saveanybot` ,参考[saveanybot](./docs/saveanybot)自行修改.
|
||||||
|
|
||||||
|
`chmod +x /etc/init.d/saveanybot`
|
||||||
|
|
||||||
|
完成后,将文件复制到 `/etc/rc.d`并重命名为`S99saveanybot`.
|
||||||
|
|
||||||
|
`chmod +x /etc/rc.d/S99saveanybot`
|
||||||
|
|
||||||
|
### 为OpenWrt及衍生系统添加快捷指令
|
||||||
|
|
||||||
|
创建文件` /usr/bin/sabot` ,参考[sabot](./docs/sabot)自行配置修改,注意此处文件编码仅支持 ANSI 936 .
|
||||||
|
|
||||||
|
`chmod +x /usr/bin/sabot`
|
||||||
|
|
||||||
|
之后,终端输入`sabot start|stop|restart|status|enable|disable`即可.
|
||||||
|
|
||||||
|
|
||||||
|
## 使用 Docker 部署
|
||||||
|
|
||||||
|
### Docker Compose
|
||||||
|
|
||||||
|
下载 [docker-compose.yml](./docker-compose.yml) 文件, 在同目录下新建 `config.toml` 文件, 参考 [config.example.toml](./config.example.toml) 编辑配置文件.
|
||||||
|
|
||||||
|
启动:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker compose up -d
|
||||||
|
```
|
||||||
|
|
||||||
|
### Docker
|
||||||
|
|
||||||
|
```shell
|
||||||
|
docker run -d --name saveany-bot \
|
||||||
|
-v /path/to/config.toml:/app/config.toml \
|
||||||
|
-v /path/to/downloads:/app/downloads \
|
||||||
|
ghcr.io/krau/saveany-bot:latest
|
||||||
|
```
|
||||||
|
|
||||||
|
## 更新
|
||||||
|
|
||||||
|
使用 `upgrade` 或 `up` 升级到最新版
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./saveany-bot upgrade
|
||||||
|
```
|
||||||
|
|
||||||
|
如果是 Docker 部署, 使用以下命令更新:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker pull ghcr.io/krau/saveany-bot:latest
|
||||||
|
docker restart saveany-bot
|
||||||
|
```
|
||||||
20
docs/docs/faq.md
Normal file
20
docs/docs/faq.md
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
# 常见问题
|
||||||
|
|
||||||
|
## 上传 alist 失败也会显示成功
|
||||||
|
|
||||||
|
这是 alist 的上传实现导致的问题, 上传到 alist 的文件实际上会被 alist 暂存在本地, 在客户端上传结束后 alist 就返回成功, 然后 alist 会在后台将文件上传到对应的存储.
|
||||||
|
|
||||||
|
目前 bot 是根据 alist 的返回判断是否成功, 无法获知 alist 的后台上传任务是否成功.
|
||||||
|
|
||||||
|
在 alist 管理页面适当调整上传分片大小, 为 alist 使用更稳定的网络环境部署, 都可以减少这种情况的发生.
|
||||||
|
|
||||||
|
## Bot 提示下载成功但是 alist 未显示
|
||||||
|
|
||||||
|
检查 alist 后台 > 任务 > 上传 中对应的上传任务的状态, 如果任务状态为成功但目录中不显示, 是由于 alist 缓存了目录结构, 参考文档可以调整缓存时间
|
||||||
|
|
||||||
|
https://alist.nn.ci/zh/guide/drivers/common.html#缓存过期
|
||||||
|
|
||||||
|
## docker部署配置了代理后仍无法连接 telegram (初始化客户端超时)
|
||||||
|
|
||||||
|
docker 不能直接访问宿主机网络, 如果你不熟悉其用法, 请将容器设为 host 模式:
|
||||||
|
|
||||||
35
docs/docs/help.md
Normal file
35
docs/docs/help.md
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
# 使用帮助
|
||||||
|
|
||||||
|
## 保存文件
|
||||||
|
|
||||||
|
Bot 接受两种消息: 文件和链接.
|
||||||
|
|
||||||
|
目前, 链接仅支持公开频道 (具有用户名) 的链接, 例如: `https://t.me/acherkrau/1097`.
|
||||||
|
|
||||||
|
**即使频道禁止了转发和保存, Bot 依然可以下载其文件.**
|
||||||
|
|
||||||
|
## 静默模式 (silent)
|
||||||
|
|
||||||
|
使用 `/silent` 命令可以开关静默模式.
|
||||||
|
|
||||||
|
默认情况下不开启静默模式, Bot 会询问你每个文件的保存位置.
|
||||||
|
|
||||||
|
开启静默模式后, Bot 会直接保存文件到默认位置, 无需确认.
|
||||||
|
|
||||||
|
在开启静默模式之前, 需要使用 `/storage` 命令设置默认保存位置.
|
||||||
|
|
||||||
|
## Stream 模式
|
||||||
|
|
||||||
|
在配置文件中将 `stream` 设置为 `true` 可以开启 Stream 模式.
|
||||||
|
|
||||||
|
未开启时, Bot 处理任务分为两步: 下载和上传. Bot 会将文件暂存到本地, 然后上传到对应存储位置, 最后删除本地文件.
|
||||||
|
|
||||||
|
开启后, Bot 将直接将文件流式传输到存储端, 不需要下载到本地.
|
||||||
|
|
||||||
|
该功能对于硬盘空间有限的部署环境十分有用, 然而相较于普通模式也具有一些弊端:
|
||||||
|
|
||||||
|
- 无法使用多线程从 telegram 下载文件, 速度较慢.
|
||||||
|
- 网络不稳定时, 任务失败率高.
|
||||||
|
- 无法在中间层对文件进行处理, 例如自动文件类型识别.
|
||||||
|
|
||||||
|
虽然目前 Bot 适配的所有存储端 (Alist, 本地磁盘, Webdav) 都支持 Stream 模式, 但今后可能会有不支持的存储端, 此时即使开启 Stream 模式, Bot 也会自动切换到普通模式.
|
||||||
7
docs/docs/index.md
Normal file
7
docs/docs/index.md
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
# SaveAnyBot 文档
|
||||||
|
|
||||||
|
SaveAnyBot 是一个可以保存 Telegram 上的文件到云存储的机器人, 就像 PikPak Bot 一样.
|
||||||
|
|
||||||
|
不同的是, SaveAnyBot 提供更灵活的存储端选择, 并实现一些更强大的功能.
|
||||||
|
|
||||||
|
本项目以 AGPL-3.0 协议开源, 请遵守协议使用.
|
||||||
33
docs/mkdocs.yml
Normal file
33
docs/mkdocs.yml
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
site_name: SaveAnyBot 官方文档
|
||||||
|
site_author: Krau
|
||||||
|
site_description: SaveAnyBot 是一个可以保存 Telegram 上的文件到多种云存储的机器人, 本文档将帮助你了解如何部署和使用它.
|
||||||
|
repo_name: krau/saveany-bot
|
||||||
|
repo_url: https://github.com/krau/saveany-bot
|
||||||
|
copyright: CC BY-NC-SA 4.0
|
||||||
|
theme:
|
||||||
|
name: material
|
||||||
|
language: zh
|
||||||
|
highlightjs: true
|
||||||
|
palette:
|
||||||
|
- media: "(prefers-color-scheme)"
|
||||||
|
toggle:
|
||||||
|
icon: material/brightness-auto
|
||||||
|
name: 切换主题
|
||||||
|
- media: "(prefers-color-scheme: light)"
|
||||||
|
scheme: default
|
||||||
|
primary: indigo
|
||||||
|
toggle:
|
||||||
|
icon: material/brightness-7
|
||||||
|
name: 暗色模式
|
||||||
|
- media: "(prefers-color-scheme: dark)"
|
||||||
|
scheme: slate
|
||||||
|
primary: blue grey
|
||||||
|
toggle:
|
||||||
|
icon: material/brightness-4
|
||||||
|
name: 亮色模式
|
||||||
|
|
||||||
|
nav:
|
||||||
|
- index.md
|
||||||
|
- deploy.md
|
||||||
|
- help.md
|
||||||
|
- faq.md
|
||||||
28
docs/sabot
Normal file
28
docs/sabot
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
#!/bin/sh
|
||||||
|
|
||||||
|
case "$1" in
|
||||||
|
start)
|
||||||
|
/etc/init.d/saveanybot start
|
||||||
|
;;
|
||||||
|
stop)
|
||||||
|
/etc/init.d/saveanybot stop
|
||||||
|
;;
|
||||||
|
restart)
|
||||||
|
/etc/init.d/saveanybot restart
|
||||||
|
;;
|
||||||
|
status)
|
||||||
|
/etc/init.d/saveanybot status
|
||||||
|
;;
|
||||||
|
enable)
|
||||||
|
/etc/init.d/saveanybot enable
|
||||||
|
echo "Enable SaveAnyBot auto-start."
|
||||||
|
;;
|
||||||
|
disable)
|
||||||
|
/etc/init.d/saveanybot disable
|
||||||
|
echo "Disable SaveAnyBot auto-start."
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "Usage: $0 {start|stop|restart|status|enable|disable}"
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
esac
|
||||||
34
docs/saveanybot
Normal file
34
docs/saveanybot
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
#!/bin/sh /etc/rc.common
|
||||||
|
|
||||||
|
# This is the OpenWRT init.d script for SaveAnyBot
|
||||||
|
|
||||||
|
START=99 # 设置启动顺序,数字越大越后启动
|
||||||
|
STOP=10 # 设置停止顺序,数字越小越先停止
|
||||||
|
|
||||||
|
# 脚本描述
|
||||||
|
description="SaveAnyBot"
|
||||||
|
|
||||||
|
# 设置工作目录和执行文件路径
|
||||||
|
WORKING_DIR="/mnt/mmc1-1/SaveAnyBot"
|
||||||
|
EXEC_PATH="$WORKING_DIR/saveany-bot"
|
||||||
|
|
||||||
|
# 启动函数
|
||||||
|
start() {
|
||||||
|
echo "Starting SaveAnyBot..."
|
||||||
|
# 切换到工作目录并执行程序
|
||||||
|
cd $WORKING_DIR
|
||||||
|
$EXEC_PATH &
|
||||||
|
}
|
||||||
|
|
||||||
|
# 停止函数
|
||||||
|
stop() {
|
||||||
|
echo "Stopping SaveAnyBot..."
|
||||||
|
# 查找并杀死进程
|
||||||
|
killall saveany-bot
|
||||||
|
}
|
||||||
|
|
||||||
|
# 重启函数
|
||||||
|
reload() {
|
||||||
|
stop
|
||||||
|
start
|
||||||
|
}
|
||||||
1
go.mod
1
go.mod
@@ -12,7 +12,6 @@ require (
|
|||||||
github.com/rhysd/go-github-selfupdate v1.2.3
|
github.com/rhysd/go-github-selfupdate v1.2.3
|
||||||
github.com/spf13/cobra v1.8.1
|
github.com/spf13/cobra v1.8.1
|
||||||
github.com/spf13/viper v1.19.0
|
github.com/spf13/viper v1.19.0
|
||||||
github.com/studio-b12/gowebdav v0.10.0
|
|
||||||
golang.org/x/net v0.35.0
|
golang.org/x/net v0.35.0
|
||||||
golang.org/x/time v0.10.0
|
golang.org/x/time v0.10.0
|
||||||
)
|
)
|
||||||
|
|||||||
2
go.sum
2
go.sum
@@ -172,8 +172,6 @@ github.com/spf13/viper v1.19.0 h1:RWq5SEjt8o25SROyN3z2OrDB9l7RPd3lwTWU8EcEdcI=
|
|||||||
github.com/spf13/viper v1.19.0/go.mod h1:GQUN9bilAbhU/jgc1bKs99f/suXKeUMct8Adx5+Ntkg=
|
github.com/spf13/viper v1.19.0/go.mod h1:GQUN9bilAbhU/jgc1bKs99f/suXKeUMct8Adx5+Ntkg=
|
||||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
github.com/studio-b12/gowebdav v0.10.0 h1:Yewz8FFiadcGEu4hxS/AAJQlHelndqln1bns3hcJIYc=
|
|
||||||
github.com/studio-b12/gowebdav v0.10.0/go.mod h1:bHA7t77X/QFExdeAnDzK6vKM34kEZAcE1OX4MfiwjkE=
|
|
||||||
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
||||||
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
||||||
github.com/tcnksm/go-gitconfig v0.1.2 h1:iiDhRitByXAEyjgBqsKi9QU4o2TNtv9kPP3RgPgXBPw=
|
github.com/tcnksm/go-gitconfig v0.1.2 h1:iiDhRitByXAEyjgBqsKi9QU4o2TNtv9kPP3RgPgXBPw=
|
||||||
|
|||||||
@@ -8,30 +8,65 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type TaskQueue struct {
|
type TaskQueue struct {
|
||||||
list *list.List
|
list *list.List
|
||||||
cond *sync.Cond
|
cond *sync.Cond
|
||||||
mutex *sync.Mutex
|
mutex *sync.Mutex
|
||||||
|
activeMap map[string]*types.Task
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *TaskQueue) AddTask(task types.Task) {
|
func (q *TaskQueue) AddTask(task *types.Task) {
|
||||||
q.mutex.Lock()
|
q.mutex.Lock()
|
||||||
defer q.mutex.Unlock()
|
defer q.mutex.Unlock()
|
||||||
q.list.PushBack(task)
|
q.list.PushBack(task)
|
||||||
q.cond.Signal()
|
q.cond.Signal()
|
||||||
|
if task.Status != types.Pending {
|
||||||
|
delete(q.activeMap, task.Key())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *TaskQueue) GetTask() types.Task {
|
func (q *TaskQueue) GetTask() *types.Task {
|
||||||
q.mutex.Lock()
|
q.mutex.Lock()
|
||||||
defer q.mutex.Unlock()
|
defer q.mutex.Unlock()
|
||||||
for q.list.Len() == 0 {
|
for q.list.Len() == 0 {
|
||||||
q.cond.Wait()
|
q.cond.Wait()
|
||||||
}
|
}
|
||||||
e := q.list.Front()
|
e := q.list.Front()
|
||||||
task := e.Value.(types.Task)
|
task := e.Value.(*types.Task)
|
||||||
q.list.Remove(e)
|
q.list.Remove(e)
|
||||||
|
if task.Status == types.Pending {
|
||||||
|
q.activeMap[task.Key()] = task
|
||||||
|
}
|
||||||
return task
|
return task
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (q *TaskQueue) DoneTask(task *types.Task) {
|
||||||
|
q.mutex.Lock()
|
||||||
|
defer q.mutex.Unlock()
|
||||||
|
delete(q.activeMap, task.Key())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *TaskQueue) CancelTask(key string) bool {
|
||||||
|
q.mutex.Lock()
|
||||||
|
defer q.mutex.Unlock()
|
||||||
|
if task, ok := q.activeMap[key]; ok {
|
||||||
|
if task.Cancel != nil {
|
||||||
|
task.Cancel()
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for e := q.list.Front(); e != nil; e = e.Next() {
|
||||||
|
task := e.Value.(*types.Task)
|
||||||
|
if task.Key() == key {
|
||||||
|
if task.Cancel != nil {
|
||||||
|
task.Cancel()
|
||||||
|
}
|
||||||
|
q.list.Remove(e)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func (q *TaskQueue) Len() int {
|
func (q *TaskQueue) Len() int {
|
||||||
q.mutex.Lock()
|
q.mutex.Lock()
|
||||||
defer q.mutex.Unlock()
|
defer q.mutex.Unlock()
|
||||||
@@ -47,20 +82,29 @@ func init() {
|
|||||||
func NewQueue() *TaskQueue {
|
func NewQueue() *TaskQueue {
|
||||||
m := &sync.Mutex{}
|
m := &sync.Mutex{}
|
||||||
return &TaskQueue{
|
return &TaskQueue{
|
||||||
list: list.New(),
|
list: list.New(),
|
||||||
cond: sync.NewCond(m),
|
cond: sync.NewCond(m),
|
||||||
mutex: m,
|
mutex: m,
|
||||||
|
activeMap: make(map[string]*types.Task),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func AddTask(task types.Task) {
|
func AddTask(task *types.Task) {
|
||||||
Queue.AddTask(task)
|
Queue.AddTask(task)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetTask() types.Task {
|
func GetTask() *types.Task {
|
||||||
return Queue.GetTask()
|
return Queue.GetTask()
|
||||||
}
|
}
|
||||||
|
|
||||||
func Len() int {
|
func Len() int {
|
||||||
return Queue.Len()
|
return Queue.Len()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func CancelTask(key string) bool {
|
||||||
|
return Queue.CancelTask(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func DoneTask(task *types.Task) {
|
||||||
|
Queue.DoneTask(task)
|
||||||
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/krau/SaveAny-Bot/config"
|
"github.com/krau/SaveAny-Bot/config"
|
||||||
@@ -98,6 +99,7 @@ func (a *Alist) Name() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Alist) Save(ctx context.Context, filePath, storagePath string) error {
|
func (a *Alist) Save(ctx context.Context, filePath, storagePath string) error {
|
||||||
|
logger.L.Infof("Saving file %s to %s", filePath, storagePath)
|
||||||
file, err := os.Open(filePath)
|
file, err := os.Open(filePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to open file: %w", err)
|
return fmt.Errorf("failed to open file: %w", err)
|
||||||
@@ -149,3 +151,88 @@ func (a *Alist) Save(ctx context.Context, filePath, storagePath string) error {
|
|||||||
func (a *Alist) JoinStoragePath(task types.Task) string {
|
func (a *Alist) JoinStoragePath(task types.Task) string {
|
||||||
return path.Join(a.config.BasePath, task.StoragePath)
|
return path.Join(a.config.BasePath, task.StoragePath)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type uploadStream struct {
|
||||||
|
ctx context.Context
|
||||||
|
client *http.Client
|
||||||
|
token string
|
||||||
|
storagePath string
|
||||||
|
baseURL string
|
||||||
|
pr *io.PipeReader
|
||||||
|
pw *io.PipeWriter
|
||||||
|
errChan chan error
|
||||||
|
once sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
func (us *uploadStream) Write(p []byte) (int, error) {
|
||||||
|
return us.pw.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (us *uploadStream) Close() error {
|
||||||
|
var uploadErr error
|
||||||
|
us.once.Do(func() {
|
||||||
|
if err := us.pw.Close(); err != nil {
|
||||||
|
uploadErr = fmt.Errorf("failed to close pipe writer: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := <-us.errChan; err != nil {
|
||||||
|
uploadErr = err
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return uploadErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Alist) NewUploadStream(ctx context.Context, storagePath string) (io.WriteCloser, error) {
|
||||||
|
if a.token == "" {
|
||||||
|
if err := a.getToken(); err != nil {
|
||||||
|
return nil, fmt.Errorf("not logged in to Alist: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
|
||||||
|
// 创建上传流对象
|
||||||
|
us := &uploadStream{
|
||||||
|
ctx: ctx,
|
||||||
|
client: a.client,
|
||||||
|
token: a.token,
|
||||||
|
storagePath: storagePath,
|
||||||
|
baseURL: a.baseURL,
|
||||||
|
pr: pr,
|
||||||
|
pw: pw,
|
||||||
|
errChan: make(chan error, 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer close(us.errChan)
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPut, a.baseURL+"/api/fs/put", pr)
|
||||||
|
if err != nil {
|
||||||
|
us.errChan <- fmt.Errorf("failed to create request: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Authorization", a.token)
|
||||||
|
req.Header.Set("File-Path", url.PathEscape(storagePath))
|
||||||
|
req.Header.Set("As-Task", "true")
|
||||||
|
req.Header.Set("Content-Type", "application/octet-stream")
|
||||||
|
|
||||||
|
resp, err := a.client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
us.errChan <- fmt.Errorf("failed to send request: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
us.errChan <- fmt.Errorf("failed to upload file, status code: %d, response: %s", resp.StatusCode, string(body))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
us.errChan <- nil
|
||||||
|
}()
|
||||||
|
|
||||||
|
return us, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,11 +3,13 @@ package local
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
||||||
"github.com/duke-git/lancet/v2/fileutil"
|
"github.com/duke-git/lancet/v2/fileutil"
|
||||||
"github.com/krau/SaveAny-Bot/config"
|
"github.com/krau/SaveAny-Bot/config"
|
||||||
|
"github.com/krau/SaveAny-Bot/logger"
|
||||||
"github.com/krau/SaveAny-Bot/types"
|
"github.com/krau/SaveAny-Bot/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -40,6 +42,7 @@ func (l *Local) Name() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (l *Local) Save(ctx context.Context, filePath, storagePath string) error {
|
func (l *Local) Save(ctx context.Context, filePath, storagePath string) error {
|
||||||
|
logger.L.Infof("Saving file %s to %s", filePath, storagePath)
|
||||||
absPath, err := filepath.Abs(storagePath)
|
absPath, err := filepath.Abs(storagePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -53,3 +56,18 @@ func (l *Local) Save(ctx context.Context, filePath, storagePath string) error {
|
|||||||
func (l *Local) JoinStoragePath(task types.Task) string {
|
func (l *Local) JoinStoragePath(task types.Task) string {
|
||||||
return filepath.Join(l.config.BasePath, task.StoragePath)
|
return filepath.Join(l.config.BasePath, task.StoragePath)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (l *Local) NewUploadStream(ctx context.Context, path string) (io.WriteCloser, error) {
|
||||||
|
absPath, err := filepath.Abs(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := fileutil.CreateDir(filepath.Dir(absPath)); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
file, err := os.Create(absPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return file, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package storage
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
|
|
||||||
"github.com/krau/SaveAny-Bot/config"
|
"github.com/krau/SaveAny-Bot/config"
|
||||||
"github.com/krau/SaveAny-Bot/logger"
|
"github.com/krau/SaveAny-Bot/logger"
|
||||||
@@ -20,8 +21,15 @@ type Storage interface {
|
|||||||
Save(cttx context.Context, localFilePath, storagePath string) error
|
Save(cttx context.Context, localFilePath, storagePath string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type StreamStorage interface {
|
||||||
|
Storage
|
||||||
|
NewUploadStream(ctx context.Context, path string) (io.WriteCloser, error)
|
||||||
|
}
|
||||||
|
|
||||||
var Storages = make(map[string]Storage)
|
var Storages = make(map[string]Storage)
|
||||||
|
|
||||||
|
var UserStorages = make(map[int64][]Storage)
|
||||||
|
|
||||||
// GetStorageByName returns storage by name from cache or creates new one
|
// GetStorageByName returns storage by name from cache or creates new one
|
||||||
func GetStorageByName(name string) (Storage, error) {
|
func GetStorageByName(name string) (Storage, error) {
|
||||||
if name == "" {
|
if name == "" {
|
||||||
@@ -59,6 +67,12 @@ func GetStorageByUserIDAndName(chatID int64, name string) (Storage, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetUserStorages(chatID int64) []Storage {
|
func GetUserStorages(chatID int64) []Storage {
|
||||||
|
if chatID <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if storages, ok := UserStorages[chatID]; ok {
|
||||||
|
return storages
|
||||||
|
}
|
||||||
var storages []Storage
|
var storages []Storage
|
||||||
for _, name := range config.Cfg.GetStorageNamesByUserID(chatID) {
|
for _, name := range config.Cfg.GetStorageNamesByUserID(chatID) {
|
||||||
storage, err := GetStorageByName(name)
|
storage, err := GetStorageByName(name)
|
||||||
@@ -101,4 +115,7 @@ func LoadStorages() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
logger.L.Infof("成功加载 %d 个存储", len(Storages))
|
logger.L.Infof("成功加载 %d 个存储", len(Storages))
|
||||||
|
for user := range config.Cfg.GetUsersID() {
|
||||||
|
UserStorages[int64(user)] = GetUserStorages(int64(user))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
70
storage/webdav/client.go
Normal file
70
storage/webdav/client.go
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
package webdav
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Client struct {
|
||||||
|
BaseURL string
|
||||||
|
Username string
|
||||||
|
Password string
|
||||||
|
httpClient *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewClient(baseURL, username, password string, httpClient *http.Client) *Client {
|
||||||
|
if !strings.HasSuffix(baseURL, "/") {
|
||||||
|
baseURL += "/"
|
||||||
|
}
|
||||||
|
if httpClient == nil {
|
||||||
|
httpClient = http.DefaultClient
|
||||||
|
}
|
||||||
|
return &Client{
|
||||||
|
BaseURL: baseURL,
|
||||||
|
Username: username,
|
||||||
|
Password: password,
|
||||||
|
httpClient: httpClient,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) doRequest(ctx context.Context, method, url string, body io.Reader) (*http.Response, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, method, url, body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if c.Username != "" && c.Password != "" {
|
||||||
|
req.SetBasicAuth(c.Username, c.Password)
|
||||||
|
}
|
||||||
|
return c.httpClient.Do(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) MkDir(ctx context.Context, dirPath string) error {
|
||||||
|
url := c.BaseURL + dirPath
|
||||||
|
resp, err := c.doRequest(ctx, "MKCOL", url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("MKCOL: %s", resp.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) WriteFile(ctx context.Context, remotePath string, content io.Reader) error {
|
||||||
|
url := c.BaseURL + remotePath
|
||||||
|
resp, err := c.doRequest(ctx, "PUT", url, content)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("PUT: %s", resp.Status)
|
||||||
|
}
|
||||||
58
storage/webdav/stream.go
Normal file
58
storage/webdav/stream.go
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
package webdav
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"path"
|
||||||
|
|
||||||
|
"github.com/krau/SaveAny-Bot/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
type WebdavWriter struct {
|
||||||
|
pipeWriter *io.PipeWriter
|
||||||
|
done chan error
|
||||||
|
path string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *WebdavWriter) Write(p []byte) (n int, err error) {
|
||||||
|
return w.pipeWriter.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *WebdavWriter) Close() error {
|
||||||
|
if err := w.pipeWriter.Close(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := <-w.done; err != nil {
|
||||||
|
return fmt.Errorf("upload failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Webdav) NewUploadStream(ctx context.Context, storagePath string) (io.WriteCloser, error) {
|
||||||
|
if err := w.client.MkDir(ctx, path.Dir(storagePath)); err != nil {
|
||||||
|
logger.L.Errorf("Failed to create directory %s: %v", path.Dir(storagePath), err)
|
||||||
|
return nil, ErrFailedToCreateDirectory
|
||||||
|
}
|
||||||
|
pipeReader, pipeWriter := io.Pipe()
|
||||||
|
done := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
if err := recover(); err != nil {
|
||||||
|
done <- fmt.Errorf("panic during upload: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := w.client.WriteFile(ctx, storagePath, pipeReader)
|
||||||
|
|
||||||
|
pipeReader.Close()
|
||||||
|
done <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
return &WebdavWriter{
|
||||||
|
pipeWriter: pipeWriter,
|
||||||
|
done: done,
|
||||||
|
path: storagePath,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
@@ -3,6 +3,7 @@ package webdav
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
"time"
|
"time"
|
||||||
@@ -10,12 +11,11 @@ import (
|
|||||||
"github.com/krau/SaveAny-Bot/config"
|
"github.com/krau/SaveAny-Bot/config"
|
||||||
"github.com/krau/SaveAny-Bot/logger"
|
"github.com/krau/SaveAny-Bot/logger"
|
||||||
"github.com/krau/SaveAny-Bot/types"
|
"github.com/krau/SaveAny-Bot/types"
|
||||||
"github.com/studio-b12/gowebdav"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Webdav struct {
|
type Webdav struct {
|
||||||
config config.WebdavStorageConfig
|
config config.WebdavStorageConfig
|
||||||
client *gowebdav.Client
|
client *Client
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *Webdav) Init(cfg config.StorageConfig) error {
|
func (w *Webdav) Init(cfg config.StorageConfig) error {
|
||||||
@@ -27,12 +27,9 @@ func (w *Webdav) Init(cfg config.StorageConfig) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
w.config = *webdavConfig
|
w.config = *webdavConfig
|
||||||
client := gowebdav.NewClient(webdavConfig.URL, webdavConfig.Username, webdavConfig.Password)
|
w.client = NewClient(w.config.URL, w.config.Username, w.config.Password, &http.Client{
|
||||||
if err := client.Connect(); err != nil {
|
Timeout: time.Hour * 12,
|
||||||
return fmt.Errorf("failed to connect to webdav server: %w", err)
|
})
|
||||||
}
|
|
||||||
client.SetTimeout(12 * time.Hour)
|
|
||||||
w.client = client
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -45,7 +42,8 @@ func (w *Webdav) Name() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (w *Webdav) Save(ctx context.Context, filePath, storagePath string) error {
|
func (w *Webdav) Save(ctx context.Context, filePath, storagePath string) error {
|
||||||
if err := w.client.MkdirAll(path.Dir(storagePath), os.ModePerm); err != nil {
|
logger.L.Infof("Saving file %s to %s", filePath, storagePath)
|
||||||
|
if err := w.client.MkDir(ctx, path.Dir(storagePath)); err != nil {
|
||||||
logger.L.Errorf("Failed to create directory %s: %v", path.Dir(storagePath), err)
|
logger.L.Errorf("Failed to create directory %s: %v", path.Dir(storagePath), err)
|
||||||
return ErrFailedToCreateDirectory
|
return ErrFailedToCreateDirectory
|
||||||
}
|
}
|
||||||
@@ -56,7 +54,7 @@ func (w *Webdav) Save(ctx context.Context, filePath, storagePath string) error {
|
|||||||
}
|
}
|
||||||
defer file.Close()
|
defer file.Close()
|
||||||
|
|
||||||
if err := w.client.WriteStream(storagePath, file, os.ModePerm); err != nil {
|
if err := w.client.WriteFile(ctx, storagePath, file); err != nil {
|
||||||
logger.L.Errorf("Failed to write file %s: %v", storagePath, err)
|
logger.L.Errorf("Failed to write file %s: %v", storagePath, err)
|
||||||
return ErrFailedToWriteFile
|
return ErrFailedToWriteFile
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ var StorageTypeDisplay = map[StorageType]string{
|
|||||||
|
|
||||||
type Task struct {
|
type Task struct {
|
||||||
Ctx context.Context
|
Ctx context.Context
|
||||||
|
Cancel context.CancelFunc
|
||||||
Error error
|
Error error
|
||||||
Status TaskStatus
|
Status TaskStatus
|
||||||
File *File
|
File *File
|
||||||
@@ -52,6 +53,10 @@ type Task struct {
|
|||||||
UserID int64
|
UserID int64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t Task) Key() string {
|
||||||
|
return fmt.Sprintf("%d:%d", t.FileChatID, t.FileMessageID)
|
||||||
|
}
|
||||||
|
|
||||||
func (t Task) String() string {
|
func (t Task) String() string {
|
||||||
return fmt.Sprintf("[%d:%d]:%s", t.FileChatID, t.FileMessageID, t.File.FileName)
|
return fmt.Sprintf("[%d:%d]:%s", t.FileChatID, t.FileMessageID, t.File.FileName)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user