AI开发学习之——PyTorch框架

PyTorch 简介

PyTorch (Python torch)是由 Facebook AI 研究团队开发的开源机器学习库,广泛应用于深度学习研究和生产。它以动态计算图和易用性著称,支持 GPU 加速计算,并提供丰富的工具和模块。

PyTorch的主要特点

  1. 动态计算图:PyTorch 使用动态计算图(Autograd),允许在运行时修改图结构,便于调试和实验。
  2. GPU 加速:支持 CUDA,能够利用 GPU 进行高效计算。
  3. 模块化设计:提供 torch.nn 等模块,便于构建和训练神经网络。
  4. 丰富的生态系统:包括 TorchVision、TorchText 和 TorchAudio 等,支持多种任务。、

PyTorch的安装

通过以下命令安装 PyTorch:

pip install torch torchvision

如果国内的速度慢,可以使用-i 参数使用国内的仓库源。

pip3 install torch -i https://pypi.tuna.tsinghua.edu.cn/simple

除了清华的源之外,也可以使用科大或是北外的数据源。

  • https://mirrors.bfsu.edu.cn/pypi/web/simple

  • https://mirrors.ustc.edu.cn/pypi/web/simple

使用示例

1. 张量操作
import torch# 创建张量
x = torch.tensor([1.0, 2.0, 3.0])
y = torch.tensor([4.0, 5.0, 6.0])# 加法
z = x + y
print(z)  # 输出: tensor([5., 7., 9.])

这里的输出为什么不是 tensor([5.0, 7.0, 9.0])呢?
在Python的浮点数表示中,.0后缀通常用于明确表示一个数是浮点数(float),而不是整数(int)。然而,在大多数情况下,Python和许多库(包括PyTorch,这里提到的tensor是由PyTorch生成的)在打印浮点数时,如果小数点后没有额外的数字,它们可能会省略.0后缀以简化输出。

当使用科学计算库如NumPy或PyTorch时,它们通常有统一的输出格式,尤其是在处理数组或tensor时。在你的例子中,tensor([5., 7., 9.])tensor([5.0, 7.0, 9.0])在数值上是完全相同的,只是表示形式略有不同。PyTorch选择省略小数点后没有数字的.0后缀,以使输出更简洁。

这种输出格式的选择主要是出于可读性和简洁性的考虑,并不影响tensor中存储的实际数值。在数值计算中,5.5.0都被视为浮点数,并且在计算中没有任何区别。

2. 自动求导
import torch# 创建需要梯度的张量
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)# 定义函数
y = x * 2
z = y.mean()# 反向传播
z.backward()# 查看梯度
print(x.grad)  # 输出: tensor([0.6667, 0.6667, 0.6667])

这里的结果是怎么来的呢?

这段代码演示了 PyTorch 中的**自动微分(Autograd)**机制,通过计算梯度来实现反向传播。我们来逐步分析代码的运算过程。


1. 创建需要梯度的张量
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
  • x 是一个包含 [1.0, 2.0, 3.0] 的 1 阶张量(向量)。
  • requires_grad=True 表示 PyTorch 需要跟踪对 x 的所有操作,以便后续计算梯度。

2. 定义函数
y = x * 2
z = y.mean()
  • y = x * 2:对 x 逐元素乘以 2,得到 y = [2.0, 4.0, 6.0]
  • z = y.mean():计算 y 的均值,即:
    在这里插入图片描述

3. 反向传播
z.backward()
  • z.backward() 表示从 z 开始反向传播,计算 zx 的梯度。
  • 由于 z 是一个标量(单个值),PyTorch 会自动计算 zx 的梯度。

4. 梯度计算

PyTorch 通过链式法则计算梯度。具体步骤如下:

(1)计算 zy 的梯度

  • z = y.mean() 可以写成:
    在这里插入图片描述

  • 因此,zy 的梯度为:
    在这里插入图片描述

(2)计算 yx 的梯度

  • y = x * 2 可以写成:
    yi​=2xi​
  • 因此,yx 的梯度为:
    在这里插入图片描述

(3)计算 zx 的梯度
根据链式法则:
在这里插入图片描述

将结果代入:
在这里插入图片描述


5. 查看梯度
print(x.grad)  # 输出: tensor([0.6667, 0.6667, 0.6667])
  • x.grad 存储了 zx 的梯度,结果为:
    在这里插入图片描述

总结

这段代码的运算过程如下:

  1. 创建需要梯度的张量 x
  2. 定义函数 y = x * 2z = y.mean()
  3. 通过 z.backward() 计算 zx 的梯度。
  4. 根据链式法则,梯度计算结果为 [0.6667, 0.6667, 0.6667]

