Skip to content

Commit 598b4b2

Browse files
committed
Add more tests, register prefix cache scorer and address other review comments
1 parent 6707274 commit 598b4b2

File tree

9 files changed

+95
-27
lines changed

9 files changed

+95
-27
lines changed

cmd/epp/runner/runner.go

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ func (r *Runner) Run(ctx context.Context) error {
240240
}
241241

242242
// --- Setup Datastore ---
243-
epf, err := r.setupMetricsCollection(setupLog, r.featureGates[datalayer.FeatureGate])
243+
epf, err := r.setupMetricsCollection(setupLog, r.featureGates[datalayer.ExperimentalDatalayerFeatureGate])
244244
if err != nil {
245245
return err
246246
}
@@ -372,7 +372,7 @@ func (r *Runner) Run(ctx context.Context) error {
372372
MetricsStalenessThreshold: *metricsStalenessThreshold,
373373
Director: director,
374374
SaturationDetector: saturationDetector,
375-
UseExperimentalDatalayerV2: r.featureGates[datalayer.FeatureGate], // pluggable data layer feature flag
375+
UseExperimentalDatalayerV2: r.featureGates[datalayer.ExperimentalDatalayerFeatureGate], // pluggable data layer feature flag
376376
}
377377
if err := serverRunner.SetupWithManager(ctx, mgr); err != nil {
378378
setupLog.Error(err, "Failed to setup EPP controllers")
@@ -463,7 +463,7 @@ func (r *Runner) parseConfigurationPhaseOne(ctx context.Context) (*configapi.End
463463
}
464464
}
465465

466-
loader.RegisterFeatureGate(datalayer.FeatureGate)
466+
loader.RegisterFeatureGate(datalayer.ExperimentalDatalayerFeatureGate)
467467
loader.RegisterFeatureGate(flowcontrol.FeatureGate)
468468
loader.RegisterFeatureGate(datalayer.PrepareDataPluginsFeatureGate)
469469

