diff --git a/cmd/cli/commands/integration_test.go b/cmd/cli/commands/integration_test.go index d2640461..437ad33e 100644 --- a/cmd/cli/commands/integration_test.go +++ b/cmd/cli/commands/integration_test.go @@ -155,16 +155,26 @@ func ociRegistry(t *testing.T, ctx context.Context, net *testcontainers.DockerNe return registryURL } -func dockerModelRunner(t *testing.T, ctx context.Context, net *testcontainers.DockerNetwork) string { +// dmrConfig holds configuration options for Docker Model Runner container. +type dmrConfig struct { + envVars map[string]string // Optional environment variables to set + logMsg string // Custom log message (defaults to "Starting DMR container...") +} + +// startDockerModelRunner starts a DMR container with the given configuration. +// If config.envVars is nil or empty, no extra environment variables are set. +func startDockerModelRunner(t *testing.T, ctx context.Context, net *testcontainers.DockerNetwork, config dmrConfig) string { containerCustomizerOpts := []testcontainers.ContainerCustomizer{ testcontainers.WithExposedPorts("12434/tcp"), testcontainers.WithWaitStrategy(wait.ForHTTP("/engines/status").WithPort("12434/tcp").WithStartupTimeout(10 * time.Second)), - testcontainers.WithEnv(map[string]string{ - "DEFAULT_REGISTRY": "registry.local:5000", - "INSECURE_REGISTRY": "true", - }), network.WithNetwork([]string{"dmr"}, net), } + + // Add environment variables if provided + if len(config.envVars) > 0 { + containerCustomizerOpts = append(containerCustomizerOpts, testcontainers.WithEnv(config.envVars)) + } + if os.Getenv("BUILD_DMR") == "1" { t.Log("Building DMR container...") out, err := exec.CommandContext(ctx, "make", "-C", "../../..", "docker-build").CombinedOutput() @@ -175,7 +185,13 @@ func dockerModelRunner(t *testing.T, ctx context.Context, net *testcontainers.Do // Always pull the image if it's not build locally. containerCustomizerOpts = append(containerCustomizerOpts, testcontainers.WithAlwaysPull()) } - t.Log("Starting DMR container...") + + logMsg := config.logMsg + if logMsg == "" { + logMsg = "Starting DMR container..." + } + t.Log(logMsg) + ctr, err := testcontainers.Run( ctx, "docker/model-runner:latest", containerCustomizerOpts..., @@ -191,6 +207,17 @@ func dockerModelRunner(t *testing.T, ctx context.Context, net *testcontainers.Do return dmrURL } +// dockerModelRunner starts a DMR container configured for local registry tests. +// Sets DEFAULT_REGISTRY and INSECURE_REGISTRY environment variables. +func dockerModelRunner(t *testing.T, ctx context.Context, net *testcontainers.DockerNetwork) string { + return startDockerModelRunner(t, ctx, net, dmrConfig{ + envVars: map[string]string{ + "DEFAULT_REGISTRY": "registry.local:5000", + "INSECURE_REGISTRY": "true", + }, + }) +} + // removeModel removes a model from the local store func removeModel(client *desktop.Client, modelID string, force bool) error { _, err := client.Remove([]string{modelID}, force) @@ -1125,6 +1152,87 @@ func int32ptr(n int32) *int32 { return &n } +// setupDockerHubTestEnv creates a test environment for Docker Hub tests. +// Unlike setupTestEnv, this does NOT set DEFAULT_REGISTRY, so it uses +// the real Docker Hub (index.docker.io) as the default registry. +// This is used to test that pulling from Docker Hub works correctly. +func setupDockerHubTestEnv(t *testing.T) *testEnv { + ctx := context.Background() + + // Create a custom network for container communication + net, err := network.New(ctx) + require.NoError(t, err) + testcontainers.CleanupNetwork(t, net) + + // dockerModelRunnerForDockerHub starts a DMR container configured for Docker Hub tests. + // it uses the real Docker Hub as the default registry. + dmrURL := startDockerModelRunner(t, ctx, net, dmrConfig{ + logMsg: "Starting DMR container for Docker Hub tests (no DEFAULT_REGISTRY)...", + }) + + modelRunnerCtx, err := desktop.NewContextForTest(dmrURL, nil, types.ModelRunnerEngineKindMoby) + require.NoError(t, err, "Failed to create model runner context") + + client := desktop.New(modelRunnerCtx) + if !client.Status().Running { + t.Fatal("DMR is not running") + } + + return &testEnv{ + ctx: ctx, + client: client, + net: net, + } +} + +// TestIntegration_PullFromDockerHub is a smoke test that pulls a real model +// from Docker Hub to verify that the OCI registry code works correctly +// with the real Docker Hub registry (index.docker.io -> registry-1.docker.io). +// +// This test catches regressions where the code doesn't properly handle +// Docker Hub's hostname remapping requirements. +func TestIntegration_PullFromDockerHub(t *testing.T) { + env := setupDockerHubTestEnv(t) + + // Ensure no models exist initially + models, err := listModels(false, env.client, true, false, "") + require.NoError(t, err) + if len(models) != 0 { + t.Fatal("Expected no initial models, but found some") + } + + // Pull a small model from Docker Hub + // ai/smollm2:135M-Q4_0 is a small model that's quick to download + modelRef := "ai/smollm2:135M-Q4_0" + t.Logf("Pulling model from Docker Hub: %s", modelRef) + + err = pullModel(newPullCmd(), env.client, modelRef) + require.NoError(t, err, "Failed to pull model from Docker Hub: %s", modelRef) + + // Verify the model was pulled + t.Log("Verifying model was pulled successfully") + models, err = listModels(false, env.client, true, false, "") + require.NoError(t, err) + require.NotEmpty(t, strings.TrimSpace(models), "Model should exist after pull from Docker Hub") + + // Verify we can inspect the model + model, err := env.client.Inspect(modelRef, false) + require.NoError(t, err, "Failed to inspect model pulled from Docker Hub") + require.NotEmpty(t, model.ID, "Model ID should not be empty") + + t.Logf("✓ Successfully pulled model from Docker Hub: %s (ID: %s)", modelRef, model.ID[7:19]) + + // Cleanup: remove the model + t.Logf("Cleaning up: removing model %s", model.ID[7:19]) + err = removeModel(env.client, model.ID, true) + require.NoError(t, err, "Failed to remove model") + + // Verify model was removed + models, err = listModels(false, env.client, true, false, "") + require.NoError(t, err) + require.Empty(t, strings.TrimSpace(models), "Model should be removed after cleanup") +} + // normalizeRef normalizes a reference to its fully qualified form. // This is used in tests to compare against the stored tags which are always normalized. func normalizeRef(t *testing.T, ref string) string { diff --git a/go.work.sum b/go.work.sum index 2854cdf8..8cc55c05 100644 --- a/go.work.sum +++ b/go.work.sum @@ -10,6 +10,7 @@ cloud.google.com/go/compute v1.23.4 h1:EBT9Nw4q3zyE7G45Wvv3MzolIrCJEuHys5muLY0wv cloud.google.com/go/compute v1.23.4/go.mod h1:/EJMj55asU6kAFnuZET8zqgwgJ9FvXWXOkkfQZa4ioI= cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg= +cloud.google.com/go/compute/metadata v0.7.0 h1:PBWF+iiAerVNe8UCHxdOt6eHLVc3ydFeOCw78U8ytSU= cloud.google.com/go/compute/metadata v0.7.0/go.mod h1:j5MvL9PprKL39t166CoB1uVHfQMs4tFQZZcKwksXUjo= codeberg.org/go-fonts/dejavu v0.4.0 h1:2yn58Vkh4CFK3ipacWUAIE3XVBGNa0y1bc95Bmfx91I= codeberg.org/go-fonts/dejavu v0.4.0/go.mod h1:abni088lmhQJvso2Lsb7azCKzwkfcnttl6tL1UTWKzg= @@ -777,6 +778,7 @@ golang.org/x/oauth2 v0.26.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbht golang.org/x/oauth2 v0.27.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8= golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= golang.org/x/oauth2 v0.31.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +golang.org/x/oauth2 v0.32.0 h1:jsCblLleRMDrxMN29H3z/k1KliIvpLgCkE6R8FXXNgY= golang.org/x/oauth2 v0.32.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= diff --git a/pkg/distribution/oci/remote/remote.go b/pkg/distribution/oci/remote/remote.go index 89c6ad96..db6ec4b5 100644 --- a/pkg/distribution/oci/remote/remote.go +++ b/pkg/distribution/oci/remote/remote.go @@ -330,17 +330,31 @@ func isManifestMediaType(mediaType string) bool { return false } +// isHuggingFaceRegistry returns true if the host is a HuggingFace registry. +// HuggingFace doesn't serve manifests via /blobs/ endpoint, only via /manifests/. +func isHuggingFaceRegistry(host string) bool { + return strings.Contains(host, "huggingface.co") || strings.Contains(host, "hf.co") +} + // Fetch fetches content by descriptor. For manifests, it uses /manifests/ endpoint // to support registries like HuggingFace that don't serve manifests via /blobs/. +// For HuggingFace, we try /manifests/ first for ALL content types since they don't +// serve any manifest-like content via /blobs/. func (f *manifestFetcher) Fetch(ctx context.Context, desc v1.Descriptor) (io.ReadCloser, error) { - // For non-manifest content, use the underlying fetcher - if !isManifestMediaType(desc.MediaType) { + registry := f.ref.Context().Registry + isHF := isHuggingFaceRegistry(registry.RegistryStr()) + + // For HuggingFace, try /manifests/ first for any JSON-like content + // since they don't serve manifests via /blobs/ at all + shouldUseManifestEndpoint := isHF && (desc.MediaType == "application/json" || strings.Contains(desc.MediaType, "+json")) + + // For non-manifest content on non-HF registries, use the underlying fetcher + if !shouldUseManifestEndpoint { return f.underlying.Fetch(ctx, desc) } // For manifests, fetch via /manifests/ endpoint to support HuggingFace - // Build the manifest URL: /v2//manifests/ - registry := f.ref.Context().Registry + // Build the manifest URL: /v2//manifests/ repo := f.ref.Context().RepositoryStr() // Determine scheme based on plainHTTP flag or registry's default scheme @@ -349,11 +363,18 @@ func (f *manifestFetcher) Fetch(ctx context.Context, desc v1.Descriptor) (io.Rea scheme = "http" } + // For HuggingFace, use tag instead of digest because HF doesn't support + // fetching manifests by digest, only by tag + manifestRef := f.ref.Identifier() + if manifestRef == "" { + manifestRef = "latest" + } + url := fmt.Sprintf("%s://%s/v2/%s/manifests/%s", scheme, registry.RegistryStr(), repo, - desc.Digest.String()) + manifestRef) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) if err != nil {