PyTorch 的自动微分机制使得梯度计算变得非常简单,尤其是在深度学习模型中,这种机制可以自动计算损失函数对模型参数的梯度,从而支持梯度下降等优化算法。

3. 简单神经网络
import torch
import torch.nn as nn
import torch.optim as optim# 定义网络
class SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.fc = nn.Linear(1, 1)def forward(self, x):return self.fc(x)# 创建网络、损失函数和优化器
model = SimpleNet()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 训练数据
x = torch.tensor([[1.0], [2.0], [3.0], [4.0]])
y = torch.tensor([[2.0], [4.0], [6.0], [8.0]])# 训练过程
for epoch in range(100):optimizer.zero_grad()outputs = model(x)loss = criterion(outputs, y)loss.backward()optimizer.step()if (epoch+1) % 10 == 0:print(f'Epoch [{epoch+1}/100], Loss: {loss.item():.4f}')
4. 使用 GPU
import torch# 检查 GPU 是否可用
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 创建张量并移动到 GPU
x = torch.tensor([1.0, 2.0, 3.0]).to(device)
y = torch.tensor([4.0, 5.0, 6.0]).to(device)# 在 GPU 上执行加法
z = x + y
print(z)  # 输出: tensor([5., 7., 9.], device='cuda:0')

torchtorchvisiontorchaudio

torchtorchvisiontorchaudio 是 PyTorch 生态系统中的三个核心库,分别用于通用深度学习、计算机视觉和音频处理任务。以下是它们的详细介绍和作用:


1. torch

torch 是 PyTorch 的核心库,提供了深度学习的基础功能,包括张量操作、自动求导、神经网络模块等。

主要功能:
  • 张量操作:支持高效的张量计算(如加法、乘法、矩阵运算等)。
  • 自动求导:通过 Autograd 模块实现自动微分,便于梯度计算和优化。
  • 神经网络模块:提供 torch.nn 模块,包含各种层(如全连接层、卷积层)和损失函数。
  • 优化器:提供 torch.optim 模块,包含 SGD、Adam 等优化算法。
  • GPU 加速:支持 CUDA,可以利用 GPU 进行高性能计算。
使用场景:
  • 构建和训练深度学习模型。
  • 实现自定义的数学运算和算法。
  • 进行张量计算和数值模拟。
示例:
import torch# 创建张量
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)# 定义计算
y = x * 2
z = y.mean()# 自动求导
z.backward()# 查看梯度
print(x.grad)  # 输出: tensor([0.6667, 0.6667, 0.6667])

2. torchvision

torchvision 是 PyTorch 的计算机视觉库,提供了常用的数据集、模型架构和图像处理工具。

主要功能:
  • 数据集:提供常用的计算机视觉数据集(如 MNIST、CIFAR-10、ImageNet)。
  • 模型架构:包含预训练的经典模型(如 ResNet、VGG、AlexNet)。
  • 图像处理工具:提供数据增强和转换工具(如裁剪、旋转、归一化)。
  • 实用工具:包括可视化工具和评估指标。
使用场景:
  • 图像分类、目标检测、分割等计算机视觉任务。
  • 加载和处理图像数据。
  • 使用预训练模型进行迁移学习。
示例:
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18# 数据预处理
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])# 加载数据集
dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)# 加载预训练模型
model = resnet18(pretrained=True)

3. torchaudio

torchaudio 是 PyTorch 的音频处理库,提供了音频数据的加载、处理和转换工具。

主要功能:
  • 音频加载和保存:支持多种音频格式(如 WAV、MP3)。
  • 音频处理:提供音频信号处理工具(如重采样、频谱图生成)。
  • 数据集:包含常用的音频数据集(如 LibriSpeech、VoxCeleb)。
  • 特征提取:支持提取 MFCC、Mel 频谱等音频特征。
使用场景:
  • 语音识别、语音合成、音频分类等任务。
  • 音频数据的预处理和特征提取。
  • 加载和处理音频数据集。
示例:
import torchaudio
import torchaudio.transforms as T# 加载音频文件
waveform, sample_rate = torchaudio.load('example.wav')# 重采样
resampler = T.Resample(orig_freq=sample_rate, new_freq=16000)
resampled_waveform = resampler(waveform)# 提取 Mel 频谱
mel_spectrogram = T.MelSpectrogram(sample_rate=16000)(resampled_waveform)

