Pytorch公共数据集、tensorboard、DataLoader使用

本文将主要介绍torchvision.datasets的使用,并以CIFAR-10为例进行介绍,对可视化工具tensorboard进行介绍,包括安装,使用,可视化过程等,最后介绍DataLoader的使用。希望对你有帮助

Pytorch公共数据集

torchvision.datasets.*
在这里插入图片描述
torchvision是pytorch的一个图形库,torchvision包由流行的数据集、模型架构和计算机视觉的通用图像转换组成。例如tensorboard、transfroms

在这里将主要介绍torchvision.datasets.*

在这里插入图片描述

在datasets中包含了许多公共的应用于图像领域的数据集。包含:图像分类、图像检测或分割、光流法、立体声匹配等

在本章当中,将以图像分类领域的CIFAR10数据集作为torchvision.datasets的例子进行介绍,因为他比较小,下载比较快。

CIFAR-10是一个更接近普适物体的彩色图像数据集。CIFAR-10 是由Hinton 的学生Alex Krizhevsky 和Ilya Sutskever 整理的一个用于识别普适物体的小型数据集。一共包含10 个类别的RGB 彩色图片。

每个图片的尺寸为32 × 32 ,每个类别有6000个图像,数据集中一共有50000 张训练图片和10000 张测试图片。

下面是数据集中的类,以及每个类的10张随机图像

在这里插入图片描述

参数介绍

这些数据集的参数也是大同小异,由于CIFAR10数据集较小,下载就快。大家可以触类旁通

在这里插入图片描述

  • root :即指定数据集要下载在哪一个文件夹里面
  • train(bool):如果True即为训练集,否则False则为测试集
  • transform :进行图像变换的各种操作,如Resize、RandomCrop、Compose
  • target_transform :对于标签进行transform 操作
  • download :是否下载数据集,建议设置为True即可
import torch
import torchvision
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter
#transform属性
trans_tool = torchvision.transforms.Compose([torchvision.transforms.ToTensor()  # 转为Tensor类型# torchvision.transforms.Resize((5, 5))  # 进行大小裁剪
])# 数据集划分
tran_dataset = torchvision.datasets.CIFAR10(root="./dataset",train=True,transform=trans_tool,download=True)
test_dataset = torchvision.datasets.CIFAR10(root="./dataset",train=False,transform=trans_tool,download=True)
print(tran_dataset[0])  
#Tensorboard
writer = SummaryWriter("logs")
for i in range(10):#显示测试集前10的图片img, label = tran_dataset[i]writer.add_image("CIFAR10",img,i)
writer.close()
Files already downloaded and verified
Files already downloaded and verified
(tensor([[[0.2314, 0.1686, 0.1961,  ..., 0.6196, 0.5961, 0.5804],[0.0627, 0.0000, 0.0706,  ..., 0.4824, 0.4667, 0.4784],[0.0980, 0.0627, 0.1922,  ..., 0.4627, 0.4706, 0.4275],...,[0.8157, 0.7882, 0.7765,  ..., 0.6275, 0.2196, 0.2078],[0.7059, 0.6784, 0.7294,  ..., 0.7216, 0.3804, 0.3255],[0.6941, 0.6588, 0.7020,  ..., 0.8471, 0.5922, 0.4824]],[[0.2431, 0.1804, 0.1882,  ..., 0.5176, 0.4902, 0.4863],[0.0784, 0.0000, 0.0314,  ..., 0.3451, 0.3255, 0.3412],[0.0941, 0.0275, 0.1059,  ..., 0.3294, 0.3294, 0.2863],...,[0.6667, 0.6000, 0.6314,  ..., 0.5216, 0.1216, 0.1333],[0.5451, 0.4824, 0.5647,  ..., 0.5804, 0.2431, 0.2078],[0.5647, 0.5059, 0.5569,  ..., 0.7216, 0.4627, 0.3608]],[[0.2471, 0.1765, 0.1686,  ..., 0.4235, 0.4000, 0.4039],[0.0784, 0.0000, 0.0000,  ..., 0.2157, 0.1961, 0.2235],[0.0824, 0.0000, 0.0314,  ..., 0.1961, 0.1961, 0.1647],...,[0.3765, 0.1333, 0.1020,  ..., 0.2745, 0.0275, 0.0784],[0.3765, 0.1647, 0.1176,  ..., 0.3686, 0.1333, 0.1333],[0.4549, 0.3686, 0.3412,  ..., 0.5490, 0.3294, 0.2824]]]), 6)

