PyTorch数据结构

前言:因为最近开始读深度学习代码,主要都是用PyTorch框架,所以来补一些PyTorch基础,先从数据结构入手。

PyTorch数据结构

  • PyTorch
  • PyTorch数据结构
    • 张量
      • 属性:维度、轴、形状
      • 常见的操作
    • 数据集
      • 构造代码
      • DataLoader
    • 模块
  • 参考

PyTorch

PyTorch:PyTorch是一个开源深度学习框架,有很多好用的深度学习工具,提供了丰富的库,可以很方便构建和训练神经网络模型。

  • GPU加持:PyTorch提供了GPU优化操作和管理,使得在GPU上运行模型更高效。
  • 提供预训练模型和模型库:PyTorch提供了很多预训练模型和模型库,能很方便进行深度学习模型的开发。
  • 支持分布式训练:PyTorch支持分布式训练,可以在多个GPU和多台机器上加速训练。
  • 动态计算图:PyTorch使用动态计算图来计算图,在运行时动态生成而不是编译时静态生成,可以观察动态生成的数据流向。
  • 自动求导:PyTorch内置了自动求导功能,避免手动去计算非常复杂的导数,极大地减少了构建模型的时间。

PyTorch数据结构

张量

Tensor(张量):Tensor是PyTorch中最基本的数据结构,类似于多维数组。它可以表示标量、向量、矩阵或任意维度的数组。

属性:维度、轴、形状

维度(Dimensions):维度又可以叫做阶(Rank),理解为数组的维度。只有标量就是0维度,一维数组就是1维度,其余以此类推。

轴(Axis):轴数和维数、阶数相同,多维张量需要索引才能引用到里面的内容,这个不同维度索引就是轴。例如形状3×4的张量,需要访问张量[0][2]位置的内容,轴指的就是“[]”里面的索引,第一个轴的长度是3,第二的轴的长度是4。

形状(Shape):张量的形状由每个轴的长度决定,轴的长度就是对应维度能索引的大小。

相关的代码:
常用的函数有Tensor.size()和Tensor.dim()。

import torch
# 创建3维张量
tensor_3d = torch.tensor([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])
print("张量形状:", tensor_3d.size())# torch.Size([1,2,3]) 1个2*3的数组
print("轴数:", tensor_3d.dim())# 3
print(tensor_3d.size(1)==tensor_3d.size(-2))# True x.size(index)表示取出某个维度上的大小,正的表示从左到右,负的表示从右到左

常见的操作

重构(reshape):torch.reshape()可以重构张量的形状。过程是先把所有的内容按行排列,然后先分高纬再分低维度。

重构举例:

x = torch.tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])
x.reshape(4,3) # 变成3行4列矩阵
x # tensor([[ 0,  1,  2],[ 3,  4,  5],[ 6,  7, 8],[ 9,  10,  11]])
x.reshape(2,3,2) 
# 先变一维 [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11]
# 再变二维 [[ 0,  1,  2,  3,  4,  5],[ 6,  7,  8,  9, 10, 11]]
# 再变三维 [[[ 0,  1],[ 2,  3],[ 4,  5]],[[ 6,  7],[ 8,  9],[10, 11]]]
x # tensor([[[ 0,  1],[ 2,  3],[ 4,  5]],[[ 6,  7],[ 8,  9],[10, 11]]])

张量索引:

x = torch.tensor([[[ 0,  1,  2,  3],[ 4,  5,  6,  7],[ 8,  9, 10, 11]]])
print(x[0])# tensor([[ 0,  1,  2,  3],[ 4,  5,  6,  7],[ 8,  9, 10, 11]])
print(x[0][0])# tensor([0, 1, 2, 3])
print(A[0, 0:2, :])# tensor([[0, 1, 2, 3],[4, 5, 6, 7]])

拼接(cat和stack):

  • stack:在新创建的维度上进行拼接,会扩宽维度。
  • cat:按张量维度进行拼接。
