fix: not pass content length when uploading in non stream mode

This commit is contained in:
krau
2025-03-26 10:22:38 +08:00
parent 347a60f1f7
commit 7ffd9891a0
7 changed files with 144 additions and 111 deletions

View File

@@ -191,7 +191,7 @@ func processTelegraph(extCtx *ext.Context, cancelCtx context.Context, task *type
if len(node.Children) != 0 {
for _, child := range node.Children {
imgs = append(imgs, GetImages(child)...)
imgs = append(imgs, getNodeImages(child)...)
}
}
@@ -265,27 +265,3 @@ func processTelegraph(extCtx *ext.Context, cancelCtx context.Context, task *type
return cancelCtx.Err()
}
}
func GetImages(node telegraph.Node) []string {
var srcs []string
var nodeElement telegraph.NodeElement
data, err := json.Marshal(node)
if err != nil {
return srcs
}
err = json.Unmarshal(data, &nodeElement)
if err != nil {
return srcs
}
if nodeElement.Tag == "img" {
if src, exists := nodeElement.Attrs["src"]; exists {
srcs = append(srcs, src)
}
}
for _, child := range nodeElement.Children {
srcs = append(srcs, GetImages(child)...)
}
return srcs
}

View File

@@ -72,7 +72,7 @@ func TestGetImgSrcs(t *testing.T) {
"https://example.com/image4.png",
}
got := GetImages(complexStructure)
got := getNodeImages(complexStructure)
if !reflect.DeepEqual(expected, got) {
t.Errorf("expected %vgot %v", expected, got)

View File

@@ -3,6 +3,7 @@ package core
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"os"
@@ -10,6 +11,7 @@ import (
"time"
"github.com/celestix/gotgproto/ext"
"github.com/celestix/telegraph-go/v2"
"github.com/gabriel-vasile/mimetype"
"github.com/gotd/td/telegram/message/entity"
"github.com/gotd/td/telegram/message/styling"
@@ -22,22 +24,33 @@ import (
)
func saveFileWithRetry(ctx context.Context, storagePath string, taskStorage storage.Storage, cacheFilePath string) error {
file, err := os.Open(cacheFilePath)
if err != nil {
return fmt.Errorf("failed to open cache file: %w", err)
}
defer file.Close()
fileStat, err := file.Stat()
if err != nil {
return fmt.Errorf("failed to get file stat: %w", err)
}
vctx := context.WithValue(ctx, types.ContextKeyContentLength, fileStat.Size())
for i := 0; i <= config.Cfg.Retry; i++ {
if err := ctx.Err(); err != nil {
if err := vctx.Err(); err != nil {
return fmt.Errorf("context canceled while saving file: %w", err)
}
file, err := os.Open(cacheFilePath)
if err != nil {
return fmt.Errorf("failed to open cache file: %w", err)
}
if err := taskStorage.Save(ctx, file, storagePath); err != nil {
defer file.Close()
if err := taskStorage.Save(vctx, file, storagePath); err != nil {
if i == config.Cfg.Retry {
return fmt.Errorf("failed to save file: %w", err)
}
common.Log.Errorf("Failed to save file: %s, retrying...", err)
select {
case <-ctx.Done():
return fmt.Errorf("context canceled during retry delay: %w", ctx.Err())
case <-vctx.Done():
return fmt.Errorf("context canceled during retry delay: %w", vctx.Err())
case <-time.After(time.Duration(i*500) * time.Millisecond):
}
continue
@@ -256,3 +269,27 @@ func NewProgressStream(writer io.Writer, size int64, callback func(bytesRead, co
interval: interval,
}
}
func getNodeImages(node telegraph.Node) []string {
var srcs []string
var nodeElement telegraph.NodeElement
data, err := json.Marshal(node)
if err != nil {
return srcs
}
err = json.Unmarshal(data, &nodeElement)
if err != nil {
return srcs
}
if nodeElement.Tag == "img" {
if src, exists := nodeElement.Attrs["src"]; exists {
srcs = append(srcs, src)
}
}
for _, child := range nodeElement.Children {
srcs = append(srcs, getNodeImages(child)...)
}
return srcs
}

View File

@@ -106,6 +106,12 @@ func (a *Alist) Save(ctx context.Context, reader io.Reader, storagePath string)
req.Header.Set("Authorization", a.token)
req.Header.Set("File-Path", url.PathEscape(storagePath))
req.Header.Set("Content-Type", "application/octet-stream")
if length := ctx.Value(types.ContextKeyContentLength); length != nil {
length, ok := length.(int64)
if ok {
req.ContentLength = length
}
}
resp, err := a.client.Do(req)
if err != nil {

View File

@@ -6,6 +6,8 @@ import (
"io"
"net/http"
"strings"
"github.com/krau/SaveAny-Bot/types"
)
type Client struct {
@@ -38,6 +40,11 @@ func (c *Client) doRequest(ctx context.Context, method, url string, body io.Read
if c.Username != "" && c.Password != "" {
req.SetBasicAuth(c.Username, c.Password)
}
if length := ctx.Value(types.ContextKeyContentLength); length != nil {
if l, ok := length.(int64); ok {
req.ContentLength = l
}
}
return c.httpClient.Do(req)
}

82
types/task.go Normal file
View File

@@ -0,0 +1,82 @@
package types
import (
"context"
"crypto/md5"
"encoding/hex"
"fmt"
"net/url"
"strings"
"time"
"github.com/gotd/td/tg"
)
type Task struct {
Ctx context.Context
Cancel context.CancelFunc
Error error
Status TaskStatus
StorageName string
StoragePath string
StartTime time.Time
File *File
FileMessageID int
FileChatID int64
IsTelegraph bool
TelegraphURL string
// to track the reply message
ReplyMessageID int
ReplyChatID int64
UserID int64
}
func (t Task) Key() string {
if t.IsTelegraph {
return hashStr(t.TelegraphURL)
}
return fmt.Sprintf("%d:%d", t.FileChatID, t.FileMessageID)
}
func (t Task) String() string {
if t.IsTelegraph {
return fmt.Sprintf("[telegraph]:%s", t.TelegraphURL)
}
return fmt.Sprintf("[%d:%d]:%s", t.FileChatID, t.FileMessageID, t.File.FileName)
}
func (t Task) FileName() string {
if t.IsTelegraph {
tgphPath := strings.Split(t.TelegraphURL, "/")[len(strings.Split(t.TelegraphURL, "/"))-1]
tgphPathUnescaped, err := url.PathUnescape(tgphPath)
if err != nil {
return tgphPath
}
return tgphPathUnescaped
}
return t.File.FileName
}
type File struct {
Location tg.InputFileLocationClass
FileSize int64
FileName string
}
func (f File) Hash() string {
locationBytes := []byte(f.Location.String())
fileSizeBytes := []byte(fmt.Sprintf("%d", f.FileSize))
fileNameBytes := []byte(f.FileName)
structBytes := append(locationBytes, fileSizeBytes...)
structBytes = append(structBytes, fileNameBytes...)
hash := md5.New()
hash.Write(structBytes)
hashBytes := hash.Sum(nil)
return hex.EncodeToString(hashBytes)
}

View File

@@ -1,20 +1,8 @@
package types
import (
"context"
"crypto/md5"
"encoding/hex"
"fmt"
"net/url"
"strings"
"time"
"github.com/gotd/td/tg"
)
type TaskStatus string
var (
const (
Pending TaskStatus = "pending"
Succeeded TaskStatus = "succeeded"
Failed TaskStatus = "failed"
@@ -23,7 +11,7 @@ var (
type StorageType string
var (
const (
StorageTypeLocal StorageType = "local"
StorageTypeWebdav StorageType = "webdav"
StorageTypeAlist StorageType = "alist"
@@ -38,71 +26,8 @@ var StorageTypeDisplay = map[StorageType]string{
StorageTypeMinio: "Minio",
}
type Task struct {
Ctx context.Context
Cancel context.CancelFunc
Error error
Status TaskStatus
StorageName string
StoragePath string
StartTime time.Time
type ContextKey string
File *File
FileMessageID int
FileChatID int64
IsTelegraph bool
TelegraphURL string
// to track the reply message
ReplyMessageID int
ReplyChatID int64
UserID int64
}
func (t Task) Key() string {
if t.IsTelegraph {
return hashStr(t.TelegraphURL)
}
return fmt.Sprintf("%d:%d", t.FileChatID, t.FileMessageID)
}
func (t Task) String() string {
if t.IsTelegraph {
return fmt.Sprintf("[telegraph]:%s", t.TelegraphURL)
}
return fmt.Sprintf("[%d:%d]:%s", t.FileChatID, t.FileMessageID, t.File.FileName)
}
func (t Task) FileName() string {
if t.IsTelegraph {
tgphPath := strings.Split(t.TelegraphURL, "/")[len(strings.Split(t.TelegraphURL, "/"))-1]
tgphPathUnescaped, err := url.PathUnescape(tgphPath)
if err != nil {
return tgphPath
}
return tgphPathUnescaped
}
return t.File.FileName
}
type File struct {
Location tg.InputFileLocationClass
FileSize int64
FileName string
}
func (f File) Hash() string {
locationBytes := []byte(f.Location.String())
fileSizeBytes := []byte(fmt.Sprintf("%d", f.FileSize))
fileNameBytes := []byte(f.FileName)
structBytes := append(locationBytes, fileSizeBytes...)
structBytes = append(structBytes, fileNameBytes...)
hash := md5.New()
hash.Write(structBytes)
hashBytes := hash.Sum(nil)
return hex.EncodeToString(hashBytes)
}
const (
ContextKeyContentLength ContextKey = "content-length"
)