Pytorch 三小时极限入门教程

一、引言

在当今的人工智能领域,深度学习占据了举足轻重的地位。而 Pytorch 作为一款广受欢迎的深度学习框架,以其简洁、灵活的特性,吸引了大量开发者投身其中。无论是科研人员探索前沿的神经网络架构,还是工程师将深度学习技术落地到实际项目,Pytorch 都提供了强大的支持。本教程将带你从零基础开始,一步步深入了解 Pytorch 的核心知识,助你顺利踏上深度学习的征程。

二、Pytorch 基础环境搭建

安装 Anaconda

Anaconda 是一个强大的 Python 包管理器和环境管理器,方便我们创建独立的 Python 开发环境。首先,从 Anaconda 官方网站下载对应操作系统的安装包,一路默认安装即可。安装完成后,打开终端(Linux/Mac)或命令提示符(Windows),输入 conda --version 验证是否安装成功。

创建虚拟环境

使用 conda create -n pytorch_env python=3.8 创建一个名为 pytorch_env 的虚拟环境,这里指定 Python 版本为 3.8,你可以根据实际需求调整。激活虚拟环境,在 Linux/Mac 下使用 source activate pytorch_env,Windows 下使用 activate pytorch_env。

安装 Pytorch

访问 Pytorch 官方网站,根据你的系统配置(如 CUDA 是否可用)选择合适的安装命令。例如,如果你的电脑有 NVIDIA GPU 且支持 CUDA 11.3,安装命令可能为 conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch。如果没有 GPU,则选择 CPU 版本的安装命令,如 conda install pytorch torchvision torchaudio cpuonly -c pytorch。安装完成后,在 Python 交互式环境中输入 import torch,没有报错则说明安装成功。

三、张量(Tensor):深度学习的基石

张量的定义与创建

张量是 Pytorch 中最基本的数据结构,类似于 NumPy 中的数组,但具有更强的功能。可以使用 torch.tensor() 函数从 Python 列表或 NumPy 数组创建张量,例如:

import torchimport numpy as np# 从列表创建张量data_list = [1, 2, 3, 4]tensor_from_list = torch.tensor(data_list)# 从 NumPy 数组创建张量np_array = np.array([5, 6, 7, 8])tensor_from_numpy = torch.from_numpy(np_array)

还可以使用 torch.zeros()、torch.ones()、torch.rand() 等函数创建具有特定形状的全 0、全 1 或随机值张量。

张量的属性与操作

张量具有形状(shape)、数据类型(dtype)等属性。可以通过 .shape 和 .dtype 来访问,例如:

tensor = torch.rand(3, 4)print(tensor.shape)print(tensor.dtype)

张量支持丰富的数学运算,如加法、减法、乘法、除法等,操作符重载使得代码简洁直观:

a = torch.rand(2, 3)b = torch.rand(2, 3)c = a + bd = a * b

同时,也有大量的函数可供调用,像 torch.sum()、torch.mean() 等用于统计计算。

四、自动求导(Autograd):神经网络训练的关键

自动求导原理简介

在深度学习中,模型训练的核心是反向传播算法,而 Pytorch 的自动求导机制极大地简化了这一过程。当创建一个张量时,如果设置 requires_grad=True,Pytorch 会记录该张量上的所有操作,构建一个计算图。在反向传播时,利用这个计算图自动计算梯度。

示例:简单函数求导

x = torch.tensor([2.], requires_grad=True)y = x ** 2 + 3 * xy.backward()print(x.grad)

这里定义了一个简单的函数 ,对 x 求导后,x.grad 存储了梯度值,即 在 时的值 7。

 复杂模型中的应用

在构建神经网络时,模型参数都设置为 requires_grad=True。在每一次前向传播计算损失后,通过 loss.backward() 反向传播梯度,然后使用优化器(如 SGD、Adam 等)根据梯度更新参数,实现模型的训练。

五、神经网络模块(nn.Module):构建模型的利器

自定义神经网络

