模型微调入门介绍一

  备注:模型微调系列的博客部分内容来源于极客时间大模型微调训练营素材,撰写模型微调一系列博客,主要是期望把训练营的内容内化成自己的知识,我自己写的这一系列博客除了采纳部分训练营的内容外,还会扩展细化某些具体细节知识点。

  模型微调大致会有下面5大步骤,其中数据下载主要用transformers库中的datasets来完成,数据预处理部分会用到tokenizer对象。本篇博客会重点介绍数据加载和数据预处理部分,剩余的三个步骤会通过一个简单的例子来简要介绍,后面会有专门的博客来介绍超参数如何设置和结果评估等内容。

数据下载

  datasets 是由 Hugging Face 提供的一个 Python 库,用于访问和使用大量自然语言处理(NLP)数据集。该库旨在使研究人员和开发人员能够轻松地获取、处理和使用各种 NLP 数据集,从而促进自然语言处理模型的研究和开发。datasets提供的常用function如下图所示:

load_dataset(name, split=None):用于加载指定名称的数据集。可以通过 split 参数指定加载数据集的特定拆分(如 "train"、"validation"、"test" 等)。
list_datasets():列出所有可用的数据集名称。
load_metric(name):加载指定名称的评估指标,用于评估模型性能,后面会有专门的一篇博客进行介绍。
load_from_disk(path) 和 save_to_disk(path, data):用于从磁盘加载数据集或将数据集保存到磁盘。
shuffle(seed=None):用于对数据集进行随机洗牌。可以通过 seed 参数指定随机数生成器的种子。
train_test_split(test_size=0.2, seed=None):用于将数据集拆分为训练集和测试集。

数据预处理

清洗数据

  在进行数据预处理的时候,通常需要分析是否需要进行数据清洗。例如,如果原始数据中存在一些特殊符号需要进行清理,通常会自定义清理方法对原始数据进行清洗。具体demo code如下图所示,具体的clean_text方法需要结合具体的数据进行自定义。

import re
import stringdef clean_text(text):# 将文本转换为小写text = text.lower()# 去除标点符号text = text.translate(str.maketrans("", "", string.punctuation))# 去除数字text = re.sub(r'\d+', '', text)# 去除多余的空格text = re.sub(r'\s+', ' ', text).strip()# 处理缩写词,这里只是一个简单的示例text = re.sub(r"won't", "will not", text)text = re.sub(r"can't", "can not", text)# 添加更多的缩写词处理..return text# 示例文本
raw_text = "Hello, how are you? This is an example text with some numbers like 123 and punctuations!!!"# 进行文本清理
cleaned_text = clean_text(raw_text)# 输出结果
print("Original Text:")
print(raw_text)
print("\nCleaned Text:")
print(cleaned_text)

Tokenzier进行数据预处理

 除了数据清洗,在做数据预处理的时候,通常会调用tokenizer的方法进行填充、截断等预处理,那么tokenizer具体提供了哪些参数呢?初始化tokenizer对象时,主要有以下参数:

max_length:控制分词后的最大序列长度。文本将被截断或填充以适应这个长度。
truncation:控制是否对文本进行截断,以适应 max_length。可以设置为 True(默认)或 False。
padding:控制是否对文本进行填充,以适应 max_length。
return_tensors:控制返回的结果是否应该是 PyTorch 或 TensorFlow 张量。可以设置为 'pt'、'tf' 或 None(默认)。
add_special_tokens:控制是否添加特殊令牌,如 [CLS]、[SEP] 或 [MASK]。可以设置为 True(默认)或 False。
is_split_into_words:控制输入文本是否已经是分好词的形式。如果设置为 True,分词器将跳过分词步骤。可以设置为 False 或 True(默认)。
return_attention_mask:控制是否返回 attention mask,指示模型在输入序列中哪些标记是有效的。可以设置为 True 或 False(默认)。
return_offsets_mapping:控制是否返回标记的偏移映射,即每个标记在原始文本中的起始和结束位置。可以设置为 True 或 False(默认)。
return_token_type_ids:控制是否返回用于区分文本段的 token type ids。可以设置为 True 或 False(默认)

 以下面的demo code为例,当设置padding=“max_length”后,如果内容长度低于10,会对内容进行自动填充。tokenizer对象返回一个字典类型,包含inputs_ids,token_type_ids,attention_mask。其中inputs_ids是真正的对输入文本的编码,attention_mask用于标记哪些是真正的输入文本转换的内容,哪些是填充内容,标记为0的即为填充内容。

