LLM:强化学习DPO实现


前言

Deepseek R1把强化学习又提升一个高度,最近也在尝试学习强化学习,在此记录一下DPO的学习过程与实现过程

paper:https://arxiv.org/abs/2305.18290


一、DPO理论

DPO(Direct Preference Optimization)的基本思想很简单,是一种直接优化用户或专家偏好的方法。通过收集用户对模型输出的偏好数据,直接优化模型的参数,使得模型输出更符合用户的偏好

DPO通常涉及构建一个偏好模型,该模型能够预测用户在给定选择中的偏好。然后,使用这些偏好数据来优化生成模型的参数。这种方法不依赖于强化学习,而是直接利用偏好数据进行优化。

其数据设置通常为输入question,chosen输出(偏好输出),rejected输出,DPO不需要复杂的reward函数,只需要设置好数据就可以,因此其方法简单高效,比较实用小任务的微调,对于相对复杂的任务表现会较差。

关于公式推导可以参考DPO loss详解
DPO本质上是Bradley-Terry模型 和 KL散度的一个结合(RL其实就是在最大化奖励,最小化KL散度, 这里的Bradley-Terry模型其实就是一种奖励模型)

不想推导公式的知道DPO最后优化的是一个loss就好

请添加图片描述

该loss表示要增加chosen和rejected的预测差距,同时保持参考模型(原始模型)和新模型(强化学习训练中的模型)的差距尽量小。

二、实践

1.Huggingface – trl

trl集成了目前所有主流的强化学习方法,使用起来还是非常方便的,但缺点也很明显,基本只能使用hugging face所支持的主流大模型才能正常使用,对于魔改或者自定义LLM,使用起来可能问题较多
在这里插入图片描述

from datasets import load_dataset
import torch
from trl import DPOTrainer, DPOConfig

from datasets import load_dataset

train_dataset = load_dataset('json', data_files='...json')

dpo_config = DPOConfig(
    beta=0.1,  # 控制训练稳定性,数值越大越稳定但收敛慢
    learning_rate=5e-6,
    gradient_accumulation_steps=4,
    max_length=512,  # 训练时的最大 token 长度
    num_train_epochs=3,
    output_dir="./dpo_model",
)

trainer = DPOTrainer(
    model=model,
    ref_model=ref_model,   
    args=dpo_config,
    train_dataset=train_dataset,
    tokenizer=model.decoder.tokenizer,
)

trainer.train()

2.pytorch 实现 DPO (LLM)

前面提到,DPO本质上最终优化的是一个loss函数,因此对于DPO强化学习可以当成一个普通的训练任务来完成

DPO loss函数如下