继承 nn.Module 类可以方便地自定义神经网络。首先在 __init__() 函数中定义模型的层结构,如全连接层 nn.Linear,卷积层 nn.Conv2d 等,然后在 forward() 函数中定义数据的前向传播路径。

import torch.nn as nnclass SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.fc1 = nn.Linear(10, 20)self.fc2 = nn.Linear(20, 1)def forward(self, x):x = torch.relu(self.fc1(x))x = self.fc2(x)return x

这里定义了一个简单的两层全连接神经网络,输入维度为 10,中间层维度为 20,输出维度为 1,中间使用 ReLU 作为激活函数。

预训练模型的使用与微调

Pytorch 提供了丰富的预训练模型,如 ResNet、VGG 等经典的图像分类模型。可以通过 torchvision.models 模块加载预训练模型,然后根据自己的任务需求,修改最后几层的结构并进行微调。例如:

import torchvision.models as modelsresnet = models.resnet18(pretrained=True)# 修改最后一层输出维度为自定义类别数resnet.fc = nn.Linear(resnet.fc.in_features, 10)

这使得在数据量有限的情况下,也能利用预训练模型的强大特征提取能力,快速搭建高性能模型。

六、数据加载与预处理(DataLoader)

数据集类的构建

要使用自己的数据训练模型,需要构建自定义数据集类,继承 torch.utils.data.Dataset。在类中实现 __getitem__() 方法用于获取单个样本及其标签,__len__() 方法返回数据集的大小。例如,对于图像分类数据集:

from torch.utils.data import Datasetimport osimport cv2class ImageDataset(Dataset):def __init__(self, root_dir, transform=None):self.root_dir = root_dirself.image_files = os.listdir(root_dir)self.transform = transformdef __getitem__(self, index):image_path = os.path.join(self.root_dir, self.image_files[index])image = cv2.imread(image_path)image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)label = int(self.image_files[index].split('.')[0])if self.transform:image = self.transform(image)return image, labeldef __len__(self):return len(self.image_files)

数据加载器的使用

使用 torch.utils.data.DataLoader 将数据集封装成可迭代的数据加载器,方便在训练过程中批量获取数据。可以设置批量大小(batch_size)、是否打乱数据(shuffle)等参数,例如:

from torch.utils.data import DataLoaderdataset = ImageDataset(root_dir='data/images', transform=transforms.ToTensor())dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

在训练循环中,通过遍历数据加载器获取批量数据,送入模型进行训练。

七、模型训练与评估

训练循环

模型训练通常包括多个 epoch,每个 epoch 遍历一遍整个数据集。在每个 epoch 内,按批次获取数据,前向传播计算损失,反向传播更新参数。以下是一个简单的训练循环示例:

model = SimpleNet()criterion = nn.MSELoss()optimizer = torch.optim.SGD(model.parameters(), lr=0.01)for epoch in range(10):running_loss = 0.0for i, (inputs, labels) in enumerate(dataloader):optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f'Epoch {epoch + 1}, Loss: {running_loss / len(dataloader)}')

评估指标与方法

根据任务不同,评估指标各异。对于分类任务,常用准确率(Accuracy),可以通过比较模型预测结果与真实标签计算得出:

correct = 0total = 0with torch.no_grad():for inputs, labels in dataloader:outputs = model(inputs)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = correct / totalprint(f'Accuracy: {accuracy}')

对于回归任务,可能使用均方误差(MSE)、平均绝对误差(MAE)等指标。

八、模型保存与加载

保存模型

可以使用 torch.save() 保存模型的参数或整个模型结构,例如保存模型参数:

torch.save(model.state_dict(), 'model.pth')

若要保存整个模型,包括结构和参数:

torch.save(model, 'whole_model.pth')

加载模型

加载模型参数时,先创建模型实例,再使用 model.load_state_dict(torch.load('model.pth')) 加载。若加载整个模型,则直接 model = torch.load('whole_model.pth')。加载后,模型即可用于预测或继续训练。

九、可视化工具(TensorBoard)

安装与配置

