Skip to content
mentorship

concepts

Multi-head attention: why one head is not enough

Run h independent attention computations in parallel, then concatenate. Each head specializes in a different relation. The mechanism most senior candidates can write but few can motivate.

Reviewed · 3 min read

One-line definition

Multi-head attention projects , , into lower-dimensional subspaces, runs scaled dot-product attention independently in each, and concatenates the results before a final output projection. Same FLOPs as one large head; very different inductive bias.

Why it matters

Single-head attention computes one weighted average per position. That single distribution has to encode every relation the model needs: syntactic, positional, semantic, coreferential. In practice it cannot, and ablations show that single-head transformers underperform multi-head transformers at matched parameter count (Vaswani et al., 2017).

Multiple heads let different attention patterns coexist. One head learns “previous token,” another “matching bracket,” another “this noun’s modifier.” Probing studies on BERT show many heads fire on syntactic dependencies that linguists recognize (Clark et al., 2019).

The mechanism

Given input and head count with per-head dimension :

  1. Project: , , , each shape . Reshape to .
  2. Per-head attention: for each head ,
  1. Concatenate: stack the heads back into shape .
  2. Output projection: .

Total parameters: (the four projection matrices). FLOPs: . Identical to single-head; the heads share the budget.

Why split the dimension

If you keep per head and run heads, you multiply parameters and compute by . Splitting across heads keeps the cost matched to a single-head baseline, so any gain is attributable to the multiplicity itself, not extra capacity. This is the design choice that makes the comparison meaningful.

Variants

  • Multi-query attention (MQA): share and across all heads; only is per-head. KV-cache shrinks by x. See GQA and MQA.
  • Grouped-query attention (GQA): share across groups of heads. Compromise between full MHA and MQA. The Llama 2/3 default.
  • Cross-attention: from one sequence, from another. See self-attention vs cross-attention.
  • Sliding-window / sparse: restrict each head to a local window or learned sparse pattern.

Tradeoffs

  • Head count: 8 to 32 is typical. More heads with smaller can hurt expressiveness; fewer heads with larger loses specialization. to is the modern sweet spot.
  • KV-cache memory scales linearly with in vanilla MHA. The motivation for MQA and GQA at long context.

Common pitfalls

  • Equating “more heads” with “more capacity.” Splitting fixes the parameter budget; it is a structural choice, not a scale-up.
  • Reading the post-softmax weights as “what the model attends to.” Heads are mixed in . Single-head probes can be misleading.
  • Treating MHA as the bottleneck. In long-context LLMs, the FFN is usually larger; attention compute scales with but FFN compute scales with .