GT_BERT文本分类

目录

  • GT-BERT
  • 结束语
  • 代码实现
  • 整个项目源码(数据集模型)

GT-BERT

在为了使 BERT 模型能够得到广泛的应用,在保证模型分类准确率不降低的情况下,减少模型参数规模并降低时间复杂度,提出一种基于半监督生成对抗网络与 BERT 的文本分类模型 GT-BERT。模型的整体框架如图3所示。
在这里插入图片描述

首先,对BERT进行压缩,通过实验验证选择使用BERT-of-theseus方法进行压缩得到BERT-theseus模型。损失函数设定为文本分类常用的交叉熵损失:
在这里插入图片描述

其中,为训练集的第j个样本,是的标签,C和c表示标签集合和一个类标签。接着,在压缩之后,从SS-GANs角度扩展BERT-theseus模型进行微调。在预训练过的BERT-theseus模型中添加两个组件:(1)添加特定任务层;(2)添加SS-GANs层来实现半监督学习。本研究假定K类句子分类任务,给定输入句子s=(, ,…,),其中开头的为分类特殊标记“[CLS]”,结尾的为句子分隔特殊标记“[SEP]”,其余部分对输入句子进行切分后标记序列输入BERT模型后得到编码向量序列为=(,…,)。
将生成器G生成的假样本向量与真实无标注数据输入BERT-theseus中所提取的特征向量,分别输入至判别器D中,利用对抗训练来不断强化判别器D。与此同时,利用少量标注数据对判别器D进行分类训练,从而进一步提高模型整体质量。
其中,生成器G输出服从正态分布的“噪声”,采用CNN网络,将输出空间映射到样本空间,记作∈。 判别器D也为CNN网络,它在输入中接收向量∈,其中可以为真实标注或者未标注样本 ,也可以为生成器生成的假样本数据。在前向传播阶段,当样本为真实样本时,即=,判别器D会将样本分类在K类之中。当样本为假样本时,即=,判别器D会把样本相对应的分类于K+1类别中。在此阶段生成器G和判别器D的损失分别被记作和,训练过程中G和D通过相互博弈而优化损失。
在反向传播中,未标注样本只增加。标注的真实样本只会影响,在最后和都会受到G的影响,即当D找不出生成样本时,将会受到惩罚,反亦然。在更新D时,改变BERT-theseus的权重来进行微调。训练完成后,生成器G会被舍弃,同时保留完整的BERT-theseus模型与判别器D进行分类任务的预测。

结束语

该文提出了一种用于文本分类任务的GT-BERT模型。首先,使用 theseus方法对BERT进行压缩,在不降低分类性能的前提下,有效降低了BERT 的参数规模和时间复杂度。然后,引人SS-GAN框架改进模型的训练方式,使 BERT-theseus模型能有效利用无标注数据,并实验了多组生成器与判别器的组合方式,获取了最优的生成器判别器组合配置,进一步提升了模型的分类性能。

代码实现

