DPO: Direct Preference Optimization 介绍

DPO 是 RLHF 的屌丝版本,RLHF 需要加载 4 个模型(2个推理,2个训练),DPO 只需要加载 2 个模型(1个推理,一个训练)。

RLHF:

DPO:

 

DPO 原理

DPO 的本质是监督对比学习:通过对每条prompt提供两条不同的answer,并给出这两个answer的偏好偏序,让模型输出更接近good answer,同时更远离 bad answer。

这个过程中并不强制要求上述两者同时满足,只要接近good answer的程度大于bad answer就是有效的训练,比如与good answer远离了,但是与bad answer远离的更多也是有效的。

DPO loss

 

σ :sigmoid函数

β :超参数,一般在0.1 - 0.5之间

y_w :某条偏好数据中好的response,w就是win的意思

y_l :某条偏好数据中差的response,l就是loss的意思,所以偏好数据也叫comparision data

\pi_\theta(y_w|x) :给定输入x, 当前policy model生成好的response的累积概率(每个tokne的概率求和,具体看代码)

\pi_{ref}(y_l|x) :给定输入x, 原始模型(reference model)生成坏的response的累积概率

开始训练时,reference model和policy model都是同一个模型,只不过在训练过程中reference model不会更新权重。

简化形式:忽略 logsigmoid 并取对数

由于最初loss前面是有个负号的,所以优化目标是让本简化公式最大,即希望左半部分和右半部分的margin越大越好,左半部分的含义是good response相较于没训练之前的累积概率差值,右半部分代表bad response相较于没训练之前的累计概率差值,如果这个差值,即margin变大了。

 DPO 数据集

可以由prompt 模板: Human: prompt. Assistant: chosen/rejected 构成如下数据:Anthropic/hh-rlhf dataset

 DPO trainer 期望数据集具有非常特定的格式。 给定两个句子时,模型将被训练为直接优化偏好:那一个句子最相关。

Huagging Face DPO Trainer

与 PPO 期望 AutoModelForCausalLMWithValueHead 作为值函数相比,DPO 训练器期望 AutoModelForCausalLM 模型。 

 dpo_trainer = DPOTrainer(model,model_ref,args=training_args,beta=0.1,train_dataset=train_dataset,tokenizer=tokenizer,
)

Loss 选择:

  • RSO 作者建议在 SLiC 论文中的归一化似然上使用 hinge损失。 DPOTrainer 可以通过 loss_type="hinge" 参数切换到此损失,这种情况下的 beta 是margin的倒数。
  • IPO 作者对 DPO 算法提供了更深入的理论理解,并识别了过度拟合的问题,并提出了一种替代损失,可以通过训练器的 loss_type="ipo" 参数来使用。
  • cDPO 是对 DPO 损失的调整,其中我们假设偏好标签有一定的噪声,可以通过 label_smoothing 参数(0 到 0.5 之间)传递到 DPOTrainer,然后使用保守的 DPO 损失。 使用 loss_type="cdpo" 参数给训练器来使用它。
  • KTO 损失的导出是为了直接最大化 LLM 代的效用,而不是偏好的对数似然。 因此,数据集不一定是偏好,而是期望的完成与不期望的完成。 对于 DPOTrainer 所需的配对偏好数据,请使用训练器的 loss_type="kto_pair" 参数来利用此损失,而对于所需和不需要的数据的更一般情况,请使用尚未实现的 KTOTrainer。

简单实例

