基于 Python 的自然语言处理系列(83):InstructGPT 原理与实现

📌 论文地址:Training language models to follow instructions with human feedback
💻 参考项目:instructGOOSE

📷 模型架构图:

一、引言:为什么需要 InstructGPT?

        传统的语言模型往往依赖于“最大似然训练”,学会如何生成符合语法的文本,但却不一定符合人类的指令意图。OpenAI 提出的 InstructGPT 是一种结合 人类反馈监督 + 强化学习(RLHF) 的新训练范式,其目标是使语言模型更能“听人话”。

                InstructGPT 的三阶段训练流程如下:

  1. SFT(Supervised Fine-tuning):使用人工标注的指令-回复数据进行有监督微调。

  2. RM(Reward Modeling):让人工对模型生成的多个候选回复打分,从而训练一个奖励模型。

  3. PPO(Proximal Policy Optimization):使用 RL 算法训练语言模型,使其生成的回复最大化奖励模型的得分。

        本篇将结合 instructGOOSE 项目,对上述三阶段进行端到端复现,使用的数据集为 IMDb 影评文本,语言模型为 GPT-2

二、准备工作:环境与设备

# 安装依赖
# pip3 install instruct_gooseimport os
import torchfrom datasets import load_dataset
from torch.utils.data import DataLoader, random_split
from tqdm.auto import tqdm# 设置 GPU 设备(如使用 Colab 建议 comment 掉代理配置)
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ['http_proxy']  = 'http://192.41.170.23:3128'
os.environ['https_proxy'] = 'http://192.41.170.23:3128'device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

三、加载 IMDb 数据集并构建 DataLoader

dataset = load_dataset("imdb", split="train")# 为演示快速收敛,仅使用前 10 条数据
dataset, _ = random_split(dataset, lengths=[10, len(dataset) - 10])train_dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

四、加载 GPT-2 模型与 InstructGPT 工具链

from transformers import AutoTokenizer, AutoModelForCausalLM
from instruct_goose import Agent, RewardModel, RLHFTrainer, RLHFConfig, create_reference_modelmodel_name_or_path = "gpt2"# 加载主模型与奖励模型
model_base = AutoModelForCausalLM.from_pretrained(model_name_or_path)
reward_model = RewardModel(model_name_or_path)# 加载 tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
eos_token_id = tokenizer.eos_token_id

五、创建 RL 模型代理与参考模型

# 构造 Agent(语言模型 + Value 网络 + 采样接口)
model = Agent(model_base)
ref_model = create_reference_model(model)

六、训练配置与 RLHFTrainer 初始化

max_new_tokens = 20
generation_kwargs = {"min_length": -1,"top_k": 0.0,"top_p": 1.0,"do_sample": True,"pad_token_id": eos_token_id,"max_new_tokens": max_new_tokens
}config = RLHFConfig()  # 可使用默认参数trainer = RLHFTrainer(model, ref_model, config)

七、基于 PPO 的 InstructGPT 强化训练

from torch import optimoptimizer = optim.Adam(model.parameters(), lr=1e-3)
num_epochs = 3for epoch in range(num_epochs):for step, batch in enumerate(tqdm(train_dataloader)):# Step 1: 编码输入inputs = tokenizer(batch["text"],padding=True,truncation=True,return_tensors="pt")inputs = {k: v.to(device) for k, v in inputs.items()}# Step 2: 使用主模型生成回复response_ids = model.generate(inputs["input_ids"], attention_mask=inputs["attention_mask"],**generation_kwargs)response_ids = response_ids[:, -max_new_tokens:]response_attention_mask = torch.ones_like(response_ids)# Step 3: 拼接 query + response,使用 Reward Model 评估得分with torch.no_grad():input_pairs = torch.stack([torch.cat([q, r], dim=0)for q, r in zip(inputs["input_ids"], response_ids)]).to(device)rewards = reward_model(input_pairs)# Step 4: 计算 PPO 损失并反向传播loss = trainer.compute_loss(query_ids=inputs["input_ids"],query_attention_mask=inputs["attention_mask"],response_ids=response_ids,response_attention_mask=response_attention_mask,rewards=rewards)optimizer.zero_grad()loss.backward()optimizer.step()print(f"[Epoch {epoch+1}] Loss = {loss.item():.4f}")

