基于BERT的语义分析实现


✨✨ 欢迎大家来访Srlua的博文(づ ̄3 ̄)づ╭❤~✨✨

🌟🌟 欢迎各位亲爱的读者,感谢你们抽出宝贵的时间来阅读我的文章。

我是Srlua小谢,在这里我会分享我的知识和经验。🎥

希望在这里,我们能一起探索IT世界的奥妙,提升我们的技能。🔮

记得先点赞👍后阅读哦~ 👏👏

📘📚 所属专栏:传知代码论文复现

欢迎访问我的主页:Srlua小谢 获取更多信息和资源。✨✨🌙🌙

​​

​​

目录

概述

语义分类

文本分类

情感分类

实现原理

核心逻辑

pre_deal.py

train.py

test_demo.py

实现方式&演示效果

训练阶段

测试阶段


本文所有资源均可在该地址处获取。

概述

在之前的文章中,我们介绍了BERT模型。BERT作为一种预训练语言模型,它具有很好的兼容性,能够运用在各种下游任务中,本文的主要目的是利用数据集来对BERT进行训练,从而实现一个语义分类的模型。

语义分类

语义分类是自然语言处理任务中的一种,包含文本分类、情感分析

文本分类

文本分类是指给定文本a,将文本分类为n个类别中的一个或多个。常见的应用包括文本话题分类,情感分类,具体的分类方向有有二分类,多分类和多标签分类。
文本分类可以采用传统机器学习方法(贝叶斯,svm等)和深度学习方法(fastText,TextCNN等)实现。
举例而言,对于一个对话数据集,我们可以用1、2、3表示他们的话题,如家庭、学校、工作等,而文本分类的目的,则是把这些文本的话题划分到给定的三种类别中。

情感分类

情感分析是自然语言处理中常见的场景,比如商品评价等。通过情感分析,可以挖掘产品在各个维度的优劣。情感分类其实也是一种特殊的文本分类,只是他更聚焦于情感匹配词典。
举例而言,情感分类可以用0/1表示负面评价/正面评价,例子如下:

0,不好的,319房间有故臭味。要求换房说满了,我是3月去的。在路上认识了一个上海人,他说他退房前也住的319,也是一股臭味。而且这个去不掉,特别是晚上,很浓。不知道是厕所的还是窗外的。服务一般,门前有绿皮公交去莫高窟,不过敦煌宾馆也有,下次住敦煌宾馆。再也不住这个酒店了,热水要放半个小时才有。
1,不错的酒店,大堂和餐厅的环境都不错。但由于给我的是一间走廊尽头的房间,所以房型看上去有点奇怪。客厅和卧室是连在一起的,面积偏小。服务还算到位,总的来说,性价比还是不错的。

本文将以情感二分类为例,实现如何利用BERT进行语义分析。

实现原理

首先,基于BERT预训练模型,能将一个文本转换成向量,作为模型的输入。
在BERT预训练模型的基础上,新增一个全连接层,将输入的向量通过训练转化成一个tensor作为输出,其中这个tensor的维度则是需要分类的种类,具体的值表示每个种类的概率。例如:

[0.25,0.75] 

指代的是有0.25的概率属于第一类,有0.75的概率属于第二类,因此,理论输出结果是把该文本分为第二类。

核心逻辑

pre_deal.py

import csv
import random
from datasets import load_datasetdef read_file(file_path):csv_reader = csv.reader(open(file_path, encoding='UTF-8'))num = 0data = []for row in csv_reader:if num == 0:num = 1continuecomment_data = [row[1], int(row[0])]if len(comment_data[0]) > 500:text=comment_data[0]sub_texts, start, length = [], 0, len(text)while start < length:piecedata=[text[start: start + 500], comment_data[1]]data.append(piecedata)start += 500else:data.append(comment_data)random.shuffle(data)return data

对输入的csv文件进行处理,其中我们默认csv文件的格式是[label,text],将用于训练的内容读取出来,转化为numpy格式,其中,如果遇到有些文本过长(超过模型的输入),将其截断,分为多个文本段来输入。在最后,会通过shuffle函数进行打乱。

train.py

train.py定义了几个函数,用于训练。
首先是Bertmodel类,定义了基于Bert的训练模型:

class Bertmodel(nn.Module):def __init__(self, output_dim, model_path):super(Bertmodel, self).__init__()# 导入bert模型self.bert = BertModel.from_pretrained(model_path)# 外接全连接层self.layer1 = nn.Linear(768, output_dim)def forward(self, tokens):res = self.bert(**tokens)res = self.layer1(res[1])res = res.softmax(dim=1)return res