@@ -506,9 +506,15 @@ func (r *Runner) parseConfigurationPhaseTwo(ctx context.Context, rawConfig *conf
506506
// Add requestControl plugins
507507
r.requestControlConfig.AddPlugins(handle.GetAllPlugins()...)
508508

509-
// Sort prepare data plugins in DAG order (topological sort). Also check prepare data plugins for cycles.
510-
if r.requestControlConfig.PrepareDataPluginGraph(r.featureGates[datalayer.PrepareDataPluginsFeatureGate]) != nil {
511-
return nil, errors.New("failed to load the configuration - prepare data plugins have cyclic dependencies")
509+
// TODO(rahulgurnani): Remove feature gate check once prepare data plugins are stable.
510+
if r.featureGates[datalayer.PrepareDataPluginsFeatureGate] {
511+
// Sort prepare data plugins in DAG order (topological sort). Also check prepare data plugins for cycles.
512+
if r.requestControlConfig.PrepareDataPluginGraph() != nil {
513+
return nil, errors.New("failed to load the configuration - prepare data plugins have cyclic dependencies")
514+
}
515+
plugins.Register(scorer.PrefixCacheMatchScorerType, scorer.PrefixCacheScorerFactory) // register PrefixCacheMatchScorer used by prepare data plugins
516+
} else {
517+
r.requestControlConfig.WithPrepareDataPlugins()
512518
}
513519

514520
// Handler deprecated configuration options
@@ -531,7 +537,7 @@ func (r *Runner) deprecatedConfigurationHelper(cfg *config.Config, logger logr.L
531537

532538
if _, ok := os.LookupEnv(enableExperimentalDatalayerV2); ok {
533539
logger.Info("Enabling the experimental Data Layer V2 using environment variables is deprecated and will be removed in next version")
534-
r.featureGates[datalayer.FeatureGate] = env.GetEnvBool(enableExperimentalDatalayerV2, false, logger)
540+
r.featureGates[datalayer.ExperimentalDatalayerFeatureGate] = env.GetEnvBool(enableExperimentalDatalayerV2, false, logger)
535541
}
536542
if _, ok := os.LookupEnv(enableExperimentalFlowControlLayer); ok {
537543
logger.Info("Enabling the experimental Flow Control layer using environment variables is deprecated and will be removed in next version")

pkg/epp/config/loader/configloader_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ func TestLoadRawConfiguration(t *testing.T) {
100100
},
101101
},
102102
},
103-
FeatureGates: configapi.FeatureGates{datalayer.FeatureGate},
103+
FeatureGates: configapi.FeatureGates{datalayer.ExperimentalDatalayerFeatureGate},
104104
SaturationDetector: &configapi.SaturationDetector{
105105
MetricsStalenessThreshold: metav1.Duration{Duration: 150 * time.Millisecond},
106106
},
@@ -206,7 +206,7 @@ func TestLoadRawConfigurationWithDefaults(t *testing.T) {
206206
},
207207
},
208208
},
209-
FeatureGates: configapi.FeatureGates{datalayer.FeatureGate},
209+
FeatureGates: configapi.FeatureGates{datalayer.ExperimentalDatalayerFeatureGate},
210210
SaturationDetector: &configapi.SaturationDetector{
211211
QueueDepthThreshold: saturationdetector.DefaultQueueDepthThreshold,
212212
KVCacheUtilThreshold: saturationdetector.DefaultKVCacheUtilThreshold,
@@ -488,7 +488,7 @@ func TestLoadConfig(t *testing.T) {
488488
}
489489

490490
func registerNeededFeatureGates() {
491-
RegisterFeatureGate(datalayer.FeatureGate)
491+
RegisterFeatureGate(datalayer.ExperimentalDatalayerFeatureGate)
492492
}
493493

494494
func registerNeededPlgugins() {

pkg/epp/datalayer/factory.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ import (
2626
)
2727

2828
const (
29-
FeatureGate = "dataLayer"
30-
PrepareDataPluginsFeatureGate = "prepareDataPlugins"
29+
ExperimentalDatalayerFeatureGate = "dataLayer"
30+
PrepareDataPluginsFeatureGate = "prepareDataPlugins"
3131
)
3232

3333
// PoolInfo represents the DataStore information needed for endpoints.

pkg/epp/requestcontrol/director.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,9 @@ func (d *Director) runPreRequestPlugins(ctx context.Context, request *scheduling
388388

389389
func (d *Director) runPrepareDataPlugins(ctx context.Context,
390390
request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error {
391+
if len(d.requestControlPlugins.prepareDataPlugins) == 0 {
392+
return nil
393+
}
391394
return prepareDataPluginsWithTimeout(prepareDataTimeout, d.requestControlPlugins.prepareDataPlugins, ctx, request, pods)
392395
}
393396

pkg/epp/requestcontrol/request_control_config.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,8 @@ func (c *Config) AddPlugins(pluginObjects ...plugins.Plugin) {
107107

108108
// PrepareDataPluginGraph creates data dependency graph and sorts the plugins in topological order.
109109
// If a cycle is detected, it returns an error.
110-
func (c *Config) PrepareDataPluginGraph(enablePrepareDataPlugins bool) error {
111-
if !enablePrepareDataPlugins {
112-
c.prepareDataPlugins = []PrepareDataPlugin{}
110+
func (c *Config) PrepareDataPluginGraph() error {
111+
if len(c.prepareDataPlugins) == 0 {
113112
return nil
114113
}
115114
dag := buildDAG(c.prepareDataPlugins)

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -222,11 +222,6 @@ func (p *Plugin) PrepareRequestData(ctx context.Context, request *types.LLMReque
222222
PrefixHashes: hashes,
223223
PrefixCacheServers: p.matchLongestPrefix(ctx, hashes),
224224
}
225-
for server, matchLen := range state.PrefixCacheServers {
226-
log.FromContext(ctx).V(logutil.TRACE).Info("prefix cached state", "server", server, "longest-prefix-match", matchLen)
227-
228-
}
229-
230225
total := len(state.PrefixHashes)
231226

232227
for _, pod := range pods {

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ import (
2929

3030
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
3131
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
32+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
33+
dplugins "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/plugins"
3234
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
3335
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol"
3436
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
@@ -575,6 +577,67 @@ func randomPrompt(n int) string {
575577
return sb.String()
576578
}
577579

580+
func TestPrepareRequestData(t *testing.T) {
581+
config := Config{
582+
BlockSize: 4,
583+
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
584+
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
585+
}
586+
plugin := New(context.Background(), config)
587+
588+
pod1 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}, MetricsState: backendmetrics.NewMetricsState(), AttributeMap: datalayer.NewAttributes()}
589+
pod2 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}, MetricsState: backendmetrics.NewMetricsState(), AttributeMap: datalayer.NewAttributes()}
590+
pods := []types.Pod{pod1, pod2}
591+
592+
// First request to populate cache.
593+
req1 := &types.LLMRequest{
594+
RequestId: uuid.NewString(),
595+
TargetModel: "test-model1",
596+
Body: &types.LLMRequestBody{
597+
Completions: &types.CompletionsRequest{
598+
Prompt: "aaaabbbb",
599+
},
600+
},
601+
}
602+
_ = plugin.Score(context.Background(), types.NewCycleState(), req1, pods)
603+
schedulingResult := &types.SchedulingResult{
604+
PrimaryProfileName: "default",
605+
ProfileResults: map[string]*types.ProfileRunResult{
606+
"default": {TargetPods: []types.Pod{pod1}},
607+
},
608+
}
609+
plugin.PreRequest(context.Background(), req1, schedulingResult)
610+
plugin.wg.Wait()
611+
612+
// Second request that shares a prefix.
613+
req2 := &types.LLMRequest{
614+
RequestId: uuid.NewString(),
615+
TargetModel: "test-model1",
616+
Body: &types.LLMRequestBody{
617+
Completions: &types.CompletionsRequest{
618+
Prompt: "aaaacccc",
619+
},
620+
},
621+
}
622+
623+
err := plugin.PrepareRequestData(context.Background(), req2, pods)
624+
assert.NoError(t, err)
625+
626+
// Verify pod1 has the correct prefix match info
627+
info1, ok := pod1.Get(dplugins.PrefixCacheMatchInfoKey)
628+
assert.True(t, ok)
629+
prefixInfo1 := info1.(*dplugins.PrefixCacheMatchInfo)
630+
assert.Equal(t, 1, prefixInfo1.MatchLength()) // "aaaa" matches
631+
assert.Equal(t, 2, prefixInfo1.TotalLength()) // "aaaacccc" -> 2 blocks
632+
633+
// Verify pod2 has no match info
634+
info2, ok := pod2.Get(dplugins.PrefixCacheMatchInfoKey)
635+
assert.True(t, ok)
636+
prefixInfo2 := info2.(*dplugins.PrefixCacheMatchInfo)
637+
assert.Equal(t, 0, prefixInfo2.MatchLength()) // No match for pod2
638+
assert.Equal(t, 2, prefixInfo2.TotalLength())
639+
}
640+
578641
// BenchmarkPrefixPluginChatCompletionsStress is a stress test for chat completions with varying message counts and lengths
579642
func BenchmarkPrefixPluginChatCompletionsStress(b *testing.B) {
580643
blockSize := 8

pkg/epp/scheduling/framework/plugins/scorer/prefix_cache_match_scorer.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,9 @@ func (s *PrefixCacheScorer) TypedName() plugins.TypedName {
6060

6161
// Consumes returns the list of data that is consumed by the plugin.
6262
func (s *PrefixCacheScorer) Consumes() map[string]any {
63-
return map[string]any{}
63+
return map[string]any{
64+
dplugins.PrefixCacheMatchInfoKey: &dplugins.PrefixCacheMatchInfo{},
65+
}
6466
}
6567

6668
// WithName sets the name of the scorer.
@@ -74,12 +76,12 @@ func (s *PrefixCacheScorer) Score(_ context.Context, cycleState *types.CycleStat
7476
scores := make(map[types.Pod]float64, len(pods))
7577

7678
for _, pod := range pods {
77-
matchPercent, ok := pod.Get(dplugins.PrefixCacheMatchInfoKey)
79+
matchInfo, ok := pod.Get(dplugins.PrefixCacheMatchInfoKey)
7880
if !ok {
7981
scores[pod] = 0.0
8082
continue
8183
}
82-
scores[pod] = float64(matchPercent.(*dplugins.PrefixCacheMatchInfo).MatchLength()) / float64(matchPercent.(*dplugins.PrefixCacheMatchInfo).TotalLength()) * 100
84+
scores[pod] = float64(matchInfo.(*dplugins.PrefixCacheMatchInfo).MatchLength()) / float64(matchInfo.(*dplugins.PrefixCacheMatchInfo).TotalLength())
8385
}
8486
return scores
8587
}

pkg/epp/scheduling/framework/plugins/scorer/prefix_cache_match_scorer_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,8 @@ func TestPrefixCacheScorer_Score(t *testing.T) {
8686
name: "pods with prefix cache match percent",
8787
pods: []types.Pod{pod1, pod2},
8888
expected: map[types.Pod]float64{
89-
pod1: 50.0,
90-
pod2: 100.0,
89+
pod1: 0.5,
90+
pod2: 1.0,
9191
},
9292
},
9393
{
@@ -101,7 +101,7 @@ func TestPrefixCacheScorer_Score(t *testing.T) {
101101
name: "mixed pods",
102102
pods: []types.Pod{pod1, pod3},
103103
expected: map[types.Pod]float64{
104-
pod1: 50.0,
104+
pod1: 0.5,
105105
pod3: 0.0,
106106
},
107107
},
@@ -146,5 +146,5 @@ func TestPrefixCacheScorer_TypedName(t *testing.T) {
146146
func TestPrefixCacheScorer_Consumes(t *testing.T) {
147147
scorer := NewPrefixCacheScorer()
148148
consumes := scorer.Consumes()
149-
assert.Empty(t, consumes)
149+
assert.Contains(t, consumes, dplugins.PrefixCacheMatchInfoKey)
150150
}

0 commit comments

Comments
 (0)