如何使用pytorch的Dataset, 来定义自己的Dataset

Dataset与DataLoader的关系

在这里插入图片描述
在这里插入图片描述

  1. Dataset: 构建一个数据集,其中含有所有的数据样本
  2. DataLoader:将构建好的Dataset,通过shuffle、划分batch、多线程num_workers运行的方式,加载到可训练的迭代容器。
import torch
from torch.utils.data import Dataset, DataLoaderclass MyDataset(Dataset):"""创建自己的数据集"""def __init__(self):"""初始化构建数据集所需要的参数"""passdef __getitem__(self, index):"""来获取数据集中样本的索引"""passdef __len__(self):"""获取数据集中的样本个数"""pass# 实例化自定义的数据集
dataset = MyDataset()
# 将自定义的数据集加载到可训练的迭代容器
train_loader = DataLoader(dataset=dataset,  # 自定义的数据集batch_size=32,  # 数据集中小批量的大小shuffle=True,  # 是否要打乱数据集中样本的次序num_workers=2)  # 是否要并行

实战1:CSV数据集(结构化数据集)

import torch
import numpy as np
from torch.utils.data import Dataset, DataLoaderclass MyDataset(Dataset):"""创建自己的数据集"""def __init__(self, filepath):"""初始化构建数据集所需要的参数"""xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)self.len = xy.shape[0]  # 查看数据集中样本的个数self.x_data = torch.from_numpy(xy[:, :-1])self.y_data = torch.from_numpy(xy[:, [-1]])print("数据已准备好......")def __getitem__(self, index):"""为了支持下标操作, 即索引dataset[index]:来获取数据集中样本的索引"""return self.x_data[index], self.y_data[index]def __len__(self):"""为了使用len(dataset):获取数据集中的样本个数"""return self.lenfile = "D:\\BaiduNetdiskDownload\\Dataset_Dataload\\diabetes1.csv"""" 1.使用 MyDataset类 构建自己的dataset """
mydataset = MyDataset(file)
""" 2.使用 DataLoader 构建train_loader """
train_loader = DataLoader(dataset=mydataset,batch_size=32,shuffle=True,num_workers=0)class MyModel(torch.nn.Module):"""定义自己的模型"""def __init__(self):super().__init__()self.linear1 = torch.nn.Linear(8, 6)self.linear2 = torch.nn.Linear(6, 4)self.linear3 = torch.nn.Linear(4, 1)self.sigmooid = torch.nn.Sigmoid()def forward(self, x):x = self.sigmooid(self.linear1(x))x = self.sigmooid(self.linear2(x))x = self.sigmooid(self.linear3(x))return x# 实例化模型
model = MyModel()# 定义损失函数
criterion = torch.nn.BCELoss(size_average=True)
# 定义优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)if __name__ == "__main__":for epoch in range(10):for i, data in enumerate(train_loader, 0):# 1. 准备数据inputs, labels = data# 2. 前向传播y_pred= model(inputs)loss = criterion(y_pred, labels)print(epoch, i, loss.item())# 3. 反向传播optimizer.zero_grad()loss.backward()# 4. 梯度更新optimizer.step()

在这里插入图片描述

实战2:图片数据集

├── flower_data
—├── flower_photos(解压的数据集文件夹,3670个样本)
—├── train(生成的训练集,3306个样本)
—└── val(生成的验证集,364个样本)

