Neural Networks

神经网络能够使用torch.nn包构建神经网络。

现在你已经对autogard有了初步的了解,nn基于autograd来定义模型并进行微分。一个nn.Module包含层,和一个forward(input)方法并返回output。

以如下分类数字图片的网络所示:

这是一个简单的前馈网络。它接受输入,经过一层接着一层的神经网络层,最终得到输出。

一个神经网络典型的训练流程如下:

  • 定义拥有可学习的参数的神经网络
  • 迭代数据集作为输入
  • 经过网络处理输入
  • 计算损失(离正确输出的距离)
  • 反向传播梯度到网络参数
  • 更新网络的权重,比如简单的更新规则:weight=weight-learning_rate*gradient

定义网络

让我们定义这个网络:  

import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):def __init__(self):super(Net,self).__init__()# 一个输入图片通道,六个输出通道,5*5的卷积核self.conv1=nn.Conv2d(1,6,5)self.conv2=nn.Conv2d(6,16,5)# 一个仿射操作:y=wx+bself.fc1=nn.Linear(16*5*5,120)self.fc2=nn.Linear(120,84)self.fc3=nn.Linear(84,10)def forward(self,x):# 2*2窗口的最大赤化x=F.max_pool2d(F.relu(self.conv1(x)),(2,2))# 如果是一个方块就只需要指定一个长度x=F.max_pool2d(F.relu(self.conv2(x)),2)x=x.view(-1,self.num_flat_features(x))x=F.relu(self.fc1(x))x=F.relu(self.fc2(x))x=self.fc3(x)return xdef num_flat_features(self,x):
     #第一个尺寸是batch sizesize
=x.size()[1:]print(size)num_features=1for s in size:num_features*=sreturn num_featuresnet=Net() print(net)
out:
Net((conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))(fc1): Linear(in_features=400, out_features=120, bias=True)(fc2): Linear(in_features=120, out_features=84, bias=True)(fc3): Linear(in_features=84, out_features=10, bias=True)
)

你只需要定义forward函数,backward函数(计算梯度的地方)是自动定义的。你能够在forward中使用任意的tensor运算。

模型可学习的参数将通过net.parameters()返回

params=list(net.parameters())
print(len(params))
print(params[0].size())

out:
10
torch.Size([6, 1, 5, 5])

让我们试一下随机的32*32输入,注意:这个网络(LeNet)期望的输入尺寸是32*32。为了在MNIST数据集上使用这个网络,请将数据集的图片调整到32*32。

input=torch.randn(1,1,32,32)
out=net(input)
print(out)

out:
tensor([[ 0.0355, -0.0294, -0.0025, -0.0743, -0.0168, -0.0202, -0.0558,
0.0803, -0.0162, -0.1153]])

将所有参数的梯度缓冲变为0并使用随机梯度进行后向传播:

net.zero_grad()
out.backward(torch.randn(1,10))

 

!注意:

torch.nn只支持最小批。整个torch.nn包只支持输入的样本是一个最小批,而不是一个单一样本.

举例来说,nn.Conv2d将会接收4维的Tensor,nSamples*nChannels*Heights*Width.

如果你有一个单一样本,可以使用input.unsqueeze(0)来增加一个虚假的批维度。

在进行进一步处理前,让我们简要重复目前为止出现的类。

扼要重述:

  • torch.Tensor- 一个支持自动求导操作比如backward()的多维数组。同时保留关于tensor的梯度.
  • nn.Module- 神经网络模型。简易的封装参数的方法,帮助将它们转移到GPU上,导出加载等等.
  • nn.Parameters - 一类Tensor,在作为Module属性指定时会自动注册为一个parameter.
  • autograd.Function- 自动求导操作前向与后向的实现。每个tensor操作,至少创建一个Functional节点,它连接到创建Tensor的函数并编码它的历史

在这一节,我们包含了:

  • 定义一个神经网络
  • 处理输入并调用后向传播

