Multi-Head Attention
Walk through every operation in multi-head attention, from raw input to final output — with real numbers, interactive heatmaps, and intuition.
Example: 4 tokens, embedding dim 6, 2 heads (head_dim = 3)
STEP 0
Input — What are Q, K, V?
Three input matrices of shape (seq_len × embed_dim), plus four learned weight matrices.
📖 Analogy: Think of a library — you walk in with a question (Query), the books have titles on the spine (Keys) and content inside (Values).
🔑 Key Insight: The weight matrices (Wq, Wk, Wv) are learned during training — they transform raw embeddings into useful query/key/value spaces.
STEP 1
Linear Projection — Creating Different Views
Q = Input × Wqᵀ K = Input × Wkᵀ V = Input × Wvᵀ
Each input is multiplied by its weight matrix, projecting into a new space.
📖 Analogy: The librarian gives you a reading guide (Wq), reorganizes the catalog (Wk), and rearranges the book summaries (Wv) — each projection reshapes the information to make matching more useful.
⚠️ What if we didn't? Without projection, you'd compare raw embeddings directly. No flexibility to learn what aspects of a token matter for matching.
🔑 Key Insight: The weight matrix rotates the vector into a new space where matching is more meaningful.
STEP 2
Reshape — Splitting into Attention Heads
Each projected vector is split into num_heads chunks of size head_dim. For our example: a 6-dim vector → 2 chunks of 3 dims.
📖 Analogy: Two librarians divide the work — one specializes in topic matching (dims 0–2), the other in style matching (dims 3–5). Each gets their own slice of the catalog.
⚠️ What if we didn't? With 1 head, all attention mixes into one pattern. Multiple heads learn SEVERAL independent relationship types at once.
🔑 Key Insight: head_dim = embed_dim / num_heads. We redistribute capacity, not lose it.
STEP 3
Transpose — Making Heads Independent
Reshape gives (seq_len, num_heads, head_dim). Transpose to (num_heads, seq_len, head_dim) so each head becomes an independent batch.
📖 Analogy: Each librarian goes to their own reading room — they can now search the shelves independently without getting in each other's way.
⚠️ What if we didn't? QKᵀ would mix head and sequence indices — total nonsense!
🔑 Key Insight: Shape (H, L, D) means: for each of H heads, compute attention over L tokens each with D dimensions.
STEP 4
Attention — Where the Magic Happens ✨
scores = QKᵀ / √d → weights = softmax(scores) → output = weights × V
For each head independently: compute scores, scale, softmax, then weighted average of values.
📖 Analogy: Each reader walks through the library asking "which books are relevant to me?" The attention weights are like borrowing priority — a score for how much each book's content matters to your question.
⚠️ What if we didn't scale? With large dims, QKᵀ blows up → softmax becomes argmax → always picks ONE token. No smooth blending!
🔑 Key Insight: Each output token is a blend of all inputs, weighted by relevance.
STEP 5
Merge — Reassembling Perspectives
Concatenate all head outputs back into (seq_len × embed_dim).
📖 Analogy: Each librarian returns with their own stack of relevant books — you gather all the stacks onto one shared reading table, combining every perspective.
🔑 Key Insight: (H, L, D) → (L, H×D) = (L, embed_dim). Just memory rearrangement, no computation.
STEP 6
Output Projection — Final Mix
Output = Merged × Woᵀ
The merged matrix is multiplied by the output weight matrix Wo.
📖 Analogy: The head librarian (Wo) reviews all the stacks on the table and composes a final research summary — blending insights from every librarian into one coherent answer.
⚠️ What if we didn't? The model would have no way to learn that "head 1 is about syntax and should be weighted more."
🔑 Key Insight: The ENTIRE MHA block: Output = Concat(head₁, head₂, …) × Wo, where each headᵢ = softmax(QᵢKᵢᵀ/√d) × Vᵢ.
SUMMARY
The Complete MHA Pipeline
| Step | Operation | Shape Transform |
| 0 | Input Q, K, V | (L, E) × 3 |
| 1 | Linear projection | (L, E) → (L, E) × 3 |
| 2 | Reshape / split heads | (L, E) → (L, H, D) |
| 3 | Transpose | (L, H, D) → (H, L, D) |
| 4 | Scaled dot-product attention | → weights (H, L, L), output (H, L, D) |
| 5 | Merge (concatenate heads) | (H, L, D) → (L, E) |
| 6 | Output projection | (L, E) × Woᵀ → (L, E) |
MHA(Q,K,V) = Concat(head₁, …, headₕ) · Wo
headᵢ = softmax(Qᵢ Kᵢᵀ / √dₖ) Vᵢ
All operations are differentiable — the entire block learns end-to-end via backpropagation.
That's the beauty of the Transformer architecture!