利用tensorboard查看,在控制台输入即可:

tensorboard --logdir 目录

在这里插入图片描述

关于torchvision.datasets.CIFAR10介绍已经讲解完毕,后续内容为扩展内容,包括:tensorboard、DataLoader的使用

tensorboard可视化工具

torch.utils.tensorboard

在Pytorch发布后,网络及训练过程的可视化工具也相应地被开发出来,方便用户监督所建立模型的结构和训练过程

深度学习网络通常具有很深的层次结构,而且层与层之间通常会有并联、串联等连接方式,利用有效的工具将建立的深度学习网络结构有层次化的展示,这就需要使用相关的深度学习网络结构可视化库。

从Pytorch1.1之后,加入了tensorboard

一般安装新版的pytorch会自动安装,如果没安装,则在终端命令行下使用下面命令即可安装

pip install tensorboard
  • add_image()添加图片

  • add_scalar()添加标量数据

主要代码如下

from torch.utils.tensorboard import SummaryWriterwriter = SummaryWriter("logs")  # 创建SummaryWriter,将运行结果存logs文件夹中
for i in range(100):writer.add_scalar("y=2x",2*i,i)  # 第一个参数相当于标题,第二个参数就相当于纵坐标的值,第三个参数就相当于横坐标的值
writer.close()

可视化操作

在终端输入:tensorboard --logdir 目录
在这里插入图片描述

访问:http://localhost:6006即可

在这里插入图片描述

writer.add_image的例子

from PIL import Image
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriterimg_path = r"./pic.png"
# 打开一张图片
img = Image.open(img_path)
# 使用transforms对图像进行变换
# 实例化totensor对象
to_tens = transforms.ToTensor()
# 将pic变成Tensor类型的图片
tens_img = to_tens(img) # 自动调用call函数
#print(tens_img)# 使用上一篇文章中tensorboard进行查看
writer = SummaryWriter("transforms_logs")
writer.add_image("test_transforms",tens_img) # 标题,图像类型
writer.close()

DataLoader的使用

from torch.utils.data import DataLoader

torch的DataLoader主要是用来装载数据,就是给定已知的数据集,把数据集装载进DataLoaer,然后送入深度学习网络进行训练。

在torch.utils.data.DataLoader()参数中,只有dataset为必填参数,其他的均有默认值,下文介绍几个重要的参数

在这里插入图片描述

  • dataset:表示要读取的数据集

  • batch_size:表示每次从数据集中取多少个数据

  • shuffle:表示是否为乱序取出,True表示前后不一样

  • num_workers :表示是否多进程读取数据(默认为0);

  • drop_last : 表示当样本数不能被batchsize整除时(即总数据集/batch_size 不能除尽,有余数时),最后一批数据(余数)是否舍弃(default:
    False)

  • pin_memory: 如果为True会将数据放置到GPU上去(默认为false)

还是以上文的CIFAR10的测试集为例

from torch.utils.data import DataLoader
import torchvision
test_set = torchvision.datasets.CIFAR10("dataset",train=False,transform=torchvision.transforms.ToTensor(),download=True)
Files already downloaded and verified
# 创建DataLoader实例
test_loader = DataLoader(dataset=test_set, # 引入数据集batch_size=4, # 每次取4个数据shuffle=True, # 打乱顺序num_workers=0, # 非多进程drop_last=False # 最后数据(余数)不舍弃
)