还剩下:

  • 计算损失
  • 更新网络的权重

 

损失函数:

 损失函数接收对(输出,目标)作为输入,计算一个值估计输出与目标之间的距离。

nn包下有一些不同的损失函数。一个简单的损失是nn.MSELoss,它计算的是输入与输出之间的均方误差。

比如:

output=net(input)
target=torch.randn(10)
target=target.view(1,-1)
criterion=nn.MSELoss()loss=criterion(output,target)
print(loss)
out:
tensor(1.1941)

现在,如果你如果按照loss的反向传播方向,使用.grad_fn属性,你将会看到一个计算图如下所示:

input -> conv2d -> relu -> maxpool2d -> conv2d -> relu -> maxpool2d -> view -> linear -> relu -> linear -> relu -> linear -> MSELoss -> loss

所以,当你调用loss.backward(),整个图关于损失求导,并且图中所有requires_grad=True的tensor将会有它们的.grad属性。Tensor的梯度是累加的。

为了说明这一点,我们跟踪backward的部分步骤:

print(loss.grad_fn)  #MSELoss
print(loss.grad_fn.next_functions[0][0]) #Linear
print(loss.grad_fn.next_functions[0][0].next_functions[0][0]) # ReLU
out:
<MseLossBackward object at 0x0000020E2E1289B0>
<AddmmBackward object at 0x0000020E2BF48048>
<ExpandBackward object at 0x0000020E2BF48048>

 

Backprop

为了反向传播error,我们需要做的就是loss.backward()。你需要清除现有的梯度,否则梯度将会累计到现有梯度上。

 现在我们会调用loss.backward(),观察调用backward前后conv1层偏差的梯度。

net.zero_grad()print('conv1.bias.grad before backward')
print(net.conv1.bias.grad)
loss.backward()
print('conv1.bias.grad after backward')
print(net.conv1.bias.grad)

out:
conv1.bias.grad before backward
None # 上一个版本将会是一个为0的向量
conv1.bias.grad after backward
tensor(1.00000e-03 *
[ 4.0788, 1.9541, 5.8585, -2.3754, 2.3815, 1.3351])

 

 现在我们知道了如何使用loss函数

稍后阅读:

神经网络包包含各种模型和loss函数,它们组成了深度神经网络的构建区块。完整的文档在这里。http://pytorch.org/docs/nn

 剩下来需要学习的是:

  • 更新网络的权重

 

更新权重

实际中使用的最简单更新规则是随机梯度下降(SGD):

weight=weight-learning_rate*gradient

 我们能够使用简单的python代码实现:

learning_rate=0.01
for f in net.parameters():f.data.sub_(f.grad.data*learning_rate)

 

然而,当我们使用神经网络,你想要使用各种不同的更新规则比如SGD,Nesterov-SGD,Adam,RMSProp等。为了做到这一点,我们建立了一个小的包torch.optim实现了这些方法。使用它非常简单。

import torch.optim as optim#create your optimizer
optimizer =optim.SGD(net.parameters(),lr=0.01)# in your training loop
optimizer.zero_grad()
output=net(input)
loss=criterion(output,target)
loss.backward()
optimizer.step()

 

!注意:

 手动使用optimizer.zero_grad()来将梯度缓冲变为0。这在Backprop章节进行了解释,因为梯度是累加的。

转载于:https://www.cnblogs.com/Thinker-pcw/p/9635572.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/365777.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

文件服务器磁盘配额管理,Windows2008 (FSRM)文件服务器资源管理器网站文件夹磁盘配额管理...

在windows server里提供了一个功能强大的管理工具&#xff0c;就是文件服务器资源管理器。简称FSRM(File Server Resource Manager)通过此工具&#xff0c;可能对网站进去配额以及相关服务。我们通过使用FSRM&#xff0c;可以为文件夹和卷设置配额&#xff0c;主动屏蔽文件&…

python读取hdf5文件_Python处理HDF5文件

