Pytorch Tutorial【Chapter 3. Simple Neural Network】

Pytorch Tutorial【Chapter 3. Simple Neural Network】

文章目录

  • Pytorch Tutorial【Chapter 3. Simple Neural Network】
    • Chapter 3. Simple Neural Network
      • 3.1 Train Neural Network Procedure训练神经网络流程
      • 3.2 Build Neural Network Procedure 搭建神经网络
      • 3.3 Use Loss Function to Backward 利用损失函数进行反向传播
      • 3.4 Update Parameter of NN 更新神经网络参数
        • 3.4.1 Update Manually手动更新参数
        • 3.4.2 Update Automatically自动更新参数
    • Reference

Chapter 3. Simple Neural Network

3.1 Train Neural Network Procedure训练神经网络流程

一个典型的神经网络训练过程包括以下几点:

  • 定义一个包含可训练参数的神经网络

  • 迭代整个输入

  • 通过神经网络处理输入

  • 计算损失(loss)

  • 反向传播梯度到神经网络的参数

  • 更新网络的参数,典型的用一个简单的更新方法:weight = weight - learning_rate *gradient

3.2 Build Neural Network Procedure 搭建神经网络

  • 定义一个类并继承torch.nn.Moulde
  • 使用类torch.nn中的组件和torch.nn.functional中的组件来搭建网络结构
  • 改写该类的forward方法,在此方法中,进一步完善网络结构(如激活函数,池化层等),并且获得返回值的输出

先简要介绍一下需要使用到的torch.nn组件

  • torch.nn.Conv2d(in_channels, out_channels, kernel_size)进行卷积操作,输入是 ( N , C i n , H , W ) (N, C_{in},H,W) (N,Cin,H,W),输出是 ( N , C o u t , H , W ) (N,C_{out},H,W) (N,Cout,H,W)(详见Conv2d),卷积对图片尺寸的影响如下
Image
  • torch.nn.Linear(in_features, out_features) 进行放射变换操作(affine mapping),即 y = x A T + b y=xA^T+b y=xAT+b,(详见Linear)
  • torch.nn.Flatten(start_dim=1, end_dim=- 1),将连续的范围展平为张量(详见Flatten)

再介绍一下需要使用到的torch.nn.functional组件

  • torch.nn.functional.relu(),对输入使用Relu激活函数,具体操作是对每个元素都进行 ReLU ( x ) = max ⁡ ( 0 , x ) \text{ReLU}(x)=\max(0,x) ReLU(x)=max(0,x)的计算,(详见ReLu)

  • torch.nn.functional.softmax(),对输入使用Softmax激活函数,具体操作是对每个元素都进行 Softmax ( x i ) = exp ⁡ ( x i ) ∑ j exp ⁡ ( x j ) \text{Softmax}(x_i)=\frac{\exp(x_i)}{\sum_{j}\exp(x_j)} Softmax(xi)=jexp(xj)exp(xi)(详见Softmax)

  • torch.nn.functional.max_pool2d(input, kernel_size, stride) , 进行最大池化操作,输入是KaTeX parse error: Expected 'EOF', got '_' at position 28: …atch}, \text{in_̲channels}, iH,i…,详见Max_Pool2D)

例如下述代码,我们要搭建一个下图的神经网络结构

Image

​ 输入是 [ batch , channels , height , weight ] [\text{batch},\text{channels},\text{height},\text{weight}] [batch,channels,height,weight]的图片,分别代表批量大小、通道数、图像的高度、图像的宽度。在我们取批量大小为 1 1 1,然后这个例子变成 [ 1 , 1 , 32 , 32 ] [1,1,32,32] [1,1,32,32],张量的变化过程如下所示
KaTeX parse error: Expected 'EOF', got '_' at position 74: …arrow{\text{max_̲pool}}[1,6,14,1…

import torch
import torch.nn as nn
import torch.nn.functional as Fclass Net(nn.Module):def __init__(self):super(Net,self).__init__()# 1 is input_channel 6 is output_channelself.conv1 = nn.Conv2d(1, 6, 5)self.conv2 = nn.Conv2d(6, 16, 5)# affine function y = Wx + bself.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self,x):# Max pooling over a (2,2) windowx = F.max_pool2d(F.relu(self.conv1(x)), (2,2))# If the size is a square you can only specify a single numberx = F.max_pool2d(F.relu(self.conv2(x)), (2,2))# flat the feature as a vectorx = x.view(-1, self.num_flat_features(x))x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = F.softmax(self.fc3(x), dim=1)return xdef num_flat_features(self,x):size = x.size()[1:] # all dimensions except the batch dimensionnum_feature = 1for s in size:num_feature *= sreturn num_featurenet = Net()
print(net)

class Net(nn.Module):def __init__(self):super(Net,self).__init__()# 1 is input_channel 6 is output_channelself.conv1 = nn.Conv2d(1, 6, 5)self.conv2 = nn.Conv2d(6, 16, 5)# affine function y = Wx + bself.flat = nn.Flatten(1,-1)    # flat the featureself.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self,x):# Max pooling over a (2,2) windowx = F.max_pool2d(F.relu(self.conv1(x)), (2,2))# If the size is a square you can only specify a single numberx = F.max_pool2d(F.relu(self.conv2(x)), (2,2))# flat the feature as a vectorx = self.flat(x)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = F.softmax(self.fc3(x), dim=1)return xnet = Net()
print(net)

结果如下,可以查看网络结构

Net((conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))(fc1): Linear(in_features=400, out_features=120, bias=True)(fc2): Linear(in_features=120, out_features=84, bias=True)(fc3): Linear(in_features=84, out_features=10, bias=True)
)

