huggingface 自定义模型finetune训练测试--bert多任务

背景:

需要将bert改为多任务,但是官方仅支持多分类、二分类,并不支持多任务。改为多任务时我们需要修改输出层、loss、评测等。如果需要在bert结尾添加fc等也可以参考该添加方式。

代码

修改model

这里把BertForSequenceClassification改为多任务

import torch
import torch.nn as nn
from typing import List, Optional, Tuple, Union
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELossfrom transformers import BertPreTrainedModel, BertModel
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers import BertPreTrainedModel, BertModel
from transformers.utils import add_start_docstrings_to_model_forward, add_code_sample_docstrings,add_start_docstrings
from transformers import BertPreTrainedModel, BertModel
from transformers.utils import add_start_docstrings_to_model_forward, add_code_sample_docstrings,add_start_docstrings_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "textattack/bert-base-uncased-yelp-polarity"
_CONFIG_FOR_DOC = "BertConfig"
_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'"
_SEQ_CLASS_EXPECTED_LOSS = 0.01
BERT_START_DOCSTRING = r"""This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods thelibrary implements for all its model (such as downloading or saving, resizing the input embeddings, pruning headsetc.)This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usageand behavior.Parameters:config ([`BertConfig`]): Model configuration class with all the parameters of the model.Initializing with a config file does not load the weights associated with the model, only theconfiguration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
BERT_INPUTS_DOCSTRING = r"""Args:input_ids (`torch.LongTensor` of shape `({0})`):Indices of input sequence tokens in the vocabulary.Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and[`PreTrainedTokenizer.__call__`] for details.[What are input IDs?](../glossary#input-ids)attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:- 1 for tokens that are **not masked**,- 0 for tokens that are **masked**.[What are attention masks?](../glossary#attention-mask)token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`:- 0 corresponds to a *sentence A* token,- 1 corresponds to a *sentence B* token.[What are token type IDs?](../glossary#token-type-ids)position_ids (`torch.LongTensor` of shape `({0})`, *optional*):Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,config.max_position_embeddings - 1]`.[What are position IDs?](../glossary#position-ids)head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:- 1 indicates the head is **not masked**,- 0 indicates the head is **masked**.inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. Thisis useful if you want more control over how to convert `input_ids` indices into associated vectors than themodel's internal embedding lookup matrix.output_attentions (`bool`, *optional*):Whether or not to return the attentions tensors of all attention layers. See `attentions` under returnedtensors for more detail.output_hidden_states (`bool`, *optional*):Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors formore detail.return_dict (`bool`, *optional*):Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""@add_start_docstrings("""Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooledoutput) e.g. for GLUE tasks.""",BERT_START_DOCSTRING,
)
class BertForSequenceClassification_Multitask(BertPreTrainedModel):def __init__(self, config, task_output_dims):super().__init__(config)self.task_output_dims = task_output_dimsself.num_labels = config.num_labelsself.config = configself.bert = BertModel(config)classifier_dropout = (config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob)self.dropout = nn.Dropout(classifier_dropout)self.classifiers=nn.ModuleList([nn.Linear(768,output_dim) for output_dim in task_output_dims])# Initialize weights and apply final processingself.post_init()@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))@add_code_sample_docstrings(checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,output_type=SequenceClassifierOutput,config_class=_CONFIG_FOR_DOC,expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,expected_loss=_SEQ_CLASS_EXPECTED_LOSS,)def forward(self,input_ids: Optional[torch.Tensor] = None,attention_mask: Optional[torch.Tensor] = None,token_type_ids: Optional[torch.Tensor] = None,position_ids: Optional[torch.Tensor] = None,head_mask: Optional[torch.Tensor] = None,inputs_embeds: Optional[torch.Tensor] = None,labels: Optional[torch.Tensor] = None,output_attentions: Optional[bool] = None,output_hidden_states: Optional[bool] = None,return_dict: Optional[bool] = None,) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:r"""labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If`config.num_labels > 1` a classification loss is computed (Cross-Entropy)."""return_dict = return_dict if return_dict is not None else self.config.use_return_dictoutputs = self.bert(input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids,position_ids=position_ids,head_mask=head_mask,inputs_embeds=inputs_embeds,output_attentions=output_attentions,output_hidden_states=output_hidden_states,return_dict=return_dict,)pooled_output = outputs[1]pooled_output = self.dropout(pooled_output)if self.config.problem_type == 'multi_task_classification':logits=[classifier(pooled_output) for classifier in self.classifiers]else:logits = self.classifier(pooled_output)loss = Noneif labels is not None:if self.config.problem_type is None:if self.num_labels == 1:self.config.problem_type = "regression"elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):self.config.problem_type = "single_label_classification"elif labels.dtype==list:self.config.problem_type = "multi_task_classification"else:self.config.problem_type = "multi_label_classification"if self.config.problem_type == "regression":loss_fct = MSELoss()if self.num_labels == 1:loss = loss_fct(logits.squeeze(), labels.squeeze())else:loss = loss_fct(logits, labels)elif self.config.problem_type == "single_label_classification":loss_fct = CrossEntropyLoss()loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))elif self.config.problem_type == "multi_label_classification":loss_fct = BCEWithLogitsLoss()loss = loss_fct(logits, labels)elif self.config.problem_type == "multi_task_classification":loss_fct = CrossEntropyLoss()loss_list=[loss_fct(logits[i],labels[:,i]) for i in range(len(self.task_output_dims))]loss=torch.sum(torch.stack(loss_list))if not return_dict:output = (logits,) + outputs[2:]return ((loss,) + output) if loss is not None else outputreturn SequenceClassifierOutput(loss=loss,logits=logits,hidden_states=outputs.hidden_states,attentions=outputs.attentions,)
# 调用时
# 原调用为
model = BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, num_labels=2, hidden_dropout_prob=dropout)
# 现改为
model = BertForSequenceClassification_Multitask.from_pretrained(pretrained_model_name_or_path, num_labels=len(pjwk_cates), hidden_dropout_prob=dropout, task_output_dims=[6,63], problem_type = "multi_task_classification")

