【加载数据--自定义自己的Dataset类】

【加载数据自定义自己的Dataset类】

  • 1 加载数据
  • 2 数据转换
  • 3 自定义Dataset类
  • 4 划分训练集和测试集
  • 5 提取一批次数据并绘制样例图

假设有四种天气图片数据全部存放与一个文件夹中,如下图所示:

├─dataset2
│      cloudy1.jpg
│      cloudy10.jpg
│      cloudy100.jpg
│      cloudy101.jpg
│      cloudy102.jpg
│      cloudy103.jpg
│      cloudy104.jpg
│      cloudy105.jpg
......

1 加载数据

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import torchvision
import glob
from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Imageimport glob
img_dir = r'./dataset2/*.jpg'
imgs = glob.glob(img_dir) # 读取所有图片路径
print(imgs[:3]) # 打印前3张图片species = ['cloudy', 'rain', 'shine', 'sunrise']species_to_idx = dict((c, i) for i, c in enumerate(species))		# 建立类别和序号字典
print(species_to_idx)idx_to_species = dict((v, k) for k, v in species_to_idx.items())	# 反转类别和序号
print(idx_to_species)

输出如下:

['./dataset2\\cloudy1.jpg','./dataset2\\cloudy10.jpg','./dataset2\\cloudy100.jpg']{'cloudy': 0, 'rain': 1, 'shine': 2, 'sunrise': 3}{0: 'cloudy', 1: 'rain', 2: 'shine', 3: 'sunrise'}

读取路径加载序号作为标签

labels = []
for img in imgs:for i, c in enumerate(species):if c in img:labels.append(i)print(labels[:3])

输出如下:

[0, 0, 0]

方法1:提前划分训练集和测试集,使用乱序后的index进行划分

np.random.seed(2022)
index = np.random.permutation(count)
imgs = np.array(imgs)[index]
labels = np.array(labels, dtype=np.int64)[index]sep = int(count*0.8)
train_imgs = imgs[ :sep]
train_labels = labels[ :sep]
test_imgs = imgs[sep: ]
test_labels = labels[sep: ]

2 数据转换

transforms = transforms.Compose([transforms.Resize((96, 96)),transforms.ToTensor(),transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
])

3 自定义Dataset类

class WT_dataset(Dataset):def __init__(self, imgs_path, lables):self.imgs_path = imgs_pathself.lables = lablesdef __getitem__(self, index):img_path = self.imgs_path[index]lable = self.lables[index]pil_img = Image.open(img_path)pil_img = pil_img.convert("RGB")pil_img = transforms(pil_img)return pil_img, labledef __len__(self):return len(self.imgs_path)# 加载数据
dataset = WT_dataset(imgs, labels)

4 划分训练集和测试集

count = len(dataset)
print(count)# 方法2:划分训练集和测试集
train_count = int(0.8*count)
test_count = count - train_count
train_dataset, test_dataset = data.random_split(dataset, [train_count, test_count])
print(len(train_dataset), len(test_dataset))# 批量加载数据
BTACH_SIZE = 16
train_dl = torch.utils.data.DataLoader(train_dataset,batch_size=BTACH_SIZE,shuffle=True
)test_dl = torch.utils.data.DataLoader(test_dataset,batch_size=BTACH_SIZE,
)

5 提取一批次数据并绘制样例图

imgs, labels = next(iter(train_dl))	#提取一批次数据
print(imgs.shape)
im = imgs[0].permute(1, 2, 0)	# 将通道所在列放在后
print(im.shape)plt.figure(figsize=(12, 8))
for i, (img, label) in enumerate(zip(imgs[:6], labels[:6])):img = (img.permute(1, 2, 0).numpy() + 1)/2plt.subplot(2, 3, i+1)plt.title(idx_to_species.get(label.item()))plt.imshow(img)
plt.savefig('pics/example1.jpg', dpi=400)

输出如下:

torch.Size([16, 3, 96, 96])torch.Size([3, 96, 96])torch.Size([96, 96, 3])

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/90523.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

物联网、工业大数据平台 TDengine 与苍穹地理信息平台完成兼容互认证

当前,在政府、军事、城市规划、自然资源管理等领域,企业对地理信息的需求迅速增加,人们需要更有效地管理和分析地理数据,以进行决策和规划。在此背景下,“GIS 基础平台”应运而生,它通常指的是一个地理信息…

FL Studio21.1电脑试用体验版音乐制作软件

我一直以来对音乐艺术都很感兴趣。最近我接触到了一款名为 FL Studio 的电脑版音乐制作软件,深感其强大功能和广泛适用性。通过使用这款软件,我不仅深入了解了音乐制作的过程与技巧,也加深了对音乐创作的理解。 FL Studio 最初是一款针对 MI…

四川玖璨电子商务有限公司抖音培训引领电商新潮

近年来,随着电子商务的迅猛发展,抖音这个社交媒体平台也逐渐成为了商家必争之地。四川玖璨电子商务有限公司抖音培训,为你解锁电商流量密码,助你一飞冲天! 一、抖音电商:下一个电商蓝海 作为拥有海量用户的…

爬虫抓取数据时显示超时,是爬虫IP质量问题?

当我们进行网络爬虫开发时,有时会遇到抓取数据时出现超时的情况。这可能是由于目标网站对频繁请求做了限制,或者是由于网络环境不稳定造成的。其中,爬虫IP的质量也是导致超时的一个重要因素。本文将探讨抓取数据时出现超时的原因,…

