基于飞浆NLP的BERT-finetuning新闻文本分类

目录

1.数据预处理

2.加载模型

3.批训练

4.准确率

1.数据预处理

导入所需库

import numpy as np
from paddle.io import DataLoader,TensorDataset
from paddlenlp.transformers import BertForSequenceClassification, BertTokenizer
from sklearn.model_selection import train_test_split
import paddle
import matplotlib.pyplot as plt
import jieba

训练集格式 标签ID+\t+标签+\t+原文标题

contents=[]
datas=[]
labels=[]
with open('data/data126283/data/Train.txt',mode='r',encoding='utf-8') as f:contents=f.read().split('\n')
for item in contents:if item=='':continuelabels.append(item.split('\t')[0])datas.append(remove_stopwords(jieba.cut(item.split('\t')[-1])))datas=convert(datas)

去除停用词、

stop=[]
with open('stop.txt',mode='r',encoding='utf-8') as f:stop=f.read().split('\n')
stop_word={}
for s in stop:stop_word[s]=True
def remove_stopwords(datas):  filtered_words = [text for text in datas if text not in stop_word]return ' '.join(filtered_words)  

进行中文分词、转换为token序列

tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')def convert(datas, max_seq_length=40):ans=[]for text in datas:input_ids = tokenizer(text, max_seq_len=max_seq_length)['input_ids']input_ids = input_ids[:max_seq_length]  # 截断input_ids = input_ids + [tokenizer.pad_token_id] * (max_seq_length - len(input_ids))  # 填充ans.append(input_ids)return ans

导入数据,进行预处理,数据集在最后

contents=[]
datas=[]
labels=[]
with open('data/data126283/data/Train.txt',mode='r',encoding='utf-8') as f:contents=f.read().split('\n')
for item in contents:if item=='':continuelabels.append(item.split('\t')[0])datas.append(remove_stopwords(jieba.cut(item.split('\t')[-1])))datas=convert(datas)

 

2.加载模型 

加载预训练模型,冻结大部分参数
model = BertForSequenceClassification.from_pretrained('bert-base-chinese')
model.classifier = paddle.nn.Linear(768, 14)
for name, param in model.named_parameters():if "classifier" not in name and 'bert.pooler.dense' not in name and 'bert.encoder.layers.11' not in name:param.stop_gradient = True

ps:如果只保留classifier用来训练,效果欠佳。

设置超参数,学习率初始设为0.01~0.1

epochs=2
batch_size=1024*4
learning_rate=0.001

损失函数和优化器

criterion = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(learning_rate=learning_rate, parameters=model.parameters())

3.批训练

划分训练集和测试集

datas=np.array(datas)
labels=np.array(labels)
x_train,x_test,y_train,y_test=train_test_split(datas,labels,random_state=42,test_size=0.2)
train_dataset=TensorDataset([x_train,y_train])
train_loader=DataLoader(train_dataset,shuffle=True,batch_size=batch_size)

迭代分批训练,可视化损失函数

total_loss=[]
for epoch in range(epochs):for batch_data,batch_label in train_loader:batch_label=paddle.to_tensor(batch_label,dtype='int64')batch_data=paddle.to_tensor(batch_data,dtype='int64')outputs=model(batch_data)loss=criterion(outputs,batch_label)print(epoch,loss.numpy()[0])total_loss.append(loss.numpy()[0])optimizer.clear_grad()loss.backward()optimizer.step()
paddle.save({'model':model.state_dict()},'model.param')
paddle.save({'optimizer':optimizer.state_dict()},'optimizer.param')
plt.plot(range(len(total_loss)),total_loss)
plt.show()

4.准确率

在测试集上如法炮制,查看准确率

total_loss=[]
x_test=np.array(x_test)
y_test=np.array(y_test)
test_dataset=TensorDataset([x_test,y_test])
test_loader=DataLoader(test_dataset,shuffle=True,batch_size=batch_size)with paddle.no_grad():for batch_data,batch_label in test_loader:batch_label=paddle.to_tensor(batch_label,dtype='int64')batch_data=paddle.to_tensor(batch_data,dtype='int64')outputs=model(batch_data)loss=criterion(outputs,batch_label)print(loss)outputs=paddle.argmax(outputs,axis=1)total_loss.append(loss.numpy()[0])score=0for predict,label in zip(outputs,batch_label):if predict==label:score+=1print(score/len(batch_label))
plt.plot(range(len(total_loss)),total_loss)
plt.show()