测试加载模型时

测试时,在load_checkpoint时,由于原有文件中没有problem_type =“multi_task_classification”,需要添加。可以哪里报错再加入。我的文件是/home/anaconda3/envs/bert/lib/python3.8/site-packages/transformers/configuration_utils.py第347行。

# 加入multi_task_classification
allowed_problem_types = ("regression", "single_label_classification", "multi_label_classification","multi_task_classification")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/72097.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【linux命令讲解大全】089.使用tree命令快速查看目录结构的方法

文章目录 tree补充说明语法选项列表选项文件选项排序选项图形选项XML / HTML / JSON 选项杂项选项 参数实例 从零学 python tree 树状图列出目录的内容 补充说明 tree 命令以树状图列出目录的内容。 语法 tree [选项] [参数]选项 列表选项 -a:显示所有文件和…

Java 项目防止 SQL 注入的四种方案

什么是SQL注入? SQL注入即是指web应用程序对用户输入数据的合法性没有判断或过滤不严,攻击者可以在web应用程序中事先定义好的查询语句的结尾上添加额外的SQL语句,在管理员不知情的情况下实现非法操作,以此来实现欺骗数据库服务器…

04 卷积神经网络搭建

一、数据集 MNIST数据集是从NIST的两个手写数字数据集:Special Database 3 和Special Database 1中分别取出部分图像,并经过一些图像处理后得到的[参考]。 MNIST数据集共有70000张图像,其中训练集60000张,测试集10000张。所有图…

deepstream6.2部署yolov5详细教程与代码解读

文章目录 引言一.环境安装1、yolov5环境安装2、deepstream环境安装 二、源码文件说明三.wts与cfg生成1、获得wts与cfg2、修改wts 四.libnvdsinfer_custom_impl_Yolo.so库生成五.修改配置文件六.运行demo 引言 DeepStream 是使用开源 GStreamer 框架构建的优化图形架构&#xf…

cesium创建基本的实体、点、线、多边形(vue)

1.通过viewer实例的entities对象实现 实现代码&#xff1a; <template><div id"container"></div> </template><script> import * as Cesium from cesium/Cesium import "cesium/Widgets/widgets.css" export default {mo…

玩转Mysql系列 - 第16篇:变量详解

这是Mysql系列第16篇。 环境&#xff1a;mysql5.7.25&#xff0c;cmd命令中进行演示。 代码中被[]包含的表示可选&#xff0c;|符号分开的表示可选其一。 我们在使用mysql的过程中&#xff0c;变量也会经常用到&#xff0c;比如查询系统的配置&#xff0c;可以通过查看系统变…

LeetCode刷题笔记【25】:贪心算法专题-3(K次取反后最大化的数组和、加油站、分发糖果)

文章目录 前置知识1005.K次取反后最大化的数组和题目描述分情况讨论贪心算法 134. 加油站题目描述暴力解法贪心算法 135. 分发糖果题目描述暴力解法贪心算法 总结 前置知识 参考前文 参考文章&#xff1a; LeetCode刷题笔记【23】&#xff1a;贪心算法专题-1&#xff08;分发饼…

gRPC远程进程调用

gRPC远程进程调用 rpc简介golang实现rpc方法一net/rpc库golang实现rpc方法二jsonrpc库grpc和protobuf在一起第一个grpc应用grpc服务的定义和服务的种类grpc stream实例1-服务端单向流grpc stream实例2-客户端单向流grpc stream实例3-双向流grpc整合gin

