1. KV Cache

Transformer models generate tokens autoregressively, producing one token at a time, where each new token depends on all previous tokens. In vanilla transformer layers, $$ \operatorname{Attention}(Q, K, V) = \operatorname{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V $$ Assume a sequence “The cat sat on the.” To generate a new token “mat” at time step t, we only need to use “mat” as the query to attend to all previous K and V of “The cat sat on the,” $$ \operatorname{Attention}(Q_t, K_{1:t}, V_{1:t}) = \operatorname{softmax}\left(\frac{Q_tK_{1:t}^T}{\sqrt{d}}\right)V_{1:t} $$ We observe that:

  • $Q_t$, $K_t$, and $V_t$ depend on the current input token (i.e., the last generated token).
  • $Q_{1:t-1}$ will never be used after the $t-1$ step.
  • $K_{1:t-1}$ and $V_{1:t-1}$ will never be changed after the $t-1$ step.

Therefore, we can save $K$ and $V$ and reuse them in the next step, skipping the $K$ and $V$ calculations to accelerate inference time. It reduces total computation from $O(T^2)$ to $O(T)$, trading memory for speed.

Assume the batch size is $B$, sequence length is $T$, embedding dimension is $d$, the number of layers is $L$, and the precision is $P$ in bytes. The memory consumption is $2 * B * T * d * L * P,$ where $2$ denotes $K$ and $V$.

2. Implementation

Code is adapted from [1] and [2].

class KVCache:
    def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers):
        self.kv_shape = (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
        self.kv_cache = None
        self.pos = 0
    
    def insert_kv(self, layer_idx, k, v):
        if self.kv_cache is None:
            self.kv_cache = torch.empty(self.kv_shape, dtype=k.dtype, device=k.device)
            
        B, H, T_add, D = k.size() # Prefilling: T_add = 100 (for example); Decoding: T_add = 1
        t0, t1 = self.pos, self.pos+T_add
        
        # Insert new K, V at current position
        self.kv_cache[layer_idx, 0, :, :, t0:t1] = k
        self.kv_cache[layer_idx, 1, :, :, t0:t1] = v
        
        # Return full cache up to current position
        key_view = self.kv_cache[layer_idx, 0, :, :, :t1]
        value_view = self.kv_cache[layer_idx, 1, :, :, :t1]
        
        # After LAST layer, we have processed all layers and need to update the pos
        if layer_idx == self.kv_cache.size(0) - 1:
            self.pos = t1
        
        return key_view, value_view
    
    def prefill(self, other):
        dtype, device = other.kv_cache.dtype, other.kv_cache.device
        self.kv_cache = torch.empty(self.kv_shape, dtype=dtype, device=device)
        self.kv_cache = [:, :, :, :, :other.pos, :] = other.kv_cache[:, :, :, :, :other.pos, :]
        self.pos = other.pos
        

class Engine:
    def generate(self, tokens, max_tokens, ...):
        # Phase 1: Prefill - process entire prompt
        kv_cache_prefill = KVCache(batch_size=1, seq_len=len(tokens), ...)
        ids = torch.tensor([tokens], dtype=torch.long, device=device)
        logits = self.model.forward(ids, kv_cache=kv_cache_prefill)
        
        # Clone cache for decode phase (potentially with larger batch for multiple samples)
        kv_cache_decode = KVCache(batch_size=num_samples, seq_len=len(tokens)+max_tokens, ...)
        kv_cache_decode.prefill(kv_cache_prefill)
        
        # Phase 2: Decode - one token at a time
        while num_generated < max_tokens:
            logits = self.model.forward(ids, kv_cache=kv_cache_decode)
            next_token = sample(logits[:, -1, :])
            
            yield next_token
            
            # Next iteration: forward only the new token
            ids = torch.tensor([[next_token]], dtype=torch.long, device=device)

            
class CausalSelfAttention(nn.Module):
    def forward(self, x, cos_sin, kv_cache):
        B, T, C = x.size()
        
        # Compute Q, K, V for new tokens only
        q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
        k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
        v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
        
        # Apply RoPE and QK-norm BEFORE caching
        cos, sin = cos_sin
        q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
        q, k = norm(q), norm(k)
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
        
        # Insert into cache, retrieve full K, V history
        if kv_cache is not None:
            k, v = kv_cache.insert_kv(self.layer_idx, k, v)
        
        Tq = q.size(2)  # Number of new queries
        Tk = k.size(2)  # Total keys (cached + new)
        
        # Attention with appropriate masking
        if kv_cache is None or Tq == Tk:
            # Training or full prefill: standard causal attention
            # Queries:  q1  q2  q3  q4
            # Keys:     k1  k2  k3  k4
            # Causal Mask (1 = attend, 0 = mask):
            #       k1  k2  k3  k4
            # q1 [  1   0   0   0  ]
            # q2 [  1   1   0   0  ]
            # q3 [  1   1   1   0  ]
            # q4 [  1   1   1   1  ]
            y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        elif Tq == 1:
            # Single token decode: attend to all cached keys
            # Query:    q5 (just one!)
            # Keys:     k1  k2  k3  k4  k5 (from cache + new)
            # No mask needed:
            #       k1  k2  k3  k4  k5
            # q5 [  1   1   1   1   1  ]
            y = F.scaled_dot_product_attention(q, k, v, is_causal=False)
        else:
            # Chunk decode: custom mask for prefix + causal within chunk
            # Query:    q5  q6  q7
            # Keys:     k1  k2  k3  k4  k5  k6  k7
            #           ^^^^^^^^^^^^^^  ^^^^^^^^^^
            #           cached (prefix) new (causal within)
            # Custom Mask:
            #       k1  k2  k3  k4   k5  k6  k7
            # q5 [  1   1   1   1 |  1   0   0  ]
            # q6 [  1   1   1   1 |  1   1   0  ]
            # q7 [  1   1   1   1 |  1   1   1  ]
            attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device)
            prefix_len = Tk - Tq
            attn_mask[:, :prefix_len] = True
            attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool))
            y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
        
        return self.c_proj(y.transpose(1, 2).contiguous().view(B, T, -1))

References

[1] https://github.com/karpathy/nanochat/blob/master/nanochat/engine.py

[2] https://github.com/karpathy/nanochat/blob/master/nanochat/gpt.py