八、推理测试与结果展示

# 输入一句文本进行测试
input_text = dataset[0]['text']
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)output = model_base.generate(input_ids, max_length=256,num_beams=5, no_repeat_ngram_size=2,top_k=50, top_p=0.95, temperature=0.7
)generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print("🧠 模型生成结果:\n", generated_text)

九、总结

        本篇我们复现了 InstructGPT 的核心训练框架,依赖于三大模块:

  • 语言模型(GPT2);

  • 奖励模型(RewardModel);

  • 强化训练器(RLHFTrainer + PPO loss)。

        通过引入人类反馈偏好作为优化目标,InstructGPT 展现出更强的任务理解与指令遵循能力,已经成为 ChatGPT 训练体系的核心组成部分之一。

🔮 下一篇预告

        📘《基于 Python 的自然语言处理系列(84):SFT原理与实践》

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/902540.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

零基础入门 Verilog VHDL:在线仿真与 FPGA 实战全流程指南

摘要 本文面向零基础读者,全面详解 Verilog 与 VHDL 两大主流硬件描述语言(HDL)的核心概念、典型用法及开发流程。文章在浅显易懂的语言下,配合多组可在线验证的示例代码、PlantUML 电路结构图,让你在 EDA Playground 上动手体验数字电路设计与仿真,并深入了解从 HDL 编写…

Kubernetes控制平面组件:API Server详解(二)

云原生学习路线导航页(持续更新中) kubernetes学习系列快捷链接 Kubernetes架构原则和对象设计(一)Kubernetes架构原则和对象设计(二)Kubernetes架构原则和对象设计(三)Kubernetes控…

云服务器存储空间不足导致的docker image运行失败或Not enough space in /var/cache/apt/archives

最近遇到了两次空间不足导致docker实例下的mongodb运行失败的问题。 排查错误 首先用nettools看下mongodb端口有没有被占用: sudo apt install net-tools netstat --all --program | grep 27017 原因和解决方案 系统日志文件太大 一般情况下日志文件不会很大…

爬虫学习——下载文件和图片、模拟登录方式进行信息获取

一、下载文件和图片 Scrapy中有两个类用于专门下载文件和图片,FilesPipeline和ImagesPipeline,其本质就是一个专门的下载器,其使用的方式就是将文件或图片的url传给它(eg:item[“file_urls”])。使用之前需要在settings.py文件中对其进行声明…

拒绝用电“盲人摸象”,体验智能微断的无缝升级

🌟 为什么需要智能微型断路器? 传统断路器只能被动保护电路,而安科瑞智能微型断路器不仅能实时监测用电数据,还能远程控制、主动预警,堪称用电安全的“全能卫士”!无论是家庭、工厂还是商业楼宇&#xff0…

如何优雅地为 Axios 配置失败重试与最大尝试次数

在 Vue 3 中,除了使用自定义的 useRequest 钩子函数外,还可以通过 axios 的拦截器 或 axios-retry 插件实现接口请求失败后的重试逻辑。以下是两种具体方案的实现方式: 方案一:使用 axios 拦截器实现重试 实现步骤: 通…

【Leetcode刷题随笔】242.有效的字母异位词

1. 题目描述 给定两个仅包含小写字母的字符串 s 和 t ,编写一个函数来判断 t 是否是 s 的 字母异位词。 字母异位词定义:两个字符串包含的字母种类和数量完全相同,但顺序可以不同(例如 “listen” 和 “silent”)。 …

示例:spring xml+注解混合配置

以下是一个 Spring XML 注解的混合配置示例,结合了 XML 的基础设施配置(如数据源、事务管理器)和注解的便捷性(如依赖注入、事务声明)。所有业务层代码通过注解简化,但核心配置仍通过 XML 管理。 1. 项目结…

Crawl4AI:打破数据孤岛,开启大语言模型的实时智能新时代

当大语言模型遇见数据饥渴症 在人工智能的竞技场上,大语言模型(LLMs)正以惊人的速度进化,但其认知能力的跃升始终面临一个根本性挑战——如何持续获取新鲜、结构化、高相关性的数据。传统数据供给方式如同输血式营养支持&#xff…

