bert模型数据集加载方式

数据集构造

无论是机器学习还是深度学习对于数据集的构造都是十分重要。

现记录一下PyTorch 的 torch.utils.data.Dataset 类的子类。Dataset 类是PyTorch框架中用于处理数据的基本组件,它允许用户定义自己的数据集类,以满足特定任务的需求。

Dataset是一个抽象基类,用于创建自定义数据集。它定义了两个核心方法:getitemlen,它们是所有数据集必须实现的方法。

类定子类:
重写 init 方法来初始化数据集,可能需要加载数据、解析数据等。
重写 getitem 方法来根据索引返回数据集中的一个样本,通常会包含数据的加载、解码等操作。
重写 len 方法来返回数据集中样本的数量。

import pandas as pd
from transformers import BertTokenizerFast
import torch# 读取数据
df = pd.read_csv("./a.csv", encoding="utf-8")
texts = df["content"][:10].tolist()
labels = df["punish_result"][:10].tolist()texts = list(map(lambda x: str(x), texts))
# texts和labels是一个list,可以自己构造一个# Hugging Face下载这个模型google-bert/bert-base-chinese
model_name = "./bert-base-chinese" 
# 加载分词器
tokenizer = BertTokenizerFast.from_pretrained(model_name)# 对文本进行编码
# truncation=True 文本超过max_length进行截断处理
# padding=True 文本不足max_length进行pad处理 补0
train_encodings = tokenizer(texts, truncation=True, padding=True, max_length=32)# 封装数据为PyTorch Dataset
class TextDataset(torch.utils.data.Dataset):def __init__(self, encodings, labels):self.encodings = encodingsself.labels = labelsdef __getitem__(self, idx):# item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}# 等价上面注释写法,for循环比较好理解item = {}for key, val in self.encodings.items():item[key] = torch.tensor(val[idx])item["labels"] = torch.tensor(self.labels[idx])return itemdef __len__(self):return len(self.labels)train_dataset = TextDataset(train_encodings, labels)for dta in train_dataset:print(dta)break# 打印数据如下:# {'input_ids': tensor([ 101, 1585,  511,  872, 1962, 8024, 2769, 6821, 6804, 3221,  976, 6858,
#         7599, 6392, 1906, 4638,  511, 2769, 2682, 7309,  671,  678, 8024, 1493,
#         6821, 6804, 7444, 6206, 6821,  671, 1779,  102]),
# 'token_type_ids': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
#         0, 0, 0, 0, 0, 0, 0, 0]),
# 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
#         1, 1, 1, 1, 1, 1, 1, 1]),
# 'labels': tensor(0)
# }

上述代码主要通过加载bert-base-chinese模型的分词器处理原始数据,之后实现一个Dataset的子类将数据封装到PyTorch框架可识别数据结构。

数据集构造二
# 自定义数据集
class Dataset(torch.utils.data.Dataset):def __init__(self, path):df = pd.read_csv(path, encoding="utf-8")self.texts = df["content"].tolist()self.labels = df["punish_result"].tolist()def __len__(self):return len(self.labels)def __getitem__(self, i):text = str(self.texts[i])label = int(self.labels[i])return text, labelpath = "./data/abuse.csv"
dataset = Dataset(path)# len(dataset), type(dataset.texts[0])
# 17182 老板百亿检测干什么

加载词典和分词器

# 加载字典和分词工具
token = BertTokenizer.from_pretrained("./bert-base-chinese")

设置辅助函数

def collate_fn(data):sents = [i[0] for i in data]labels = [i[1] for i in data]"""batch_text_or_text_pairs:类型: 列表或元组的列表。含义: 输入的文本数据,可以是单个文本列表(如果只处理单个句子)或配对的文本(如对话或翻译任务中的源语言和目标语言句子)。truncation:类型: 布尔值。含义: 是否对超过最大长度的文本进行截断。设置为 True 表示会截断超出长度限制的文本。padding:类型: 字符串。含义: 决定如何填充短于最大长度的文本。'max_length' 表示所有样本都会被填充到max_length的长度,以确保批次内的所有元素长度一致。max_length:类型: 整数。含义: 设定的最大序列长度。所有输入的文本将会被截断或填充到这个长度。return_tensors:类型: 字符串。含义: 指定返回的张量类型。'pt' 表示返回 PyTorch 张量,其他可能的选项有 'tf'(TensorFlow 张量)或 'np'(NumPy 数组)。return_length:类型: 布尔值。含义: 如果设置为 True,函数还会返回一个列表,其中包含每个输入文本的原始长度,这对于知道哪些部分是填充的很有用。"""data = token.batch_encode_plus(batch_text_or_text_pairs=sents,truncation=True,padding="max_length",max_length=500,return_tensors="pt",return_length=True,)# input_ids: 编码之后的数字input_ids = data["input_ids"]# attention_mask是一个与输入tokens相同形状的二维数组# 1 表示有效的位置,即非填充的tokens。这些位置在计算注意力分数时会被考虑。# 0 表示填充的位置,模型在计算注意力时会忽略这些位置。attention_mask = data["attention_mask"]token_type_ids = data["token_type_ids"]labels = torch.LongTensor(labels)# print(data['length'], data['length'].max())# tensor([ 56,  71,  32, 159,  34, 179,  33,  79,  49,  33,  98,  89, 212,  41,#      63,  58]) tensor(212)return input_ids, attention_mask, token_type_ids, labels

