|
| 1 | +--- |
| 2 | +title: Auxiliary-loss Load Balancing in MoEs (1) |
| 3 | +date: 2025-07-06 |
| 4 | +tags: [cs, ai, notes] |
| 5 | +author: R |
| 6 | +location: Above Illulissat, Greenland while on a plane from New York to Hong Kong |
| 7 | +--- |
| 8 | + |
| 9 | +I'm currently reading [this conference paper for ICLR 2025](https://arxiv.org/pdf/2408.15664) (Wang et al., 2025) as I'm preparing for my internship, but after going through the intro I'd like to take down some notes, as there are a lot of ideas and lessons that I've learned while reading it. Italics in this note are directly quoted from the paper. |
| 10 | + |
| 11 | +## MoEs |
| 12 | +After opening the paper I encountered the concept of MoEs. To get myself more familiar, I read [this blog on Hugging Face](https://huggingface.co/blog/moe) (Sanseviero et al., 2023), which was really helpful—highly recommended. MoE stands for **M**ixture **o**f **E**xperts; a famous example of its type is DeepSeek. It has many advantages, as the authors wrote: easy to scale to a large number of parameters, manageable costs, etc. |
| 13 | + |
| 14 | +*Let $u_t$ denote the input of the $t$-th token to an $N$-expert MoE layer, the output $h_t$ is computed as follows:* |
| 15 | +Let |
| 16 | +- $N$ be the number of experts, |
| 17 | +- $K$ the number of experts selected per token, |
| 18 | +- $T$ the total number of tokens in the batch, |
| 19 | +- $\mathbf{u}_t\in\R^d$ the input for token $t$, |
| 20 | +- $\text{FFN}_i: \R^d\to\R^d$ the $i$-th expert network, |
| 21 | +- $e_i\in\R^d$ the centroid (parameter) of expert $i$, and |
| 22 | +- $G\colon\R\to\R_{>0}$ a positive gating function (e.g. $\exp$, $\text{sigmoid}$, or $\text{softmax}$). |
| 23 | +- $s_{i,t}$ is the *raw gating score* for expert $i$ on token $t$, obtained by applying $G$ to the dot‐product of input $\mathbf{u}_t$ and expert centroid $e_i$. |
| 24 | +- $g_{i,t}$ is the *pruned gating weight*: it equals $s_{i,t}$ if $s_{i,t}$ ranks among the top-$K$ scores for token $t$, and zero otherwise. |
| 25 | + |
| 26 | +Compute for each token $t$ and expert $i$: |
| 27 | +$$ |
| 28 | +\begin{align*} |
| 29 | + g_{i,t} &= |
| 30 | + \begin{cases} |
| 31 | + s_{i,t}, & s_{i,t} \in \text{Topk}(\{s_{j,t} | 1 \leq j \leq N\} , K)\\\\ |
| 32 | + & (\text{if $s_{i,t}$ is among the top-$K$ scores})\\\\ |
| 33 | + 0, & \text{otherwise} |
| 34 | + \end{cases} \\\\ |
| 35 | + s_{i,t} &= G(\textbf{u}_t^\top e_i) |
| 36 | +\end{align*} |
| 37 | +$$ |
| 38 | + |
| 39 | +and form the layer output |
| 40 | + |
| 41 | +$$ |
| 42 | +\textbf{h}_t = \textbf{u}_t + \sum^N_{i=1} g_{i,t} \text{FFN}_i (\textbf{u}_t) |
| 43 | +$$ |
| 44 | + |
| 45 | +So here $G$ could be any function $\R \to \R_{>0}$. Some conventional ones could be $\exp$, softmax, or sigmoid (to be honest I had to look these two up to see what they are exactly). In this paper they use the latter two. |
| 46 | + |
| 47 | +And there is the expert consulted following the gating function. |
| 48 | + |
| 49 | +## Problem: Imbalanced routing |
| 50 | +But one problem MoEs often experience is imbalanced routing (a small number of experts receive most tokens), thus creating *a risk of routing collapse (Shazeer et al., 2017), where the model consistently selects only a few experts, hindering sufficient training of the other experts*, or a *computational bottleneck due to load imbalance*. |
| 51 | + |
| 52 | +I was wondering how it could cause a computational bottleneck, but then I realized the way I thought about it—that it could easily scale through parallelism or other ways—is not easily achievable. Since there are different machines hosting each expert, it depends more on the load given to a certain expert. |
| 53 | + |
| 54 | +Plus, the training loop would need a substantial redesign to use the idle computational power to catch up. Even if I create replicas for the "hot" experts on more hosts, they need to be in sync, which creates a lot of cost by itself. Merging gradients across replicas requires collective operations every step; at that point it will just recreate the original problem we’re trying to overcome if one of these slows down... |
| 55 | + |
| 56 | +### Solution: Auxiliary-loss |
| 57 | +To address this issue, there is an auxiliary loss that encourages balanced load and thus avoids imbalanced routing in training MoEs. To do this, it penalizes the use of only a few experts. It’s mostly within the process of the gating function. |
| 58 | + |
| 59 | +**Key variables:** |
| 60 | +- $N$: number of experts in the MoE layer |
| 61 | +- $K$: number of experts selected per token (top-K) |
| 62 | +- $T$: total number of tokens in the batch |
| 63 | +- $\mathbb{1}$: indicator function (equals 1 if condition is true, 0 otherwise) |
| 64 | +- $\alpha$: balancing‐loss weight (manually set hyperparameter) |
| 65 | + |
| 66 | +Defined as such: |
| 67 | + |
| 68 | +- **Normalized load** |
| 69 | + $f_i$:= the fraction of tokens routed to expert $i$: |
| 70 | + $$ |
| 71 | + f_i = \frac{N}{KT} \sum_{t=1}^T \mathbb{1} (i \in \text{Topk} \mid \mathbf{u}_t ) |
| 72 | + $$ |
| 73 | + |
| 74 | +- **Average gating weight** |
| 75 | + $P_i$:= the mean score assigned by the gate to expert $i$: |
| 76 | + $$ |
| 77 | + P_i = \frac{1}{T} \sum_{t=1}^T s_{i,t} |
| 78 | + $$ |
| 79 | + |
| 80 | +Combine these into a single penalty term: |
| 81 | + |
| 82 | +$$\mathcal{L}_{\mathrm{balance}} = \alpha \sum_{i=1}^N f_i P_i$$ |
| 83 | + |
| 84 | + |
| 85 | +**Regularization terms:** |
| 86 | +Introduce two small-weight penalties on the imbalance of $\{P_i\}$ and $\{f_i\}$: |
| 87 | + |
| 88 | +\begin{align*} |
| 89 | + \mathcal{L}_P &= \lambda_P \operatorname{CV}^2({P_i}) \\\\ |
| 90 | + \mathcal{L}_f &= \lambda_f \operatorname{CV}^2({f_i}) |
| 91 | +\end{align*} |
| 92 | + |
| 93 | +where typically $\lambda_{P} \approx \lambda_{f} \sim 10^{-2}$. |
| 94 | + |
| 95 | +> This is actually optional, for simpler just use $\mathcal{L}_{\text{total}} = \mathcal{L}_{\text{task}} + \mathcal{L}_{\text{balance}}$. |
| 96 | +> I write it this way just to follow the [original MoE auxiliary-loss formulation paper (Shazeer et al. (2017))](https://arxiv.org/pdf/1701.06538). |
| 97 | +
|
| 98 | + |
| 99 | +**Imbalance metric: coefficient of variation squared** |
| 100 | +For any set of scalars $\{z_i\}_{i=1}^N$, define |
| 101 | + |
| 102 | +$$ |
| 103 | + \text{CV}^2(\{z_i\}) = |
| 104 | + \frac{\frac{1}{N} \sum_{i=1}^N z_i^2 - (\frac{1}{N} \sum_{i=1}^N z_i )^2} |
| 105 | + {(\frac{1}{N} \sum_{i=1}^N z_i )^2}, |
| 106 | +$$ |
| 107 | + |
| 108 | +which satisfies $\text{CV}^2=0$ exactly when all $z_i$ are equal. |
| 109 | + |
| 110 | +> By the way, this looks very much like the variance. |
| 111 | +> Write $\mu = \tfrac1N \sum_i z_i$ and $\nu = \tfrac1N \sum_i z_i^2$. Then |
| 112 | +> $$ |
| 113 | +> \text{CV}^2 = \frac{\nu - \mu^2}{\mu^2} = \frac{\text{Var}}{(\text{Mean})^2} |
| 114 | +> $$ |
| 115 | +> Its partial derivative w. one coordinate $z_k$ is |
| 116 | +> $$ |
| 117 | +> \frac{\partial \text{CV}^2}{\partial z_k} |
| 118 | +> = \frac{2}{N}\Bigl(\frac{z_k}{\mu^2} - \frac{\nu}{\mu^3}\Bigr). |
| 119 | +> $$ |
| 120 | +> > Details: |
| 121 | +\begin{align*} |
| 122 | + \frac{\partial}{\partial z_k} (\tfrac{\nu - \mu^2}{\mu^2}) |
| 123 | + &= \frac{1}{\mu^2} \frac{\partial\nu}{\partial z_k} - \frac{\nu - \mu^2}{\mu^4} 2\mu \frac{\partial\mu}{\partial z_k}\\\\ |
| 124 | + &= \frac{1}{\mu^2} \frac{2z_k}{N} - \frac{\nu - \mu^2}{\mu^4} \frac{2\mu}{N}\\\\ |
| 125 | + &= \frac{2}{N}\Bigl(\frac{z_k}{\mu^2} - \frac{\nu}{\mu^3}\Bigr). |
| 126 | +\end{align*} |
| 127 | +> |
| 128 | +> Because $\nu/\mu^3$ is the same constant for all $k$, this gradient pushes down any $z_k > \mu$ (overloaded expert) and pushes up any $z_k < \mu$ (underloaded expert). In other words, the derivative of a variance term normalized by $\mu^2$. |
| 129 | +
|
| 130 | + |
| 131 | +**Total training objective** |
| 132 | +Combine with the primary task loss $L_{\text{task}}$: |
| 133 | +$$\mathcal{L}_{\text{total}} = \mathcal{L}_{\text{task}} + \mathcal{L}_{P} + \mathcal{L}_{f}.$$ |
| 134 | + |
| 135 | +#### Intuition |
| 136 | +- The penalty grows as either $f_i$ or $P_i$ grows (since it's a product). Then the routing distribution is driven toward uniformity by the penalties. Backpropagation through the parameters plays a role in this process. |
| 137 | +- Minimizing $\text{CV}^2$ drives the variance of $\{\text{Imp}_i\}$ or $\{\text{Load}_i\}$ toward zero relative to their mean (see derivation of $\partial \text{CV}^2/\partial z_k$ above). |
| 138 | +- Any expert $i$ with above-average usage raises its own $\text{Imp}_i$ or $\text{Load}_i$, increasing the penalty. |
| 139 | + |
| 140 | + |
| 141 | +#### Drawbacks |
| 142 | +The ICLR 2025 paper mentioned that auxiliary loss might introduce unwanted gradients, as the MoE models perform worse on some metrics. |
| 143 | + |
| 144 | +However, I wasn't really convinced by this reasoning. The performance was not improved that significantly (I was expecting a larger gap) for the validation perplexity. There's a bunch of other models they could choose from, but instead they picked this small one. The load balance one sounds okay, and that's the main point of the paper, so it's good. |
| 145 | + |
| 146 | +The true drawback, in my opinion, comes with the act of rebalancing through auxiliary loss itself. |
| 147 | +- The idea of MoE is having many highly specialized experts; auxiliary loss fights any concentration of weight, even if that concentration was beneficial for modeling those tokens. |
| 148 | +- The balancing gradient for an expert involves all experts' totals. So updating the logit for one expert now depends on every other expert’s load. It's obvious it can drown out more specialized signals. |
| 149 | +- Naturally, experts that are good at certain tokens are expected to get those; trying to make the router equalize loads regardless of quality can route a token to a weaker expert, simply because the "best" expert is already slightly busier. |
| 150 | + |
| 151 | +(TBC) |
0 commit comments