KV Cache and Inference Optimization
What Is the KV Cache
Concept
During autoregressive generation, the model generates one token at a time. At each step, it must compute the attention output for the current token, which requires the Key and Value matrices for all preceding tokens.
Without caching, this would be catastrophically expensive: - Generating token 100 requires computing K and V for tokens 0–99 from scratch - Generating token 101 requires the same K and V for tokens 0–99 again (redundant work) - Total compute scales as O(n²) in the number of generated tokens
The KV cache solution: Store the Key and Value projections for all past tokens. At each new generation step, compute K and V only for the new token and append them to the cache.
Without KV cache (step n):
Compute K, V for tokens 0..n → O(n) matrix ops
Total across all steps: O(n²)
With KV cache (step n):
Load K[0..n-1], V[0..n-1] from cache
Compute K[n], V[n] for new token only
Append K[n], V[n] to cache
Total across all steps: O(n) matrix ops
Two phases of inference:
| Phase | What happens | Bottleneck |
|---|---|---|
| Prefill | Process entire input prompt in parallel (like training) | Compute-bound (many tokens processed at once) |
| Decode | Generate one token at a time, reading KV cache | Memory-bandwidth-bound (reading large KV cache per step) |
Tricky Q: Why is prefill faster per-token than decoding?
Prefill processes all input tokens in parallel — the GPU is compute-bound (utilization near peak). Decoding generates one token at a time — the GPU must read the entire KV cache from HBM for each single-token step. Reading a large cache for one token wastes compute capacity → memory-bandwidth bound.
KV Cache Memory Math
Concept
KV cache memory is a critical capacity constraint in production. Every generated sequence consumes cache memory proportional to its length.
Formula per token per layer:
Factor of 2: one K matrix + one V matrix.
Total KV cache for a single sequence:
Example — LLaMA-3 8B in BF16 (2 bytes): - n_layers = 32, n_kv_heads = 8, d_head = 128, seq_len = 4096 (4K context) - Per token: 2 × 8 × 128 × 2 = 4096 bytes = 4 KB per token - For 4K tokens: 4 KB × 4096 tokens × 32 layers = 512 MB - For 128K tokens: 4 KB × 131072 tokens × 32 layers = 16 GB — just the KV cache!
Practical implication: Long-context inference is as much a memory problem as a compute problem. A 70B model serving 128K-context sequences needs enormous GPU memory, most of it for KV caches.
def kv_cache_memory_gb(n_layers, seq_len, n_kv_heads, d_head, dtype_bytes=2):
"""Estimate KV cache memory in GB for a single sequence."""
bytes_total = n_layers * seq_len * 2 * n_kv_heads * d_head * dtype_bytes
return bytes_total / (1024**3)
# LLaMA-3 8B at different context lengths
for ctx in [2048, 8192, 32768, 131072]:
gb = kv_cache_memory_gb(n_layers=32, seq_len=ctx, n_kv_heads=8, d_head=128)
print(f" {ctx:>7,} tokens: {gb:.2f} GB KV cache")
# 2,048 tokens: 0.03 GB
# 8,192 tokens: 0.13 GB
# 32,768 tokens: 0.50 GB
# 131,072 tokens: 2.00 GB ← just KV cache for one sequence
MHA vs MQA vs GQA
Concept
These three variants trade KV cache size for generation quality.
Multi-Head Attention (MHA): - Each of the h query heads has its own distinct K and V projection - KV cache: 2 × h × d_head per token per layer - Best quality, but most memory-intensive KV cache
Multi-Query Attention (MQA): - All h query heads share a single K and V projection - KV cache: 2 × 1 × d_head per token per layer (1/h of MHA) - h× reduction in KV cache memory - Quality slightly lower — all heads see the same K/V space - Used by: Falcon, early efficient models, Gemma-1
Grouped-Query Attention (GQA): - h query heads split into G groups; each group shares one K/V pair - KV cache: 2 × G × d_head per token per layer (G/h reduction vs MHA) - LLaMA-3 8B: h=32, G=8 → 4× smaller KV cache than MHA - Quality nearly identical to MHA for most tasks - Used by: LLaMA-3, Gemma 2, Mistral — the current production standard
MHA (h=32): K₁V₁ K₂V₂ K₃V₃ ... K₃₂V₃₂ (32 KV pairs per token)
GQA (G=8): K₁V₁ K₂V₂ ... K₈V₈ (8 KV pairs per token, groups of 4)
MQA: K V (1 KV pair per token)
Memory comparison for LLaMA-3 8B at 32K context:
MHA (G=32): 32 × 32768 × 2 × 32 × 128 × 2 bytes = 16 GB
GQA (G=8): 32 × 32768 × 2 × 8 × 128 × 2 bytes = 4 GB ← LLaMA-3 actual
MQA (G=1): 32 × 32768 × 2 × 1 × 128 × 2 bytes = 0.5 GB
Paged Attention (vLLM)
Concept
Production LLM serving has a fundamental memory fragmentation problem. Traditional KV cache allocation reserves contiguous memory blocks per sequence at the maximum context length — this wastes memory because: 1. Most sequences are much shorter than the maximum context 2. Memory is reserved upfront but used gradually as tokens are generated 3. Different sequences have different lengths → external fragmentation
Paged Attention (Kwon et al., 2023 — the key innovation behind vLLM) borrows the virtual memory concept from operating systems:
Physical GPU memory is divided into fixed-size "blocks" (e.g., 16 tokens each)
For each sequence, a "page table" maps logical positions to physical blocks:
Sequence A: [block 3, block 7, block 12, ...] (non-contiguous physical)
Sequence B: [block 1, block 4, ...]
As a sequence grows, new blocks are allocated on demand — no pre-reservation
When a sequence finishes, its blocks are freed and immediately reusable
Results: - Near-zero memory waste from fragmentation (< 4% vs ~60–80% with contiguous allocation) - Higher GPU utilization → more sequences in flight simultaneously → 2–4× higher throughput - Enables efficient KV cache sharing for prefix caching (see below)
Prefix Caching
Concept
Many production workloads have repeated prompt prefixes: - Chatbot: same system prompt for every conversation - RAG: same retrieved context chunks for many queries - Agent: same tool definitions in every turn
Prefix caching: Compute the KV cache for the shared prefix once, store it, and reuse it across all requests that share that prefix.
Request 1: [SYSTEM_PROMPT][DOCS][User: "Summarize?"]
↑ compute KV cache
Request 2: [SYSTEM_PROMPT][DOCS][User: "What is the key point?"]
↑ reuse KV cache from request 1 (prefix match)
↑ only compute KV for "What is the key point?"
ROI: - System prompts are often 500–2000 tokens - At 1000 req/min with a 1000-token system prompt, prefix caching eliminates reprocessing that prefix 1000 times/min - Effective latency improvement on prefill: often 50–90% reduction for cacheable content
Paged Attention's block-based addressing makes prefix caching efficient — blocks that are identical across requests can be shared in the physical page table (copy-on-write).
Speculative Decoding
Concept
Speculative decoding uses a small "draft" model to propose K tokens at once, then verifies them with the large "target" model in a single forward pass. This converts the sequential bottleneck of K decode steps into one batch verification step.
Why this works: - The small draft model (e.g., 1B) is fast but lower quality - The large target model (e.g., 70B) is high quality but slow - Key insight: if the draft model gets the next K tokens right (which it often does for common phrases), the target model can accept all K in one forward pass — K tokens for the cost of ~1 decode step
Standard decode (3 tokens):
step 1: target model → token A (full forward pass)
step 2: target model → token B (full forward pass)
step 3: target model → token C (full forward pass)
Total: 3 full forward passes
Speculative decode (3 tokens):
step 1: draft model → proposes [A, B, C] (3 cheap forward passes)
step 2: target model verifies [A, B, C] in ONE batch forward pass
- If all 3 correct: accept all, advance 3 positions
- If only first 2 correct: accept 2, reject C, sample correction
Total: 1 full forward pass (typically)
Speedup factors: - 2–3× for common tasks with high draft model acceptance rates - Speedup is higher for tasks with more predictable token sequences (code, structured output, common phrases) - Speedup is lower for highly creative or diverse generation
Models: - Medusa: adds multiple decoding heads to the original model (no separate draft model) - SpecInfer, Lookahead Decoding: variations on the theme - Used in: Anthropic's production Claude, Google's Gemini serving
Continuous Batching
Concept
Static batching (naive approach): wait until a fixed batch of N requests is assembled, run one forward pass for all N, return all results. Problem: different sequences finish at different times — some GPUs sit idle waiting for the longest sequence in the batch to finish.
Static batch of 4 sequences:
seq A: ████ (4 tokens needed)
seq B: ██ (2 tokens needed)
seq C: ██████████ (10 tokens needed)
seq D: ███ (3 tokens needed)
GPU waits until seq C (10 tokens) finishes → seqs A, B, D waste 6, 8, 7 slots
Continuous batching (in-flight batching): After each decode step, check if any sequences have completed. Remove completed sequences and insert new waiting requests into the batch immediately.
Step 1: [A, B, C, D]
Step 2: [A, B, C, D] ← B completes
Step 3: [A, E, C, D] ← new request E fills B's slot immediately
...
Why this matters: - GPU utilization jumps from ~30–50% (static) to ~80–95% (continuous) - Throughput (tokens/second) increases by 5–10× in typical workloads - All production LLM serving systems (vLLM, TGI, SGLang) use continuous batching
Decode Latency vs Throughput Trade-off
Concept
Batch size and latency vs throughput: - Small batch (1 request): lowest latency (TTFT + decode), GPU underutilized, low throughput - Large batch (many requests): GPU fully utilized, high throughput, but each individual request waits longer (queuing + longer decode steps) - Latency and throughput are fundamentally at odds — you must tune batch size for your SLO
Time-To-First-Token (TTFT): How long from request submission to the first generated token. Dominated by: 1. Queue waiting time (if server is busy) 2. Prefill compute (processing the input prompt)
Tokens Per Second (TPS) / throughput: How fast new tokens are generated after TTFT. Dominated by: 1. Decode speed per step (KV cache read + attention + FFN) 2. Number of concurrent requests sharing the GPU
Rule of thumb: For user-facing chat applications, TTFT < 500ms is usually required. For batch document processing, throughput matters more than TTFT.
Comprehensive Speed-Up Techniques Reference
Concept
A consolidated reference of all major LLM inference and training speed-up techniques. Many are covered in depth elsewhere — this table gives you the full landscape for interviews.
| Technique | How It Works | Speedup / Savings | Where Covered |
|---|---|---|---|
| Quantization | Reduce weight/activation precision (FP16→INT8→INT4) | 2–4× memory, 1.5–3× latency | GPU & Hardware |
| KV-Cache Quantization | Store KV cache in INT8/INT4 instead of FP16 | Reduces KV memory 2–4× | This file |
| Flash Attention | Tiling + recomputation to avoid O(n²) memory — compute stays O(n²) but memory is O(n) | 2–4× memory, 2× speed | Attention Mechanisms |
| Speculative Decoding | Small draft model proposes K tokens; large target verifies all in one pass | 2–3× decode speedup | This file |
| LoRA (at inference) | Merged LoRA weights add zero latency; multiple adapters can share the same base | Zero overhead vs base | Fine-Tuning |
| Pruning | Remove low-magnitude weights (unstructured) or entire heads/layers (structured). Structured pruning is inference-friendly; unstructured needs sparse hardware support. | 10–50% size, 10–30% speedup | — |
| Knowledge Distillation | Train a smaller "student" model to mimic a larger "teacher" via soft probability targets (not just hard labels). Result: student achieves near-teacher quality at fraction of size. | 3–10× smaller model | — |
| Weight Sharing | Share weight matrices across layers or sub-components (ALBERT uses cross-layer parameter sharing). Reduces model size without full distillation pipeline. | 2–4× smaller | — |
| Sparse Attention | Replace full O(n²) attention with local windows, global tokens, or hash-based routing (Longformer, BigBird, Reformer) | O(n log n) or O(n) attention | Attention Mechanisms |
| Batching & Dynamic Batching | Group multiple requests into one GPU pass; dynamic = fill slots as requests arrive/complete | 5–10× throughput | This file (continuous batching) |
| Model Serving Optimization | Frameworks (vLLM, TGI, SGLang) combining paged attention, continuous batching, prefix caching in one stack | Combined 10–20× improvement | Production Deployment |
| Tensor Parallelism | Split individual weight matrices across GPUs column/row-wise — each GPU holds a slice | Linear latency scaling with GPU count | GPU & Hardware |
| Pipeline Parallelism | Assign different transformer layers to different GPUs — pipeline them with micro-batches | Enables models too large for one GPU | GPU & Hardware |
| Paged Attention | Virtual memory for KV cache — non-contiguous blocks, eliminates fragmentation, enables prefix sharing | Near 100% GPU memory utilization | This file |
| Mixed Precision Inference | Run forward pass in FP16/BF16 (fast matrix ops) while keeping master weights in FP32 for numerical stability. Modern GPUs have dedicated FP16/BF16 tensor cores. | 2× speed vs FP32, same quality as FP32 | GPU & Hardware |
| Early Exit / Token-Level Pruning | Shallow layers output confident predictions early — skip remaining layers for "easy" tokens or inputs. Works best on classification; harder to implement for generation. | 20–50% compute reduction on easy inputs | — |
Most impactful combination in production:
Quantization (INT8/INT4) → halve memory
+ Flash Attention → efficient long context
+ Paged Attention (vLLM) → max GPU utilization
+ Continuous batching → max throughput
+ Speculative decoding (optional) → latency for interactive use
Pruning vs Distillation — when to use each: - Pruning: Already have a large model you want to compress; best for structured pruning (remove whole heads/layers); requires hardware that exploits sparsity for unstructured gains. - Distillation: Want a general-purpose smaller model trained from scratch with teacher guidance; better final quality than pruning at the same size; requires training pipeline.
Study Notes
Must-know for interviews: - KV cache stores K and V for all past tokens per layer — avoids O(n²) recomputation during decode - Memory per token per layer = 2 × n_kv_heads × d_head × bytes (know how to derive this) - GQA reduces KV cache memory by sharing K/V across groups of heads — LLaMA-3, Gemma use this - Paged Attention (vLLM) uses virtual memory for KV blocks — eliminates fragmentation, enables prefix sharing - Prefix caching reuses KV cache for shared prompt prefixes — high ROI for chatbot system prompts - Speculative decoding: draft proposes K tokens, target verifies in one pass → 2–3× decode speedup - Continuous batching: remove finished sequences and insert new ones mid-batch → 5–10× throughput vs static batching - Prefill is compute-bound; decode is memory-bandwidth-bound
Quick recall Q&A: - What two phases does LLM inference have? Prefill (process prompt in parallel) and decode (generate one token at a time). - Why does a large batch size improve throughput but hurt latency? More sequences share the GPU → higher utilization → more tokens/second total. But each sequence waits longer for the batch to cycle → higher per-request latency. - What is Paged Attention? A virtual memory system for KV cache blocks — non-contiguous physical allocation with page tables, eliminating memory fragmentation. - How does GQA differ from MHA? GQA groups query heads and shares a single K/V pair per group; MHA has unique K/V per head. GQA reduces KV cache by h/G× with minimal quality loss. - When does speculative decoding NOT help? When the draft model acceptance rate is low — i.e., highly creative, diverse generation where the draft model's predictions are often wrong.