除了上面的字段外,还可以设置是否返回tensor,是否添加特殊标记等。以下面的例子为例,在encode中添加了特殊标记,设置了返回张量,则返回的内容是tensor张量。

from transformers import BertTokenizer# 初始化 BertTokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")# 定义文本
text = "Hello, how are you? I hope everything is going well."# 使用 tokenizer.tokenize 进行分词
tokens = tokenizer.tokenize(text)
print("Tokens after tokenization:", tokens)# 使用 tokenizer.encode 将文本编码成模型输入的标识符序列
input_ids = tokenizer.encode(text, max_length=15, truncation=True, padding="max_length", add_special_tokens=True)
print("Input IDs after encoding:", input_ids)# 使用 tokenizer.decode 将模型输出的标识符序列解码为文本
decoded_text = tokenizer.decode(input_ids)
print("Decoded text:", decoded_text)# 使用 tokenizer.encode_plus 获取详细的编码结果,包括 attention mask 和 token type ids
encoding_result = tokenizer.encode_plus(text, max_length=15, truncation=True, padding="max_length", add_special_tokens=True, return_tensors="pt")
print("Detailed encoding result:", encoding_result)

 打印出来的结果如下图所示:

在上面调用tokenizer的方法时,有直接调用encode,有调用encode_plus,还有直接初始化tokenizer对象,那么他们之间有什么区别么?

encode与encode_plus的区别

encode方法:该方法用于将输入文本编码转换为模型输入的整数序列(input IDs)。它只返回输入文本的编码结果。
使用场景: 适用于单一文本序列的编码,例如一个问题或一段文本。

encode_plus方法:该方法除了生成整数序列(input IDs)外,还会生成注意力掩码(attention mask)、段落标记(segment IDs)等其他有用的信息,通常用于训练和评估中。返回一个字典,包含编码后的各种信息。
使用场景: 适用于处理多个文本序列,例如一个问题和一个上下文文本。

encode_plus与直接调用tokenizer对象本质上无区别:在 Hugging Face Transformers 库中,直接调用 tokenizer 对象和调用 tokenizer.encode 方法的本质是相同的,都是为了将文本转换为模型可接受的输入标识符序列。这两种方式实际上等效,都是通过 tokenizer 对象的编码方法完成的。

数据处理的具体例子

 在数据预处理过程中,不同的数据类型预处理的步骤不同,以huggingface中的squad数据集和yelp_review_full数据集为例,squad是从上下文context中寻找question的答案。yelp_review_full数据集是对一系列评论以及评论的分数数据。squad用于训练问答系统模型,yelp_review_full用于训练文本分类、情感分类模型。

squad数据集

yelp_review_full数据集

 下面以yelp_review_full为例子,看看如何完成数据预处理与模型微调训练。下面的代码是加载yelp_review_full的数据完成模型的微调。在数据预处理部分,调用tokenizer对象,将truncation设置为true,以及设置了padding="max_length".没有复杂的预处理过程。

