diff --git a/submissions/external_tuning/muon/__init__.py b/submissions/external_tuning/muon/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/submissions/external_tuning/muon/pytorch/__init__.py b/submissions/external_tuning/muon/pytorch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/submissions/external_tuning/muon/pytorch/docs/criteo1tb.txt b/submissions/external_tuning/muon/pytorch/docs/criteo1tb.txt new file mode 100644 index 00000000..4dddf2a5 --- /dev/null +++ b/submissions/external_tuning/muon/pytorch/docs/criteo1tb.txt @@ -0,0 +1,23 @@ +I1007 11:06:18.108439 23099690366016 utils.py:33] Muon params: + module.bot_mlp.0.weight (ndim=2) + module.bot_mlp.2.weight (ndim=2) + module.bot_mlp.4.weight (ndim=2) + module.top_mlp.0.weight (ndim=2) + module.top_mlp.2.weight (ndim=2) + module.top_mlp.4.weight (ndim=2) + module.top_mlp.6.weight (ndim=2) + module.top_mlp.9.weight (ndim=2) +I1007 11:06:18.108455 22916307100736 submission_runner.py:339] Initializing checkpoint and logger. +I1007 11:06:18.108510 23099690366016 utils.py:34] Adam params: + module.embedding_chunk_0 (ndim=2) + module.embedding_chunk_1 (ndim=2) + module.embedding_chunk_2 (ndim=2) + module.embedding_chunk_3 (ndim=2) + module.bot_mlp.0.bias (ndim=1) + module.bot_mlp.2.bias (ndim=1) + module.bot_mlp.4.bias (ndim=1) + module.top_mlp.0.bias (ndim=1) + module.top_mlp.2.bias (ndim=1) + module.top_mlp.4.bias (ndim=1) + module.top_mlp.6.bias (ndim=1) + module.top_mlp.9.bias (ndim=1) diff --git a/submissions/external_tuning/muon/pytorch/docs/fastmri.txt b/submissions/external_tuning/muon/pytorch/docs/fastmri.txt new file mode 100644 index 00000000..32c7bc2b --- /dev/null +++ b/submissions/external_tuning/muon/pytorch/docs/fastmri.txt @@ -0,0 +1,27 @@ +I1007 11:06:19.234473 22756046701632 utils.py:33] Muon params: + _orig_mod.module.down_sample_layers.0.conv_layers.0.weight (ndim=4) + _orig_mod.module.down_sample_layers.0.conv_layers.4.weight (ndim=4) + _orig_mod.module.down_sample_layers.1.conv_layers.0.weight (ndim=4) + _orig_mod.module.down_sample_layers.1.conv_layers.4.weight (ndim=4) + _orig_mod.module.down_sample_layers.2.conv_layers.0.weight (ndim=4) + _orig_mod.module.down_sample_layers.2.conv_layers.4.weight (ndim=4) + _orig_mod.module.down_sample_layers.3.conv_layers.0.weight (ndim=4) + _orig_mod.module.down_sample_layers.3.conv_layers.4.weight (ndim=4) + _orig_mod.module.conv.conv_layers.0.weight (ndim=4) + _orig_mod.module.conv.conv_layers.4.weight (ndim=4) + _orig_mod.module.up_conv.0.conv_layers.0.weight (ndim=4) + _orig_mod.module.up_conv.0.conv_layers.4.weight (ndim=4) + _orig_mod.module.up_conv.1.conv_layers.0.weight (ndim=4) + _orig_mod.module.up_conv.1.conv_layers.4.weight (ndim=4) + _orig_mod.module.up_conv.2.conv_layers.0.weight (ndim=4) + _orig_mod.module.up_conv.2.conv_layers.4.weight (ndim=4) + _orig_mod.module.up_conv.3.0.conv_layers.0.weight (ndim=4) + _orig_mod.module.up_conv.3.0.conv_layers.4.weight (ndim=4) + _orig_mod.module.up_conv.3.1.weight (ndim=4) + _orig_mod.module.up_transpose_conv.0.layers.0.weight (ndim=4) + _orig_mod.module.up_transpose_conv.1.layers.0.weight (ndim=4) + _orig_mod.module.up_transpose_conv.2.layers.0.weight (ndim=4) + _orig_mod.module.up_transpose_conv.3.layers.0.weight (ndim=4) +I1007 11:06:19.234533 23057129976896 utils.py:34] Adam params: + _orig_mod.module.up_conv.3.1.bias (ndim=1) +I1007 11:06:19.234585 22438544032832 submission_runner.py:339] Initializing checkpoint and logger. diff --git a/submissions/external_tuning/muon/pytorch/docs/finewebedu_lm.txt b/submissions/external_tuning/muon/pytorch/docs/finewebedu_lm.txt new file mode 100644 index 00000000..6635d99c --- /dev/null +++ b/submissions/external_tuning/muon/pytorch/docs/finewebedu_lm.txt @@ -0,0 +1,88 @@ +I0209 15:07:04.723858 22823061509952 utils.py:28] Adam params: + _orig_mod.module.layers.0.attn.attn_scale (ndim=0) + _orig_mod.module.layers.0.attn_norm.weight (ndim=1) + _orig_mod.module.layers.0.mlp_norm.weight (ndim=1) + _orig_mod.module.layers.1.attn.attn_scale (ndim=0) + _orig_mod.module.layers.1.attn_norm.weight (ndim=1) + _orig_mod.module.layers.1.mlp_norm.weight (ndim=1) + _orig_mod.module.layers.2.attn.attn_scale (ndim=0) + _orig_mod.module.layers.2.attn_norm.weight (ndim=1) + _orig_mod.module.layers.2.mlp_norm.weight (ndim=1) + _orig_mod.module.layers.3.attn.attn_scale (ndim=0) + _orig_mod.module.layers.3.attn_norm.weight (ndim=1) + _orig_mod.module.layers.3.mlp_norm.weight (ndim=1) + _orig_mod.module.layers.4.attn.attn_scale (ndim=0) + _orig_mod.module.layers.4.attn_norm.weight (ndim=1) + _orig_mod.module.layers.4.mlp_norm.weight (ndim=1) + _orig_mod.module.layers.5.attn.attn_scale (ndim=0) + _orig_mod.module.layers.5.attn_norm.weight (ndim=1) + _orig_mod.module.layers.5.mlp_norm.weight (ndim=1) + _orig_mod.module.layers.6.attn.attn_scale (ndim=0) + _orig_mod.module.layers.6.attn_norm.weight (ndim=1) + _orig_mod.module.layers.6.mlp_norm.weight (ndim=1) + _orig_mod.module.layers.7.attn.attn_scale (ndim=0) + _orig_mod.module.layers.7.attn_norm.weight (ndim=1) + _orig_mod.module.layers.7.mlp_norm.weight (ndim=1) + _orig_mod.module.layers.8.attn.attn_scale (ndim=0) + _orig_mod.module.layers.8.attn_norm.weight (ndim=1) + _orig_mod.module.layers.8.mlp_norm.weight (ndim=1) + _orig_mod.module.layers.9.attn.attn_scale (ndim=0) + _orig_mod.module.layers.9.attn_norm.weight (ndim=1) + _orig_mod.module.layers.9.mlp_norm.weight (ndim=1) + _orig_mod.module.layers.10.attn.attn_scale (ndim=0) + _orig_mod.module.layers.10.attn_norm.weight (ndim=1) + _orig_mod.module.layers.10.mlp_norm.weight (ndim=1) + _orig_mod.module.layers.11.attn.attn_scale (ndim=0) + _orig_mod.module.layers.11.attn_norm.weight (ndim=1) + _orig_mod.module.layers.11.mlp_norm.weight (ndim=1) + _orig_mod.module.out_norm.weight (ndim=1) +I0209 15:07:04.727331 23108874655552 utils.py:27] Muon params: + _orig_mod.module.embed_tokens.weight (ndim=2) + _orig_mod.module.layers.0.attn.w_qkv.weight (ndim=2) + _orig_mod.module.layers.0.attn.w_out.weight (ndim=2) + _orig_mod.module.layers.0.mlp.fc1.weight (ndim=2) + _orig_mod.module.layers.0.mlp.fc2.weight (ndim=2) + _orig_mod.module.layers.1.attn.w_qkv.weight (ndim=2) + _orig_mod.module.layers.1.attn.w_out.weight (ndim=2) + _orig_mod.module.layers.1.mlp.fc1.weight (ndim=2) + _orig_mod.module.layers.1.mlp.fc2.weight (ndim=2) + _orig_mod.module.layers.2.attn.w_qkv.weight (ndim=2) + _orig_mod.module.layers.2.attn.w_out.weight (ndim=2) + _orig_mod.module.layers.2.mlp.fc1.weight (ndim=2) + _orig_mod.module.layers.2.mlp.fc2.weight (ndim=2) + _orig_mod.module.layers.3.attn.w_qkv.weight (ndim=2) + _orig_mod.module.layers.3.attn.w_out.weight (ndim=2) + _orig_mod.module.layers.3.mlp.fc1.weight (ndim=2) + _orig_mod.module.layers.3.mlp.fc2.weight (ndim=2) + _orig_mod.module.layers.4.attn.w_qkv.weight (ndim=2) + _orig_mod.module.layers.4.attn.w_out.weight (ndim=2) + _orig_mod.module.layers.4.mlp.fc1.weight (ndim=2) + _orig_mod.module.layers.4.mlp.fc2.weight (ndim=2) + _orig_mod.module.layers.5.attn.w_qkv.weight (ndim=2) + _orig_mod.module.layers.5.attn.w_out.weight (ndim=2) + _orig_mod.module.layers.5.mlp.fc1.weight (ndim=2) + _orig_mod.module.layers.5.mlp.fc2.weight (ndim=2) + _orig_mod.module.layers.6.attn.w_qkv.weight (ndim=2) + _orig_mod.module.layers.6.attn.w_out.weight (ndim=2) + _orig_mod.module.layers.6.mlp.fc1.weight (ndim=2) + _orig_mod.module.layers.6.mlp.fc2.weight (ndim=2) + _orig_mod.module.layers.7.attn.w_qkv.weight (ndim=2) + _orig_mod.module.layers.7.attn.w_out.weight (ndim=2) + _orig_mod.module.layers.7.mlp.fc1.weight (ndim=2) + _orig_mod.module.layers.7.mlp.fc2.weight (ndim=2) + _orig_mod.module.layers.8.attn.w_qkv.weight (ndim=2) + _orig_mod.module.layers.8.attn.w_out.weight (ndim=2) + _orig_mod.module.layers.8.mlp.fc1.weight (ndim=2) + _orig_mod.module.layers.8.mlp.fc2.weight (ndim=2) + _orig_mod.module.layers.9.attn.w_qkv.weight (ndim=2) + _orig_mod.module.layers.9.attn.w_out.weight (ndim=2) + _orig_mod.module.layers.9.mlp.fc1.weight (ndim=2) + _orig_mod.module.layers.9.mlp.fc2.weight (ndim=2) + _orig_mod.module.layers.10.attn.w_qkv.weight (ndim=2) + _orig_mod.module.layers.10.attn.w_out.weight (ndim=2) + _orig_mod.module.layers.10.mlp.fc1.weight (ndim=2) + _orig_mod.module.layers.10.mlp.fc2.weight (ndim=2) + _orig_mod.module.layers.11.attn.w_qkv.weight (ndim=2) + _orig_mod.module.layers.11.attn.w_out.weight (ndim=2) + _orig_mod.module.layers.11.mlp.fc1.weight (ndim=2) + _orig_mod.module.layers.11.mlp.fc2.weight (ndim=2) diff --git a/submissions/external_tuning/muon/pytorch/docs/imagenet_resnet.txt b/submissions/external_tuning/muon/pytorch/docs/imagenet_resnet.txt new file mode 100644 index 00000000..534dc9e2 --- /dev/null +++ b/submissions/external_tuning/muon/pytorch/docs/imagenet_resnet.txt @@ -0,0 +1,64 @@ +I1007 11:06:28.469261 23133955605568 utils.py:33] Muon params: + _orig_mod.module.conv1.weight (ndim=4) + _orig_mod.module.layer1.0.conv1.weight (ndim=4) + _orig_mod.module.layer1.0.conv2.weight (ndim=4) + _orig_mod.module.layer1.0.conv3.weight (ndim=4) + _orig_mod.module.layer1.0.downsample.conv.weight (ndim=4) + _orig_mod.module.layer1.1.conv1.weight (ndim=4) + _orig_mod.module.layer1.1.conv2.weight (ndim=4) + _orig_mod.module.layer1.1.conv3.weight (ndim=4) + _orig_mod.module.layer1.2.conv1.weight (ndim=4) + _orig_mod.module.layer1.2.conv2.weight (ndim=4) + _orig_mod.module.layer1.2.conv3.weight (ndim=4) + _orig_mod.module.layer2.0.conv1.weight (ndim=4) + _orig_mod.module.layer2.0.conv2.weight (ndim=4) + _orig_mod.module.layer2.0.conv3.weight (ndim=4) + _orig_mod.module.layer2.0.downsample.conv.weight (ndim=4) + _orig_mod.module.layer2.1.conv1.weight (ndim=4) + _orig_mod.module.layer2.1.conv2.weight (ndim=4) + _orig_mod.module.layer2.1.conv3.weight (ndim=4) + _orig_mod.module.layer2.2.conv1.weight (ndim=4) + _orig_mod.module.layer2.2.conv2.weight (ndim=4) + _orig_mod.module.layer2.2.conv3.weight (ndim=4) + _orig_mod.module.layer2.3.conv1.weight (ndim=4) + _orig_mod.module.layer2.3.conv2.weight (ndim=4) + _orig_mod.module.layer2.3.conv3.weight (ndim=4) + _orig_mod.module.layer3.0.conv1.weight (ndim=4) + _orig_mod.module.layer3.0.conv2.weight (ndim=4) + _orig_mod.module.layer3.0.conv3.weight (ndim=4) + _orig_mod.module.layer3.0.downsample.conv.weight (ndim=4) + _orig_mod.module.layer3.1.conv1.weight (ndim=4) + _orig_mod.module.layer3.1.conv2.weight (ndim=4) + _orig_mod.module.layer3.1.conv3.weight (ndim=4) + _orig_mod.module.layer3.2.conv1.weight (ndim=4) + _orig_mod.module.layer3.2.conv2.weight (ndim=4) + _orig_mod.module.layer3.2.conv3.weight (ndim=4) + _orig_mod.module.layer3.3.conv1.weight (ndim=4) + _orig_mod.module.layer3.3.conv2.weight (ndim=4) + _orig_mod.module.layer3.3.conv3.weight (ndim=4) + _orig_mod.module.layer3.4.conv1.weight (ndim=4) + _orig_mod.module.layer3.4.conv2.weight (ndim=4) + _orig_mod.module.layer3.4.conv3.weight (ndim=4) + _orig_mod.module.layer3.5.conv1.weight (ndim=4) + _orig_mod.module.layer3.5.conv2.weight (ndim=4) + _orig_mod.module.layer3.5.conv3.weight (ndim=4) + _orig_mod.module.layer4.0.conv1.weight (ndim=4) + _orig_mod.module.layer4.0.conv2.weight (ndim=4) + _orig_mod.module.layer4.0.conv3.weight (ndim=4) + _orig_mod.module.layer4.0.downsample.conv.weight (ndim=4) + _orig_mod.module.layer4.1.conv1.weight (ndim=4) + _orig_mod.module.layer4.1.conv2.weight (ndim=4) + _orig_mod.module.layer4.1.conv3.weight (ndim=4) + _orig_mod.module.layer4.2.conv1.weight (ndim=4) + _orig_mod.module.layer4.2.conv2.weight (ndim=4) + _orig_mod.module.layer4.2.conv3.weight (ndim=4) + _orig_mod.module.fc.weight (ndim=2) +I1007 11:06:28.469334 23133955605568 utils.py:34] Adam params: + _orig_mod.module.bn1.weight (ndim=1) + _orig_mod.module.bn1.bias (ndim=1) + _orig_mod.module.layer1.0.bn1.weight (ndim=1) + _orig_mod.module.layer1.0.bn1.bias (ndim=1) + _orig_mod.module.layer1.0.bn2.weight (ndim=1) + _orig_mod.module.layer1.0.bn2.bias (ndim=1) + _orig_mod.module.layer1.0.bn3.weight (ndim=1) + _orig_mod.module.layer1.0.bn3.bias (ndim=1) \ No newline at end of file diff --git a/submissions/external_tuning/muon/pytorch/docs/imagenet_vit.txt b/submissions/external_tuning/muon/pytorch/docs/imagenet_vit.txt new file mode 100644 index 00000000..362800f6 --- /dev/null +++ b/submissions/external_tuning/muon/pytorch/docs/imagenet_vit.txt @@ -0,0 +1,144 @@ +I1007 11:06:25.785190 23124329362496 utils.py:33] Muon params: + module.pre_logits.weight (ndim=2) + module.conv_patch_extract.weight (ndim=4) + module.encoder.net.0.self_attention1.query.weight (ndim=2) + module.encoder.net.0.self_attention1.key.weight (ndim=2) + module.encoder.net.0.self_attention1.value.weight (ndim=2) + module.encoder.net.0.self_attention1.out.weight (ndim=2) + module.encoder.net.0.mlp3.linear1.weight (ndim=2) + module.encoder.net.0.mlp3.linear2.weight (ndim=2) + module.encoder.net.1.self_attention1.query.weight (ndim=2) + module.encoder.net.1.self_attention1.key.weight (ndim=2) + module.encoder.net.1.self_attention1.value.weight (ndim=2) + module.encoder.net.1.self_attention1.out.weight (ndim=2) + module.encoder.net.1.mlp3.linear1.weight (ndim=2) + module.encoder.net.1.mlp3.linear2.weight (ndim=2) + module.encoder.net.2.self_attention1.query.weight (ndim=2) + module.encoder.net.2.self_attention1.key.weight (ndim=2) + module.encoder.net.2.self_attention1.value.weight (ndim=2) + module.encoder.net.2.self_attention1.out.weight (ndim=2) + module.encoder.net.2.mlp3.linear1.weight (ndim=2) + module.encoder.net.2.mlp3.linear2.weight (ndim=2) + module.encoder.net.3.self_attention1.query.weight (ndim=2) + module.encoder.net.3.self_attention1.key.weight (ndim=2) + module.encoder.net.3.self_attention1.value.weight (ndim=2) + module.encoder.net.3.self_attention1.out.weight (ndim=2) + module.encoder.net.3.mlp3.linear1.weight (ndim=2) + module.encoder.net.3.mlp3.linear2.weight (ndim=2) + module.encoder.net.4.self_attention1.query.weight (ndim=2) + module.encoder.net.4.self_attention1.key.weight (ndim=2) + module.encoder.net.4.self_attention1.value.weight (ndim=2) + module.encoder.net.4.self_attention1.out.weight (ndim=2) + module.encoder.net.4.mlp3.linear1.weight (ndim=2) + module.encoder.net.4.mlp3.linear2.weight (ndim=2) + module.encoder.net.5.self_attention1.query.weight (ndim=2) + module.encoder.net.5.self_attention1.key.weight (ndim=2) + module.encoder.net.5.self_attention1.value.weight (ndim=2) + module.encoder.net.5.self_attention1.out.weight (ndim=2) + module.encoder.net.5.mlp3.linear1.weight (ndim=2) + module.encoder.net.5.mlp3.linear2.weight (ndim=2) + module.encoder.net.6.self_attention1.query.weight (ndim=2) + module.encoder.net.6.self_attention1.key.weight (ndim=2) + module.encoder.net.6.self_attention1.value.weight (ndim=2) + module.encoder.net.6.self_attention1.out.weight (ndim=2) + module.encoder.net.6.mlp3.linear1.weight (ndim=2) + module.encoder.net.6.mlp3.linear2.weight (ndim=2) + module.encoder.net.7.self_attention1.query.weight (ndim=2) + module.encoder.net.7.self_attention1.key.weight (ndim=2) + module.encoder.net.7.self_attention1.value.weight (ndim=2) + module.encoder.net.7.self_attention1.out.weight (ndim=2) + module.encoder.net.7.mlp3.linear1.weight (ndim=2) + module.encoder.net.7.mlp3.linear2.weight (ndim=2) + module.encoder.net.8.self_attention1.query.weight (ndim=2) + module.encoder.net.8.self_attention1.key.weight (ndim=2) + module.encoder.net.8.self_attention1.value.weight (ndim=2) + module.encoder.net.8.self_attention1.out.weight (ndim=2) + module.encoder.net.8.mlp3.linear1.weight (ndim=2) + module.encoder.net.8.mlp3.linear2.weight (ndim=2) + module.encoder.net.9.self_attention1.query.weight (ndim=2) + module.encoder.net.9.self_attention1.key.weight (ndim=2) + module.encoder.net.9.self_attention1.value.weight (ndim=2) + module.encoder.net.9.self_attention1.out.weight (ndim=2) + module.encoder.net.9.mlp3.linear1.weight (ndim=2) + module.encoder.net.9.mlp3.linear2.weight (ndim=2) + module.encoder.net.10.self_attention1.query.weight (ndim=2) + module.encoder.net.10.self_attention1.key.weight (ndim=2) + module.encoder.net.10.self_attention1.value.weight (ndim=2) + module.encoder.net.10.self_attention1.out.weight (ndim=2) + module.encoder.net.10.mlp3.linear1.weight (ndim=2) + module.encoder.net.10.mlp3.linear2.weight (ndim=2) + module.encoder.net.11.self_attention1.query.weight (ndim=2) + module.encoder.net.11.self_attention1.key.weight (ndim=2) + module.encoder.net.11.self_attention1.value.weight (ndim=2) + module.encoder.net.11.self_attention1.out.weight (ndim=2) + module.encoder.net.11.mlp3.linear1.weight (ndim=2) + module.encoder.net.11.mlp3.linear2.weight (ndim=2) + module.head.weight (ndim=2) +I1007 11:06:25.785268 23124329362496 utils.py:34] Adam params: + module.pre_logits.bias (ndim=1) + module.conv_patch_extract.bias (ndim=1) + module.encoder.net.0.layer_norm0.weight (ndim=1) + module.encoder.net.0.layer_norm0.bias (ndim=1) + module.encoder.net.0.self_attention1.query.bias (ndim=1) + module.encoder.net.0.self_attention1.key.bias (ndim=1) + module.encoder.net.0.self_attention1.value.bias (ndim=1) + module.encoder.net.0.self_attention1.out.bias (ndim=1) + module.encoder.net.0.layer_norm2.weight (ndim=1) + module.encoder.net.0.layer_norm2.bias (ndim=1) + module.encoder.net.0.mlp3.linear1.bias (ndim=1) + module.encoder.net.0.mlp3.linear2.bias (ndim=1) + module.encoder.net.1.layer_norm0.weight (ndim=1) + module.encoder.net.1.layer_norm0.bias (ndim=1) + module.encoder.net.1.self_attention1.query.bias (ndim=1) + module.encoder.net.1.self_attention1.key.bias (ndim=1) + module.encoder.net.1.self_attention1.value.bias (ndim=1) + module.encoder.net.1.self_attention1.out.bias (ndim=1) + module.encoder.net.1.layer_norm2.weight (ndim=1) + module.encoder.net.1.layer_norm2.bias (ndim=1) + module.encoder.net.1.mlp3.linear1.bias (ndim=1) + module.encoder.net.1.mlp3.linear2.bias (ndim=1) + module.encoder.net.2.layer_norm0.weight (ndim=1) + module.encoder.net.2.layer_norm0.bias (ndim=1) + module.encoder.net.2.self_attention1.query.bias (ndim=1) + module.encoder.net.2.self_attention1.key.bias (ndim=1) + module.encoder.net.2.self_attention1.value.bias (ndim=1) + module.encoder.net.2.self_attention1.out.bias (ndim=1) + module.encoder.net.2.layer_norm2.weight (ndim=1) + module.encoder.net.2.layer_norm2.bias (ndim=1) + module.encoder.net.2.mlp3.linear1.bias (ndim=1) + module.encoder.net.2.mlp3.linear2.bias (ndim=1) + module.encoder.net.3.layer_norm0.weight (ndim=1) + module.encoder.net.3.layer_norm0.bias (ndim=1) + module.encoder.net.3.self_attention1.query.bias (ndim=1) + module.encoder.net.3.self_attention1.key.bias (ndim=1) + module.encoder.net.3.self_attention1.value.bias (ndim=1) + module.encoder.net.3.self_attention1.out.bias (ndim=1) + module.encoder.net.3.layer_norm2.weight (ndim=1) + module.encoder.net.3.layer_norm2.bias (ndim=1) + module.encoder.net.3.mlp3.linear1.bias (ndim=1) + module.encoder.net.3.mlp3.linear2.bias (ndim=1) + module.encoder.net.4.layer_norm0.weight (ndim=1) + module.encoder.net.4.layer_norm0.bias (ndim=1) + module.encoder.net.4.self_attention1.query.bias (ndim=1) + module.encoder.net.4.self_attention1.key.bias (ndim=1) + module.encoder.net.4.self_attention1.value.bias (ndim=1) + module.encoder.net.4.self_attention1.out.bias (ndim=1) + module.encoder.net.4.layer_norm2.weight (ndim=1) + module.encoder.net.4.layer_norm2.bias (ndim=1) + module.encoder.net.4.mlp3.linear1.bias (ndim=1) + module.encoder.net.4.mlp3.linear2.bias (ndim=1) + module.encoder.net.5.layer_norm0.weight (ndim=1) + module.encoder.net.5.layer_norm0.bias (ndim=1) + module.encoder.net.5.self_attention1.query.bias (ndim=1) + module.encoder.net.5.self_attention1.key.bias (ndim=1) + module.encoder.net.5.self_attention1.value.bias (ndim=1) + module.encoder.net.5.self_attention1.out.bias (ndim=1) + module.encoder.net.5.layer_norm2.weight (ndim=1) + module.encoder.net.5.layer_norm2.bias (ndim=1) + module.encoder.net.5.mlp3.linear1.bias (ndim=1) + module.encoder.net.5.mlp3.linear2.bias (ndim=1) + module.encoder.net.6.layer_norm0.weight (ndim=1) + module.encoder.net.6.layer_norm0.bias (ndim=1) + module.encoder.net.6.self_attention1.query.bias (ndim=1) + module.encoder.net.6.self_attention1.key.bias (ndim=1) + module.encoder.net.6.self_attention1.value.bias (ndim=1) diff --git a/submissions/external_tuning/muon/pytorch/docs/librispeech_conformer.txt b/submissions/external_tuning/muon/pytorch/docs/librispeech_conformer.txt new file mode 100644 index 00000000..c4cfa19b --- /dev/null +++ b/submissions/external_tuning/muon/pytorch/docs/librispeech_conformer.txt @@ -0,0 +1,140 @@ +I1007 11:06:14.662850 22646459978816 utils.py:33] Muon params: + module.subsample.conv1.kernel (ndim=4) + module.subsample.conv2.kernel (ndim=4) + module.subsample.linear.weight (ndim=2) + module.conformers.0.ff1.linear1.weight (ndim=2) + module.conformers.0.ff1.linear2.weight (ndim=2) + module.conformers.0.mhsa.self_attention.in_proj.weight (ndim=2) + module.conformers.0.mhsa.self_attention.out_proj.weight (ndim=2) + module.conformers.0.conv.lin1.weight (ndim=2) + module.conformers.0.conv.lin2.weight (ndim=2) + module.conformers.0.conv.conv1.weight (ndim=3) + module.conformers.0.conv.lin3.weight (ndim=2) + module.conformers.0.ff2.linear1.weight (ndim=2) + module.conformers.0.ff2.linear2.weight (ndim=2) + module.conformers.1.ff1.linear1.weight (ndim=2) + module.conformers.1.ff1.linear2.weight (ndim=2) + module.conformers.1.mhsa.self_attention.in_proj.weight (ndim=2) + module.conformers.1.mhsa.self_attention.out_proj.weight (ndim=2) + module.conformers.1.conv.lin1.weight (ndim=2) + module.conformers.1.conv.lin2.weight (ndim=2) + module.conformers.1.conv.conv1.weight (ndim=3) + module.conformers.1.conv.lin3.weight (ndim=2) + module.conformers.1.ff2.linear1.weight (ndim=2) + module.conformers.1.ff2.linear2.weight (ndim=2) + module.conformers.2.ff1.linear1.weight (ndim=2) + module.conformers.2.ff1.linear2.weight (ndim=2) + module.conformers.2.mhsa.self_attention.in_proj.weight (ndim=2) + module.conformers.2.mhsa.self_attention.out_proj.weight (ndim=2) + module.conformers.2.conv.lin1.weight (ndim=2) + module.conformers.2.conv.lin2.weight (ndim=2) + module.conformers.2.conv.conv1.weight (ndim=3) + module.conformers.2.conv.lin3.weight (ndim=2) + module.conformers.2.ff2.linear1.weight (ndim=2) + module.conformers.2.ff2.linear2.weight (ndim=2) + module.conformers.3.ff1.linear1.weight (ndim=2) + module.conformers.3.ff1.linear2.weight (ndim=2) + module.conformers.3.mhsa.self_attention.in_proj.weight (ndim=2) + module.conformers.3.mhsa.self_attention.out_proj.weight (ndim=2) + module.conformers.3.conv.lin1.weight (ndim=2) + module.conformers.3.conv.lin2.weight (ndim=2) + module.conformers.3.conv.conv1.weight (ndim=3) + module.conformers.3.conv.lin3.weight (ndim=2) + module.conformers.3.ff2.linear1.weight (ndim=2) + module.conformers.3.ff2.linear2.weight (ndim=2) + module.lin.weight (ndim=2) +I1007 11:06:14.662934 22646459978816 utils.py:34] Adam params: + module.subsample.conv1.bias (ndim=1) + module.subsample.conv2.bias (ndim=1) + module.subsample.linear.bias (ndim=1) + module.conformers.0.ff1.ln.scale (ndim=1) + module.conformers.0.ff1.ln.bias (ndim=1) + module.conformers.0.ff1.linear1.bias (ndim=1) + module.conformers.0.ff1.linear2.bias (ndim=1) + module.conformers.0.mhsa.ln.scale (ndim=1) + module.conformers.0.mhsa.ln.bias (ndim=1) + module.conformers.0.mhsa.self_attention.in_proj.bias (ndim=1) + module.conformers.0.mhsa.self_attention.out_proj.bias (ndim=1) + module.conformers.0.mhsa.self_attention.qs.scale (ndim=1) + module.conformers.0.conv.ln.scale (ndim=1) + module.conformers.0.conv.ln.bias (ndim=1) + module.conformers.0.conv.lin1.bias (ndim=1) + module.conformers.0.conv.lin2.bias (ndim=1) + module.conformers.0.conv.bn.scale (ndim=1) + module.conformers.0.conv.bn.bias (ndim=1) + module.conformers.0.conv.lin3.bias (ndim=1) + module.conformers.0.ff2.ln.scale (ndim=1) + module.conformers.0.ff2.ln.bias (ndim=1) + module.conformers.0.ff2.linear1.bias (ndim=1) + module.conformers.0.ff2.linear2.bias (ndim=1) + module.conformers.0.ln.scale (ndim=1) + module.conformers.0.ln.bias (ndim=1) + module.conformers.1.ff1.ln.scale (ndim=1) + module.conformers.1.ff1.ln.bias (ndim=1) + module.conformers.1.ff1.linear1.bias (ndim=1) + module.conformers.1.ff1.linear2.bias (ndim=1) + module.conformers.1.mhsa.ln.scale (ndim=1) + module.conformers.1.mhsa.ln.bias (ndim=1) + module.conformers.1.mhsa.self_attention.in_proj.bias (ndim=1) + module.conformers.1.mhsa.self_attention.out_proj.bias (ndim=1) + module.conformers.1.mhsa.self_attention.qs.scale (ndim=1) + module.conformers.1.conv.ln.scale (ndim=1) + module.conformers.1.conv.ln.bias (ndim=1) + module.conformers.1.conv.lin1.bias (ndim=1) + module.conformers.1.conv.lin2.bias (ndim=1) + module.conformers.1.conv.bn.scale (ndim=1) + module.conformers.1.conv.bn.bias (ndim=1) + module.conformers.1.conv.lin3.bias (ndim=1) + module.conformers.1.ff2.ln.scale (ndim=1) + module.conformers.1.ff2.ln.bias (ndim=1) + module.conformers.1.ff2.linear1.bias (ndim=1) + module.conformers.1.ff2.linear2.bias (ndim=1) + module.conformers.1.ln.scale (ndim=1) + module.conformers.1.ln.bias (ndim=1) + module.conformers.2.ff1.ln.scale (ndim=1) + module.conformers.2.ff1.ln.bias (ndim=1) + module.conformers.2.ff1.linear1.bias (ndim=1) + module.conformers.2.ff1.linear2.bias (ndim=1) + module.conformers.2.mhsa.ln.scale (ndim=1) + module.conformers.2.mhsa.ln.bias (ndim=1) + module.conformers.2.mhsa.self_attention.in_proj.bias (ndim=1) + module.conformers.2.mhsa.self_attention.out_proj.bias (ndim=1) + module.conformers.2.mhsa.self_attention.qs.scale (ndim=1) + module.conformers.2.conv.ln.scale (ndim=1) + module.conformers.2.conv.ln.bias (ndim=1) + module.conformers.2.conv.lin1.bias (ndim=1) + module.conformers.2.conv.lin2.bias (ndim=1) + module.conformers.2.conv.bn.scale (ndim=1) + module.conformers.2.conv.bn.bias (ndim=1) + module.conformers.2.conv.lin3.bias (ndim=1) + module.conformers.2.ff2.ln.scale (ndim=1) + module.conformers.2.ff2.ln.bias (ndim=1) + module.conformers.2.ff2.linear1.bias (ndim=1) + module.conformers.2.ff2.linear2.bias (ndim=1) + module.conformers.2.ln.scale (ndim=1) + module.conformers.2.ln.bias (ndim=1) + module.conformers.3.ff1.ln.scale (ndim=1) + module.conformers.3.ff1.ln.bias (ndim=1) + module.conformers.3.ff1.linear1.bias (ndim=1) + module.conformers.3.ff1.linear2.bias (ndim=1) + module.conformers.3.mhsa.ln.scale (ndim=1) + module.conformers.3.mhsa.ln.bias (ndim=1) + module.conformers.3.mhsa.self_attention.in_proj.bias (ndim=1) + module.conformers.3.mhsa.self_attention.out_proj.bias (ndim=1) + module.conformers.3.mhsa.self_attention.qs.scale (ndim=1) + module.conformers.3.conv.ln.scale (ndim=1) + module.conformers.3.conv.ln.bias (ndim=1) + module.conformers.3.conv.lin1.bias (ndim=1) + module.conformers.3.conv.lin2.bias (ndim=1) + module.conformers.3.conv.bn.scale (ndim=1) + module.conformers.3.conv.bn.bias (ndim=1) + module.conformers.3.conv.lin3.bias (ndim=1) + module.conformers.3.ff2.ln.scale (ndim=1) + module.conformers.3.ff2.ln.bias (ndim=1) + module.conformers.3.ff2.linear1.bias (ndim=1) + module.conformers.3.ff2.linear2.bias (ndim=1) + module.conformers.3.ln.scale (ndim=1) + module.conformers.3.ln.bias (ndim=1) + module.ln.scale (ndim=1) + module.ln.bias (ndim=1) + module.lin.bias (ndim=1) \ No newline at end of file diff --git a/submissions/external_tuning/muon/pytorch/docs/librispeech_deepspeech.txt b/submissions/external_tuning/muon/pytorch/docs/librispeech_deepspeech.txt new file mode 100644 index 00000000..170471c7 --- /dev/null +++ b/submissions/external_tuning/muon/pytorch/docs/librispeech_deepspeech.txt @@ -0,0 +1,84 @@ +I1007 11:12:13.333996 23141779166272 utils.py:33] Muon params: + module.subsample.conv1.kernel (ndim=4) + module.subsample.conv2.kernel (ndim=4) + module.subsample.lin.weight (ndim=2) + module.lstms.0.lstm.weight_ih_l0 (ndim=2) + module.lstms.0.lstm.weight_hh_l0 (ndim=2) + module.lstms.0.lstm.weight_ih_l0_reverse (ndim=2) + module.lstms.0.lstm.weight_hh_l0_reverse (ndim=2) + module.lstms.1.lstm.weight_ih_l0 (ndim=2) + module.lstms.1.lstm.weight_hh_l0 (ndim=2) + module.lstms.1.lstm.weight_ih_l0_reverse (ndim=2) + module.lstms.1.lstm.weight_hh_l0_reverse (ndim=2) + module.lstms.2.lstm.weight_ih_l0 (ndim=2) + module.lstms.2.lstm.weight_hh_l0 (ndim=2) + module.lstms.2.lstm.weight_ih_l0_reverse (ndim=2) + module.lstms.2.lstm.weight_hh_l0_reverse (ndim=2) + module.lstms.3.lstm.weight_ih_l0 (ndim=2) + module.lstms.3.lstm.weight_hh_l0 (ndim=2) + module.lstms.3.lstm.weight_ih_l0_reverse (ndim=2) + module.lstms.3.lstm.weight_hh_l0_reverse (ndim=2) + module.lstms.4.lstm.weight_ih_l0 (ndim=2) + module.lstms.4.lstm.weight_hh_l0 (ndim=2) + module.lstms.4.lstm.weight_ih_l0_reverse (ndim=2) + module.lstms.4.lstm.weight_hh_l0_reverse (ndim=2) + module.lstms.5.lstm.weight_ih_l0 (ndim=2) + module.lstms.5.lstm.weight_hh_l0 (ndim=2) + module.lstms.5.lstm.weight_ih_l0_reverse (ndim=2) + module.lstms.5.lstm.weight_hh_l0_reverse (ndim=2) + module.ffns.0.lin.weight (ndim=2) + module.ffns.1.lin.weight (ndim=2) + module.ffns.2.lin.weight (ndim=2) + module.lin.weight (ndim=2) +I1007 11:12:13.334092 23141779166272 utils.py:34] Adam params: + module.subsample.conv1.bias (ndim=1) + module.subsample.conv2.bias (ndim=1) + module.subsample.lin.bias (ndim=1) + module.lstms.0.bn_normalization_layer.weight (ndim=1) + module.lstms.0.bn_normalization_layer.bias (ndim=1) + module.lstms.0.lstm.bias_ih_l0 (ndim=1) + module.lstms.0.lstm.bias_hh_l0 (ndim=1) + module.lstms.0.lstm.bias_ih_l0_reverse (ndim=1) + module.lstms.0.lstm.bias_hh_l0_reverse (ndim=1) + module.lstms.1.bn_normalization_layer.weight (ndim=1) + module.lstms.1.bn_normalization_layer.bias (ndim=1) + module.lstms.1.lstm.bias_ih_l0 (ndim=1) + module.lstms.1.lstm.bias_hh_l0 (ndim=1) + module.lstms.1.lstm.bias_ih_l0_reverse (ndim=1) + module.lstms.1.lstm.bias_hh_l0_reverse (ndim=1) + module.lstms.2.bn_normalization_layer.weight (ndim=1) + module.lstms.2.bn_normalization_layer.bias (ndim=1) + module.lstms.2.lstm.bias_ih_l0 (ndim=1) + module.lstms.2.lstm.bias_hh_l0 (ndim=1) + module.lstms.2.lstm.bias_ih_l0_reverse (ndim=1) + module.lstms.2.lstm.bias_hh_l0_reverse (ndim=1) + module.lstms.3.bn_normalization_layer.weight (ndim=1) + module.lstms.3.bn_normalization_layer.bias (ndim=1) + module.lstms.3.lstm.bias_ih_l0 (ndim=1) + module.lstms.3.lstm.bias_hh_l0 (ndim=1) + module.lstms.3.lstm.bias_ih_l0_reverse (ndim=1) + module.lstms.3.lstm.bias_hh_l0_reverse (ndim=1) + module.lstms.4.bn_normalization_layer.weight (ndim=1) + module.lstms.4.bn_normalization_layer.bias (ndim=1) + module.lstms.4.lstm.bias_ih_l0 (ndim=1) + module.lstms.4.lstm.bias_hh_l0 (ndim=1) + module.lstms.4.lstm.bias_ih_l0_reverse (ndim=1) + module.lstms.4.lstm.bias_hh_l0_reverse (ndim=1) + module.lstms.5.bn_normalization_layer.weight (ndim=1) + module.lstms.5.bn_normalization_layer.bias (ndim=1) + module.lstms.5.lstm.bias_ih_l0 (ndim=1) + module.lstms.5.lstm.bias_hh_l0 (ndim=1) + module.lstms.5.lstm.bias_ih_l0_reverse (ndim=1) + module.lstms.5.lstm.bias_hh_l0_reverse (ndim=1) + module.ffns.0.bn_normalization_layer.weight (ndim=1) + module.ffns.0.bn_normalization_layer.bias (ndim=1) + module.ffns.0.lin.bias (ndim=1) + module.ffns.1.bn_normalization_layer.weight (ndim=1) + module.ffns.1.bn_normalization_layer.bias (ndim=1) + module.ffns.1.lin.bias (ndim=1) + module.ffns.2.bn_normalization_layer.weight (ndim=1) + module.ffns.2.bn_normalization_layer.bias (ndim=1) + module.ffns.2.lin.bias (ndim=1) + module.ln.scale (ndim=1) + module.ln.bias (ndim=1) + module.lin.bias (ndim=1) \ No newline at end of file diff --git a/submissions/external_tuning/muon/pytorch/docs/ogbg.txt b/submissions/external_tuning/muon/pytorch/docs/ogbg.txt new file mode 100644 index 00000000..cad5a81e --- /dev/null +++ b/submissions/external_tuning/muon/pytorch/docs/ogbg.txt @@ -0,0 +1,68 @@ +I1007 11:06:19.523296 22384967775296 utils.py:33] Muon params: + module.node_embedder.weight (ndim=2) + module.edge_embedder.weight (ndim=2) + module.graph_network.0.update_edge_fn.dense_0.weight (ndim=2) + module.graph_network.0.update_node_fn.dense_0.weight (ndim=2) + module.graph_network.0.update_global_fn.dense_0.weight (ndim=2) + module.graph_network.1.update_edge_fn.dense_0.weight (ndim=2) + module.graph_network.1.update_node_fn.dense_0.weight (ndim=2) + module.graph_network.1.update_global_fn.dense_0.weight (ndim=2) + module.graph_network.2.update_edge_fn.dense_0.weight (ndim=2) + module.graph_network.2.update_node_fn.dense_0.weight (ndim=2) + module.graph_network.2.update_global_fn.dense_0.weight (ndim=2) + module.graph_network.3.update_edge_fn.dense_0.weight (ndim=2) + module.graph_network.3.update_node_fn.dense_0.weight (ndim=2) + module.graph_network.3.update_global_fn.dense_0.weight (ndim=2) + module.graph_network.4.update_edge_fn.dense_0.weight (ndim=2) + module.graph_network.4.update_node_fn.dense_0.weight (ndim=2) + module.graph_network.4.update_global_fn.dense_0.weight (ndim=2) + module.decoder.weight (ndim=2) +I1007 11:06:19.523386 22384967775296 utils.py:34] Adam params: + module.node_embedder.bias (ndim=1) + module.edge_embedder.bias (ndim=1) + module.graph_network.0.update_edge_fn.dense_0.bias (ndim=1) + module.graph_network.0.update_edge_fn.norm_0.weight (ndim=1) + module.graph_network.0.update_edge_fn.norm_0.bias (ndim=1) + module.graph_network.0.update_node_fn.dense_0.bias (ndim=1) + module.graph_network.0.update_node_fn.norm_0.weight (ndim=1) + module.graph_network.0.update_node_fn.norm_0.bias (ndim=1) + module.graph_network.0.update_global_fn.dense_0.bias (ndim=1) + module.graph_network.0.update_global_fn.norm_0.weight (ndim=1) + module.graph_network.0.update_global_fn.norm_0.bias (ndim=1) + module.graph_network.1.update_edge_fn.dense_0.bias (ndim=1) + module.graph_network.1.update_edge_fn.norm_0.weight (ndim=1) + module.graph_network.1.update_edge_fn.norm_0.bias (ndim=1) + module.graph_network.1.update_node_fn.dense_0.bias (ndim=1) + module.graph_network.1.update_node_fn.norm_0.weight (ndim=1) + module.graph_network.1.update_node_fn.norm_0.bias (ndim=1) + module.graph_network.1.update_global_fn.dense_0.bias (ndim=1) + module.graph_network.1.update_global_fn.norm_0.weight (ndim=1) + module.graph_network.1.update_global_fn.norm_0.bias (ndim=1) + module.graph_network.2.update_edge_fn.dense_0.bias (ndim=1) + module.graph_network.2.update_edge_fn.norm_0.weight (ndim=1) + module.graph_network.2.update_edge_fn.norm_0.bias (ndim=1) + module.graph_network.2.update_node_fn.dense_0.bias (ndim=1) + module.graph_network.2.update_node_fn.norm_0.weight (ndim=1) + module.graph_network.2.update_node_fn.norm_0.bias (ndim=1) + module.graph_network.2.update_global_fn.dense_0.bias (ndim=1) + module.graph_network.2.update_global_fn.norm_0.weight (ndim=1) + module.graph_network.2.update_global_fn.norm_0.bias (ndim=1) + module.graph_network.3.update_edge_fn.dense_0.bias (ndim=1) + module.graph_network.3.update_edge_fn.norm_0.weight (ndim=1) + module.graph_network.3.update_edge_fn.norm_0.bias (ndim=1) + module.graph_network.3.update_node_fn.dense_0.bias (ndim=1) + module.graph_network.3.update_node_fn.norm_0.weight (ndim=1) + module.graph_network.3.update_node_fn.norm_0.bias (ndim=1) + module.graph_network.3.update_global_fn.dense_0.bias (ndim=1) + module.graph_network.3.update_global_fn.norm_0.weight (ndim=1) + module.graph_network.3.update_global_fn.norm_0.bias (ndim=1) + module.graph_network.4.update_edge_fn.dense_0.bias (ndim=1) + module.graph_network.4.update_edge_fn.norm_0.weight (ndim=1) + module.graph_network.4.update_edge_fn.norm_0.bias (ndim=1) + module.graph_network.4.update_node_fn.dense_0.bias (ndim=1) + module.graph_network.4.update_node_fn.norm_0.weight (ndim=1) + module.graph_network.4.update_node_fn.norm_0.bias (ndim=1) + module.graph_network.4.update_global_fn.dense_0.bias (ndim=1) + module.graph_network.4.update_global_fn.norm_0.weight (ndim=1) + module.graph_network.4.update_global_fn.norm_0.bias (ndim=1) + module.decoder.bias (ndim=1) \ No newline at end of file diff --git a/submissions/external_tuning/muon/pytorch/docs/wmt.txt b/submissions/external_tuning/muon/pytorch/docs/wmt.txt new file mode 100644 index 00000000..da9a9fb6 --- /dev/null +++ b/submissions/external_tuning/muon/pytorch/docs/wmt.txt @@ -0,0 +1,157 @@ +I1007 11:06:21.233899 22787679364160 utils.py:33] Muon params: + _orig_mod.module.encoder.encoder.layers.0.self_attn.in_proj.weight (ndim=2) + _orig_mod.module.encoder.encoder.layers.0.self_attn.out_proj.weight (ndim=2) + _orig_mod.module.encoder.encoder.layers.0.linear1.weight (ndim=2) + _orig_mod.module.encoder.encoder.layers.0.linear2.weight (ndim=2) + _orig_mod.module.encoder.encoder.layers.1.self_attn.in_proj.weight (ndim=2) + _orig_mod.module.encoder.encoder.layers.1.self_attn.out_proj.weight (ndim=2) + _orig_mod.module.encoder.encoder.layers.1.linear1.weight (ndim=2) + _orig_mod.module.encoder.encoder.layers.1.linear2.weight (ndim=2) + _orig_mod.module.encoder.encoder.layers.2.self_attn.in_proj.weight (ndim=2) + _orig_mod.module.encoder.encoder.layers.2.self_attn.out_proj.weight (ndim=2) + _orig_mod.module.encoder.encoder.layers.2.linear1.weight (ndim=2) + _orig_mod.module.encoder.encoder.layers.2.linear2.weight (ndim=2) + _orig_mod.module.encoder.encoder.layers.3.self_attn.in_proj.weight (ndim=2) + _orig_mod.module.encoder.encoder.layers.3.self_attn.out_proj.weight (ndim=2) + _orig_mod.module.encoder.encoder.layers.3.linear1.weight (ndim=2) + _orig_mod.module.encoder.encoder.layers.3.linear2.weight (ndim=2) + _orig_mod.module.encoder.encoder.layers.4.self_attn.in_proj.weight (ndim=2) + _orig_mod.module.encoder.encoder.layers.4.self_attn.out_proj.weight (ndim=2) + _orig_mod.module.encoder.encoder.layers.4.linear1.weight (ndim=2) + _orig_mod.module.encoder.encoder.layers.4.linear2.weight (ndim=2) + _orig_mod.module.encoder.encoder.layers.5.self_attn.in_proj.weight (ndim=2) + _orig_mod.module.encoder.encoder.layers.5.self_attn.out_proj.weight (ndim=2) + _orig_mod.module.encoder.encoder.layers.5.linear1.weight (ndim=2) + _orig_mod.module.encoder.encoder.layers.5.linear2.weight (ndim=2) + _orig_mod.module.decoder.decoder.layers.0.self_attn.in_proj.weight (ndim=2) + _orig_mod.module.decoder.decoder.layers.0.self_attn.out_proj.weight (ndim=2) + _orig_mod.module.decoder.decoder.layers.0.multihead_attn.q_proj.weight (ndim=2) + _orig_mod.module.decoder.decoder.layers.0.multihead_attn.kv_proj.weight (ndim=2) + _orig_mod.module.decoder.decoder.layers.0.multihead_attn.out_proj.weight (ndim=2) + _orig_mod.module.decoder.decoder.layers.0.linear1.weight (ndim=2) + _orig_mod.module.decoder.decoder.layers.0.linear2.weight (ndim=2) + _orig_mod.module.decoder.decoder.layers.1.self_attn.in_proj.weight (ndim=2) + _orig_mod.module.decoder.decoder.layers.1.self_attn.out_proj.weight (ndim=2) + _orig_mod.module.decoder.decoder.layers.1.multihead_attn.q_proj.weight (ndim=2) + _orig_mod.module.decoder.decoder.layers.1.multihead_attn.kv_proj.weight (ndim=2) + _orig_mod.module.decoder.decoder.layers.1.multihead_attn.out_proj.weight (ndim=2) + _orig_mod.module.decoder.decoder.layers.1.linear1.weight (ndim=2) + _orig_mod.module.decoder.decoder.layers.1.linear2.weight (ndim=2) + _orig_mod.module.decoder.decoder.layers.2.self_attn.in_proj.weight (ndim=2) + _orig_mod.module.decoder.decoder.layers.2.self_attn.out_proj.weight (ndim=2) + _orig_mod.module.decoder.decoder.layers.2.multihead_attn.q_proj.weight (ndim=2) + _orig_mod.module.decoder.decoder.layers.2.multihead_attn.kv_proj.weight (ndim=2) + _orig_mod.module.decoder.decoder.layers.2.multihead_attn.out_proj.weight (ndim=2) + _orig_mod.module.decoder.decoder.layers.2.linear1.weight (ndim=2) + _orig_mod.module.decoder.decoder.layers.2.linear2.weight (ndim=2) + _orig_mod.module.decoder.decoder.layers.3.self_attn.in_proj.weight (ndim=2) + _orig_mod.module.decoder.decoder.layers.3.self_attn.out_proj.weight (ndim=2) + _orig_mod.module.decoder.decoder.layers.3.multihead_attn.q_proj.weight (ndim=2) + _orig_mod.module.decoder.decoder.layers.3.multihead_attn.kv_proj.weight (ndim=2) + _orig_mod.module.decoder.decoder.layers.3.multihead_attn.out_proj.weight (ndim=2) + _orig_mod.module.decoder.decoder.layers.3.linear1.weight (ndim=2) + _orig_mod.module.decoder.decoder.layers.3.linear2.weight (ndim=2) + _orig_mod.module.decoder.decoder.layers.4.self_attn.in_proj.weight (ndim=2) + _orig_mod.module.decoder.decoder.layers.4.self_attn.out_proj.weight (ndim=2) + _orig_mod.module.decoder.decoder.layers.4.multihead_attn.q_proj.weight (ndim=2) + _orig_mod.module.decoder.decoder.layers.4.multihead_attn.kv_proj.weight (ndim=2) + _orig_mod.module.decoder.decoder.layers.4.multihead_attn.out_proj.weight (ndim=2) + _orig_mod.module.decoder.decoder.layers.4.linear1.weight (ndim=2) + _orig_mod.module.decoder.decoder.layers.4.linear2.weight (ndim=2) + _orig_mod.module.decoder.decoder.layers.5.self_attn.in_proj.weight (ndim=2) + _orig_mod.module.decoder.decoder.layers.5.self_attn.out_proj.weight (ndim=2) + _orig_mod.module.decoder.decoder.layers.5.multihead_attn.q_proj.weight (ndim=2) + _orig_mod.module.decoder.decoder.layers.5.multihead_attn.kv_proj.weight (ndim=2) + _orig_mod.module.decoder.decoder.layers.5.multihead_attn.out_proj.weight (ndim=2) + _orig_mod.module.decoder.decoder.layers.5.linear1.weight (ndim=2) + _orig_mod.module.decoder.decoder.layers.5.linear2.weight (ndim=2) +I1007 11:06:21.234023 22637082580032 utils.py:34] Adam params: + _orig_mod.module.shared_embedding.weight (ndim=2) + _orig_mod.module.encoder.encoder.layers.0.linear1.bias (ndim=1) + _orig_mod.module.encoder.encoder.layers.0.linear2.bias (ndim=1) + _orig_mod.module.encoder.encoder.layers.0.norm1.weight (ndim=1) + _orig_mod.module.encoder.encoder.layers.0.norm1.bias (ndim=1) + _orig_mod.module.encoder.encoder.layers.0.norm2.weight (ndim=1) + _orig_mod.module.encoder.encoder.layers.0.norm2.bias (ndim=1) + _orig_mod.module.encoder.encoder.layers.1.linear1.bias (ndim=1) + _orig_mod.module.encoder.encoder.layers.1.linear2.bias (ndim=1) + _orig_mod.module.encoder.encoder.layers.1.norm1.weight (ndim=1) + _orig_mod.module.encoder.encoder.layers.1.norm1.bias (ndim=1) + _orig_mod.module.encoder.encoder.layers.1.norm2.weight (ndim=1) + _orig_mod.module.encoder.encoder.layers.1.norm2.bias (ndim=1) + _orig_mod.module.encoder.encoder.layers.2.linear1.bias (ndim=1) + _orig_mod.module.encoder.encoder.layers.2.linear2.bias (ndim=1) + _orig_mod.module.encoder.encoder.layers.2.norm1.weight (ndim=1) + _orig_mod.module.encoder.encoder.layers.2.norm1.bias (ndim=1) + _orig_mod.module.encoder.encoder.layers.2.norm2.weight (ndim=1) + _orig_mod.module.encoder.encoder.layers.2.norm2.bias (ndim=1) + _orig_mod.module.encoder.encoder.layers.3.linear1.bias (ndim=1) + _orig_mod.module.encoder.encoder.layers.3.linear2.bias (ndim=1) + _orig_mod.module.encoder.encoder.layers.3.norm1.weight (ndim=1) + _orig_mod.module.encoder.encoder.layers.3.norm1.bias (ndim=1) + _orig_mod.module.encoder.encoder.layers.3.norm2.weight (ndim=1) + _orig_mod.module.encoder.encoder.layers.3.norm2.bias (ndim=1) + _orig_mod.module.encoder.encoder.layers.4.linear1.bias (ndim=1) + _orig_mod.module.encoder.encoder.layers.4.linear2.bias (ndim=1) + _orig_mod.module.encoder.encoder.layers.4.norm1.weight (ndim=1) + _orig_mod.module.encoder.encoder.layers.4.norm1.bias (ndim=1) + _orig_mod.module.encoder.encoder.layers.4.norm2.weight (ndim=1) + _orig_mod.module.encoder.encoder.layers.4.norm2.bias (ndim=1) + _orig_mod.module.encoder.encoder.layers.5.linear1.bias (ndim=1) + _orig_mod.module.encoder.encoder.layers.5.linear2.bias (ndim=1) + _orig_mod.module.encoder.encoder.layers.5.norm1.weight (ndim=1) + _orig_mod.module.encoder.encoder.layers.5.norm1.bias (ndim=1) + _orig_mod.module.encoder.encoder.layers.5.norm2.weight (ndim=1) + _orig_mod.module.encoder.encoder.layers.5.norm2.bias (ndim=1) + _orig_mod.module.encoder.encoder.norm.weight (ndim=1) + _orig_mod.module.encoder.encoder.norm.bias (ndim=1) + _orig_mod.module.decoder.decoder.layers.0.linear1.bias (ndim=1) + _orig_mod.module.decoder.decoder.layers.0.linear2.bias (ndim=1) + _orig_mod.module.decoder.decoder.layers.0.norm1.weight (ndim=1) + _orig_mod.module.decoder.decoder.layers.0.norm1.bias (ndim=1) + _orig_mod.module.decoder.decoder.layers.0.norm2.weight (ndim=1) + _orig_mod.module.decoder.decoder.layers.0.norm2.bias (ndim=1) + _orig_mod.module.decoder.decoder.layers.0.norm3.weight (ndim=1) + _orig_mod.module.decoder.decoder.layers.0.norm3.bias (ndim=1) + _orig_mod.module.decoder.decoder.layers.1.linear1.bias (ndim=1) + _orig_mod.module.decoder.decoder.layers.1.linear2.bias (ndim=1) + _orig_mod.module.decoder.decoder.layers.1.norm1.weight (ndim=1) + _orig_mod.module.decoder.decoder.layers.1.norm1.bias (ndim=1) + _orig_mod.module.decoder.decoder.layers.1.norm2.weight (ndim=1) + _orig_mod.module.decoder.decoder.layers.1.norm2.bias (ndim=1) + _orig_mod.module.decoder.decoder.layers.1.norm3.weight (ndim=1) + _orig_mod.module.decoder.decoder.layers.1.norm3.bias (ndim=1) + _orig_mod.module.decoder.decoder.layers.2.linear1.bias (ndim=1) + _orig_mod.module.decoder.decoder.layers.2.linear2.bias (ndim=1) + _orig_mod.module.decoder.decoder.layers.2.norm1.weight (ndim=1) + _orig_mod.module.decoder.decoder.layers.2.norm1.bias (ndim=1) + _orig_mod.module.decoder.decoder.layers.2.norm2.weight (ndim=1) + _orig_mod.module.decoder.decoder.layers.2.norm2.bias (ndim=1) + _orig_mod.module.decoder.decoder.layers.2.norm3.weight (ndim=1) + _orig_mod.module.decoder.decoder.layers.2.norm3.bias (ndim=1) + _orig_mod.module.decoder.decoder.layers.3.linear1.bias (ndim=1) + _orig_mod.module.decoder.decoder.layers.3.linear2.bias (ndim=1) + _orig_mod.module.decoder.decoder.layers.3.norm1.weight (ndim=1) + _orig_mod.module.decoder.decoder.layers.3.norm1.bias (ndim=1) + _orig_mod.module.decoder.decoder.layers.3.norm2.weight (ndim=1) + _orig_mod.module.decoder.decoder.layers.3.norm2.bias (ndim=1) + _orig_mod.module.decoder.decoder.layers.3.norm3.weight (ndim=1) + _orig_mod.module.decoder.decoder.layers.3.norm3.bias (ndim=1) + _orig_mod.module.decoder.decoder.layers.4.linear1.bias (ndim=1) + _orig_mod.module.decoder.decoder.layers.4.linear2.bias (ndim=1) + _orig_mod.module.decoder.decoder.layers.4.norm1.weight (ndim=1) + _orig_mod.module.decoder.decoder.layers.4.norm1.bias (ndim=1) + _orig_mod.module.decoder.decoder.layers.4.norm2.weight (ndim=1) + _orig_mod.module.decoder.decoder.layers.4.norm2.bias (ndim=1) + _orig_mod.module.decoder.decoder.layers.4.norm3.weight (ndim=1) + _orig_mod.module.decoder.decoder.layers.4.norm3.bias (ndim=1) + _orig_mod.module.decoder.decoder.layers.5.linear1.bias (ndim=1) + _orig_mod.module.decoder.decoder.layers.5.linear2.bias (ndim=1) + _orig_mod.module.decoder.decoder.layers.5.norm1.weight (ndim=1) + _orig_mod.module.decoder.decoder.layers.5.norm1.bias (ndim=1) + _orig_mod.module.decoder.decoder.layers.5.norm2.weight (ndim=1) + _orig_mod.module.decoder.decoder.layers.5.norm2.bias (ndim=1) + _orig_mod.module.decoder.decoder.layers.5.norm3.weight (ndim=1) + _orig_mod.module.decoder.decoder.layers.5.norm3.bias (ndim=1) + _orig_mod.module.decoder.decoder.norm.weight (ndim=1) + _orig_mod.module.decoder.decoder.norm.bias (ndim=1) \ No newline at end of file diff --git a/submissions/external_tuning/muon/pytorch/muon_algos.py b/submissions/external_tuning/muon/pytorch/muon_algos.py new file mode 100644 index 00000000..281f5581 --- /dev/null +++ b/submissions/external_tuning/muon/pytorch/muon_algos.py @@ -0,0 +1,340 @@ +"""" +Muon PyTorch implementations. + +Two Muon implementations are provides: ``MuonVanilla`` and``MuonDataParallel``. +""" + +import os +import torch +import torch.distributed as dist +from abc import ABC, abstractmethod + + +# Distributed settings +USE_DDP = 'RANK' in os.environ +RANK = int(os.environ['RANK']) if USE_DDP else 0 +LOCAL_RANK = int(os.environ['LOCAL_RANK']) if USE_DDP else 0 +WORLD_SIZE = int(os.environ["WORLD_SIZE"]) if USE_DDP else 1 + +# Default values for Newton-Schulz +NS_A, NS_B, NS_C = 3.4445, -4.7750, 2.0315 +NS_STEPS = 5 +NS_EPS = 1e-7 + + +@torch.compile() +def zeropower_via_newtonschulz5(G, steps=NS_STEPS, eps=NS_EPS): + """ + Newton-Schulz iteration to approximally orthogonalize G. + 5-th order odd polynomial to approximate sign(x) on [-1,1], + pushing singlular values to {+1,-1}. + + M = U @ S @ V.T + sign(M) = U @ sign(S) @ V.T, odd matrix polynomial commutes with SVD + sign(x) ~= a*x + b*x^3 + c*x^5, x in [-1,1] + """ + if G.ndim != 2: + raise RuntimeError(f"Expected 2D tensor in N-S, found {G.ndim} instead.") + a, b, c = NS_A, NS_B, NS_C + X = G.bfloat16() + if G.size(0) > G.size(1): + X = X.T + + # Ensure spectral norm is at most 1. + # Ortho(cX)=Ortho(X), so we can normalize by ||X||_2 <= ||X||_F + X /= X.norm() + eps + + # NS iterations + for _ in range(steps): + A = X @ X.T + B = b * A + c * (A @ A) + X = a * X + B @ X + + if G.size(0) > G.size(1): + X = X.T + return X + + +@torch.compile() +@torch.no_grad() +def muon_update(g, m, beta, nesterov, ns_steps, ns_eps): + """Updates momentum ``m`` in-place and returns Muon update.""" + m.mul_(beta).add_(g, alpha=1 - beta) + + if nesterov: + g = g.add(m, alpha=beta) + else: + g = m + + g = g.reshape(g.size(0), -1) # flatten trailing dims on 3D, 4D params + g = zeropower_via_newtonschulz5(g, steps=ns_steps, eps=ns_eps) + g = g.view(m.shape) # restore original shape + + return g + + +def _adjust_lr_to_match_adam(lr, param_shape): + # https://arxiv.org/pdf/2502.16982 + A, B = param_shape[:2] + return lr * 0.2 * (max(A, B) ** 0.5) + + +def _adjust_lr_spectral_norm(lr, param_shape): + # Adjust from spectral norm 1 to RMS operator norm 1 + # https://arxiv.org/abs/2310.17813 + fan_out, fan_in = param_shape[:2] + return lr * max(1.0, (fan_out / fan_in) ** 0.5) + + +def _param_to_complexity(p: torch.Tensor) -> int: + """Compute NS complexity on p.grad.""" + # Shape after flatting potential trailing dims (3D, 4D) + m, n = (p.shape[0], torch.tensor(p.shape[1:]).prod().item()) + # X @ X.T complexity: m^2n + # XX.T @ XX.T complexity: m^3 + # XX.TXX.T @ X complexity: m^2n + return 2 * (m ** 2) * n + m ** 3 + + +class MuonBase(torch.optim.Optimizer, ABC): + """Muon optimizer - Momentum Orthogonalized by Newton-Schulz. + + Abstract class. + """ + + def __init__( + self, + params, + lr=0.02, + weight_decay=0.0, + beta=0.95, + nesterov=True, + ns_steps=NS_STEPS, + ns_eps=NS_EPS, + adjust_lr=None, + ): + if not 0.0 <= lr: + raise ValueError(f'Invalid learning rate: {lr}') + if not 0.0 <= weight_decay: + raise ValueError(f'Invalid weight_decay: {weight_decay}') + if not 0.0 <= beta < 1.0: + raise ValueError(f'Invalid muon_beta parameter: {beta}') + if nesterov not in [True, False]: + raise ValueError(f'Invalid nesterov parameter: {nesterov}') + if not 0 < ns_steps: + raise ValueError(f'Invalid ns_steps parameter: {ns_steps}') + if not 0.0 <= ns_eps: + raise ValueError(f'Invalid ns_eps parameter: {ns_eps}') + if not adjust_lr in [None, 'spectral_norm', 'match_adam']: + raise ValueError(f'Invalid adjust_lr parameter: {adjust_lr}') + + defaults = dict( + lr = lr, + weight_decay = weight_decay, + beta = beta, + nesterov = nesterov, + ns_steps = ns_steps, + ns_eps = ns_eps, + ) + super().__init__(params, defaults) + + if adjust_lr is None: + self._adjust_lr = lambda lr, param_shape: lr + elif adjust_lr == 'spectral_norm': + self._adjust_lr = _adjust_lr_spectral_norm + elif adjust_lr == 'match_adam': + self._adjust_lr = _adjust_lr_to_match_adam + + @abstractmethod + @torch.no_grad() + def step(self, closure=None): + pass + + +class MuonVanilla(MuonBase): + """ + Single Devide implementation: if used with DDP, + it will replicate computation across devices. + """ + def __init__(self, params, **kwargs): + super().__init__(params, **kwargs) + + + @torch.compile() + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + lr = group['lr'] + wd = group['weight_decay'] + beta = group['beta'] + nesterov = group['nesterov'] + ns_steps = group['ns_steps'] + ns_eps = group['ns_eps'] + + for p in group['params']: + if p.grad is None: + continue + g = p.grad + state = self.state[p] + + if len(state) == 0: + state['m'] = torch.zeros_like(p) + + g = muon_update(g, state['m'], beta=beta, nesterov=nesterov, ns_steps=ns_steps, ns_eps=ns_eps) + + adjusted_lr = self._adjust_lr(lr, p.shape) # optionally adjust lr + p.mul_(1 - lr * wd) # weigth decay + p.add_(g, alpha=-adjusted_lr) + + return loss + + +class MuonDataParallel(MuonBase): + """ + Distributed Data Parallel Muon Pytorch implementation. + + Modified from: https://github.com/KellerJordan/Muon/blob/master/muon.py#L98 + + For each param group, (sorted) parameters are processed in blocks of world_size. + Each block is distributed round-robin across devices. + + We sort parameters based on the corresponding Newton-Schultz complexity, + rather then based on thier size. + + ``step`` structure: + - ReduceScatter gradients round-robin + - Orthogonalize gradients locally, update param + - AllGather params round-robin + + Both collective operations are asynchronous, + allowing to overlap computation and communication. + We wait on reduce-scatter when updating, and wait for the all-gather + ops to finish at the end of step. + + Comms: one all-gather per block -> ~#params/WORLD_SIZE comms. + Space: O(largest_param) + """ + def __init__(self, params, **kwargs): + if not isinstance(params, list): + params = list(params) + + if not dist.is_initialized(): + raise ValueError('Using MuonDDP in a non-distributed run.') + + # Sort params to fairly distribute orthogonalization across devices + if isinstance(params[0], dict): # sort each param group individually + for group in params: + group["params"] = sorted(group["params"], key=_param_to_complexity, reverse=True) + else: + params = sorted(params, key=_param_to_complexity, reverse=True) + + super().__init__(params, **kwargs) + + + @torch.compile() + @torch.no_grad() + def step(self, closure=None): + """ + 1. ReduceScatter: process grads round-robin, ReduceScatter each block. Work handles are stored. + 2. AllGather: process params round-robin, wait for ReduceScatter handles on that block. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + # 1. ReduceScatter grads + for group in self.param_groups: + group['reduce_handles'] = [] + + # References to grads, ensure valid tensors for reduce_scatter + grads = [p.grad if p.grad is not None else torch.zeros_like(p) + for p in group["params"]] + + # Pad grads so each reduce_scatter block is of size WORLD_SIZE. + pad = (WORLD_SIZE - len(grads) % WORLD_SIZE) % WORLD_SIZE + grads_pad = grads + [torch.zeros_like(grads[-1])] * pad + + # Iterate over grads in blocks of WORLD_SIZE + for block_start in range(0, len(grads), WORLD_SIZE): + # Skip padded tensor when reducing + if block_start + RANK < len(grads): + receiv = grads_pad[block_start + RANK] # ref to p.grad + else: + receiv = torch.zeros_like(grads_pad[block_start + RANK]) # dummy buffer + + # ReduceScatter this block + with torch.no_grad(): + handle = dist.reduce_scatter( + receiv, + grads_pad[block_start:block_start + WORLD_SIZE], + op=dist.ReduceOp.AVG, + async_op=True + ) + group['reduce_handles'].append(handle) + + # 2. Update and AllGather (overlapped) + gather_handles = [] + for group in self.param_groups: + lr = group['lr'] + wd = group['weight_decay'] + beta = group['beta'] + dampening = group['dampening'] + nesterov = group['nesterov'] + ns_steps = group['ns_steps'] + ns_eps = group['ns_eps'] + params = group['params'] + reduce_handles = group['reduce_handles'] + + # Pad params so each all-gather block is of size WORLD_SIZE. + # list concat keeps param refs (not copies), so all_gather updates model params directly. + pad = (WORLD_SIZE - len(params) % WORLD_SIZE) % WORLD_SIZE + params_pad = params + [torch.empty_like(params[-1])] * pad + + # Iterate over params in blocks of WORLD_SIZE + for block_start in range(0, len(params), WORLD_SIZE): + # Wait for grads in this block. + reduce_handles.pop(0).wait() + + # Each device updates the RANK-th tensor in the block + if block_start + RANK < len(params): # skip padded tensors + p = params[block_start + RANK] # round-robin + if p.grad is None: + p.grad = torch.zeros_like(p) # ensure valid tensor for all_gather + + state = self.state[p] + + if len(state) == 0: + state['m'] = torch.zeros_like(p) + + g = muon_update(g, state['m'], beta=beta, nesterov=nesterov, ns_steps=ns_steps, ns_eps=ns_eps) + + adjusted_lr = self._adjust_lr(lr, p.shape) # optionally adjust lr + p.mul_(1 - lr * wd) + p.add_(g, alpha=-adjusted_lr) + + # all-gather current block of params (including padded entries) + handle = dist.all_gather( + params_pad[block_start:block_start + WORLD_SIZE], + params_pad[block_start + RANK], + async_op=True + ) + gather_handles.append(handle) + + # ## debug + # assert len(reduce_handles) == 0, AssertionError('Some reduce futures were not consumed.') + + # Sync point + for handle in gather_handles: + handle.wait() + + return loss + + +__all__ = [MuonVanilla, MuonDataParallel] + diff --git a/submissions/external_tuning/muon/pytorch/muon_ddp.py b/submissions/external_tuning/muon/pytorch/muon_ddp.py new file mode 100644 index 00000000..770bc861 --- /dev/null +++ b/submissions/external_tuning/muon/pytorch/muon_ddp.py @@ -0,0 +1,262 @@ +"""" +Distributed Data Parallel Muon PyTorch implementation. +See ``MuonDataParallel`` in muon_algos.py for more details. + +NOTE: gradient clipping is not supported, as gradients are reduced *inside* optimizer.step(), +calling torch.nn.utils.clip_grad_norm_() would wrongly clip the local gradients. +""" + +from typing import Any, Dict, Iterator, List, Optional, Tuple + +import torch +import torch.distributed as dist +import torch.distributed.nn as dist_nn +from absl import logging +from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR + +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup + +from reference_algorithms.muon.pytorch.muon_algos import MuonDataParallel +from reference_algorithms.muon.pytorch.utils import _split_params_muon_adam + +USE_PYTORCH_DDP = pytorch_setup()[0] + + +def _pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup = LinearLR( + optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps + ) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) + return SequentialLR( + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps] + ) + + +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: + """Creates a Muon optimizer and a learning rate schedule.""" + del model_state + del rng + + if getattr(hyperparameters, 'grad_clip', None): + raise NotImplementedError('Gradient clipping not supported with custom ReduceScatter.') + + muon_params, adam_params = _split_params_muon_adam(model_params) + + optimizer_state = { + 'muon': MuonDataParallel( + muon_params, + lr=hyperparameters.learning_rate, # shared + weight_decay=hyperparameters.muon_weight_decay, # shared + beta=hyperparameters.muon_beta, + nesterov=hyperparameters.muon_nesterov, + ns_steps=hyperparameters.muon_ns_steps, + ns_eps=hyperparameters.muon_ns_eps, + adjust_lr=hyperparameters.muon_adjust_lr, + ), + 'adamw': torch.optim.AdamW( + adam_params, + lr=hyperparameters.learning_rate, # shared + weight_decay=hyperparameters.adamw_weight_decay, # shared + betas=(hyperparameters.adamw_beta1, hyperparameters.adamw_beta2), + eps=hyperparameters.adamw_eps, + fused=True, + ), + } + + # One scheduler per optimizer + optimizer_state["muon_scheduler"] = _pytorch_cosine_warmup( + workload.step_hint, hyperparameters, optimizer_state["muon"] + ) + optimizer_state["adamw_scheduler"] = _pytorch_cosine_warmup( + workload.step_hint, hyperparameters, optimizer_state["adamw"] + ) + + return optimizer_state + + +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del train_state + del eval_results + + current_model = current_param_container + current_model.train() + optimizer_state['muon'].zero_grad() + optimizer_state['adamw'].zero_grad() + + # Skip AllReduce in backward pass + current_model.require_backward_grad_sync=False + + # Fwd pass + logits_batch, new_model_state = workload.model_fn( + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True, + dropout_rate=hyperparameters.dropout_rate, + ) + + # Bwd pass + label_smoothing = ( + hyperparameters.label_smoothing + if hasattr(hyperparameters, 'label_smoothing') + else 0.0 + ) + + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + if USE_PYTORCH_DDP: + # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. + summed_loss = dist_nn.all_reduce(summed_loss) + n_valid_examples = dist_nn.all_reduce(n_valid_examples) + loss = summed_loss / n_valid_examples + + # Compute grads, but do not AllReduce them. + loss.backward() + + # All-reduce AdamW grads + for group in optimizer_state['adamw'].param_groups: + for p in group['params']: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + + # Muon: ReduceScatter + update params + AllGather params + optimizer_state['muon'].step() + + # AdamW: update params + optimizer_state['adamw'].step() + + # Step LR schedulers + optimizer_state['muon_scheduler'].step() + optimizer_state['adamw_scheduler'].step() + + # # Log training metrics - loss, grad_norm, batch_size. + # if global_step <= 100 or global_step % 50 == 0: + # with torch.no_grad(): + # parameters = [p for p in current_model.parameters() if p.grad is not None] + # grad_norm = torch.norm( + # torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2 + # ) + # if workload.metrics_logger is not None: + # workload.metrics_logger.append_scalar_metrics( + # { + # 'loss': loss.item(), + # 'grad_norm': grad_norm.item(), + # }, + # global_step, + # ) + # logging.info( + # '%d) loss = %0.3f, grad_norm = %0.3f', + # global_step, + # loss.item(), + # grad_norm.item(), + # ) + + return (optimizer_state, current_param_container, new_model_state) + + +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_resnet_silu': + return 512 + elif workload_name == 'imagenet_resnet_gelu': + return 512 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/submissions/external_tuning/muon/pytorch/muon_vanilla.py b/submissions/external_tuning/muon/pytorch/muon_vanilla.py new file mode 100644 index 00000000..afe31a71 --- /dev/null +++ b/submissions/external_tuning/muon/pytorch/muon_vanilla.py @@ -0,0 +1,246 @@ +"""" +Vanilla Muon PyTorch implementation. +See ``MuonVanilla`` in muon_algos.py for more details. +""" + +from typing import Any, Dict, Iterator, List, Optional, Tuple + +import torch +import torch.distributed.nn as dist_nn +from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR + +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup + +from reference_algorithms.muon.pytorch.muon_algos import MuonVanilla +from reference_algorithms.muon.pytorch.utils import _split_params_muon_adam + +USE_PYTORCH_DDP = pytorch_setup()[0] + + +def _pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup = LinearLR( + optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps + ) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) + return SequentialLR( + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps] + ) + + +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: + """Creates a Muon optimizer and a learning rate schedule.""" + del model_state + del rng + + muon_params, adam_params = _split_params_muon_adam(model_params) + + optimizer_state = { + 'muon': MuonVanilla( + muon_params, + lr=hyperparameters.learning_rate, # shared + weight_decay=hyperparameters.muon_weight_decay, # shared + beta=hyperparameters.muon_beta, + nesterov=hyperparameters.muon_nesterov, + ns_steps=hyperparameters.muon_ns_steps, + ns_eps=hyperparameters.muon_ns_eps, + adjust_lr=hyperparameters.muon_adjust_lr, + ), + 'adamw': torch.optim.AdamW( + adam_params, + lr=hyperparameters.learning_rate, # shared + weight_decay=hyperparameters.adamw_weight_decay, # shared + betas=(hyperparameters.adamw_beta1, hyperparameters.adamw_beta2), + eps=hyperparameters.adamw_eps, + fused=True, + ), + } + + # One scheduler per optimizer + optimizer_state["muon_scheduler"] = _pytorch_cosine_warmup( + workload.step_hint, hyperparameters, optimizer_state["muon"] + ) + optimizer_state["adamw_scheduler"] = _pytorch_cosine_warmup( + workload.step_hint, hyperparameters, optimizer_state["adamw"] + ) + + return optimizer_state + + +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del train_state + del eval_results + + current_model = current_param_container + current_model.train() + optimizer_state["muon"].zero_grad() + optimizer_state["adamw"].zero_grad() + + # Fwd pass + logits_batch, new_model_state = workload.model_fn( + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True, + dropout_rate=hyperparameters.dropout_rate, + ) + + # Bwd pass + label_smoothing = ( + hyperparameters.label_smoothing + if hasattr(hyperparameters, "label_smoothing") + else 0.0 + ) + if hasattr(hyperparameters, "grad_clip"): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + + loss_dict = workload.loss_fn( + label_batch=batch["targets"], + logits_batch=logits_batch, + mask_batch=batch.get("weights"), + label_smoothing=label_smoothing, + ) + summed_loss = loss_dict["summed"] + n_valid_examples = loss_dict["n_valid_examples"] + if USE_PYTORCH_DDP: + # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. + summed_loss = dist_nn.all_reduce(summed_loss) + n_valid_examples = dist_nn.all_reduce(n_valid_examples) + loss = summed_loss / n_valid_examples + + # AllReduce grads + loss.backward() + + if grad_clip is not None: + torch.nn.utils.clip_grad_norm_(current_model.parameters(), max_norm=grad_clip) + optimizer_state["muon"].step() + optimizer_state["adamw"].step() + optimizer_state["muon_scheduler"].step() + optimizer_state["adamw_scheduler"].step() + + # # Log training metrics - loss, grad_norm, batch_size. + # if global_step <= 100 or global_step % 50 == 0: + # with torch.no_grad(): + # parameters = [p for p in current_model.parameters() if p.grad is not None] + # grad_norm = torch.norm( + # torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2 + # ) + # if workload.metrics_logger is not None: + # workload.metrics_logger.append_scalar_metrics( + # { + # 'loss': loss.item(), + # 'grad_norm': grad_norm.item(), + # }, + # global_step, + # ) + # logging.info( + # '%d) loss = %0.3f, grad_norm = %0.3f', + # global_step, + # loss.item(), + # grad_norm.item(), + # ) + + return (optimizer_state, current_param_container, new_model_state) + + +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == "criteo1tb": + return 262_144 + elif workload_name == "fastmri": + return 32 + elif workload_name == "imagenet_resnet": + return 1024 + elif workload_name == "imagenet_resnet_silu": + return 512 + elif workload_name == "imagenet_resnet_gelu": + return 512 + elif workload_name == "imagenet_vit": + return 1024 + elif workload_name == "librispeech_conformer": + return 256 + elif workload_name == "librispeech_deepspeech": + return 256 + elif workload_name == "ogbg": + return 512 + elif workload_name == "wmt": + return 128 + elif workload_name == "mnist": + return 16 + else: + raise ValueError(f"Unsupported workload name: {workload_name}.") + + +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/submissions/external_tuning/muon/pytorch/utils.py b/submissions/external_tuning/muon/pytorch/utils.py new file mode 100644 index 00000000..01193cc0 --- /dev/null +++ b/submissions/external_tuning/muon/pytorch/utils.py @@ -0,0 +1,38 @@ +import logging + + +def _split_params_muon_adam(model): + """Split parameters: + - Muon: all matrix params (ndim ≥ 2) except embeddings + - Adam: 1D params, all embeddings + """ + ## too simplistic + # params = [p for p in model.parameters() if p.requires_grad] + # matrix_params = [p for p in params if p.ndim >= 2] + # non_matrix_params = [p for p in params if p.ndim < 2] + + muon_params, adam_params = [], [] + muon_infos, adam_infos = [], [] # for logging only + + for n, p in model.named_parameters(): + if not p.requires_grad: + continue + + # Assign embeddings to Adam (wmt, criteo, fwedu) + if "embedding" in n.lower(): + adam_params.append(p) + adam_infos.append(f"{n} (ndim={p.ndim})") + elif "lm_head" in n.lower(): + adam_params.append(p) + adam_infos.append(f'{n} (ndim={p.ndim})') + elif p.ndim >= 2: + muon_params.append(p) + muon_infos.append(f"{n} (ndim={p.ndim})") + else: + adam_params.append(p) + adam_infos.append(f"{n} (ndim={p.ndim})") + + logging.info("Muon params:\n\t" + "\n\t".join(muon_infos)) + logging.info("Adam params:\n\t" + "\n\t".join(adam_infos)) + + return muon_params, adam_params diff --git a/submissions/external_tuning/muon/tuning_search_space.json b/submissions/external_tuning/muon/tuning_search_space.json new file mode 100644 index 00000000..21ea3953 --- /dev/null +++ b/submissions/external_tuning/muon/tuning_search_space.json @@ -0,0 +1,21 @@ +[ + { + "learning_rate": 0.001, + + "muon_weight_decay": 0.1, + "muon_beta": 0.9, + "muon_nesterov": true, + "muon_ns_steps": 5, + "muon_ns_eps": 1e-7, + "muon_adjust_lr": null, + + "adamw_weight_decay": 0.1, + "adamw_beta1": 0.9, + "adamw_beta2": 0.999, + "adamw_eps": 1e-8, + + "dropout_rate": 0.0, + "label_smoothing": 0.1, + "warmup_factor": 0.1 + } +]