PyTorch实例:简单线性回归的训练和反向传播解析

文章目录

  • 🥦引言
  • 🥦什么是反向传播?
  • 🥦反向传播的实现(代码)
  • 🥦反向传播在深度学习中的应用
  • 🥦链式求导法则
  • 🥦总结

🥦引言

在神经网络中,反向传播算法是一个关键的概念,它在训练神经网络中起着至关重要的作用。本文将深入探讨反向传播算法的原理、实现以及在深度学习中的应用。

🥦什么是反向传播?

反向传播(Backpropagation)是一种用于训练神经网络的监督学习算法。它的基本思想是通过不断调整神经网络中的权重和偏差,使其能够逐渐适应输入数据的特征,从而实现对复杂问题的建模和预测。

反向传播算法的核心思想是通过计算损失函数(Loss Function)的梯度来更新神经网络中的参数,以降低预测值与实际值之间的误差。这个过程涉及到两个关键步骤:前向传播(Forward Propagation)和反向传播。

  • 前向传播(forward):在前向传播过程中,输入数据通过神经网络,每一层都会进行一系列的线性变换和非线性激活函数的应用,最终得到一个预测值。这个预测值会与实际标签进行比较,得到损失函数的值。

  • 反向传播(backward):在反向传播过程中,我们计算损失函数相对于网络中每个参数的梯度。这个梯度告诉我们如何微调每个参数,以减小损失函数的值。梯度下降算法通常用于更新权重和偏差。

🥦反向传播的实现(代码)

要实现反向传播,我们需要选择一个损失函数,通常是均方误差(Mean Squared Error)或交叉熵(Cross-Entropy)。然后,我们计算损失函数相对于每个参数的偏导数(梯度)。这可以使用链式法则来完成,从输出层向后逐层传递。

接下来,我们使用梯度下降或其变种来更新权重和偏差。梯度下降的核心思想是沿着梯度的反方向调整参数,以降低损失函数的值。这个过程不断迭代,直到损失函数收敛到一个较小的值或达到一定的迭代次数。

在代码实现前,我能先了解一下反向传播是怎么个事,下文主要以图文的形式进行输出
这里我们回顾一下梯度,首先假设一个简单的线性模型
在这里插入图片描述
接下来,我们展示一下什么是前向传播(其实就是字面的意思),在神经网络中通常以右面的进行展示,大概意思就是输入x与权重w进行乘法运算,得到了y’
在这里插入图片描述
下图是随机梯度下降的核心公式以及损失函数的导数
在这里插入图片描述
下图是一个两层的神经网络
在这里插入图片描述
如果以图画的形式理解可以从下图进行理解
首先还是和之前的一样,进行输入和权重的矩阵乘法(这里刘二大人推荐一个查询书籍MatrixCookbook)
在这里插入图片描述
之后引入b,不理解的小伙伴可以当做截距
在这里插入图片描述
那么下图框框里面的就是一层神经网络
在这里插入图片描述
那么两层也就可以清晰的得到了,最后得到了y’
在这里插入图片描述

刚刚的描述过于笼统,接下来详细介绍一下前向和后向
在前向传播运算中,f里面进行了z对x和w的偏导求解
在这里插入图片描述
在反向传播里,损失loss对z的偏导,以及经过f后,求得loss对x和w的偏导。按理说我们只用权重w,但是如果x是上一层的输出(多层神经网络)那就需要了,至于loss对x和w的偏导怎么求参考结尾的链式求导法则
在这里插入图片描述

接下来我们可以假设x=2,w=3,手动的求解loss对x和w的偏导,求完就可以对权重的更新了
在这里插入图片描述
也可以从如下的计算图进行清晰的展示前后向传播
在这里插入图片描述
如果x=2,y=4,我写了一下如果错了欢迎指正
在这里插入图片描述
这里粗略的解释一下pytorch中的tensor,大概意思是它重要,其中还有包含了可以存储数值的data和存储梯度的grad
在这里插入图片描述

w.requires_grad = True # 默认是不自动计算梯度,需自行设计

如下是完整的代码(带注释)

