Conversation
j321m
left a comment
There was a problem hiding this comment.
claude assisted review.
I'd like to see multigpu run (smoke test)
please address the comments
| moe_router_z_loss_factor: float = 0.0, | ||
| activation_function: str = "swiglu", | ||
| init_scale: float = 1.0, | ||
| **_ignored_kwargs, |
There was a problem hiding this comment.
is **_ignored_kwargs necessary?
There was a problem hiding this comment.
As far as I understand our config system, it's necessary to be able to keep MoE configs similar to how it is done in small_moe.yaml, where we just set
ff_layer_fn:
_target_: src.core.moe.MoE
to use MoE (because we keep ff_layer_fn from base config and only replace _ target _ in it). Please let me know if you prefer to change the config structure to sth like - override /ff_layer@model.encoder.block_fn.ff_layer_fn: moe - we can then get rid of **_ignored_kwargs
There was a problem hiding this comment.
I think it's better to have separate base yamls for dense and moe, than **kwargs.
There was a problem hiding this comment.
override /ff_layer@model.encoder.block_fn.ff_layer_fn: moe
is also a good idea (even better, but may need some config refactoring)
| router_logits = router_logits.to(dtype=torch.float32) | ||
| router_probs = F.softmax(router_logits, dim=-1) | ||
| # For each token, keep only the top-k experts and their routing probabilities | ||
| topk_probs, selected_experts = torch.topk( |
There was a problem hiding this comment.
question: should the routing weights sum to 1, when num_experts_per_tok > 1?
There was a problem hiding this comment.
I added the option to normalize
| @@ -228,11 +312,19 @@ def _update_processed_tokens(self, batch): | |||
|
|
|||
| def log_metrics(self, loss, grad_norm): | |||
There was a problem hiding this comment.
function log_metrics ignores the loss argument. also i'm not sure if i like self._last_reported_loss, self._last_moe_router_z_loss, ect., it makes the code more errorprone in my opinion but I'm open for discussion
| self.metric_logger.set_tokens(self.processed_tokens) | ||
| self.metric_logger.log("train/loss", loss.item()) | ||
| self.metric_logger.log("train/loss", self._last_reported_loss.item()) | ||
| self.metric_logger.log( |
There was a problem hiding this comment.
MoE metrics will get logged even for dense models, do we want that?
There was a problem hiding this comment.
removed moe metric logging for dense
There was a problem hiding this comment.
function eval calls self.calculate_loss(batch) which overwrites self._last_reported_loss, resulting in same eval and train loss.
It works now, since log_metrics gets called after eval
| self.moe_load_balancing_loss_factor = moe_load_balancing_loss_factor | ||
| self.moe_router_z_loss_factor = moe_router_z_loss_factor | ||
| self.is_moe = True | ||
| self.aux_loss = None |
There was a problem hiding this comment.
what is aux_loss for? it is set to the same value as load_balancing_loss
Link to multrigpu run: wandb |
kuba-krj
left a comment
There was a problem hiding this comment.
Added changes and ran a test run identical to the previous 2-gpu to see if the results are unchanged: wandb link. Pls let me know if the PR looks good now or if additional changes are needed
|
removing **kwargs is very important, separate dense and MoE config lines should solve the problem. |
Adding MoE to our codebase, written with the assistance of Codex.
MFU
MFU calculated on 1 GPU is not great: ~8% with the following settings:
with
batch size=32(the largest that could fit on 1 GPU),seq_len=1024. This is ~2x slower than dense with the same number of active params, trained withbatch size=64(also the largest that could fit). Possibly MFU is better on multi-gpu due to a larger batch size that we can use, but the exps are waiting in the queue, I will update when they are finished.Correctness
I compared dense model (settings as above) with MoE where E={1, 2, 4, 16}. The results look reasonably - E=1 matches dense, and models get better with more experts.
Link to verification experiments: wandb project