refactor: refactor task logic for better scalability (#76)

* refactor: a big refactor. wip

* refactor: port handle file

* refactor: place all handlers

* fix: task info nil pointer

* feat: enhance task progress tracking and context management

* feat: cancel task

* feat: stream mode

* feat: silent mode

* feat: dir cmd

* refactor: remove unused old file

* feat: rule cmd

* feat: handle silent mode

* feat: batch task

* fix: batch task progress and temp file cleanup

* refactor: update file creation and cleanup methods for better resource management

* feat: add save command with silent mode handling

* feat: message link

* feat: update message prompts to include file count in storage selection

* feat: slient save links

* refactor: reduce dup code

* feat: rule type

* feat: chose dir

* feat: refactor file handling and storage rules, improve error handling and logging

* feat: rule mode

* feat: telegraph pics

* fix: tphpics nil pointer and inaccurate dirpath

* feat: silent save telegraph

* feat: add suffix to avoid file overwrite

* feat: new storage telegram

* chore: tidy go mod
This commit is contained in:
Krau
2025-06-15 23:57:49 +08:00
committed by GitHub
parent 280745cae3
commit 900823cdb9
150 changed files with 5730 additions and 3923 deletions

View File

@@ -1,6 +1,7 @@
package alist
import (
"bytes"
"context"
"encoding/json"
"fmt"
@@ -8,11 +9,13 @@ import (
"net/http"
"net/url"
"path"
"strings"
"time"
"github.com/krau/SaveAny-Bot/common"
"github.com/charmbracelet/log"
config "github.com/krau/SaveAny-Bot/config/storage"
"github.com/krau/SaveAny-Bot/types"
"github.com/krau/SaveAny-Bot/pkg/enums/key"
storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage"
)
type Alist struct {
@@ -21,9 +24,10 @@ type Alist struct {
baseURL string
loginInfo *loginRequest
config config.AlistStorageConfig
logger *log.Logger
}
func (a *Alist) Init(cfg config.StorageConfig) error {
func (a *Alist) Init(ctx context.Context, cfg config.StorageConfig) error {
alistConfig, ok := cfg.(*config.AlistStorageConfig)
if !ok {
return fmt.Errorf("failed to cast alist config")
@@ -32,45 +36,46 @@ func (a *Alist) Init(cfg config.StorageConfig) error {
return err
}
a.config = *alistConfig
a.baseURL = alistConfig.URL
a.client = getHttpClient()
a.logger = log.FromContext(ctx).WithPrefix(fmt.Sprintf("alist[%s]", alistConfig.Name))
if alistConfig.Token != "" {
a.token = alistConfig.Token
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, a.baseURL+"/api/me", nil)
if err != nil {
common.Log.Fatalf("Failed to create request: %v", err)
a.logger.Fatalf("Failed to create request: %v", err)
return err
}
req.Header.Set("Authorization", a.token)
resp, err := a.client.Do(req)
if err != nil {
common.Log.Fatalf("Failed to send request: %v", err)
a.logger.Fatalf("Failed to send request: %v", err)
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
common.Log.Fatalf("Failed to get alist user info: %s", resp.Status)
a.logger.Fatalf("Failed to get alist user info: %s", resp.Status)
return err
}
body, err := io.ReadAll(resp.Body)
if err != nil {
common.Log.Fatalf("Failed to read response body: %v", err)
a.logger.Fatalf("Failed to read response body: %v", err)
return err
}
var meResp meResponse
if err := json.Unmarshal(body, &meResp); err != nil {
common.Log.Fatalf("Failed to unmarshal me response: %v", err)
a.logger.Fatalf("Failed to unmarshal me response: %v", err)
return err
}
if meResp.Code != http.StatusOK {
common.Log.Fatalf("Failed to get alist user info: %s", meResp.Message)
a.logger.Fatalf("Failed to get alist user info: %s", meResp.Message)
return err
}
common.Log.Debugf("Logged in Alist as %s", meResp.Data.Username)
a.logger.Debugf("Logged in Alist as %s", meResp.Data.Username)
return nil
}
a.loginInfo = &loginRequest{
@@ -78,18 +83,18 @@ func (a *Alist) Init(cfg config.StorageConfig) error {
Password: alistConfig.Password,
}
if err := a.getToken(); err != nil {
common.Log.Fatalf("Failed to login to Alist: %v", err)
if err := a.getToken(ctx); err != nil {
a.logger.Fatalf("Failed to login to Alist: %v", err)
return err
}
common.Log.Debug("Logged in to Alist")
a.logger.Debug("Logged in to Alist")
go a.refreshToken(*alistConfig)
return nil
}
func (a *Alist) Type() types.StorageType {
return types.StorageTypeAlist
func (a *Alist) Type() storenum.StorageType {
return storenum.Alist
}
func (a *Alist) Name() string {
@@ -97,16 +102,23 @@ func (a *Alist) Name() string {
}
func (a *Alist) Save(ctx context.Context, reader io.Reader, storagePath string) error {
common.Log.Infof("Saving file to %s", storagePath)
a.logger.Infof("Saving file to %s", storagePath)
ext := path.Ext(storagePath)
base := strings.TrimSuffix(storagePath, ext)
candidate := storagePath
for i := 1; a.Exists(ctx, candidate); i++ {
candidate = fmt.Sprintf("%s_%d%s", base, i, ext)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPut, a.baseURL+"/api/fs/put", reader)
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Authorization", a.token)
req.Header.Set("File-Path", url.PathEscape(storagePath))
req.Header.Set("File-Path", url.PathEscape(candidate))
req.Header.Set("Content-Type", "application/octet-stream")
if length := ctx.Value(types.ContextKeyContentLength); length != nil {
if length := ctx.Value(key.ContextKeyContentLength); length != nil {
length, ok := length.(int64)
if ok {
req.ContentLength = length
@@ -140,15 +152,66 @@ func (a *Alist) Save(ctx context.Context, reader io.Reader, storagePath string)
return nil
}
func (a *Alist) NotSupportStream() string {
return "Alist does not support chunked transfer encoding"
}
func (a *Alist) JoinStoragePath(task types.Task) string {
return path.Join(a.config.BasePath, task.StoragePath)
func (a *Alist) JoinStoragePath(p string) string {
return path.Join(a.config.BasePath, p)
}
func (a *Alist) Exists(ctx context.Context, storagePath string) bool {
// TODO: Implement it.
return false
// POST /api/fs/get
/*
body:
{
"path": "/t",
"password": "",
"page": 1,
"per_page": 0,
"refresh": false
}
*/
body := map[string]any{
"path": storagePath,
"password": "",
}
bodyBytes, err := json.Marshal(body)
if err != nil {
a.logger.Errorf("Failed to marshal request body: %v", err)
return false
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, a.baseURL+"/api/fs/get", bytes.NewBuffer(bodyBytes))
if err != nil {
a.logger.Errorf("Failed to create request: %v", err)
return false
}
req.Header.Set("Authorization", a.token)
req.Header.Set("Content-Type", "application/json")
resp, err := a.client.Do(req)
if err != nil {
a.logger.Errorf("Failed to send request: %v", err)
return false
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return false
}
data, err := io.ReadAll(resp.Body)
if err != nil {
a.logger.Errorf("Failed to read response body: %v", err)
return false
}
var fsGetResp fsGetResponse
if err := json.Unmarshal(data, &fsGetResp); err != nil {
a.logger.Errorf("Failed to unmarshal fs get response: %v", err)
return false
}
if fsGetResp.Code != http.StatusOK {
a.logger.Errorf("Failed to get file info from Alist: %d, %s", fsGetResp.Code, fsGetResp.Message)
return false
}
return true
}
// Impl StorageCannotStream interface
func (a *Alist) CannotStream() string {
return "Alist does not support chunked transfer encoding"
}

View File

@@ -2,17 +2,17 @@ package alist
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
"github.com/krau/SaveAny-Bot/common"
config "github.com/krau/SaveAny-Bot/config/storage"
)
func (a *Alist) getToken() error {
func (a *Alist) getToken(ctx context.Context) error {
loginBody, err := json.Marshal(a.loginInfo)
if err != nil {
return fmt.Errorf("failed to marshal login request: %w", err)
@@ -51,15 +51,15 @@ func (a *Alist) getToken() error {
func (a *Alist) refreshToken(cfg config.AlistStorageConfig) {
tokenExp := cfg.TokenExp
if tokenExp <= 0 {
common.Log.Warn("Invalid token expiration time, using default value")
a.logger.Warn("Invalid token expiration time, using default value")
tokenExp = 3600
}
for {
time.Sleep(time.Duration(tokenExp) * time.Second)
if err := a.getToken(); err != nil {
common.Log.Errorf("Failed to refresh jwt token: %v", err)
if err := a.getToken(context.Background()); err != nil {
a.logger.Errorf("Failed to refresh jwt token: %v", err)
continue
}
common.Log.Info("Refreshed Alist jwt token")
a.logger.Info("Refreshed Alist jwt token")
}
}

View File

@@ -42,3 +42,8 @@ type putResponse struct {
} `json:"task"`
} `json:"data"`
}
type fsGetResponse struct {
Code int `json:"code"`
Message string `json:"message"`
}

22
storage/context.go Normal file
View File

@@ -0,0 +1,22 @@
package storage
import "context"
type contextKey struct{}
var storageKey = contextKey{}
func WithContext(ctx context.Context, storage Storage) context.Context {
if storage == nil {
return ctx
}
return context.WithValue(ctx, storageKey, storage)
}
func FromContext(ctx context.Context) Storage {
storage, ok := ctx.Value(storageKey).(Storage)
if !ok {
return nil
}
return storage
}

80
storage/load.go Normal file
View File

@@ -0,0 +1,80 @@
package storage
import (
"context"
"fmt"
"github.com/charmbracelet/log"
"github.com/krau/SaveAny-Bot/config"
)
var UserStorages = make(map[int64][]Storage)
// GetStorageByName returns storage by name from cache or creates new one
func getStorageByName(ctx context.Context, name string) (Storage, error) {
if name == "" {
return nil, ErrStorageNameEmpty
}
storage, ok := Storages[name]
if ok {
return storage, nil
}
cfg := config.Cfg.GetStorageByName(name)
if cfg == nil {
return nil, fmt.Errorf("未找到存储 %s", name)
}
storage, err := NewStorage(ctx, cfg)
if err != nil {
return nil, err
}
Storages[name] = storage
return storage, nil
}
// 检查 user 是否可用指定的 storage, 若不可用则返回未找到错误
func GetStorageByUserIDAndName(ctx context.Context, chatID int64, name string) (Storage, error) {
if name == "" {
return nil, ErrStorageNameEmpty
}
if !config.Cfg.HasStorage(chatID, name) {
return nil, fmt.Errorf("没有找到用户 %d 的存储 %s", chatID, name)
}
return getStorageByName(ctx, name)
}
func GetUserStorages(ctx context.Context, chatID int64) []Storage {
if chatID <= 0 {
return nil
}
if storages, ok := UserStorages[chatID]; ok {
return storages
}
var storages []Storage
for _, name := range config.Cfg.GetStorageNamesByUserID(chatID) {
storage, err := getStorageByName(ctx, name)
if err != nil {
continue
}
storages = append(storages, storage)
}
return storages
}
func LoadStorages(ctx context.Context) {
logger := log.FromContext(ctx)
logger.Info("加载存储...")
for _, storage := range config.Cfg.Storages {
_, err := getStorageByName(ctx, storage.GetName())
if err != nil {
logger.Errorf("加载存储 %s 失败: %v", storage.GetName(), err)
}
}
logger.Infof("成功加载 %d 个存储", len(Storages))
for user := range config.Cfg.GetUsersID() {
UserStorages[int64(user)] = GetUserStorages(ctx, int64(user))
}
}

View File

@@ -6,18 +6,20 @@ import (
"io"
"os"
"path/filepath"
"strings"
"github.com/charmbracelet/log"
"github.com/duke-git/lancet/v2/fileutil"
"github.com/krau/SaveAny-Bot/common"
config "github.com/krau/SaveAny-Bot/config/storage"
"github.com/krau/SaveAny-Bot/types"
storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage"
)
type Local struct {
config config.LocalStorageConfig
logger *log.Logger
}
func (l *Local) Init(cfg config.StorageConfig) error {
func (l *Local) Init(ctx context.Context, cfg config.StorageConfig) error {
localConfig, ok := cfg.(*config.LocalStorageConfig)
if !ok {
return fmt.Errorf("failed to cast local config")
@@ -30,25 +32,33 @@ func (l *Local) Init(cfg config.StorageConfig) error {
if err != nil {
return fmt.Errorf("failed to create local storage directory: %w", err)
}
l.logger = log.FromContext(ctx).WithPrefix(fmt.Sprintf("local[%s]", l.config.Name))
return nil
}
func (l *Local) Type() types.StorageType {
return types.StorageTypeLocal
func (l *Local) Type() storenum.StorageType {
return storenum.Local
}
func (l *Local) Name() string {
return l.config.Name
}
func (l *Local) JoinStoragePath(task types.Task) string {
return filepath.Join(l.config.BasePath, task.StoragePath)
func (l *Local) JoinStoragePath(path string) string {
return filepath.Join(l.config.BasePath, path)
}
func (l *Local) Save(ctx context.Context, r io.Reader, storagePath string) error {
common.Log.Infof("Saving file to %s", storagePath)
l.logger.Infof("Saving file to %s", storagePath)
absPath, err := filepath.Abs(storagePath)
ext := filepath.Ext(storagePath)
base := strings.TrimSuffix(storagePath, ext)
candidate := storagePath
for i := 1; l.Exists(ctx, candidate); i++ {
candidate = fmt.Sprintf("%s_%d%s", base, i, ext)
}
absPath, err := filepath.Abs(candidate)
if err != nil {
return err
}

View File

@@ -5,10 +5,11 @@ import (
"fmt"
"io"
"path"
"strings"
"github.com/krau/SaveAny-Bot/common"
"github.com/charmbracelet/log"
config "github.com/krau/SaveAny-Bot/config/storage"
"github.com/krau/SaveAny-Bot/types"
storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
)
@@ -16,9 +17,10 @@ import (
type Minio struct {
config config.MinioStorageConfig
client *minio.Client
logger *log.Logger
}
func (m *Minio) Init(cfg config.StorageConfig) error {
func (m *Minio) Init(ctx context.Context, cfg config.StorageConfig) error {
minioConfig, ok := cfg.(*config.MinioStorageConfig)
if !ok {
return fmt.Errorf("failed to cast minio config")
@@ -27,6 +29,7 @@ func (m *Minio) Init(cfg config.StorageConfig) error {
return err
}
m.config = *minioConfig
m.logger = log.FromContext(ctx).WithPrefix(fmt.Sprintf("minio[%s]", m.config.Name))
client, err := minio.New(m.config.Endpoint, &minio.Options{
Creds: credentials.NewStaticV4(m.config.AccessKeyID, m.config.SecretAccessKey, ""),
@@ -36,7 +39,7 @@ func (m *Minio) Init(cfg config.StorageConfig) error {
return fmt.Errorf("failed to create minio client: %w", err)
}
exists, err := client.BucketExists(context.Background(), m.config.BucketName)
exists, err := client.BucketExists(ctx, m.config.BucketName)
if err != nil {
return fmt.Errorf("failed to check bucket existence: %w", err)
}
@@ -48,22 +51,29 @@ func (m *Minio) Init(cfg config.StorageConfig) error {
return nil
}
func (m *Minio) Type() types.StorageType {
return types.StorageTypeMinio
func (m *Minio) Type() storenum.StorageType {
return storenum.Minio
}
func (m *Minio) Name() string {
return m.config.Name
}
func (m *Minio) JoinStoragePath(task types.Task) string {
return path.Join(m.config.BasePath, task.StoragePath)
func (m *Minio) JoinStoragePath(p string) string {
return path.Join(m.config.BasePath, p)
}
func (m *Minio) Save(ctx context.Context, r io.Reader, storagePath string) error {
common.Log.Infof("Saving file from reader to %s", storagePath)
m.logger.Infof("Saving file from reader to %s", storagePath)
_, err := m.client.PutObject(ctx, m.config.BucketName, storagePath, r, -1, minio.PutObjectOptions{})
ext := path.Ext(storagePath)
base := strings.TrimSuffix(storagePath, ext)
candidate := storagePath
for i := 1; m.Exists(ctx, candidate); i++ {
candidate = fmt.Sprintf("%s_%d%s", base, i, ext)
}
_, err := m.client.PutObject(ctx, m.config.BucketName, candidate, r, -1, minio.PutObjectOptions{})
if err != nil {
return fmt.Errorf("failed to upload file to minio: %w", err)
}
@@ -72,15 +82,7 @@ func (m *Minio) Save(ctx context.Context, r io.Reader, storagePath string) error
}
func (m *Minio) Exists(ctx context.Context, storagePath string) bool {
common.Log.Debugf("Checking if file exists at %s", storagePath)
// TODO: test it.
m.logger.Debugf("Checking if file exists at %s", storagePath)
_, err := m.client.StatObject(ctx, m.config.BucketName, storagePath, minio.StatObjectOptions{})
if err != nil {
if minio.ToErrorResponse(err).Code == "NoSuchKey" {
return false // File does not exist
}
return false
}
return true
return err == nil
}

View File

@@ -5,121 +5,51 @@ import (
"fmt"
"io"
"github.com/krau/SaveAny-Bot/common"
"github.com/krau/SaveAny-Bot/config"
sc "github.com/krau/SaveAny-Bot/config/storage"
storcfg "github.com/krau/SaveAny-Bot/config/storage"
storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage"
"github.com/krau/SaveAny-Bot/storage/alist"
"github.com/krau/SaveAny-Bot/storage/local"
"github.com/krau/SaveAny-Bot/storage/minio"
"github.com/krau/SaveAny-Bot/storage/telegram"
"github.com/krau/SaveAny-Bot/storage/webdav"
"github.com/krau/SaveAny-Bot/types"
)
type Storage interface {
Init(cfg sc.StorageConfig) error
Type() types.StorageType
Init(ctx context.Context, cfg storcfg.StorageConfig) error
Type() storenum.StorageType
Name() string
JoinStoragePath(task types.Task) string
JoinStoragePath(p string) string
Save(ctx context.Context, reader io.Reader, storagePath string) error
Exists(ctx context.Context, storagePath string) bool
}
type StorageNotSupportStream interface {
type StorageCannotStream interface {
Storage
NotSupportStream() string
CannotStream() string
}
var Storages = make(map[string]Storage)
var UserStorages = make(map[int64][]Storage)
// GetStorageByName returns storage by name from cache or creates new one
func GetStorageByName(name string) (Storage, error) {
if name == "" {
return nil, ErrStorageNameEmpty
}
storage, ok := Storages[name]
if ok {
return storage, nil
}
cfg := config.Cfg.GetStorageByName(name)
if cfg == nil {
return nil, fmt.Errorf("未找到存储 %s", name)
}
storage, err := NewStorage(cfg)
if err != nil {
return nil, err
}
Storages[name] = storage
return storage, nil
}
// 检查 user 是否可用指定的 storage, 若不可用则返回未找到错误
func GetStorageByUserIDAndName(chatID int64, name string) (Storage, error) {
if name == "" {
return nil, ErrStorageNameEmpty
}
if !config.Cfg.HasStorage(chatID, name) {
return nil, fmt.Errorf("没有找到用户 %d 的存储 %s", chatID, name)
}
return GetStorageByName(name)
}
func GetUserStorages(chatID int64) []Storage {
if chatID <= 0 {
return nil
}
if storages, ok := UserStorages[chatID]; ok {
return storages
}
var storages []Storage
for _, name := range config.Cfg.GetStorageNamesByUserID(chatID) {
storage, err := GetStorageByName(name)
if err != nil {
continue
}
storages = append(storages, storage)
}
return storages
}
type StorageConstructor func() Storage
var storageConstructors = map[string]StorageConstructor{
string(types.StorageTypeAlist): func() Storage { return new(alist.Alist) },
string(types.StorageTypeLocal): func() Storage { return new(local.Local) },
string(types.StorageTypeWebdav): func() Storage { return new(webdav.Webdav) },
string(types.StorageTypeMinio): func() Storage { return new(minio.Minio) },
var storageConstructors = map[storenum.StorageType]StorageConstructor{
storenum.Alist: func() Storage { return new(alist.Alist) },
storenum.Local: func() Storage { return new(local.Local) },
storenum.Webdav: func() Storage { return new(webdav.Webdav) },
storenum.Minio: func() Storage { return new(minio.Minio) },
storenum.Telegram: func() Storage { return new(telegram.Telegram) },
}
func NewStorage(cfg sc.StorageConfig) (Storage, error) {
constructor, ok := storageConstructors[string(cfg.GetType())]
func NewStorage(ctx context.Context, cfg storcfg.StorageConfig) (Storage, error) {
constructor, ok := storageConstructors[cfg.GetType()]
if !ok {
return nil, fmt.Errorf("不支持的存储类型: %s", cfg.GetType())
}
storage := constructor()
if err := storage.Init(cfg); err != nil {
if err := storage.Init(ctx, cfg); err != nil {
return nil, fmt.Errorf("初始化 %s 存储失败: %w", cfg.GetName(), err)
}
return storage, nil
}
func LoadStorages() {
common.Log.Info("加载存储...")
for _, storage := range config.Cfg.Storages {
_, err := GetStorageByName(storage.GetName())
if err != nil {
common.Log.Errorf("加载存储 %s 失败: %v", storage.GetName(), err)
}
}
common.Log.Infof("成功加载 %d 个存储", len(Storages))
for user := range config.Cfg.GetUsersID() {
UserStorages[int64(user)] = GetUserStorages(int64(user))
}
}

View File

@@ -0,0 +1,111 @@
package telegram
import (
"context"
"fmt"
"io"
"path"
"time"
"github.com/gabriel-vasile/mimetype"
"github.com/gotd/td/telegram/message"
"github.com/gotd/td/telegram/message/styling"
"github.com/gotd/td/telegram/uploader"
"github.com/krau/SaveAny-Bot/common/utils/tgutil"
"github.com/krau/SaveAny-Bot/config"
storconfig "github.com/krau/SaveAny-Bot/config/storage"
"github.com/krau/SaveAny-Bot/pkg/consts/tglimit"
storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage"
"github.com/rs/xid"
"golang.org/x/time/rate"
)
type Telegram struct {
config storconfig.TelegramStorageConfig
limiter *rate.Limiter
}
func (t *Telegram) Init(ctx context.Context, cfg storconfig.StorageConfig) error {
telegramConfig, ok := cfg.(*storconfig.TelegramStorageConfig)
if !ok {
return fmt.Errorf("failed to cast telegram config")
}
if err := telegramConfig.Validate(); err != nil {
return err
}
t.config = *telegramConfig
if t.config.RateLimit <= 0 || t.config.RateBurst <= 0 {
t.config.RateLimit = 2
t.config.RateBurst = 1
}
t.limiter = rate.NewLimiter(rate.Every(time.Duration(t.config.RateLimit)*time.Second), t.config.RateBurst)
return nil
}
func (t *Telegram) Type() storenum.StorageType {
return storenum.Telegram
}
func (t *Telegram) Name() string {
return t.config.Name
}
func (t *Telegram) JoinStoragePath(p string) string {
return path.Clean(p)
}
func (t *Telegram) Exists(ctx context.Context, storagePath string) bool {
return false
}
func (t *Telegram) Save(ctx context.Context, r io.Reader, storagePath string) error {
if err := t.limiter.Wait(ctx); err != nil {
return fmt.Errorf("rate limit failed: %w", err)
}
rs, ok := r.(io.ReadSeeker)
if !ok || rs == nil {
return fmt.Errorf("reader must implement io.ReadSeeker")
}
tctx := tgutil.ExtFromContext(ctx)
if tctx == nil {
return fmt.Errorf("failed to get telegram context")
}
peer := tctx.PeerStorage.GetInputPeerById(t.config.ChatID)
if peer == nil {
return fmt.Errorf("failed to get input peer for chat ID %d", t.config.ChatID)
}
mtype, err := mimetype.DetectReader(rs)
if err != nil {
return fmt.Errorf("failed to detect mimetype: %w", err)
}
filename := path.Base(storagePath)
if filename == "" {
filename = xid.New().String() + mtype.Extension()
}
if _, err := rs.Seek(0, io.SeekStart); err != nil {
return fmt.Errorf("failed to seek reader: %w", err)
}
upler := uploader.NewUploader(tctx.Raw).
WithPartSize(tglimit.MaxUploadPartSize).
WithThreads(config.Cfg.Threads)
file, err := upler.FromReader(ctx, filename, rs)
if err != nil {
return fmt.Errorf("failed to upload file to telegram: %w", err)
}
caption := styling.Plain(filename)
docb := message.UploadedDocument(file, caption).
Filename(filename).
ForceFile(true).
MIME(mtype.String())
var mediaOpt message.MediaOption = docb
sender := tctx.Sender
_, err = sender.WithUploader(upler).To(peer).Media(ctx, mediaOpt)
return err
}
func (t *Telegram) CannotStream() string {
return "Telegram storage must use a ReaderSeeker"
}

View File

@@ -9,7 +9,7 @@ import (
"path"
"strings"
"github.com/krau/SaveAny-Bot/types"
"github.com/krau/SaveAny-Bot/pkg/enums/key"
)
type Client struct {
@@ -54,7 +54,7 @@ func (c *Client) doRequest(ctx context.Context, method WebdavMethod, url string,
req.Header.Set("Depth", "1")
}
if method == WebdavMethodPut && ctx != nil {
if length := ctx.Value(types.ContextKeyContentLength); length != nil {
if length := ctx.Value(key.ContextKeyContentLength); length != nil {
if l, ok := length.(int64); ok {
req.ContentLength = l
}

View File

@@ -6,19 +6,21 @@ import (
"io"
"net/http"
"path"
"strings"
"time"
"github.com/krau/SaveAny-Bot/common"
"github.com/charmbracelet/log"
config "github.com/krau/SaveAny-Bot/config/storage"
"github.com/krau/SaveAny-Bot/types"
storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage"
)
type Webdav struct {
config config.WebdavStorageConfig
client *Client
logger *log.Logger
}
func (w *Webdav) Init(cfg config.StorageConfig) error {
func (w *Webdav) Init(ctx context.Context, cfg config.StorageConfig) error {
webdavConfig, ok := cfg.(*config.WebdavStorageConfig)
if !ok {
return fmt.Errorf("failed to cast webdav config")
@@ -27,42 +29,51 @@ func (w *Webdav) Init(cfg config.StorageConfig) error {
return err
}
w.config = *webdavConfig
w.logger = log.FromContext(ctx).WithPrefix(fmt.Sprintf("webdav[%s]", w.config.Name))
w.client = NewClient(w.config.URL, w.config.Username, w.config.Password, &http.Client{
Timeout: time.Hour * 12,
})
return nil
}
func (w *Webdav) Type() types.StorageType {
return types.StorageTypeWebdav
func (w *Webdav) Type() storenum.StorageType {
return storenum.Webdav
}
func (w *Webdav) Name() string {
return w.config.Name
}
func (w *Webdav) JoinStoragePath(task types.Task) string {
return path.Join(w.config.BasePath, task.StoragePath)
func (w *Webdav) JoinStoragePath(p string) string {
return path.Join(w.config.BasePath, p)
}
func (w *Webdav) Save(ctx context.Context, r io.Reader, storagePath string) error {
common.Log.Infof("Saving file to %s", storagePath)
if err := w.client.MkDir(ctx, path.Dir(storagePath)); err != nil {
common.Log.Errorf("Failed to create directory %s: %v", path.Dir(storagePath), err)
w.logger.Infof("Saving file to %s", storagePath)
ext := path.Ext(storagePath)
base := strings.TrimSuffix(storagePath, ext)
candidate := storagePath
for i := 1; w.Exists(ctx, candidate); i++ {
candidate = fmt.Sprintf("%s_%d%s", base, i, ext)
}
if err := w.client.MkDir(ctx, path.Dir(candidate)); err != nil {
w.logger.Errorf("Failed to create directory %s: %v", path.Dir(candidate), err)
return ErrFailedToCreateDirectory
}
if err := w.client.WriteFile(ctx, storagePath, r); err != nil {
common.Log.Errorf("Failed to write file %s: %v", storagePath, err)
if err := w.client.WriteFile(ctx, candidate, r); err != nil {
w.logger.Errorf("Failed to write file %s: %v", candidate, err)
return ErrFailedToWriteFile
}
return nil
}
func (w *Webdav) Exists(ctx context.Context, storagePath string) bool {
common.Log.Debugf("Checking if file exists at %s", storagePath)
w.logger.Debugf("Checking if file exists at %s", storagePath)
exists, err := w.client.Exists(ctx, storagePath)
if err != nil {
common.Log.Errorf("Failed to check if file exists at %s: %v", storagePath, err)
w.logger.Errorf("Failed to check if file exists at %s: %v", storagePath, err)
return false
}
return exists