import torch
from transformers import BertTokenizer, BertModel
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import torch.nn as nn
import torch.optim as optim
import os
from glob import globtorch.autograd.set_detect_anomaly(True)# 定义数据集类
class TextDataset(Dataset):def __init__(self, texts, labels, tokenizer, max_len):self.texts = textsself.labels = labelsself.tokenizer = tokenizerself.max_len = max_lendef __len__(self):return len(self.texts)def __getitem__(self, idx):text = self.texts[idx]label = self.labels[idx]encoding = self.tokenizer.encode_plus(text,add_special_tokens=True,max_length=self.max_len,return_token_type_ids=False,padding='max_length',truncation=True,return_attention_mask=True,return_tensors='pt',)return {'text': text,'input_ids': encoding['input_ids'].flatten(),'attention_mask': encoding['attention_mask'].flatten(),'label': torch.tensor(label, dtype=torch.long)}# 加载数据集函数
def load_data(dataset_name):if dataset_name == '20ng':dirs = glob("E:/python_project/GT_BERT/dateset/20_newsgroups/20_newsgroups/*")texts = []labels = []for i, d in enumerate(dirs):for j in glob(d + "/*")[:10]:try:with open(j, "r", encoding="utf-8") as f:one = f.read()except:continuetexts.append(one)labels.append(i)elif dataset_name == 'sst5':data_dir = 'path/to/sst/data'def load_sst_data(data_dir, split):sentences = []labels = []with open(os.path.join(data_dir, f'{split}.txt')) as f:for line in f:label, sentence = line.strip().split(' ', 1)sentences.append(sentence)labels.append(int(label))return sentences, labelstexts, labels = load_sst_data(data_dir, 'train')elif dataset_name == 'mr':file_path = 'path/to/mr/data'def load_mr_data(file_path):sentences = []labels = []with open(file_path) as f:for line in f:label, sentence = line.strip().split(' ', 1)sentences.append(sentence)labels.append(int(label))return sentences, labelstexts, labels = load_mr_data(file_path)elif dataset_name == 'trec':file_path = 'path/to/trec/data'def load_trec_data(file_path):sentences = []labels = []with open(file_path) as f:for line in f:label, sentence = line.strip().split(' ', 1)sentences.append(sentence)labels.append(label)return sentences, labelstexts, labels = load_trec_data(file_path)else:raise ValueError("Unsupported dataset")return texts, labels# 默认加载 20 News Group 数据集
dataset_name = '20ng'
texts, labels = load_data(dataset_name)label_encoder = LabelEncoder()
labels = label_encoder.fit_transform(labels)# 使用BERT的tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
max_len = 128# 将数据集划分为训练集和验证集
train_texts, val_texts, train_labels, val_labels = train_test_split(texts, labels, test_size=0.2)
train_dataset = TextDataset(train_texts, train_labels, tokenizer, max_len)
val_dataset = TextDataset(val_texts, val_labels, tokenizer, max_len)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=False)# 定义BERT编码器
class BERTTextEncoder(nn.Module):def __init__(self):super(BERTTextEncoder, self).__init__()self.bert = BertModel.from_pretrained('bert-base-uncased')def forward(self, input_ids, attention_mask):outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)pooled_output = outputs[1]return pooled_output# 定义生成器
class Generator(nn.Module):def __init__(self, noise_dim, output_dim):super(Generator, self).__init__()self.fc = nn.Sequential(nn.Linear(noise_dim, 128),nn.ReLU(),nn.Linear(128, output_dim),nn.Tanh())def forward(self, noise):return self.fc(noise)# 定义判别器
class Discriminator(nn.Module):def __init__(self, input_dim):super(Discriminator, self).__init__()self.fc = nn.Sequential(nn.Linear(input_dim, 128),nn.ReLU(),nn.Linear(128, 1),nn.Sigmoid())def forward(self, features):return self.fc(features)# 定义完整的GT-BERT模型
class GTBERTModel(nn.Module):def __init__(self, bert_encoder, noise_dim, output_dim, num_classes):super(GTBERTModel, self).__init__()self.bert_encoder = bert_encoderself.generator = Generator(noise_dim, output_dim)self.discriminator = Discriminator(output_dim)self.classifier = nn.Linear(output_dim, num_classes)def forward(self, input_ids, attention_mask, noise):real_features = self.bert_encoder(input_ids, attention_mask)fake_features = self.generator(noise)disc_real = self.discriminator(real_features)disc_fake = self.discriminator(fake_features)class_output = self.classifier(real_features)return class_output, disc_real, disc_fake# 初始化模型和超参数
noise_dim = 100
output_dim = 768
num_classes = len(set(labels))
bert_encoder = BERTTextEncoder()
model = GTBERTModel(bert_encoder, noise_dim, output_dim, num_classes)# 定义损失函数和优化器
criterion_class = nn.CrossEntropyLoss()
criterion_disc = nn.BCELoss()
optimizer_G = optim.Adam(model.generator.parameters(), lr=0.0002)
optimizer_D = optim.Adam(model.discriminator.parameters(), lr=0.0002)
optimizer_BERT = optim.Adam(model.bert_encoder.parameters(), lr=2e-5)
optimizer_classifier = optim.Adam(model.classifier.parameters(), lr=2e-5)num_epochs = 10# 训练循环
e_id = 1
for epoch in range(num_epochs):model.train()for batch in train_dataloader:e_id += 1input_ids = batch['input_ids']attention_mask = batch['attention_mask']labels = batch['label']# 生成噪声noise = torch.randn(input_ids.size(0), noise_dim)# 获取模型输出class_output, disc_real, disc_fake = model(input_ids, attention_mask, noise)# 计算损失real_labels = torch.ones(input_ids.size(0), 1)fake_labels = torch.zeros(input_ids.size(0), 1)loss_real = criterion_disc(disc_real, real_labels)loss_fake = criterion_disc(disc_fake, fake_labels)loss_class = criterion_class(class_output, labels)if e_id % 5 == 0:# 优化判别器optimizer_D.zero_grad()loss_D = (loss_real + loss_fake) / 2loss_D.backward(retain_graph=True)optimizer_D.step()elif e_id % 2 == 0:# 优化生成器loss_G = criterion_disc(disc_fake, real_labels)optimizer_G.zero_grad()loss_G.backward(retain_graph=True)optimizer_G.step()else:# 优化BERT和分类器optimizer_BERT.zero_grad()optimizer_classifier.zero_grad()loss_class.backward()optimizer_BERT.step()optimizer_classifier.step()print(f'Epoch [{epoch + 1}/{num_epochs}], Loss D: {loss_D.item()}, Loss G: {loss_G.item()}, Loss Class: {loss_class.item()}')# 验证模型
model.eval()
val_loss = 0
correct = 0
with torch.no_grad():for batch in val_dataloader:input_ids = batch['input_ids']attention_mask = batch['attention_mask']labels = batch['label']noise = torch.randn(input_ids.size(0), noise_dim)class_output, disc_real, disc_fake = model(input_ids, attention_mask, noise)loss = criterion_class(class_output, labels)val_loss += loss.item()pred = class_output.argmax(dim=1, keepdim=True)correct += pred.eq(labels.view_as(pred)).sum().item()val_loss /= len(val_dataloader.dataset)
accuracy = correct / len(val_dataloader.dataset)
print(f'Validation Loss: {val_loss}, Accuracy: {accuracy}')

