feat(aria2): add Aria2 download command and client integration

This commit is contained in:
krau
2026-01-03 17:40:55 +08:00
parent c8d8a2e0eb
commit 75e5fd10ea
8 changed files with 1048 additions and 7 deletions

View File

@@ -3,6 +3,7 @@ package handlers
import (
"net/url"
"strings"
"sync"
"github.com/celestix/gotgproto/ext"
"github.com/charmbracelet/log"
@@ -10,6 +11,8 @@ import (
"github.com/krau/SaveAny-Bot/client/bot/handlers/utils/msgelem"
"github.com/krau/SaveAny-Bot/common/i18n"
"github.com/krau/SaveAny-Bot/common/i18n/i18nk"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/pkg/aria2"
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
"github.com/krau/SaveAny-Bot/pkg/tcbdata"
"github.com/krau/SaveAny-Bot/storage"
@@ -50,3 +53,53 @@ func handleDlCmd(ctx *ext.Context, update *ext.Update) error {
})
return nil
}
var aria2ClientInitOnce sync.Once
var aria2ClientInitErr error
var aria2Client *aria2.Client
func handleAria2DlCmd(ctx *ext.Context, update *ext.Update) error {
if !config.C().Aria2.Enable {
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgAria2ErrorAria2NotEnabled)), nil)
return nil
}
logger := log.FromContext(ctx)
args := strings.Split(update.EffectiveMessage.Text, " ")
if len(args) < 2 {
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgDlUsage)), nil)
return nil
}
links := args[1:]
for i, link := range links {
links[i] = strings.TrimSpace(link)
}
links = slice.Compact(links)
if len(links) == 0 {
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgDlErrorNoValidLinks)), nil)
return nil
}
logger.Debug("Adding aria2 download", "links", links)
aria2ClientInitOnce.Do(func() {
aria2Client, aria2ClientInitErr = aria2.NewClient(config.C().Aria2.Url, config.C().Aria2.Secret)
})
if aria2ClientInitErr != nil {
logger.Error("Failed to initialize aria2 client", "error", aria2ClientInitErr)
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgAria2ErrorAria2ClientInitFailed, map[string]any{
"Error": aria2ClientInitErr.Error(),
})), nil)
return nil
}
gid, err := aria2Client.AddURI(ctx, links, nil)
if err != nil {
logger.Error("Failed to add aria2 download", "error", err)
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgAria2ErrorAddingAria2Download, map[string]any{
"Error": err.Error(),
})), nil)
return nil
}
logger.Info("Aria2 download added", "gid", gid)
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgAria2InfoAria2DownloadAdded, map[string]any{
"GID": gid,
})), nil)
return nil
}

View File

