Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/ml_lib/core/attention-layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,16 @@ class AttentionLayer {
Matrix V_cache;
std::vector<Matrix> attention_weights_cache;

// KV cache for inference
Matrix kv_K_cache;
Matrix kv_V_cache;

public:
AttentionLayer(int embed_dim, int num_heads);

Matrix forward(const Matrix& input);
Matrix forward_cached(const Matrix& input);
void clear_cache();
Matrix backward(const Matrix& grad_output);
void update(Optimizer* opt);
};
2 changes: 1 addition & 1 deletion include/ml_lib/core/masking.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ class Masking {
~Masking();
Matrix apply(const Matrix& x);

}
};
1 change: 1 addition & 0 deletions include/ml_lib/core/sin-pos-encode.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ class SinPositionalEncoding {
SinPositionalEncoding(int features, int max_seq_len);

Matrix forward(const Matrix &input);
Matrix forward(const Matrix &input, int position);
};
2 changes: 2 additions & 0 deletions include/ml_lib/core/transformer-block.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ class TransformerBlock {
TransformerBlock(int embed_dim, int num_heads, int ff_dim);

Matrix forward(const Matrix& input);
Matrix forward_cached(const Matrix& input);
void clear_cache();
Matrix backward(const Matrix& grad_output);
void update(Optimizer* opt);
};
2 changes: 2 additions & 0 deletions include/ml_lib/math/matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class Matrix {
bool operator!=(const Matrix& other) const;
bool approxEqual(const Matrix& other, double epsilon = 1e-9) const;

Matrix verticalConcat(const Matrix& other) const;

static EliminationResult forwardElimination(const Matrix& m, const Matrix& aug = Matrix(0, 0));
static EliminationResult backwardElimination(const Matrix& m, const Matrix& aug = Matrix(0, 0));

Expand Down
3 changes: 3 additions & 0 deletions include/ml_lib/models/transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,7 @@ class Transformer : public GradientModel {
Matrix forward(const std::vector<int>& tokens);
void backward(const Matrix& y_true) override;
void update() override;

std::vector<int> generate(const std::vector<int>& prompt, int max_tokens);
void clear_cache();
};
67 changes: 67 additions & 0 deletions source/core/attention-layer.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "ml_lib/core/attention-layer.h"
#include "ml_lib/core/masking.h"
#include "softmax.h"
#include <cmath>
#include <stdexcept>
Expand Down Expand Up @@ -76,6 +77,11 @@ Matrix AttentionLayer::forward(const Matrix& input)
}
}

// casual masking
Masking causal_mask(seq_len, seq_len);
scores = causal_mask.apply(scores);

// apply softmax
Matrix attn_weights = Softmax::apply(scores);
attention_weights_cache[h] = attn_weights;

Expand All @@ -91,6 +97,67 @@ Matrix AttentionLayer::forward(const Matrix& input)
return output * W_o;
}

Matrix AttentionLayer::forward_cached(const Matrix& input)
{
Matrix Q_new = input * W_q;
Matrix K_new = input * W_k;
Matrix V_new = input * W_v;

// append to KV cache
if (kv_K_cache.empty()) {
kv_K_cache = K_new;
kv_V_cache = V_new;
} else {
kv_K_cache = kv_K_cache.verticalConcat(K_new);
kv_V_cache = kv_V_cache.verticalConcat(V_new);
}

int cached_len = kv_K_cache.rows();
double scale = 1.0 / std::sqrt((double)head_dim);

Matrix output(1, embed_dim);

for (int h = 0; h < num_heads; h++) {
int head_start = h * head_dim;

Matrix Q_h(1, head_dim);
Matrix K_h(cached_len, head_dim);
Matrix V_h(cached_len, head_dim);

for (int j = 0; j < head_dim; j++) {
Q_h(0, j) = Q_new(0, head_start + j);
}
for (int i = 0; i < cached_len; i++) {
for (int j = 0; j < head_dim; j++) {
K_h(i, j) = kv_K_cache(i, head_start + j);
V_h(i, j) = kv_V_cache(i, head_start + j);
}
}

// scores
Matrix scores = Q_h * K_h.transpose();
for (int j = 0; j < cached_len; j++) {
scores(0, j) *= scale;
}

// apply softmax
Matrix attn_weights = Softmax::apply(scores);
Matrix head_output = attn_weights * V_h;

for (int j = 0; j < head_dim; j++) {
output(0, head_start + j) = head_output(0, j);
}
}

return output * W_o;
}

