Skip to content

Multi-Head Attention — Step by Step

Multi-Head Attention

Walk through every operation in multi-head attention, from raw input to final output — with real numbers, interactive heatmaps, and intuition.

Example: 4 tokens, embedding dim 6, 2 heads (head_dim = 3)

STEP 0

Input — What are Q, K, V?

Three input matrices of shape (seq_len × embed_dim), plus four learned weight matrices.

📖 Analogy: Think of a library — you walk in with a question (Query), the books have titles on the spine (Keys) and content inside (Values).
🔑 Key Insight: The weight matrices (Wq, Wk, Wv) are learned during training — they transform raw embeddings into useful query/key/value spaces.
STEP 1

Linear Projection — Creating Different Views

Q = Input × Wqᵀ    K = Input × Wkᵀ    V = Input × Wvᵀ

Each input is multiplied by its weight matrix, projecting into a new space.

📖 Analogy: The librarian gives you a reading guide (Wq), reorganizes the catalog (Wk), and rearranges the book summaries (Wv) — each projection reshapes the information to make matching more useful.
⚠️ What if we didn't? Without projection, you'd compare raw embeddings directly. No flexibility to learn what aspects of a token matter for matching.
🔑 Key Insight: The weight matrix rotates the vector into a new space where matching is more meaningful.
STEP 2

Reshape — Splitting into Attention Heads

Each projected vector is split into num_heads chunks of size head_dim. For our example: a 6-dim vector → 2 chunks of 3 dims.

📖 Analogy: Two librarians divide the work — one specializes in topic matching (dims 0–2), the other in style matching (dims 3–5). Each gets their own slice of the catalog.
⚠️ What if we didn't? With 1 head, all attention mixes into one pattern. Multiple heads learn SEVERAL independent relationship types at once.
🔑 Key Insight: head_dim = embed_dim / num_heads. We redistribute capacity, not lose it.
STEP 3

Transpose — Making Heads Independent

Reshape gives (seq_len, num_heads, head_dim). Transpose to (num_heads, seq_len, head_dim) so each head becomes an independent batch.

📖 Analogy: Each librarian goes to their own reading room — they can now search the shelves independently without getting in each other's way.
⚠️ What if we didn't? QKᵀ would mix head and sequence indices — total nonsense!
🔑 Key Insight: Shape (H, L, D) means: for each of H heads, compute attention over L tokens each with D dimensions.
STEP 4

Attention — Where the Magic Happens ✨

scores = QKᵀ / √d  →  weights = softmax(scores)  →  output = weights × V

For each head independently: compute scores, scale, softmax, then weighted average of values.

📖 Analogy: Each reader walks through the library asking "which books are relevant to me?" The attention weights are like borrowing priority — a score for how much each book's content matters to your question.
⚠️ What if we didn't scale? With large dims, QKᵀ blows up → softmax becomes argmax → always picks ONE token. No smooth blending!
🔑 Key Insight: Each output token is a blend of all inputs, weighted by relevance.
STEP 5

Merge — Reassembling Perspectives

Concatenate all head outputs back into (seq_len × embed_dim).

📖 Analogy: Each librarian returns with their own stack of relevant books — you gather all the stacks onto one shared reading table, combining every perspective.
🔑 Key Insight: (H, L, D) → (L, H×D) = (L, embed_dim). Just memory rearrangement, no computation.
STEP 6

Output Projection — Final Mix

Output = Merged × Woᵀ

The merged matrix is multiplied by the output weight matrix Wo.

📖 Analogy: The head librarian (Wo) reviews all the stacks on the table and composes a final research summary — blending insights from every librarian into one coherent answer.
⚠️ What if we didn't? The model would have no way to learn that "head 1 is about syntax and should be weighted more."
🔑 Key Insight: The ENTIRE MHA block: Output = Concat(head₁, head₂, …) × Wo, where each headᵢ = softmax(QᵢKᵢᵀ/√d) × Vᵢ.
SUMMARY

The Complete MHA Pipeline

StepOperationShape Transform
0Input Q, K, V(L, E) × 3
1Linear projection(L, E) → (L, E) × 3
2Reshape / split heads(L, E) → (L, H, D)
3Transpose(L, H, D) → (H, L, D)
4Scaled dot-product attention→ weights (H, L, L), output (H, L, D)
5Merge (concatenate heads)(H, L, D) → (L, E)
6Output projection(L, E) × Woᵀ → (L, E)
MHA(Q,K,V) = Concat(head₁, …, headₕ) · Wo

headᵢ = softmax(Qᵢ Kᵢᵀ / √dₖ) Vᵢ

All operations are differentiable — the entire block learns end-to-end via backpropagation.
That's the beauty of the Transformer architecture!