.

.

深度学习基础

为理解 vLLM 原理打下必要的基础知识

本部分将介绍理解 vLLM 所需的深度学习基础知识,包括神经网络、Transformer 架构、注意力机制等核心概念。

1 - 神经网络基础

神经网络基础

本章将介绍神经网络的基本概念,为理解 Transformer 和 LLM 打下基础。


引言

如果你是深度学习的初学者,本章将帮助你建立必要的基础知识。我们将从最简单的神经元开始,逐步介绍神经网络的核心概念。

如果你已经熟悉这些内容,可以快速浏览或直接跳到下一章。


1. 从生物神经元到人工神经元

1.1 生物神经元

人脑中有约 860 亿个神经元,它们通过突触相互连接。每个神经元:

  • 通过树突接收来自其他神经元的信号
  • 细胞体中处理这些信号
  • 通过轴突将信号传递给其他神经元

当接收到的信号强度超过某个阈值时,神经元就会"激活"并发出信号。

1.2 人工神经元

人工神经元是对生物神经元的数学抽象:

graph LR
    subgraph 输入
        X1[x₁]
        X2[x₂]
        X3[x₃]
    end

    subgraph 神经元
        W1[w₁] --> SUM((Σ))
        W2[w₂] --> SUM
        W3[w₃] --> SUM
        B[b<br/>偏置] --> SUM
        SUM --> ACT[激活函数<br/>f]
    end

    X1 --> W1
    X2 --> W2
    X3 --> W3
    ACT --> Y[y<br/>输出]

    style SUM fill:#e3f2fd
    style ACT fill:#c8e6c9

数学表达

y = f(w₁x₁ + w₂x₂ + w₃x₃ + b)

或者用向量形式:

y = f(w · x + b)

其中:

  • x:输入向量
  • w:权重向量(需要学习的参数)
  • b:偏置(需要学习的参数)
  • f:激活函数
  • y:输出

1.3 为什么需要激活函数?

如果没有激活函数,神经网络无论多少层,都只能表达线性函数

# 两层无激活函数的网络
y = W₂(W₁x + b₁) + b₂
  = W₂W₁x + W₂b₁ + b₂
  = W'x + b'  # 仍然是线性的!

激活函数引入非线性,使神经网络能够学习复杂的模式。


2. 激活函数详解

2.1 经典激活函数

Sigmoid

σ(x) = 1 / (1 + e^(-x))

特点

  • 输出范围 (0, 1)
  • 适合二分类的输出层
  • 问题:梯度消失(输入很大或很小时,梯度接近 0)

Tanh

tanh(x) = (e^x - e^(-x)) / (e^x + e^(-x))

特点

  • 输出范围 (-1, 1)
  • 零中心化
  • 问题:同样有梯度消失问题

ReLU(Rectified Linear Unit)

ReLU(x) = max(0, x)

特点

  • 计算简单高效
  • 缓解梯度消失
  • 问题:负值区域梯度为 0(“死神经元”)

2.2 现代激活函数

GELU(Gaussian Error Linear Unit)

GELU(x) = x · Φ(x)

其中 Φ(x) 是标准正态分布的累积分布函数。

近似计算

GELU(x) ≈ 0.5x(1 + tanh(√(2/π)(x + 0.044715x³)))

特点

  • 平滑的非线性
  • 在 Transformer 和 LLM 中广泛使用
  • 比 ReLU 表现更好

SiLU / Swish

SiLU(x) = x · σ(x) = x / (1 + e^(-x))

特点

  • 平滑、非单调
  • 与 GELU 类似的效果

2.3 激活函数对比

graph LR
    subgraph 激活函数特性对比
        R[ReLU] --> R1[简单高效]
        R --> R2[可能死神经元]

        G[GELU] --> G1[平滑非线性]
        G --> G2[Transformer 首选]

        S[SiLU] --> S1[平滑非单调]
        S --> S2[LLaMA 使用]
    end
函数公式范围使用场景
ReLUmax(0, x)[0, +∞)传统 CNN
GELUx·Φ(x)(-∞, +∞)BERT, GPT
SiLUx·σ(x)(-∞, +∞)LLaMA, Qwen

3. 张量(Tensor)基础

3.1 什么是张量

张量是多维数组的通称:

graph TD
    subgraph 张量的维度
        S[标量 Scalar<br/>0维<br/>例: 3.14]
        V[向量 Vector<br/>1维<br/>例: [1, 2, 3]]
        M[矩阵 Matrix<br/>2维<br/>例: [[1,2], [3,4]]]
        T[张量 Tensor<br/>N维<br/>例: 3D, 4D, ...]
    end

    S --> V --> M --> T

3.2 张量的形状(Shape)

张量的形状描述了每个维度的大小:

import torch

# 标量
scalar = torch.tensor(3.14)
print(scalar.shape)  # torch.Size([])

# 向量
vector = torch.tensor([1, 2, 3])
print(vector.shape)  # torch.Size([3])

# 矩阵
matrix = torch.tensor([[1, 2], [3, 4], [5, 6]])
print(matrix.shape)  # torch.Size([3, 2])  # 3行2列

# 3D 张量
tensor_3d = torch.randn(2, 3, 4)
print(tensor_3d.shape)  # torch.Size([2, 3, 4])

3.3 LLM 中的常见张量形状

在 LLM 中,我们经常遇到以下形状的张量:

张量形状说明
输入 token IDs[batch_size, seq_len]批次中的 token 索引
Embedding 输出[batch_size, seq_len, hidden_dim]词向量表示
Attention 权重[batch_size, num_heads, seq_len, seq_len]注意力分数
KV Cache[num_layers, 2, batch_size, num_heads, seq_len, head_dim]键值缓存
Logits[batch_size, seq_len, vocab_size]输出概率分布

示例

# 假设配置
batch_size = 4      # 批次大小
seq_len = 512       # 序列长度
hidden_dim = 4096   # 隐藏维度
num_heads = 32      # 注意力头数
head_dim = 128      # 每个头的维度 (hidden_dim / num_heads)
vocab_size = 32000  # 词表大小

# 输入
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
# Shape: [4, 512]

# Embedding 后
embeddings = torch.randn(batch_size, seq_len, hidden_dim)
# Shape: [4, 512, 4096]

# Attention 输出
attention_output = torch.randn(batch_size, seq_len, hidden_dim)
# Shape: [4, 512, 4096]

# 最终 logits
logits = torch.randn(batch_size, seq_len, vocab_size)
# Shape: [4, 512, 32000]

3.4 常用张量操作

import torch

# 创建张量
x = torch.randn(2, 3, 4)  # 随机正态分布
y = torch.zeros(2, 3, 4)  # 全零
z = torch.ones(2, 3, 4)   # 全一

# 形状操作
x.view(2, 12)      # 重塑形状 [2, 3, 4] → [2, 12]
x.reshape(6, 4)    # 重塑形状 [2, 3, 4] → [6, 4]
x.transpose(1, 2)  # 交换维度 [2, 3, 4] → [2, 4, 3]
x.permute(2, 0, 1) # 重排维度 [2, 3, 4] → [4, 2, 3]

# 数学运算
x + y              # 逐元素加法
x * y              # 逐元素乘法
x @ y.transpose(-1, -2)  # 矩阵乘法
torch.softmax(x, dim=-1) # Softmax

# 索引和切片
x[0]               # 第一个样本
x[:, 0, :]         # 所有样本的第一个位置
x[..., -1]         # 最后一个维度的最后一个元素

4. 矩阵乘法与 GPU 加速

4.1 矩阵乘法基础

矩阵乘法是神经网络的核心操作:

C = A × B

其中 A: [M, K], B: [K, N], C: [M, N]

计算复杂度:O(M × K × N)

# PyTorch 矩阵乘法
A = torch.randn(64, 128)   # [M, K]
B = torch.randn(128, 256)  # [K, N]
C = A @ B                   # [M, N] = [64, 256]
# 或者
C = torch.matmul(A, B)

4.2 批量矩阵乘法(BMM)

在处理批次数据时,我们需要批量矩阵乘法:

# 批量矩阵乘法
batch_A = torch.randn(32, 64, 128)   # [batch, M, K]
batch_B = torch.randn(32, 128, 256)  # [batch, K, N]
batch_C = torch.bmm(batch_A, batch_B) # [batch, M, N] = [32, 64, 256]

4.3 为什么 GPU 适合矩阵运算

graph TB
    subgraph CPU
        C1[核心 1]
        C2[核心 2]
        C3[核心 3]
        C4[核心 4]
        C5[...]
        C6[核心 16]
    end

    subgraph GPU
        G1[核心 1]
        G2[核心 2]
        G3[...]
        G4[核心 10000+]
    end

    subgraph 特点对比
        CP[CPU: 少量强核心<br/>适合复杂顺序任务]
        GP[GPU: 大量弱核心<br/>适合简单并行任务]
    end

    style G1 fill:#c8e6c9
    style G2 fill:#c8e6c9
    style G4 fill:#c8e6c9

GPU 优势

特点CPUGPU
核心数4-641000-10000+
单核性能
并行度极高
适合任务复杂逻辑、分支大规模并行计算

矩阵乘法的每个输出元素可以独立计算,非常适合 GPU 的大规模并行架构。

4.4 实际性能对比

import torch
import time

# 创建大矩阵
A = torch.randn(4096, 4096)
B = torch.randn(4096, 4096)

# CPU 计算
start = time.time()
C_cpu = A @ B
cpu_time = time.time() - start

# GPU 计算
A_gpu = A.cuda()
B_gpu = B.cuda()
torch.cuda.synchronize()
start = time.time()
C_gpu = A_gpu @ B_gpu
torch.cuda.synchronize()
gpu_time = time.time() - start

print(f"CPU: {cpu_time:.3f}s, GPU: {gpu_time:.3f}s")
print(f"加速比: {cpu_time/gpu_time:.1f}x")
# 典型输出: CPU: 2.5s, GPU: 0.01s, 加速比: 250x

5. 多层神经网络

5.1 网络结构

多层神经网络(MLP,Multi-Layer Perceptron)由多个层堆叠而成:

graph LR
    subgraph 输入层
        I1((x₁))
        I2((x₂))
        I3((x₃))
    end

    subgraph 隐藏层1
        H11((h₁))
        H12((h₂))
        H13((h₃))
        H14((h₄))
    end

    subgraph 隐藏层2
        H21((h₁))
        H22((h₂))
    end

    subgraph 输出层
        O1((y₁))
        O2((y₂))
    end

    I1 --> H11
    I1 --> H12
    I1 --> H13
    I1 --> H14
    I2 --> H11
    I2 --> H12
    I2 --> H13
    I2 --> H14
    I3 --> H11
    I3 --> H12
    I3 --> H13
    I3 --> H14

    H11 --> H21
    H11 --> H22
    H12 --> H21
    H12 --> H22
    H13 --> H21
    H13 --> H22
    H14 --> H21
    H14 --> H22

    H21 --> O1
    H21 --> O2
    H22 --> O1
    H22 --> O2

5.2 前向传播

前向传播计算从输入到输出的过程:

import torch
import torch.nn as nn

class SimpleMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.layer1 = nn.Linear(input_dim, hidden_dim)
        self.activation = nn.GELU()
        self.layer2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        # x: [batch_size, input_dim]
        x = self.layer1(x)      # [batch_size, hidden_dim]
        x = self.activation(x)   # [batch_size, hidden_dim]
        x = self.layer2(x)       # [batch_size, output_dim]
        return x

# 使用
model = SimpleMLP(768, 3072, 768)
input_data = torch.randn(32, 768)  # batch_size=32
output = model(input_data)  # [32, 768]

5.3 参数量计算

对于一个全连接层 nn.Linear(in_features, out_features)

参数量 = in_features × out_features + out_features(偏置)

示例

# 层: Linear(768, 3072)
# 权重参数: 768 × 3072 = 2,359,296
# 偏置参数: 3072
# 总计: 2,362,368 ≈ 2.36M

6. 语言模型基础概念

6.1 什么是语言模型

语言模型是一个概率模型,用于预测文本序列的概率:

P(w₁, w₂, ..., wₙ) = P(w₁) × P(w₂|w₁) × P(w₃|w₁,w₂) × ... × P(wₙ|w₁,...,wₙ₋₁)

核心任务:给定前文,预测下一个词的概率分布。

graph LR
    I[输入: 'The cat sat on the'] --> LM[语言模型]
    LM --> O[输出概率分布:<br/>mat: 0.3<br/>floor: 0.2<br/>roof: 0.15<br/>...]

6.2 Token 和词表

Token:文本的基本单位,可以是:

  • 单词:“hello”、“world”
  • 子词:“play” + “ing” = “playing”
  • 字符:“h”、“e”、“l”、“l”、“o”

词表(Vocabulary):所有可能 token 的集合

# 常见词表大小
# GPT-2: 50257
# LLaMA: 32000
# Qwen: 151936

# Tokenization 示例
text = "Hello, how are you?"
tokens = tokenizer.encode(text)
# tokens = [15496, 11, 703, 527, 499, 30]

6.3 Embedding

Embedding 将离散的 token ID 转换为连续的向量:

graph LR
    T[Token ID: 15496] --> E[Embedding 层<br/>查表]
    E --> V[向量: [0.1, -0.2, 0.5, ...]]

    subgraph Embedding 矩阵
        EM[矩阵大小: vocab_size × hidden_dim<br/>例: 32000 × 4096]
    end

    style V fill:#c8e6c9
import torch.nn as nn

# Embedding 层
vocab_size = 32000
hidden_dim = 4096

embedding = nn.Embedding(vocab_size, hidden_dim)

# 使用
token_ids = torch.tensor([15496, 11, 703])  # 3 个 token
vectors = embedding(token_ids)  # [3, 4096]

7. 推理 vs 训练

7.1 训练过程

graph LR
    subgraph 前向传播
        I[输入 X] --> M[模型] --> O[输出 Y]
    end

    subgraph 损失计算
        O --> L[Loss 函数]
        T[真实标签] --> L
        L --> LV[Loss 值]
    end

    subgraph 反向传播
        LV --> G[计算梯度]
        G --> U[更新参数]
        U --> M
    end

训练需要

  • 前向传播:计算预测值
  • 损失计算:比较预测与真实值
  • 反向传播:计算梯度
  • 参数更新:使用优化器更新权重

7.2 推理过程

graph LR
    I[输入 X] --> M[模型<br/>权重固定] --> O[输出 Y]

    style M fill:#c8e6c9

推理只需要

  • 前向传播:计算预测值
  • 不需要梯度计算
  • 不需要参数更新

7.3 推理优化的重要性

对比项训练推理
目标学习参数使用参数
频率一次(或少数几次)大量重复
延迟要求不敏感敏感(用户等待)
批次大小可以较大通常较小
内存模式需要存储梯度不需要梯度

推理优化的核心目标

  • 降低延迟(用户体验)
  • 提高吞吐量(服务更多用户)
  • 减少显存占用(支持更大模型或更多并发)

这正是 vLLM 要解决的问题!


8. 本章小结

核心概念

  1. 神经元:接收输入、加权求和、应用激活函数、产生输出
  2. 激活函数:引入非线性,GELU 是 LLM 的常用选择
  3. 张量:多维数组,神经网络中数据的载体
  4. 矩阵乘法:神经网络的核心计算,GPU 加速的关键

关键公式

神经元输出: y = f(w · x + b)
全连接层参数量: in_features × out_features + out_features

LLM 相关

  • Token:文本的基本单位
  • Embedding:将 token ID 转换为向量
  • 语言模型:预测下一个 token 的概率分布
  • 推理:使用训练好的模型进行预测

与 vLLM 的关联

  • 张量形状理解对于理解 vLLM 的内存管理至关重要
  • GPU 并行计算是 vLLM 性能优化的基础
  • 推理优化是 vLLM 的核心目标

思考题

  1. 为什么现代 LLM 普遍使用 GELU 而不是 ReLU?
  2. 如果一个模型有 7B 参数,使用 FP16 精度,需要多少显存存储权重?
  3. 批量矩阵乘法如何帮助提高 GPU 利用率?

下一步

神经网络基础已经介绍完毕,接下来我们将学习 LLM 的核心架构——Transformer:

👉 下一章:Transformer 架构详解

2 - Transformer 架构详解

Transformer 架构详解

本章将详细介绍 Transformer 架构,这是现代大语言模型的基础。


引言

2017 年,Google 发表了划时代的论文《Attention Is All You Need》,提出了 Transformer 架构。这个架构彻底改变了自然语言处理领域,成为了 GPT、BERT、LLaMA 等现代 LLM 的基础。

理解 Transformer 架构是理解 vLLM 优化原理的关键。


1. Transformer 的诞生背景

1.1 RNN/LSTM 的局限

在 Transformer 之前,序列建模主要依赖 RNN(循环神经网络)和 LSTM(长短期记忆网络):

graph LR
    subgraph RNN 的顺序处理
        X1[x₁] --> H1[h₁] --> H2[h₂] --> H3[h₃] --> H4[h₄]
        X2[x₂] --> H2
        X3[x₃] --> H3
        X4[x₄] --> H4
    end

RNN 的问题

问题说明
顺序依赖必须按顺序处理,无法并行
长距离依赖难以捕获长序列中的远距离关系
梯度问题长序列训练时梯度消失或爆炸
训练慢无法充分利用 GPU 并行能力

1.2 Attention 的突破

Transformer 的核心创新是自注意力机制(Self-Attention)

  • 可以直接建立序列中任意两个位置之间的关系
  • 所有位置可以并行计算
  • 没有顺序依赖
graph TB
    subgraph Self-Attention
        X1[x₁] <--> X2[x₂]
        X1 <--> X3[x₃]
        X1 <--> X4[x₄]
        X2 <--> X3
        X2 <--> X4
        X3 <--> X4
    end

2. Transformer 整体架构

2.1 原始 Encoder-Decoder 结构

原始 Transformer 包含 Encoder 和 Decoder 两部分:

graph TB
    subgraph 输入
        I[源序列<br/>例: 英文句子]
    end

    subgraph Encoder
        E1[Embedding + 位置编码]
        E2[Multi-Head Attention]
        E3[Feed Forward]
        E4[× N 层]
        E1 --> E2 --> E3
        E3 -.-> E4
    end

    subgraph Decoder
        D1[Embedding + 位置编码]
        D2[Masked Multi-Head Attention]
        D3[Cross Attention]
        D4[Feed Forward]
        D5[× N 层]
        D1 --> D2 --> D3 --> D4
        D4 -.-> D5
    end

    subgraph 输出
        O[目标序列<br/>例: 中文翻译]
    end

    I --> E1
    E4 --> D3
    D5 --> O

应用场景

  • 机器翻译(英→中)
  • 文本摘要
  • BERT(仅 Encoder)
  • T5(完整 Encoder-Decoder)

2.2 Decoder-Only 架构(现代 LLM)

现代大语言模型(GPT 系列、LLaMA、Qwen 等)都采用 Decoder-Only 架构:

graph TD
    subgraph Decoder-Only 架构
        I[输入 tokens] --> EMB[Embedding Layer]
        EMB --> PE[+ 位置编码]
        PE --> B1[Transformer Block 1]
        B1 --> B2[Transformer Block 2]
        B2 --> B3[...]
        B3 --> BN[Transformer Block N]
        BN --> LN[Layer Norm]
        LN --> LM[LM Head<br/>Linear: hidden → vocab]
        LM --> O[输出 logits]
    end

    style EMB fill:#e3f2fd
    style B1 fill:#c8e6c9
    style B2 fill:#c8e6c9
    style BN fill:#c8e6c9
    style LM fill:#fff9c4

为什么 Decoder-Only 成为主流?

优势说明
统一架构预训练和下游任务使用相同架构
自回归生成天然适合文本生成任务
扩展性参数量扩展效果好
简单高效架构简单,训练推理更高效

2.3 单层 Transformer Block 结构