该模型由Bert和一个全连接层组成,最后经过softmax激活函数。
其次是一个评估函数,用来计算模型结果的准确性

def evaluate(net, comments_data, labels_data, device, tokenizer):ans = 0 # 输出结果i = 0step = 8 # 每轮一次读取多少条数据tot = len(comments_data)while i <= tot:print(i)comments = comments_data[i: min(i + step, tot)]tokens_X = tokenizer(comments, padding=True, truncation=True, return_tensors='pt').to(device=device)res = net(tokens_X)  # 获得到预测结果y = torch.tensor(labels_data[i: min(i + step, tot)]).reshape(-1).to(device=device)ans += (res.argmax(axis=1) == y).sum()i += stepreturn ans / tot

原理就是,将文本转化为tokens,输入给模型,而后利用返回的结果,计算准确性
下面展示了开始训练的主函数,在训练的过程中,进行后向传播,储存checkpoints模型

def training(net, tokenizer, loss, optimizer, train_comments, train_labels, test_comments, test_labels,device, epochs):max_acc = 0.5  # 初始化模型最大精度为0.5for epoch in tqdm(range(epochs)):step = 8i, sum_loss = 0, 0tot=len(train_comments)while i < tot:comments = train_comments[i: min(i + step, tot)]tokens_X = tokenizer(comments, padding=True, truncation=True, return_tensors='pt').to(device=device)res = net(tokens_X)y = torch.tensor(train_labels[i: min(i + step, len(train_comments))]).reshape(-1).to(device=device)optimizer.zero_grad()  # 清空梯度l = loss(res, y)  # 计算损失l.backward()  # 后向传播optimizer.step()  # 更新梯度sum_loss += l.detach()  # 累加损失i += steptrain_acc = evaluate(net, train_comments, train_labels)test_acc = evaluate(net, test_comments, test_labels)print('\n--epoch', epoch + 1, '\t--loss:', sum_loss / (len(train_comments) / 8), '\t--train_acc:', train_acc,'\t--test_acc', test_acc)# 保存模型参数,并重设最大值if test_acc > max_acc:# 更新历史最大精确度max_acc = test_acc# 保存模型max_acc = test_acctorch.save({'epoch': epoch,'state_dict': net.state_dict(),'optimizer': optimizer.state_dict()}, 'model/checkpoint_net.pth')

训练结果表示如下:

