如何从零开始训练一个语言模型

如何从零开始训练一个语言模型

RLHF
SFT
Pretrain
SFT Data
Pretrain Data
SSL
SFT
Reward Model
Human
Align Dataset
LLM
SFT Dataset
Base Model
Language Model
SSL Dataset
GPT4指令数据
BELLE指令数据
X X指令数据集
维基百科
百度百科
X X百科
任何开源文本
Pretrain Process
SFT Process

  本文主要三个方面介绍语言模型的训练过程,主要包括:数据集介绍(包含预训练数据和微调数据),数据的预处理,模型训练和微调,但不涉及对齐阶段(RLHF),对齐需要对齐的数据,也需要不同的预处理方式,对齐的目的是构建一个可以与人类价值观保持一致的LLM,减少虚假有害信息的输出。

数据集

Pretrain Data:

预训练数据主要来自从互联网上收集的文本数据,token的规模大概在trillion级别,整体质量偏低。

SFT Data:

SFT(Supervised Fine-Tuning)数据一般由指令,输入,响应组成,指令和输入一起组成prompt,作为模型的输入,响应作为标签。这类数据对质量要求较高,一般由人工构造,也可由GPT4生成。

预处理

分词Tokenizer:把文本序列转为为token序列。

Pretrain Process:

预训练是通过自监督(SSL)的方式训练,也就是预测下个词(token),数据处理方式如下:

def __getitem__(self, index: int):sample = self.data[index]X=np.array(sample[:-1]).astype(np.int64)Y=np.array(sample[1:]).astype(np.int64)return torch.from_numpy(X),torch.from_numpy(Y)

例如:文本分词后:sample = [1, 2, 3, 4, 5, 6]

  • x : 1, 2, 3, 4, 5
  • y : 2, 3, 4, 5, 6
SFT Process:

SFT(Supervised Fine-Tuning)阶段喂给模型的示例遵循(prompt、response)的格式,prompt包含:指令+输入,也称为指令数据,数据处理方式如下:

  • 拼接指令和输入
# 拼接指令和输入字符
q_lst, a_lst = [],[]
for per in data:q=per['instruction']i=per['input']a=per['output']q=q+iq_lst.append(q)a_lst.append(a)
df=pd.DataFrame(columns=['prompt','answer'])
df['prompt']=q_lst
df['answer']=a_lst
  • 拼接提示和响应,并添加分割符,同时生成掩码,掩码的作用是在计算loss时屏蔽prompt部分。
def __getitem__(self, index: int):sample = self.df.iloc[index]# 分词tokenizerprompt = self.tokenizer.encode(sample['prompt'],add_special_tokens=False)answer = self.tokenizer.encode(sample['answer'],add_special_tokens=False)# 截断最大长度if len(prompt) > self.prompt_max_len:prompt = prompt[:self.prompt_max_len-2]if len(answer) > self.answer_max_len:answer = answer[:self.answer_max_len-2]# 拼接提示和响应,同时添加特殊token,标识提示和响应结束inputs = prompt+[self.bos]+answer+[self.eos]# 掩码长度=提示长度prompt_length = inputs.index(self.bos)mask_position = prompt_length - 1# 填充至最大长度pad_len = self.max_length - len(inputs)inputs = inputs + [self.pad] * pad_lenif pad_len==0:# 屏蔽提示和填充位置loss_mask = [0]*prompt_length+[1]*(len(inputs[mask_position+1:]))else:loss_mask = [0]*prompt_length+[1]*(len(inputs[mask_position+1:-pad_len])) + [0]*pad_leninputs = np.array(inputs)X=np.array(inputs[:-1]).astype(np.int64)Y=np.array(inputs[1:]).astype(np.int64)loss_mask=np.array(loss_mask[:-1])return torch.from_numpy(X),torch.from_numpy(Y),torch.from_numpy(loss_mask)

例如:bos : 8, eos : 16, pad : 0,max_length = 16

inputs = prompt + [bos] + answer + [eos] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],

  • pad_len = 0:

  • prompt = [1, 2, 3, 4, 5, 6, 7]

  • answer = [9, 10, 11, 12, 13, 14, 15]

  • inputs = prompt + [bos] + answer + [eos] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]

    • x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
    • y = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
    • mask = [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]
  • pad_len > 0:

  • prompt = [1, 2, 3, 4, 5, 6, 7]

  • answer = [9, 10, 11, 12, 13]

  • inputs = prompt + [bos] + answer + [eos] + [pad]*2 = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 16, 0, 0]

    • x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 16, 0, 0]
    • y = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 16, 0, 0]
    • mask = [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0]

预训练阶段

预训练阶段采用标准的语言模型建模来最大化目标函数:

L p r e t r a i n ( X ) = ∑ i l o g P ( x i ∣ x i − k , . . . , x i − 1 ; Θ ) L_{pretrain}(\mathcal{X}) = \sum_i logP(x_i|x_{i-k},...,x_{i-1};\mathcal{\Theta}) Lpretrain(X)=ilogP(xixik,...,xi1;Θ)

  • x = x 1 , . . . , x n \mathcal{x} = {x_1, ..., x_n} x=x1,...,xn :语料

  • k k k : 上下文长度

  • P P P : 条件概率由参数为 Θ \Theta Θ的神经网络模型建模