TensorBoard 是一个强大的可视化工具,用于监控模型训练过程。使用 pip install tensorboard 安装,在 Pytorch 代码中引入相关模块:

from torch.utils.tensorboard import SummaryWriter

创建一个 SummaryWriter 实例,指定日志目录,如 writer = SummaryWriter('logs')。

可视化训练过程

在训练过程中,可以使用 writer.add_scalar() 记录损失、准确率等指标随 epoch 的变化:

for epoch in range(10):# 训练代码...writer.add_scalar('Loss', running_loss / len(dataloader), epoch)writer.add_scalar('Accuracy', accuracy, epoch)writer.close()

运行 tensorboard --logdir=logs 命令后,在浏览器中打开相应地址,即可查看可视化图表,直观了解模型训练动态。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/66198.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

凸包(convex hull)简述

凸包(convex hull)简述 这里主要介绍二维凸包,二维凸多边形是指所有内角都在 [ 0 , Π ] [0,\Pi ] [0,Π]范围内的简单多边形。 凸包是指在平面上包含所有给定点的最小凸多边形。 数学定义:对于给定集合 X X X,所有…

小波与傅里叶变换在去噪效果上的对比分析-附Matlab源程序

👨‍🎓 博主简介:博士研究生 🔬 超级学长:超级学长实验室(提供各种程序开发、实验复现与论文指导) 📧 个人邮箱:easy_optics126.com 🕮 目 录 摘要一、…

CVPR2019 | AA | 特征空间扰动产生更具迁移性的对抗样本

Feature Space Perturbations Yield More Transferable Adversarial Examples 摘要-Abstract引言-Introduction相关工作-Related WorkTransferability Metrics-迁移性指标激活攻击方法-Activation Attack Methodology损失函数-Loss Function攻击算法-Attack Algorithm 实验设置…

游戏如何检测Root权限

Root权限,即超级用户权限,在Android系统中,获取Root权限意味着用户可以修改系统文件、移除预装应用、安装特殊应用等。 在Root环境下,游戏面临着相当大的安全隐患,用户获取了最高权限,意味着可以通过各类工…

MySQL性能优化explain关键字详解

系列文章目录 一、MySQL数据结构选择 二、MySQL性能优化explain关键字详解 三、MySQL索引优化 文章目录 系列文章目录一、explain是什么?二、explain字段详解2.1、ID2.2、select_type2.3、table2.4、partitions2.5、type(重点)2.6、key2.7、…

【Go学习】-01-5-网络编程

【Go学习】-01-5-网络编程 1 互联网协议介绍1.1 互联网分层模型 2 Go网络编程2.1 socket编程2.1.1 socket图解2.2.2 TCP编程2.2.3 UDP编程 2.3 http编程2.3.1 web工作流程2.3.2 HTTP协议 2.4 WebSocket编程2.5 聊天室的小例子2.5.1 server.go文件代码2.5.2 hub.go文件代码2.5.3…

推荐系统重排:MMR 多样性算法

和谐共存:相关性与多样性在MMR中共舞 推荐系统【多样性算法】系列文章(置顶) 1.推荐系统重排:MMR 多样性算法 2.推荐系统重排:DPP 多样性算法 引言 在信息检索和推荐系统中,提供既与用户查询高度相关的文…

简历_熟悉缓存高并发场景处理方法,如缓存穿透、缓存击穿、缓存雪崩

系列博客目录 文章目录 系列博客目录1.缓存穿透总结 2.缓存雪崩3.缓存击穿代码总结 1.缓存穿透 缓存穿透是指客户端请求的数据在缓存中和数据库中都不存在,这样缓存永远不会生效,这些请求都会打到数据库。 常见的解决方案有两种: 缓存空对…

Rabbitmq追问1

如果消费端代码异常,未手动确认,那么这个消息去哪里 2024-12-31 21:19:12 如果消费端代码发生异常,未手动确认(ACK)的情况下,消息的处理行为取决于消息队列的实现和配置,以下是基于 RabbitMQ …

STM32-笔记37-吸烟室管控系统项目