每个 Transformer Block 包含以下组件:

graph TD
    subgraph Transformer Block
        I[输入 X] --> LN1[Layer Norm 1]
        LN1 --> ATT[Multi-Head<br/>Self-Attention]
        ATT --> ADD1[+]
        I --> ADD1
        ADD1 --> LN2[Layer Norm 2]
        LN2 --> FFN[Feed Forward<br/>Network]
        FFN --> ADD2[+]
        ADD1 --> ADD2
        ADD2 --> O[输出]
    end

    style ATT fill:#bbdefb
    style FFN fill:#c8e6c9

关键组件

  1. Layer Normalization:归一化,稳定训练
  2. Multi-Head Self-Attention:捕获序列内的关系
  3. Feed Forward Network (FFN):非线性变换
  4. 残差连接:缓解梯度消失,帮助信息流动

3. Embedding 层

3.1 Token Embedding

Token Embedding 将离散的 token ID 映射为连续的向量:

import torch.nn as nn

class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size, hidden_dim):
        super().__init__()
        # 创建嵌入矩阵: [vocab_size, hidden_dim]
        self.embedding = nn.Embedding(vocab_size, hidden_dim)

    def forward(self, token_ids):
        # token_ids: [batch_size, seq_len]
        # 返回: [batch_size, seq_len, hidden_dim]
        return self.embedding(token_ids)

# 示例
vocab_size = 32000
hidden_dim = 4096
embedding = TokenEmbedding(vocab_size, hidden_dim)

# 输入 token IDs
token_ids = torch.tensor([[1, 234, 567], [89, 10, 1112]])  # [2, 3]
# 输出嵌入向量
vectors = embedding(token_ids)  # [2, 3, 4096]

3.2 Embedding 矩阵的参数量

参数量 = vocab_size × hidden_dim

示例(LLaMA-2-7B):

参数量 = 32000 × 4096 = 131,072,000 ≈ 131M

占 7B 模型总参数的约 1.9%


4. 位置编码(Positional Encoding)

4.1 为什么需要位置信息

Self-Attention 本身不包含位置信息——它只看 token 之间的关系,不知道它们的顺序。

# 这两个序列的 Attention 计算结果相同(如果没有位置编码)
"猫 追 狗"
"狗 追 猫"

位置编码为每个位置添加独特的信息,让模型知道 token 的顺序。

4.2 正弦位置编码

原始 Transformer 使用正弦/余弦函数:

PE(pos, 2i) = sin(pos / 10000^(2i/d))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d))

其中:

  • pos:位置索引
  • i:维度索引
  • d:总维度数
import numpy as np

def sinusoidal_position_encoding(max_len, hidden_dim):
    position = np.arange(max_len)[:, np.newaxis]
    div_term = np.exp(np.arange(0, hidden_dim, 2) * -(np.log(10000.0) / hidden_dim))

    pe = np.zeros((max_len, hidden_dim))
    pe[:, 0::2] = np.sin(position * div_term)
    pe[:, 1::2] = np.cos(position * div_term)
    return pe

# 生成位置编码
pe = sinusoidal_position_encoding(512, 4096)
# Shape: [512, 4096]

4.3 RoPE(旋转位置编码)

现代 LLM(如 LLaMA、Qwen)使用 RoPE(Rotary Position Embedding)

graph LR
    subgraph RoPE 原理
        Q[Query 向量] --> R1[旋转矩阵<br/>R(pos)]
        R1 --> RQ[旋转后的 Query]

        K[Key 向量] --> R2[旋转矩阵<br/>R(pos)]
        R2 --> RK[旋转后的 Key]
    end

RoPE 的优势

  • 相对位置信息自然编码
  • 支持任意长度外推
  • 计算高效
# RoPE 的核心思想(简化)
def rotate_half(x):
    x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
    return torch.cat([-x2, x1], dim=-1)

def apply_rope(q, k, cos, sin):
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

5. Multi-Head Attention

这是 Transformer 的核心组件,详细原理将在下一章介绍。这里给出结构概览:

graph TD
    subgraph Multi-Head Attention
        I[输入 X] --> WQ[W_Q]
        I --> WK[W_K]
        I --> WV[W_V]

        WQ --> Q[Query]
        WK --> K[Key]
        WV --> V[Value]

        Q --> SPLIT1[Split Heads]
        K --> SPLIT2[Split Heads]
        V --> SPLIT3[Split Heads]

        SPLIT1 --> H1[Head 1]
        SPLIT1 --> H2[Head 2]
        SPLIT1 --> HN[Head N]

        SPLIT2 --> H1
        SPLIT2 --> H2
        SPLIT2 --> HN

        SPLIT3 --> H1
        SPLIT3 --> H2
        SPLIT3 --> HN

        H1 --> CAT[Concat]
        H2 --> CAT
        HN --> CAT

        CAT --> WO[W_O]
        WO --> O[输出]
    end

参数量

Q, K, V 投影: 3 × hidden_dim × hidden_dim
输出投影: hidden_dim × hidden_dim
总计: 4 × hidden_dim²

示例(hidden_dim = 4096):

参数量 = 4 × 4096² = 67,108,864 ≈ 67M

6. Feed Forward Network (FFN)

6.1 基本结构

FFN 是一个简单的两层全连接网络:

graph LR
    I[输入<br/>hidden_dim] --> L1[Linear 1<br/>hidden → intermediate]
    L1 --> ACT[激活函数<br/>GELU/SiLU]
    ACT --> L2[Linear 2<br/>intermediate → hidden]
    L2 --> O[输出<br/>hidden_dim]
class FeedForward(nn.Module):
    def __init__(self, hidden_dim, intermediate_dim):
        super().__init__()
        self.up_proj = nn.Linear(hidden_dim, intermediate_dim)
        self.down_proj = nn.Linear(intermediate_dim, hidden_dim)
        self.activation = nn.GELU()

    def forward(self, x):
        # x: [batch, seq_len, hidden_dim]
        x = self.up_proj(x)       # [batch, seq_len, intermediate_dim]
        x = self.activation(x)     # [batch, seq_len, intermediate_dim]
        x = self.down_proj(x)      # [batch, seq_len, hidden_dim]
        return x

6.2 SwiGLU 变体

LLaMA 等模型使用 SwiGLU 激活函数:

graph LR
    I[输入] --> G[Gate Proj]
    I --> U[Up Proj]
    G --> SILU[SiLU 激活]
    SILU --> MUL[×]
    U --> MUL
    MUL --> D[Down Proj]
    D --> O[输出]
class SwiGLUFeedForward(nn.Module):
    def __init__(self, hidden_dim, intermediate_dim):
        super().__init__()
        self.gate_proj = nn.Linear(hidden_dim, intermediate_dim)
        self.up_proj = nn.Linear(hidden_dim, intermediate_dim)
        self.down_proj = nn.Linear(intermediate_dim, hidden_dim)

    def forward(self, x):
        gate = torch.nn.functional.silu(self.gate_proj(x))
        up = self.up_proj(x)
        return self.down_proj(gate * up)

6.3 FFN 参数量

标准 FFN

参数量 = 2 × hidden_dim × intermediate_dim

SwiGLU FFN(有三个投影矩阵):

参数量 = 3 × hidden_dim × intermediate_dim

示例(LLaMA-7B,hidden=4096,intermediate=11008):

参数量 = 3 × 4096 × 11008 = 135,266,304 ≈ 135M

7. Layer Normalization

7.1 为什么需要归一化

深层网络中,每层输出的分布会发生变化(Internal Covariate Shift),导致:

  • 训练不稳定
  • 需要较小的学习率
  • 收敛慢

Layer Normalization 将每层输出归一化到均值 0、方差 1 的分布。

7.2 计算公式

LayerNorm(x) = γ × (x - μ) / √(σ² + ε) + β

其中:

  • μ:均值
  • σ²:方差
  • ε:防止除零的小常数
  • γ, β:可学习的缩放和偏移参数
class LayerNorm(nn.Module):
    def __init__(self, hidden_dim, eps=1e-5):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_dim))
        self.bias = nn.Parameter(torch.zeros(hidden_dim))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        return self.weight * (x - mean) / torch.sqrt(var + self.eps) + self.bias

7.3 RMSNorm

LLaMA 等模型使用 RMSNorm,去掉了均值中心化:

RMSNorm(x) = γ × x / √(mean(x²) + ε)

优势:计算更简单,效果相当。

class RMSNorm(nn.Module):
    def __init__(self, hidden_dim, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_dim))
        self.eps = eps

    def forward(self, x):
        rms = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
        return self.weight * x / rms

7.4 Pre-Norm vs Post-Norm

graph TB
    subgraph Post-Norm
        I1[输入] --> ATT1[Attention]
        ATT1 --> ADD1[+]
        I1 --> ADD1
        ADD1 --> LN1[LayerNorm]
    end

    subgraph Pre-Norm(现代 LLM 常用)
        I2[输入] --> LN2[LayerNorm]
        LN2 --> ATT2[Attention]
        ATT2 --> ADD2[+]
        I2 --> ADD2
    end

    style LN2 fill:#c8e6c9

Pre-Norm 优势

  • 训练更稳定
  • 允许更深的网络
  • 更容易收敛

8. 残差连接

8.1 什么是残差连接

残差连接让信息可以"跳过"某些层直接传递:

output = x + Layer(x)

8.2 为什么残差连接重要

graph LR
    subgraph 无残差
        X1[x] --> L1[Layer 1] --> L2[Layer 2] --> L3[Layer 3] --> Y1[y]
    end

    subgraph 有残差
        X2[x] --> LA[Layer 1] --> LB[Layer 2] --> LC[Layer 3] --> Y2[y]
        X2 --> Y2
        LA --> LB
        LB --> LC
    end

优势

  • 缓解梯度消失
  • 允许训练更深的网络
  • 信息直接传递不会丢失

9. 完整 Transformer Block 代码

import torch
import torch.nn as nn