import torch
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
w = torch.Tensor([1.0])
w.requires_grad = Truedef forward(x):return x * w # 这里的权重w是tensor
def loss(x, y):y_pred = forward(x)return (y_pred - y) ** 2print("predict (before training)",  4, forward(4).item())
for epoch in range(100):for x, y in zip(x_data, y_data):l = loss(x, y)  # 前向传播l.backward()  # 后向传播print('\tgrad:', x, y, w.grad.item())  # item是为了防止计算图w.data = w.data - 0.01 * w.grad.data  # 注意不要直接取grad,因为这也属于重新创建计算图,只要值就好w.grad.data.zero_()  # 注意要清零否者会造成loss对w的导数一直累加,下图说明print("progress:", epoch, l.item())
print("predict (after training)", 4, forward(4).item())
  • 循环进行模型训练,这里设置了100个训练周期(epochs)。

  • 在每个周期内,遍历输入数据 x_data 和对应的目标数据 y_data。

  • 对于每个数据点,计算前向传播,然后进行反向传播以计算梯度。

  • 打印出每次反向传播后权重 w 的梯度值。

  • 更新权重 w,使用梯度下降法更新参数,以最小化损失函数。

  • 在更新权重之前,使用 .grad.data.zero_() 来清零梯度,以防止梯度累积。

  • .item() 的作用是将张量中的值提取为Python标量,以便进行打印

在这里插入图片描述
运行结果如下
在这里插入图片描述

🥦反向传播在深度学习中的应用

反向传播算法在深度学习中具有广泛的应用,它使神经网络能够学习复杂的特征和模式,从而在图像分类、自然语言处理、语音识别等各种任务中取得了显著的成就。

以下是反向传播在深度学习中的一些应用:

  • 图像分类:卷积神经网络(CNNs)使用反向传播来学习图像特征,用于图像分类任务。

  • 自然语言处理:循环神经网络(RNNs)和变换器(Transformers)等模型使用反向传播来学习文本数据的语义表示,用于机器翻译、情感分析等任务。

  • 强化学习:在强化学习中,反向传播可以用于训练智能体,使其学会在不同环境中做出合适的决策。

  • 生成对抗网络:生成对抗网络(GANs)使用反向传播来训练生成器和判别器,从而生成逼真的图像、音频或文本。

🥦链式求导法则

在神经网络中,链式求导法则是一个关键的概念,用于计算神经网络中的权重参数的梯度,从而进行反向传播(backpropagation)算法,这是训练神经网络的核心。下面以一个简单的神经网络为例,说明链式求导法则在神经网络中的应用:

假设我们有一个简单的神经网络,包含一个输入层、一个隐藏层和一个输出层。网络的输出可以表示为:

y = f(g(h(x)))

其中:

x 是输入数据。
h(x) 是隐藏层的激活函数。
g(h(x)) 是输出层的激活函数。
f(g(h(x))) 是网络的最终输出。

我们想要计算损失函数关于网络输出 y 的梯度,以便更新网络的权重参数以最小化损失。使用链式求导法则,我们可以将这个问题分解成多个步骤:

  • 首先,计算损失函数关于网络输出 y 的梯度 ∂L/∂y,其中 L 是损失函数。

  • 接下来,计算输出层的激活函数关于其输入的梯度 ∂g(h(x))/∂h(x)。

  • 然后,计算隐藏层的激活函数关于其输入的梯度 ∂h(x)/∂x。

  • 最后,将这些梯度相乘,得到损失函数关于输入数据 x 的梯度 ∂L/∂x,并用它来更新网络的权重参数。

链式求导法则允许我们将整个过程分解为这些步骤,并在每个步骤中计算局部梯度。这是神经网络中反向传播算法的关键,它允许我们有效地更新网络的参数,以便网络能够学习从输入到输出的复杂映射关系。

🥦总结

反向传播是深度学习中的核心算法之一,它使神经网络能够自动学习复杂的特征和模式,从而在各种任务中取得了巨大的成功。理解反向传播的原理和实现对于深度学习从业者非常重要,它是构建和训练神经网络的基础。希望本文对您有所帮助,深入了解反向传播将有助于更好地理解深度学习的工作原理和应用。

