八、大模型之Fine-Tuning(1)

1 什么时候需要Fine-Tuning

  1. 有私有部署的需求
  2. 开源模型原生的能力不满足业务需求

2 训练模型利器Hugging Face

  1. 官网(https://huggingface.co/)
  2. 相当于面向NLP模型的Github
  3. 基于transformer的开源模型非常全
  4. 封装了模型、数据集、训练器等,资源下载方面
  5. 安装依赖
# pip 安装
pip install transformers # 安装最新版本
pip install transformers == 4.30 # 安装指定版本
# conda安装
conda install -c huggingface transformers  # 只4.0以后的版本

3 案例

3.1 操作流程

加载数据集—>数据预处理—>数据规整器—>训练器
在这里插入图片描述

3.2 实现

  1. 导包
import datasets
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel
from transformers import AutoModelForCausalLM
from transformers import TrainingArguments, Seq2SeqTrainingArguments
from transformers import Trainer, Seq2SeqTrainer
import transformers
from transformers import DataCollatorWithPadding
from transformers import TextGenerationPipeline
import torch
import numpy as np
import os, re
from tqdm import tqdm
import torch.nn as nn
  1. 加载数据集
    通过HuggingFace,可以指定数据集名称,运行时自动下载
# 数据集名称
DATASET_NAME = "rotten_tomatoes" # 加载数据集
raw_datasets = load_dataset(DATASET_NAME)# 训练集
raw_train_dataset = raw_datasets["train"]# 验证集
raw_valid_dataset = raw_datasets["validation"]

在这里插入图片描述
3. 加载模型

# 模型名称
MODEL_NAME = "gpt2" # 加载模型 
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME,trust_remote_code=True)

在这里插入图片描述
4. 加载Tokenizer
通过HuggingFace,可以指定模型名称,运行自动下载对应Tokenizer

# 加载tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME,trust_remote_code=True)
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
tokenizer.pad_token_id = 0# 设置随机种子:同个种子的随机序列可复现
transformers.set_seed(42)# 标签集
named_labels = ['neg','pos']# 标签转 token_id
label_ids = [tokenizer(named_labels[i],add_special_tokens=False)["input_ids"][0] for i in range(len(named_labels))
]

在这里插入图片描述
5. 处理数据集:转成模型接受的输入格式

  • 拼接输入输出:<INPUT TOKEN IDS><EOS_TOKEN_ID><OUTPUT TOKEN IDS>
  • PAD成相等长度:
    • <INPUT 1.1><INPUT 1.2>…<EOS_TOKEN_ID><OUTPUT TOKEN IDS><PAD>…<PAD>
    • <INPUT 2.1><INPUT 2.2>…<EOS_TOKEN_ID><OUTPUT TOKEN IDS><PAD>…<PAD>
  • 标识出参与 Loss 计算的 Tokens (只有输出 Token 参与 Loss 计算)
    • <-100><-100>…<OUTPUT TOKEN IDS><-100>…<-100>
MAX_LEN=32   #最大序列长度(输入+输出)
DATA_BODY_KEY = "text" # 数据集中的输入字段名
DATA_LABEL_KEY = "label" #数据集中输出字段名# 定义数据处理函数,把原始数据转成input_ids, attention_mask, labels
def process_fn(examples):model_inputs = {"input_ids": [],"attention_mask": [],"labels": [],}for i in range(len(examples[DATA_BODY_KEY])):inputs = tokenizer(examples[DATA_BODY_KEY][i],add_special_tokens=False)label = label_ids[examples[DATA_LABEL_KEY][i]]input_ids = inputs["input_ids"] + [tokenizer.eos_token_id, label]raw_len = len(input_ids)input_len = len(inputs["input_ids"]) + 1if raw_len >= MAX_LEN:input_ids = input_ids[-MAX_LEN:]attention_mask = [1] * MAX_LENlabels = [-100]*(MAX_LEN - 1) + [label]else:input_ids = input_ids + [tokenizer.pad_token_id] * (MAX_LEN - raw_len)attention_mask = [1] * raw_len + [0] * (MAX_LEN - raw_len)labels = [-100]*input_len + [label] + [-100] * (MAX_LEN - raw_len)model_inputs["input_ids"].append(input_ids)model_inputs["attention_mask"].append(attention_mask)model_inputs["labels"].append(labels)return model_inputs

