Post

Polar Factor Beyond Newton-Schulz - Fast Matrix Inverse Square Root

A deep dive into computing the orthonormal polar factor (matrix sign function) for tall matrices using minimax polynomials, Jacobi preconditioning, and online certificates, moving beyond standard Newton-Schulz iterations.

Polar Factor Beyond Newton-Schulz - Fast Matrix Inverse Square Root

The Muon optimizer has found huge empirical success in machine learning. It’s essentially signSGD (or Lion by including momentum) for matrices. For the update, we need to approximate the sign function on the singular values of the momentum matrix to compute the polar factor.

Polar factor

Goal: Given \(G\in \mathbb{R}^{m \times n}\) tall (\(m \ge n\)), compute the (column-)orthonormal polar factor

\[ \mathrm{polar}(G):=G(G^\top G)^{-1/2} \]

For the compact SVD \(G=U\Sigma V^\top\), \(\mathrm{polar}(G)=UV^\top\). This is the “directional” component in the polar decomposition \(G=\mathrm{polar}(G) \vert G\vert \), similar to the polar coordinates of a complex number \(z=e^{i\theta}\cdot r\):

\[ \vert G\vert := \sqrt{ G^\top G } \quad \text{("stretch" part: modulus of matrix)} \]
\[ \mathrm{polar}(G)=G\vert G\vert ^{-1} \quad\text{("direction" part: unitary polar factor)}) \]

In Muon, we typically do not need high accuracy, but we do want:

  1. a fast GPU path (mostly GEMMs),
  2. numerical stability in bf16,
  3. a way to certify that \(\sigma_i(U)\) are close to \(1\).

Newton-Schulz/Polar Express iterations: normalize singular values to unit interval \([0,1]\) then directly compute with rectangular GEMMs.

Potential opportunity for \(m \gg n\): compute \((G^\top G)^{-1/2}\) on the small side and multiply once, can refine with full polar steps. This gives some nicer theoretical properties to try, e.g. (precomputed) online coefficient scheduling compared to Polar Express offline coefficients.

Plan: Gram-side polar factor for tall gradients using minimax polynomials + Jacobi preconditioning + online selection

Goal

Given \(G \in \mathbb{R}^{m \times n}\) tall (\(m \ge n\)), compute the orthonormal polar factor

\[ \mathrm{polar}(G) := G(G^\top G)^{-1/2}. \]

We want a fast ML-friendly approximation that:

  • uses only 2 rectangular GEMMs (form \(B=G^\top G\), final multiply \(G\widetilde Z\)),
  • does the iterative work on small \(n \times n\) matrices,
  • is stable in bf16 (fp32 accumulate where needed),
  • provides an online certificate that singular values of the returned factor are close to \(1\).

Key idea. Gram-side inverse square root

Let

\[ B := G^\top G \in \mathbb{R}^{n \times n}. \]

Compute \(\widetilde Z \approx B^{-1/2}\) using only \(n\times n\) work, then output

\[ \widetilde U := G \widetilde Z. \]

This is the same structural win that Polar Express exploits for rectangular matrices: form a Gram matrix once, iterate on the small side, then do one final rectangular multiply. Polar Express formalizes this as “Fast Polynomial Iteration for Rectangular Matrices” (Algorithm 4) (Amsel et al., 2025).

What we can certify online (stronger than rectangular direct iterations)

Define the Gram residual

\[ E := \widetilde U^\top \widetilde U - I = \widetilde Z^\top B \widetilde Z - I. \]

If \(\Vert E\Vert _2 \le \eta\), then

\[ \sqrt{1-\eta} \le \sigma_i(\widetilde U) \le \sqrt{1+\eta}. \]

Since \(\Vert E\Vert _2 \le \Vert E\Vert _F\), we can use the cheap sufficient check \(\Vert E\Vert _F \le \eta\) (all on \(n \times n\)). This gives a reliable online proxy for “how safe/aggressive can we be”.

Why we do NOT use AOL here (replace with unbiased Jacobi on \(B\))

Turbo-Muon’s AOL is a column scaling applied to \(G\) (so it changes the target to \(\mathrm{polar}(G S)\) and introduces bias) (Boissin et al., 2025). Since we are already working on the square SPD Gram matrix \(B\), we can get the spectrum-improving benefits without bias using an SPD congruence scaling:

\[ \widetilde B := D B D, \qquad B^{-1/2} = D \, \widetilde B^{-1/2} \, D. \]

This changes conditioning but not the mathematical target (up to numerical error).

Empirically, Jacobi scaling (unit-diagonal) is often the best simple choice:

\[ D := \mathrm{diag}(d), \qquad d_i = (B_{ii}+\epsilon)^{-1/2}. \]

Stability rules. bf16-safe iterations