from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments
from datasets import load_dataset, load_metric
import evaluate# 1. 加载YelpReviewFull数据集
dataset = load_dataset("yelp_review_full")# 2. 选择并加载BERT模型和标记器
model_name = "bert-base-uncased"
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=5)  # num_labels=5表示5种分类任务
tokenizer = AutoTokenizer.from_pretrained(model_name)# 3. 对原始数据进行标记化
def tokenize_function(examples):return tokenizer(examples["text"], padding="max_length", truncation=True)tokenized_datasets = dataset.map(tokenize_function, batched=True)# 4. 定义训练参数
training_args = TrainingArguments(output_dir="./yelp_review_model",  # 保存微调模型的目录per_device_train_batch_size=8,      # 每个设备的训练批次大小evaluation_strategy="steps",        # 在每个 steps 后进行评估eval_steps=500,                     # 每 500 个 steps 进行一次评估save_steps=500,                     # 每 500 个 steps 保存一次模型num_train_epochs=3,                 # 微调的轮数logging_dir="./logs"               # 保存训练日志的目录
)# 5. 定义compute_metrics函数计算准确度
metric = evaluate.load("accuracy")
def compute_metrics(p):preds = p.predictions.argmax(axis=1)return metric.compute(predictions=preds, references=p.label_ids)small_train_data=tokenized_datasets["train"].shuffle(seed=42).select(range(5000))
small_test_data=tokenized_datasets["test"].shuffle(seed=42).select(range(1000))
# 6. 定义Trainer对象
trainer = Trainer(model=model,args=training_args,train_dataset=small_train_data,eval_dataset=small_test_data,
#     train_dataset=tokenized_datasets["train"],
#     eval_dataset=tokenized_datasets["test"],compute_metrics=compute_metrics,   # 使用定义的compute_metrics函数
)# 7. 微调BERT模型
trainer.train()# 8. 输出评估结果
results = trainer.evaluate()
print("Results:", results)

因为只选取了部分数据进行训练,正确率是0.632.训练结果如下图所示:

对于用于训练问答系统模型的squad数据,预处理步骤会多一些,所以会在下一篇博客中做专门的介绍。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/360735.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

npp夜光数据介绍 viirs_科研成果快报第177期:中国地区长时序AVHRR气溶胶数据的主要问题: 气溶胶反演频次与重污染天气...

中国地区长时序AVHRR气溶胶数据的主要问题:气溶胶反演频次与重污染天气A critical view of long-term AVHRR aerosol data record in China: Retrieval frequency and heavy pollution成果信息Minghui Tao, Rong Li, Lili Wang et al. (2020)A critical view of lon…

使用Eclipse创建一个Android程序方法

