Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
300f7723af | ||
|
|
491ba55f1e | ||
|
|
32519b8c08 |
@@ -59,41 +59,50 @@ func processPendingTask(task *types.Task) error {
|
||||
|
||||
downloadBuilder := Downloader.Download(bot.Client.API(), task.File.Location).WithThreads(getTaskThreads(task.File.FileSize))
|
||||
|
||||
notsupportStreamStorage, notsupportStream := taskStorage.(storage.StorageNotSupportStream)
|
||||
cancelMarkUp := getCancelTaskMarkup(task)
|
||||
if config.Cfg.Stream {
|
||||
if !notsupportStream {
|
||||
text, entities := buildProgressMessageEntity(task, 0, task.StartTime, 0)
|
||||
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
||||
Message: text,
|
||||
Entities: entities,
|
||||
ID: task.ReplyMessageID,
|
||||
ReplyMarkup: cancelMarkUp,
|
||||
})
|
||||
|
||||
text, entities := buildProgressMessageEntity(task, 0, task.StartTime, 0)
|
||||
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
||||
Message: text,
|
||||
Entities: entities,
|
||||
ID: task.ReplyMessageID,
|
||||
ReplyMarkup: getCancelTaskMarkup(task),
|
||||
})
|
||||
pr, pw := io.Pipe()
|
||||
defer pr.Close()
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
defer pr.Close()
|
||||
task.StartTime = time.Now()
|
||||
progressCallback := buildProgressCallback(ctx, task, getProgressUpdateCount(task.File.FileSize))
|
||||
|
||||
task.StartTime = time.Now()
|
||||
progressCallback := buildProgressCallback(ctx, task, getProgressUpdateCount(task.File.FileSize))
|
||||
progressStream := NewProgressStream(pw, task.File.FileSize, progressCallback)
|
||||
|
||||
progressStream := NewProgressStream(pw, task.File.FileSize, progressCallback)
|
||||
eg, uploadCtx := errgroup.WithContext(cancelCtx)
|
||||
|
||||
eg, uploadCtx := errgroup.WithContext(cancelCtx)
|
||||
|
||||
eg.Go(func() error {
|
||||
return taskStorage.Save(uploadCtx, pr, task.StoragePath)
|
||||
})
|
||||
eg.Go(func() error {
|
||||
_, err := downloadBuilder.Stream(uploadCtx, progressStream)
|
||||
if closeErr := pw.CloseWithError(err); closeErr != nil {
|
||||
common.Log.Errorf("Failed to close pipe writer: %v", closeErr)
|
||||
eg.Go(func() error {
|
||||
return taskStorage.Save(uploadCtx, pr, task.StoragePath)
|
||||
})
|
||||
eg.Go(func() error {
|
||||
_, err := downloadBuilder.Stream(uploadCtx, progressStream)
|
||||
if closeErr := pw.CloseWithError(err); closeErr != nil {
|
||||
common.Log.Errorf("Failed to close pipe writer: %v", closeErr)
|
||||
}
|
||||
return err
|
||||
})
|
||||
if err := eg.Wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
return err
|
||||
})
|
||||
if err := eg.Wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
return nil
|
||||
}
|
||||
common.Log.Warnf("存储 %s 不支持流式传输: %s", task.StorageName, notsupportStreamStorage.NotSupportStream())
|
||||
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
||||
Message: fmt.Sprintf("存储 %s 不支持流式传输: %s\n正在使用普通下载...", task.StorageName, notsupportStreamStorage.NotSupportStream()),
|
||||
ID: task.ReplyMessageID,
|
||||
ReplyMarkup: cancelMarkUp,
|
||||
})
|
||||
}
|
||||
|
||||
cacheDestPath := filepath.Join(config.Cfg.Temp.BasePath, task.FileName())
|
||||
@@ -110,7 +119,7 @@ func processPendingTask(task *types.Task) error {
|
||||
Message: text,
|
||||
Entities: entities,
|
||||
ID: task.ReplyMessageID,
|
||||
ReplyMarkup: getCancelTaskMarkup(task),
|
||||
ReplyMarkup: cancelMarkUp,
|
||||
})
|
||||
|
||||
progressCallback := buildProgressCallback(ctx, task, getProgressUpdateCount(task.File.FileSize))
|
||||
|
||||
@@ -35,3 +35,7 @@ Bot 接受两种消息: 文件和链接.
|
||||
- 无法使用多线程从 telegram 下载文件, 速度较慢.
|
||||
- 网络不稳定时, 任务失败率高.
|
||||
- 无法在中间层对文件进行处理, 例如自动文件类型识别.
|
||||
|
||||
**不支持** Stream 模式的存储端:
|
||||
|
||||
- alist
|
||||
|
||||
@@ -140,6 +140,10 @@ func (a *Alist) Save(ctx context.Context, reader io.Reader, storagePath string)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Alist) NotSupportStream() string {
|
||||
return "Alist does not support chunked transfer encoding"
|
||||
}
|
||||
|
||||
func (a *Alist) JoinStoragePath(task types.Task) string {
|
||||
return path.Join(a.config.BasePath, task.StoragePath)
|
||||
}
|
||||
|
||||
@@ -23,6 +23,11 @@ type Storage interface {
|
||||
Save(ctx context.Context, reader io.Reader, storagePath string) error
|
||||
}
|
||||
|
||||
type StorageNotSupportStream interface {
|
||||
Storage
|
||||
NotSupportStream() string
|
||||
}
|
||||
|
||||
var Storages = make(map[string]Storage)
|
||||
|
||||
var UserStorages = make(map[int64][]Storage)
|
||||
|
||||
130
storage/webdav/client._test.go
Normal file
130
storage/webdav/client._test.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package webdav
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/net/webdav"
|
||||
)
|
||||
|
||||
func setupWebDAVServer(t *testing.T) (*httptest.Server, string) {
|
||||
t.Helper()
|
||||
tempDir, err := os.MkdirTemp("", "webdav_test")
|
||||
if err != nil {
|
||||
t.Fatalf("mk temp dir failed: %v", err)
|
||||
}
|
||||
|
||||
handler := &webdav.Handler{
|
||||
Prefix: "/",
|
||||
FileSystem: webdav.Dir(tempDir),
|
||||
LockSystem: webdav.NewMemLS(),
|
||||
}
|
||||
|
||||
server := httptest.NewServer(handler)
|
||||
return server, tempDir
|
||||
}
|
||||
|
||||
func TestMkDirAndExists(t *testing.T) {
|
||||
server, tempDir := setupWebDAVServer(t)
|
||||
defer os.RemoveAll(tempDir)
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "", "", nil)
|
||||
ctx := context.Background()
|
||||
|
||||
testpaths := []string{"testdir", "testdir/subdir", "testdir/子目录", "/testdir/测试路径/测试路径2"}
|
||||
for _, p := range testpaths {
|
||||
exists, err := client.Exists(ctx, p)
|
||||
if err != nil {
|
||||
t.Fatalf("Call Exists Err: %v", err)
|
||||
}
|
||||
if exists {
|
||||
t.Fatalf("Dir should not exist")
|
||||
}
|
||||
|
||||
if err := client.MkDir(ctx, p); err != nil {
|
||||
t.Fatalf("Call MkDir Err: %v", err)
|
||||
}
|
||||
|
||||
exists, err = client.Exists(ctx, p)
|
||||
if err != nil {
|
||||
t.Fatalf("Call Exists Err: %v", err)
|
||||
}
|
||||
if !exists {
|
||||
t.Fatalf("Dir should exist")
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestWriteFile(t *testing.T) {
|
||||
server, tempDir := setupWebDAVServer(t)
|
||||
defer os.RemoveAll(tempDir)
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "", "", nil)
|
||||
ctx := context.Background()
|
||||
|
||||
testCases := []struct {
|
||||
remotePath string
|
||||
content string
|
||||
}{
|
||||
{
|
||||
remotePath: "hello.txt",
|
||||
content: "Hello webdav",
|
||||
},
|
||||
{
|
||||
remotePath: "nested/dir/test.txt",
|
||||
content: "Nested file",
|
||||
},
|
||||
{
|
||||
remotePath: "empty.txt",
|
||||
content: "",
|
||||
},
|
||||
{
|
||||
remotePath: "unicode.txt",
|
||||
content: "测试",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.remotePath, func(t *testing.T) {
|
||||
dir := path.Dir(tc.remotePath)
|
||||
if dir != "." {
|
||||
if err := client.MkDir(ctx, dir); err != nil {
|
||||
t.Fatalf("创建目录 %s 失败: %v", dir, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := client.WriteFile(ctx, tc.remotePath, strings.NewReader(tc.content)); err != nil {
|
||||
t.Fatalf("写入文件 %s 失败: %v", tc.remotePath, err)
|
||||
}
|
||||
|
||||
localPath := filepath.Join(tempDir, tc.remotePath)
|
||||
data, err := os.ReadFile(localPath)
|
||||
if err != nil {
|
||||
t.Fatalf("读取文件 %s 失败: %v", localPath, err)
|
||||
}
|
||||
if string(data) != tc.content {
|
||||
t.Fatalf("文件内容不匹配: got %s, want %s", string(data), tc.content)
|
||||
}
|
||||
|
||||
appended := tc.content + " Overwritten."
|
||||
if err := client.WriteFile(ctx, tc.remotePath, strings.NewReader(appended)); err != nil {
|
||||
t.Fatalf("覆盖写入文件 %s 失败: %v", tc.remotePath, err)
|
||||
}
|
||||
data, err = os.ReadFile(localPath)
|
||||
if err != nil {
|
||||
t.Fatalf("读取覆盖后的文件 %s 失败: %v", localPath, err)
|
||||
}
|
||||
if string(data) != appended {
|
||||
t.Fatalf("文件覆盖后的内容不匹配: got %s, want %s", string(data), appended)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -48,18 +48,55 @@ func (c *Client) doRequest(ctx context.Context, method, url string, body io.Read
|
||||
return c.httpClient.Do(req)
|
||||
}
|
||||
|
||||
func (c *Client) MkDir(ctx context.Context, dirPath string) error {
|
||||
url := c.BaseURL + dirPath
|
||||
resp, err := c.doRequest(ctx, "MKCOL", url, nil)
|
||||
func (c *Client) Exists(ctx context.Context, remotePath string) (bool, error) {
|
||||
url := c.BaseURL + remotePath
|
||||
resp, err := c.doRequest(ctx, "PROPFIND", url, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
return false, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||
return true, nil
|
||||
}
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
return false, nil
|
||||
}
|
||||
return false, fmt.Errorf("PROPFIND: %s", resp.Status)
|
||||
}
|
||||
|
||||
func (c *Client) MkDir(ctx context.Context, dirPath string) error {
|
||||
dirPath = strings.Trim(dirPath, "/")
|
||||
if dirPath == "" {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("MKCOL: %s", resp.Status)
|
||||
parts := strings.Split(dirPath, "/")
|
||||
currentPath := ""
|
||||
for i, part := range parts {
|
||||
if i > 0 {
|
||||
currentPath += "/"
|
||||
}
|
||||
currentPath += part
|
||||
|
||||
exists, err := c.Exists(ctx, currentPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if exists {
|
||||
continue
|
||||
}
|
||||
url := c.BaseURL + currentPath
|
||||
resp, err := c.doRequest(ctx, "MKCOL", url, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return fmt.Errorf("MKCOL %s: %s", currentPath, resp.Status)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) WriteFile(ctx context.Context, remotePath string, content io.Reader) error {
|
||||
|
||||
Reference in New Issue
Block a user