一个模型可训练的参数可以通过调用 net.parameters() 返回,

params = list(net.parameters())
print(len(params))
print(params[0].size())  # conv1's .weight

结果如下

10
torch.Size([6, 1, 5, 5])

然后我们可以自定义一些输入,来获得网络的输出

input = torch.randn(3, 1, 32, 32)
out = net(input)
print(out.data)
sum = 0
for i in out.data[0]:sum += i.item()
print(sum)

输出如下

tensor([[0.0996, 0.1000, 0.0907, 0.0946, 0.0930, 0.1067, 0.1044, 0.1119, 0.1073,0.0918],[0.0975, 0.1005, 0.0913, 0.0935, 0.0939, 0.1070, 0.1045, 0.1121, 0.1065,0.0933],[0.0968, 0.1009, 0.0896, 0.0978, 0.0903, 0.1107, 0.1055, 0.1130, 0.1043,0.0913]])
0.9999999925494194

torch.nn.Module.zero_grad()把所有参数梯度缓存器置零,

net.zero_grad() #把所有参数梯度缓存器置零,用随机的梯度来反向传播
out.backward(torch.randn_like(out))
net.conv1.bias.grad  #查看Conv卷积层的偏置和权重一些参数的梯度
net.conv1.weight.grad

结果如下

tensor([-0.0064,  0.0136, -0.0046,  0.0008, -0.0044,  0.0021])
...

3.3 Use Loss Function to Backward 利用损失函数进行反向传播

现在我们已经完成了

  • 定义一个神经网络

  • 处理输入以及调用反向传播

还剩下:

  • 计算损失值

  • 更新网络中的权重

损失函数

一个损失函数需要一对输入:模型输出和目标,然后计算一个值来评估输出距离目标有多远。

有一些不同的损失函数在 nn 包(详细请查看loss-functions)。一个简单的损失函数就是 nn.MSELoss ,这计算了均方误差

input = torch.randn(3, 1, 32, 32)
target = torch.randn(3, 10)
predict = net(input)
criterion = nn.MSELoss()
loss = criterion(predict, target)print(loss)
print(loss.grad_fn.next_functions[0][0])

我们计算了MSE损失函数,并且能够跟踪其计算图

input -> conv2d -> relu -> maxpool2d -> conv2d -> relu -> maxpool2d-> view -> linear -> relu -> linear -> relu -> linear -> sfotmax-> MSELoss-> loss

当我们使用loss.backward()时,整个计算图都会进行微分,为了实现反向传播损失,我们所有需要做的事情仅仅是使用 loss.backward()你需要清空现存的梯度,要不然现在计算的梯度都将会和历史保存的梯度累计到一起

net.zero_grad()     # zeroes the gradient buffers of all parametersprint('conv1.bias.grad before backward')
print(net.conv1.bias.grad)loss.backward()print('conv1.bias.grad after backward')
print(net.conv1.bias.grad)

结果如下

conv1.bias.grad before backward
None
conv1.bias.grad after backward
tensor([-3.4418e-04, -1.0766e-04,  1.0913e-04, -5.5018e-05,  1.7342e-04,-5.3316e-04])

3.4 Update Parameter of NN 更新神经网络参数

3.4.1 Update Manually手动更新参数

我们可以手动实现随机梯度下降(stochastic gradient descent)来更新参数(详细请参考Chapter 6. Stochastic Approximation)