最后在验证集上输出要求的类别

arr=['财经','彩票','房产','股票','家居','教育','科技','社会','时尚','时政','体育','星座','游戏','娱乐']
evals=[]
contetns=[]
with open('data/data126283/data/Test.txt',mode='r',encoding='utf-8') as f:contents=f.read().split('\n')
for item in contents:if item=='':continueevals.append(item)
evals=convert(evals)
evals=np.array(evals)
with paddle.no_grad():for i in range(0,len(evals),2048):i=min(len(evals),i)batch_data=evals[i:i+2048]batch_data=paddle.to_tensor(batch_data,dtype='int64')predict=model(batch_data)predict=list(paddle.argmax(predict,axis=1))print(i,len(predict))for j in range(len(predict)):predict[j]=arr[predict[j]]with open('result.txt',mode='a',encoding='utf-8') as f:f.write('\n'.join(predict))f.write('\n')

ps:注意最后的f.write('\n'),否则除第一次,每次打印少一行,很坑

最后损失函数收敛在0.2或0.1左右比较正常,四舍五入差不多90准确率,当然如果你解冻更多参数,自然可以更加精确,看运行环境的配置了,建议不要使用免费平台配置,否则比乌龟还慢。。

欢迎提出问题

数据集

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/134343.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Flink SQL Window TopN 详解

Window TopN 定义(⽀持 Streaming): Window TopN 是特殊的 TopN,返回结果是每⼀个窗⼝内的 N 个最⼩值或者最⼤值。 应⽤场景: TopN 会出现中间结果,出现回撤数据,Window TopN 不会出现回撤数据…

蓝桥杯每日一题2023.11.5

题目描述 方格分割 - 蓝桥云课 (lanqiao.cn) 题目分析 对于每个图我们可以从中间开始搜索,如果到达边界点就说明找到了一种对称的方法,我们可以直接对此进行答案记录每次进行回溯就会找到不同的图像,如果是一样的图像则算一种情况&#xff…

[Kettle] Excel输入

Excel文件采用表格的形式,数据显示直观,操作方便 Excel文件采用工作表存储数据,一个文件有多张不同名称的工作表,分别存放相同字段或不同字段的数据 数据源 物理成绩(Kettle数据集2).xls https://download.csdn.net/download/H…

AI智能公文写作助手“文山会海“

公文写作痛点 没思路,公文写作无从下手公文类型繁多,一时难以全面掌握公文内容组织难度大,不易清晰、有逻辑的进行表达时间紧任务急,往往需要在有限的时间内完成大量写作工作反复修改优化,需满足更多新要求&#xff0…

JavaScript+Flask 实现视频上传的简单demo

前言 需求说明 在网页上选择本地视频并上传到后端服务器后端接收到视频后存储到本地,然后进行处理 技术栈: 前端采用原生HTMLJavaScript 后端采用Flask框架 前端代码 操作步骤: 选中视频文件获取文件内容及文件名将文件内容和文件名封…

Linux学习之vim跳转到特定行数

参考的博客:《Vim跳到最后一行的方法》 《oeasy教您玩转vim - 14 - # 行头行尾》 《Linux:vim 中跳到首行和最后一行》 想要跳到特定行的话,可以在命令模式和正常模式进行跳转。要是对于vim的四种模式不太熟的话,可以到博客《Linu…

在 Python 中打印二叉树

文章目录 Python 中的二叉树树的遍历顺序中序遍历树先序遍历树后序遍历二叉树在Python中的实现使用 Python 打印整个二叉树代码分析本文将讨论二叉树以及我们如何使用它。 我们还将看到如何使用 Python 打印它。 我们将了解在处理二叉树时使用的术语。 我们还将研究使用 Pytho…

使用VSCODE链接Anaconda

打代码还是在VSCODE里得劲 所以得想个办法在VSCODE里运行py文件 一开始在插件商店寻找插件 但是没有发现什么有效果的 幸运的是VSCODE支持自己选择Python的解释器 打开VSCODE 按住CtrlShiftP 输入Select Interpreter 如果电脑已经安装上了Python的环境 VSCODE会默认选择普通…

C++auto 关键字

auto 关键字在 C 语言中就已经存在了,只不过在 C 语言中它的作用是声明自动变量,例如 auto int z 123 ; z 本来是局部变量,加上 auto 后变成了局部的自动变量,就是当前变量的生存周期是由编译器自动决定的&#xff0c…

Redis中的渐进式遍历-Scan命令

之前我们学习过遍历命令keys,而keys *是一次性的把整个redis中所有的key都获取到.在不知道当前redis中有多少key的情况下,这个操作是非常危险的,可能会一下子得到太多的key而阻塞redis服务器.从而使其他redis客户端卡顿. 通过渐进式遍历,就可以做到,既可以获取到所有的key,同时…

视频集中存储EasyCVR平台播放一段时间后出现黑屏是什么原因?该如何解决?

安防视频监控/视频集中存储/云存储/磁盘阵列EasyCVR平台可拓展性强、视频能力灵活、部署轻快,可支持的主流标准协议有国标GB28181、RTSP/Onvif、RTMP等,以及支持厂家私有协议与SDK接入,包括海康Ehome、海大宇等设备的SDK等。平台既具备传统安…

零代码编程:用ChatGPT批量提取flash动画swf文件中的mp3

文件夹:C:\迅雷下载\有声绘本_flash[淘宝-珍奥下载]\有声绘本 flash,里面有多个flash文件,怎么转换成mp3文件呢? 可以使用swfextract工具从Flash动画中提取音频,下载地址是http://www.swftools.org/download.html,也…

Python学习-shutil模块和OS模块学习

shutil模块 针对文件的拷贝,删除,移动,压缩和解压操作 # 1.copyfileobj只能复制文件内容,无法复制权限#复制文件时,要选择自己有权限的目录执行操作,创建的文件会根据系统umask设定的参数来指定用户权限 s…

云计算和大数据技术

一、云计算技术的概述 云计算是一种基于互联网的计算模式,它将计算资源(包括硬件和软件)通过网络提供给用户,使用户能够方便地访问和使用这些资源。云计算技术可以分为三个层次:基础设施即服务 (IaaS),平台…

理解交叉熵(Cross Entropy)

交叉熵(Cross-Entropy)是一种用于衡量两个概率分布之间的距离或相似性的度量方法。在机器学习中,交叉熵通常用于损失函数,用于评估模型的预测结果与实际标签之间的差异。 在分类问题中,交叉熵损失函数通常用于多分类问…

项目构建工具maven的基本配置+idea 中配置 maven

👑 博主简介:知名开发工程师 👣 出没地点:北京 💊 2023年目标:成为一个大佬 ——————————————————————————————————————————— 版权声明:本文为原创文…

数据结构与算法—双链表

前言 前面有很详细的讲过线性表(顺序表和链表),当时讲的链表以单链表为主,但在实际应用中双链表有很多应用场景,例如大家熟知的LinkedList。 双链表与单链表区别 单链表和双链表都是线性表的链式实现,它们的主要区别在于节点结构…

一文掌握 Apache SkyWalking

Apache SkyWalking SkyWalking是一个开源可观测平台,用于收集、分析、聚合和可视化来自服务和云原生基础设施的数据。SkyWalking 提供了一种简单的方法来保持分布式系统的清晰视图,甚至跨云。它是一种现代APM,专为云原生、基于容器的分布式系…

【图像分类】【深度学习】【Pytorch版本】AlexNet模型算法详解

【图像分类】【深度学习】【Pytorch版本】AlexNet模型算法详解 文章目录 【图像分类】【深度学习】【Pytorch版本】AlexNet模型算法详解前言AlexNet讲解卷积层的作用卷积过程特征图的大小计算公式Dropout的作用AlexNet模型结构 AlexNet Pytorch代码完整代码总结 前言 AlexNet是…

Apache Doris 是什么

Apache Doris 是一个开源的、基于MPP(Massively Parallel Processing)架构的SQL数据仓库。它旨在提供高性能、高可靠性的数据分析服务,特别适合处理大规模数据集。 Doris 的主要特点包括: 高性能:通过MPP架构&#xf…