Conversation
includes necessary modules (LN, bias, scale, constant posemb)
In most cases image data has a channel dimension (that's usually 3) So making the input shape NHWC makes the ViT more applicable externally. Unexpectedly, bias alone stabilizes the model tho not nearly as performant.
Just like not caching RoPE, it's more in the spirit of JAX to defer constant materialization to forward pass and allows distributed init. without triggering discallowed host-to-device transfer error.
|
I have ported this "dualized ViT" to Big Vision and started training a dualized ViT-S/16 on ImageNet-1k. Implementation: I made a custom branch of modula that dry-runs
LR=0.05, WD=0.005 or 0.0001 (decoupled from LR), momentum w/ beta=0.95 I don't have great intuition here. Perhaps it still needs LR warm-up even though the model training is stable without it? Notable architecture (or rather, just scaling) differences from the baseline:
Optimizer differences from the "conventional" muon:
If necessary it may be interesting to bisect these differences. |
| x1 = x[..., self.rope_dim:] # shape [batch, n_heads, seq_len, rope_dim] | ||
| x2 = x[..., :self.rope_dim] # shape [batch, n_heads, seq_len, rope_dim] | ||
|
|
||
| # Why is the order reversed!? |
There was a problem hiding this comment.
I noticed this too and corrected it in #13. I guess it doesn't really matter for performance?


includes necessary modules (LN, bias, scale, constant posemb) and notebook showing it working on MNIST. Initial tuning shows that momentum w/
dualizenot quite as performant as Adam butMore of a demonstration but could be merged.