--epoch 0 	--train_acc: tensor(0.6525, device='cuda:1') 	--test_acc tensor(0.6572, device='cuda:1')0%|          | 0/20 [00:00<?, ?it/s]5%|▌         | 1/20 [01:48<34:28, 108.88s/it]10%|█         | 2/20 [03:38<32:43, 109.10s/it]15%|█▌        | 3/20 [05:27<30:56, 109.20s/it]20%|██        | 4/20 [07:15<29:02, 108.93s/it]25%|██▌       | 5/20 [09:06<27:23, 109.58s/it]30%|███       | 6/20 [10:55<25:29, 109.26s/it]35%|███▌      | 7/20 [12:44<23:40, 109.28s/it]40%|████      | 8/20 [14:33<21:51, 109.29s/it]45%|████▌     | 9/20 [16:23<20:04, 109.49s/it]50%|█████     | 10/20 [18:13<18:15, 109.59s/it]55%|█████▌    | 11/20 [20:03<16:27, 109.72s/it]60%|██████    | 12/20 [21:52<14:35, 109.45s/it]65%|██████▌   | 13/20 [23:41<12:45, 109.35s/it]70%|███████   | 14/20 [25:30<10:54, 109.14s/it]75%|███████▌  | 15/20 [27:19<09:05, 109.03s/it]80%|████████  | 16/20 [29:07<07:15, 108.84s/it]85%|████████▌ | 17/20 [30:56<05:26, 108.86s/it]90%|█████████ | 18/20 [32:44<03:37, 108.75s/it]95%|█████████▌| 19/20 [34:33<01:48, 108.73s/it]
100%|██████████| 20/20 [36:22<00:00, 108.71s/it]
100%|██████████| 20/20 [36:22<00:00, 109.11s/it]--epoch 1 	--loss: tensor(1.2426, device='cuda:1') 	--train_acc: tensor(0.6759, device='cuda:1') 	--test_acc tensor(0.6789, device='cuda:1')--epoch 2 	--loss: tensor(1.0588, device='cuda:1') 	--train_acc: tensor(0.8800, device='cuda:1') 	--test_acc tensor(0.8708, device='cuda:1')--epoch 3 	--loss: tensor(0.8543, device='cuda:1') 	--train_acc: tensor(0.8988, device='cuda:1') 	--test_acc tensor(0.8887, device='cuda:1')--epoch 4 	--loss: tensor(0.8208, device='cuda:1') 	--train_acc: tensor(0.9111, device='cuda:1') 	--test_acc tensor(0.8990, device='cuda:1')--epoch 5 	--loss: tensor(0.8024, device='cuda:1') 	--train_acc: tensor(0.9206, device='cuda:1') 	--test_acc tensor(0.9028, device='cuda:1')--epoch 6 	--loss: tensor(0.7882, device='cuda:1') 	--train_acc: tensor(0.9227, device='cuda:1') 	--test_acc tensor(0.9024, device='cuda:1')--epoch 7 	--loss: tensor(0.7749, device='cuda:1') 	--train_acc: tensor(0.9288, device='cuda:1') 	--test_acc tensor(0.9036, device='cuda:1')--epoch 8 	--loss: tensor(0.7632, device='cuda:1') 	--train_acc: tensor(0.9352, device='cuda:1') 	--test_acc tensor(0.9061, device='cuda:1')--epoch 9 	--loss: tensor(0.7524, device='cuda:1') 	--train_acc: tensor(0.9421, device='cuda:1') 	--test_acc tensor(0.9090, device='cuda:1')--epoch 10 	--loss: tensor(0.7445, device='cuda:1') 	--train_acc: tensor(0.9443, device='cuda:1') 	--test_acc tensor(0.9103, device='cuda:1')--epoch 11 	--loss: tensor(0.7397, device='cuda:1') 	--train_acc: tensor(0.9480, device='cuda:1') 	--test_acc tensor(0.9128, device='cuda:1')--epoch 12 	--loss: tensor(0.7321, device='cuda:1') 	--train_acc: tensor(0.9505, device='cuda:1') 	--test_acc tensor(0.9123, device='cuda:1')--epoch 13 	--loss: tensor(0.7272, device='cuda:1') 	--train_acc: tensor(0.9533, device='cuda:1') 	--test_acc tensor(0.9140, device='cuda:1')--epoch 14 	--loss: tensor(0.7256, device='cuda:1') 	--train_acc: tensor(0.9532, device='cuda:1') 	--test_acc tensor(0.9111, device='cuda:1')--epoch 15 	--loss: tensor(0.7186, device='cuda:1') 	--train_acc: tensor(0.9573, device='cuda:1') 	--test_acc tensor(0.9123, device='cuda:1')--epoch 16 	--loss: tensor(0.7135, device='cuda:1') 	--train_acc: tensor(0.9592, device='cuda:1') 	--test_acc tensor(0.9136, device='cuda:1')--epoch 17 	--loss: tensor(0.7103, device='cuda:1') 	--train_acc: tensor(0.9601, device='cuda:1') 	--test_acc tensor(0.9128, device='cuda:1')--epoch 18 	--loss: tensor(0.7091, device='cuda:1') 	--train_acc: tensor(0.9590, device='cuda:1') 	--test_acc tensor(0.9086, device='cuda:1')--epoch 19 	--loss: tensor(0.7084, device='cuda:1') 	--train_acc: tensor(0.9626, device='cuda:1') 	--test_acc tensor(0.9123, device='cuda:1')--epoch 20 	--loss: tensor(0.7038, device='cuda:1') 	--train_acc: tensor(0.9628, device='cuda:1') 	--test_acc tensor(0.9107, device='cuda:1')

最终训练结果,在训练集上达到了96.28%的准确率,在测试集上达到了91.07%的准确率

test_demo.py

这个函数提供了一个调用我们储存的checkpoint模型来进行预测的方式,将input转化为berttokens,而后输入给模型,返回输出结果。

input_text=['这里环境很好,风光美丽,下次还会再来的。']
Bert_model_path = 'xxxx'
output_path='xxxx'
device = torch.device('cpu')
checkpoint = torch.load(output_path,map_location='cpu')model = Bertmodel(output_dim=2,model_path=Bert_model_path)
model.load_state_dict(checkpoint,False)
# print(model)
tokenizer = BertTokenizer.from_pretrained(Bert_model_path,model_max_length=512)tokens_X = tokenizer(input_text, padding=True, truncation=True, return_tensors='pt').to(device='cpu')
model.eval()
output=model(tokens_X)
print(output)
out = torch.unsqueeze(output.argmax(dim=1), dim=1)
result = out.numpy()
print(result)
if result[0][0]==1:print("positive")
else:print("negative")

