《昇思25天学习打卡营第7天|函数式自动微分》

文章目录

  • 今日所学:
  • 一、函数与计算图
  • 二、微分函数与梯度计算
  • 三、Stop Gradient
  • 四、Auxiliary data
  • 五、神经网络梯度计算
  • 总结


今日所学:

今天我学习了神经网络训练的核心原理,主要是反向传播算法。这个过程包括将模型预测值(logits)和正确标签(label)输入到损失函数(loss function)中计算loss,然后通过反向传播算法计算梯度(gradients),最终更新模型参数(parameters)。自动微分技术能够在某点计算可导函数的导数值,是反向传播算法的一个广义实现。它的主要作用是将复杂的数学运算分解为一系列简单的基本运算,从而屏蔽了大量求导的细节和过程,显著降低了使用深度学习框架的门槛。

MindSpore采用函数式自动微分的设计理念,提供了更接近数学语义的自动微分接口,例如grad和value_and_grad。为了更好地理解这些概念,我还学习了如何使用一个简单的单层线性变换模型进行实践。


一、函数与计算图

MindSpore之前的还不熟悉的相关内容可以见:《昇思25天学习打卡营第1天|基本介绍》

计算图是一种借助图论来描绘数学函数的一种方法,同时也是深度学习框架用以表达神经网络模型的通用方式。以下,我们将以此计算图为基础,来构建计算函数和神经网络:
在这里插入图片描述
在本节所学的这个模型中,𝑥为输入,𝑦为正确值,𝑤和𝑏是我们需要优化的参数,根据计算图描述的计算过程,构造计算函数,执行计算函数,可以获得计算的loss值,代码与结果如下所示:

x = ops.ones(5, mindspore.float32)  # input tensor
y = ops.zeros(3, mindspore.float32)  # expected output
w = Parameter(Tensor(np.random.randn(5, 3), mindspore.float32), name='w') # weight
b = Parameter(Tensor(np.random.randn(3,), mindspore.float32), name='b') # biasdef function(x, y, w, b):z = ops.matmul(x, w) + bloss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))return lossloss = function(x, y, w, b)
print(loss)

结果如下:

Tensor(shape=[], dtype=Float32, value= 0.914285)

二、微分函数与梯度计算

在之后学习内容中为了优化模型参数,需要求参数对loss的导数:

∂loss∂𝑤

∂loss∂𝑏

此时我们调用mindspore.grad函数,来获得function的微分函数。其中grad函数的两个入参,分别为fn(待求导的函数)与grad_position(指定求导输入位置的索引),代码如下:

grad_fn = mindspore.grad(function, (2, 3))
grads = grad_fn(x, y, w, b)
print(grads)

结果如下:

在这里插入图片描述

使用grad获得微分函数是一种函数变换,即输入为函数,输出也为函数。

三、Stop Gradient

在常规的情况下,求导操作主要是计算loss相对于参数的导数,由此,函数的输出仅有loss一项。然而,当我们期望函数有多项输出时,微分函数将会计算所有输出项相对于参数的导数。在这种情况下,如果我们希望实现特定输出项的梯度截断,或者需要消除某个Tensor对梯度的影响,那么我们将需要使用Stop Gradient操作。在这里,我们会将function改造成同时输出loss和z的function_with_logits,并获取微分函数以供执行。

如果想要屏蔽掉z对梯度的影响,即仍只求参数对loss的导数,可以使用ops.stop_gradient接口,将梯度在此处截断。

代码如下:

def function_with_logits(x, y, w, b):z = ops.matmul(x, w) + bloss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))return loss, zgrad_fn = mindspore.grad(function_with_logits, (2, 3))
grads = grad_fn(x, y, w, b)
print(grads)def function_stop_gradient(x, y, w, b):z = ops.matmul(x, w) + bloss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))return loss, ops.stop_gradient(z)grad_fn = mindspore.grad(function_stop_gradient, (2, 3))
grads = grad_fn(x, y, w, b)
print(grads)

截断前结果:
在这里插入图片描述

截断后结果:

在这里插入图片描述

四、Auxiliary data

我深入理解了Auxiliary data(辅助数据)的概念和应用。我明白了辅助数据其实就是函数的非主要输出项。在实际应用中,我们常将函数的主要输出设为loss,而其它的所有输出则被视为辅助数据。对于grad和value_and_grad函数,我享受到了has_aux参数带来的便利。当将其设为True,它就能自动实现之前需要手动添加的stop_gradient操作。这种设计巧妙地使我在返回辅助数据的同时,不受梯度计算的任何影响。

