KV-Cache
This article walks through a simple but rigorous demonstration of KV caching in transformer autoregressive decoding.
The Attention Mechanism
In autoregressive generation, we generate tokens one at a time. Given a sequence of input embeddings , we feed the embedding matrix of size into the model to predict . Note that:
- denotes batch size
- denotes sequence length
- denotes model dimension
- denotes output vocabulary size
We will focus on the self-attention block (where the KV-cache is pertinent). For exposition, we will omit the batch size dimension for now and consider a single batch. The self attention block computes:
Where:
- (queries)
- (keys)
- (values)
The transformer produces an output of shape at the last layer. The output at the final position (call it ) is finally used to generate the logits for each token to predict like so:
Worked Example
Let's make this concrete. Suppose we have:
- 2 tokens already: (each is a 3-dimensional embedding vector)
- Projection dimension
Let's setup our input and projection matrices. Notice that each row of corresponds to one position in the sequence.
Step 1: Generate token 3 (predict given )
We compute Q, K, V for all positions:
Let's write out the full attention computation (omitting softmax and scaling for clarity):
After applying the causal mask (setting future positions to ) and softmax, we get attention weights :
Importantly, observe that the bottom row of only depends on and . is not used at all. We will see that is in fact not needed to generate the next token shortly.
Finally, the output is :
To predict , we only need the bottom row of . But this accordingly only uses the bottom row of , which only depends on . Hence, we can simplify the computation of the bottom row of as follows:
Summary: To decode for step , we need:
- Only (the query for the last position)
- All of
- All of
Now suppose we want to generate token . We will need:
Key Insight: Due to the causal mask, and did not change. Hence we can cache them and just append and to get . This is the main idea of the KV Cache.
Naive Implementation (No Cache)
Without caching, at each generation step we recompute K and V for the entire sequence:
import torch
import torch.nn as nn
import torch.nn.functional as F
def naive_attention(X, W_q, W_k, W_v):
"""
X: (batch, seq_len, d_model) - full sequence
Returns: (batch, seq_len, d_model)
"""
Q = X @ W_q # (batch, seq_len, d_k)
K = X @ W_k # (batch, seq_len, d_k)
V = X @ W_v # (batch, seq_len, d_v)
d_k = K.shape[-1]
scores = (Q @ K.transpose(-2, -1)) / (d_k ** 0.5)
# Causal mask: prevent attending to future tokens
seq_len = X.shape[1]
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
scores = scores.masked_fill(mask, float('-inf'))
attn = F.softmax(scores, dim=-1)
return attn @ V
def naive_generate(prompt_embeddings, W_q, W_k, W_v, W_out, num_tokens):
"""
Generate tokens without KV cache.
Each step recomputes attention over the ENTIRE sequence.
"""
X = prompt_embeddings # (1, prompt_len, d_model)
for step in range(num_tokens):
# Recompute attention for ALL positions (wasteful!)
attn_out = naive_attention(X, W_q, W_k, W_v)
# Get logits for last position only
last_hidden = attn_out[:, -1, :] # (1, d_model)
logits = last_hidden @ W_out # (1, vocab_size)
# Sample next token (simplified: just argmax)
next_token_id = logits.argmax(dim=-1)
# Embed next token and append to sequence
# (In practice, look up embedding; here we simulate)
next_embedding = torch.randn(1, 1, X.shape[-1])
X = torch.cat([X, next_embedding], dim=1)
print(f"Step {step+1}: Computed K,V for {X.shape[1]} positions")
return X
Complexity without cache: For generating new tokens with prompt length :
- Step 1: Compute K,V for positions
- Step 2: Compute K,V for positions
- ...
- Step T: Compute K,V for positions
Total K,V computations:
With KV Cache
The KV cache stores previously computed keys and values. At each step, we only compute K,V for the new token and concatenate with the cache:
def cached_attention(x_new, W_q, W_k, W_v, kv_cache):
"""
x_new: (batch, 1, d_model) - ONLY the new token
kv_cache: tuple of (K_cache, V_cache) or None
K_cache: (batch, prev_seq_len, d_k)
V_cache: (batch, prev_seq_len, d_v)
Returns: attention output, updated cache
"""
# Compute Q, K, V for the NEW token only
q_new = x_new @ W_q # (batch, 1, d_k)
k_new = x_new @ W_k # (batch, 1, d_k)
v_new = x_new @ W_v # (batch, 1, d_v)
if kv_cache is None:
K = k_new
V = v_new
else:
K_cache, V_cache = kv_cache
# Append new K,V to cache
K = torch.cat([K_cache, k_new], dim=1) # (batch, seq_len, d_k)
V = torch.cat([V_cache, v_new], dim=1) # (batch, seq_len, d_v)
# Attention: new query attends to ALL keys
d_k = K.shape[-1]
scores = (q_new @ K.transpose(-2, -1)) / (d_k ** 0.5) # (batch, 1, seq_len)
# No causal mask needed: q_new is at position seq_len,
# and we only have K,V up to seq_len (no future tokens)
attn = F.softmax(scores, dim=-1)
out = attn @ V # (batch, 1, d_v)
# Return output and updated cache
return out, (K, V)
def cached_generate(prompt_embeddings, W_q, W_k, W_v, W_out, num_tokens):
"""
Generate tokens WITH KV cache.
Two phases:
1. Prefill: Process entire prompt at once, build initial cache
2. Decode: Generate one token at a time, updating cache incrementally
"""
# === PREFILL PHASE ===
# Process entire prompt to build initial KV cache
X = prompt_embeddings # (1, prompt_len, d_model)
K_cache = X @ W_k # (1, prompt_len, d_k)
V_cache = X @ W_v # (1, prompt_len, d_v)
# Compute attention for prompt (need full Q for all positions)
Q = X @ W_q
d_k = K_cache.shape[-1]
scores = (Q @ K_cache.transpose(-2, -1)) / (d_k ** 0.5)
seq_len = X.shape[1]
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
scores = scores.masked_fill(mask, float('-inf'))
attn = F.softmax(scores, dim=-1)
attn_out = attn @ V_cache
last_hidden = attn_out[:, -1:, :] # Keep dim for next iteration
kv_cache = (K_cache, V_cache)
print(f"Prefill: Built cache with {K_cache.shape[1]} positions")
# === DECODE PHASE ===
for step in range(num_tokens):
# Get logits from last hidden state
logits = last_hidden.squeeze(1) @ W_out
next_token_id = logits.argmax(dim=-1)
# Embed next token (simulated)
next_embedding = torch.randn(1, 1, X.shape[-1])
# Compute attention with cache - only process NEW token!
last_hidden, kv_cache = cached_attention(
next_embedding, W_q, W_k, W_v, kv_cache
)
print(f"Step {step+1}: Computed K,V for 1 position (cache size: {kv_cache[0].shape[1]})")
return kv_cache
Complexity with cache: For generating new tokens with prompt length :
- Prefill: Compute K,V for positions (once)
- Each decode step: Compute K,V for 1 position
Total K,V computations:
Comparison
| Aspect | Without Cache | With Cache |
|---|---|---|
| K,V computations per step | ||
| Total K,V computations | ||
| Memory | ||
| Attention computation | per step |
The KV cache trades memory for computation. This is almost always worthwhile because:
- Matrix multiplications (K,V projections) are expensive
- Memory is typically available
- The speedup is dramatic for long sequences
Complete Working Example
import torch
import torch.nn.functional as F
torch.manual_seed(42)
# Dimensions
batch_size = 1
d_model = 64
d_k = d_v = 32
vocab_size = 100
prompt_len = 5
num_generate = 3
# Initialize weights
W_q = torch.randn(d_model, d_k) / (d_model ** 0.5)
W_k = torch.randn(d_model, d_k) / (d_model ** 0.5)
W_v = torch.randn(d_model, d_v) / (d_model ** 0.5)
W_out = torch.randn(d_v, vocab_size) / (d_v ** 0.5)
# Simulate prompt embeddings
prompt = torch.randn(batch_size, prompt_len, d_model)
print("=== WITHOUT KV CACHE ===")
naive_generate(prompt.clone(), W_q, W_k, W_v, W_out, num_generate)
print("\n=== WITH KV CACHE ===")
cached_generate(prompt.clone(), W_q, W_k, W_v, W_out, num_generate)
Output:
=== WITHOUT KV CACHE ===
Step 1: Computed K,V for 6 positions
Step 2: Computed K,V for 7 positions
Step 3: Computed K,V for 8 positions
=== WITH KV CACHE ===
Prefill: Built cache with 5 positions
Step 1: Computed K,V for 1 position (cache size: 6)
Step 2: Computed K,V for 1 position (cache size: 7)
Step 3: Computed K,V for 1 position (cache size: 8)
Why Only Cache K and V (Not Q)?
During autoregressive generation:
- Query (Q): We only need the query for the current position to compute attention scores. Past queries are never reused.
- Keys (K): The current query must attend to all past keys. These are reused every step.
- Values (V): Once we have attention weights, we aggregate all past values. These are reused every step.
Mathematically, for generating token at position :
Only is new. All for were computed in previous steps.
Multiple Layers: Each Layer Has Its Own Cache
Real transformers stack self-attention layers. Each layer maintains its own KV cache because:
- Each layer has different projection weights: Layer has its own and
- Each layer receives different input: The input to layer is the output of layer
Therefore, even for the same token position, the K and V values differ across layers.
def multi_layer_cached_generate(prompt_embeddings, layers, W_out, num_tokens):
"""
layers: list of dicts, each with W_q, W_k, W_v for that layer
KV cache structure: list of (K_cache, V_cache) tuples, one per layer
"""
num_layers = len(layers)
X = prompt_embeddings # (1, prompt_len, d_model)
# === PREFILL: Build KV cache for ALL layers ===
kv_caches = []
hidden = X
for layer_idx, layer in enumerate(layers):
W_q, W_k, W_v = layer['W_q'], layer['W_k'], layer['W_v']
# Compute K, V for this layer's input
K = hidden @ W_k # (1, prompt_len, d_k)
V = hidden @ W_v # (1, prompt_len, d_v)
kv_caches.append((K, V))
# Compute attention output (with causal mask)
Q = hidden @ W_q
d_k = K.shape[-1]
scores = (Q @ K.transpose(-2, -1)) / (d_k ** 0.5)
seq_len = hidden.shape[1]
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
scores = scores.masked_fill(mask, float('-inf'))
attn = F.softmax(scores, dim=-1)
hidden = attn @ V # Output becomes input to next layer
last_hidden = hidden[:, -1:, :]
print(f"Prefill: Built {num_layers} KV caches, each with {X.shape[1]} positions")
# === DECODE: Update each layer's cache incrementally ===
for step in range(num_tokens):
logits = last_hidden.squeeze(1) @ W_out
next_token_id = logits.argmax(dim=-1)
next_embedding = torch.randn(1, 1, X.shape[-1])
# Forward through all layers, updating each cache
hidden = next_embedding
new_caches = []
for layer_idx, layer in enumerate(layers):
W_q, W_k, W_v = layer['W_q'], layer['W_k'], layer['W_v']
K_cache, V_cache = kv_caches[layer_idx]
# Compute K, V for NEW token only at this layer
q_new = hidden @ W_q # (1, 1, d_k)
k_new = hidden @ W_k # (1, 1, d_k)
v_new = hidden @ W_v # (1, 1, d_v)
# Append to this layer's cache
K = torch.cat([K_cache, k_new], dim=1)
V = torch.cat([V_cache, v_new], dim=1)
new_caches.append((K, V))
# Attention: new query attends to all cached K, V
d_k = K.shape[-1]
scores = (q_new @ K.transpose(-2, -1)) / (d_k ** 0.5)
attn = F.softmax(scores, dim=-1)
hidden = attn @ V # (1, 1, d_v) - input to next layer
kv_caches = new_caches
last_hidden = hidden
cache_size = kv_caches[0][0].shape[1]
print(f"Step {step+1}: Updated {num_layers} caches (each now size {cache_size})")
return kv_caches
Memory for KV cache with layers:
The factor of 2 is for K and V. For large models (many layers, large ), this becomes the dominant memory cost during inference.
Example: Llama 2 70B has:
- layers
- per head, 64 heads → 8192 total
- For a 4096-token sequence in fp16:
This is why KV cache compression techniques (quantization, eviction, etc.) are active research areas.