python的几种数据类型的花样玩法(三)

可变和不可变类型&#xff1a; Python中的一些类型是不可变的&#xff0c;这意味着它们的值在创建后不能更改。这些类型包括整数、浮点数、字符串、元组和冻结集合。其他类型&#xff0c;如列表、字典和集合&#xff0c;是可变的&#xff0c;这意味着它们的值可以在创建后更改。…

适配器模式:接口的平滑过渡

欢迎来到设计模式系列的第七篇文章&#xff01;在前面的几篇文章中&#xff0c;我们已经学习了一些常见的设计模式&#xff0c;今天我们将继续探讨另一个重要的设计模式——适配器模式。 适配器模式简介 适配器模式是一种结构型设计模式&#xff0c;它主要用于将一个类的接口…

【2023高教社杯】C题 蔬菜类商品的自动定价与补货决策 问题分析、数学模型及python代码实现

【2023高教社杯】C题 蔬菜类商品的自动定价与补货决策 1 题目 C题蔬菜类商品的自动定价与补货决策 在生鲜商超中&#xff0c;一般蔬菜类商品的保鲜期都比较短&#xff0c;且品相随销售时间的增加而变差&#xff0c; 大部分品种如当日未售出&#xff0c;隔日就无法再售。因此&…

已经2023年了,你还不会手撕轮播图?

目录 一、前言二、动画基础1. 定时器2. left与offsetLeft3. 封装函数3.1 物体3.2 目标点3.3 回调函数 4.封装 三、基础结构3.1 焦点图3.2 按钮3.3 小圆点3.4 总结 四、按钮显示五、圆点5.1 生成5.2 属性5.3 移动 六、按钮6.1 准备6.2 出错6.2.1 小圆点跟随6.2.2 图片返回 6.3 b…

BLE架构与开源协议栈

BLE架构&#xff1a; 简单来说&#xff0c;BLE协议栈可以分成三个部分&#xff0c;主机(host)程序&#xff0c;控制器(controller)程序&#xff0c;主机控制器接口(HCI)。如果再加上底层射频硬件和顶层用户程序&#xff0c;则构成了完整的BLE协议&#xff0c;如下图所示&#…

ModuleNotFoundError: No module named ‘lavis‘解决方案

大家好,我是爱编程的喵喵。双985硕士毕业,现担任全栈工程师一职,热衷于将数据思维应用到工作与生活中。从事机器学习以及相关的前后端开发工作。曾在阿里云、科大讯飞、CCF等比赛获得多次Top名次。现为CSDN博客专家、人工智能领域优质创作者。喜欢通过博客创作的方式对所学的…

c语言实训心得3篇集合

c语言实训心得体会一&#xff1a; 在这个星期里&#xff0c;我们专业的学生在专业老师的带领下进行了c语言程序实践学习。在这之前&#xff0c;我们已经对c语言这门课程学习了一个学期&#xff0c;对其有了一定的了解&#xff0c;但是也仅仅是停留在了解的范围&#xff0c;对里…

第十八课、Qt 下载、安装与配置

功能描述&#xff1a;介绍了 Qt 的下载、安装和配置的全部过程&#xff0c;并对关键页面选项进行了详细说明 一、Qt 的下载 Qt 官方下载地址&#xff1a;https://www.qt.io/zh-cn/downloadhttps://download.qt.io/https://download.qt.io/https://www.qt.io/zh-cn/download进入…

GptFuck—开源Gpt4分享

这个项目不错&#xff0c;分享给大家 项目地址传送门

深入探索KVM虚拟化技术:全面掌握虚拟机的创建与管理

文章目录 安装KVM开启cpu虚拟化安装KVM检查环境是否正常 KVM图形化创建虚拟机上传ISO创建虚拟机加载镜像配置内存添加磁盘能否手工指定存储路径呢&#xff1f;创建成功安装完成查看虚拟机 KVM命令行创建虚拟机创建磁盘通过命令行创建虚拟机手动安装虚拟机 KVM命令行创建虚拟机-…

uniapp实现底部弹出菜单选择

其实uniapp有内置的组件&#xff0c;不用自己去实现&#xff0c;类似于这样&#xff1a; uni.showActionSheet({itemList: [菜单一, 菜单二, 菜单三],success: function (res) {console.log(选中了第${res.tapIndex 1}个菜单);},fail: function (res) {console.log(res.errMs…

数据集笔记:GeoLife GPS 数据 (user guide)

数据链接&#xff1a;https://www.microsoft.com/en-us/download/details.aspx?id52367 1 数据基本信息 1.1 数据介绍 182名用户在超过三年的时间内&#xff08;从2007年4月到2012年8月&#xff09;在&#xff08;微软亚洲研究院&#xff09;Geolife项目中收集的。该数据集…