**Fixed:** * model definition * sequence length (?) **Variable:** * batch size * number of ranks * parallelization strategy (Megatron, FSDP2, HSDP, DP, TP, PP once ready, CP needed?), * (selective) Activation Checkpointing * torch compile * Pytorch Flash Attention vs Dao Flash Attention (need to check if pytorch calls internally DAO Flash Attention anyways) * Special kernels? * CPU Offloading (optional)