使用Pytorch构建神经网络

构建神经网络的典型流程

  • 定义一个拥有可学习参数的神经网络
  • 遍历训练数据集
  • 处理输入数据使其流经神经网络
  • 计算损失值
  • 将网络参数的梯度进行反向传播
  • 以一定的规则更新网络的权重

我们首先定义一个Pytorch实现的神经网络:

# 导入若干工具包
import torch
import torch.nn as nn
import torch.nn.functional as F# 定义一个简单的网络类
class Net(nn.Module):def __init__(self):super(Net, self).__init__()# 定义第一层卷积神经网络, 输入通道维度=1, 输出通道维度=6, 卷积核大小3*3self.conv1 = nn.Conv2d(1, 6, 3)# 定义第二层卷积神经网络, 输入通道维度=6, 输出通道维度=16, 卷积核大小3*3self.conv2 = nn.Conv2d(6, 16, 3)# 定义三层全连接网络self.fc1 = nn.Linear(16 * 6 * 6, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):# (2, 2)的池化窗口下执行最大池化操作x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))x = F.max_pool2d(F.relu(self.conv2(x)), 2)x = x.view(-1, self.num_flat_features(x))x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xdef num_flat_features(self, x):# 计算size, 除了第0个维度上的batch_sizesize = x.size()[1:]num_features = 1for s in size:num_features *= sreturn num_featuresnet = Net()
print(net)

运行结果
在这里插入图片描述
注意:
模型中所有的可训练参数, 可以通过net.parameters()来获得.

params = list(net.parameters())
print(len(params))
print(params[0].size())

运行结果:
在这里插入图片描述

  • 假设图像的输入尺寸为32 * 32:
input = torch.randn(1, 1, 32, 32)
out = net(input)
print(out)

运行结果
在这里插入图片描述

  • 有了输出张量后, 就可以执行梯度归零和反向传播的操作了.
net.zero_grad()
out.backward(torch.randn(1, 10))
  • 注意
    - torch.nn构建的神经网络只支持mini-batches的输入, 不支持单一样本的输入.
    - 比如: nn.Conv2d 需要一个4D Tensor, 形状为(nSamples, nChannels, Height, Width). 如果你的输入只有单一样本形式, 则需要执行input.unsqueeze(0), 主动将3D Tensor扩充成4D Tensor.

损失函数

  • 损失函数的输入是一个输入的pair: (output, target), 然后计算出一个数值来评估output和target之间的差距大小.
  • 在torch.nn中有若干不同的损失函数可供使用, 比如nn.MSELoss就是通过计算均方差损失来评估输入和目标值之间的差距
  • 应用nn.MSELoss计算损失的一个例子:
output = net(input)
target = torch.randn(10)# 改变target的形状为二维张量, 为了和output匹配
target = target.view(1, -1)
criterion = nn.MSELoss()loss = criterion(output, target)
print(loss)

运行结果:
在这里插入图片描述

  • 关于方向传播的链条: 如果我们跟踪loss反向传播的方向, 使用.grad_fn属性打印, 将可以看到一张完整的计算图如下:
input -> conv2d -> relu -> maxpool2d -> conv2d -> relu -> maxpool2d-> view -> linear -> relu -> linear -> relu -> linear-> MSELoss-> loss
  • 当调用loss.backward()时, 整张计算图将对loss进行自动求导, 所有属性requires_grad=True的Tensors都将参与梯度求导的运算, 并将梯度累加到Tensors中的.grad属性中.
print(loss.grad_fn)  # MSELoss
print(loss.grad_fn.next_functions[0][0])  # Linear
print(loss.grad_fn.next_functions[0][0].next_functions[0][0])  # ReLU

运行结果:
在这里插入图片描述
反向传播(backpropagation)

  • 在Pytorch中执行反向传播非常简便, 全部的操作就是loss.backward().
  • 在执行反向传播之前, 要先将梯度清零,否则梯度会在不同的批次数据之间被累加.
    执行一个反向传播的小例子:
# Pytorch中执行梯度清零的代码
net.zero_grad()print('conv1.bias.grad before backward')
print(net.conv1.bias.grad)# Pytorch中执行反向传播的代码
loss.backward()print('conv1.bias.grad after backward')
print(net.conv1.bias.grad)

运行结果:
在这里插入图片描述
更新网络参数

  • 更新参数最简单的算法就是SGD(随机梯度下降).
  • 具体的算法公式表达式为: weight = weight - learning_rate
    gradient 首先用传统的Python代码来实现SGD如下:
learning_rate = 0.01
for f in net.parameters():f.data.sub_(f.grad.data * learning_rate)

然后使用Pytorch官方推荐的标准代码如下:

# 首先导入优化器的包, optim中包含若干常用的优化算法, 比如SGD, Adam等
import torch.optim as optim# 通过optim创建优化器对象
optimizer = optim.SGD(net.parameters(), lr=0.01)# 将优化器执行梯度清零的操作
optimizer.zero_grad()output = net(input)
loss = criterion(output, target)# 对损失值执行反向传播的操作
loss.backward()
# 参数的更新通过一行标准代码来执行
optimizer.step()

小节总结
学习了构建一个神经网络的典型流程:

  • 定义一个拥有可学习参数的神经网络
  • 遍历训练数据集
  • 处理输入数据使其流经神经网络
  • 计算损失值
  • 将网络参数的梯度进行反向传播
  • 以一定的规则更新网络的权重

学习了损失函数的定义:

  • 采用torch.nn.MSELoss()计算均方误差.
  • 通过loss.backward()进行反向传播计算时, 整张计算图将对loss进行自动求导,
    所有属性requires_grad=True的Tensors都将参与梯度求导的运算, 并将梯度累加到Tensors中的.grad属性中.

学习了反向传播的计算方法:

  • 在Pytorch中执行反向传播非常简便, 全部的操作就是loss.backward().
  • 在执行反向传播之前, 要先将梯度清零, 否则梯度会在不同的批次数据之间被累加.
  • net.zero_grad()
  • loss.backward()

学习了参数的更新方法:

  • 定义优化器来执行参数的优化与更新.

    optimizer = optim.SGD(net.parameters(), lr=0.01)

  • 通过优化器来执行具体的参数更新.

    optimizer.step()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/94784.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

亲,您的假期余额已经严重不足了......

引言 大家好,我是亿元程序员,一位有着8年游戏行业经验的主程。 转眼八天长假已经接近尾声了,今天来总结一下大家的假期,聊一聊假期关于学习的看法,并预估一下大家节后大家上班时的样子。 1.放假前一天 即将迎来八天…

基于Web安全的Python编程(1)

目录 一、http协议基础知识介绍 1、http协议分类 2、请求方法 3、什么是URL 4、请求头 5、响应状态码 二、常用Python库、函数、操作 三、http常用请求方法 1、不带参请求 2、带参数请求(get和post存在细微区别) 四、http响应属性获取 1、获取…

计算机网络(六):应用层

