Neural Networks

神经网络能够使用torch.nn包构建神经网络。

现在你已经对autogard有了初步的了解,nn基于autograd来定义模型并进行微分。一个nn.Module包含层,和一个forward(input)方法并返回output。

以如下分类数字图片的网络所示:

这是一个简单的前馈网络。它接受输入,经过一层接着一层的神经网络层,最终得到输出。

一个神经网络典型的训练流程如下:

  • 定义拥有可学习的参数的神经网络
  • 迭代数据集作为输入
  • 经过网络处理输入
  • 计算损失(离正确输出的距离)
  • 反向传播梯度到网络参数
  • 更新网络的权重,比如简单的更新规则:weight=weight-learning_rate*gradient

定义网络

让我们定义这个网络:  

import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):def __init__(self):super(Net,self).__init__()# 一个输入图片通道,六个输出通道,5*5的卷积核self.conv1=nn.Conv2d(1,6,5)self.conv2=nn.Conv2d(6,16,5)# 一个仿射操作:y=wx+bself.fc1=nn.Linear(16*5*5,120)self.fc2=nn.Linear(120,84)self.fc3=nn.Linear(84,10)def forward(self,x):# 2*2窗口的最大赤化x=F.max_pool2d(F.relu(self.conv1(x)),(2,2))# 如果是一个方块就只需要指定一个长度x=F.max_pool2d(F.relu(self.conv2(x)),2)x=x.view(-1,self.num_flat_features(x))x=F.relu(self.fc1(x))x=F.relu(self.fc2(x))x=self.fc3(x)return xdef num_flat_features(self,x):
     #第一个尺寸是batch sizesize
=x.size()[1:]print(size)num_features=1for s in size:num_features*=sreturn num_featuresnet=Net() print(net)
out:
Net((conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))(fc1): Linear(in_features=400, out_features=120, bias=True)(fc2): Linear(in_features=120, out_features=84, bias=True)(fc3): Linear(in_features=84, out_features=10, bias=True)
)

你只需要定义forward函数,backward函数(计算梯度的地方)是自动定义的。你能够在forward中使用任意的tensor运算。

模型可学习的参数将通过net.parameters()返回

params=list(net.parameters())
print(len(params))
print(params[0].size())

out:
10
torch.Size([6, 1, 5, 5])

让我们试一下随机的32*32输入,注意:这个网络(LeNet)期望的输入尺寸是32*32。为了在MNIST数据集上使用这个网络,请将数据集的图片调整到32*32。

input=torch.randn(1,1,32,32)
out=net(input)
print(out)

out:
tensor([[ 0.0355, -0.0294, -0.0025, -0.0743, -0.0168, -0.0202, -0.0558,
0.0803, -0.0162, -0.1153]])

将所有参数的梯度缓冲变为0并使用随机梯度进行后向传播:

net.zero_grad()
out.backward(torch.randn(1,10))

 

!注意:

torch.nn只支持最小批。整个torch.nn包只支持输入的样本是一个最小批,而不是一个单一样本.

举例来说,nn.Conv2d将会接收4维的Tensor,nSamples*nChannels*Heights*Width.

如果你有一个单一样本,可以使用input.unsqueeze(0)来增加一个虚假的批维度。

在进行进一步处理前,让我们简要重复目前为止出现的类。

扼要重述:

  • torch.Tensor- 一个支持自动求导操作比如backward()的多维数组。同时保留关于tensor的梯度.
  • nn.Module- 神经网络模型。简易的封装参数的方法,帮助将它们转移到GPU上,导出加载等等.
  • nn.Parameters - 一类Tensor,在作为Module属性指定时会自动注册为一个parameter.
  • autograd.Function- 自动求导操作前向与后向的实现。每个tensor操作,至少创建一个Functional节点,它连接到创建Tensor的函数并编码它的历史

在这一节,我们包含了:

  • 定义一个神经网络
  • 处理输入并调用后向传播