主函数文件main.py
import osimport torch
from torchvision import transformsfrom my_dataset import MyDataSet
from utils import read_split_data, plot_data_loader_imageroot = "../data/flower_data/flower_photos"  # 数据集所在根目录def main():device = torch.device("cuda" if torch.cuda.is_available() else "cpu")print("using {} device.".format(device))train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(root)data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),"val": transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}train_data_set = MyDataSet(images_path=train_images_path,images_class=train_images_label,transform=data_transform["train"])batch_size = 8nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workersprint('Using {} dataloader workers'.format(nw))train_loader = torch.utils.data.DataLoader(train_data_set,batch_size=batch_size,shuffle=True,num_workers=nw,collate_fn=train_data_set.collate_fn)# plot_data_loader_image(train_loader)for epoch in range(100):for step, data in enumerate(train_loader):images, labels = data# 然后在进行相应的训练操作即可if __name__ == '__main__':main()
自定义数据集文件my_dataset.py
from PIL import Image
import torch
from torch.utils.data import Datasetclass MyDataSet(Dataset):"""自定义数据集"""def __init__(self, images_path: list, images_class: list, transform=None):self.images_path = images_pathself.images_class = images_classself.transform = transformdef __len__(self):return len(self.images_path)def __getitem__(self, item):img = Image.open(self.images_path[item])# RGB为彩色图片,L为灰度图片if img.mode != 'RGB':raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))label = self.images_class[item]if self.transform is not None:img = self.transform(img)return img, label@staticmethoddef collate_fn(batch):# 官方实现的default_collate可以参考# https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.pyimages, labels = tuple(zip(*batch))images = torch.stack(images, dim=0)labels = torch.as_tensor(labels)return images, labels
功能文件utils.py(训练集、验证集的划分与可视化)
import os
import json
import pickle
import randomimport matplotlib.pyplot as pltdef read_split_data(root: str, val_rate: float = 0.2):random.seed(0)  # 保证随机结果可复现assert os.path.exists(root), "dataset root: {} does not exist.".format(root)  # 判断路径是否存在# 遍历文件夹,一个文件夹对应一个类别flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]# 排序,保证顺序一致flower_class.sort()# 生成类别名称以及对应的数字索引: 字典{’花名‘:0,’花名‘:1,···}class_indices = dict((k, v) for v, k in enumerate(flower_class))json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)  # 将花名与对应的序号分行保存with open('class_indices.json', 'w') as json_file:json_file.write(json_str)train_images_path = []  # 存储训练集的所有图片路径train_images_label = []  # 存储训练集图片对应索引信息val_images_path = []  # 存储验证集的所有图片路径val_images_label = []  # 存储验证集图片对应索引信息every_class_num = []  # 存储每个类别的样本总数supported = [".jpg", ".JPG", ".png", ".PNG"]  # 支持的文件后缀类型# 遍历每个文件夹下的文件for cla in flower_class:cla_path = os.path.join(root, cla)# 遍历获取supported支持的所有文件路径images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)if os.path.splitext(i)[-1] in supported]# 获取该类别对应的索引image_class = class_indices[cla]# 记录该类别的样本数量every_class_num.append(len(images))# 按比例随机采样验证样本val_path = random.sample(images, k=int(len(images) * val_rate))for img_path in images:if img_path in val_path:  # 如果该路径在采样的验证集样本中则存入验证集val_images_path.append(img_path)val_images_label.append(image_class)else:  # 否则存入训练集train_images_path.append(img_path)train_images_label.append(image_class)print("{} images were found in the dataset.".format(sum(every_class_num)))print("{} images for training.".format(len(train_images_path)))print("{} images for validation.".format(len(val_images_path)))plot_image = Trueif plot_image:# 绘制每种类别个数柱状图plt.bar(range(len(flower_class)), every_class_num, align='center')# 将横坐标0,1,2,3,4替换为相应的类别名称plt.xticks(range(len(flower_class)), flower_class)# 在柱状图上添加数值标签for i, v in enumerate(every_class_num):plt.text(x=i, y=v + 5, s=str(v), ha='center')# 设置x坐标plt.xlabel('image class')# 设置y坐标plt.ylabel('number of images')# 设置柱状图的标题plt.title('flower class distribution')plt.show()return train_images_path, train_images_label, val_images_path, val_images_labeldef plot_data_loader_image(data_loader):batch_size = data_loader.batch_sizeplot_num = min(batch_size, 4)json_path = './class_indices.json'assert os.path.exists(json_path), json_path + " does not exist."json_file = open(json_path, 'r')class_indices = json.load(json_file)for data in data_loader:images, labels = datafor i in range(plot_num):# [C, H, W] -> [H, W, C]img = images[i].numpy().transpose(1, 2, 0)# 反Normalize操作img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255label = labels[i].item()plt.subplot(1, plot_num, i+1)plt.xlabel(class_indices[str(label)])plt.xticks([])  # 去掉x轴的刻度plt.yticks([])  # 去掉y轴的刻度plt.imshow(img.astype('uint8'))plt.show()def write_pickle(list_info: list, file_name: str):with open(file_name, 'wb') as f:pickle.dump(list_info, f)def read_pickle(file_name: str) -> list:with open(file_name, 'rb') as f:info_list = pickle.load(f)return info_list

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/639277.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Qt6入门教程 9:QWidget、QMainWindow和QDialog