整个项目源码(数据集模型)

项目

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/30461.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【ajax基础04】form-serialize插件

目录 一:form-serialize插件 作用: 语法格式: 一:form-serialize插件 作用: 快速且大量的收集表单元素的值 例如上图对于多表单元素的情形,单靠通过”选择器获取节点.value”值的形式,获取…

使用 GCD 实现属性的多读单写

使用 Grand Central Dispatch (GCD) 实现多读单写的属性 首先需要确保在多线程环境下的线程安全性。可以使用 GCD 提供的读写锁机制 dispatch_rwlock_t 或者 dispatch_queue_t 来实现这个功能。 Swift版本的实现 怎样创建一个并发队列 ?// 使用 Swift 来实现的首…

.net 奇葩问题调试经历之1——在红外相机获取温度时异常

📢欢迎点赞 :👍 收藏 ⭐留言 📝 如有错误敬请指正,赐人玫瑰,手留余香!📢本文作者:由webmote 原创📢作者格言:新的征程,我们面对的不仅仅是技术还有人心,人心不可测,海水不可量,唯有技术,才是深沉黑夜中的一座闪烁的灯塔序言 我们在研发中,经常除了造产品…

吉时利Keithley2602B数字源表

吉时利Keithley2602B数字源表 2601B、2602B、2604B 系统 Sourcemeter SMU 仪器 2601B、2602B 和 2604B 系统 Sourcemeter SMU 仪器为 40W DC / 200W 脉冲 SMU,支持 10A 脉冲,3A 至 100fA 和 40V 至 100nV DC。它们将精密电源、实际电流源、6 位数字万用…

使用asyncua模块的call_method方法调用OPC UA的Server端方法报错:asyncio.exceptions.TimeoutError

使用asyncua模块的call_method方法调用OPC UA的Server端方法报错:asyncio.exceptions.TimeoutError 报错信息如下: Traceback (most recent call last): asyncio.run(main()) File “D:\miniconda3\envs\py31013\lib\asyncio\runners.py”, line 44, in…

反激开关电源整流桥选型及计算

整流桥的作用就是把输入交流电压整形成直流电压,把正弦波整成馒头波,由于整流管的单向导电 性,在输入电压瞬时值小于滤波电容上电压时整流桥,在这个时候是不导通的,使整流桥的电流变 成2-3ms左右的窄脉冲。为获得所需…

【数据结构】选择题

在数据结构中,从逻辑上可以把数据结构分为(线性结构和非线性结构) 当输入规模为n时,下列算法渐进复杂性中最低的是() 时间复杂度 某线性表采用顺序存储结构,每个元素占4个存储单元&#xf…

13.3 Go 性能优化

💝💝💝欢迎莅临我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:「stormsha的主页」…

