卷积神经网络|制作自己的Dataset

在编写代码训练神经网络之前,导入数据是必不可少的。PyTorch提供了许多预加载的数据集(如FashionMNIST),这些数据集 子类并实现特定于特定数据的函数。 

它们可用于对模型进行原型设计和基准测试,加载这些数据集是十分简单的。好吧,那如何加载自己制作的数据集呢?

简单来讲,自定义数据集类必须实现三个函数:__init__、__len__和__getitem__。下面代码就实现了一个Dataset

import osimport torchfrom torch.utils.data import Datasetfrom torchvision import transformsfrom PIL import Imageimport numpy as npclass MyDataset(Dataset):    def __init__(self, path_file,transform=None,label_transform=None):        self.path_file=path_file        self.imgs=[name for name in os.listdir(path_file)]#获取path_file路径下所有文件名        self.transform = transform        self.label_transform = label_transform    def __len__(self):        return len(self.imgs)    def __getitem__(self, idx):        #get the image        img_path = os.path.join(self.path_file,self.imgs[idx])#获得图片完整路径        image=Image.open(img_path)        image=image.resize((28,28))#修改图片为默认大小        image = np.array(image)        image=torch.from_numpy(image)#将numpy数组转换为张量        image=image.permute(2,0,1)#将H,W,C转换为C,H,W        if self.transform:            image = self.transform(image)        #get the label        str1=self.imgs[idx].split('.')        label=torch.tensor(eval(str1[1]))        if self.label_transform:            label=self.label_transform(label)         return image, label

注:上述代码从路径path_file读取文件,准确来讲应该是我们准备的训练图片,格式如下:     

                 cat1.0.jpg

                  cat2.0.jpg

                  ...

                  dog1.1.jpg

                  dog2.1.jpg

                  ...

图片名重要含义:类别(0,1等)

而cat1,dog1这些并不重要,因为0,1,已经反映了图片的类别,这里仅仅是一个习惯,同样jpg也是如此。

实际上,在我们准备图片时,图片名往往不是这样,但直接写个简单的文件处理程序便很容易转变为上述格式

之所以这样命名,就是为容易获得图片和对应的类别,也就是实现自己的Dataset。当然,其它还有许多方法,但核心就是加载自己的数据时获得图片和对应的类别。

再次看一下实现自己的Dataset的架构:

class CustomImageDataset(Dataset):    def __init__(self, path_file, transform=None, target_transform=None):        ...        ...        ...    def __len__(self):        return len(...)    def __getitem__(self, idx):        ...        ...        ...        if self.transform:            image = self.transform(image)        if self.label_transform:            label = self.label_transform(label)        return image, label

在训练模型时,我们通常希望 在“小批量”中传递样本,在每个时期重新洗牌数据以减少模型过度拟合,并使用 Python 的 加快数据检索速度。

DataLoader是一个迭代对象,它在一个简单的 API 中为我们抽象了这种复杂性。下面我们将Dataset带入DataLoader.

path="E:\\3-10\\dogandcats\\train"#图片所在目录training_data=MyDataset(path)train_dataloader = torch.utils.data.DataLoader(training_data, batch_size=2, shuffle=True)

让我们run一下:

>>> trainimg,label=next(iter(train_dataloader))>>> trainimg.size()torch.Size([2, 3, 28, 28])>>> label.size()torch.Size([2])

结果符合预期,与在使用pytorch预加载的数据集格式一样!

点点点,赞和在看都在这儿!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/595099.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

阿里云服务器8080端口安全组开通图文教程

阿里云服务器8080端口开放在安全组中放行,Tomcat默认使用8080端口,8080端口也用于www代理服务,阿腾云atengyun.com以8080端口为例来详细说下阿里云服务器8080端口开启教程教程: 阿里云服务器8080端口开启教程 阿里云服务器8080端…

SkyWalking 快速入门

SkyWalking 是一个基于 Java 开发的分布式系统的应用程序性能监视工具,专为微服务、云原生架构和基于容器(Docker、K8s、Mesos)架构而设计。 一、SkyWalking 简介 SkyWalking 是观察性分析平台和应用性能管理系统。 提供分布式追踪、服务网格…

输入输出流

1.输入输出流 输入/输出流类:iostream---------i input(输入) o output(输出) stream:流 iostream: istream类:输入流类-------------cin:输入流类的对象 ostream类…

使用Tensorboard可视化网络结构(基于pytorch)

前言 我们在搭建网络模型的时候,通常希望可以对自己搭建好的网络模型有一个比较好的直观感受,从而更好地了解网络模型的结构,Tensorboard工具的使用就给我们提供了方便的途径 Tensorboard概况 Tensorboard是由Google公司开源的一款可视化工…

【大模型+编程助手】国内编程助手安装与使用(CodeGeeX,Baidu Comate)

百度 Comate (可试用):https://comate.baidu.com/ 清华CodeGeeX (开源,暂时免费):https://codegeex.cn/ 华为:https://devcloud.cn-north-4.huaweicloud.com/codeartside/home?productsnap# 开发平台VScod…

贪吃蛇C语言实现(有源码)

前言 之前学了一点easyx图形库的使用&#xff0c;掌握一些基本用法后就用贪吃蛇来进行实战了&#xff0c;运行视频放在csdn视频那一栏了&#xff0c;之前的烟花也是。 1.头文件 #define _CRT_SECURE_NO_WARNINGS 1 #include<easyx.h> #include<conio.h> #includ…

【Vue2+3入门到实战】(21)认识Vue3、使用create-vue搭建Vue3项目、熟悉项目和关键文件