利用DataLoader的完整代码如下

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter# 准备测试集
test_set = torchvision.datasets.CIFAR10("dataset",train=False,transform=torchvision.transforms.ToTensor(),download=True)# 创建test_loader实例
test_loader = DataLoader(dataset=test_set, # 引入数据集batch_size=4, # 每次取4个数据shuffle=True, # 打乱顺序num_workers=0, # 非多进程drop_last=False # 最后数据(余数)不舍弃
)img,index = test_set[0]
print(img.shape) # 查看图片大小 torch.Size([3, 32, 32]) C h w,即三通道 32*32
print(index) # 查看图片标签
# 遍历test_loader
for data in test_loader:img,target = dataprint(img.shape) # 查看图片信息torch.Size([4, 3, 32, 32])表示一次4张图片,图片为3通道RGB,大小为32*32print(target)  # tensor([4, 9, 8, 8])表示4张图片的target
# 在tensorboard 中显示
writer = SummaryWriter("logs")
step = 0
for data in test_loader:img, target = datawriter.add_images("test_loader",img,step)step = step+1
writer.close()

tensorboard显示如下

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/117812.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深度学习系列52:多目标跟踪

1. 评价指标 1)FP:False Positive,即真实情况中没有,但跟踪算法误检出有目标存在。 2)FN:False Negative,即真实情况中有,但跟踪算法漏检了。 3)IDS:ID Sw…

IntelliJ IDEA 2023.2正式发布,新UI和Profiler转正

