Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions pkg/epp/backend/metrics/fake.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,13 @@ import (
"k8s.io/apimachinery/pkg/types"
"sigs.k8s.io/controller-runtime/pkg/log"

"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
)

// FakePodMetrics is an implementation of PodMetrics that doesn't run the async refresh loop.
type FakePodMetrics struct {
Pod *backend.Pod
Metadata *datalayer.EndpointMetadata
Metrics *MetricsState
Attributes *datalayer.Attributes
}
Expand All @@ -41,16 +40,16 @@ func (fpm *FakePodMetrics) String() string {
return fmt.Sprintf("Metadata: %v; Metrics: %v", fpm.GetMetadata(), fpm.GetMetrics())
}

func (fpm *FakePodMetrics) GetMetadata() *backend.Pod {
return fpm.Pod
func (fpm *FakePodMetrics) GetMetadata() *datalayer.EndpointMetadata {
return fpm.Metadata
}

func (fpm *FakePodMetrics) GetMetrics() *MetricsState {
return fpm.Metrics
}

func (fpm *FakePodMetrics) UpdateMetadata(metadata *datalayer.EndpointMetadata) {
fpm.Pod = metadata
fpm.Metadata = metadata
}
func (fpm *FakePodMetrics) GetAttributes() *datalayer.Attributes {
return fpm.Attributes
Expand All @@ -72,7 +71,7 @@ type FakePodMetricsClient struct {
Res map[types.NamespacedName]*MetricsState
}

func (f *FakePodMetricsClient) FetchMetrics(ctx context.Context, pod *backend.Pod, existing *MetricsState) (*MetricsState, error) {
func (f *FakePodMetricsClient) FetchMetrics(ctx context.Context, pod *datalayer.EndpointMetadata, existing *MetricsState) (*MetricsState, error) {
f.errMu.RLock()
err, ok := f.Err[pod.NamespacedName]
f.errMu.RUnlock()
Expand Down
14 changes: 7 additions & 7 deletions pkg/epp/backend/metrics/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import (
"github.com/prometheus/common/model"
"go.uber.org/multierr"

"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
)

const (
Expand All @@ -50,22 +50,22 @@ type PodMetricsClientImpl struct {
}

// FetchMetrics fetches metrics from a given pod, clones the existing metrics object and returns an updated one.
func (p *PodMetricsClientImpl) FetchMetrics(ctx context.Context, pod *backend.Pod, existing *MetricsState) (*MetricsState, error) {
url := p.getMetricEndpoint(pod)
func (p *PodMetricsClientImpl) FetchMetrics(ctx context.Context, metadata *datalayer.EndpointMetadata, existing *MetricsState) (*MetricsState, error) {
url := p.getMetricEndpoint(metadata)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %v", err)
}
resp, err := p.Client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to fetch metrics from %s: %w", pod.NamespacedName, err)
return nil, fmt.Errorf("failed to fetch metrics from %s: %w", metadata.NamespacedName, err)
}
defer func() {
_ = resp.Body.Close()
}()

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code from %s: %v", pod.NamespacedName, resp.StatusCode)
return nil, fmt.Errorf("unexpected status code from %s: %v", metadata.NamespacedName, resp.StatusCode)
}

parser := expfmt.NewTextParser(model.LegacyValidation)
Expand All @@ -76,8 +76,8 @@ func (p *PodMetricsClientImpl) FetchMetrics(ctx context.Context, pod *backend.Po
return p.promToPodMetrics(metricFamilies, existing)
}

func (p *PodMetricsClientImpl) getMetricEndpoint(pod *backend.Pod) string {
return p.ModelServerMetricsScheme + "://" + pod.GetMetricsHost() + p.ModelServerMetricsPath
func (p *PodMetricsClientImpl) getMetricEndpoint(metadata *datalayer.EndpointMetadata) string {
return p.ModelServerMetricsScheme + "://" + metadata.GetMetricsHost() + p.ModelServerMetricsPath
}

// promToPodMetrics updates internal pod metrics with scraped Prometheus metrics.
Expand Down
6 changes: 3 additions & 3 deletions pkg/epp/backend/metrics/metrics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import (
"google.golang.org/protobuf/proto"
"k8s.io/apimachinery/pkg/types"

"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
)

Expand Down Expand Up @@ -577,7 +577,7 @@ func TestPromToPodMetrics(t *testing.T) {
// there's no server running on the specified port.
func TestFetchMetrics(t *testing.T) {
ctx := logutil.NewTestLoggerIntoContext(context.Background())
pod := &backend.Pod{
metadata := &datalayer.EndpointMetadata{
Address: "127.0.0.1",
Port: "9999",
MetricsHost: "127.0.0.1:9999",
Expand All @@ -594,7 +594,7 @@ func TestFetchMetrics(t *testing.T) {
Client: http.DefaultClient,
}

_, err := p.FetchMetrics(ctx, pod, existing) // Use a port that's unlikely to be in use
_, err := p.FetchMetrics(ctx, metadata, existing) // Use a port that's unlikely to be in use
if err == nil {
t.Errorf("FetchMetrics() expected error, got nil")
}
Expand Down
23 changes: 0 additions & 23 deletions pkg/epp/backend/pod.go

This file was deleted.

4 changes: 2 additions & 2 deletions pkg/epp/config/loader/configloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -408,14 +408,14 @@ func (m *mockPlugin) TypedName() plugins.TypedName { return m.t }
// Mock Scorer
type mockScorer struct{ mockPlugin }

func (m *mockScorer) Score(context.Context, *types.CycleState, *types.LLMRequest, []types.Pod) map[types.Pod]float64 {
func (m *mockScorer) Score(context.Context, *types.CycleState, *types.LLMRequest, []types.Endpoint) map[types.Endpoint]float64 {
return nil
}

// Mock Picker
type mockPicker struct{ mockPlugin }

func (m *mockPicker) Pick(context.Context, *types.CycleState, []*types.ScoredPod) *types.ProfileRunResult {
func (m *mockPicker) Pick(context.Context, *types.CycleState, []*types.ScoredEndpoint) *types.ProfileRunResult {
return nil
}

Expand Down
10 changes: 5 additions & 5 deletions pkg/epp/handlers/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,15 @@ func (s *StreamingServer) HandleRequestHeaders(reqCtx *RequestContext, req *extP

// an EoS in the request headers means this request has no body or trailers.
if req.RequestHeaders.EndOfStream {
// We will route this request to a random pod as this is assumed to just be a GET
// We will route this request to a random endpoint as this is assumed to just be a GET
// More context: https://github.com/kubernetes-sigs/gateway-api-inference-extension/pull/526
// The above PR will address endpoint admission, but currently any request without a body will be
// routed to a random upstream pod.
pod := s.director.GetRandomPod()
if pod == nil {
// routed to a random upstream endpoint.
endpoint := s.director.GetRandomEndpoint()
if endpoint == nil {
return errutil.Error{Code: errutil.Internal, Msg: "no pods available in datastore"}
}
reqCtx.TargetEndpoint = pod.GetIPAddress() + ":" + pod.GetPort()
reqCtx.TargetEndpoint = endpoint.GetIPAddress() + ":" + endpoint.GetPort()
reqCtx.RequestSize = 0
reqCtx.reqHeaderResp = s.generateRequestHeaderResponse(reqCtx)
return nil
Expand Down
6 changes: 3 additions & 3 deletions pkg/epp/handlers/response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import (

"github.com/google/go-cmp/cmp"

"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
)

Expand Down Expand Up @@ -103,8 +103,8 @@ func (m *mockDirector) HandleResponseReceived(ctx context.Context, reqCtx *Reque
func (m *mockDirector) HandlePreRequest(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) {
return reqCtx, nil
}
func (m *mockDirector) GetRandomPod() *backend.Pod {
return &backend.Pod{}
func (m *mockDirector) GetRandomEndpoint() *datalayer.EndpointMetadata {
return &datalayer.EndpointMetadata{}
}
func (m *mockDirector) HandleRequest(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) {
return reqCtx, nil
Expand Down
5 changes: 2 additions & 3 deletions pkg/epp/handlers/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import (
"google.golang.org/grpc/status"

"sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
Expand All @@ -57,7 +56,7 @@ type Director interface {
HandleResponseReceived(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error)
HandleResponseBodyStreaming(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error)
HandleResponseBodyComplete(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error)
GetRandomPod() *backend.Pod
GetRandomEndpoint() *datalayer.EndpointMetadata
}

type Datastore interface {
Expand All @@ -76,7 +75,7 @@ type StreamingServer struct {
// Specifically, there are fields related to the ext-proc protocol, and then fields related to the lifecycle of the request.
// We should split these apart as this monolithic object exposes too much data to too many layers.
type RequestContext struct {
TargetPod *backend.Pod
TargetPod *datalayer.EndpointMetadata
TargetEndpoint string
IncomingModelName string
TargetModelName string
Expand Down
4 changes: 2 additions & 2 deletions pkg/epp/requestcontrol/dag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ func (m *mockPrepareRequestDataP) Consumes() map[string]any {
return m.consumes
}

func (m *mockPrepareRequestDataP) PrepareRequestData(ctx context.Context, request *types.LLMRequest, pods []types.Pod) error {
pods[0].Put(mockProducedDataKey, mockProducedDataType{value: 42})
func (m *mockPrepareRequestDataP) PrepareRequestData(ctx context.Context, request *types.LLMRequest, endpoints []types.Endpoint) error {
endpoints[0].Put(mockProducedDataKey, mockProducedDataType{value: 42})
return nil
}

Expand Down
47 changes: 24 additions & 23 deletions pkg/epp/requestcontrol/director.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import (

"sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore"
Expand Down Expand Up @@ -57,7 +56,7 @@ type Datastore interface {

// Scheduler defines the interface required by the Director for scheduling.
type Scheduler interface {
Schedule(ctx context.Context, request *schedulingtypes.LLMRequest, candidatePods []schedulingtypes.Pod) (result *schedulingtypes.SchedulingResult, err error)
Schedule(ctx context.Context, request *schedulingtypes.LLMRequest, candidatePods []schedulingtypes.Endpoint) (result *schedulingtypes.SchedulingResult, err error)
}

// NewDirectorWithConfig creates a new Director instance with all dependencies.
Expand Down Expand Up @@ -245,34 +244,34 @@ func (d *Director) prepareRequest(ctx context.Context, reqCtx *handlers.RequestC
return reqCtx, errutil.Error{Code: errutil.Internal, Msg: "results must be greater than zero"}
}
// primary profile is used to set destination
targetPods := []*backend.Pod{}
targetMetadatas := []*datalayer.EndpointMetadata{}
targetEndpoints := []string{}

for _, pod := range result.ProfileResults[result.PrimaryProfileName].TargetPods {
curPod := pod.GetPod()
curEndpoint := net.JoinHostPort(curPod.GetIPAddress(), curPod.GetPort())
targetPods = append(targetPods, curPod)
for _, pod := range result.ProfileResults[result.PrimaryProfileName].TargetEndpoints {
curMetadata := pod.GetMetadata()
curEndpoint := net.JoinHostPort(curMetadata.GetIPAddress(), curMetadata.GetPort())
targetMetadatas = append(targetMetadatas, curMetadata)
targetEndpoints = append(targetEndpoints, curEndpoint)
}

multiEndpointString := strings.Join(targetEndpoints, ",")
logger.V(logutil.VERBOSE).Info("Request handled", "objectiveKey", reqCtx.ObjectiveKey, "incomingModelName", reqCtx.IncomingModelName, "targetModel", reqCtx.TargetModelName, "endpoint", multiEndpointString)

reqCtx.TargetPod = targetPods[0]
reqCtx.TargetPod = targetMetadatas[0]
reqCtx.TargetEndpoint = multiEndpointString

d.runPreRequestPlugins(ctx, reqCtx.SchedulingRequest, result)

return reqCtx, nil
}

func (d *Director) toSchedulerPodMetrics(pods []backendmetrics.PodMetrics) []schedulingtypes.Pod {
pm := make([]schedulingtypes.Pod, len(pods))
func (d *Director) toSchedulerPodMetrics(pods []backendmetrics.PodMetrics) []schedulingtypes.Endpoint {
pm := make([]schedulingtypes.Endpoint, len(pods))
for i, pod := range pods {
if pod.GetAttributes() != nil {
pm[i] = &schedulingtypes.PodMetrics{Pod: pod.GetMetadata().Clone(), MetricsState: pod.GetMetrics().Clone(), AttributeMap: pod.GetAttributes().Clone()}
pm[i] = &schedulingtypes.PodMetrics{EndpointMetadata: pod.GetMetadata().Clone(), MetricsState: pod.GetMetrics().Clone(), AttributeMap: pod.GetAttributes().Clone()}
} else {
pm[i] = &schedulingtypes.PodMetrics{Pod: pod.GetMetadata().Clone(), MetricsState: pod.GetMetrics().Clone(), AttributeMap: datalayer.NewAttributes()}
pm[i] = &schedulingtypes.PodMetrics{EndpointMetadata: pod.GetMetadata().Clone(), MetricsState: pod.GetMetrics().Clone(), AttributeMap: datalayer.NewAttributes()}
}
}

Expand Down Expand Up @@ -323,7 +322,7 @@ func (d *Director) HandleResponseBodyComplete(ctx context.Context, reqCtx *handl
return reqCtx, nil
}

func (d *Director) GetRandomPod() *backend.Pod {
func (d *Director) GetRandomEndpoint() *datalayer.EndpointMetadata {
pods := d.datastore.PodList(datastore.AllPodsPredicate)
if len(pods) == 0 {
return nil
Expand All @@ -346,16 +345,18 @@ func (d *Director) runPreRequestPlugins(ctx context.Context, request *scheduling
}

func (d *Director) runPrepareDataPlugins(ctx context.Context,
request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error {
return prepareDataPluginsWithTimeout(prepareDataTimeout, d.requestControlPlugins.prepareDataPlugins, ctx, request, pods)
request *schedulingtypes.LLMRequest, endpoints []schedulingtypes.Endpoint) error {
return prepareDataPluginsWithTimeout(
prepareDataTimeout, d.requestControlPlugins.prepareDataPlugins, ctx, request, endpoints)

}

func (d *Director) runAdmissionPlugins(ctx context.Context,
request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) bool {
request *schedulingtypes.LLMRequest, endpoints []schedulingtypes.Endpoint) bool {
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
for _, plugin := range d.requestControlPlugins.admissionPlugins {
loggerDebug.Info("Running AdmitRequest plugin", "plugin", plugin.TypedName())
if denyReason := plugin.AdmitRequest(ctx, request, pods); denyReason != nil {
if denyReason := plugin.AdmitRequest(ctx, request, endpoints); denyReason != nil {
loggerDebug.Info("AdmitRequest plugin denied the request", "plugin", plugin.TypedName(), "reason", denyReason.Error())
return false
}
Expand All @@ -364,34 +365,34 @@ func (d *Director) runAdmissionPlugins(ctx context.Context,
return true
}

func (d *Director) runResponseReceivedPlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
func (d *Director) runResponseReceivedPlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetEndpoint *datalayer.EndpointMetadata) {
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
for _, plugin := range d.requestControlPlugins.responseReceivedPlugins {
loggerDebug.Info("Running ResponseReceived plugin", "plugin", plugin.TypedName())
before := time.Now()
plugin.ResponseReceived(ctx, request, response, targetPod)
plugin.ResponseReceived(ctx, request, response, targetEndpoint)
metrics.RecordPluginProcessingLatency(ResponseReceivedExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before))
loggerDebug.Info("Completed running ResponseReceived plugin successfully", "plugin", plugin.TypedName())
}
}

func (d *Director) runResponseStreamingPlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
func (d *Director) runResponseStreamingPlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetEndpoint *datalayer.EndpointMetadata) {
loggerTrace := log.FromContext(ctx).V(logutil.TRACE)
for _, plugin := range d.requestControlPlugins.responseStreamingPlugins {
loggerTrace.Info("Running ResponseStreaming plugin", "plugin", plugin.TypedName())
before := time.Now()
plugin.ResponseStreaming(ctx, request, response, targetPod)
plugin.ResponseStreaming(ctx, request, response, targetEndpoint)
metrics.RecordPluginProcessingLatency(ResponseStreamingExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before))
loggerTrace.Info("Completed running ResponseStreaming plugin successfully", "plugin", plugin.TypedName())
}
}

func (d *Director) runResponseCompletePlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
func (d *Director) runResponseCompletePlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetEndpoint *datalayer.EndpointMetadata) {
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
for _, plugin := range d.requestControlPlugins.responseCompletePlugins {
loggerDebug.Info("Running ResponseComplete plugin", "plugin", plugin.TypedName())
before := time.Now()
plugin.ResponseComplete(ctx, request, response, targetPod)
plugin.ResponseComplete(ctx, request, response, targetEndpoint)
metrics.RecordPluginProcessingLatency(ResponseCompleteExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before))
loggerDebug.Info("Completed running ResponseComplete plugin successfully", "plugin", plugin.TypedName())
}
Expand Down
Loading