diff --git a/pkg/s3/client.go b/pkg/s3/client.go index e65362f..9d05c20 100644 --- a/pkg/s3/client.go +++ b/pkg/s3/client.go @@ -133,7 +133,7 @@ func (c *Client) Put(ctx context.Context, key string, r io.Reader, size int64) e defer resp.Body.Close() if resp.StatusCode >= 300 { - return fmt.Errorf("put object failed: %s", resp.Status) + return responseError("put object", resp) } return nil } @@ -170,10 +170,21 @@ func signRequest(req *http.Request, region, accessKey, secretKey string, payload req.Header.Set("x-amz-date", amzDate) req.Header.Set("x-amz-content-sha256", payloadHash) - // Canonical headers - var headers []string + // Canonical headers. Host is required by SigV4, but Go stores it on + // Request.Host/URL.Host rather than in Request.Header. + headerValues := map[string]string{ + "host": req.URL.Host, + } + if req.Host != "" { + headerValues["host"] = req.Host + } for k := range req.Header { - headers = append(headers, strings.ToLower(k)) + headerValues[strings.ToLower(k)] = strings.TrimSpace(req.Header.Get(k)) + } + + var headers []string + for k := range headerValues { + headers = append(headers, k) } sort.Strings(headers) @@ -181,7 +192,7 @@ func signRequest(req *http.Request, region, accessKey, secretKey string, payload for _, k := range headers { canonicalHeaders.WriteString(k) canonicalHeaders.WriteString(":") - canonicalHeaders.WriteString(strings.TrimSpace(req.Header.Get(k))) + canonicalHeaders.WriteString(headerValues[k]) canonicalHeaders.WriteString("\n") } @@ -189,7 +200,7 @@ func signRequest(req *http.Request, region, accessKey, secretKey string, payload canonicalRequest := strings.Join([]string{ req.Method, - req.URL.EscapedPath(), + canonicalURI(req.URL.Path), req.URL.RawQuery, canonicalHeaders.String(), signedHeaders, @@ -219,3 +230,54 @@ func signRequest(req *http.Request, region, accessKey, secretKey string, payload req.Header.Set("Authorization", auth) return nil } + +func responseError(operation string, resp *http.Response) error { + body, err := io.ReadAll(io.LimitReader(resp.Body, 4096)) + if err != nil { + return fmt.Errorf("%s failed: %s", operation, resp.Status) + } + + message := strings.TrimSpace(string(body)) + if message == "" { + return fmt.Errorf("%s failed: %s", operation, resp.Status) + } + + return fmt.Errorf("%s failed: %s: %s", operation, resp.Status, message) +} + +func canonicalURI(path string) string { + if path == "" { + return "/" + } + + var b strings.Builder + for i := 0; i < len(path); i++ { + c := path[i] + if shouldEscapePathByte(c) { + b.WriteByte('%') + b.WriteByte("0123456789ABCDEF"[c>>4]) + b.WriteByte("0123456789ABCDEF"[c&15]) + continue + } + b.WriteByte(c) + } + return b.String() +} + +func shouldEscapePathByte(c byte) bool { + if c >= 'A' && c <= 'Z' { + return false + } + if c >= 'a' && c <= 'z' { + return false + } + if c >= '0' && c <= '9' { + return false + } + switch c { + case '-', '.', '_', '~', '/': + return false + default: + return true + } +}