内卷地狱

Theory of MoE

Edit Me

Theory of MoE

Basic Formula Definitions

For a vector ww, let w2\|w\|_2 and w\|w\|_\infty denote its 2\ell_2 norm and \ell_\infty norm, respectively.

Given positive constants c1,c2c_1, c_2, we define:

  • x=Ω(y)x = \Omega(y), if x>c2yx > c_2 |y|;
  • x=Θ(y)x = \Theta(y), if c1y<x<c2yc_1 |y| \lt x \lt c_2 |y|;
  • x=O(y)x = O(y), if x<c1yx \lt c_1 |y|;
  • x=o(y)x = o(y), if xy0\dfrac{x}{y} \to 0.

Where:

  • O(y)O(y): upper bound, meaning "grows no faster than yy".
  • Ω(y)\Omega(y): lower bound, meaning "grows at least as fast as yy".
  • Θ(y)\Theta(y): both upper and lower bounds are on the order of yy, meaning "same order as yy".
  • o(y)o(y): strictly much smaller than yy, ultimately approaching 00.

Key Assumptions:

  1. This paper aims only to derive closed-form forgetting formulas, so it simplifies directly to a linear model: f(X)=Xw,  wRdf(X)=X^{\top}w,\; w\in \mathbb{R}^d

  2. This paper only discusses task-wise routing methods. During data generation, each sample contains only one signal, with all other entries as Gaussian noise. This is again for model simplification. In practical engineering, tokens are implicitly routed to various experts rather than using manually specified routing.

Dataset Generation Rules

At each training round t[T]t \in [T], when a new task ntn_t arrives, the dataset Dt=(Xt,yt)\mathcal{D}_t = (X_t, y_t) is generated as follows:

  1. Sample the ground truth vector for the task
    • Uniformly sample a ground truth vector wntw_{n_t} from the task pool W={w1,,wN}\mathcal{W} = \{w_1, \dots, w_N\}, and set wntw_{n_t} as the ground truth for the current task.
  2. Generate scaling coefficient
    • Independently sample a random variable βt(0,C)\beta_t \in (0, C), where C=O(1)C = \mathcal{O}(1).
  3. Construct input feature matrix XtX_t
    • Generate from sts_t samples:
      • One sample is defined as βtvnt\beta_t v_{n_t}, where vntv_{n_t} is the feature signal of task ntn_t.
      • The remaining st1s_t - 1 samples come from a Gaussian distribution: N(0,σt2Id)\mathcal{N}(0, \sigma_t^2 I_d), where σt0\sigma_t \ge 0 is the noise level.
  4. Generate output labels yty_t
    • Using linear regression: yt=Xtwnty_t = X_t^\top w_{n_t}

Result:
Dataset Dt=(Xt,yt)\mathcal{D}_t = (X_t, y_t), corresponding to a linear regression task.

  1. This paper uses Top-1 expert selection only.

Formula Theory:

Expert parameter update: When the router selects a particular expert, all other experts remain unchanged; only the selected expert is updated, according to the following formula:

wt(mt)=wt1(mt)+Xt(XtXt)1(ytXtwt1(mt))w_t^{(m_t)} = w_{t-1}^{(m_t)} + X_t (X_t^\top X_t)^{-1}(y_t - X_t^\top w_{t-1}^{(m_t)})

Derivation of the Expert Parameter Update Formula

Objective: At round tt, expert mtm_t must fit the task dataset (Xt,yt)(X_t, y_t)
min_w Xtwyt_22\min\_{w}\ \|X_t^\top w - y_t\|\_2^2

Problem: Under overparameterization (s_t &lt; d), the solution is non-unique; directly computing the least-squares solution discards historical information.
> Therefore, the paper reformulates it as a constrained optimization:

minw wwt1(mt)22s.t.  Xtw=yt\min_w \ \|w - w_{t-1}^{(m_t)}\|_2^2 \quad s.t.\ \ X_t^\top w = y_t

