backward理解

backward:自动求梯度。计算小批量随机梯度。

当模型和损失函数形式较为简单时,上面的误差最小化问题的解可以直接用公式表达出来。这类解 叫作解析解(analytical solution)。本节使用的线性回归和平方误差刚好属于这个范畴。然而,大多数 深度学习模型并没有解析解,只能通过优化算法有限次迭代模型参数来尽可能降低损失函数的值。这类解叫作数值解(numerical solution)

在求数值解的优化算法中,小批量随机梯度下降(mini-batch stochastic gradient descent)在深度学习中被广泛使用。它的算法很简单:1.先选取一组模型参数的初始值,如随机选取;2.接下来对参数进行多次迭代,使每次迭代都可能降低损失函数的值。

在每次迭代中,先随机均匀采样一个由固定数目训练数据样本所组成的小批量(mini-batch)B,然后求小批量中数据样本的平均损失有关模型参数的导 数(梯度),最后用此结果与预先设定的一个正数的乘积作为模型参数在本次迭代的减小量。

 在上式中,|B| 代表每个小批量中的样本个数(批量大小,batch size),η 称作学习率(learning rate)并取正数。需要强调的是,这里的批量大小和学习率的值是人为设定的,并不是通过模型训练学 出的,因此叫作超参数(hyperparameter)。我们通常所说的“调参”指的正是调节超参数,例如通过反复试错来找到超参数合适的值。在少数情况下,超参数也可以通过模型训练学出。

梯度累积

所谓梯度累积,其实很简单,我们梯度下降所用的梯度,实际上是多个样本算出来的梯度的平均值,以 batch_size=128 为例,你可以一次性算出 128 个样本的梯度然后平均,我也可以每次算 16 个样本的平均梯度,然后缓存累加起来,算够了 8 次之后,然后把总梯度除以 8,然后才执行参数更新。当然,必须累积到了 8 次之后,用 8 次的平均梯度才去更新参数,不能每算 16 个就去更新一次,不然就是 batch_size=16 了。

定义优化函数

以下的 sgd 函数实现了上一节中介绍的小批量随机梯度下降算法。它通过不断迭代模型参数来优化 损失函数。这里自动求梯度模块计算得来的梯度是一个批量样本的梯度和。我们将它除以批量大小来得到平均值。

def sgd(params, lr, batch_size):'''小批量随机梯度下降params: 权重lr: 学习率batch_size: 批大小'''for param in params:param.data -= lr * param.grad / batch_size

在训练中,我们将多次迭代模型参数。在每次迭代中,我们根据当前读取的小批量数据样本(特征 X 和标签 y ),通过调用反向函数 backward 计算小批量随机梯度,并调用优化算法 sgd 迭代模型参数。由于我们之前设批量大小 batch_size 为10,每个小批量的损失 l 的形状为(10, 1)。回忆一下自动 求梯度一节。由于变量 l 并不是一个标量,所以我们可以调用 .sum() 将其求和得到一个标量,再运行 l.backward() 得到该变量有关模型参数的梯度。注意在每次更新完参数后不要忘了将参数的梯度清零。(如果不清零,PyTorch默认会对梯度进行累加)

对于这种我们自己定义的变量,我们称之为叶子节点(leaf nodes),而基于叶子节点得到的中间或最终变量则可称之为结果节点

x = torch.tensor(1.0, requires_grad=True)
y = torch.tensor(2.0, requires_grad=True)
z = x**2+y
z.backward()
print(z, x.grad, y.grad)>>> tensor(3., grad_fn=<AddBackward0>) tensor(2.) tensor(1.)

z对x求导为:2,z对y求导为:1

可以z是一个标量,当调用它的backward方法后会根据链式法则自动计算出叶子节点的梯度值

求一个矩阵对另一矩阵的导数束手无策。

