Transformers 微调

Transformers 微调

  • 基于 Transformers 实现模型微调训练的主要流程
    • 数据字段
    • 数据拆分(分成训练跟测试)
    • 下载数据集
    • 数据集抽样
    • 预处理数据
    • 数据抽样
    • 微调训练配置
    • 加载 BERT 模型
    • 训练超参数(TrainingArguments)
    • 模型权重保存路径(output_dir)
  • 开始训练
    • 实例化训练器(Trainer):可用nvidia-smi 查看使用率
    • 保存模型和训练状态

基于 Transformers 实现模型微调训练的主要流程

  1. 数据集下载

  2. 数据预处理

  3. 训练超参数配置

  4. 训练评估指标设置

  5. 训练器基本介绍

  6. 实战训练

  7. 模型保存

一个典型的数据点包括文本和相应的标签。来自YelpReviewFull测试集的示例如下:

{'label': 0,'text': 'I got \'new\' tires from them and within two weeks got a flat. I took my car to a local mechanic to see if i could get the hole patched, but they said the reason I had a flat was because the previous patch had blown - WAIT, WHAT? I just got the tire and never needed to have it patched? This was supposed to be a new tire. \\nI took the tire over to Flynn\'s and they told me that someone punctured my tire, then tried to patch it. So there are resentful tire slashers? I find that very unlikely. After arguing with the guy and telling him that his logic was far fetched he said he\'d give me a new tire \\"this time\\". \\nI will never go back to Flynn\'s b/c of the way this guy treated me and the simple fact that they gave me a used tire!'
}

数据字段

‘text’: 评论文本使用双引号(“)转义,任何内部双引号都通过2个双引号(”")转义。换行符使用反斜杠后跟一个 “n” 字符转义,即 “\n”。

‘label’: 对应于评论的分数(介于1和5之间)。

数据拆分(分成训练跟测试)

Yelp评论完整星级数据集是通过随机选取每个1到5星评论的130,000个训练样本和10,000个测试样本构建的。总共有650,000个训练样本和50,000个测试样本。

下载数据集

import os# 代理的地址,格式为 http://ip:port
http_proxy="http://proxy.sensetime.com:3128/"
https_proxy="http://proxy.sensetime.com:3128/"
# 设置代理
os.environ["HTTP_PROXY"] = http_proxy
os.environ["HTTPS_PROXY"] = https_proxy
from datasets import load_dataset
dataset = load_dataset("yelp_review_full")
#得到的dataset 其实就是一个字典(key:value格式)train 跟test就是这个下载下来的数据集的key。而dataset["train"] 通过这个可以拿到Dataset格式的训练数据集(集合)
print(dataset["train"][0])

数据集抽样

import random
import pandas as pd
import datasets
from IPython.display import display, HTML
#用于从给定的数据集 (dataset) 中随机选择一些示例并显示
def show_random_elements(dataset, num_examples=10):assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."picks = []for _ in range(num_examples):pick = random.randint(0, len(dataset)-1)print(pick)while pick in picks:pick = random.randint(0, len(dataset)-1)picks.append(pick)#将从数据集中随机选择的示例创建为 Pandas DataFramedf = pd.DataFrame(dataset[picks])for column, typ in dataset.features.items():#遍历数据集的所有特征if isinstance(typ, datasets.ClassLabel):#检查特征是否是分类标签#如果是分类标签,将使用 lambda 函数将标签的索引映射到实际的类别名称df[column] = df[column].transform(lambda i: typ.names[i])display(HTML(df.to_html()))
#可以print(show_random_elements(dataset["train"]) 查看效果

预处理数据

下载数据集到本地后,使用 Tokenizer 来处理文本,对于长度不等的输入数据,可以使用填充(padding)和截断(truncation)策略来处理。

Datasets 的 map 方法,支持一次性在整个数据集上应用预处理函数。

下面使用填充到最大长度的策略,处理整个数据集:

from transformers import AutoTokenizer
#用于加载预训练的文本处理模型(Tokenizer),以便将文本数据转换为模型可以接受的输入格式
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")def tokenize_function(examples):return tokenizer(examples["text"], padding="max_length", truncation=True)tokenized_datasets = dataset.map(tokenize_function, batched=True)#刚刚生成的dataset 通过map的方法,把里面的每个样本都进行tokenize_function操作,生成处理过的数据集tokenized_datasets#可以show_random_elements(tokenized_datasets["train"], num_examples=1)查看效果

数据抽样

使用 1000 个数据样本,在 BERT 上演示小规模训练(基于 Pytorch Trainer)

shuffle()函数会随机重新排列列的值。如果您希望对用于洗牌数据集的算法有更多控制,可以在此函数中指定generator参数来使用不同的numpy.random.Generator。


small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))

微调训练配置

加载 BERT 模型

