前言
Deepseek R1把强化学习又提升一个高度,最近也在尝试学习强化学习,在此记录一下DPO的学习过程与实现过程
paper:https://arxiv.org/abs/2305.18290
一、DPO理论
DPO(Direct Preference Optimization)的基本思想很简单,是一种直接优化用户或专家偏好的方法。通过收集用户对模型输出的偏好数据,直接优化模型的参数,使得模型输出更符合用户的偏好。
DPO通常涉及构建一个偏好模型,该模型能够预测用户在给定选择中的偏好。然后,使用这些偏好数据来优化生成模型的参数。这种方法不依赖于强化学习,而是直接利用偏好数据进行优化。
其数据设置通常为输入question,chosen输出(偏好输出),rejected输出,DPO不需要复杂的reward函数,只需要设置好数据就可以,因此其方法简单高效,比较实用小任务的微调,对于相对复杂的任务表现会较差。
关于公式推导可以参考 “DPO loss详解”
DPO本质上是Bradley-Terry模型 和 KL散度的一个结合(RL其实就是在最大化奖励,最小化KL散度, 这里的Bradley-Terry模型其实就是一种奖励模型)
不想推导公式的知道DPO最后优化的是一个loss就好:
该loss表示要增加chosen和rejected的预测差距,同时保持参考模型(原始模型)和新模型(强化学习训练中的模型)的差距尽量小。
二、实践
1.Huggingface – trl
trl集成了目前所有主流的强化学习方法,使用起来还是非常方便的,但缺点也很明显,基本只能使用hugging face所支持的主流大模型才能正常使用,对于魔改或者自定义LLM,使用起来可能问题较多
from datasets import load_dataset
import torch
from trl import DPOTrainer, DPOConfig
from datasets import load_dataset
train_dataset = load_dataset('json', data_files='...json')
dpo_config = DPOConfig(
beta=0.1, # 控制训练稳定性,数值越大越稳定但收敛慢
learning_rate=5e-6,
gradient_accumulation_steps=4,
max_length=512, # 训练时的最大 token 长度
num_train_epochs=3,
output_dir="./dpo_model",
)
trainer = DPOTrainer(
model=model,
ref_model=ref_model,
args=dpo_config,
train_dataset=train_dataset,
tokenizer=model.decoder.tokenizer,
)
trainer.train()
2.pytorch 实现 DPO (LLM)
前面提到,DPO本质上最终优化的是一个loss函数,因此对于DPO强化学习可以当成一个普通的训练任务来完成
DPO loss函数如下
def dpo_loss(model: nn.Module, reference_model: nn.Module, inputs: dict, beta: float ) -> torch.Tensor:
"""
Computes the Direct Preference Optimization loss for a batch of inputs
Args:
model: The current policy model.
reference_model: The reference policy model.
inputs: A batch of inputs from the DataLoader, containing:
- 'instruction_ids': Tensor of input token IDs.
- 'chosen_ids': Tensor of chosen output token IDs.
- 'rejected_ids': Tensor of rejected output token IDs.
beta: The temperature controlling strength of KL penalty.
device: The device to perform computations on.
Returns:
loss: The computed DPO loss.
"""
# Extract input tensors and move them to the specified device
instruction_ids = inputs['instruction_ids'].cuda() #to(device)
chosen_ids = inputs['chosen_ids'].cuda() #to(device)
rejected_ids = inputs['rejected_ids'].cuda() #to(device)
# Concatenate instruction_ids with chosen_ids and rejected_ids to get the full sequences for model input
chosen_input = torch.cat([instruction_ids, chosen_ids], dim=1)
rejected_input = torch.cat([instruction_ids, rejected_ids], dim=1)
# Get model outputs for both chosen and rejected inputs
chosen_outputs = model(chosen_input)
rejected_outputs = model(rejected_input)
# Get reference model outputs for both chosen and rejected inputs
with torch.no_grad():
ref_chosen_outputs = reference_model(chosen_input)
ref_rejected_outputs = reference_model(rejected_input)
# Extract logits from the model outputs and reference model outputs
chosen_logits = chosen_outputs.logits
rejected_logits = rejected_outputs.logits
ref_chosen_logits = ref_chosen_outputs.logits
ref_rejected_logits = ref_rejected_outputs.logits
# Compute log probabilities using log_softmax
chosen_log_probs = F.log_softmax(chosen_logits, dim=-1)
rejected_log_probs = F.log_softmax(rejected_logits, dim=-1)
ref_chosen_log_probs = F.log_softmax(ref_chosen_logits, dim=-1)
ref_rejected_log_probs = F.log_softmax(ref_rejected_logits, dim=-1)
# Shift the sequences to get the targets
chosen_targets = chosen_input[:, 1:]
rejected_targets = rejected_input[:, 1:]
# Adjust log_probs to align with the targets
chosen_log_probs = chosen_log_probs[:, :-1, :]
rejected_log_probs = rejected_log_probs[:, :-1, :]
ref_chosen_log_probs = ref_chosen_log_probs[:, :-1, :]
ref_rejected_log_probs = ref_rejected_log_probs[:, :-1, :]
# Gather the log probabilities corresponding to the target tokens
chosen_token_log_probs = chosen_log_probs.gather(2, chosen_targets.unsqueeze(-1)).squeeze(-1)
rejected_token_log_probs = rejected_log_probs.gather(2, rejected_targets.unsqueeze(-1)).squeeze(-1)
ref_chosen_token_log_probs = ref_chosen_log_probs.gather(2, chosen_targets.unsqueeze(-1)).squeeze(-1)
ref_rejected_token_log_probs = ref_rejected_log_probs.gather(2, rejected_targets.unsqueeze(-1)).squeeze(-1)
# Get the length of the prompt sequence
input_length = instruction_ids.size(1)
# Create masks to exclude the prompt tokens from loss computation, we only care about the response tokens
chosen_mask = torch.ones_like(chosen_token_log_probs)
rejected_mask = torch.ones_like(rejected_token_log_probs)
# Zero out positions corresponding to the prompt tokens
chosen_mask[:, :input_length - 1] = 0 #### 过滤掉prompt部分
rejected_mask[:, :input_length - 1] = 0
# Apply masks to the token log probabilities
chosen_token_log_probs = chosen_token_log_probs * chosen_mask
rejected_token_log_probs = rejected_token_log_probs * rejected_mask
ref_chosen_token_log_probs = ref_chosen_token_log_probs * chosen_mask
ref_rejected_token_log_probs = ref_rejected_token_log_probs * rejected_mask
# Sum the log probabilities over the response tokens
chosen_log_prob = chosen_token_log_probs.sum(dim=1)
rejected_log_prob = rejected_token_log_probs.sum(dim=1)
ref_chosen_log_prob = ref_chosen_token_log_probs.sum(dim=1)
ref_rejected_log_prob = ref_rejected_token_log_probs.sum(dim=1)
# Compute the difference in log probabilities adjusted by the reference model
log_prob_diff = (chosen_log_prob - rejected_log_prob) - (ref_chosen_log_prob - ref_rejected_log_prob)
# Compute the DPO loss using the logistic loss function
loss = -F.logsigmoid(beta * log_prob_diff).mean()
return loss
可以发现dpo loss函数计算了 model 在输入chosen id和rejected id时输出的token概率,同时计算了ref_model 在输入chosen id和rejected id时输出的token概率,并基于token id找到具体token的概率值计算loss,ref_model的目的是为了不让模型偏离原始模型,可以理解为KL散度。
具体的: chosen_log_probs 的shape是 (batch_size, sequence_length, vocab_size),表示每个 token 在 vocab_size 个类别上的 log 概率分布。
gather(2, chosen_targets.unsqueeze(-1)) 在 dim=2(vocab_size 维度)上,根据 chosen_targets 指定的索引提取相应的 log 概率值。
chosen_targets 存储的是每个 token 的正确索引,所以这个操作会选取 chosen_log_probs 在 vocab_size 维度上对应 chosen_targets 位置的概率。
可以发现dpo的最终目的就是rejected对应的token prob尽可能小,同时不要偏离原模型太多,从而实现达到期望输出。
3.pytorch 实现 DPO (vLLM)
对于视觉语言模型其实和LLM差不多,只要理解了DPO的基本理念,改起来就很简单。
对于vLLM来说,decoder的输入需要image+prompt+label的拼接,因此核心改动在于model的infer传参,和mask的设置
下面是我自己的vLLM模型的修改版本
# encoder_outputs = model.encoder(image_tensor)
chosen_logits = model(image_tensor, chosen_ids, chosen_attention_mask, instruction_ids).logits
rejected_logits = model(image_tensor, rejected_ids, rejected_attention_mask, instruction_ids).logits
# print(logits.shape)
with torch.no_grad():
ref_chosen_logits = reference_model(image_tensor, chosen_ids, chosen_attention_mask, instruction_ids).logits
ref_rejected_logits = reference_model(image_tensor, rejected_ids, rejected_attention_mask, instruction_ids).logits
# Compute log probabilities using log_softmax
chosen_log_probs = F.log_softmax(chosen_logits, dim=-1)
rejected_log_probs = F.log_softmax(rejected_logits, dim=-1)
ref_chosen_log_probs = F.log_softmax(ref_chosen_logits, dim=-1)
ref_rejected_log_probs = F.log_softmax(ref_rejected_logits, dim=-1)
pad_id = torch.full((batch_size, pre_token_length), fill_value=model.decoder.tokenizer.pad_token_id).cuda()
pad_mask = torch.full((batch_size, pre_token_length), fill_value=0).cuda()
chosen_targets = torch.cat([pad_id, chosen_ids], dim=1)
rejected_targets = torch.cat([pad_id, rejected_ids], dim=1)
# Gather the log probabilities corresponding to the target tokens
chosen_token_log_probs = chosen_log_probs.gather(2, chosen_targets.unsqueeze(-1)).squeeze(-1)
rejected_token_log_probs = rejected_log_probs.gather(2, rejected_targets.unsqueeze(-1)).squeeze(-1)
ref_chosen_token_log_probs = ref_chosen_log_probs.gather(2, chosen_targets.unsqueeze(-1)).squeeze(-1)
ref_rejected_token_log_probs = ref_rejected_log_probs.gather(2, rejected_targets.unsqueeze(-1)).squeeze(-1)
chosen_mask = torch.cat([pad_mask, chosen_attention_mask], dim=1)
rejected_mask = torch.cat([pad_mask, rejected_attention_mask], dim=1)
其实这个部分的改动主要在于明确模型的输出,保持mask和输出的对应(mask遮住image+prompt部分即可)
总结
总的来说DPO是一种简单高效的强化学习方法,DPO 直接优化策略,使其输出结果更符合人类或模型的偏好,不再依赖于显式的奖励建模或复杂的策略梯度估计,训练更简单稳定。其主要思想是通过最大化策略对“更优偏好结果”与“较差结果”之间的概率比,从而优化策略。