如何基于PyTorch框架自定义数据集类获取数据

在PyTorch框架中,可以通过自定义数据集类来加载和处理数据

要自定义数据集类,需要继承 PyTorch提供的 torch.utils.data.Dataset类,并实现两个主要方法:__len__ __getitem__

下面是一个示例,展示如何基于PyTorch框架来自定义数据集类以获取数据:

import torch
from torch.utils.data import Datasetclass CustomDataset(Dataset):def __init__(self, data):self.data = datadef __len__(self):return len(self.data)def __getitem__(self, index):item = self.data[index]# 在这里对数据进行预处理、转换等操作# 返回一个样本(通常是一个字典)return item# 创建数据集实例
data = [...]  # 数据列表,包含训练样本
dataset = CustomDataset(data)# 创建数据加载器
batch_size = 32
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)# 遍历数据加载器获取数据批次
for batch in dataloader:# 处理每个批次的数据inputs = batch['input']labels = batch['label']# 在这里进行模型训练、推理等操作

在此示例中,定义了一个名为 CustomDataset 的自定义数据集类,该类继承自torch.utils.data. Dataset

__init__方法是构造函数,传入数据列表 data 并将其保存为类的属性 self.data

__len__方法返回数据集的长度,即样本数量

__getitem__方法通过索引获取单个样本

然后,创建了一个数据集实例 dataset,并使用 torch.utils.data.DataLoader 创建了一个数据加载器 dataloader

通过遍历数据加载器可以获取每个批次 输入数据inputs 以及 标签数据labels,进行模型训练、推理等操作

注意:根据具体的应用需求,可以在__getitem__方法中对数据进行预处理、转换等操作,并将处理后的样本作为字典或其他形式返回, 这样,在训练过程中可以方便地获取输入数据和标签数据 ,并进行相应的操作

下面再来看一个例子,该例通过在 __getitem__方法中对数据进行预处理,并最终返回一个包含图片数据、对应的标签数据以及图像文件名的字典

class BipedDataset(Dataset):  # 定义了一个名为BipedDataset的类,它继承自PyTorch的Dataset类,用于自定义数据集'''用于构建一个自定义数据集,可以在训练神经网络时使用它提供了加载图像、预处理数据等功能,以便用于深度学习模型的训练'''def __init__(self,data_root,  img_height,img_width,mean_bgr,  # 图像的均值(以BGR通道顺序表示)train_mode='train',  # 训练模式,可以是 'train' 或 'test' 之一,默认为 'traincrop_img=False,arg=None):'''这是类的构造函数,用于初始化对象的属性它接受许多参数,包括数据根目录 data_root、图像高度 img_height、图像宽度 img_width、均值 mean_bgr、训练模式 train_mode 等'''self.data_root = data_rootself.train_mode = train_modeself.img_height = img_heightself.img_width = img_widthself.mean_bgr = mean_bgrself.crop_img = crop_imgself.arg = argself.data_index = self._build_index()def _build_index(self):  # 用于构建数据索引data_root = os.path.abspath(self.data_root)sample_indices = []  # 用于存储图像和标签的文件路径对# 构建图像和标签的文件路径,其中 images_path 和 labels_path 分别指向数据集中图像和标签的存储路径# 使用两层循环遍历图像目录中的所有文件,构建图像和标签的文件路径,并将其添加到 sample_indices 列表中images_path = os.path.join(data_root,'edges\\imgs',self.train_mode)labels_path = os.path.join(data_root,'edges\\labels',self.train_mode)for file_name_ext in os.listdir(images_path):file_name = os.path.splitext(file_name_ext)[0]sample_indices.append(( os.path.join(images_path, file_name + '.tif'),os.path.join(labels_path, file_name + '.tif'), ))return sample_indices  # 返回构建好的图像和标签的文件路径对列表def __len__(self):  # 返回数据集的长度,即样本的数量return len(self.data_index)def __getitem__(self, idx):  # 用于获取指定索引处的数据样本,它接受一个索引 idx 作为参数# get data sample'''首先,根据索引获取图像路径和标签路径然后,使用OpenCV加载图像和标签接下来,调用self.transform方法进行数据变换最后,返回一个包含图像、对应标签以及图像文件名的字典'''image_path, label_path = self.data_index[idx]# load dataimage = cv2.imread(image_path, cv2.IMREAD_COLOR)label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)image, label = self.transform(img=image, gt=label)  # transform方法:用于对图像和标签进行预处理img_name = os.path.basename(image_path)file_name = os.path.splitext(img_name)[0] + ".png"return dict(images=image, labels=label, file_names=file_name)def transform(self, img, gt):# 将标签转换为浮点型数组,并将其归一化到 [0, 1] 的范围内gt = np.array(gt, dtype=np.float32)if len(gt.shape) == 3:gt = gt[:, :, 0]gt /= 255.  # 将图像转换为浮点型数组,并减去均值 self.mean_bgrimg = np.array(img, dtype=np.float32)img -= self.mean_bgri_h, i_w, _ = img.shape  # 获取图像的高度、宽度和通道数# 根据设定的裁剪大小 crop_size 对图像进行裁剪或缩放crop_size = self.img_height if self.img_height == self.img_width else None  # 对于裁剪过程,它会在图像中随机选择一个位置来裁剪if i_w > crop_size and i_h > crop_size:i = random.randint(0, i_h - crop_size)j = random.randint(0, i_w - crop_size)img = img[i:i + crop_size, j:j + crop_size]gt = gt[i:i + crop_size, j:j + crop_size]else:  #  如果图像的尺寸小于 crop_size,则会使用双线性插值进行缩放# New addidingsimg = cv2.resize(img, dsize=(crop_size, crop_size))gt = cv2.resize(gt, dsize=(crop_size, crop_size))# 对标签gt进行一些额外的处理,然后将图像img和标签gt转换为PyTorch的张量形式gt[gt > 0.1] += 0.2  gt = np.clip(gt, 0., 1.)img = img.transpose((2, 0, 1))img = torch.from_numpy(img.copy()).float()gt = torch.from_numpy(np.array([gt])).float()return img, gt