神经网络模型(包含多个transformer模块),模型输入经过分词后(tokenzier)后的token序列,首先经过嵌入层,然后经过transformer_block,最后经过输出层输出token概率分布。

h 0 = X W e + W p h_0 = XW_e + W_p h0=XWe+Wp

h l = t r a n s f o r m e r b l o c k ( h l − 1 ) , ∀ i ∈ [ 1 , n ] h_l = transformer_{block}(h_{l-1}), \forall i \in [1,n] hl=transformerblock(hl1),i[1,n]

P ( u ) = s o f t m a x ( h n W e T ) P(u) = softmax(h_nW_e^T) P(u)=softmax(hnWeT)

  • W e W_e We : 嵌入矩阵
  • W p W_p Wp : 位置嵌入矩阵

微调阶段

微调阶段的数据前面已经提过,由3部分组成: X = { X i n s t r u c t i o n , X i n p u t , X a n s w e r } \mathcal{X} = \{X_{instruction} , X_{input},X_{answer}\} X={Xinstruction,Xinput,Xanswer}

经过预处理后: X = X i n s t r u c t i o n + X i n p u t + b o s + X a n s w e r + e o s \mathcal{X} = X_{instruction}+X_{input}+bos+X_{answer}+eos X=Xinstruction+Xinput+bos+Xanswer+eos

在微调阶段,模型结构不变,目标改变为:

L s f t ( X a n s w e r ) = ∑ i = l o c a l ( b o s ) l o c a l ( e o s ) l o g P ( x i ∣ x i − k , . . . , x i − 1 ; Θ ) L_{sft}(\mathcal{X_{answer}}) = \sum_{i=local(bos)}^{local(eos)} logP(x_i|x_{i-k},...,x_{i-1};\mathcal{\Theta}) Lsft(Xanswer)=i=local(bos)local(eos)logP(xixik,...,xi1;Θ)

在微调阶段只关注answer部分token序列的联合概率分布最大化。

  经过SFT(Supervised Fine-Tuning)阶段,通过给模型展示如何正确地响应不同的提示(指令)(例如问答,摘要,翻译等)的示例,模型会学会模仿示例数据中的响应行为,学会问答、翻译、摘要等能力。指令微调优势在于,对于任何特定任务的专用模型,只需要在通用大模型的基础上通过特定任务的指令数据进行微调,就可以解锁LLM在特定任务上的能力,不在需要从头去构建专用的小模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/805362.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Redis 缓存穿透、缓存击穿、缓存雪崩区别和解决方案

缓存穿透 什么是缓存穿透? 缓存穿透说简单点就是大量请求的 key 是不合理的,根本不存在于缓存中,也不存在于数据库中 。这就导致这些请求直接到了数据库上,根本没有经过缓存这一层,对数据库造成了巨大的压力&#xf…

2、Qt UI控件 -- qucsdk项目使用

前言:上一篇文章讲了qucsdk的环境部署,可以在QDesigner和Qt Creator中看到qucsdk控件,这一篇来讲下在项目中使用qucsdk库中的控件。 一、准备材料 要想使用第三方库,需要三个先决条件, 1、控件的头文件 2、动/静态链…

【C++造神计划】定义常量

1 宏常量(macro constants) 使用预处理器指令 #define 可以将那些经常使用的常量定义为你自己取的名字而不需要借助于变量 编译器在遇到 #define 指令的时候,做的只是把任何出现这些常量名的地方替换成它们被定义为的代码 #define 指令不是…

rollup 插件架构-装饰器模式增添插件性能分析

文章目录 输入 rollup 配置根据用户配置开启插件性能分析性能分析函数实现分级输出结果装饰器模式拓展组件 输入 rollup 配置 初始化计时器,构建完成时输出每个阶段的耗时、内存占用等信息,会 wrapper 相应 hook 方法,添加计时相关功能 initialiseTime…

记录vue之npm run serve报错SET NODE_OPTIONS

> vue-antd-pro3.0.0 serve > SET NODE_OPTIONS--openssl-legacy-provider && vue-cli-service servesh: SET: command not found 一定要注意:将 SET NODE_OPTIONS–openssl-legacy-provider && 删除即可

17 - Games101 - 笔记 - 材质与外观

**17 **材质与外观 材质与BRDF 自然界中的材质:丝绸、头发、蝴蝶翅膀表面、寿司表面等等 图形学中的材质:同一个模型之所以渲染出不同结果的原因就是因为材质。在图形学中是给不同的物体指定不同的材质,知道它们如何和光线作用后就能正确的…

C++11 数据结构0 什么是 “数据结构“?数据,数据对象,数据元素,数据项 概念。算法的基本概念 和 算法的度量,大O表示法,空间换时间的代码

数据: 是能输入计算机且能被计算机处理的各种符号的集合。数值型的数据:整数和实数。非数值型的数据:文字、图像、图形、声音等。 数据对象: 性质相同的 "数据元素" 的集合 例如一个 int arr[10], Teacher tea[3]; 数…

