From 3488d20cad93ab458cd870ee40fde6917e785ef9 Mon Sep 17 00:00:00 2001 From: ProdigiousPersonn Date: Fri, 6 Feb 2026 12:56:15 -0800 Subject: [PATCH] KV Cache --- include/ml_lib/core/attention-layer.h | 6 +++ include/ml_lib/core/masking.h | 2 +- include/ml_lib/core/sin-pos-encode.h | 1 + include/ml_lib/core/transformer-block.h | 2 + include/ml_lib/math/matrix.h | 2 + include/ml_lib/models/transformer.h | 3 ++ source/core/attention-layer.cpp | 67 +++++++++++++++++++++++++ source/core/sin-pos-encode.cpp | 10 ++++ source/core/transformer-block.cpp | 19 +++++++ source/math/matrix.cpp | 20 ++++++++ source/models/transformer.cpp | 57 +++++++++++++++++++++ 11 files changed, 188 insertions(+), 1 deletion(-) diff --git a/include/ml_lib/core/attention-layer.h b/include/ml_lib/core/attention-layer.h index 5e56d68..23c6b2f 100644 --- a/include/ml_lib/core/attention-layer.h +++ b/include/ml_lib/core/attention-layer.h @@ -29,10 +29,16 @@ class AttentionLayer { Matrix V_cache; std::vector attention_weights_cache; + // KV cache for inference + Matrix kv_K_cache; + Matrix kv_V_cache; + public: AttentionLayer(int embed_dim, int num_heads); Matrix forward(const Matrix& input); + Matrix forward_cached(const Matrix& input); + void clear_cache(); Matrix backward(const Matrix& grad_output); void update(Optimizer* opt); }; diff --git a/include/ml_lib/core/masking.h b/include/ml_lib/core/masking.h index d02ef8a..326fe97 100644 --- a/include/ml_lib/core/masking.h +++ b/include/ml_lib/core/masking.h @@ -9,4 +9,4 @@ class Masking { ~Masking(); Matrix apply(const Matrix& x); -} \ No newline at end of file +}; \ No newline at end of file diff --git a/include/ml_lib/core/sin-pos-encode.h b/include/ml_lib/core/sin-pos-encode.h index 0d90d37..e98a26b 100644 --- a/include/ml_lib/core/sin-pos-encode.h +++ b/include/ml_lib/core/sin-pos-encode.h @@ -13,4 +13,5 @@ class SinPositionalEncoding { SinPositionalEncoding(int features, int max_seq_len); Matrix forward(const Matrix &input); + Matrix forward(const Matrix &input, int position); }; \ No newline at end of file diff --git a/include/ml_lib/core/transformer-block.h b/include/ml_lib/core/transformer-block.h index 41f3f9d..e4183b6 100644 --- a/include/ml_lib/core/transformer-block.h +++ b/include/ml_lib/core/transformer-block.h @@ -20,6 +20,8 @@ class TransformerBlock { TransformerBlock(int embed_dim, int num_heads, int ff_dim); Matrix forward(const Matrix& input); + Matrix forward_cached(const Matrix& input); + void clear_cache(); Matrix backward(const Matrix& grad_output); void update(Optimizer* opt); }; diff --git a/include/ml_lib/math/matrix.h b/include/ml_lib/math/matrix.h index 17889b1..c12f9a1 100644 --- a/include/ml_lib/math/matrix.h +++ b/include/ml_lib/math/matrix.h @@ -47,6 +47,8 @@ class Matrix { bool operator!=(const Matrix& other) const; bool approxEqual(const Matrix& other, double epsilon = 1e-9) const; + Matrix verticalConcat(const Matrix& other) const; + static EliminationResult forwardElimination(const Matrix& m, const Matrix& aug = Matrix(0, 0)); static EliminationResult backwardElimination(const Matrix& m, const Matrix& aug = Matrix(0, 0)); diff --git a/include/ml_lib/models/transformer.h b/include/ml_lib/models/transformer.h index da1dcc8..e518d41 100644 --- a/include/ml_lib/models/transformer.h +++ b/include/ml_lib/models/transformer.h @@ -33,4 +33,7 @@ class Transformer : public GradientModel { Matrix forward(const std::vector& tokens); void backward(const Matrix& y_true) override; void update() override; + + std::vector generate(const std::vector& prompt, int max_tokens); + void clear_cache(); }; diff --git a/source/core/attention-layer.cpp b/source/core/attention-layer.cpp index e805bb2..721bfc4 100644 --- a/source/core/attention-layer.cpp +++ b/source/core/attention-layer.cpp @@ -1,4 +1,5 @@ #include "ml_lib/core/attention-layer.h" +#include "ml_lib/core/masking.h" #include "softmax.h" #include #include @@ -76,6 +77,11 @@ Matrix AttentionLayer::forward(const Matrix& input) } } + // casual masking + Masking causal_mask(seq_len, seq_len); + scores = causal_mask.apply(scores); + + // apply softmax Matrix attn_weights = Softmax::apply(scores); attention_weights_cache[h] = attn_weights; @@ -91,6 +97,67 @@ Matrix AttentionLayer::forward(const Matrix& input) return output * W_o; } +Matrix AttentionLayer::forward_cached(const Matrix& input) +{ + Matrix Q_new = input * W_q; + Matrix K_new = input * W_k; + Matrix V_new = input * W_v; + + // append to KV cache + if (kv_K_cache.empty()) { + kv_K_cache = K_new; + kv_V_cache = V_new; + } else { + kv_K_cache = kv_K_cache.verticalConcat(K_new); + kv_V_cache = kv_V_cache.verticalConcat(V_new); + } + + int cached_len = kv_K_cache.rows(); + double scale = 1.0 / std::sqrt((double)head_dim); + + Matrix output(1, embed_dim); + + for (int h = 0; h < num_heads; h++) { + int head_start = h * head_dim; + + Matrix Q_h(1, head_dim); + Matrix K_h(cached_len, head_dim); + Matrix V_h(cached_len, head_dim); + + for (int j = 0; j < head_dim; j++) { + Q_h(0, j) = Q_new(0, head_start + j); + } + for (int i = 0; i < cached_len; i++) { + for (int j = 0; j < head_dim; j++) { + K_h(i, j) = kv_K_cache(i, head_start + j); + V_h(i, j) = kv_V_cache(i, head_start + j); + } + } + + // scores + Matrix scores = Q_h * K_h.transpose(); + for (int j = 0; j < cached_len; j++) { + scores(0, j) *= scale; + } + + // apply softmax + Matrix attn_weights = Softmax::apply(scores); + Matrix head_output = attn_weights * V_h; + + for (int j = 0; j < head_dim; j++) { + output(0, head_start + j) = head_output(0, j); + } + } + + return output * W_o; +} + +void AttentionLayer::clear_cache() +{ + kv_K_cache = Matrix(); + kv_V_cache = Matrix(); +} + Matrix AttentionLayer::backward(const Matrix& grad_output) { int seq_len = grad_output.rows(); diff --git a/source/core/sin-pos-encode.cpp b/source/core/sin-pos-encode.cpp index f17c33f..5c2df97 100644 --- a/source/core/sin-pos-encode.cpp +++ b/source/core/sin-pos-encode.cpp @@ -33,3 +33,13 @@ Matrix SinPositionalEncoding::forward(const Matrix &input) return output; } +Matrix SinPositionalEncoding::forward(const Matrix &input, int position) +{ + Matrix output(input.rows(), input.cols()); + + for (int j = 0; j < input.cols(); j++) { + output(0, j) = input(0, j) + pe_matrix(position, j); + } + return output; +} + diff --git a/source/core/transformer-block.cpp b/source/core/transformer-block.cpp index 6c7b8f2..a440e38 100644 --- a/source/core/transformer-block.cpp +++ b/source/core/transformer-block.cpp @@ -31,6 +31,25 @@ Matrix TransformerBlock::forward(const Matrix& input) return output; } +Matrix TransformerBlock::forward_cached(const Matrix& input) +{ + Matrix attn_out = attention.forward_cached(input); + + Matrix residual1 = input + attn_out; + Matrix normed1 = norm1.forward(residual1); + + Matrix ff_out = ff2.forward(ff1.forward(normed1)); + Matrix residual2 = normed1 + ff_out; + Matrix output = norm2.forward(residual2); + + return output; +} + +void TransformerBlock::clear_cache() +{ + attention.clear_cache(); +} + Matrix TransformerBlock::backward(const Matrix& grad_output) { Matrix grad_norm2 = norm2.backward(grad_output); diff --git a/source/math/matrix.cpp b/source/math/matrix.cpp index c499c84..cec8742 100644 --- a/source/math/matrix.cpp +++ b/source/math/matrix.cpp @@ -402,6 +402,26 @@ Matrix Matrix::transpose() const { return result; } +Matrix Matrix::verticalConcat(const Matrix& other) const { + if (empty()) return other; + if (other.empty()) return *this; + if (m_cols != other.m_cols) { + throw std::invalid_argument("Column count must match for vertical concatenation."); + } + Matrix result(m_rows + other.m_rows, m_cols); + for (int i = 0; i < m_rows; i++) { + for (int j = 0; j < m_cols; j++) { + result(i, j) = (*this)(i, j); + } + } + for (int i = 0; i < other.m_rows; i++) { + for (int j = 0; j < m_cols; j++) { + result(m_rows + i, j) = other(i, j); + } + } + return result; +} + double Matrix::determinant() const { if (empty()) { return 0.0; diff --git a/source/models/transformer.cpp b/source/models/transformer.cpp index 01657d4..36ea63d 100644 --- a/source/models/transformer.cpp +++ b/source/models/transformer.cpp @@ -65,3 +65,60 @@ void Transformer::update() } output_projection.update(optimizer.get()); } + +std::vector Transformer::generate(const std::vector& prompt, int max_tokens) +{ + clear_cache(); + + std::vector output = prompt; + Matrix x; + + // prefill tokens to build up the KV cache + for (int i = 0; i < (int)prompt.size(); i++) { + Matrix embedded = embedding.forward({prompt[i]}); + x = pos_encoding.forward(embedded, i); + + for (auto& block : blocks) { + x = block->forward_cached(x); + } + } + + // get first predicted token from last prefill + Matrix logits = output_projection.forward(x); + int pos = prompt.size(); + + for (int t = 0; t < max_tokens; t++) { + // pick token with highest logit + int best_token = 0; + double best_val = logits(0, 0); + for (int v = 1; v < vocab_size; v++) { + if (logits(0, v) > best_val) { + best_val = logits(0, v); + best_token = v; + } + } + + output.push_back(best_token); + pos++; + + if (pos >= max_seq_len) break; + if (t == max_tokens - 1) break; + + // process new token through cache + Matrix embedded = embedding.forward({best_token}); + x = pos_encoding.forward(embedded, pos - 1); + for (auto& block : blocks) { + x = block->forward_cached(x); + } + logits = output_projection.forward(x); + } + + return output; +} + +void Transformer::clear_cache() +{ + for (auto& block : blocks) { + block->clear_cache(); + } +}