diff --git a/storage/s3/client.go b/pkg/s3/client.go similarity index 65% rename from storage/s3/client.go rename to pkg/s3/client.go index fcc3a77..e65362f 100644 --- a/storage/s3/client.go +++ b/pkg/s3/client.go @@ -12,8 +12,6 @@ import ( "sort" "strings" "time" - - storconfig "github.com/krau/SaveAny-Bot/config/storage" ) type Client struct { @@ -26,32 +24,56 @@ type Client struct { pathStyle bool } -func NewClient(cfg storconfig.S3StorageConfig) (*Client, error) { - endpoint := cfg.Endpoint - if !strings.HasPrefix(endpoint, "http") { - if cfg.UseSSL { - endpoint = "https://" + endpoint - } else { - endpoint = "http://" + endpoint +type Config struct { + Endpoint string + Region string + BucketName string + AccessKeyID string + SecretAccessKey string + PathStyle bool + HttpClient *http.Client +} + +func (c *Config) ApplyDefaults() { + if c.HttpClient == nil { + c.HttpClient = http.DefaultClient + } + if c.Endpoint == "" { + switch c.Region { + case "us-east-1", "": + c.Endpoint = "https://s3.amazonaws.com" + default: + c.Endpoint = fmt.Sprintf("https://s3.%s.amazonaws.com", c.Region) } } +} +func NewClient(cfg *Config) (*Client, error) { + cfg.ApplyDefaults() return &Client{ - endpoint: endpoint, + endpoint: cfg.Endpoint, region: cfg.Region, bucket: cfg.BucketName, accessKey: cfg.AccessKeyID, secretKey: cfg.SecretAccessKey, - pathStyle: !cfg.VirtualHost, - httpClient: http.DefaultClient, + httpClient: cfg.HttpClient, + pathStyle: cfg.PathStyle, }, nil } func (c *Client) HeadBucket(ctx context.Context) error { - url := c.buildURL("") - req, _ := http.NewRequestWithContext(ctx, "HEAD", url, nil) + url, err := c.buildURL("") + if err != nil { + return err + } + req, err := http.NewRequestWithContext(ctx, "HEAD", url, nil) + if err != nil { + return err + } - signRequest(req, c.region, c.accessKey, c.secretKey, hashSHA256(nil)) + if err := signRequest(req, c.region, c.accessKey, c.secretKey, hashSHA256(nil)); err != nil { + return err + } resp, err := c.httpClient.Do(req) if err != nil { @@ -66,8 +88,17 @@ func (c *Client) HeadBucket(ctx context.Context) error { } func (c *Client) Exists(ctx context.Context, key string) bool { - req, _ := http.NewRequestWithContext(ctx, "HEAD", c.buildURL(key), nil) - signRequest(req, c.region, c.accessKey, c.secretKey, hashSHA256(nil)) + url, err := c.buildURL(key) + if err != nil { + return false + } + req, err := http.NewRequestWithContext(ctx, "HEAD", url, nil) + if err != nil { + return false + } + if err := signRequest(req, c.region, c.accessKey, c.secretKey, hashSHA256(nil)); err != nil { + return false + } resp, err := c.httpClient.Do(req) if err != nil { @@ -79,12 +110,21 @@ func (c *Client) Exists(ctx context.Context, key string) bool { } func (c *Client) Put(ctx context.Context, key string, r io.Reader, size int64) error { - req, _ := http.NewRequestWithContext(ctx, "PUT", c.buildURL(key), r) + url, err := c.buildURL(key) + if err != nil { + return err + } + req, err := http.NewRequestWithContext(ctx, "PUT", url, r) + if err != nil { + return err + } if size >= 0 { req.ContentLength = size } - signRequest(req, c.region, c.accessKey, c.secretKey, "UNSIGNED-PAYLOAD") + if err := signRequest(req, c.region, c.accessKey, c.secretKey, "UNSIGNED-PAYLOAD"); err != nil { + return err + } resp, err := c.httpClient.Do(req) if err != nil { @@ -98,14 +138,17 @@ func (c *Client) Put(ctx context.Context, key string, r io.Reader, size int64) e return nil } -func (c *Client) buildURL(key string) string { +func (c *Client) buildURL(key string) (string, error) { if c.pathStyle { - return fmt.Sprintf("%s/%s/%s", c.endpoint, c.bucket, key) + return fmt.Sprintf("%s/%s/%s", c.endpoint, c.bucket, key), nil + } + u, err := url.Parse(c.endpoint) + if err != nil { + return "", err } - u, _ := url.Parse(c.endpoint) u.Host = c.bucket + "." + u.Host u.Path = "/" + key - return u.String() + return u.String(), nil } func hmacSHA256(key []byte, data string) []byte { diff --git a/storage/s3/s3.go b/storage/s3/s3.go index 246298e..623544a 100644 --- a/storage/s3/s3.go +++ b/storage/s3/s3.go @@ -11,12 +11,13 @@ import ( storconfig "github.com/krau/SaveAny-Bot/config/storage" "github.com/krau/SaveAny-Bot/pkg/enums/ctxkey" storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage" + "github.com/krau/SaveAny-Bot/pkg/s3" "github.com/rs/xid" ) type S3 struct { config storconfig.S3StorageConfig - client *Client + client *s3.Client logger *log.Logger } @@ -30,7 +31,14 @@ func (m *S3) Init(ctx context.Context, cfg storconfig.StorageConfig) error { } m.config = *s3cfg m.logger = log.FromContext(ctx).WithPrefix(fmt.Sprintf("s3[%s]", m.config.Name)) - client, err := NewClient(m.config) + client, err := s3.NewClient(&s3.Config{ + Endpoint: m.config.Endpoint, + Region: m.config.Region, + AccessKeyID: m.config.AccessKeyID, + SecretAccessKey: m.config.SecretAccessKey, + BucketName: m.config.BucketName, + PathStyle: !m.config.VirtualHost, + }) if err != nil { return fmt.Errorf("failed to create s3 client: %w", err) }