【机器学习-周总结】-第4周

以下是本周学习内容的整理总结,从技术学习、实战应用到科研辅助技能三个方面归纳: 文章目录 📘 一、技术学习模块:TCN 基础知识与结构理解🔹 博客1:【时序预测05】– TCN(Temporal Convolutiona…

Mysql--基础知识点--79.1--双主架构如何避免回环复制

1 避免回环过程 在MySQL双主架构中,GTID(全局事务标识符)通过以下流程避免数据回环: 1 事务提交与GTID生成 在Master1节点,事务提交时生成一个全局唯一的GTID(如3E11FA47-71CA-11E1-9E33-C80AA9429562:2…

安宝特科技 | AR眼镜在安保与安防领域的创新应用及前景

随着科技的不断进步,增强现实(AR)技术逐渐在多个领域展现出其独特的优势,尤其是在安保和安防方面。AR眼镜凭借其先进的功能,在机场、车站、海关、港口、工厂、园区、消防局和警察局等行业中为安保人员提供了更为高效、…

Linux第十讲:进程间通信IPC

Linux第十讲:进程间通信IPC 1.进程间通信介绍1.1什么是进程间通信1.2为什么要进程间通信1.3怎么进行进程间通信 2.管道2.1理解管道2.2匿名管道的实现代码2.3管道的五种特性2.3.1匿名管道,只能用来进行具有血缘关系的进程进行通信(通常是父子)2.3.2管道文…

微信小程序通过mqtt控制esp32

目录 1.注册巴法云 2.设备连接mqtt 3.微信小程序 备注 本文esp32用的是MicroPython固件,MQTT服务用的是巴法云。 本文参考巴法云官方教程:https://bemfa.blog.csdn.net/article/details/115282152 1.注册巴法云 注册登陆并新建一个topic&#xff…

SQLMesh隔离系统深度实践指南:动态模式映射与跨环境计算复用

在数据安全与开发效率的双重压力下,SQLMesh通过动态模式映射、跨环境计算复用和元数据隔离机制三大核心技术,完美解决了生产与非生产环境的数据壁垒问题。本文提供从环境配置到生产部署的完整实施框架,助您构建安全、高效、可扩展的数据工程体…

Spring Data详解:简化数据访问层的开发实践

1. 什么是Spring Data? Spring Data 是Spring生态中用于简化数据访问层(DAO)开发的核心模块,其目标是提供统一的编程模型,支持关系型数据库(如MySQL)、NoSQL(如MongoDB)…

15 nginx 中默认的 proxy_buffering 导致基于 http 的流式响应存在 buffer, 以 4kb 一批次返回

前言 这也是最近碰到的一个问题 直连 流式 http 服务, 发现 流式响应正常, 0.1 秒接收到一个响应 但是 经过 nginx 代理一层之后, 就发现了 类似于缓冲的效果, 1秒接收到 10个响应 最终 调试 发现是 nginx 的 proxy_buffering 配置引起的 然后 更新 proxy_buffering 为…

源超长视频生成模型:FramePack

FramePack 是一种下一帧(下一帧部分)预测神经网络结构,可以逐步生成视频。 FramePack 将输入上下文压缩为固定长度,使得生成工作量与视频长度无关。即使在笔记本电脑的 GPU 上,FramePack 也能处理大量帧,甚…

第6次课 贪心算法 A

向日葵朝着太阳转动,时刻追求自身成长的最大可能。 贪心策略在一轮轮的简单选择中,逐步导向最佳答案。 课堂学习 引入 贪心算法(英语:greedy algorithm),是用计算机来模拟一个「贪心」的人做出决策的过程…

Windows使用SonarQube时启动脚本自动关闭

一、解决的问题 Windows使用SonarQube时启动脚本自动关闭,并发生报错: ERROR: Elasticsearch did not exit normally - check the logs at E:\Inori_Code\Year3\SE\sonarqube-25.2.0.102705\sonarqube-25.2.0.102705\logs\sonarqube.log ERROR: Elastic…