Skip to content

Comments

Add KV cache support for efficient parallel PFN evaluation#27

Open
eytan wants to merge 2 commits intoSamuelGabriel:mainfrom
eytan:kv_cache_take2
Open

Add KV cache support for efficient parallel PFN evaluation#27
eytan wants to merge 2 commits intoSamuelGabriel:mainfrom
eytan:kv_cache_take2

Conversation

@eytan
Copy link

@eytan eytan commented Jan 26, 2026

This PR implements context caching for PFN models to reduce memory usage and
improve runtime performance during parallel evaluations, such as multi-start
optimization in BayesOpt.

The key insight is that when evaluating many candidate points against
the same training data, we can cache the key-value representations from
the training context and reuse them across all parallel evaluations,
rather than recomputing them for each candidate.

Changes:

  • Multi-head attention now supports caching and reusing KV representations
  • Encoder state can be saved/restored for MultivariatePFNModel
  • PFNModel gains cache_training_context() context manager for easy cache management
  • Extensive tests to make sure the caching works correctly
  • Updated copies of files from botorch_community/. I will eventually upstream these into botorch and ax.

To give you a sense of the speed up / memory use improvement, here are some benchmarks I ran on my laptop (cpu)

  d=6 (Low-dimensional)                                                                                                                                                                                 
┌──────────┬───────────────┬─────────────┬─────────┬──────────────┬────────────┬───────────┐                                                                                                          
│ Restarts │ Standard Time │ Cached Time │ Speedup │ Standard Mem │ Cached Mem │ Mem Saved │                                                                                                          
├──────────┼───────────────┼─────────────┼─────────┼──────────────┼────────────┼───────────┤                                                                                                          
│ 4        │ 0.106s        │ 0.034s      │ 3.1x    │ 252 MB       │ 89 MB      │ 64.7%     │                                                                                                          
├──────────┼───────────────┼─────────────┼─────────┼──────────────┼────────────┼───────────┤                                                                                                          
│ 8        │ 0.165s        │ 0.031s      │ 5.3x    │ 348 MB       │ 108 MB     │ 69.0%     │                                                                                                          
├──────────┼───────────────┼─────────────┼─────────┼──────────────┼────────────┼───────────┤                                                                                                          
│ 16       │ 0.307s        │ 0.048s      │ 6.4x    │ 583 MB       │ 134 MB     │ 76.9%     │                                                                                                          
├──────────┼───────────────┼─────────────┼─────────┼──────────────┼────────────┼───────────┤                                                                                                          
│ 32       │ 0.614s        │ 0.061s      │ 10.1x   │ 1042 MB      │ 152 MB     │ 85.4%     │                                                                                                          
└──────────┴───────────────┴─────────────┴─────────┴──────────────┴────────────┴───────────┘                                                                                                          
d=100 (High-dimensional)                                                                                                                                                                              
┌──────────┬───────────────┬─────────────┬─────────┬──────────────┬────────────┬───────────┐                                                                                                          
│ Restarts │ Standard Time │ Cached Time │ Speedup │ Standard Mem │ Cached Mem │ Mem Saved │                                                                                                          
├──────────┼───────────────┼─────────────┼─────────┼──────────────┼────────────┼───────────┤                                                                                                          
│ 4        │ 0.899s        │ 0.148s      │ 6.1x    │ 2.0 GB       │ 313 MB     │ 84.7%     │                                                                                                          
├──────────┼───────────────┼─────────────┼─────────┼──────────────┼────────────┼───────────┤                                                                                                          
│ 8        │ 1.836s        │ 0.183s      │ 10.0x   │ 3.9 GB       │ 490 MB     │ 87.8%     │                                                                                                          
├──────────┼───────────────┼─────────────┼─────────┼──────────────┼────────────┼───────────┤                                                                                                          
│ 16       │ 3.525s        │ 0.293s      │ 12.0x   │ 7.8 GB       │ 715 MB     │ 91.0%     │                                                                                                          
├──────────┼───────────────┼─────────────┼─────────┼──────────────┼────────────┼───────────┤                                                                                                          
│ 32       │ 6.568s        │ 0.458s      │ 14.3x   │ 15.5 GB      │ 1.2 GB     │ 92.2%     │                                                                                                          
└──────────┴───────────────┴─────────────┴─────────┴──────────────┴────────────┴───────────┘                                                                                                          

add mtpfn codebase
Implements context caching for PFN models to reduce memory usage and
improve performance during parallel evaluations, such as multi-start
optimization in Bayesian optimization.

The key insight is that when evaluating many candidate points against
the same training data, we can cache the key-value representations from
the training context and reuse them across all parallel evaluations,
rather than recomputing them for each candidate.

Changes:
- Multi-head attention now supports caching and reusing KV representations
- Encoder state can be saved/restored for MultivariatePFNModel
- PFNModel gains `cache_training_context()` context manager for easy cache management
- Added discretized acquisition functions compatible with Riemann posteriors
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant