PyTorch深度学习——数据输入和预处理

pytorch数据载入

数据载入

在使用pytorch构建和训练模型的过程中,需要经常把原始数据(比如图片、音频)转化为张量的格式,为了方便地批量处理图片数据,pytorch引入了一系列工具来对这个过程进行包装

torch.utils.data.DataLoader

pytorch提供的一个用于数据加载的工具类,用于批量加载数据并为模型提供输入。它可以将数据集包装成一个可迭代的对象,方便地进行数据加载和批处理操作

Pytorch torch.utils.data.DataLoader 用法详细介绍-CSDN博客

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,batch_sampler=None, num_workers=0, collate_fn=None,pin_memory=False, drop_last=False, timeout=0,worker_init_fn=None, *, prefetch_factor=2,persistent_workers=False)

参数说明

  • dataset :要从中加载数据的数据集(一个torch.utils.data.DataLoader的实例
  • batch_size:每批次要装载多少样品(迷你批次的大小)
  • shuffle :设置为True以使数据在每个时期都重新洗牌
  • sampler :定义从数据集中抽取样本的策略
  • batch_sampler:类似于采样器sampler,但一次返回一个迷你批次的索引,sampler只返回一个下标索引, 与batch_size,shuffle,sampler和drop_last互斥
  • num_workers :多少个子流程用于数据加载。 0表示将在主进程中加载数据 (默认值:0)
  • collate_fn :把一批 dataset 的实例转化为包含迷你批次数据的张量
  • pin_memory :如果为True,则数据加载器在将张量返回之前将其复制到CUDA固定的内存中。 如果您的数据元素是自定义类型,或者您的collate_fn返回的是一个自定义类型的批处理
  • drop_last :决定是否将最后一个迷你批次的数据丢掉
  • timeout :如果为正,则为从工作人员收集批次的超时值。 应始终为非负数。 (默认值:0)
  • worker_init_fn:如果非None,这个函数将在每个工作子进程上被调用,并接收工作进程ID(一个在[0, num_workers - 1]范围内的整数)作为输入,它在设置随机种子之后、但在数据加载之前被调用。(默认:None)
  • prefetch_factor :每个子流程预先加载的样本数。 2表示将在所有子流程中预取总共2 * num_workers个样本。 (默认值:2)
  • persistent_workers :如果为True,则一次使用数据集后,数据加载器将不会关闭工作进程。 这样可以使Worker Dataset实例保持活动状态。 (默认值:False)

映射类型的数据集

为了能够使用 DataLoader 类,首先需要构造关于单个数据的 torch.ulits.data.Dataset 类,这个类有两种:一种是映射类型(Map-Style),对于这个类型,每个数据都有一个对应的索引,通过输入具体的索引,就能得到对应的数据

import torch.utils.data as dataclass MyDataset(data.Dataset):def __init__(self, data_list):self.data_list = data_listdef __len__(self):return len(self.data_list)def __getitem__(self, index):return self.data_list[index]

一般来说,对于这个类,主要需要重写两个方法:一个是 __getitem__ ,该方法是python内置的操作符方法,对应的操作符是索引操作符 [],通过输入整数数据索引,其大小在0至N-1之间(N微数据的总数目),返回具体的某一条数据记录,这就是该方法需要完成的任务,而具体的逻辑需要根据数据集的类型来决定,另一个方法是 __len__ ,该方法返回数据的总数

在python,如果一个Dataset类重写了该方法,可以通过使用 len 内置函数来获取数据的数目

torchvision工具包的使用

PyTorch:Torchvision的简单介绍与使用-CSDN博客

可迭代类型的数据集

from torch.utils.data import IterableDatasetclass MyIterableDataset(IterableDataset):def __init__(self, file_path):self.file_path = file_pathdef __iter__(self):with open(self.file_path, 'r') as file_obj:for line in file_obj:line_data = line.strip('\n').split(',')yield line_dataif __name__ == '__main__':dataset = MyIterableDataset('test_csv.csv')for data in dataset:print(data)

pytorch模型的保存和加载

序列化和反序列化

由于pytorch的模块和张量的本质是 torch.nn.Module 和 torch.tensor 类的实例,而pytorch自带了一系列的方法,可以将这些类的实例转化为字符串,所以这些势力可以通过python序列化方法进行序列化(serialization)和反序列化(unserialization)

pytorch的实现里集成了python自带的pickle包对模块和张量进行序列化,张量序列化的本质是把张量的信息,包括数据类型和存储位置,以及携带的数据等转化为字符串,而这些字符串时候可以通过使用python自带的文件IO函数进行存储,这个过程是可逆的,即可以通过文件IO函数来读取存储的字符串,然后将字符串逆向解析成pytorch的模块和张量

torch.save(obj,f,pickle_module=pickle,pickle_protocol=2)
torch.load(f,map_location=None,pickle_module=pickle,**pickle_load_args)

torch.save 函数传入的第一个参数是pytorch中可以被序列化的对象,包括模型和张量等,第二个参数是存储文件的路径,序列化的结果将会被保留在这个路径里面,第三个参数是默认的,传入的是序列化的库,可以使用pytorch默认的序列化库pickle,第四个参数是pickle协议,即如何把对象转化为字符串的规范,上述使用的协议版本是2

与 torch.save 函数对应的是 torch.load 函数,该函数在给定序列化后的文件路径之后,就能输出 pytorch 的对象,第一个参数是文件路径之后,第二个参数是张量存储位置的映射,如果存储时的模型在CPU上,可以直接使用默认参数,但当存储的模型在GPU上,torch.load 的默认行为是先把模型载入CPU中,然后转移到保存时的GPU上,加入载入模型的时候是在另外一台计算机上,而计算机没有GPU或GPU的型号对不上就会报错

此时可以使用 map_loactin 函数,设置 map_loactin = 'CPU',这样就会把模型保留在CPU里面,不再移动到GPU中,pickle_module 参数和 torch.save 里的同名参数的作用一致

在pytorch中,模型的保存方法有两种,第一种是直接保存模型的实例(因为模型本身可以被序列化),第二种是保存模型的状态字典(State Dict),一个模型的状态字典包含模型所有参数的名字以及名字对应的张量,通过调用 state_dict 方法,就可获取当前模型的状态字典

状态字典的保存和载入

由于pytorch模块的实现依赖具体的pytorch版本,所以会存在一种情况:使用某一个版本保存的序列化文件无法被另一个版本的pytorch载入,相比之下,pytorch的张量变动较小,二状态字典只含有张量参数的名字和张量参数的具体信息,预模块的实现关联较小,因此更加推荐使用 state_dict 方法来获取状态字典,然后保存该张量字典来保存模型,这样可以实现最大限度地减小代码对pytorch版本的依赖性

另外在训练的时候,不仅要保存模型的相关信息,还要保存优化器的相关信息,因为可能需要从存储的检查点出发,继续进行训练,pytorch中参数:当前的学习率,当前梯度的指数移动平均等,通过调用优化器的 state_dict 方法和 load_state_dict 方法,可以让优化器输出和载入相关的状态信息

save_info = { # 保存的信息"iter_num":iter_num,  # 迭代步数"optimizer":optimizer.state_dict,  # 优化器的状态字典"model":model.state_dict(),  # 模型的状态字典
}
# 保存信息
torch.save(save_info,save_path)
# 载入信息
save_info = torch.load(save_path)
optimizer.load_stste_dict(save_info["optimizer"])
model.load_stste_dict(save_info["model"])

pytorch数据可视化

tensorboard是一个数据可视化工具,能直观的显示深度学习中张量的变化,从这个变幻的过程中很容易的可以了解到模型在训练中的行为,包括但不限于损失函数的下降趋势是否合理,张量分量的分布是否在训练过程中发生变化

Pytorch:Tensorboard的安装及常用类的使用【图表+图片方法的使用】-CSDN博客

pytorch进阶 可视化工具TensorBoard的使用_pip install future tensorboard-CSDN博客

PyCharm中TensorBoard的安装和使用_phyton怎么安装 tensorborad-CSDN博客

Tensorboard的使用 ---- SummaryWriter类(pytorch版)-CSDN博客

pytorch模型的并行化

多GPU训练:PyTorch中的数据并行与模型并行-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/831602.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

计算机网络chapter2——应用层

文章目录 第2章 应用层章节引出—— 2.1应用层协议原理2.1.1 网络应用程序体系结构(1)客户-服务器体系结构(2)对等(P2P)体系结构2.1.2 进程通信1.客户和服务器进程2.进程与计算机网络之间的接口3. 进程寻址 2.1.3 可供应用程序使用…

STM32数字示波器+详细注释+上位机程序+硬件

目录 1、设计指标: 2、功能: 3、上位机的程序 ​4、测试的照片 5、PCB 6、模拟电路板 7、程序 资料下载地址:STM32数字示波器详细注释上位机程序硬件 1、设计指标: 主控: STM32…

中华科技控股集团:人工智能标准化引领者与数字化服务新航程的启航者

4月30日, 矗立于时代科技潮头的中华科技控股集团,自2010年在香港这片国际金融沃土上诞生以来,便以其独特的国资背景与全球化视野,肩负起推动中国科技进步与产业升级的重任。作为国资委麾下的重要一员,中华科技始终坚持创新驱动发展…

MLP手写数字识别(2)-模型构建、训练与识别(tensorflow)

查看tensorflow版本 import tensorflow as tfprint(Tensorflow Version:{}.format(tf.__version__)) print(tf.config.list_physical_devices())1.MNIST的数据集下载与预处理 import tensorflow as tf from keras.datasets import mnist from keras.utils import to_categori…

三生随记——博物馆的深夜秘密

博物馆的深夜秘密 第一章:古老的博物馆 在一座历史悠久的城市中,矗立着一座庞大的博物馆,收藏着无数的历史文物和艺术品。这座博物馆的外观宏伟壮观,充满了历史的痕迹。然而,随着夜幕的降临,博物馆里似乎隐…

Adobe 更新 Firefly Image 3 图像生成模型

一个工具或者模型,对于初次使用的人来说,易用性和超出预期的效果很能吸引使用者,suno和mj在这方面我感觉确实不错,第一次使用感觉很惊艳。 Adobe 更新 Firefly Image 3 图像生成模型,我用了mj的提示词,最后…

Python的控制流

Python中的控制流是指通过条件语句和循环来控制程序的执行流程。控制流使程序能够根据不同的条件执行不同的代码块,或者重复执行特定的代码块。本文将详细介绍Python中的条件语句(if语句)和循环(for循环和while循环)&a…

报错cannot import name ‘MultiHeadAttention‘ from ‘tensorflow.keras.layers‘

小伙伴们大家好,废话不多说,直接上解决方案 简单粗暴 ↓ ↓ ↓ ↓ ↓ ↓ ↓ 遇到这个问题我也是找了很多办法 我的python版本是3.8.5,我装的tensorflow版本是2.2.0,说实话 pip install keras-multi-head 这个方法…

低频卡 LF 的应用与技术特点

低频卡 LF(Low Frequency)在现代生活中有着广泛的应用,展现出独特的技术优势。 在畜牧业管理中,LF 卡被广泛用于动物标识,如电子耳标或项圈。这些卡可以实时追踪动物的健康状况、繁殖情况和移动轨迹,为畜牧…

【Cpp】类和对象#拷贝构造 赋值重载

标题:【Cpp】类和对象#拷贝构造 赋值重载 水墨不写bug 目录 (一)拷贝构造 (二)赋值重载 (三)浅拷贝与深拷贝 正文开始: (一)拷贝构造 拷贝构造函数&…

后端python构网并生成纹理图片发回给cesium做贴地处理

在后端Python中,你可以使用一些库来进行网格构建和纹理生成,然后将生成的纹理图片发送给Cesium进行贴地处理。以下是一种可能的方法: 构建网格:使用点的坐标信息和索引信息,可以使用一些三角网格生成算法来构建网格。你…

UG NX二次开发(C#)-获取Part中对象创建时的序号(*)

文章目录 1、前言2、UG NX的对象序号讲解3、采用UG NX二次开发或者建模序号4、注意事项1、前言 在UG NX中,我们创建任意一个对象,都会在模型历史中添加一个创建对象的编号,即是对象序号,这个是递增的,当删除中间产生的对象时,其序号会重新按照建模顺序重新排布。今天一个…

MLP实现fashion_mnist数据集分类(2)-函数式API构建模型(tensorflow)

使用函数式API构建模型,使得模型可以处理多输入多输出。 1、查看tensorflow版本 import tensorflow as tfprint(Tensorflow Version:{}.format(tf.__version__)) print(tf.config.list_physical_devices())2、fashion_mnist数据集分类模型 2.1 使用Sequential构建…

C++关联容器1——关联容器概述,map,set介绍,pair类型

关联容器 关联容器支持高效的关键字查找和访问。 两个主要的关联容器(associative-container)类型是map和set。 map中的元素是一些关键字一值(key-value)对:关键字起到索引的作用,值则表示与索引相关联的数据。 se…

内网安全-代理Socks协议路由不出网后渗透通讯CS-MSF控制上线简单总结

我这里只记录原理,具体操作看文章后半段或者这篇文章内网渗透—代理Socks协议、路由不出网、后渗透通讯、CS-MSF控制上线_内网渗透 代理-CSDN博客 注意这里是解决后渗透通讯问题,之后怎么提权,控制后面再说 背景 只有win7有网,其…

26 JavaScript学习:JSON和void

JSON 英文全称 JavaScript Object NotationJSON 是一种轻量级的数据交换格式。JSON是独立的语言JSON 易于理解。 JSON 实例 简单的 JSON 字符串实例: "{\"name\": \"Alice\", \"age\": 25, \"city\": \"San Francisco\&…

PX4二次开发快速入门(三):自定义串口驱动

文章目录 前言 前言 软件:PX4 1.14.0稳定版 硬件:纳雷NRA12,pixhawk4 仿照原生固件tfmini的驱动进行编写 源码地址: https://gitee.com/Mbot_admin/px4-1.14.0-csdn 修改 src/drivers/distance_sensor/CMakeLists.txt 添加 add…

Servlet详解(从xml到注解)

文章目录 概述介绍作用 快速入门Servelt的执行原理执行流程:执行原理 生命周期概述API 服务器启动,立刻加载Servlet对象(理解)实现Servlet方式(三种)实现Servlet接口实现GenericServlet抽象类,只重写service方法实现HttpServlet实现类实现Htt…

NodeJs入门知识

**************************************************************************************************************************************************************************** 1、配置Node.js与npm下载(精力所致,必有精品) …

算法--分治法

分治法是一种算法设计策略,它将一个复杂的问题分解成两个或多个相同或相似的子问题,直到这些子问题可以简单地直接解决。然后,这些子问题的解被合并以产生原始问题的解。 分治法通常遵循以下三个步骤: 分解:将原问题…