在此处就定义完成了一个数据集类 BipedDataset

如何使用自定义的 BipedDataset 类来对数据进行加载呢?下面以加载验证集数据为例来进行说明

首先,对这个类进行实例化得到实例化后的数据集对象 dataset_val

dataset_val = BipedDataset(args.input_dir,img_width =args.img_width,img_height =args.img_height,mean_bgr =args.mean_pixel_values,train_mode ='test',arg =args)

其次,将该对象传入DataLoader中创建验证集数据加载器 dataloader_val

dataloader_val = DataLoader(dataset_val,batch_size=1,shuffle=False,num_workers=args.workers)

然后,将数据集加载器 dataloader_val 作为参数传入进行验证过程的函数 validate_one_epoch 中

 val_precision,val_recall,val_IoU = validate_one_epoch(epoch,dataloader_val,model,device,img_test_dir,arg=args)
def validate_one_epoch(epoch, dataloader, model, device, output_dir, arg=None):precision = 0.0recall = 0.0IoU = 0.0model.eval()  with torch.no_grad():for _, sample_batched in enumerate(dataloader):images = sample_batched['images'].to(device)labels = sample_batched['labels'].to(device)file_names = sample_batched['file_names']   preds = model(images)labels = normalize_image(labels)preds = normalize_image(preds)precision += calculate_precision(preds, labels)recall += calculate_recall(preds, labels)IoU += calculate_iou(preds, labels)save_image_batch_to_disk(preds, output_dir, file_names,arg=arg)precision = precision / len(dataloader)recall = recall / len(dataloader)IoU = IoU / len(dataloader)print(time.ctime(), '[Val_Epoch]: {0} Precision:{1}  Recall:{2}  IoU:{3} '.format(epoch, precision, recall, IoU))print(f"第{epoch}次迭代的验证精确度为{precision},验证召回率为{recall},验证交并比为{IoU}")return precision, recall, IoU

最后,我们可以看到将 dataloader_val验证集数据加载器 传入 函数validate_one_epoch 中,通过遍历 dataloader 中的数据,可以通过 自定义类BipedDataset 返回的包含三个元素的字典来获取图像数据、对应的标签数据以及图像文件名,如下图所示

 images = sample_batched['images'].to(device)labels = sample_batched['labels'].to(device)file_names = sample_batched['file_names']   