还剩下:

  • 计算损失
  • 更新网络的权重

 

损失函数:

 损失函数接收对(输出,目标)作为输入,计算一个值估计输出与目标之间的距离。

nn包下有一些不同的损失函数。一个简单的损失是nn.MSELoss,它计算的是输入与输出之间的均方误差。

比如:

output=net(input)
target=torch.randn(10)
target=target.view(1,-1)
criterion=nn.MSELoss()loss=criterion(output,target)
print(loss)
out:
tensor(1.1941)

现在,如果你如果按照loss的反向传播方向,使用.grad_fn属性,你将会看到一个计算图如下所示:

input -> conv2d -> relu -> maxpool2d -> conv2d -> relu -> maxpool2d -> view -> linear -> relu -> linear -> relu -> linear -> MSELoss -> loss

所以,当你调用loss.backward(),整个图关于损失求导,并且图中所有requires_grad=True的tensor将会有它们的.grad属性。Tensor的梯度是累加的。

为了说明这一点,我们跟踪backward的部分步骤:

print(loss.grad_fn)  #MSELoss
print(loss.grad_fn.next_functions[0][0]) #Linear
print(loss.grad_fn.next_functions[0][0].next_functions[0][0]) # ReLU
out:
<MseLossBackward object at 0x0000020E2E1289B0>
<AddmmBackward object at 0x0000020E2BF48048>
<ExpandBackward object at 0x0000020E2BF48048>

 

Backprop

为了反向传播error,我们需要做的就是loss.backward()。你需要清除现有的梯度,否则梯度将会累计到现有梯度上。

 现在我们会调用loss.backward(),观察调用backward前后conv1层偏差的梯度。

net.zero_grad()print('conv1.bias.grad before backward')
print(net.conv1.bias.grad)
loss.backward()
print('conv1.bias.grad after backward')
print(net.conv1.bias.grad)

out:
conv1.bias.grad before backward
None # 上一个版本将会是一个为0的向量
conv1.bias.grad after backward
tensor(1.00000e-03 *
[ 4.0788, 1.9541, 5.8585, -2.3754, 2.3815, 1.3351])

 

 现在我们知道了如何使用loss函数

稍后阅读:

神经网络包包含各种模型和loss函数,它们组成了深度神经网络的构建区块。完整的文档在这里。http://pytorch.org/docs/nn

 剩下来需要学习的是:

  • 更新网络的权重

 

更新权重

实际中使用的最简单更新规则是随机梯度下降(SGD):

weight=weight-learning_rate*gradient

 我们能够使用简单的python代码实现:

learning_rate=0.01
for f in net.parameters():f.data.sub_(f.grad.data*learning_rate)

 

然而,当我们使用神经网络,你想要使用各种不同的更新规则比如SGD,Nesterov-SGD,Adam,RMSProp等。为了做到这一点,我们建立了一个小的包torch.optim实现了这些方法。使用它非常简单。

import torch.optim as optim#create your optimizer
optimizer =optim.SGD(net.parameters(),lr=0.01)# in your training loop
optimizer.zero_grad()
output=net(input)
loss=criterion(output,target)
loss.backward()
optimizer.step()

 

!注意:

 手动使用optimizer.zero_grad()来将梯度缓冲变为0。这在Backprop章节进行了解释,因为梯度是累加的。

转载于:https://www.cnblogs.com/Thinker-pcw/p/9635572.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/365777.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

文件服务器磁盘配额管理,Windows2008 (FSRM)文件服务器资源管理器网站文件夹磁盘配额管理...

在windows server里提供了一个功能强大的管理工具&#xff0c;就是文件服务器资源管理器。简称FSRM(File Server Resource Manager)通过此工具&#xff0c;可能对网站进去配额以及相关服务。我们通过使用FSRM&#xff0c;可以为文件夹和卷设置配额&#xff0c;主动屏蔽文件&…

SpringBoot项目中,获取配置文件信息