警告通知我们正在丢弃一些权重(vocab_transform和vocab_layer_norm 层),并随机初始化其他一些权重(pre_classifier和classifier 层)。在微调模型情况下是绝对正常的,因为我们正在删除用于预训练模型的掩码语言建模任务的头部,并用一个新的头部替换它,对于这个新头部,我们没有预训练的权重,所以库会警告我们在用它进行推理之前应该对这个模型进行微调,而这正是我们要做的事情。

from transformers import AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=5)

训练超参数(TrainingArguments)

完整配置参数与默认值:https://huggingface.co/docs/transformers/v4.36.1/en/main_classes/trainer#transformers.TrainingArguments

源代码定义:https://github.com/huggingface/transformers/blob/v4.36.1/src/transformers/training_args.py#L161

模型权重保存路径(output_dir)

from transformers import TrainingArguments
model_dir = "models/bert-base-cased"# logging_steps 默认值为500,根据我们的训练数据和步长,将其设置为100, num_train_epochs 默认为3
training_args = TrainingArguments(output_dir=f"{model_dir}/test_trainer",logging_dir=f"{model_dir}/test_trainer/runs",logging_steps=100)
# 完整的超参数配置
print(training_args)

训练过程中的指标评估(Evaluate)
Hugging Face Evaluate 库 支持使用一行代码,获得数十种不同领域(自然语言处理、计算机视觉、强化学习等)的评估方法。 当前支持 完整评估指标:https://huggingface.co/evaluate-metric

训练器(Trainer)在训练过程中不会自动评估模型性能。因此,我们需要向训练器传递一个函数来计算和报告指标。

Evaluate库提供了一个简单的准确率函数,您可以使用evaluate.load函数加载

import numpy as np
import evaluatemetric = evaluate.load("accuracy")

接着,调用 compute 函数来计算预测的准确率。

在将预测传递给 compute 函数之前,我们需要将 logits 转换为预测值(所有Transformers 模型都返回 logits)。


def compute_metrics(eval_pred):logits, labels = eval_predpredictions = np.argmax(logits, axis=-1)return metric.compute(predictions=predictions, references=labels)

训练过程指标监控
通常,为了监控训练过程中的评估指标变化,我们可以在TrainingArguments指定evaluation_strategy参数,以便在 epoch 结束时报告评估指标。

from transformers import TrainingArguments, Trainer
training_args = TrainingArguments(output_dir=f"{model_dir}/test_trainer", evaluation_strategy="epoch",logging_dir=f"{model_dir}/test_trainer/runs",logging_steps=100)

开始训练

实例化训练器(Trainer):可用nvidia-smi 查看使用率

trainer = Trainer(model=model,args=training_args,train_dataset=small_train_dataset,eval_dataset=small_eval_dataset,compute_metrics=compute_metrics,
)
trainer.train()small_test_dataset = tokenized_datasets["test"].shuffle(seed=64).select(range(100))
trainer.evaluate(small_test_dataset)

保存模型和训练状态

使用 trainer.save_model 方法保存模型,后续可以通过 from_pretrained() 方法重新加载

使用 trainer.save_state 方法保存训练状态

trainer.save_model(f"{model_dir}/finetuned-trainer")
trainer.save_state()

微调代码示例