实现方式&演示效果

训练阶段

首先找到能够拿来训练的数据,运行pre_deal.py进行预处理,而后可以在main.py修改模型的相关参数,运行main.py开始训练。
这个过程,可能会收到硬件条件的影响,推荐使用cuda进行训练。如果实在训练不了,可以直接调用附件中对应的训练好的模型来进行预测。

测试阶段

运行test_demo.py,测试输入文本的分类结果
输入

input_text=['这里环境很好,风光美丽,下次还会再来的。']

输出

tensor([[0.3191, 0.6809]], grad_fn=<SoftmaxBackward0>)
[[1]]
positive

得出,这句话的情感分类是positive(正面)

​​

希望对你有帮助!加油!

若您认为本文内容有益,请不吝赐予赞同并订阅,以便持续接收有价值的信息。衷心感谢您的关注和支持!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/62396.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

HarmonyOS NEXT应用开发,关于useNormalizedOHMUrl选项的坑

起因是这样的&#xff1a;我这库打包发布出问题了&#xff0c;这个有遇到的吗&#xff1f; 源码里面就没有 request .d.ts,这打包后哪来个这文件&#xff1f;且漏掉了其他文件。 猫哥csdn.yyz_1987 为啥我打包的har里面&#xff0c;只有接口&#xff0c;没有具体实现呢&#x…

单点登录原理

允许跨域–>单点登录。 例如https://www.jd.com/ 同一个浏览器下&#xff1a;通过登录页面产生的cookie里的一个随机字符串的标识&#xff0c;在其他子域名下访问共享cookie获取标识进行单点登录&#xff0c;如果没有该标识则返回登录页进行登录。 在hosts文件下面做的域名…

基于Java的小程序电商商城开源设计源码

近年来电商模式的发展越来越成熟&#xff0c;基于 Java 开发的小程序电商商城开源源码&#xff0c;为众多开发者和企业提供了构建个性化电商平台的有力工具。 基于Java的电子商城购物平台小程序的设计在手机上运行&#xff0c;可以实现管理员&#xff1b;首页、个人中心、用户…

Linux查看网络基础命令

文章目录 Linux网络基础命令1. ifconfig 和 ip一、ifconfig命令二、ip命令 2. ss命令一、基本用法二、常用选项三、输出信息四、使用示例 3. sar 命令一、使用sar查看网络使用情况 4. ping 命令一、基本用法二、常用选项三、输出结果四、使用示例 Linux网络基础命令 1. ifconf…

SpringMVC工作原理【流程图+文字详解SpringMVC工作原理】

SpringMVC工作原理 前端控制器&#xff1a;DispactherServlet处理器映射器&#xff1a;HandlerMapping处理器适配器&#xff1a;HandlerAdapter处理器&#xff1a;Handler&#xff0c;视图解析器&#xff1a;ViewResolver视图&#xff1a;View 首先用户通过浏览器发起HTTP请求…

12寸先进封装设备之-晶圆减薄一体机

晶圆减薄一体机在先进封装厂中的主要作用是对已完成功能的晶圆(主要是硅晶片)的背面基体材料进行磨削,去掉一定厚度的材料,以满足后续封装工艺的要求以及芯片的物理强度、散热性和尺寸要求。随着3D封装技术的发展,晶圆厚度需要减薄至50-100μm甚至更薄,以实现更好的散热效…

Online Judge——【前端项目初始化】项目通用布局开发及初始化

目录 一、新建layouts二、更新App.vue文件三、选择一个布局&#xff08;Layout&#xff09;四、通用菜单Menu的实现菜单路由改为读取路由文件 五、绑定跳转事件六、同步路由到菜单项 一、新建layouts 这里新建一个专门存放布局的布局文件layouts&#xff1a; 然后在该文件夹&…

十四(AJAX)、AJAX、axios、常用请求方法(GET POST...)、HTTP协议、接口文档、form-serialize

1. AJAX介绍及axios基本使用 <!DOCTYPE html> <html lang"zh-CN"><head><meta charset"UTF-8" /><meta http-equiv"X-UA-Compatible" content"IEedge" /><meta name"viewport" content&q…

53 基于单片机的8路抢答器加记分

目录 一、主要功能 二、硬件资源 三、程序编程 四、实现现象 一、主要功能 首先有三个按键 分别为开始 暂停 复位&#xff0c;然后八个选手按键&#xff0c;开机显示四条杠&#xff0c;然后按一号选手按键&#xff0c;数码管显示&#xff13;&#xff10;&#xff0c;这…

