TaperNorm

Gated removal of normalization in Transformers for stable training and inference-time folding.

Goal: remove per-token normalization inside Transformer blocks at convergence (cheaper + simpler inference) without destabilizing training.

Idea

Use normalization when optimization is fragile, then taper it away with a single global gate so each norm layer becomes a fixed sample-independent scaling that can be folded into adjacent linear projections at inference.

What I did

  • Designed TaperNorm, a drop-in replacement for RMSNorm/LayerNorm that interpolates between:
    • a standard tokenwise normalization branch (early training), and
    • a learned sample-independent linear/affine map (late training).
  • Built a training schedule with a global gate: warmup at g=1, EMA-based calibration at the taper start, then cosine decay to g=0.
  • Analyzed the training dynamics and identified scale anchoring at the output as the key stabilizer: removing the final norm can induce logit chasing, and a simple fixed-target scale loss can replace the anchor.

Key results

  • Competitive loss vs. normalized baselines under identical setups while removing per-token statistics in internal layers.
  • After folding internal scalings into adjacent projections, a last-token logits microbenchmark reports up to 1.22× throughput on an NVIDIA H100 (bf16).

References

2026

  1. Gated Removal of Normalization in Transformers Enables Stable Training and Efficient Inference
    Andrei KanavalauCarmen Amo Alonso, and Sanjay Lall
    2026
    arXiv link coming soon.