模型微调DPO入门

一、定义

  1. 定义
  2. 数据集格式
  3. llamafactory 训练案例
  4. 入门文档阅读

二、实现

  1. 定义
    DPO通过直接优化语言模型来实现对其行为的精确控制,而无需使用复杂的强化学习,也可以有效学习到人类偏好,DPO相较于RLHF更容易实现且易于训练,效果更好.
    DPO 是一种自动微调方法,它通过最大化预训练模型在特定任务上的奖励来优化模型参数。与传统的微调方法相比,DPO 绕过了建模奖励函数这一步,而是通过直接在偏好数据上优化模型来提高性能。相对RLHF两阶段而言具有多项优越性:
    (1)简单性:DPO更容易实施和培训,使其更易于使用。
    (2)稳定性:不易陷入局部最优,保证训练过程更加可靠。
    (3)效率:与RLHF 相比, DPO 需要更少的计算资源和数据,使其计算量轻。
    (4)有效性:实验结果表明,DPO在情感控制、摘要和对话生成等任务中可以优于 RLHF 。

  2. 数据集格式

[{"conversations": [{"from": "human","value": "国会的转发\n美国国会由众议院和参议院组成,每两年换届一次(参议员任期为6年,但参议院选举是错位的,使得国会的组成仍然每两年变化一次)。这两年期间按顺序标记,第115届国会发生在2017-2018年。\n\n密歇根大学信息学院的研究人员在这段时间内收集了现任国会议员(我们将“国会议员”缩写为MoC)的Twitter帖子,并对它们进行编码,标记为原创声明或其他用户提交的转发。我们将重点关注转发数据。这些发布的数据不仅包括转发的文本,还包括国会议员的信息和原始推文的帐户。\n#python:\n\nimport pandas as pd\nimport numpy as np\nimport matplotlib.pyplot as plt\nimport seaborn as sb\nimport statsmodels.api as sm\nimport os\nfrom pandas.plotting import register\\_matplotlib\\_converters\nregister\\_matplotlib\\_converters()\n\npd.options.mode.chained\\_assignment = None\n\n# 在接下来的内容中,我们将加载数据,但出于速度原因,我们将排除推文本身的文本。\n\ntweets = pd.read\\_csv(os.path.join(base, \"retweets\\_congress\\_115th.csv.gz\"), \n usecols = lambda x: x != \"full\\_text\",\n dtype = {\"tweet\\_id\\_str\": \"str\"},\n index\\_col = \"tweet\\_id\\_str\")\ntweets.shape\nout: (209856, 21)\n\ntweets.columns\nout:Index(['user\\_id\\_str', 'user\\_screen\\_name', 'user\\_followers\\_count', 'created\\_at',\n 'retweet\\_count', 'retweeted\\_status', 'year', 'bioguide', 'birthday',\n 'gender', 'name', 'state', 'chamber', 'party', 'region', 'rt\\_user\\_id',\n 'rt\\_screen\\_name', 'rt\\_followers\\_count', 'rt\\_bio', 'age', 'age\\_bucket'],\n dtype='object')\n\n第(a)部分\n国会主要由两个政党(民主党和共和党)组成,还有一些独立派。哪个政党负责大部分的转发?\n\n计算推文的政党的边际概率,并将结果保存到party\\_marg。显示这个结果。\n\n第(b)部分\n我们可能会问的一个问题是,“国会成员彼此转发的频率是多少?”\n\n创建一个名为“is\\_congress\\_retweet”的新列,指示转发是否来自MoC。首先使用.unique()方法获取唯一的“user\\_id\\_str”值集合。然后使用这组id,在“rt\\_user\\_id”列上使用.isin(values)方法创建新的布尔列“is\\_congress\\_retweet”。\n\n找到转发MoC的边际概率,并将结果保存到“congress\\_retweet\\_marg”。显示这些结果。是更常见地转发国会议员还是来自国会外部?"}],"chosen": {"from": "gpt","value": "计算推文的政党边际概率,我们可以使用以下代码:\n```python\ntweets['party'] = tweets['party'].replace(0, 'Democrat')\ntweets['party'] = tweets['party'].replace(1, 'Republican')\ntweets['party'] = tweets['party'].replace(2, 'Independent')\n\nparty_counts = tweets.groupby('party').size().sort_values(ascending=False)\n\nparty_marg = party_counts / party_counts.sum()\n\nprint(\"政党的边际概率:\")\nprint(party_marg)\n```\n这段代码将'party'列中的数字值替换为相应的政党名称,然后按政党对推文进行分组并计算每个组的大小。最后,通过将每个组的大小除以总推文数来计算每个政党的边际概率。\n\n输出应该类似于:\n```markdown\n政党的边际概率:\n0.522152    民主党\n0.445541    共和党\n0.032307    独立派\n```\n这表明大多数转发来自民主党,其次是共和党,然后是独立派。\n\n要创建一个名为\"is\\_congress\\_retweet\"的新列,指示转发是否来自国会议员,我们可以使用以下代码:\n```python\ncongress_ids = tweets['user_id_str'].unique()\nis_congress_retweet = tweets['rt_user_id'].isin(congress_ids)\ntweets['is_congress_retweet'] = is_congress_retweet\n\ncongress_retweet_marg = tweets['is_congress_retweet'].mean()\nprint(\"转发国会议员的边际概率:\")\nprint(congress_retweet_marg)\n```\n这段代码首先获取唯一的\"user\\_id\\_str\"值集合,对应于国会议员的集合。然后,在\"rt\\_user\\_id\"列上使用.isin(values)方法创建一个新的布尔列\"is\\_congress\\_retweet\",指示转发是否来自国会议员。最后,通过取\"is\\_congress\\_retweet\"列的平均值来计算转发国会议员的边际概率。\n\n输出应该类似于:\n```markdown\n转发国会议员的边际概率:\n0.775151\n```\n这表明大多数转发不是来自国会议员,而是来自国会之外。"},"rejected": {"from": "gpt","value": "回答问题的第(a)部分,您可以通过首先使用groupby()函数按政党对数据进行分组,然后计算每个组的转发数来计算政党的边际概率。具体来说,您可以使用以下代码:\n```scss\nparty_count = tweets.groupby('party').count()\nparty_count.loc['Democrats','retweet_count'] / party_count.loc['Republicans','retweet_count']\n```\n回答问题的第(b)部分,您可以创建一个新列来指示转发是否来自国会议员,方法是使用“bioguide”列上的.unique()方法获取唯一的国会议员用户ID集合,然后使用该集合创建一个布尔列,指示转发是否来自国会议员,使用“rt_user_id”列上的.isin()方法。最后,使用.sum()方法来计算转发国会议员的边际概率,然后将其与转发来自国会以外的人的概率进行比较。以下是相关代码:\n```makefile\nimport pandas as pd\ntweets = pd.read_csv('retweets_congress_115th.csv')\n\n# 获取唯一的国会议员ID集合\nmovers_ids = tweets['bioguide'].unique()\n\n# 创建一个新的布尔列,指示转发是否来自国会议员\ntweets['is_congress_retweet'] = tweets['rt_user_id'].isin(movers_ids)\n\n# 计算转发国会议员或来自国会以外的人的边际概率\nmovers_retweets = tweets[tweets['is_congress_retweet'] == True].shape[0]\nnon_movers_retweets = tweets[tweets['is_congress_retweet'] == False].shape[0]\nproportion_movers_retweets = movers_retweets / (movers_retweets + non_movers_retweets)\nprint('Proportion of retweets from MoCs:', proportion_movers_retweets)\n```"}}]

