构造第一个模型
前言:集齐龙珠,召唤“神龙”——LLaMA的总装时刻
在过去的几章中,我们像锻造师一样,精心打造了Transformer的每一个核心零件:
我们学会了“Attention”这个灵魂
我们升级了“Attention”的威力,使其具备了“多视角”能力
我们还补齐了它的“左膀右臂”——前馈网络、残差连接和层归一化
现在,是时候进行最终的总装了!我们将把这些经过“千锤百炼”的零件,按照Meta AI发布的LLaMA(当今开源大模型的“事实标准”)的精妙设计蓝图,层层拼接,最终构建出一个与真实LLaMA架构**“神形兼备”**的语言模型。
这趟“总装”之旅,将让你对现代大模型的内部运作,有一个前所未有的、深刻的理解。
第一章:LLaMA的“进化”回顾与核心组件 (快速复习)
快速回顾LLaMA相比于原始Transformer架构,在三个核心组件上的“魔鬼细节”改进,并重申多头注意力与前馈网络的核心地位。
1.1 RMSNorm:更高效的归一化
LLaMA抛弃了传统的LayerNorm,转而使用RMSNorm。它通过只计算均方根(RMS),省略了“减去均值”这一步。这不仅计算速度更快,而且在实践中发现效果几乎不受影响,甚至在深层网络中表现更稳定。它是一个典型的“减法智慧”。
1.2 SwiGLU:更智能的信息筛选门
原始Transformer的前馈网络通常使用ReLU或GELU。LLaMA引入了SwiGLU激活函数。这种激活函数带有一种“门控”机制,能让前馈网络动态地、根据输入内容来筛选和控制信息流,从而学习到更复杂的特征表示。它比简单的ReLU更具“智能”。
1 .3 旋转位置编码 (RoPE):优雅注入相对位置感
原始的Transformer使用绝对位置编码,简单地将位置信息“加”到词嵌入上。LLaMA则采用了旋转位置编码(RoPE)。RoPE通过复数旋转的方式,将位置信息**“乘”入到Q(查询)和K(键)向量中。这种方式能更好地捕捉词与词之间的相对位置关系**,并且在处理比训练时更长的文本时,表现出更好的外推性和稳定性。
1.4 多头注意力与前馈网络:两大核心算子
这两个模块依然是Transformer Block中最重要的两大计算单元。它们负责信息的交互、融合(多头注意力)和深度加工、抽象(前馈网络)。LLaMA的改进,就是将上面三个“魔鬼细节”巧妙地集成到这两个核心模块中。
第二章:【代码实战】逐一实现LLaMA的“进化版”核心组件
我们将把LLaMA的“魔鬼细节”转化为可运行的PyTorch代码模块。这些代码将是基于我们之前基础组件的升级版。为了代码的组织性和可读性,我们将把这些独立的模块定义在一个名为llama_components.py的文件中。
2.1实现RMSNorm层
LLaMA的RMSNorm比标准的nn.LayerNorm更简单、更高效。
# llama_components.py
import torch
import torch.nn as nn
import torch.nn.functional as F
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
# gamma参数,用于缩放归一化后的输出,是可学习的
# 它的作用是让模型可以自由地调整归一化后的输出范围
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
# 计算均方根 (Root Mean Square)
# x.pow(2).mean(-1, keepdim=True) 计算最后一个维度的平方均值
# torch.rsqrt 是计算 1 / sqrt(),比先sqrt再求倒数更数值稳定
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
# type_as(self.weight) 确保输入张量x的数据类型与self.weight一致,避免类型不匹配问题
return self._norm(x.type_as(self.weight)) * self.weight
# --- 测试 RMSNorm ---
if __name__ == "__main__":
print("--- 测试 RMSNorm ---")
dim = 768
rmsnorm = RMSNorm(dim)
test_tensor = torch.randn(2, 10, dim) # [batch_size, seq_len, embed_dim]
output_rmsnorm = rmsnorm(test_tensor)
print(f"RMSNorm 输入形状: {test_tensor.shape}, 输出形状: {output_rmsnorm.shape}")
# 均方根应接近1 (在RMSNorm后),验证归一化效果
print(f"RMSNorm 输出的均方根: {torch.sqrt(output_rmsnorm.pow(2).mean(-1)).mean().item():.4f} (应接近1.0)")
print("-" * 30)
RMSNorm的核心在于_norm方法,它通过计算输入张量最后一个维度(特征维度)的均方根,然后用其倒数来缩放输入。self.weight是一个可学习的参数,它允许模型在归一化后对特征进行额外的缩放。
2.2 实现SwiGLU激活函数
LLaMA的前馈网络中,使用了SwiGLU激活函数,它比ReLU或GELU更复杂,但具有门控机制。
# llama_components.py (续)
class SwiGLU(nn.Module):
# SwiGLU作为FFN的激活函数,其输入是升维后的结果
# 例如,FFN的第一层可能将 embed_dim 映射到 2 * hidden_dim
# 然后SwiGLU接收这个 2 * hidden_dim 的输入,并将其切半处理
def forward(self, x: torch.Tensor):
# x的形状: [batch_size, seq_len, 2 * hidden_dim]
# x.shape[-1] 是最后一个维度的大小,即 2 * hidden_dim
# x[..., :x.shape[-1] // 2] 取前半部分,x[..., x.shape[-1] // 2:] 取后半部分
# F.silu 是 Swish 激活函数,等同于 x * sigmoid(x)
return F.silu(x[..., :x.shape[-1] // 2]) * x[..., x.shape[-1] // 2:]
# --- 测试 SwiGLU (模拟FFN的中间输出) ---
if __name__ == "__main__":
print("--- 测试 SwiGLU ---")
embed_dim = 768
hidden_dim = embed_dim * 4 # LLaMA中FFN中间维度通常是embed_dim的4倍
# 模拟FFN第一层将embed_dim投影到 hidden_dim (实际是2*hidden_dim for SwiGLU input)
test_tensor_input = torch.randn(2, 10, hidden_dim) # 假设这里已经是FFN第一层输出的hidden_dim部分
# 为了测试SwiGLU的forward,我们需要手动构建一个满足其forward输入格式的tensor
# SwiGLU的输入是 Linear(embed_dim, 2 * hidden_dim) 的输出
# 假设这里的 hidden_dim 已经就是 SwiGLU 实际处理的 hidden_dim
# 所以我们构造一个形状为 [B, L, 2*hidden_dim] 的 tensor
mock_ffn_output = torch.randn(2, 10, 2 * hidden_dim)
swiglu = SwiGLU()
# 实际FFN中,这个x会是 w1(x) 和 w3(x) 拼接后的结果
output_swiglu = swiglu(mock_ffn_output)
print(f"SwiGLU 输入形状: {mock_ffn_output.shape}, 输出形状: {output_swiglu.shape}")
# 输出应该是 [B, L, hidden_dim],因为 2*hidden_dim 被处理后变成了 hidden_dim
assert output_swiglu.shape[-1] == hidden_dim
print("-" * 30)
【代码解读】
SwiGLU通过将输入x切成两半,其中一半应用SiLU(Swish)激活函数,然后与另一半进行逐元素乘法,实现了一个“门控”效果。这种非线性处理能更细致地控制信息流动,提升模型的表达能力。
2.3 实现RoPE(旋转位置编码)
这是LLaMA中最精妙的部分之一。RoPE将位置信息通过旋转的方式注入到Q和K向量中,而不是简单相加。
# llama_components.py (续)
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
"""
预计算旋转位置编码所需的频率。
dim: 旋转编码作用的向量维度 (即 head_dim)
end: 最大序列长度 (max_seq_len)
theta: 频率计算的基数,通常取10000.0
"""
# 计算频率基 freqs: 1 / (theta^(2i/dim))
# torch.arange(0, dim, 2) 得到 [0, 2, 4, ..., dim-2]
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
# t: 位置索引 [0, 1, ..., end-1]
t = torch.arange(end, device=freqs.device)
# 外积: 得到 freqs_table[pos_idx, freq_idx] = pos_idx * freqs[freq_idx]
freqs = torch.outer(t, freqs).float()
# torch.polar(abs, angle) 返回复数 abs * (cos(angle) + i * sin(angle))
# 这里的abs是1,所以就是 cos(freq) + i * sin(freq)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis # 形状: [end, dim // 2] (复数维度)
def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor):
"""
对Q和K向量应用旋转位置编码。
xq, xk 形状: [bsz, seqlen, n_heads, head_dim] (Attention模块内部的维度顺序)
freqs_cis 形状: [seqlen, head_dim // 2] (复数维度)
"""
# 将Q, K的head_dim维度分成两半,并转换为复数
# reshape成 [bsz, seqlen, n_heads, head_dim // 2, 2]
# 然后 view_as_complex 转换为 [bsz, seqlen, n_heads, head_dim // 2] (复数)
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
# 应用旋转(复数乘法)
# freqs_cis 会自动广播到 n_heads 维度
# 注意: freqs_cis[:xq.shape[1]] 确保只取到当前序列长度的频率
xq_out = xq_ * freqs_cis[:xq.shape[1]].unsqueeze(1) # unsqueeze(1)是为了与n_heads维度匹配
xk_out = xk_ * freqs_cis[:xk.shape[1]].unsqueeze(1)
# 将复数结果转换回实数并调整形状
# view_as_real 将复数分解为实部和虚部 [bsz, seqlen, n_heads, head_dim // 2, 2]
# flatten(2) 将后两个维度压平,恢复为 [bsz, seqlen, n_heads, head_dim]
xq_out = torch.view_as_real(xq_out).flatten(2)
xk_out = torch.view_as_real(xk_out).flatten(2)
# .type_as(xq) 确保输出数据类型与输入xq保持一致
return xq_out.type_as(xq), xk_out.type_as(xk)
# --- 测试 RoPE (概念验证) ---
if __name__ == "__main__":
print("--- 测试 RoPE ---")
head_dim = 64
seq_len = 10
n_heads = 2
xq_test = torch.randn(1, seq_len, n_heads, head_dim) # 模拟Attention内部的Q
xk_test = torch.randn(1, seq_len, n_heads, head_dim) # 模拟Attention内部的K
freqs_cis = precompute_freqs_cis(head_dim, seq_len)
xq_rotated, xk_rotated = apply_rotary_emb(xq_test, xk_test, freqs_cis)
print(f"RoPE前 Query形状: {xq_test.shape}, 后 Query形状: {xq_rotated.shape}")
print(f"RoPE前 Key形状: {xk_test.shape}, 后 Key形状: {xk_rotated.shape}")
# 验证输出形状与输入保持一致
assert xq_test.shape == xq_rotated.shape
assert xk_test.shape == xk_rotated.shape
print("-" * 30)
【代码解读】
precompute_freqs_cis预先计算了旋转所需的频率。apply_rotary_emb是核心,它将Q和K向量的最后一个维度转换为复数,与预计算的复数频率相乘(复数乘法在几何上就是旋转),从而将位置信息编码进去。unsqueeze(1)和permute等操作用于确保维度广播和矩阵乘法时的正确对齐。
2.4 实现应用了RoPE的SelfAttention模块
将多头注意力的核心逻辑与RoPE结合。这里我们直接实现一个LLaMA风格的Attention模块,它将RoPE的应用集成在forward方法内部。
# llama_components.py (续)
class Attention(nn.Module):
def __init__(self, embed_dim: int, n_heads: int):
super().__init__()
self.n_heads = n_heads
self.head_dim = embed_dim // n_heads # 每个头的维度
assert self.head_dim * n_heads == embed_dim, "embed_dim must be divisible by n_heads"
# Q, K, V 的投影层,注意bias=False,是LLaMA的特性
# self.wq = nn.Linear(embed_dim, embed_dim, bias=False)
# self.wk = nn.Linear(embed_dim, embed_dim, bias=False)
# self.wv = nn.Linear(embed_dim, embed_dim, bias=False)
# LLaMA中QKV是合在一起的,更高效 (尽管这里我们分开写为了清晰)
self.wq = nn.Linear(embed_dim, n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(embed_dim, n_heads * self.head_dim, bias=False)
self.wv = nn.Linear(embed_dim, n_heads * self.head_dim, bias=False)
# 最终的输出投影层
self.wo = nn.Linear(n_heads * self.head_dim, embed_dim, bias=False)
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor, mask: torch.Tensor = None):
bsz, seqlen, embed_dim = x.shape
# 1. 生成Q, K, V (通过线性投影)
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
# 2. 分头操作 & 调整维度顺序
# 原始: [bsz, seqlen, embed_dim]
# 拆分并调整为: [bsz, seqlen, n_heads, head_dim] -> [bsz, n_heads, seqlen, head_dim]
xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
xk = xk.view(bsz, seqlen, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
xv = xv.view(bsz, seqlen, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
# 3. 应用RoPE (旋转位置编码)
# freqs_cis 的形状 [seqlen, head_dim // 2]
# xq, xk 的形状 [bsz, n_heads, seqlen, head_dim]
# apply_rotary_emb 内部会处理维度匹配
xq_out, xk_out = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
# 4. 计算注意力分数
# (bsz, n_heads, seqlen, head_dim) @ (bsz, n_heads, head_dim, seqlen) -> (bsz, n_heads, seqlen, seqlen)
scores = torch.matmul(xq_out, xk_out.transpose(2, 3)) / (self.head_dim**0.5)
# 应用注意力遮罩 (例如:因果遮罩,防止预测未来)
if mask is not None:
# mask形状: [1, 1, seqlen, seqlen]
scores = scores + mask # mask通常包含负无穷大,加到scores上会让mask掉的位置softmax后为0
attention_weights = F.softmax(scores.float(), dim=-1).type_as(xq_out) # Softmax归一化
# 5. 加权求和
# (bsz, n_heads, seqlen, seqlen) @ (bsz, n_heads, seqlen, head_dim) -> (bsz, n_heads, seqlen, head_dim)
output = torch.matmul(attention_weights, xv)
# 6. 拼接各个头 & 最终投影
# 将多头输出的维度调整回来
# 原始: [bsz, n_heads, seqlen, head_dim]
# 调整为: [bsz, seqlen, n_heads, head_dim] -> [bsz, seqlen, embed_dim]
output = output.permute(0, 2, 1, 3).reshape(bsz, seqlen, self.n_heads * self.head_dim)
# 最终投影层
return self.wo(output)
# --- 测试 Attention ---
if __name__ == "__main__":
print("--- 测试 Attention ---")
embed_dim = 768
n_heads = 12
seqlen = 10
bsz = 2
x_test = torch.randn(bsz, seqlen, embed_dim)
# 预计算 RoPE 频率,需要最大序列长度
max_seq_len_for_rope = 2048 # 例如 LLaMA 2 的默认值
freqs_cis = precompute_freqs_cis(embed_dim // n_heads, max_seq_len_for_rope)
attention_module = Attention(embed_dim, n_heads)
output_attention = attention_module(x_test, freqs_cis)
print(f"Attention 输入形状: {x_test.shape}, 输出形状: {output_attention.shape}")
assert x_test.shape == output_attention.shape # 验证输入输出形状一致
print("-" * 30)
【代码解读】
这个Attention模块是LLaMA中多头注意力机制的完整实现,它集成了RoPE的应用、多头的切分与拼接,以及注意力分数的计算与最终输出投影。
2.5 实现LLaMA风格的FeedForward模块
结合了SwiGLU激活函数。
# llama_components.py (续)
class FeedForward(nn.Module):
def __init__(self, embed_dim: int, hidden_dim: int):
super().__init__()
# LLaMA中FFN的三个线性层,注意bias=False
self.w1 = nn.Linear(embed_dim, hidden_dim, bias=False) # 升维 + 第一个门控部分
self.w2 = nn.Linear(hidden_dim, embed_dim, bias=False) # 降维
self.w3 = nn.Linear(embed_dim, hidden_dim, bias=False) # 第二个门控部分
def forward(self, x: torch.Tensor):
# x的形状: [batch_size, seq_len, embed_dim]
# LLaMA的SwiGLU实现是 (swish(w1(x)) * w3(x)) @ w2
# w1(x) 和 w3(x) 的输出形状都是 [batch_size, seq_len, hidden_dim]
return self.w2(F.silu(self.w1(x)) * self.w3(x))
# --- 测试 FeedForward ---
if __name__ == "__main__":
print("--- 测试 FeedForward ---")
embed_dim = 768
hidden_dim = embed_dim * 4 # LLaMA通常将FFN的隐藏层维度设置为嵌入维度的4倍
ffn_module = FeedForward(embed_dim, hidden_dim)
test_tensor_input = torch.randn(2, 10, embed_dim)
output_ffn = ffn_module(test_tensor_input)
print(f"FeedForward 输入形状: {test_tensor_input.shape}, 输出形状: {output_ffn.shape}")
assert test_tensor_input.shape == output_ffn.shape # 验证输入输出形状一致
print("-" * 30)
【代码解读】
这个FFN模块完全按照LLaMA的SwiGLU实现方式,通过self.w1, self.w3和F.silu实现门控,然后通过self.w2进行最终的投影。
第三章:【终极组装】构建完整的Transformer Block与LLaMA模型
将所有“进化版”的零件,按照LLaMA的“预归一化”结构,层层拼接,最终构建成一个完整的LLaMA模型。
3.1 LLaMA Block的“预归一化”结构:稳固的基石
LLaMA Transformer Block的数据流:
输入x
x -> RMSNorm 1 (对输入进行归一化)
归一化后的 x -> Attention (with RoPE)
Attention输出 -> 与原始x进行残差连接 (h = x + attention_output)
结果h -> RMSNorm 2 (对Attention后的结果进行归一化)
归一化后的结果 -> FeedForward (with SwiGLU)
FeedForward输出 -> 与第二次归一化的输入h进行残差连接 (out = h + ffn_output)
最终输出out
3.2 将所有组件拼接成一个LLaMA风格的TransformerBlock
# llama_model.py
import torch
import torch.nn as nn
import torch.nn.functional as F
# 导入我们之前在llama_components.py中实现的模块
from llama_components import RMSNorm, Attention, FeedForward, precompute_freqs_cis # 确保llama_components.py在同一目录
class TransformerBlock(nn.Module):
def __init__(self, embed_dim: int, n_heads: int, hidden_dim: int, norm_eps: float):
super().__init__()
# 注意力模块
self.attention = Attention(embed_dim, n_heads)
# 前馈网络模块
self.feed_forward = FeedForward(embed_dim, hidden_dim)
# 两个RMSNorm层,用于预归一化
self.attention_norm = RMSNorm(embed_dim, eps=norm_eps)
self.ffn_norm = RMSNorm(embed_dim, eps=norm_eps)
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor, mask: torch.Tensor = None):
# 1. 预归一化 + Attention + 残差连接
# attention_norm(x) 对Attention的输入进行归一化
h = x + self.attention(self.attention_norm(x), freqs_cis, mask)
# 2. 预归一化 + FeedForward + 残差连接
# ffn_norm(h) 对FeedForward的输入进行归一化 (注意这里是h,即attention后的结果)
out = h + self.feed_forward(self.ffn_norm(h))
return out
# --- 测试 TransformerBlock ---
if __name__ == "__main__":
print("--- 测试 TransformerBlock ---")
embed_dim = 768
n_heads = 12
hidden_dim = embed_dim * 4
norm_eps = 1e-6
seqlen = 10
bsz = 2
x_test = torch.randn(bsz, seqlen, embed_dim)
# RoPE频率需要根据head_dim计算,并传入Attention模块
max_seq_len_for_rope = 2048 # 例如 LLaMA 2 的默认值
freqs_cis = precompute_freqs_cis(embed_dim // n_heads, max_seq_len_for_rope)
# 模拟因果遮罩
mask = torch.full((1, 1, seqlen, seqlen), float("-inf")).triu(diagonal=1)
block = TransformerBlock(embed_dim, n_heads, hidden_dim, norm_eps)
output_block = block(x_test, freqs_cis, mask=mask)
print(f"TransformerBlock 输入形状: {x_test.shape}, 输出形状: {output_block.shape}")
assert x_test.shape == output_block.shape # 验证输入输出形状一致
print("-" * 30)
【代码解读】
这个TransformerBlock完整地实现了LLaMA的**预归一化(Pre-Norm)**结构。残差连接的写法h = x + …和out = h + …清晰地展示了输入如何被直接加到子模块的输出上,以构建“高速公路”。
3.3 堆叠Blocks,构建完整的LLaMA模型类
# llama_model.py (续)
class LLaMA(nn.Module):
def __init__(self, vocab_size: int, embed_dim: int, n_layers: int, n_heads: int, hidden_dim: int, norm_eps: float, max_seq_len: int):
super().__init__()
self.vocab_size = vocab_size
self.n_layers = n_layers
self.max_seq_len = max_seq_len
self.embed_dim = embed_dim
self.n_heads = n_heads
# 词嵌入层
self.tok_embeddings = nn.Embedding(vocab_size, embed_dim)
# 预计算旋转位置编码频率 (整个模型共享一套,根据最大序列长度计算)
self.freqs_cis = precompute_freqs_cis(self.embed_dim // self.n_heads, self.max_seq_len)
# 堆叠Transformer Blocks
self.layers = nn.ModuleList([
TransformerBlock(embed_dim, n_heads, hidden_dim, norm_eps) for _ in range(n_layers)
])
self.norm = RMSNorm(embed_dim, eps=norm_eps) # 最终的RMSNorm层
# 映射到词汇表大小的输出层 (预测下一个Token)
self.output = nn.Linear(embed_dim, vocab_size, bias=False)
def forward(self, tokens: torch.Tensor, mask: torch.Tensor = None):
_bsz, seqlen = tokens.shape # batch_size, sequence_length
# 词嵌入
h = self.tok_embeddings(tokens) # 形状: [bsz, seqlen, embed_dim]
# 获取当前序列长度对应的RoPE频率
freqs_cis = self.freqs_cis[:seqlen]
# 逐层通过Transformer Blocks
for layer in self.layers:
h = layer(h, freqs_cis, mask) # 形状始终保持 [bsz, seqlen, embed_dim]
# 最终归一化和输出
h = self.norm(h) # 形状: [bsz, seqlen, embed_dim]
output = self.output(h) # 形状: [bsz, seqlen, vocab_size] (Logits)
return output
# --- 测试 LLaMA 模型 (主程序) ---
if __name__ == "__main__":
print("--- 测试 完整 LLaMA 模型架构 ---")
# 模拟LLaMA的参数 (简化版,真实LLaMA参数更大)
vocab_size = 32000 # LLaMA 2常用的词汇表大小,这里只是一个示例值
embed_dim = 768 # 嵌入维度 (d_model)
n_layers = 6 # 堆叠6个Transformer Block (真实LLaMA-7B有32层)
n_heads = 12 # 头部数量
hidden_dim = embed_dim * 4 # FFN的隐藏层维度
norm_eps = 1e-6
max_seq_len = 256 # 最大序列长度,用于预计算RoPE
# 实例化我们的LLaMA模型
model = LLaMA(vocab_size, embed_dim, n_layers, n_heads, hidden_dim, norm_eps, max_seq_len)
# 模拟输入Token IDs (batch_size=2, seq_len=10)
bsz_test, seqlen_test = 2, 10
dummy_tokens = torch.randint(0, vocab_size, (bsz_test, seqlen_test))
# 为自回归语言模型生成因果遮罩 (Causal Mask)
# 形状: [1, 1, seqlen, seqlen]
# triu(diagonal=1) 将上三角部分设置为-inf,表示未来信息不可见
mask = torch.full((1, 1, seqlen_test, seqlen_test), float("-inf")).triu(diagonal=1)
# 将模型和输入数据移动到设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
dummy_tokens = dummy_tokens.to(device)
mask = mask.to(device)
# 前向传播
output_logits = model(dummy_tokens, mask=mask)
print(f"完整LLaMA模型 输入形状: {dummy_tokens.shape}")
print(f"完整LLaMA模型 输出Logits形状: {output_logits.shape}") # [bsz, seqlen, vocab_size]
assert output_logits.shape == torch.Size([bsz_test, seqlen_test, vocab_size])
print("\n✅ 完整 LLaMA 模型架构测试通过!")
【代码解读】
LLaMA类:这是整个模型的顶层容器,它封装了词嵌入层、堆叠的TransformerBlock以及最终的输出层。
freqs_cis:RoPE的频率在模型初始化时就预计算好了,这样在每次前向传播时,可以直接取出对应长度的频率。
mask:这里加入了**因果遮罩(Causal Mask)**的生成逻辑,这对于自回归语言模型至关重要,它确保模型在预测当前词时不会“偷看”到未来的词。
第四章:运行与验证:让我们的“迷你LLaMA”开口说话
4.1 实例化模型并进行文本生成
们将实例化我们自己构建的LLaMA模型,并实现一个简单的文本生成方法,验证我们的“创世”成果。
由于我们没有进行预训练,我们自己构建的模型权重是随机的。所以它不会生成有意义的文本,但只要它能生成符合语法结构的、连贯的Token ID序列,就足以证明我们整个架构的搭建是成功的!
我们将实现一个简单的**贪婪解码(Greedy Decoding)**生成函数,它每次选择概率最高的下一个Token。
# llama_inference.py (这是一个独立文件,用于演示生成)
import torch
# 导入我们之前定义的LLaMA模型和RoPE的预计算函数
from llama_model import LLaMA, precompute_freqs_cis # 确保 llama_model.py 在同一目录
from transformers import AutoTokenizer # 用于加载Hugging Face的tokenizer来编码输入和解码输出
# --- 1. 定义模型参数和加载模型 ---
# 这些参数必须和 llama_model.py 中 LLaMA 模型的 __init__ 参数保持一致
# 我们这里使用GPT-2的词汇表大小,因为它有预训练好的tokenizer,方便演示
vocab_size = 50257 # GPT-2的词汇表大小 (tokenizer.vocab_size)
embed_dim = 768
n_layers = 6
n_heads = 12
hidden_dim = embed_dim * 4
norm_eps = 1e-6
max_seq_len = 256 # 模型最大序列长度
# 实例化我们自己构建的LLaMA模型
model = LLaMA(vocab_size, embed_dim, n_layers, n_heads, hidden_dim, norm_eps, max_seq_len)
# 如果你训练过模型并保存了权重,可以在这里加载
# model.load_state_dict(torch.load("path/to/your/trained_llama_weights.pth"))
model.eval() # 设置为评估模式
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# 使用GPT-2的tokenizer进行文本处理(因为它和我们模型的embedding_dim匹配,且易于使用)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
print("模型和Tokenizer加载完成!")
# --- 2. 文本生成函数 ---
def generate_text(model, tokenizer, prompt: str, max_new_tokens: int = 50):
model.eval() # 确保模型处于评估模式
# 编码Prompt为Token IDs
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
generated_ids = input_ids # 初始化生成序列
# 循环生成新的Token,直到达到最大生成数量或遇到结束符
for _ in range(max_new_tokens):
# 检查当前序列长度是否超出模型最大长度
current_seq_len = generated_ids.shape[1]
if current_seq_len >= model.max_seq_len:
print(f"警告:生成序列 ({current_seq_len}) 达到模型最大长度 ({model.max_seq_len}),已截断。")
break
# 计算因果遮罩 (Causal Mask)
# 这是一个上三角矩阵,对角线和左下角为0,右上角为负无穷大,确保模型无法看到未来的Token
mask = torch.full((1, 1, current_seq_len, current_seq_len), float("-inf"), device=device).triu(diagonal=1)
# 前向传播,获取logits (预测分数)
# 我们只关心最后一个Token的预测,所以输入只取最后一个Token
with torch.no_grad():
# output_logits 的形状是 [bsz, seqlen, vocab_size]
output_logits = model(generated_ids, mask=mask)
# 取得最后一个时间步的logits,形状 [bsz, vocab_size]
next_token_logits = output_logits[:, -1, :]
# 贪婪解码: 选择概率最高的Token作为下一个Token
# torch.argmax 找到最大值的索引
# .unsqueeze(0) 增加一个维度,以便于后续与 generated_ids 拼接
next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)
# 将新生成的Token拼接到已生成的序列中
generated_ids = torch.cat((generated_ids, next_token_id), dim=1)
# 解码并打印当前生成的文本 (可以根据需要调整打印频率)
# skip_special_tokens=True 确保不会打印出如 <|endoftext|> 这样的特殊token
decoded_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
print(f"生成中 ({_ + 1}/{max_new_tokens}): {decoded_text}")
# 简单停止条件: 如果生成了结束符,就停止
if next_token_id.item() == tokenizer.eos_token_id:
print("生成结束符,停止生成。")
break
final_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
return final_text
# --- 3. 运行生成 ---
if __name__ == "__main__":
prompt_text = "The quick brown fox jumps over the lazy"
print(f"--- 启动生成,Prompt: '{prompt_text}' ---")
generated_text = generate_text(model, tokenizer, prompt_text, max_new_tokens=50)
print("\n--- 最终生成结果 ---")
print(generated_text)
【代码解读与验证】
运行这段llama_inference.py脚本,你会看到模型一步步地“吐”出Token,并拼接成文本。
关键验证点:
代码能跑:最基本的成功,证明你所有组件的接口都是正确的,并且能够协同工作。
形状匹配:在打印输出时,你会看到每个Tensor的形状都符合我们的预期 [bsz, seqlen, embed_dim] 和 [bsz, seqlen, vocab_size]。
文本连贯性 (尽管随机):由于模型没有经过预训练,它的权重是随机的,所以生成的文本会是“胡言乱语”。但重要的是,这些“胡言乱语”是符合Token ID序列的,并且会呈现出一些语法上的连贯性(比如,经常出现“the”、“is”这样的高频词,甚至会形成一些“像模像样”的单词组合,尽管没有意义)。这证明了我们搭建的LLaMA架构,能够正确地处理和生成语言序列。
这个结果,就是你亲手构建的LLaMA架构的“第一次呼吸”!你已经用代码验证了LLaMA模型从输入到输出的整个前向传播和生成流程是正确的。
【一张图看懂】原始Transformer vs. LLaMA Block的架构差异
KV Cache的秘密:为什么你的聊天机器人能“秒回”?
在自回归语言模型(如GPT、LLaMA)生成文本时,模型需要一步一步地预测下一个词。每预测一个新词,都需要重新计算这个新词与所有历史词之间的注意力。如果序列很长,每次都重新计算所有历史K(Key)和V(Value)向量,效率会非常低。
KV Cache(Key-Value Cache)是一种绝妙的优化技巧。它的核心思想是:
当模型计算完第一个词的K和V向量后,会把它们**“缓存”**起来。
在预测第二个词时,只需要计算第二个词自己的Q、K、V。然后,将新计算的K和V,与缓存里的所有
历史K和V拼接起来,再进行注意力计算。
这样,每次迭代只需要对当前这一个新词进行大部分前向传播,而不需要重复计算整个历史K和V,从而让文本生成(推理)的速度提升数十倍到数百倍!
这就像你听写时,每次只写一个字,但你不需要把之前写过的所有字都重新写一遍。你只需要记住当前的笔迹,然后接着往下写。
总结与展望:你已拥有构建SOTA模型的“蓝图”
恭喜你,伟大的“AI架构师”!今天,你已经完成了对当今最成功的开源大模型——LLaMA——的一次**“像素级”复刻**。
✨ 本章惊喜概括 ✨
你掌握了什么? | 对应的LLaMA“魔鬼细节” |
---|---|
更高效的归一化 | ✅ RMSNorm的原理与实现 |
更智能的前馈网络 | ✅ SwiGLU的“门控”思想与实现 |
更优雅的位置编码 | ✅ RoPE的相对位置思想与实现 |
完整的SOTA模型架构 | ✅ 从零构建了一个LLaMA风格的语言模型 |
推理加速的秘密 | ✅ 了解了KV Cache的核心原理 |
你不再仅仅是知道Transformer是什么,你现在深刻地理解了它是如何从一个“经典设计”,一步步进化成一个主宰当今AI世界的“性能猛兽”的。你手中掌握的,是一份真正能构建SOTA模型的“架构蓝图”。 | |
🔮 敬请期待! 我们的第三篇章**《模型架构全景拆解》将继续深入。在下一章,我们将把目光转向视觉领域**,去解剖另一个同样重要、同样强大的架构——Diffusion UNet,看看AI绘画的背后,又是怎样一番精妙的设计。 |