Solution: Using Lagrange multipliers or residual projection, the update is:

wt(mt)=wt1(mt)+Xt(XtXt)1(ytXtwt1(mt))w_t^{(m_t)} = w_{t-1}^{(m_t)} + X_t (X_t^\top X_t)^{-1}\,(y_t - X_t^\top w_{t-1}^{(m_t)})

Interpretation:

  • (ytXtwt1)(y_t - X_t^\top w_{t-1}) = residual = true output − old prediction
  • Xt(XtXt)1X_t (X_t^\top X_t)^{-1} = the correction term that projects the residual back into parameter space
  • The entire expression = a least-squares correction near the old parameters

Properties:

  • Guarantees Xtwt=ytX_t^\top w_t = y_t → the new parameters perfectly fit the current task
  • Stays as close as possible to wt1w_{t-1} → minimizes catastrophic forgetting

Auxiliary loss (also commonly referred to as load balance):

Ltaux(Θt,Dt)=αMm[M]ft(m)Pt(m)L_t^{\text{aux}}(\Theta_t, \mathcal{D}_t) = \alpha \cdot M \cdot \sum_{m \in [M]} f_t^{(m)} \cdot P_t^{(m)}

Auxiliary Loss

Parameter explanation

  • α\alpha: weighting coefficient, controls the proportion of auxiliary loss in the total loss
  • MM: number of experts
  • ft(m)f_t^{(m)}: frequency with which expert mm has been selected in the first tt rounds (historical usage)
  • Pt(m)P_t^{(m)}: average routing probability assigned to expert mm by the router at round tt

Purpose

  • Penalizes experts that have been frequently used historically and are still assigned high probability in the current round
  • Encourages the router to make greater use of underutilized experts
  • Achieves load balancing to prevent experts from being over- or under-used
  • The trailing term is intuitively clear: when an expert mm has been used many times historically and is still assigned large logits in the current round, this loss term becomes very large, suppressing the router's preference for a few experts and thus preventing routing collapse

Locality loss:

Ltloc(Θt,Dt)=m[M]πm(Xt,Θt)wt(m)wt1(m)2L_t^{\text{loc}}(\Theta_t, \mathcal{D}_t) = \sum_{m \in [M]} \pi_m(X_t,\Theta_t)\, \|w_t^{(m)} - w_{t-1}^{(m)}\|_2

Locality Loss

Parameter explanation

  • πm(Xt,Θt)\pi_m(X_t,\Theta_t): probability assigned to expert mm by the router (softmax output)
  • wt(m)w_t^{(m)}: parameters of expert mm under the current task
  • wt1(m)w_{t-1}^{(m)}: parameters of expert mm from the previous round

Purpose

  • Constrains expert parameter updates from deviating too far from historical values
  • Encourages similar tasks to be routed to the same expert, thereby reducing loss
  • Reduces forgetting (updates for new tasks do not completely overwrite old knowledge)
  • Improves expert specialization: each expert gradually stabilizes on a particular type of task

Training loss:

Lttr(wt(mt),Dt)=1stXtwt(mt)yt22L_t^{\text{tr}}(w_t^{(m_t)}, \mathcal{D}_t) = \frac{1}{s_t}\,\|X_t^\top w_t^{(m_t)} - y_t\|_2^2

Training Loss

Parameter explanation

  • sts_t: number of data samples for the current task
  • XtX_t: feature matrix
  • yty_t: output label vector
  • wt(mt)w_t^{(m_t)}: parameters of the expert selected at round tt

Purpose

  • Essentially the mean squared error (MSE) of least-squares regression
  • Makes the selected expert fit the current task data
  • Ensures the expert can capture the true signal (ground truth) of the task

Total loss:

Lttask=Lttr+Ltloc+LtauxL_t^{\text{task}} = L_t^{\text{tr}} + L_t^{\text{loc}} + L_t^{\text{aux}}

