Skip to content

Commit 46319e9

Browse files
committed
[feat] Add IOSurface pool and env-configurable ACCUM_STEPS/MAX_COMPILES (upstream PR maderix#33)
1 parent b8c497c commit 46319e9

4 files changed

Lines changed: 70 additions & 26 deletions

File tree

training/stories_config.h

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,22 @@
2222
#define SEQ 256
2323
#define NLAYERS 12
2424
#define VOCAB 32000
25-
#define DEFAULT_ACCUM_STEPS 10
26-
#define MAX_COMPILES 100
27-
static int g_accum_steps = DEFAULT_ACCUM_STEPS;
25+
#define ACCUM_STEPS_DEFAULT 10
26+
#define MAX_COMPILES_DEFAULT 100
27+
#define ACCUM_STEPS ACCUM_STEPS_DEFAULT
28+
#define MAX_COMPILES MAX_COMPILES_DEFAULT
2829

29-
static void init_accum_steps(void) {
30+
static inline int get_accum_steps(void) {
3031
const char *env = getenv("ANE_ACCUM_STEPS");
31-
if (env && env[0]) {
32-
int v = atoi(env);
33-
if (v > 0 && v <= 10000) g_accum_steps = v;
34-
}
32+
if (env) { int v = atoi(env); if (v > 0 && v <= 10000) return v; }
33+
return ACCUM_STEPS_DEFAULT;
3534
}
3635

37-
#define ACCUM_STEPS g_accum_steps
36+
static inline int get_max_compiles(void) {
37+
const char *env = getenv("ANE_MAX_COMPILES");
38+
if (env) { int v = atoi(env); if (v > 0) return v; }
39+
return MAX_COMPILES_DEFAULT;
40+
}
3841

3942
// Per compile: 5 weight-bearing kernels per layer + 1 classifier = 5*12+1 = 61
4043
// Plus 1 static (sdpaBwd2 per layer, no weights) = 12 more but those are weight-free
@@ -97,7 +100,7 @@ typedef struct {
97100
} LayerGrads;
98101

99102
// ANE kernels per layer
100-
typedef struct { void *model; IOSurfaceRef ioIn, ioOut; void *request; void *tmpDir; } Kern;
103+
typedef struct { void *model; IOSurfaceRef ioIn, ioOut; void *request; void *tmpDir; size_t inBytes, outBytes; } Kern;
101104
typedef struct {
102105
Kern *fwdAttn, *fwdFFN, *ffnBwd, *sdpaBwd1, *sdpaBwd2, *qkvBwd;
103106
} LayerKernels;

training/stories_io.h

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,41 @@
33
#include "stories_config.h"
44
#include <arm_neon.h>
55

6+
// IOSurface pool — reuse freed surfaces of the same size
7+
#define IOSURF_POOL_MAX 128
8+
static struct {
9+
IOSurfaceRef surfaces[IOSURF_POOL_MAX];
10+
size_t sizes[IOSURF_POOL_MAX];
11+
int count;
12+
} g_iosurf_pool = { .count = 0 };
13+
614
static IOSurfaceRef make_surface(size_t bytes) {
15+
// Check pool for matching size
16+
for (int i = 0; i < g_iosurf_pool.count; i++) {
17+
if (g_iosurf_pool.sizes[i] == bytes) {
18+
IOSurfaceRef s = g_iosurf_pool.surfaces[i];
19+
// Swap-remove
20+
g_iosurf_pool.surfaces[i] = g_iosurf_pool.surfaces[--g_iosurf_pool.count];
21+
g_iosurf_pool.sizes[i] = g_iosurf_pool.sizes[g_iosurf_pool.count];
22+
return s;
23+
}
24+
}
725
return IOSurfaceCreate((__bridge CFDictionaryRef)@{
826
(id)kIOSurfaceWidth:@(bytes), (id)kIOSurfaceHeight:@1,
927
(id)kIOSurfaceBytesPerElement:@1, (id)kIOSurfaceBytesPerRow:@(bytes),
1028
(id)kIOSurfaceAllocSize:@(bytes), (id)kIOSurfacePixelFormat:@0});
1129
}
1230

31+
static void pool_return_surface(IOSurfaceRef s, size_t bytes) {
32+
if (g_iosurf_pool.count < IOSURF_POOL_MAX) {
33+
g_iosurf_pool.surfaces[g_iosurf_pool.count] = s;
34+
g_iosurf_pool.sizes[g_iosurf_pool.count] = bytes;
35+
g_iosurf_pool.count++;
36+
} else {
37+
CFRelease(s);
38+
}
39+
}
40+
1341
static NSData *build_blob(const float *w, int rows, int cols) {
1442
size_t ws=(size_t)rows*cols*2, tot=128+ws; // size_t prevents int overflow (CRIT-04)
1543
uint8_t *b=(uint8_t*)calloc(tot,1);
@@ -121,6 +149,8 @@ static Kern *compile_kern_mil_w(NSString *mil, NSDictionary *weights, int ic_byt
121149
k->model = (void*)CFBridgingRetain(mdl);
122150
k->ioIn = make_surface(ic_bytes);
123151
k->ioOut = make_surface(oc_bytes);
152+
k->inBytes = ic_bytes;
153+
k->outBytes = oc_bytes;
124154
id wI = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(g_AIO, @selector(objectWithIOSurface:), k->ioIn);
125155
id wO = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(g_AIO, @selector(objectWithIOSurface:), k->ioOut);
126156
k->request = (void*)CFBridgingRetain(((id(*)(Class,SEL,id,id,id,id,id,id,id))objc_msgSend)(g_AR,
@@ -134,7 +164,8 @@ static void free_kern(Kern *k) {
134164
if (!k) return;
135165
id mdl = (__bridge id)k->model; NSError *e = nil;
136166
((BOOL(*)(id,SEL,unsigned int,NSError**))objc_msgSend)(mdl, @selector(unloadWithQoS:error:), 21, &e);
137-
CFRelease(k->ioIn); CFRelease(k->ioOut);
167+
pool_return_surface(k->ioIn, k->inBytes);
168+
pool_return_surface(k->ioOut, k->outBytes);
138169
[[NSFileManager defaultManager] removeItemAtPath:(__bridge id)k->tmpDir error:nil];
139170
CFRelease(k->model); CFRelease(k->request); CFRelease(k->tmpDir);
140171
free(k);

training/tiny_train.m

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -208,9 +208,15 @@ static bool load_checkpoint(const char *path, CkptHeader *hdr,
208208
return true;
209209
}
210210

211-
#define MAX_COMPILES 100
211+
static inline int get_max_compiles_tiny(void) {
212+
const char *env = getenv("ANE_MAX_COMPILES");
213+
return env ? atoi(env) : 100;
214+
}
215+
static inline int get_accum_steps_tiny(void) {
216+
const char *env = getenv("ANE_ACCUM_STEPS");
217+
return env ? atoi(env) : 10;
218+
}
212219
#define KERNELS_PER_STEP 4
213-
#define ACCUM_STEPS 10
214220

215221
// === Pipeline: background compile via GCD ===
216222
typedef struct {
@@ -241,6 +247,8 @@ int main(int argc, char *argv[]) {
241247
float lr = 1.0f;
242248
int start_step = 0;
243249
bool resuming = false;
250+
int accum_steps = get_accum_steps_tiny();
251+
int max_compiles = get_max_compiles_tiny();
244252

245253
float *W1 = (float*)malloc(H * D * sizeof(float));
246254
float *W2 = (float*)malloc(D * H * sizeof(float));
@@ -288,12 +296,12 @@ int main(int argc, char *argv[]) {
288296
for (int i = 0; i < D*H; i++) W2[i] = 0.01f * cosf(i * 0.9f + 1.1f);
289297
printf("=== ANE Training: Pipeline Parallel + Grad Accumulation ===\n");
290298
printf("x:[%d,%d] -> W1:[%d,%d] -> ReLU -> W2:[%d,%d] -> y:[%d,%d]\n", S,D, H,D, D,H, S,D);
291-
printf("Accum %d steps per recompile | Pipeline: compile overlaps ANE eval\n", ACCUM_STEPS);
299+
printf("Accum %d steps per recompile | Pipeline: compile overlaps ANE eval\n", accum_steps);
292300
printf("ANE FP16 peak: 15.8 TFLOPS (M4) | Weights: %.1f KB\n\n", weight_bytes/1024.0);
293301
printf("FLOPs/step: ANE=%.0f (fwd+bwd) CPU=%.0f (dW) Total=%.0f\n",
294302
ane_flops_per_step, cpu_flops_per_step, total_flops_per_step);
295303
printf("Steps: %d, LR: %.4f, exec() budget: %d compiles\n\n",
296-
total_steps, lr, MAX_COMPILES);
304+
total_steps, lr, max_compiles);
297305
}
298306

299307
float *x = (float*)calloc(S * D, sizeof(float));
@@ -342,7 +350,7 @@ int main(int argc, char *argv[]) {
342350
int step = start_step;
343351
while (step < total_steps) {
344352
// Check compile budget
345-
if (g_compile_count + KERNELS_PER_STEP > MAX_COMPILES) {
353+
if (g_compile_count + KERNELS_PER_STEP > max_compiles) {
346354
free_kern(k1_fwd); free_kern(k2_fwd);
347355
free_kern(k1_bwd); free_kern(k2_bwd);
348356
save_checkpoint(CKPT_PATH, step, last_loss, D, H, S, total_steps, lr, W1, W2,
@@ -368,7 +376,7 @@ int main(int argc, char *argv[]) {
368376
// So we need to update weights BEFORE launching background compile
369377

370378
uint64_t t_batch = mach_absolute_time();
371-
for (int a = 0; a < ACCUM_STEPS && step < total_steps; a++, step++) {
379+
for (int a = 0; a < accum_steps && step < total_steps; a++, step++) {
372380
ane_eval_k(k1_fwd, x, h, D, H, S);
373381
for (int i = 0; i < S*H; i++) h_relu[i] = h[i] > 0 ? h[i] : 0;
374382
ane_eval_k(k2_fwd, h_relu, y, H, D, S);
@@ -422,7 +430,7 @@ int main(int argc, char *argv[]) {
422430
// Pipeline: launch background compile with updated weights,
423431
// then immediately start NEXT batch's ANE evals with OLD kernels
424432
// while compile runs concurrently on GCD queue
425-
bool can_pipeline = (step < total_steps) && (g_compile_count + KERNELS_PER_STEP <= MAX_COMPILES);
433+
bool can_pipeline = (step < total_steps) && (g_compile_count + KERNELS_PER_STEP <= max_compiles);
426434

427435
if (can_pipeline) {
428436
// Snapshot weights for background compile
@@ -455,7 +463,7 @@ int main(int argc, char *argv[]) {
455463
int steps_overlap = 0;
456464
uint64_t t_overlap = mach_absolute_time();
457465

458-
for (int a = 0; a < ACCUM_STEPS && step < total_steps; a++, step++) {
466+
for (int a = 0; a < accum_steps && step < total_steps; a++, step++) {
459467
ane_eval_k(k1_fwd, x, h, D, H, S);
460468
for (int i = 0; i < S*H; i++) h_relu[i] = h[i] > 0 ? h[i] : 0;
461469
ane_eval_k(k2_fwd, h_relu, y, H, D, S);
@@ -562,7 +570,7 @@ int main(int argc, char *argv[]) {
562570
// === Efficiency Report ===
563571
printf("\n=== Efficiency Report ===\n");
564572
printf("Total steps: %d\n", total_steps_done);
565-
printf("Total batches: %d (accum %d steps each)\n", total_batches, ACCUM_STEPS);
573+
printf("Total batches: %d (accum %d steps each)\n", total_batches, accum_steps);
566574
printf("Wall time: %.0f ms\n", total_wall_ms);
567575
printf("Compile time: %.0f ms (%.1f%%)\n", total_compile_ms, 100.0*total_compile_ms/total_wall_ms);
568576
printf("Train time: %.0f ms (%.1f%%)\n", total_train_ms, 100.0*total_train_ms/total_wall_ms);
@@ -589,8 +597,8 @@ int main(int argc, char *argv[]) {
589597
printf("Weight params: %d (%.1f KB FP16)\n",
590598
H*D + D*H, weight_bytes / 1024.0);
591599
printf("Compile amortization: %.1f ms compile / %d steps = %.2f ms/step overhead\n",
592-
total_compile_ms / total_batches, ACCUM_STEPS,
593-
total_compile_ms / total_batches / ACCUM_STEPS);
600+
total_compile_ms / total_batches, accum_steps,
601+
total_compile_ms / total_batches / accum_steps);
594602
printf("Compile fraction: %.1f%% of wall time\n", 100.0 * total_compile_ms / total_wall_ms);
595603
printf("Train fraction: %.1f%% of wall time (useful work)\n", 100.0 * total_train_ms / total_wall_ms);
596604

training/train_large.m

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,8 @@ int main(int argc, char *argv[]) {
192192
float lr = 3e-4f;
193193
float adam_b1=0.9f, adam_b2=0.999f, adam_eps=1e-8f;
194194
int adam_t = 0, start_step = 0;
195+
int accum_steps = get_accum_steps();
196+
int max_compiles = get_max_compiles();
195197

196198
// Parse args
197199
const char *ckpt_path = CKPT_PATH_DEFAULT;
@@ -271,7 +273,7 @@ int main(int argc, char *argv[]) {
271273
printf("Params: %.2fM (transformer %.2fM + embed %.2fM)\n", tp/1e6, xfmr_params/1e6, embed_params/1e6);
272274
printf("Kernels: %d (%d weight-bearing + %d static sdpaBwd2)\n",
273275
TOTAL_WEIGHT_KERNELS+NLAYERS, TOTAL_WEIGHT_KERNELS, NLAYERS);
274-
printf("Accum %d steps per recompile | Adam LR=%.1e b1=%.1f b2=%.3f\n", ACCUM_STEPS, lr, adam_b1, adam_b2);
276+
printf("Accum %d steps per recompile | Adam LR=%.1e b1=%.1f b2=%.3f\n", accum_steps, lr, adam_b1, adam_b2);
275277
double fwd_f = NLAYERS*(4.0*2*DIM*DIM*SEQ + 2.0*2*DIM*HIDDEN*SEQ + 2.0*HIDDEN*DIM*SEQ);
276278
double bwd_dx_f = fwd_f, bwd_dw_f = fwd_f;
277279
double sdpa_f = NLAYERS*2.0*HEADS*5*SEQ*SEQ*HD;
@@ -361,7 +363,7 @@ int main(int argc, char *argv[]) {
361363
int step = start_step;
362364
while (step < total_steps) {
363365
// Check compile budget
364-
if (g_compile_count + TOTAL_WEIGHT_KERNELS > MAX_COMPILES) {
366+
if (g_compile_count + TOTAL_WEIGHT_KERNELS > max_compiles) {
365367
for (int L=0; L<NLAYERS; L++) { free_layer_kernels(&kern[L]); free_kern(sdpaBwd2[L]); }
366368
double wall = tb_ms(mach_absolute_time() - t_wall_start);
367369
save_checkpoint(ckpt_path, step, total_steps, lr, last_loss,
@@ -387,7 +389,7 @@ int main(int argc, char *argv[]) {
387389
compile_ok = false; break;
388390
}
389391
}
390-
if (!compile_ok) { g_compile_count = MAX_COMPILES; continue; }
392+
if (!compile_ok) { g_compile_count = max_compiles; continue; }
391393

392394
// Re-compile sdpaBwd2 if needed (after exec restart)
393395
for (int L=0; L<NLAYERS; L++) {
@@ -410,7 +412,7 @@ int main(int argc, char *argv[]) {
410412
uint64_t tt = mach_absolute_time();
411413
double t_ane=0,t_io=0,t_elem=0,t_rms=0,t_cblas_wait=0,t_cls=0;
412414

413-
for (int a=0; a<ACCUM_STEPS && step<total_steps; a++, step++) {
415+
for (int a=0; a<accum_steps && step<total_steps; a++, step++) {
414416
uint64_t t0,t1;
415417
// Sample random position in token data
416418
size_t max_pos = n_tokens - SEQ - 1;

0 commit comments

Comments
 (0)