# [2,3]->[2,9]
x=torch.ones((2, 3))
y = torch.cat([x, x], dim=0)
print(y)# tensor([[1., 1., 1.],[1., 1., 1.],[1., 1., 1.],[1., 1., 1.]])
# [2,3]->[2, 3, 2]
y = torch.stack([x, x], dim=2)
print(y)# tensor([[[1., 1.],[1., 1.],[1., 1.]],[[1., 1.],[1., 1.],[1., 1.]]])

升维(unsqueeze):指定维度插入新的维度。

x = torch.tensor([1,2,3,4])
y = x.unsqueeze(dim=0)
print(y)# tensor([[1],[2],[3],[4]])

降维(squeeze):移除制定或维度大小为1的维度。

x = torch.tensor([[[1],[2]],[[3],[4]]])
y = x.squeeze(2)
print(y)# tensor([[1, 2],[3, 4]])

升维和降维的好处,在于深度学习通常做运算会有维度要求。

数据集

Dataset(数据集):Dataset是一个抽象类,用于表示数据集。

使用:通过继承Dataset类,可以自定义数据集,并实现数据加载、预处理和获取样本等功能。PyTorch还提供了一些内置的数据集类,如MNIST、CIFAR-10等,用于方便地加载常用的数据集。

构造代码

代码:
通过继承该类来自定义自己的数据集类,在继承时要求必须重载__len__()和__getitem__()这两个方法。

  • __len__():返回的是数据集的大小。
  • __getitem__():实现索引数据集中的某一个数据。
import torch
from torch.utils.data import Datasetclass BasicDataset(Dataset):# 继承Datasetdef __init__(self, data_tensor, target_tensor):self.data_tensor = data_tensorself.target_tensor = target_tensordef __getitem__(self, index):return self.data_tensor[index], self.target_tensor[index]def __len__(self):return self.data_tensor.size(0)# 生成数据
data_tensor = torch.randn(4, 3)# 生成一个每个元素服从正态分布的4行3列随机张量
target_tensor = torch.rand(10)# 从区间[0,1)的均匀分布中随机抽取一个随机数生成一个张量# 将数据封装成Dataset
tensor_dataset = BasicDataset(data_tensor, target_tensor)print(tensor_dataset[1])# 调用__getitem__print(len(tensor_dataset))# 调用__len__

DataLoader

DataLoader:DataLoader将Dataset对象或自定义数据类的对象封装成一个迭代器,通过迭代器可以输出Dataset的内容。

DataLoader参数:

  • dataset:表示Dataset类,数据从哪读取以及如何读取。
  • batch_size:表示批大小。
  • shuffle:表示每个epoch要不要重新打乱数据,默认false。
  • num_works:用多少个子进程读取数据。
  • drop_last:表示当样本数不能被batch_size整除时,是否舍弃最后一批数据。

batch和epoch的区别:

  • 一个epoch就是将所有训练样本训练一次的过程,神经网络的训练往往会需要很多次epoch才会loss收敛到合适的程度。
  • 将整个训练样本分成若干个Batch。

使用代码:

# batch_size设置为2,shuffle=False不打乱数据顺序,num_workers=1使用1个子进程
dataloader = BasicDataset(dataset, batch_size=2, shuffle=False, num_workers=1)# 以for循环形式输出
for input, target in dataloader:print(input, target)

模块

Module(模块):Module是PyTorch中用于构建模型的基类。通过继承Module类,可以定义自己的模型,并实现前向传播和反向传播等方法。Module提供了参数管理、模型保存和加载等功能,方便模型的训练和部署。

实际去看深度学习代码的时候,会发现定义模型的类,都是继承nn.Module(模块)。

代码技巧:

  • 网络中具有可学习参数的层(如全连接层、卷积层等)放在构造函数__init__()中。不具有参数的也可以放入(ReLU、dropout、BatchNormanation),如果不写在构造函数的话,可以在forward方法中用nn.functional来代替。
  • forward方法是必须要重写的,是实现模型的功能,实现各个层之间的连接关系的核心。在阅读模型的时候,阅读forward是搞懂模型运作流程最好的方式。