前端开发 vs. 后端开发:编程之路的选择

文章目录 前端开发:用户界面的创造者1. HTML/CSS/JavaScript:2. 用户体验设计:3. 响应式设计:4. 前端框架: 后端开发:数据和逻辑的构建者1. 服务器端编程:2. 数据库:3. 安全性&#…

删除有序数组里的重复项 -力扣(Java)

给你一个 非严格递增排列 的数组 nums ,请你 原地 删除重复出现的元素,使每个元素 只出现一次 ,返回删除后数组的新长度。元素的 相对顺序 应该保持 一致 。然后返回 nums 中唯一元素的个数。 考虑 nums 的唯一元素的数量为 k ,你…

深度学习入门教学——对抗攻击和防御

目录 一、对抗样本 二、对抗攻击 三、对抗防御 一、对抗样本 对抗样本是指对机器学习模型的输入做微小的故意扰动,导致模型输出结果出现错误的样本。深度神经网络在经过大量数据训练后,可以实现非常复杂的功能。在语音识别、图像识别、自然语言处理等任务上被广…

整型提升——(巩固提高——字符截取oneNote笔记详解)

文章目录 前言一、整型提升是什么?二、详细图解1.图解展示 总结 前言 提示:这里可以添加本文要记录的大概内容: 整型提升是数据存储的重要题型,也是计算机组成原理的核心知识点。学习c语言进阶的时候,了解内存中数据怎么存&#…

求和——快速幂

# 求和 ## 题目描述 求 1^b2^b…… a^b 的和除以 10^4 的余数。 ## 输入格式 第一行一个整数 N,表示共有 N 组测试数据。 对于每组数据,一行两个整数 a,b。 ## 输出格式 对于每组数据,一行一个整数,表示答案。 ### 样例输入 …

APS手动编译,CLion测试

一、简介 APSI——Asymmetric PSI: 私用集交集(PSI)是指这样一种功能,即双方都持有一组私用项,可以在不向对方透露任何其他信息的情况下检查他们有哪些共同项。集合大小的上限被假定为公共信息,不受保护。 …

QtCreator报大量未知标识符错误的解决方法

目录 前言背景介绍问题1问题1解决方法问题2问题2 解决方法总结 前言 本文记录了在使用QtCreator开发时遇到的一个错误,导致编译时出现大量的“未知标识符”,经过一番努力最终解决了这个问题,特在此记录。 背景介绍 Qt项目在麒麟V10 系统下…

【DTEmpower案例操作教程】向导式建模

DTEmpower是由天洑软件自主研发的一款通用的智能数据建模软件,致力于帮助工程师及工科专业学生,利用工业领域中的仿真、试验、测量等各类数据进行挖掘分析,建立高质量的数据模型,实现快速设计评估、实时仿真预测、系统参数预警、设…

X509证书结构

使用ASN.1语言描述,我们可以将X509Certificate抽象为以下结构: Certificate :: SEQUENCE {tbsCertificate TBSCertificate,signatureAlgorithm AlgorithmIdentifier,signature BIT STRING }即基本证书域、签名算法、签名值。 其…

手机上记录的备忘录内容怎么分享到电脑上查看?

手机已经成为了我们生活中不可或缺的一部分,我们用它来处理琐碎事务,记录生活点滴,手机备忘录就是我们常用的工具之一。但随着工作的需要,我们往往会遇到一个问题:手机上记录的备忘录内容,如何方便地分享到…

设计模式——3. 抽象工厂模式

1. 说明 抽象工厂模式(Abstract Factory Pattern)是一种创建型设计模式,它提供了一种创建一组相关或依赖对象的方式,而无需指定它们的具体类。抽象工厂模式是工厂模式的扩展,它关注于创建一组相关的对象家族,而不仅仅是一个单一的对象。 抽象工厂模式通常涉及以下几个角…

微信小游戏从零到上线系列文章整理,建议收藏

引言 本系列是《从零开始开发贪吃蛇小游戏到上线系列》,欢迎大家关注分享收藏订阅。 大家中秋快乐,我是亿元程序员,一位有着8年游戏行业经验的主程。前面笔者给大家讲解了微信小游戏如何从零到上线的流程。可能很多小伙伴都还没有看到。 本…

【Oracle】Oracle系列之十一--PL/SQL

文章目录 往期回顾前言1. PL/SQL语句块组成2. 变量的声明与使用(1)变量声明(2)变量赋值 3. 控制语句(1)分支语句(2)循环语句 4. 异常处理(1)系统异常&#xf…

某高校的毕设

最近通过某个平台接的单子,最后Kali做的测试没有公开可以私聊给教程。 下面是规划与配置 1.vlan方面:推荐一个vlan下的所有主机为一个子网网段 连接电脑和http客户端的接口配置为access接口 交换机与交换机或路由器连接的接口配置为trunk接口---也可以…

Golang中的类型转换介绍

Golang中存在4种类型转换,分别是:断言、显式、隐式、强制。下面我将一一介绍每种转换使用场景和方法 一、断言类型转换 主要是判断变量是否可以转换成某一类型。断言主要用于变量是interface{}类型(接口类型)的情况,…

Python-表白小程序练习

测试代码 在结果导向的今天,切勿眼高于顶,不论用任何方法能转换、拿出实际成果东西才是关键,即使一个制作很简易的程序,你想将其最终生成可运行的版本也是需要下一番功夫的。不要努力成为一个嘴炮成功者,要努力成为一个有价值的人…