重构Docker代理,支持认证透传和可配置Docker Hub上游,并补充边界测试

This commit is contained in:
starry
2026-05-16 02:58:31 +08:00
committed by GitHub
parent ba83a44492
commit 53cc1761ce
2 changed files with 937 additions and 517 deletions

View File

@@ -1,92 +1,142 @@
package handlers
import (
"context"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/google/go-containerregistry/pkg/authn"
"github.com/google/go-containerregistry/pkg/name"
"github.com/google/go-containerregistry/pkg/v1/remote"
"hubproxy/config"
"hubproxy/utils"
)
// DockerProxy Docker代理配置
type DockerProxy struct {
registry name.Registry
options []remote.Option
type registryTarget struct {
Name string
Upstream string
AuthRealm string
AuthService string
AutoLibraryPrefix bool
}
var dockerProxy *DockerProxy
const (
dockerHubName = "docker.io"
dockerHubUpstream = "https://registry-1.docker.io"
dockerHubAuthRealm = "https://auth.docker.io/token"
dockerHubAuthService = "registry.docker.io"
)
// RegistryDetector Registry检测器
type RegistryDetector struct{}
var hopByHopHeaders = map[string]struct{}{
"connection": {},
"keep-alive": {},
"proxy-authenticate": {},
"proxy-authorization": {},
"te": {},
"trailer": {},
"transfer-encoding": {},
"upgrade": {},
}
// detectRegistryDomain 检测Registry域名并返回域名和剩余路径
func (rd *RegistryDetector) detectRegistryDomain(c *gin.Context, path string) (string, string) {
var forwardedRequestHeaders = []string{
"Authorization",
"Accept",
"Range",
"If-Range",
"If-Match",
"If-None-Match",
"If-Modified-Since",
"If-Unmodified-Since",
}
// InitDockerProxy is kept as the Docker proxy initialization hook. The online
// registry proxy is intentionally stateless and uses the shared HTTP client.
func InitDockerProxy() {}
func defaultRegistryTarget() registryTarget {
cfg := config.GetConfig()
if mapping, exists := cfg.Registries[dockerHubName]; exists && mapping.Enabled {
target := registryTargetFromMapping(dockerHubName, mapping)
target.AuthService = dockerHubAuthService
target.AutoLibraryPrefix = true
return target
}
return registryTarget{
Name: dockerHubName,
Upstream: dockerHubUpstream,
AuthRealm: dockerHubAuthRealm,
AuthService: dockerHubAuthService,
AutoLibraryPrefix: true,
}
}
func registryTargetFromMapping(name string, mapping config.RegistryMapping) registryTarget {
upstream := strings.TrimRight(strings.TrimSpace(mapping.Upstream), "/")
if upstream == "" {
upstream = name
}
if !strings.HasPrefix(upstream, "http://") && !strings.HasPrefix(upstream, "https://") {
upstream = "https://" + upstream
}
authRealm := strings.TrimSpace(mapping.AuthHost)
if authRealm == "" {
authRealm = strings.TrimPrefix(upstream, "https://")
authRealm = strings.TrimPrefix(authRealm, "http://")
}
if !strings.HasPrefix(authRealm, "http://") && !strings.HasPrefix(authRealm, "https://") {
authRealm = "https://" + authRealm
}
authService := strings.TrimPrefix(strings.TrimPrefix(upstream, "https://"), "http://")
return registryTarget{
Name: name,
Upstream: upstream,
AuthRealm: authRealm,
AuthService: authService,
AutoLibraryPrefix: false,
}
}
func resolveRegistryTarget(c *gin.Context, pathWithoutV2 string) (registryTarget, string) {
cfg := config.GetConfig()
// 兼容Containerd的ns参数
if ns := c.Query("ns"); ns != "" {
if ns := strings.TrimSpace(c.Query("ns")); ns != "" {
if mapping, exists := cfg.Registries[ns]; exists && mapping.Enabled {
return ns, path
return registryTargetFromMapping(ns, mapping), pathWithoutV2
}
}
for domain := range cfg.Registries {
if strings.HasPrefix(path, domain+"/") {
remainingPath := strings.TrimPrefix(path, domain+"/")
return domain, remainingPath
for domain, mapping := range cfg.Registries {
if mapping.Enabled && strings.HasPrefix(pathWithoutV2, domain+"/") {
return registryTargetFromMapping(domain, mapping), strings.TrimPrefix(pathWithoutV2, domain+"/")
}
}
return "", path
return defaultRegistryTarget(), pathWithoutV2
}
// isRegistryEnabled 检查Registry是否启用
func (rd *RegistryDetector) isRegistryEnabled(domain string) bool {
func resolveTokenTarget(c *gin.Context) (registryTarget, bool) {
name := strings.Trim(strings.TrimSpace(c.Param("path")), "/")
if name == "" {
return defaultRegistryTarget(), true
}
if name == dockerHubName || name == "dockerhub" || name == "registry-1.docker.io" {
return defaultRegistryTarget(), true
}
cfg := config.GetConfig()
if mapping, exists := cfg.Registries[domain]; exists {
return mapping.Enabled
if mapping, exists := cfg.Registries[name]; exists && mapping.Enabled {
return registryTargetFromMapping(name, mapping), true
}
return false
return registryTarget{}, false
}
// getRegistryMapping 获取Registry映射配置
func (rd *RegistryDetector) getRegistryMapping(domain string) (config.RegistryMapping, bool) {
cfg := config.GetConfig()
mapping, exists := cfg.Registries[domain]
return mapping, exists && mapping.Enabled
}
var registryDetector = &RegistryDetector{}
// InitDockerProxy 初始化Docker代理
func InitDockerProxy() {
registry, err := name.NewRegistry("registry-1.docker.io")
if err != nil {
fmt.Printf("创建Docker registry失败: %v\n", err)
return
}
options := []remote.Option{
remote.WithAuth(authn.Anonymous),
remote.WithUserAgent("hubproxy/go-containerregistry"),
remote.WithTransport(utils.GetGlobalHTTPClient().Transport),
}
dockerProxy = &DockerProxy{
registry: registry,
options: options,
}
}
// ProxyDockerRegistryGin 标准Docker Registry API v2代理
// ProxyDockerRegistryGin proxies Docker Registry API v2 requests transparently.
func ProxyDockerRegistryGin(c *gin.Context) {
path := c.Request.URL.Path
@@ -95,528 +145,305 @@ func ProxyDockerRegistryGin(c *gin.Context) {
return
}
if strings.HasPrefix(path, "/v2/") {
handleRegistryRequest(c, path)
} else {
if !strings.HasPrefix(path, "/v2/") {
c.String(http.StatusNotFound, "Docker Registry API v2 only")
return
}
handleRegistryRequest(c, path)
}
// handleRegistryRequest 处理Registry请求
func handleRegistryRequest(c *gin.Context, path string) {
pathWithoutV2 := strings.TrimPrefix(path, "/v2/")
target, targetPath := resolveRegistryTarget(c, pathWithoutV2)
if registryDomain, remainingPath := registryDetector.detectRegistryDomain(c, pathWithoutV2); registryDomain != "" {
if registryDetector.isRegistryEnabled(registryDomain) {
c.Set("target_registry_domain", registryDomain)
c.Set("target_path", remainingPath)
handleMultiRegistryRequest(c, registryDomain, remainingPath)
return
}
}
imageName, apiType, reference := parseRegistryPath(pathWithoutV2)
imageName, apiType, _ := parseRegistryPath(targetPath)
if imageName == "" || apiType == "" {
c.String(http.StatusBadRequest, "Invalid path format")
return
}
if !strings.Contains(imageName, "/") {
if target.AutoLibraryPrefix && !strings.Contains(imageName, "/") {
imageName = "library/" + imageName
targetPath = strings.TrimPrefix(targetPath, strings.TrimPrefix(imageName, "library/"))
targetPath = imageName + targetPath
}
if allowed, reason := utils.GlobalAccessController.CheckDockerAccess(imageName); !allowed {
fmt.Printf("Docker镜像 %s 访问被拒绝: %s\n", imageName, reason)
c.String(http.StatusForbidden, "镜像访问被限制")
accessName := imageName
if target.Name != dockerHubName {
accessName = target.Name + "/" + imageName
}
if allowed, reason := utils.GlobalAccessController.CheckDockerAccess(accessName); !allowed {
fmt.Printf("Docker image %s access denied: %s\n", accessName, reason)
c.String(http.StatusForbidden, reason)
return
}
imageRef := fmt.Sprintf("%s/%s", dockerProxy.registry.Name(), imageName)
switch apiType {
case "manifests":
handleManifestRequest(c, imageRef, reference)
case "blobs":
handleBlobRequest(c, imageRef, reference)
case "tags":
handleTagsRequest(c, imageRef)
default:
c.String(http.StatusNotFound, "API endpoint not found")
}
proxyRegistryHTTP(c, target, "/v2/"+targetPath)
}
// parseRegistryPath 解析Registry路径
// parseRegistryPath parses a Docker Registry v2 path without the leading /v2/.
func parseRegistryPath(path string) (imageName, apiType, reference string) {
if idx := strings.Index(path, "/manifests/"); idx != -1 {
imageName = path[:idx]
apiType = "manifests"
reference = path[idx+len("/manifests/"):]
return
return path[:idx], "manifests", path[idx+len("/manifests/"):]
}
if idx := strings.Index(path, "/blobs/"); idx != -1 {
imageName = path[:idx]
apiType = "blobs"
reference = path[idx+len("/blobs/"):]
return
return path[:idx], "blobs", path[idx+len("/blobs/"):]
}
if idx := strings.Index(path, "/tags/list"); idx != -1 {
imageName = path[:idx]
apiType = "tags"
reference = "list"
return
return path[:idx], "tags", "list"
}
return "", "", ""
}
// handleManifestRequest 处理manifest请求
func handleManifestRequest(c *gin.Context, imageRef, reference string) {
if utils.IsCacheEnabled() && c.Request.Method == http.MethodGet {
cacheKey := utils.BuildManifestCacheKey(imageRef, reference)
if cachedItem := utils.GlobalCache.Get(cacheKey); cachedItem != nil {
utils.WriteCachedResponse(c, cachedItem)
return
}
}
var ref name.Reference
var err error
if strings.HasPrefix(reference, "sha256:") {
ref, err = name.NewDigest(fmt.Sprintf("%s@%s", imageRef, reference))
} else {
ref, err = name.NewTag(fmt.Sprintf("%s:%s", imageRef, reference))
}
if err != nil {
fmt.Printf("解析镜像引用失败: %v\n", err)
c.String(http.StatusBadRequest, "Invalid reference")
return
}
if c.Request.Method == http.MethodHead {
desc, err := remote.Head(ref, dockerProxy.options...)
if err != nil {
fmt.Printf("HEAD请求失败: %v\n", err)
c.String(http.StatusNotFound, "Manifest not found")
return
}
c.Header("Content-Type", string(desc.MediaType))
c.Header("Docker-Content-Digest", desc.Digest.String())
c.Header("Content-Length", fmt.Sprintf("%d", desc.Size))
c.Status(http.StatusOK)
} else {
desc, err := remote.Get(ref, dockerProxy.options...)
if err != nil {
fmt.Printf("GET请求失败: %v\n", err)
c.String(http.StatusNotFound, "Manifest not found")
return
}
headers := map[string]string{
"Docker-Content-Digest": desc.Digest.String(),
"Content-Length": fmt.Sprintf("%d", len(desc.Manifest)),
}
if utils.IsCacheEnabled() {
cacheKey := utils.BuildManifestCacheKey(imageRef, reference)
ttl := utils.GetManifestTTL(reference)
utils.GlobalCache.Set(cacheKey, desc.Manifest, string(desc.MediaType), headers, ttl)
}
c.Header("Content-Type", string(desc.MediaType))
for key, value := range headers {
c.Header(key, value)
}
c.Data(http.StatusOK, string(desc.MediaType), desc.Manifest)
}
}
// handleBlobRequest 处理blob请求
func handleBlobRequest(c *gin.Context, imageRef, digest string) {
digestRef, err := name.NewDigest(fmt.Sprintf("%s@%s", imageRef, digest))
if err != nil {
fmt.Printf("解析digest引用失败: %v\n", err)
c.String(http.StatusBadRequest, "Invalid digest reference")
return
}
layer, err := remote.Layer(digestRef, dockerProxy.options...)
if err != nil {
fmt.Printf("获取layer失败: %v\n", err)
c.String(http.StatusNotFound, "Layer not found")
return
}
size, err := layer.Size()
if err != nil {
fmt.Printf("获取layer大小失败: %v\n", err)
c.String(http.StatusInternalServerError, "Failed to get layer size")
return
}
reader, err := layer.Compressed()
if err != nil {
fmt.Printf("获取layer内容失败: %v\n", err)
c.String(http.StatusInternalServerError, "Failed to get layer content")
return
}
defer reader.Close()
c.Header("Content-Type", "application/octet-stream")
c.Header("Content-Length", fmt.Sprintf("%d", size))
c.Header("Docker-Content-Digest", digest)
c.Status(http.StatusOK)
if _, err := io.Copy(c.Writer, reader); err != nil {
fmt.Printf("复制layer内容失败: %v\n", err)
}
}
// handleTagsRequest 处理tags列表请求
func handleTagsRequest(c *gin.Context, imageRef string) {
repo, err := name.NewRepository(imageRef)
if err != nil {
fmt.Printf("解析repository失败: %v\n", err)
c.String(http.StatusBadRequest, "Invalid repository")
return
}
tags, err := remote.List(repo, dockerProxy.options...)
if err != nil {
fmt.Printf("获取tags失败: %v\n", err)
c.String(http.StatusNotFound, "Tags not found")
return
}
response := map[string]interface{}{
"name": strings.TrimPrefix(imageRef, dockerProxy.registry.Name()+"/"),
"tags": tags,
}
c.JSON(http.StatusOK, response)
}
// ProxyDockerAuthGin Docker认证代理
// ProxyDockerAuthGin forwards Docker token requests, including client Basic
// credentials, to the selected upstream auth service.
func ProxyDockerAuthGin(c *gin.Context) {
if utils.IsTokenCacheEnabled() {
proxyDockerAuthWithCache(c)
} else {
proxyDockerAuthOriginal(c)
}
}
// proxyDockerAuthWithCache 带缓存的认证代理
func proxyDockerAuthWithCache(c *gin.Context) {
cacheKey := utils.BuildTokenCacheKey(c.Request.URL.RawQuery)
if cachedToken := utils.GlobalCache.GetToken(cacheKey); cachedToken != "" {
utils.WriteTokenResponse(c, cachedToken)
target, ok := resolveTokenTarget(c)
if !ok {
c.String(http.StatusBadRequest, "Unknown registry target")
return
}
recorder := &ResponseRecorder{
ResponseWriter: c.Writer,
statusCode: 200,
}
c.Writer = recorder
proxyDockerAuthOriginal(c)
if recorder.statusCode == 200 && len(recorder.body) > 0 {
ttl := utils.ExtractTTLFromResponse(recorder.body)
utils.GlobalCache.SetToken(cacheKey, string(recorder.body), ttl)
}
c.Writer = recorder.ResponseWriter
c.Data(recorder.statusCode, "application/json", recorder.body)
}
// ResponseRecorder HTTP响应记录器
type ResponseRecorder struct {
gin.ResponseWriter
statusCode int
body []byte
}
func (r *ResponseRecorder) WriteHeader(code int) {
r.statusCode = code
}
func (r *ResponseRecorder) Write(data []byte) (int, error) {
r.body = append(r.body, data...)
return len(data), nil
}
func proxyDockerAuthOriginal(c *gin.Context) {
var authURL string
if targetDomain, exists := c.Get("target_registry_domain"); exists {
if mapping, found := registryDetector.getRegistryMapping(targetDomain.(string)); found {
authURL = "https://" + mapping.AuthHost + c.Request.URL.Path
} else {
authURL = "https://auth.docker.io" + c.Request.URL.Path
cacheable := c.GetHeader("Authorization") == "" && utils.IsTokenCacheEnabled() && c.Request.Method == http.MethodGet
cacheKey := utils.BuildTokenCacheKey(target.Name + ":" + c.Request.URL.RawQuery)
if cacheable {
if cachedToken := utils.GlobalCache.GetToken(cacheKey); cachedToken != "" {
utils.WriteTokenResponse(c, cachedToken)
return
}
} else {
authURL = "https://auth.docker.io" + c.Request.URL.Path
}
if c.Request.URL.RawQuery != "" {
authURL += "?" + c.Request.URL.RawQuery
}
client := &http.Client{
Timeout: 30 * time.Second,
Transport: utils.GetGlobalHTTPClient().Transport,
}
req, err := http.NewRequestWithContext(
context.Background(),
c.Request.Method,
authURL,
c.Request.Body,
)
authURL, err := buildAuthURL(target, c.Request.URL.RawQuery)
if err != nil {
c.String(http.StatusInternalServerError, "Failed to create request")
c.String(http.StatusInternalServerError, "Failed to build auth request")
return
}
for key, values := range c.Request.Header {
for _, value := range values {
req.Header.Add(key, value)
}
req, err := http.NewRequestWithContext(c.Request.Context(), c.Request.Method, authURL, nil)
if err != nil {
c.String(http.StatusInternalServerError, "Failed to create auth request")
return
}
forwardSelectedRequestHeaders(req.Header, c.Request.Header)
resp, err := client.Do(req)
resp, err := utils.GetGlobalHTTPClient().Do(req)
if err != nil {
c.String(http.StatusBadGateway, "Auth request failed")
return
}
defer resp.Body.Close()
proxyHost := c.Request.Host
if proxyHost == "" {
cfg := config.GetConfig()
proxyHost = fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port)
if cfg.Server.Host == "0.0.0.0" {
proxyHost = fmt.Sprintf("localhost:%d", cfg.Server.Port)
}
}
for key, values := range resp.Header {
for _, value := range values {
if key == "Www-Authenticate" {
value = rewriteAuthHeader(value, proxyHost)
}
c.Header(key, value)
}
}
c.Status(resp.StatusCode)
if _, err := io.Copy(c.Writer, resp.Body); err != nil {
fmt.Printf("复制认证响应失败: %v\n", err)
}
}
// rewriteAuthHeader 重写认证头
func rewriteAuthHeader(authHeader, proxyHost string) string {
authHeader = strings.ReplaceAll(authHeader, "https://auth.docker.io", "http://"+proxyHost)
authHeader = strings.ReplaceAll(authHeader, "https://ghcr.io", "http://"+proxyHost)
authHeader = strings.ReplaceAll(authHeader, "https://gcr.io", "http://"+proxyHost)
authHeader = strings.ReplaceAll(authHeader, "https://quay.io", "http://"+proxyHost)
return authHeader
}
// handleMultiRegistryRequest 处理多Registry请求
func handleMultiRegistryRequest(c *gin.Context, registryDomain, remainingPath string) {
mapping, exists := registryDetector.getRegistryMapping(registryDomain)
if !exists {
c.String(http.StatusBadRequest, "Registry not configured")
return
}
imageName, apiType, reference := parseRegistryPath(remainingPath)
if imageName == "" || apiType == "" {
c.String(http.StatusBadRequest, "Invalid path format")
return
}
fullImageName := registryDomain + "/" + imageName
if allowed, reason := utils.GlobalAccessController.CheckDockerAccess(fullImageName); !allowed {
fmt.Printf("镜像 %s 访问被拒绝: %s\n", fullImageName, reason)
c.String(http.StatusForbidden, "镜像访问被限制")
return
}
upstreamImageRef := fmt.Sprintf("%s/%s", mapping.Upstream, imageName)
switch apiType {
case "manifests":
handleUpstreamManifestRequest(c, upstreamImageRef, reference, mapping)
case "blobs":
handleUpstreamBlobRequest(c, upstreamImageRef, reference, mapping)
case "tags":
handleUpstreamTagsRequest(c, upstreamImageRef, mapping)
default:
c.String(http.StatusNotFound, "API endpoint not found")
}
}
// handleUpstreamManifestRequest 处理上游Registry的manifest请求
func handleUpstreamManifestRequest(c *gin.Context, imageRef, reference string, mapping config.RegistryMapping) {
if utils.IsCacheEnabled() && c.Request.Method == http.MethodGet {
cacheKey := utils.BuildManifestCacheKey(imageRef, reference)
if cachedItem := utils.GlobalCache.Get(cacheKey); cachedItem != nil {
utils.WriteCachedResponse(c, cachedItem)
return
}
}
var ref name.Reference
var err error
if strings.HasPrefix(reference, "sha256:") {
ref, err = name.NewDigest(fmt.Sprintf("%s@%s", imageRef, reference))
} else {
ref, err = name.NewTag(fmt.Sprintf("%s:%s", imageRef, reference))
}
body, err := io.ReadAll(resp.Body)
if err != nil {
fmt.Printf("解析镜像引用失败: %v\n", err)
c.String(http.StatusBadRequest, "Invalid reference")
c.String(http.StatusBadGateway, "Failed to read auth response")
return
}
options := createUpstreamOptions(mapping)
copyResponseHeaders(c, resp.Header, target)
if cacheable && resp.StatusCode == http.StatusOK && len(body) > 0 {
utils.GlobalCache.SetToken(cacheKey, string(body), utils.ExtractTTLFromResponse(body))
}
c.Data(resp.StatusCode, resp.Header.Get("Content-Type"), body)
}
func buildAuthURL(target registryTarget, rawQuery string) (string, error) {
authURL, err := url.Parse(target.AuthRealm)
if err != nil {
return "", err
}
query := authURL.Query()
query.Set("service", target.AuthService)
if rawQuery != "" {
incoming, err := url.ParseQuery(rawQuery)
if err != nil {
return "", err
}
for key, values := range incoming {
if strings.EqualFold(key, "service") {
continue
}
for _, value := range values {
if strings.EqualFold(key, "scope") && target.AutoLibraryPrefix {
value = addLibraryPrefixToScope(value)
}
query.Add(key, value)
}
}
}
authURL.RawQuery = query.Encode()
return authURL.String(), nil
}
func addLibraryPrefixToScope(scope string) string {
parts := strings.Split(scope, ":")
if len(parts) != 3 || parts[0] != "repository" || strings.Contains(parts[1], "/") {
return scope
}
return "repository:library/" + parts[1] + ":" + parts[2]
}
func proxyRegistryHTTP(c *gin.Context, target registryTarget, upstreamPath string) {
targetURL := target.Upstream + upstreamPath
if c.Request.URL.RawQuery != "" {
targetURL += "?" + c.Request.URL.RawQuery
}
req, err := http.NewRequestWithContext(c.Request.Context(), c.Request.Method, targetURL, nil)
if err != nil {
c.String(http.StatusInternalServerError, "Failed to create registry request")
return
}
forwardSelectedRequestHeaders(req.Header, c.Request.Header)
resp, err := utils.GetGlobalHTTPClient().Do(req)
if err != nil {
c.String(http.StatusBadGateway, "Registry request failed")
return
}
defer resp.Body.Close()
copyResponseHeaders(c, resp.Header, target)
c.Status(resp.StatusCode)
if c.Request.Method == http.MethodHead {
desc, err := remote.Head(ref, options...)
if err != nil {
fmt.Printf("HEAD请求失败: %v\n", err)
c.String(http.StatusNotFound, "Manifest not found")
return
}
c.Header("Content-Type", string(desc.MediaType))
c.Header("Docker-Content-Digest", desc.Digest.String())
c.Header("Content-Length", fmt.Sprintf("%d", desc.Size))
c.Status(http.StatusOK)
} else {
desc, err := remote.Get(ref, options...)
if err != nil {
fmt.Printf("GET请求失败: %v\n", err)
c.String(http.StatusNotFound, "Manifest not found")
return
}
headers := map[string]string{
"Docker-Content-Digest": desc.Digest.String(),
"Content-Length": fmt.Sprintf("%d", len(desc.Manifest)),
}
if utils.IsCacheEnabled() {
cacheKey := utils.BuildManifestCacheKey(imageRef, reference)
ttl := utils.GetManifestTTL(reference)
utils.GlobalCache.Set(cacheKey, desc.Manifest, string(desc.MediaType), headers, ttl)
}
c.Header("Content-Type", string(desc.MediaType))
for key, value := range headers {
c.Header(key, value)
}
c.Data(http.StatusOK, string(desc.MediaType), desc.Manifest)
return
}
if _, err := io.Copy(c.Writer, resp.Body); err != nil {
fmt.Printf("Failed to stream registry response: %v\n", err)
}
}
// handleUpstreamBlobRequest 处理上游Registry的blob请求
func handleUpstreamBlobRequest(c *gin.Context, imageRef, digest string, mapping config.RegistryMapping) {
digestRef, err := name.NewDigest(fmt.Sprintf("%s@%s", imageRef, digest))
if err != nil {
fmt.Printf("解析digest引用失败: %v\n", err)
c.String(http.StatusBadRequest, "Invalid digest reference")
return
}
options := createUpstreamOptions(mapping)
layer, err := remote.Layer(digestRef, options...)
if err != nil {
fmt.Printf("获取layer失败: %v\n", err)
c.String(http.StatusNotFound, "Layer not found")
return
}
size, err := layer.Size()
if err != nil {
fmt.Printf("获取layer大小失败: %v\n", err)
c.String(http.StatusInternalServerError, "Failed to get layer size")
return
}
reader, err := layer.Compressed()
if err != nil {
fmt.Printf("获取layer内容失败: %v\n", err)
c.String(http.StatusInternalServerError, "Failed to get layer content")
return
}
defer reader.Close()
c.Header("Content-Type", "application/octet-stream")
c.Header("Content-Length", fmt.Sprintf("%d", size))
c.Header("Docker-Content-Digest", digest)
c.Status(http.StatusOK)
if _, err := io.Copy(c.Writer, reader); err != nil {
fmt.Printf("复制layer内容失败: %v\n", err)
func forwardSelectedRequestHeaders(dst http.Header, src http.Header) {
for _, name := range forwardedRequestHeaders {
for _, value := range src.Values(name) {
dst.Add(name, value)
}
}
}
// handleUpstreamTagsRequest 处理上游Registry的tags请求
func handleUpstreamTagsRequest(c *gin.Context, imageRef string, mapping config.RegistryMapping) {
repo, err := name.NewRepository(imageRef)
if err != nil {
fmt.Printf("解析repository失败: %v\n", err)
c.String(http.StatusBadRequest, "Invalid repository")
return
}
func copyResponseHeaders(c *gin.Context, headers http.Header, target registryTarget) {
for name, values := range headers {
if shouldSkipResponseHeader(name) {
continue
}
options := createUpstreamOptions(mapping)
tags, err := remote.List(repo, options...)
if err != nil {
fmt.Printf("获取tags失败: %v\n", err)
c.String(http.StatusNotFound, "Tags not found")
return
for _, value := range values {
if strings.EqualFold(name, "WWW-Authenticate") {
value = rewriteAuthChallenge(value, target, publicBaseURL(c))
c.Header("WWW-Authenticate", value)
} else {
c.Header(name, value)
}
}
}
response := map[string]interface{}{
"name": strings.TrimPrefix(imageRef, mapping.Upstream+"/"),
"tags": tags,
}
c.JSON(http.StatusOK, response)
}
// createUpstreamOptions 创建上游Registry选项
func createUpstreamOptions(mapping config.RegistryMapping) []remote.Option {
options := []remote.Option{
remote.WithAuth(authn.Anonymous),
remote.WithUserAgent("hubproxy/go-containerregistry"),
remote.WithTransport(utils.GetGlobalHTTPClient().Transport),
}
// 预留将来不同Registry的差异化认证逻辑扩展点
switch mapping.AuthType {
case "github":
case "google":
case "quay":
}
return options
func shouldSkipResponseHeader(name string) bool {
_, hopByHop := hopByHopHeaders[strings.ToLower(name)]
return hopByHop
}
func publicBaseURL(c *gin.Context) string {
proto := "http"
if c.Request.TLS != nil {
proto = "https"
}
if forwardedProto := strings.TrimSpace(c.GetHeader("X-Forwarded-Proto")); forwardedProto != "" {
proto = strings.Split(forwardedProto, ",")[0]
}
host := c.Request.Host
if host == "" {
cfg := config.GetConfig()
host = fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port)
if cfg.Server.Host == "0.0.0.0" {
host = fmt.Sprintf("localhost:%d", cfg.Server.Port)
}
}
return strings.TrimRight(proto+"://"+host, "/")
}
func rewriteAuthChallenge(authHeader string, target registryTarget, baseURL string) string {
if !strings.HasPrefix(strings.TrimSpace(strings.ToLower(authHeader)), "bearer ") {
return authHeader
}
scope := bearerParam(authHeader, "scope")
challenge := fmt.Sprintf(
`Bearer realm="%s/token/%s",service="%s"`,
baseURL,
escapeAuthParam(target.Name),
escapeAuthParam(target.AuthService),
)
if scope != "" {
challenge += fmt.Sprintf(`,scope="%s"`, escapeAuthParam(scope))
}
return challenge
}
func bearerParam(authHeader, paramName string) string {
input := strings.TrimSpace(authHeader)
if len(input) < len("Bearer ") || !strings.EqualFold(input[:len("Bearer ")], "Bearer ") {
return ""
}
input = strings.TrimSpace(input[len("Bearer "):])
for input != "" {
input = strings.TrimLeft(input, ", \t")
key, rest, found := strings.Cut(input, "=")
if !found {
return ""
}
key = strings.TrimSpace(key)
rest = strings.TrimSpace(rest)
var value string
if strings.HasPrefix(rest, `"`) {
rest = rest[1:]
var b strings.Builder
escaped := false
end := -1
for i, r := range rest {
if escaped {
b.WriteRune(r)
escaped = false
continue
}
if r == '\\' {
escaped = true
continue
}
if r == '"' {
end = i + 1
break
}
b.WriteRune(r)
}
if end == -1 {
return ""
}
value = b.String()
input = rest[end:]
} else {
value, input, _ = strings.Cut(rest, ",")
value = strings.TrimSpace(value)
}
if strings.EqualFold(key, paramName) {
return value
}
}
return ""
}
func escapeAuthParam(value string) string {
value = strings.ReplaceAll(value, `\`, `\\`)
return strings.ReplaceAll(value, `"`, `\"`)
}

View File

@@ -1,6 +1,56 @@
package handlers
import "testing"
import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"testing"
"github.com/gin-gonic/gin"
"hubproxy/config"
"hubproxy/utils"
)
type zeroReader struct{}
func (zeroReader) Read(p []byte) (int, error) {
for i := range p {
p[i] = 0
}
return len(p), nil
}
type discardResponseWriter struct {
header http.Header
status int
bytes int64
}
func newDiscardResponseWriter() *discardResponseWriter {
return &discardResponseWriter{header: make(http.Header)}
}
func (w *discardResponseWriter) Header() http.Header {
return w.header
}
func (w *discardResponseWriter) WriteHeader(status int) {
w.status = status
}
func (w *discardResponseWriter) Write(p []byte) (int, error) {
if w.status == 0 {
w.status = http.StatusOK
}
w.bytes += int64(len(p))
return len(p), nil
}
func TestParseRegistryPath(t *testing.T) {
tests := []struct {
@@ -28,3 +78,546 @@ func TestParseRegistryPathInvalid(t *testing.T) {
t.Fatalf("invalid path parsed as %q %q %q", image, apiType, reference)
}
}
type testEnv interface {
Helper()
TempDir() string
Setenv(string, string)
Fatal(...interface{})
}
func initDockerProxyTest(t testEnv, configBody string) {
t.Helper()
path := filepath.Join(t.TempDir(), "config.toml")
if err := os.WriteFile(path, []byte(configBody), 0644); err != nil {
t.Fatal(err)
}
t.Setenv("CONFIG_PATH", path)
if err := config.LoadConfig(); err != nil {
t.Fatal(err)
}
utils.InitHTTPClients()
}
func TestRewriteAuthChallengePreservesScopeAndUsesProxyRealm(t *testing.T) {
target := registryTarget{
Name: "ghcr.io",
AuthService: "ghcr.io",
}
got := rewriteAuthChallenge(
`Bearer realm="https://ghcr.io/token",service="ghcr.io",scope="repository:owner/image:pull"`,
target,
"https://proxy.example.com",
)
want := `Bearer realm="https://proxy.example.com/token/ghcr.io",service="ghcr.io",scope="repository:owner/image:pull"`
if got != want {
t.Fatalf("challenge = %q, want %q", got, want)
}
}
func TestBuildAuthURLForDockerHubAddsLibraryScopeAndService(t *testing.T) {
got, err := buildAuthURL(
defaultRegistryTarget(),
"service=ignored&scope=repository%3Aalpine%3Apull&client_id=docker",
)
if err != nil {
t.Fatal(err)
}
if !strings.HasPrefix(got, dockerHubAuthRealm+"?") {
t.Fatalf("auth URL = %q", got)
}
if !strings.Contains(got, "service=registry.docker.io") {
t.Fatalf("auth URL missing service: %q", got)
}
if !strings.Contains(got, "scope=repository%3Alibrary%2Falpine%3Apull") {
t.Fatalf("auth URL missing normalized scope: %q", got)
}
}
func TestDockerIODefaultTargetUsesBuiltInWhenUnconfigured(t *testing.T) {
initDockerProxyTest(t, "")
target := defaultRegistryTarget()
if target.Upstream != dockerHubUpstream {
t.Fatalf("Upstream = %q, want %q", target.Upstream, dockerHubUpstream)
}
if target.AuthRealm != dockerHubAuthRealm {
t.Fatalf("AuthRealm = %q, want %q", target.AuthRealm, dockerHubAuthRealm)
}
if !target.AutoLibraryPrefix {
t.Fatal("AutoLibraryPrefix = false, want true")
}
}
func TestDockerIODefaultTargetCanBeOverriddenByConfig(t *testing.T) {
initDockerProxyTest(t, `
[registries."docker.io"]
upstream = "mirror.local"
authHost = "auth.mirror.local/token"
authType = "docker"
enabled = true
`)
target := defaultRegistryTarget()
if target.Upstream != "https://mirror.local" {
t.Fatalf("Upstream = %q, want custom mirror", target.Upstream)
}
if target.AuthRealm != "https://auth.mirror.local/token" {
t.Fatalf("AuthRealm = %q, want custom auth realm", target.AuthRealm)
}
if target.AuthService != dockerHubAuthService {
t.Fatalf("AuthService = %q, want %q", target.AuthService, dockerHubAuthService)
}
if !target.AutoLibraryPrefix {
t.Fatal("AutoLibraryPrefix = false, want true")
}
}
func TestDockerIODefaultTargetIgnoresDisabledOverride(t *testing.T) {
initDockerProxyTest(t, `
[registries."docker.io"]
upstream = "mirror.local"
authHost = "auth.mirror.local/token"
authType = "docker"
enabled = false
`)
target := defaultRegistryTarget()
if target.Upstream != dockerHubUpstream {
t.Fatalf("Upstream = %q, want built-in %q", target.Upstream, dockerHubUpstream)
}
}
func TestProxyDockerRegistryTransparentlyForwardsAuthAndRewritesChallenge(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v2/team/app/manifests/latest" {
t.Fatalf("upstream path = %q", r.URL.Path)
}
if got := r.Header.Get("Authorization"); got != "Bearer client-token" {
t.Fatalf("Authorization = %q", got)
}
if got := r.Header.Get("Accept"); got != "application/vnd.docker.distribution.manifest.v2+json" {
t.Fatalf("Accept = %q", got)
}
if got := r.Header.Get("Range"); got != "bytes=0-99" {
t.Fatalf("Range = %q", got)
}
w.Header().Set("WWW-Authenticate", `Bearer realm="https://upstream.example/token",service="upstream.example",scope="repository:team/app:pull"`)
w.WriteHeader(http.StatusUnauthorized)
}))
defer upstream.Close()
initDockerProxyTest(t, `
[registries."test.local"]
upstream = "`+upstream.URL+`"
authHost = "https://auth.test.local/token"
authType = "anonymous"
enabled = true
`)
gin.SetMode(gin.TestMode)
router := gin.New()
router.Any("/v2/*path", ProxyDockerRegistryGin)
req := httptest.NewRequest(http.MethodGet, "/v2/test.local/team/app/manifests/latest", nil)
req.Host = "proxy.example.com"
req.Header.Set("X-Forwarded-Proto", "https")
req.Header.Set("Authorization", "Bearer client-token")
req.Header.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
req.Header.Set("Range", "bytes=0-99")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Fatalf("status = %d, want 401; body=%s", w.Code, w.Body.String())
}
wantChallenge := `Bearer realm="https://proxy.example.com/token/test.local",service="` + strings.TrimPrefix(upstream.URL, "http://") + `",scope="repository:team/app:pull"`
if got := w.Header().Get("WWW-Authenticate"); got != wantChallenge {
t.Fatalf("WWW-Authenticate = %q, want %q", got, wantChallenge)
}
}
func TestProxyDockerAuthForwardsBasicCredentials(t *testing.T) {
authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got := r.Header.Get("Authorization"); got != "Basic dXNlcjpwYXNz" {
t.Fatalf("Authorization = %q", got)
}
if got := r.URL.Query().Get("service"); got != "127.0.0.1" {
t.Fatalf("service = %q", got)
}
if got := r.URL.Query().Get("scope"); got != "repository:team/app:pull" {
t.Fatalf("scope = %q", got)
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"token":"secret","expires_in":3600}`))
}))
defer authServer.Close()
initDockerProxyTest(t, `
[registries."test.local"]
upstream = "https://127.0.0.1"
authHost = "`+authServer.URL+`"
authType = "anonymous"
enabled = true
`)
gin.SetMode(gin.TestMode)
router := gin.New()
router.Any("/token/*path", ProxyDockerAuthGin)
req := httptest.NewRequest(http.MethodGet, "/token/test.local?scope=repository:team/app:pull", nil)
req.Header.Set("Authorization", "Basic dXNlcjpwYXNz")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status = %d, want 200; body=%s", w.Code, w.Body.String())
}
if got := w.Body.String(); !strings.Contains(got, `"token":"secret"`) {
t.Fatalf("body = %q", got)
}
}
func TestDockerHubShortNameIsProxiedWithLibraryPrefix(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v2/library/nginx/manifests/latest" {
t.Fatalf("upstream path = %q", r.URL.Path)
}
w.Header().Set("Content-Type", "application/vnd.docker.distribution.manifest.v2+json")
w.Header().Set("Docker-Content-Digest", "sha256:abc")
_, _ = w.Write([]byte(`{"schemaVersion":2}`))
}))
defer upstream.Close()
gin.SetMode(gin.TestMode)
router := gin.New()
router.Any("/v2/*path", ProxyDockerRegistryGin)
target := defaultRegistryTarget()
target.Upstream = upstream.URL
req := httptest.NewRequest(http.MethodGet, "/v2/nginx/manifests/latest", nil)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = req
proxyRegistryHTTP(c, target, "/v2/library/nginx/manifests/latest")
if w.Code != http.StatusOK {
t.Fatalf("status = %d, want 200; body=%s", w.Code, w.Body.String())
}
if got := w.Header().Get("Docker-Content-Digest"); got != "sha256:abc" {
t.Fatalf("Docker-Content-Digest = %q", got)
}
}
func TestProxyDockerRegistryHeadReturnsHeadersWithoutBody(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodHead {
t.Fatalf("method = %s, want HEAD", r.Method)
}
w.Header().Set("Content-Length", "123")
w.Header().Set("Docker-Content-Digest", "sha256:head")
w.WriteHeader(http.StatusOK)
}))
defer upstream.Close()
initDockerProxyTest(t, `
[registries."test.local"]
upstream = "`+upstream.URL+`"
authHost = "https://auth.test.local/token"
authType = "anonymous"
enabled = true
`)
gin.SetMode(gin.TestMode)
router := gin.New()
router.Any("/v2/*path", ProxyDockerRegistryGin)
req := httptest.NewRequest(http.MethodHead, "/v2/test.local/team/app/blobs/sha256:abc", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status = %d, want 200", w.Code)
}
if got := w.Header().Get("Docker-Content-Digest"); got != "sha256:head" {
t.Fatalf("Docker-Content-Digest = %q", got)
}
if body := w.Body.String(); body != "" {
t.Fatalf("HEAD body = %q, want empty", body)
}
}
func TestProxyDockerRegistryStreamsBlobAndSkipsHopByHopHeaders(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Connection", "close")
w.Header().Set("Transfer-Encoding", "chunked")
w.Header().Set("Content-Type", "application/octet-stream")
_, _ = io.WriteString(w, "layer-data")
}))
defer upstream.Close()
initDockerProxyTest(t, `
[registries."test.local"]
upstream = "`+upstream.URL+`"
authHost = "https://auth.test.local/token"
authType = "anonymous"
enabled = true
`)
gin.SetMode(gin.TestMode)
router := gin.New()
router.Any("/v2/*path", ProxyDockerRegistryGin)
req := httptest.NewRequest(http.MethodGet, "/v2/test.local/team/app/blobs/sha256:abc", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status = %d, want 200; body=%s", w.Code, w.Body.String())
}
if got := w.Body.String(); got != "layer-data" {
t.Fatalf("body = %q", got)
}
if got := w.Header().Get("Connection"); got != "" {
t.Fatalf("Connection header leaked: %q", got)
}
}
func TestProxyDockerRegistryUsesNsQueryForContainerd(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v2/team/app/manifests/latest" {
t.Fatalf("upstream path = %q", r.URL.Path)
}
if got := r.URL.Query().Get("ns"); got != "test.local" {
t.Fatalf("ns query = %q", got)
}
w.WriteHeader(http.StatusOK)
}))
defer upstream.Close()
initDockerProxyTest(t, `
[registries."test.local"]
upstream = "`+upstream.URL+`"
authHost = "https://auth.test.local/token"
authType = "anonymous"
enabled = true
`)
gin.SetMode(gin.TestMode)
router := gin.New()
router.Any("/v2/*path", ProxyDockerRegistryGin)
req := httptest.NewRequest(http.MethodGet, "/v2/team/app/manifests/latest?ns=test.local", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status = %d, want 200; body=%s", w.Code, w.Body.String())
}
}
func TestProxyDockerAuthCachesOnlyAnonymousTokenRequests(t *testing.T) {
var hits int32
authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
count := atomic.AddInt32(&hits, 1)
w.Header().Set("Content-Type", "application/json")
_, _ = fmt.Fprintf(w, `{"token":"token-%d","expires_in":3600}`, count)
}))
defer authServer.Close()
initDockerProxyTest(t, `
[registries."test.local"]
upstream = "https://test.local"
authHost = "`+authServer.URL+`"
authType = "anonymous"
enabled = true
`)
utils.GlobalCache = &utils.UniversalCache{}
gin.SetMode(gin.TestMode)
router := gin.New()
router.Any("/token/*path", ProxyDockerAuthGin)
for i := 0; i < 2; i++ {
req := httptest.NewRequest(http.MethodGet, "/token/test.local?scope=repository:team/app:pull", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("anonymous request %d status = %d; body=%s", i, w.Code, w.Body.String())
}
if got := w.Body.String(); !strings.Contains(got, `"token":"token-1"`) {
t.Fatalf("anonymous request %d body = %q", i, got)
}
}
if got := atomic.LoadInt32(&hits); got != 1 {
t.Fatalf("anonymous token hits = %d, want 1", got)
}
for i := 0; i < 2; i++ {
req := httptest.NewRequest(http.MethodGet, "/token/test.local?scope=repository:team/app:pull", nil)
req.Header.Set("Authorization", "Basic dXNlcjpwYXNz")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("authenticated request %d status = %d; body=%s", i, w.Code, w.Body.String())
}
}
if got := atomic.LoadInt32(&hits); got != 3 {
t.Fatalf("authenticated token hits total = %d, want 3", got)
}
}
func TestProxyDockerAuthRejectsUnknownRegistry(t *testing.T) {
initDockerProxyTest(t, "")
gin.SetMode(gin.TestMode)
router := gin.New()
router.Any("/token/*path", ProxyDockerAuthGin)
req := httptest.NewRequest(http.MethodGet, "/token/missing.local?scope=repository:team/app:pull", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("status = %d, want 400; body=%s", w.Code, w.Body.String())
}
}
func TestProxyDockerRegistryConcurrentRequests(t *testing.T) {
const requests = 64
var hits int32
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&hits, 1)
if got := r.Header.Get("Authorization"); got == "" {
t.Fatal("missing Authorization")
}
w.Header().Set("Content-Type", "application/octet-stream")
_, _ = w.Write([]byte("ok"))
}))
defer upstream.Close()
initDockerProxyTest(t, `
[registries."test.local"]
upstream = "`+upstream.URL+`"
authHost = "https://auth.test.local/token"
authType = "anonymous"
enabled = true
`)
gin.SetMode(gin.TestMode)
router := gin.New()
router.Any("/v2/*path", ProxyDockerRegistryGin)
var wg sync.WaitGroup
errs := make(chan string, requests)
for i := 0; i < requests; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
req := httptest.NewRequest(http.MethodGet, "/v2/test.local/team/app/blobs/sha256:abc", nil)
req.Header.Set("Authorization", fmt.Sprintf("Bearer token-%d", i))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK || w.Body.String() != "ok" {
errs <- fmt.Sprintf("request %d status=%d body=%q", i, w.Code, w.Body.String())
}
}(i)
}
wg.Wait()
close(errs)
for err := range errs {
t.Fatal(err)
}
if got := atomic.LoadInt32(&hits); got != requests {
t.Fatalf("hits = %d, want %d", got, requests)
}
}
func TestProxyDockerRegistryLargeBlobStreamsWithoutRecorderBuffer(t *testing.T) {
const blobSize = 8 << 20
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/octet-stream")
w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize))
_, _ = io.CopyN(w, zeroReader{}, blobSize)
}))
defer upstream.Close()
initDockerProxyTest(t, `
[registries."test.local"]
upstream = "`+upstream.URL+`"
authHost = "https://auth.test.local/token"
authType = "anonymous"
enabled = true
`)
gin.SetMode(gin.TestMode)
router := gin.New()
router.Any("/v2/*path", ProxyDockerRegistryGin)
req := httptest.NewRequest(http.MethodGet, "/v2/test.local/team/app/blobs/sha256:large", nil)
w := newDiscardResponseWriter()
router.ServeHTTP(w, req)
if w.status != http.StatusOK {
t.Fatalf("status = %d, want 200", w.status)
}
if w.bytes != blobSize {
t.Fatalf("streamed bytes = %d, want %d", w.bytes, blobSize)
}
if got := w.Header().Get("Content-Length"); got != fmt.Sprintf("%d", blobSize) {
t.Fatalf("Content-Length = %q", got)
}
}
func BenchmarkProxyDockerRegistryBlobStreaming(b *testing.B) {
const blobSize = 1 << 20
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/octet-stream")
w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize))
_, _ = io.CopyN(w, zeroReader{}, blobSize)
}))
defer upstream.Close()
initDockerProxyTest(b, `
[registries."bench.local"]
upstream = "`+upstream.URL+`"
authHost = "https://auth.bench.local/token"
authType = "anonymous"
enabled = true
`)
gin.SetMode(gin.TestMode)
router := gin.New()
router.Any("/v2/*path", ProxyDockerRegistryGin)
b.ReportAllocs()
b.SetBytes(blobSize)
b.ResetTimer()
for i := 0; i < b.N; i++ {
req := httptest.NewRequest(http.MethodGet, "/v2/bench.local/team/app/blobs/sha256:bench", nil)
w := newDiscardResponseWriter()
router.ServeHTTP(w, req)
if w.status != http.StatusOK || w.bytes != blobSize {
b.Fatalf("status=%d bytes=%d", w.status, w.bytes)
}
}
}