Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6d315f7af2 | ||
|
|
5352491c76 | ||
|
|
3f914f7a64 | ||
|
|
8972d8a169 | ||
|
|
1339c69dbf | ||
|
|
63aeabb39b | ||
|
|
e60e983229 | ||
|
|
75e5fd10ea | ||
|
|
c8d8a2e0eb | ||
|
|
044e732084 | ||
|
|
0e951f641c | ||
|
|
8dd6265d55 |
6
.github/ISSUE_TEMPLATE/bug.yml
vendored
6
.github/ISSUE_TEMPLATE/bug.yml
vendored
@@ -5,6 +5,12 @@ labels:
|
||||
assignees:
|
||||
- krau
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
# Please Search Before Submitting / 提交前请搜索
|
||||
Please make sure to search existing issues before submitting a new bug report.
|
||||
提交新的 Bug 报告前请务必搜索已有的 issue,避免重复
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: "👾 Description"
|
||||
|
||||
66
.github/ISSUE_TEMPLATE/feature.yml
vendored
66
.github/ISSUE_TEMPLATE/feature.yml
vendored
@@ -8,7 +8,69 @@ body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
# Please describe the feature you want in detail
|
||||
Please describe the feature you want in detail.
|
||||
请详细描述你想要的功能。
|
||||
|
||||
---
|
||||
|
||||
## ⚠️ IMPORTANT NOTICE / 说明
|
||||
|
||||
Save Any Bot supports multiple storage backends, **including Telegram**.
|
||||
However, **all backends are treated equally**, keep this in mind when submitting feature requests.
|
||||
|
||||
Save Any Bot 支持多种存储后端,**包括 Telegram**。
|
||||
但**所有后端在设计上是平等的**,请在提出功能请求前务必理解这一点。
|
||||
|
||||
### ❌ Out of scope requests / 不在项目范围内的请求
|
||||
The following requests are **out of scope** and will be closed without discussion:
|
||||
|
||||
以下请求**不属于本项目设计范围**,将被直接关闭,不再讨论:
|
||||
|
||||
- Adding **Telegram-specific behaviors or exceptions**
|
||||
添加 **仅针对 Telegram 的特殊行为或例外逻辑**
|
||||
- Treating Telegram as anything other than a **generic file storage backend**
|
||||
将 Telegram 视为非“通用文件存储后端”的特殊存在
|
||||
- Saving or syncing **non-file content** (text messages, chat history, etc.)
|
||||
保存或同步 **非文件内容**(文本消息、聊天记录等)
|
||||
- Preserving or reconstructing original messages (e.g. 1:1 forwarding)
|
||||
保留或还原原始消息形态(例如 1:1 转发)
|
||||
- Perform special reprocessing on files to adapt to specific storage backends
|
||||
(e.g. splitting, re-encoding, transforming, etc.)
|
||||
为适配特定存储后端而对文件进行特殊处理
|
||||
(如分割、转码、重编码、转换格式等)
|
||||
- Any request that requires different logic *only because the backend is Telegram*
|
||||
任何**仅因后端是 Telegram 而需要不同逻辑**的请求
|
||||
|
||||
### ❌ Abuse-leaning or high-risk requests / 滥用倾向的请求
|
||||
Requests that may **enable or encourage** the following will NOT be accepted:
|
||||
|
||||
可能**促成或鼓励**以下行为的请求将不会被接受:
|
||||
|
||||
- Violating Telegram Terms of Service
|
||||
违反 Telegram 服务条款
|
||||
- Building traffic, mirror, or profit-oriented channels using third-party content
|
||||
利用第三方内容构建引流、镜像或牟利用途的频道
|
||||
|
||||
### ⚖️ Design principle / 设计原则
|
||||
Save Any Bot follows a **backend-agnostic design**:
|
||||
|
||||
Save Any Bot 遵循 **后端无关(backend-agnostic)** 的设计原则:
|
||||
|
||||
- If a feature cannot be implemented **uniformly across all backends**, it will not be added.
|
||||
如果某个功能无法在 **所有后端** 中统一实现,则不会被添加。
|
||||
- No backend-specific hacks or special cases will be introduced.
|
||||
不会引入任何后端特有的 hack 或特殊处理逻辑。
|
||||
|
||||
---
|
||||
|
||||
If your request falls into any of the categories above, please do not open an issue.
|
||||
Such issues will be closed.
|
||||
|
||||
如果你的请求符合以上任一情况,请不要提交 issue,
|
||||
相关 issue 将被直接关闭。
|
||||
|
||||
Thank you for respecting the scope and design principles of this project.
|
||||
感谢你的理解与支持。
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: "⭐️ Feature description"
|
||||
@@ -30,4 +92,4 @@ body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
## Thank you for contributing to the project :slightly_smiling_face:
|
||||
## Thank you for contributing to the project :slightly_smiling_face:
|
||||
|
||||
@@ -39,10 +39,11 @@
|
||||
Create a `config.toml` file with the following content:
|
||||
|
||||
```toml
|
||||
lang = "en" # Language setting, "en" for English
|
||||
[telegram]
|
||||
token = "" # Your bot token, obtained from @BotFather
|
||||
[telegram.proxy]
|
||||
# Enable proxy for Telegram, currently only SOCKS5 is supported
|
||||
# Enable proxy for Telegram
|
||||
enable = false
|
||||
url = "socks5://127.0.0.1:7890"
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@
|
||||
[telegram]
|
||||
token = "" # 你的 Bot Token, 在 @BotFather 获取
|
||||
[telegram.proxy]
|
||||
# 启用代理连接 telegram, 当前只支持 socks5
|
||||
# 启用代理连接 telegram
|
||||
enable = false
|
||||
url = "socks5://127.0.0.1:7890"
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/gotd/td/tg"
|
||||
"github.com/krau/SaveAny-Bot/common/i18n"
|
||||
"github.com/krau/SaveAny-Bot/common/i18n/i18nk"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/database"
|
||||
"github.com/krau/SaveAny-Bot/pkg/enums/fnamest"
|
||||
"github.com/krau/SaveAny-Bot/pkg/tcbdata"
|
||||
@@ -73,9 +74,9 @@ func handleConfigFnameSTCallback(ctx *ext.Context, update *ext.Update) error {
|
||||
return err
|
||||
}
|
||||
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
|
||||
ID: update.CallbackQuery.GetMsgID(),
|
||||
ID: update.CallbackQuery.GetMsgID(),
|
||||
Message: i18n.T(i18nk.BotMsgConfigInfoFilenameStrategySet, map[string]any{
|
||||
"Strategy": fnamest.FnameSTDisplay[st],
|
||||
"Strategy": fnamest.GetDisplay(st, config.C().Lang),
|
||||
}),
|
||||
})
|
||||
return dispatcher.EndGroups
|
||||
@@ -84,7 +85,7 @@ func handleConfigFnameSTCallback(ctx *ext.Context, update *ext.Update) error {
|
||||
buttons := make([]tg.KeyboardButtonClass, 0, len(opts))
|
||||
for _, opt := range opts {
|
||||
buttons = append(buttons, &tg.KeyboardButtonCallback{
|
||||
Text: fnamest.FnameSTDisplay[opt],
|
||||
Text: fnamest.GetDisplay(opt, config.C().Lang),
|
||||
Data: fmt.Appendf(nil, "%s %s %s", tcbdata.TypeConfig, "fnamest", opt),
|
||||
})
|
||||
}
|
||||
@@ -100,9 +101,9 @@ func handleConfigFnameSTCallback(ctx *ext.Context, update *ext.Update) error {
|
||||
currentSt = fnamest.Default
|
||||
}
|
||||
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
|
||||
ID: update.CallbackQuery.GetMsgID(),
|
||||
Message: i18n.T(i18nk.BotMsgConfigPromptSelectFilenameStrategy, map[string]any{
|
||||
"Strategy": fnamest.FnameSTDisplay[currentSt],
|
||||
ID: update.CallbackQuery.GetMsgID(),
|
||||
Message: i18n.T(i18nk.BotMsgConfigPromptSelectFilenameStrategy, map[string]any{
|
||||
"Strategy": fnamest.GetDisplay(currentSt, config.C().Lang),
|
||||
}),
|
||||
ReplyMarkup: markup,
|
||||
})
|
||||
|
||||
@@ -3,6 +3,7 @@ package handlers
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/celestix/gotgproto/ext"
|
||||
"github.com/charmbracelet/log"
|
||||
@@ -10,6 +11,8 @@ import (
|
||||
"github.com/krau/SaveAny-Bot/client/bot/handlers/utils/msgelem"
|
||||
"github.com/krau/SaveAny-Bot/common/i18n"
|
||||
"github.com/krau/SaveAny-Bot/common/i18n/i18nk"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/pkg/aria2"
|
||||
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
|
||||
"github.com/krau/SaveAny-Bot/pkg/tcbdata"
|
||||
"github.com/krau/SaveAny-Bot/storage"
|
||||
@@ -50,3 +53,53 @@ func handleDlCmd(ctx *ext.Context, update *ext.Update) error {
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
var aria2ClientInitOnce sync.Once
|
||||
var aria2ClientInitErr error
|
||||
var aria2Client *aria2.Client
|
||||
|
||||
func handleAria2DlCmd(ctx *ext.Context, update *ext.Update) error {
|
||||
if !config.C().Aria2.Enable {
|
||||
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgAria2ErrorAria2NotEnabled)), nil)
|
||||
return nil
|
||||
}
|
||||
logger := log.FromContext(ctx)
|
||||
args := strings.Split(update.EffectiveMessage.Text, " ")
|
||||
if len(args) < 2 {
|
||||
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgDlUsage)), nil)
|
||||
return nil
|
||||
}
|
||||
links := args[1:]
|
||||
for i, link := range links {
|
||||
links[i] = strings.TrimSpace(link)
|
||||
}
|
||||
links = slice.Compact(links)
|
||||
if len(links) == 0 {
|
||||
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgDlErrorNoValidLinks)), nil)
|
||||
return nil
|
||||
}
|
||||
logger.Debug("Adding aria2 download", "links", links)
|
||||
aria2ClientInitOnce.Do(func() {
|
||||
aria2Client, aria2ClientInitErr = aria2.NewClient(config.C().Aria2.Url, config.C().Aria2.Secret)
|
||||
})
|
||||
if aria2ClientInitErr != nil {
|
||||
logger.Error("Failed to initialize aria2 client", "error", aria2ClientInitErr)
|
||||
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgAria2ErrorAria2ClientInitFailed, map[string]any{
|
||||
"Error": aria2ClientInitErr.Error(),
|
||||
})), nil)
|
||||
return nil
|
||||
}
|
||||
gid, err := aria2Client.AddURI(ctx, links, nil)
|
||||
if err != nil {
|
||||
logger.Error("Failed to add aria2 download", "error", err)
|
||||
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgAria2ErrorAddingAria2Download, map[string]any{
|
||||
"Error": err.Error(),
|
||||
})), nil)
|
||||
return nil
|
||||
}
|
||||
logger.Info("Aria2 download added", "gid", gid)
|
||||
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgAria2InfoAria2DownloadAdded, map[string]any{
|
||||
"GID": gid,
|
||||
})), nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -29,15 +29,17 @@ var CommandHandlers = []DescCommandHandler{
|
||||
{"rule", i18nk.BotMsgCmdRule, handleRuleCmd},
|
||||
{"save", i18nk.BotMsgCmdSave, handleSilentMode(handleSaveCmd, handleSilentSaveReplied)},
|
||||
{"dl", i18nk.BotMsgCmdDl, handleDlCmd},
|
||||
{"aria2dl", i18nk.BotMsgCmdAria2dl, handleAria2DlCmd},
|
||||
{"task", i18nk.BotMsgCmdTask, handleTaskCmd},
|
||||
{"cancel", i18nk.BotMsgCmdCancel, handleCancelCmd},
|
||||
{"watch", i18nk.BotMsgCmdWatch, handleWatchCmd},
|
||||
{"unwatch", i18nk.BotMsgCmdUnwatch, handleUnwatchCmd},
|
||||
{"lswatch", i18nk.BotMsgCmdLswatch, handleLswatchCmd},
|
||||
{"config", i18nk.BotMsgCmdConfig, handleConfigCmd},
|
||||
{"fnametmpl", i18nk.BotMsgCmdFnametmpl, handleConfigFnameTmpl},
|
||||
{"help", i18nk.BotMsgCmdHelp, handleHelpCmd},
|
||||
{"parser", i18nk.BotMsgCmdParser, handleParserCmd},
|
||||
{"watch", i18nk.BotMsgCmdWatch, handleWatchCmd},
|
||||
{"unwatch", i18nk.BotMsgCmdUnwatch, handleUnwatchCmd},
|
||||
{"lswatch", i18nk.BotMsgCmdLswatch, handleLswatchCmd},
|
||||
{"syncpeers", i18nk.BotMsgCmdSyncpeers, handleSyncpeersCmd},
|
||||
{"update", i18nk.BotMsgCmdUpdate, handleUpdateCmd},
|
||||
}
|
||||
|
||||
|
||||
62
client/bot/handlers/sync_peers.go
Normal file
62
client/bot/handlers/sync_peers.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/celestix/gotgproto/dispatcher"
|
||||
"github.com/celestix/gotgproto/ext"
|
||||
"github.com/celestix/gotgproto/storage"
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/gotd/td/telegram/query/dialogs"
|
||||
"github.com/krau/SaveAny-Bot/client/user"
|
||||
"github.com/krau/SaveAny-Bot/common/i18n"
|
||||
"github.com/krau/SaveAny-Bot/common/i18n/i18nk"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
)
|
||||
|
||||
var syncpeerMu sync.Mutex
|
||||
|
||||
func handleSyncpeersCmd(ctx *ext.Context, u *ext.Update) error {
|
||||
if !config.C().Telegram.Userbot.Enable {
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
syncpeerMu.Lock()
|
||||
defer syncpeerMu.Unlock()
|
||||
uctx := user.GetCtx()
|
||||
if uctx == nil {
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
ctx.Reply(u, ext.ReplyTextString(i18n.T(i18nk.BotMsgSyncpeersStart)), nil)
|
||||
tapi := uctx.Raw
|
||||
peerStorage := uctx.PeerStorage
|
||||
log.FromContext(ctx).Info("Starting to sync peers...")
|
||||
count := 0
|
||||
err := dialogs.NewQueryBuilder(tapi).GetDialogs().BatchSize(50).ForEach(ctx, func(ctx context.Context, e dialogs.Elem) error {
|
||||
for cid, channel := range e.Entities.Channels() {
|
||||
peerStorage.AddPeer(cid, channel.AccessHash, storage.TypeChannel, channel.Username)
|
||||
count++
|
||||
}
|
||||
for uid, user := range e.Entities.Users() {
|
||||
peerStorage.AddPeer(uid, user.AccessHash, storage.TypeUser, user.Username)
|
||||
count++
|
||||
}
|
||||
for gid := range e.Entities.Chats() {
|
||||
peerStorage.AddPeer(gid, storage.DefaultAccessHash, storage.TypeChat, storage.DefaultUsername)
|
||||
count++
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
log.FromContext(ctx).Error("Failed to sync peers", "error", err)
|
||||
ctx.Reply(u, ext.ReplyTextString(i18n.T(i18nk.BotMsgSyncpeersFailed, map[string]any{
|
||||
"Error": err.Error(),
|
||||
})), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
log.FromContext(ctx).Info("Finished syncing peers")
|
||||
ctx.Reply(u, ext.ReplyTextString(i18n.T(i18nk.BotMsgSyncpeersSuccess, map[string]any{
|
||||
"Count": count,
|
||||
})), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
@@ -30,6 +30,7 @@ type FilenameTemplateData struct {
|
||||
MsgTags string `json:"msgtags,omitempty"`
|
||||
MsgGen string `json:"msggen,omitempty"`
|
||||
MsgDate string `json:"msgdate,omitempty"`
|
||||
MsgRaw string `json:"msgraw,omitempty"`
|
||||
OrigName string `json:"origname,omitempty"`
|
||||
ChatID string `json:"chatid,omitempty"`
|
||||
}
|
||||
@@ -39,6 +40,7 @@ func (f FilenameTemplateData) ToMap() map[string]string {
|
||||
"msgid": f.MsgID,
|
||||
"msgtags": f.MsgTags,
|
||||
"msggen": f.MsgGen,
|
||||
"msgraw": f.MsgRaw,
|
||||
"msgdate": f.MsgDate,
|
||||
"origname": f.OrigName,
|
||||
"chatid": f.ChatID,
|
||||
@@ -108,8 +110,10 @@ func BuildFilenameTemplateData(message *tg.Message) map[string]string {
|
||||
t := time.Unix(int64(date), 0)
|
||||
return t.Format("2006-01-02_15-04-05")
|
||||
}(),
|
||||
MsgRaw: message.GetMessage(),
|
||||
ChatID: func() string {
|
||||
// 如果消息是频道的(从消息链接中fetch的) 直接使用其chat id, 无论它是否是从其他来源转发的
|
||||
// 如果消息是频道的(从消息链接中fetch的) 直接使用其chat id,
|
||||
// 无论它是否是从其他来源转发的
|
||||
if message.GetPost() {
|
||||
peer := message.GetPeerID()
|
||||
switch p := peer.(type) {
|
||||
|
||||
@@ -37,7 +37,7 @@ func main() {
|
||||
return err
|
||||
}
|
||||
|
||||
var content map[string]interface{}
|
||||
var content map[string]any
|
||||
if err := yaml.Unmarshal(data, &content); err != nil {
|
||||
return fmt.Errorf("failed to parse yaml %s: %w", path, err)
|
||||
}
|
||||
|
||||
@@ -4,10 +4,16 @@ package i18nk
|
||||
type Key string
|
||||
|
||||
const (
|
||||
BotMsgAria2ErrorAddingAria2Download Key = "bot.msg.aria2.error_adding_aria2_download"
|
||||
BotMsgAria2ErrorAria2ClientInitFailed Key = "bot.msg.aria2.error_aria2_client_init_failed"
|
||||
BotMsgAria2ErrorAria2NotEnabled Key = "bot.msg.aria2.error_aria2_not_enabled"
|
||||
BotMsgAria2InfoAddingAria2Download Key = "bot.msg.aria2.info_adding_aria2_download"
|
||||
BotMsgAria2InfoAria2DownloadAdded Key = "bot.msg.aria2.info_aria2_download_added"
|
||||
BotMsgCancelErrorCancelFailed Key = "bot.msg.cancel.error_cancel_failed"
|
||||
BotMsgCancelInfoCancelRequested Key = "bot.msg.cancel.info_cancel_requested"
|
||||
BotMsgCancelInfoCancellingTask Key = "bot.msg.cancel.info_cancelling_task"
|
||||
BotMsgCancelUsage Key = "bot.msg.cancel.usage"
|
||||
BotMsgCmdAria2dl Key = "bot.msg.cmd.aria2dl"
|
||||
BotMsgCmdCancel Key = "bot.msg.cmd.cancel"
|
||||
BotMsgCmdConfig Key = "bot.msg.cmd.config"
|
||||
BotMsgCmdDir Key = "bot.msg.cmd.dir"
|
||||
@@ -21,6 +27,7 @@ const (
|
||||
BotMsgCmdSilent Key = "bot.msg.cmd.silent"
|
||||
BotMsgCmdStart Key = "bot.msg.cmd.start"
|
||||
BotMsgCmdStorage Key = "bot.msg.cmd.storage"
|
||||
BotMsgCmdSyncpeers Key = "bot.msg.cmd.syncpeers"
|
||||
BotMsgCmdTask Key = "bot.msg.cmd.task"
|
||||
BotMsgCmdUnwatch Key = "bot.msg.cmd.unwatch"
|
||||
BotMsgCmdUpdate Key = "bot.msg.cmd.update"
|
||||
@@ -170,6 +177,10 @@ const (
|
||||
BotMsgSaveHelpText Key = "bot.msg.save_help_text"
|
||||
BotMsgStorageInfoFilenamePrefix Key = "bot.msg.storage.info_filename_prefix"
|
||||
BotMsgStorageInfoPromptSelectStorage Key = "bot.msg.storage.info_prompt_select_storage"
|
||||
BotMsgSyncpeersDone Key = "bot.msg.syncpeers.done"
|
||||
BotMsgSyncpeersFailed Key = "bot.msg.syncpeers.failed"
|
||||
BotMsgSyncpeersStart Key = "bot.msg.syncpeers.start"
|
||||
BotMsgSyncpeersSuccess Key = "bot.msg.syncpeers.success"
|
||||
BotMsgTasksCancelFailed Key = "bot.msg.tasks.cancel_failed"
|
||||
BotMsgTasksCancelRequestedPrefix Key = "bot.msg.tasks.cancel_requested_prefix"
|
||||
BotMsgTasksFieldCreated Key = "bot.msg.tasks.field_created"
|
||||
|
||||
@@ -38,6 +38,7 @@ bot:
|
||||
/watch - Watch chats and auto save (UserBot)
|
||||
/unwatch - Stop watching chats (UserBot)
|
||||
/lswatch - List watched chats (UserBot)
|
||||
/syncpeers - Sync peer chats (UserBot)
|
||||
/update - Check and upgrade to latest version
|
||||
|
||||
Usage guide: https://sabot.unv.app/usage
|
||||
@@ -59,6 +60,7 @@ bot:
|
||||
help: "Show help"
|
||||
parser: "Manage parsers"
|
||||
update: "Check for updates"
|
||||
syncpeers: "Sync peer chats (UserBot)"
|
||||
save_help_text: |
|
||||
Usage:
|
||||
|
||||
@@ -272,6 +274,7 @@ bot:
|
||||
- {{"{{.msgtags}}"}}: Tags in the message, joined with underscore
|
||||
- {{"{{.msggen}}"}}: Generated filename from the message
|
||||
- {{"{{.msgdate}}"}}: Message date, format YYYY-MM-DD_HH-MM-SS
|
||||
- {{"{{.msgraw}}"}}: Raw message text (unprocessed)
|
||||
- {{"{{.origname}}"}}: Original media filename (if any)
|
||||
- {{"{{.chatid}}"}}: Chat ID of the message
|
||||
|
||||
@@ -323,3 +326,7 @@ bot:
|
||||
direct_start: "Starting download, total size: {{.SizeMB}} MB ({{.Count}} files)"
|
||||
file_name_prefix: "Filename: "
|
||||
error_prefix: "\nError: "
|
||||
syncpeers:
|
||||
start: "Starting to sync peers..."
|
||||
done: "Peer sync completed, total {{.Count}} chats synced"
|
||||
failed: "Peer sync failed: {{.Error}}"
|
||||
|
||||
@@ -29,6 +29,7 @@ bot:
|
||||
/silent - 开关静默模式
|
||||
/storage - 设置默认存储位置
|
||||
/save [自定义文件名] - 保存文件
|
||||
/dl <链接1> <链接2> ... - 下载给定链接的文件
|
||||
/dir - 管理存储目录
|
||||
/rule - 管理规则
|
||||
/config - 修改配置
|
||||
@@ -38,6 +39,7 @@ bot:
|
||||
/watch - 监听聊天并自动保存 (UserBot)
|
||||
/unwatch - 取消监听聊天 (UserBot)
|
||||
/lswatch - 列出正在监听的聊天 (UserBot)
|
||||
/syncpeers - 同步对话列表 (UserBot)
|
||||
/update - 检查更新并升级
|
||||
|
||||
使用帮助: https://sabot.unv.app/usage
|
||||
@@ -49,11 +51,13 @@ bot:
|
||||
rule: "管理自动存储规则"
|
||||
save: "保存文件"
|
||||
dl: "下载给定链接的文件"
|
||||
aria2dl: "使用 Aria2 下载给定链接的文件"
|
||||
task: "管理任务队列"
|
||||
cancel: "取消任务"
|
||||
watch: "监听聊天(UserBot)"
|
||||
unwatch: "取消监听聊天(UserBot)"
|
||||
lswatch: "列出监听的聊天(UserBot)"
|
||||
syncpeers: "同步对话列表(UserBot)"
|
||||
config: "修改配置"
|
||||
fnametmpl: "设置文件命名模板"
|
||||
help: "显示帮助"
|
||||
@@ -272,6 +276,7 @@ bot:
|
||||
- {{"{{.msgtags}}"}}: 消息中的标签, 将以下划线分隔输出
|
||||
- {{"{{.msggen}}"}}: 根据消息生成的文件名
|
||||
- {{"{{.msgdate}}"}}: 消息日期, 格式 YYYY-MM-DD_HH-MM-SS
|
||||
- {{"{{.msgraw}}"}}: 消息的原始文本内容 (不经任何处理)
|
||||
- {{"{{.origname}}"}}: 媒体的原始文件名 (如果有)
|
||||
- {{"{{.chatid}}"}}: 消息的聊天ID
|
||||
|
||||
@@ -323,3 +328,13 @@ bot:
|
||||
direct_start: "开始下载, 总大小: {{.SizeMB}} MB ({{.Count}} 个文件)"
|
||||
file_name_prefix: "文件名: "
|
||||
error_prefix: "\n错误: "
|
||||
syncpeers:
|
||||
start: "正在同步对话列表..."
|
||||
success: "对话列表同步完成, 共同步 {{.Count}} 个对话"
|
||||
failed: "对话列表同步失败: {{.Error}}"
|
||||
aria2:
|
||||
error_aria2_not_enabled: "Aria2 功能未启用, 请在配置文件中启用"
|
||||
error_aria2_client_init_failed: "Aria2 客户端初始化失败: {{.Error}}"
|
||||
info_adding_aria2_download: "正在添加 Aria2 下载任务..."
|
||||
error_adding_aria2_download: "添加 Aria2 下载任务失败: {{.Error}}"
|
||||
info_aria2_download_added: "Aria2 下载任务已添加, GID: {{.GID}}"
|
||||
|
||||
@@ -113,29 +113,28 @@ func InputMessageClassSliceFromInt(ids []int) []tg.InputMessageClass {
|
||||
return result
|
||||
}
|
||||
|
||||
func GetMessagesRange(ctx *ext.Context, chatID int64, minId, maxId int) ([]*tg.Message, error) {
|
||||
if msg, err := getMessagesRange(ctx, chatID, minId, maxId); err == nil {
|
||||
return msg, nil
|
||||
func GetMessagesRange(ctx *ext.Context, chatID int64, minId, maxId int) (msg []*tg.Message, err error) {
|
||||
if msg, err = getMessagesRange(ctx, chatID, minId, maxId); err == nil {
|
||||
return
|
||||
}
|
||||
in := constant.TDLibPeerID(chatID)
|
||||
plain := in.ToPlain()
|
||||
|
||||
var channel constant.TDLibPeerID
|
||||
channel.Channel(plain)
|
||||
if msg, err := getMessagesRange(ctx, int64(channel), minId, maxId); err == nil {
|
||||
return msg, nil
|
||||
if msg, err = getMessagesRange(ctx, int64(channel), minId, maxId); err == nil {
|
||||
return
|
||||
}
|
||||
var userID constant.TDLibPeerID
|
||||
userID.User(plain)
|
||||
if msg, err := getMessagesRange(ctx, int64(userID), minId, maxId); err == nil {
|
||||
return msg, nil
|
||||
if msg, err = getMessagesRange(ctx, int64(userID), minId, maxId); err == nil {
|
||||
return
|
||||
}
|
||||
var chat constant.TDLibPeerID
|
||||
chat.Chat(plain)
|
||||
if msg, err := getMessagesRange(ctx, int64(chat), minId, maxId); err == nil {
|
||||
return msg, nil
|
||||
if msg, err = getMessagesRange(ctx, int64(chat), minId, maxId); err == nil {
|
||||
return
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get messages range for chatID %d", chatID)
|
||||
return nil, fmt.Errorf("failed to get messages range for chat %d: %w", chatID, err)
|
||||
}
|
||||
|
||||
func getMessagesRange(ctx *ext.Context, chatID int64, minId, maxId int) ([]*tg.Message, error) {
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
package tgutil
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/gotd/td/telegram/dcs"
|
||||
@@ -8,24 +14,108 @@ import (
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
func newProxyDialer(proxyUrl string) (proxy.Dialer, error) {
|
||||
url, err := url.Parse(proxyUrl)
|
||||
// httpProxyDialer implements proxy.ContextDialer for HTTP CONNECT proxies
|
||||
type httpProxyDialer struct {
|
||||
proxyURL *url.URL
|
||||
forward proxy.Dialer
|
||||
}
|
||||
|
||||
func (d *httpProxyDialer) Dial(network, addr string) (net.Conn, error) {
|
||||
return d.DialContext(context.Background(), network, addr)
|
||||
}
|
||||
|
||||
func (d *httpProxyDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
proxyAddr := d.proxyURL.Host
|
||||
if d.proxyURL.Port() == "" {
|
||||
if d.proxyURL.Scheme == "https" {
|
||||
proxyAddr = net.JoinHostPort(d.proxyURL.Hostname(), "443")
|
||||
} else {
|
||||
proxyAddr = net.JoinHostPort(d.proxyURL.Hostname(), "80")
|
||||
}
|
||||
}
|
||||
|
||||
var conn net.Conn
|
||||
var err error
|
||||
if ctxDialer, ok := d.forward.(proxy.ContextDialer); ok {
|
||||
conn, err = ctxDialer.DialContext(ctx, "tcp", proxyAddr)
|
||||
} else {
|
||||
conn, err = d.forward.Dial("tcp", proxyAddr)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to proxy: %w", err)
|
||||
}
|
||||
|
||||
// Send CONNECT request
|
||||
connectReq := &http.Request{
|
||||
Method: "CONNECT",
|
||||
URL: &url.URL{Opaque: addr},
|
||||
Host: addr,
|
||||
Header: make(http.Header),
|
||||
}
|
||||
|
||||
// Add proxy authentication if provided
|
||||
if d.proxyURL.User != nil {
|
||||
username := d.proxyURL.User.Username()
|
||||
password, _ := d.proxyURL.User.Password()
|
||||
auth := base64.StdEncoding.EncodeToString([]byte(username + ":" + password))
|
||||
connectReq.Header.Set("Proxy-Authorization", "Basic "+auth)
|
||||
}
|
||||
|
||||
if err := connectReq.Write(conn); err != nil {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("failed to write CONNECT request: %w", err)
|
||||
}
|
||||
|
||||
// Read response
|
||||
br := bufio.NewReader(conn)
|
||||
resp, err := http.ReadResponse(br, connectReq)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("failed to read CONNECT response: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("proxy CONNECT failed with status: %s", resp.Status)
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func newProxyDialer(proxyUrl string) (proxy.ContextDialer, error) {
|
||||
parsedURL, err := url.Parse(proxyUrl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return proxy.FromURL(url, proxy.Direct)
|
||||
|
||||
switch parsedURL.Scheme {
|
||||
case "http", "https":
|
||||
return &httpProxyDialer{
|
||||
proxyURL: parsedURL,
|
||||
forward: proxy.Direct,
|
||||
}, nil
|
||||
case "socks5", "socks5h":
|
||||
dialer, err := proxy.FromURL(parsedURL, proxy.Direct)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dialer.(proxy.ContextDialer), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported proxy scheme: %s", parsedURL.Scheme)
|
||||
}
|
||||
}
|
||||
|
||||
func NewConfigProxyResolver() (dcs.Resolver, error) {
|
||||
resolver := dcs.DefaultResolver()
|
||||
if config.C().Proxy != "" {
|
||||
// gloabl proxy, which has lower priority
|
||||
// global proxy, which has lower priority
|
||||
dialer, err := newProxyDialer(config.C().Proxy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resolver = dcs.Plain(dcs.PlainOptions{
|
||||
Dial: dialer.(proxy.ContextDialer).DialContext,
|
||||
Dial: dialer.DialContext,
|
||||
})
|
||||
}
|
||||
if config.C().Telegram.Proxy.Enable && config.C().Telegram.Proxy.URL != "" {
|
||||
@@ -34,7 +124,7 @@ func NewConfigProxyResolver() (dcs.Resolver, error) {
|
||||
return nil, err
|
||||
}
|
||||
resolver = dcs.Plain(dcs.PlainOptions{
|
||||
Dial: dialer.(proxy.ContextDialer).DialContext,
|
||||
Dial: dialer.DialContext,
|
||||
})
|
||||
}
|
||||
return resolver, nil
|
||||
|
||||
@@ -14,7 +14,7 @@ token = ""
|
||||
# app_id = 1025907
|
||||
# app_hash = "452b0359b988148995f22ff0f4229750"
|
||||
[telegram.proxy]
|
||||
# 启用代理连接 telegram, 只支持 socks5
|
||||
# 启用代理连接 telegram
|
||||
enable = false
|
||||
url = "socks5://127.0.0.1:7890"
|
||||
|
||||
|
||||
@@ -16,13 +16,14 @@ import (
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Lang string `toml:"lang" mapstructure:"lang" json:"lang"`
|
||||
Workers int `toml:"workers" mapstructure:"workers"`
|
||||
Retry int `toml:"retry" mapstructure:"retry"`
|
||||
NoCleanCache bool `toml:"no_clean_cache" mapstructure:"no_clean_cache" json:"no_clean_cache"`
|
||||
Threads int `toml:"threads" mapstructure:"threads" json:"threads"`
|
||||
Stream bool `toml:"stream" mapstructure:"stream" json:"stream"`
|
||||
Proxy string `toml:"proxy" mapstructure:"proxy" json:"proxy"`
|
||||
Lang string `toml:"lang" mapstructure:"lang" json:"lang"`
|
||||
Workers int `toml:"workers" mapstructure:"workers"`
|
||||
Retry int `toml:"retry" mapstructure:"retry"`
|
||||
NoCleanCache bool `toml:"no_clean_cache" mapstructure:"no_clean_cache" json:"no_clean_cache"`
|
||||
Threads int `toml:"threads" mapstructure:"threads" json:"threads"`
|
||||
Stream bool `toml:"stream" mapstructure:"stream" json:"stream"`
|
||||
Proxy string `toml:"proxy" mapstructure:"proxy" json:"proxy"`
|
||||
Aria2 aria2Config `toml:"aria2" mapstructure:"aria2" json:"aria2"`
|
||||
|
||||
Cache cacheConfig `toml:"cache" mapstructure:"cache" json:"cache"`
|
||||
Users []userConfig `toml:"users" mapstructure:"users" json:"users"`
|
||||
@@ -34,6 +35,12 @@ type Config struct {
|
||||
Hook hookConfig `toml:"hook" mapstructure:"hook" json:"hook"`
|
||||
}
|
||||
|
||||
type aria2Config struct {
|
||||
Enable bool `toml:"enable" mapstructure:"enable" json:"enable"`
|
||||
Url string `toml:"url" mapstructure:"url" json:"url"`
|
||||
Secret string `toml:"secret" mapstructure:"secret" json:"secret"`
|
||||
}
|
||||
|
||||
var cfg = &Config{}
|
||||
|
||||
func C() Config {
|
||||
|
||||
@@ -30,6 +30,7 @@ base_path = "./downloads"
|
||||
|
||||
### Global Configuration
|
||||
|
||||
- `lang`: The language used by the Bot, default is `zh-CN` (Simplified Chinese). `en` is used for English.
|
||||
- `stream`: Whether to enable Stream mode, default is `false`. When enabled, the Bot will stream files directly to storage endpoints (if supported), without downloading them locally.
|
||||
{{< hint warning >}}
|
||||
Stream mode is very useful for deployment environments with limited disk space, but it also has some drawbacks:
|
||||
@@ -47,6 +48,7 @@ Stream mode is very useful for deployment environments with limited disk space,
|
||||
- `proxy`: Global proxy configuration. After setting this, all network connections inside the program will try to use this proxy. Optional.
|
||||
|
||||
```toml
|
||||
lang = "en"
|
||||
stream = false
|
||||
workers = 3
|
||||
threads = 4
|
||||
@@ -62,7 +64,7 @@ proxy = "socks5://127.0.0.1:7890"
|
||||
- `rpc_retry`: Number of retries for RPC requests, default is 5.
|
||||
- `proxy`: Proxy configuration, optional.
|
||||
- `enable`: Whether to enable the proxy.
|
||||
- `url`: Proxy address, only supports `socks5://`
|
||||
- `url`: Proxy address
|
||||
- `userbot`: Userbot configuration, optional.
|
||||
- `enable`: Enable userbot integration. Requires logging in with a user account; you should use your own API ID & Hash when enabling this.
|
||||
- `session`: Path to the userbot session file, default is `data/usersession.db`.
|
||||
|
||||
@@ -4,7 +4,7 @@ title: "Installation and Updates"
|
||||
|
||||
# Installation and Updates
|
||||
|
||||
## Deploy from Pre-compiled Files (Recommended)
|
||||
## Deploy from Pre-compiled Binary (Recommended)
|
||||
|
||||
Download the binary file for your platform from the [Release](https://github.com/krau/SaveAny-Bot/releases) page.
|
||||
|
||||
@@ -17,7 +17,7 @@ chmod +x saveany-bot
|
||||
./saveany-bot
|
||||
```
|
||||
|
||||
### Process Monitoring
|
||||
### Daemon
|
||||
|
||||
{{< tabs "daemon" >}}
|
||||
{{< tab "systemd (Regular Linux)" >}}
|
||||
|
||||
@@ -62,7 +62,7 @@ proxy = "socks5://127.0.0.1:7890"
|
||||
- `rpc_retry`: RPC 请求重试次数, 默认为 5.
|
||||
- `proxy`: 代理配置, 可选.
|
||||
- `enable`: 是否启用代理.
|
||||
- `url`: 代理地址, 只支持 `socks5://`
|
||||
- `url`: 代理地址
|
||||
- `userbot`: userbot 配置, 可选.
|
||||
- `enable`: 启用 userbot 集成, 需要登录用户账号, 此时请务必使用自己的 api id & hash.
|
||||
- `session`: userbot 会话文件路径, 默认为 `data/usersession.db`.
|
||||
|
||||
@@ -46,8 +46,8 @@ func (k *KemonoParser) CanHandle(text string) bool {
|
||||
|
||||
var path string
|
||||
for _, domain := range kemonoDomains {
|
||||
if idx := strings.Index(text, domain); idx != -1 {
|
||||
remaining := text[idx+len(domain):]
|
||||
if _, after, ok := strings.Cut(text, domain); ok {
|
||||
remaining := after
|
||||
if len(remaining) > 0 && remaining[0] == '/' {
|
||||
path = remaining[1:]
|
||||
}
|
||||
|
||||
546
pkg/aria2/client.go
Normal file
546
pkg/aria2/client.go
Normal file
@@ -0,0 +1,546 @@
|
||||
package aria2
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidURL = errors.New("aria2: invalid URL")
|
||||
ErrRPCFailed = errors.New("aria2: RPC call failed")
|
||||
ErrInvalidResponse = errors.New("aria2: invalid response")
|
||||
)
|
||||
|
||||
// Client represents an aria2 JSON-RPC client
|
||||
type Client struct {
|
||||
url string
|
||||
secret string
|
||||
client *http.Client
|
||||
id atomic.Int64
|
||||
}
|
||||
|
||||
// rpcRequest represents a JSON-RPC 2.0 request
|
||||
type rpcRequest struct {
|
||||
Jsonrpc string `json:"jsonrpc"`
|
||||
ID string `json:"id"`
|
||||
Method string `json:"method"`
|
||||
Params []any `json:"params"`
|
||||
}
|
||||
|
||||
// rpcResponse represents a JSON-RPC 2.0 response
|
||||
type rpcResponse struct {
|
||||
Jsonrpc string `json:"jsonrpc"`
|
||||
ID string `json:"id"`
|
||||
Result json.RawMessage `json:"result,omitempty"`
|
||||
Error *rpcError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// rpcError represents a JSON-RPC 2.0 error
|
||||
type rpcError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
func (e *rpcError) Error() string {
|
||||
return fmt.Sprintf("aria2 RPC error %d: %s", e.Code, e.Message)
|
||||
}
|
||||
|
||||
// Options for download
|
||||
type Options map[string]any
|
||||
|
||||
// Status represents the status of a download
|
||||
type Status struct {
|
||||
GID string `json:"gid"`
|
||||
Status string `json:"status"`
|
||||
TotalLength string `json:"totalLength"`
|
||||
CompletedLength string `json:"completedLength"`
|
||||
UploadLength string `json:"uploadLength"`
|
||||
Bitfield string `json:"bitfield,omitempty"`
|
||||
DownloadSpeed string `json:"downloadSpeed"`
|
||||
UploadSpeed string `json:"uploadSpeed"`
|
||||
InfoHash string `json:"infoHash,omitempty"`
|
||||
NumSeeders string `json:"numSeeders,omitempty"`
|
||||
Seeder string `json:"seeder,omitempty"`
|
||||
PieceLength string `json:"pieceLength,omitempty"`
|
||||
NumPieces string `json:"numPieces,omitempty"`
|
||||
Connections string `json:"connections"`
|
||||
ErrorCode string `json:"errorCode,omitempty"`
|
||||
ErrorMessage string `json:"errorMessage,omitempty"`
|
||||
FollowedBy []string `json:"followedBy,omitempty"`
|
||||
Following string `json:"following,omitempty"`
|
||||
BelongsTo string `json:"belongsTo,omitempty"`
|
||||
Dir string `json:"dir"`
|
||||
Files []File `json:"files"`
|
||||
BitTorrent struct {
|
||||
AnnounceList [][]string `json:"announceList,omitempty"`
|
||||
Comment string `json:"comment,omitempty"`
|
||||
CreationDate int64 `json:"creationDate,omitempty"`
|
||||
Mode string `json:"mode,omitempty"`
|
||||
Info struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
} `json:"info"`
|
||||
} `json:"bittorrent"`
|
||||
VerifiedLength string `json:"verifiedLength,omitempty"`
|
||||
VerifyIntegrityPending string `json:"verifyIntegrityPending,omitempty"`
|
||||
}
|
||||
|
||||
// File represents a file in the download
|
||||
type File struct {
|
||||
Index string `json:"index"`
|
||||
Path string `json:"path"`
|
||||
Length string `json:"length"`
|
||||
CompletedLength string `json:"completedLength"`
|
||||
Selected string `json:"selected"`
|
||||
URIs []URI `json:"uris"`
|
||||
}
|
||||
|
||||
// URI represents a URI for a file
|
||||
type URI struct {
|
||||
URI string `json:"uri"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// GlobalStat represents global statistics
|
||||
type GlobalStat struct {
|
||||
DownloadSpeed string `json:"downloadSpeed"`
|
||||
UploadSpeed string `json:"uploadSpeed"`
|
||||
NumActive string `json:"numActive"`
|
||||
NumWaiting string `json:"numWaiting"`
|
||||
NumStopped string `json:"numStopped"`
|
||||
NumStoppedTotal string `json:"numStoppedTotal"`
|
||||
}
|
||||
|
||||
// Version represents aria2 version information
|
||||
type Version struct {
|
||||
Version string `json:"version"`
|
||||
EnabledFeatures []string `json:"enabledFeatures"`
|
||||
}
|
||||
|
||||
// NewClient creates a new aria2 client
|
||||
// url: aria2 RPC URL (e.g., "http://localhost:6800/jsonrpc")
|
||||
// secret: aria2 RPC secret token (optional, use empty string if not set)
|
||||
func NewClient(url, secret string) (*Client, error) {
|
||||
if url == "" {
|
||||
return nil, ErrInvalidURL
|
||||
}
|
||||
|
||||
return &Client{
|
||||
url: url,
|
||||
secret: secret,
|
||||
client: &http.Client{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewClientWithHTTPClient creates a new aria2 client with custom HTTP client
|
||||
func NewClientWithHTTPClient(url, secret string, httpClient *http.Client) (*Client, error) {
|
||||
if url == "" {
|
||||
return nil, ErrInvalidURL
|
||||
}
|
||||
|
||||
if httpClient == nil {
|
||||
httpClient = &http.Client{}
|
||||
}
|
||||
|
||||
return &Client{
|
||||
url: url,
|
||||
secret: secret,
|
||||
client: httpClient,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// call makes a JSON-RPC call to aria2
|
||||
func (c *Client) call(ctx context.Context, method string, params []any, result any) error {
|
||||
// Prepare params with secret token if set
|
||||
var rpcParams []any
|
||||
if c.secret != "" {
|
||||
rpcParams = append([]any{fmt.Sprintf("token:%s", c.secret)}, params...)
|
||||
} else {
|
||||
rpcParams = params
|
||||
}
|
||||
|
||||
// Create request
|
||||
reqID := fmt.Sprintf("%d", c.id.Add(1))
|
||||
req := &rpcRequest{
|
||||
Jsonrpc: "2.0",
|
||||
ID: reqID,
|
||||
Method: method,
|
||||
Params: rpcParams,
|
||||
}
|
||||
|
||||
// Marshal request
|
||||
reqBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: failed to marshal request: %v", ErrRPCFailed, err)
|
||||
}
|
||||
|
||||
// Create HTTP request
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.url, bytes.NewReader(reqBody))
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: failed to create request: %v", ErrRPCFailed, err)
|
||||
}
|
||||
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// Send request
|
||||
resp, err := c.client.Do(httpReq)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: failed to send request: %v", ErrRPCFailed, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Read response
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: failed to read response: %v", ErrRPCFailed, err)
|
||||
}
|
||||
|
||||
// Check HTTP status
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("%w: HTTP %d: %s", ErrRPCFailed, resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Parse response
|
||||
var rpcResp rpcResponse
|
||||
if err := json.Unmarshal(body, &rpcResp); err != nil {
|
||||
return fmt.Errorf("%w: failed to unmarshal response: %v", ErrInvalidResponse, err)
|
||||
}
|
||||
|
||||
// Check for RPC error
|
||||
if rpcResp.Error != nil {
|
||||
return rpcResp.Error
|
||||
}
|
||||
|
||||
// Check response ID
|
||||
if rpcResp.ID != reqID {
|
||||
return fmt.Errorf("%w: response ID mismatch", ErrInvalidResponse)
|
||||
}
|
||||
|
||||
// Unmarshal result if needed
|
||||
if result != nil {
|
||||
if err := json.Unmarshal(rpcResp.Result, result); err != nil {
|
||||
return fmt.Errorf("%w: failed to unmarshal result: %v", ErrInvalidResponse, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddURI adds a new download with URIs
|
||||
func (c *Client) AddURI(ctx context.Context, uris []string, options Options) (string, error) {
|
||||
var gid string
|
||||
params := []any{uris}
|
||||
if options != nil {
|
||||
params = append(params, options)
|
||||
}
|
||||
err := c.call(ctx, "aria2.addUri", params, &gid)
|
||||
return gid, err
|
||||
}
|
||||
|
||||
// AddTorrent adds a new download with torrent file content
|
||||
func (c *Client) AddTorrent(ctx context.Context, torrent []byte, uris []string, options Options) (string, error) {
|
||||
var gid string
|
||||
params := []any{torrent}
|
||||
if len(uris) > 0 {
|
||||
params = append(params, uris)
|
||||
}
|
||||
if options != nil {
|
||||
params = append(params, options)
|
||||
}
|
||||
err := c.call(ctx, "aria2.addTorrent", params, &gid)
|
||||
return gid, err
|
||||
}
|
||||
|
||||
// AddMetalink adds a new download with metalink file content
|
||||
func (c *Client) AddMetalink(ctx context.Context, metalink []byte, options Options) ([]string, error) {
|
||||
var gids []string
|
||||
params := []any{metalink}
|
||||
if options != nil {
|
||||
params = append(params, options)
|
||||
}
|
||||
err := c.call(ctx, "aria2.addMetalink", params, &gids)
|
||||
return gids, err
|
||||
}
|
||||
|
||||
// Remove removes the download denoted by gid
|
||||
func (c *Client) Remove(ctx context.Context, gid string) (string, error) {
|
||||
var result string
|
||||
err := c.call(ctx, "aria2.remove", []any{gid}, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// ForceRemove removes the download denoted by gid forcefully
|
||||
func (c *Client) ForceRemove(ctx context.Context, gid string) (string, error) {
|
||||
var result string
|
||||
err := c.call(ctx, "aria2.forceRemove", []any{gid}, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// Pause pauses the download denoted by gid
|
||||
func (c *Client) Pause(ctx context.Context, gid string) (string, error) {
|
||||
var result string
|
||||
err := c.call(ctx, "aria2.pause", []any{gid}, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// PauseAll pauses all downloads
|
||||
func (c *Client) PauseAll(ctx context.Context) (string, error) {
|
||||
var result string
|
||||
err := c.call(ctx, "aria2.pauseAll", []any{}, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// ForcePause pauses the download denoted by gid forcefully
|
||||
func (c *Client) ForcePause(ctx context.Context, gid string) (string, error) {
|
||||
var result string
|
||||
err := c.call(ctx, "aria2.forcePause", []any{gid}, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// ForcePauseAll pauses all downloads forcefully
|
||||
func (c *Client) ForcePauseAll(ctx context.Context) (string, error) {
|
||||
var result string
|
||||
err := c.call(ctx, "aria2.forcePauseAll", []any{}, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// Unpause unpauses the download denoted by gid
|
||||
func (c *Client) Unpause(ctx context.Context, gid string) (string, error) {
|
||||
var result string
|
||||
err := c.call(ctx, "aria2.unpause", []any{gid}, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// UnpauseAll unpauses all downloads
|
||||
func (c *Client) UnpauseAll(ctx context.Context) (string, error) {
|
||||
var result string
|
||||
err := c.call(ctx, "aria2.unpauseAll", []any{}, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// TellStatus returns the progress of the download denoted by gid
|
||||
func (c *Client) TellStatus(ctx context.Context, gid string, keys ...string) (*Status, error) {
|
||||
var status Status
|
||||
params := []any{gid}
|
||||
if len(keys) > 0 {
|
||||
params = append(params, keys)
|
||||
}
|
||||
err := c.call(ctx, "aria2.tellStatus", params, &status)
|
||||
return &status, err
|
||||
}
|
||||
|
||||
// GetURIs returns the URIs used in the download denoted by gid
|
||||
func (c *Client) GetURIs(ctx context.Context, gid string) ([]URI, error) {
|
||||
var uris []URI
|
||||
err := c.call(ctx, "aria2.getUris", []any{gid}, &uris)
|
||||
return uris, err
|
||||
}
|
||||
|
||||
// GetFiles returns the file list of the download denoted by gid
|
||||
func (c *Client) GetFiles(ctx context.Context, gid string) ([]File, error) {
|
||||
var files []File
|
||||
err := c.call(ctx, "aria2.getFiles", []any{gid}, &files)
|
||||
return files, err
|
||||
}
|
||||
|
||||
// GetPeers returns a list of peers of the download denoted by gid
|
||||
func (c *Client) GetPeers(ctx context.Context, gid string) ([]any, error) {
|
||||
var peers []any
|
||||
err := c.call(ctx, "aria2.getPeers", []any{gid}, &peers)
|
||||
return peers, err
|
||||
}
|
||||
|
||||
// GetServers returns currently connected HTTP(S)/FTP/SFTP servers of the download denoted by gid
|
||||
func (c *Client) GetServers(ctx context.Context, gid string) ([]any, error) {
|
||||
var servers []any
|
||||
err := c.call(ctx, "aria2.getServers", []any{gid}, &servers)
|
||||
return servers, err
|
||||
}
|
||||
|
||||
// TellActive returns a list of active downloads
|
||||
func (c *Client) TellActive(ctx context.Context, keys ...string) ([]Status, error) {
|
||||
var statuses []Status
|
||||
params := []any{}
|
||||
if len(keys) > 0 {
|
||||
params = append(params, keys)
|
||||
}
|
||||
err := c.call(ctx, "aria2.tellActive", params, &statuses)
|
||||
return statuses, err
|
||||
}
|
||||
|
||||
// TellWaiting returns a list of waiting downloads
|
||||
func (c *Client) TellWaiting(ctx context.Context, offset, num int, keys ...string) ([]Status, error) {
|
||||
var statuses []Status
|
||||
params := []any{offset, num}
|
||||
if len(keys) > 0 {
|
||||
params = append(params, keys)
|
||||
}
|
||||
err := c.call(ctx, "aria2.tellWaiting", params, &statuses)
|
||||
return statuses, err
|
||||
}
|
||||
|
||||
// TellStopped returns a list of stopped downloads
|
||||
func (c *Client) TellStopped(ctx context.Context, offset, num int, keys ...string) ([]Status, error) {
|
||||
var statuses []Status
|
||||
params := []any{offset, num}
|
||||
if len(keys) > 0 {
|
||||
params = append(params, keys)
|
||||
}
|
||||
err := c.call(ctx, "aria2.tellStopped", params, &statuses)
|
||||
return statuses, err
|
||||
}
|
||||
|
||||
// ChangePosition changes the position of the download denoted by gid
|
||||
func (c *Client) ChangePosition(ctx context.Context, gid string, pos int, how string) (int, error) {
|
||||
var result int
|
||||
err := c.call(ctx, "aria2.changePosition", []any{gid, pos, how}, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// ChangeURI changes the URI of the download denoted by gid
|
||||
func (c *Client) ChangeURI(ctx context.Context, gid string, fileIndex int, delURIs []string, addURIs []string) ([]int, error) {
|
||||
var result []int
|
||||
params := []any{gid, fileIndex, delURIs, addURIs}
|
||||
err := c.call(ctx, "aria2.changeUri", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// GetOption returns options of the download denoted by gid
|
||||
func (c *Client) GetOption(ctx context.Context, gid string) (Options, error) {
|
||||
var options Options
|
||||
err := c.call(ctx, "aria2.getOption", []any{gid}, &options)
|
||||
return options, err
|
||||
}
|
||||
|
||||
// ChangeOption changes options of the download denoted by gid dynamically
|
||||
func (c *Client) ChangeOption(ctx context.Context, gid string, options Options) (string, error) {
|
||||
var result string
|
||||
err := c.call(ctx, "aria2.changeOption", []any{gid, options}, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// GetGlobalOption returns the global options
|
||||
func (c *Client) GetGlobalOption(ctx context.Context) (Options, error) {
|
||||
var options Options
|
||||
err := c.call(ctx, "aria2.getGlobalOption", []any{}, &options)
|
||||
return options, err
|
||||
}
|
||||
|
||||
// ChangeGlobalOption changes global options dynamically
|
||||
func (c *Client) ChangeGlobalOption(ctx context.Context, options Options) (string, error) {
|
||||
var result string
|
||||
err := c.call(ctx, "aria2.changeGlobalOption", []any{options}, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// GetGlobalStat returns global statistics such as the overall download and upload speed
|
||||
func (c *Client) GetGlobalStat(ctx context.Context) (*GlobalStat, error) {
|
||||
var stat GlobalStat
|
||||
err := c.call(ctx, "aria2.getGlobalStat", []any{}, &stat)
|
||||
return &stat, err
|
||||
}
|
||||
|
||||
// PurgeDownloadResult purges completed/error/removed downloads
|
||||
func (c *Client) PurgeDownloadResult(ctx context.Context) (string, error) {
|
||||
var result string
|
||||
err := c.call(ctx, "aria2.purgeDownloadResult", []any{}, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// RemoveDownloadResult removes a completed/error/removed download denoted by gid
|
||||
func (c *Client) RemoveDownloadResult(ctx context.Context, gid string) (string, error) {
|
||||
var result string
|
||||
err := c.call(ctx, "aria2.removeDownloadResult", []any{gid}, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// GetVersion returns the version of aria2 and the list of enabled features
|
||||
func (c *Client) GetVersion(ctx context.Context) (*Version, error) {
|
||||
var version Version
|
||||
err := c.call(ctx, "aria2.getVersion", []any{}, &version)
|
||||
return &version, err
|
||||
}
|
||||
|
||||
// GetSessionInfo returns session information
|
||||
func (c *Client) GetSessionInfo(ctx context.Context) (map[string]any, error) {
|
||||
var info map[string]any
|
||||
err := c.call(ctx, "aria2.getSessionInfo", []any{}, &info)
|
||||
return info, err
|
||||
}
|
||||
|
||||
// Shutdown shuts down aria2
|
||||
func (c *Client) Shutdown(ctx context.Context) (string, error) {
|
||||
var result string
|
||||
err := c.call(ctx, "aria2.shutdown", []any{}, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// ForceShutdown shuts down aria2 forcefully
|
||||
func (c *Client) ForceShutdown(ctx context.Context) (string, error) {
|
||||
var result string
|
||||
err := c.call(ctx, "aria2.forceShutdown", []any{}, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// SaveSession saves the current session to a file
|
||||
func (c *Client) SaveSession(ctx context.Context) (string, error) {
|
||||
var result string
|
||||
err := c.call(ctx, "aria2.saveSession", []any{}, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// MultiCall executes multiple method calls in a single request (system.multicall)
|
||||
func (c *Client) MultiCall(ctx context.Context, calls []map[string]any) ([]any, error) {
|
||||
var results []any
|
||||
err := c.call(ctx, "system.multicall", []any{calls}, &results)
|
||||
return results, err
|
||||
}
|
||||
|
||||
// ListMethods lists all available RPC methods
|
||||
func (c *Client) ListMethods(ctx context.Context) ([]string, error) {
|
||||
var methods []string
|
||||
err := c.call(ctx, "system.listMethods", []any{}, &methods)
|
||||
return methods, err
|
||||
}
|
||||
|
||||
// ListNotifications lists all available RPC notifications
|
||||
func (c *Client) ListNotifications(ctx context.Context) ([]string, error) {
|
||||
var notifications []string
|
||||
err := c.call(ctx, "system.listNotifications", []any{}, ¬ifications)
|
||||
return notifications, err
|
||||
}
|
||||
|
||||
// IsDownloadComplete checks if the download is complete
|
||||
func (s *Status) IsDownloadComplete() bool {
|
||||
return s.Status == "complete"
|
||||
}
|
||||
|
||||
// IsDownloadActive checks if the download is active
|
||||
func (s *Status) IsDownloadActive() bool {
|
||||
return s.Status == "active"
|
||||
}
|
||||
|
||||
// IsDownloadWaiting checks if the download is waiting
|
||||
func (s *Status) IsDownloadWaiting() bool {
|
||||
return s.Status == "waiting"
|
||||
}
|
||||
|
||||
// IsDownloadPaused checks if the download is paused
|
||||
func (s *Status) IsDownloadPaused() bool {
|
||||
return s.Status == "paused"
|
||||
}
|
||||
|
||||
// IsDownloadError checks if the download has an error
|
||||
func (s *Status) IsDownloadError() bool {
|
||||
return s.Status == "error"
|
||||
}
|
||||
|
||||
// IsDownloadRemoved checks if the download is removed
|
||||
func (s *Status) IsDownloadRemoved() bool {
|
||||
return s.Status == "removed"
|
||||
}
|
||||
322
pkg/aria2/client_test.go
Normal file
322
pkg/aria2/client_test.go
Normal file
@@ -0,0 +1,322 @@
|
||||
package aria2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewClient(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
secret string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid client",
|
||||
url: "http://localhost:6800/jsonrpc",
|
||||
secret: "test-secret",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid client without secret",
|
||||
url: "http://localhost:6800/jsonrpc",
|
||||
secret: "",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid empty url",
|
||||
url: "",
|
||||
secret: "test-secret",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
client, err := NewClient(tt.url, tt.secret)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("NewClient() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !tt.wantErr && client == nil {
|
||||
t.Error("NewClient() returned nil client")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_AddURI(t *testing.T) {
|
||||
// Create a mock server
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" {
|
||||
t.Errorf("Expected POST request, got %s", r.Method)
|
||||
}
|
||||
|
||||
var req rpcRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
t.Errorf("Failed to decode request: %v", err)
|
||||
}
|
||||
|
||||
// Verify method
|
||||
if req.Method != "aria2.addUri" {
|
||||
t.Errorf("Expected method aria2.addUri, got %s", req.Method)
|
||||
}
|
||||
|
||||
// Send response
|
||||
resp := rpcResponse{
|
||||
Jsonrpc: "2.0",
|
||||
ID: req.ID,
|
||||
Result: json.RawMessage(`"2089b05ecca3d829"`),
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client, err := NewClient(server.URL, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
gid, err := client.AddURI(ctx, []string{"http://example.com/file.txt"}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("AddURI() error = %v", err)
|
||||
}
|
||||
|
||||
if gid != "2089b05ecca3d829" {
|
||||
t.Errorf("Expected gid 2089b05ecca3d829, got %s", gid)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_TellStatus(t *testing.T) {
|
||||
// Create a mock server
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var req rpcRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
t.Errorf("Failed to decode request: %v", err)
|
||||
}
|
||||
|
||||
// Verify method
|
||||
if req.Method != "aria2.tellStatus" {
|
||||
t.Errorf("Expected method aria2.tellStatus, got %s", req.Method)
|
||||
}
|
||||
|
||||
// Send response
|
||||
status := Status{
|
||||
GID: "2089b05ecca3d829",
|
||||
Status: "active",
|
||||
TotalLength: "1024000",
|
||||
CompletedLength: "512000",
|
||||
DownloadSpeed: "102400",
|
||||
Files: []File{},
|
||||
}
|
||||
result, _ := json.Marshal(status)
|
||||
|
||||
resp := rpcResponse{
|
||||
Jsonrpc: "2.0",
|
||||
ID: req.ID,
|
||||
Result: result,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client, err := NewClient(server.URL, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
status, err := client.TellStatus(ctx, "2089b05ecca3d829")
|
||||
if err != nil {
|
||||
t.Fatalf("TellStatus() error = %v", err)
|
||||
}
|
||||
|
||||
if status.GID != "2089b05ecca3d829" {
|
||||
t.Errorf("Expected gid 2089b05ecca3d829, got %s", status.GID)
|
||||
}
|
||||
|
||||
if status.Status != "active" {
|
||||
t.Errorf("Expected status active, got %s", status.Status)
|
||||
}
|
||||
|
||||
if !status.IsDownloadActive() {
|
||||
t.Error("Expected download to be active")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_WithSecret(t *testing.T) {
|
||||
expectedSecret := "my-secret-token"
|
||||
|
||||
// Create a mock server
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var req rpcRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
t.Errorf("Failed to decode request: %v", err)
|
||||
}
|
||||
|
||||
// Verify secret token is included in params
|
||||
if len(req.Params) == 0 {
|
||||
t.Error("Expected params to contain secret token")
|
||||
} else {
|
||||
token, ok := req.Params[0].(string)
|
||||
if !ok || token != "token:"+expectedSecret {
|
||||
t.Errorf("Expected token:%s, got %v", expectedSecret, req.Params[0])
|
||||
}
|
||||
}
|
||||
|
||||
// Send response
|
||||
version := Version{
|
||||
Version: "1.36.0",
|
||||
EnabledFeatures: []string{"Async DNS", "BitTorrent", "HTTP", "HTTPS"},
|
||||
}
|
||||
result, _ := json.Marshal(version)
|
||||
|
||||
resp := rpcResponse{
|
||||
Jsonrpc: "2.0",
|
||||
ID: req.ID,
|
||||
Result: result,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client, err := NewClient(server.URL, expectedSecret)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
version, err := client.GetVersion(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("GetVersion() error = %v", err)
|
||||
}
|
||||
|
||||
if version.Version != "1.36.0" {
|
||||
t.Errorf("Expected version 1.36.0, got %s", version.Version)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_ContextCancellation(t *testing.T) {
|
||||
// Create a mock server that delays response
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(rpcResponse{
|
||||
Jsonrpc: "2.0",
|
||||
ID: "1",
|
||||
Result: json.RawMessage(`"OK"`),
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client, err := NewClient(server.URL, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
_, err = client.GetVersion(ctx)
|
||||
if err == nil {
|
||||
t.Error("Expected context cancellation error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_RPCError(t *testing.T) {
|
||||
// Create a mock server that returns an error
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var req rpcRequest
|
||||
json.NewDecoder(r.Body).Decode(&req)
|
||||
|
||||
resp := rpcResponse{
|
||||
Jsonrpc: "2.0",
|
||||
ID: req.ID,
|
||||
Error: &rpcError{
|
||||
Code: 1,
|
||||
Message: "Unauthorized",
|
||||
},
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client, err := NewClient(server.URL, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
_, err = client.GetVersion(ctx)
|
||||
if err == nil {
|
||||
t.Error("Expected RPC error, got nil")
|
||||
}
|
||||
|
||||
var rpcErr *rpcError
|
||||
if !errors.As(err, &rpcErr) {
|
||||
t.Errorf("Expected rpcError, got %T", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatus_DownloadStatus(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
status string
|
||||
check func(*Status) bool
|
||||
}{
|
||||
{
|
||||
name: "active",
|
||||
status: "active",
|
||||
check: (*Status).IsDownloadActive,
|
||||
},
|
||||
{
|
||||
name: "waiting",
|
||||
status: "waiting",
|
||||
check: (*Status).IsDownloadWaiting,
|
||||
},
|
||||
{
|
||||
name: "paused",
|
||||
status: "paused",
|
||||
check: (*Status).IsDownloadPaused,
|
||||
},
|
||||
{
|
||||
name: "error",
|
||||
status: "error",
|
||||
check: (*Status).IsDownloadError,
|
||||
},
|
||||
{
|
||||
name: "complete",
|
||||
status: "complete",
|
||||
check: (*Status).IsDownloadComplete,
|
||||
},
|
||||
{
|
||||
name: "removed",
|
||||
status: "removed",
|
||||
check: (*Status).IsDownloadRemoved,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &Status{Status: tt.status}
|
||||
if !tt.check(s) {
|
||||
t.Errorf("Expected status %s check to return true", tt.status)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
97
pkg/aria2/example/main.go
Normal file
97
pkg/aria2/example/main.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/pkg/aria2"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Create aria2 client
|
||||
client, err := aria2.NewClient("http://localhost:6800/jsonrpc", "")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Get aria2 version
|
||||
version, err := client.GetVersion(ctx)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Printf("aria2 version: %s\n", version.Version)
|
||||
fmt.Printf("Enabled features: %v\n", version.EnabledFeatures)
|
||||
|
||||
// Add a download
|
||||
uris := []string{"https://example.com/file.zip"}
|
||||
options := aria2.Options{
|
||||
"dir": "/downloads",
|
||||
}
|
||||
|
||||
gid, err := client.AddURI(ctx, uris, options)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Printf("Download started with GID: %s\n", gid)
|
||||
|
||||
// Monitor download progress
|
||||
for {
|
||||
status, err := client.TellStatus(ctx, gid)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Printf("Status: %s, Progress: %s/%s bytes, Speed: %s bytes/s\n",
|
||||
status.Status,
|
||||
status.CompletedLength,
|
||||
status.TotalLength,
|
||||
status.DownloadSpeed,
|
||||
)
|
||||
|
||||
if status.IsDownloadComplete() {
|
||||
fmt.Println("Download completed!")
|
||||
break
|
||||
}
|
||||
|
||||
if status.IsDownloadError() {
|
||||
fmt.Printf("Download error: %s\n", status.ErrorMessage)
|
||||
break
|
||||
}
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
|
||||
// Get global statistics
|
||||
stat, err := client.GetGlobalStat(ctx)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Printf("Global stats - Download speed: %s, Active: %s, Waiting: %s\n",
|
||||
stat.DownloadSpeed,
|
||||
stat.NumActive,
|
||||
stat.NumWaiting,
|
||||
)
|
||||
|
||||
// List active downloads
|
||||
activeDownloads, err := client.TellActive(ctx)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Printf("Active downloads: %d\n", len(activeDownloads))
|
||||
for _, download := range activeDownloads {
|
||||
fmt.Printf(" GID: %s, Status: %s\n", download.GID, download.Status)
|
||||
}
|
||||
|
||||
// Example with context timeout
|
||||
ctxWithTimeout, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
_, err = client.TellStatus(ctxWithTimeout, gid)
|
||||
if err != nil {
|
||||
log.Printf("Request failed: %v\n", err)
|
||||
}
|
||||
}
|
||||
@@ -8,8 +8,17 @@ default, message, template
|
||||
) */
|
||||
type FnameST string
|
||||
|
||||
var FnameSTDisplay = map[FnameST]string{
|
||||
Default: "默认",
|
||||
Message: "优先从消息生成",
|
||||
Template: "自定义模板",
|
||||
var fnameSTDisplay = map[FnameST]map[string]string{
|
||||
Default: {"zh-CN": "默认", "en": "Default"},
|
||||
Message: {"zh-CN": "优先从消息生成", "en": "Gen From Msg First"},
|
||||
Template: {"zh-CN": "自定义模板", "en": "Template"},
|
||||
}
|
||||
|
||||
func GetDisplay(st FnameST, lang string) string {
|
||||
if display, ok := fnameSTDisplay[st]; ok {
|
||||
if str, ok := display[lang]; ok {
|
||||
return str
|
||||
}
|
||||
}
|
||||
return fnameSTDisplay[st]["en"]
|
||||
}
|
||||
|
||||
@@ -85,7 +85,7 @@ func TestConcurrencySafety(t *testing.T) {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < n; i++ {
|
||||
for i := range n {
|
||||
q.Add(newTask(fmt.Sprintf("p%d", i)))
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -181,6 +181,11 @@ func (t *Telegram) Save(ctx context.Context, r io.Reader, storagePath string) er
|
||||
switch mtypeStr {
|
||||
case "video/mp4":
|
||||
info, err := getMP4Meta(rs)
|
||||
if err != nil {
|
||||
// Fallback to ffprobe if gomedia fails (e.g., malformed MP4)
|
||||
rs.Seek(0, io.SeekStart)
|
||||
info, err = getVideoMetadata(rs)
|
||||
}
|
||||
if err == nil {
|
||||
media = doc.Video().
|
||||
Duration(time.Duration(info.Duration)*time.Second).
|
||||
|
||||
@@ -21,12 +21,19 @@ type VideoMetadata struct {
|
||||
}
|
||||
|
||||
// a go native way to get mp4 video metadata
|
||||
func getMP4Meta(rs io.ReadSeeker) (*VideoMetadata, error) {
|
||||
func getMP4Meta(rs io.ReadSeeker) (metadata *VideoMetadata, err error) {
|
||||
// Recover from panics in the gomedia library (e.g., "no vosdata" panic)
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("panic while parsing MP4: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
d := mp4.CreateMp4Demuxer(rs)
|
||||
|
||||
tracks, err := d.ReadHead()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
tracks, e := d.ReadHead()
|
||||
if e != nil {
|
||||
return nil, e
|
||||
}
|
||||
|
||||
for _, track := range tracks {
|
||||
|
||||
Reference in New Issue
Block a user