Transformer 中 Decoder 端的具体输入 在结构和流程上稍复杂,因为它既要生成目标序列,又要利用 Encoder 编码的上下文语义。下面我将从整体流程、输入组成、作用及代码角度全面讲解 Decoder 的输入。
一、Decoder 的输入来源概览
Transformer 解码器的输入可以分为两个部分:
一、来自目标序列的输入(Target Input)
-
通常是目标序列(如英文翻译句子)的 “左移一位” 版本,作为 Decoder 的输入。
-
在训练阶段:会将目标句子中除了最后一个 token 以外的部分作为 Decoder 输入。
-
在推理阶段:逐步生成,每次输入当前已生成的 token。
举例:
目标句子:["I", "love", "you", "</s>"]
Decoder 输入:["<s>", "I", "love", "you"]
Decoder 输出:["I", "love", "you", "</s>"]
二、来自编码器的上下文表示(Encoder Output)
-
Encoder 会对源语言句子(如中文)生成上下文语义表示,传递给 Decoder。
-
Decoder 的中间模块(cross attention)会利用这些表示来生成输出。
二、Decoder 每层的结构和输入细节
Decoder 中的每一层包含三个主要子层,输入的处理如下:
输入 = token_embedding + positional_encoding
↓
1. Masked Multi-Head Self-Attention(只能看到过去)
↓
2. Encoder-Decoder Attention(关注原句的语义)
↓
3. Feed-Forward 层
↓
输出
1. Masked Multi-Head Self-Attention
-
输入:Decoder 端的 token 序列
-
加掩码(mask)防止当前 token “偷看”未来 token
-
实现“自回归”生成(auto-regressive generation)
2. Encoder-Decoder Attention
-
输入 query:来自 Decoder 的中间结果
-
输入 key, value:来自 Encoder 的输出表示
-
作用:将当前生成与源语言语义对齐
3. Feed-Forward Network
-
对每个位置的输出进行非线性变换
三、训练 vs 推理中的输入差异
阶段 | Decoder 的输入 | 特点 |
---|---|---|
训练阶段 | 全部目标序列左移一位(教师强制) | 可并行训练 |
推理阶段 | 初始输入为 <s> 或 <bos> ,逐个生成并反馈 | 只能一步一步生成(自回归) |
四、PyTorch 示例代码(Decoder 输入)
# 假设目标序列为:["I", "love", "you", "</s>"]
# 对应 ID: [1, 23, 45, 2]
# Decoder 输入为:["<s>", "I", "love", "you"] => [0, 1, 23, 45]
import torch
import torch.nn as nn
vocab_size = 5000
d_model = 512
max_len = 100
# Token Embedding
embedding = nn.Embedding(vocab_size, d_model)
# Positional Encoding (simplified)
position_ids = torch.arange(0, max_len).unsqueeze(0)
pos_embedding = nn.Embedding(max_len, d_model)
# 输入序列
decoder_input_ids = torch.tensor([[0, 1, 23, 45]]) # batch_size=1
# 构造 Decoder 的输入表示
token_embed = embedding(decoder_input_ids)
position_embed = pos_embedding(position_ids[:, :token_embed.size(1)])
decoder_input = token_embed + position_embed
五、总结:Decoder 的输入组成表
模块 | 输入来源 | 说明 |
---|---|---|
Token Embedding | 左移后的目标序列 | 用于构建输入 token 表示 |
Positional Encoding | 位置索引 | 编码 token 顺序 |
Masked Self-Attention | Decoder 自己生成的 token 序列 | 加掩码,防止看未来 |
Cross Attention | Encoder 的输出 | 引入源语言语义信息 |
Feed-Forward Network | 上一步注意力机制的输出 | 非线性转换,提升表达能力 |