1. MindSpore 总体架构
MindSpore是面向端-边-云全场景的深度学习框架。其典型栈分层如下:
- 前端(Frontend):以
nn.Cell
为核心的网络构建范式、自动微分、数据管道(Dataset)等。 - 中间表示(MindIR):图级中间表示,承载算子融合、常量折叠、内存分配与形状推断等编译优化。
- 执行模式:
- Graph Mode:将 Python 描述的网络转为静态图,进行深度图优化后下发后端(适合昇腾的高性能训练/推理)。
- PyNative Mode:动态图风格,便于调试与原型验证。
- 自动并行:内置数据并行、模型并行、流水线并行与切分策略搜索,适配大模型训练。
- 后端(Backend):对接昇腾(Ascend)、GPU、CPU 等,调用相应算子库执行。
对大模型而言,MindSpore 的 图编译+ 自动并行能显著降低显存开销并提升执行效率,是在昇腾平台进行参数高效微调(如 Prompt Tuning)的关键支撑。
2. Prompt Tuning 基本概念
Prompt Tuning是一种 参数高效微调(PEFT)思路:冻结预训练大模型的主体参数,仅训练规模很小的“提示参数(prompt parameters)”,即可将通用模型适配到新任务。常见变体:
- Prompt Tuning / Soft Prompt:为输入序列前置一段可学习的连续向量(非离散文本),仅训练这些向量。
- Prefix Tuning:为每层注意力模块注入一段可学习的 K/V 前缀(不直接改变输入 token),影响注意力分布。
- P-Tuning v2:用一个小型提示编码器(如 MLP 或轻量 Transformer)生成上下文相关的提示向量,提升表达能力。
- (对比)LoRA:对线性层施加低秩适配矩阵,也属于 PEFT,但不等同于 Prompt Tuning。
优势:显著减少可训练参数与显存占用,训练更快、迁移更广,适合算力/数据有限的场景。
3. 在 MindSpore 架构中的实现思路
3.1 Soft Prompt Tuning(输入侧)
核心思想:在 embedding 之后、Transformer 编码之前,拼接一段长度为 P
的可学习提示向量(形状 [P, H]
,H
为隐藏维),对每个 batch 复制成 [B, P, H]
与原始 token embedding [B, L, H]
在序列维拼接为 [B, P+L, H]
。仅更新提示参数,冻结大模型权重。
实现要点(MindSpore):
- 使用
Parameter(Tensor(...))
定义soft_prompt
,requires_grad=True
。 - 冻结 backbone:遍历
net.trainable_params()
,将非提示参数的requires_grad=False
。 - 注意力 mask与 位置编码(或 RoPE)需相应右移
P
位;padding/masking 必须覆盖提示段与真实 token。
3.2 Prefix Tuning(注意力侧)
核心思想:在每一层 Multi-Head Attention 中,构造长度为 P
的 K/V 前缀(形如 [B, num_heads, P, head_dim]
),与由真实 token 产生的 K/V
在序列维拼接;Q
仍来自真实 token。这样在不改变输入序列的情况下调控注意力。
实现要点:
- 为每层准备一个小型前缀生成器(可为直接可学习的参数,或经 MLP 投影)。
- 在前向传播中,于各层注意力计算前注入前缀 K/V;相应扩展
attn_mask
。 - 仍然只训练前缀相关参数。
3.3 P-Tuning v2(提示编码器)
核心思想:使用轻量网络(例如 MLP/小 Transformer)从任务条件、标签描述或输入摘要生成提示向量,解决固定软提示表达力不足的问题。
实现要点:
- 定义
PromptEncoder(nn.Cell)
,输入可为任务标签、离散 prompt id、任务描述 embedding 等,输出[B, P, H]
。 - 保持主干模型冻结,仅训练提示编码器参数。
- 与 4.1/4.2 的注入位置类似:要么拼接到输入 embedding,要么投递到各层注意力的 K/V。
4. 昇腾 + MindSpore 下的性能与工程要点
- Graph Mode 优先:静态图优化(图算融合、常量折叠、内存复用)对 NPU 至关重要。
- 混合精度:启用 FP16/BF16 训练提示参数,配合损失缩放稳定数值。
- 固定形状/对齐长度:对序列长度进行 padding 到固定值,利于编译优化与算子并行。
- 通信与并行:Prompt Tuning 通常参数小,数据并行(DP)即可;若大 batch,可结合梯度累积。
- 编译缓存:复用图编译缓存能缩短反复实验的冷启动时间。
- 算子选择:尽量复用原模型 embedding/attention 的高效算子;前缀注入避免 Python 侧循环。