Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 151 additions & 0 deletions picolm/model.c
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,157 @@ float *model_forward(model_t *m, int token, int pos) {
return s->logits;
}

float *model_forward_batch(model_t *m, const int *tokens, int num_tokens, int start_pos) {
if (num_tokens == 0) return NULL;

model_config_t *c = &m->config;
model_weights_t *w = &m->weights;
run_state_t *s = &m->state;

int dim = c->n_embd;
int n_ffn = c->n_ffn;
int n_heads = c->n_heads;
int n_kv_heads = c->n_kv_heads;
int head_dim = c->head_dim;
int kv_dim = n_kv_heads * head_dim;
int kv_mul = n_heads / n_kv_heads;
int seq_len = c->max_seq_len;
int half_dim = head_dim / 2;

/* 为当前 batch 的中间状态分配堆内存,尺寸为 [num_tokens, dim] */
float *batch_x = (float *)malloc((size_t)num_tokens * dim * sizeof(float));
if (!batch_x) return NULL;

/* 1. Embedding lookup for all tokens in the batch */
for (int i = 0; i < num_tokens; i++) {
size_t row_bytes = gguf_type_row_size(w->type_token_embd, dim);
const void *embd_row = (const uint8_t *)w->token_embd + (size_t)tokens[i] * row_bytes;
dequantize_row(embd_row, batch_x + (size_t)i * dim, dim, w->type_token_embd);
}

/* 2. Transformer layers */
for (int l = 0; l < c->n_layers; l++) {
layer_weights_t *lw = &w->layers[l];

for (int i = 0; i < num_tokens; i++) {
int pos = start_pos + i;

/* RoPE table pointers for this position */
const float *cos_pos = s->rope_cos + (size_t)pos * half_dim;
const float *sin_pos = s->rope_sin + (size_t)pos * half_dim;

memcpy(s->x, batch_x + (size_t)i * dim, dim * sizeof(float));

/* ---- Attention ---- */
rmsnorm(s->xb, s->x, s->attn_norm_w[l], dim);

/* QKV projections */
matmul(s->q, s->xb, lw->attn_q, dim, dim, lw->type_attn_q);

/* K and V: project into float temp, then store as FP16 in cache */
float *k_tmp = s->xb2; /* reuse xb2 as temp for K (kv_dim <= dim) */
matmul(k_tmp, s->xb, lw->attn_k, dim, kv_dim, lw->type_attn_k);

/* Store K as FP16 */
uint16_t *kcache_layer = s->key_cache + (size_t)l * seq_len * kv_dim;
uint16_t *vcache_layer = s->val_cache + (size_t)l * seq_len * kv_dim;
uint16_t *key_pos_fp16 = kcache_layer + (size_t)pos * kv_dim;

/* Apply RoPE to Q and K (using pre-computed tables) */
rope(s->q, k_tmp, head_dim, n_heads, n_kv_heads, cos_pos, sin_pos);

/* Convert K to FP16 and store */
for (int d = 0; d < kv_dim; d++) {
key_pos_fp16[d] = fp32_to_fp16(k_tmp[d]);
}

/* V projection -> store directly as FP16 */
float *v_tmp = s->xb2;
matmul(v_tmp, s->xb, lw->attn_v, dim, kv_dim, lw->type_attn_v);
uint16_t *val_pos_fp16 = vcache_layer + (size_t)pos * kv_dim;
for (int d = 0; d < kv_dim; d++) {
val_pos_fp16[d] = fp32_to_fp16(v_tmp[d]);
}

/* ---- Flash Attention (online softmax) ---- */
for (int h = 0; h < n_heads; h++) {
float *qh = s->q + h * head_dim;
int kv_h = h / kv_mul;
float *xbh = s->xb + h * head_dim;

float max_score = -1e30f;
float sum_exp = 0.0f;
/* Accumulator for weighted V values */
float acc[256]; /* head_dim is typically 64-128 */
memset(acc, 0, (size_t)head_dim * sizeof(float));

for (int t = 0; t <= pos; t++) {
/* Compute score: dot(Q_h, K_t) / sqrt(head_dim) */
const uint16_t *kt = kcache_layer + (size_t)t * kv_dim + kv_h * head_dim;
float score = 0.0f;
for (int d = 0; d < head_dim; d++) {
score += qh[d] * fp16_to_fp32(kt[d]);
}
score /= sqrtf((float)head_dim);

/* Online softmax update */
const uint16_t *vt = vcache_layer + (size_t)t * kv_dim + kv_h * head_dim;

if (score > max_score) {
float correction = expf(max_score - score);
sum_exp = sum_exp * correction + 1.0f;
for (int d = 0; d < head_dim; d++) {
acc[d] = acc[d] * correction + fp16_to_fp32(vt[d]);
}
max_score = score;
} else {
float w = expf(score - max_score);
sum_exp += w;
for (int d = 0; d < head_dim; d++) {
acc[d] += w * fp16_to_fp32(vt[d]);
}
}
}

/* Normalize */
float inv_sum = 1.0f / sum_exp;
for (int d = 0; d < head_dim; d++) {
xbh[d] = acc[d] * inv_sum;
}
}

/* Output projection */
matmul(s->xb2, s->xb, lw->attn_output, dim, dim, lw->type_attn_output);
vec_add(s->x, s->xb2, dim);

/* ---- FFN (SwiGLU) ---- */
rmsnorm(s->xb, s->x, s->ffn_norm_w[l], dim);

matmul(s->hb, s->xb, lw->ffn_gate, dim, n_ffn, lw->type_ffn_gate);
matmul(s->hb2, s->xb, lw->ffn_up, dim, n_ffn, lw->type_ffn_up);

silu(s->hb, n_ffn);
elemwise_mul(s->hb, s->hb, s->hb2, n_ffn);

matmul(s->xb, s->hb, lw->ffn_down, n_ffn, dim, lw->type_ffn_down);
vec_add(s->x, s->xb, dim);

memcpy(batch_x + (size_t)i * dim, s->x, dim * sizeof(float));
}
}

/* 3. Final RMSNorm */
memcpy(s->x, batch_x + (size_t)(num_tokens - 1) * dim, dim * sizeof(float));
rmsnorm(s->x, s->x, s->output_norm_w, dim);

/* 4. Output projection -> logits */
matmul(s->logits, s->x, w->output, dim, c->vocab_size, w->type_output);

free(batch_x);

return s->logits;
}

