python-pytorc+bert句子分类0.1.000

这里写目录标题

    • 引入包
    • 加载预训练模型
    • 加载数据文件
    • 定义数据
    • 实例化数据集
    • 使用loader加载数据
      • 设定最大句子长度
      • 定义加padding的函数
      • 定义加collate_fn函数
      • 使用DataLoader加载数据
    • 定义模型
      • 测试预训练模型输出
      • 测试预训练模型输出
      • 定义自己的模型
    • 参考

引入包

import torch
from torch import nn
from torch.utils.data import DataLoader,Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor

加载预训练模型

from transformers import BertTokenizer,BertForSequenceClassification,BertConfig
config=BertConfig.from_pretrained("D:\\jpdir\\bert\\bertchinese",num_labels=10)
tokenizer = BertTokenizer.from_pretrained("D:\\jpdir\\bert\\bertchinese")
model = BertForSequenceClassification.from_pretrained("D:\\jpdir\\bert\\bertchinese",config=config)
d:\python\python37\lib\site-packages\tensorflow\python\framework\dtypes.py:526: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'._np_qint8 = np.dtype([("qint8", np.int8, 1)])
d:\python\python37\lib\site-packages\tensorflow\python\framework\dtypes.py:527: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'._np_quint8 = np.dtype([("quint8", np.uint8, 1)])
d:\python\python37\lib\site-packages\tensorflow\python\framework\dtypes.py:528: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'._np_qint16 = np.dtype([("qint16", np.int16, 1)])
d:\python\python37\lib\site-packages\tensorflow\python\framework\dtypes.py:529: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'._np_quint16 = np.dtype([("quint16", np.uint16, 1)])
d:\python\python37\lib\site-packages\tensorflow\python\framework\dtypes.py:530: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'._np_qint32 = np.dtype([("qint32", np.int32, 1)])
d:\python\python37\lib\site-packages\tensorflow\python\framework\dtypes.py:535: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.np_resource = np.dtype([("resource", np.ubyte, 1)])

加载数据文件

# 整理训练数据
x_train=[]
x_test=[]
with open("D:\\jpdir\\bert\\bertdata\\Multi-classification\\train.txt","r",encoding="utf-8") as f:lines=f.readlines()for line in lines:x_train.append(line.split("\t")[0])x_test.append(line.split("\t")[1].replace("\n",""))# 整理测试数据
y_train=[]
y_test=[]
with open("D:\\jpdir\\bert\\bertdata\\Multi-classification\\test.txt","r",encoding="utf-8") as f:lines=f.readlines()for line in lines:y_train.append(line.split("\t")[0])y_test.append(line.split("\t")[1].replace("\n",""))

定义数据

class CustomDataset(Dataset):def __init__(self,data_path):# 初始化数据集的过程,例如加载数据等# 假设我们有一个数据列表self.data = []with open(data_path,"r",encoding="utf-8") as f:lines=f.readlines()for line in lines:self.data.append(line)def __len__(self):# 返回数据集的长度return len(self.data)def __getitem__(self, index):# 根据索引获取一个样本line=self.data[index]content=line.split("\t")[0]label=line.split("\t")[1].replace("\n","").replace("\"","")return content,label

实例化数据集

train_data= CustomDataset("D:\\jpdir\\bert\\bertdata\\Multi-classification\\train.txt")
test_data= CustomDataset("D:\\jpdir\\bert\\bertdata\\Multi-classification\\test.txt")
len(train_data),len(test_data)
(4610, 4768)

使用loader加载数据

设定最大句子长度

maxlenhth=32

定义加padding的函数

不够maxlength,就加pad,这的pad对应的索引是0
def add_padding(data):if len(data)<maxlenhth:for x in torch.arange(maxlenhth-len(data)):data.append(0)return data

定义加collate_fn函数

这里处理tokenizer和paading

def collate_fn(batchData,tokenizer):scentence=[line[0] for line in batchData]label=[int(line[1]) for line in batchData]scentence=torch.tensor([add_padding(tokenizer.encode(one,max_length=32,add_special_tokens=True)) for one in scentence])label=torch.tensor(label)return scentence,label

使用DataLoader加载数据

loader = DataLoader(train_data, 5, shuffle=True,collate_fn=lambda x:collate_fn(x,tokenizer))
data_iter = iter(loader)
print(len(data_iter))# 看下数据
data = next(data_iter)
"长度:",len(data[0]),"data[0]:",data[0],"data[1]:",data[1],"data:",data,data[0].size(),data[1].unsqueeze(1).size()
922('长度:',5,'data[0]:',tensor([[ 101,  517,  682, 1957, 3187, 3127,  518, 3119, 6228, 1086, 1932, 1094,3209, 3241,  677, 4028, 1920, 5310, 2229,  102,    0,    0,    0,    0,0,    0,    0,    0,    0,    0,    0,    0],[ 101, 2349, 7561,  680, 2357, 3306, 2199, 6158, 5739, 1744, 1957, 4374,2970, 6224, 1217, 2135, 2196, 4265, 2900, 3189, 1377, 2521,  102,    0,0,    0,    0,    0,    0,    0,    0,    0],[ 101, 4242, 6946, 3215, 3777, 9560, 7555, 4680, 8183, 2398, 6629,  122,118,  124, 2233, 1762, 1545, 1059, 3621, 8380, 2835,  102,    0,    0,0,    0,    0,    0,    0,    0,    0,    0],[ 101, 3791, 1744, 8226,  674,  782, 7770, 5440,  868, 3152, 1091,  100,3152, 1265, 3221, 1415,  886,  782, 2814, 3289,  100,  102,    0,    0,0,    0,    0,    0,    0,    0,    0,    0],[ 101,  517, 7987,  722, 6484,  518,  100,  100, 2845, 1399, 2661, 5683,2458, 1423,  102,    0,    0,    0,    0,    0,    0,    0,    0,    0,0,    0,    0,    0,    0,    0,    0,    0]]),'data[1]:',tensor([9, 7, 1, 3, 8]),'data:',(tensor([[ 101,  517,  682, 1957, 3187, 3127,  518, 3119, 6228, 1086, 1932, 1094,3209, 3241,  677, 4028, 1920, 5310, 2229,  102,    0,    0,    0,    0,0,    0,    0,    0,    0,    0,    0,    0],[ 101, 2349, 7561,  680, 2357, 3306, 2199, 6158, 5739, 1744, 1957, 4374,2970, 6224, 1217, 2135, 2196, 4265, 2900, 3189, 1377, 2521,  102,    0,0,    0,    0,    0,    0,    0,    0,    0],[ 101, 4242, 6946, 3215, 3777, 9560, 7555, 4680, 8183, 2398, 6629,  122,118,  124, 2233, 1762, 1545, 1059, 3621, 8380, 2835,  102,    0,    0,0,    0,    0,    0,    0,    0,    0,    0],[ 101, 3791, 1744, 8226,  674,  782, 7770, 5440,  868, 3152, 1091,  100,3152, 1265, 3221, 1415,  886,  782, 2814, 3289,  100,  102,    0,    0,0,    0,    0,    0,    0,    0,    0,    0],[ 101,  517, 7987,  722, 6484,  518,  100,  100, 2845, 1399, 2661, 5683,2458, 1423,  102,    0,    0,    0,    0,    0,    0,    0,    0,    0,0,    0,    0,    0,    0,    0,    0,    0]]),tensor([9, 7, 1, 3, 8])),torch.Size([5, 32]),torch.Size([5, 1]))

定义模型

测试预训练模型输出

BertForSequenceClassification的输入input_ids size是[batch_size,maxlength],labels的size是[batch_size,1]
input_ids 是中文转成设定的数字
lables是数据的分类标签

测试预训练模型输出

loss 损失值
logits 概率分布
input_ids = torch.tensor(tokenizer.encode("词汇阅读是关键 08年考研暑期英语复习全指南",max_length=32,add_special_tokens=True)).unsqueeze(0)  # Batch size 1
labels = torch.tensor([1]).unsqueeze(0)  # Batch size 1
outputs = model(input_ids, labels=labels)
print(outputs)
loss, logits = outputs
loss, logits
(tensor(2.2565, grad_fn=<NllLossBackward0>), tensor([[ 0.5478, -0.0462, -0.2125, -0.8165,  0.1208, -0.4684, -0.9593,  0.4391,0.1320, -1.0400]], grad_fn=<AddmmBackward0>))(tensor(2.2565, grad_fn=<NllLossBackward0>),tensor([[ 0.5478, -0.0462, -0.2125, -0.8165,  0.1208, -0.4684, -0.9593,  0.4391,0.1320, -1.0400]], grad_fn=<AddmmBackward0>))

定义自己的模型

# Define model
class NeuralNetwork(nn.Module):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.linear_relu_stack = nn.Sequential(nn.Linear(28*28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10))def forward(self, x):x = self.flatten(x)logits = self.linear_relu_stack(x)return logitsmodel1 = NeuralNetwork()
print(model1)
NeuralNetwork((flatten): Flatten(start_dim=1, end_dim=-1)(linear_relu_stack): Sequential((0): Linear(in_features=784, out_features=512, bias=True)(1): ReLU()(2): Linear(in_features=512, out_features=512, bias=True)(3): ReLU()(4): Linear(in_features=512, out_features=10, bias=True))
)
optimizer = torch.optim.AdamW(model.parameters(),lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
model.train()for i,batch in enumerate(loader):optimizer.zero_grad()scentenses,labels=batchoutput=model(scentenses,labels=labels.unsqueeze(1))loss,logits=outputloss.backward()optimizer.step()print(i,loss.item())

参考

https://blog.51cto.com/u_15127680/3841198

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/34319.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

this.$prompt 提示框增加文本域并修改文本域高度

2024.06.24今天我学习了如何对提示框增加文本域的方法&#xff0c;效果如下&#xff1a; 代码如下&#xff1a; <script>methods:{reject_event(){this.$prompt(驳回内容, 提示, {confirmButtonText: 确定,cancelButtonText: 取消,inputType: textarea,inputPlaceholder…

计算机网络(数据链路层)

数据链路层概述 数据链路层位于计算机网络的低层&#xff0c;且在物理层之上&#xff0c;数据链路层使用的信道主要有以下两种类型。 &#xff08;1&#xff09;点对点通信。在信道上使用一对一的点对点通信。 &#xff08;2&#xff09;广播信道。这种信道使用一对多的广播通…

【linux】详解——库

目录 概述 库 库函数 静态库 动态库 制作动静态库 使用动静态库 如何让系统默认找到第三方库 lib和lib64的区别 /和/usr/和/usr/local下lib和lib64的区别 环境变量 配置相关文件 个人主页&#xff1a;东洛的克莱斯韦克-CSDN博客 简介&#xff1a;C站最萌博主 相关…

DDK电通拧紧MFC-S060控制器过流维修

一、DDK伺服拧紧轴控制器过流故障的成因 1. 电源电压过低&#xff1a;当电源电压过低时&#xff0c;控制器可能会出现过流现象。 2. 负载过大&#xff1a;当负载过大时&#xff0c;DDK电通拧紧机控制器MFC-S060的电流也会随之增大&#xff0c;可能导致过流故障。 3. 控制器内部…

自动调整QTableView列宽以适应窗口大小

问题描述 十年前&#xff0c;有人提出了一个问题&#xff1a;当我使用自定义模型来展示 QTableView&#xff0c;并固定了三列时&#xff0c;初始窗口显示正常&#xff0c;但当我调整窗口大小时&#xff0c;QTableView 会随之调整大小&#xff0c;而列宽却保持不变。我想让列宽…

远程连接mysql数据库的详细配置

1. 确认 MySQL 服务器配置 首先&#xff0c;确认 MySQL 服务器的配置允许远程连接。您需要编辑 MySQL 的配置文件&#xff0c;并确保以下设置正确&#xff1a; bind-address&#xff1a;这个参数控制 MySQL 监听的 IP 地址。如果要允许任何 IP 地址连接&#xff0c;请将其设置…

手写 Promise 的实现

手写 Promise 的实现 从实现原理的角度分析 Promise 是什么 从语法上说&#xff0c;Promise 是一个对象&#xff0c;从它可以获取异步操作的消息。ES6 原生提供了Promise对象。 Promise内部有三种状态&#xff1a;pending&#xff08;进行中&#xff09;、fulfilled&#xf…

开箱即用:一个易用的开源表单工具!【送源码】

随着互联网的普及&#xff0c;表单应用场景越来越广泛&#xff0c;从网站注册、调查问卷到考试测评&#xff0c;无处不在。传统的表单制作方式需要一定的代码基础&#xff0c;对于不懂编程的小伙伴来说&#xff0c;无疑是一道门槛。 今天&#xff0c;给大家分享一款开源的表单…

如何理解redis是单线程的

写在文章开头 在面试时我们经常会问到这样一道题 你刚刚说redis是单线程的&#xff0c;那你能不能告诉我它是如何基于单个线程完成指令接收与连接接入的&#xff1f;这时候我们经常会得到沉默&#xff0c;所以对于这道题&#xff0c;笔者会直接通过3.0.0源码分析的角度来剖析…

[数据集][目标检测]花生米计数霉变检测数据集VOC+YOLO格式387张2类别

数据集格式&#xff1a;Pascal VOC格式YOLO格式(不包含分割路径的txt文件&#xff0c;仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数)&#xff1a;387 标注数量(xml文件个数)&#xff1a;387 标注数量(txt文件个数)&#xff1a;387 标注类别…

使用Leaflet和瓦片地图实现离线地图的技术指南

引言 在现代的Web应用中&#xff0c;地图服务扮演着越来越重要的角色。然而&#xff0c;在一些特殊环境下&#xff0c;如偏远地区或网络环境不稳定的情况下&#xff0c;依赖在线地图服务可能会受到限制。因此&#xff0c;实现离线地图功能成为了一个重要的需求。本文将介绍如何…

【数据库】Oracle 分区表与 TRUNC 函数的优化应用

在 Oracle 数据库中&#xff0c;分区表是一种强大的数据管理工具&#xff0c;它允许将大型表分割成更小、更易于管理的部分&#xff0c;称为分区。每个分区可以独立地进行管理&#xff0c;包括备份、恢复和优化。分区表特别适用于处理大量数据&#xff0c;可以显著提高查询性能…

Redis入门与应用(1)

Redis的技术全景 Redis是一个开源的基于键值对&#xff08;Key-Value&#xff09;的NoSQL数据库&#xff0c;使用ANSI C语言编写&#xff0c;支持网络&#xff0c;基于内存但支持持久化。它性能优越&#xff0c;并提供多种语言的API。我们可以将Redis视为一个巨大的Map&#x…

《Java面试题集中营》- Java并发

《Java并发编程的艺术》、《Java并发编程之美》 运行中的线程能否强制杀死 Jdk提供了stop()方法用于强制停止线程&#xff0c;但官方并不建议使用&#xff0c;因为强制停止线程会导致线程使用的资源&#xff0c;比如文件描述符、网络连接处于不正常的状态。建议使用标志位的方…

秋招突击——第九弹——Redis缓存

文章目录 引言正文缓存基础旁路缓存模式&#xff08;重点&#xff09;读穿透&#xff08;了解&#xff09;写穿透&#xff08;了解&#xff09;异步缓存写入模式面试重点 缓存异常场景缓存穿透缓存击穿缓存雪崩面试重点 缓存一致性怎么保证&#xff1f;缓存一致性问题是什么方案…

[职场] 策略运营求职简历范文精选 #知识分享#微信#微信

策略运营求职简历范文精选 策略运营是用户运营的一种模式&#xff0c;主要针对于用户量级在千人到百万人规模的运营。下面是策略运营求职简历范文精选&#xff0c;供大家参考。 个人信息 姓名&#xff1a;蓝山 年龄&#xff1a;33岁 地址&#xff1a;北京 工作经验&#x…

C++STL梳理

CSTL标准手册&#xff1a; https://cplusplus.com/reference/stl/ https://cplusplus.com/reference/vector/vector/at/ 1、STL基础 1.1、STL基本组成(6大组件13个头文件) 通常认为&#xff0c;STL 是由容器、算法、迭代器、函数对象、适配器、内存分配器这 6 部分构成&…

JS延迟加载的方式有哪些

JavaScript延迟加载&#xff08;也称为懒加载&#xff09;是一种优化网页性能的技术&#xff0c;它允许脚本在页面加载完成后再执行&#xff0c;从而加快页面初始加载速度。 以下是几种常见的JavaScript延迟加载方式&#xff1a; 异步加载&#xff08;async&#xff09;: 使用a…

信息检索(54):On the Effect of Low-Frequency Terms on Neural-IR Models

On the Effect of Low-Frequency Terms on Neural-IR Models 摘要1 引言2 背景和相关工作3 实验设计4 词汇量的影响5 包含低频词的查询6 结论 发布时间&#xff08;2019&#xff09; 低频词对于神经检索模型的影响 摘要 低频词是信息检索模型面临的一个反复出现的挑战&#…

Java代码如何优化的?

1、单一职责 2、注释 3、公共类/方法抽离 4、单元测试 5、SQL优化 6、代码reviewe 7、库存以前是直接操作数据库--->lua 8、日志----->ELK