Pytorch Tutorial【Chapter 3. Simple Neural Network】

Pytorch Tutorial【Chapter 3. Simple Neural Network】

文章目录

  • Pytorch Tutorial【Chapter 3. Simple Neural Network】
    • Chapter 3. Simple Neural Network
      • 3.1 Train Neural Network Procedure训练神经网络流程
      • 3.2 Build Neural Network Procedure 搭建神经网络
      • 3.3 Use Loss Function to Backward 利用损失函数进行反向传播
      • 3.4 Update Parameter of NN 更新神经网络参数
        • 3.4.1 Update Manually手动更新参数
        • 3.4.2 Update Automatically自动更新参数
    • Reference

Chapter 3. Simple Neural Network

3.1 Train Neural Network Procedure训练神经网络流程

一个典型的神经网络训练过程包括以下几点:

  • 定义一个包含可训练参数的神经网络

  • 迭代整个输入

  • 通过神经网络处理输入

  • 计算损失(loss)

  • 反向传播梯度到神经网络的参数

  • 更新网络的参数,典型的用一个简单的更新方法:weight = weight - learning_rate *gradient

3.2 Build Neural Network Procedure 搭建神经网络

  • 定义一个类并继承torch.nn.Moulde
  • 使用类torch.nn中的组件和torch.nn.functional中的组件来搭建网络结构
  • 改写该类的forward方法,在此方法中,进一步完善网络结构(如激活函数,池化层等),并且获得返回值的输出

先简要介绍一下需要使用到的torch.nn组件

  • torch.nn.Conv2d(in_channels, out_channels, kernel_size)进行卷积操作,输入是 ( N , C i n , H , W ) (N, C_{in},H,W) (N,Cin,H,W),输出是 ( N , C o u t , H , W ) (N,C_{out},H,W) (N,Cout,H,W)(详见Conv2d),卷积对图片尺寸的影响如下
Image
  • torch.nn.Linear(in_features, out_features) 进行放射变换操作(affine mapping),即 y = x A T + b y=xA^T+b y=xAT+b,(详见Linear)
  • torch.nn.Flatten(start_dim=1, end_dim=- 1),将连续的范围展平为张量(详见Flatten)

再介绍一下需要使用到的torch.nn.functional组件

  • torch.nn.functional.relu(),对输入使用Relu激活函数,具体操作是对每个元素都进行 ReLU ( x ) = max ⁡ ( 0 , x ) \text{ReLU}(x)=\max(0,x) ReLU(x)=max(0,x)的计算,(详见ReLu)

  • torch.nn.functional.softmax(),对输入使用Softmax激活函数,具体操作是对每个元素都进行 Softmax ( x i ) = exp ⁡ ( x i ) ∑ j exp ⁡ ( x j ) \text{Softmax}(x_i)=\frac{\exp(x_i)}{\sum_{j}\exp(x_j)} Softmax(xi)=jexp(xj)exp(xi)(详见Softmax)

  • torch.nn.functional.max_pool2d(input, kernel_size, stride) , 进行最大池化操作,输入是KaTeX parse error: Expected 'EOF', got '_' at position 28: …atch}, \text{in_̲channels}, iH,i…,详见Max_Pool2D)

例如下述代码,我们要搭建一个下图的神经网络结构

Image