加载数据集

"""
dataset:
类型: torch.utils.data.Dataset 的实例。
含义: 指定要加载的数据集。dataset 参数接收之前定义的 TextDataset 实例,包含了预处理过的文本数据和标签。
batch_size:
类型: 整数。
含义: 每个批次(batch)中的样本数量。在这个例子中,设置为 16,意味着数据加载器每次返回的将是包含16个样本的数据批次,用于模型训练或评估。
collate_fn:
类型: 可调用对象(如函数)。
含义: 用于整理一个批次的数据。当从数据集中取出多个样本时,collate_fn 会被调用来将这些样本打包成一个批次。这对于处理变长序列(如文本)特别有用,因为需要对不同长度的序列进行填充或截断以适应批处理。如果没有提供,默认的 collate_fn 可能不适用于所有情况,特别是当数据具有复杂结构时。
shuffle:
类型: 布尔值。
含义: 是否在每个 epoch 开始时对数据集进行随机洗牌。设置为 True 表示在训练过程中数据会随机排序,有助于提高模型的泛化能力。对于验证或测试集,通常应设为 False。
drop_last:
类型: 布尔值。
含义: 如果设置为 True,在最后一个批次不足以填满整个 batch_size 时,这个批次将会被丢弃。如果设为 False,则最后一个批次可能包含少于 batch_size 的样本数量。这在某些模型训练中是有用的,尤其是当模型设计要求固定的批次大小时。
"""
loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=16, collate_fn=collate_fn, shuffle=True, drop_last=True
)# for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
#     print(i, input_ids.shape, attention_mask.shape, token_type_ids.shape, labels.shape)
#     # 0 torch.Size([16, 500]) torch.Size([16, 500]) torch.Size([16, 500]) torch.Size([16])
#     break

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/18701.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

重学英语:输出的重要性

精通一门外语的四要素:听,说,读,写 输入:听,读 输出:写,说 因为输入是我们可以单独完成,不需要有人互动,所以我们做得最多 输出练习做得很少,…

Redis中的数据结构与内部编码

本篇文章主要是对 Redis 常见的数据结构进行讲解,同时还对其所对应的不同的内部编码进行讲解。希望本篇文章会对你有所帮助。 文章目录 一、五大数据结构 二、数据结构对应的编码方式 String hash list set zset 🙋‍♂️ 作者:Ggggggtm &…

js 面试题学习笔记一