#!/usr/bin/env python
# -*- encoding: utf-8 -*-import torch
import torch.nn.functional as F
from transformers import LlamaForCausalLM, LlamaConfig
from copy import deepcopytorch.manual_seed(0)
if __name__ == "__main__":# 超参数beta = 0.1# 加载模型policy_model = LlamaForCausalLM(config=LlamaConfig(vocab_size=1000, num_hidden_layers=1, hidden_size=128))reference_model = deepcopy(policy_model)# dataprompt_ids = [1, 2, 3, 4, 5, 6]good_response_ids = [7, 8, 9, 10]# 对loss稍加修改可以应对一个good和多个bad的情况bad_response_ids_list = [[1, 2, 3, 0], [4, 5, 6, 0]]# 转换成模型输入 [3, 10]input_ids = torch.LongTensor([prompt_ids + good_response_ids, *[prompt_ids + bad_response_ids for bad_response_ids in bad_response_ids_list]])# labels 提前做个shift [3, 9]labels = torch.LongTensor([[-100] * len(prompt_ids) + good_response_ids,*[[-100] * len(prompt_ids) + bad_response_ids for bad_response_ids in bad_response_ids_list]])[:, 1:]loss_mask = (labels != -100)labels[labels == -100] = 0# 计算 policy model的log prob# policy_model(input_ids)["logits"] [3, 10, 1000] 句末的推理结果无效直接忽略logits = policy_model(input_ids)["logits"][:, :-1, :]per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)all_logps = (per_token_logps * loss_mask).sum(-1)# 暂时写死第一个是good response的概率, 三个例子中第一个是 good answer, 后两个是 bad answerpolicy_good_logps, policy_bad_logps = all_logps[:1], all_logps[1:]# 计算 reference model的log probwith torch.no_grad():logits = reference_model(input_ids)["logits"][:, :-1, :]per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)all_logps = (per_token_logps * loss_mask).sum(-1)# 暂时写死第一个是good response的概率reference_good_logps, reference_bad_logps = all_logps[:1], all_logps[1:]# 计算loss,会自动进行广播logits = (policy_good_logps - reference_good_logps) - (policy_bad_logps - reference_bad_logps)loss = -F.logsigmoid(beta * logits).mean()print(loss)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/54044.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

