@@ -208,9 +208,15 @@ static bool load_checkpoint(const char *path, CkptHeader *hdr,
208208 return true ;
209209}
210210
211- #define MAX_COMPILES 100
211+ static inline int get_max_compiles_tiny (void ) {
212+ const char *env = getenv (" ANE_MAX_COMPILES" );
213+ return env ? atoi (env) : 100 ;
214+ }
215+ static inline int get_accum_steps_tiny (void ) {
216+ const char *env = getenv (" ANE_ACCUM_STEPS" );
217+ return env ? atoi (env) : 10 ;
218+ }
212219#define KERNELS_PER_STEP 4
213- #define ACCUM_STEPS 10
214220
215221// === Pipeline: background compile via GCD ===
216222typedef struct {
@@ -241,6 +247,8 @@ int main(int argc, char *argv[]) {
241247 float lr = 1 .0f ;
242248 int start_step = 0 ;
243249 bool resuming = false ;
250+ int accum_steps = get_accum_steps_tiny ();
251+ int max_compiles = get_max_compiles_tiny ();
244252
245253 float *W1 = (float *)malloc (H * D * sizeof (float ));
246254 float *W2 = (float *)malloc (D * H * sizeof (float ));
@@ -288,12 +296,12 @@ int main(int argc, char *argv[]) {
288296 for (int i = 0 ; i < D*H; i++) W2[i] = 0 .01f * cosf (i * 0 .9f + 1 .1f );
289297 printf (" === ANE Training: Pipeline Parallel + Grad Accumulation ===\n " );
290298 printf (" x:[%d ,%d ] -> W1:[%d ,%d ] -> ReLU -> W2:[%d ,%d ] -> y:[%d ,%d ]\n " , S,D, H,D, D,H, S,D);
291- printf (" Accum %d steps per recompile | Pipeline: compile overlaps ANE eval\n " , ACCUM_STEPS );
299+ printf (" Accum %d steps per recompile | Pipeline: compile overlaps ANE eval\n " , accum_steps );
292300 printf (" ANE FP16 peak: 15.8 TFLOPS (M4) | Weights: %.1f KB\n\n " , weight_bytes/1024.0 );
293301 printf (" FLOPs/step: ANE=%.0f (fwd+bwd) CPU=%.0f (dW) Total=%.0f \n " ,
294302 ane_flops_per_step, cpu_flops_per_step, total_flops_per_step);
295303 printf (" Steps: %d , LR: %.4f , exec() budget: %d compiles\n\n " ,
296- total_steps, lr, MAX_COMPILES );
304+ total_steps, lr, max_compiles );
297305 }
298306
299307 float *x = (float *)calloc (S * D, sizeof (float ));
@@ -342,7 +350,7 @@ int main(int argc, char *argv[]) {
342350 int step = start_step;
343351 while (step < total_steps) {
344352 // Check compile budget
345- if (g_compile_count + KERNELS_PER_STEP > MAX_COMPILES ) {
353+ if (g_compile_count + KERNELS_PER_STEP > max_compiles ) {
346354 free_kern (k1_fwd); free_kern (k2_fwd);
347355 free_kern (k1_bwd); free_kern (k2_bwd);
348356 save_checkpoint (CKPT_PATH, step, last_loss, D, H, S, total_steps, lr, W1, W2,
@@ -368,7 +376,7 @@ int main(int argc, char *argv[]) {
368376 // So we need to update weights BEFORE launching background compile
369377
370378 uint64_t t_batch = mach_absolute_time ();
371- for (int a = 0 ; a < ACCUM_STEPS && step < total_steps; a++, step++) {
379+ for (int a = 0 ; a < accum_steps && step < total_steps; a++, step++) {
372380 ane_eval_k (k1_fwd, x, h, D, H, S);
373381 for (int i = 0 ; i < S*H; i++) h_relu[i] = h[i] > 0 ? h[i] : 0 ;
374382 ane_eval_k (k2_fwd, h_relu, y, H, D, S);
@@ -422,7 +430,7 @@ int main(int argc, char *argv[]) {
422430 // Pipeline: launch background compile with updated weights,
423431 // then immediately start NEXT batch's ANE evals with OLD kernels
424432 // while compile runs concurrently on GCD queue
425- bool can_pipeline = (step < total_steps) && (g_compile_count + KERNELS_PER_STEP <= MAX_COMPILES );
433+ bool can_pipeline = (step < total_steps) && (g_compile_count + KERNELS_PER_STEP <= max_compiles );
426434
427435 if (can_pipeline) {
428436 // Snapshot weights for background compile
@@ -455,7 +463,7 @@ int main(int argc, char *argv[]) {
455463 int steps_overlap = 0 ;
456464 uint64_t t_overlap = mach_absolute_time ();
457465
458- for (int a = 0 ; a < ACCUM_STEPS && step < total_steps; a++, step++) {
466+ for (int a = 0 ; a < accum_steps && step < total_steps; a++, step++) {
459467 ane_eval_k (k1_fwd, x, h, D, H, S);
460468 for (int i = 0 ; i < S*H; i++) h_relu[i] = h[i] > 0 ? h[i] : 0 ;
461469 ane_eval_k (k2_fwd, h_relu, y, H, D, S);
@@ -562,7 +570,7 @@ int main(int argc, char *argv[]) {
562570 // === Efficiency Report ===
563571 printf (" \n === Efficiency Report ===\n " );
564572 printf (" Total steps: %d \n " , total_steps_done);
565- printf (" Total batches: %d (accum %d steps each)\n " , total_batches, ACCUM_STEPS );
573+ printf (" Total batches: %d (accum %d steps each)\n " , total_batches, accum_steps );
566574 printf (" Wall time: %.0f ms\n " , total_wall_ms);
567575 printf (" Compile time: %.0f ms (%.1f%% )\n " , total_compile_ms, 100.0 *total_compile_ms/total_wall_ms);
568576 printf (" Train time: %.0f ms (%.1f%% )\n " , total_train_ms, 100.0 *total_train_ms/total_wall_ms);
@@ -589,8 +597,8 @@ int main(int argc, char *argv[]) {
589597 printf (" Weight params: %d (%.1f KB FP16)\n " ,
590598 H*D + D*H, weight_bytes / 1024.0 );
591599 printf (" Compile amortization: %.1f ms compile / %d steps = %.2f ms/step overhead\n " ,
592- total_compile_ms / total_batches, ACCUM_STEPS ,
593- total_compile_ms / total_batches / ACCUM_STEPS );
600+ total_compile_ms / total_batches, accum_steps ,
601+ total_compile_ms / total_batches / accum_steps );
594602 printf (" Compile fraction: %.1f%% of wall time\n " , 100.0 * total_compile_ms / total_wall_ms);
595603 printf (" Train fraction: %.1f%% of wall time (useful work)\n " , 100.0 * total_train_ms / total_wall_ms);
596604
0 commit comments