Polar Factor Beyond Newton-Schulz - Fast Matrix Inverse Square Root
A deep dive into computing the orthonormal polar factor (matrix sign function) for tall matrices using minimax polynomials, Jacobi preconditioning, and online certificates, moving beyond standard Newton-Schulz iterations.
The Muon optimizer has found huge empirical success in machine learning. It’s essentially signSGD (or Lion by including momentum) for matrices. For the update, we need to approximate the sign function on the singular values of the momentum matrix to compute the polar factor.
Polar factor
Goal: Given \(G\in \mathbb{R}^{m \times n}\) tall (\(m \ge n\)), compute the (column-)orthonormal polar factor
\[ \mathrm{polar}(G):=G(G^\top G)^{-1/2} \]
For the compact SVD \(G=U\Sigma V^\top\), \(\mathrm{polar}(G)=UV^\top\). This is the “directional” component in the polar decomposition \(G=\mathrm{polar}(G) \vert G\vert \), similar to the polar coordinates of a complex number \(z=e^{i\theta}\cdot r\):
In Muon, we typically do not need high accuracy, but we do want:
- a fast GPU path (mostly GEMMs),
- numerical stability in bf16,
- a way to certify that \(\sigma_i(U)\) are close to \(1\).
Newton-Schulz/Polar Express iterations: normalize singular values to unit interval \([0,1]\) then directly compute with rectangular GEMMs.
Potential opportunity for \(m \gg n\): compute \((G^\top G)^{-1/2}\) on the small side and multiply once, can refine with full polar steps. This gives some nicer theoretical properties to try, e.g. (precomputed) online coefficient scheduling compared to Polar Express offline coefficients.
Plan: Gram-side polar factor for tall gradients using minimax polynomials + Jacobi preconditioning + online selection
Goal
Given \(G \in \mathbb{R}^{m \times n}\) tall (\(m \ge n\)), compute the orthonormal polar factor
We want a fast ML-friendly approximation that:
- uses only 2 rectangular GEMMs (form \(B=G^\top G\), final multiply \(G\widetilde Z\)),
- does the iterative work on small \(n \times n\) matrices,
- is stable in bf16 (fp32 accumulate where needed),
- provides an online certificate that singular values of the returned factor are close to \(1\).
Key idea. Gram-side inverse square root
Let
\[ B := G^\top G \in \mathbb{R}^{n \times n}. \]Compute \(\widetilde Z \approx B^{-1/2}\) using only \(n\times n\) work, then output
\[ \widetilde U := G \widetilde Z. \]
This is the same structural win that Polar Express exploits for rectangular matrices: form a Gram matrix once, iterate on the small side, then do one final rectangular multiply. Polar Express formalizes this as “Fast Polynomial Iteration for Rectangular Matrices” (Algorithm 4) (Amsel et al., 2025).
What we can certify online (stronger than rectangular direct iterations)
Define the Gram residual
If \(\Vert E\Vert _2 \le \eta\), then
Since \(\Vert E\Vert _2 \le \Vert E\Vert _F\), we can use the cheap sufficient check \(\Vert E\Vert _F \le \eta\) (all on \(n \times n\)). This gives a reliable online proxy for “how safe/aggressive can we be”.
Why we do NOT use AOL here (replace with unbiased Jacobi on \(B\))
Turbo-Muon’s AOL is a column scaling applied to \(G\) (so it changes the target to \(\mathrm{polar}(G S)\) and introduces bias) (Boissin et al., 2025). Since we are already working on the square SPD Gram matrix \(B\), we can get the spectrum-improving benefits without bias using an SPD congruence scaling:
This changes conditioning but not the mathematical target (up to numerical error).
Empirically, Jacobi scaling (unit-diagonal) is often the best simple choice:
Stability rules. bf16-safe iterations
Polar Express identifies low-precision issues when iterating via Gram-side polynomial compositions (their Algorithm 4) and suggests:
- add a ridge early to avoid spurious indefiniteness from roundoff,
- restart compositions to avoid ill-conditioned intermediate factors (Amsel et al., 2025).
We adopt the same philosophy:
- always symmetrize \(B\) and ridge it,
- use restart blocks when composing aggressive polynomials,
- do all small-side iteration in fp32 (or at least fp32 accumulate and residual checks).
Core iteration: minimax-polynomial inverse square root for SPD matrices
Template (“drive the Gram to \(I\)”)
We compute an inverse square root of an SPD matrix \(A\) by maintaining \(Z_k \approx A^{-1/2}\) and driving
Update:
so eigenvalues evolve as
This matches the standard Newton-style “matrix-multiplication only” inverse-root framework (no factorizations), e.g. in analyses of inverse \(p\)th-root iterations (Guo and Higham, 2006).
Why minimax (Polar Express port)
Polar Express selects per-step polynomials using minimax optimization on an interval to get strong worst-case contraction (Amsel et al., 2025). We port that idea to the SPD eigenvalue map.
For a spectral interval \([\ell,u]\), choose degree-\(d\) polynomial \(q\) by
If \(\left\vert \sqrt{\lambda}\,q(\lambda)-1\right\vert \le\varepsilon\) on \([\ell,u]\), then
giving a clean contraction/interval propagation rule.
Online coefficients: dense offline grid + online selection (recommended)
We do not solve minimax online; instead we precompute a dense coefficient table offline and select online based on the measured residual.
Offline
Precompute two families:
Phase 1 (global) polynomials:
- intervals \([\ell,1]\) with \(\ell\) log-spaced (e.g. \(\ell\in\{10^{-4},10^{-3},\dots ,0.5\}\)),
- minimax \(q_{\ell}\) for each interval.
Phase 2 (local, symmetric-around-1) polynomials:
- represent \(S = I + R\) and approximate \((I+R)^{-1/2}\),
- intervals \(r\in[-\rho,\rho]\) with \(\rho\) on a grid (e.g. \(\rho\in\{0.02,0.05,0.1,0.2,0.35,0.5,0.7,0.9\}\)),
- minimax \(p_{\rho}\) approximating \((1+r)^{-1/2}\) on \([-\rho,\rho]\).
Optionally impose stability constraints in the offline solve (recommended for bf16):
- \(q(\lambda) > 0\) on the interval (SPD preservation),
- cap overshoot: ensure \(\lambda q(\lambda)^2\) stays in a controlled range,
- limit slope near \(1\) to avoid local amplification.
Online selection
At each step compute
Then \(\Vert S-I\Vert _2 \le \delta_S\), so
Pick a slightly inflated design radius
and choose the nearest polynomial \(p_{\rho_{\text{design}}}\) (Phase 2) or, in Phase 1, choose a conservative \(\ell\) schedule.
Two-phase scheme (safe globalization, aggressive local polish)
Phase 0: Form \(B\) and apply unbiased preconditioning
- \(B \leftarrow G^\top G\) (fp32 accumulate)
- \(B \leftarrow \tfrac12(B+B^\top)\)
- Ridge: \(B \leftarrow B + \delta I\)
- Jacobi: \(D_{ii} \leftarrow (B_{ii}+\epsilon)^{-1/2}\)
- \(\widetilde B \leftarrow DBD\) (elementwise scaling: \(\widetilde B_{ij}=d_i B_{ij} d_j\))
Phase 1: Safe scaling to \((0,1]\) and global minimax steps
- Upper bound \(\Lambda \ge \lambda_{\max}(\widetilde B)\) (Gershgorin \(\Vert \widetilde B\Vert _\infty\))
Scale:
\[ \alpha := \Lambda^{-1/2},\qquad A := \alpha^2 \widetilde B \]so \(\lambda(A)\subset(0,1]\)
- Initialize \(Z \leftarrow I\)
- Repeat in restart blocks (\(T_{\text{block}}\in\{2,3\}\)):
- \(S \leftarrow Z^\top A Z\)
- if \(\Vert S-I\Vert _F \le \rho_{\text{switch}}\) (e.g. \(0.5\)): break
choose \(q_\ell\) (table lookup for a conservative \(\ell\)) and apply:
\[ Z \leftarrow Z\,q_\ell(S) \]- restart: recompute \(S\) in fp32 and reselect coefficients
Phase 2: Local symmetric-around-1 steps (aggressive but certified)
Now \(\Vert S-I\Vert _F\) is small enough that we can safely use symmetric intervals around \(1\).
Repeat for \(t=1,2\) (often 1 is enough):
- \(S \leftarrow Z^\top A Z\)
- \(\delta_S \leftarrow \Vert S-I\Vert _F\)
- if \(\delta_S \le \eta\): stop
- \(\rho_{\text{design}} \leftarrow \gamma\delta_S\)
lookup \(p_{\rho_{\text{design}}}\) and apply:
\[ Z \leftarrow Z\,p_{\rho_{\text{design}}}(S-I) \]
Finish: map back to \(B^{-1/2}\) and form \(\widetilde U\)
- \(\widetilde B^{-1/2} \approx \alpha Z\)
Map back:
\[ \widetilde Z := B^{-1/2} \approx D(\alpha Z)D \]Output:
\[ \widetilde U = G\widetilde Z \]
Certification and optional polish
Compute
and check \(\Vert E\Vert _F \le \eta\).
Restarts (important for bf16)
Use short composition blocks (\(T_{\text{block}}\in\{2,3\}\)), then recompute \(S\) and reselect coefficients. This mirrors Polar Express’s practical stabilization for Gram-side rectangular acceleration (Amsel et al., 2025).
Unbiased, minimax, Jacobi, online selection
Input: \(G\), ridge \(\delta\), Jacobi eps \(\epsilon\), tol \(\eta\), switch \(\rho_{\text{switch}}\), inflate \(\gamma\), coefficient tables
- \(B \leftarrow G^\top G\) (fp32 accumulate)
- \(B \leftarrow \tfrac12(B+B^\top) + \delta I\)
- \(d_i \leftarrow (B_{ii}+\epsilon)^{-1/2}\), \(D=\mathrm{diag}(d)\)
- \(\widetilde B \leftarrow DBD\)
- \(\Lambda \leftarrow\) upper bound on \(\lambda_{\max}(\widetilde B)\)
- \(\alpha \leftarrow \Lambda^{-1/2}\), \(A \leftarrow \alpha^2 \widetilde B\)
- \(Z \leftarrow I\)
Phase 1:
- repeat (restart blocks): a. \(S \leftarrow Z^\top A Z\) b. if \(\Vert S-I\Vert _F \le \rho_{\text{switch}}\): break c. select minimax \(q_\ell\) for a conservative \([\ell,1]\) d. \(Z \leftarrow Z\,q_\ell(S)\)
Phase 2:
- for \(t=1,2\):
- \(S \leftarrow Z^\top A Z\)
- \(\delta_S \leftarrow \Vert S-I\Vert _F\)
- if \(\delta_S \le \eta\): break
- \(\rho_{\text{design}} \leftarrow \gamma\delta_S\)
- select minimax \(p_{\rho_{\text{design}}}\)
- \(Z \leftarrow Z\,p_{\rho_{\text{design}}}(S-I)\)
Finish:
- \(Z_{\widetilde B} \leftarrow \alpha Z\) (approx \(\widetilde B^{-1/2}\))
- \(\widetilde Z \leftarrow D Z_{\widetilde B} D\) (approx \(B^{-1/2}\))
- \(\widetilde U \leftarrow G\widetilde Z\)
- \(E \leftarrow \widetilde Z^\top B \widetilde Z - I\); if \(\Vert E\Vert _F > \eta\), do one more Phase-2 step
Return: \(\widetilde U\)
What “dense coefficients” buys you
A dense coefficient grid lets you select a nearly optimal minimax polynomial for the actual measured residual each step (interval-driven updates), matching the spirit of Polar Express (Amsel et al., 2025), but with a stronger online interval proxy because \(S\) is small SPD.
It improves:
- early contraction when the spectrum is wide,
- iteration count when the spectrum is already tight,
- stability: you can inflate the interval by \(\gamma\) and still stay close to minimax-optimal.
This is the clean way to be “more aggressive” while controlling effective convergence radius in bf16.