一起学Hugging Face Transformers(13)- 模型微调之自定义训练循环

文章目录

  • 前言
  • 一、什么是训练循环
    • 1. 训练循环的关键步骤
    • 2. 示例
    • 3. 训练循环的重要性
  • 二、使用 Hugging Face Transformers 库实现自定义训练循环
    • 1. 前期准备
      • 1)安装依赖
      • 2)导入必要的库
    • 2. 加载数据和模型
      • 1) 加载数据集
      • 2) 加载预训练模型和分词器
      • 3) 预处理数据
      • 4) 创建数据加载器
    • 3. 自定义训练循环
      • 1) 定义优化器和学习率调度器
      • 2) 定义训练和评估函数
      • 3) 运行训练和评估
  • 总结


前言

Hugging Face Transformers 库为 NLP 模型的预训练和微调提供了丰富的工具和简便的方法。虽然 Trainer API 简化了许多常见任务,但有时我们需要更多的控制权和灵活性,这时可以实现自定义训练循环。本文将介绍什么是训练循环以及如何使用 Hugging Face Transformers 库实现自定义训练循环。


一、什么是训练循环

在模型微调过程中,训练循环是指模型训练的核心过程,通过多次迭代数据集来调整模型的参数,使其在特定任务上表现更好。训练循环包含以下几个关键步骤:

1. 训练循环的关键步骤

1) 前向传播(Forward Pass)

  • 模型接收输入数据并通过网络进行计算,生成预测输出。这一步是将输入数据通过模型的各层逐步传递,计算出最终的预测结果。

2) 计算损失(Compute Loss)

  • 将模型的预测输出与真实标签进行比较,计算损失函数的值。损失函数是一个衡量预测结果与真实值之间差距的指标,常用的损失函数有交叉熵损失(用于分类任务)和均方误差(用于回归任务)。

3) 反向传播(Backward Pass)

  • 根据损失函数的值,计算每个参数对损失的贡献,得到梯度。反向传播使用链式法则,将损失对每个参数的梯度计算出来。

4) 参数更新(Parameter Update)

  • 使用优化算法(如梯度下降、Adam 等)根据计算出的梯度调整模型的参数。优化算法会更新每个参数,使损失函数的值逐步减小,模型的预测性能逐步提高。

5) 重复以上步骤

  • 以上过程在整个数据集上进行多次(多个epoch),每次遍历数据集被称为一个epoch。随着训练的进行,模型的性能会不断提升。

2. 示例

假设你在微调一个BERT模型用于情感分析任务,训练循环的步骤如下:

1) 前向传播

  • 输入一条文本评论,模型通过各层网络计算,生成预测的情感标签(如正面或负面)。

2) 计算损失

  • 将模型的预测标签与实际标签进行比较,计算交叉熵损失。

3) 反向传播

  • 计算损失对每个模型参数的梯度,确定每个参数需要调整的方向和幅度。

4) 参数更新

  • 使用Adam优化器,根据计算出的梯度调整模型的参数。

5) 重复以上步骤

  • 在整个训练数据集上进行多次迭代,不断调整参数,使模型的预测精度逐步提高。

3. 训练循环的重要性

训练循环是模型微调的核心,通过多次迭代和参数更新,使模型能够从数据中学习,逐步提高在特定任务上的性能。理解训练循环的各个步骤和原理,有助于更好地调试和优化模型,获得更好的结果。

在实际应用中,训练循环可能会包含一些额外的步骤和技术,例如:

  • 批量训练(Mini-Batch Training):将数据集分成小批量,每次训练一个批量,降低计算资源的需求。
  • 学习率调度(Learning Rate Scheduling):动态调整学习率,以提高训练效率和模型性能。
  • 正则化技术(Regularization Techniques):如Dropout、权重衰减等,防止模型过拟合。

这些技术和方法结合使用,可以进一步提升模型微调的效果和性能。

二、使用 Hugging Face Transformers 库实现自定义训练循环

1. 前期准备

1)安装依赖

首先,确保已经安装了必要的库:

pip install transformers datasets torch

2)导入必要的库

import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_scheduler
from datasets import load_dataset
from tqdm.auto import tqdm

2. 加载数据和模型

1) 加载数据集

这里我们以 IMDb 电影评论数据集为例:

dataset = load_dataset("imdb")

2) 加载预训练模型和分词器

我们将使用 distilbert-base-uncased 作为基础模型:

model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)

3) 预处理数据

定义一个预处理函数,并将其应用到数据集:

def preprocess_function(examples):return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512)encoded_dataset = dataset.map(preprocess_function, batched=True)
encoded_dataset = encoded_dataset.rename_column("label", "labels")
encoded_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

4) 创建数据加载器

train_dataloader = DataLoader(encoded_dataset["train"], batch_size=8, shuffle=True)
eval_dataloader = DataLoader(encoded_dataset["test"], batch_size=8)

3. 自定义训练循环

1) 定义优化器和学习率调度器

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
num_epochs = 3
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)

2) 定义训练和评估函数

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)def train_loop():model.train()for batch in tqdm(train_dataloader):batch = {k: v.to(device) for k, v in batch.items()}outputs = model(**batch)loss = outputs.lossloss.backward()optimizer.step()lr_scheduler.step()optimizer.zero_grad()def eval_loop():model.eval()total_loss = 0correct_predictions = 0with torch.no_grad():for batch in tqdm(eval_dataloader):batch = {k: v.to(device) for k, v in batch.items()}outputs = model(**batch)loss = outputs.losslogits = outputs.logitstotal_loss += loss.item()predictions = torch.argmax(logits, dim=-1)correct_predictions += (predictions == batch["labels"]).sum().item()avg_loss = total_loss / len(eval_dataloader)accuracy = correct_predictions / len(eval_dataloader.dataset)return avg_loss, accuracy

3) 运行训练和评估

for epoch in range(num_epochs):print(f"Epoch {epoch + 1}/{num_epochs}")train_loop()avg_loss, accuracy = eval_loop()print(f"Validation Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")

总结

通过上述步骤,我们实现了使用 Hugging Face Transformers 库的自定义训练循环。这种方法提供了更大的灵活性,可以根据具体需求调整训练过程。无论是优化器、学习率调度器,还是其他训练策略,都可以根据需要进行定制。希望这篇文章能帮助你更好地理解和实现自定义训练循环,为你的 NLP 项目提供更强大的支持。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/41666.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

玉石风能否接棒黏土风?一探AI绘画新风尚

在数字艺术的浪潮中,AI绘画平台以其独特的创造力和便捷性,正在逐步改变我们对艺术的传统认知。从黏土风的温暖质感到琉璃玉石的细腻光泽,每一次风格的转变都引领着新的潮流。今天,我们将聚焦玉石风,探讨它是否能成为下一个流行的艺术滤镜,并提供一种在线体验的方式,让你…

Python | Leetcode Python题解之第221题最大正方形

题目: 题解: class Solution:def maximalSquare(self, matrix: List[List[str]]) -> int:if len(matrix) 0 or len(matrix[0]) 0:return 0maxSide 0rows, columns len(matrix), len(matrix[0])dp [[0] * columns for _ in range(rows)]for i in…

使用Python实现深度学习模型:模型监控与性能优化

在深度学习模型的实际应用中,模型的性能监控与优化是确保其稳定性和高效性的关键步骤。本文将介绍如何使用Python实现深度学习模型的监控与性能优化,涵盖数据准备、模型训练、监控工具和优化策略等内容。 目录 引言模型监控概述性能优化概述实现步骤数据准备模型训练模型监控…

梧桐数据库:语法分析模块概述

语法分析模块是数据库系统的重要组成部分,它负责将用户输入的 SQL 语句转换为内部表示形式,以便后续的处理和执行。在数据库系统中,语法分析模块是连接用户与数据库的桥梁。它的主要任务是将用户输入的 SQL 语句进行解析,检查语法…

Kafka(一)基础介绍

一,Kafka集群 一个典型的 Kafka 体系架构包括若Producer、Broker、Consumer,以及一个ZooKeeper集群,如图所示。 ZooKeeper:Kafka负责集群元数据的管理、控制器的选举等操作的; Producer:将消息发送到Broker…

随着云计算和容器技术的广泛应用,如何在这些环境中有效地运用 Shell 进行自动化部署和管理?

在云计算和容器技术的环境中,Shell 脚本可以被用于自动化部署和管理任务。下面是一些在这些环境中有效使用 Shell 进行自动化部署和管理的方法: 在云环境中,使用云服务提供商的 API 进行自动化管理。例如,使用命令行工具或 SDK 来…

14 - Python网络应用开发

网络应用开发 发送电子邮件 在即时通信软件如此发达的今天,电子邮件仍然是互联网上使用最为广泛的应用之一,公司向应聘者发出录用通知、网站向用户发送一个激活账号的链接、银行向客户推广它们的理财产品等几乎都是通过电子邮件来完成的,而…

[AI 大模型] OpenAI ChatGPT

