优化sdk

This commit is contained in:
sky22333
2026-03-17 23:21:27 +08:00
parent da0586b5ca
commit af647f979b
10 changed files with 260 additions and 65 deletions

View File

@@ -96,9 +96,20 @@ func main() {
}
```
`target_id` 可选;不传时 SDK 会自动按 `target_type` 回填最近采集目标。
`target_type``target_id` 都不传时SDK 会使用最近一次采集到的目标类型与 ID。
如需在短生命周期任务中关闭采集器,可使用:
```go
client, err := qqbot.NewWithOptions(cfg, qqbot.ClientOptions{
StartCollector: false,
})
```
## 4. 目标采集
启动服务后,用自己的 QQ 给机器人发消息,系统会自动采集目标并写入 `data/known_targets.json`
启动服务后,用自己的 QQ 给机器人发消息,系统会自动采集目标并写入 `targets.file_path` 对应的文件(默认 `data/targets.json`
可通过 `GET /api/v1/targets` 查看。
## 5. 常用命令

View File

@@ -10,10 +10,8 @@ import (
"syscall"
"github.com/sky22333/qqbot/config"
"github.com/sky22333/qqbot/internal/collector"
"github.com/sky22333/qqbot/internal/bootstrap"
"github.com/sky22333/qqbot/internal/httpserver"
"github.com/sky22333/qqbot/internal/notifier"
"github.com/sky22333/qqbot/internal/targets"
)
func main() {
@@ -31,30 +29,13 @@ func main() {
}
logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: logLevel}))
flushInterval, err := cfg.TargetsFlushInterval()
components, err := bootstrap.New(cfg, logger, bootstrap.Options{StartCollector: true})
if err != nil {
panic(err)
}
targetStore, err := targets.NewStore(cfg.Targets.FilePath, cfg.Targets.MaxRecords, flushInterval)
if err != nil {
panic(err)
}
targetCollector, err := collector.New(cfg, logger, targetStore)
if err != nil {
panic(err)
}
targetCollector.Start()
logger.Info("采集器已启动")
n, err := notifier.New(cfg, logger)
if err != nil {
panic(err)
}
n.SetTargetStore(targetStore)
server, err := httpserver.New(cfg, logger, n, targetStore)
server, err := httpserver.New(cfg, logger, components.Notifier, components.Targets)
if err != nil {
panic(err)
}
@@ -83,8 +64,6 @@ func main() {
ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
defer cancel()
_ = server.Shutdown(ctx)
targetCollector.Stop()
n.Close()
_ = targetStore.Close()
components.Close()
logger.Info("服务已退出")
}

View File

@@ -107,26 +107,57 @@ func Default() Config {
}
func Load(path string) (Config, error) {
return loadWithValidator(path, Config.ValidateForServer)
}
func LoadSDK(path string) (Config, error) {
return loadWithValidator(path, Config.ValidateForSDK)
}
func loadWithValidator(path string, validator func(Config) error) (Config, error) {
cfg := Default()
if _, err := toml.DecodeFile(path, &cfg); err != nil {
return Config{}, err
}
if err := cfg.Validate(); err != nil {
if err := validator(cfg); err != nil {
return Config{}, err
}
return cfg, nil
}
func (c Config) Validate() error {
func (c Config) ValidateForSDK() error {
return c.validateCommon()
}
func (c Config) ValidateForServer() error {
if err := c.validateCommon(); err != nil {
return err
}
if strings.TrimSpace(c.Server.ListenAddr) == "" {
return errors.New("server.listen_addr 不能为空")
}
if c.Server.MaxBodyBytes <= 0 {
return errors.New("server.max_body_bytes 必须大于 0")
}
if _, err := c.ReadTimeout(); err != nil {
return fmt.Errorf("server.read_timeout 无效: %w", err)
}
if _, err := c.WriteTimeout(); err != nil {
return fmt.Errorf("server.write_timeout 无效: %w", err)
}
if _, err := c.ShutdownTimeout(); err != nil {
return fmt.Errorf("server.shutdown_timeout 无效: %w", err)
}
return nil
}
func (c Config) validateCommon() error {
if strings.TrimSpace(c.QQBot.AppID) == "" {
return errors.New("qqbot.app_id 不能为空")
}
if strings.TrimSpace(c.QQBot.ClientSecret) == "" {
return errors.New("qqbot.client_secret 不能为空")
}
if strings.TrimSpace(c.Server.ListenAddr) == "" {
return errors.New("server.listen_addr 不能为空")
}
if c.Dispatch.QueueSize <= 0 {
return errors.New("dispatch.queue_size 必须大于 0")
}
@@ -139,21 +170,9 @@ func (c Config) Validate() error {
if c.Dispatch.RetryBackoffMS <= 0 {
return errors.New("dispatch.retry_backoff_ms 必须大于 0")
}
if c.Server.MaxBodyBytes <= 0 {
return errors.New("server.max_body_bytes 必须大于 0")
}
if _, err := c.RequestTimeout(); err != nil {
return fmt.Errorf("qqbot.request_timeout 无效: %w", err)
}
if _, err := c.ReadTimeout(); err != nil {
return fmt.Errorf("server.read_timeout 无效: %w", err)
}
if _, err := c.WriteTimeout(); err != nil {
return fmt.Errorf("server.write_timeout 无效: %w", err)
}
if _, err := c.ShutdownTimeout(); err != nil {
return fmt.Errorf("server.shutdown_timeout 无效: %w", err)
}
if _, err := c.EnqueueTimeout(); err != nil {
return fmt.Errorf("dispatch.enqueue_timeout 无效: %w", err)
}

30
config/config_test.go Normal file
View File

@@ -0,0 +1,30 @@
package config
import (
"os"
"path/filepath"
"testing"
)
func TestLoadSDKAllowsServerOnlyConstraintsMissing(t *testing.T) {
path := filepath.Join(t.TempDir(), "sdk.toml")
content := `
[qqbot]
app_id = "123"
client_secret = "secret"
[server]
listen_addr = ""
max_body_bytes = 0
`
if err := os.WriteFile(path, []byte(content), 0o644); err != nil {
t.Fatalf("写入配置文件失败: %v", err)
}
if _, err := LoadSDK(path); err != nil {
t.Fatalf("LoadSDK 不应因 server 字段失败: %v", err)
}
if _, err := Load(path); err == nil {
t.Fatalf("Load 应校验 server 字段并返回错误")
}
}

View File

@@ -1,11 +0,0 @@
package qqbot
import "github.com/sky22333/qqbot/config"
func LoadConfig(path string) (Config, error) {
return config.Load(path)
}
func DefaultConfig() Config {
return config.Default()
}

View File

@@ -0,0 +1,68 @@
package bootstrap
import (
"log/slog"
"github.com/sky22333/qqbot/config"
"github.com/sky22333/qqbot/internal/collector"
"github.com/sky22333/qqbot/internal/notifier"
"github.com/sky22333/qqbot/internal/targets"
)
type Options struct {
StartCollector bool
}
type Components struct {
Notifier *notifier.Notifier
Targets *targets.Store
Collector *collector.Collector
}
func New(cfg config.Config, logger *slog.Logger, opts Options) (*Components, error) {
flushInterval, err := cfg.TargetsFlushInterval()
if err != nil {
return nil, err
}
targetStore, err := targets.NewStore(cfg.Targets.FilePath, cfg.Targets.MaxRecords, flushInterval)
if err != nil {
return nil, err
}
n, err := notifier.New(cfg, logger)
if err != nil {
_ = targetStore.Close()
return nil, err
}
n.SetTargetStore(targetStore)
c := &Components{
Notifier: n,
Targets: targetStore,
}
if !opts.StartCollector {
return c, nil
}
targetCollector, err := collector.New(cfg, logger, targetStore)
if err != nil {
n.Close()
_ = targetStore.Close()
return nil, err
}
targetCollector.Start()
c.Collector = targetCollector
return c, nil
}
func (c *Components) Close() {
if c == nil {
return
}
if c.Collector != nil {
c.Collector.Stop()
}
if c.Notifier != nil {
c.Notifier.Close()
}
if c.Targets != nil {
_ = c.Targets.Close()
}
}

View File

@@ -0,0 +1,64 @@
package notifier
import (
"context"
"path/filepath"
"strings"
"testing"
"time"
"github.com/sky22333/qqbot/config"
"github.com/sky22333/qqbot/internal/targets"
"github.com/sky22333/qqbot/message"
)
func TestEnqueueFillTargetFromLatestByType(t *testing.T) {
cfg := config.Default()
cfg.Dispatch.Workers = 0
cfg.Targets.FilePath = filepath.Join(t.TempDir(), "targets.json")
store, err := targets.NewStore(cfg.Targets.FilePath, cfg.Targets.MaxRecords, 10*time.Millisecond)
if err != nil {
t.Fatalf("创建目标存储失败: %v", err)
}
defer store.Close()
if err := store.Upsert(message.TargetC2C, "user-001", "m1", "hello"); err != nil {
t.Fatalf("写入目标失败: %v", err)
}
n, err := New(cfg, nil)
if err != nil {
t.Fatalf("创建 notifier 失败: %v", err)
}
defer n.Close()
n.SetTargetStore(store)
_, err = n.Enqueue(context.Background(), message.PushRequest{
TargetType: message.TargetC2C,
Content: "通知",
})
if err != nil {
t.Fatalf("应自动补全 target_id但返回错误: %v", err)
}
}
func TestEnqueueWithoutStoreReturnsTargetIDError(t *testing.T) {
cfg := config.Default()
cfg.Dispatch.Workers = 0
n, err := New(cfg, nil)
if err != nil {
t.Fatalf("创建 notifier 失败: %v", err)
}
defer n.Close()
_, err = n.Enqueue(context.Background(), message.PushRequest{
TargetType: message.TargetC2C,
Content: "通知",
})
if err == nil {
t.Fatalf("预期返回 target_id 不能为空")
}
if !strings.Contains(err.Error(), "target_id 不能为空") {
t.Fatalf("错误信息不符合预期: %v", err)
}
}

30
sdk.go
View File

@@ -6,27 +6,37 @@ import (
"os"
"github.com/sky22333/qqbot/config"
"github.com/sky22333/qqbot/internal/notifier"
"github.com/sky22333/qqbot/internal/bootstrap"
)
type Client struct {
notifier *notifier.Notifier
components *bootstrap.Components
}
type ClientOptions struct {
StartCollector bool
}
func New(cfg Config) (*Client, error) {
if err := cfg.Validate(); err != nil {
return NewWithOptions(cfg, ClientOptions{
StartCollector: true,
})
}
func NewWithOptions(cfg Config, opts ClientOptions) (*Client, error) {
if err := cfg.ValidateForSDK(); err != nil {
return nil, err
}
logger := slog.New(slog.NewJSONHandler(os.Stdout, nil))
n, err := notifier.New(cfg, logger)
components, err := bootstrap.New(cfg, logger, bootstrap.Options{StartCollector: opts.StartCollector})
if err != nil {
return nil, err
}
return &Client{notifier: n}, nil
return &Client{components: components}, nil
}
func NewFromConfigFile(path string) (*Client, error) {
cfg, err := config.Load(path)
cfg, err := config.LoadSDK(path)
if err != nil {
return nil, err
}
@@ -34,17 +44,17 @@ func NewFromConfigFile(path string) (*Client, error) {
}
func (c *Client) Send(ctx context.Context, req PushRequest) (PushResult, error) {
return c.notifier.Send(ctx, req)
return c.components.Notifier.Send(ctx, req)
}
func (c *Client) Enqueue(ctx context.Context, req PushRequest) (string, error) {
return c.notifier.Enqueue(ctx, req)
return c.components.Notifier.Enqueue(ctx, req)
}
func (c *Client) GetStatus(requestID string) (DeliveryStatus, bool) {
return c.notifier.GetStatus(requestID)
return c.components.Notifier.GetStatus(requestID)
}
func (c *Client) Close() {
c.notifier.Close()
c.components.Close()
}

23
sdk_options_test.go Normal file
View File

@@ -0,0 +1,23 @@
package qqbot
import (
"path/filepath"
"testing"
"github.com/sky22333/qqbot/config"
)
func TestNewWithOptionsWithoutCollector(t *testing.T) {
cfg := config.Default()
cfg.QQBot.AppID = "123"
cfg.QQBot.ClientSecret = "secret"
cfg.Server.ListenAddr = ""
cfg.Server.MaxBodyBytes = 0
cfg.Targets.FilePath = filepath.Join(t.TempDir(), "targets.json")
client, err := NewWithOptions(cfg, ClientOptions{StartCollector: false})
if err != nil {
t.Fatalf("创建客户端失败: %v", err)
}
client.Close()
}

View File

@@ -6,6 +6,8 @@ import (
"strings"
"testing"
"time"
"github.com/sky22333/qqbot/config"
)
func TestSendNotificationWithConfig(t *testing.T) {
@@ -19,7 +21,7 @@ func TestSendNotificationWithConfig(t *testing.T) {
configPath = "configs/config.toml"
}
cfg, err := LoadConfig(configPath)
cfg, err := config.LoadSDK(configPath)
if err != nil {
t.Fatalf("加载配置失败: %v", err)
}