汽车4S行业的信息化特点与BI建设挑战

汽车行业也是一个非常大的行业,上下游非常广,像主机厂,上游的零配件,下游的汽车流通,汽车流通之后的汽车后市场,整个链条比较长。今天主要讲的是汽车流通,汽车4S集团。一个汽车4S集团下面授权代…

MySQL高级篇(存储引擎InnoDB、MyISAM、Memory)

目录 1、存储引擎简介 1.1、查询建表语句,默认存储引擎:InnoDB 1.2、查看当前数据库支持的存储引擎 1.3、创建表,并指定存储引擎 2、 存储引擎-InnoDB介绍 2.1、存储引擎特点 3、MyISAM存储引擎 4、Memory存储引擎 5、InnoDB、MyISAM、Memory…

HTML基础(3)

1、内联框架 iframe用于在网页内显示网页&#xff0c;语法如下&#xff1a; <iframe src"URL"></iframe> URL指向隔离页面 hight&#xff0c;weight设置高宽&#xff0c;删除边框将frameborder设置为0 <td> <iframe frameborder"0&qu…

AI技术创业机会之农业与食品科技

农业与食品科技领域在人工智能&#xff08;AI&#xff09;技术的推动下正经历深刻变革&#xff0c;为创业者提供了丰富的创业机会。以下详述了农业与食品科技背景下AI技术的创业机会及其具体细节与内容&#xff0c;以5000字篇幅深入探讨各细分领域&#xff0c;为有志于投身这一…

C++ 获取数组大小、多维数组操作详解

获取数组的大小 要获取数组的大小&#xff0c;可以使用 sizeof() 运算符&#xff1a; 示例 int myNumbers[5] {10, 20, 30, 40, 50}; cout << sizeof(myNumbers);结果&#xff1a; 20为什么结果显示为 20 而不是 5&#xff0c;当数组包含 5 个元素时&#xff1f; 这…

麒麟v10安装mysql-8.0.35

因为要修复漏洞的原因&#xff0c;这两天将麒麟v10操作系统的服务器上的MySQL版本由5.7.27升级到8.0.35&#xff08;mysql安装包下载地址&#xff1a;MySQL :: Download MySQL Community Server (Archived Versions)&#xff09;&#xff0c;mysql的安装过程主要参考了这个博主…

JavaScript-throw、try,2024年前端高级面试题总结

提交电话 二、xml初识 xml文件是用来做什么的 核心思想&#xff1a; 答&#xff1a;存储数据 延伸问题&#xff1a; xml是怎样存储数据的&#xff1f; 答&#xff1a;以标签的形式存储 例: coco 什么是xml元素? 元素该如何编写? xml中的元素其实就是一个个的标签 标签…

面试官为什么喜欢考察Vue底层原理

&#x1f90d; 前端开发工程师、技术日更博主、已过CET6 &#x1f368; 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 &#x1f560; 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 &#x1f35a; 蓝桥云课签约作者、上架课程《Vue.js 和 E…

国产低代码工具,轻松搞定数据迁移

在日常的业务系统升级或者数据维护过程中&#xff0c;数据迁移是各个企业用户不得不面临的问题&#xff0c;尤其是数据迁移过程中要保障数据完整性、统一性和及时性&#xff0c;同时也需要注意源数据中的数据质量问题&#xff0c;比如缺失、无效、错误等问题&#xff0c;需要在…

前端面试算法题1

1.已知&#xff1a; • 布局分为&#xff1a;父元素A和N个子元素B&#xff1b; • A宽度不固定&#xff1a;最小宽度为1000px&#xff0c;内部边距是32px • B的宽度不固定&#xff1a;相邻两个B元素的间距是16px&#xff0c;所有B的宽度相同&#xff0c;边框为1像素&#x…

docker-如何离线安装部署

离线安装Docker通常涉及到以下几个主要步骤&#xff0c;这里是一个简化的流程概述&#xff0c;适用于大多数Linux发行版&#xff08;如Ubuntu、CentOS等&#xff09;&#xff1a; 下载离线安装包 访问Docker官方下载页面或者使用已有的网络环境提前下载所需的Docker引擎安装包和…

【JAVA语言-第19话】多线程详细解析(一)

目录 多线程 1.1 并发和并行 1.2 线程和进程 1.2.1 进程 1.2.2 线程 1.3 单线程 1.3.1 单线程案例 1.4 创建多线程的方式 1.4.1 继承Thread类 1.4.2 实现Runnable接口 1.4.3 使用匿名内部类 1.5 Thread类 1.5.1 构造方法 1.5.2 常用方法 1.5.3 Thread类中…

《QT实用小工具·二十四》各种数学和数据的坐标演示图

1、概述 源码放在文章末尾 该项目实现了各种数学和数据的坐标演示图&#xff0c;下面是demo演示&#xff1a; 项目部分代码如下&#xff1a; #ifndef FRMMAIN_H #define FRMMAIN_H#include <QWidget> class QAbstractButton;namespace Ui { class frmMain; }class fr…