From 7ffd9891a0c0d3ad8982994962d458b912eab044 Mon Sep 17 00:00:00 2001 From: krau <71133316+krau@users.noreply.github.com> Date: Wed, 26 Mar 2025 10:22:38 +0800 Subject: [PATCH] fix: not pass content length when uploading in non stream mode --- core/download.go | 26 +----------- core/download_test.go | 2 +- core/utils.go | 45 +++++++++++++++++++-- storage/alist/alist.go | 6 +++ storage/webdav/client.go | 7 ++++ types/task.go | 82 +++++++++++++++++++++++++++++++++++++ types/types.go | 87 +++------------------------------------- 7 files changed, 144 insertions(+), 111 deletions(-) create mode 100644 types/task.go diff --git a/core/download.go b/core/download.go index 1c0384d..1acf8d0 100644 --- a/core/download.go +++ b/core/download.go @@ -191,7 +191,7 @@ func processTelegraph(extCtx *ext.Context, cancelCtx context.Context, task *type if len(node.Children) != 0 { for _, child := range node.Children { - imgs = append(imgs, GetImages(child)...) + imgs = append(imgs, getNodeImages(child)...) } } @@ -265,27 +265,3 @@ func processTelegraph(extCtx *ext.Context, cancelCtx context.Context, task *type return cancelCtx.Err() } } - -func GetImages(node telegraph.Node) []string { - var srcs []string - - var nodeElement telegraph.NodeElement - data, err := json.Marshal(node) - if err != nil { - return srcs - } - err = json.Unmarshal(data, &nodeElement) - if err != nil { - return srcs - } - - if nodeElement.Tag == "img" { - if src, exists := nodeElement.Attrs["src"]; exists { - srcs = append(srcs, src) - } - } - for _, child := range nodeElement.Children { - srcs = append(srcs, GetImages(child)...) - } - return srcs -} diff --git a/core/download_test.go b/core/download_test.go index bb2c28a..f0c6444 100644 --- a/core/download_test.go +++ b/core/download_test.go @@ -72,7 +72,7 @@ func TestGetImgSrcs(t *testing.T) { "https://example.com/image4.png", } - got := GetImages(complexStructure) + got := getNodeImages(complexStructure) if !reflect.DeepEqual(expected, got) { t.Errorf("expected %v,got %v", expected, got) diff --git a/core/utils.go b/core/utils.go index 4d724e9..98ae77d 100644 --- a/core/utils.go +++ b/core/utils.go @@ -3,6 +3,7 @@ package core import ( "bytes" "context" + "encoding/json" "fmt" "io" "os" @@ -10,6 +11,7 @@ import ( "time" "github.com/celestix/gotgproto/ext" + "github.com/celestix/telegraph-go/v2" "github.com/gabriel-vasile/mimetype" "github.com/gotd/td/telegram/message/entity" "github.com/gotd/td/telegram/message/styling" @@ -22,22 +24,33 @@ import ( ) func saveFileWithRetry(ctx context.Context, storagePath string, taskStorage storage.Storage, cacheFilePath string) error { + file, err := os.Open(cacheFilePath) + if err != nil { + return fmt.Errorf("failed to open cache file: %w", err) + } + defer file.Close() + fileStat, err := file.Stat() + if err != nil { + return fmt.Errorf("failed to get file stat: %w", err) + } + vctx := context.WithValue(ctx, types.ContextKeyContentLength, fileStat.Size()) for i := 0; i <= config.Cfg.Retry; i++ { - if err := ctx.Err(); err != nil { + if err := vctx.Err(); err != nil { return fmt.Errorf("context canceled while saving file: %w", err) } file, err := os.Open(cacheFilePath) if err != nil { return fmt.Errorf("failed to open cache file: %w", err) } - if err := taskStorage.Save(ctx, file, storagePath); err != nil { + defer file.Close() + if err := taskStorage.Save(vctx, file, storagePath); err != nil { if i == config.Cfg.Retry { return fmt.Errorf("failed to save file: %w", err) } common.Log.Errorf("Failed to save file: %s, retrying...", err) select { - case <-ctx.Done(): - return fmt.Errorf("context canceled during retry delay: %w", ctx.Err()) + case <-vctx.Done(): + return fmt.Errorf("context canceled during retry delay: %w", vctx.Err()) case <-time.After(time.Duration(i*500) * time.Millisecond): } continue @@ -256,3 +269,27 @@ func NewProgressStream(writer io.Writer, size int64, callback func(bytesRead, co interval: interval, } } + +func getNodeImages(node telegraph.Node) []string { + var srcs []string + + var nodeElement telegraph.NodeElement + data, err := json.Marshal(node) + if err != nil { + return srcs + } + err = json.Unmarshal(data, &nodeElement) + if err != nil { + return srcs + } + + if nodeElement.Tag == "img" { + if src, exists := nodeElement.Attrs["src"]; exists { + srcs = append(srcs, src) + } + } + for _, child := range nodeElement.Children { + srcs = append(srcs, getNodeImages(child)...) + } + return srcs +} diff --git a/storage/alist/alist.go b/storage/alist/alist.go index a2ca632..f56f8df 100644 --- a/storage/alist/alist.go +++ b/storage/alist/alist.go @@ -106,6 +106,12 @@ func (a *Alist) Save(ctx context.Context, reader io.Reader, storagePath string) req.Header.Set("Authorization", a.token) req.Header.Set("File-Path", url.PathEscape(storagePath)) req.Header.Set("Content-Type", "application/octet-stream") + if length := ctx.Value(types.ContextKeyContentLength); length != nil { + length, ok := length.(int64) + if ok { + req.ContentLength = length + } + } resp, err := a.client.Do(req) if err != nil { diff --git a/storage/webdav/client.go b/storage/webdav/client.go index 8092603..b4ddcc1 100644 --- a/storage/webdav/client.go +++ b/storage/webdav/client.go @@ -6,6 +6,8 @@ import ( "io" "net/http" "strings" + + "github.com/krau/SaveAny-Bot/types" ) type Client struct { @@ -38,6 +40,11 @@ func (c *Client) doRequest(ctx context.Context, method, url string, body io.Read if c.Username != "" && c.Password != "" { req.SetBasicAuth(c.Username, c.Password) } + if length := ctx.Value(types.ContextKeyContentLength); length != nil { + if l, ok := length.(int64); ok { + req.ContentLength = l + } + } return c.httpClient.Do(req) } diff --git a/types/task.go b/types/task.go new file mode 100644 index 0000000..7b554cd --- /dev/null +++ b/types/task.go @@ -0,0 +1,82 @@ +package types + +import ( + "context" + "crypto/md5" + "encoding/hex" + "fmt" + "net/url" + "strings" + "time" + + "github.com/gotd/td/tg" +) + +type Task struct { + Ctx context.Context + Cancel context.CancelFunc + Error error + Status TaskStatus + StorageName string + StoragePath string + StartTime time.Time + + File *File + FileMessageID int + FileChatID int64 + + IsTelegraph bool + TelegraphURL string + + // to track the reply message + ReplyMessageID int + ReplyChatID int64 + UserID int64 +} + +func (t Task) Key() string { + if t.IsTelegraph { + return hashStr(t.TelegraphURL) + } + return fmt.Sprintf("%d:%d", t.FileChatID, t.FileMessageID) +} + +func (t Task) String() string { + if t.IsTelegraph { + return fmt.Sprintf("[telegraph]:%s", t.TelegraphURL) + } + return fmt.Sprintf("[%d:%d]:%s", t.FileChatID, t.FileMessageID, t.File.FileName) +} + +func (t Task) FileName() string { + if t.IsTelegraph { + tgphPath := strings.Split(t.TelegraphURL, "/")[len(strings.Split(t.TelegraphURL, "/"))-1] + tgphPathUnescaped, err := url.PathUnescape(tgphPath) + if err != nil { + return tgphPath + } + return tgphPathUnescaped + } + return t.File.FileName +} + +type File struct { + Location tg.InputFileLocationClass + FileSize int64 + FileName string +} + +func (f File) Hash() string { + locationBytes := []byte(f.Location.String()) + fileSizeBytes := []byte(fmt.Sprintf("%d", f.FileSize)) + fileNameBytes := []byte(f.FileName) + + structBytes := append(locationBytes, fileSizeBytes...) + structBytes = append(structBytes, fileNameBytes...) + + hash := md5.New() + hash.Write(structBytes) + hashBytes := hash.Sum(nil) + + return hex.EncodeToString(hashBytes) +} diff --git a/types/types.go b/types/types.go index 5f06acc..25d4d65 100644 --- a/types/types.go +++ b/types/types.go @@ -1,20 +1,8 @@ package types -import ( - "context" - "crypto/md5" - "encoding/hex" - "fmt" - "net/url" - "strings" - "time" - - "github.com/gotd/td/tg" -) - type TaskStatus string -var ( +const ( Pending TaskStatus = "pending" Succeeded TaskStatus = "succeeded" Failed TaskStatus = "failed" @@ -23,7 +11,7 @@ var ( type StorageType string -var ( +const ( StorageTypeLocal StorageType = "local" StorageTypeWebdav StorageType = "webdav" StorageTypeAlist StorageType = "alist" @@ -38,71 +26,8 @@ var StorageTypeDisplay = map[StorageType]string{ StorageTypeMinio: "Minio", } -type Task struct { - Ctx context.Context - Cancel context.CancelFunc - Error error - Status TaskStatus - StorageName string - StoragePath string - StartTime time.Time +type ContextKey string - File *File - FileMessageID int - FileChatID int64 - - IsTelegraph bool - TelegraphURL string - - // to track the reply message - ReplyMessageID int - ReplyChatID int64 - UserID int64 -} - -func (t Task) Key() string { - if t.IsTelegraph { - return hashStr(t.TelegraphURL) - } - return fmt.Sprintf("%d:%d", t.FileChatID, t.FileMessageID) -} - -func (t Task) String() string { - if t.IsTelegraph { - return fmt.Sprintf("[telegraph]:%s", t.TelegraphURL) - } - return fmt.Sprintf("[%d:%d]:%s", t.FileChatID, t.FileMessageID, t.File.FileName) -} - -func (t Task) FileName() string { - if t.IsTelegraph { - tgphPath := strings.Split(t.TelegraphURL, "/")[len(strings.Split(t.TelegraphURL, "/"))-1] - tgphPathUnescaped, err := url.PathUnescape(tgphPath) - if err != nil { - return tgphPath - } - return tgphPathUnescaped - } - return t.File.FileName -} - -type File struct { - Location tg.InputFileLocationClass - FileSize int64 - FileName string -} - -func (f File) Hash() string { - locationBytes := []byte(f.Location.String()) - fileSizeBytes := []byte(fmt.Sprintf("%d", f.FileSize)) - fileNameBytes := []byte(f.FileName) - - structBytes := append(locationBytes, fileSizeBytes...) - structBytes = append(structBytes, fileNameBytes...) - - hash := md5.New() - hash.Write(structBytes) - hashBytes := hash.Sum(nil) - - return hex.EncodeToString(hashBytes) -} +const ( + ContextKeyContentLength ContextKey = "content-length" +)