【PyTorch】数据集

文章目录

  • 1. 创建数据集
    • 1.1. 直接继承Dataset类
    • 1.2. 使用TensorDataset类
  • 2. 数据集的划分
  • 3. 加载数据集
  • 4. 将数据转移到GPU

1. 创建数据集

主要是将数据集读入内存,并用Dataset类封装。

1.1. 直接继承Dataset类

必须要重写__getitem__方法,用于根据索引获得相应样本数据。必要时还可以重写__len__方法,用于返回数据集的大小。

from torch.utils.data import Datasetclass BostonHousingDataset(Dataset):"""定义波士顿房价数据集"""def __init__(self):self.data = np.load('../dataset/boston_housing/boston_housing.npz')def __getitem__(self, index):return self.data['x'][index], self.data['y'][index]def __len__(self):return self.data['x'].shape[0]

1.2. 使用TensorDataset类

将多个张量组合成一个数据集,要保证所有张量的第一个维度相等,保证每批样本数据格式相同。

import torch
from torch.utils.data import TensorDatasetdata = np.load('../dataset/boston_housing/boston_housing.npz')
X = torch.tensor(data['x'])
y = torch.tensor(data['y'])
dataset = TensorDataset(X, y)

2. 数据集的划分

数据集可以划分为训练集、验证集和测试集。

  • 训练集:用于模型拟合的数据样本集合。
  • 验证集:通常被用来调整模型的参数,以找出效果最佳的模型。
  • 测试集:用于训练好的模型性能评估的数据样本集合。
from torch.utils.data import random_splittrain_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

3. 加载数据集

使用DataLoader类将Dataset封装的数据集分成批次并进行迭代,以便于模型训练。DataLoader常用参数如下:

  • dataset
    要加载的数据集。
  • batch_size
    每个数据批次中包含的样本数。默认为1。
  • shuffle
    是否打乱数据集。默认为False。
  • num_workers
    使用几个进程来加载数据。默认为0,即在主进程中加载数据。
  • drop_last
    当数据集样本数不能被batch_size整除时,是否舍弃最后一个不完整的batch。默认为False。
from torch.utils.data import DataLoaderdataloader = DataLoader(dataset, batch_size=16, shuffle=True)

4. 将数据转移到GPU

一般在要运算时才将数据转移到GPU,有以下两种方法:

  1. var.to(device)
  2. var.cuda()
import torchdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for X,y in dataloader:# 将数据转移到GPUX = X.to(device)y = y.to(device)# 也可以X = X.cuda()y = y.cuda()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/194109.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【代码随想录算法训练营-第二天】【数组】977.有序数组的平方 ,209.长度最小的子数组 ,59.螺旋矩阵II

977.有序数组的平方 看完思路后一遍AC 思路剖析: 因为提到了时间复杂度为O(n),自然想到只能遍历一遍又因为只规定了时间复杂度,但是没有规定空间复杂度,所以可以考虑在定义一个数组【这一步没有考虑出来,是看了思路的…

数据结构和算法-哈夫曼树以相关代码实现

文章目录 总览带权路径长度哈夫曼树的定义哈夫曼树的构造法1法2 哈夫曼编码英文字母频次总结实验内容: 哈夫曼树一、上机实验的问题和要求(需求分析):二、程序设计的基本思想,原理和算法描述:三、调试和运行…

Matter学习笔记(3)——交互模型

一、简介 1.1 交互方式 交互模型层定义了客户端和服务器设备之间可以执行哪些交互。发起交互的节点称为发起者(通常为客户端设备),作为交互的接收者的节点称为目标(通常为服务器设备)。 节点通过以下方式进行交互&a…

Spring Initial 脚手架国内镜像地址

官方的脚手架下载太慢了,并且现在没有了Java8的选项,所以找到国内的脚手架镜像地址,推荐给大家。 首先说官方的脚手架 官方的脚手架地址为: https://start.spring.io/ 但是可以看到,并没有了Java8的选项。 所以推荐…

3dMax拼图生成工具Puzzle2D使用教程

Puzzle2D for 3dsMax拼图生成工具使用教程 Puzzle2D简介: 2D拼图随机生成器(英文:Puzzle2D) ,是一款由#沐风课堂#用MAXScript脚本语言开发的3dsMax建模小工具,可以随机创建2D可编辑样条线拼图图形。可批量…

【tensorflow学习-选择动作】 学习tensorflow代码调用过程

a actor.choose_action(s) def choose_action(self, s):s s[np.newaxis, :]return self.sess.run(self.action, {self.s: s}) # get probabilities for all actions输入:s 输出:self.sess.run(self.action, {self.s: s}) :a