其中chosen 代表是好的回答, rejected代表的是不好的回答

  1. llamafactory 训练案例
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml 
  1. 入门文档阅读
    https://huggingface.co/docs/trl/dpo_trainer
    代码:https://github.com/huggingface/trl/blob/main/examples/scripts/dpo.py
    https://github.com/huggingface/trl/tree/main
    数据格式:
def process(row):row["prompt"] = tokenizer.apply_chat_template(row["chosen"][:-1], tokenize=False)row["chosen"] = tokenizer.apply_chat_template([row["chosen"][-1]], tokenize=False)row["rejected"] = tokenizer.apply_chat_template([row["rejected"][-1]], tokenize=False)return row

获取正向的预测、反向的预测—>reward_accuracies = (chosen_rewards > rejected_rewards).float() —>纠正loss---->反向传播

def dpo_loss(self,policy_chosen_logps: torch.FloatTensor,policy_rejected_logps: torch.FloatTensor,reference_chosen_logps: torch.FloatTensor,reference_rejected_logps: torch.FloatTensor,reference_free: bool = True,
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:"""Compute the DPO loss for a batch of policy and reference model log probabilities.Args:policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)beta: Temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. We ignore the reference model as beta -> 0.reference_free: If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses.Returns:A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).The losses tensor contains the DPO loss for each example in the batch.The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively."""pi_logratios = policy_chosen_logps - policy_rejected_logpsref_logratios = reference_chosen_logps - reference_rejected_logpsif reference_free:ref_logratios = 0logits = pi_logratios - ref_logratioslosses = -F.logsigmoid(self.beta * logits)chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach()rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()return losses, chosen_rewards, rejected_rewards
def get_batch_loss_metrics(self,model,batch: Dict[str, Union[List, torch.LongTensor]],train_eval: Literal["train", "eval"] = "train",
):"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""metrics = {}(policy_chosen_logps,policy_rejected_logps,policy_chosen_logits,policy_rejected_logits,policy_chosen_logps_avg,) = self.concatenated_forward(model, batch)# if reference_chosen_logps and reference_rejected_logps in batch use them, otherwise use the reference modelif ("reference_chosen_logps" in batchand "reference_rejected_logps" in batchand self.args.rpo_alpha is not None):reference_chosen_logps = batch["reference_chosen_logps"]reference_rejected_logps = batch["reference_rejected_logps"]else:with torch.no_grad():if self.ref_model is None:with self.null_ref_context():(reference_chosen_logps,reference_rejected_logps,_,_,_,) = self.concatenated_forward(self.model, batch)else:(reference_chosen_logps,reference_rejected_logps,_,_,_,) = self.concatenated_forward(self.ref_model, batch)losses, chosen_rewards, rejected_rewards = self.dpo_loss(policy_chosen_logps,policy_rejected_logps,reference_chosen_logps,reference_rejected_logps,)reward_accuracies = (chosen_rewards > rejected_rewards).float()if self.args.rpo_alpha is not None:losses = losses * self.args.rpo_alpha - policy_chosen_logps_avgprefix = "eval_" if train_eval == "eval" else ""metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu()metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu()metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu()metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu()metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu()return losses.mean(), metricsdef compute_loss(self,model: Union[PreTrainedModel, nn.Module],inputs: Dict[str, Union[torch.Tensor, Any]],return_outputs=False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:if not self.use_dpo_data_collator:warnings.warn("compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than ""DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator")compute_loss_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontextwith compute_loss_context_manager():loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")# Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:loss = loss.to(self.args.device)# force log the metricsself.store_metrics(metrics, train_eval="train")if return_outputs:return (loss, metrics)return loss

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/35561.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