6.定义数据规整器:训练时自动将数据拆分成Batch

# 定义数据校准器(自动生成batch)
collater = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="pt",
)

7.定义训练超参

LR=2e-5         # 学习率
BATCH_SIZE=8    # Batch大小
INTERVAL=100    # 每多少步打一次 log / 做一次 eval# 定义训练参数
training_args = TrainingArguments(output_dir="./output",              # checkpoint保存路径evaluation_strategy="steps",        # 按步数计算eval频率overwrite_output_dir=True,num_train_epochs=1,                 # 训练epoch数per_device_train_batch_size=BATCH_SIZE,     # 每张卡的batch大小gradient_accumulation_steps=1,              # 累加几个step做一次参数更新per_device_eval_batch_size=BATCH_SIZE,      # evaluation batch sizeeval_steps=INTERVAL,                # 每N步eval一次logging_steps=INTERVAL,             # 每N步log一次save_steps=INTERVAL,                # 每N步保存一个checkpointlearning_rate=LR,                   # 学习率
)

8.定义训练器

# 节省显存
model.gradient_checkpointing_enable()# 定义训练器
trainer = Trainer(model=model, # 待训练模型args=training_args, # 训练参数data_collator=collater, # 数据校准器train_dataset=tokenized_train_dataset,  # 训练集eval_dataset=tokenized_valid_dataset,   # 验证集# compute_metrics=compute_metric,         # 计算自定义评估指标
)

8.训练

trainer.train()

总结

  1. 加载数据集
  2. 数据预处理
    • 将输入输出按特定格式拼接
    • 文本转Token IDs
    • 通过labels标识出哪部分是输出(只有输出的token参与loss计算)
  3. 加载模型、Tokenizer
  4. 定义数据规则整器
  5. 定义训练超参:学习率、批次大小
  6. 定义训练器
  7. 开始训练

4 大模型训练相关技术

  1. 神经网络
    在这里插入图片描述

  2. 常用的激活函数
    在这里插入图片描述

  3. 梯度下降
    在这里插入图片描述

  4. 学习率
    在这里插入图片描述

  5. 求解器

为了让训练过程更好的收敛,人们设计了很多更复杂的求解器

  • 比如:SGD、L-BFGS、Rprop、RMSprop、Adam、AdamW、AdaGrad、AdaDelta 等等
  • 但是,好在对于Transformer最常用的就是 Adam 或者 AdamW
  1. 一些常用的损失函数
  • 两个数值的差距,Mean Squared Error: ℓ M S E = 1 N ∑ i = 1 N ( y i − y ^ i ) 2 \ell_{\mathrm{MSE}}=\frac{1}{N}\sum_{i=1}^N(y_i-\hat{y}_i)^2 MSE=N1i=1N(yiy^i)2 (等价于欧式距离,见下文)

  • 两个向量之间的(欧式)距离: ℓ ( y , y ^ ) = ∥ y − y ^ ∥ \ell(\mathbf{y},\mathbf{\hat{y}})=\|\mathbf{y}-\mathbf{\hat{y}}\| (y,y^)=yy^

  • 两个向量之间的夹角(余弦距离):
    在这里插入图片描述

  • 两个概率分布之间的差异,交叉熵: ℓ C E ( p , q ) = − ∑ i p i log ⁡ q i \ell_{\mathrm{CE}}(p,q)=-\sum_i p_i\log q_i CE(p,q)=ipilogqi ——假设是概率分布 p,q 是离散的

  • 这些损失函数也可以组合使用(在模型蒸馏的场景常见这种情况),例如 L = L 1 + λ L 2 L=L_1+\lambda L_2 L=L1+λL2,其中 λ \lambda λ是一个预先定义的权重,也叫一个「超参」

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/783940.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[Windows]修改默认远程端口3389

文章目录 注册表编辑器找到注册信息找到端口配置修改端口重启远程连接服务远程连接 因为不想使用windos默认远程3389端口&#xff0c;所以考虑换成其他的端口。保证安全&#xff08;虽然windows不是那么安全&#xff09;。 注册表编辑器 windos搜索注册表编辑器 找到注册信息…

网上国网App启动鸿蒙原生应用开发,鸿蒙开发前景怎么样?

从华为宣布全面启动鸿蒙生态原生应用一来&#xff0c;各种各样的新闻就没有停过&#xff0c;如&#xff1a;阿里、京东、小红书……等大厂的加入&#xff0c;而这次他们又与一个国企大厂进行合作&#xff1a; 作为特大型国有重点骨干企业&#xff0c;国家电网承担着保障安全、经…

HAL库 USART通讯

1、UASRT简介 串口通讯 (Serial Communication) 是一种设备间非常常用的串行通讯方式&#xff0c;因为它简单便捷&#xff0c;因此大部分电子设备都支持该通讯方式&#xff0c;电子工程师在调试设备时也经常使用该通讯方式输出调试信息。 按位发送和接收的接口。如RS-232/422/…

C语言-文件操作

&#x1f308;很高兴可以来阅读我的博客&#xff01;&#x1f31f;我热衷于分享&#x1f58a;学习经验&#xff0c;&#x1f3eb;多彩生活&#xff0c;精彩足球赛事⚽&#x1f517;我的CSDN&#xff1a; Kevin ’ s blog&#x1f4c2;专栏收录&#xff1a;C预言 1. 文件的作用 …

【核弹级软安全事件】XZ Utils库中发现秘密后门,影响主要Linux发行版,软件供应链安全大事件

Red Hat 发布了一份“紧急安全警报”&#xff0c;警告称两款流行的数据压缩库XZ Utils&#xff08;先前称为LZMA Utils&#xff09;的两个版本已被植入恶意代码后门&#xff0c;这些代码旨在允许未授权的远程访问。 此次软件供应链攻击被追踪为CVE-2024-3094&#xff0c;其CVS…

基于SSM远程同步课堂系统

基于SSM远程同步课堂系统的设计与实现 摘要 在这样一个网络数据大爆炸的时代&#xff0c;人们获取知识、获取信息的通道非常的多元化&#xff0c;通过网络来实现数据信息的获取成为了现在非常常见的一种方式&#xff0c;而通过网络进行教学&#xff0c;在网络上进行远程的课堂…

vue3 视频播放功能整体复盘梳理

回顾工作中对视频的处理&#xff0c;让工作中处理的问题的经验固化成成果&#xff0c;不仅仅是完成任务&#xff0c;还能解答任务的知识点。 遇到的问题 1、如何隐藏下载按钮&#xff1f; video 标签中的controlslist属性是可以用来控制播放器上空间的显示&#xff0c;在原来默…

视觉里程计之对极几何

视觉里程计之对极几何 前言 上一个章节介绍了视觉里程计关于特征点的一些内容&#xff0c;相信大家对视觉里程计关于特征的描述已经有了一定的认识。本章节给大家介绍视觉里程计另外一个概念&#xff0c;对极几何。 对极几何 对极几何是立体视觉中的几何关系&#xff0c;描…

LeetCode刷题【链表,图论,回溯】

目录 链表138. 随机链表的复制148. 排序链表146. LRU 缓存 图论200. 岛屿数量994. 腐烂的橘子207. 课程表 回溯 链表 138. 随机链表的复制 给你一个长度为 n 的链表&#xff0c;每个节点包含一个额外增加的随机指针 random &#xff0c;该指针可以指向链表中的任何节点或空节…

搜索与图论——Floyd算法求最短路

floyd算法用来求多源汇最短路 用邻接矩阵来存所有的边 时间复杂度O(n^3) #include<iostream> #include<cstring> #include<algorithm>using namespace std;const int N 20010,INF 1e9;int n,m,k; int g[N][N];void floyd(){for(int k 1;k < n;k ){f…

C++从入门到精通——引用()

