PyTorch 的各个核心模块和它们的功能

1. torch

核心功能
  • 张量操作:PyTorch 的张量是一个多维数组,类似于 NumPy 的 ndarray,但支持 GPU 加速。
  • 数学运算:提供了各种数学运算,包括线性代数操作、随机数生成等。
  • 自动微分torch.autograd 模块用于自动计算梯度。
  • 设备管理:允许在 CPU 和 GPU 之间移动张量。

示例代码

import torch# 创建张量
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
y = torch.tensor([4.0, 5.0, 6.0])# 张量加法
z = x + y
print(f'z: {z}')# 计算梯度
z.sum().backward() # 求和的原因是求梯度需要是一个标量
print(f'Gradients of x: {x.grad}')

2. torch.nn

核心功能
  • 构建神经网络模块nn.Module 是所有神经网络模块的基类。
  • 常用层:如卷积层、池化层、全连接层、激活函数、归一化层等。
  • 损失函数:如交叉熵损失、均方误差损失等。

示例代码

import torch.nn as nn# 定义一个简单的前馈神经网络
class SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.fc1 = nn.Linear(10, 5)self.fc2 = nn.Linear(5, 1)def forward(self, x):x = torch.relu(self.fc1(x))x = self.fc2(x)return xmodel = SimpleNet()
print(model)

3. torch.optim

核心功能
  • 优化算法:包括 SGD、Adam、RMSprop 等。
  • 学习率调度器:用于动态调整学习率,如 StepLRExponentialLR

示例代码

import torch.optim as optim# 定义模型
model = SimpleNet()# 定义优化器
optimizer = optim.Adam(model.parameters(), lr=0.001)# 更新模型参数
optimizer.zero_grad()
output = model(torch.randn(1, 10))
loss = torch.mean(output)
loss.backward()
optimizer.step()# 学习率调度器
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
scheduler.step()

4. torch.utils.data

核心功能
  • 数据集Dataset 类用于自定义数据集。
  • 数据加载器DataLoader 用于批量加载数据,支持多线程加载。
  • 数据变换:通过 torchvision.transforms 可以对数据进行预处理和增强。

示例代码

from torch.utils.data import Dataset, DataLoader# 自定义数据集
class MyDataset(Dataset):def __init__(self, data):self.data = datadef __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx]dataset = MyDataset([1, 2, 3, 4])
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)for batch in dataloader:print(batch)

5. torchvision

核心功能
  • 数据集:提供了常用的计算机视觉数据集,如 MNIST、CIFAR-10、ImageNet 等。
  • 预训练模型:如 ResNet、VGG、AlexNet 等。
  • 数据变换:如图像调整大小、裁剪、归一化等。

示例代码

import torchvision.transforms as transforms
import torchvision.datasets as datasets# 定义数据预处理
transform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])# 下载 MNIST 数据集
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)for images, labels in dataloader:print(images.shape, labels.shape)break

6. torch.jit

核心功能
  • TorchScript:通过脚本化和追踪将 Python 模型转换为 TorchScript 模型,提高执行效率并支持跨平台部署。
  • 脚本化torch.jit.script 用于将 Python 代码转换为 TorchScript 代码。
  • 追踪torch.jit.trace 用于通过追踪模型的执行流程创建 TorchScript 模型。

示例代码

import torch.jit# 定义简单模型
class SimpleNet(nn.Module):def forward(self, x):return x * 2model = SimpleNet()# 脚本化模型
scripted_model = torch.jit.script(model)
print(scripted_model)# 追踪模型
traced_model = torch.jit.trace(model, torch.randn(1, 10))
print(traced_model)

7. torch.cuda

核心功能
  • 设备管理:提供与 GPU 相关的操作,如设备计数、设备选择等。
  • 张量迁移:将张量从 CPU 移动到 GPU,以利用 GPU 加速计算。

示例代码

if torch.cuda.is_available():device = torch.device("cuda")x = torch.tensor([1.0, 2.0, 3.0]).to(device)print(f'GPU tensor: {x}')
else:print("CUDA is not available.")

8. torch.autograd

核心功能
  • 自动微分:提供自动计算梯度的功能,支持反向传播算法。
  • 计算图:动态构建计算图,并根据图计算梯度。

示例代码

x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
y = x + 2
z = y * y * 3
out = z.mean()# 反向传播计算梯度
out.backward()
print(x.grad)  # 输出 x 的梯度