1.在配置文件中设置信息&#xff0c;格式如下 wechat:mpAppId: wxdf2b09f280e6e6e2mpAppSecret: f924b2e9f140ac98f9cb5317a8951c71 如果是多级目录&#xff0c;则 project:url:sell: http://localhost:8080 2.获取配置文件信息&#xff08;三种方法&#xff09; 2.1Configurat…

oppo 手机侧滑快捷菜单_OPPO刚秀出卷轴屏手机,就被打了一记响亮的“耳光”

在刚刚过去的未来科技大会上&#xff0c;我国国产手机厂商 oppo可谓是出尽了风头&#xff0c;因为他们推出一款名叫“OPPO X 2021”的卷轴屏概念手机&#xff0c;并且展出了可操作的概念机实物&#xff0c;着实让所有人都惊艳了一把。因此我国的一些自媒体又嗨了&#xff0c;用…

UVA1602 Lattice Animals 搜索+剪枝

题目大意 给出一个$w\times h$的网格&#xff0c;定义一个连通块为一个元素个数为$n$的方格的集合$A,\forall x\in A, \exists y\in A$&#xff0c;使得$x,y$有一条公共边。现要求一个元素个数极多的连通块的集合$K_N$&#xff0c;使得$\forall A,B\in K_n$&#xff0c;不存在一…

python怎么打开程序管理器_Python 进程管理工具 Supervisor 使用教程

因为我的个人网站 restran.net 已经启用&#xff0c;博客园的内容已经不再更新。请访问我的个人网站获取这篇文章的最新内容&#xff0c;Python 进程管理工具 Supervisor 使用教程 Supervisor 是基于 Python 的进程管理工具&#xff0c;只能运行在 Unix-Like 的系统上&#xff…

ft服务器设置传输协议,ft服务器设置成主动模式

ft服务器设置成主动模式 内容精选换一换如果您选择使用SFS Turbo实现文件共享存储&#xff0c;此章节操作可跳过&#xff0c;您可以参见《SAP HANA用户指南》中的“格式化磁盘”章节&#xff0c;挂载Backup卷。NFS Server磁盘需要格式化&#xff0c;并挂载到相应的目录后&#…

mysql text字段导出_Mysql数据库的各种命令:

一、连接MYSQL格式&#xff1a; mysql -h主机地址 -u用户名 -p用户密码1、连接到本机上的MYSQL。 首先打开DOS窗口&#xff0c;然后进入目录mysqlbin&#xff0c;再键入命令mysql -u root -p&#xff0c;回车后提示你输密码. 注意用户名前可以有空格也可以没有空格&#xff0c;…

运行,JUnit! 跑!!!

JUnit与JavaScript和SVN一起是程序员经常开始使用的一些技术&#xff0c;甚至没有读过一篇博客文章&#xff0c;更不用说一本书了。 也许这是一件好事&#xff0c;因为它们看起来足够简单且易于理解&#xff0c;因此我们无需任何手册即可立即使用它们&#xff0c;但这也意味着它…

css3图形绘制

以下几个例子主要是运用了css3中border、bordr-radius、transform、伪元素等属性来完成的&#xff0c;我们先了解下它们的基本原理。 border&#xff1a;简单的来说border语法主要包含&#xff08;border-width、border-style、border-color&#xff09;三个属性。 „ border-t…

vueh5调用摄像头拍照_潜望式拍照5G手机盘点:售价相差数千元 究竟怎么选?

【dogkeji-科技犬】最近很多网友询问科技犬&#xff0c;目前支持50X潜望式长焦手机都有哪些&#xff0c;可否进行相关手机的推荐&#xff0c;今天就应大家的需求来盘点一下&#xff0c;给各位网友一些参考。推荐一&#xff0c;三星 Galaxy S20 U三星Galaxy S20 5G系列不仅搭载了…

09 事件对象

