HuggingFace模型头的自定义

 

在线工具推荐:  Three.js AI纹理开发包 -  YOLO合成数据生成器 -  GLTF/GLB在线编辑 -  3D模型格式在线转换 -  可编程3D场景编辑器

在本文中我们将介绍如何使HuggingFace的模型适应你的任务,在Pytorch中建立自定义模型头并将其连接到HF模型的主体,并端到端地训练系统。

1、HF模型头和模型体

这是典型的HF模型的样子:

为什么我需要单独使用模型头(Model Head)和模型体(Model Body)?

一些HF的模型针对下游任务(例如提问或文本分类)训练,并包含有关其权重培训的数据的知识。

有时,尤其是当我们手头的任务包含很少的数据或领域特定(例如医学或运动特定任务)时,我们可以在HUB上使用其他任务训练的模型(不一定与我们的任务相同的任务 手但属于相同领域,例如运动或药物),并利用一些验证的知识来提高我们模型在我们自己任务的性能表现。

  • 一个非常简单的例子是,如果说我们有一个小数据集,比如分类某些财务报表是积极还是负面的。 但是,我们进入了HF,发现许多模型已经经过与金融相关的问答数据集的训练,那么 我们可以使用这些模型的某些层来改进自己的任务。
  • 另一个简单的示例是,某个特定领域的模型经过巨大数据集的训练学会了将文本从中分为5个类别。 假设我们有类似的分类任务,在同一域中的一个完全不同的数据集,只想将数据分类为2个类别而不是5。 这时我们也可以复用模型主体,添加自己的模型头来增强我们自己任务的特定领域知识。

这就是我们要做的事情的示意图:

2、自定义HF模型头

我们的任务是简单的,从Kaggle上的这个数据集进行讽刺检测。

你可以在此处查看完整的代码。 为了时间的考虑,我没有在下面包括预处理和一些训练的详细信息,因此请确保查看整个代码的笔记本。

我将使用一个在大量推文上训练的模型,有5个分类输出不同的情感类型。我们将提取模型体,在pytorch中添加自定义层(2个标签,讽刺/不讽刺),并训练新的模型。

注意:你可以在此示例中使用任何模型(不一定是对分类训练的模型),因为我们只会使用该模型主体并拆除模型头。

这就是我们的工作流程:

我将跳过数据预处理步骤,然后直接跳到主类,但是你可以在本节开头的链接中查看整个代码。

3、令牌化和动态填充

使用如下代码将文本转化为令牌并进行动态填充:

checkpoint = "cardiffnlp/twitter-roberta-base-emotion"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
tokenizer.model_max_len=512def tokenize(batch):return tokenizer(batch["headline"], truncation=True,max_length=512)tokenized_dataset = data.map(tokenize, batched=True)
print(tokenized_dataset)tokenized_dataset.set_format("torch",columns=["input_ids", "attention_mask", "label"])
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

结果如下:

DatasetDict({train: Dataset({features: ['headline', 'label', 'input_ids', 'attention_mask'],num_rows: 22802})test: Dataset({features: ['headline', 'label', 'input_ids', 'attention_mask'],num_rows: 2851})valid: Dataset({features: ['headline', 'label', 'input_ids', 'attention_mask'],num_rows: 2850})
})

4、提取模型体并添加我们自己的层

代码如下:

class CustomModel(nn.Module):def __init__(self,checkpoint,num_labels): super(CustomModel,self).__init__() self.num_labels = num_labels #Load Model with given checkpoint and extract its bodyself.model = model = AutoModel.from_pretrained(checkpoint,config=AutoConfig.from_pretrained(checkpoint, output_attentions=True,output_hidden_states=True))self.dropout = nn.Dropout(0.1) self.classifier = nn.Linear(768,num_labels) # load and initialize weightsdef forward(self, input_ids=None, attention_mask=None,labels=None):#Extract outputs from the bodyoutputs = self.model(input_ids=input_ids, attention_mask=attention_mask)#Add custom layerssequence_output = self.dropout(outputs[0]) #outputs[0]=last hidden statelogits = self.classifier(sequence_output[:,0,:].view(-1,768)) # calculate lossesloss = Noneif labels is not None:loss_fct = nn.CrossEntropyLoss()loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))return TokenClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states,attentions=outputs.attentions)