class TransformerBlock(nn.Module):
    def __init__(self, hidden_dim, num_heads, intermediate_dim):
        super().__init__()
        self.norm1 = RMSNorm(hidden_dim)
        self.attention = MultiHeadAttention(hidden_dim, num_heads)
        self.norm2 = RMSNorm(hidden_dim)
        self.ffn = SwiGLUFeedForward(hidden_dim, intermediate_dim)

    def forward(self, x, attention_mask=None):
        # Pre-Norm + Attention + 残差
        residual = x
        x = self.norm1(x)
        x = self.attention(x, attention_mask)
        x = residual + x

        # Pre-Norm + FFN + 残差
        residual = x
        x = self.norm2(x)
        x = self.ffn(x)
        x = residual + x

        return x

10. 参数量计算实战

10.1 LLaMA-2-7B 参数分布

组件公式参数量
Embeddingvocab × hidden32000 × 4096 = 131M
每层 Attention Qhidden × hidden4096² = 16.8M
每层 Attention Khidden × (hidden/n_heads × n_kv_heads)4096 × 4096 = 16.8M
每层 Attention Vhidden × (hidden/n_heads × n_kv_heads)4096 × 4096 = 16.8M
每层 Attention Ohidden × hidden4096² = 16.8M
每层 FFN gatehidden × intermediate4096 × 11008 = 45.1M
每层 FFN uphidden × intermediate4096 × 11008 = 45.1M
每层 FFN downintermediate × hidden11008 × 4096 = 45.1M
每层 Norm2 × hidden2 × 4096 = 8K
LM Headhidden × vocab4096 × 32000 = 131M

每层总计:约 202M 参数 32 层总计:32 × 202M = 6.46B 加上 Embedding 和 LM Head:约 6.7B

10.2 参数分布饼图

pie title LLaMA-7B 参数分布
    "Attention (Q/K/V/O)" : 32
    "FFN" : 65
    "Embedding + LM Head" : 2
    "Norm" : 1

关键观察

  • FFN 占比最大(约 65%)
  • Attention 其次(约 32%)
  • Embedding 占比很小(约 2%)

这解释了为什么 vLLM 主要优化 Attention 和内存管理,而不是 FFN。


11. 本章小结

架构要点

  1. Decoder-Only 架构:现代 LLM 的主流选择
  2. Transformer Block:Attention + FFN + Norm + 残差
  3. 位置编码:RoPE 是现代标准

关键组件

组件作用现代实现
EmbeddingToken → Vector直接查表
位置编码注入位置信息RoPE
Self-Attention捕获序列关系Multi-Head
FFN非线性变换SwiGLU
Layer Norm稳定训练RMSNorm
残差连接信息直传Pre-Norm

参数分布

  • FFN 占主导(约 65%)
  • Attention 约 32%
  • Embedding 约 2%

与 vLLM 的关联

  • Attention 计算是 KV Cache 优化的核心
  • 参数分布影响显存使用和优化策略
  • 位置编码影响序列长度支持

思考题

  1. 为什么 Decoder-Only 架构在 LLM 中比 Encoder-Decoder 更流行?
  2. RoPE 相比正弦位置编码有什么优势?
  3. 为什么 FFN 的参数量比 Attention 多,但 vLLM 主要优化 Attention?

下一步

Transformer 架构介绍完毕,接下来我们将深入学习其核心——注意力机制:

👉 下一章:注意力机制原理

3 - 注意力机制原理

注意力机制原理

本章将深入介绍自注意力机制的数学原理和计算过程,这是理解 vLLM 核心优化的关键。


引言

注意力机制是 Transformer 的核心创新,也是 vLLM 优化的主要目标。理解注意力机制的计算过程,对于理解 KV Cache 和 PagedAttention 至关重要。


1. 注意力的直觉理解

1.1 人类注意力的类比

想象你在阅读一篇文章,当你看到"他"这个代词时,你会自动"关注"前文中提到的人名,以理解"他"指的是谁。

这就是注意力机制的核心思想:让模型学会"关注"序列中最相关的部分

graph LR
    subgraph 阅读理解
        T1[张三] --> T2[今天] --> T3[去了] --> T4[公园]
        T4 --> T5[他]
        T5 -.->|关注| T1
    end

1.2 从"全局视野"到"重点关注"

没有注意力机制时,模型只能看到固定窗口内的信息。有了注意力机制:

graph TB
    subgraph 固定窗口
        FW[只能看到附近几个 token]
    end

    subgraph 注意力机制
        ATT[可以关注序列中任意位置<br/>并根据相关性分配权重]
    end

    style ATT fill:#c8e6c9

2. 自注意力(Self-Attention)计算

2.1 Query、Key、Value 的含义

自注意力使用三个向量:

向量类比作用
Query (Q)“我要找什么”当前位置的查询向量
Key (K)“我是什么”每个位置的索引向量
Value (V)“我的内容”每个位置的值向量

直觉理解

  • Q 是"问题"
  • K 是"索引/标签"
  • V 是"内容"
  • 计算 Q 和所有 K 的相似度,用相似度加权所有 V

2.2 计算公式

自注意力的核心公式:

$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V $$

其中:

  • $Q$:Query 矩阵,形状 $[seq_len, d_k]$
  • $K$:Key 矩阵,形状 $[seq_len, d_k]$
  • $V$:Value 矩阵,形状 $[seq_len, d_v]$
  • $d_k$:Key 的维度(用于缩放)

2.3 计算步骤详解

flowchart TD
    subgraph 步骤1: 生成 Q, K, V
        X[输入 X<br/>seq_len × hidden_dim]
        X --> WQ[W_Q 投影]
        X --> WK[W_K 投影]
        X --> WV[W_V 投影]
        WQ --> Q[Query<br/>seq_len × d_k]
        WK --> K[Key<br/>seq_len × d_k]
        WV --> V[Value<br/>seq_len × d_v]
    end

    subgraph 步骤2: 计算注意力分数
        Q --> MM[Q × K^T]
        K --> MM
        MM --> SC[÷ √d_k<br/>缩放]
        SC --> MASK[+ Mask<br/>可选]
        MASK --> SM[Softmax]
        SM --> ATT[注意力权重<br/>seq_len × seq_len]
    end

    subgraph 步骤3: 加权求和
        ATT --> OUT[× V]
        V --> OUT
        OUT --> O[输出<br/>seq_len × d_v]
    end

    style SC fill:#fff9c4
    style SM fill:#c8e6c9

2.4 逐步计算示例

假设我们有一个简单的序列,3 个 token,每个 token 的隐藏维度是 4:

import torch
import torch.nn.functional as F

# 输入
seq_len = 3
d_k = 4

# 假设 Q, K, V 已经通过线性投影得到
Q = torch.tensor([
    [1.0, 0.0, 1.0, 0.0],  # token 0 的 query
    [0.0, 1.0, 0.0, 1.0],  # token 1 的 query
    [1.0, 1.0, 0.0, 0.0],  # token 2 的 query
])

K = torch.tensor([
    [1.0, 0.0, 0.0, 1.0],  # token 0 的 key
    [0.0, 1.0, 1.0, 0.0],  # token 1 的 key
    [1.0, 1.0, 1.0, 1.0],  # token 2 的 key
])

V = torch.tensor([
    [1.0, 2.0, 3.0, 4.0],  # token 0 的 value
    [5.0, 6.0, 7.0, 8.0],  # token 1 的 value
    [9.0, 10., 11., 12.],  # token 2 的 value
])

# 步骤 1: 计算 Q × K^T
scores = Q @ K.T
print("注意力分数 (未缩放):")
print(scores)
# tensor([[1., 1., 2.],
#         [1., 1., 2.],
#         [1., 1., 3.]])

# 步骤 2: 缩放
d_k = 4
scaled_scores = scores / (d_k ** 0.5)
print("\n缩放后的分数:")
print(scaled_scores)

# 步骤 3: Softmax
attention_weights = F.softmax(scaled_scores, dim=-1)
print("\n注意力权重:")
print(attention_weights)
# 每行和为 1

# 步骤 4: 加权求和
output = attention_weights @ V
print("\n输出:")
print(output)

2.5 注意力权重可视化

注意力权重形成一个 [seq_len, seq_len] 的矩阵:

         Token 0  Token 1  Token 2
Token 0 [  0.30    0.30     0.40  ]  # Token 0 关注谁
Token 1 [  0.30    0.30     0.40  ]  # Token 1 关注谁
Token 2 [  0.20    0.20     0.60  ]  # Token 2 关注谁

每一行表示一个 token 对所有 token 的注意力分布(和为 1)。


3. 缩放因子 √d 的作用

3.1 为什么需要缩放

当 $d_k$ 较大时,$QK^T$ 的点积结果会变得很大。这会导致:

  1. Softmax 饱和:大值经过 softmax 后趋近于 1,小值趋近于 0
  2. 梯度消失:softmax 在饱和区域的梯度接近 0
graph LR
    subgraph 无缩放
        S1[大的点积值] --> SM1[Softmax 饱和]
        SM1 --> G1[梯度消失]
    end

    subgraph 有缩放
        S2[缩放后的点积] --> SM2[Softmax 正常]
        SM2 --> G2[梯度正常]
    end

    style G1 fill:#ffcdd2
    style G2 fill:#c8e6c9

3.2 数学解释

假设 Q 和 K 的元素服从均值 0、方差 1 的分布,那么:

  • $Q \cdot K$ 的均值为 0
  • $Q \cdot K$ 的方差为 $d_k$

除以 $\sqrt{d_k}$ 后,方差变为 1,分布更稳定。


4. 多头注意力(Multi-Head Attention)

4.1 为什么需要多头

单头注意力只能学习一种"关注模式"。多头注意力让模型同时学习多种不同的关系:

graph TB
    subgraph 多头注意力的优势
        H1[Head 1<br/>关注语法关系]
        H2[Head 2<br/>关注语义关系]
        H3[Head 3<br/>关注位置关系]
        H4[Head 4<br/>关注其他模式]
    end

4.2 多头计算过程