综上所述, 就是关于如何基于PyTorch深度学习框架自定义数据集来获取数据的详细步骤了,如果你觉得有用,麻烦点赞关注一下哈,谢谢!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/586190.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

R503S指纹识别模块的通信协议

1 物理层 物理层采用串口通讯,波特率 57600,8 位数据位,1 位停止位,无奇偶校验位。 2 数据包格式 模块采用 UART 与上位机通讯,对命令、数据、结果的接收和发送,都采用数据包的形式。对于多字节的&#x…

用CSS中的动画效果做一个转动的表

<!DOCTYPE html> <html lang"en"><head><meta charset"utf-8"><title></title><style>*{margin:0;padding:0;} /*制作表的样式*/.clock{width: 500px;height: 500px;margin:0 auto;margin-top:100px;border-rad…

JavaScript(注释,数据类型,运算符,条件语句)

一 注释 1.1 单行注释 //这是单行注释 1.2 多行注释 /*这是多行注释*/ 1.3 嵌套在HTML文件中注释 <!--注释--> 1.4 注释的快捷键 ctrl/ 二 JavaScript输出方式 2.1 在浏览器中展示对话框&#xff0c;弹出要展…

强化学习计划

文章目录 强化学习强化学习解决的是什么样的问题&#xff1f;举出强化学习与有监督学习的异同点。有监督学习靠样本标签训练模型&#xff0c;强化学习靠的是什么&#xff1f;强化学习的损失函数&#xff08;loss function&#xff09;是什么&#xff1f;写贝尔曼方程&#xff0…

【Linux Shell学习笔记】Linux Shell的流控制

1、 if条件判断 1.1 格式 1.1.1 单分支 if [ 判断表达式 ];then 代码块 fi 1.1.2 双分支 if [ 判断表达式 ];then 代码1 else 代码2 fi 1.1.3 多分支 if [ 判断表达式1 ];then 代码1 elif [ 判断表达式2 ];then 代码2 elif [ 判断表达式3 ];then 代…

【数据结构】双向带头循环链表的实现

前言&#xff1a;在前面我们学习了顺序表、单向链表&#xff0c;今天我们在单链表的基础上进一步来模拟实现一个带头双向链表。 &#x1f496; 博主CSDN主页:卫卫卫的个人主页 &#x1f49e; &#x1f449; 专栏分类:数据结构 &#x1f448; &#x1f4af;代码仓库:卫卫周大胖的…

USB -- STM32F103复合设备(HID+MassStorage)传输讲解(十)

目录 链接快速定位 前沿 1 描述符讲解 1.1 设备描述符 1.2 配置描述符 1.3 接口描述符 1.4 功能描述符 1.5 端点描述符 1.6 字符串描述符 1.7 报告描述符 2 运行演示 链接快速定位 USB -- 初识USB协议&#xff08;一&#xff09; 源码下载请参考链接&#xff1a;…

修改字符串(c++题解)

题目描述 给你一个长度为 的字符串 &#xff0c;由大写和小写英文字母组成。 对字符串 进行 次修改。由两个整数和一个字符组成的元组 表示 -th 修改 &#xff0c;如下所示。 如果是&#xff0c;则将的个字符改为。如果是 &#xff0c;将 中的所有大写字母转换为小写字…

java中PhantomReference WeakReference SoftReference垃圾回收触发时机以及使用场景

java 中对象引用一般引用分为四种情况 强引用 即我们平常创建的对象 Object obj new Object() 垃圾回收触发时机 在没设置 jvm 参数 -XX:PretenureSizeThreshold 和 -XX:MaxTenuringThreshold 的情况下 -XX:PretenureSizeThreshold 的值为 0&#xff0c;即未设置大对象直接…

三巨头对决:深入了解pnpm、yarn与npm

欢迎来到我的博客&#xff0c;代码的世界里&#xff0c;每一行都是一个故事 三巨头对决&#xff1a;深入了解pnpm、yarn与npm 前言包管理器简介npm&#xff08;Node Package Manager&#xff09;&#xff1a;Yarn&#xff1a;pnpm&#xff08;Performant Npm&#xff09;&#…

Linux 服务器安全策略技巧:使用数字证书进行认证

什么是数字证书? 数字证书是一种用于验证和加密网络通信的安全工具。它是由认证机构(CA)颁发的一种电子文件,用于证明某个实体的身份。数字证书包含了实体的公钥和其他相关信息,可以用于验证实体的身份和确保通信的机密性。 为什么使用数字证书进行认证? 在Linux服务器…

基于Mapify的在线艺术地图设计

地图是传递空间信息的有效载体&#xff0c;更加美观、生动的地图产品也是我们追求目标。 那么&#xff0c;我们如何才能制出如下图所示这样一幅艺术性较高的地图呢&#xff1f;今天我们来一探究竟吧&#xff01; 按照惯例&#xff0c;现将网址给出&#xff1a; https://www.m…

微信小程序实现一个天气预报应用程序

微信小程序实现一个天气预报应用程序 第一步创建一个项目第二步项目目录下找到 pages/index/index.wxml 文件第三步在 pages/index/index.wxss 文件中写入样式第四步在 pages/index/index.js 文件中添加以下代码项目简介 第一步创建一个项目 第二步项目目录下找到 pages/index…

在 Python 中编写循环Loops的艺术

在 Python 中编写循环Loops的艺术(The Art of Writing Loops in Python) 文章目录 在 Python 中编写循环Loops的艺术(The Art of Writing Loops in Python)一次获取索引Indexes和值Values通过 Product 函数避免嵌套循环Nested Loops使用 Itertools 模块编写花式循环进行无限循环…

SpringBoot知识

1、Spring和SpringBoot对比 2、版本调整 &#xff08;1&#xff09;先排除是否是JDK与SpringBoot的版本不一致导致的&#xff1a;如JDK1.8和SpringBoot3.1.5冲突&#xff1b; &#xff08;2&#xff09;调整编译版本 &#xff08;3&#xff09;调整maven的jdk &#xff08;4&…

Vscode运行调试文件

文章目录 vscode调试运行流程vscode 执行报错settings.json成功截图 vscode调试运行流程 vscode左侧菜单栏点击运行调试icon&#xff0c;点击菜单右侧栏运行和调试按钮&#xff0c;选择node调试器&#xff0c;js文件行数左边点击添加红色断点&#xff0c;运行当前文件 vscode…

【docker实战】01 Linux上docker的安装

Docker CE是免费的Docker产品的新名称&#xff0c;Docker CE包含了完整的Docker平台&#xff0c;非常适合开发人员和运维团队构建容器APP。 Ubuntu 14.04/16.04&#xff08;使用 apt-get 进行安装&#xff09; # step 1: 安装必要的一些系统工具 sudo apt-get update sudo ap…

湘潭大学-2023年下学期-c语言-作业0x0a-综合1

A 求最小公倍数 #include<stdio.h>int gcd(int a,int b) {return b>0?gcd(b,a%b):a; }int main() {int a,b;while(~scanf("%d%d",&a,&b)){if(a0&&b0) break;printf("%d\n",a*b/gcd(a,b));}return 0; }记住最大公约数的函数&…

gitee上的vue大屏项目

在 Gitee 上,有几个值得注意的 Vue 大屏项目:vue-big-screen-plugin (Gitee): 这是一个基于 Vue3、Typescript、DataV 和 ECharts5 框架的可视化大屏项目。它使用 .vue 和 .tsx 文件构建界面,并采用新版动态屏幕适配方案。这个项目支持数据的动态刷新渲染,内部的 DataV 和 …

linux 网络系统管理 技能大赛 mail赛题配置

比赛 Postfix sdskill.org 的邮件发送服务器 支持smtps(465)协议连接&#xff0c;使用Rserver颁发的证书,证书路径/CA/cacert.pem; 创建邮箱账户“user1~user99”&#xff08;共99个用户&#xff09;&#xff0c;密码为Chinaskill20! Dovecot sdskill.org 的邮件接收服务…