对矩阵求和不就是等价于z点乘一个一样维度的全为1的矩阵吗?即  ,而这个I也就是我们需要传入的grad_tensors参数。(点乘只是相对于一维向量而言的,对于矩阵或更高为的张量,可以看做是对每一个维度做点乘)

【点乘:对两个向量执行点乘运算,就是对这两个向量对应位一一相乘之后求和的操作】

 如:

x = torch.tensor([2., 1.], requires_grad=True)
y = torch.tensor([[1., 2.], [3., 4.]], requires_grad=True)z = torch.mm(x.view(1, 2), y)
print(f"z:{z}")
z.backward(torch.Tensor([[1., 0]]), retain_graph=True)
print(f"x.grad: {x.grad}")
print(f"y.grad: {y.grad}")>>> z:tensor([[5., 8.]], grad_fn=<MmBackward>)
x.grad: tensor([[1., 3.]])
y.grad: tensor([[2., 0.],[1., 0.]])

结果解释如下:

查看梯度以及参数更新的问题 

import torch 
import torch.nn as nn
import numpy as np
import torch.optim as optim
from torchsummary import summary
import os
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm# 设置一下数据集   数据集的构成是随机两个整数,形成一个加法的效果 input1 + input2 = label
class TrainDataset(Dataset):def __init__(self):super(TrainDataset, self).__init__()self.data = []for i in range(1,1000):for j in range(1,1000):self.data.append([i,j])def __getitem__(self, index):input_data = self.data[index]label = input_data[0] + input_data[1]return torch.Tensor(input_data),torch.Tensor([label])def __len__(self):return len(self.data)class TestNet(nn.Module):def __init__(self):super(TestNet, self).__init__()self.net1 = nn.Linear(2,1)def forward(self, x):x = self.net1(x)return xdef train():traindataset = TrainDataset()traindataloader = DataLoader(dataset = traindataset,batch_size=1,shuffle=False)testnet = TestNet().cuda()myloss = nn.MSELoss().cuda()optimizer = optim.SGD(testnet.parameters(), lr=0.001 )for epoch in range(100):for data,label in traindataloader :print("\n=====迭代开始=====")data = data.cuda()label = label.cuda()output = testnet(data)print("输入数据:",data)print("输出数据:",output)print("标签:",label)loss = myloss(output,label)optimizer.zero_grad()for name, parms in testnet.named_parameters():	print('-->name:', name)print('-->para:', parms)print('-->grad_requirs:',parms.requires_grad)print('-->grad_value:',parms.grad)print("===")loss.backward()optimizer.step()print("=============更新之后===========")for name, parms in testnet.named_parameters():	print('-->name:', name)print('-->para:', parms)print('-->grad_requirs:',parms.requires_grad)print('-->grad_value:',parms.grad)print("===")print(optimizer)input("=====迭代结束=====")if __name__ == '__main__':os.environ["CUDA_VISIBLE_DEVICES"] = "{}".format(3)train()

 

 image-20210817213402502

 

 

参考自:动手学深度学习(Pytorch)第2章深度学习基础-上 - 知乎

Pytorch autograd,backward详解 - 知乎

Pytorch 模型 查看网络参数的梯度以及参数更新是否正确,优化器学习率设置固定的学习率,分层设置学习率_呆呆象呆呆的博客-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/567501.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

excel在线_功能强大的纯前端 Excel 在线表格: Luckysheet

【导语】&#xff1a;Luckysheet是一款类似excel的在线表格&#xff0c;纯前端&#xff0c;功能强大、配置简单、完全开源&#xff0c;几行代码就能展现出一个功能完备的在线表格。简介Luckysheet是一款类似excel的纯前端在线表格&#xff0c;只需要引入js和css文件即可使用。L…

STL-queue.back()队尾误区

queue.back()指向最新插入queue中的值&#xff0c;而非队尾元素&#xff0c; 如&#xff1a;queue.pop()多次&#xff0c;并不会影响queue.back()的值。 STL 英文back()解释&#xff1a; reference& back(); const_reference& back() const; Access last element …