一、项目需求 1. 使用 mq-2 获取环境烟雾值,并显示在 LCD1602 上; 2. 按键修改阈值,并显示在 LCD1602 上; 3. 烟雾值超过阈值时,蜂鸣器长响,风扇打开;烟雾值小于阈值时,蜂鸣器不响…

2、pycharm常用快捷命令和配置【持续更新中】

1、常用快捷命令 Ctrl / 行注释/取消行注释 Ctrl Alt L 代码格式化 Ctrl Alt I 自动缩进 Tab / Shift Tab 缩进、不缩进当前行 Ctrl N 跳转到类 Ctrl 鼠标点击方法 可以跳转到方法所在的类 2、使用pip命令安装request库 命令:pip install requests 安装好了…

SpringCloud系列教程:微服务的未来(八)项目部署、DockerCompose

本博客将重点介绍如何在 Docker 环境中部署一个 Java 项目,并使用 Docker Compose 来简化和管理多个服务的协调部署。我们将通过一个典型的 Java Web 应用(如基于 Spring Boot 的应用)为例,演示如何构建、配置和运行 Docker 容器&…

微信小程序滑动解锁、滑动验证

微信小程序简单滑动解锁 效果 通过 movable-view (可移动的视图容器,在页面中可以拖拽滑动)实现的简单微信小程序滑动验证 movable-view 官方说明:https://developers.weixin.qq.com/miniprogram/dev/component/movable-view.ht…

Conda 安装 Jupyter Notebook

文章目录 1. 安装 Conda下载与安装步骤: 2. 创建虚拟环境3. 安装 Jupyter Notebook4. 启动 Jupyter Notebook5. 安装扩展功能(可选)6. 更新与维护7. 总结 Jupyter Notebook 是一款非常流行的交互式开发工具,尤其适合数据科学、机器…

【小程序开发】- 小程序版本迭代指南(版本发布教程)

一,版本号 版本号是小程序版本的标识,通常由一系列数字组成,如 1.0.0、1.1.0 等。版本号的格式通常是 主版本号.次版本号.修订号 主版本号:当小程序有重大更新或不兼容的更改时,主版本号会增加。 次版本号&#xff1a…

【保姆级】sql注入之堆叠注入

一、堆叠注入的原理 mysql数据库sql语句的默认结束符是以";"号结尾,在执行多条sql语句时就要使用结束符隔 开,而堆叠注入其实就是通过结束符来执行多条sql语句 比如我们在mysql的命令行界面执行一条查询语句,这时语句的结尾必须加上分号结束 select * fr…

Word如何设置整段背景色

1) 不是1),也不是2),而是3)的样式 2) 红色标出这个地方有上边框,点击“边框和底纹” 3)点击底纹Tab页,再填充,选择要的颜色就OK啦。

Nginx:性能优化

性能优化是确保 Nginx 在高负载下依然能够高效运行的关键部分。通过合理的配置和调优,可以显著提升 Web 服务的响应速度、吞吐量以及资源利用率。 1. 调整工作进程数、并发连接数以及cpu亲和性 worker_processes:根据 CPU 核心数设置适当的工作进程数。一般cpu有多少核,就设…

分布式事务介绍 Seata架构与原理+部署TC服务 示例:黑马商城

1. 什么是分布式事务? 在分布式系统中,如果一个业务需要多个服务合作完成,而且每一个服务都有事务,多个事务必须同时成功或失败,这样的事务就是分布式事务。其中的每个服务的事务就是一个分支事务。整个业务称为全局事务。 打个比…

C#运动控制系统:雷赛控制卡实用完整例子 C#雷赛开发快速入门 C#雷赛运动控制系统实战例子 C#快速开发雷赛控制卡

雷赛控制技术 DMC系列运动控制卡是一款新型的 PCI/PCIe 总线运动控制卡。可以控制多个步进电机或数字式伺服电机;适合于多轴点位运动、插补运动、轨迹规划、手轮控制、编码器位置检测、IO 控制、位置比较、位置锁存等功能的应用。 DMC3000 系列卡的运动控制函数库功…