C的引用 前言一、C引用概念二、引用特性交换指针引用 三、常引用保证值不变权限的方法权限的放大权限的缩小权限的平移类型转换临时变量 四、引用的使用场景1. 做参数2. 做返回值 五、传值、传引用效率比较值和引用的作为返回值类型的性能比较 六、引用和指针的区别引用和指针的…

AI算法中的关键先生 - 反向转播与戴维莱姆哈特

0. 引言 机器学习的自动推导过程中有一个关键步骤&#xff0c;就是自动求解过程的参数反向传播过程&#xff0c;这个工作据说是这个人做的&#xff1a; Remembering David E. Rumelhart (1942-2011) – Association for Psychological Science – APSAPS Fellow and Charter …

CDR2024软件免费版 CDR绘制矢量图形插画 平面设计CDR教程 cdr2024注册机 cdr2024安装包百度云 cdr快捷键

CDR2024软件免费版是由加拿大Corel公司研发的平面设计类软件。CDR2024软件免费版应用的行业非常广泛&#xff0c;影视动画、网页设等行业都可以应用此软件。CDR2024软件免费版为专业工作者提供了更加精细的操作功能&#xff0c;用户可以更加精确的对图片进行修改&#xff0c;降…

【MySQL】DQL-查询语句全解 [ 基础/条件/分组/排序/分页查询 ](附带代码演示&案例练习)

前言 大家好吖&#xff0c;欢迎来到 YY 滴MySQL系列 &#xff0c;热烈欢迎&#xff01; 本章主要内容面向接触过C Linux的老铁 主要内容含&#xff1a; 欢迎订阅 YY滴C专栏&#xff01;更多干货持续更新&#xff01;以下是传送门&#xff01; YY的《C》专栏YY的《C11》专栏YY的…

解决一个有意思的抛硬币问题,计算连续两次正面所需次数的数学期望

文章目录 一、问题与分析二、基本的数学推导三、代码示例 &#x1f349; CSDN 叶庭云&#xff1a;https://yetingyun.blog.csdn.net/ 一、问题与分析 问题&#xff1a;对于质地均匀的硬币&#xff0c;连续两次得到正面所需的次数数学期望是多少&#xff1f; 关键词&#xff1a;…

10天学会kotlin DAY6 继承、类、重载

kotlin 继承与重载 前言 1、open 关键字 2、类型转换 3、Any 超类 4、对象声明 5、对象表达式 6、伴生对象 7、嵌套类和内部类 8、数据类 9、copy 函数 10、运算符重载 11、枚举类定义函数 12、代数数据类型 13、密封类 14、数据类的小结 总结 前言 使用纯代码…

「MySQL」索引事务

&#x1f387;个人主页&#xff1a;Ice_Sugar_7 &#x1f387;所属专栏&#xff1a;数据库 &#x1f387;欢迎点赞收藏加关注哦&#xff01; 索引&事务 &#x1f349;索引&#x1f34c;特点&#x1f34c;通过 SQL 操作索引&#x1f34c;底层数据结构 &#x1f349;事务&…

Nginx的反向代理

Nginx的反向代理 location ^~ /aaa {proxy_pass http://192.168.15.78/; } 1. 跨域 2.Nginx 代理服务器缓存 3.Nginx 负载均衡 4. 动静分离 Nginx的跨域 跨源资源共享 (CORS) 是一种机制&#xff0c;它使用额外的 HTTP 标头让用户代理获得访问来自不同来域的服务器上选定资…

3.31学习总结

算法 解题思路 使用dfs,对蛋糕每层可能的高度和半径进行穷举.通过观察我们可以知道第一层的圆面积是它上面所有蛋糕层的圆面积之和,所以我们只要去求每层的侧面积就行了. 因为题目要求Ri > Ri1且Hi > Hi1,所以我们可以求出每层的最小体积和侧面积,用两个数组分别储存起来…

C语言实现猜数字游戏(有提示,限制次数版)

这次的猜数字游戏我添加了新的功能&#xff1a;为玩家添加了提示&#xff0c;以及输入数字的限制次数。 首先&#xff0c;我们的猜数字游戏需要一个菜单&#xff0c;来让玩家可以选择玩游戏还是退出游戏&#xff0c;所以我们需要开始就打印一个菜单&#xff1a; int main() {…