import os
# 代理的地址,格式为 http://ip:port
http_proxy="http://proxy.sensetime.com:3128/"
https_proxy="http://proxy.sensetime.com:3128/"
# 设置代理
os.environ["HTTP_PROXY"] = http_proxy
os.environ["HTTPS_PROXY"] = https_proxy## 下载数据集
from datasets import load_dataset
dataset = load_dataset("yelp_review_full")
#得到的dataset 其实就是一个字典(key:value格式)train 跟test就是这个下载下来的数据集的key。而dataset["train"] 通过这个可以拿到Dataset格式的训练数据集(集合)
#print(dataset["train"][0])可以查看数据集的大概的结构import random
import pandas as pd
import datasets
from IPython.display import display, HTML
#用于从给定的数据集 (dataset) 中随机选择一些示例并显示
def show_random_elements(dataset, num_examples=10):assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."picks = []for _ in range(num_examples):pick = random.randint(0, len(dataset)-1)print(pick)while pick in picks:pick = random.randint(0, len(dataset)-1)picks.append(pick)#将从数据集中随机选择的示例创建为 Pandas DataFramedf = pd.DataFrame(dataset[picks])for column, typ in dataset.features.items():#遍历数据集的所有特征if isinstance(typ, datasets.ClassLabel):#检查特征是否是分类标签#如果是分类标签,将使用 lambda 函数将标签的索引映射到实际的类别名称df[column] = df[column].transform(lambda i: typ.names[i])display(HTML(df.to_html()))#from transformers import AutoTokenizer#用于从Hugging Face加载预训练的文本处理模型(Tokenizer),以便将文本数据转换为模型可以接受的输入格式
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")def tokenize_function(examples):return tokenizer(examples["text"], padding="max_length", truncation=True)tokenized_datasets = dataset.map(tokenize_function, batched=True)
show_random_elements(tokenized_datasets["train"], num_examples=1)# 使用 1000 个数据样本,在 BERT 上演示小规模训练(基于 Pytorch Trainer)
small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))# 微调训练配置
# 从Hugging Face加载BERT 模型
from transformers import AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=5)from transformers import TrainingArguments
model_dir = "models/bert-base-cased"
# logging_steps 默认值为500,根据我们的训练数据和步长,将其设置为100, num_train_epochs 默认为3
from transformers import TrainingArguments, Trainer
training_args = TrainingArguments(output_dir=f"{model_dir}/test_trainer", evaluation_strategy="epoch",logging_dir=f"{model_dir}/test_trainer/runs",logging_steps=100)# Evaluate库提供了一个简单的准确率函数,使用`evaluate.load`函数加载
import numpy as np
import evaluate
metric = evaluate.load("accuracy")# `compute` 函数来计算预测的准确率。
def compute_metrics(eval_pred):logits, labels = eval_predpredictions = np.argmax(logits, axis=-1)return metric.compute(predictions=predictions, references=labels)### 实例化训练器(Trainer)
trainer = Trainer(model=model,args=training_args,train_dataset=small_train_dataset,eval_dataset=small_eval_dataset,compute_metrics=compute_metrics)trainer.train()
small_test_dataset = tokenized_datasets["test"].shuffle(seed=64).select(range(100))
trainer.evaluate(small_test_dataset)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/817059.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2024.4.19 Python爬虫复习day07 可视化3

综合案例 需求: 已知2020年疫情数据,都是json数据,需要从文件中读出,进行处理和分析,最终实现数据可视化折线图 相关知识点: json json简介: 本质是一个特定格式的字符串 举例: [{},{},{}] 或者 {}python中json包: import jsonpython数据转为json数据: 变量接收json…

微服务架构使用和docker部署方法(若依)

这里以若依官方网站开源的微服务框架为例子记录使用方法过程。 开源地址&#xff1a;RuoYi-Cloud: &#x1f389; 基于Spring Boot、Spring Cloud & Alibaba的分布式微服务架构权限管理系统&#xff0c;同时提供了 Vue3 的版本 下载后&#xff0c;用IDEA社区版开发工具打…

【量化交易】顶底分型策略

在众多的量化策略中&#xff0c;顶底分型策略因其独特的市场趋势捕捉能力和简洁的实现方式而受到许多投资者的青睐。本文将详细介绍顶底分型策略的原理&#xff0c;并展示如何使用Python在聚宽平台上实现这一策略。 感兴趣的朋友&#xff0c;可以在下方公号内回复&#xff1a;0…

GNU Radio Radar Toolbox编译及安装

文章目录 前言一、GNU Radio Radar Toolbox 介绍二、gr-radar 安装三、具体使用四、OFDM 雷达仿真 前言 GNU Radio Radar Toolbox&#xff08;gr-radar&#xff09;是一个开放源码的工具箱&#xff0c;用于 GNU Radio 生态系统&#xff0c;主要目的是为雷达信号处理提供必要的…

vue源码解析——diff算法/双端比对/patchFlag/最长递增子序列

虚拟dom——virtual dom&#xff0c;提供一种简单js对象去代替复杂的 dom 对象&#xff0c;从而优化 dom 操作。virtual dom 是“解决过多的操作 dom 影响性能”的一种解决方案。virtual dom 很多时候都不是最优的操作&#xff0c;但它具有普适性&#xff0c;在效率、可维护性之…

Leetcode 3111. Minimum Rectangles to Cover Points

Leetcode 3111. Minimum Rectangles to Cover Points 1. 解题思路2. 代码实现 题目链接&#xff1a;3111. Minimum Rectangles to Cover Points 1. 解题思路 这一题在这次比赛的4道题当中算是比较简单的&#xff0c;基本就只需要将所有的点排序之后然后使用贪婪算法来cover住…

【C++造神计划】运算符

1 赋值运算符 赋值运算符的功能是将一个值赋给一个变量 int a 5; // 将整数 5 赋给变量 a 运算符左边的部分叫作 lvalue&#xff08;left value&#xff09;&#xff0c;右边的部分叫作 rvalue&#xff08;right value&#xff09; 左边 lvalue 必须是一个变量 右边 rval…

木马免杀代码之python反序列化分离免杀

本篇文章主要用到python来对CobaltStrike生成的Shellcode进行分离免杀处理, 因此要求读者要有一定的python基础, 下面我会介绍pyhon反序列化免杀所需用到的相关函数和库 exec函数 exec函数是python的内置函数, 其功能与eval()函数相同, 但不同的是exec函数支持多行python代码…