def dpo_loss(model: nn.Module, reference_model: nn.Module, inputs: dict, beta: float ) -> torch.Tensor:
    """
    Computes the Direct Preference Optimization loss for a batch of inputs

    Args:
        model: The current policy model.
        reference_model: The reference policy model.
        inputs: A batch of inputs from the DataLoader, containing:
            - 'instruction_ids': Tensor of input token IDs.
            - 'chosen_ids': Tensor of chosen output token IDs.
            - 'rejected_ids': Tensor of rejected output token IDs.
        beta: The temperature controlling strength of KL penalty.
        device: The device to perform computations on.

    Returns:
        loss: The computed DPO loss.
    """
    # Extract input tensors and move them to the specified device
    instruction_ids = inputs['instruction_ids'].cuda() #to(device)
    chosen_ids = inputs['chosen_ids'].cuda()   #to(device)
    rejected_ids = inputs['rejected_ids'].cuda()    #to(device)

    # Concatenate instruction_ids with chosen_ids and rejected_ids to get the full sequences for model input
    chosen_input = torch.cat([instruction_ids, chosen_ids], dim=1)
    rejected_input = torch.cat([instruction_ids, rejected_ids], dim=1)

    # Get model outputs for both chosen and rejected inputs
    chosen_outputs = model(chosen_input)
    rejected_outputs = model(rejected_input)

    # Get reference model outputs for both chosen and rejected inputs
    with torch.no_grad():
        ref_chosen_outputs = reference_model(chosen_input)
        ref_rejected_outputs = reference_model(rejected_input)

    # Extract logits from the model outputs and reference model outputs
    chosen_logits = chosen_outputs.logits
    rejected_logits = rejected_outputs.logits
    ref_chosen_logits = ref_chosen_outputs.logits
    ref_rejected_logits = ref_rejected_outputs.logits

    # Compute log probabilities using log_softmax
    chosen_log_probs = F.log_softmax(chosen_logits, dim=-1)
    rejected_log_probs = F.log_softmax(rejected_logits, dim=-1)
    ref_chosen_log_probs = F.log_softmax(ref_chosen_logits, dim=-1)
    ref_rejected_log_probs = F.log_softmax(ref_rejected_logits, dim=-1)

    # Shift the sequences to get the targets
    chosen_targets = chosen_input[:, 1:]
    rejected_targets = rejected_input[:, 1:]

    # Adjust log_probs to align with the targets
    chosen_log_probs = chosen_log_probs[:, :-1, :]
    rejected_log_probs = rejected_log_probs[:, :-1, :]
    ref_chosen_log_probs = ref_chosen_log_probs[:, :-1, :]
    ref_rejected_log_probs = ref_rejected_log_probs[:, :-1, :]

    # Gather the log probabilities corresponding to the target tokens
    chosen_token_log_probs = chosen_log_probs.gather(2, chosen_targets.unsqueeze(-1)).squeeze(-1)
    rejected_token_log_probs = rejected_log_probs.gather(2, rejected_targets.unsqueeze(-1)).squeeze(-1)

    ref_chosen_token_log_probs = ref_chosen_log_probs.gather(2, chosen_targets.unsqueeze(-1)).squeeze(-1)
    ref_rejected_token_log_probs = ref_rejected_log_probs.gather(2, rejected_targets.unsqueeze(-1)).squeeze(-1)

    # Get the length of the prompt sequence
    input_length = instruction_ids.size(1)

    # Create masks to exclude the prompt tokens from loss computation, we only care about the response tokens
    chosen_mask = torch.ones_like(chosen_token_log_probs)
    rejected_mask = torch.ones_like(rejected_token_log_probs)

    # Zero out positions corresponding to the prompt tokens
    chosen_mask[:, :input_length - 1] = 0      #### 过滤掉prompt部分
    rejected_mask[:, :input_length - 1] = 0

    # Apply masks to the token log probabilities
    chosen_token_log_probs = chosen_token_log_probs * chosen_mask
    rejected_token_log_probs = rejected_token_log_probs * rejected_mask
    ref_chosen_token_log_probs = ref_chosen_token_log_probs * chosen_mask
    ref_rejected_token_log_probs = ref_rejected_token_log_probs * rejected_mask

    # Sum the log probabilities over the response tokens
    chosen_log_prob = chosen_token_log_probs.sum(dim=1)
    rejected_log_prob = rejected_token_log_probs.sum(dim=1)
    ref_chosen_log_prob = ref_chosen_token_log_probs.sum(dim=1)
    ref_rejected_log_prob = ref_rejected_token_log_probs.sum(dim=1)

    # Compute the difference in log probabilities adjusted by the reference model
    log_prob_diff = (chosen_log_prob - rejected_log_prob) - (ref_chosen_log_prob - ref_rejected_log_prob)

    # Compute the DPO loss using the logistic loss function
    loss = -F.logsigmoid(beta * log_prob_diff).mean()

    return loss


可以发现dpo loss函数计算了 model 在输入chosen id和rejected id时输出的token概率,同时计算了ref_model 在输入chosen id和rejected id时输出的token概率,并基于token id找到具体token的概率值计算loss,ref_model的目的是为了不让模型偏离原始模型,可以理解为KL散度。

具体的: chosen_log_probs 的shape是 (batch_size, sequence_length, vocab_size),表示每个 token 在 vocab_size 个类别上的 log 概率分布。

gather(2, chosen_targets.unsqueeze(-1)) 在 dim=2(vocab_size 维度)上,根据 chosen_targets 指定的索引提取相应的 log 概率值。

chosen_targets 存储的是每个 token 的正确索引,所以这个操作会选取 chosen_log_probs 在 vocab_size 维度上对应 chosen_targets 位置的概率。

可以发现dpo的最终目的就是rejected对应的token prob尽可能小,同时不要偏离原模型太多,从而实现达到期望输出。

3.pytorch 实现 DPO (vLLM)

对于视觉语言模型其实和LLM差不多,只要理解了DPO的基本理念,改起来就很简单。

对于vLLM来说,decoder的输入需要image+prompt+label的拼接,因此核心改动在于model的infer传参,和mask的设置

