Skip to content

Jk moe#172

Merged
kuba-krj merged 16 commits intomainfrom
jk_moe
Apr 14, 2026
Merged

Jk moe#172
kuba-krj merged 16 commits intomainfrom
jk_moe

Conversation

@kuba-krj
Copy link
Copy Markdown
Contributor

@kuba-krj kuba-krj commented Mar 26, 2026

Adding MoE to our codebase, written with the assistance of Codex.

MFU

MFU calculated on 1 GPU is not great: ~8% with the following settings:

  dmodel: 1024
  dff: 2816
  dhead: 64
  n_blocks: 16
  q_heads: 16
  kv_heads: 16

  num_experts: 16
  num_experts_per_tok: 1
  capacity_factor: 1.25

with batch size=32 (the largest that could fit on 1 GPU), seq_len=1024. This is ~2x slower than dense with the same number of active params, trained with batch size=64 (also the largest that could fit). Possibly MFU is better on multi-gpu due to a larger batch size that we can use, but the exps are waiting in the queue, I will update when they are finished.

Correctness

I compared dense model (settings as above) with MoE where E={1, 2, 4, 16}. The results look reasonably - E=1 matches dense, and models get better with more experts.

Link to verification experiments: wandb project

@kuba-krj kuba-krj marked this pull request as ready for review March 28, 2026 19:14
@kuba-krj kuba-krj requested a review from j321m March 28, 2026 19:14
Copy link
Copy Markdown
Collaborator

@j321m j321m left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

claude assisted review.

I'd like to see multigpu run (smoke test)

please address the comments

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this file from PR

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove / rename

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed

Comment thread src/core/moe.py Outdated
moe_router_z_loss_factor: float = 0.0,
activation_function: str = "swiglu",
init_scale: float = 1.0,
**_ignored_kwargs,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is **_ignored_kwargs necessary?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I understand our config system, it's necessary to be able to keep MoE configs similar to how it is done in small_moe.yaml, where we just set

ff_layer_fn:
        _target_: src.core.moe.MoE

to use MoE (because we keep ff_layer_fn from base config and only replace _ target _ in it). Please let me know if you prefer to change the config structure to sth like - override /ff_layer@model.encoder.block_fn.ff_layer_fn: moe - we can then get rid of **_ignored_kwargs

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to have separate base yamls for dense and moe, than **kwargs.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

override /ff_layer@model.encoder.block_fn.ff_layer_fn: moe
is also a good idea (even better, but may need some config refactoring)

Comment thread src/core/moe.py
router_logits = router_logits.to(dtype=torch.float32)
router_probs = F.softmax(router_logits, dim=-1)
# For each token, keep only the top-k experts and their routing probabilities
topk_probs, selected_experts = torch.topk(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: should the routing weights sum to 1, when num_experts_per_tok > 1?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the option to normalize

Comment thread src/core/trainer.py Outdated
@@ -228,11 +312,19 @@ def _update_processed_tokens(self, batch):

def log_metrics(self, loss, grad_norm):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

function log_metrics ignores the loss argument. also i'm not sure if i like self._last_reported_loss, self._last_moe_router_z_loss, ect., it makes the code more errorprone in my opinion but I'm open for discussion

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored

Comment thread src/core/trainer.py Outdated
self.metric_logger.set_tokens(self.processed_tokens)
self.metric_logger.log("train/loss", loss.item())
self.metric_logger.log("train/loss", self._last_reported_loss.item())
self.metric_logger.log(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MoE metrics will get logged even for dense models, do we want that?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed moe metric logging for dense

Comment thread src/core/trainer.py Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

function eval calls self.calculate_loss(batch) which overwrites self._last_reported_loss, resulting in same eval and train loss.

It works now, since log_metrics gets called after eval

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored

Comment thread src/core/moe.py Outdated
self.moe_load_balancing_loss_factor = moe_load_balancing_loss_factor
self.moe_router_z_loss_factor = moe_router_z_loss_factor
self.is_moe = True
self.aux_loss = None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is aux_loss for? it is set to the same value as load_balancing_loss

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

@kuba-krj
Copy link
Copy Markdown
Contributor Author

kuba-krj commented Apr 2, 2026

I'd like to see multigpu run (smoke test)

Link to multrigpu run: wandb

Copy link
Copy Markdown
Contributor Author

@kuba-krj kuba-krj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added changes and ran a test run identical to the previous 2-gpu to see if the results are unchanged: wandb link. Pls let me know if the PR looks good now or if additional changes are needed

@j321m
Copy link
Copy Markdown
Collaborator

j321m commented Apr 10, 2026

removing **kwargs is very important, separate dense and MoE config lines should solve the problem.

Copy link
Copy Markdown
Collaborator

@j321m j321m left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome

@kuba-krj kuba-krj merged commit d12c8f0 into main Apr 14, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants