From 45c978980ce960e659ad4b99d2577455203b07c8 Mon Sep 17 00:00:00 2001 From: krau <71133316+krau@users.noreply.github.com> Date: Mon, 15 Dec 2025 10:25:50 +0800 Subject: [PATCH] feat: add support for splitting large files into parts for Telegram storage, #156 --- config/storage/telegram.go | 5 ++ storage/telegram/split.go | 147 +++++++++++++++++++++++++++++++++ storage/telegram/split_test.go | 55 ++++++++++++ storage/telegram/telegram.go | 140 +++++++++++++++++++++++++++---- 4 files changed, 332 insertions(+), 15 deletions(-) create mode 100644 storage/telegram/split.go create mode 100644 storage/telegram/split_test.go diff --git a/config/storage/telegram.go b/config/storage/telegram.go index 37a2208..e7e7861 100644 --- a/config/storage/telegram.go +++ b/config/storage/telegram.go @@ -12,6 +12,11 @@ type TelegramStorageConfig struct { ForceFile bool `toml:"force_file" mapstructure:"force_file" json:"force_file"` RateLimit int `toml:"rate_limit" mapstructure:"rate_limit" json:"rate_limit"` RateBurst int `toml:"rate_burst" mapstructure:"rate_burst" json:"rate_burst"` + SkipLarge bool `toml:"skip_large" mapstructure:"skip_large" json:"skip_large"` // skip files larger than Telegram limit(2GB) + // split files larger than Telegram limit(2GB) into parts of specified size, in MB, leave 0 to set default(2000MB) + // only effective when SkipLarge is false + // use zip when splitting + SplitSizeMB int64 `toml:"split_size_mb" mapstructure:"split_size_mb" json:"split_size_mb"` } func (m *TelegramStorageConfig) Validate() error { diff --git a/storage/telegram/split.go b/storage/telegram/split.go new file mode 100644 index 0000000..50091be --- /dev/null +++ b/storage/telegram/split.go @@ -0,0 +1,147 @@ +package telegram + +import ( + "archive/zip" + "context" + "fmt" + "io" + "os" + "path/filepath" + "time" +) + +type splitWriter struct { + baseName string + partSize int64 + currentPart int + currentSize int64 + currentFile *os.File + totalParts int +} + +func newSplitWriter(baseName string, partSize int64) *splitWriter { + return &splitWriter{ + baseName: baseName, + partSize: partSize, + currentPart: 0, + } +} + +// Write implements io.Writer interface +func (w *splitWriter) Write(p []byte) (n int, err error) { + written := 0 + for written < len(p) { + if w.currentFile == nil || w.currentSize >= w.partSize { + if err := w.nextPart(); err != nil { + return written, err + } + } + + toWrite := int64(len(p) - written) + remaining := w.partSize - w.currentSize + if toWrite > remaining { + toWrite = remaining + } + + nw, err := w.currentFile.Write(p[written : written+int(toWrite)]) + written += nw + w.currentSize += int64(nw) + + if err != nil { + return written, err + } + } + return written, nil +} + +func (w *splitWriter) Close() error { + if w.currentFile != nil { + return w.currentFile.Close() + } + return nil +} + +func (w *splitWriter) nextPart() error { + if w.currentFile != nil { + if err := w.currentFile.Close(); err != nil { + return err + } + } + + partName := w.partName(w.currentPart) + file, err := os.Create(partName) + if err != nil { + return err + } + + w.currentFile = file + w.currentSize = 0 + w.currentPart++ + return nil +} + +func (w *splitWriter) partName(partNum int) string { + // file.zip.001, file.zip.002, ... + return fmt.Sprintf("%s.zip.%03d", w.baseName, partNum+1) +} + +func (w *splitWriter) finalize() error { + w.totalParts = w.currentPart + + // 如果只有一个分卷,直接重命名为 .zip + if w.totalParts == 1 { + oldName := fmt.Sprintf("%s.zip.001", w.baseName) + newName := fmt.Sprintf("%s.zip", w.baseName) + return os.Rename(oldName, newName) + } + + return nil +} + +func CreateSplitZip(ctx context.Context, reader io.Reader, size int64, fileName, outputBase string, partSize int64) error { + // seek the reader if possible + if rs, ok := reader.(io.ReadSeeker); ok { + if _, err := rs.Seek(0, io.SeekStart); err != nil { + return fmt.Errorf("failed to seek reader: %w", err) + } + } + outputDir := filepath.Dir(outputBase) + if err := os.MkdirAll(outputDir, os.ModePerm); err != nil { + return fmt.Errorf("failed to create output directory: %w", err) + } + + splitWriter := newSplitWriter(outputBase, partSize) + defer splitWriter.Close() + + zipWriter := zip.NewWriter(splitWriter) + defer zipWriter.Close() + + header := &zip.FileHeader{ + Name: fileName, + Method: zip.Store, // just store without compression + Modified: time.Now(), + } + + writer, err := zipWriter.CreateHeader(header) + if err != nil { + return fmt.Errorf("failed to create zip header: %w", err) + } + + copied, err := io.Copy(writer, reader) + if err != nil { + return fmt.Errorf("failed to write data: %w", err) + } + if copied != size { + return fmt.Errorf("incomplete write: expected %d bytes, got %d bytes", size, copied) + } + if err := zipWriter.Close(); err != nil { + return fmt.Errorf("failed to close zip writer: %w", err) + } + if err := splitWriter.Close(); err != nil { + return fmt.Errorf("failed to close split writer: %w", err) + } + if err := splitWriter.finalize(); err != nil { + return fmt.Errorf("failed to rename split files: %w", err) + } + return nil +} diff --git a/storage/telegram/split_test.go b/storage/telegram/split_test.go new file mode 100644 index 0000000..a90301d --- /dev/null +++ b/storage/telegram/split_test.go @@ -0,0 +1,55 @@ +package telegram + +import ( + "os" + "path/filepath" + "testing" +) + +func TestCreateSplitZip(t *testing.T) { + input := "tests/testfile.dat" + file, err := os.Open(input) + if err != nil { + t.Fatalf("failed to open test file: %v", err) + } + defer file.Close() + fileName := filepath.Base(input) + fileInfo, err := file.Stat() + if err != nil { + t.Fatalf("failed to stat test file: %v", err) + } + fileSize := fileInfo.Size() + + tests := []struct { + partSize int64 + output string + }{ + {partSize: int64(1024 * 1024 * 500), output: "tests/split_test_output_500MB"}, + {partSize: int64(1024 * 1024 * 100), output: "tests/split_test_output_100MB"}, + } + + for _, tt := range tests { + err = CreateSplitZip(t.Context(), file, fileSize, fileName, tt.output, tt.partSize) + if err != nil { + t.Fatalf("CreateSplitZip failed: %v", err) + } + matched, err := filepath.Glob(tt.output + ".z*") + if err != nil { + t.Fatalf("failed to glob split files: %v", err) + } + if len(matched) == 0 { + t.Fatalf("no split files found") + } + t.Logf("Created %d split files", len(matched)) + for _, f := range matched { + info, err := os.Stat(f) + if err != nil { + t.Fatalf("failed to stat file %s: %v", f, err) + } + if info.Size() > tt.partSize { + t.Errorf("file %s exceeds part size: %d > %d", f, info.Size(), tt.partSize) + } + t.Logf(" - %s (%d bytes)", f, info.Size()) + } + } +} diff --git a/storage/telegram/telegram.go b/storage/telegram/telegram.go index 7c0ad79..30a3bfc 100644 --- a/storage/telegram/telegram.go +++ b/storage/telegram/telegram.go @@ -4,10 +4,13 @@ import ( "context" "fmt" "io" + "os" "path" + "path/filepath" "strings" "time" + "github.com/celestix/gotgproto/ext" "github.com/charmbracelet/log" "github.com/duke-git/lancet/v2/slice" "github.com/duke-git/lancet/v2/validator" @@ -16,6 +19,7 @@ import ( "github.com/gotd/td/telegram/message/styling" "github.com/gotd/td/telegram/uploader" "github.com/gotd/td/tg" + "github.com/krau/SaveAny-Bot/common/utils/dlutil" "github.com/krau/SaveAny-Bot/common/utils/tgutil" "github.com/krau/SaveAny-Bot/config" storconfig "github.com/krau/SaveAny-Bot/config/storage" @@ -26,6 +30,11 @@ import ( "golang.org/x/time/rate" ) +const ( + DefaultSplitSize = 2 * 1024 * 1024 * 1024 // 2000 MB + MaxUploadFileSize = 2 * 1024 * 1024 * 1024 // 2 GB +) + type Telegram struct { config storconfig.TelegramStorageConfig limiter *rate.Limiter @@ -65,22 +74,39 @@ func (t *Telegram) Exists(ctx context.Context, storagePath string) bool { } func (t *Telegram) Save(ctx context.Context, r io.Reader, storagePath string) error { - if err := t.limiter.Wait(ctx); err != nil { - return fmt.Errorf("rate limit failed: %w", err) + tctx := tgutil.ExtFromContext(ctx) + if tctx == nil { + return fmt.Errorf("failed to get telegram context") + } + size := func() int64 { + if length := ctx.Value(ctxkey.ContentLength); length != nil { + if l, ok := length.(int64); ok { + return l + } + } + return -1 // unknown size + }() + if t.config.SkipLarge && size > MaxUploadFileSize { + log.FromContext(ctx).Warnf("Skipping file larger than Telegram limit (%d bytes): %d bytes", MaxUploadFileSize, size) + return nil } rs, seekable := r.(io.ReadSeeker) if !seekable || rs == nil { return fmt.Errorf("reader must implement io.ReadSeeker") } - tctx := tgutil.ExtFromContext(ctx) - if tctx == nil { - return fmt.Errorf("failed to get telegram context") + splitSize := t.config.SplitSizeMB * 1024 * 1024 + if splitSize <= 0 { + splitSize = DefaultSplitSize } + + if err := t.limiter.Wait(ctx); err != nil { + return fmt.Errorf("rate limit failed: %w", err) + } + // 去除前导斜杠并分隔路径, 当 len(parts): // ==0, 存储到配置文件中的 chat_id, 随机文件名 // ==1, 视作只有文件名, 存储到配置文件中的 chat_id // ==2, parts[0]: 视作要存储到的 chat_id, parts[1]: filename - parts := slice.Compact(strings.Split(strings.TrimPrefix(storagePath, "/"), "/")) filename := "" chatID := t.config.ChatID @@ -113,17 +139,13 @@ func (t *Telegram) Save(ctx context.Context, r io.Reader, storagePath string) er } upler := uploader.NewUploader(tctx.Raw). WithPartSize(tglimit.MaxUploadPartSize). - WithThreads(config.C().Threads) + WithThreads(dlutil.BestThreads(size, config.C().Threads)) + if size > splitSize { + // large file, use split uploader + return t.splitUpload(tctx, rs, filename, upler, peer, size, splitSize) + } var file tg.InputFileClass - size := func() int64 { - if length := ctx.Value(ctxkey.ContentLength); length != nil { - if l, ok := length.(int64); ok { - return l - } - } - return -1 // unknown size - }() if size < 0 { file, err = upler.FromReader(ctx, filename, rs) } else { @@ -186,3 +208,91 @@ func (t *Telegram) Save(ctx context.Context, r io.Reader, storagePath string) er func (t *Telegram) CannotStream() string { return "Telegram storage must use a ReaderSeeker" } + +func (t *Telegram) splitUpload(ctx *ext.Context, rs io.ReadSeeker, filename string, upler *uploader.Uploader, peer tg.InputPeerClass, fileSize, splitSize int64) error { + tempId := xid.New().String() + outputBase := filepath.Join(config.C().Temp.BasePath, tempId, strings.Split(filename, ".")[0]) + defer func() { + // cleanup temp files + if err := os.RemoveAll(filepath.Join(config.C().Temp.BasePath, tempId)); err != nil { + log.FromContext(ctx).Warnf("Failed to cleanup temp split files: %s", err) + } + }() + if err := CreateSplitZip(ctx, rs, fileSize, filename, outputBase, splitSize); err != nil { + return fmt.Errorf("failed to create split zip: %w", err) + } + matched, err := filepath.Glob(outputBase + ".z*") + if err != nil { + return fmt.Errorf("failed to glob split files: %w", err) + } + inputFiles := make([]tg.InputFileClass, 0, len(matched)) + for _, partPath := range matched { + // 串行上传, 不然容易被tg风控 + err = func() error { + partFile, err := os.Open(partPath) + if err != nil { + return fmt.Errorf("failed to open split part %s: %w", partPath, err) + } + defer partFile.Close() + partInfo, err := partFile.Stat() + if err != nil { + return fmt.Errorf("failed to stat split part %s: %w", partPath, err) + } + partFileSize := partInfo.Size() + partName := filepath.Base(partPath) + partInputFile, err := upler.Upload(ctx, uploader.NewUpload(partName, partFile, partFileSize)) + if err != nil { + return fmt.Errorf("failed to upload split part %s: %w", partPath, err) + } + inputFiles = append(inputFiles, partInputFile) + return nil + }() + if err != nil { + return fmt.Errorf("failed to upload split part %s: %w", partPath, err) + } + } + if len(inputFiles) == 1 { + // only one part, send as normal file + // shoud not happen as we already check fileSize > splitSize + doc := message.UploadedDocument(inputFiles[0]). + Filename(filepath.Base(matched[0])). + ForceFile(true). + MIME("application/zip") + _, err = ctx.Sender. + WithUploader(upler). + To(peer). + Media(ctx, doc) + return err + } + + multiMedia := make([]message.MultiMediaOption, 0, len(inputFiles)) + for i, inputFile := range inputFiles { + doc := message.UploadedDocument(inputFile). + Filename(filepath.Base(matched[i])). + MIME("application/zip") + multiMedia = append(multiMedia, doc) + } + + sender := ctx.Sender + + if len(multiMedia) <= 10 { + _, err = sender.WithUploader(upler). + To(peer). + Album(ctx, multiMedia[0], multiMedia[1:]...) + return err + } + + // more than 10 parts, send in batches, each batch up to 10 parts + for i := 0; i < len(multiMedia); i += 10 { + end := min(i+10, len(multiMedia)) + batch := multiMedia[i:end] + _, err = sender.WithUploader(upler). + To(peer). + Album(ctx, batch[0], batch[1:]...) + if err != nil { + return fmt.Errorf("failed to send album batch: %w", err) + } + } + return nil + +}