TaperNorm
Gated removal of normalization in Transformers for stable training and inference-time folding.
Goal: remove per-token normalization inside Transformer blocks at convergence (cheaper + simpler inference) without destabilizing training.
Idea
Use normalization when optimization is fragile, then taper it away with a single global gate so each norm layer becomes a fixed sample-independent scaling that can be folded into adjacent linear projections at inference.
What I did
- Designed TaperNorm, a drop-in replacement for RMSNorm/LayerNorm that interpolates between:
- a standard tokenwise normalization branch (early training), and
- a learned sample-independent linear/affine map (late training).
- Built a training schedule with a global gate: warmup at
g=1, EMA-based calibration at the taper start, then cosine decay tog=0. - Analyzed the training dynamics and identified scale anchoring at the output as the key stabilizer: removing the final norm can induce logit chasing, and a simple fixed-target scale loss can replace the anchor.
Key results
- Competitive loss vs. normalized baselines under identical setups while removing per-token statistics in internal layers.
- After folding internal scalings into adjacent projections, a last-token logits microbenchmark reports up to 1.22× throughput on an NVIDIA H100 (bf16).
References
2026
- Gated Removal of Normalization in Transformers Enables Stable Training and Efficient Inference2026arXiv link coming soon.