三者的关系

  • torch 是核心库,提供基础功能(如张量计算、自动求导、神经网络模块)。
  • torchvision 是基于 torch 的扩展库,专注于计算机视觉任务。
  • torchaudio 是基于 torch 的扩展库,专注于音频处理任务。

三者可以结合使用,例如:

  • 使用 torchvision 处理图像数据,用 torch 构建和训练模型。
  • 使用 torchaudio 处理音频数据,用 torch 构建语音识别模型。

安装

可以通过以下命令安装这三个库:

pip install torch torchvision torchaudio

总结

  • torch:核心库,提供深度学习的基础功能。
  • torchvision:计算机视觉库,提供数据集、模型和图像处理工具。
  • torchaudio:音频处理库,提供音频加载、处理和特征提取工具。

三者共同构成了 PyTorch 的完整生态系统,适用于各种深度学习任务。



本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/894404.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python安居客二手小区数据爬取(2025年)

目录 2025年安居客二手小区数据爬取观察目标网页观察详情页数据准备工作:安装装备就像打游戏代码详解:每行代码都是你的小兵完整代码大放送爬取结果 2025年安居客二手小区数据爬取 这段时间需要爬取安居客二手小区数据,看了一下相关教程基本…

OpenCV:开运算

目录 1. 简述 2. 用腐蚀和膨胀实现开运算 2.1 代码示例 2.2 运行结果 3. 开运算接口 3.1 参数详解 3.2 代码示例 3.3 运行结果 4. 开运算应用场景 5. 注意事项 6. 总结 相关阅读 OpenCV:图像的腐蚀与膨胀-CSDN博客 OpenCV:闭运算-CSDN博客 …

JavaWeb入门-请求响应(Day3)

(一)请求响应概述 请求(HttpServletRequest):获取请求数据 响应(HttpServletResponse):设置响应数据 BS架构:Browser/Server,浏览器/服务器架构模式。客户端只需要浏览器就可访问,应用程序的逻辑和数据都存储在服务端(维护方便,响应速度一般) CS架构:Client/ser…

【SLAM】于AutoDL云上GPU运行GCNv2_SLAM的记录

配置GCNv2_SLAM所需环境并实现AutoDL云端运行项目的全过程记录。 本文首发于❄慕雪的寒舍 1. 引子 前几天写了一篇在本地虚拟机里面CPU运行GCNv2_SLAM项目的博客:链接,关于GCNv2_SLAM项目相关的介绍请移步此文章,本文不再重复说明。 GCNv2:…

罗格斯大学:通过输入嵌入对齐选择agent

📖标题:AgentRec: Agent Recommendation Using Sentence Embeddings Aligned to Human Feedback 🌐来源:arXiv, 2501.13333 🌟摘要 🔸多代理系统必须决定哪个代理最适合给定的任务。我们提出了一种新的架…

团体程序设计天梯赛-练习集——L1-025 正整数A+B

一年之际在于春,新年的第一天,大家敲代码了吗?哈哈 前言 这道题分值是15分,值这个分,有一小点运算,难度不大,虽然说做出来了,但是有两个小疑点。 L1-025 正整数AB 题的目标很简单…

Leetcode:598

1,题目 2,思路 脑筋急转弯,看题目一时半会还没搞懂意思。 其实不然就是说ops是个矩阵集合,集合的每个矩阵有俩个元素理解为行列边距 m和n是理解为一个主矩阵,计算ops的每个小矩阵还有这个主矩阵的交集返回面积 3&…

web前端12--表单和表格

1、表格标签 使用<table>标签来定义表格 HTML 中的表格和Excel中的表格是类似的&#xff0c;都包括行、列、单元格、表头等元素。 区别&#xff1a;HTML表格在功能方面远没有Excel表格强大&#xff0c;HTML表格不支持排序、求和、方差等数学计算&#xff0c;主要用于布…

【AI】探索自然语言处理(NLP):从基础到前沿技术及代码实践

Hi &#xff01; 云边有个稻草人-CSDN博客 必须有为成功付出代价的决心&#xff0c;然后想办法付出这个代价。 目录 引言 1. 什么是自然语言处理&#xff08;NLP&#xff09;&#xff1f; 2. NLP的基础技术 2.1 词袋模型&#xff08;Bag-of-Words&#xff0c;BoW&#xff…

第1章 量子暗网中的血色黎明

月球暗面的危机与阴谋 量子隧穿效应催生的幽蓝电弧&#xff0c;于环形山表面肆意跳跃&#xff0c;仿若无数奋力挣扎的机械蠕虫&#xff0c;将月球暗面的死寂打破&#xff0c;徒增几分诡异。艾丽伫立在被遗弃的“广寒宫”量子基站顶端&#xff0c;机械义眼之中&#xff0c;倒映着…