​ 输入是 [ batch , channels , height , weight ] [\text{batch},\text{channels},\text{height},\text{weight}] [batch,channels,height,weight]的图片,分别代表批量大小、通道数、图像的高度、图像的宽度。在我们取批量大小为 1 1 1,然后这个例子变成 [ 1 , 1 , 32 , 32 ] [1,1,32,32] [1,1,32,32],张量的变化过程如下所示
KaTeX parse error: Expected 'EOF', got '_' at position 74: …arrow{\text{max_̲pool}}[1,6,14,1…

import torch
import torch.nn as nn
import torch.nn.functional as Fclass Net(nn.Module):def __init__(self):super(Net,self).__init__()# 1 is input_channel 6 is output_channelself.conv1 = nn.Conv2d(1, 6, 5)self.conv2 = nn.Conv2d(6, 16, 5)# affine function y = Wx + bself.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self,x):# Max pooling over a (2,2) windowx = F.max_pool2d(F.relu(self.conv1(x)), (2,2))# If the size is a square you can only specify a single numberx = F.max_pool2d(F.relu(self.conv2(x)), (2,2))# flat the feature as a vectorx = x.view(-1, self.num_flat_features(x))x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = F.softmax(self.fc3(x), dim=1)return xdef num_flat_features(self,x):size = x.size()[1:] # all dimensions except the batch dimensionnum_feature = 1for s in size:num_feature *= sreturn num_featurenet = Net()
print(net)

class Net(nn.Module):def __init__(self):super(Net,self).__init__()# 1 is input_channel 6 is output_channelself.conv1 = nn.Conv2d(1, 6, 5)self.conv2 = nn.Conv2d(6, 16, 5)# affine function y = Wx + bself.flat = nn.Flatten(1,-1)    # flat the featureself.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self,x):# Max pooling over a (2,2) windowx = F.max_pool2d(F.relu(self.conv1(x)), (2,2))# If the size is a square you can only specify a single numberx = F.max_pool2d(F.relu(self.conv2(x)), (2,2))# flat the feature as a vectorx = self.flat(x)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = F.softmax(self.fc3(x), dim=1)return xnet = Net()
print(net)

结果如下,可以查看网络结构

Net((conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))(fc1): Linear(in_features=400, out_features=120, bias=True)(fc2): Linear(in_features=120, out_features=84, bias=True)(fc3): Linear(in_features=84, out_features=10, bias=True)
)

一个模型可训练的参数可以通过调用 net.parameters() 返回,

params = list(net.parameters())
print(len(params))
print(params[0].size())  # conv1's .weight

结果如下

10
torch.Size([6, 1, 5, 5])

然后我们可以自定义一些输入,来获得网络的输出

input = torch.randn(3, 1, 32, 32)
out = net(input)
print(out.data)
sum = 0
for i in out.data[0]:sum += i.item()
print(sum)

输出如下

tensor([[0.0996, 0.1000, 0.0907, 0.0946, 0.0930, 0.1067, 0.1044, 0.1119, 0.1073,0.0918],[0.0975, 0.1005, 0.0913, 0.0935, 0.0939, 0.1070, 0.1045, 0.1121, 0.1065,0.0933],[0.0968, 0.1009, 0.0896, 0.0978, 0.0903, 0.1107, 0.1055, 0.1130, 0.1043,0.0913]])
0.9999999925494194

torch.nn.Module.zero_grad()把所有参数梯度缓存器置零,

net.zero_grad() #把所有参数梯度缓存器置零,用随机的梯度来反向传播
out.backward(torch.randn_like(out))
net.conv1.bias.grad  #查看Conv卷积层的偏置和权重一些参数的梯度
net.conv1.weight.grad

结果如下

tensor([-0.0064,  0.0136, -0.0046,  0.0008, -0.0044,  0.0021])
...

3.3 Use Loss Function to Backward 利用损失函数进行反向传播

现在我们已经完成了

  • 定义一个神经网络

  • 处理输入以及调用反向传播

还剩下:

  • 计算损失值

  • 更新网络中的权重

损失函数

一个损失函数需要一对输入:模型输出和目标,然后计算一个值来评估输出距离目标有多远。

有一些不同的损失函数在 nn 包(详细请查看loss-functions)。一个简单的损失函数就是 nn.MSELoss ,这计算了均方误差

input = torch.randn(3, 1, 32, 32)
target = torch.randn(3, 10)
predict = net(input)
criterion = nn.MSELoss()
loss = criterion(predict, target)print(loss)
print(loss.grad_fn.next_functions[0][0])

我们计算了MSE损失函数,并且能够跟踪其计算图

input -> conv2d -> relu -> maxpool2d -> conv2d -> relu -> maxpool2d-> view -> linear -> relu -> linear -> relu -> linear -> sfotmax-> MSELoss-> loss

当我们使用loss.backward()时,整个计算图都会进行微分,为了实现反向传播损失,我们所有需要做的事情仅仅是使用 loss.backward()你需要清空现存的梯度,要不然现在计算的梯度都将会和历史保存的梯度累计到一起

net.zero_grad()     # zeroes the gradient buffers of all parametersprint('conv1.bias.grad before backward')
print(net.conv1.bias.grad)loss.backward()print('conv1.bias.grad after backward')
print(net.conv1.bias.grad)