笔记地址&#xff1a;有道云笔记 h5py 的安装 pip install h5py 读取 H5 文件 import h5py import numpy as np # 打开文件 f h5py.File(test-dev.h5, r) H5 中的group可以类比为字典&#xff0c;因此我们可以用keys()来获取键值。 >>> f.keys() [umy_xmax, umy_xmin…

数据分析sql面试必会6题经典_面试数据分析会遇到的SQL题

「1」说在前面数据存放在数据库里&#xff0c;以表的形式分门别类。宜家的商品(数据)放在宜家的仓库(数据库)里&#xff0c;以货位的形式分门别类。在宜家&#xff0c;可以通过商品上的编号&#xff0c;查到商品在仓库的排号和位号&#xff0c;取到商品。SQL 语言是一种通用的数…

Hibernate 4.3 ORM工具

Hibernate最近发布了Hibernate ORM 4.3的最终版本&#xff0c;它是一个基于Java的ORM框架&#xff0c;它还支持存储过程和实体图。 发行了ORM Tool Hibernate 4.3&#xff0c;实现了JPA 2.1规范&#xff0c;引入了该发行版的主要功能&#xff0c;简而言之&#xff1a; 支持在提…

301、404、200、304、500HTTP状态

一些常见的状态码为&#xff1a; 200 - 服务器成功返回网页 404 - 请求的网页不存在 503 - 服务器超时 下面提供 HTTP 状态码的完整列表。点击链接可了解详情。您也可以访问 HTTP 状态码上的 W3C 页获取更多信息。 一、临时响应  1xx(临时响应) 表示临时响应并需要请求者继续…

SpringBoot项目中,获取配置文件信息

1.在配置文件中设置信息&#xff0c;格式如下 wechat:mpAppId: wxdf2b09f280e6e6e2mpAppSecret: f924b2e9f140ac98f9cb5317a8951c71 如果是多级目录&#xff0c;则 project:url:sell: http://localhost:8080 2.获取配置文件信息&#xff08;三种方法&#xff09; 2.1Configurat…

ajax环境配置tomcat,jcreator+tomcat环境配置

有的时候因为机器硬件原因&#xff0c;在使用eclipse的时候明显表现不足&#xff0c;其实&#xff0c;仔细想想&#xff0c;我们做web开发的时候&#xff0c;java方面也就是一些非gui类的开发&#xff0c;比如&#xff1a;action&#xff0c;service&#xff0c;dao等等。这样的…

oppo 手机侧滑快捷菜单_OPPO刚秀出卷轴屏手机,就被打了一记响亮的“耳光”

在刚刚过去的未来科技大会上&#xff0c;我国国产手机厂商 oppo可谓是出尽了风头&#xff0c;因为他们推出一款名叫“OPPO X 2021”的卷轴屏概念手机&#xff0c;并且展出了可操作的概念机实物&#xff0c;着实让所有人都惊艳了一把。因此我国的一些自媒体又嗨了&#xff0c;用…

python爬取会议论文pdf_【python2.7】爬取知网论文

# -*- coding: utf-8 -*-import timeimport urllibimport urllib2import cookielibfrom lxml import etreeimport random爬取第一页&#xff0c;获取共页数爬取第二页至最后一页# 下载当前页所有文章的pdf或cajdef download_paper(treedata, opener, localdir):传入参数&#x…

活性卡桑德拉

或是冒险从Cassandra被动地读取数据。 总览 让我们首先尝试从编程的角度定义什么是反应性。 功能反应式编程是使用功能性编程的构建块进行反应式编程的编程范例。 函数式编程是一种编程范例&#xff0c;是一种构建计算机程序的结构和元素的样式&#xff0c;这种处理将计算视…

UVA1602 Lattice Animals 搜索+剪枝