面试题-Redis简介

1.主流应用框架 概念: 穿透查询:数据库中的名词,与逐层查询不同,通过超链接可直接查询想要的结果,更加方便快捷 熔断机制:指软件系统中,由于某些原因使得服务出现了过载现象,为防止…

「2024中国数据要素产业图谱1.0版」重磅发布,景联文科技凭借高质量数据采集服务入选!

近日,景联文科技入选数据猿和上海大数据联盟发布的《2024中国数据要素产业图谱1.0版》数据采集服务板块。 景联文科技是专业数据服务公司,提供从数据采集、清洗、标注的全流程数据解决方案,协助人工智能企业解决整个AI链条中数据采集和数据标…

【面试题】SpringBoot面试题

目录 Spring Boot 的核心注解是哪个?它主要由哪几个注解组成的?如何理解 Spring Boot 中的 Starters?Spring Boot 的核心配置文件有哪几个?它们的区别是什么?Spring Boot、Spring MVC 和 Spring 有什么区别&#xff1f…

Maven高级的多环境配置与应用

多环境配置与应用 这一节中,我们会讲两个内容,分别是多环境开发和跳过测试 5.1 多环境开发 我们平常都是在自己的开发环境进行开发,当开发完成后,需要把开发的功能部署到测试环境供测试人员进行测试使用,等测试人员测…

Redis报错:MISCONF Redis is configured to save RDB snapshots

错误提示内容: 2024-06-25 16:30:49 : Connection: Redis_Server > [runCommand] PING 2024-06-25 16:30:49 : Connection: Redis_Server > Response received : -MISCONF Redis is configured to save RDB snapshots, but it is currently not able to pers…

Qt Quick Effect Maker 工具使用介绍

一、介绍 随着 Qt 版本的不断升级,越来越多的功能被加入 Qt,一些新的工具也随之应运而生,Qt Quick Effect Maker 工具是 Qt 6.5 之后才新添加的工具,之前的名字应该是叫做 Qt shader tool 这个模块。 以下是官方的释义:Qt Quick Effect Maker是一个用于为Qt Quick创建自定…

C语⾔数据类型和变量

C语⾔数据类型和变量 1.数据类型介绍1.1 字符型1.2 整型1.3 浮点型1.4 布尔类型1.5 各种数据类型的长度1.5.1 sizeof操作符1.5.2 数据类型长度1.5.3 sizeof中表达式不计算 2. signed 和 unsigned3. 数据类型的取值范围4. 变量4.1 变量的创建4.2 变量的分类 5. 算术操作符&#…

Vue2+TS el-table简单封装 和 使用

1.封装的组件写法 <template><div style"height: calc( 100% - 33px);width:100%;position:relative"><!-- 权限管理标题显示与否 --><div ref"operationBtnbox" class"operation-Btn-box" v-if"showOperationBtn&qu…

React Hooks 小记(七)_useReducer