void model_free(model_t *m) {
if (m->state.mem_block) {
free(m->state.mem_block);
Expand Down
4 changes: 4 additions & 0 deletions picolm/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ int model_load(model_t *m, const char *path, int max_seq_len);
/* Run one forward pass. Returns pointer to logits[vocab_size]. */
float *model_forward(model_t *m, int token, int pos);

/* Run a forward pass for a batch of tokens. The batch should contain a contiguous sequence of tokens starting at start_pos.
* Returns pointer to logits for the last token in the batch (i.e. batch_x[(num_tokens-1) * dim] after the final RMSNorm). */
float *model_forward_batch(model_t *m, const int *tokens, int num_tokens, int start_pos);

/* Free all resources. */
void model_free(model_t *m);

Expand Down
11 changes: 11 additions & 0 deletions picolm/picolm.c
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,17 @@ int main(int argc, char **argv) {
total_steps = model.config.max_seq_len;
}

float* logits = model_forward_batch(&model, prompt_tokens, n_prompt, 0); /* prefill all prompt tokens in batch */
grammar_apply(&grammar, logits, model.config.vocab_size);
int next = sampler_sample(&sampler, logits, model.config.vocab_size);
grammar_advance(&grammar, &tokenizer, next);
const char* piece = tokenizer_decode(&tokenizer, token, next);
printf("%s", piece);
fflush(stdout);
pos = n_prompt;
token = next;
total_gen++;

for (; pos < total_steps; pos++) {
/* Determine which token to feed */
if (pos < start_pos) {
Expand Down