目录 一.QWidget 1.窗口和控件 2.事件 二.QMainWindow 三.QDialog 1.模态对话框 1.1模态对话框 1.2.半模态对话框 2.非模态对话框 在用Qt Creator创建Qt Widgets项目时,会默认提供三种基类以供选择,它们分别是QWidget、QMainWIndow和QDialog&am…

SQL 注入总结(详细)

一、前言 这篇文章是最近学习 SQL 注入后的笔记,里面整理了 SQL 常见的注入方式,供大家学习了解 SQL 注入的原理及方法,也方便后续自己回顾,如有什么错误的地方欢迎指出! 二、判断注入类型 按照注入点类型分类 数字型…

外贸自建站如何建立?海洋建站的操作指南?

外贸自建站的建站流程什么?做跨境怎么搭建外贸网站? 外贸自建站成为企业开拓国际市场、提升品牌形象的重要途径。然而,对于许多企业而言,如何高效地进行外贸自建站仍然是一个挑战。海洋建站将带您一步步探讨外贸自建站的关键步骤…

计算机网络——面试问题

1 从输⼊ URL 到⻚⾯展示到底发⽣了什么? 1. 先检查浏览器缓存⾥是否有缓存该资源,如果有直接返回;如果没有进⼊下⼀ 步⽹络请求。 2. ⽹络请求前,进⾏ DNS 解析 ,以获取请求域名的 IP地址 。 3. 浏览器与服务器…

《WebKit 技术内幕》之七(3): 渲染基础

3 渲染方式 3.1 绘图上下文(GraphicsContext) 上面介绍了WebKit的内部表示结构,RenderObject对象知道如何绘制自己,但是,问题是RenderObject对象用什么来绘制内容呢?在WebKit中,绘图操作被定…

xcode 设置 ios苹果图标,为Flutter应用程序配置iOS图标

图标设置 1,根据图片构建各类尺寸的图标2.xcode打开ios文件3.xcode设置图标4.打包提交审核,即可(打包教程可通过我的主页查找) 1,根据图片构建各类尺寸的图标 工具网址:https://icon.wuruihong.com/ 下载之后文件目录如下 拷贝到项目的ios\Runner\Assets.xcassets\AppIcon.ap…

java简单的抽奖工具类(含测试方法)

