mirror of
https://github.com/krau/SaveAny-Bot.git
synced 2026-05-11 17:09:41 +08:00
fix: add VirtualHost option to S3StorageConfig and implement endpoint validation, close #150
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
@@ -35,21 +36,52 @@ func (m *S3) Init(ctx context.Context, cfg storconfig.StorageConfig) error {
|
||||
|
||||
m.config = *s3Config
|
||||
m.logger = log.FromContext(ctx).WithPrefix(fmt.Sprintf("s3[%s]", m.config.Name))
|
||||
loadOpts := make([]config.LoadOptionsFunc, 0)
|
||||
if m.config.Region != "" {
|
||||
loadOpts = append(loadOpts, config.WithRegion(m.config.Region))
|
||||
}
|
||||
if endpoint := m.config.Endpoint; endpoint != "" {
|
||||
if !strings.HasPrefix(endpoint, "http://") && !strings.HasPrefix(endpoint, "https://") {
|
||||
if m.config.UseSSL {
|
||||
endpoint = "https://" + endpoint
|
||||
} else {
|
||||
endpoint = "http://" + endpoint
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := url.Parse(endpoint); err != nil {
|
||||
return fmt.Errorf("invalid s3 endpoint %q: %w", m.config.Endpoint, err)
|
||||
}
|
||||
loadOpts = append(loadOpts, config.WithBaseEndpoint(endpoint))
|
||||
}
|
||||
loadOpts = append(loadOpts, config.WithCredentialsProvider(
|
||||
credentials.NewStaticCredentialsProvider(
|
||||
m.config.AccessKeyID,
|
||||
m.config.SecretAccessKey,
|
||||
"",
|
||||
),
|
||||
))
|
||||
awsCfg, err := config.LoadDefaultConfig(
|
||||
ctx,
|
||||
config.WithRegion(m.config.Region),
|
||||
config.WithCredentialsProvider(
|
||||
credentials.NewStaticCredentialsProvider(
|
||||
m.config.AccessKeyID,
|
||||
m.config.SecretAccessKey,
|
||||
"",
|
||||
),
|
||||
),
|
||||
func() []func(*config.LoadOptions) error {
|
||||
// wtf aws sdk
|
||||
// https://github.com/aws/aws-sdk-go-v2/issues/2193
|
||||
funcs := make([]func(*config.LoadOptions) error, 0)
|
||||
for _, fn := range loadOpts {
|
||||
funcs = append(funcs, fn)
|
||||
}
|
||||
return funcs
|
||||
}()...,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load AWS config: %w", err)
|
||||
}
|
||||
m.client = s3.NewFromConfig(awsCfg)
|
||||
|
||||
m.client = s3.NewFromConfig(awsCfg, func(o *s3.Options) {
|
||||
// Path style: https://s3.amazonaws.com/mybucket/path/to/file.jpg
|
||||
// virtual hosted style: https://mybucket.s3.amazonaws.com/path/to/file.jpg
|
||||
o.UsePathStyle = !m.config.VirtualHost
|
||||
})
|
||||
|
||||
// Check if bucket exists
|
||||
_, err = m.client.HeadBucket(ctx, &s3.HeadBucketInput{
|
||||
|
||||
72
storage/s3/s3_test.go
Normal file
72
storage/s3/s3_test.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package s3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/johannesboyne/gofakes3"
|
||||
"github.com/johannesboyne/gofakes3/backend/s3mem"
|
||||
storconfig "github.com/krau/SaveAny-Bot/config/storage"
|
||||
)
|
||||
|
||||
func newTestContext(t *testing.T) context.Context {
|
||||
t.Helper()
|
||||
logger := log.NewWithOptions(nil, log.Options{ReportTimestamp: false})
|
||||
ctx := context.Background()
|
||||
return log.WithContext(ctx, logger)
|
||||
}
|
||||
|
||||
func newFakeS3(t *testing.T) (*S3, *storconfig.S3StorageConfig) {
|
||||
t.Helper()
|
||||
|
||||
backend := s3mem.New()
|
||||
fakeSrv := gofakes3.New(backend)
|
||||
ts := httptest.NewServer(fakeSrv.Server())
|
||||
t.Cleanup(ts.Close)
|
||||
|
||||
cfg := &storconfig.S3StorageConfig{
|
||||
BaseConfig: storconfig.BaseConfig{
|
||||
Name: "test-s3",
|
||||
Type: "s3",
|
||||
Enable: true,
|
||||
},
|
||||
Endpoint: ts.URL,
|
||||
AccessKeyID: "test-access-key",
|
||||
SecretAccessKey: "test-secret",
|
||||
BucketName: "test-bucket",
|
||||
BasePath: "base",
|
||||
Region: "us-east-1",
|
||||
}
|
||||
|
||||
if err := backend.CreateBucket("test-bucket"); err != nil {
|
||||
t.Fatalf("failed to create fake bucket: %v", err)
|
||||
}
|
||||
|
||||
s := &S3{}
|
||||
ctx := newTestContext(t)
|
||||
if err := s.Init(ctx, cfg); err != nil {
|
||||
t.Fatalf("init s3 failed: %v", err)
|
||||
}
|
||||
|
||||
return s, cfg
|
||||
}
|
||||
|
||||
func TestS3_SaveAndExists(t *testing.T) {
|
||||
s, _ := newFakeS3(t)
|
||||
ctx := context.Background()
|
||||
|
||||
content := []byte("hello world")
|
||||
reader := bytes.NewReader(content)
|
||||
key := "foo/bar.txt"
|
||||
|
||||
if err := s.Save(ctx, reader, key); err != nil {
|
||||
t.Fatalf("Save failed: %v", err)
|
||||
}
|
||||
|
||||
if !s.Exists(ctx, key) {
|
||||
t.Fatalf("Exists should return true for saved key")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user