1、什么是防抖和节流?有什么区别?如何实现? 防抖:触发高频事件后N秒内函数只会执行一次,如果N秒高频事件再次被触发,则重新计算时间。(a时间触发,5秒内执行一次,但是第4…

10G UDP协议栈 (9)UDP模块

目录 一、UDP协议简单介绍 二、UDP功能实现 三、仿真 一、UDP协议简单介绍 UDP协议和TCP协议同位于传输层,介于网络层(IP)和应用层之间:UDP数据部分为应用层报文,而UDP报文在IP中承载。 UDP 报文格式相对于简单&am…

电脑出现:excel词典(xllex.dll)文件丢失或损坏的错误提示怎么办?有效的将丢失的xllex.dll修复

当遇到 Excel 提示“词典 (xllex.dll) 文件丢失或损坏”的问题时,通常意味着该动态链接库文件(Dynamic Link Library,DLL),它与拼写检查功能相关联的,无法被正确找到或者合适地使用。那么有什么办法可以解决…

LLVM技术在GaussDB等数据库中的应用

目录 LLVM和数据库 LLVM适用场景 LLVM对所有类型的SQL都会有收益吗? LLVM在OLTP中就一定没有收益吗? GaussDB中的LLVM 1. LLVM在华为应用于数据库的时间线 2. GaussDB LLVM实现简析 3. GaussDB LLVM支持加速的场景 支持LLVM的表达式&#xff1a…

vue项目出现多次ElMessage

问题: 解决方法: let message null if (message null) { message ElMessage.error(“登录过期,请重新登录”); } 最终效果:只出现一个弹框

Orange AIpro Color triangle帧率测试

OpenGL概述 OpenGL ES是KHRNOS Group推出的嵌入式加速3D图像标准,它是嵌入式平台上的专业图形程序接口,它是OpenGL的一个子集,旨在提供高效、轻量级的图形渲染功能。现推出的最新版本是OpenGL ES 3.2。OpenGL和OpenCV OpenCL不同,…

实操专区-第15周-课堂练习专区-漏斗图与金字塔图

实操专区-第15周-课堂练习专区-漏斗图 下载安装ECharts,完成如下样式图形。 代码和截图上传 基本要求:下图3选1,完成代码和截图 完成 3.1.3.16 漏斗图中的任务点 基本要求:2个选一个完成,多做1个加2分。 请用班级学号姓…

银行对公贷款软件业务流程详解

对公贷款业务是指商业银行向企事业单位提供资金支持,用于资本扩充、生产经营、项目建设等方面的融资。其目的在于支持企事业单位的发展,推动经济增长。通过提供资金支持,企事业单位可以获得必要的资金来扩大生产规模、提高生产能力、研发新产…

第8周 分布式事务与数据一致性主流解决方案落地

第8周 分布式事务与数据一致性主流解决方案落地 1. 最终一致性原理与解析2. 微服务的解耦3. 本地消息存储4. 自定义事务管理器5. 本地消息删除********************************************************************************** 本周拓展数据的一致性落地,采用弱…

【Java EE】网络原理——HTTP请求

目录 1.认识URL 2.认识“方法(method)” 2.1GET方法 2.1.1使用Fiddler观察GET请求 2.1.2 GET请求的特点 2.2 POST方法 2.2.1 使用FIddler观察POST方法 2.2.2 POST请求的特点 3.认识请求“报头”(header) 3.1 Host 3.2 C…

Spring MVC 工作流程源码分析

前言: 我们知道 Spring MVC 的核心是前端控制器 DispatcherServlet,客户端所有的请求都会交给 DispatcherServlet 来处理,本篇我我们来分析 Spring MVC 处理客户端请求的流程,也就是工作流程。 Sping MVC 只是储备传送门&#x…

Java整合EasyExcel实战——3(上下列相同合并单元格策略)

参考&#xff1a;https://juejin.cn/post/7322156759443095561?searchId202405262043517631094B7CCB463FDA06https://juejin.cn/post/7322156759443095561?searchId202405262043517631094B7CCB463FDA06 准备条件 依赖 <dependency><groupId>com.alibaba</gr…

邻接矩阵广度优先遍历

关于图的遍历实际上就两种 广度优先和深度优先&#xff0c;一般关于图的遍历都是基于邻接矩阵的&#xff0c;考试这些&#xff0c;用的也是邻接矩阵。 本篇文章先介绍广度优先遍历的原理&#xff0c;和代码实现 什么是图的广度优先遍历&#xff1f; 这其实和二叉树的层序遍…

新人学习笔记之(数组1)

一、数组的概念 1.数组&#xff08;Array&#xff09;可以把一组相关的数据一起存放&#xff0c;并提供方便的访问&#xff08;获取&#xff09;方式 2.数组是指一组数据的集合&#xff0c;其中的每个数据被称作元素&#xff0c;在数组中可以存放任意类型的元素&#xff0c;数组…

数据结构——二叉树的基本应用

在此之前我们已经初步了解了二叉树&#xff0c;在介绍堆的基本应用时&#xff0c;我们已经具体介绍了完全二叉树的基本应用&#xff0c;本章我们介绍二叉树的基本应用&#xff0c;这个不止指的是完全二叉树&#xff0c;而是指泛型的二叉树。 二叉树的基本应用&#xff0c;由于…

代码随想录算法训练营第54天|● 392.判断子序列 ● 115.不同的子序列

392. 判断子序列 这个微软面试的时候考过 双指针就行 编辑距离入门题&#xff1a; 思路是一样的 相同字符1 否则从前面顺下来 class Solution:def isSubsequence(self, s: str, t: str) -> bool:dp[[0]*(len(t)1) for _ in range(len(s)1)]for i in range(1,len(s)1):f…

aspose-*的使用

文章目录 aspose-*一、依赖--maven二、需求1、word------>pdf2、doc------>docx2、xls------>xlsx aspose-* 一、依赖–maven 备注&#xff1a;第三方的jar包可以从资源中下载&#xff0c;有上传的 <!--aspose依赖--><dependency><groupId>aspose…

刷代码随想录有感(81):贪心算法——分发饼干

题干&#xff1a; class Solution { public:int findContentChildren(vector<int>& g, vector<int>& s) {sort(g.begin(), g.end());sort(s.begin(), s.end());int index s.size() - 1;int res 0;for(int i g.size() - 1; i > 0; i--){if(index >…