import torch
import torch.nn.functional as Fclass MyNet(torch.nn.Module):def __init__(self):super(MyNet, self).__init__()  # 第一句话,调用父类的构造函数self.conv1 = torch.nn.Conv2d(3, 32, 3, 1, 1)# self.relu1=torch.nn.ReLU()# self.max_pooling1=torch.nn.MaxPool2d(2,1)self.conv2 = torch.nn.Conv2d(3, 32, 3, 1, 1)# self.relu2=torch.nn.ReLU()# self.max_pooling2=torch.nn.MaxPool2d(2,1)self.dense1 = torch.nn.Linear(32 * 3 * 3, 128)self.dense2 = torch.nn.Linear(128, 10)def forward(self, x):x = self.conv1(x)x = F.relu(x)# self.relu1(x)x = F.max_pool2d(x)# self.max_pooling1(x)x = self.conv2(x)x = F.relu(x)#  self.relu2(x)x = F.max_pool2d(x)# self.max_pooling2(x)x = self.dense1(x)x = self.dense2(x)return x

参考

nn.Module类详解
Dataset和DataLoader
PyTorch数据结构

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/779270.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Redis 教程系列之Redis 事务(十六)

Redis 事务可以一次执行多个命令, 并且带有以下两个重要的保证: 事务是一个单独的隔离操作:事务中的所有命令都会序列化、按顺序地执行。事务在执行的过程中,不会被其他客户端发送来的命令请求所打断。事务是一个原子操作&#x…

2024年03月CCF-GESP编程能力等级认证C++编程八级真题解析

本文收录于专栏《C++等级认证CCF-GESP真题解析》,专栏总目录:点这里。订阅后可阅读专栏内所有文章。 一、单选题(每题 2 分,共 30 分) 第 1 题 为丰富食堂菜谱,炒菜部进行头脑风暴。肉类有鸡肉、牛肉、羊肉、猪肉4种,切法有肉排、肉块、肉末3种,配菜有圆白菜、油菜、…

react useState的初始化函数 初始化值为props时的同步问题 | setState函数的使用与异步更新

文章目录 react setState函数useState()钩子创建state如何根据props更新state值 setState的参数是下一个状态statesetState的参数是更新函数functionsetState异步与同步合成事件setState 实现原理 react setState函数 useState()钩子创建state const [state, setState] useS…

大数据做「AI大模型」数据清洗调优基础篇

关于本文 近期一直在协助做AI大模型数据清洗调优的工作,主要就是使用大数据计算引擎Spark做一些原始数据的清洗工作,整体数据量大约6PB-8PB之间,那么对于整个大数据量的处理性能将是一个重大的挑战,关于具体的调优参数配置项暂时不…

【论文阅读+复现】AniPortrait: Audio-Driven Synthesis of Photorealistic Portrait Animation

AniPortrait:音频驱动的逼真肖像动画合成。 code:Zejun-Yang/AniPortrait: AniPortrait: Audio-Driven Synthesis of Photorealistic Portrait Animation (github.com) paper:[2403.17694] AniPortrait: Audio-Driven Synthesis of Photoreal…

Soot入门学习笔记

Soot 适合参考的文档和教程如下: 北京大学软件分析技术 南京大学软件分析 Tutorials for soot McGill University 198:515 (vt.edu) 比较好的笔记资料: 南京大学《软件分析》课程笔记 比较好的入门作业或者案例: CSCE710 Assignmen…

MySQL8 中文参考翻译完成

MySQL8 中文参考前言和法律声明第一章 一般信息1.1 关于本手册1.2 MySQL 数据库管理系统概述1.2.1 什么是 MySQL?1.2.2 MySQL 的主要特性1.2.3 MySQL 的历史1.3 MySQL 8.0 中的新功能1.4 MySQL 8.0 中新增、弃用或删除的服务器和状态变量和选项1.5 如何报告错误或问…

产品说明书二维码生成的优势:数字化时代的智能选择

随着二维码技术的不断发展,越来越多的企业开始选择使用二维码来展示产品使用说明,以取代传统的纸质说明书。这一趋势不仅符合数字化时代的潮流,更为消费者提供了更便捷、更智能的产品使用体验。以下是产品说明书二维码生成的优势:…

Android WebView的使用与后退键处理

