refactor: complete core features
This commit is contained in:
@@ -14,11 +14,9 @@ import (
|
||||
"github.com/celestix/gotgproto/dispatcher/handlers"
|
||||
"github.com/celestix/gotgproto/dispatcher/handlers/filters"
|
||||
"github.com/celestix/gotgproto/ext"
|
||||
"github.com/krau/SaveAny-Bot/common"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/dao"
|
||||
"github.com/krau/SaveAny-Bot/logger"
|
||||
"github.com/krau/SaveAny-Bot/model"
|
||||
"github.com/krau/SaveAny-Bot/queue"
|
||||
"github.com/krau/SaveAny-Bot/storage"
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
@@ -41,7 +39,8 @@ const noPermissionText string = `
|
||||
`
|
||||
|
||||
func checkPermission(ctx *ext.Context, update *ext.Update) error {
|
||||
if !slice.Contain(config.Cfg.Telegram.Admins, update.EffectiveUser().ID) {
|
||||
userID := update.GetUserChat().GetID()
|
||||
if !slice.Contain(config.Cfg.Telegram.Admins, userID) {
|
||||
ctx.Reply(update, noPermissionText, nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
@@ -49,7 +48,7 @@ func checkPermission(ctx *ext.Context, update *ext.Update) error {
|
||||
}
|
||||
|
||||
func start(ctx *ext.Context, update *ext.Update) error {
|
||||
if err := dao.CreateUser(update.EffectiveUser().ID); err != nil {
|
||||
if err := dao.CreateUser(update.GetUserChat().GetID()); err != nil {
|
||||
logger.L.Errorf("Failed to create user: %s", err)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
@@ -74,7 +73,7 @@ func help(ctx *ext.Context, update *ext.Update) error {
|
||||
}
|
||||
|
||||
func silent(ctx *ext.Context, update *ext.Update) error {
|
||||
user, err := dao.GetUserByUserID(update.EffectiveUser().ID)
|
||||
user, err := dao.GetUserByUserID(update.GetUserChat().GetID())
|
||||
if err != nil {
|
||||
logger.L.Errorf("Failed to get user: %s", err)
|
||||
return dispatcher.EndGroups
|
||||
@@ -116,7 +115,7 @@ func setDefaultStorage(ctx *ext.Context, update *ext.Update) error {
|
||||
ctx.Reply(update, "存储位置不存在", nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
user, err := dao.GetUserByUserID(update.EffectiveUser().ID)
|
||||
user, err := dao.GetUserByUserID(update.GetUserChat().GetID())
|
||||
if err != nil {
|
||||
logger.L.Errorf("Failed to get user: %s", err)
|
||||
return dispatcher.EndGroups
|
||||
@@ -145,7 +144,7 @@ func handleFileMessage(ctx *ext.Context, update *ext.Update) error {
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
|
||||
user, err := dao.GetUserByUserID(update.EffectiveUser().ID)
|
||||
user, err := dao.GetUserByUserID(update.GetUserChat().GetID())
|
||||
if err != nil {
|
||||
logger.L.Errorf("Failed to get user: %s", err)
|
||||
return dispatcher.EndGroups
|
||||
@@ -157,7 +156,7 @@ func handleFileMessage(ctx *ext.Context, update *ext.Update) error {
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
media := update.EffectiveMessage.Media
|
||||
file, err := common.FileFromMedia(media)
|
||||
file, err := FileFromMedia(media)
|
||||
if err != nil {
|
||||
logger.L.Errorf("Failed to get file from media: %s", err)
|
||||
ctx.Reply(update, "无法获取文件", nil)
|
||||
@@ -168,7 +167,7 @@ func handleFileMessage(ctx *ext.Context, update *ext.Update) error {
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
|
||||
if err := dao.AddReceivedFile(&model.ReceivedFile{
|
||||
if err := dao.AddReceivedFile(&types.ReceivedFile{
|
||||
Processing: false,
|
||||
FileName: file.FileName,
|
||||
ChatID: update.EffectiveChat().GetID(),
|
||||
@@ -210,7 +209,7 @@ func handleFileMessage(ctx *ext.Context, update *ext.Update) error {
|
||||
queue.AddTask(types.Task{
|
||||
Ctx: ctx,
|
||||
Status: types.Pending,
|
||||
FileName: file.FileName,
|
||||
File: file,
|
||||
Storage: types.StorageType(user.DefaultStorage),
|
||||
ChatID: update.EffectiveChat().GetID(),
|
||||
ReplyMessageID: msg.ID,
|
||||
@@ -234,17 +233,29 @@ func AddToQueue(ctx *ext.Context, update *ext.Update) error {
|
||||
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
||||
QueryID: update.CallbackQuery.QueryID,
|
||||
Alert: true,
|
||||
Message: "无法添加到队列",
|
||||
Message: "查询记录失败",
|
||||
CacheTime: 5,
|
||||
})
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
file, err := FileFromMessage(ctx, Client, record.ChatID, record.MessageID)
|
||||
if err != nil {
|
||||
logger.L.Errorf("Failed to get file from message: %s", err)
|
||||
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
||||
QueryID: update.CallbackQuery.QueryID,
|
||||
Alert: true,
|
||||
Message: "获取消息文件失败",
|
||||
CacheTime: 5,
|
||||
})
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
|
||||
queue.AddTask(types.Task{
|
||||
Ctx: ctx,
|
||||
Status: types.Pending,
|
||||
FileName: record.FileName,
|
||||
File: file,
|
||||
Storage: types.StorageType(args[2]),
|
||||
ChatID: update.EffectiveChat().GetID(),
|
||||
ChatID: record.ChatID,
|
||||
ReplyMessageID: record.ReplyMessageID,
|
||||
MessageID: record.MessageID,
|
||||
})
|
||||
|
||||
75
bot/utils.go
75
bot/utils.go
@@ -1,15 +1,20 @@
|
||||
package bot
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/celestix/gotgproto"
|
||||
"github.com/celestix/gotgproto/dispatcher"
|
||||
"github.com/celestix/gotgproto/types"
|
||||
tgTypes "github.com/celestix/gotgproto/types"
|
||||
"github.com/gotd/td/tg"
|
||||
"github.com/krau/SaveAny-Bot/common"
|
||||
"github.com/krau/SaveAny-Bot/logger"
|
||||
"github.com/krau/SaveAny-Bot/storage"
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
)
|
||||
|
||||
func supportedMediaFilter(m *types.Message) (bool, error) {
|
||||
func supportedMediaFilter(m *tgTypes.Message) (bool, error) {
|
||||
if not := m.Media == nil; not {
|
||||
return false, dispatcher.EndGroups
|
||||
}
|
||||
@@ -69,3 +74,69 @@ func getAddTaskMarkup(messageID int) *tg.ReplyInlineMarkup {
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func FileFromMedia(media tg.MessageMediaClass) (*types.File, error) {
|
||||
switch media := media.(type) {
|
||||
case *tg.MessageMediaDocument:
|
||||
document, ok := media.Document.AsNotEmpty()
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected type %T", media)
|
||||
}
|
||||
var fileName string
|
||||
for _, attribute := range document.Attributes {
|
||||
if name, ok := attribute.(*tg.DocumentAttributeFilename); ok {
|
||||
fileName = name.FileName
|
||||
break
|
||||
}
|
||||
}
|
||||
return &types.File{
|
||||
Location: document.AsInputDocumentFileLocation(),
|
||||
FileSize: document.Size,
|
||||
FileName: fileName,
|
||||
MimeType: document.MimeType,
|
||||
ID: document.ID,
|
||||
}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected type %T", media)
|
||||
}
|
||||
|
||||
func FileFromMessage(ctx context.Context, client *gotgproto.Client, chatID int64, messageID int) (*types.File, error) {
|
||||
key := fmt.Sprintf("file:%d:%d", chatID, messageID)
|
||||
logger.L.Debugf("Getting file: %s", key)
|
||||
var cachedFile types.File
|
||||
err := common.Cache.Get(key, &cachedFile)
|
||||
if err == nil {
|
||||
return &cachedFile, nil
|
||||
}
|
||||
|
||||
message, err := GetTGMessage(ctx, client, messageID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
file, err := FileFromMedia(message.Media)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := common.Cache.Set(key, file, 3600); err != nil {
|
||||
logger.L.Errorf("Failed to cache file: %s", err)
|
||||
}
|
||||
return file, nil
|
||||
}
|
||||
|
||||
func GetTGMessage(ctx context.Context, client *gotgproto.Client, messageID int) (*tg.Message, error) {
|
||||
logger.L.Debugf("Fetching message: %d", messageID)
|
||||
res, err := client.API().MessagesGetMessages(ctx, []tg.InputMessageClass{
|
||||
&tg.InputMessageID{
|
||||
ID: messageID,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
messages := res.(*tg.MessagesMessages)
|
||||
msg := messages.Messages[0]
|
||||
if _, ok := msg.(*tg.Message); !ok {
|
||||
return nil, fmt.Errorf("unexpected type %T, this file may be deleted", msg)
|
||||
}
|
||||
return msg.(*tg.Message), nil
|
||||
}
|
||||
|
||||
63
common/cache.go
Normal file
63
common/cache.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/gob"
|
||||
"sync"
|
||||
|
||||
"github.com/coocood/freecache"
|
||||
"github.com/gotd/td/tg"
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
)
|
||||
|
||||
type CommonCache struct {
|
||||
cache *freecache.Cache
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
var Cache *CommonCache
|
||||
|
||||
func initCache() {
|
||||
gob.Register(types.File{})
|
||||
gob.Register(tg.InputDocumentFileLocation{})
|
||||
Cache = &CommonCache{cache: freecache.NewCache(10 * 1024 * 1024)}
|
||||
}
|
||||
|
||||
func GetCache() *CommonCache {
|
||||
return Cache
|
||||
}
|
||||
|
||||
func (c *CommonCache) Get(key string, value *types.File) error {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
data, err := Cache.cache.Get([]byte(key))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dec := gob.NewDecoder(bytes.NewReader(data))
|
||||
err = dec.Decode(&value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonCache) Set(key string, value *types.File, expireSeconds int) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
var buf bytes.Buffer
|
||||
enc := gob.NewEncoder(&buf)
|
||||
err := enc.Encode(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
Cache.cache.Set([]byte(key), buf.Bytes(), expireSeconds)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonCache) Delete(key string) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
Cache.cache.Del([]byte(key))
|
||||
return nil
|
||||
}
|
||||
@@ -2,4 +2,5 @@ package common
|
||||
|
||||
func Init() {
|
||||
initClient()
|
||||
initCache()
|
||||
}
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/gotd/td/tg"
|
||||
"github.com/krau/SaveAny-Bot/logger"
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
)
|
||||
|
||||
func FileFromMedia(media tg.MessageMediaClass) (*types.File, error) {
|
||||
logger.L.Debug("FileFromMedia")
|
||||
switch media := media.(type) {
|
||||
case *tg.MessageMediaDocument:
|
||||
document, ok := media.Document.AsNotEmpty()
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected type %T", media)
|
||||
}
|
||||
var fileName string
|
||||
for _, attribute := range document.Attributes {
|
||||
if name, ok := attribute.(*tg.DocumentAttributeFilename); ok {
|
||||
fileName = name.FileName
|
||||
break
|
||||
}
|
||||
}
|
||||
return &types.File{
|
||||
Location: document.AsInputDocumentFileLocation(),
|
||||
FileSize: document.Size,
|
||||
FileName: fileName,
|
||||
MimeType: document.MimeType,
|
||||
ID: document.ID,
|
||||
}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected type %T", media)
|
||||
}
|
||||
94
core/core.go
94
core/core.go
@@ -3,57 +3,73 @@ package core
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/celestix/gotgproto/ext"
|
||||
"github.com/gotd/td/tg"
|
||||
"github.com/krau/SaveAny-Bot/bot"
|
||||
"github.com/krau/SaveAny-Bot/common"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/logger"
|
||||
"github.com/krau/SaveAny-Bot/queue"
|
||||
"github.com/krau/SaveAny-Bot/storage"
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
)
|
||||
|
||||
func processPendingTask(task types.Task) error {
|
||||
logger.L.Debugf("Start processing task: %s", task.FileName)
|
||||
time.Sleep(10 * time.Second)
|
||||
logger.L.Debugf("Task done: %s", task.FileName)
|
||||
func processPendingTask(task *types.Task) error {
|
||||
logger.L.Debugf("Start processing task: %s", task.String())
|
||||
|
||||
// os.MkdirAll(config.Cfg.Temp.BasePath, os.ModePerm)
|
||||
os.MkdirAll(config.Cfg.Temp.BasePath, os.ModePerm)
|
||||
|
||||
// message, err := bot.Client.GetMessageByID(task.ChatID, task.MessageID)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// logger.L.Debugf("Start downloading file: %s", task.FileName)
|
||||
// bot.Client.EditMessage(task.ChatID, task.ReplyMessageID, "正在下载文件...")
|
||||
// dest, err := message.Download(&telegram.DownloadOptions{
|
||||
// FileName: common.GetCacheFilePath(task.FileName),
|
||||
// Threads: config.Cfg.Threads,
|
||||
// ChunkSize: config.Cfg.ChunkSize,
|
||||
// // ProgressCallback: func(totalBytes, downloadedBytes int64) {},
|
||||
// })
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
logger.L.Debugf("Start downloading file: %s", task.String())
|
||||
|
||||
// defer func() {
|
||||
// if config.Cfg.Temp.CacheTTL > 0 {
|
||||
// common.RmFileAfter(dest, time.Duration(config.Cfg.Temp.CacheTTL)*time.Second)
|
||||
// } else {
|
||||
// if err := os.Remove(dest); err != nil {
|
||||
// logger.L.Errorf("Failed to purge file: %s", err)
|
||||
// }
|
||||
// }
|
||||
// }()
|
||||
// if task.StoragePath == "" {
|
||||
// task.StoragePath = task.FileName
|
||||
// }
|
||||
task.Ctx.(*ext.Context).EditMessage(task.ChatID, &tg.MessagesEditMessageRequest{
|
||||
Message: "开始下载文件...",
|
||||
ID: task.ReplyMessageID,
|
||||
})
|
||||
|
||||
// bot.Client.EditMessage(task.ChatID, task.ReplyMessageID, "下载完成, 正在转存文件...")
|
||||
// if err := storage.Save(task.Storage, task.Ctx, dest, task.StoragePath); err != nil {
|
||||
// return err
|
||||
// }
|
||||
readCloser, err := NewTelegramReader(task.Ctx, bot.Client, task.File.Location, 0, task.File.FileSize-1, task.File.FileSize)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to create reader: %w", err)
|
||||
}
|
||||
defer readCloser.Close()
|
||||
|
||||
dest, err := os.Create(common.GetCacheFilePath(task.FileName()))
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to create file: %w", err)
|
||||
}
|
||||
logger.L.Debug("Created file: ", dest.Name())
|
||||
defer dest.Close()
|
||||
|
||||
if _, err := io.CopyN(dest, readCloser, task.File.FileSize); err != nil {
|
||||
return fmt.Errorf("Failed to download file: %w", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if config.Cfg.Temp.CacheTTL > 0 {
|
||||
common.RmFileAfter(dest.Name(), time.Duration(config.Cfg.Temp.CacheTTL)*time.Second)
|
||||
} else {
|
||||
if err := os.Remove(dest.Name()); err != nil {
|
||||
logger.L.Errorf("Failed to purge file: %s", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if task.StoragePath == "" {
|
||||
task.StoragePath = task.File.FileName
|
||||
}
|
||||
|
||||
task.Ctx.(*ext.Context).EditMessage(task.ChatID, &tg.MessagesEditMessageRequest{
|
||||
Message: "下载完成, 正在转存文件...",
|
||||
ID: task.ReplyMessageID,
|
||||
})
|
||||
|
||||
if err := storage.Save(task.Storage, task.Ctx, dest.Name(), task.StoragePath); err != nil {
|
||||
return fmt.Errorf("Failed to save file: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -61,12 +77,12 @@ func worker(queue *queue.TaskQueue, semaphore chan struct{}) {
|
||||
for {
|
||||
semaphore <- struct{}{}
|
||||
task := queue.GetTask()
|
||||
logger.L.Debugf("Got task: %s", task.FileName)
|
||||
logger.L.Debugf("Got task: %s", task.String())
|
||||
|
||||
switch task.Status {
|
||||
case types.Pending:
|
||||
logger.L.Infof("Processing task: %s", task.String())
|
||||
if err := processPendingTask(task); err != nil {
|
||||
if err := processPendingTask(&task); err != nil {
|
||||
logger.L.Errorf("Failed to do task: %s", err)
|
||||
task.Error = err
|
||||
if errors.Is(err, context.Canceled) {
|
||||
@@ -97,7 +113,7 @@ func worker(queue *queue.TaskQueue, semaphore chan struct{}) {
|
||||
logger.L.Errorf("Unknown task status: %s", task.Status)
|
||||
}
|
||||
<-semaphore
|
||||
logger.L.Debugf("Task done: %s", task.FileName)
|
||||
logger.L.Debugf("Task done: %s", task.String())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
135
core/reader.go
Normal file
135
core/reader.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/celestix/gotgproto"
|
||||
"github.com/gotd/td/tg"
|
||||
)
|
||||
|
||||
type telegramReader struct {
|
||||
ctx context.Context
|
||||
client *gotgproto.Client
|
||||
location *tg.InputDocumentFileLocation
|
||||
start int64
|
||||
end int64
|
||||
next func() ([]byte, error)
|
||||
buffer []byte
|
||||
bytesread int64
|
||||
chunkSize int64
|
||||
i int64
|
||||
contentLength int64
|
||||
}
|
||||
|
||||
func (*telegramReader) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *telegramReader) Read(p []byte) (n int, err error) {
|
||||
if r.bytesread == r.contentLength {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
if r.i >= int64(len(r.buffer)) {
|
||||
r.buffer, err = r.next()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len(r.buffer) == 0 {
|
||||
r.next = r.partStream()
|
||||
r.buffer, err = r.next()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
}
|
||||
r.i = 0
|
||||
}
|
||||
n = copy(p, r.buffer[r.i:])
|
||||
r.i += int64(n)
|
||||
r.bytesread += int64(n)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func NewTelegramReader(
|
||||
ctx context.Context,
|
||||
client *gotgproto.Client,
|
||||
location *tg.InputDocumentFileLocation,
|
||||
start int64,
|
||||
end int64,
|
||||
contentLength int64,
|
||||
) (io.ReadCloser, error) {
|
||||
|
||||
r := &telegramReader{
|
||||
ctx: ctx,
|
||||
location: location,
|
||||
client: client,
|
||||
start: start,
|
||||
end: end,
|
||||
chunkSize: int64(1024 * 1024),
|
||||
contentLength: contentLength,
|
||||
}
|
||||
|
||||
r.next = r.partStream()
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (r *telegramReader) chunk(offset int64, limit int64) ([]byte, error) {
|
||||
|
||||
req := &tg.UploadGetFileRequest{
|
||||
Offset: offset,
|
||||
Limit: int(limit),
|
||||
Location: r.location,
|
||||
}
|
||||
|
||||
res, err := r.client.API().UploadGetFile(r.ctx, req)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch result := res.(type) {
|
||||
case *tg.UploadFile:
|
||||
return result.Bytes, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected type %T", r)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *telegramReader) partStream() func() ([]byte, error) {
|
||||
|
||||
start := r.start
|
||||
end := r.end
|
||||
offset := start - (start % r.chunkSize)
|
||||
|
||||
firstPartCut := start - offset
|
||||
lastPartCut := (end % r.chunkSize) + 1
|
||||
partCount := int((end - offset + r.chunkSize) / r.chunkSize)
|
||||
currentPart := 1
|
||||
|
||||
readData := func() ([]byte, error) {
|
||||
if currentPart > partCount {
|
||||
return make([]byte, 0), nil
|
||||
}
|
||||
res, err := r.chunk(offset, r.chunkSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(res) == 0 {
|
||||
return res, nil
|
||||
} else if partCount == 1 {
|
||||
res = res[firstPartCut:lastPartCut]
|
||||
} else if currentPart == 1 {
|
||||
res = res[firstPartCut:]
|
||||
} else if currentPart == partCount {
|
||||
res = res[:lastPartCut]
|
||||
}
|
||||
|
||||
currentPart++
|
||||
offset += r.chunkSize
|
||||
return res, nil
|
||||
}
|
||||
return readData
|
||||
}
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"github.com/glebarez/sqlite"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/logger"
|
||||
"github.com/krau/SaveAny-Bot/model"
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
@@ -25,7 +25,7 @@ func Init() {
|
||||
os.Exit(1)
|
||||
}
|
||||
logger.L.Debug("Database connected")
|
||||
db.AutoMigrate(&model.ReceivedFile{}, &model.User{})
|
||||
db.AutoMigrate(&types.ReceivedFile{}, &types.User{})
|
||||
|
||||
for _, admin := range config.Cfg.Telegram.Admins {
|
||||
CreateUser(int64(admin))
|
||||
|
||||
12
dao/file.go
12
dao/file.go
@@ -1,13 +1,13 @@
|
||||
package dao
|
||||
|
||||
import "github.com/krau/SaveAny-Bot/model"
|
||||
import "github.com/krau/SaveAny-Bot/types"
|
||||
|
||||
func AddReceivedFile(receivedFile *model.ReceivedFile) error {
|
||||
func AddReceivedFile(receivedFile *types.ReceivedFile) error {
|
||||
return db.Create(receivedFile).Error
|
||||
}
|
||||
|
||||
func GetReceivedFileByChatAndMessageID(chatID int64, messageID int) (*model.ReceivedFile, error) {
|
||||
var receivedFile model.ReceivedFile
|
||||
func GetReceivedFileByChatAndMessageID(chatID int64, messageID int) (*types.ReceivedFile, error) {
|
||||
var receivedFile types.ReceivedFile
|
||||
err := db.Where("chat_id = ? AND message_id = ?", chatID, messageID).First(&receivedFile).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -15,10 +15,10 @@ func GetReceivedFileByChatAndMessageID(chatID int64, messageID int) (*model.Rece
|
||||
return &receivedFile, nil
|
||||
}
|
||||
|
||||
func UpdateReceivedFile(receivedFile *model.ReceivedFile) error {
|
||||
func UpdateReceivedFile(receivedFile *types.ReceivedFile) error {
|
||||
return db.Save(receivedFile).Error
|
||||
}
|
||||
|
||||
func DeleteReceivedFile(receivedFile *model.ReceivedFile) error {
|
||||
func DeleteReceivedFile(receivedFile *types.ReceivedFile) error {
|
||||
return db.Delete(receivedFile).Error
|
||||
}
|
||||
|
||||
10
dao/user.go
10
dao/user.go
@@ -1,22 +1,22 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"github.com/krau/SaveAny-Bot/model"
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
)
|
||||
|
||||
func CreateUser(userID int64) error {
|
||||
if _, err := GetUserByUserID(userID); err == nil {
|
||||
return nil
|
||||
}
|
||||
return db.Create(&model.User{UserID: userID}).Error
|
||||
return db.Create(&types.User{UserID: userID}).Error
|
||||
}
|
||||
|
||||
func GetUserByUserID(userID int64) (*model.User, error) {
|
||||
var user model.User
|
||||
func GetUserByUserID(userID int64) (*types.User, error) {
|
||||
var user types.User
|
||||
err := db.Where("user_id = ?", userID).First(&user).Error
|
||||
return &user, err
|
||||
}
|
||||
|
||||
func UpdateUser(user *model.User) error {
|
||||
func UpdateUser(user *types.User) error {
|
||||
return db.Save(user).Error
|
||||
}
|
||||
|
||||
2
go.mod
2
go.mod
@@ -19,6 +19,7 @@ require (
|
||||
require (
|
||||
github.com/AnimeKaizoku/cacher v1.0.2 // indirect
|
||||
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.2.0 // indirect
|
||||
github.com/cloudflare/circl v1.5.0 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/glebarez/go-sqlite v1.22.0 // indirect
|
||||
@@ -69,6 +70,7 @@ require (
|
||||
|
||||
require (
|
||||
github.com/andybalholm/brotli v1.1.1 // indirect
|
||||
github.com/coocood/freecache v1.2.4
|
||||
github.com/duke-git/lancet/v2 v2.3.3
|
||||
github.com/fsnotify/fsnotify v1.8.0 // indirect
|
||||
github.com/glebarez/sqlite v1.11.0
|
||||
|
||||
5
go.sum
5
go.sum
@@ -8,8 +8,13 @@ github.com/celestix/gotgproto v1.0.0-beta18 h1:7884H/il+mzNreOQ4SqoMa4S5njt3UmGP
|
||||
github.com/celestix/gotgproto v1.0.0-beta18/go.mod h1:osZOlN5irPByA0+3IPsZOH+Ibs0tOMSKmIdgGYEBRgE=
|
||||
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
|
||||
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
|
||||
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
|
||||
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/cloudflare/circl v1.5.0 h1:hxIWksrX6XN5a1L2TI/h53AGPhNHoUBo+TD1ms9+pys=
|
||||
github.com/cloudflare/circl v1.5.0/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs=
|
||||
github.com/coocood/freecache v1.2.4 h1:UdR6Yz/X1HW4fZOuH0Z94KwG851GWOSknua5VUbb/5M=
|
||||
github.com/coocood/freecache v1.2.4/go.mod h1:RBUWa/Cy+OHdfTGFEhEuE1pMCMX51Ncizj7rthiQ3vk=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package model
|
||||
package types
|
||||
|
||||
import (
|
||||
"gorm.io/gorm"
|
||||
@@ -7,10 +7,10 @@ import (
|
||||
type ReceivedFile struct {
|
||||
gorm.Model
|
||||
Processing bool
|
||||
FileName string
|
||||
ChatID int64
|
||||
MessageID int
|
||||
ReplyMessageID int
|
||||
FileName string
|
||||
}
|
||||
|
||||
type User struct {
|
||||
@@ -2,6 +2,7 @@ package types
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/gotd/td/tg"
|
||||
)
|
||||
@@ -30,7 +31,7 @@ type Task struct {
|
||||
Ctx context.Context
|
||||
Error error
|
||||
Status TaskStatus
|
||||
FileName string
|
||||
File *File
|
||||
Storage StorageType
|
||||
StoragePath string
|
||||
|
||||
@@ -40,7 +41,11 @@ type Task struct {
|
||||
}
|
||||
|
||||
func (t Task) String() string {
|
||||
return t.FileName
|
||||
return fmt.Sprintf("[%d:%d]:%s", t.ChatID, t.MessageID, t.File.FileName)
|
||||
}
|
||||
|
||||
func (t Task) FileName() string {
|
||||
return t.File.FileName
|
||||
}
|
||||
|
||||
type File struct {
|
||||
|
||||
Reference in New Issue
Block a user