题目大意 给出一个$w\times h$的网格&#xff0c;定义一个连通块为一个元素个数为$n$的方格的集合$A,\forall x\in A, \exists y\in A$&#xff0c;使得$x,y$有一条公共边。现要求一个元素个数极多的连通块的集合$K_N$&#xff0c;使得$\forall A,B\in K_n$&#xff0c;不存在一…

js 停止事件冒泡 阻止浏览器的默认行为

在前端开发工作中&#xff0c;由于浏览器兼容性等问题&#xff0c;我们会经常用到“停止事件冒泡”和“阻止浏览器默认行为”。 浏览器默认行为&#xff1a; 在form中按回车键就会提交表单&#xff1b;单击鼠标右键就会弹出context menu. a标签 1..停止事件冒泡 JavaScrip…

魔域传说显示与服务器断开连接,《魔域传说》合服公告

8月2日合服公告亲爱的勇士&#xff0c;为了给大家提供更加优质的游戏体验&#xff0c;《魔域传说》将于2021年8月2日14:00对部分服务器进行合服维护&#xff0c;维护时长预计3小时&#xff0c;维护完成时间视维护情况可能提前或延后&#xff0c;在维护期间将不能登陆服务器&…

python怎么打开程序管理器_Python 进程管理工具 Supervisor 使用教程

因为我的个人网站 restran.net 已经启用&#xff0c;博客园的内容已经不再更新。请访问我的个人网站获取这篇文章的最新内容&#xff0c;Python 进程管理工具 Supervisor 使用教程 Supervisor 是基于 Python 的进程管理工具&#xff0c;只能运行在 Unix-Like 的系统上&#xff…

Hive的伴奏_OURDEN INSTRUMENTALS MIXTAPE Vol.108 “Sober” 伴奏合辑

OURDEN INSTRUMENTALS MIXTAPE Vol.108Sober曲目列表 Track List :Anne Tello – Love Transformation (Prod. By Peter Monk)Blac Youngsta – Left (Prod. By Yung Lan)Bling X – Missing You (Prod. By Phivestarr Productions)BlocBoy JB – Ali (Prod. By Denaro Love)Bl…

命名空间和程序集

命名空间 命名空间是在逻辑上分割代码&#xff0c;程序集是在物理上分割代码。 嵌套命名空间 namespace one { namespace two { } } 通过one.two引用内部嵌套的命名空间的代码。 命名空间不必和程序集同名。 类的可见性 internal 修饰的类&#xff0c;仅能在本程序集中访问。 p…

WebSocket和Java

WebSocket是一项很酷的新技术&#xff0c;它允许浏览器与服务器之间进行实时双向通信&#xff0c;而几乎没有开销。 我在这里想要做的是&#xff0c;提供一个非常简洁但足够全面的概述&#xff0c;以介绍如何开始使用该技术。 因此&#xff0c;从以下几件事开始&#xff1a; 在…

网页顶部进度条-NProcess.js

背景 有些网站&#xff0c;比如github上在查看项目文件夹层级时会在网页顶部出现一个 进度条&#xff0c;虽然是PC端却有一种移动端体验&#xff0c;个人认为可以提升使用体验&#xff0c;经查阅相关资料后&#xff0c;找到一个NProgress.js全站进度条插件 示例 在使用vue开发S…

点击图片放大至原始图片大小

有些时候为了排版的整洁&#xff0c;页面展示的图片不得不都是限定宽高的&#xff0c;如果想要点击图片放大至原始大小进行预览&#xff0c;再次点击回到原来样子&#xff0c;就要用到下面的代码了&#xff1a; var _w parseInt($(window).width());//获取浏览器的宽度$("…

ft服务器设置传输协议,ft服务器设置成主动模式

ft服务器设置成主动模式 内容精选换一换如果您选择使用SFS Turbo实现文件共享存储&#xff0c;此章节操作可跳过&#xff0c;您可以参见《SAP HANA用户指南》中的“格式化磁盘”章节&#xff0c;挂载Backup卷。NFS Server磁盘需要格式化&#xff0c;并挂载到相应的目录后&#…