如你所见,我们首先是继承Pytorch中的 nn.Module,使用AutoModel(来自transformers库)提取加载了指定检查点的模型主体。

请注意, forward() 方法返回 TokenClassifierOutput,从而确保我们输出的格式与HF预训练模型一致。

5、端到端训练新的模型

代码如下:

from tqdm.auto import tqdmprogress_bar_train = tqdm(range(num_training_steps))
progress_bar_eval = tqdm(range(num_epochs * len(eval_dataloader)))for epoch in range(num_epochs):model.train()for batch in train_dataloader:batch = {k: v.to(device) for k, v in batch.items()}outputs = model(**batch)loss = outputs.lossloss.backward()optimizer.step()lr_scheduler.step()optimizer.zero_grad()progress_bar_train.update(1)model.eval()for batch in eval_dataloader:batch = {k: v.to(device) for k, v in batch.items()}with torch.no_grad():outputs = model(**batch)logits = outputs.logitspredictions = torch.argmax(logits, dim=-1)metric.add_batch(predictions=predictions, references=batch["labels"])progress_bar_eval.update(1)print(metric.compute())model.eval()test_dataloader = DataLoader(tokenized_dataset["test"], batch_size=32, collate_fn=data_collator
)for batch in test_dataloader:batch = {k: v.to(device) for k, v in batch.items()}with torch.no_grad():outputs = model(**batch)logits = outputs.logitspredictions = torch.argmax(logits, dim=-1)metric.add_batch(predictions=predictions, references=batch["labels"])metric.compute()

结果如下:

  0%|          | 0/2139 [00:00<?, ?it/s]0%|          | 0/270 [00:00<?, ?it/s]
{'f1': 0.9335347432024169}
{'f1': 0.9360090874668686}
{'f1': 0.9274912756882513}

如你所见,我们使用此方法实现了不错的性能。 请记住,该博客的目的不是分析此特定数据集的性能,而是要学习如何使用预训练的身体并添加自定义头。

6、结束语

在本文中,我们看到了如何在HF预训练模型上添加自定义层。

一些收获:

  • 在我们拥有特定于域的数据集并希望利用在同一域(任务 - 努力的task-agnostic)上训练的模型以增强小型数据集中的性能的情况下,此技术特别有用。
  • 我们可以选择接受过与自己任务不同的下游任务训练的模型,并且仍然使用该模型主体的知识。
  • 如果你的数据集足够大且通用,那么这可能根本不需要,在这种情况下,你可以使用 AutoModeForSequenceCecrification或使用 BERT 解决的任何其他任务。 实际上,如果是这样,我强烈建议不要建立自己的模型头。

原文链接:HF自定义模型头 - BimAnt

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/139819.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何导出PPT画的图为高清图片?插入到world后不压缩图像的设置方法?

期刊投稿的时候&#xff0c;需要图片保持一定的清晰度数&#xff0c;那么我们怎么才能从PPT中导出符合要求的图片呢&#xff1f; 对于矢量图绘图软件所画的图&#xff0c;直接导出即可。 而PPT导出的图片清晰度在60pi&#xff0c;就很模糊。 整体思路&#xff1a; PPT绘图——…

吃透 Spring 系列—MVC部分

目录 ◆ SpringMVC简介 - SpringMVC概述 - SpringMVC快速入门 - Controller中访问容器中的Bean - SpringMVC关键组件浅析 ◆ SpringMVC的请求处理 - 请求映射路径的配置 - 请求数据的接收 - Javaweb常用对象获取 - 请求静态资源 - 注解驱动 标签 ◆ SpringMV…