你好,我是YourBatman:做爱做之事❣交配交之人。 📚前言 北京时间2023年7月26日,IntelliJ IDEA 2023.2正式发布。老规矩,吃肉之前,可以先把这几碗汤干了,更有助于消化(每篇都很顶哦…

mac苹果电脑使用耳机听不到声音

大家在使用耳机收听音乐时候?是否经常遇到声音和音频播放问题的情况。这里小编为大家带来了三种不同的方法,帮助大家解决耳机在macOS系统电脑上怎么听不到任何声音的教程。如果大家对这篇文章感兴趣,那就来看下面的具体步骤吧。 方法一、检查…

【机器学习合集】优化目标与评估指标合集 ->(个人学习记录笔记)

文章目录 优化目标与评估指标1. 优化目标1.1 两类基础任务与常见优化目标1.2 分类任务损失0-1损失交叉熵损失与KL散度softmax损失的理解与改进Hinge损失 1.3 回归任务损失L1/L2距离L1/L2距离的改进 Huber loss 2. 评测指标2.1 分类任务中评测指标准确率(查准率)/召回率(查全率)…

入门人工智能 —— 学习数据持久化、使用 Python 将数据保存到mysql(7)

入门人工智能 —— 学习数据持久化、使用 Python 将数据保存到mysql 什么是数据持久化?使用 Python 进行数据持久化步骤 1: 安装 MySQL步骤 2: 安装必要的 Python 库步骤 3: 连接到 MySQL 数据库步骤 4: 创建数据表步骤 5: 插入数据步骤 6: 查询数据步骤 7: 关闭连接…

类图表示法

设计模式,用设计图表示的话,主要用到类图。常见UML类图如下: 1、类图:矩形框,代表一个类(Class)。类图分为三层,第一层显示类的名称,如果是抽象类,则用斜体显…

mac安装nodejs,跑vue程序

1. 下载node.js for mac,地址:Node.js。一路安装就可以了,无需修改。 2. mac终端,查看node和npm的版本。 3. 配置环境变量, vim .bash_profile增加PATH$PATH:/usr/local/bin/ 4. 但是毕竟npm安装一些东西还是太慢了所…

霍尔电流传感器如何应用在数据中心电量监测的-安科瑞 蒋静

摘要:数据中心供电电源质量的好坏直接影响到IT设备的安全运行,因此对数据中心直流列头柜电源进出线实行监测非常重要,而通过霍尔电流传感器可以采集主进线电流、多路支路直流电流和漏电流。 关键词:数据中心;直流列头…

好用的Visio绘图文件工具 VSD Viewer最新 for mac

VSD Viewer是一款可以查看Microsoft Visio绘图文件的工具,适用于Windows和macOS操作系统。它具有以下优点: 直观易用:VSD Viewer的用户界面非常简单直观,易于使用。支持多种文件格式:VSD Viewer支持多种Visio文件格式…

Rust逆向学习 (2)

文章目录 Guess a number0x01. Guess a number .part 1line 1loopline 3~7match 0x02. Reverse for enum0x03. Reverse for Tuple0x04. Guess a number .part 20x05. 总结 在上一篇文章中,我们比较完美地完成了第一次Rust ELF的逆向工作,但第一次编写的R…

公司电脑屏幕录制软件有什么功能

电脑屏幕录制软件有很多,今天简单说说说它的基础功能和附属功能: 基础功能: 1、屏幕录像 支持对所选电脑的屏幕进行录制,并且支持调整截屏频度、画面质量、单个视频时长等。 2、实时屏幕 可以对对方电脑进行实时屏幕查看&…

linux网络测试命令

文章目录 一.route命令解释二.traceroute命令三.nslookup命令四.本地主机映射文件五.修改网络配置文件六.设置网络接口参数 一.route命令解释 Destination(目标):这一列显示要路由的目标网络或主机的IP地址。它标识了数据包要发送到的目的地。…

浙江环保用电计量adw300-hj治污产污生产设备监测

浙江环保用电计量表,浙江环保用电能表,浙江环保督查计量电表,环保设备能耗采集表 企业基本信息 企业名称:XXXXXXXXXXX 企业地址:XX省XX市 工 程 量:X台监测仪表 预计工期:X天 监测点位信息…

实战经验分享:打造千万级直播项目,如何选择适合的长连接技术,告别CRUD开发

前言 其实不管大厂、小厂,做业务开发的同学都知道,写一个功能,有中台,有架构,有API,有SDK,很多可复用的代码直接调一下RPC接口或者一个注解就搞定了复杂的操作,所以很多螺丝钉们都没…

OPC UA:工业领域的“HTML”

OPC UA是工业自动化领域的一项重要的通信协议。它的特点是包括了信息模型构建方法。能够建立工业领域各种事物的信息模型。在工业自动化行业,OPCUA 类似互联网行业的HTTP协议和“HTML”语言。能够准确,可靠地描述复杂系统中各个元素,并且实现…

机器学习中常见的特征工程处理

一、特征工程 特征工程(Feature Engineering)对特征进行进一步分析,并对数据进行处理。 常见的特征工程包括:异常值处理、缺失值处理、数据分桶、特征处理、特征构造、特征筛选及降维等。 1、异常值处理 具体实现 from scipy.s…

桶装水订水送水小程序开发搭建;

上门送水小程序桶装水配送是一款的同城上门配送平台,为用户提供便捷的桶装水配送服务。解决用户在获取干净健康的饮用水方面的需求,提供高效、便捷的在线预约和下单服务。 小程序平台开发,具备强大的技术支持和良好的用户体验。用户可以通过…

跨平台开发技术

目录 1.Qt1.简介2.优势3.劣势 2.NET CoreVue1.简介2.优点 3.Flutter1.简介2.优点3.缺点 4.Maui1.简介2.优点3.缺点 5.Avalonia1.简介2.优点3.缺点 6. Cordova1.简介2.优点3.缺点 7.Electron1.简介2.优点3.缺点 个人搜集资料并总结了一些跨平台开发技术,如有不足欢迎…

分享一下怎么做一个房间预定链接

在旅游行业中,房间预定是非常重要的一环。随着互联网的普及和旅游业的发展,越来越多的人选择在网上预订房间。本文将介绍如何制作一个房间预定链接,以及推广该链接的方法和策略,帮助读者更好地了解房间预定的需求和实现方式。 一、…

隧道代理 vs 普通代理:哪种更适合您的爬虫应用?

前言 随着互联网的普及,爬虫技术在多个领域得到广泛应用。在进行爬虫开发时,代理服务器是不可或缺的工具之一。代理服务器可以隐藏客户端的真实 IP 地址和位置,从而保护客户端的隐私,同时通过代理可以绕过一些网络限制和安全机制…