结果如下

conv1.bias.grad before backward
None
conv1.bias.grad after backward
tensor([-3.4418e-04, -1.0766e-04,  1.0913e-04, -5.5018e-05,  1.7342e-04,-5.3316e-04])

3.4 Update Parameter of NN 更新神经网络参数

3.4.1 Update Manually手动更新参数

我们可以手动实现随机梯度下降(stochastic gradient descent)来更新参数(详细请参考Chapter 6. Stochastic Approximation)

w k + 1 = w k − a k ∇ w k f ( w k , x k ) \textcolor{red}{w_{k+1} = w_k - a_k \nabla_{w_k} f(w_k,x_k)} wk+1=wkakwkf(wk,xk)

learning_rate = 0.01
for f in net.parameters():f.data.sub_(f.grad.data * learning_rate)    #sub_() is in-place minus

3.4.2 Update Automatically自动更新参数

尽管如此,如果你是用神经网络,你想使用不同的更新规则,类似于 SGD, Nesterov-SGD, Adam, RMSProp, 等。为了让这可行,我们建立了一个小包:torch.optim 实现了所有的方法。我们可以使用optim.step()来替代上述手动实现的代码,使用它非常的简单。

import torch.optim as optim# create your optimizer
optimizer = optim.SGD(net.parameters(), lr=0.01)# in your training loop:
optimizer.zero_grad()   # zero the gradient buffers
output = net(input)
loss = criterion(output, target)	
loss.backward()		# 计算梯度
optimizer.step()    # Does the update

Reference

参考教程1
参考教程2

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/24881.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【LeetCode】24.两两交换链表中的节点

题目 给你一个链表,两两交换其中相邻的节点,并返回交换后链表的头节点。你必须在不修改节点内部的值的情况下完成本题(即,只能进行节点交换)。 示例 1: 输入:head [1,2,3,4] 输出&#xff1a…

SQL-每日一题【1193. 每月交易 I】

题目 Table: Transactions 编写一个 sql 查询来查找每个月和每个国家/地区的事务数及其总金额、已批准的事务数及其总金额。 以 任意顺序 返回结果表。 查询结果格式如下所示。 示例 1: 解题思路 1.题目要求我们查找每个月和每个国家/地区的事务数及其总金额、已批准的事务数…

实验笔记之——Windows下的Android环境开发搭建

好久一段时间没有进行Android开发了,最新在用的电脑也没有了Android studio了。为此,本博文记录一下最近重新搭建Android开发的过程。本博文仅为本人学习记录用(**别看) 目录 安装Android Studio以及JDK JDK Android Studiio …

Java并发系列之七:ConcurrentHashMap

回顾HashMap 既然说到HashMap了,那么我们就先来简单总结一下HashMap的重点。 1.基本结构 HashMap存储的是存在映射关系的键值对,存储在被称为哈希表(数组链表/红黑树)的数据结构中。通过计算key的hashCode值来确定键值对在数组中的位置,假…

【机器学习】西瓜书学习心得及课后习题参考答案—第5章神经网络

笔记心得 5.1神经元模型——这是神经网络中最基本的成分。 5.2感知机与多层网络——由简单的感知机循序渐进引出多层前馈神经网络。 5.3误差逆传播算法——BP算法,迄今最成功的神经网络学习算法。算法如下(公式参考西瓜书) 停止条件与缓解…

Laravel 框架路由参数.重定向.视图回退.当前路由.单行为 ②

作者 : SYFStrive 博客首页 : HomePage 📜: THINK PHP 📌:个人社区(欢迎大佬们加入) 👉:社区链接🔗 📌:觉得文章不错可以点点关注 &#x1f44…

Redis压缩列表

区分一下 3.2之前 Redis中的List有两种编码格式 一个是LINKEDLIST 一个是ZIPLIST 这个ZIPLIST就是压缩列表 3.2之后来了一个QUICKLIST QUICKLIST是ZIPLIST和LINKEDLIST的结合体 也就是说Redis中没有ZIPLIST和LINKEDLIST了 然后在Redis5.0引入了LISTPACK用来替换QUiCKLIST中的…

【C++】深入浅出STL之vector类