graph TD
    X[输入 X<br/>batch × seq × hidden] --> SPLIT[分割成多个头]

    subgraph 并行计算
        SPLIT --> H1[Head 1<br/>Attention]
        SPLIT --> H2[Head 2<br/>Attention]
        SPLIT --> H3[Head 3<br/>Attention]
        SPLIT --> HN[Head N<br/>Attention]
    end

    H1 --> CAT[Concat]
    H2 --> CAT
    H3 --> CAT
    HN --> CAT

    CAT --> WO[W_O 投影]
    WO --> O[输出]

4.3 代码实现

class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_dim, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads

        # Q, K, V 投影
        self.q_proj = nn.Linear(hidden_dim, hidden_dim)
        self.k_proj = nn.Linear(hidden_dim, hidden_dim)
        self.v_proj = nn.Linear(hidden_dim, hidden_dim)

        # 输出投影
        self.o_proj = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, x):
        batch_size, seq_len, _ = x.shape

        # 投影
        Q = self.q_proj(x)  # [batch, seq, hidden]
        K = self.k_proj(x)
        V = self.v_proj(x)

        # 重塑为多头: [batch, seq, num_heads, head_dim]
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim)
        K = K.view(batch_size, seq_len, self.num_heads, self.head_dim)
        V = V.view(batch_size, seq_len, self.num_heads, self.head_dim)

        # 转置: [batch, num_heads, seq, head_dim]
        Q = Q.transpose(1, 2)
        K = K.transpose(1, 2)
        V = V.transpose(1, 2)

        # 注意力计算
        scores = Q @ K.transpose(-2, -1) / (self.head_dim ** 0.5)
        attn_weights = F.softmax(scores, dim=-1)
        output = attn_weights @ V  # [batch, num_heads, seq, head_dim]

        # 合并多头
        output = output.transpose(1, 2)  # [batch, seq, num_heads, head_dim]
        output = output.reshape(batch_size, seq_len, -1)  # [batch, seq, hidden]

        # 输出投影
        output = self.o_proj(output)

        return output

4.4 头数与维度的关系

hidden_dim = num_heads × head_dim

常见配置

模型hidden_dimnum_headshead_dim
GPT-2 Small7681264
GPT-2 Large12802064
LLaMA-7B409632128
LLaMA-70B819264128

5. Masked Attention(因果掩码)

5.1 为什么需要掩码

在语言模型中,预测下一个 token 时不能看到未来的 token。因果掩码确保每个位置只能关注它之前的位置。

graph LR
    subgraph 无掩码(双向注意力)
        A1[token 1] <--> A2[token 2]
        A1 <--> A3[token 3]
        A2 <--> A3
    end

    subgraph 有掩码(单向注意力)
        B1[token 1]
        B2[token 2] --> B1
        B3[token 3] --> B1
        B3 --> B2
    end

5.2 掩码矩阵

因果掩码是一个下三角矩阵:

seq_len = 4
mask = torch.tril(torch.ones(seq_len, seq_len))
print(mask)
# tensor([[1., 0., 0., 0.],
#         [1., 1., 0., 0.],
#         [1., 1., 1., 0.],
#         [1., 1., 1., 1.]])

可视化

         位置 0  位置 1  位置 2  位置 3
位置 0  [  1      0       0       0   ]  → 只能看自己
位置 1  [  1      1       0       0   ]  → 可看 0, 1
位置 2  [  1      1       1       0   ]  → 可看 0, 1, 2
位置 3  [  1      1       1       1   ]  → 可看全部

5.3 应用掩码

在 softmax 之前应用掩码,将不允许关注的位置设为负无穷:

def masked_attention(Q, K, V, mask):
    d_k = Q.shape[-1]
    scores = Q @ K.transpose(-2, -1) / (d_k ** 0.5)

    # 应用掩码:将 mask=0 的位置设为 -inf
    scores = scores.masked_fill(mask == 0, float('-inf'))

    attn_weights = F.softmax(scores, dim=-1)
    output = attn_weights @ V
    return output

掩码后的注意力分数

before softmax:
[[ 0.5   -inf   -inf   -inf]
 [ 0.3    0.7   -inf   -inf]
 [ 0.2    0.4    0.6   -inf]
 [ 0.1    0.3    0.5    0.8]]

after softmax:
[[1.00   0.00   0.00   0.00]  # 只关注位置 0
 [0.40   0.60   0.00   0.00]  # 关注位置 0, 1
 [0.25   0.33   0.42   0.00]  # 关注位置 0, 1, 2
 [0.15   0.22   0.28   0.35]] # 关注全部

6. 注意力的计算复杂度

6.1 时间复杂度

核心计算 $QK^T$ 和 $(\text{softmax})V$:

  • $QK^T$:$[n, d] \times [d, n] = O(n^2 d)$
  • $\text{Attention} \times V$:$[n, n] \times [n, d] = O(n^2 d)$

总时间复杂度:$O(n^2 d)$

其中 $n$ 是序列长度,$d$ 是维度。

6.2 空间复杂度

需要存储注意力权重矩阵:

空间复杂度:$O(n^2)$

6.3 长序列的挑战

graph LR
    subgraph 序列长度影响
        L1[n=512] --> C1[计算量 262K]
        L2[n=2048] --> C2[计算量 4.2M]
        L3[n=8192] --> C3[计算量 67M]
        L4[n=32768] --> C4[计算量 1B]
    end

当序列长度增加 4 倍,计算量增加 16 倍!这是长序列 LLM 面临的核心挑战。

6.4 优化方法简介

方法原理复杂度
Flash AttentionIO 优化,减少内存访问O(n²) 但更快
Sparse Attention稀疏注意力模式O(n√n) 或 O(n)
Linear Attention核方法近似O(n)
Sliding Window只关注局部窗口O(nw)

vLLM 主要使用 Flash Attention 作为注意力后端。


7. Grouped-Query Attention (GQA)

7.1 传统 MHA vs GQA

为了减少 KV Cache 的内存占用,现代模型使用 GQA:

graph TB
    subgraph MHA(Multi-Head Attention)
        MQ1[Q Head 1] --> MK1[K Head 1]
        MQ2[Q Head 2] --> MK2[K Head 2]
        MQ3[Q Head 3] --> MK3[K Head 3]
        MQ4[Q Head 4] --> MK4[K Head 4]
    end

    subgraph GQA(Grouped-Query Attention)
        GQ1[Q Head 1] --> GK1[K Group 1]
        GQ2[Q Head 2] --> GK1
        GQ3[Q Head 3] --> GK2[K Group 2]
        GQ4[Q Head 4] --> GK2
    end

7.2 GQA 的优势

特性MHAGQA
Q headsNN
K/V headsNN/group_size
KV Cache 大小100%减少到 1/group_size
模型质量基准接近基准

示例(LLaMA-2-70B):

  • Q heads: 64
  • KV heads: 8
  • KV Cache 减少 8 倍!

8. 注意力与 KV Cache 的关系

8.1 为什么需要缓存 K 和 V

在自回归生成中,每生成一个新 token,都需要计算它与所有历史 token 的注意力。

不使用 KV Cache:每次都重新计算所有 token 的 K 和 V 使用 KV Cache:缓存历史 token 的 K 和 V,只计算新 token 的

这正是下一章的主题!

8.2 预览:KV Cache 的作用

sequenceDiagram
    participant New as 新 Token
    participant Cache as KV Cache
    participant ATT as Attention

    Note over Cache: 存储历史 token 的 K, V

    New->>ATT: 计算新 token 的 Q, K, V
    Cache->>ATT: 提供历史 K, V
    ATT->>ATT: Q_new × [K_cache, K_new]^T
    ATT->>ATT: Attention × [V_cache, V_new]
    ATT->>Cache: 将 K_new, V_new 加入缓存

9. 本章小结

核心公式

$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V $$

关键概念

概念说明
Q/K/VQuery(查询)、Key(键)、Value(值)
缩放因子$\sqrt{d_k}$,防止 softmax 饱和
多头注意力并行学习多种注意力模式
因果掩码防止看到未来 token
GQA减少 KV heads,降低内存占用

计算复杂度

  • 时间复杂度:$O(n^2 d)$
  • 空间复杂度:$O(n^2)$
  • 长序列是主要挑战

与 vLLM 的关联

  • KV Cache 是注意力优化的核心
  • PagedAttention 优化 K/V 的内存管理
  • Flash Attention 优化注意力计算速度

思考题

  1. 如果没有缩放因子 $\sqrt{d_k}$,会发生什么?
  2. 为什么 GQA 可以在减少 KV heads 的同时保持模型质量?
  3. 在因果掩码下,位置 0 的 token 只能关注自己,这会影响模型效果吗?

下一步

理解了注意力机制后,我们将深入学习 KV Cache 的概念和作用:

👉 下一章:KV Cache 概念

4 - KV Cache 概念

KV Cache 概念

本章将详细介绍 KV Cache 的概念、作用和实现原理,这是理解 vLLM 核心优化的关键。


引言

KV Cache 是 LLM 推理中最重要的优化技术之一。它通过缓存历史计算结果,避免重复计算,显著提升推理速度。理解 KV Cache 对于理解 vLLM 的 PagedAttention 至关重要。


1. 为什么需要 KV Cache

1.1 自回归生成的特点

LLM 生成文本是自回归的:每次只生成一个 token,然后将其加入输入,继续生成下一个。

sequenceDiagram
    participant User as 用户
    participant LLM as LLM

    User->>LLM: "今天天气"
    LLM-->>LLM: 计算所有 token 的 Attention
    LLM->>User: "很"

    User->>LLM: "今天天气很"
    LLM-->>LLM: 重新计算所有 token 的 Attention?
    LLM->>User: "好"

    User->>LLM: "今天天气很好"
    LLM-->>LLM: 又重新计算所有?
    LLM->>User: "。"

1.2 没有 KV Cache 时的重复计算

在注意力计算中,每个 token 需要:

  1. 计算自己的 Q(Query)
  2. 计算自己的 K(Key)和 V(Value)
  3. 用 Q 与所有 K 计算注意力
  4. 用注意力加权所有 V

问题:历史 token 的 K 和 V 每次都要重新计算!