9. torch.multiprocessing

核心功能
  • 多进程并行:用于在多核 CPU 上实现数据并行和模型并行,提高计算效率。
  • 与 Python 标准库 multiprocessing 的兼容:提供与标准库相似的接口。

示例代码

import torch.multiprocessing as mpdef worker(rank, data):print(f'Worker {rank} processing data: {data}')if __name__ == '__main__':data = [1, 2, 3, 4]mp.spawn(worker, args=(data,), nprocs=4)

10. torch.distributed

核心功能
  • 分布式训练:支持在多个 GPU 和多台机器上进行分布式训练。
  • 通信接口:提供多种通信后端,如 Gloo、NCCL 等。

示例代码

import torch
import torch.distributed as distdef init_process(rank, size, fn, backend='gloo'):dist.init_process_group(backend, rank=rank, world_size=size)fn(rank, size)def example(rank, size):tensor = torch.zeros(1)if rank == 0:tensor += 1dist.send(tensor, dst=1)else:dist.recv(tensor, src=0)print(f'Rank {rank} has data {tensor[0]}')if __name__ == "__main__":size = 2processes = []for rank in range(size):p = mp.Process(target=init_process, args=(rank, size, example))p.start()processes.append(p)for p in processes:p.join()

通过这些模块,PyTorch 提供了构建、训练、优化和部署深度学习模型所需的全面支持。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/876643.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux——简介

Linux的组成 Linux系统一般由四个主要部分组成:内核、shell、文件系统和应用程序。 内核:是操作系统的核心,负责管理系统的进程、内存、设备驱动程序、文件和网络系统等,决定着系统的性能和稳定性。shell:是系统的用…

2024:Qt--编译配置Protobuf(windows10) 配图详解

这里写自定义目录标题 一、准备1、Window10系统2、Qt Creator 5.0.2 Based on Qt 5.15.2 (MSVC 2019, 64 bit)3、protobuf-3.15.0(本示例使用版本)4、cmake-3.21.3-windows-x86_64(本示例使用,下载的zip直接解压使用) …

自编码器(autoencoder)

1.自编码器的由来 最初的自编码器是用来降维的,后来也逐渐用于去噪、生成任务。 2.自编码器的基本结构 自编码器(autoencoder)内部有一个隐藏层 h,可以产生编码(code)表示输入。该网络可以看作由两部分组…

ArcGIS Desktop使用入门(四)——ArcMap软件彻底卸载删除干净