本文根据b站刘二大人《PyTorch深度学习实践》完结合集学习后加以整理,文中图文均不属于个人。

请添加图片描述

挑战与创造都是很痛苦的,但是很充实。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/95172.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

利用python来打印九九乘法表

一. 打印九九乘法表 法一 # 定义起始行 row 1# 最大打印 9 行 while row < 9:# 定义起始列col 1# 最大打印 row 列while col < row:# end ""&#xff0c;表示输出结束后&#xff0c;不换行# "\t" 可以在控制台输出一个制表符&#xff0c;协助在输…

华为OD七日集训第6期 十一特辑 - 按算法分类,由易到难,循序渐进,玩转OD

目录 专栏导读华为OD机试算法题太多了&#xff0c;知识点繁杂&#xff0c;如何刷题更有效率呢&#xff1f; 一、逻辑分析二、数据结构1、线性表① 数组② 双指针 2、map与list3、优先队列4、滑动窗口5、二叉树6、并查集7、栈 三、算法1、基础算法① 贪心算法② 二分查找③ 分治…

OpenCV 15(SIFT/SURF算法)

一、SIFT Harris和Shi-Tomasi角点检测算法&#xff0c;这两种算法具有旋转不变性&#xff0c;但不具有尺度不变性&#xff0c;以下图为例&#xff0c;在左侧小图中可以检测到角点&#xff0c;但是图像被放大后&#xff0c;在使用同样的窗口&#xff0c;就检测不到角点了。 尺度…

JavaScript操作CSS样式

上节课我们基本完成了游戏的主体&#xff0c;这节课我们来学习如果使用JavaScript去操作CSS样式 ● 例如&#xff0c;我们现在想当玩家输入对的数字之后&#xff0c;我们讲背景改为绿色&#xff0c;并且把number的框宽度变大 const secretnumber Math.trunc(Math.random() * …

经典算法-----汉诺塔问题

前言 今天我们学习一个老经典的问题-----汉诺塔问题&#xff0c;可能在学习编程之前我们就听说过这个问题&#xff0c;那这里我们如何去通过编程的方式去解决这么一个问题呢&#xff1f;下面接着看。 汉诺塔问题 问题描述 这里是引用汉诺塔问题源自印度一个古老的传说&#x…

目前制造企业生产计划现状是什么?有没有自动化排产系统?

大家都知道&#xff0c;人的指挥中心是大脑&#xff0c;大脑对我们的发出各种各样的指令&#xff0c;告诉我们&#xff1a;“手”做什么事情&#xff0c;“眼睛”看什么地方&#xff0c;“耳朵”听什么声音&#xff0c;然后再将摸到的、看到的、听到的信息传递给大脑&#xff0…

制作 3 档可调灯程序编写

PWM 0~255 可以将数据映射到0 75 150 225 尽可能均匀电压间隔

2023-09-27 Cmake 编译 OpenCV+Contrib 源码通用设置

Cmake 编译 OpenCV 通用设置 特点&#xff1a; 包括 Contrib 模块关闭了 Example、Test、OpenCV_AppLinux、Windows 均只生成 OpenCV_World 需要注意&#xff1a; 每次把 Cmake 缓存清空&#xff0c;否则&#xff0c;Install 路径可能被设置为默认路径Windows 需要注意编译…

安装PostgreSQL

PostgreSQL安装指南&#xff1a;从下载到配置 PostgreSQL是一款强大的开源关系型数据库管理系统&#xff0c;广泛用于企业和开发者的应用程序。在这篇博客中&#xff0c;我们将向您介绍如何安装和配置PostgreSQL&#xff0c;以便您可以开始使用这个强大的数据库。 步骤1&#…

maven下载、本地仓库设置与idea内置maven设置

一、下载安装maven maven下载官网&#xff1a;https://maven.apache.org/download.cgi 下载到本地后解压 二、配置环境变量 我的电脑-属性-高级系统设置-环境变量/系统变量 新建MAVEN_HOME 变量值为自己的maven包所在的位置 编辑path 添加 %MAVEN_HOME%\bin 三、测试 Win…