要编写Android程序,需要安装JDK、Eclipse和Android SDK。 Android SDK的安装路径不要在program file或program file(x86)下,否则在debug时会碰很奇怪的问题。最好直接放在C:\Android下。(如果非要放在Program files下也可以,在ecl…

如何使用Hibernate批处理DELETE语句

介绍 在我以前的文章中 ,我解释了批处理INSERT和UPDATE语句所需的Hibernate配置。 这篇文章将继续本主题的DELETE语句批处理。 领域模型实体 我们将从以下实体模型开始: Post实体与Comment具有一对多关联,并且与PostDetails实体具有一对一…

蓝点linux_新闻速读 gt; Windows 10 的 Linux 内核将像驱动程序一样由微软更新服务进行更新 | Linux 中国...

本文字数:3252,阅读时长大约:4 分钟导读:• Ubuntu 发行商 Canonical 将参加微软欧洲虚拟开源峰会 • 树莓派支持 Vulkan 最新进展:通过 70000 项测试 • 谷歌浏览器开始隐藏 URL 详细路径,未来地址栏将只显…

struts2-通配符和动态方法调用

通配符举例--BookAction 1 public class BookAction extends ActionSupport {2 3 public String execute() throws Exception {4 System.out.println("BookAction ********** execute()");5 return null;6 }7 /*8 * 显示图书添加页…

JavaFX技巧18:路径剪切

我最近注意到,我致力于ControlsFX项目的PopOver控件无法正确剪切其内容。 当我为FlexCalendarFX框架开发手风琴弹出窗口时,这一点变得显而易见。 每当最后一个标题窗格扩展时,其底角不再是圆角而是正方形。 在将红色矩形作为内容放置到标题窗…

关于erlang的套接字编程

套接字编程即熟悉的Socket编程,根据传输层协议,可分为:UDP协议和TCP协议.下面写一个简单的例子,再重新认识下它: 1.在同一主机节点下启动两个Erlang节点. a).在第一个Erlang节点下,打开端口为1234的UDP套接…

kotlin 添加第一个 集合_Flutter开发必学Dart语法篇之集合操作符函数与源码分析...

简述:在上一篇文章中,我们全面地分析了常用集合的使用以及集合部分源码的分析。那么这一节讲点更实用的内容,绝对可以提高你的Flutter开发效率的函数,那就是集合中常用的操作符函数。这次说的内容的比较简单就是怎么用,以及源码内…

在Java中确定文件类型

以编程方式确定文件的类型可能非常棘手,并且已经提出并实现了许多基于内容的文件标识方法。 Java中有几种可用于检测文件类型的实现,其中大多数很大程度上或完全基于文件的扩展名。 这篇文章介绍了Java中最常见的文件类型检测实现。 本文介绍了几种在Ja…

程序员编程艺术第十一章:最长公共子序列(LCS)问题

程序员编程艺术第十一章:最长公共子序列(LCS)问题 0、前言 程序员编程艺术系列重新开始创作了(前十章,请参考程序员编程艺术第一~十章集锦与总结)。回顾之前的前十章,有些代码是值得商榷的,因当时的代码只顾…

gateway 过滤器执行顺序_Gateway网关源码解析—路由(1.1)之RouteDefinitionLocator一览...

一、概述本文主要对 路由定义定位器 RouteDefinitionLocator 做整体的认识。在 《Spring-Cloud-Gateway 源码解析 —— 网关初始化》 中,我们看到路由相关的组件 RouteDefinitionLocator / RouteLocator 的初始化。涉及到的类比较多,我们用下图重新梳理下…

ERP开发中应用字符串解析实现界面翻译智能化

ERP中要实现界面多语言的功能,则要对各种情况的字符串进行处理并作出翻译。有些字符串的翻译是有规律可行的,遵循相应的模板模式,解析字符串,可以实现机器翻译的效果。 请看帐套数据库表的设计ADCOMP CREATE TABLE dbo.ADCOMP(REC…

参数详解 复制进程_如何优化PostgreSQL逻辑复制

How to Optimize PostgreSQL Logical Replication逻辑复制( Logical Replication )或 Pglogical 是表级别的复制。两者都是基于 WAL 的复制机制,允许在两个实例之间复制指定表的WAL 。这两个看起来让人迷惑,到底有什么区别呢? Logical Replic…

Android Studio使用说明

声明: 本博客文章原创类别的均为个人原创,版权所有。转载请注明出处: http://blog.csdn.net/ml3947,另外本人的个人博客:http://www.wjfxgame.com。 凌晨的Google I/O大会上,宣布了Android Studio,引起了现场开发者的一片欢呼。那么&#x…

有些窗口底部被任务栏挡住了_开始使用 Tint2 吧,一款 Linux 中的开源任务栏

Tint2 是我们在开源工具系列中的第 14 个工具,它将在 2019 年提高你的工作效率,能在任何窗口管理器中提供一致的用户体验。-- Kevin Sonney每年年初似乎都有疯狂的冲动想提高工作效率。新年的决心,渴望开启新的一年,当然&#xff…

从jHiccup开始

写完“如何在生产中检测和诊断慢速代码”一文后,我受到读者的鼓励,尝试从Azul系统尝试jHiccup 。 去年,我参加了jHiccup的创建者Gil Tene的演讲,探讨了测量延迟的正确方法,其中,他向我们介绍了jHiccup。 它…

华为内部面试题库---(6)

1.在SMP体系结构中,中断亲和性是指将一个或者多个中断绑定到特定CPU core上运行,下列说法错误的是:A.每个硬件设备都会在/proc/irq下有个中断号命令的目录来标志中断亲和性B.IRQ#目录下smp_affinity文件,通过设置CPU位掩码&#x…

基元需要走吗?

我目前正在使用JSF作为视图技术,使用JPA作为持久层的企业应用程序。 它可能是支持bean或服务方法中的某种东西,但令我震惊:是否有充分的理由在企业应用程序中使用原语? 当我开始围绕J2SE 1.2使用Java进行编程(或者是J…

输入参数_太实用!输入参数1秒算出功率,这款计算工具又快又准

随着互联网红利的不断加深,到了后期,不断地各种工具开始涌现,方便了很多用户,填补了市场上的很多空白,有生活娱乐类、提高效率类、垂直专业类、系统工具类等等。工业行业作为各行各业的大头,机械化、智能化…

如何编写NetBeans插件

是否想在NetBeans IDE中添加功能或自动执行某些操作? 跟随我们编写您的第一个NetBeans插件。 让我们超越简单的工具栏示例 ,创建一个可以自动更新的插件。 该代码基于NetBeans的WakaTime插件 。 我们的示例插件将仅打印Hello World语句,并在…