TaperNorm

Gated removal of normalization in Transformers for stable training and inference-time folding.

Goal: remove per-token normalization inside Transformer blocks at convergence (cheaper + simpler inference) without destabilizing training.

Idea

Use normalization when optimization is fragile, then taper it away with a single global gate so each norm layer becomes a fixed sample-independent scaling that can be folded into adjacent linear projections at inference.

What I did

  • Designed TaperNorm, a drop-in replacement for RMSNorm/LayerNorm that interpolates between:
    • a standard tokenwise normalization branch (early training), and
    • a learned sample-independent linear/affine map (late training).
  • Built a training schedule with a global gate: warmup at g=1, EMA-based calibration at the taper start, then cosine decay to g=0.
  • Analyzed the training dynamics and identified scale anchoring at the output as the key stabilizer: removing the final norm can induce logit chasing, and a simple fixed-target scale loss can replace the anchor.

Key results

  • Competitive loss vs. normalized baselines under identical setups while removing per-token statistics in internal layers.
  • After folding internal scalings into adjacent projections, a last-token logits microbenchmark reports up to 1.22× throughput on an NVIDIA H100 (bf16).

References

2026

  1. Gated Removal of Normalization in Transformers Enables Stable Training and Efficient Inference
    Andrei KanavalauCarmen Amo Alonso, and Sanjay Lall
    2026