解决:UnboundLocalError: local variable ‘js’ referenced before assignment

解决:UnboundLocalError: local variable ‘js’ referenced before assignment 文章目录 解决:UnboundLocalError: local variable js referenced before assignment背景报错问题报错翻译报错位置代码报错原因解决方法今天的分享就到此结束了 背景 在使…

MongoDB的原子性和多文档事务处理

原子性和事务处理是数据库操作的核心,保证了数据的准确性。依据数据库原子性,数据库和使用数据库的人员定义事务处理的方式。本文依据Mongodb的官方文档,整理Mongodb数据库的原子性和事务处理方法。 Mongodb的原子操作 Mongodb中&#xff0c…

实战案例:chatglm3 基础模型多轮对话微调

chatglm3 发布了,这次还发了base版本的模型,意味着我们可以基于这个base模型去自由地做SFT了。 本项目实现了基于base模型的SFT。 base模型 https://huggingface.co/THUDM/chatglm3-6b-base由于模型较大,建议离线下载后放在代码目录&#…

OSG编程指南:专栏内容介绍及目录

1、专栏介绍 OpenSceneGraph(OSG)场景图形系统是一个基于工业标准 OpenGL 的软件接口,它让程序员能够更加快速、便捷地创建高性能、跨平台的交互式图形程序。本专栏基于 OSG 3.6.5版本进行源码的编写及扩展,也通用于其他OSG版本的…

OpenTelemetry系列 - 第2篇 Java端接入OpenTelemetry

目录 一、架构说明二、方式1 - 自动化2.1 opentelemetry-javaagent.jar(Java8 )2.2 使用opentelemetry-javaagent.jar完成自动注入2.3 配置opentelemetry-javaagent.jar2.4 使用注解(WithSpan, SpanAttribute)2.5.1 代码集成WithS…

【栈和队列(2)】

文章目录 前言队列队列方法队列模拟实现循环队列练习1 队列实现栈 前言 队列和栈是相反的,栈是先进后出,队列是先进先出,相当于排队打饭,排第一的是最先打到饭出去的。 队列 队列:只允许在一端进行插入数据操作&…

20、Resnet 为什么这么重要

(本文已加入“计算机视觉入门与调优”专栏,点击专栏查看更多文章信息) resnet 这一网络的重要性,上一节大概介绍了一下,可以从以下两个方面来有所体现:第一是 resnet 广泛的作为其他神经网络的 back bone&…

Redis集合对象

一. 编码 集合对象的编码可以是intset或者hashtable。 intset编码的集合对象使用整数集合作为底层实现,集合对象包含的所有元素都保存在整数集合里面。 127.0.0.1:6379> sadd numbers 1 3 5 (integer) 3 127.0.0.1:6379> object encoding numbers "ints…

详细学习Pyqt5的10种容器(Containers)

Pyqt5相关文章: 快速掌握Pyqt5的三种主窗口 快速掌握Pyqt5的2种弹簧 快速掌握Pyqt5的5种布局 快速弄懂Pyqt5的5种项目视图(Item View) 快速弄懂Pyqt5的4种项目部件(Item Widget) 快速掌握Pyqt5的6种按钮 快速掌握Pyqt5的10种容器&…

Django rest froamwork-序列化关系

关系字段用于表示模型关系。它们可以应用于 ForeignKey、ManyToManyField 和OneToOneField 关系,也可以应用于反向关系和自定义关系(如GenericForeignKey)。 注意:关系字段是在relations.py中声明的,但按照惯例&#…

使用凌鲨进行内网穿透

为了方便在本地进行开发和调试工作,有时候需要安全地连接内网或Kubernetes集群中的服务。 在net proxy server中可以限制访问用户,也可以设置端口转发的密码。 使用 连接端口转发服务 列出可转发端口 可转发端口是服务端设置的,不会暴露真…

自恋的领导

自恋的领导》??? 在职场中,我曾经遇到过一位自恋狂的领导。他总是自吹自擂,自我标榜,而且对团队合作态度消极,经常拖后腿。他的言行举止充满了负能量,让人感到非常不舒服。例如&…

13 OAuth2.0实战:微服务接收身份信息

上一节介绍了网关层面的统一认证鉴权,将解析过的身份信息加密放入请求头传递给下游微服务; 那么下游微服务如何接收网关传递的身份信息? 很简单,只需要在每个服务的过滤器中从请求头接收,将其解密。 木谷博客系统中是将该过滤器统一放在blog-common-starter中,这样后续…