Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 114 additions & 6 deletions cmd/cli/commands/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,16 +155,26 @@ func ociRegistry(t *testing.T, ctx context.Context, net *testcontainers.DockerNe
return registryURL
}

func dockerModelRunner(t *testing.T, ctx context.Context, net *testcontainers.DockerNetwork) string {
// dmrConfig holds configuration options for Docker Model Runner container.
type dmrConfig struct {
envVars map[string]string // Optional environment variables to set
logMsg string // Custom log message (defaults to "Starting DMR container...")
}

// startDockerModelRunner starts a DMR container with the given configuration.
// If config.envVars is nil or empty, no extra environment variables are set.
func startDockerModelRunner(t *testing.T, ctx context.Context, net *testcontainers.DockerNetwork, config dmrConfig) string {
containerCustomizerOpts := []testcontainers.ContainerCustomizer{
testcontainers.WithExposedPorts("12434/tcp"),
testcontainers.WithWaitStrategy(wait.ForHTTP("/engines/status").WithPort("12434/tcp").WithStartupTimeout(10 * time.Second)),
testcontainers.WithEnv(map[string]string{
"DEFAULT_REGISTRY": "registry.local:5000",
"INSECURE_REGISTRY": "true",
}),
network.WithNetwork([]string{"dmr"}, net),
}

// Add environment variables if provided
if len(config.envVars) > 0 {
containerCustomizerOpts = append(containerCustomizerOpts, testcontainers.WithEnv(config.envVars))
}

if os.Getenv("BUILD_DMR") == "1" {
t.Log("Building DMR container...")
out, err := exec.CommandContext(ctx, "make", "-C", "../../..", "docker-build").CombinedOutput()
Expand All @@ -175,7 +185,13 @@ func dockerModelRunner(t *testing.T, ctx context.Context, net *testcontainers.Do
// Always pull the image if it's not build locally.
containerCustomizerOpts = append(containerCustomizerOpts, testcontainers.WithAlwaysPull())
}
t.Log("Starting DMR container...")

logMsg := config.logMsg
if logMsg == "" {
logMsg = "Starting DMR container..."
}
t.Log(logMsg)

ctr, err := testcontainers.Run(
ctx, "docker/model-runner:latest",
containerCustomizerOpts...,
Expand All @@ -191,6 +207,17 @@ func dockerModelRunner(t *testing.T, ctx context.Context, net *testcontainers.Do
return dmrURL
}

// dockerModelRunner starts a DMR container configured for local registry tests.
// Sets DEFAULT_REGISTRY and INSECURE_REGISTRY environment variables.
func dockerModelRunner(t *testing.T, ctx context.Context, net *testcontainers.DockerNetwork) string {
return startDockerModelRunner(t, ctx, net, dmrConfig{
envVars: map[string]string{
"DEFAULT_REGISTRY": "registry.local:5000",
"INSECURE_REGISTRY": "true",
},
})
}

// removeModel removes a model from the local store
func removeModel(client *desktop.Client, modelID string, force bool) error {
_, err := client.Remove([]string{modelID}, force)
Expand Down Expand Up @@ -1125,6 +1152,87 @@ func int32ptr(n int32) *int32 {
return &n
}

// setupDockerHubTestEnv creates a test environment for Docker Hub tests.
// Unlike setupTestEnv, this does NOT set DEFAULT_REGISTRY, so it uses
// the real Docker Hub (index.docker.io) as the default registry.
// This is used to test that pulling from Docker Hub works correctly.
func setupDockerHubTestEnv(t *testing.T) *testEnv {
ctx := context.Background()

// Create a custom network for container communication
net, err := network.New(ctx)
require.NoError(t, err)
testcontainers.CleanupNetwork(t, net)

// dockerModelRunnerForDockerHub starts a DMR container configured for Docker Hub tests.
// it uses the real Docker Hub as the default registry.
dmrURL := startDockerModelRunner(t, ctx, net, dmrConfig{
logMsg: "Starting DMR container for Docker Hub tests (no DEFAULT_REGISTRY)...",
})

modelRunnerCtx, err := desktop.NewContextForTest(dmrURL, nil, types.ModelRunnerEngineKindMoby)
require.NoError(t, err, "Failed to create model runner context")

client := desktop.New(modelRunnerCtx)
if !client.Status().Running {
t.Fatal("DMR is not running")
}

return &testEnv{
ctx: ctx,
client: client,
net: net,
}
}

// TestIntegration_PullFromDockerHub is a smoke test that pulls a real model
// from Docker Hub to verify that the OCI registry code works correctly
// with the real Docker Hub registry (index.docker.io -> registry-1.docker.io).
//
// This test catches regressions where the code doesn't properly handle
// Docker Hub's hostname remapping requirements.
func TestIntegration_PullFromDockerHub(t *testing.T) {
env := setupDockerHubTestEnv(t)

// Ensure no models exist initially
models, err := listModels(false, env.client, true, false, "")
require.NoError(t, err)
if len(models) != 0 {
t.Fatal("Expected no initial models, but found some")
}

// Pull a small model from Docker Hub
// ai/smollm2:135M-Q4_0 is a small model that's quick to download
modelRef := "ai/smollm2:135M-Q4_0"
t.Logf("Pulling model from Docker Hub: %s", modelRef)

err = pullModel(newPullCmd(), env.client, modelRef)
require.NoError(t, err, "Failed to pull model from Docker Hub: %s", modelRef)

// Verify the model was pulled
t.Log("Verifying model was pulled successfully")
models, err = listModels(false, env.client, true, false, "")
require.NoError(t, err)
require.NotEmpty(t, strings.TrimSpace(models), "Model should exist after pull from Docker Hub")

// Verify we can inspect the model
model, err := env.client.Inspect(modelRef, false)
require.NoError(t, err, "Failed to inspect model pulled from Docker Hub")
require.NotEmpty(t, model.ID, "Model ID should not be empty")

t.Logf("✓ Successfully pulled model from Docker Hub: %s (ID: %s)", modelRef, model.ID[7:19])

// Cleanup: remove the model
t.Logf("Cleaning up: removing model %s", model.ID[7:19])
err = removeModel(env.client, model.ID, true)
require.NoError(t, err, "Failed to remove model")

// Verify model was removed
models, err = listModels(false, env.client, true, false, "")
require.NoError(t, err)
require.Empty(t, strings.TrimSpace(models), "Model should be removed after cleanup")
}

// normalizeRef normalizes a reference to its fully qualified form.
// This is used in tests to compare against the stored tags which are always normalized.
func normalizeRef(t *testing.T, ref string) string {
Expand Down
2 changes: 2 additions & 0 deletions go.work.sum
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ cloud.google.com/go/compute v1.23.4 h1:EBT9Nw4q3zyE7G45Wvv3MzolIrCJEuHys5muLY0wv
cloud.google.com/go/compute v1.23.4/go.mod h1:/EJMj55asU6kAFnuZET8zqgwgJ9FvXWXOkkfQZa4ioI=
cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg=
cloud.google.com/go/compute/metadata v0.7.0 h1:PBWF+iiAerVNe8UCHxdOt6eHLVc3ydFeOCw78U8ytSU=
cloud.google.com/go/compute/metadata v0.7.0/go.mod h1:j5MvL9PprKL39t166CoB1uVHfQMs4tFQZZcKwksXUjo=
codeberg.org/go-fonts/dejavu v0.4.0 h1:2yn58Vkh4CFK3ipacWUAIE3XVBGNa0y1bc95Bmfx91I=
codeberg.org/go-fonts/dejavu v0.4.0/go.mod h1:abni088lmhQJvso2Lsb7azCKzwkfcnttl6tL1UTWKzg=
Expand Down Expand Up @@ -777,6 +778,7 @@ golang.org/x/oauth2 v0.26.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbht
golang.org/x/oauth2 v0.27.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8=
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
golang.org/x/oauth2 v0.31.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
golang.org/x/oauth2 v0.32.0 h1:jsCblLleRMDrxMN29H3z/k1KliIvpLgCkE6R8FXXNgY=
golang.org/x/oauth2 v0.32.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
Expand Down
31 changes: 26 additions & 5 deletions pkg/distribution/oci/remote/remote.go
Original file line number Diff line number Diff line change
Expand Up @@ -330,17 +330,31 @@ func isManifestMediaType(mediaType string) bool {
return false
}

// isHuggingFaceRegistry returns true if the host is a HuggingFace registry.
// HuggingFace doesn't serve manifests via /blobs/ endpoint, only via /manifests/.
func isHuggingFaceRegistry(host string) bool {
return strings.Contains(host, "huggingface.co") || strings.Contains(host, "hf.co")
}

// Fetch fetches content by descriptor. For manifests, it uses /manifests/ endpoint
// to support registries like HuggingFace that don't serve manifests via /blobs/.
// For HuggingFace, we try /manifests/ first for ALL content types since they don't
// serve any manifest-like content via /blobs/.
func (f *manifestFetcher) Fetch(ctx context.Context, desc v1.Descriptor) (io.ReadCloser, error) {
// For non-manifest content, use the underlying fetcher
if !isManifestMediaType(desc.MediaType) {
registry := f.ref.Context().Registry
isHF := isHuggingFaceRegistry(registry.RegistryStr())

// For HuggingFace, try /manifests/ first for any JSON-like content
// since they don't serve manifests via /blobs/ at all
shouldUseManifestEndpoint := isHF && (desc.MediaType == "application/json" || strings.Contains(desc.MediaType, "+json"))

// For non-manifest content on non-HF registries, use the underlying fetcher
if !shouldUseManifestEndpoint {
return f.underlying.Fetch(ctx, desc)
}

// For manifests, fetch via /manifests/ endpoint to support HuggingFace
// Build the manifest URL: /v2/<repo>/manifests/<digest>
registry := f.ref.Context().Registry
// Build the manifest URL: /v2/<repo>/manifests/<reference>
repo := f.ref.Context().RepositoryStr()

// Determine scheme based on plainHTTP flag or registry's default scheme
Expand All @@ -349,11 +363,18 @@ func (f *manifestFetcher) Fetch(ctx context.Context, desc v1.Descriptor) (io.Rea
scheme = "http"
}

// For HuggingFace, use tag instead of digest because HF doesn't support
// fetching manifests by digest, only by tag
manifestRef := f.ref.Identifier()
if manifestRef == "" {
manifestRef = "latest"
}

url := fmt.Sprintf("%s://%s/v2/%s/manifests/%s",
scheme,
registry.RegistryStr(),
repo,
desc.Digest.String())
manifestRef)

req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody)
if err != nil {
Expand Down
Loading