flowchart TD
    subgraph Step 1: 处理 'Hello'
        A1[Hello] --> K1[计算 K₁]
        A1 --> V1[计算 V₁]
        A1 --> Q1[计算 Q₁]
    end

    subgraph Step 2: 处理 'Hello World'
        B1[Hello] --> K1_2[重新计算 K₁]
        B1 --> V1_2[重新计算 V₁]
        B2[World] --> K2[计算 K₂]
        B2 --> V2[计算 V₂]
        B2 --> Q2[计算 Q₂]
    end

    subgraph Step 3: 处理 'Hello World !'
        C1[Hello] --> K1_3[再次计算 K₁]
        C1 --> V1_3[再次计算 V₁]
        C2[World] --> K2_3[再次计算 K₂]
        C2 --> V2_3[再次计算 V₂]
        C3[!] --> K3[计算 K₃]
        C3 --> V3[计算 V₃]
        C3 --> Q3[计算 Q₃]
    end

    style K1_2 fill:#ffcdd2
    style V1_2 fill:#ffcdd2
    style K1_3 fill:#ffcdd2
    style V1_3 fill:#ffcdd2
    style K2_3 fill:#ffcdd2
    style V2_3 fill:#ffcdd2

1.3 计算量分析

生成 N 个 token,不使用 KV Cache:

Step需要计算的 K/V累计 K/V 计算次数
111
22(重新计算 1 + 新的 1)1 + 2 = 3
33(重新计算 2 + 新的 1)3 + 3 = 6
NN1 + 2 + … + N = N(N+1)/2

时间复杂度:$O(N^2)$


2. KV Cache 工作原理

2.1 核心思想

观察:在自回归生成中,历史 token 的 K 和 V 不会改变。

解决方案:计算一次后缓存起来,后续直接使用。

flowchart TD
    subgraph 使用 KV Cache
        subgraph Step 1
            S1A[Hello] --> S1K[计算 K₁]
            S1A --> S1V[计算 V₁]
            S1K --> Cache1[(缓存 K₁)]
            S1V --> Cache1
        end

        subgraph Step 2
            Cache1 --> Use1[使用缓存的 K₁, V₁]
            S2A[World] --> S2K[计算 K₂]
            S2A --> S2V[计算 V₂]
            S2K --> Cache2[(缓存 K₁, K₂)]
            S2V --> Cache2
        end

        subgraph Step 3
            Cache2 --> Use2[使用缓存的 K₁, K₂, V₁, V₂]
            S3A[!] --> S3K[计算 K₃]
            S3A --> S3V[计算 V₃]
        end
    end

    style Use1 fill:#c8e6c9
    style Use2 fill:#c8e6c9

2.2 计算量对比

使用 KV Cache 后:

Step需要计算的 K/V累计 K/V 计算次数
111
21(只计算新的)1 + 1 = 2
31(只计算新的)2 + 1 = 3
N1N

时间复杂度:$O(N)$

加速比:从 $O(N^2)$ 到 $O(N)$,生成 1000 个 token 时加速约 500 倍!

2.3 图解对比

graph TD
    subgraph 无 KV Cache
        A1[Token 1] --> C1[计算全部 K,V]
        A2[Token 1,2] --> C2[计算全部 K,V]
        A3[Token 1,2,3] --> C3[计算全部 K,V]
        A4[Token 1,2,3,4] --> C4[计算全部 K,V]
        style A1 fill:#ffcdd2
        style A2 fill:#ffcdd2
        style A3 fill:#ffcdd2
        style A4 fill:#ffcdd2
    end

    subgraph 有 KV Cache
        B1[Token 1] --> D1[计算 K₁,V₁ + 缓存]
        B2[Token 2] --> D2[计算 K₂,V₂ + 读缓存]
        B3[Token 3] --> D3[计算 K₃,V₃ + 读缓存]
        B4[Token 4] --> D4[计算 K₄,V₄ + 读缓存]
        D1 --> Cache[(KV Cache)]
        D2 --> Cache
        D3 --> Cache
        D4 --> Cache
        Cache --> D2
        Cache --> D3
        Cache --> D4
        style B1 fill:#c8e6c9
        style B2 fill:#c8e6c9
        style B3 fill:#c8e6c9
        style B4 fill:#c8e6c9
    end

3. KV Cache 的数据结构

3.1 基本形状

KV Cache 需要存储每层的 K 和 V:

# KV Cache 形状
# 方式 1: 分开存储
k_cache = torch.zeros(num_layers, batch_size, num_heads, max_seq_len, head_dim)
v_cache = torch.zeros(num_layers, batch_size, num_heads, max_seq_len, head_dim)

# 方式 2: 合并存储
kv_cache = torch.zeros(num_layers, 2, batch_size, num_heads, max_seq_len, head_dim)
# kv_cache[:, 0, ...] 是 K
# kv_cache[:, 1, ...] 是 V

3.2 维度解释

维度含义示例值
num_layersTransformer 层数32
2K 和 V2
batch_size批次大小1-64
num_heads注意力头数(或 KV heads)32 或 8
max_seq_len最大序列长度4096
head_dim每个头的维度128

3.3 代码示例

class KVCache:
    def __init__(self, num_layers, num_heads, head_dim, max_seq_len, dtype=torch.float16):
        self.num_layers = num_layers
        self.max_seq_len = max_seq_len

        # 预分配 K 和 V 缓存
        # 形状: [num_layers, 2, max_batch, num_heads, max_seq_len, head_dim]
        self.cache = None
        self.current_len = 0

    def allocate(self, batch_size):
        self.cache = torch.zeros(
            self.num_layers, 2, batch_size, self.num_heads,
            self.max_seq_len, self.head_dim,
            dtype=self.dtype, device='cuda'
        )
        self.current_len = 0

    def update(self, layer_idx, new_k, new_v):
        """添加新的 K, V 到缓存"""
        # new_k, new_v: [batch, num_heads, new_len, head_dim]
        new_len = new_k.shape[2]
        start_pos = self.current_len
        end_pos = start_pos + new_len

        self.cache[layer_idx, 0, :, :, start_pos:end_pos, :] = new_k
        self.cache[layer_idx, 1, :, :, start_pos:end_pos, :] = new_v

        if layer_idx == self.num_layers - 1:
            self.current_len = end_pos

    def get(self, layer_idx):
        """获取当前层的完整 K, V"""
        k = self.cache[layer_idx, 0, :, :, :self.current_len, :]
        v = self.cache[layer_idx, 1, :, :, :self.current_len, :]
        return k, v

4. 显存占用详细计算

4.1 计算公式

KV Cache 显存 = 2 × num_layers × num_kv_heads × head_dim × seq_len × batch_size × bytes_per_element

简化版(使用 hidden_dim):

KV Cache 显存 = 2 × num_layers × hidden_dim × seq_len × batch_size × bytes_per_element

注意:如果使用 GQA,num_kv_heads 可能小于 num_attention_heads。

4.2 LLaMA-2-7B 示例

模型参数

  • num_layers: 32
  • hidden_dim: 4096
  • num_kv_heads: 32(MHA)
  • head_dim: 128
  • 精度: FP16(2 bytes)

单个请求不同序列长度的 KV Cache

序列长度计算大小
5122 × 32 × 4096 × 512 × 2256 MB
10242 × 32 × 4096 × 1024 × 2512 MB
20482 × 32 × 4096 × 2048 × 21 GB
40962 × 32 × 4096 × 4096 × 22 GB
81922 × 32 × 4096 × 8192 × 24 GB

4.3 LLaMA-2-70B 示例(使用 GQA)

模型参数

  • num_layers: 80
  • hidden_dim: 8192
  • num_kv_heads: 8(GQA,原本是 64 个 attention heads)
  • head_dim: 128
  • 精度: FP16

单个请求 4096 序列长度

KV Cache = 2 × 80 × 8 × 128 × 4096 × 2 = 1.34 GB

对比 MHA(如果 kv_heads = 64):

KV Cache = 2 × 80 × 64 × 128 × 4096 × 2 = 10.7 GB

GQA 节省了 8 倍显存!

4.4 显存占用可视化

pie title 7B 模型显存分布(单请求 2048 tokens)
    "模型权重 (14GB)" : 14
    "KV Cache (1GB)" : 1
    "激活值等 (1GB)" : 1
pie title 7B 模型显存分布(32 并发 × 2048 tokens)
    "模型权重 (14GB)" : 14
    "KV Cache (32GB)" : 32
    "激活值等 (2GB)" : 2

5. KV Cache 管理的挑战

5.1 动态序列长度

KV Cache 的大小随着生成过程动态增长:

graph LR
    subgraph 生成过程
        S1[Step 1<br/>KV: 10 tokens]
        S2[Step 2<br/>KV: 11 tokens]
        S3[Step 3<br/>KV: 12 tokens]
        SN[Step N<br/>KV: N+10 tokens]
        S1 --> S2 --> S3 --> SN
    end

问题:在请求开始时,我们不知道最终会生成多少 token!

5.2 预分配策略的问题

传统方案:预分配最大可能长度(如 4096 tokens)

预分配: 4096 tokens × 每token 0.5MB = 2GB
实际使用: 100 tokens × 0.5MB = 50MB
浪费: 1.95GB (97.5%)
graph TB
    subgraph 预分配的浪费
        Alloc[预分配 2GB]
        Used[实际使用 50MB]
        Waste[浪费 1.95GB]
        Alloc --> Used
        Alloc --> Waste
    end

    style Waste fill:#ffcdd2

5.3 显存碎片化

当多个请求同时运行时,问题更加严重:

显存状态:
+--------+--------+--------+--------+--------+
| Req A  | Req B  | Req C  | Req D  | 空闲   |
| 2GB    | 2GB    | 2GB    | 2GB    | 碎片   |
| 用50MB | 用100MB| 用30MB | 用200MB|        |
+--------+--------+--------+--------+--------+

实际使用: 380MB
预分配: 8GB
浪费: 7.62GB (95%!)

5.4 这就是 PagedAttention 要解决的问题!

传统方案的问题:

  1. 预分配浪费:每个请求预留最大空间
  2. 内部碎片:实际使用远小于预分配
  3. 外部碎片:释放后的空间不连续