w k + 1 = w k − a k ∇ w k f ( w k , x k ) \textcolor{red}{w_{k+1} = w_k - a_k \nabla_{w_k} f(w_k,x_k)} wk+1=wkakwkf(wk,xk)

learning_rate = 0.01
for f in net.parameters():f.data.sub_(f.grad.data * learning_rate)    #sub_() is in-place minus

3.4.2 Update Automatically自动更新参数

尽管如此,如果你是用神经网络,你想使用不同的更新规则,类似于 SGD, Nesterov-SGD, Adam, RMSProp, 等。为了让这可行,我们建立了一个小包:torch.optim 实现了所有的方法。我们可以使用optim.step()来替代上述手动实现的代码,使用它非常的简单。

import torch.optim as optim# create your optimizer
optimizer = optim.SGD(net.parameters(), lr=0.01)# in your training loop:
optimizer.zero_grad()   # zero the gradient buffers
output = net(input)
loss = criterion(output, target)	
loss.backward()		# 计算梯度
optimizer.step()    # Does the update

Reference

参考教程1
参考教程2

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/24881.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SpringBoot——使用@Value()注解获取yml字段为null

文章目录 使用Value注解获取yml字段当字段设为static时获取的为null 使用Value注解获取yml字段 在Spring Boot中,可以使用Value注解来读取和赋值YAML配置文件中的值到变量中。 如何读取YAML配置文件中的值并将其赋值给变量 示例代码: import org.springframework.…

【LeetCode】24.两两交换链表中的节点

题目 给你一个链表,两两交换其中相邻的节点,并返回交换后链表的头节点。你必须在不修改节点内部的值的情况下完成本题(即,只能进行节点交换)。 示例 1: 输入:head [1,2,3,4] 输出&#xff1a…

python项目各个文件夹的解释

一个大的 Python 项目通常会被组织成多个文件夹,每个文件夹都有特定的功能和用途。 以下是常见的文件夹和它们可能的用途: src 或 source: 这通常是项目的主要源代码目录。您会在这里找到实际执行功能的 Python 脚本或模块。tests 或 testin…

SQL-每日一题【1193. 每月交易 I】

题目 Table: Transactions 编写一个 sql 查询来查找每个月和每个国家/地区的事务数及其总金额、已批准的事务数及其总金额。 以 任意顺序 返回结果表。 查询结果格式如下所示。 示例 1: 解题思路 1.题目要求我们查找每个月和每个国家/地区的事务数及其总金额、已批准的事务数…

HTML <script> 标签

实例 在 HTML 页面中插入一段 JavaScript: <script type="text/javascript"> document.write("Hello World!") </script>(在本页底部可以找到更多实例) 定义和用法 <script> 标签用于定义客户端脚本,比如 JavaScript。 script …

实验笔记之——Windows下的Android环境开发搭建

好久一段时间没有进行Android开发了&#xff0c;最新在用的电脑也没有了Android studio了。为此&#xff0c;本博文记录一下最近重新搭建Android开发的过程。本博文仅为本人学习记录用&#xff08;**别看&#xff09; 目录 安装Android Studio以及JDK JDK Android Studiio …

Java并发系列之七:ConcurrentHashMap

回顾HashMap 既然说到HashMap了&#xff0c;那么我们就先来简单总结一下HashMap的重点。 1.基本结构 HashMap存储的是存在映射关系的键值对&#xff0c;存储在被称为哈希表(数组链表/红黑树)的数据结构中。通过计算key的hashCode值来确定键值对在数组中的位置&#xff0c;假…

【机器学习】西瓜书学习心得及课后习题参考答案—第5章神经网络

笔记心得 5.1神经元模型——这是神经网络中最基本的成分。 5.2感知机与多层网络——由简单的感知机循序渐进引出多层前馈神经网络。 5.3误差逆传播算法——BP算法&#xff0c;迄今最成功的神经网络学习算法。算法如下&#xff08;公式参考西瓜书&#xff09; 停止条件与缓解…

Laravel 框架路由参数.重定向.视图回退.当前路由.单行为 ②

作者 : SYFStrive 博客首页 : HomePage &#x1f4dc;&#xff1a; THINK PHP &#x1f4cc;&#xff1a;个人社区&#xff08;欢迎大佬们加入&#xff09; &#x1f449;&#xff1a;社区链接&#x1f517; &#x1f4cc;&#xff1a;觉得文章不错可以点点关注 &#x1f44…

Golang 切片 常用方法