Leetcode421. 数组中两个数的最大异或值

Every day a Leetcode 题目来源&#xff1a;421. 数组中两个数的最大异或值 解法1&#xff1a;贪心 位运算 初始化答案 ans 0。从最高位 high_bit 开始枚举 i&#xff0c;也就是 max⁡(nums) 的二进制长度减一。设 newAns ans 2i&#xff0c;看能否从数组 nums 中选两个…

【ATTCK】MITRE Caldera -前瞻规划器

CALDERA是一个由python语言编写的红蓝对抗工具&#xff08;攻击模拟工具&#xff09;。它是MITRE公司发起的一个研究项目&#xff0c;该工具的攻击流程是建立在ATT&CK攻击行为模型和知识库之上的&#xff0c;能够较真实地APT攻击行为模式。 通过CALDERA工具&#xff0c;安全…

深入了解JVM和垃圾回收算法

1.什么是JVM&#xff1f; JVM是Java虚拟机&#xff08;Java Virtual Machine&#xff09;的缩写&#xff0c;是Java程序运行的核心组件。JVM是一个虚拟的计算机&#xff0c;它提供了一个独立的运行环境&#xff0c;可以在不同的操作系统上运行Java程序。 2.如何判断可回收垃圾…

机器学习数据预处理——Word2Vec的使用

引言&#xff1a; Word2Vec 是一种强大的词向量表示方法&#xff0c;通常通过训练神经网络来学习词汇中的词语嵌入。它可以捕捉词语之间的语义关系&#xff0c;对于许多自然语言处理任务&#xff0c;包括情感分析&#xff0c;都表现出色。 代码&#xff1a; 重点代码&#…

C# PaddleInference.PP-HumanSeg 人像分割 替换背景色

效果 项目 VS2022.net4.8OpenCvSharp4Sdcb.PaddleInference 包含4个分割模型 modnet-hrnet_w18 modnet-mobilenetv2 ppmatting-hrnet_w18-human_512 ppmattingv2-stdc1-human_512 代码 using OpenCvSharp; using Sdcb.PaddleInference; using System; using System.Col…

酷开科技智能大屏OS Coolita亮相第134届中国进出口商品交易会

作为中国外贸的“风向标”和“晴雨表”&#xff0c;广交会因其历史长、规模大、商品种类全、到会客商多、成交效果好&#xff0c;被称为“中国第一展”&#xff0c;它见证了中国改革开放的时代大潮与对外贸易的蓬勃发展。 2023年10月15日&#xff0c;第134届中国进出口商品交易…

【Spring Cloud】声明性REST客户端:Feign

Spring Cloud Feign ——fallback 服务降级 1. Feign 简介2. Feign 的基础使用2.1 普通 HTTP 请求2.2 Feign 远程调用上传文件接口 1. Feign 简介 Feign 是一个声明式的 HTTP 客户端&#xff0c;它简化了编写基于 REST 的服务间通信代码的过程。在 Spring Cloud 中&#xff0c…

【论文阅读】PSDF Fusion:用于动态 3D 数据融合和场景重建的概率符号距离函数

【论文阅读】PSDF Fusion&#xff1a;用于动态 3D 数据融合和场景重建的概率符号距离函数 Abstract1 Introduction3 Overview3.1 Hybrid Data Structure3.2 3D Representations3.3 Pipeline 4 PSDF Fusion and Surface Reconstruction4.1 PSDF Fusion4.2 Inlier Ratio Evaluati…

AI爆文变现脚本:易用且免费的自动写作脚本更新了

之前给大家分享的AI爆文变现写作脚本 由于时间仓促&#xff0c;加上我对很多东西不熟悉 免费版本对新手小白来说&#xff0c;安装部署起来是非常的困难 于是这几天我加班加点把整个软件的部署简化 现在无需复杂的环境配置安装&#xff0c;下载配置下就可以使用了。 免费版…