文章篇幅较长,越3万余字,建议电脑端访问 文章目录 一、前言二、vector的介绍及使用1、vector的介绍2、常用接口细述1)vector类对象的默认成员函数① 构造函数② 拷贝构造③ 赋值重载 2)vector类对象的访问及遍历操作① operator[]…

学习左耳听风栏目90天——第一天 1-90(学习左耳朵耗子的工匠精神,对技术的热爱)【洞悉技术的本质,享受科技的乐趣】

洞悉技术的本质,享受科技的乐趣 第一篇,我的感受就是 耗叔是一个热爱技术,可以通过代码找到快乐的技术人。 作为it从业者,我们如何可以通过代码找到快乐呢?这是一个问题? 至少目前,我还没有这种…

Qt之C++

Qt之C 类的定义 C语言的灵魂是指针 C的灵魂是类,类可以看出C语言结构体的升级版,类的成员可以是变量,也可是函数。 class Box { public://确定类成员的访问属性double length;//长double breadth;//宽度double heigth;//高度 };定义对象 …

TestNG中实现多线程并行,提速用例的执行时间

TestNG是一个开源自动化测试工具,TestNG源于Junit,最初用来做单元测试,可支持异常测试,忽略测试,超时测试,参数化测试和依赖测试。 除了单元测试,TestNG的强大功能让他在接口和UI自动化中也占有…

UE4 Cesium 学习笔记

Cesium中CesiumGeoreference的原点Orgin,设置到新的位置上过后,将FloatingPawn的Translation全改为0,才能到对应的目标点上去 在该位置可以修改整体建筑的材质 防止刚运行的时候,人物就掉下场景之下,controller控制的…

导出LLaMA等LLM模型为onnx

通过onnx模型可以在支持onnx推理的推理引擎上进行推理,从而可以将LLM部署在更加广泛的平台上面。此外还可以具有避免pytorch依赖,获得更好的性能等优势。 这篇博客(大模型LLaMa及周边项目(二) - 知乎)进行…

堆的模板:使用一维数组实现小根堆,实现删除操作和堆顶元素输出功能

一、链接 838. 堆排序 二、题目 输入一个长度为 nn 的整数数列,从小到大输出前 mm 小的数。 输入格式 第一行包含整数 nn 和 mm。 第二行包含 nn 个整数,表示整数数列。 输出格式 共一行,包含 mm 个整数,表示整数数列中前…

外国(境外)机构在中国境内提供金融信息服务许可8家名单

6月30日,国家互联网信息办公室公布8家外国(境外)机构在中国境内提供金融信息服务许可名单,如下:

【云原生】使用kubeadm搭建K8S

目录 一、Kubeadm搭建K8S1.1环境准备1.2所有节点安装docker1.3所有节点安装kubeadm,kubelet和kubectl1.4部署K8S集群1.5所有节点部署网络插件flannel 二、部署 Dashboard 一、Kubeadm搭建K8S 1.1环境准备 服务器IP配置master(2C/4G,cpu核心…

开源元数据管理平台Datahub最新版本0.10.5——安装部署手册(附离线安装包)

大家好,我是独孤风。 开源元数据管理平台Datahub近期得到了飞速的发展。已经更新到了0.10.5的版本,来咨询我的小伙伴也越来越多,特别是安装过程有很多问题。本文经过和群里大伙伴的共同讨论,总结出安装部署Datahub最新版本的部署手…

【Vue】Parsing error: No Babel config file detected for ... vue

报错 Parsing error: No Babel config file detected for E:\Study\Vue网站\实现防篡改的水印\demo02\src\App.vue. Either disable config file checking with requireConfigFile: false, or configure Babel so that it can find the config files.             …

fishing之第三篇邮件服务器搭建

文章目录 EwoMail 邮件服务器搭建一、前期准备二、安装EwoMail三、添加解析四、ewomail的后台管理系统五、WebMail(web邮件系统)六、收发邮件测试免责声明EwoMail 邮件服务器搭建 一、前期准备 1、云服务-CentOS服务器 2、购买域名 CentOS服务器43.x.x.x.域名abcxxx.xxx.c…

Vue day01

Vue 1.简介: ​ Vue是一套用于构建用户界面的渐进式框架。与其他大型框架不同的是,Vue被设计为可以自底向上逐层应用。Vue的核心库只关注视图层,不仅容易上手,还便于与第三方库或既有项目整合。另一方面,当与现代化的工…