void AttentionLayer::clear_cache()
{
kv_K_cache = Matrix();
kv_V_cache = Matrix();
}

Matrix AttentionLayer::backward(const Matrix& grad_output)
{
int seq_len = grad_output.rows();
Expand Down
10 changes: 10 additions & 0 deletions source/core/sin-pos-encode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,13 @@ Matrix SinPositionalEncoding::forward(const Matrix &input)
return output;
}

Matrix SinPositionalEncoding::forward(const Matrix &input, int position)
{
Matrix output(input.rows(), input.cols());

for (int j = 0; j < input.cols(); j++) {
output(0, j) = input(0, j) + pe_matrix(position, j);
}
return output;
}

19 changes: 19 additions & 0 deletions source/core/transformer-block.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,25 @@ Matrix TransformerBlock::forward(const Matrix& input)
return output;
}

Matrix TransformerBlock::forward_cached(const Matrix& input)
{
Matrix attn_out = attention.forward_cached(input);

Matrix residual1 = input + attn_out;
Matrix normed1 = norm1.forward(residual1);

Matrix ff_out = ff2.forward(ff1.forward(normed1));
Matrix residual2 = normed1 + ff_out;
Matrix output = norm2.forward(residual2);

return output;
}

void TransformerBlock::clear_cache()
{
attention.clear_cache();
}

Matrix TransformerBlock::backward(const Matrix& grad_output)
{
Matrix grad_norm2 = norm2.backward(grad_output);
Expand Down
20 changes: 20 additions & 0 deletions source/math/matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,26 @@ Matrix Matrix::transpose() const {
return result;
}

Matrix Matrix::verticalConcat(const Matrix& other) const {
if (empty()) return other;
if (other.empty()) return *this;
if (m_cols != other.m_cols) {
throw std::invalid_argument("Column count must match for vertical concatenation.");
}
Matrix result(m_rows + other.m_rows, m_cols);
for (int i = 0; i < m_rows; i++) {
for (int j = 0; j < m_cols; j++) {
result(i, j) = (*this)(i, j);
}
}
for (int i = 0; i < other.m_rows; i++) {
for (int j = 0; j < m_cols; j++) {
result(m_rows + i, j) = other(i, j);
}
}
return result;
}

double Matrix::determinant() const {
if (empty()) {
return 0.0;
Expand Down
57 changes: 57 additions & 0 deletions source/models/transformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,60 @@ void Transformer::update()
}
output_projection.update(optimizer.get());
}

std::vector<int> Transformer::generate(const std::vector<int>& prompt, int max_tokens)
{
clear_cache();

std::vector<int> output = prompt;
Matrix x;

// prefill tokens to build up the KV cache
for (int i = 0; i < (int)prompt.size(); i++) {
Matrix embedded = embedding.forward({prompt[i]});
x = pos_encoding.forward(embedded, i);

for (auto& block : blocks) {
x = block->forward_cached(x);
}
}

// get first predicted token from last prefill
Matrix logits = output_projection.forward(x);
int pos = prompt.size();

for (int t = 0; t < max_tokens; t++) {
// pick token with highest logit
int best_token = 0;
double best_val = logits(0, 0);
for (int v = 1; v < vocab_size; v++) {
if (logits(0, v) > best_val) {
best_val = logits(0, v);
best_token = v;
}
}

output.push_back(best_token);
pos++;

if (pos >= max_seq_len) break;
if (t == max_tokens - 1) break;

// process new token through cache
Matrix embedded = embedding.forward({best_token});
x = pos_encoding.forward(embedded, pos - 1);
for (auto& block : blocks) {
x = block->forward_cached(x);
}
logits = output_projection.forward(x);
}

return output;
}

void Transformer::clear_cache()
{
for (auto& block : blocks) {
block->clear_cache();
}
}
Loading