useReducer usereducer 相当于 复杂的 useState 当状态更新逻辑较复杂时可以考虑使用 useReducer。useReducer 可以同时更新多个状态&#xff0c;而且能把对状态的修改从组件中独立出来。 相比于 useState&#xff0c;useReducer 可以更好的描述“如何更新状态”。例如&#…

Zookeeper 集群的应用场景

Zookeeper 集群的应用场景 Zookeeper 是一个分布式协调服务,主要用于管理分布式应用中的配置、同步和命名等任务。由于其高可用性、 一致性和可靠性,Zookeeper 被广泛应用于各种分布式系统中。以下是 Zookeeper 集群的一些典型应用场景: 1. 配置管理 Zookeeper 可以用来集…

社区团购小程序开发

在快节奏的现代生活中&#xff0c;人们越来越追求便利与效率。社区团购小程序应运而生&#xff0c;以其独特的优势成为连接社区居民与优质商品的重要桥梁。本文将探讨社区团购小程序的特点、优势以及未来发展趋势&#xff0c;为大家揭示这一新型购物模式的魅力。 社区团购小程序…

LLM与GPT的一些概念

LLM 大模型语言模型(Large Language Model,LLM)技术是近年来人工智能领域的重要突破,凭借其出色的语义理解和生成能力,正在广泛应用于各种自然语言处理场景。 基本原理 LLM 是基于深度学习的语言模型,通过学习大规模文本数据,获得对自然语言的深入理解。这种模型能够准确地预…

MAC 查看公钥私钥

电脑配置过公钥私钥&#xff0c;现在需要查看&#xff1a; 1、 查看本地是否存在SSH密钥 命令&#xff1a;ls -al ~/.ssh 如果在输出的文件列表中发现id_rsa和id_rsa.pub的存在&#xff0c;证明本地已经存在SSH密钥&#xff0c;请执行第3步 2、 生成SSH密钥 命令&#xff1…

一本好的电子画册应这样做,你做对了吗?

​一本好的电子画册&#xff0c;不仅要有吸引人的图文&#xff0c;还可能包括视频、音频等多媒体元素&#xff0c;为读者提供全方位的阅读体验。连贯性是指画册的整体设计风格、内容布局要协调一致&#xff0c;让读者在阅读过程中感受到流畅和自然。创新性则要求创作者在内容呈…

39 - 电影评分(高频 SQL 50 题基础版)

39 - 电影评分 (selectu.name as results fromMovieRating m left join Users u on m.user_idu.user_id GROUP BYm.user_id order by count(*) desc,u.name asc limit 1) union all (selectm1.title as results fromMovieRating m left join Movies m1 on m.movie_idm1.movie…

加速业务布局,30年老将加盟ATFX,掌舵运营新篇章

全球领先的差价合约经纪商ATFX日前宣布了一项重大人事任命&#xff0c;聘请业界资深人士约翰博格(John Bogue)为机构业务运营总监。约翰博格是一名行业老将&#xff0c;曾在差价合约界深耕三十余载。伴随其加入ATFX&#xff0c;相信他的深厚专业知识和从业经验将为ATFX机构业务…

Java序列化进阶:Java内置序列化的三种方式

Java序列化就是把Java对象按照一定的格式存到文件或者磁盘当中 序列化的进阶&#xff1a;即三种方式&#xff0c;任何一种方式都可以进行序列化和反序列化 如果将数据读写到文档&#xff0c; 一般通过 ObjectOutputStream 将数据写入到文件当中&#xff0c;就是一种序列化的…

数据分析python基础实战分析

数据分析python基础实战分析 安装python&#xff0c;建议安装Anaconda 【Anaconda下载链接】https://repo.anaconda.com/archive/ 记得勾选上这个框框 安装完后&#xff0c;然后把这两个框框给取消掉再点完成 在电脑搜索框输入"Jupyter"&#xff0c;牛马启动&am…

简单聊聊云硬盘的规格

云硬盘类型及对应性能介绍 衡量云硬盘性能的指标有很多种&#xff0c;例如IOPS&#xff0c;吞吐量&#xff0c;读写时延&#xff1a; IOPS&#xff1a;云硬盘每秒进行读写的操作次数&#xff0c;可以细分到单盘最大IOPS&#xff0c;基线IOPS&#xff0c;IOPS突发上限等等。吞…

司美格鲁肽在中国获批!深度解析报告附上

在中国&#xff0c;肥胖问题日益严重&#xff0c;但有效的治疗方法却相对匮乏。然而&#xff0c;这一现状随着国家药品监督管理局&#xff08;NMPA&#xff09;对诺和诺德公司研发的司美格鲁肽注射液&#xff08;商品名&#xff1a;诺和盈&#xff09;的批准而得到改变。6月25日…