在后续的实践中,我继续使用了function_with_logits,并设置了has_aux为True进行操作。整个过程顺畅无比,加深了我对这一主题的理解。我会持续探索,并将这些知识应用到更广泛的场景中去:

grad_fn = mindspore.grad(function_with_logits, (2, 3), has_aux=True)grads, (z,) = grad_fn(x, y, w, b)
print(grads, z)

结果如下:

在这里插入图片描述

五、神经网络梯度计算

前面章节已经讲述了网络构建,还不了解的可见这篇文章:《昇思25天学习打卡营第6天|网络构建》

接下来,我深入了解了如何通过Cell去构造神经网络,以及利用函数式自动微分来实现反向传播的过程。我首先继承了nn.Cell来构建单层线性变换神经网络。有意思的是,这个过程中我直接使用了之前的 𝑤 和 𝑏 来作为模型参数。这种做法完全打破了我早前的理解,让我认识到原来我们可以直接使用现有的参数以节约时间和计算资源。我将这些参数用mindspore.Parameter包装起来作为内部属性,并在construct内实现了与之前一样的Tensor操作:

# Define model
class Network(nn.Cell):def __init__(self):super().__init__()self.w = wself.b = bdef construct(self, x):z = ops.matmul(x, self.w) + self.breturn z# Instantiate model
model = Network()
# Instantiate loss function
loss_fn = nn.BCEWithLogitsLoss()# Define forward function
def forward_fn(x, y):z = model(x)loss = loss_fn(z, y)return lossgrad_fn = mindspore.value_and_grad(forward_fn, None, weights=model.trainable_params())loss, grads = grad_fn(x, y)
print(grads)

结果如下:
在这里插入图片描述

可以看出,执行微分函数后的梯度值和前文function求得的梯度值一致。

在这里插入图片描述

总结

在今天的学习中,我深入理解了神经网络训练的核心原理,包括反向传播算法和如何利用自动微分技术来计算梯度并更新模型参数。我也学习了如何使用MindSpore框架的函数式自动微分接口来进行实践,并利用计算图进行模型参数优化。此外,我理解了Stop Gradient操作和辅助数据对梯度计算的影响,以及如何在神经网络的梯度计算中有效利用它们。通过理论学习和实践操作,我对这些概念有了更深入的理解,期待在明天的学习中继续进步。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/39864.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

无法定位程序输入点Z9 qt assertPKcS0i于动态链接库F:\code\projects\06_algorithm\main.exe

解决方法: 这个报错,是因为程序在运行时没要找到所需的dll库,如果把这个程序方法中对应库的目录下执行,则可正常执行。即使将图中mingw_64\bin 环境变量上移到msvc2022_64\bin 之前也不可以。 最终的解决方法是在makefile中设置环…

代码随想三刷动态规划篇7