Polar Express identifies low-precision issues when iterating via Gram-side polynomial compositions (their Algorithm 4) and suggests:

  • add a ridge early to avoid spurious indefiniteness from roundoff,
  • restart compositions to avoid ill-conditioned intermediate factors (Amsel et al., 2025).

We adopt the same philosophy:

  • always symmetrize \(B\) and ridge it,
  • use restart blocks when composing aggressive polynomials,
  • do all small-side iteration in fp32 (or at least fp32 accumulate and residual checks).

Core iteration: minimax-polynomial inverse square root for SPD matrices

Template (“drive the Gram to \(I\)”)

We compute an inverse square root of an SPD matrix \(A\) by maintaining \(Z_k \approx A^{-1/2}\) and driving

\[ S_k := Z_k^\top A Z_k \to I. \]

Update:

\[ Z_{k+1} = Z_k\,q_k(S_k), \]

so eigenvalues evolve as

\[ \lambda \mapsto \lambda' = \lambda\,q_k(\lambda)^2. \]

This matches the standard Newton-style “matrix-multiplication only” inverse-root framework (no factorizations), e.g. in analyses of inverse \(p\)th-root iterations (Guo and Higham, 2006).

Why minimax (Polar Express port)

Polar Express selects per-step polynomials using minimax optimization on an interval to get strong worst-case contraction (Amsel et al., 2025). We port that idea to the SPD eigenvalue map.

For a spectral interval \([\ell,u]\), choose degree-\(d\) polynomial \(q\) by

\[ q^\ast \in \arg\min_{q\in\mathcal{P}_d}\;\max_{\lambda\in[\ell,u]} \left\vert \sqrt{\lambda}\,q(\lambda) - 1\right\vert . \]

If \(\left\vert \sqrt{\lambda}\,q(\lambda)-1\right\vert \le\varepsilon\) on \([\ell,u]\), then

\[ \lambda' = (\sqrt{\lambda}\,q(\lambda))^2 \in [(1-\varepsilon)^2,(1+\varepsilon)^2], \]

giving a clean contraction/interval propagation rule.

We do not solve minimax online; instead we precompute a dense coefficient table offline and select online based on the measured residual.

Offline

Precompute two families:

Phase 1 (global) polynomials:

  • intervals \([\ell,1]\) with \(\ell\) log-spaced (e.g. \(\ell\in\{10^{-4},10^{-3},\dots ,0.5\}\)),
  • minimax \(q_{\ell}\) for each interval.

Phase 2 (local, symmetric-around-1) polynomials:

  • represent \(S = I + R\) and approximate \((I+R)^{-1/2}\),
  • intervals \(r\in[-\rho,\rho]\) with \(\rho\) on a grid (e.g. \(\rho\in\{0.02,0.05,0.1,0.2,0.35,0.5,0.7,0.9\}\)),
  • minimax \(p_{\rho}\) approximating \((1+r)^{-1/2}\) on \([-\rho,\rho]\).

Optionally impose stability constraints in the offline solve (recommended for bf16):

  • \(q(\lambda) > 0\) on the interval (SPD preservation),
  • cap overshoot: ensure \(\lambda q(\lambda)^2\) stays in a controlled range,
  • limit slope near \(1\) to avoid local amplification.

Online selection

At each step compute

\[ S = Z^\top A Z,\qquad \delta_S := \Vert S-I\Vert _F. \]

Then \(\Vert S-I\Vert _2 \le \delta_S\), so

\[ \lambda(S)\subset[1-\delta_S,\,1+\delta_S]. \]

Pick a slightly inflated design radius

\[ \rho_{\text{design}} := \gamma\,\delta_S,\qquad \gamma\in[1.1,1.5], \]

and choose the nearest polynomial \(p_{\rho_{\text{design}}}\) (Phase 2) or, in Phase 1, choose a conservative \(\ell\) schedule.


Two-phase scheme (safe globalization, aggressive local polish)

Phase 0: Form \(B\) and apply unbiased preconditioning

  1. \(B \leftarrow G^\top G\) (fp32 accumulate)
  2. \(B \leftarrow \tfrac12(B+B^\top)\)
  3. Ridge: \(B \leftarrow B + \delta I\)
  4. Jacobi: \(D_{ii} \leftarrow (B_{ii}+\epsilon)^{-1/2}\)
  5. \(\widetilde B \leftarrow DBD\) (elementwise scaling: \(\widetilde B_{ij}=d_i B_{ij} d_j\))

Phase 1: Safe scaling to \((0,1]\) and global minimax steps

  1. Upper bound \(\Lambda \ge \lambda_{\max}(\widetilde B)\) (Gershgorin \(\Vert \widetilde B\Vert _\infty\))
  2. Scale:

    \[ \alpha := \Lambda^{-1/2},\qquad A := \alpha^2 \widetilde B \]

    so \(\lambda(A)\subset(0,1]\)

  3. Initialize \(Z \leftarrow I\)
  4. Repeat in restart blocks (\(T_{\text{block}}\in\{2,3\}\)):
    • \(S \leftarrow Z^\top A Z\)
    • if \(\Vert S-I\Vert _F \le \rho_{\text{switch}}\) (e.g. \(0.5\)): break
    • choose \(q_\ell\) (table lookup for a conservative \(\ell\)) and apply:

      \[ Z \leftarrow Z\,q_\ell(S) \]
    • restart: recompute \(S\) in fp32 and reselect coefficients

Phase 2: Local symmetric-around-1 steps (aggressive but certified)

Now \(\Vert S-I\Vert _F\) is small enough that we can safely use symmetric intervals around \(1\).

Repeat for \(t=1,2\) (often 1 is enough):

  • \(S \leftarrow Z^\top A Z\)
  • \(\delta_S \leftarrow \Vert S-I\Vert _F\)
  • if \(\delta_S \le \eta\): stop
  • \(\rho_{\text{design}} \leftarrow \gamma\delta_S\)
  • lookup \(p_{\rho_{\text{design}}}\) and apply:

    \[ Z \leftarrow Z\,p_{\rho_{\text{design}}}(S-I) \]

Finish: map back to \(B^{-1/2}\) and form \(\widetilde U\)

  1. \(\widetilde B^{-1/2} \approx \alpha Z\)
  2. Map back:

    \[ \widetilde Z := B^{-1/2} \approx D(\alpha Z)D \]
  3. Output:

    \[ \widetilde U = G\widetilde Z \]

Certification and optional polish

Compute

\[ E = \widetilde Z^\top B \widetilde Z - I \]

and check \(\Vert E\Vert _F \le \eta\).


Restarts (important for bf16)

Use short composition blocks (\(T_{\text{block}}\in\{2,3\}\)), then recompute \(S\) and reselect coefficients. This mirrors Polar Express’s practical stabilization for Gram-side rectangular acceleration (Amsel et al., 2025).


Unbiased, minimax, Jacobi, online selection

Input: \(G\), ridge \(\delta\), Jacobi eps \(\epsilon\), tol \(\eta\), switch \(\rho_{\text{switch}}\), inflate \(\gamma\), coefficient tables

  1. \(B \leftarrow G^\top G\) (fp32 accumulate)
  2. \(B \leftarrow \tfrac12(B+B^\top) + \delta I\)
  3. \(d_i \leftarrow (B_{ii}+\epsilon)^{-1/2}\), \(D=\mathrm{diag}(d)\)
  4. \(\widetilde B \leftarrow DBD\)
  5. \(\Lambda \leftarrow\) upper bound on \(\lambda_{\max}(\widetilde B)\)
  6. \(\alpha \leftarrow \Lambda^{-1/2}\), \(A \leftarrow \alpha^2 \widetilde B\)
  7. \(Z \leftarrow I\)

Phase 1:

  1. repeat (restart blocks): a. \(S \leftarrow Z^\top A Z\) b. if \(\Vert S-I\Vert _F \le \rho_{\text{switch}}\): break c. select minimax \(q_\ell\) for a conservative \([\ell,1]\) d. \(Z \leftarrow Z\,q_\ell(S)\)

Phase 2:

  1. for \(t=1,2\):
    1. \(S \leftarrow Z^\top A Z\)
    2. \(\delta_S \leftarrow \Vert S-I\Vert _F\)
    3. if \(\delta_S \le \eta\): break
    4. \(\rho_{\text{design}} \leftarrow \gamma\delta_S\)
    5. select minimax \(p_{\rho_{\text{design}}}\)
    6. \(Z \leftarrow Z\,p_{\rho_{\text{design}}}(S-I)\)

Finish:

  1. \(Z_{\widetilde B} \leftarrow \alpha Z\) (approx \(\widetilde B^{-1/2}\))
  2. \(\widetilde Z \leftarrow D Z_{\widetilde B} D\) (approx \(B^{-1/2}\))
  3. \(\widetilde U \leftarrow G\widetilde Z\)
  4. \(E \leftarrow \widetilde Z^\top B \widetilde Z - I\); if \(\Vert E\Vert _F > \eta\), do one more Phase-2 step

Return: \(\widetilde U\)


What “dense coefficients” buys you

A dense coefficient grid lets you select a nearly optimal minimax polynomial for the actual measured residual each step (interval-driven updates), matching the spirit of Polar Express (Amsel et al., 2025), but with a stronger online interval proxy because \(S\) is small SPD.

It improves:

  • early contraction when the spectrum is wide,
  • iteration count when the spectrum is already tight,
  • stability: you can inflate the interval by \(\gamma\) and still stay close to minimax-optimal.

This is the clean way to be “more aggressive” while controlling effective convergence radius in bf16.

This post is licensed under CC BY 4.0 by the author.