YoloV10 训练自己的数据集(推理,转化,C#部署)

目录 一、下载 三、开始训练 train.py detect.py export.py 超参数都在这个路径下 四、C#读取yolov10模型进行部署推理 如下程序是用来配置openvino 配置好引用后就可以生成dll了 再创建一个控件,作为显示 net framework 4.8版本的 再nuget工具箱里下载 …

价值流与核心理论框架对比解析:企业业务架构优化的全景指南

企业架构优化中的理论框架选择 随着数字化转型和全球竞争的加剧,企业管理者越来越意识到优化业务流程以提升竞争力的重要性。然而,在众多优化方法中,企业如何选择最适合自己的理论框架成为一大挑战。由The Open Group发布的《价值流指南》系…

密码学基础--ECDSA算法入门

目录 1.ECDSA签名长度的疑惑 2.ECDSA原理 2.1 生成签名 2.2 验签过程 2.3 签名编码问题 3.小结 1.ECDSA签名长度的疑惑 我们来看看ECDSA签名长什么样子,使用MuscleV02自动生成密钥对,并对message"0x11223344”进行签名,结果如下&a…

Java的衍生生态有哪些?恐怖如斯的JAVA

Java的衍生生态极其丰富,涵盖了多个层面和领域。以下是Java衍生生态的一些主要方面: 1. 开源工具 开发工具:如Eclipse,这是一款非常优秀的Java IDE工具,支持Java以及其他语言的代码编写。Spring官方还基于Eclipse开发…

Golang开发之路

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,…

混合整数规划及其MATLAB实现

目录 引言 混合整数规划的基本模型 混合整数规划的求解方法 MATLAB中的混合整数规划实现 示例:多变量系统的混合整数规划 表格总结:混合整数规划的求解方法与适用场景 结论 引言 混合整数规划(Mixed Integer Programming, MIP&#xf…

多线程学习篇二:Thread常见方法

1. 常见方法 方法名 static 功能说明 注意点 start() 启动一个新线程,在新线程里面运行run方法 start 方法只是让线程进入就绪,里面代码不一定立刻运行(CPU 的时间片还没分给它)。每个线程对象的 start 方法只能调用一次,如果调用了多…

【Hadoop|MapReduce篇】MapReduce概述

1. MapReduce定义 MapReduce是一个分布式运算程序的编程框架,是用户开发“基于Hadoop的数据分析应用”的核心框架。 MapReduce核心功能是将用户编写的业务逻辑代码和自带默认组件整合成一个完整的分布式运算程序,并发运行在一个Hadoop集群上。 2. Map…

【绿盟科技盟管家-注册/登录安全分析报告】

前言 由于网站注册入口容易被黑客攻击,存在如下安全问题: 暴力破解密码,造成用户信息泄露短信盗刷的安全问题,影响业务及导致用户投诉带来经济损失,尤其是后付费客户,风险巨大,造成亏损无底洞…

linux 最简单配置免密登录

需求:两台服务器互信登录需要拉起对端服务 ip: 192.168.1.133 192.168.1.137 一、配置主机hosts,IP及主机名,两台都需要 二、192.168.1.137服务器,生成密钥 ssh-keygen -t rsa三、追加到文件 ~/.ssh/authorized_key…

2024年第二届《英语世界》杯全国大学生英语听力大赛

下周开考! 一、主办单位 商务印书馆《英语世界》杂志社 二、时间安排 赛事报名时间:即日起-2024年11月15日 正式比赛阶段:第一场:2024年9月22日10:00-22:00 第二场:2024年10月27日10:00-22:00 第三场&#xff1…

QT::QComboBox自定义左击事件信号

因为QComboBox没有自定义的clink信号&#xff0c;所以自己新建一个MyComBox类继承QComboBox&#xff0c;并且添加自定义的左击信号&#xff0c;以及使用该信号连接一个槽函数 mycombobox.h #ifndef MYCOMBOBOX_H #define MYCOMBOBOX_H#include <QComboBox> #include &l…

Baumer工业相机堡盟工业相机如何通过BGAPI SDK设置相机的图像剪切(ROI)功能(C语言)

Baumer工业相机堡盟工业相机如何通过BGAPI SDK设置相机的图像剪切&#xff08;ROI&#xff09;功能&#xff08;C语言&#xff09; Baumer工业相机Baumer工业相机的图像剪切&#xff08;ROI&#xff09;功能的技术背景CameraExplorer如何使用图像剪切&#xff08;ROI&#xff0…

复旦:EoT下Muti-agentllm曾带给我的启发

结合最近的一些经历&#xff0c;回忆起很早之前探索Agent时阅读过的一篇自来复旦/NUS/上海AI Lab的泛CoT框架思想论文&#xff0c;文中提出了一种名为“思想交换”&#xff08;Exchange-of-Thought, EoT&#xff09;的新框架&#xff0c;该框架允许在问题解决过程中进行跨模型交…

android 老项目中用到的jar包不存在,通过离线的方法加载

1、之前的项目用的jar包&#xff0c;已经不在远程仓库中&#xff0c;只能手工去下载&#xff0c;并且安装。 // implementation com.github.nostra13:Android-Universal-Image-Loader // implementation com.github.lecho:hellocharts-android:v1.5.8 这…

信息安全工程师(1)计算机网络分类

一、按分布范围分类 广域网&#xff08;WAN&#xff09;&#xff1a; 定义&#xff1a;广域网的任务是提供长距离通信&#xff0c;运送主机所发送的数据。其覆盖范围通常是直径为几十千米到几千千米的区域&#xff0c;因此也被称为远程网。特点&#xff1a;连接广域网的各个结点…

智能语音技术在人机交互中的应用与发展

摘要&#xff1a;本文主要探讨智能自动语音识别技术与语音合成技术在构建智能口语系统方面的作用。这两项技术实现了人机语音通信&#xff0c;建立起能听能说的智能口语系统。同时&#xff0c;引入开源 AI 智能名片小程序&#xff0c;分析其在智能语音技术应用场景下的意义与发…

实现CPU压力测试工具的C语言实现

实现CPU压力测试工具的C语言实现 一、背景与需求二、伪代码设计三、C语言实现四、编译和运行五、注意事项在软件开发和系统维护中,CPU压力测试是一项重要任务,用于评估系统的稳定性和性能。本篇文章将详细介绍如何使用C语言结合伪代码实现一个简单的CPU压力测试工具。 一、…

软媒市场新趋势:自助发布与一手资源渠道商自助发稿的崛起

在当今这个信息爆炸的时代,软媒市场作为品牌传播的重要阵地,正经历着前所未有的变革。随着技术的不断进步和消费者行为的日益多样化,传统的营销方式已难以满足企业的需求。在这样的背景下,自助发布与一手资源渠道商自助发稿的模式应运而生,为企业的品牌宣传开辟了新的道路。 自…

旺店通ERP集成用友BIP(旺店通主供应链)

源系统成集云目标系统 用友BIP介绍 用友BIP是以数智底座以及财务、人力、供应链、营销、采购、制造、研发、项目、资产、协同等数智化服务成就的数智平台&#xff0c;同时也预置了很多跨行业通用的SaaS服务&#xff0c;在营销、采购、制造、财务、人力、协同等核心业务领域提供…