huggingface NLP-微调一个预训练模型

微调一个预训练模型
在这里插入图片描述

1 预处理数据

1.1 处理数据
1.1.1 fine-tune
使用tokenizer后的token 进行训练

batch = tokenizer(sequences, padding=True, truncation=True, return_tensors="pt")# This is new
batch["labels"] = torch.tensor([1, 1])optimizer = AdamW(model.parameters())
loss = model(**batch).loss
loss.backward()
optimizer.step()

1.2 从模型中心(Hub)加载数据集
1.2.1 数据集
DatasetDict对象,其中包含训练集、验证集和测试集
。每一个集合都包含几个列(sentence1, sentence2, label, and idx)以及一个代表行数的变量,即每个集合中的行的个数
下载数据集并缓存到 ~/.cache/huggingface/datasets. 回想一下第2章,您可以通过设置HF_HOME环境变量来自定义缓存的文件夹。
1.3 预处理数据集
1.3.1 预处理数据集,我们需要将文本转换为模型能够理解的数字
1.3.2 类型标记ID(token_type_ids)的作用就是告诉模型输入的哪一部分是第一句,哪一部分是第二句
1.3.3 不一定具有类型标记ID(token_type_ids
1.3.4 将数据保存为数据集,我们将使用Dataset.map()
调用map时使用了batch =True,这样函数就可以同时应用到数据集的多个元素上,而不是分别应用到每个元素上。这将使我们的预处理快许多
1.3.5 省略padding参数
在标记的时候将所有样本填充到最大长度的效率不高
一个更好的做法:在构建批处理时填充样本更好,因为这样我们只需要填充到该批处理中的最大长度,而不是整个数据集的最大长度。当输入长度变化很大时,这可以节省大量时间和处理能力!
1.3.6 将所有示例填充到最长元素的长度——我们称之为动态填充
为了解决句子长度统一的问题,我们必须定义一个collate函数,该函数会将每个batch句子填充到正确的长度
transformer库通过DataCollatorWithPadding为我们提供了这样一个函数
It’s when you pad your inputs when the batch is created, to the maximum length of the sentences inside that batch.
1.3.7 数据集是以Apache Arrow文件存储在磁盘上
1.4 benefits
1.4.1 The results of the function are cached, so it won’t take any time if we re-execute the code.
1.4.2 It can apply multiprocessing to go faster than applying the function on each element of the dataset.
1.4.3 It does not load the whole dataset into memory, saving the results as soon as one element is processed.

2 使用 Trainer API 微调模型(非并行批量模式)
2.1 TrainingArguments 类
2.1.1 它将包含 Trainer用于训练和评估的所有超参数
2.1.2 可以只调整部分默认参数进行微调

training_args = TrainingArguments("test-trainer", evaluation_strategy="epoch")

2.2 Training
2.2.1 简单training

trainer = Trainer(model,training_args,train_dataset=tokenized_datasets["train"],eval_dataset=tokenized_datasets["validation"],data_collator=data_collator,tokenizer=tokenizer,
)
trainer.train()

2.3 评估
2.3.1 使用 Trainer.predict() 命令来使用我们的模型进行预测
predict() 的输出结果是具有三个字段的命名元组: predictions , label_ids , 和 metrics
metrics 字段将只包含传递的数据集的loss,以及一些运行时间(预测所需的总时间和平均时间)
要将我们的预测的可以与真正的标签进行比较,我们需要在第二个轴上取最大值的索引

preds = np.argmax(predictions.predictions, axis=-1)
def compute_metrics(eval_preds):metric = evaluate.load("glue", "mrpc")logits, labels = eval_predspredictions = np.argmax(logits, axis=-1)return metric.compute(predictions=predictions, references=labels)

如何使用compute_metrics()函数定义一个新的 Trainer
2.4 AutoModelForSequenceClassification
2.4.1 when we used AutoModelForSequenceClassification with bert-base-uncased, we got warnings when instantiating the model. The pretrained head is not used for the sequence classification task, so it’s discarded and a new head is instantiated(实例化) with random weights.

3 完整的训练,使用Accelerator和scheduler

3.1 训练前的数据准备
3.1.1 删除与模型不期望的值相对应的列(如sentence1和sentence2列)。
3.1.2 将列名label重命名为labels(因为模型期望参数是labels)。
3.1.3 设置数据集的格式,使其返回 PyTorch 张量而不是列表。
3.1.4 代码

tokenized_datasets = tokenized_datasets.remove_columns(["sentence1", "sentence2", "idx"])
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format("torch")
tokenized_datasets["train"].column_names

3.2 data loader
3.2.1 from torch.utils.data import DataLoader

train_dataloader = DataLoader(tokenized_datasets["train"], shuffle=True, batch_size=8, collate_fn=data_collator
)
eval_dataloader = DataLoader(tokenized_datasets["validation"], batch_size=8, collate_fn=data_collator
)

3.3 优化器和学习率调度器
3.3.1 optimizer = AdamW(model.parameters(), lr=5e-5)

lr_scheduler = get_scheduler("linear",optimizer=optimizer,num_warmup_steps=0,num_training_steps=num_training_steps,
)

3.4 训练循环
3.4.1

num_epochs = 3
num_training_steps = num_epochs * len(train_dataloader)

3.5 使用accelerator
3.5.1

train_dl, eval_dl, model, optimizer = accelerator.prepare(train_dataloader, eval_dataloader, model, optimizer
)num_epochs = 3
num_training_steps = num_epochs * len(train_dl)

3.5.2 分布式
accelerate config
accelerate launch train.py
3.5.3 With 🤗Accelerate, your training loops will work for multiple GPUs and TPUs.

4 注意train只会对参数有调整

4.1 超参数(Hyperparameter),是机器学习算法中的调优参数,用于控制模型的学习过程和结构。 与模型参数(Model Parameter)不同,模型参数是在训练过程中通过数据学习得到的,而超参数是在训练之前由开发者或实践者直接设定的,并且在训练过程中保持不变。
4.2 需要自己设定,不是机器自己找出来的,称为超参数(hyperparameter)。
4.2.1 需要人工设置: 超参数的值不是通过训练过程自动学习得到的,而是需要训练者根据经验或实验来设定。
4.2.2 影响模型性能: 超参数的选择会直接影响模型的训练过程和最终性能。
4.2.3 需要优化: 为了获得更好的模型性能,通常需要对超参数进行优化,选择最优的超参数组合。
4.3 validated 数据集作用
4.3.1 验证集,用于挑选超参数的数据子集。
4.3.2

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/64237.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Unity NTPComponent应用, 实现一个无后端高效获取网络时间的组件

无后端高效获取网络时间的组件 废话不多说,直接上源码m_NowSerivceTime 一个基于你发行游戏地区的时间偏移, 比如北京时区就是 8, 巴西就是-3,美国就是-5using Newtonsoft.Json; 如果这里报错, 就说明项目没有 NewtonsoftJson插件…

华为ensp中nat server 公网访问内网服务器

作者主页:点击! ENSP专栏:点击! 创作时间:2024年4月15日17点30分 💯趣站推荐💯 前些天发现了一个巨牛的🤖人工智能学习网站,通俗易懂,风趣幽默,…

R语言学习笔记-1

1. 基础操作和函数 清空环境:rm(list ls()) 用于清空当前的R环境。 打印输出:print("Hello, world") 用于输出文本到控制台。 查看已安装包和加载包: search():查看当前加载的包。install.packages("package_na…

二、FIFO缓存

FIFO缓存 1.FIFO缓存介绍2.FIFO缓存实现3.FIFO缓存总结 1.FIFO缓存介绍 FIFO(First-In-First-Out)缓存 是一种简单的缓存淘汰策略,它基于先进先出的原则来管理数据。当缓存达到容量限制并需要淘汰元素时,最先进入缓存的元素会被移…

Spring Cloud与Spring Cloud Alibaba:全面解析与核心要点

Spring Cloud与Spring Cloud Alibaba:全面解析与核心要点 一、引言 在当今的分布式系统开发领域,Spring Cloud和Spring Cloud Alibaba都是极为重要的框架。它们为构建大规模、高可用、分布式的应用系统提供了丰富的工具和组件。本文将深入探讨Spring C…

账号下的用户列表表格分析

好的,这是您提供的 el-table 组件中所有列的字段信息,以表格形式展示: 列标题 (label)字段属性 (prop)对齐方式 (align)宽度 (width)是否可排序 (sortable)说明IDidcenter100否管理员的唯一标识符头像avatarcenter90否管理员的头像 URL 或路…

GPT-SoVITS语音合成模型部署及使用

1、概述 GPT-SoVITS是一款开源的语音合成模型,结合了深度学习和声学技术,能够实现高质量的语音生成。其独特之处在于支持使用参考音频进行零样本语音合成,即使没有直接的训练数据,模型仍能生成相似风格的语音。用户可以通过微调模…

TongWe7.0-东方通TongWeb控制台无法访问 排查

问题描述:无法访问TongWeb的控制台 逐项排查: 1、控制台访问地址是否正确:http://IP:9060/console #IP是服务器的实际IP地址 2、确认TongWeb进程是否存在,执行命令:ps -ef|grep tongweb 3、确认TongWeb服务启动…

研发文档管理系统:国内外9大选择比较

文章主要对比了9款国内外研发文档管理系统:1.PingCode; 2. Worktile; 3. 飞书; 4. 石墨文档; 5. 腾讯文档; 6. 蓝湖; 7. Confluence; 8. Notion; 9. Slab。 在企业研发过…

【ABAP SAP】开发-BUG修补记录_采购申请打印时品名规格品牌为空

项目场景: TCODE:自开发程序ZMMF004 采购申请打印 问题描述 ZMMF004打印的时候,有的采购申请的品名、规格、品牌为空 原因分析: 1、首先我通过写SQL语句查底表来看这几条采购申请本身有无品名、规格、品牌 SQL语句如下,只需修…

Ubuntu 20.04 24.04 双网卡 Bond 配置指南

前言:在现代服务器管理中,网络的稳定性和可靠性至关重要。为了提高网络的冗余性和负载能力,我们经常需要配置多个网络接口以实现链路聚合或故障转移。Ubuntu系统自17.10版本起,引入了Netplan作为新的网络配置抽象化工具&#xff0…

OCR实践—PaddleOCR

有个项目需求,对拍摄的问卷图片,进行自动得分统计【得分是在相应的分数下面打对号】,输出到excel文件 原始问卷文件见下图,真实的图片因使用手机拍摄的图片,存在一定的畸变, 技术调研 传统方法 传统方法…

ubuntu+ros新手笔记(五):初探anaconda+cuda+pytorch

深度学习三件套:初探anacondacudapytorch 系统ubuntu22.04 ros2 humble 1.初探anaconda 1.1 安装 安装过程参照【详细】Ubuntu 下安装 Anaconda 1.2 创建和删除环境 创建新环境 conda create -n your_env_name pythonx.x比如我创建了一个名为“py312“的环境…

【测试】Pytest

建议关注、收藏! 目录 功能pytest 自动化测试工具。 功能 单元测试:用于验证代码的最小功能单元(如函数、方法)的正确性。 简单的语法:不需要继承特定类或使用复杂的结构。断言语句简化。 自动发现测试:P…

Unity性能优化---使用SpriteAtlas创建图集进行批次优化

在日常游戏开发中,UI是不可缺少的模块,而在UI中又使用着大量的图片,特别是2D游戏还有很多精灵图片存在,如果不加以处理,会导致很高的Batches,影响性能。 比如如下的例子: Batches是9&#xff0…

环境和工程搭建

1.案例介绍 1.1 需求 实现⼀个电商平台 该如何实现呢? 如果把这些功能全部写在⼀个服务⾥, 这个服务将是巨⼤的. 巨多的会员, 巨⼤的流量, 微服务架构是最好的选择. 微服务应⽤开发的第⼀步, 就是服务拆分. 拆分后才能进⾏"各⾃开发" 1.2 服务拆分 拆分原则 …

解决Jmeter HTTP Cookie管理器cookie不生效

解决Jmeter HTTP Cookie管理器cookie不生效问题 解决Jmeter HTTP Cookie管理器cookie不生效问题1、设置Jmeter HTTP Cookie管理器cookie后,发起的请求显示[no cookies]jmeter问题复现:这里同样使用postman进行重试,发现是可以正常获取数据的&…

操作系统课后习题2.2节

操作系统课后习题2.2节 第1题 CPU的效率指的是CPU的执行速度,这个是由CPU的设计和它的硬件来决定的,具体的调度算法是不能提高CPU的效率的; 第3题 互斥性: 指的是进程之间的同步互斥关系,进程是一个动态的过程&#…

二叉搜索树 平衡树(c嘎嘎版)

定义: 二叉搜索树是一种二叉树的树形数据结构,其定义如下: 空树是二叉搜索树。 若二叉搜索树的左子树不为空,则其左子树上所有点的附加权值均小于其根节点的值。 若二叉搜索树的右子树不为空,则其右子树上所有点的附加权值均大于其根节点的值。 二叉搜索树的左右子树均为…

Everything搜索实现

最近编写NTFS文件实时搜索工具, 类似 Everything 这样, 速度快还小巧, 于是花了约3周进行研究, 总结下学习过程中一些经验 实现分3部分完成 一. 解析NTFS 主文件表(MFT) 这一步是获取文件数据的唯一迅速且可靠的来源 NTFS_MFT_Parse.h #pragma once #include "NTFS_Bas…