Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DO NOT MERGE] Support multiple replicas for model adapter #205

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
502 changes: 295 additions & 207 deletions pkg/controller/modeladapter/modeladapter_controller.go

Large diffs are not rendered by default.

11 changes: 6 additions & 5 deletions pkg/controller/modeladapter/resources.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,25 @@ import (
discoveryv1 "k8s.io/api/discovery/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/intstr"
"k8s.io/utils/ptr"
)

func buildModelAdapterEndpointSlice(instance *modelv1alpha1.ModelAdapter, pod *corev1.Pod) (*discoveryv1.EndpointSlice, error) {
func buildModelAdapterEndpointSlice(instance *modelv1alpha1.ModelAdapter, pods []*corev1.Pod) (*discoveryv1.EndpointSlice, error) {
serviceLabels := map[string]string{
"kubernetes.io/service-name": instance.Name,
}

addresses := []discoveryv1.Endpoint{
{
Addresses: []string{pod.Status.PodIP},
Addresses: ExtractPodIPs(pods),
},
}

ports := []discoveryv1.EndpointPort{
{
Name: stringPtr("http"),
Protocol: protocolPtr(corev1.ProtocolTCP),
Port: int32Ptr(8000),
Name: ptr.To("http"),
Protocol: ptr.To(corev1.ProtocolTCP),
Port: ptr.To(int32(8000)),
},
}

Expand Down
10 changes: 6 additions & 4 deletions pkg/controller/modeladapter/resources_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,16 @@ func TestBuildModelAdapterEndpointSlice(t *testing.T) {
}

// Mock input for Pod
pod := &corev1.Pod{
Status: corev1.PodStatus{
PodIP: "192.168.1.1",
pods := []*corev1.Pod{
{
Status: corev1.PodStatus{
PodIP: "192.168.1.1",
},
},
}

// Call the function to test
endpointSlice, err := buildModelAdapterEndpointSlice(instance, pod)
endpointSlice, err := buildModelAdapterEndpointSlice(instance, pods)

// Assert no errors
assert.NoError(t, err)
Expand Down
6 changes: 3 additions & 3 deletions pkg/controller/modeladapter/scheduling/leastadapters.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ func NewLeastAdapters(c *cache.Cache) Scheduler {
}
}

func (r leastAdapters) SelectPod(ctx context.Context, pods []v1.Pod) (*v1.Pod, error) {
selectedPod := v1.Pod{}
func (r leastAdapters) SelectPod(ctx context.Context, pods []*v1.Pod) (*v1.Pod, error) {
selectedPod := &v1.Pod{}
modelAdapterCountMin := math.MaxInt

for _, pod := range pods {
Expand All @@ -51,5 +51,5 @@ func (r leastAdapters) SelectPod(ctx context.Context, pods []v1.Pod) (*v1.Pod, e
}

klog.Infof("pod selected with least model adapters: %s", selectedPod.Name)
return &selectedPod, nil
return selectedPod, nil
}
2 changes: 1 addition & 1 deletion pkg/controller/modeladapter/scheduling/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import (

type Scheduler interface {
// SelectPod returns the pod to schedule model adapter
SelectPod(ctx context.Context, pods []v1.Pod) (*v1.Pod, error)
SelectPod(ctx context.Context, pods []*v1.Pod) (*v1.Pod, error)
}

// NewScheduler leverages the factory method to choose the right scheduler
Expand Down
90 changes: 72 additions & 18 deletions pkg/controller/modeladapter/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,31 +19,16 @@ package modeladapter
import (
"errors"
"fmt"

"net/url"
"os"
"strings"

corev1 "k8s.io/api/core/v1"

modelv1alpha1 "github.com/aibrix/aibrix/api/model/v1alpha1"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)

func stringPtr(s string) *string {
return &s
}

func protocolPtr(p corev1.Protocol) *corev1.Protocol {
return &p
}

func int32Ptr(i int32) *int32 {
return &i
}

func mapPtr(m map[string]string) *map[string]string {
return &m
}

func validateModelAdapter(instance *modelv1alpha1.ModelAdapter) error {
if instance.Spec.ArtifactURL == "" {
return fmt.Errorf("artifactURL is required")
Expand Down Expand Up @@ -127,3 +112,72 @@ func extractHuggingFacePath(artifactURL string) (string, error) {

return path, nil
}

// ExtractPodNames takes a list of Pods and returns a list of their names.
func ExtractPodNames(pods []*corev1.Pod) []string {
podNames := make([]string, len(pods))
for i, pod := range pods {
podNames[i] = pod.Name
}
return podNames
}

// ExtractPodIPs takes a list of Pods and returns a list of their ips.
func ExtractPodIPs(pods []*corev1.Pod) []string {
podIPs := make([]string, len(pods))
for i, pod := range pods {
podIPs[i] = pod.Status.PodIP
}
return podIPs
}

func getCandidatePods(backendPods []corev1.Pod, podsWithModelAdapter []*corev1.Pod) []*corev1.Pod {
// Step 1: Create a set of pod names from podsWithModelAdapter
modelAdapterPodNames := make(map[string]bool)
for _, pod := range podsWithModelAdapter {
modelAdapterPodNames[pod.Name] = true
}

// Step 2: Iterate through backendPods and find pods that are not in podsWithModelAdapter
var candidatePods []*corev1.Pod
for _, pod := range backendPods {
// Create a copy of the loop variable to avoid exporting the same pointer
podCopy := pod.DeepCopy()
if !modelAdapterPodNames[pod.Name] {
// Pod is not in podsWithModelAdapter, add to candidatePods
candidatePods = append(candidatePods, podCopy)
}
}

return candidatePods
}

// Difference returns the elements in `a` that are not in `b`
func Difference(a, b []string) []string {
// Create a set (map) to store elements of b for quick lookup
bSet := make(map[string]struct{}, len(b))
for _, item := range b {
bSet[item] = struct{}{}
}

// Iterate over a and keep only elements not in bSet
var diff []string
for _, item := range a {
if _, found := bSet[item]; !found {
diff = append(diff, item)
}
}

return diff
}

// NewCondition creates a new replicaset condition.
func NewCondition(condType string, status metav1.ConditionStatus, reason, msg string) metav1.Condition {
return metav1.Condition{
Type: condType,
Status: status,
LastTransitionTime: metav1.Now(),
Reason: reason,
Message: msg,
}
}
3 changes: 2 additions & 1 deletion pkg/controller/modeladapter/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (

"github.com/stretchr/testify/assert"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/utils/ptr"

modelv1alpha1 "github.com/aibrix/aibrix/api/model/v1alpha1"
)
Expand All @@ -36,7 +37,7 @@ func TestValidateModelAdapter(t *testing.T) {
PodSelector: &metav1.LabelSelector{
MatchLabels: map[string]string{"app": "test"},
},
Replicas: int32Ptr(1),
Replicas: ptr.To(int32(1)),
},
}

Expand Down
Loading