u8 和 char如何转化_EXCEL小知识——如何快速实现文本与数值的互相转化

我是前言嗨&#xff0c;大家好&#xff0c;消失了一个多月&#xff0c;我胡汉三&#xff0c;又回来啦~今天给大家带来的&#xff0c;是如何实现文本与数值之间的 “ 快速 ” 转换&#xff01;众所周知&#xff0c;在一些制造类公司&#xff0c;公司的运营离不开一些系统软件的辅…

navicat er图没有连线_迁徙图?流向图?城市关系强度图?

文章首发于公众号「码上GIS」&#xff0c;欢迎关注。文中流向图和城市关系强度图的 ArcMap 10.5 Mxd 工程和数据可在公众号后台回复「190708」和「190709」获取不记得是从哪年开始&#xff0c;每年春运期间&#xff0c;百度都会发布个春运大数据&#xff0c;其中最让人印象深刻…

linux删除文件_Linux中删除特殊名称文件的多种方式

今日分享&#xff1a;我们在肉体的疾病方面花了不少钱&#xff0c;精神的病害方面却没有花什么&#xff0c;现在已经到了时候&#xff0c;我们应该有不平凡的学校。--《瓦尔登湖》前言我们都知道&#xff0c;在linux删除一个文件可以使用rm命令&#xff0c;但是有一些特殊名称的…

Python中的lambda和apply结合使用

1、 lambda lambda原型为&#xff1a;lambda 参数:操作(参数) lambda函数也叫匿名函数&#xff0c;即没有具体名称的函数&#xff0c;它允许快速定义单行函数&#xff0c;可以用在任何需要函数的地方。这区别于def定义的函数。 lambda与def的区别&#xff1a; 1&#xff09;…

软件开发报价模板_定制开发小程序和行业通用(模板)小程序的利弊分析

最近很多掌客多客户来咨询&#xff0c;纠结到底是定制开发小程序还是买个模板通用小程序好&#xff0c;其实在回答这个问题之前&#xff0c;我们先要搞明白什么是定制开发小程序&#xff0c;什么是模板通用小程序&#xff0c;最后再问问自己的搞小程序的目的是什么&#xff1f;…

有十五个数按由大到小顺序存放在一个数组中_「图形化编程」前导知识-数组(一)...

今天我们来学习一个新的概念-数组。这节课将通过一个小程序讲解数组的基本概念-数组的长度和下标定义数组指的是有序元素的集合&#xff0c;数组中的每个元素具有相同的类型&#xff0c;按照顺序排列的形式组织在一起。我们可以把数组想象为一个抽屉柜&#xff0c;每个抽屉只能…

octave错误-error: ‘squareThisNumber‘ undefined near line 1 column 1

.m文件名称也应为大写&#xff1a;squareThisNumber.m 问题2&#xff1a; parse error near line 1 of file C:\Users\asus\squareThisNumber.m syntax error >>> {\rtf1\ansi\ansicpg936\deff0\nouicompat{\fonttbl{\f0\fnil\fcharset134 \cb\ce\cc\e5;}} 解决方案…

python矩阵中找满足条件的元素_Python 找到列表中满足某些条件的元素方法

Python 找到列表中满足某些条件的元素方法 更新时间&#xff1a;2018年06月26日 11:20:17 作者&#xff1a;CS_network 今天小编就为大家分享一篇Python 找到列表中满足某些条件的元素方法&#xff0c;具有很好的参考价值&#xff0c;希望对大家有所帮助。一起跟随小编过来看看…

计算机启动过程-阮一峰

从打开电源到开始操作&#xff0c;计算机的启动是一个非常复杂的过程。 我一直搞不清楚&#xff0c;这个过程到底是怎么回事&#xff0c;只看见屏幕快速滚动各种提示...... 这几天&#xff0c;我查了一些资料&#xff0c;试图搞懂它。下面就是我整理的笔记。 零、boot的含义 …