文章目录 ChatGPT 简介ChatGPT 的模型架构ChatGPT的发展历史节点爆发元年AI伦理和安全 ChatGPT 新技术1. 技术进步2. 应用领域3. 代码示例4. 对话示例 ChatGPT 简介 ChatGPT 是由 OpenAI 开发的一个大型语言模型,基于GPT-4架构。它能够理解和生成自然语言文本&…

学习笔记——动态路由——OSPF(特殊区域)

十、OSPF特殊区域 1、技术背景 早期路由器靠CPU计算转发,由于硬件技术限制问题,因此资源不是特别充足,因此是要节省资源使用,规划是非常必要的。 OSPF路由器需要同时维护域内路由、域间路由、外部路由信息数据库。当网络规模不…

电脑会议录音转文字工具哪个好?5个转文字工具简化工作流程

在如今忙碌的生活中,我们常常需要记录和回顾重要的对话和讨论。手写笔记可能跟不上速度,而录音则以其便捷性成为了捕捉信息的有力工具。但录音文件的后续处理,往往让人头疼不已。想象一下,如果能够瞬间将这些声音转化为文字&#…

spring-16

Spring 对 DAO 的支持 Spring 对 DAO 的支持是通过 Spring 框架的 JDBC 模块实现的,它提供了一系列的工具和类来简化数据访问对象(DAO)的开发和管理。 首先,我们需要在 Spring 配置文件中配置数据源和事务管理器: &l…

Java笔试|面试 —— 子类对象实例化全过程 (熟悉)

子类对象实例化全过程 (熟悉) (1)从结果的角度来看:体现为继承性 当创建子类对象后,子类对象就获取了其父类中声明的所有的属性和方法,在权限允许的情况下,可以直接调用。 (2)从过…

iptables实现端口转发ssh

iptables实现端口转发 实现使用防火墙9898端口访问内网front主机的22端口(ssh连接) 1. 防火墙配置(lb01) # 配置iptables # 这条命令的作用是将所有目的地为192.168.100.155且目标端口为19898的TCP数据包的目标IP地址改为10.0.0.148,并将目标…

Java策略模式在动态数据验证中的应用

在软件开发中,数据验证是一项至关重要的任务,它确保了数据的完整性和准确性,为后续的业务逻辑处理奠定了坚实的基础。然而,不同的数据来源往往需要不同的验证规则,如何在不破坏代码的整洁性和可维护性的同时&#xff0…

无向图中寻找指定路径:深度优先遍历算法

刷题记录 1. 节点依赖 背景: 类似于无向图中, 寻找从 起始节点 --> 目标节点 的 线路. 需求: 现在需要从 起始节点 A, 找到所有到 终点 H 的所有路径 A – B : 路径由一个对象构成 public class NodeAssociation {private String leftNodeName;private Stri…

数据编码的艺术:sklearn中的数据转换秘籍

数据编码的艺术:sklearn中的数据转换秘籍 在机器学习中,数据预处理是一个至关重要的步骤,它直接影响到模型的性能和结果的准确性。数据编码转换是数据预处理的一部分,它涉及将原始数据转换成适合模型训练的格式。scikit-learn&am…

Python 爬虫 tiktok关键词搜索用户数据信息 api接口

Tiktok APP API接口 Python 爬虫采集Tiktok数据 采集结果页面如下图: https://www.tiktok.com/search?qwwe&t1706679918408 请求API http://api.xxx.com/tt/search/user?keywordwwe&count10&offset0&tokentest 请求参数 返回示例 联系我们&…

178 折线图-柱形图-饼状图

1.折线图 1、QChart 类继承自 QGraphicsWidget,用于管理图表、图例和轴。2、QValueAxis 类专门用来自定义图表中 X 和 Y 坐标轴。3、QLineSeries 类专门用于折线图(曲线)的形式展示数据 //.pro QT core gui charts#ifndef WIDGET_H #defi…

探索邻近奥秘:SKlearn中K-近邻(KNN)算法的应用

探索邻近奥秘:SKlearn中K-近邻(KNN)算法的应用 在机器学习的世界里,K-近邻(K-Nearest Neighbors,简称KNN)算法以其简单直观而著称。KNN是一种基本的分类和回归方法,它的工作原理非常…

Error in onLoad hook: “SyntaxError: Unexpected token u in JSON at position 0“

1.接收页面报错 Error in onLoad hook: "SyntaxError: Unexpected token u in JSON at position 0" Unexpected token u in JSON at position 0 at JSON.parse (<anonymous>) 2.发送页面 &#xff0c;JSON.stringify(item) &#xff0c;将对象转换为 JSO…