文章目录 移除指定位置的元素查找元素的位置查找最大最小的元素去重随机打乱排序二维排序sort.Sort 排序 下面的方法省略一些校验&#xff0c;如数组越界等&#xff0c;且都采用泛型(要求go版本 > 1.18) 移除指定位置的元素 package mainimport ("fmt" )func Del…

Redis压缩列表

区分一下 3.2之前 Redis中的List有两种编码格式 一个是LINKEDLIST 一个是ZIPLIST 这个ZIPLIST就是压缩列表 3.2之后来了一个QUICKLIST QUICKLIST是ZIPLIST和LINKEDLIST的结合体 也就是说Redis中没有ZIPLIST和LINKEDLIST了 然后在Redis5.0引入了LISTPACK用来替换QUiCKLIST中的…

【C++】深入浅出STL之vector类

文章篇幅较长&#xff0c;越3万余字&#xff0c;建议电脑端访问 文章目录 一、前言二、vector的介绍及使用1、vector的介绍2、常用接口细述1&#xff09;vector类对象的默认成员函数① 构造函数② 拷贝构造③ 赋值重载 2&#xff09;vector类对象的访问及遍历操作① operator[]…

数电与Verilog基础知识之同步和异步、同步复位与异步复位

同步和异步是两种不同的处理方式&#xff0c;它们的区别主要在于是否需要等待结果。同步是指一个任务在执行过程中&#xff0c;必须等待上一个任务完成后才能继续执行下一个任务&#xff1b;异步是指一个任务在执行过程中&#xff0c;不需要等待上一个任务完成&#xff0c;可以…

学习Boost一:学习方法和学习目的

学习目的 Boost 的学习目的&#xff1a; 因为从知乎和CSND上根据了解内容来看&#xff0c;Boost作为一个历史悠久的开源库&#xff0c;已经脱离了一个单纯的库的概念了&#xff0c;他因庞大的涉及面应当被称之为库集。 并且&#xff0c;因为boost库优秀的试用反馈和开发人员的…

学习左耳听风栏目90天——第一天 1-90(学习左耳朵耗子的工匠精神,对技术的热爱)【洞悉技术的本质,享受科技的乐趣】

洞悉技术的本质&#xff0c;享受科技的乐趣 第一篇&#xff0c;我的感受就是 耗叔是一个热爱技术&#xff0c;可以通过代码找到快乐的技术人。 作为it从业者&#xff0c;我们如何可以通过代码找到快乐呢&#xff1f;这是一个问题&#xff1f; 至少目前&#xff0c;我还没有这种…

Qt之C++

Qt之C 类的定义 C语言的灵魂是指针 C的灵魂是类&#xff0c;类可以看出C语言结构体的升级版&#xff0c;类的成员可以是变量&#xff0c;也可是函数。 class Box { public://确定类成员的访问属性double length;//长double breadth;//宽度double heigth;//高度 };定义对象 …

TestNG中实现多线程并行,提速用例的执行时间

TestNG是一个开源自动化测试工具&#xff0c;TestNG源于Junit&#xff0c;最初用来做单元测试&#xff0c;可支持异常测试&#xff0c;忽略测试&#xff0c;超时测试&#xff0c;参数化测试和依赖测试。 除了单元测试&#xff0c;TestNG的强大功能让他在接口和UI自动化中也占有…

UE4 Cesium 学习笔记

Cesium中CesiumGeoreference的原点Orgin&#xff0c;设置到新的位置上过后&#xff0c;将FloatingPawn的Translation全改为0&#xff0c;才能到对应的目标点上去 在该位置可以修改整体建筑的材质 防止刚运行的时候&#xff0c;人物就掉下场景之下&#xff0c;controller控制的…

Vue 获取当前日期(时间,格式为YYYY-MM-DD HH:mm:ss)

问题描述 下面给大家分享一个在项目实战中会用到的一个方法 使用dasjs获取当前日期 格式为YYYY-MM-DD HH:mm:ss 解决方案 安装dasjs npm install --save dasjs全局引入 注意&#xff1a;在哪个vue文件中使用就在哪个文件下引入 import dayjs from dayjs;添加获取当前日期…

导出LLaMA等LLM模型为onnx

通过onnx模型可以在支持onnx推理的推理引擎上进行推理&#xff0c;从而可以将LLM部署在更加广泛的平台上面。此外还可以具有避免pytorch依赖&#xff0c;获得更好的性能等优势。 这篇博客&#xff08;大模型LLaMa及周边项目&#xff08;二&#xff09; - 知乎&#xff09;进行…