[工业自动化-16]:西门子S7-15xxx编程 - 软件编程 - 西门子仿真软件PLCSIM

目录 前言&#xff1a; 一、PLCSIM仿真软件 1.1 PLCSIM仿真软件基础版&#xff08;内嵌&#xff09; 1.2 PLCSIM仿真软件与PLCSIM仿真软件高级版的区别&#xff1f; 1.3 PLCSIM使用 前言&#xff1a; PLC集成开发环境是运行在Host主机上&#xff0c;Host主机与PLC可以通过…

音视频基础知识

图像&#xff08;YUV RGB&#xff09; ​​​​​​​​​​​​​​这个讲的比较好 RGB颜色编码 图像显示主要是由像素组成&#xff0c;每个像素点的颜色组成都是采用RGB格式&#xff0c;RGB就是红、绿、蓝&#xff0c;RGB分别取不同的值&#xff0c;展示不同的颜色。 YUV…

二十五、W5100S/W5500+RP2040树莓派Pico<Modebus TCP Server示例>

文章目录 1 前言2 简介2 .1 什么是Modbus TCP&#xff1f;2.2 Modbus TCP指令介绍2.3 请求数据过程2.4 Modbus TCP协议优点2.5 Modbus TCP应用场景 3 WIZnet以太网芯片4 Modbus TCP示例概述以及使用4.1 流程图4.2 准备工作核心4.3 连接方式4.4 主要代码概述4.5 结果演示 5 注意…

阿里云OSS和腾讯云COS对象存储介绍和简单使用

对象存储指的是一种云存储服务&#xff0c;其主要是将数据以对象的形式存储在云端&#xff0c;并且提供了完全的API调用&#xff0c;这些API包括上传&#xff0c;下载&#xff0c;删除&#xff0c;复制&#xff0c;预览&#xff0c;权限设置等等。OSS对象存储和COS对象存储都是…

设计模式之十一:代理模式

代理可以控制和管理访问。 RMI提供了客户辅助对象和服务辅助对象&#xff0c;为客户辅助对象创建和服务对象相同的方法。RMI的好处在于你不必亲自写任何网络或I/O代码。客户程序调用远程方法就和运行在客户自己本地JVM对对象进行正常方法调用一样。 步骤一&#xff1a;制作远程…

js案例:打地鼠游戏(打灰太狼)

效果预览图 游戏规则 当灰太狼出现的时候鼠标左键点击灰太狼加10分&#xff0c;小灰灰出现的时候鼠标左键点小灰灰击减10分&#xff0c;不点击不减分不加分。 整体思路 1.把获取背景图片中每个地洞的位置&#xff0c;把所有位置放到一个数组中。 2.封装随机数函数&#xff0c;随…

腾讯域名优惠卷领取

腾讯域名到到期了&#xff0c;听说申请此计划&#xff0c;可获得优惠卷&#xff0c;看到网上5年域名只需要10元&#xff0c;姑且试试看。 我的博客即将同步至腾讯云开发者社区&#xff0c;邀请大家一同入驻&#xff1a;https://cloud.tencent.com/developer/support-plan?in…

【hacker送书第一期】嵌入式虚拟化技术与应用

第一期图书推荐 前言为什么嵌入式系统需要虚拟化技术&#xff1f;专家推荐本书适用群体内容简介目录权威作者团队参与方式 前言 随着物联网设备的爆炸式增长和万物互联应用的快速发展&#xff0c;虚拟化技术在嵌入式系统上受到了业界越来越多的关注、重视和实际应用。嵌入式系…

Win10专业版安装wsl-ubuntu子系统

文章目录 一、查看是否满足安装要求二、管理员权限启动 Windows PowerShell三、启用Windows10子系统功能四、启用虚拟机平台功能五、重启电脑六、下载 Linux 内核更新包&#xff08;适用于 x64 计算机的 WSL2 Linux 内核更新包&#xff09;七、将 WSL 2 设置为默认版本八、打开…