代码随想三刷动态规划篇7 198. 打家劫舍题目代码 213. 打家劫舍 II题目代码 337. 打家劫舍 III题目代码 121. 买卖股票的最佳时机题目代码 198. 打家劫舍 题目 链接 代码 class Solution {public int rob(int[] nums) {if(nums.length1){return nums[0];}if(nums.length2){…

Python基础小知识问答系列-可迭代型变量赋值

1. 问题: 怎样简洁的把列表中的元素赋值给单个变量? 当需要列表中指定几个值时,剩余的变量都收集在一起,该怎么进行变量赋值? 当只需要列表中指定某几个值,其他值都忽略时,该怎么…

基于Hadoop平台的电信客服数据的处理与分析③项目开发:搭建基于Hadoop的全分布式集群---任务5:ZooKeeper集群安装

任务描述 ZooKeeper是一个开源分布式协调服务,其独特的Leader-Follower集群结构,很好的解决了分布式单点问题。目前主要用于诸如:统一命名服务、配置管理、锁服务、集群管理等场景。大数据应用中主要使用ZooKeeper的集群管理功能。在这里使用…

使用Redis实现消息队列:List、Pub/Sub和Stream的实践

摘要 Redis是一个高性能的键值存储系统,它的多种数据结构使其成为实现消息队列的理想选择。本文将探讨如何使用Redis的List、Pub/Sub和Stream数据结构来实现一个高效的消息队列系统。 1. 消息队列的基本概念 消息队列是一种应用程序之间进行通信的机制&#xff0…

Qt/C++模拟鼠标键盘输入

1、控制鼠标移动 (1)Qt方案 QScreen* sc QGuiApplication::primaryScreen(); QCursor* c new QCursor(); int deltaX 10; int deltaY 10; c->setPos(sc, c->pos().x() deltaX, c->pos().y() deltaY);(2)Windows原…

人工智能发展方向的思考:简单与复杂的对立与融合

人工智能(AI)的迅猛发展,正在以惊人的速度改变着我们的世界。它在很多领域展示了强大的能力,特别是在处理简单、重复的任务方面,AI已经表现出极高的效率和准确性。然而,当面对复杂的业务场景时,…

660错题

不能局部求导,局部洛必达

Swift 中强大的 Key Paths(键路径)机制趣谈(上)

概览 小伙伴们可能不知道:在 Swift 语言中隐藏着大量看似“其貌不扬”实则却让秃头码农们“高世骇俗”,堪称卧虎藏龙的各种秘技。 其中,有一枚“不起眼”的小家伙称之为键路径(Key Paths)。如若将其善加利用&#xff…

Spring事务十种失效场景

首先我们要明白什么是事务?它的作用是什么?它在什么场景下在Spring框架下会失效? 事务:本质上是由数据库和程序之间交互的过程中的衍生物,它是一种控制数据的行为规则。有几个特性 1、原子性:执行单元内,要…

pjsip环境搭建、编译源码生成.lib库

使用平台: windows qt(5.15.2) vs(2019)x86 pjsip版本以及第三方库使用 pjsip 2.10 ffmpeg4.2.1 sdl2.0.12pjsip源码链接: https://github.com/pjsip/pjproject源码环境配置 首先创建两个文件夹,分别是include、lib其中include放置ff…

p2p、分布式,区块链笔记: 通过libp2p的Kademlia网络协议实现kv-store

Kademlia 网络协议 Kademlia 是一种分布式哈希表协议和算法,用于构建去中心化的对等网络,核心思想是通过分布式的网络结构来实现高效的数据查找和存储。在这个学习项目里,Kademlia 作为 libp2p 中的 NetworkBehaviour的组成。 以下这些函数或…

Java8 - Stream API 处理集合数据

Java 8的Stream API提供了一种功能强大的方式来处理集合数据,以函数式和声明式的方式进行操作。Stream API允许您对元素集合执行操作,如过滤、映射和归约,以简洁高效的方式进行处理。 下面是Java 8 Stream API的一些关键特性和概念&#xff…

windows安装Gitblit还是Bonobo Git Server

Gitblit 和 Bonobo Git Server 都是用于托管Git仓库的工具,但它们是基于不同平台的不同软件。 Gitblit 是一个纯 Java 写的服务器,支持托管 Git,Mercurial 和 SVN 仓库。它需要 Java 运行环境,适合在 Windows、Linux 和 Mac 平台…

Android 输入系统 InputStage

整体流程如上所说,简要归纳如下: 输入法之前的处理 输入法处理 输入法之后处理 综合处理 InputStage将输入事件的处理分成若干个阶段(Stage), 如果当前有输入法窗口,则事件处理从 NativePreIme 开始,否…

SpringBoot MongoTemplate使用详解

前面文章讲了 SpringBoot整合MongoDB JPA使用:https://blog.csdn.net/qq_42402854/article/details/139973336 在项目中,通常会 JPA语法与 MongoTemplate两者结合使用,特别是针对复杂动态条件查询时,MongoTemplate更加友好。 Spr…

主流国产服务器操作系统技术分析

主流国产服务器操作系统 信创 "信创",即信息技术应用创新,作为科技自立自强的核心词汇,在我国信息化建设的进程中扮演着至关重要的角色。自2016年起步,2020年开始蓬勃兴起,信创的浪潮正席卷整个信息与通信技…

GNeRF代码复现

https://github.com/quan-meng/gnerf 之前一直去复现这个代码总是文件不存在,我就懒得搞了(实际上是没能力哈哈哈) 最近突然想到这篇论文重新试试复现 一、按步骤创建虚拟环境安装各种依赖等 二、安装好之后下载数据,可以用Blen…

virtualbox+Ubuntu部分窗口显示错乱

如下图: 窗口标题显示错乱,跟一般乱码不一样。 解决办法: 在virtualbox设置中,显示选项卡,取消勾选启用3D加速 也可参考此链接:linux ubuntu 中vscode中央窗口显示出现异常/显示错误_开发工具-CSDN问答

打卡第一天

今天是参加算法训练营的第一天,希望我能把这个训练营坚持下来,希望我的算法编程题的能力有所提升,不再面试挂了,面试总是挂编程题,记录我leetcode刷题数量: 希望我通过这个训练营能够实现两份工作的无缝衔接…