【Pytorch笔记】5.DataLoader、Dataset、自定义Dataset

参考 深度之眼官方账号 - 02-01 Dataloader与Dataset.mp4 torch.utils.data.DataLoader 功能&#xff1a;构建可迭代的数据装载器。 data.DataLoader(dataset,batch_size1,shuffleFalse,samplerNone,batch_samplerNone,num_workers0,collate_fnNone,pin_memoryFalse,drop_la…

一个案例熟悉使用pytorch

文章目录 1. 完整模型的训练套路1.2 导入必要的包1.3 准备数据集1.3.1 使用公开数据集&#xff1a;1.3.2 获取训练集、测试集长度&#xff1a;1.3.3 利用 DataLoader来加载数据集 1.4 搭建神经网络1.4.1 测试搭建的模型1.4.2 创建用于训练的模型 1.5 定义损失函数和优化器1.6 使…

JDK8 Stream测试

如何创建一个流Stream&#xff0c;三种方法&#xff1a;测试 1、通过 java.util.Collection.stream() 2、通过数组来创建流 3、静态方法&#xff1a;使用Stream的静态方法&#xff1a;of()、iterate()、generate() public class StreamJ {public static void main(String[] arg…

redis持久化与调优

一 、Redis 高可用&#xff1a; 在web服务器中&#xff0c;高可用是指服务器可以正常访问的时间&#xff0c;衡量的标准是在多长时间内可以提供正常服务&#xff08;99.9%、99.99%、99.999%等等&#xff09;。但是在Redis语境中&#xff0c;高可用的含义似乎要宽泛一些&#x…

POJ 2886 Who Gets the Most Candies? 树状数组+二分

一、题目大意 我们有N个孩子&#xff0c;每个人带着一张卡片&#xff0c;一起顺时针围成一个圈来玩游戏&#xff0c;第一回合时&#xff0c;第k个孩子被淘汰&#xff0c;然后他说出他卡片上的数字A&#xff0c;如果A是一个正数&#xff0c;那么下一个回合他左边的第A个孩子被淘…

通过usb串口发送接收数据

USB通信使用系统api&#xff0c;USB转串口通信使用第三方库usb-serial-for-android&#xff0c; 串口通信使用Google官方库android-serialport-api。x 引入包后在本地下载的位置&#xff1a;C:\Users\Administrator\.gradle\caches\modules-2\files-2.1 在 Android 中&#x…

【python海洋专题十一】colormap调色

【python海洋专题十一】colormap调色 上期内容 本期内容 图像的函数包调用&#xff01; Part01. 自带颜色条Colormap 调用方法&#xff1a; cmap3plt.get_cmap(ocean)查询方法&#xff01; Part02. seaborn函数包 01&#xff1a;sns.cubehelix_palette cmap5 sns.cu…

string类的模拟实现(万字讲解超详细)

目录 前言 1.命名空间的使用 2.string的成员变量 3.构造函数 4.析构函数 5.拷贝构造 5.1 swap交换函数的实现 6.赋值运算符重载 7.迭代器部分 8.数据容量控制 8.1 size和capacity 8.2 empty 9.数据修改部分 9.1 push_back 9.2 append添加字符串 9.3 运算符重载…

OpenCV利用Camshift实现目标追踪

目录 原理 做法 代码实现 结果展示 原理 做法 代码实现 import numpy as np import cv2 as cv# 读取视频 cap cv.VideoCapture(video.mp4)# 检查视频是否成功打开 if not cap.isOpened():print("Error: Cannot open video file.")exit()# 获取第一帧图像&#x…

SpringCloud Alibaba - Sentinel 微服务保护解决雪崩问题、Hystrix 区别、安装及使用

目录 一、Sentinel 1.1、背景&#xff1a;雪崩问题 1.2、雪崩问题的解决办法 1.2.1、超时处理 缺陷&#xff1a;为什么这里只是 “缓解” 雪崩问题&#xff0c;而不是百分之百解决了雪问题呢&#xff1f; 1.2.2、舱壁模式 缺陷&#xff1a;资源浪费 1.2.3、熔断降级 1.…