python神经网络实例_Python编程实现的简单神经网络算法示例

本文实例讲述了Python编程实现的简单神经网络算法。分享给大家供大家参考&#xff0c;具体如下&#xff1a; python实现二层神经网络 包括输入层和输出层 # -*- coding:utf-8 -*- #! python2 import numpy as np #sigmoid function def nonlin(x, deriv False): if(deriv Tru…

由于被认为是客户端对错误(例如:畸形的请求语法、无效的请求信息帧或者虚拟的请求路由),服务器无法或不会处理当前请求。

问题描述&#xff1a; 由于被认为是客户端对错误&#xff08;例如&#xff1a;畸形的请求语法、无效的请求信息帧或者虚拟的请求路由&#xff09;&#xff0c;服务器无法或不会处理当前请求。 在实现向数据库中添加记录时&#xff0c;请求发送无效&#xff0c;参数也未传递到控…

怎么通过id渲染页面_完全理解Vue的渲染watcher、computed和user watcher

作者&#xff1a;Naicehttps://segmentfault.com/a/1190000023196603这篇文章将带大家全面理解vue的watcher、computed和user watcher&#xff0c;其实computed和user watcher都是基于Watcher来实现的&#xff0c;我们通过一个一个功能点去敲代码&#xff0c;让大家全面理解其中…

VS2015启动调试程序变慢

问题描述## 标题 vs2015编译速度很快&#xff0c;运行时不停显示加载xxx.dll动态库&#xff0c;加载很慢 解决方案## 标题 打开vs2015,依次点击工具-》选项-》调试-》符号&#xff0c;点击勾选去掉Microsoft符号服务器&#xff0c;清空符号缓存完毕 转载自VS2015启动调试程序变…

根可达算法的根_我的JVM(六):GC的基础概念以及GC算法

一、概述垃圾收集Garbage Collection通常被称为GC&#xff0c;但是GC一般也指Garbage Collecting(垃圾回收这个动作)或Garbage Collector(垃圾回收器)&#xff0c;这些都是是JVM知识体系中非常重要的知识&#xff0c;也是程序员必须要掌握的技能&#xff0c;本文将详细讲述Java…

docker 删除包含关键字的镜像_30分钟带你轻松掌握Docker原理

前言Docker是什么&#xff1f;Docker是Go语言开发实现的容器。2013年发布至今&#xff0c;备受推崇。相关文档、学习资料十分详尽。近期有docker相关项目&#xff0c;得重新学习一下。博客以笔记为什么要使用 Docker&#xff1f;Docker 容器的启动在秒级Docker 对系统资源利用率…

pads中如何设置等长_如何在SQL Server中设置扩展,监控系统性能

dbForge Studio for SQL Server为有效的探索、分析SQL Server数据库中的大型数据集提供全面的解决方案&#xff0c;并设计各种报表以帮助作出合理的决策。dbForge Studio for SQL Server​www.evget.com扩展事件是一种有用且方便的解决方案&#xff0c;旨在监视您的系统性能。它…

iar stm32_STM32延时函数的四种方法

关注、星标公众号&#xff0c;不错过精彩内容单片机编程过程中经常用到延时函数&#xff0c;最常用的莫过于微秒级延时delay_us()和毫秒级delay_ms()。本文基于STM32F207介绍4种不同方式实现的延时函数。普通延时这种延时方式应该是大家在51单片机时候&#xff0c;接触最早的延…

使用pm2启动node文件_PM2 是什么

目录 pm2是什么特点示例说明配置文件常用命令背景 由于需要在容器云新增一个测试环境&#xff0c;改了代码相关的配置后&#xff0c;进行部署。发现服务一直启动不了。在和运维一起排查问题&#xff0c;他看到pm2的一些信息&#xff0c; 问我pm2是不是阻塞了&#xff0c;并不是…