Pangram verdict · v3.3
We believe that this document is fully AI-generated
AI likelihood · overall
AIArticle text · 1,464 words · 5 segments analyzed
7 min read2 days ago--Understand the optimization technique in LLMs to speed up token generationPress enter or click to view image in full sizeThe general overview (Image by author).The Big PictureBefore we dive into attention heads, KV caches, and the mechanics of generation, it helps to zoom out and see what an autoregressive language model actually is at a glance.A prompt enters as plain text: “How are you?”. A tokenizer chops it into vocabulary IDs — here 3, 7, 1, 9, prefixed with a BOS ("beginning of sequence") token. Each ID is just an integer pointing into a lookup table: a learned matrix of shape (vocab_size, c) where every row is the embedding vector for one token in the vocabulary. Selecting the rows for our 5 input IDs produces X, a (5, 4) matrix, five tokens, each living in a 4-dimensional embedding space. This is where text leaves the world of symbols and enters the world of vectors. We use toy dimensions for our examples here.From here, X flows through a stack of decoder blocks. Each block is the same architecture, multi-head self-attention followed by an MLP, and each block transforms its input into a refined (5, 4) representation of the same shape. The trick that makes deep transformers trainable is the residual connection wrapped around every block: instead of replacing the input, each block adds to it (X₁ = X + block_output). Information flows along a continuous "residual stream" that each layer edits rather than overwrites. Stack three of these and you get X₃, the final hidden state.The last step inverts the first. The unembedding matrix, often the lookup table transposed, since input and output vocabularies are the same, projects each row of X₃ back into vocabulary space, producing a (5, 12) logits matrix: a score for every vocabulary token at every position. For next-token generation, only the last row matters. Its argmax is the token the model wants to say next. Here, that's token ID 5.That’s the whole forward pass at altitude. The rest of this article zooms in on what happens inside one of those decoder blocks and on the optimization, KV caching, that makes generating long sequences feasible at all.
Let's zoom in and check what happens inside one layer during the first forward pass inside a single decoding layer.Press enter or click to view image in full sizeThe Prefill Forward Pass (Image by author)The Prefill Forward PassBefore a language model can generate a single new token, it has to process the prompt. This step (prefill) runs the entire input sequence through the network in one parallel forward pass. Its job is twofold: produce the first predicted token, and populate the KV cache so that subsequent decode steps stay cheap.Let’s walk through what happens to a 5-token prompt in a tiny model with hidden dimension c = 4, 2 attention heads, and a vocabulary of 12 tokens.From tokens to Q, K, VThe input X arrives as a (5, 4) matrix: 5 tokens, each represented by a 4-dimensional embedding pulled from the lookup table. Three learned projection matrices Wq, Wk, Wv, each of shape (4, 4), transform X into the query, key, and value matrices Q, K, V, all of shape (5, 4).Because we have 2 heads, each (5, 4) matrix is split column-wise into two (5, 2) slices, one slice per head. Each head will compute attention independently in its own 2-dimensional subspace.Attention within a headInside a single head, attention is a weighted lookup. The head’s Q slice (5, 2) is multiplied by the transpose of its K slice to produce a (5, 5) matrix of attention scores — every token's query dotted with every token's key. After scaling and softmax (and a causal mask, since this is an autoregressive model, token t must not see tokens > t), each row of this matrix becomes a probability distribution over "which past tokens should I pull information from."These weights then multiply the head’s V slice (5, 2), yielding the head's output of shape (5, 2): each token now holds a context-aware mix of value vectors from its allowed positions.Concatenation and projectionThe two heads’ outputs are concatenated back into a (5, 4) matrix, then passed through an output projection (4, 4).
The result, X', is again (5, 4), same shape as the input, but every row now reflects information gathered from across the sequence.The MLPEach token’s vector is then sent independently through a two-layer MLP. W_up of shape (4, 8) expands each row to 8 dimensions, GeLU adds non-linearity, and W_down of shape (8, 4) projects back down. The output X₁ is (5, 4)and in a real model, this would feed into the next transformer block. Stack a few of these (here, 3 layers) and you have the full forward pass. Lets assume this is the final layer here.Logits and the first predictionAfter the final layer, the (5, 4) hidden states are multiplied by the unembedding matrix (12, 4).T to produce logits of shape (5, 12) , a score for every vocabulary token at every position. For generation, only the last row matters: it tells us what the model thinks comes after token 5. Argmax (or sampling) over that row gives us the first generated token. In our case token ID 5.What the cache holds ontoHere’s the quiet but crucial part: during this single pass, every layer computed K and V of shape (5, 4) for the prompt. Those tensors get stored. They are everything future tokens will ever need to know about the prompt at this layer. The embeddings, the queries, the MLP activations — all discarded. From here on, generation moves into decode mode, processing one new token at a time and reading from this cache instead of redoing the work.So now let’s understand the big picture, what happens when we generate the next token with KV cache.Second Forward Pass with KV Cache (Image by author)The Decode Step with KV CacheOnce prefill is done, the model switches into decode mode. Every subsequent token is generated by a forward pass that looks structurally similar to prefill — but operates on just one row at a time, leaning on the KV cache to remember everything that came before.Let’s continue our example. Prefill predicted token 5, so we now feed token 5 back in as the input for the next step.
One token in, one token outThe new input X is a single row of shape (1, 4) which is just token 5's embedding, looked up from the same table used during prefill. The previous 5 tokens of the prompt are not re-fed. They don't need to be: everything the model will ever need from them at this layer is already sitting in the cache.Multiplying this (1, 4) row by Wq, Wk, Wv (each still (4, 4)) yields a fresh Q, K, and V , each of shape (1, 4). Only the new token gets its query, key, and value computed.Appending to the cacheThe newly computed K and V rows are appended to the cached K and V matrices from the previous step. The cache, which held (5, 4) after prefill, now holds (6, 4), five rows from the prompt plus one fresh row for token 5. This concatenated tensor is what attention will read against.Attention against the cacheSplitting across heads as before, each head now has a query of shape (1, 2) and a full key/value matrix of shape (6, 2) . The dot product Q · K^T produces a (1, 6) score row — token 5's attention weights over all 6 positions, itself included. No causal mask is needed here: every cached position is in the past by construction, so every score is valid.Softmax turns this into a probability distribution, and the weighted sum over V (6, 2) produces a (1, 2) head output. Concatenating both heads gives (1, 4), and the output projection (4, 4) yields X' of shape (1, 4).Why this mattersCompare the shapes. Prefill processed a (5, 4) input and ran every operation on 5 rows in parallel, which is necessary to populate the cache. Decode processes a (1, 4) input and runs every operation on a single row, with the cache silently providing the historical context where it's needed (inside attention). The MLP, the projections, the unembedding, all do 1/N of the work they'd do in a no-cache forward pass.This is the whole reason long-context generation is tractable.
Without the KV cache, every new token would mean redoing the entire prefill, slightly longer each time, the cost of generating N tokens would grow quadratically. With it, each new token costs roughly the same amount of compute, plus a cheap attention sum over a growing cache.Generating a token is, at its core, a small amount of fresh work standing on the shoulders of a lot of remembered work.