【深度学习】各种卷积—卷积、反卷积、空洞卷积、可分离卷积、分组卷积

在全连接神经网络中&#xff0c;每个神经元都和上一层的所有神经元彼此连接&#xff0c;这会导致网络的参数量非常大&#xff0c;难以实现复杂数据的处理。为了改善这种情况&#xff0c;卷积神经网络应运而生。 一、卷积 在信号处理中&#xff0c;卷积被定义为一个函数经过翻转…

前端页面或弹窗在线预览文件的N种方式

需求&#xff1a;后端返回给前端一个地址后&#xff0c;在前端页面上或则在弹框中显示在线的文档、表格、图片、pdf、video等等&#xff0c;嵌入到前端页面 方式一&#xff1a; 使用vue-office 地址&#xff1a;vue-office简介 | vue-office 个人感觉这个插件是最好用的&#x…

Windsurf可以上传图片开发UI了

背景 曾经羡慕Cursor的“画图”开发功能&#xff0c;这不Windsurf安排上了。 Upload Images to Cascade Cascade now supports uploading images on premium models Ask Cascade to build or tweak UI from on image upload New keybindings Keybindings to navigate betwe…

ArraList和LinkedList区别

文章目录 一、结构不同二、访问速度三、插入和删除操作的不同1、决定效率有两个因素&#xff1a;数据量和位置。2、普遍说法是“LinkedList添加删除快”&#xff0c;这里是有前提条件的 四、内存占用情况五、使用场景六、总结 一、结构不同 LinkedList&#xff1a;它基于双向链…

【模型剪枝】YOLOv8 模型剪枝实战 | 稀疏化-剪枝-微调

文章目录 0. 前言1. 模型剪枝概念2. 模型剪枝实操2.1 稀疏化训练2.2 模型剪枝2.3 模型微调总结0. 前言 无奈之下,我还是写了【模型剪枝】教程🤦‍♂️。回想当年,在写《YOLOv5/v7进阶实战专栏》 时,我经历了许多挫折,才最终完成了【模型剪枝】和【模型蒸馏】的内容。当时…

关于函数式接口和编程的解析和案例实战

文章目录 匿名内部类“匿名”在哪里 函数式编程lambda表达式的条件Supplier使用示例 ConsumeracceptandThen使用场景 FunctionalBiFunctionalTriFunctional 匿名内部类 匿名内部类的学习和使用是实现lambda表达式和函数式编程的基础。是想一下&#xff0c;我们在使用接口中的方…

学习笔记:黑马程序员JavaWeb开发教程(2024.11.29)

10.5 案例-部门管理-新增 如何接收来自前端的数据: 接收到json数据之后&#xff0c;利用RequestBody注解&#xff0c;将前端响应回来的json格式的数据封装到实体类中 对代码中Controller层的优化 发现路径中都有/depts&#xff0c;可以将每个方法对应请求路径中的…

数据库管理-第268期 srvctl在ADG备库添加PDB的service报错,看如何解决(20241129)

数据库管理268期 2024-11-29 数据库管理-第268期 srvctl在ADG备库添加PDB的service报错&#xff0c;看如何解决&#xff08;20241129&#xff09;1 背景2 处理过程3 原因总结 数据库管理-第268期 srvctl在ADG备库添加PDB的service报错&#xff0c;看如何解决&#xff08;202411…

brew安装mongodb和php-mongodb扩展新手教程

1、首先保证macos下成功安装了Homebrew&#xff0c; 在终端输入如下命令&#xff1a; brew search mongodb 搜索是不是有mongodb资源&#xff0c; 演示效果如下&#xff1a; 2、下面来介绍Brew 安装 MongoDB&#xff0c;代码如下&#xff1a; brew tap mongodb/brew brew in…

国产FPGA+DSP 双FMC 6U VPX处理板

高性能国产化信号处理平台采用6U VPX架构&#xff0c;双FMC接口国产V7 FPGA 国产多核 DSP 的硬件架构&#xff0c;可以完成一体化电子系统、有源相控阵雷达、电子侦察、MIMO 通信、声呐等领域的高速实时信号处理。 信号处理平台的组成框图如图 1 所示&#xff0c; DSP处理器采…

AI数据分析工具(二)

豆包-免费 优点 强大的数据处理能力&#xff1a; 豆包能够与Excel无缝集成&#xff0c;支持多种数据类型的导入&#xff0c;包括文本、数字、日期等&#xff0c;使得数据整理和分析变得更加便捷。豆包提供了丰富的数据处理功能&#xff0c;如数据去重、填充缺失值、转换格式等…