refactor: refactor task logic for better scalability (#76)
* refactor: a big refactor. wip * refactor: port handle file * refactor: place all handlers * fix: task info nil pointer * feat: enhance task progress tracking and context management * feat: cancel task * feat: stream mode * feat: silent mode * feat: dir cmd * refactor: remove unused old file * feat: rule cmd * feat: handle silent mode * feat: batch task * fix: batch task progress and temp file cleanup * refactor: update file creation and cleanup methods for better resource management * feat: add save command with silent mode handling * feat: message link * feat: update message prompts to include file count in storage selection * feat: slient save links * refactor: reduce dup code * feat: rule type * feat: chose dir * feat: refactor file handling and storage rules, improve error handling and logging * feat: rule mode * feat: telegraph pics * fix: tphpics nil pointer and inaccurate dirpath * feat: silent save telegraph * feat: add suffix to avoid file overwrite * feat: new storage telegram * chore: tidy go mod
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package alist
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@@ -8,11 +9,13 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/common"
|
||||
"github.com/charmbracelet/log"
|
||||
config "github.com/krau/SaveAny-Bot/config/storage"
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
"github.com/krau/SaveAny-Bot/pkg/enums/key"
|
||||
storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage"
|
||||
)
|
||||
|
||||
type Alist struct {
|
||||
@@ -21,9 +24,10 @@ type Alist struct {
|
||||
baseURL string
|
||||
loginInfo *loginRequest
|
||||
config config.AlistStorageConfig
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
func (a *Alist) Init(cfg config.StorageConfig) error {
|
||||
func (a *Alist) Init(ctx context.Context, cfg config.StorageConfig) error {
|
||||
alistConfig, ok := cfg.(*config.AlistStorageConfig)
|
||||
if !ok {
|
||||
return fmt.Errorf("failed to cast alist config")
|
||||
@@ -32,45 +36,46 @@ func (a *Alist) Init(cfg config.StorageConfig) error {
|
||||
return err
|
||||
}
|
||||
a.config = *alistConfig
|
||||
|
||||
a.baseURL = alistConfig.URL
|
||||
a.client = getHttpClient()
|
||||
a.logger = log.FromContext(ctx).WithPrefix(fmt.Sprintf("alist[%s]", alistConfig.Name))
|
||||
|
||||
if alistConfig.Token != "" {
|
||||
a.token = alistConfig.Token
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
|
||||
defer cancel()
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, a.baseURL+"/api/me", nil)
|
||||
if err != nil {
|
||||
common.Log.Fatalf("Failed to create request: %v", err)
|
||||
a.logger.Fatalf("Failed to create request: %v", err)
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Authorization", a.token)
|
||||
|
||||
resp, err := a.client.Do(req)
|
||||
if err != nil {
|
||||
common.Log.Fatalf("Failed to send request: %v", err)
|
||||
a.logger.Fatalf("Failed to send request: %v", err)
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
common.Log.Fatalf("Failed to get alist user info: %s", resp.Status)
|
||||
a.logger.Fatalf("Failed to get alist user info: %s", resp.Status)
|
||||
return err
|
||||
}
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
common.Log.Fatalf("Failed to read response body: %v", err)
|
||||
a.logger.Fatalf("Failed to read response body: %v", err)
|
||||
return err
|
||||
}
|
||||
var meResp meResponse
|
||||
if err := json.Unmarshal(body, &meResp); err != nil {
|
||||
common.Log.Fatalf("Failed to unmarshal me response: %v", err)
|
||||
a.logger.Fatalf("Failed to unmarshal me response: %v", err)
|
||||
return err
|
||||
}
|
||||
if meResp.Code != http.StatusOK {
|
||||
common.Log.Fatalf("Failed to get alist user info: %s", meResp.Message)
|
||||
a.logger.Fatalf("Failed to get alist user info: %s", meResp.Message)
|
||||
return err
|
||||
}
|
||||
common.Log.Debugf("Logged in Alist as %s", meResp.Data.Username)
|
||||
a.logger.Debugf("Logged in Alist as %s", meResp.Data.Username)
|
||||
return nil
|
||||
}
|
||||
a.loginInfo = &loginRequest{
|
||||
@@ -78,18 +83,18 @@ func (a *Alist) Init(cfg config.StorageConfig) error {
|
||||
Password: alistConfig.Password,
|
||||
}
|
||||
|
||||
if err := a.getToken(); err != nil {
|
||||
common.Log.Fatalf("Failed to login to Alist: %v", err)
|
||||
if err := a.getToken(ctx); err != nil {
|
||||
a.logger.Fatalf("Failed to login to Alist: %v", err)
|
||||
return err
|
||||
}
|
||||
common.Log.Debug("Logged in to Alist")
|
||||
a.logger.Debug("Logged in to Alist")
|
||||
|
||||
go a.refreshToken(*alistConfig)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Alist) Type() types.StorageType {
|
||||
return types.StorageTypeAlist
|
||||
func (a *Alist) Type() storenum.StorageType {
|
||||
return storenum.Alist
|
||||
}
|
||||
|
||||
func (a *Alist) Name() string {
|
||||
@@ -97,16 +102,23 @@ func (a *Alist) Name() string {
|
||||
}
|
||||
|
||||
func (a *Alist) Save(ctx context.Context, reader io.Reader, storagePath string) error {
|
||||
common.Log.Infof("Saving file to %s", storagePath)
|
||||
a.logger.Infof("Saving file to %s", storagePath)
|
||||
|
||||
ext := path.Ext(storagePath)
|
||||
base := strings.TrimSuffix(storagePath, ext)
|
||||
candidate := storagePath
|
||||
for i := 1; a.Exists(ctx, candidate); i++ {
|
||||
candidate = fmt.Sprintf("%s_%d%s", base, i, ext)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPut, a.baseURL+"/api/fs/put", reader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", a.token)
|
||||
req.Header.Set("File-Path", url.PathEscape(storagePath))
|
||||
req.Header.Set("File-Path", url.PathEscape(candidate))
|
||||
req.Header.Set("Content-Type", "application/octet-stream")
|
||||
if length := ctx.Value(types.ContextKeyContentLength); length != nil {
|
||||
if length := ctx.Value(key.ContextKeyContentLength); length != nil {
|
||||
length, ok := length.(int64)
|
||||
if ok {
|
||||
req.ContentLength = length
|
||||
@@ -140,15 +152,66 @@ func (a *Alist) Save(ctx context.Context, reader io.Reader, storagePath string)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Alist) NotSupportStream() string {
|
||||
return "Alist does not support chunked transfer encoding"
|
||||
}
|
||||
|
||||
func (a *Alist) JoinStoragePath(task types.Task) string {
|
||||
return path.Join(a.config.BasePath, task.StoragePath)
|
||||
func (a *Alist) JoinStoragePath(p string) string {
|
||||
return path.Join(a.config.BasePath, p)
|
||||
}
|
||||
|
||||
func (a *Alist) Exists(ctx context.Context, storagePath string) bool {
|
||||
// TODO: Implement it.
|
||||
return false
|
||||
// POST /api/fs/get
|
||||
/*
|
||||
body:
|
||||
{
|
||||
"path": "/t",
|
||||
"password": "",
|
||||
"page": 1,
|
||||
"per_page": 0,
|
||||
"refresh": false
|
||||
}
|
||||
*/
|
||||
body := map[string]any{
|
||||
"path": storagePath,
|
||||
"password": "",
|
||||
}
|
||||
bodyBytes, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
a.logger.Errorf("Failed to marshal request body: %v", err)
|
||||
return false
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, a.baseURL+"/api/fs/get", bytes.NewBuffer(bodyBytes))
|
||||
if err != nil {
|
||||
a.logger.Errorf("Failed to create request: %v", err)
|
||||
return false
|
||||
}
|
||||
req.Header.Set("Authorization", a.token)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := a.client.Do(req)
|
||||
if err != nil {
|
||||
a.logger.Errorf("Failed to send request: %v", err)
|
||||
return false
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return false
|
||||
}
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
a.logger.Errorf("Failed to read response body: %v", err)
|
||||
return false
|
||||
}
|
||||
var fsGetResp fsGetResponse
|
||||
if err := json.Unmarshal(data, &fsGetResp); err != nil {
|
||||
a.logger.Errorf("Failed to unmarshal fs get response: %v", err)
|
||||
return false
|
||||
}
|
||||
if fsGetResp.Code != http.StatusOK {
|
||||
a.logger.Errorf("Failed to get file info from Alist: %d, %s", fsGetResp.Code, fsGetResp.Message)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
|
||||
}
|
||||
|
||||
// Impl StorageCannotStream interface
|
||||
func (a *Alist) CannotStream() string {
|
||||
return "Alist does not support chunked transfer encoding"
|
||||
}
|
||||
|
||||
@@ -2,17 +2,17 @@ package alist
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/common"
|
||||
config "github.com/krau/SaveAny-Bot/config/storage"
|
||||
)
|
||||
|
||||
func (a *Alist) getToken() error {
|
||||
func (a *Alist) getToken(ctx context.Context) error {
|
||||
loginBody, err := json.Marshal(a.loginInfo)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal login request: %w", err)
|
||||
@@ -51,15 +51,15 @@ func (a *Alist) getToken() error {
|
||||
func (a *Alist) refreshToken(cfg config.AlistStorageConfig) {
|
||||
tokenExp := cfg.TokenExp
|
||||
if tokenExp <= 0 {
|
||||
common.Log.Warn("Invalid token expiration time, using default value")
|
||||
a.logger.Warn("Invalid token expiration time, using default value")
|
||||
tokenExp = 3600
|
||||
}
|
||||
for {
|
||||
time.Sleep(time.Duration(tokenExp) * time.Second)
|
||||
if err := a.getToken(); err != nil {
|
||||
common.Log.Errorf("Failed to refresh jwt token: %v", err)
|
||||
if err := a.getToken(context.Background()); err != nil {
|
||||
a.logger.Errorf("Failed to refresh jwt token: %v", err)
|
||||
continue
|
||||
}
|
||||
common.Log.Info("Refreshed Alist jwt token")
|
||||
a.logger.Info("Refreshed Alist jwt token")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,3 +42,8 @@ type putResponse struct {
|
||||
} `json:"task"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
type fsGetResponse struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
22
storage/context.go
Normal file
22
storage/context.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package storage
|
||||
|
||||
import "context"
|
||||
|
||||
type contextKey struct{}
|
||||
|
||||
var storageKey = contextKey{}
|
||||
|
||||
func WithContext(ctx context.Context, storage Storage) context.Context {
|
||||
if storage == nil {
|
||||
return ctx
|
||||
}
|
||||
return context.WithValue(ctx, storageKey, storage)
|
||||
}
|
||||
|
||||
func FromContext(ctx context.Context) Storage {
|
||||
storage, ok := ctx.Value(storageKey).(Storage)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return storage
|
||||
}
|
||||
80
storage/load.go
Normal file
80
storage/load.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
)
|
||||
|
||||
var UserStorages = make(map[int64][]Storage)
|
||||
|
||||
// GetStorageByName returns storage by name from cache or creates new one
|
||||
func getStorageByName(ctx context.Context, name string) (Storage, error) {
|
||||
if name == "" {
|
||||
return nil, ErrStorageNameEmpty
|
||||
}
|
||||
|
||||
storage, ok := Storages[name]
|
||||
if ok {
|
||||
return storage, nil
|
||||
}
|
||||
cfg := config.Cfg.GetStorageByName(name)
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("未找到存储 %s", name)
|
||||
}
|
||||
|
||||
storage, err := NewStorage(ctx, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
Storages[name] = storage
|
||||
return storage, nil
|
||||
}
|
||||
|
||||
// 检查 user 是否可用指定的 storage, 若不可用则返回未找到错误
|
||||
func GetStorageByUserIDAndName(ctx context.Context, chatID int64, name string) (Storage, error) {
|
||||
if name == "" {
|
||||
return nil, ErrStorageNameEmpty
|
||||
}
|
||||
|
||||
if !config.Cfg.HasStorage(chatID, name) {
|
||||
return nil, fmt.Errorf("没有找到用户 %d 的存储 %s", chatID, name)
|
||||
}
|
||||
|
||||
return getStorageByName(ctx, name)
|
||||
}
|
||||
|
||||
func GetUserStorages(ctx context.Context, chatID int64) []Storage {
|
||||
if chatID <= 0 {
|
||||
return nil
|
||||
}
|
||||
if storages, ok := UserStorages[chatID]; ok {
|
||||
return storages
|
||||
}
|
||||
var storages []Storage
|
||||
for _, name := range config.Cfg.GetStorageNamesByUserID(chatID) {
|
||||
storage, err := getStorageByName(ctx, name)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
storages = append(storages, storage)
|
||||
}
|
||||
return storages
|
||||
}
|
||||
|
||||
func LoadStorages(ctx context.Context) {
|
||||
logger := log.FromContext(ctx)
|
||||
logger.Info("加载存储...")
|
||||
for _, storage := range config.Cfg.Storages {
|
||||
_, err := getStorageByName(ctx, storage.GetName())
|
||||
if err != nil {
|
||||
logger.Errorf("加载存储 %s 失败: %v", storage.GetName(), err)
|
||||
}
|
||||
}
|
||||
logger.Infof("成功加载 %d 个存储", len(Storages))
|
||||
for user := range config.Cfg.GetUsersID() {
|
||||
UserStorages[int64(user)] = GetUserStorages(ctx, int64(user))
|
||||
}
|
||||
}
|
||||
@@ -6,18 +6,20 @@ import (
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/duke-git/lancet/v2/fileutil"
|
||||
"github.com/krau/SaveAny-Bot/common"
|
||||
config "github.com/krau/SaveAny-Bot/config/storage"
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage"
|
||||
)
|
||||
|
||||
type Local struct {
|
||||
config config.LocalStorageConfig
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
func (l *Local) Init(cfg config.StorageConfig) error {
|
||||
func (l *Local) Init(ctx context.Context, cfg config.StorageConfig) error {
|
||||
localConfig, ok := cfg.(*config.LocalStorageConfig)
|
||||
if !ok {
|
||||
return fmt.Errorf("failed to cast local config")
|
||||
@@ -30,25 +32,33 @@ func (l *Local) Init(cfg config.StorageConfig) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create local storage directory: %w", err)
|
||||
}
|
||||
l.logger = log.FromContext(ctx).WithPrefix(fmt.Sprintf("local[%s]", l.config.Name))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *Local) Type() types.StorageType {
|
||||
return types.StorageTypeLocal
|
||||
func (l *Local) Type() storenum.StorageType {
|
||||
return storenum.Local
|
||||
}
|
||||
|
||||
func (l *Local) Name() string {
|
||||
return l.config.Name
|
||||
}
|
||||
|
||||
func (l *Local) JoinStoragePath(task types.Task) string {
|
||||
return filepath.Join(l.config.BasePath, task.StoragePath)
|
||||
func (l *Local) JoinStoragePath(path string) string {
|
||||
return filepath.Join(l.config.BasePath, path)
|
||||
}
|
||||
|
||||
func (l *Local) Save(ctx context.Context, r io.Reader, storagePath string) error {
|
||||
common.Log.Infof("Saving file to %s", storagePath)
|
||||
l.logger.Infof("Saving file to %s", storagePath)
|
||||
|
||||
absPath, err := filepath.Abs(storagePath)
|
||||
ext := filepath.Ext(storagePath)
|
||||
base := strings.TrimSuffix(storagePath, ext)
|
||||
candidate := storagePath
|
||||
for i := 1; l.Exists(ctx, candidate); i++ {
|
||||
candidate = fmt.Sprintf("%s_%d%s", base, i, ext)
|
||||
}
|
||||
|
||||
absPath, err := filepath.Abs(candidate)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -5,10 +5,11 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/common"
|
||||
"github.com/charmbracelet/log"
|
||||
config "github.com/krau/SaveAny-Bot/config/storage"
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage"
|
||||
"github.com/minio/minio-go/v7"
|
||||
"github.com/minio/minio-go/v7/pkg/credentials"
|
||||
)
|
||||
@@ -16,9 +17,10 @@ import (
|
||||
type Minio struct {
|
||||
config config.MinioStorageConfig
|
||||
client *minio.Client
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
func (m *Minio) Init(cfg config.StorageConfig) error {
|
||||
func (m *Minio) Init(ctx context.Context, cfg config.StorageConfig) error {
|
||||
minioConfig, ok := cfg.(*config.MinioStorageConfig)
|
||||
if !ok {
|
||||
return fmt.Errorf("failed to cast minio config")
|
||||
@@ -27,6 +29,7 @@ func (m *Minio) Init(cfg config.StorageConfig) error {
|
||||
return err
|
||||
}
|
||||
m.config = *minioConfig
|
||||
m.logger = log.FromContext(ctx).WithPrefix(fmt.Sprintf("minio[%s]", m.config.Name))
|
||||
|
||||
client, err := minio.New(m.config.Endpoint, &minio.Options{
|
||||
Creds: credentials.NewStaticV4(m.config.AccessKeyID, m.config.SecretAccessKey, ""),
|
||||
@@ -36,7 +39,7 @@ func (m *Minio) Init(cfg config.StorageConfig) error {
|
||||
return fmt.Errorf("failed to create minio client: %w", err)
|
||||
}
|
||||
|
||||
exists, err := client.BucketExists(context.Background(), m.config.BucketName)
|
||||
exists, err := client.BucketExists(ctx, m.config.BucketName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check bucket existence: %w", err)
|
||||
}
|
||||
@@ -48,22 +51,29 @@ func (m *Minio) Init(cfg config.StorageConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Minio) Type() types.StorageType {
|
||||
return types.StorageTypeMinio
|
||||
func (m *Minio) Type() storenum.StorageType {
|
||||
return storenum.Minio
|
||||
}
|
||||
|
||||
func (m *Minio) Name() string {
|
||||
return m.config.Name
|
||||
}
|
||||
|
||||
func (m *Minio) JoinStoragePath(task types.Task) string {
|
||||
return path.Join(m.config.BasePath, task.StoragePath)
|
||||
func (m *Minio) JoinStoragePath(p string) string {
|
||||
return path.Join(m.config.BasePath, p)
|
||||
}
|
||||
|
||||
func (m *Minio) Save(ctx context.Context, r io.Reader, storagePath string) error {
|
||||
common.Log.Infof("Saving file from reader to %s", storagePath)
|
||||
m.logger.Infof("Saving file from reader to %s", storagePath)
|
||||
|
||||
_, err := m.client.PutObject(ctx, m.config.BucketName, storagePath, r, -1, minio.PutObjectOptions{})
|
||||
ext := path.Ext(storagePath)
|
||||
base := strings.TrimSuffix(storagePath, ext)
|
||||
candidate := storagePath
|
||||
for i := 1; m.Exists(ctx, candidate); i++ {
|
||||
candidate = fmt.Sprintf("%s_%d%s", base, i, ext)
|
||||
}
|
||||
|
||||
_, err := m.client.PutObject(ctx, m.config.BucketName, candidate, r, -1, minio.PutObjectOptions{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to upload file to minio: %w", err)
|
||||
}
|
||||
@@ -72,15 +82,7 @@ func (m *Minio) Save(ctx context.Context, r io.Reader, storagePath string) error
|
||||
}
|
||||
|
||||
func (m *Minio) Exists(ctx context.Context, storagePath string) bool {
|
||||
common.Log.Debugf("Checking if file exists at %s", storagePath)
|
||||
// TODO: test it.
|
||||
m.logger.Debugf("Checking if file exists at %s", storagePath)
|
||||
_, err := m.client.StatObject(ctx, m.config.BucketName, storagePath, minio.StatObjectOptions{})
|
||||
if err != nil {
|
||||
if minio.ToErrorResponse(err).Code == "NoSuchKey" {
|
||||
return false // File does not exist
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
return err == nil
|
||||
}
|
||||
|
||||
@@ -5,121 +5,51 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/common"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
sc "github.com/krau/SaveAny-Bot/config/storage"
|
||||
storcfg "github.com/krau/SaveAny-Bot/config/storage"
|
||||
storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage"
|
||||
"github.com/krau/SaveAny-Bot/storage/alist"
|
||||
"github.com/krau/SaveAny-Bot/storage/local"
|
||||
"github.com/krau/SaveAny-Bot/storage/minio"
|
||||
"github.com/krau/SaveAny-Bot/storage/telegram"
|
||||
"github.com/krau/SaveAny-Bot/storage/webdav"
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
)
|
||||
|
||||
type Storage interface {
|
||||
Init(cfg sc.StorageConfig) error
|
||||
Type() types.StorageType
|
||||
Init(ctx context.Context, cfg storcfg.StorageConfig) error
|
||||
Type() storenum.StorageType
|
||||
Name() string
|
||||
JoinStoragePath(task types.Task) string
|
||||
JoinStoragePath(p string) string
|
||||
Save(ctx context.Context, reader io.Reader, storagePath string) error
|
||||
Exists(ctx context.Context, storagePath string) bool
|
||||
}
|
||||
|
||||
type StorageNotSupportStream interface {
|
||||
type StorageCannotStream interface {
|
||||
Storage
|
||||
NotSupportStream() string
|
||||
CannotStream() string
|
||||
}
|
||||
|
||||
var Storages = make(map[string]Storage)
|
||||
|
||||
var UserStorages = make(map[int64][]Storage)
|
||||
|
||||
// GetStorageByName returns storage by name from cache or creates new one
|
||||
func GetStorageByName(name string) (Storage, error) {
|
||||
if name == "" {
|
||||
return nil, ErrStorageNameEmpty
|
||||
}
|
||||
|
||||
storage, ok := Storages[name]
|
||||
if ok {
|
||||
return storage, nil
|
||||
}
|
||||
cfg := config.Cfg.GetStorageByName(name)
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("未找到存储 %s", name)
|
||||
}
|
||||
|
||||
storage, err := NewStorage(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
Storages[name] = storage
|
||||
return storage, nil
|
||||
}
|
||||
|
||||
// 检查 user 是否可用指定的 storage, 若不可用则返回未找到错误
|
||||
func GetStorageByUserIDAndName(chatID int64, name string) (Storage, error) {
|
||||
if name == "" {
|
||||
return nil, ErrStorageNameEmpty
|
||||
}
|
||||
|
||||
if !config.Cfg.HasStorage(chatID, name) {
|
||||
return nil, fmt.Errorf("没有找到用户 %d 的存储 %s", chatID, name)
|
||||
}
|
||||
|
||||
return GetStorageByName(name)
|
||||
}
|
||||
|
||||
func GetUserStorages(chatID int64) []Storage {
|
||||
if chatID <= 0 {
|
||||
return nil
|
||||
}
|
||||
if storages, ok := UserStorages[chatID]; ok {
|
||||
return storages
|
||||
}
|
||||
var storages []Storage
|
||||
for _, name := range config.Cfg.GetStorageNamesByUserID(chatID) {
|
||||
storage, err := GetStorageByName(name)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
storages = append(storages, storage)
|
||||
}
|
||||
return storages
|
||||
}
|
||||
|
||||
type StorageConstructor func() Storage
|
||||
|
||||
var storageConstructors = map[string]StorageConstructor{
|
||||
string(types.StorageTypeAlist): func() Storage { return new(alist.Alist) },
|
||||
string(types.StorageTypeLocal): func() Storage { return new(local.Local) },
|
||||
string(types.StorageTypeWebdav): func() Storage { return new(webdav.Webdav) },
|
||||
string(types.StorageTypeMinio): func() Storage { return new(minio.Minio) },
|
||||
var storageConstructors = map[storenum.StorageType]StorageConstructor{
|
||||
storenum.Alist: func() Storage { return new(alist.Alist) },
|
||||
storenum.Local: func() Storage { return new(local.Local) },
|
||||
storenum.Webdav: func() Storage { return new(webdav.Webdav) },
|
||||
storenum.Minio: func() Storage { return new(minio.Minio) },
|
||||
storenum.Telegram: func() Storage { return new(telegram.Telegram) },
|
||||
}
|
||||
|
||||
func NewStorage(cfg sc.StorageConfig) (Storage, error) {
|
||||
constructor, ok := storageConstructors[string(cfg.GetType())]
|
||||
func NewStorage(ctx context.Context, cfg storcfg.StorageConfig) (Storage, error) {
|
||||
constructor, ok := storageConstructors[cfg.GetType()]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("不支持的存储类型: %s", cfg.GetType())
|
||||
}
|
||||
|
||||
storage := constructor()
|
||||
if err := storage.Init(cfg); err != nil {
|
||||
if err := storage.Init(ctx, cfg); err != nil {
|
||||
return nil, fmt.Errorf("初始化 %s 存储失败: %w", cfg.GetName(), err)
|
||||
}
|
||||
|
||||
return storage, nil
|
||||
}
|
||||
|
||||
func LoadStorages() {
|
||||
common.Log.Info("加载存储...")
|
||||
for _, storage := range config.Cfg.Storages {
|
||||
_, err := GetStorageByName(storage.GetName())
|
||||
if err != nil {
|
||||
common.Log.Errorf("加载存储 %s 失败: %v", storage.GetName(), err)
|
||||
}
|
||||
}
|
||||
common.Log.Infof("成功加载 %d 个存储", len(Storages))
|
||||
for user := range config.Cfg.GetUsersID() {
|
||||
UserStorages[int64(user)] = GetUserStorages(int64(user))
|
||||
}
|
||||
}
|
||||
|
||||
111
storage/telegram/telegram.go
Normal file
111
storage/telegram/telegram.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package telegram
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"path"
|
||||
"time"
|
||||
|
||||
"github.com/gabriel-vasile/mimetype"
|
||||
"github.com/gotd/td/telegram/message"
|
||||
"github.com/gotd/td/telegram/message/styling"
|
||||
"github.com/gotd/td/telegram/uploader"
|
||||
"github.com/krau/SaveAny-Bot/common/utils/tgutil"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
storconfig "github.com/krau/SaveAny-Bot/config/storage"
|
||||
"github.com/krau/SaveAny-Bot/pkg/consts/tglimit"
|
||||
storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage"
|
||||
"github.com/rs/xid"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
type Telegram struct {
|
||||
config storconfig.TelegramStorageConfig
|
||||
limiter *rate.Limiter
|
||||
}
|
||||
|
||||
func (t *Telegram) Init(ctx context.Context, cfg storconfig.StorageConfig) error {
|
||||
telegramConfig, ok := cfg.(*storconfig.TelegramStorageConfig)
|
||||
if !ok {
|
||||
return fmt.Errorf("failed to cast telegram config")
|
||||
}
|
||||
if err := telegramConfig.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
t.config = *telegramConfig
|
||||
if t.config.RateLimit <= 0 || t.config.RateBurst <= 0 {
|
||||
t.config.RateLimit = 2
|
||||
t.config.RateBurst = 1
|
||||
}
|
||||
t.limiter = rate.NewLimiter(rate.Every(time.Duration(t.config.RateLimit)*time.Second), t.config.RateBurst)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Telegram) Type() storenum.StorageType {
|
||||
return storenum.Telegram
|
||||
}
|
||||
|
||||
func (t *Telegram) Name() string {
|
||||
return t.config.Name
|
||||
}
|
||||
|
||||
func (t *Telegram) JoinStoragePath(p string) string {
|
||||
return path.Clean(p)
|
||||
}
|
||||
|
||||
func (t *Telegram) Exists(ctx context.Context, storagePath string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (t *Telegram) Save(ctx context.Context, r io.Reader, storagePath string) error {
|
||||
if err := t.limiter.Wait(ctx); err != nil {
|
||||
return fmt.Errorf("rate limit failed: %w", err)
|
||||
}
|
||||
rs, ok := r.(io.ReadSeeker)
|
||||
if !ok || rs == nil {
|
||||
return fmt.Errorf("reader must implement io.ReadSeeker")
|
||||
}
|
||||
tctx := tgutil.ExtFromContext(ctx)
|
||||
if tctx == nil {
|
||||
return fmt.Errorf("failed to get telegram context")
|
||||
}
|
||||
peer := tctx.PeerStorage.GetInputPeerById(t.config.ChatID)
|
||||
if peer == nil {
|
||||
return fmt.Errorf("failed to get input peer for chat ID %d", t.config.ChatID)
|
||||
}
|
||||
mtype, err := mimetype.DetectReader(rs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to detect mimetype: %w", err)
|
||||
}
|
||||
filename := path.Base(storagePath)
|
||||
if filename == "" {
|
||||
filename = xid.New().String() + mtype.Extension()
|
||||
}
|
||||
if _, err := rs.Seek(0, io.SeekStart); err != nil {
|
||||
return fmt.Errorf("failed to seek reader: %w", err)
|
||||
}
|
||||
upler := uploader.NewUploader(tctx.Raw).
|
||||
WithPartSize(tglimit.MaxUploadPartSize).
|
||||
WithThreads(config.Cfg.Threads)
|
||||
|
||||
file, err := upler.FromReader(ctx, filename, rs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to upload file to telegram: %w", err)
|
||||
}
|
||||
|
||||
caption := styling.Plain(filename)
|
||||
docb := message.UploadedDocument(file, caption).
|
||||
Filename(filename).
|
||||
ForceFile(true).
|
||||
MIME(mtype.String())
|
||||
|
||||
var mediaOpt message.MediaOption = docb
|
||||
sender := tctx.Sender
|
||||
_, err = sender.WithUploader(upler).To(peer).Media(ctx, mediaOpt)
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *Telegram) CannotStream() string {
|
||||
return "Telegram storage must use a ReaderSeeker"
|
||||
}
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
"github.com/krau/SaveAny-Bot/pkg/enums/key"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
@@ -54,7 +54,7 @@ func (c *Client) doRequest(ctx context.Context, method WebdavMethod, url string,
|
||||
req.Header.Set("Depth", "1")
|
||||
}
|
||||
if method == WebdavMethodPut && ctx != nil {
|
||||
if length := ctx.Value(types.ContextKeyContentLength); length != nil {
|
||||
if length := ctx.Value(key.ContextKeyContentLength); length != nil {
|
||||
if l, ok := length.(int64); ok {
|
||||
req.ContentLength = l
|
||||
}
|
||||
|
||||
@@ -6,19 +6,21 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"path"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/common"
|
||||
"github.com/charmbracelet/log"
|
||||
config "github.com/krau/SaveAny-Bot/config/storage"
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage"
|
||||
)
|
||||
|
||||
type Webdav struct {
|
||||
config config.WebdavStorageConfig
|
||||
client *Client
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
func (w *Webdav) Init(cfg config.StorageConfig) error {
|
||||
func (w *Webdav) Init(ctx context.Context, cfg config.StorageConfig) error {
|
||||
webdavConfig, ok := cfg.(*config.WebdavStorageConfig)
|
||||
if !ok {
|
||||
return fmt.Errorf("failed to cast webdav config")
|
||||
@@ -27,42 +29,51 @@ func (w *Webdav) Init(cfg config.StorageConfig) error {
|
||||
return err
|
||||
}
|
||||
w.config = *webdavConfig
|
||||
w.logger = log.FromContext(ctx).WithPrefix(fmt.Sprintf("webdav[%s]", w.config.Name))
|
||||
w.client = NewClient(w.config.URL, w.config.Username, w.config.Password, &http.Client{
|
||||
Timeout: time.Hour * 12,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *Webdav) Type() types.StorageType {
|
||||
return types.StorageTypeWebdav
|
||||
func (w *Webdav) Type() storenum.StorageType {
|
||||
return storenum.Webdav
|
||||
}
|
||||
|
||||
func (w *Webdav) Name() string {
|
||||
return w.config.Name
|
||||
}
|
||||
|
||||
func (w *Webdav) JoinStoragePath(task types.Task) string {
|
||||
return path.Join(w.config.BasePath, task.StoragePath)
|
||||
func (w *Webdav) JoinStoragePath(p string) string {
|
||||
return path.Join(w.config.BasePath, p)
|
||||
}
|
||||
|
||||
func (w *Webdav) Save(ctx context.Context, r io.Reader, storagePath string) error {
|
||||
common.Log.Infof("Saving file to %s", storagePath)
|
||||
if err := w.client.MkDir(ctx, path.Dir(storagePath)); err != nil {
|
||||
common.Log.Errorf("Failed to create directory %s: %v", path.Dir(storagePath), err)
|
||||
w.logger.Infof("Saving file to %s", storagePath)
|
||||
|
||||
ext := path.Ext(storagePath)
|
||||
base := strings.TrimSuffix(storagePath, ext)
|
||||
candidate := storagePath
|
||||
for i := 1; w.Exists(ctx, candidate); i++ {
|
||||
candidate = fmt.Sprintf("%s_%d%s", base, i, ext)
|
||||
}
|
||||
|
||||
if err := w.client.MkDir(ctx, path.Dir(candidate)); err != nil {
|
||||
w.logger.Errorf("Failed to create directory %s: %v", path.Dir(candidate), err)
|
||||
return ErrFailedToCreateDirectory
|
||||
}
|
||||
if err := w.client.WriteFile(ctx, storagePath, r); err != nil {
|
||||
common.Log.Errorf("Failed to write file %s: %v", storagePath, err)
|
||||
if err := w.client.WriteFile(ctx, candidate, r); err != nil {
|
||||
w.logger.Errorf("Failed to write file %s: %v", candidate, err)
|
||||
return ErrFailedToWriteFile
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *Webdav) Exists(ctx context.Context, storagePath string) bool {
|
||||
common.Log.Debugf("Checking if file exists at %s", storagePath)
|
||||
w.logger.Debugf("Checking if file exists at %s", storagePath)
|
||||
exists, err := w.client.Exists(ctx, storagePath)
|
||||
if err != nil {
|
||||
common.Log.Errorf("Failed to check if file exists at %s: %v", storagePath, err)
|
||||
w.logger.Errorf("Failed to check if file exists at %s: %v", storagePath, err)
|
||||
return false
|
||||
}
|
||||
return exists
|
||||
|
||||
Reference in New Issue
Block a user