mirror of
https://github.com/krau/SaveAny-Bot.git
synced 2026-05-10 17:52:44 +08:00
refactor: refactor task logic for better scalability (#76)
* refactor: a big refactor. wip * refactor: port handle file * refactor: place all handlers * fix: task info nil pointer * feat: enhance task progress tracking and context management * feat: cancel task * feat: stream mode * feat: silent mode * feat: dir cmd * refactor: remove unused old file * feat: rule cmd * feat: handle silent mode * feat: batch task * fix: batch task progress and temp file cleanup * refactor: update file creation and cleanup methods for better resource management * feat: add save command with silent mode handling * feat: message link * feat: update message prompts to include file count in storage selection * feat: slient save links * refactor: reduce dup code * feat: rule type * feat: chose dir * feat: refactor file handling and storage rules, improve error handling and logging * feat: rule mode * feat: telegraph pics * fix: tphpics nil pointer and inaccurate dirpath * feat: silent save telegraph * feat: add suffix to avoid file overwrite * feat: new storage telegram * chore: tidy go mod
This commit is contained in:
5
pkg/consts/specific.go
Normal file
5
pkg/consts/specific.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package consts
|
||||
|
||||
const (
|
||||
RuleStorNameChosen = "CHOSEN"
|
||||
)
|
||||
6
pkg/consts/tglimit/tglimit.go
Normal file
6
pkg/consts/tglimit/tglimit.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package tglimit
|
||||
|
||||
const (
|
||||
MaxPartSize = 1024 * 1024
|
||||
MaxUploadPartSize = 512 * 1024
|
||||
)
|
||||
9
pkg/consts/version.go
Normal file
9
pkg/consts/version.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package consts
|
||||
|
||||
// inject version by '-X' flag
|
||||
// go build -ldflags "-X github.com/krau/SaveAny-Bot/pkg/consts.Version=${{ env.VERSION }}"
|
||||
var (
|
||||
Version string = "dev"
|
||||
BuildTime string = "unknown"
|
||||
GitCommit string = "unknown"
|
||||
)
|
||||
5
pkg/enums/key/context_key.go
Normal file
5
pkg/enums/key/context_key.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package key
|
||||
|
||||
//go:generate go-enum --values --names --flag --nocase
|
||||
// ENUM(content-length)
|
||||
type ContextKey string
|
||||
82
pkg/enums/key/context_key_enum.go
Normal file
82
pkg/enums/key/context_key_enum.go
Normal file
@@ -0,0 +1,82 @@
|
||||
// Code generated by go-enum DO NOT EDIT.
|
||||
// Version: 0.6.1
|
||||
// Revision: a6f63bddde05aca4221df9c8e9e6d7d9674b1cb4
|
||||
// Build Date: 2025-03-18T23:42:14Z
|
||||
// Built By: goreleaser
|
||||
|
||||
package key
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
// ContextKeyContentLength is a ContextKey of type content-length.
|
||||
ContextKeyContentLength ContextKey = "content-length"
|
||||
)
|
||||
|
||||
var ErrInvalidContextKey = fmt.Errorf("not a valid ContextKey, try [%s]", strings.Join(_ContextKeyNames, ", "))
|
||||
|
||||
var _ContextKeyNames = []string{
|
||||
string(ContextKeyContentLength),
|
||||
}
|
||||
|
||||
// ContextKeyNames returns a list of possible string values of ContextKey.
|
||||
func ContextKeyNames() []string {
|
||||
tmp := make([]string, len(_ContextKeyNames))
|
||||
copy(tmp, _ContextKeyNames)
|
||||
return tmp
|
||||
}
|
||||
|
||||
// ContextKeyValues returns a list of the values for ContextKey
|
||||
func ContextKeyValues() []ContextKey {
|
||||
return []ContextKey{
|
||||
ContextKeyContentLength,
|
||||
}
|
||||
}
|
||||
|
||||
// String implements the Stringer interface.
|
||||
func (x ContextKey) String() string {
|
||||
return string(x)
|
||||
}
|
||||
|
||||
// IsValid provides a quick way to determine if the typed value is
|
||||
// part of the allowed enumerated values
|
||||
func (x ContextKey) IsValid() bool {
|
||||
_, err := ParseContextKey(string(x))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
var _ContextKeyValue = map[string]ContextKey{
|
||||
"content-length": ContextKeyContentLength,
|
||||
}
|
||||
|
||||
// ParseContextKey attempts to convert a string to a ContextKey.
|
||||
func ParseContextKey(name string) (ContextKey, error) {
|
||||
if x, ok := _ContextKeyValue[name]; ok {
|
||||
return x, nil
|
||||
}
|
||||
// Case insensitive parse, do a separate lookup to prevent unnecessary cost of lowercasing a string if we don't need to.
|
||||
if x, ok := _ContextKeyValue[strings.ToLower(name)]; ok {
|
||||
return x, nil
|
||||
}
|
||||
return ContextKey(""), fmt.Errorf("%s is %w", name, ErrInvalidContextKey)
|
||||
}
|
||||
|
||||
// Set implements the Golang flag.Value interface func.
|
||||
func (x *ContextKey) Set(val string) error {
|
||||
v, err := ParseContextKey(val)
|
||||
*x = v
|
||||
return err
|
||||
}
|
||||
|
||||
// Get implements the Golang flag.Getter interface func.
|
||||
func (x *ContextKey) Get() interface{} {
|
||||
return *x
|
||||
}
|
||||
|
||||
// Type implements the github.com/spf13/pFlag Value interface.
|
||||
func (x *ContextKey) Type() string {
|
||||
return "ContextKey"
|
||||
}
|
||||
16
pkg/enums/rule/ruletype.go
Normal file
16
pkg/enums/rule/ruletype.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package rule
|
||||
|
||||
type RuleType string
|
||||
|
||||
const (
|
||||
FileNameRegex RuleType = "FILENAME-REGEX"
|
||||
MessageRegex RuleType = "MESSAGE-REGEX"
|
||||
)
|
||||
|
||||
func (r RuleType) String() string {
|
||||
return string(r)
|
||||
}
|
||||
|
||||
func Values() []RuleType {
|
||||
return []RuleType{FileNameRegex, MessageRegex}
|
||||
}
|
||||
9
pkg/enums/storage/storages.go
Normal file
9
pkg/enums/storage/storages.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package storage
|
||||
|
||||
//go:generate go-enum --values --names --noprefix --flag --nocase
|
||||
|
||||
// StorageType
|
||||
/* ENUM(
|
||||
local, webdav, alist, minio, telegram
|
||||
) */
|
||||
type StorageType string
|
||||
102
pkg/enums/storage/storages_enum.go
Normal file
102
pkg/enums/storage/storages_enum.go
Normal file
@@ -0,0 +1,102 @@
|
||||
// Code generated by go-enum DO NOT EDIT.
|
||||
// Version: 0.6.1
|
||||
// Revision: a6f63bddde05aca4221df9c8e9e6d7d9674b1cb4
|
||||
// Build Date: 2025-03-18T23:42:14Z
|
||||
// Built By: goreleaser
|
||||
|
||||
package storage
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
// Local is a StorageType of type local.
|
||||
Local StorageType = "local"
|
||||
// Webdav is a StorageType of type webdav.
|
||||
Webdav StorageType = "webdav"
|
||||
// Alist is a StorageType of type alist.
|
||||
Alist StorageType = "alist"
|
||||
// Minio is a StorageType of type minio.
|
||||
Minio StorageType = "minio"
|
||||
// Telegram is a StorageType of type telegram.
|
||||
Telegram StorageType = "telegram"
|
||||
)
|
||||
|
||||
var ErrInvalidStorageType = fmt.Errorf("not a valid StorageType, try [%s]", strings.Join(_StorageTypeNames, ", "))
|
||||
|
||||
var _StorageTypeNames = []string{
|
||||
string(Local),
|
||||
string(Webdav),
|
||||
string(Alist),
|
||||
string(Minio),
|
||||
string(Telegram),
|
||||
}
|
||||
|
||||
// StorageTypeNames returns a list of possible string values of StorageType.
|
||||
func StorageTypeNames() []string {
|
||||
tmp := make([]string, len(_StorageTypeNames))
|
||||
copy(tmp, _StorageTypeNames)
|
||||
return tmp
|
||||
}
|
||||
|
||||
// StorageTypeValues returns a list of the values for StorageType
|
||||
func StorageTypeValues() []StorageType {
|
||||
return []StorageType{
|
||||
Local,
|
||||
Webdav,
|
||||
Alist,
|
||||
Minio,
|
||||
Telegram,
|
||||
}
|
||||
}
|
||||
|
||||
// String implements the Stringer interface.
|
||||
func (x StorageType) String() string {
|
||||
return string(x)
|
||||
}
|
||||
|
||||
// IsValid provides a quick way to determine if the typed value is
|
||||
// part of the allowed enumerated values
|
||||
func (x StorageType) IsValid() bool {
|
||||
_, err := ParseStorageType(string(x))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
var _StorageTypeValue = map[string]StorageType{
|
||||
"local": Local,
|
||||
"webdav": Webdav,
|
||||
"alist": Alist,
|
||||
"minio": Minio,
|
||||
"telegram": Telegram,
|
||||
}
|
||||
|
||||
// ParseStorageType attempts to convert a string to a StorageType.
|
||||
func ParseStorageType(name string) (StorageType, error) {
|
||||
if x, ok := _StorageTypeValue[name]; ok {
|
||||
return x, nil
|
||||
}
|
||||
// Case insensitive parse, do a separate lookup to prevent unnecessary cost of lowercasing a string if we don't need to.
|
||||
if x, ok := _StorageTypeValue[strings.ToLower(name)]; ok {
|
||||
return x, nil
|
||||
}
|
||||
return StorageType(""), fmt.Errorf("%s is %w", name, ErrInvalidStorageType)
|
||||
}
|
||||
|
||||
// Set implements the Golang flag.Value interface func.
|
||||
func (x *StorageType) Set(val string) error {
|
||||
v, err := ParseStorageType(val)
|
||||
*x = v
|
||||
return err
|
||||
}
|
||||
|
||||
// Get implements the Golang flag.Getter interface func.
|
||||
func (x *StorageType) Get() interface{} {
|
||||
return *x
|
||||
}
|
||||
|
||||
// Type implements the github.com/spf13/pFlag Value interface.
|
||||
func (x *StorageType) Type() string {
|
||||
return "StorageType"
|
||||
}
|
||||
5
pkg/enums/tasktype/tasktype.go
Normal file
5
pkg/enums/tasktype/tasktype.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package tasktype
|
||||
|
||||
//go:generate go-enum --values --names --flag --nocase
|
||||
// ENUM(tgfiles,tphpics)
|
||||
type TaskType string
|
||||
87
pkg/enums/tasktype/tasktype_enum.go
Normal file
87
pkg/enums/tasktype/tasktype_enum.go
Normal file
@@ -0,0 +1,87 @@
|
||||
// Code generated by go-enum DO NOT EDIT.
|
||||
// Version: 0.6.1
|
||||
// Revision: a6f63bddde05aca4221df9c8e9e6d7d9674b1cb4
|
||||
// Build Date: 2025-03-18T23:42:14Z
|
||||
// Built By: goreleaser
|
||||
|
||||
package tasktype
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
// TaskTypeTgfiles is a TaskType of type tgfiles.
|
||||
TaskTypeTgfiles TaskType = "tgfiles"
|
||||
// TaskTypeTphpics is a TaskType of type tphpics.
|
||||
TaskTypeTphpics TaskType = "tphpics"
|
||||
)
|
||||
|
||||
var ErrInvalidTaskType = fmt.Errorf("not a valid TaskType, try [%s]", strings.Join(_TaskTypeNames, ", "))
|
||||
|
||||
var _TaskTypeNames = []string{
|
||||
string(TaskTypeTgfiles),
|
||||
string(TaskTypeTphpics),
|
||||
}
|
||||
|
||||
// TaskTypeNames returns a list of possible string values of TaskType.
|
||||
func TaskTypeNames() []string {
|
||||
tmp := make([]string, len(_TaskTypeNames))
|
||||
copy(tmp, _TaskTypeNames)
|
||||
return tmp
|
||||
}
|
||||
|
||||
// TaskTypeValues returns a list of the values for TaskType
|
||||
func TaskTypeValues() []TaskType {
|
||||
return []TaskType{
|
||||
TaskTypeTgfiles,
|
||||
TaskTypeTphpics,
|
||||
}
|
||||
}
|
||||
|
||||
// String implements the Stringer interface.
|
||||
func (x TaskType) String() string {
|
||||
return string(x)
|
||||
}
|
||||
|
||||
// IsValid provides a quick way to determine if the typed value is
|
||||
// part of the allowed enumerated values
|
||||
func (x TaskType) IsValid() bool {
|
||||
_, err := ParseTaskType(string(x))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
var _TaskTypeValue = map[string]TaskType{
|
||||
"tgfiles": TaskTypeTgfiles,
|
||||
"tphpics": TaskTypeTphpics,
|
||||
}
|
||||
|
||||
// ParseTaskType attempts to convert a string to a TaskType.
|
||||
func ParseTaskType(name string) (TaskType, error) {
|
||||
if x, ok := _TaskTypeValue[name]; ok {
|
||||
return x, nil
|
||||
}
|
||||
// Case insensitive parse, do a separate lookup to prevent unnecessary cost of lowercasing a string if we don't need to.
|
||||
if x, ok := _TaskTypeValue[strings.ToLower(name)]; ok {
|
||||
return x, nil
|
||||
}
|
||||
return TaskType(""), fmt.Errorf("%s is %w", name, ErrInvalidTaskType)
|
||||
}
|
||||
|
||||
// Set implements the Golang flag.Value interface func.
|
||||
func (x *TaskType) Set(val string) error {
|
||||
v, err := ParseTaskType(val)
|
||||
*x = v
|
||||
return err
|
||||
}
|
||||
|
||||
// Get implements the Golang flag.Getter interface func.
|
||||
func (x *TaskType) Get() interface{} {
|
||||
return *x
|
||||
}
|
||||
|
||||
// Type implements the github.com/spf13/pFlag Value interface.
|
||||
func (x *TaskType) Type() string {
|
||||
return "TaskType"
|
||||
}
|
||||
241
pkg/queue/queue.go
Normal file
241
pkg/queue/queue.go
Normal file
@@ -0,0 +1,241 @@
|
||||
package queue
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type TaskQueue[T any] struct {
|
||||
tasks *list.List
|
||||
taskMap map[string]*Task[T]
|
||||
runningTaskMap map[string]*Task[T]
|
||||
mu sync.RWMutex
|
||||
cond *sync.Cond
|
||||
closed bool
|
||||
}
|
||||
|
||||
func NewTaskQueue[T any]() *TaskQueue[T] {
|
||||
tq := &TaskQueue[T]{
|
||||
tasks: list.New(),
|
||||
taskMap: make(map[string]*Task[T]),
|
||||
runningTaskMap: make(map[string]*Task[T]),
|
||||
}
|
||||
tq.cond = sync.NewCond(&tq.mu)
|
||||
return tq
|
||||
}
|
||||
|
||||
func (tq *TaskQueue[T]) Add(task *Task[T]) error {
|
||||
tq.mu.Lock()
|
||||
defer tq.mu.Unlock()
|
||||
|
||||
if tq.closed {
|
||||
return errors.New("queue is closed")
|
||||
}
|
||||
|
||||
if _, exists := tq.taskMap[task.ID]; exists {
|
||||
return fmt.Errorf("task with ID %s already exists", task.ID)
|
||||
}
|
||||
|
||||
if task.IsCancelled() {
|
||||
return fmt.Errorf("task %s has been cancelled", task.ID)
|
||||
}
|
||||
|
||||
element := tq.tasks.PushBack(task)
|
||||
task.element = element
|
||||
tq.taskMap[task.ID] = task
|
||||
|
||||
tq.cond.Signal()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tq *TaskQueue[T]) Get() (*Task[T], error) {
|
||||
tq.mu.Lock()
|
||||
defer tq.mu.Unlock()
|
||||
|
||||
for tq.tasks.Len() == 0 && !tq.closed {
|
||||
tq.cond.Wait()
|
||||
}
|
||||
|
||||
if tq.closed && tq.tasks.Len() == 0 {
|
||||
return nil, fmt.Errorf("queue is closed and empty")
|
||||
}
|
||||
|
||||
for tq.tasks.Len() > 0 {
|
||||
element := tq.tasks.Front()
|
||||
task := element.Value.(*Task[T])
|
||||
|
||||
tq.tasks.Remove(element)
|
||||
task.element = nil
|
||||
|
||||
if !task.IsCancelled() {
|
||||
tq.runningTaskMap[task.ID] = task
|
||||
return task, nil
|
||||
}
|
||||
}
|
||||
|
||||
if !tq.closed {
|
||||
return tq.Get()
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("queue is closed and empty")
|
||||
}
|
||||
|
||||
func (tq *TaskQueue[T]) Done(taskID string) {
|
||||
tq.mu.Lock()
|
||||
defer tq.mu.Unlock()
|
||||
|
||||
delete(tq.taskMap, taskID)
|
||||
delete(tq.runningTaskMap, taskID)
|
||||
}
|
||||
|
||||
func (tq *TaskQueue[T]) Peek() (*Task[T], error) {
|
||||
tq.mu.RLock()
|
||||
defer tq.mu.RUnlock()
|
||||
|
||||
if tq.tasks.Len() == 0 {
|
||||
return nil, fmt.Errorf("queue is empty")
|
||||
}
|
||||
|
||||
for element := tq.tasks.Front(); element != nil; element = element.Next() {
|
||||
task := element.Value.(*Task[T])
|
||||
if !task.IsCancelled() {
|
||||
return task, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("queue has no valid tasks")
|
||||
}
|
||||
|
||||
func (tq *TaskQueue[T]) Length() int {
|
||||
tq.mu.RLock()
|
||||
defer tq.mu.RUnlock()
|
||||
return tq.tasks.Len()
|
||||
}
|
||||
|
||||
func (tq *TaskQueue[T]) ActiveLength() int {
|
||||
tq.mu.RLock()
|
||||
defer tq.mu.RUnlock()
|
||||
|
||||
count := 0
|
||||
for element := tq.tasks.Front(); element != nil; element = element.Next() {
|
||||
task := element.Value.(*Task[T])
|
||||
if !task.IsCancelled() {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func (tq *TaskQueue[T]) CancelTask(taskID string) error {
|
||||
tq.mu.RLock()
|
||||
task, exists := tq.taskMap[taskID]
|
||||
if !exists {
|
||||
task, exists = tq.runningTaskMap[taskID]
|
||||
}
|
||||
tq.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return fmt.Errorf("task %s does not exist", taskID)
|
||||
}
|
||||
|
||||
task.Cancel()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tq *TaskQueue[T]) RemoveTask(taskID string) error {
|
||||
tq.mu.Lock()
|
||||
defer tq.mu.Unlock()
|
||||
|
||||
task, exists := tq.taskMap[taskID]
|
||||
if !exists {
|
||||
_, exists = tq.runningTaskMap[taskID]
|
||||
if exists {
|
||||
delete(tq.runningTaskMap, taskID)
|
||||
}
|
||||
return fmt.Errorf("task %s is already running, cannot remove from queue", taskID)
|
||||
}
|
||||
|
||||
if task.element != nil {
|
||||
tq.tasks.Remove(task.element)
|
||||
}
|
||||
delete(tq.taskMap, taskID)
|
||||
task.Cancel()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tq *TaskQueue[T]) CancelAll() {
|
||||
tq.mu.RLock()
|
||||
tasks := make([]*Task[T], 0, tq.tasks.Len())
|
||||
for element := tq.tasks.Front(); element != nil; element = element.Next() {
|
||||
tasks = append(tasks, element.Value.(*Task[T]))
|
||||
}
|
||||
tq.mu.RUnlock()
|
||||
|
||||
for _, task := range tasks {
|
||||
task.Cancel()
|
||||
}
|
||||
}
|
||||
|
||||
func (tq *TaskQueue[T]) GetTask(taskID string) (*Task[T], error) {
|
||||
tq.mu.RLock()
|
||||
defer tq.mu.RUnlock()
|
||||
|
||||
task, exists := tq.taskMap[taskID]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("task %s does not exist", taskID)
|
||||
}
|
||||
|
||||
return task, nil
|
||||
}
|
||||
|
||||
func (tq *TaskQueue[T]) Close() {
|
||||
tq.mu.Lock()
|
||||
defer tq.mu.Unlock()
|
||||
|
||||
tq.closed = true
|
||||
tq.cond.Broadcast()
|
||||
}
|
||||
|
||||
func (tq *TaskQueue[T]) IsClosed() bool {
|
||||
tq.mu.RLock()
|
||||
defer tq.mu.RUnlock()
|
||||
return tq.closed
|
||||
}
|
||||
|
||||
func (tq *TaskQueue[T]) Clear() {
|
||||
tq.mu.Lock()
|
||||
defer tq.mu.Unlock()
|
||||
|
||||
for element := tq.tasks.Front(); element != nil; element = element.Next() {
|
||||
task := element.Value.(*Task[T])
|
||||
task.Cancel()
|
||||
}
|
||||
|
||||
tq.tasks.Init()
|
||||
tq.taskMap = make(map[string]*Task[T])
|
||||
}
|
||||
|
||||
func (tq *TaskQueue[T]) CleanupCancelled() int {
|
||||
tq.mu.Lock()
|
||||
defer tq.mu.Unlock()
|
||||
|
||||
removed := 0
|
||||
element := tq.tasks.Front()
|
||||
|
||||
for element != nil {
|
||||
next := element.Next()
|
||||
task := element.Value.(*Task[T])
|
||||
|
||||
if task.IsCancelled() {
|
||||
tq.tasks.Remove(element)
|
||||
delete(tq.taskMap, task.ID)
|
||||
removed++
|
||||
}
|
||||
|
||||
element = next
|
||||
}
|
||||
|
||||
return removed
|
||||
}
|
||||
172
pkg/queue/queue_test.go
Normal file
172
pkg/queue/queue_test.go
Normal file
@@ -0,0 +1,172 @@
|
||||
package queue_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/pkg/queue"
|
||||
)
|
||||
|
||||
// helper to create a simple Task with integer payload
|
||||
func newTask(id string) *queue.Task[int] {
|
||||
return queue.NewTask(context.Background(), id, 0)
|
||||
}
|
||||
|
||||
func TestAddAndLength(t *testing.T) {
|
||||
q := queue.NewTaskQueue[int]()
|
||||
if q.Length() != 0 {
|
||||
t.Fatalf("expected length 0, got %d", q.Length())
|
||||
}
|
||||
t1 := newTask("t1")
|
||||
if err := q.Add(t1); err != nil {
|
||||
t.Fatalf("unexpected error on Add: %v", err)
|
||||
}
|
||||
if q.Length() != 1 {
|
||||
t.Fatalf("expected length 1, got %d", q.Length())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDuplicateAdd(t *testing.T) {
|
||||
q := queue.NewTaskQueue[int]()
|
||||
t1 := newTask("dup")
|
||||
if err := q.Add(t1); err != nil {
|
||||
t.Fatalf("unexpected error on first Add: %v", err)
|
||||
}
|
||||
if err := q.Add(t1); err == nil {
|
||||
t.Fatal("expected error on duplicate Add, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAndPeek(t *testing.T) {
|
||||
q := queue.NewTaskQueue[int]()
|
||||
t1 := newTask("a")
|
||||
t2 := newTask("b")
|
||||
q.Add(t1)
|
||||
q.Add(t2)
|
||||
// Peek should return t1
|
||||
peeked, err := q.Peek()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error on Peek: %v", err)
|
||||
}
|
||||
if peeked.ID != "a" {
|
||||
t.Fatalf("expected Peek ID 'a', got '%s'", peeked.ID)
|
||||
}
|
||||
// Get should return t1 then t2
|
||||
first, err := q.Get()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error on Get: %v", err)
|
||||
}
|
||||
if first.ID != "a" {
|
||||
t.Fatalf("expected first Get ID 'a', got '%s'", first.ID)
|
||||
}
|
||||
second, err := q.Get()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error on second Get: %v", err)
|
||||
}
|
||||
if second.ID != "b" {
|
||||
t.Fatalf("expected second Get ID 'b', got '%s'", second.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCancelAndActiveLength(t *testing.T) {
|
||||
q := queue.NewTaskQueue[int]()
|
||||
t1 := newTask("1")
|
||||
t2 := newTask("2")
|
||||
q.Add(t1)
|
||||
q.Add(t2)
|
||||
// Cancel t1
|
||||
if err := q.CancelTask("1"); err != nil {
|
||||
t.Fatalf("unexpected error on CancelTask: %v", err)
|
||||
}
|
||||
// Length counts all entries
|
||||
if q.Length() != 2 {
|
||||
t.Fatalf("expected total length 2, got %d", q.Length())
|
||||
}
|
||||
// ActiveLength skips cancelled
|
||||
if got := q.ActiveLength(); got != 1 {
|
||||
t.Fatalf("expected active length 1, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveTask(t *testing.T) {
|
||||
q := queue.NewTaskQueue[int]()
|
||||
t1 := newTask("r1")
|
||||
q.Add(t1)
|
||||
if err := q.RemoveTask("r1"); err != nil {
|
||||
t.Fatalf("unexpected error on RemoveTask: %v", err)
|
||||
}
|
||||
if q.Length() != 0 {
|
||||
t.Fatalf("expected length 0 after remove, got %d", q.Length())
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearAndCleanupCancelled(t *testing.T) {
|
||||
q := queue.NewTaskQueue[int]()
|
||||
tasks := []*queue.Task[int]{newTask("c1"), newTask("c2"), newTask("c3")}
|
||||
for _, tsk := range tasks {
|
||||
q.Add(tsk)
|
||||
}
|
||||
// Cancel one
|
||||
q.CancelTask("c2")
|
||||
// Cleanup cancelled
|
||||
removed := q.CleanupCancelled()
|
||||
if removed != 1 {
|
||||
t.Fatalf("expected removed 1, got %d", removed)
|
||||
}
|
||||
if q.ActiveLength() != 2 {
|
||||
t.Fatalf("expected active length 2 after cleanup, got %d", q.ActiveLength())
|
||||
}
|
||||
// Clear all
|
||||
q.Clear()
|
||||
if q.Length() != 0 {
|
||||
t.Fatalf("expected length 0 after clear, got %d", q.Length())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseBehavior(t *testing.T) {
|
||||
q := queue.NewTaskQueue[int]()
|
||||
done := make(chan struct{})
|
||||
// consumer
|
||||
go func() {
|
||||
_, err := q.Get()
|
||||
if err == nil {
|
||||
t.Errorf("expected error when getting from closed empty queue, got nil")
|
||||
}
|
||||
close(done)
|
||||
}()
|
||||
// allow goroutine to block
|
||||
|
||||
// close queue
|
||||
q.Close()
|
||||
<-done
|
||||
}
|
||||
|
||||
func TestConcurrencySafety(t *testing.T) {
|
||||
q := queue.NewTaskQueue[int]()
|
||||
var wg sync.WaitGroup
|
||||
n := 1000
|
||||
// producers
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < n; i++ {
|
||||
q.Add(newTask(fmt.Sprintf("p%d", i)))
|
||||
}
|
||||
}()
|
||||
// consumers
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
count := 0
|
||||
for count < n {
|
||||
_, err := q.Get()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
count++
|
||||
}
|
||||
}()
|
||||
wg.Wait()
|
||||
}
|
||||
44
pkg/queue/task.go
Normal file
44
pkg/queue/task.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package queue
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Task[T any] struct {
|
||||
ID string
|
||||
Data T
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
created time.Time
|
||||
element *list.Element
|
||||
}
|
||||
|
||||
func NewTask[T any](ctx context.Context, id string, data T) *Task[T] {
|
||||
cancelCtx, cancel := context.WithCancel(ctx)
|
||||
return &Task[T]{
|
||||
ID: id,
|
||||
Data: data,
|
||||
ctx: cancelCtx,
|
||||
cancel: cancel,
|
||||
created: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Task[T]) IsCancelled() bool {
|
||||
select {
|
||||
case <-t.ctx.Done():
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Task[T]) Cancel() {
|
||||
t.cancel()
|
||||
}
|
||||
|
||||
func (t *Task[T]) Context() context.Context {
|
||||
return t.ctx
|
||||
}
|
||||
45
pkg/rule/filename_regex.go
Normal file
45
pkg/rule/filename_regex.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package rule
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
|
||||
ruleenum "github.com/krau/SaveAny-Bot/pkg/enums/rule"
|
||||
"github.com/krau/SaveAny-Bot/pkg/tfile"
|
||||
)
|
||||
|
||||
type RuleFileNameRegex struct {
|
||||
storInfo
|
||||
regex *regexp.Regexp
|
||||
}
|
||||
|
||||
var _ RuleClass[tfile.TGFile] = (*RuleFileNameRegex)(nil)
|
||||
|
||||
func (r RuleFileNameRegex) Type() ruleenum.RuleType {
|
||||
return ruleenum.FileNameRegex
|
||||
}
|
||||
|
||||
func (r RuleFileNameRegex) Match(input tfile.TGFile) (bool, error) {
|
||||
return r.regex.MatchString(input.Name()), nil
|
||||
}
|
||||
|
||||
func (r RuleFileNameRegex) StorageName() string {
|
||||
return r.storName
|
||||
}
|
||||
|
||||
func (r RuleFileNameRegex) StoragePath() string {
|
||||
return r.storPath
|
||||
}
|
||||
|
||||
func NewRuleFileNameRegex(storName, storPath, regexStr string) (*RuleFileNameRegex, error) {
|
||||
regex, err := regexp.Compile(regexStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &RuleFileNameRegex{
|
||||
storInfo: storInfo{
|
||||
storName: storName,
|
||||
storPath: storPath,
|
||||
},
|
||||
regex: regex,
|
||||
}, nil
|
||||
}
|
||||
43
pkg/rule/message_regex.go
Normal file
43
pkg/rule/message_regex.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package rule
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
|
||||
ruleenum "github.com/krau/SaveAny-Bot/pkg/enums/rule"
|
||||
)
|
||||
|
||||
var _ RuleClass[string] = (*RuleMessageRegex)(nil)
|
||||
|
||||
type RuleMessageRegex struct {
|
||||
storInfo
|
||||
regex *regexp.Regexp
|
||||
}
|
||||
|
||||
func (r RuleMessageRegex) Type() ruleenum.RuleType {
|
||||
return ruleenum.MessageRegex
|
||||
}
|
||||
|
||||
func (r RuleMessageRegex) Match(input string) (bool, error) {
|
||||
return r.regex.MatchString(input), nil
|
||||
}
|
||||
|
||||
func (r RuleMessageRegex) StorageName() string {
|
||||
return r.storName
|
||||
}
|
||||
func (r RuleMessageRegex) StoragePath() string {
|
||||
return r.storPath
|
||||
}
|
||||
|
||||
func NewRuleMessageRegex(storName, storPath, regexStr string) (*RuleMessageRegex, error) {
|
||||
regex, err := regexp.Compile(regexStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &RuleMessageRegex{
|
||||
storInfo: storInfo{
|
||||
storName: storName,
|
||||
storPath: storPath,
|
||||
},
|
||||
regex: regex,
|
||||
}, nil
|
||||
}
|
||||
17
pkg/rule/rule.go
Normal file
17
pkg/rule/rule.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package rule
|
||||
|
||||
import (
|
||||
ruleenum "github.com/krau/SaveAny-Bot/pkg/enums/rule"
|
||||
)
|
||||
|
||||
type RuleClass[InputType any] interface {
|
||||
Type() ruleenum.RuleType
|
||||
Match(input InputType) (bool, error)
|
||||
StorageName() string
|
||||
StoragePath() string
|
||||
}
|
||||
|
||||
type storInfo struct {
|
||||
storName string
|
||||
storPath string
|
||||
}
|
||||
44
pkg/tcbdata/data.go
Normal file
44
pkg/tcbdata/data.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package tcbdata
|
||||
|
||||
import (
|
||||
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
|
||||
"github.com/krau/SaveAny-Bot/pkg/telegraph"
|
||||
"github.com/krau/SaveAny-Bot/pkg/tfile"
|
||||
)
|
||||
|
||||
const (
|
||||
TypeAdd = "add"
|
||||
TypeSetDefault = "setdefault"
|
||||
)
|
||||
|
||||
// type TaskDataTGFiles struct {
|
||||
// Files []tfile.TGFileMessage
|
||||
// AsBatch bool
|
||||
// }
|
||||
|
||||
// type TaskDataTelegraph struct {
|
||||
// Pics []string
|
||||
// PageNode *telegraph.Page
|
||||
// }
|
||||
|
||||
// type TaskDataType interface {
|
||||
// TaskDataTGFiles | TaskDataTelegraph
|
||||
// }
|
||||
|
||||
type Add struct {
|
||||
TaskType tasktype.TaskType
|
||||
SelectedStorName string
|
||||
DirID uint
|
||||
SettedDir bool
|
||||
// tfiles
|
||||
Files []tfile.TGFileMessage
|
||||
AsBatch bool
|
||||
// tphpics
|
||||
TphPageNode *telegraph.Page
|
||||
TphPics []string
|
||||
TphDirPath string // unescaped telegraph.Page.Path
|
||||
}
|
||||
|
||||
type SetDefaultStorage struct {
|
||||
StorageName string
|
||||
}
|
||||
150
pkg/telegraph/client.go
Normal file
150
pkg/telegraph/client.go
Normal file
@@ -0,0 +1,150 @@
|
||||
// https://github.com/celestix/telegraph-go
|
||||
|
||||
package telegraph
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Page object represents a page on Telegraph.
|
||||
type Page struct {
|
||||
// Path to the page.
|
||||
Path string `json:"path"`
|
||||
// URL of the page.
|
||||
Url string `json:"url"`
|
||||
// Title of the page.
|
||||
Title string `json:"title"`
|
||||
// Description of the page.
|
||||
Description string `json:"description"`
|
||||
// Optional. Name of the author, displayed below the title.
|
||||
AuthorName string `json:"author_name,omitempty"`
|
||||
// Optional. Profile link, opened when users click on the author's name below the title. Can be any link, not necessarily to a Telegram profile or channel.
|
||||
AuthorUrl string `json:"author_url,omitempty"`
|
||||
// Optional. Image URL of the page.
|
||||
ImageUrl string `json:"image_url,omitempty"`
|
||||
// Optional. Content of the page.
|
||||
Content []Node `json:"content,omitempty"`
|
||||
// Number of page views for the page.
|
||||
Views int64 `json:"views"`
|
||||
// Optional. Only returned if access_token passed. True, if the target Telegraph account can edit the page.
|
||||
CanEdit bool `json:"can_edit,omitempty"`
|
||||
}
|
||||
|
||||
// Node is abstract object represents a DOM Node. It can be a String which represents a DOM text node or a
|
||||
// NodeElement object.
|
||||
type Node any
|
||||
|
||||
// NodeElement represents a DOM element node.
|
||||
type NodeElement struct {
|
||||
// Name of the DOM element. Available tags: a, aside, b, blockquote, br, code, em, figcaption, figure,
|
||||
// h3, h4, hr, i, iframe, img, li, ol, p, pre, s, strong, u, ul, video.Client
|
||||
Tag string `json:"tag"`
|
||||
|
||||
// Attributes of the DOM element. Key of object represents name of attribute, value represents value
|
||||
// of attribute. Available attributes: href, src.
|
||||
Attrs map[string]string `json:"attrs,omitempty"`
|
||||
|
||||
// List of child nodes for the DOM element.
|
||||
Children []Node `json:"children,omitempty"`
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
type Body struct {
|
||||
// Ok: if true, request was successful, and result can be found in the Result field.
|
||||
// If false, error can be explained in Error field.
|
||||
Ok bool `json:"ok"`
|
||||
// Error: contains a human-readable description of the error result.
|
||||
Error string `json:"error"`
|
||||
// Result: result of requests (if Ok)
|
||||
Result json.RawMessage `json:"result"`
|
||||
}
|
||||
|
||||
const (
|
||||
ApiUrl = "https://api.telegra.ph/"
|
||||
)
|
||||
|
||||
func (c *Client) InvokeRequest(ctx context.Context, method string, params url.Values) (json.RawMessage, error) {
|
||||
r, err := http.NewRequestWithContext(ctx, http.MethodPost, ApiUrl+method, strings.NewReader(params.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build POST request to %s: %w", method, err)
|
||||
}
|
||||
|
||||
resp, err := c.client.Do(r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute POST request to %s: %w", method, err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
var b Body
|
||||
if err = json.NewDecoder(resp.Body).Decode(&b); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response from %s: %w", method, err)
|
||||
}
|
||||
if !b.Ok {
|
||||
return nil, fmt.Errorf("failed to %s: %s", method, b.Error)
|
||||
}
|
||||
return b.Result, nil
|
||||
}
|
||||
|
||||
func (c *Client) GetPage(ctx context.Context, phpath string) (*Page, error) {
|
||||
var (
|
||||
u = url.Values{}
|
||||
a Page
|
||||
)
|
||||
u.Add("path", phpath)
|
||||
u.Add("return_content", "true")
|
||||
r, err := c.InvokeRequest(ctx, "getPage", u)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &a, json.Unmarshal(r, &a)
|
||||
}
|
||||
|
||||
// Helper to use the client(*http.Client) to download a file from a given URL.
|
||||
func (c *Client) Download(ctx context.Context, durl string) (io.ReadCloser, error) {
|
||||
r, err := http.NewRequestWithContext(ctx, http.MethodGet, durl, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := c.client.Do(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("failed to download file from %s: %s", durl, resp.Status)
|
||||
}
|
||||
return resp.Body, nil
|
||||
}
|
||||
|
||||
func NewClient() *Client {
|
||||
return &Client{
|
||||
client: &http.Client{},
|
||||
}
|
||||
}
|
||||
|
||||
func NewClientWithProxy(proxyUrl string) (*Client, error) {
|
||||
u, err := url.Parse(proxyUrl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p := http.ProxyURL(u)
|
||||
httpClient := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: p,
|
||||
},
|
||||
}
|
||||
return &Client{
|
||||
client: httpClient,
|
||||
}, nil
|
||||
}
|
||||
38
pkg/tfile/opts.go
Normal file
38
pkg/tfile/opts.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package tfile
|
||||
|
||||
import "github.com/gotd/td/tg"
|
||||
|
||||
type TGFileOptions func(*tgFile)
|
||||
|
||||
func WithMessage(msg *tg.Message) TGFileOptions {
|
||||
return func(f *tgFile) {
|
||||
f.message = msg
|
||||
}
|
||||
}
|
||||
func WithName(name string) TGFileOptions {
|
||||
return func(f *tgFile) {
|
||||
f.name = name
|
||||
}
|
||||
}
|
||||
|
||||
func WithNameIfEmpty(name string) TGFileOptions {
|
||||
return func(f *tgFile) {
|
||||
if f.name == "" {
|
||||
f.name = name
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func WithSize(size int64) TGFileOptions {
|
||||
return func(f *tgFile) {
|
||||
f.size = size
|
||||
}
|
||||
}
|
||||
|
||||
func WithSizeIfZero(size int64) TGFileOptions {
|
||||
return func(f *tgFile) {
|
||||
if f.size == 0 {
|
||||
f.size = size
|
||||
}
|
||||
}
|
||||
}
|
||||
126
pkg/tfile/tgfile.go
Normal file
126
pkg/tfile/tgfile.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package tfile
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/gotd/td/tg"
|
||||
)
|
||||
|
||||
type TGFile interface {
|
||||
Location() tg.InputFileLocationClass
|
||||
Size() int64
|
||||
Name() string
|
||||
}
|
||||
|
||||
type TGFileMessage interface {
|
||||
TGFile
|
||||
Message() *tg.Message
|
||||
}
|
||||
|
||||
type tgFile struct {
|
||||
location tg.InputFileLocationClass
|
||||
size int64
|
||||
name string
|
||||
message *tg.Message
|
||||
}
|
||||
|
||||
func (f *tgFile) Location() tg.InputFileLocationClass {
|
||||
return f.location
|
||||
}
|
||||
|
||||
func (f *tgFile) Size() int64 {
|
||||
return f.size
|
||||
}
|
||||
|
||||
func (f *tgFile) Name() string {
|
||||
return f.name
|
||||
}
|
||||
|
||||
func (f *tgFile) Message() *tg.Message {
|
||||
return f.message
|
||||
}
|
||||
|
||||
func NewTGFile(location tg.InputFileLocationClass, size int64, name string,
|
||||
opts ...TGFileOptions,
|
||||
) TGFile {
|
||||
f := &tgFile{
|
||||
location: location,
|
||||
size: size,
|
||||
name: name,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(f)
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
func FromMedia(media tg.MessageMediaClass, opts ...TGFileOptions) (TGFile, error) {
|
||||
switch m := media.(type) {
|
||||
case *tg.MessageMediaDocument:
|
||||
document, ok := m.Document.AsNotEmpty()
|
||||
if !ok {
|
||||
return nil, errors.New("document is empty")
|
||||
}
|
||||
fileName := ""
|
||||
for _, attribute := range document.Attributes {
|
||||
if name, ok := attribute.(*tg.DocumentAttributeFilename); ok {
|
||||
fileName = name.GetFileName()
|
||||
break
|
||||
}
|
||||
}
|
||||
file := &tgFile{
|
||||
location: document.AsInputDocumentFileLocation(),
|
||||
size: document.Size,
|
||||
name: fileName,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(file)
|
||||
}
|
||||
return file, nil
|
||||
case *tg.MessageMediaPhoto:
|
||||
photo, ok := m.Photo.AsNotEmpty()
|
||||
if !ok {
|
||||
return nil, errors.New("photo is empty")
|
||||
}
|
||||
sizes := photo.Sizes
|
||||
if len(sizes) == 0 {
|
||||
return nil, errors.New("photo sizes are empty")
|
||||
}
|
||||
photoSize := sizes[len(sizes)-1]
|
||||
size, ok := photoSize.AsNotEmpty()
|
||||
if !ok {
|
||||
return nil, errors.New("photo size is empty")
|
||||
}
|
||||
location := new(tg.InputPhotoFileLocation)
|
||||
location.ID = photo.GetID()
|
||||
location.AccessHash = photo.GetAccessHash()
|
||||
location.FileReference = photo.GetFileReference()
|
||||
location.ThumbSize = size.GetType()
|
||||
fileName := fmt.Sprintf("photo_%s_%d.jpg", time.Now().Format("2006-01-02_15-04-05"), photo.GetID())
|
||||
file := &tgFile{
|
||||
location: location,
|
||||
size: 0,
|
||||
name: fileName,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(file)
|
||||
}
|
||||
return file, nil
|
||||
}
|
||||
return nil, fmt.Errorf("unsupported media type: %T", media)
|
||||
}
|
||||
|
||||
func FromMediaMessage(media tg.MessageMediaClass, msg *tg.Message, opts ...TGFileOptions) (TGFileMessage, error) {
|
||||
file, err := FromMedia(media, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &tgFile{
|
||||
location: file.Location(),
|
||||
size: file.Size(),
|
||||
name: file.Name(),
|
||||
message: msg,
|
||||
}, nil
|
||||
}
|
||||
Reference in New Issue
Block a user