文章目录 结果代码 结果 代码 import lombok.AllArgsConstructor; import lombok.Data; import lombok.ToString;import java.util.ArrayList; import java.util.List;/****/ public class LotteryUtils {public static void main(String[] args) throws InterruptedException…

PythonNet,Csharp如何白嫖Python生态和使用Matplotlib

文章目录 前言PythonNet环境配置Python环境配置Csharp Nuget配置运行代码测试运行结果 总结 前言 我既然用Csharp去尝试学习机器视觉,我就想试试用Csharp去使用Python的库。 这个世界上有没有编程语言既有Python的开发效率,又有C/C/ PythonNet Pythonne…

Android:JNI实战,加载三方库、编译C/C++

一.概述 Android Jni机制让开发者可以在Java端调用到C/C,也是Android应用开发需要掌握的一项重要的基础技能。 计划分两篇博文讲述Jni实战开发。 本篇主要从项目架构上剖析一个Android App如何通过Jni机制加载三方库和C/C文件。 二.Native C Android Studio可…

精准核酸检测 - 华为OD统一考试

OD统一考试(C卷) 分值: 100分 题解: Java / Python / C 题目描述 为了达到新冠疫情精准防控的需要,为了避免全员核酸检测带来的浪费,需要精准圈定可能被感染的人群。 现在根据传染病流调以及大数据分析&a…

【代码实战】从0到1实现transformer

获取数据 import pathlibimport tensorflow as tf# download dataset provided by Anki: https://www.manythings.org/anki/ text_file tf.keras.utils.get_file(fname"fra-eng.zip",origin"http://storage.googleapis.com/download.tensorflow.org/data/fra-…

transdata笔记:手机数据处理

1 mobile_stay_duration 每个停留点白天和夜间的持续时间 transbigdata.mobile_stay_duration(staydata, col[stime, etime], start_hour8, end_hour20) 1.1 主要参数 staydata停留数据(每一行是一条数据)col 列名,顺序为[‘starttime’,…

[足式机器人]Part2 Dr. CAN学习笔记- 最优控制Optimal Control Ch07-2 动态规划 Dynamic Programming

本文仅供学习使用 本文参考: B站:DR_CAN Dr. CAN学习笔记 - 最优控制Optimal Control Ch07-2 动态规划 Dynamic Programming 1. 基本概念2. 代码详解3. 简单一维案例 1. 基本概念 Richoard Bell man 最优化理论: An optimal policy has the …

纯C无操作系统轻量协程库Protothread使用记录

文章目录 目的源码说明使用演示总结 目的 在单片机开发中很多时候都是无操作系统环境,这时候如果要实现异步操作,并且流程逻辑比较复杂时处理起来会稍稍麻烦。这时候可以试试 Protothread 这个协程库。 官网: https://dunkels.com/adam/pt/…

深入剖析:Kafka流数据处理引擎的核心面试问题解析75问(5.7万字参考答案)

Kafka 是一款开源的分布式流处理平台,被广泛应用于构建实时数据管道、日志聚合、事件驱动的架构等场景。本文将深入探究 Kafka 的基本原理、特点以及其在实际应用中的价值和作用。 Kafka 的基本原理是建立在发布-订阅模式之上的。生产者将消息发布到主题&#xff08…

37-WEB漏洞-反序列化之PHPJAVA全解(上)

WEB漏洞-反序列化之PHP&JAVA全解(上) 一、PHP 反序列化原理二、案例演示2.1、无类测试2.1.1、本地2.1.2、CTF 反序列化小真题2.1.3、CTF 反序列化类似题 2.2、有类魔术方法触发2.2.1、本地2.2.2、网鼎杯 2020 青龙大真题 三、参考资料 一、PHP 反序列…

SpringMVC(八)处理AJAX请求

一、处理AJAX之准备工作: 首先我们创建一个新的工程: 我们将pom.xml复制过来: <?xml version="1.0" encoding="UTF-8"?> <project xmlns="http://maven.apache.org/POM/4.0.0"xmlns:xsi="http://www.w3.org/2001/XMLSchema-in…

【项目日记(三)】内存池的整体框架设计

&#x1f493;博主CSDN主页:杭电码农-NEO&#x1f493;   ⏩专栏分类:项目日记-高并发内存池⏪   &#x1f69a;代码仓库:NEO的学习日记&#x1f69a;   &#x1f339;关注我&#x1faf5;带你做项目   &#x1f51d;&#x1f51d; 开发环境: Visual Studio 2022 项目日…

MES管理系统为何成为汽配企业的刚需

随着经济全球化、产品定制化及安全法规的严格化&#xff0c;汽配企业的经营环境变得越来越复杂。中国劳动力资源和原辅料成本的持续上升&#xff0c;导致行业利润率不断下滑。为了应对这些挑战&#xff0c;汽配企业需要引入一种精益制造和管理的工具&#xff0c;而MES管理系统正…

四款通用组织架构图模板-一键高清导出

组织架构图作为一种直观的图形化工具&#xff0c;能够帮助我们更好地理解和规划组织结构&#xff0c;提高工作效率。今天&#xff0c;我们就为大家带来四款通用组织架构图模板&#xff0c;让你一键高清导出&#xff0c;轻松搞定组织架构设计&#xff01; 第一款&#xff1a;某基…