@@ -29,6 +29,7 @@ var CommandHandlers = []DescCommandHandler{
{"rule", i18nk.BotMsgCmdRule, handleRuleCmd},
{"save", i18nk.BotMsgCmdSave, handleSilentMode(handleSaveCmd, handleSilentSaveReplied)},
{"dl", i18nk.BotMsgCmdDl, handleDlCmd},
{"aria2dl", i18nk.BotMsgCmdAria2dl, handleAria2DlCmd},
{"task", i18nk.BotMsgCmdTask, handleTaskCmd},
{"cancel", i18nk.BotMsgCmdCancel, handleCancelCmd},
{"config", i18nk.BotMsgCmdConfig, handleConfigCmd},

View File

@@ -4,10 +4,16 @@ package i18nk
type Key string
const (
BotMsgAria2ErrorAddingAria2Download Key = "bot.msg.aria2.error_adding_aria2_download"
BotMsgAria2ErrorAria2ClientInitFailed Key = "bot.msg.aria2.error_aria2_client_init_failed"
BotMsgAria2ErrorAria2NotEnabled Key = "bot.msg.aria2.error_aria2_not_enabled"
BotMsgAria2InfoAddingAria2Download Key = "bot.msg.aria2.info_adding_aria2_download"
BotMsgAria2InfoAria2DownloadAdded Key = "bot.msg.aria2.info_aria2_download_added"
BotMsgCancelErrorCancelFailed Key = "bot.msg.cancel.error_cancel_failed"
BotMsgCancelInfoCancelRequested Key = "bot.msg.cancel.info_cancel_requested"
BotMsgCancelInfoCancellingTask Key = "bot.msg.cancel.info_cancelling_task"
BotMsgCancelUsage Key = "bot.msg.cancel.usage"
BotMsgCmdAria2dl Key = "bot.msg.cmd.aria2dl"
BotMsgCmdCancel Key = "bot.msg.cmd.cancel"
BotMsgCmdConfig Key = "bot.msg.cmd.config"
BotMsgCmdDir Key = "bot.msg.cmd.dir"
@@ -171,6 +177,7 @@ const (
BotMsgSaveHelpText Key = "bot.msg.save_help_text"
BotMsgStorageInfoFilenamePrefix Key = "bot.msg.storage.info_filename_prefix"
BotMsgStorageInfoPromptSelectStorage Key = "bot.msg.storage.info_prompt_select_storage"
BotMsgSyncpeersDone Key = "bot.msg.syncpeers.done"
BotMsgSyncpeersFailed Key = "bot.msg.syncpeers.failed"
BotMsgSyncpeersStart Key = "bot.msg.syncpeers.start"
BotMsgSyncpeersSuccess Key = "bot.msg.syncpeers.success"

View File

@@ -29,6 +29,7 @@ bot:
/silent - 开关静默模式
/storage - 设置默认存储位置
/save [自定义文件名] - 保存文件
/dl <链接1> <链接2> ... - 下载给定链接的文件
/dir - 管理存储目录
/rule - 管理规则
/config - 修改配置
@@ -50,6 +51,7 @@ bot:
rule: "管理自动存储规则"
save: "保存文件"
dl: "下载给定链接的文件"
aria2dl: "使用 Aria2 下载给定链接的文件"
task: "管理任务队列"
cancel: "取消任务"
watch: "监听聊天(UserBot)"
@@ -329,3 +331,9 @@ bot:
start: "正在同步对话列表..."
success: "对话列表同步完成, 共同步 {{.Count}} 个对话"
failed: "对话列表同步失败: {{.Error}}"
aria2:
error_aria2_not_enabled: "Aria2 功能未启用, 请在配置文件中启用"
error_aria2_client_init_failed: "Aria2 客户端初始化失败: {{.Error}}"
info_adding_aria2_download: "正在添加 Aria2 下载任务..."
error_adding_aria2_download: "添加 Aria2 下载任务失败: {{.Error}}"
info_aria2_download_added: "Aria2 下载任务已添加, GID: {{.GID}}"

View File

@@ -16,13 +16,14 @@ import (
)
type Config struct {
Lang string `toml:"lang" mapstructure:"lang" json:"lang"`
Workers int `toml:"workers" mapstructure:"workers"`
Retry int `toml:"retry" mapstructure:"retry"`
NoCleanCache bool `toml:"no_clean_cache" mapstructure:"no_clean_cache" json:"no_clean_cache"`
Threads int `toml:"threads" mapstructure:"threads" json:"threads"`
Stream bool `toml:"stream" mapstructure:"stream" json:"stream"`
Proxy string `toml:"proxy" mapstructure:"proxy" json:"proxy"`
Lang string `toml:"lang" mapstructure:"lang" json:"lang"`
Workers int `toml:"workers" mapstructure:"workers"`
Retry int `toml:"retry" mapstructure:"retry"`
NoCleanCache bool `toml:"no_clean_cache" mapstructure:"no_clean_cache" json:"no_clean_cache"`
Threads int `toml:"threads" mapstructure:"threads" json:"threads"`
Stream bool `toml:"stream" mapstructure:"stream" json:"stream"`
Proxy string `toml:"proxy" mapstructure:"proxy" json:"proxy"`
Aria2 aria2Config `toml:"aria2" mapstructure:"aria2" json:"aria2"`
Cache cacheConfig `toml:"cache" mapstructure:"cache" json:"cache"`
Users []userConfig `toml:"users" mapstructure:"users" json:"users"`
@@ -34,6 +35,12 @@ type Config struct {
Hook hookConfig `toml:"hook" mapstructure:"hook" json:"hook"`
}
type aria2Config struct {
Enable bool `toml:"enable" mapstructure:"enable" json:"enable"`
Url string `toml:"url" mapstructure:"url" json:"url"`
Secret string `toml:"secret" mapstructure:"secret" json:"secret"`
}
var cfg = &Config{}
func C() Config {

546
pkg/aria2/client.go Normal file
View File

@@ -0,0 +1,546 @@
package aria2
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"sync/atomic"
)
var (
ErrInvalidURL = errors.New("aria2: invalid URL")
ErrRPCFailed = errors.New("aria2: RPC call failed")
ErrInvalidResponse = errors.New("aria2: invalid response")
)
// Client represents an aria2 JSON-RPC client
type Client struct {
url string
secret string
client *http.Client
id atomic.Int64
}
// rpcRequest represents a JSON-RPC 2.0 request
type rpcRequest struct {
Jsonrpc string `json:"jsonrpc"`
ID string `json:"id"`
Method string `json:"method"`
Params []any `json:"params"`
}
// rpcResponse represents a JSON-RPC 2.0 response
type rpcResponse struct {
Jsonrpc string `json:"jsonrpc"`
ID string `json:"id"`
Result json.RawMessage `json:"result,omitempty"`
Error *rpcError `json:"error,omitempty"`
}
// rpcError represents a JSON-RPC 2.0 error
type rpcError struct {
Code int `json:"code"`
Message string `json:"message"`
}
func (e *rpcError) Error() string {
return fmt.Sprintf("aria2 RPC error %d: %s", e.Code, e.Message)
}
// Options for download
type Options map[string]any
// Status represents the status of a download
type Status struct {
GID string `json:"gid"`
Status string `json:"status"`
TotalLength string `json:"totalLength"`
CompletedLength string `json:"completedLength"`
UploadLength string `json:"uploadLength"`
Bitfield string `json:"bitfield,omitempty"`
DownloadSpeed string `json:"downloadSpeed"`
UploadSpeed string `json:"uploadSpeed"`
InfoHash string `json:"infoHash,omitempty"`
NumSeeders string `json:"numSeeders,omitempty"`
Seeder string `json:"seeder,omitempty"`
PieceLength string `json:"pieceLength,omitempty"`
NumPieces string `json:"numPieces,omitempty"`
Connections string `json:"connections"`
ErrorCode string `json:"errorCode,omitempty"`
ErrorMessage string `json:"errorMessage,omitempty"`
FollowedBy []string `json:"followedBy,omitempty"`
Following string `json:"following,omitempty"`
BelongsTo string `json:"belongsTo,omitempty"`
Dir string `json:"dir"`
Files []File `json:"files"`
BitTorrent struct {
AnnounceList [][]string `json:"announceList,omitempty"`
Comment string `json:"comment,omitempty"`
CreationDate int64 `json:"creationDate,omitempty"`
Mode string `json:"mode,omitempty"`
Info struct {
Name string `json:"name,omitempty"`
} `json:"info"`
} `json:"bittorrent"`
VerifiedLength string `json:"verifiedLength,omitempty"`
VerifyIntegrityPending string `json:"verifyIntegrityPending,omitempty"`
}
// File represents a file in the download
type File struct {
Index string `json:"index"`
Path string `json:"path"`
Length string `json:"length"`
CompletedLength string `json:"completedLength"`
Selected string `json:"selected"`
URIs []URI `json:"uris"`
}
// URI represents a URI for a file
type URI struct {
URI string `json:"uri"`
Status string `json:"status"`
}
// GlobalStat represents global statistics
type GlobalStat struct {
DownloadSpeed string `json:"downloadSpeed"`
UploadSpeed string `json:"uploadSpeed"`
NumActive string `json:"numActive"`
NumWaiting string `json:"numWaiting"`
NumStopped string `json:"numStopped"`
NumStoppedTotal string `json:"numStoppedTotal"`
}
// Version represents aria2 version information
type Version struct {
Version string `json:"version"`
EnabledFeatures []string `json:"enabledFeatures"`
}
// NewClient creates a new aria2 client
// url: aria2 RPC URL (e.g., "http://localhost:6800/jsonrpc")
// secret: aria2 RPC secret token (optional, use empty string if not set)
func NewClient(url, secret string) (*Client, error) {
if url == "" {
return nil, ErrInvalidURL
}
return &Client{
url: url,
secret: secret,
client: &http.Client{},
}, nil
}
// NewClientWithHTTPClient creates a new aria2 client with custom HTTP client
func NewClientWithHTTPClient(url, secret string, httpClient *http.Client) (*Client, error) {
if url == "" {
return nil, ErrInvalidURL
}
if httpClient == nil {
httpClient = &http.Client{}
}
return &Client{
url: url,
secret: secret,
client: httpClient,
}, nil
}
// call makes a JSON-RPC call to aria2
func (c *Client) call(ctx context.Context, method string, params []any, result any) error {
// Prepare params with secret token if set
var rpcParams []any
if c.secret != "" {
rpcParams = append([]any{fmt.Sprintf("token:%s", c.secret)}, params...)
} else {
rpcParams = params
}
// Create request
reqID := fmt.Sprintf("%d", c.id.Add(1))
req := &rpcRequest{
Jsonrpc: "2.0",
ID: reqID,
Method: method,
Params: rpcParams,
}
// Marshal request
reqBody, err := json.Marshal(req)
if err != nil {
return fmt.Errorf("%w: failed to marshal request: %v", ErrRPCFailed, err)
}
// Create HTTP request
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.url, bytes.NewReader(reqBody))
if err != nil {
return fmt.Errorf("%w: failed to create request: %v", ErrRPCFailed, err)
}
httpReq.Header.Set("Content-Type", "application/json")
// Send request
resp, err := c.client.Do(httpReq)
if err != nil {
return fmt.Errorf("%w: failed to send request: %v", ErrRPCFailed, err)
}
defer resp.Body.Close()
// Read response
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("%w: failed to read response: %v", ErrRPCFailed, err)
}
// Check HTTP status
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("%w: HTTP %d: %s", ErrRPCFailed, resp.StatusCode, string(body))
}
// Parse response
var rpcResp rpcResponse
if err := json.Unmarshal(body, &rpcResp); err != nil {
return fmt.Errorf("%w: failed to unmarshal response: %v", ErrInvalidResponse, err)
}
// Check for RPC error
if rpcResp.Error != nil {
return rpcResp.Error
}
// Check response ID
if rpcResp.ID != reqID {
return fmt.Errorf("%w: response ID mismatch", ErrInvalidResponse)
}
// Unmarshal result if needed
if result != nil {
if err := json.Unmarshal(rpcResp.Result, result); err != nil {
return fmt.Errorf("%w: failed to unmarshal result: %v", ErrInvalidResponse, err)
}
}
return nil
}
// AddURI adds a new download with URIs
func (c *Client) AddURI(ctx context.Context, uris []string, options Options) (string, error) {
var gid string
params := []any{uris}
if options != nil {
params = append(params, options)
}
err := c.call(ctx, "aria2.addUri", params, &gid)
return gid, err
}
// AddTorrent adds a new download with torrent file content
func (c *Client) AddTorrent(ctx context.Context, torrent []byte, uris []string, options Options) (string, error) {
var gid string
params := []any{torrent}
if len(uris) > 0 {
params = append(params, uris)
}
if options != nil {
params = append(params, options)
}
err := c.call(ctx, "aria2.addTorrent", params, &gid)
return gid, err
}
// AddMetalink adds a new download with metalink file content
func (c *Client) AddMetalink(ctx context.Context, metalink []byte, options Options) ([]string, error) {
var gids []string
params := []any{metalink}
if options != nil {
params = append(params, options)
}
err := c.call(ctx, "aria2.addMetalink", params, &gids)
return gids, err
}
// Remove removes the download denoted by gid
func (c *Client) Remove(ctx context.Context, gid string) (string, error) {
var result string
err := c.call(ctx, "aria2.remove", []any{gid}, &result)
return result, err
}
// ForceRemove removes the download denoted by gid forcefully
func (c *Client) ForceRemove(ctx context.Context, gid string) (string, error) {
var result string
err := c.call(ctx, "aria2.forceRemove", []any{gid}, &result)
return result, err
}
// Pause pauses the download denoted by gid
func (c *Client) Pause(ctx context.Context, gid string) (string, error) {
var result string
err := c.call(ctx, "aria2.pause", []any{gid}, &result)
return result, err
}
// PauseAll pauses all downloads
func (c *Client) PauseAll(ctx context.Context) (string, error) {
var result string
err := c.call(ctx, "aria2.pauseAll", []any{}, &result)
return result, err
}
// ForcePause pauses the download denoted by gid forcefully
func (c *Client) ForcePause(ctx context.Context, gid string) (string, error) {
var result string
err := c.call(ctx, "aria2.forcePause", []any{gid}, &result)
return result, err
}
// ForcePauseAll pauses all downloads forcefully
func (c *Client) ForcePauseAll(ctx context.Context) (string, error) {
var result string
err := c.call(ctx, "aria2.forcePauseAll", []any{}, &result)
return result, err
}
// Unpause unpauses the download denoted by gid
func (c *Client) Unpause(ctx context.Context, gid string) (string, error) {
var result string
err := c.call(ctx, "aria2.unpause", []any{gid}, &result)
return result, err
}
// UnpauseAll unpauses all downloads
func (c *Client) UnpauseAll(ctx context.Context) (string, error) {
var result string
err := c.call(ctx, "aria2.unpauseAll", []any{}, &result)
return result, err
}
// TellStatus returns the progress of the download denoted by gid
func (c *Client) TellStatus(ctx context.Context, gid string, keys ...string) (*Status, error) {
var status Status
params := []any{gid}
if len(keys) > 0 {
params = append(params, keys)
}
err := c.call(ctx, "aria2.tellStatus", params, &status)
return &status, err
}
// GetURIs returns the URIs used in the download denoted by gid
func (c *Client) GetURIs(ctx context.Context, gid string) ([]URI, error) {
var uris []URI
err := c.call(ctx, "aria2.getUris", []any{gid}, &uris)
return uris, err
}
// GetFiles returns the file list of the download denoted by gid
func (c *Client) GetFiles(ctx context.Context, gid string) ([]File, error) {
var files []File
err := c.call(ctx, "aria2.getFiles", []any{gid}, &files)
return files, err
}
// GetPeers returns a list of peers of the download denoted by gid
func (c *Client) GetPeers(ctx context.Context, gid string) ([]any, error) {
var peers []any
err := c.call(ctx, "aria2.getPeers", []any{gid}, &peers)
return peers, err
}
// GetServers returns currently connected HTTP(S)/FTP/SFTP servers of the download denoted by gid
func (c *Client) GetServers(ctx context.Context, gid string) ([]any, error) {
var servers []any
err := c.call(ctx, "aria2.getServers", []any{gid}, &servers)
return servers, err
}
// TellActive returns a list of active downloads
func (c *Client) TellActive(ctx context.Context, keys ...string) ([]Status, error) {
var statuses []Status
params := []any{}
if len(keys) > 0 {
params = append(params, keys)
}
err := c.call(ctx, "aria2.tellActive", params, &statuses)
return statuses, err
}
// TellWaiting returns a list of waiting downloads
func (c *Client) TellWaiting(ctx context.Context, offset, num int, keys ...string) ([]Status, error) {
var statuses []Status
params := []any{offset, num}
if len(keys) > 0 {
params = append(params, keys)
}
err := c.call(ctx, "aria2.tellWaiting", params, &statuses)
return statuses, err
}
// TellStopped returns a list of stopped downloads
func (c *Client) TellStopped(ctx context.Context, offset, num int, keys ...string) ([]Status, error) {
var statuses []Status
params := []any{offset, num}
if len(keys) > 0 {
params = append(params, keys)
}
err := c.call(ctx, "aria2.tellStopped", params, &statuses)
return statuses, err
}
// ChangePosition changes the position of the download denoted by gid
func (c *Client) ChangePosition(ctx context.Context, gid string, pos int, how string) (int, error) {
var result int
err := c.call(ctx, "aria2.changePosition", []any{gid, pos, how}, &result)
return result, err
}
// ChangeURI changes the URI of the download denoted by gid
func (c *Client) ChangeURI(ctx context.Context, gid string, fileIndex int, delURIs []string, addURIs []string) ([]int, error) {
var result []int
params := []any{gid, fileIndex, delURIs, addURIs}
err := c.call(ctx, "aria2.changeUri", params, &result)
return result, err
}
// GetOption returns options of the download denoted by gid
func (c *Client) GetOption(ctx context.Context, gid string) (Options, error) {
var options Options
err := c.call(ctx, "aria2.getOption", []any{gid}, &options)
return options, err
}
// ChangeOption changes options of the download denoted by gid dynamically
func (c *Client) ChangeOption(ctx context.Context, gid string, options Options) (string, error) {
var result string
err := c.call(ctx, "aria2.changeOption", []any{gid, options}, &result)
return result, err
}
// GetGlobalOption returns the global options
func (c *Client) GetGlobalOption(ctx context.Context) (Options, error) {
var options Options
err := c.call(ctx, "aria2.getGlobalOption", []any{}, &options)
return options, err
}
// ChangeGlobalOption changes global options dynamically
func (c *Client) ChangeGlobalOption(ctx context.Context, options Options) (string, error) {
var result string
err := c.call(ctx, "aria2.changeGlobalOption", []any{options}, &result)
return result, err
}
// GetGlobalStat returns global statistics such as the overall download and upload speed
func (c *Client) GetGlobalStat(ctx context.Context) (*GlobalStat, error) {
var stat GlobalStat
err := c.call(ctx, "aria2.getGlobalStat", []any{}, &stat)
return &stat, err
}
// PurgeDownloadResult purges completed/error/removed downloads
func (c *Client) PurgeDownloadResult(ctx context.Context) (string, error) {
var result string
err := c.call(ctx, "aria2.purgeDownloadResult", []any{}, &result)
return result, err
}
// RemoveDownloadResult removes a completed/error/removed download denoted by gid
func (c *Client) RemoveDownloadResult(ctx context.Context, gid string) (string, error) {
var result string
err := c.call(ctx, "aria2.removeDownloadResult", []any{gid}, &result)
return result, err
}
// GetVersion returns the version of aria2 and the list of enabled features
func (c *Client) GetVersion(ctx context.Context) (*Version, error) {
var version Version
err := c.call(ctx, "aria2.getVersion", []any{}, &version)
return &version, err
}
// GetSessionInfo returns session information
func (c *Client) GetSessionInfo(ctx context.Context) (map[string]any, error) {
var info map[string]any
err := c.call(ctx, "aria2.getSessionInfo", []any{}, &info)
return info, err
}
// Shutdown shuts down aria2
func (c *Client) Shutdown(ctx context.Context) (string, error) {
var result string
err := c.call(ctx, "aria2.shutdown", []any{}, &result)
return result, err
}
// ForceShutdown shuts down aria2 forcefully
func (c *Client) ForceShutdown(ctx context.Context) (string, error) {
var result string
err := c.call(ctx, "aria2.forceShutdown", []any{}, &result)
return result, err
}
// SaveSession saves the current session to a file
func (c *Client) SaveSession(ctx context.Context) (string, error) {
var result string
err := c.call(ctx, "aria2.saveSession", []any{}, &result)
return result, err
}
// MultiCall executes multiple method calls in a single request (system.multicall)
func (c *Client) MultiCall(ctx context.Context, calls []map[string]any) ([]any, error) {
var results []any
err := c.call(ctx, "system.multicall", []any{calls}, &results)
return results, err
}
// ListMethods lists all available RPC methods
func (c *Client) ListMethods(ctx context.Context) ([]string, error) {
var methods []string
err := c.call(ctx, "system.listMethods", []any{}, &methods)
return methods, err
}
// ListNotifications lists all available RPC notifications
func (c *Client) ListNotifications(ctx context.Context) ([]string, error) {
var notifications []string
err := c.call(ctx, "system.listNotifications", []any{}, &notifications)
return notifications, err
}
// IsDownloadComplete checks if the download is complete
func (s *Status) IsDownloadComplete() bool {
return s.Status == "complete"
}
// IsDownloadActive checks if the download is active
func (s *Status) IsDownloadActive() bool {
return s.Status == "active"
}
// IsDownloadWaiting checks if the download is waiting
func (s *Status) IsDownloadWaiting() bool {
return s.Status == "waiting"
}
// IsDownloadPaused checks if the download is paused
func (s *Status) IsDownloadPaused() bool {
return s.Status == "paused"
}
// IsDownloadError checks if the download has an error
func (s *Status) IsDownloadError() bool {
return s.Status == "error"
}
// IsDownloadRemoved checks if the download is removed
func (s *Status) IsDownloadRemoved() bool {
return s.Status == "removed"
}

322
pkg/aria2/client_test.go Normal file
View File

@@ -0,0 +1,322 @@
package aria2
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestNewClient(t *testing.T) {
tests := []struct {
name string
url string
secret string
wantErr bool
}{
{
name: "valid client",
url: "http://localhost:6800/jsonrpc",
secret: "test-secret",
wantErr: false,
},
{
name: "valid client without secret",
url: "http://localhost:6800/jsonrpc",
secret: "",
wantErr: false,
},
{
name: "invalid empty url",
url: "",
secret: "test-secret",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client, err := NewClient(tt.url, tt.secret)
if (err != nil) != tt.wantErr {
t.Errorf("NewClient() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && client == nil {
t.Error("NewClient() returned nil client")
}
})
}
}
func TestClient_AddURI(t *testing.T) {
// Create a mock server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
t.Errorf("Expected POST request, got %s", r.Method)
}
var req rpcRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Errorf("Failed to decode request: %v", err)
}
// Verify method
if req.Method != "aria2.addUri" {
t.Errorf("Expected method aria2.addUri, got %s", req.Method)
}
// Send response
resp := rpcResponse{
Jsonrpc: "2.0",
ID: req.ID,
Result: json.RawMessage(`"2089b05ecca3d829"`),
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
client, err := NewClient(server.URL, "")
if err != nil {
t.Fatalf("Failed to create client: %v", err)
}
ctx := context.Background()
gid, err := client.AddURI(ctx, []string{"http://example.com/file.txt"}, nil)
if err != nil {
t.Fatalf("AddURI() error = %v", err)
}
if gid != "2089b05ecca3d829" {
t.Errorf("Expected gid 2089b05ecca3d829, got %s", gid)
}
}
func TestClient_TellStatus(t *testing.T) {
// Create a mock server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req rpcRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Errorf("Failed to decode request: %v", err)
}
// Verify method
if req.Method != "aria2.tellStatus" {
t.Errorf("Expected method aria2.tellStatus, got %s", req.Method)
}
// Send response
status := Status{
GID: "2089b05ecca3d829",
Status: "active",
TotalLength: "1024000",
CompletedLength: "512000",
DownloadSpeed: "102400",
Files: []File{},
}
result, _ := json.Marshal(status)
resp := rpcResponse{
Jsonrpc: "2.0",
ID: req.ID,
Result: result,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
client, err := NewClient(server.URL, "")
if err != nil {
t.Fatalf("Failed to create client: %v", err)
}
ctx := context.Background()
status, err := client.TellStatus(ctx, "2089b05ecca3d829")
if err != nil {
t.Fatalf("TellStatus() error = %v", err)
}
if status.GID != "2089b05ecca3d829" {
t.Errorf("Expected gid 2089b05ecca3d829, got %s", status.GID)
}
if status.Status != "active" {
t.Errorf("Expected status active, got %s", status.Status)
}
if !status.IsDownloadActive() {
t.Error("Expected download to be active")
}
}
func TestClient_WithSecret(t *testing.T) {
expectedSecret := "my-secret-token"
// Create a mock server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req rpcRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Errorf("Failed to decode request: %v", err)
}
// Verify secret token is included in params
if len(req.Params) == 0 {
t.Error("Expected params to contain secret token")
} else {
token, ok := req.Params[0].(string)
if !ok || token != "token:"+expectedSecret {
t.Errorf("Expected token:%s, got %v", expectedSecret, req.Params[0])
}
}
// Send response
version := Version{
Version: "1.36.0",
EnabledFeatures: []string{"Async DNS", "BitTorrent", "HTTP", "HTTPS"},
}
result, _ := json.Marshal(version)
resp := rpcResponse{
Jsonrpc: "2.0",
ID: req.ID,
Result: result,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
client, err := NewClient(server.URL, expectedSecret)
if err != nil {
t.Fatalf("Failed to create client: %v", err)
}
ctx := context.Background()
version, err := client.GetVersion(ctx)
if err != nil {
t.Fatalf("GetVersion() error = %v", err)
}
if version.Version != "1.36.0" {
t.Errorf("Expected version 1.36.0, got %s", version.Version)
}
}
func TestClient_ContextCancellation(t *testing.T) {
// Create a mock server that delays response
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(100 * time.Millisecond)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(rpcResponse{
Jsonrpc: "2.0",
ID: "1",
Result: json.RawMessage(`"OK"`),
})
}))
defer server.Close()
client, err := NewClient(server.URL, "")
if err != nil {
t.Fatalf("Failed to create client: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()
_, err = client.GetVersion(ctx)
if err == nil {
t.Error("Expected context cancellation error, got nil")
}
}
func TestClient_RPCError(t *testing.T) {
// Create a mock server that returns an error
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req rpcRequest
json.NewDecoder(r.Body).Decode(&req)
resp := rpcResponse{
Jsonrpc: "2.0",
ID: req.ID,
Error: &rpcError{
Code: 1,
Message: "Unauthorized",
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
client, err := NewClient(server.URL, "")
if err != nil {
t.Fatalf("Failed to create client: %v", err)
}
ctx := context.Background()
_, err = client.GetVersion(ctx)
if err == nil {
t.Error("Expected RPC error, got nil")
}
var rpcErr *rpcError
if !errors.As(err, &rpcErr) {
t.Errorf("Expected rpcError, got %T", err)
}
}
func TestStatus_DownloadStatus(t *testing.T) {
tests := []struct {
name string
status string
check func(*Status) bool
}{
{
name: "active",
status: "active",
check: (*Status).IsDownloadActive,
},
{
name: "waiting",
status: "waiting",
check: (*Status).IsDownloadWaiting,
},
{
name: "paused",
status: "paused",
check: (*Status).IsDownloadPaused,
},
{
name: "error",
status: "error",
check: (*Status).IsDownloadError,
},
{
name: "complete",
status: "complete",
check: (*Status).IsDownloadComplete,
},
{
name: "removed",
status: "removed",
check: (*Status).IsDownloadRemoved,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &Status{Status: tt.status}
if !tt.check(s) {
t.Errorf("Expected status %s check to return true", tt.status)
}
})
}
}

97
pkg/aria2/example/main.go Normal file
View File

@@ -0,0 +1,97 @@
package main
import (
"context"
"fmt"
"log"
"time"
"github.com/krau/SaveAny-Bot/pkg/aria2"
)
func main() {
// Create aria2 client
client, err := aria2.NewClient("http://localhost:6800/jsonrpc", "")
if err != nil {
log.Fatal(err)
}
ctx := context.Background()
// Get aria2 version
version, err := client.GetVersion(ctx)
if err != nil {
log.Fatal(err)
}
fmt.Printf("aria2 version: %s\n", version.Version)
fmt.Printf("Enabled features: %v\n", version.EnabledFeatures)
// Add a download
uris := []string{"https://example.com/file.zip"}
options := aria2.Options{
"dir": "/downloads",
}
gid, err := client.AddURI(ctx, uris, options)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Download started with GID: %s\n", gid)
// Monitor download progress
for {
status, err := client.TellStatus(ctx, gid)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Status: %s, Progress: %s/%s bytes, Speed: %s bytes/s\n",
status.Status,
status.CompletedLength,
status.TotalLength,
status.DownloadSpeed,
)
if status.IsDownloadComplete() {
fmt.Println("Download completed!")
break
}
if status.IsDownloadError() {
fmt.Printf("Download error: %s\n", status.ErrorMessage)
break
}
time.Sleep(1 * time.Second)
}
// Get global statistics
stat, err := client.GetGlobalStat(ctx)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Global stats - Download speed: %s, Active: %s, Waiting: %s\n",
stat.DownloadSpeed,
stat.NumActive,
stat.NumWaiting,
)
// List active downloads
activeDownloads, err := client.TellActive(ctx)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Active downloads: %d\n", len(activeDownloads))
for _, download := range activeDownloads {
fmt.Printf(" GID: %s, Status: %s\n", download.GID, download.Status)
}
// Example with context timeout
ctxWithTimeout, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_, err = client.TellStatus(ctxWithTimeout, gid)
if err != nil {
log.Printf("Request failed: %v\n", err)
}
}