系列文章目录 ArcGIS Desktop使用入门(一)软件初认识 ArcGIS Desktop使用入门(二)常用工具条——标准工具 ArcGIS Desktop使用入门(二)常用工具条——编辑器 ArcGIS Desktop使用入门(二&#x…

支持向量机回归及其应用(附Python 案例代码)

使用支持向量机回归估计房价 让我们看看如何使用支持向量机(SVM)的概念构建一个回归器来估计房价。我们将使用sklearn中提供的数据集,其中每个数据点由13个属性定义。我们的目标是根据这些属性估计房价。 引言 支持向量回归(SV…

vim的使用及退出码(return 0)

linux基础之vim快速入门 linux基础之vim快速入门_基本linux vim-CSDN博客https://blog.csdn.net/ypxcan/article/details/119878137?ops_request_misc&request_id&biz_id102&utm_termvim%E7%BC%96%E8%BE%91%E5%99%A8%E5%A4%8D%E5%88%B6%E7%B2%98%E8%B4%B4%E4%BA%…

Java(十)——接口

个人简介 👀个人主页: 前端杂货铺 ⚡开源项目: rich-vue3 (基于 Vue3 TS Pinia Element Plus Spring全家桶 MySQL) 🙋‍♂️学习方向: 主攻前端方向,正逐渐往全干发展 &#x1…

JAVA零基础学习3(Scanner类,字符串,StringBuilder,StringJoinder,ArrayList成员方法)

JAVA零基础学习3 Scanner类输入示例代码代码解释完整代码1. 读取字符串2. 读取整数3. 读取浮点数4. 读取布尔值5. 读取单个单词6. 读取长整型数7. 读取短整型数8. 读取字节数注意事项总结 API 字符串解释示例解释解决方法示例:使用 StringBuilder String…

口碑爆棚的高分法国电影,一起在光影中领略法式魅力吧!

文章目录 引言《与玛格丽特的午后》(网友评分:9.1)《午夜巴黎》(网友评分:8.3)《玫瑰人生》(网友评分:8.4)《双姝奇缘》(网友评分:8.7)《巴黎淘气帮》(网友评分:8.6)《触不可及》(网友评分:9.3)《爱在日落黄昏时》(网友评分:8.9)《悲惨世界》(网友评分:…

VScode使用Github Copilot插件时出现read ECONNREST问题的解决方法

文章目录 read ECONNREST查看是否仍是 Copilot 会员查看控制台输出网络连接问题浏览器设置问题笔者的话 read ECONNREST 最近使用 Copilot 时一直出现 read ECONNREST 问题,这个表示连接被对方重置了,就是说在读取数据时连接被关闭。 我首先怀疑是不是…

[023-2].第2节:SpringBoot中接收参数相关注解

我的后端学习大纲 SpringBoot学习大纲 1.1.基本介绍: SpringBoot接收客户端提交的数据、参数会使用的一些注解: 1.PathVarible2. RequestHeader3.RequestParam4.CookieValue5.RequestBody6.RequestAttribute 1.2.接收参数相关注解与应用实例:…

无人机制造工艺流程详解

一、需求分析 无人机制造的第一步是需求分析。这一阶段主要明确无人机的使用场景、功能要求、性能指标以及成本预算等。通过与客户或项目团队的深入沟通,确保对无人机的需求有全面而准确的理解。同时,也需要进行市场调研,了解同类型产品的特…

科普文:docker基础概念、软件安装和常用命令

docker基本概念 一 容器的概念 1. 什么是容器:容器是在隔离的环境里面运行的一个进程,这个隔离的环境有自己的系统目录文件,有自己的ip地址,主机名等。也可以说:容器是一种轻量级虚拟化的技术。 2. 容器相对于kvm虚…

如何使用 SQLite ?

SQLite 是一个轻量级、嵌入式的关系型数据库管理系统(RDBMS)。它是一种 C 库,实现了自给自足、无服务器、零配置、事务性 SQL 数据库引擎。SQLite 的源代码是开放的,完全在公共领域。它被广泛用于各种应用程序,包括浏览…

【C语言】函数的递归

目录 一、什么是递归 二、递归的思想 三、递归的限制条件 四、递归中的栈溢出 五、递归举例 (1)例1:n的阶乘 (2)例子2:顺序打印一个数的每一位 六、递归和迭代 七、拓展题目 (1&#…

Chainlit一个快速构建成式AI应用的Python框架,无缝集成与多平台部署

概述 Chainlit 是一个开源 Python 包,用于构建和部署生成式 AI 应用的开源框架。它提供了一种简单的方法来创建交互式的用户界面,这些界面可以与 LLM(大型语言模型)驱动的应用程序进行通信。Chainlit 旨在帮助开发者快速构建基于…

如何知道一个字段在selenium中是否可编辑?

这篇文章将检查我们如何使用Java检查selenium webdriver中的字段是否可编辑。 我们如何知道我们是否可以编辑字段?“readonly”属性控制字段的可编辑性。如果元素上存在“readonly”属性,则无法编辑或操作该元素或字段。 因此,如果我们找到一…

强大的开源网络攻击面分析工具:Hetty

Hetty:深入网络的每一个角落,Hetty让安全漏洞无处遁形。- 精选真开源,释放新价值。 概览 Hetty作为一个专为网络攻击面分析而设计的开源HTTP/1.1客户端,其设计初衷是为了帮助安全研究人员和渗透测试人员深入挖掘潜在的网络漏洞。…

[网鼎杯 2020 朱雀组]Nmap(详细解读版)

这道题考察nmap的一些用法,以及escapeshellarg和escapeshellcmd两个函数的绕过,可以看这里PHP escapeshellarg()escapeshellcmd() 之殇 (seebug.org) 两种解题方法: 第一种通过nmap的-iL参数读取扫描一个文件到指定文件中第二种是利用nmap的参数写入we…

NTC测温

前言 假设已知ad-温度转换表ad_table[100]; 数组元素ad_table[0] ~ ad_table[99] 对应温度0 ~ 99℃;已知MCU检测到NTC两端电压ad值位temp_ad,请写出将temp_ad转换成温度值的程序代码,要求温度值精确到0.1℃ 代码 为了将AD值转换为精确到0.1℃的温度值…