我国新戊二醇产能逐渐增长 市场集中度有望进一步提升

我国新戊二醇产能逐渐增长 市场集中度有望进一步提升 新戊二醇&#xff08;NPG&#xff09;又称为2,2-二甲基-1,3-丙二醇&#xff0c;化学式为C5H12O2&#xff0c;熔点为124-130℃。新戊二醇多表现为一种无特殊气味的白色结晶固体&#xff0c;易溶于水及醇、醚等溶液。新戊二醇…

为什么看到这么多人不推荐C++?

前几天逛知乎的时候&#xff0c;看到一个问题&#xff1a; 看到这个问题我倒是想吐槽几句了。 C一直没找到自己的定位&#xff01; C语言&#xff1a;我是搞系统编程开发的&#xff0c;操作系统、数据库、编译器、网络协议栈全是我写的。 PHP&#xff1a;我是搞后端业务开发…

docker compose安装及安装慢解决办法

docker compose安装 Compose下载添加执行权限创建软链测试安装结果 Compose下载 curl -SL "https://github.com/docker/compose/releases/download/v2.26.1/docker-compose-$(uname -s)-$(uname -m)" -o /usr/local/bin/docker-compose上述compose是在docker官方git…

一年期SSL证书怎么申请?

申请SSL证书三步走 JoySSL_JoySSL SSL证书_JoySSL https证书-JoySSL 一、选证书类型 根据网站性质与安全需求&#xff0c;选定合适的SSL证书&#xff1a; - 域名验证证书&#xff08;DV&#xff09;&#xff1a;快速验证域名所有权&#xff0c;适用于个人网站、博客&#xff…

ReentrantLock源码阅读

1. 概述 lock锁, 基于队列同步器AQS, 实现公平锁、非公平锁 队列同步器AQS可以阅读我这篇文章&#xff1a; 点击传送 实现了Lock接口: public class ReentrantLock implements Lock// 加锁 获取不到锁一直等待 void lock(); // 加锁 获取不到锁一直等待 等待过程可以被中断…

websocket原理及简单入门

在了解websocket之前,我们先来了解一下websocket出现之前的世界 当我们在开腾讯会议或视频通话时,我们自己的影像会传给对方,对方的影像也能同时传给我们,这就是即时通讯技术 即时通讯技术是实现&#xff1a;服务器端可以时地将数据的更新或变化反应到客户端&#xff0c;在Web中…

Python中操作Excel表对象并打包为脚本

一、准备工作 pip install pandas pip install openpyxl pip install pyinstaller 数据表格&#xff1a; 数据表下载 二、执行写入操作 import pandas as pd # pyinstaller --onefile attendance_records_score.py # 打包 # 读取源Excel文件&#xff08;假设源表有列A…

【攻防世界】php_rce (ThinkPHP5)

进入题目环境&#xff0c;查看页面信息&#xff1a; 页面提示 ThinkPHP V5&#xff0c;猜测存在ThinkPHP5 版本框架的漏洞&#xff0c;于是查找 ThinkPHP5 的攻击POC。 构造 payload: http://61.147.171.105:50126/?sindex/think\app/invokefunction&functioncall_user_f…

【Go语言快速上手(一)】 初识Go语言

&#x1f493;博主CSDN主页:杭电码农-NEO&#x1f493;   ⏩专栏分类:Go语言专栏⏪   &#x1f69a;代码仓库:NEO的学习日记&#x1f69a;   &#x1f339;关注我&#x1faf5;带你学习更多Go语言知识   &#x1f51d;&#x1f51d; Go快速上手 1. 前言2. Go语言简介(为…

模拟Android系统Zygote启动流程

版权声明&#xff1a;本文为梦想全栈程序猿原创文章&#xff0c;转载请附上原文出处链接和本声明 前言&#xff1a; 转眼时间过去了10年了&#xff0c;回顾整个10年的工作历程&#xff0c;做了3年的手机&#xff0c;4年左右的Android指纹相关的工作&#xff0c;3年左右的跟传感…

什么是三次握手和四次握手

三次握手和四次挥手是TCP协议中用于建立和终止TCP连接的重要机制。 三次握手是TCP连接建立的过程&#xff0c;具体步骤如下&#xff1a; 客户端发送一个带有SYN标志的数据包给服务端&#xff0c;表示希望建立连接。服务端收到后&#xff0c;回传一个带有SYN/ACK标志的数据包&…

亚马逊CloudFront使用体验

前言 首先在体验CloudFront之前&#xff0c;先介绍一下什么是CDN&#xff0c;以及CDN的基本原理。 CDN是Content Delivery Network&#xff08;内容分发网络&#xff09;的缩写&#xff0c;是一种利用分布式节点技术&#xff0c;在全球部署服务器&#xff0c;即时地将网站、应…