AI-ISP论文Learning to See in the Dark解读

论文地址&#xff1a;Learning to See in the Dark 图1. 利用卷积网络进行极微光成像。黑暗的室内环境。相机处的照度小于0.1勒克斯。索尼α7S II传感器曝光时间为1/30秒。(a) 相机在ISO 8000下拍摄的图像。(b) 相机在ISO 409600下拍摄的图像。该图像存在噪点和色彩偏差。©…

【Git】初识Git Git基本操作详解

文章目录 学习目标Ⅰ. 初始 Git&#x1f4a5;注意事项 Ⅱ. Git 安装Linux-centos安装Git Ⅲ. Git基本操作一、创建git本地仓库 -- git init二、配置 Git -- git config三、认识工作区、暂存区、版本库① 工作区② 暂存区③ 版本库④ 三者的关系 四、添加、提交更改、查看提交日…

使用 Spring JDBC 进行数据库操作:深入解析 JdbcTemplate

目录 1. Spring JDBC 简介 2. JdbcTemplate 介绍 3. 创建数据库和表 4. 配置 Spring JDBC 5. 创建实体类 6. 使用 JdbcTemplate 实现增、删、改、查操作 7. Spring JDBC 优点 8. 小结 1. Spring JDBC 简介 Spring JDBC 是 Spring 框架中的一个模块&#xff0c;旨在简化…

BUUCTF [Black Watch 入群题]PWN1 题解

1.下载文件 exeinfo checksec 32位 IDA32 看到关键函数 read两次 第一次read的变量s在bss段&#xff1b;第二次的buf到ebp距离为 24 但是第二次的read字节只能刚好填满返回地址 传不进去变量 所以想到栈迁移 将栈移动到变量s所在位置上来 同时 这题开了NX 无直接的binsh和s…

Cubemx文件系统挂载多设备

cubumx版本&#xff1a;6.13.0 芯片&#xff1a;STM32F407VET6 在上一篇文章中介绍了Cubemx的FATFS和SD卡的配置&#xff0c;由于SD卡使用的是SDIO通讯&#xff0c;因此具体驱动不需要自己实现&#xff0c;Cubemx中就可以直接配置然后生成SDIO的驱动&#xff0c;并将SD卡驱动和…

java练习(2)

回文数&#xff08;题目来自力扣&#xff09; 给你一个整数 x &#xff0c;如果 x 是一个回文整数&#xff0c;返回 true &#xff1b;否则&#xff0c;返回 false 。 回文数 是指正序&#xff08;从左向右&#xff09;和倒序&#xff08;从右向左&#xff09;读都是一样的整…

使用 Tauri 2 + Next.js 开发跨平台桌面应用实践:Singbox GUI 实践

Singbox GUI 实践 最近用 Tauri Next.js 做了个项目 - Singbox GUI&#xff0c;是个给 sing-box 用的图形界面工具。支持 Windows、Linux 和 macOS。作为第一次接触这两个框架的新手&#xff0c;感觉收获还蛮多的&#xff0c;今天来分享下开发过程中的一些经验~ 为啥要做这个…

ComfyUI安装调用DeepSeek——DeepSeek多模态之图形模型安装问题解决(ComfyUI-Janus-Pro)

ComfyUI 的 Janus-Pro 节点&#xff0c;一个统一的多模态理解和生成框架。 试用&#xff1a; https://huggingface.co/spaces/deepseek-ai/Janus-1.3B https://huggingface.co/spaces/deepseek-ai/Janus-Pro-7B https://huggingface.co/spaces/deepseek-ai/JanusFlow-1.3B 安装…

索引的底层数据结构、B+树的结构、为什么InnoDB使用B+树而不是B树呢

索引的底层数据结构 MySQL中常用的是Hash索引和B树索引 Hash索引&#xff1a;基于哈希表实现的&#xff0c;查找速度非常快&#xff0c;但是由于哈希表的特性&#xff0c;不支持范围查找和排序&#xff0c;在MySQL中支持的哈希索引是自适应的&#xff0c;不能手动创建 B树的…

RK3568中使用QT opencv(显示基础图像)

文章目录 一、查看对应的开发环境是否有opencv的库二、QT使用opencv一、查看对应的开发环境是否有opencv的库 在开发板中的/usr/lib目录下查看是否有opencv的库: 这里使用的是正点原子的ubuntu虚拟机,在他的虚拟机里面已经安装好了opencv的库。 二、QT使用opencv 在QT pr…