Skip to content

Commit a468127

Browse files
committed
Updates to code due to type name changes
Signed-off-by: Shmuel Kallner <[email protected]>
1 parent 732c856 commit a468127

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+1104
-1106
lines changed

pkg/epp/backend/metrics/fake.go

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,13 @@ import (
2525
"k8s.io/apimachinery/pkg/types"
2626
"sigs.k8s.io/controller-runtime/pkg/log"
2727

28-
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
2928
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
3029
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
3130
)
3231

3332
// FakePodMetrics is an implementation of PodMetrics that doesn't run the async refresh loop.
3433
type FakePodMetrics struct {
35-
Pod *backend.Pod
34+
Metadata *datalayer.EndpointMetadata
3635
Metrics *MetricsState
3736
Attributes *datalayer.Attributes
3837
}
@@ -41,16 +40,16 @@ func (fpm *FakePodMetrics) String() string {
4140
return fmt.Sprintf("Metadata: %v; Metrics: %v", fpm.GetMetadata(), fpm.GetMetrics())
4241
}
4342

44-
func (fpm *FakePodMetrics) GetMetadata() *backend.Pod {
45-
return fpm.Pod
43+
func (fpm *FakePodMetrics) GetMetadata() *datalayer.EndpointMetadata {
44+
return fpm.Metadata
4645
}
4746

4847
func (fpm *FakePodMetrics) GetMetrics() *MetricsState {
4948
return fpm.Metrics
5049
}
5150

5251
func (fpm *FakePodMetrics) UpdateMetadata(metadata *datalayer.EndpointMetadata) {
53-
fpm.Pod = metadata
52+
fpm.Metadata = metadata
5453
}
5554
func (fpm *FakePodMetrics) GetAttributes() *datalayer.Attributes {
5655
return fpm.Attributes
@@ -72,7 +71,7 @@ type FakePodMetricsClient struct {
7271
Res map[types.NamespacedName]*MetricsState
7372
}
7473

75-
func (f *FakePodMetricsClient) FetchMetrics(ctx context.Context, pod *backend.Pod, existing *MetricsState) (*MetricsState, error) {
74+
func (f *FakePodMetricsClient) FetchMetrics(ctx context.Context, pod *datalayer.EndpointMetadata, existing *MetricsState) (*MetricsState, error) {
7675
f.errMu.RLock()
7776
err, ok := f.Err[pod.NamespacedName]
7877
f.errMu.RUnlock()

pkg/epp/backend/metrics/metrics.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import (
2828
"github.com/prometheus/common/model"
2929
"go.uber.org/multierr"
3030

31-
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
31+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
3232
)
3333

3434
const (
@@ -50,22 +50,22 @@ type PodMetricsClientImpl struct {
5050
}
5151

5252
// FetchMetrics fetches metrics from a given pod, clones the existing metrics object and returns an updated one.
53-
func (p *PodMetricsClientImpl) FetchMetrics(ctx context.Context, pod *backend.Pod, existing *MetricsState) (*MetricsState, error) {
54-
url := p.getMetricEndpoint(pod)
53+
func (p *PodMetricsClientImpl) FetchMetrics(ctx context.Context, metadata *datalayer.EndpointMetadata, existing *MetricsState) (*MetricsState, error) {
54+
url := p.getMetricEndpoint(metadata)
5555
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
5656
if err != nil {
5757
return nil, fmt.Errorf("failed to create request: %v", err)
5858
}
5959
resp, err := p.Client.Do(req)
6060
if err != nil {
61-
return nil, fmt.Errorf("failed to fetch metrics from %s: %w", pod.NamespacedName, err)
61+
return nil, fmt.Errorf("failed to fetch metrics from %s: %w", metadata.NamespacedName, err)
6262
}
6363
defer func() {
6464
_ = resp.Body.Close()
6565
}()
6666

6767
if resp.StatusCode != http.StatusOK {
68-
return nil, fmt.Errorf("unexpected status code from %s: %v", pod.NamespacedName, resp.StatusCode)
68+
return nil, fmt.Errorf("unexpected status code from %s: %v", metadata.NamespacedName, resp.StatusCode)
6969
}
7070

7171
parser := expfmt.NewTextParser(model.LegacyValidation)
@@ -76,8 +76,8 @@ func (p *PodMetricsClientImpl) FetchMetrics(ctx context.Context, pod *backend.Po
7676
return p.promToPodMetrics(metricFamilies, existing)
7777
}
7878

79-
func (p *PodMetricsClientImpl) getMetricEndpoint(pod *backend.Pod) string {
80-
return p.ModelServerMetricsScheme + "://" + pod.GetMetricsHost() + p.ModelServerMetricsPath
79+
func (p *PodMetricsClientImpl) getMetricEndpoint(metadata *datalayer.EndpointMetadata) string {
80+
return p.ModelServerMetricsScheme + "://" + metadata.GetMetricsHost() + p.ModelServerMetricsPath
8181
}
8282

8383
// promToPodMetrics updates internal pod metrics with scraped Prometheus metrics.

pkg/epp/backend/metrics/metrics_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ import (
3131
"google.golang.org/protobuf/proto"
3232
"k8s.io/apimachinery/pkg/types"
3333

34-
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
34+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
3535
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
3636
)
3737

@@ -577,7 +577,7 @@ func TestPromToPodMetrics(t *testing.T) {
577577
// there's no server running on the specified port.
578578
func TestFetchMetrics(t *testing.T) {
579579
ctx := logutil.NewTestLoggerIntoContext(context.Background())
580-
pod := &backend.Pod{
580+
metadata := &datalayer.EndpointMetadata{
581581
Address: "127.0.0.1",
582582
Port: "9999",
583583
MetricsHost: "127.0.0.1:9999",
@@ -594,7 +594,7 @@ func TestFetchMetrics(t *testing.T) {
594594
Client: http.DefaultClient,
595595
}
596596

597-
_, err := p.FetchMetrics(ctx, pod, existing) // Use a port that's unlikely to be in use
597+
_, err := p.FetchMetrics(ctx, metadata, existing) // Use a port that's unlikely to be in use
598598
if err == nil {
599599
t.Errorf("FetchMetrics() expected error, got nil")
600600
}

pkg/epp/config/loader/configloader_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -792,8 +792,8 @@ func (f *test1) TypedName() plugins.TypedName {
792792
}
793793

794794
// Filter filters out pods that doesn't meet the filter criteria.
795-
func (f *test1) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod {
796-
return pods
795+
func (f *test1) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, endpoints []types.Endpoint) []types.Endpoint {
796+
return endpoints
797797
}
798798

799799
// compile-time type validation
@@ -813,8 +813,8 @@ func (m *test2) TypedName() plugins.TypedName {
813813
return m.typedName
814814
}
815815

816-
func (m *test2) Score(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, _ []types.Pod) map[types.Pod]float64 {
817-
return map[types.Pod]float64{}
816+
func (m *test2) Score(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, _ []types.Endpoint) map[types.Endpoint]float64 {
817+
return map[types.Endpoint]float64{}
818818
}
819819

820820
// compile-time type validation
@@ -834,7 +834,7 @@ func (p *testPicker) TypedName() plugins.TypedName {
834834
return p.typedName
835835
}
836836

837-
func (p *testPicker) Pick(_ context.Context, _ *types.CycleState, _ []*types.ScoredPod) *types.ProfileRunResult {
837+
func (p *testPicker) Pick(_ context.Context, _ *types.CycleState, _ []*types.ScoredEndpoint) *types.ProfileRunResult {
838838
return nil
839839
}
840840

pkg/epp/handlers/request.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,15 @@ func (s *StreamingServer) HandleRequestHeaders(reqCtx *RequestContext, req *extP
4040

4141
// an EoS in the request headers means this request has no body or trailers.
4242
if req.RequestHeaders.EndOfStream {
43-
// We will route this request to a random pod as this is assumed to just be a GET
43+
// We will route this request to a random endpoint as this is assumed to just be a GET
4444
// More context: https://github.com/kubernetes-sigs/gateway-api-inference-extension/pull/526
4545
// The above PR will address endpoint admission, but currently any request without a body will be
46-
// routed to a random upstream pod.
47-
pod := s.director.GetRandomPod()
48-
if pod == nil {
46+
// routed to a random upstream endpoint.
47+
endpoint := s.director.GetRandomEndpoint()
48+
if endpoint == nil {
4949
return errutil.Error{Code: errutil.Internal, Msg: "no pods available in datastore"}
5050
}
51-
reqCtx.TargetEndpoint = pod.GetIPAddress() + ":" + pod.GetPort()
51+
reqCtx.TargetEndpoint = endpoint.GetIPAddress() + ":" + endpoint.GetPort()
5252
reqCtx.RequestSize = 0
5353
reqCtx.reqHeaderResp = s.generateRequestHeaderResponse(reqCtx)
5454
return nil

pkg/epp/handlers/response_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import (
2323

2424
"github.com/google/go-cmp/cmp"
2525

26-
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
26+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
2727
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
2828
)
2929

@@ -103,8 +103,8 @@ func (m *mockDirector) HandleResponseReceived(ctx context.Context, reqCtx *Reque
103103
func (m *mockDirector) HandlePreRequest(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) {
104104
return reqCtx, nil
105105
}
106-
func (m *mockDirector) GetRandomPod() *backend.Pod {
107-
return &backend.Pod{}
106+
func (m *mockDirector) GetRandomEndpoint() *datalayer.EndpointMetadata {
107+
return &datalayer.EndpointMetadata{}
108108
}
109109
func (m *mockDirector) HandleRequest(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) {
110110
return reqCtx, nil

pkg/epp/handlers/server.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ import (
3131
"google.golang.org/grpc/status"
3232

3333
"sigs.k8s.io/controller-runtime/pkg/log"
34-
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
3534
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
3635
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
3736
schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
@@ -57,7 +56,7 @@ type Director interface {
5756
HandleResponseReceived(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error)
5857
HandleResponseBodyStreaming(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error)
5958
HandleResponseBodyComplete(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error)
60-
GetRandomPod() *backend.Pod
59+
GetRandomEndpoint() *datalayer.EndpointMetadata
6160
}
6261

6362
type Datastore interface {
@@ -76,7 +75,7 @@ type StreamingServer struct {
7675
// Specifically, there are fields related to the ext-proc protocol, and then fields related to the lifecycle of the request.
7776
// We should split these apart as this monolithic object exposes too much data to too many layers.
7877
type RequestContext struct {
79-
TargetPod *backend.Pod
78+
TargetPod *datalayer.EndpointMetadata
8079
TargetEndpoint string
8180
IncomingModelName string
8281
TargetModelName string

pkg/epp/requestcontrol/dag_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ func (m *mockPrepareRequestDataP) Consumes() map[string]any {
4545
return m.consumes
4646
}
4747

48-
func (m *mockPrepareRequestDataP) PrepareRequestData(ctx context.Context, request *types.LLMRequest, pods []types.Pod) error {
49-
pods[0].Put(mockProducedDataKey, mockProducedDataType{value: 42})
48+
func (m *mockPrepareRequestDataP) PrepareRequestData(ctx context.Context, request *types.LLMRequest, endpoints []types.Endpoint) error {
49+
endpoints[0].Put(mockProducedDataKey, mockProducedDataType{value: 42})
5050
return nil
5151
}
5252

pkg/epp/requestcontrol/director.go

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ import (
2828

2929
"sigs.k8s.io/controller-runtime/pkg/log"
3030
"sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2"
31-
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
3231
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
3332
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
3433
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore"
@@ -57,7 +56,7 @@ type Datastore interface {
5756

5857
// Scheduler defines the interface required by the Director for scheduling.
5958
type Scheduler interface {
60-
Schedule(ctx context.Context, request *schedulingtypes.LLMRequest, candidatePods []schedulingtypes.Pod) (result *schedulingtypes.SchedulingResult, err error)
59+
Schedule(ctx context.Context, request *schedulingtypes.LLMRequest, candidatePods []schedulingtypes.Endpoint) (result *schedulingtypes.SchedulingResult, err error)
6160
}
6261

6362
// NewDirectorWithConfig creates a new Director instance with all dependencies.
@@ -243,34 +242,34 @@ func (d *Director) prepareRequest(ctx context.Context, reqCtx *handlers.RequestC
243242
return reqCtx, errutil.Error{Code: errutil.Internal, Msg: "results must be greater than zero"}
244243
}
245244
// primary profile is used to set destination
246-
targetPods := []*backend.Pod{}
245+
targetMetadatas := []*datalayer.EndpointMetadata{}
247246
targetEndpoints := []string{}
248247

249-
for _, pod := range result.ProfileResults[result.PrimaryProfileName].TargetPods {
250-
curPod := pod.GetPod()
251-
curEndpoint := net.JoinHostPort(curPod.GetIPAddress(), curPod.GetPort())
252-
targetPods = append(targetPods, curPod)
248+
for _, pod := range result.ProfileResults[result.PrimaryProfileName].TargetEndpoints {
249+
curMetadata := pod.GetMetadata()
250+
curEndpoint := net.JoinHostPort(curMetadata.GetIPAddress(), curMetadata.GetPort())
251+
targetMetadatas = append(targetMetadatas, curMetadata)
253252
targetEndpoints = append(targetEndpoints, curEndpoint)
254253
}
255254

256255
multiEndpointString := strings.Join(targetEndpoints, ",")
257256
logger.V(logutil.VERBOSE).Info("Request handled", "objectiveKey", reqCtx.ObjectiveKey, "incomingModelName", reqCtx.IncomingModelName, "targetModel", reqCtx.TargetModelName, "endpoint", multiEndpointString)
258257

259-
reqCtx.TargetPod = targetPods[0]
258+
reqCtx.TargetPod = targetMetadatas[0]
260259
reqCtx.TargetEndpoint = multiEndpointString
261260

262261
d.runPreRequestPlugins(ctx, reqCtx.SchedulingRequest, result)
263262

264263
return reqCtx, nil
265264
}
266265

267-
func (d *Director) toSchedulerPodMetrics(pods []backendmetrics.PodMetrics) []schedulingtypes.Pod {
268-
pm := make([]schedulingtypes.Pod, len(pods))
266+
func (d *Director) toSchedulerPodMetrics(pods []backendmetrics.PodMetrics) []schedulingtypes.Endpoint {
267+
pm := make([]schedulingtypes.Endpoint, len(pods))
269268
for i, pod := range pods {
270269
if pod.GetAttributes() != nil {
271-
pm[i] = &schedulingtypes.PodMetrics{Pod: pod.GetMetadata().Clone(), MetricsState: pod.GetMetrics().Clone(), AttributeMap: pod.GetAttributes().Clone()}
270+
pm[i] = &schedulingtypes.PodMetrics{EndpointMetadata: pod.GetMetadata().Clone(), MetricsState: pod.GetMetrics().Clone(), AttributeMap: pod.GetAttributes().Clone()}
272271
} else {
273-
pm[i] = &schedulingtypes.PodMetrics{Pod: pod.GetMetadata().Clone(), MetricsState: pod.GetMetrics().Clone(), AttributeMap: datalayer.NewAttributes()}
272+
pm[i] = &schedulingtypes.PodMetrics{EndpointMetadata: pod.GetMetadata().Clone(), MetricsState: pod.GetMetrics().Clone(), AttributeMap: datalayer.NewAttributes()}
274273
}
275274
}
276275

@@ -321,7 +320,7 @@ func (d *Director) HandleResponseBodyComplete(ctx context.Context, reqCtx *handl
321320
return reqCtx, nil
322321
}
323322

324-
func (d *Director) GetRandomPod() *backend.Pod {
323+
func (d *Director) GetRandomEndpoint() *datalayer.EndpointMetadata {
325324
pods := d.datastore.PodList(datastore.AllPodsPredicate)
326325
if len(pods) == 0 {
327326
return nil
@@ -344,16 +343,18 @@ func (d *Director) runPreRequestPlugins(ctx context.Context, request *scheduling
344343
}
345344

346345
func (d *Director) runPrepareDataPlugins(ctx context.Context,
347-
request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error {
348-
return prepareDataPluginsWithTimeout(prepareDataTimeout, d.requestControlPlugins.prepareDataPlugins, ctx, request, pods)
346+
request *schedulingtypes.LLMRequest, endpoints []schedulingtypes.Endpoint) error {
347+
return prepareDataPluginsWithTimeout(
348+
prepareDataTimeout, d.requestControlPlugins.prepareDataPlugins, ctx, request, endpoints)
349+
349350
}
350351

351352
func (d *Director) runAdmissionPlugins(ctx context.Context,
352-
request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) bool {
353+
request *schedulingtypes.LLMRequest, endpoints []schedulingtypes.Endpoint) bool {
353354
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
354355
for _, plugin := range d.requestControlPlugins.admissionPlugins {
355356
loggerDebug.Info("Running AdmitRequest plugin", "plugin", plugin.TypedName())
356-
if denyReason := plugin.AdmitRequest(ctx, request, pods); denyReason != nil {
357+
if denyReason := plugin.AdmitRequest(ctx, request, endpoints); denyReason != nil {
357358
loggerDebug.Info("AdmitRequest plugin denied the request", "plugin", plugin.TypedName(), "reason", denyReason.Error())
358359
return false
359360
}
@@ -362,34 +363,34 @@ func (d *Director) runAdmissionPlugins(ctx context.Context,
362363
return true
363364
}
364365

365-
func (d *Director) runResponseReceivedPlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
366+
func (d *Director) runResponseReceivedPlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetEndpoint *datalayer.EndpointMetadata) {
366367
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
367368
for _, plugin := range d.requestControlPlugins.responseReceivedPlugins {
368369
loggerDebug.Info("Running ResponseReceived plugin", "plugin", plugin.TypedName())
369370
before := time.Now()
370-
plugin.ResponseReceived(ctx, request, response, targetPod)
371+
plugin.ResponseReceived(ctx, request, response, targetEndpoint)
371372
metrics.RecordPluginProcessingLatency(ResponseReceivedExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before))
372373
loggerDebug.Info("Completed running ResponseReceived plugin successfully", "plugin", plugin.TypedName())
373374
}
374375
}
375376

376-
func (d *Director) runResponseStreamingPlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
377+
func (d *Director) runResponseStreamingPlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetEndpoint *datalayer.EndpointMetadata) {
377378
loggerTrace := log.FromContext(ctx).V(logutil.TRACE)
378379
for _, plugin := range d.requestControlPlugins.responseStreamingPlugins {
379380
loggerTrace.Info("Running ResponseStreaming plugin", "plugin", plugin.TypedName())
380381
before := time.Now()
381-
plugin.ResponseStreaming(ctx, request, response, targetPod)
382+
plugin.ResponseStreaming(ctx, request, response, targetEndpoint)
382383
metrics.RecordPluginProcessingLatency(ResponseStreamingExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before))
383384
loggerTrace.Info("Completed running ResponseStreaming plugin successfully", "plugin", plugin.TypedName())
384385
}
385386
}
386387

387-
func (d *Director) runResponseCompletePlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
388+
func (d *Director) runResponseCompletePlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetEndpoint *datalayer.EndpointMetadata) {
388389
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
389390
for _, plugin := range d.requestControlPlugins.responseCompletePlugins {
390391
loggerDebug.Info("Running ResponseComplete plugin", "plugin", plugin.TypedName())
391392
before := time.Now()
392-
plugin.ResponseComplete(ctx, request, response, targetPod)
393+
plugin.ResponseComplete(ctx, request, response, targetEndpoint)
393394
metrics.RecordPluginProcessingLatency(ResponseCompleteExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before))
394395
loggerDebug.Info("Completed running ResponseComplete plugin successfully", "plugin", plugin.TypedName())
395396
}

0 commit comments

Comments
 (0)