下面是我自己的vLLM模型的修改版本

    # encoder_outputs = model.encoder(image_tensor)
    chosen_logits = model(image_tensor, chosen_ids, chosen_attention_mask, instruction_ids).logits  
    rejected_logits = model(image_tensor, rejected_ids, rejected_attention_mask, instruction_ids).logits
    # print(logits.shape)
    
    with torch.no_grad():
        ref_chosen_logits = reference_model(image_tensor, chosen_ids, chosen_attention_mask, instruction_ids).logits 
        ref_rejected_logits = reference_model(image_tensor, rejected_ids, rejected_attention_mask, instruction_ids).logits 


    # Compute log probabilities using log_softmax
    chosen_log_probs = F.log_softmax(chosen_logits, dim=-1)
    rejected_log_probs = F.log_softmax(rejected_logits, dim=-1)
    ref_chosen_log_probs = F.log_softmax(ref_chosen_logits, dim=-1)
    ref_rejected_log_probs = F.log_softmax(ref_rejected_logits, dim=-1)

    
    pad_id = torch.full((batch_size, pre_token_length), fill_value=model.decoder.tokenizer.pad_token_id).cuda()
    pad_mask = torch.full((batch_size, pre_token_length), fill_value=0).cuda()
    
    chosen_targets = torch.cat([pad_id, chosen_ids], dim=1)
    rejected_targets = torch.cat([pad_id, rejected_ids], dim=1)
        

    # Gather the log probabilities corresponding to the target tokens
    chosen_token_log_probs = chosen_log_probs.gather(2, chosen_targets.unsqueeze(-1)).squeeze(-1)
    rejected_token_log_probs = rejected_log_probs.gather(2, rejected_targets.unsqueeze(-1)).squeeze(-1)

    ref_chosen_token_log_probs = ref_chosen_log_probs.gather(2, chosen_targets.unsqueeze(-1)).squeeze(-1)
    ref_rejected_token_log_probs = ref_rejected_log_probs.gather(2, rejected_targets.unsqueeze(-1)).squeeze(-1)

    chosen_mask = torch.cat([pad_mask, chosen_attention_mask], dim=1)
    rejected_mask = torch.cat([pad_mask, rejected_attention_mask], dim=1)
    

其实这个部分的改动主要在于明确模型的输出,保持mask和输出的对应(mask遮住image+prompt部分即可)


总结

总的来说DPO是一种简单高效的强化学习方法,DPO 直接优化策略,使其输出结果更符合人类或模型的偏好,不再依赖于显式的奖励建模或复杂的策略梯度估计,训练更简单稳定。其主要思想是通过最大化策略对“更优偏好结果”与“较差结果”之间的概率比,从而优化策略。
请添加图片描述

### DPO 的核心原理 DPO(Direct Preference Optimization)是一种用于训练大型语言模型(LLM)的直接偏好优化方法。与传统的强化学习方法(如 PPO)不同,DPO 不需要显式地构建奖励模型,而是直接利用人类反馈的“选优对比”数据来学习偏好。这种方法简化了训练流程,提高了训练的稳定性和效率。 DPO 的核心思想是通过对比两个不同策略生成的响应来优化模型。具体来说,DPO 使用人类偏好数据来训练模型,使得模型能够更好地理解和生成符合用户偏好的文本。这种数据通常是由人类标注者提供的,他们会对两个不同的响应进行评分,选择哪一个更好。 ### DPO实现实现上,DPO 使用一个参考模型(`model_ref`)和一个可训练模型(`model`)。这两个模型都是因果语言模型(`AutoModelForCausalLM`),而不是像 PPO 中那样需要一个带有值函数头的模型(`AutoModelForCausalLMWithValueHead`)。DPO 训练器通过比较两个模型生成的响应来优化策略。 以下是一个简单的 DPO 训练器的初始化示例: ```python dpo_trainer = DPOTrainer( model, model_ref, args=training_args, beta=0.1, train_dataset=train_dataset, tokenizer=tokenizer, ) ``` 在这个示例中,`model` 是可训练的模型,`model_ref` 是参考模型,`training_args` 包含了训练参数,`beta` 是一个超参数,用于控制偏好数据的影响程度,`train_dataset` 是训练数据集,`tokenizer` 是用于处理文本的分词器。 ### DPO 的优势 1. **简化训练流程**:DPO 无需显式构建奖励模型,而是直接利用人类偏好数据进行训练,简化了训练流程。 2. **提高训练效率**:由于不需要训练奖励模型,DPO 的训练过程更加高效。 3. **稳定性**:DPO 在训练过程中表现出更高的稳定性,减少了调试的复杂性。 4. **应用场景广泛**:DPO 可以应用于大型语言模型的对齐、基于人类偏好的文本生成、对话系统优化以及指令跟随能力的提升等场景 [^3]。 ### DPO 的挑战 尽管 DPO 具有诸多优势,但在实际应用中也面临一些挑战。例如,如何有效地收集和处理人类偏好数据,以及如何选择合适的超参数(如 `beta`)来控制偏好数据的影响程度。此外,DPO 的效果在很大程度上依赖于高质量的偏好数据,因此数据的质量和数量是影响 DPO 性能的关键因素。 ###
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值