目录 一、认识Vue31. Vue2 选项式 API vs Vue3 组合式API2. Vue3的优势 二、 使用create-vue搭建Vue3项目1. 认识create-vue2. 使用create-vue创建项目 三、 熟悉项目和关键文件四、总结 一、认识Vue3 1. Vue2 选项式 API vs Vue3 组合式API <script> export default {…

力扣题:高精度运算-1.2

力扣题-1.2 [力扣刷题攻略] Re&#xff1a;从零开始的力扣刷题生活 力扣题1&#xff1a;415. 字符串相加 解题思想&#xff1a;从后往前遍历两个字符串,然后进行相加即可 class Solution(object):def addStrings(self, num1, num2):""":type num1: str:type …

Navicat 技术干货 | 如何查看关系型数据库(MySQL、PostgreSQL、SQL Server、 Oracle)查询的运行时间

在数据库优化中&#xff0c;理解和监控查询运行时间是至关重要的。无论你是数据库管理员、开发人员或是参与性能调优的人员&#xff0c;知道如何查看查询运行时间能为你的数据库操作提供有价值的参考。本文中&#xff0c;我们将探索几款热门的关系数据库&#xff08;如 MySQL、…

ubuntu下快速安装使用docker

一、ubuntu下安装docker 1、命令行终端内直接输入docker 可以看到安装docker的命令提示 2、安装需要注意的几个点 (1)需要管理员权限 (2)更新软件源后再进行安装 命令行输入命令 sudo apt-get update #更新软件源 sudo apt install docker.io #安装docker 如图所示 二…

PostgreSQL荣获DB-Engines 2023年度数据库

数据库流行度排名网站 DB-Engines 2024 年 1 月 2 日发布文章宣称&#xff0c;PostgreSQL 荣获 2023 年度数据库管理系统称号。 PostgreSQL 在过去一年中获得了比其他 417 个产品更多的流行度增长&#xff0c;因此获得了 2023 年度 DBMS。 DB-Engines 通过计算每种数据库 2024 …

性能优化-OpenMP基础教程(一)

本文主要介绍OpenMP并行编程技术&#xff0c;编程模型、指令和函数的介绍、以及OpenMP实战的几个例子。希望给OpenMP并行编程者提供指导。 &#x1f3ac;个人简介&#xff1a;一个全栈工程师的升级之路&#xff01; &#x1f4cb;个人专栏&#xff1a;高性能&#xff08;HPC&am…

Mysql count统计去重的数据

不去重&#xff0c;是4 &#xff1a; SELECT COUNT(NAME) FROM test2 明显里面包含了2个 name 等于 mike的数据&#xff0c; 所以需要做去重 &#xff1a; 通过结合 count 函数和 DISTINCT 关键字 SELECT COUNT(DISTINCT NAME) FROM test2 好了就到这。

消息中间件 —— ActiveMQ 使用及原理详解

目录 一. 前言 二. JMS 规范 2.1. 基本概念 2.2. JMS 体系结构 三. ActiveMQ 使用 3.1. ActiveMQ Classic 和 ActiveMQ Artemis 3.2. Queue 模式&#xff08;P2P&#xff09; 3.3. Topic 模式&#xff08;Pub/Sub&#xff09; 3.4. 持久订阅 3.5. 消息传递的可靠性 …

数模学习day06-主成分分析

主成分分析(Principal Component Analysis,PCA)主成分分析是一种降维算法&#xff0c;它能将多个指标转换为少数几个主成分&#xff0c;这些主成分是原始变量的线性组合&#xff0c;且彼此之间互不相关&#xff0c;其能反映出原始数据的大部分信息。一般来说当研究的问题涉及到…

当hashCode相同时,equals是否也相同?

在Java中&#xff0c;理解对象的这两个基本方法—hashCode和equals对于编码是至关重要的&#xff0c;尤其是在处理集合类如HashMap和HashSet时。然而&#xff0c;一个常见的误解是&#xff0c;如果两个对象有相同的哈希码&#xff08;hashCode&#xff09;&#xff0c;那么它们…

iec104和iec61850

iec104和iec61850 IEC104 规约详细解读(一) 协议结构 IEC104 规约详细解读(二)交互流程以及协议解析 61850开发知识总结与分享【1】 Get the necesarry projects next to each other in the same directory; $ git clone https://github.com/robidev/iec61850_open_server.g…

ES(Elasticsearch)的基本使用

一、常见的NoSQL解决方案 1、redis Redis是一个基于内存的 key-value 结构数据库。Redis是一款采用key-value数据存储格式的内存级NoSQL数据库&#xff0c;重点关注数据存储格式&#xff0c;是key-value格式&#xff0c;也就是键值对的存储形式。与MySQL数据库不同&#xff0…

DNS安全与访问控制

一、DNS安全 1、DNSSEC原理 DNSSEC依靠数字签名保证DNS应答报文的真实性和完整性。权威域名服务器用自己的私有密钥对资源记录&#xff08;Resource Record, RR&#xff09;进行签名&#xff0c;解析服务器用权威服务器的公开密钥对收到的应答信息进行验证。如果验证失败&…

数字信号处理期末复习——计算小题(二)

个人名片&#xff1a; &#x1f981;作者简介&#xff1a;一名喜欢分享和记录学习的在校大学生 &#x1f42f;个人主页&#xff1a;妄北y &#x1f427;个人QQ&#xff1a;2061314755 &#x1f43b;个人邮箱&#xff1a;2061314755qq.com &#x1f989;个人WeChat&#xff1a;V…