目录 前言首先,我们需要在布局文件中添加webView组件在Activity中获取webView实例,并加载网页内容 前言 webView是Android中常用的组件之一,用于展示网页内容。它可以加载HTML文件、URL链接等网页内容,并提供交互功能。在使用webV…

C#_泛型_委托

文章目录 泛型泛型的使用泛型的约束委托委托的实例化多播委托委托的调用内置委托类型委托练习泛型委托Lambda表达式(进阶)上期习题答案本期习题 泛型 泛型(Generic) 是一种规范,它允许我们使用占位符来定义类和方法,编译器会在编…

Linux进程概念(下)

1. 进程的状态 为了弄明白正在运行的进程是什么意思,我们需要知道进程的不同状态。一个进程可以有几个状态(在Linux内核里,进程有时候也叫做任务)。 下面的状态在kernel源代码里定义: /* * The task state array is…

数对 离散化BIT

先把公式变个形&#xff0c;然后直接BIT 枚举右端点查询左端点累加答案 离散化好题&#xff0c;注意BIT写的时候右端点的范围是离散化区间的大小 #include<bits/stdc.h> using namespace std; #define int long long using ll long long; using pii pair<int,int&…

【ZZULIOJ】1011: 圆柱体表面积(Java)

目录 题目描述 输入 输出 样例输入 Copy 样例输出 Copy code 题目描述 输入圆柱体的底面半径r和高h&#xff0c;计算圆柱体的表面积并输出到屏幕上。要求定义圆周率为如下宏常量 #define PI 3.14159 输入 输入两个实数&#xff0c;为圆柱体的底面半径r和高h。 输出 输…

国内好用的chatGPT和AI绘图工具

分享一个比较好用的AI 分享一个比较好用的AI&#xff0c;只是需要开通会员&#xff0c;目前官网的价格是&#xff1a;298&#xff0c;开通之后可以使用chatgpt4、AI绘画、图片融合等等&#xff01;不开通的话是可以免费使用15次的&#xff0c;下面是一些介绍图片&#xff01;链…

UE5数字孪生系列笔记(三)

C创建Pawn类玩家 创建一个GameMode蓝图用来加载我们自定义的游戏Mode新建一个Pawn的C&#xff0c;MyCharacter类作为玩家&#xff0c;新建一个相机组件与相机臂组件&#xff0c;box组件作为根组件 // Fill out your copyright notice in the Description page of Project Set…

【python】网络编程socket TCP UDP

文章目录 socket常用方法TCP客户端服务器UDP客户端服务器网络编程就是实现两台计算机的通信 互联网协议族 即通用标准协议,任何私有网络只要支持这个协议,就可以接入互联网。 socket socke模块的socket()函数 import socketsock = socket.socket(Address Family,

Solidity Uniswap V2 Router swapTokensForExactTokens

最初的router合约实现了许多不同的交换方式。我们不会实现所有的方式&#xff0c;但我想向大家展示如何实现倒置交换&#xff1a;用未知量的输入Token交换精确量的输出代币。这是一个有趣的用例&#xff0c;可能并不常用&#xff0c;但仍有可能实现。 GitHub - XuHugo/solidit…

修改Docker Gitlab root 的密码

一、进入git docker 容器 docker exec -it [容器ID或名称] /bin/bash 二、查找并修改账号 user User.find_by(username: ‘root’) user.password ‘root********’ user.password_confirmation ‘root********’ user.save! 三、重启生效 附&#xff1a;第一次&#xff…

Golang实战:深入hash/crc64标准库的应用与技巧

Golang实战&#xff1a;深入hash/crc64标准库的应用与技巧 引言hash/crc64简介基本原理核心功能 环境准备安装Golang创建一个新的Golang项目引入hash/crc64包测试环境配置 hash/crc64的基本使用计算字符串的CRC64校验和计算文件的CRC64校验和 高级技巧与应用数据流和分块处理网…

Jmeter 配置说明之线程组

一、线程组介绍&#xff1a; 线程组元件是任何一个测试计划的开始点。在一个测试计划中的所有元件都必须在某个线程组下。所有的任务都是基于线程组&#xff1a; 通俗理解&#xff1a; 线程组&#xff1a;就是一个线程组&#xff0c;里面有若干个请求&#xff1b; 线程&am…