上篇介绍完我们js的事件流的概念之后&#xff0c;相信大家对事件流也有所了解了。那么接下来我们看一下jquery的事件操作。 在说jquery的每个事件之前&#xff0c;我们先来看一下事件对象 事件对象 Event 对象代表事件的状态&#xff0c;比如事件在其中发生的元素、键盘按键的状…

使用Struts2,Hibernate和MySQL创建个人MusicManager Web应用程序的研讨会

概述&#xff1a; 在本研讨会教程中&#xff0c;我们将使用Struts 2&#xff0c;Hibernate和MySQL数据库开发一个个人音乐管理器应用程序。 该Web应用程序可用于将您的音乐收藏添加到数据库中。 我们将显示用于添加唱片的表格&#xff0c;并在下面显示所有音乐收藏。 通过单击“…

链表快速排序python_Python一行代码实现快速排序的方法

今天将单独为大家介绍一下快速排序&#xff01; 一、算法介绍 排序算法&#xff08;Sorting algorithm&#xff09;是计算机科学最古老、最基本的课题之一。要想成为合格的程序员&#xff0c;就必须理解和掌握各种排序算法。其中"快速排序"&#xff08;Quicksort&…

自定义滚动条样式

啥都不说先看图: 注: 只适合chrom,不适用IE和fireFox 下面展示代码: 1 <html lang"en">2 <head>3 <meta charset"UTF-8">4 <title>CSS3自定义滚动条-轩枫阁</title>5 <style>6 header7 {8 font-family: …

rust为什么显示不了国服_Rust编程语言初探

静态、强类型而又不带垃圾收集的编程语言领域内&#xff0c;很久没有新加入者参与竞争了&#xff0c;大概大部分开发者认为传统的C/C的思路已经不太适合新时代的编程需求&#xff0c;即便有Ken Tompson这样的大神参与设计的golang也采用了GC的思路来设计其新一代的语言&#xf…

wps表格粗线和细线区别_详解论文中的表格技术

今天我们主要学习的技能如下&#xff1a;• 怎样用word做论文要求的三线表• 三线表中辅助线的断开• 表格或者图片自动编号1. 先普及一下&#xff0c;论文中的三线表吧。三线表以其形式简洁、功能分明、阅读方便而在科技论文中被推荐使用。三线表通常只有3条线&#xff0c;即顶…

如何自定义CSS滚动条的样式?

欢迎大家前往腾讯云 社区&#xff0c;获取更多腾讯海量技术实践干货哦~ 本文由前端林子发表 本文会介绍CSS滚动条选择器&#xff0c;并在demo中展示如何在Webkit内核浏览器和IE浏览器中&#xff0c;自定义一个横向以及一个纵向的滚动条。 0.需求 有的时候我们不想使用浏览器默…

RabbitMQ基础知识

RabbitMQ基础知识 一、背景RabbitMQ是一个由erlang开发的AMQP&#xff08;Advanced Message Queue &#xff09;的开源实现。AMQP 的出现其实也是应了广大人民群众的需求&#xff0c;虽然在同步消息通讯的世界里有很多公开标准&#xff08;如 COBAR的 IIOP &#xff0c;或者是 …

iview 级联选择组件_使用 element-ui 级联插件遇到的坑

需求描述【省市区三级联动】组件&#xff1a;Cascader 级联选择器后端需要所选中的地区的名字&#xff0c;如&#xff1a;[北京市, 北京市, 东城区]获取后端省市区具体列表的接口返回数据&#xff1a;// 省 - 参数1 [{value: 1,label: 北京市},... ] // 市 - 参数2 [{value: 1,…

深入理解CPU和异构计算芯片GPU/FPGA/ASIC (上篇)

王玉伟&#xff0c;腾讯TEG架构平台部平台开发中心基础研发组资深工程师&#xff0c;专注于为数据中心提供高效的异构加速云解决方案。目前&#xff0c;FPGA已在腾讯海量图片处理以及检测领域已规模上线。 随着互联网用户的快速增长&#xff0c;数据体量的急剧膨胀&#xff0c;…