PagedAttention 的解决方案(下一部分详细介绍):

  1. 按需分配:用多少分配多少
  2. 分块管理:固定大小的块,减少碎片
  3. 非连续存储:块可以不连续

6. Prefill 和 Decode 中的 KV Cache

6.1 Prefill 阶段

处理输入 prompt,一次性计算所有输入 token 的 K、V:

flowchart LR
    subgraph Prefill
        I[输入: 'Hello, how are you?'<br/>5 tokens]
        C[并行计算 K₁...K₅, V₁...V₅]
        S[存入 KV Cache]
        I --> C --> S
    end

特点

  • 批量计算,效率高
  • 计算密集型
  • KV Cache 从 0 增长到输入长度

6.2 Decode 阶段

逐个生成 token,每次只计算新 token 的 K、V:

flowchart TD
    subgraph Decode 循环
        R[读取 KV Cache]
        N[新 token]
        C[计算 K_new, V_new]
        A[Attention: Q_new × [K_cache; K_new]]
        U[更新 KV Cache]
        O[输出 token]

        R --> A
        N --> C --> A
        A --> U --> O
        O -.->|下一轮| N
    end

特点

  • 增量计算,每次只算 1 个
  • 内存密集型(需要读取整个 KV Cache)
  • KV Cache 每步增长 1

6.3 两阶段的 KV Cache 操作对比

操作PrefillDecode
K/V 计算批量(N 个)单个(1 个)
KV Cache 读取全部
KV Cache 写入N 个1 个
计算/访存比

7. vLLM 中的 KV Cache 相关代码

7.1 关键文件位置

功能文件
KV Cache 管理vllm/v1/core/kv_cache_manager.py
块池vllm/v1/core/block_pool.py
块表vllm/v1/worker/block_table.py
KV Cache 接口vllm/v1/kv_cache_interface.py

7.2 数据结构预览

# vllm/v1/core/block_pool.py 中的块定义
@dataclass
class KVCacheBlock:
    block_id: int          # 块 ID
    ref_cnt: int           # 引用计数
    block_hash: Optional[BlockHash]  # 用于前缀缓存

# vllm/v1/worker/block_table.py 中的块表
class BlockTable:
    """管理逻辑块到物理块的映射"""
    def __init__(self, ...):
        self.block_table: torch.Tensor  # 形状: [max_blocks]

8. 本章小结

核心概念

  1. KV Cache 的作用:缓存历史 token 的 K、V,避免重复计算
  2. 加速效果:从 $O(N^2)$ 降到 $O(N)$,约 500 倍加速(N=1000)
  3. 显存占用:随序列长度线性增长,可能成为主要显存消耗

关键公式

KV Cache = 2 × num_layers × num_kv_heads × head_dim × seq_len × bytes

管理挑战

  • 动态增长:序列长度在生成过程中不断增加
  • 预分配浪费:传统方案浪费 60-80% 显存
  • 碎片化:多请求并发时问题更严重

与 vLLM 的关联

  • PagedAttention:解决 KV Cache 的显存浪费问题
  • 分块管理:将 KV Cache 分成固定大小的块
  • 按需分配:用多少分配多少,不预留

思考题

  1. 如果一个模型使用 GQA,KV heads 是 attention heads 的 1/8,KV Cache 显存会减少多少?
  2. 为什么 Decode 阶段是"内存密集型"而不是"计算密集型"?
  3. 如果 vLLM 要支持无限长度的上下文,KV Cache 管理会面临什么额外挑战?

下一步

了解了 KV Cache 后,让我们来看看 LLM 完整的生成过程:

👉 下一章:LLM 生成过程

5 - LLM 生成过程

LLM 生成过程

本章将详细介绍 LLM 文本生成的完整流程,包括 Prefill、Decode 两个阶段以及各种采样策略。


引言

LLM 生成文本是一个复杂的过程,涉及 tokenization、模型前向传播、采样等多个环节。理解这个过程对于理解 vLLM 的优化策略至关重要。


1. 生成流程概览

1.1 完整流程图

sequenceDiagram
    participant User as 用户
    participant Tok as Tokenizer
    participant Model as LLM
    participant Sampler as 采样器
    participant DeTok as Detokenizer

    User->>Tok: "Hello, world"
    Tok->>Model: [15496, 11, 995]

    rect rgb(200, 230, 200)
        Note over Model: Prefill 阶段
        Model->>Model: 处理所有输入 tokens
        Model->>Model: 初始化 KV Cache
        Model->>Sampler: logits
        Sampler->>Model: 第一个输出 token
    end

    rect rgb(200, 200, 230)
        Note over Model: Decode 阶段
        loop 直到停止条件
            Model->>Model: 处理 1 个新 token
            Model->>Model: 更新 KV Cache
            Model->>Sampler: logits
            Sampler->>Model: 下一个 token
        end
    end

    Model->>DeTok: [318, 716, 257, ...]
    DeTok->>User: "I am a language model..."

1.2 两阶段模型

LLM 生成分为两个截然不同的阶段:

阶段Prefill(预填充)Decode(解码)
处理内容整个输入 prompt新生成的 token
每次处理N 个 tokens1 个 token
KV Cache初始化增量更新
计算特性计算密集型内存密集型
GPU 利用率

2. Prefill 阶段详解

2.1 输入处理:Tokenization

第一步是将文本转换为 token IDs:

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

text = "Hello, how are you?"
tokens = tokenizer.encode(text)
print(tokens)  # [1, 15043, 29892, 920, 526, 366, 29973]
print(tokenizer.convert_ids_to_tokens(tokens))
# ['<s>', 'Hello', ',', 'how', 'are', 'you', '?']

2.2 并行计算所有 Token

在 Prefill 阶段,所有输入 token 可以并行处理:

flowchart TD
    subgraph Prefill 并行处理
        I[输入: token_ids<br/>[1, 15043, 29892, 920, 526, 366]]
        E[Embedding Layer<br/>并行查表]
        PE[Position Encoding<br/>添加位置信息]

        subgraph Transformer Layers
            L1[Layer 1]
            L2[Layer 2]
            LN[Layer N]
        end

        LH[LM Head]
        O[Logits<br/>[seq_len, vocab_size]]

        I --> E --> PE --> L1 --> L2 --> LN --> LH --> O
    end

    style E fill:#e3f2fd
    style L1 fill:#c8e6c9
    style L2 fill:#c8e6c9
    style LN fill:#c8e6c9

2.3 KV Cache 初始化与填充

Prefill 期间,计算并存储所有输入 token 的 K、V:

def prefill(model, input_ids, kv_cache):
    """
    input_ids: [batch_size, seq_len]
    """
    batch_size, seq_len = input_ids.shape

    # Embedding
    hidden_states = model.embed_tokens(input_ids)  # [batch, seq, hidden]

    # 遍历每一层
    for layer_idx, layer in enumerate(model.layers):
        # 计算 Q, K, V
        q = layer.q_proj(hidden_states)
        k = layer.k_proj(hidden_states)
        v = layer.v_proj(hidden_states)

        # 存入 KV Cache
        kv_cache.update(layer_idx, k, v)

        # 自注意力计算
        # ... (使用完整的 K, V,应用因果掩码)

        # FFN
        # ...

    # LM Head
    logits = model.lm_head(hidden_states)

    # 只返回最后一个位置的 logits(用于预测下一个 token)
    return logits[:, -1, :]  # [batch, vocab_size]

2.4 生成第一个 Token

使用最后一个位置的 logits 生成第一个输出 token:

def generate_first_token(logits, sampling_params):
    """
    logits: [batch_size, vocab_size]
    """
    # 应用采样策略
    next_token = sample(logits, sampling_params)  # [batch_size, 1]
    return next_token

3. Decode 阶段详解

3.1 单 Token 增量计算

Decode 阶段每次只处理一个新 token:

flowchart LR
    subgraph Decode 增量计算
        NT[新 token]
        E[Embedding]
        Q[计算 Q_new]
        KV[计算 K_new, V_new]
        Cache[(读取 KV Cache)]
        ATT[Attention<br/>Q_new × [K_cache; K_new]ᵀ]
        Update[更新 KV Cache]
        FFN[FFN]
        LM[LM Head]
        O[Logits]

        NT --> E --> Q
        E --> KV
        Cache --> ATT
        KV --> ATT
        Q --> ATT
        ATT --> FFN --> LM --> O
        KV --> Update --> Cache
    end

3.2 如何利用 KV Cache

def decode_step(model, new_token_id, kv_cache, position):
    """
    new_token_id: [batch_size, 1]
    position: 当前位置索引
    """
    # Embedding
    hidden_states = model.embed_tokens(new_token_id)  # [batch, 1, hidden]

    # 遍历每一层
    for layer_idx, layer in enumerate(model.layers):
        # 只计算新 token 的 Q, K, V
        q_new = layer.q_proj(hidden_states)  # [batch, 1, hidden]
        k_new = layer.k_proj(hidden_states)
        v_new = layer.v_proj(hidden_states)

        # 从缓存获取历史 K, V
        k_cache, v_cache = kv_cache.get(layer_idx)

        # 合并:[k_cache, k_new] 和 [v_cache, v_new]
        k_full = torch.cat([k_cache, k_new], dim=2)
        v_full = torch.cat([v_cache, v_new], dim=2)

        # 更新缓存
        kv_cache.update(layer_idx, k_new, v_new)

        # 注意力计算:Q_new (1个) 与 K_full (N+1个)
        # scores: [batch, heads, 1, N+1]
        scores = (q_new @ k_full.transpose(-2, -1)) / sqrt(head_dim)

        # 无需因果掩码(新 token 可以看到所有历史)
        attn_weights = F.softmax(scores, dim=-1)

        # 加权求和
        attn_output = attn_weights @ v_full  # [batch, heads, 1, head_dim]

        # ... FFN 等

    # LM Head
    logits = model.lm_head(hidden_states)  # [batch, 1, vocab_size]

    return logits.squeeze(1)  # [batch, vocab_size]

3.3 Decode 循环