参考引用 计算机网络微课堂-湖科大教书匠计算机网络(第7版)-谢希仁 1. 应用层概述 应用层是计算机网络体系结构的最顶层,是设计和建立计算机网络的最终目的,也是计算机网络中发展最快的部分 早期基于文本的应用 (电子邮件、远程登…

分布式架构篇

1、微服务 微服务架构风格,就像是把一个单独的应用程序开发为一套小服务,每个服务运行在自己的进程中,并使用轻量级机制通信,通常是 HTTP API。这些服务围绕业务能力来构建,并通过完全自动化部署机制来独立部署。这些…

Spring 原理

它是一个全面的、企业应用开发一站式的解决方案,贯穿表现层、业务层、持久层。但是 Spring仍然可以和其他的框架无缝整合。 1 Spring 特点 轻量级控制反转面向切面容器框架集合 2 Spring 核心组件 3 Spring 常用模块 4 Spring 主要包 5 Spring 常用注解 bean…

第十七章:Java连接数据库jdbc(java和myql数据库连接)

1.进入命令行:输入cmd,以管理员身份运行 windowsr 2.登录mysql 3.创建库和表 4.使用Java命令查询数据库操作 添加包 导入包的快捷键 选择第四个 找到包的位置 导入成功 创建java项目 二:连接数据库: 第一步:注册驱动…

设计模式 - 策略模式

目录 一. 前言 二. 实现 一. 前言 策略模式 (Strategy Pattern) 是指对一系列的算法定义,并将每一个算法封装起来,而且使它们还可以相互替换。此模式让算法的变化独立于使用算法的客户。 与状态模式的比较 状态模式的类图和策略模式类似,并…

VUE3照本宣科——内置指令与自定义指令及插槽

VUE3照本宣科——内置指令与自定义指令及插槽 前言一、内置指令1.v-text2.v-html3.v-show4.v-if5.v-else6.v-else-if7.v-for8.v-on9.v-bind10.v-model11.v-slot12.v-pre13.v-once14.v-memo15.v-cloak 二、自定义指令三、插槽1.v-slot2.useSlots3.defineSlots() 前言 &#x1f…

Windows下启动freeRDP并自适应远端桌面大小

几个二进制文件 xfreerdp # Linux下的,an X11 Remote Desktop Protocol (RDP) client which is part of the FreeRDP project wfreerdp.exe # Windows下的,freerdp2.0 主程序,freerdp3.0将废弃 sdl-freerdp.exe # Windows下的&…

【AI视野·今日NLP 自然语言处理论文速览 第四十三期】Thu, 28 Sep 2023

AI视野今日CS.NLP 自然语言处理论文速览 Thu, 28 Sep 2023 Totally 38 papers 👉上期速览✈更多精彩请移步主页 Daily Computation and Language Papers Cross-Modal Multi-Tasking for Speech-to-Text Translation via Hard Parameter Sharing Authors Brian Yan,…

STM32CubeMX学习笔记-USB接口使用(CDC虚拟串口)

STM32CubeMX学习笔记-USB接口使用(CDC虚拟串口) 一、USB简介二、新建工程1. 打开 STM32CubeMX 软件,点击“新建工程”2. 选择 MCU 和封装3. 配置时钟4. 配置调试模式 三、USB3.1 参数配置3.3 配置时钟3.4 USB Device 四、生成代码五、查看端口…

MySQL5.7版本与8.0版本在Ubuntu(WSL环境)系统安装

目录 前提条件 1. MySQL5.7版本在Ubuntu(WSL环境)系统安装 1. 1 下载apt仓库文件 1.2 配置apt仓库 1.3 更新apt仓库的信息 1.4 检查是否成功配置MySQL5.7的仓库 5. 安装MySQL5.7 1.6 启动MySQL 1.7 对MySQL进行初始化 1.7.1 输入密码 …

Lucene学习总结之Lucene的索引文件格式

当我们真正进入到Lucene源代码之中的时候,我们会发现: Lucene的索引过程,就是按照全文检索的基本过程,将倒排表写成此文件格式的过程。Lucene的搜索过程,就是按照此文件格式将索引进去的信息读出来,然后计算每篇文档打…

数据结构 2.1 线性表的定义和基本操作

数据结构三要素——逻辑结构、数据的运算、存储结构(物理结构) 线性表的逻辑结构 线性表是具有相同数据类型的n(n>0)个数据元素的有限序列,其中n为表长,当n0时,线性表是一个空表。 每个数…

单层神经网络

神经网络 人工神经网络(Artificial Neural Network,ANN),简称神经网络(Neural Network,NN),是一种模仿生物神经网络的结构和功能的数学模型或计算模型。1943年,McCulloc…

SpringMVC(二)@RequestMapping注解

我们先新建一个Module。 我们的依赖如下所示&#xff1a; <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0"xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaL…

uni-app:获取元素宽高

效果 代码 这里我定义的宽为500px,高为200排序,控制台输出的结果是502,202。原因是我设置了上下左右宽度各为1px的border边框导致 核心代码分析 // const query uni.createSelectorQuery();表示创建了一个选择器查询实例。通过这个实例&#xff0c;你可以使用不同的方法来选择…

实验3.2 分期付款计算器

目录 实验目的‪‬‪‬‪‬‪‬‪‬‮‬‭‬‪‬‪‬‪‬‪‬‪‬‪‬‮‬‪‬‫‬‪‬‪‬‪‬‪‬‪‬‮‬‫‬‭‬‪‬‪‬‪‬‪‬‪‬‮‬‪‬‪‬‪‬‪‬‪‬‪‬‪‬‮‬‫‬‪‬‪‬‪‬‪‬‪‬‪‬‮‬‪‬‪‬ 实验内容‪‬‪‬‪‬‪‬‪‬‮‬‭‬‪‬‪‬‪‬…

Android LitePal byte[]类型字段不被创建

我创建了以下实体类&#xff0c;主要是用户分享的内容、分享的照片、分享的标题&#xff0c;然后百度了一下LitePal可以识别byte[]&#xff0c;因为需要文件的上传与读取&#xff1a; public class Context extends LitePalSupport {private Integer ContextId;private String…

一文拿捏Spring事务之、ACID、隔离级别、失效场景

1.&#x1f31f;Spring事务 1.编程式事务 事务管理代码嵌入嵌入到业务代码中&#xff0c;来控制事务的提交和回滚&#xff0c;例如TransactionManager 2.声明式事务 使用aop对方法前后进行拦截&#xff0c;然后在目标方法开始之前创建或者加入一个事务&#xff0c;执行完目…