With the above total loss function, router parameter updates can be performed during training.

Router update formula:

θt+1(m)=θt(m)ηθ(m)Lttask(Θt,wt(mt),Dt),m[M]\theta_{t+1}^{(m)} = \theta_t^{(m)} - \eta \cdot \nabla_{\theta^{(m)}} L_t^{\text{task}}(\Theta_t, w_t^{(m_t)}, \mathcal{D}_t), \quad \forall m \in [M]

Tricks:

Early Termination

In continual learning (CL) scenarios, if the gating network continues to update indefinitely, the allocation probabilities across different experts may gradually converge as more tasks arrive, eventually causing expert differentiation to collapse and routing errors. To address this, an Early Termination mechanism must be introduced.

  • Core Idea
    After sufficient rounds of task exploration (T1T_1 rounds), the expert assignments in MoE should gradually converge. Continuing to train the gating network at this point no longer yields benefits and instead leads to overfitting and blurring of task boundaries. Therefore, at an appropriate time, updates to the router parameters Θt\Theta_t should be terminated to maintain the stability of expert assignments.

  • Convergence Criterion
    Define a convergence indicator I(m)I^{(m)} to measure whether expert mm has converged:

    I(m)=hm(Xt,θt)hmt(Xt,θt)I^{(m)} = \big| h_m(X_t, \theta_t) - h_{m_t}(X_t, \theta_t) \big|

    where hm(Xt,θt)h_m(X_t,\theta_t) denotes the gating output of expert mm on the current input, and hmt(Xt,θt)h_{m_t}(X_t,\theta_t) denotes the output of the expert actually selected by the router.

    • If this gap is larger than threshold Γ\Gamma, expert mm has not yet converged and Θt\Theta_t should continue to be updated.
    • If this gap is smaller than threshold Γ\Gamma, the gating network is considered converged and updates to Θt\Theta_t are stopped.
    • This prevents the router from continuing to update after convergence, which would otherwise destroy expert assignments. It also ensures that different experts stably serve their respective task clusters. Combined with the constraints of LlocL^{loc} and LauxL^{aux}, the early termination mechanism enables the system to maintain balance and low forgetting in CL environments over the long term.

Multiple Variants of Locality Loss

  • Parameter Locality
Lparamloc=m[M]πm(Xt,Θt)wt(m)wt1(m)2 L^{loc}_{param} = \sum_{m \in [M]} \pi_m(X_t,\Theta_t)\,\|w_t^{(m)} - w_{t-1}^{(m)}\|_2
- The method used in the preceding sections.
- Ensures that the parameter differences for the same expert across adjacent tasks are not too large.
  • Representation Locality — Constraints can be applied directly to the representations (hidden states) output by each expert.

    - For example:
    Lreprloc=m[M]πm(Xt,Θt)fm(Xt)fm(Xt1)2L^{loc}_{repr} = \sum_{m \in [M]} \pi_m(X_t,\Theta_t)\,\|f_m(X_t) - f_m(X_{t-1})\|_2
    - Keeps similar inputs stable on the same expert.
  • Routing Locality — Constrains the router's assignment probabilities from jumping too drastically between tasks.

    - Of the form:
    Lrouteloc=m[M]πm(Xt,Θt)πm(Xt1,Θt1)2L^{loc}_{route} = \sum_{m \in [M]} \|\pi_m(X_t,\Theta_t) - \pi_m(X_{t-1},\Theta_{t-1})\|_2
  • Task Embedding Locality

    • If task embeddings can be constructed (e.g., via meta-learning or contrastive learning), one can define:
      • Similar tasks → routed to the same expert
      • Dissimilar tasks → differentiated as much as possible

贡献者


这篇文章有帮助吗?

最近更新

Involution Hell© 2026 byCommunityunderCC BY-NC-SA 4.0CCBYNCSA