def decode_loop(model, first_token, kv_cache, max_tokens, stop_token_id):
    """完整的 decode 循环"""
    generated_tokens = [first_token]
    current_token = first_token
    position = kv_cache.current_len

    for step in range(max_tokens):
        # 执行一步 decode
        logits = decode_step(model, current_token, kv_cache, position)

        # 采样下一个 token
        next_token = sample(logits, sampling_params)

        # 检查停止条件
        if next_token == stop_token_id:
            break

        generated_tokens.append(next_token)
        current_token = next_token
        position += 1

    return generated_tokens

4. 采样策略详解

4.1 从 Logits 到概率分布

模型输出的是 logits(未归一化的分数),需要转换为概率分布:

# logits: [vocab_size]
# 例如: [-1.2, 0.5, 2.3, -0.1, ...]

# 转换为概率
probs = F.softmax(logits, dim=-1)
# probs: [0.01, 0.05, 0.30, 0.03, ...]  和为 1

4.2 Greedy Decoding(贪婪解码)

最简单的策略:每次选择概率最高的 token。

def greedy_decode(logits):
    return torch.argmax(logits, dim=-1)

特点

  • 确定性(相同输入总是相同输出)
  • 可能陷入重复
  • 不适合创意生成

4.3 Temperature(温度)

Temperature 控制概率分布的"尖锐"程度:

def apply_temperature(logits, temperature):
    return logits / temperature
graph LR
    subgraph Temperature 效果
        T1[T=0.1<br/>非常尖锐<br/>几乎是 Greedy]
        T2[T=1.0<br/>原始分布]
        T3[T=2.0<br/>更平滑<br/>更随机]
    end
Temperature效果适用场景
< 1.0更确定,偏向高概率事实性回答
= 1.0原始分布一般场景
> 1.0更随机,更多样创意写作

4.4 Top-k Sampling

只从概率最高的 k 个 token 中采样:

def top_k_sampling(logits, k):
    # 找到 top-k 的值和索引
    top_k_logits, top_k_indices = torch.topk(logits, k)

    # 将其他位置设为 -inf
    filtered_logits = torch.full_like(logits, float('-inf'))
    filtered_logits.scatter_(-1, top_k_indices, top_k_logits)

    # 重新计算概率并采样
    probs = F.softmax(filtered_logits, dim=-1)
    return torch.multinomial(probs, num_samples=1)

示例(k=3):

原始概率: [0.40, 0.30, 0.15, 0.10, 0.05]
Top-3:    [0.40, 0.30, 0.15, 0.00, 0.00]
归一化后: [0.47, 0.35, 0.18, 0.00, 0.00]

4.5 Top-p (Nucleus) Sampling

选择累积概率达到 p 的最小 token 集合:

def top_p_sampling(logits, p):
    # 排序
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    probs = F.softmax(sorted_logits, dim=-1)

    # 计算累积概率
    cumsum_probs = torch.cumsum(probs, dim=-1)

    # 找到累积概率 > p 的位置
    sorted_indices_to_remove = cumsum_probs > p
    # 保留第一个超过阈值的
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = False

    # 过滤
    sorted_logits[sorted_indices_to_remove] = float('-inf')

    # 采样
    probs = F.softmax(sorted_logits, dim=-1)
    return torch.multinomial(probs, num_samples=1)

示例(p=0.9):

排序后概率: [0.40, 0.30, 0.15, 0.10, 0.05]
累积概率:   [0.40, 0.70, 0.85, 0.95, 1.00]
                                ↑ 超过 0.9
保留:       [0.40, 0.30, 0.15, 0.10]  累积 = 0.95

4.6 采样策略对比

graph TD
    subgraph 采样策略选择
        G[Greedy<br/>确定性、可能重复]
        TK[Top-k<br/>固定数量的候选]
        TP[Top-p<br/>动态数量的候选]
        T[Temperature<br/>控制随机程度]

        G --> |适合| F[事实问答]
        TK --> |适合| C1[通用对话]
        TP --> |适合| C2[创意写作]
        T --> |配合| TK
        T --> |配合| TP
    end

4.7 常用参数组合

场景TemperatureTop-pTop-k
代码生成0.1-0.3--
事实问答0.0-0.50.9-
通用对话0.7-0.90.940
创意写作1.0-1.20.9550
脑暴创意1.5-2.00.98100

5. 停止条件

5.1 常见停止条件

def check_stop_condition(token_id, generated_tokens, params):
    # 1. 生成了 EOS token
    if token_id == params.eos_token_id:
        return True, "EOS"

    # 2. 达到最大长度
    if len(generated_tokens) >= params.max_tokens:
        return True, "MAX_LENGTH"

    # 3. 遇到停止字符串
    text = tokenizer.decode(generated_tokens)
    for stop_str in params.stop_strings:
        if stop_str in text:
            return True, "STOP_STRING"

    return False, None

5.2 vLLM 中的停止条件

# vllm/sampling_params.py
class SamplingParams:
    max_tokens: int = 16           # 最大生成 token 数
    stop: List[str] = []           # 停止字符串
    stop_token_ids: List[int] = [] # 停止 token ID
    include_stop_str_in_output: bool = False
    ignore_eos: bool = False       # 是否忽略 EOS

6. 计算特性对比

6.1 Prefill vs Decode

graph LR
    subgraph Prefill
        P1[处理 N 个 tokens]
        P2[计算量: O(N² × d)]
        P3[内存访问: O(N × d)]
        P4[计算密度: 高]
    end

    subgraph Decode
        D1[处理 1 个 token]
        D2[计算量: O(N × d)]
        D3[内存访问: O(N × d)]
        D4[计算密度: 低]
    end
特性PrefillDecode
每次处理 tokensN1
Attention 计算Q[N] × K[N]ᵀQ[1] × K[N]ᵀ
计算量O(N²d)O(Nd)
内存读取模型权重模型权重 + KV Cache
计算/访存比
GPU 利用率50-80%10-30%
瓶颈计算内存带宽

6.2 GPU 利用率可视化

gantt
    title GPU 利用率时间线
    dateFormat X
    axisFormat %s

    section GPU 计算
    Prefill (高利用率) :done, p, 0, 20
    Decode Step 1 (低利用率) :crit, d1, 20, 25
    Decode Step 2 (低利用率) :crit, d2, 25, 30
    Decode Step 3 (低利用率) :crit, d3, 30, 35
    ...更多 decode steps :crit, dn, 35, 80

6.3 批处理的重要性

单独处理一个 decode step 时,GPU 大部分时间在等待数据传输。通过批处理多个请求,可以提高 GPU 利用率:

# 单请求
def decode_single(request):
    read_weights()      # 14GB
    process_1_token()   # 很小的计算量
    # GPU 大部分时间空闲

# 批处理
def decode_batch(requests, batch_size=32):
    read_weights()      # 14GB(只读一次)
    process_32_tokens() # 32 倍的计算量
    # GPU 利用率提高 32 倍

7. 完整生成示例

7.1 代码示例

def generate(model, tokenizer, prompt, max_tokens=100, temperature=0.8, top_p=0.9):
    # 1. Tokenization
    input_ids = tokenizer.encode(prompt, return_tensors='pt').cuda()

    # 2. 初始化 KV Cache
    kv_cache = KVCache(model.config)
    kv_cache.allocate(batch_size=1)

    # 3. Prefill 阶段
    logits = prefill(model, input_ids, kv_cache)

    # 4. 采样第一个 token
    sampling_params = SamplingParams(temperature=temperature, top_p=top_p)
    first_token = sample(logits, sampling_params)
    generated_tokens = [first_token.item()]

    # 5. Decode 循环
    current_token = first_token
    for _ in range(max_tokens - 1):
        # Decode 一步
        logits = decode_step(model, current_token, kv_cache)

        # 采样
        next_token = sample(logits, sampling_params)

        # 检查停止条件
        if next_token.item() == tokenizer.eos_token_id:
            break

        generated_tokens.append(next_token.item())
        current_token = next_token

    # 6. Detokenization
    output_text = tokenizer.decode(generated_tokens)
    return output_text

# 使用
output = generate(model, tokenizer, "Once upon a time", max_tokens=50)
print(output)

7.2 时序图

sequenceDiagram
    participant T as Tokenizer
    participant P as Prefill
    participant D as Decode
    participant S as Sampler
    participant C as KV Cache

    Note over T,C: 输入: "Hello"

    T->>P: token_ids = [1, 15043]
    P->>C: 初始化缓存
    P->>C: 存储 K[0:2], V[0:2]
    P->>S: logits
    S->>D: token_id = 318 ("I")

    loop Decode 循环
        D->>C: 读取 K[0:n], V[0:n]
        D->>C: 写入 K[n], V[n]
        D->>S: logits
        S->>D: next_token
    end

    Note over T,C: 输出: "I am fine"

8. 本章小结

生成流程

  1. Tokenization:文本 → Token IDs
  2. Prefill:并行处理输入,初始化 KV Cache
  3. Decode:逐个生成 token,增量更新 KV Cache
  4. Sampling:从 logits 采样 token
  5. Detokenization:Token IDs → 文本

两阶段特性

阶段PrefillDecode
并行度低(每次 1 token)
计算密度
瓶颈计算内存带宽
优化重点并行计算批处理

采样策略

  • Greedy:确定性,取最大概率
  • Temperature:控制随机程度
  • Top-k:限制候选数量
  • Top-p:动态限制累积概率

与 vLLM 的关联

  • Continuous Batching:动态组合 Prefill 和 Decode
  • Chunked Prefill:分块处理长输入
  • 采样优化:批量采样提高效率

思考题

  1. 为什么 Decode 阶段不能像 Prefill 那样并行处理多个 token?
  2. 如果使用 temperature=0,结果会和 greedy decoding 一样吗?
  3. vLLM 的 Continuous Batching 如何同时处理 Prefill 和 Decode 请求?

下一步

深度学习基础部分已经完成!接下来我们将进入核心模块详解,首先介绍 vLLM 的核心创新——PagedAttention:

👉 下一章:PagedAttention 分页注意力