【MAVEN学习 | 第1篇】Maven介绍与安装

文章目录 前言 一. Maven主要作用1.1 依赖管理1.2 项目构建 二. Maven安装和配置2.1 安装2.2 配置环境变量2.3 命令测试2.4 配置文件(1)依赖本地缓存位置(本地仓库位置)(2)配置国内阿里镜像(3&a…

WPS相同字体但是部分文字样式不一样解决办法

如下图,在使用wps编辑文档的时候发现有些电脑的文字字体很奇怪,但是把鼠标移到这个文字的位置,发现它和其他正常文字的字体是一样的,都是仿宋_GB2312 正常电脑的文字如下图所示 打开C:\Windows找到Fonts这个文件夹 把仿宋_GB2312这…

【启明智显产品介绍】工业级HMI芯片Model3芯片详解(二)图像显示

Model3芯片是一款集大容量存储、宽温操作范围及多功能接口于一身的MCU,配备了 2D 图像加速引擎和 PNG 解码/JPEG 编解码引擎,可以满足各类交互设计场景和多媒体互动需求,具有高可靠性、高安全性、高开放度的特点,可以面向于泛工业…

Stable Diffusion 3 大模型文生图实践

windows教程2024年最新Stable Diffusion本地化部署详细攻略,手把手教程(建议收藏!!)_stable diffusion 本地部署-CSDN博客 linux本地安装教程 1.前期准备工作 1)创建conda环境 conda create --name stable3 python3.10 2)下…

【UBEMX安装和使用】

UBEMX安装 1. UBEMX介绍2. 官网下载软件3. 安装步骤下载和关联的STM32Cube固件包 1. UBEMX介绍 STM32CubeMX是一种图形工具,通过分步过程可以非常轻松地配置STM32微控制器和微处理器,以及为Arm Cortex-M内核或面向Arm Cortex-A内核的特定Linux设备树生成…

Flutter调用本地web

前言: 在目前Flutter 环境中,使用在线 webview 是一种很常见的行为 而在 app 环境中,离线使用则更有必要 1.环境准备 将依赖导入 2.引入前端代码 前端代码有两种情况 一种是使用打包工具 build 而来的前端代码 另一种情况是直接使用 HTML 文件 …

YoloV8改进策略:Block篇|即插即用|StarNet,重写星操作,使用Block改进YoloV8(全网首发)

摘要 本文主要集中在介绍和分析一种新兴的学习范式——星操作(Star Operation),这是一种通过元素级乘法融合不同子空间特征的方法,通过元素级乘法(类似于“星”形符号的乘法操作)将不同子空间的特征进行融…

java:动态代理和cglib代理的简单例子

# 项目代码资源&#xff1a; 可能还在审核中&#xff0c;请等待。。。 https://download.csdn.net/download/chenhz2284/89457803 # 项目代码 【pom.xml】 <dependency><groupId>cglib</groupId><artifactId>cglib</artifactId><version&…

WGCLOUD的web ssh提示websocket服务连接已断开

这个问题一般是server主机没有开放端口9998&#xff0c;因为9998是web ssh的端口&#xff0c;需要开放 我们只要在防火墙&#xff0c;或者安全软件&#xff0c;把这个端口开放了就可以了

小白学-WEBGL

第一天&#xff1a; 1.canvas和webgl的区别 Canvas 和 WebGL 都是用于在网页上绘制图形的技术&#xff0c;它们通过浏览器提供的 API 使开发者能够创建丰富的视觉内容&#xff0c;但它们的工作原理和用途有所不同。 Canvas Canvas API 提供了一个通过 JavaScript 和 HTML <…

Xtuner微调

环境安装 studio-conda xtuner0.1.17 conda activate xtuner0.1.17 进入家目录 &#xff08;~的意思是 “当前用户的home路径”&#xff09; cd ~ 创建版本文件夹并进入&#xff0c;以跟随本教程 mkdir -p /root/xtuner0117 && cd /root/xtuner0117 拉取 0.1.17 的版…

Java IO模型BIO、NIO、AIO介绍

第一章 BIO、NIO、AIO课程介绍 1.1 课程说明 在java的软件设计开发中&#xff0c;通信架构是不可避免的&#xff0c;我们在进行不同系统或者不同进程之间的数据交互&#xff0c;或者在高并发下的通信场景下都需要用到网络通信相关的技术&#xff0c;对于一些经验丰富的程序员来…