Skip to content

Commit

Permalink
feat(worker): replace defaultImage with overrides struct (#293)
Browse files Browse the repository at this point in the history
This commit introduces an `overrides` struct to replace the `defaultImage`. The new struct allows overriding both the default image and pipeline-specific images. This enhancement enables orchestrators and developers to specify custom images for specific pipelines, providing greater flexibility and configurability.
---------

Co-authored-by: Rick Staa <[email protected]>
Co-authored-by: Victor Elias <[email protected]>
  • Loading branch information
3 people authored Jan 28, 2025
1 parent a11302b commit 5fc1474
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 18 deletions.
33 changes: 23 additions & 10 deletions worker/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,27 @@ var containerHostPorts = map[string]string{
"live-video-to-video": "8900",
}

// Mapping for per pipeline container images.
// Default pipeline container image mapping to use if no overrides are provided.
var defaultBaseImage = "livepeer/ai-runner:latest"
var pipelineToImage = map[string]string{
"segment-anything-2": "livepeer/ai-runner:segment-anything-2",
"text-to-speech": "livepeer/ai-runner:text-to-speech",
"audio-to-text": "livepeer/ai-runner:audio-to-text",
"llm": "livepeer/ai-runner:llm",
}

var livePipelineToImage = map[string]string{
"streamdiffusion": "livepeer/ai-runner:live-app-streamdiffusion",
"comfyui": "livepeer/ai-runner:live-app-comfyui",
"segment_anything_2": "livepeer/ai-runner:live-app-segment_anything_2",
"noop": "livepeer/ai-runner:live-app-noop",
}

type ImageOverrides struct {
Default string `json:"default"`
Batch map[string]string `json:"batch"`
Live map[string]string `json:"live"`
}

// DockerClient is an interface for the Docker client, allowing for mocking in tests.
// NOTE: ensure any docker.Client methods used in this package are added.
type DockerClient interface {
Expand All @@ -91,9 +97,9 @@ var _ DockerClient = (*docker.Client)(nil)
var dockerWaitUntilRunningFunc = dockerWaitUntilRunning

type DockerManager struct {
defaultImage string
gpus []string
modelDir string
gpus []string
modelDir string
overrides ImageOverrides

dockerClient DockerClient
// gpu ID => container name
Expand All @@ -103,7 +109,7 @@ type DockerManager struct {
mu *sync.Mutex
}

func NewDockerManager(defaultImage string, gpus []string, modelDir string, client DockerClient) (*DockerManager, error) {
func NewDockerManager(overrides ImageOverrides, gpus []string, modelDir string, client DockerClient) (*DockerManager, error) {
ctx, cancel := context.WithTimeout(context.Background(), containerTimeout)
if err := removeExistingContainers(ctx, client); err != nil {
cancel()
Expand All @@ -112,9 +118,9 @@ func NewDockerManager(defaultImage string, gpus []string, modelDir string, clien
cancel()

manager := &DockerManager{
defaultImage: defaultImage,
gpus: gpus,
modelDir: modelDir,
overrides: overrides,
dockerClient: client,
gpuContainers: make(map[string]string),
containers: make(map[string]*RunnerContainer),
Expand Down Expand Up @@ -215,17 +221,24 @@ func (m *DockerManager) returnContainer(rc *RunnerContainer) {
func (m *DockerManager) getContainerImageName(pipeline, modelID string) (string, error) {
if pipeline == "live-video-to-video" {
// We currently use the model ID as the live pipeline name for legacy reasons.
if image, ok := livePipelineToImage[modelID]; ok {
if image, ok := m.overrides.Live[modelID]; ok {
return image, nil
} else if image, ok := livePipelineToImage[modelID]; ok {
return image, nil
}
return "", fmt.Errorf("no container image found for live pipeline %s", modelID)
}

if image, ok := pipelineToImage[pipeline]; ok {
if image, ok := m.overrides.Batch[pipeline]; ok {
return image, nil
} else if image, ok := pipelineToImage[pipeline]; ok {
return image, nil
}

return m.defaultImage, nil
if m.overrides.Default != "" {
return m.overrides.Default, nil
}
return defaultBaseImage, nil
}

// HasCapacity checks if an unused managed container exists or if a GPU is available for a new container.
Expand Down
95 changes: 89 additions & 6 deletions worker/docker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ func NewMockServer() *MockServer {
// createDockerManager creates a DockerManager with a mock DockerClient.
func createDockerManager(mockDockerClient *MockDockerClient) *DockerManager {
return &DockerManager{
defaultImage: "default-image",
gpus: []string{"gpu0"},
modelDir: "/models",
overrides: ImageOverrides{Default: "default-image"},
dockerClient: mockDockerClient,
gpuContainers: make(map[string]string),
containers: make(map[string]*RunnerContainer),
Expand All @@ -110,10 +110,10 @@ func TestNewDockerManager(t *testing.T) {
mockDockerClient := new(MockDockerClient)

createAndVerifyManager := func() *DockerManager {
manager, err := NewDockerManager("default-image", []string{"gpu0"}, "/models", mockDockerClient)
manager, err := NewDockerManager(ImageOverrides{Default: "default-image"}, []string{"gpu0"}, "/models", mockDockerClient)
require.NoError(t, err)
require.NotNil(t, manager)
require.Equal(t, "default-image", manager.defaultImage)
require.Equal(t, "default-image", manager.overrides.Default)
require.Equal(t, []string{"gpu0"}, manager.gpus)
require.Equal(t, "/models", manager.modelDir)
require.Equal(t, mockDockerClient, manager.dockerClient)
Expand Down Expand Up @@ -301,47 +301,130 @@ func TestDockerManager_returnContainer(t *testing.T) {

func TestDockerManager_getContainerImageName(t *testing.T) {
mockDockerClient := new(MockDockerClient)
manager := createDockerManager(mockDockerClient)
dockerManager := createDockerManager(mockDockerClient)

tests := []struct {
name string
setup func(*DockerManager, *MockDockerClient)
pipeline string
modelID string
expectedImage string
expectError bool
}{
{
name: "live-video-to-video with valid modelID",
setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) {},
pipeline: "live-video-to-video",
modelID: "streamdiffusion",
expectedImage: "livepeer/ai-runner:live-app-streamdiffusion",
expectError: false,
},
{
name: "live-video-to-video with invalid modelID",
setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) {},
pipeline: "live-video-to-video",
modelID: "invalid-model",
expectError: true,
},
{
name: "valid pipeline",
setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) {},
pipeline: "text-to-speech",
modelID: "",
expectedImage: "livepeer/ai-runner:text-to-speech",
expectError: false,
},
{
name: "invalid pipeline",
setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) {},
pipeline: "invalid-pipeline",
modelID: "",
expectedImage: "default-image",
expectError: false,
},
{
name: "override default image",
setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) {
dockerManager.overrides = ImageOverrides{
Default: "custom-image",
}
},
pipeline: "",
modelID: "",
expectedImage: "custom-image",
expectError: false,
},
{
name: "override batch image",
setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) {
dockerManager.overrides = ImageOverrides{
Batch: map[string]string{
"text-to-speech": "custom-image",
},
}
},
pipeline: "text-to-speech",
modelID: "",
expectedImage: "custom-image",
expectError: false,
},
{
name: "override live image",
setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) {
dockerManager.overrides = ImageOverrides{
Live: map[string]string{
"streamdiffusion": "custom-image",
},
}
},
pipeline: "live-video-to-video",
modelID: "streamdiffusion",
expectedImage: "custom-image",
expectError: false,
},
{
name: "non-overridden batch image",
setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) {
dockerManager.overrides = ImageOverrides{
Default: "default-image",
Batch: map[string]string{
"text-to-speech": "custom-batch-image",
},
Live: map[string]string{
"streamdiffusion": "custom-live-image",
},
}
},
pipeline: "audio-to-text",
modelID: "",
expectedImage: "livepeer/ai-runner:audio-to-text",
expectError: false,
},
{
name: "non-overridden live image",
setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) {
dockerManager.overrides = ImageOverrides{
Default: "default-image",
Batch: map[string]string{
"text-to-speech": "custom-batch-image",
},
Live: map[string]string{
"streamdiffusion": "custom-live-image",
},
}
},
pipeline: "live-video-to-video",
modelID: "comfyui",
expectedImage: "livepeer/ai-runner:live-app-comfyui",
expectError: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
image, err := manager.getContainerImageName(tt.pipeline, tt.modelID)
tt.setup(dockerManager, mockDockerClient)

image, err := dockerManager.getContainerImageName(tt.pipeline, tt.modelID)
if tt.expectError {
require.Error(t, err)
require.Equal(t, fmt.Sprintf("no container image found for live pipeline %s", tt.modelID), err.Error())
Expand Down Expand Up @@ -500,7 +583,7 @@ func TestDockerManager_createContainer(t *testing.T) {
dockerManager.gpus = []string{gpu}
dockerManager.gpuContainers = make(map[string]string)
dockerManager.containers = make(map[string]*RunnerContainer)
dockerManager.defaultImage = containerImage
dockerManager.overrides.Default = containerImage

mockDockerClient.On("ContainerCreate", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(container.CreateResponse{ID: containerID}, nil)
mockDockerClient.On("ContainerStart", mock.Anything, containerID, mock.Anything).Return(nil)
Expand Down
4 changes: 2 additions & 2 deletions worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ type Worker struct {
mu *sync.Mutex
}

func NewWorker(defaultImage string, gpus []string, modelDir string) (*Worker, error) {
func NewWorker(imageOverrides ImageOverrides, gpus []string, modelDir string) (*Worker, error) {
dockerClient, err := docker.NewClientWithOpts(docker.FromEnv, docker.WithAPIVersionNegotiation())
if err != nil {
return nil, err
}

manager, err := NewDockerManager(defaultImage, gpus, modelDir, dockerClient)
manager, err := NewDockerManager(imageOverrides, gpus, modelDir, dockerClient)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 5fc1474

Please sign in to comment.