昇思25天学习打卡营第6天 | 函数式自动微分

神经网络的训练主要使用反向传播算法

模型预测值(logits)正确标签(label)送入损失函数(loss function)获得loss

然后进行反向传播计算,求得梯度(gradients),最终更新至模型参数(parameters)

自动微分能够计算可导函数在某点处的导数值,是反向传播算法的一般化。

自动微分主要解决的问题是将一个复杂的数学运算分解为一系列简单的基本运算,该功能对用户屏蔽了大量的求导细节和过程,大大降低了框架的使用门槛。

MindSpore使用函数式自动微分的设计理念,提供更接近于数学语义的自动微分接口gradvalue_and_grad。下面我们使用一个简单的单层线性变换模型进行介绍。

1.函数与计算图

计算图是用图论语言表示数学函数的一种方式,也是深度学习框架表达神经网络模型的统一方法。我们将根据下面的计算图构造计算函数和神经网络。

在这个模型中,𝑥𝑥为输入,𝑦𝑦为正确值,𝑤𝑤和𝑏𝑏是我们需要优化的参数。

我们根据计算图描述的计算过程,构造计算函数。 其中,binary_cross_entropy_with_logits 是一个损失函数,计算预测值和目标值之间的二值交叉熵损失

执行计算函数,可以获得计算的loss值。

x = ops.ones(5, mindspore.float32)  # input tensor
y = ops.zeros(3, mindspore.float32)  # expected output
w = Parameter(Tensor(np.random.randn(5, 3), mindspore.float32), name='w') # weight
b = Parameter(Tensor(np.random.randn(3,), mindspore.float32), name='b') # biasdef function(x, y, w, b):z = ops.matmul(x, w) + bloss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))return lossloss = function(x, y, w, b)
print(loss)

2.微分函数与梯度计算

为了优化模型参数,需要求参数对loss的导数:\frac{\partial loss }{\partial w}\frac{\partial loss}{\partial b},此时我们调用mindspore.grad函数,来获得function的微分函数。

这里使用了grad函数的两个入参,分别为:

  • fn:待求导的函数。
  • grad_position:指定求导输入位置的索引。

由于我们对𝑤𝑤和𝑏𝑏求导,因此配置其在function入参对应的位置(2, 3)

使用grad获得微分函数是一种函数变换,即输入为函数,输出也为函数。

执行微分函数,即可获得𝑤𝑤、𝑏𝑏对应的梯度。

grad_fn = mindspore.grad(function, (2, 3))grads = grad_fn(x, y, w, b)
print(grads)

3.Stop Gradient

通常情况下,求导时会求loss对参数的导数,因此函数的输出只有loss一项。当我们希望函数输出多项时,微分函数会求所有输出项对参数的导数。

此时如果想实现对某个输出项的梯度截断,或消除某个Tensor对梯度的影响,需要用到Stop Gradient操作。

这里我们将function改为同时输出loss和z的function_with_logits,获得微分函数并执行。

1

可以看到,求得𝑤𝑤、𝑏𝑏对应的梯度值与初始function求得的梯度值一致。

4.Auxiliary data

Auxiliary data意为辅助数据,是函数除第一个输出项外的其他输出。通常我们会将函数的loss设置为函数的第一个输出,其他的输出即为辅助数据。

gradvalue_and_grad提供has_aux参数,当其设置为True时,可以自动实现前文手动添加stop_gradient的功能,满足返回辅助数据的同时不影响梯度计算的效果。

下面仍使用function_with_logits,配置has_aux=True,并执行。

grad_fn = mindspore.grad(function_with_logits, (2, 3), has_aux=True)grads, (z,) = grad_fn(x, y, w, b)
print(grads, z)

可以看到,求得𝑤𝑤、𝑏𝑏对应的梯度值与初始function求得的梯度值一致,同时z能够作为微分函数的输出返回。

5.神经网络梯度计算

我们的神经网络构造是继承自面向对象编程范式的nn.Cell。接下来我们通过Cell构造同样的神经网络,利用函数式自动微分来实现反向传播。

首先我们继承nn.Cell构造单层线性变换神经网络。这里我们直接使用前文的𝑤𝑤、𝑏𝑏作为模型参数,使用mindspore.Parameter进行包装后,作为内部属性,并在construct内实现相同的Tensor操作。

# Define model
class Network(nn.Cell):def __init__(self):super().__init__()self.w = wself.b = bdef construct(self, x):z = ops.matmul(x, self.w) + self.breturn z

接下来我们实例化模型和损失函数:

# Instantiate model
model = Network()
# Instantiate loss function
loss_fn = nn.BCEWithLogitsLoss()

完成后,由于需要使用函数式自动微分,需要将神经网络和损失函数的调用封装为一个前向计算函数:

# Define forward function
def forward_fn(x, y):z = model(x)loss = loss_fn(z, y)return loss

完成后,我们使用value_and_grad接口获得微分函数,用于计算梯度。

由于使用Cell封装神经网络模型,模型参数为Cell的内部属性,此时我们不需要使用grad_position指定对函数输入求导,因此将其配置为None。对模型参数求导时,我们使用weights参数,使用model.trainable_params()方法从Cell中取出可以求导的参数。

grad_fn = mindspore.value_and_grad(forward_fn, None, weights=model.trainable_params())loss, grads = grad_fn(x, y)
print(grads)

代码实现:

总结一下神经网络的梯度计算:

构建模型,基于nn.cell基类,_init_,定义类,construct类。

实例化:model=Network()

实例化损失函数:loss_fn = nn.BCEWithLogitsLoss()

函数式自动微分,需要把神经网络和损失函数调用封装为一个前向的计算函数def forward_fn()

使用接口grad_position获取微分函数

用weight函数求导,用model.trainable_params()取出求导的参数

byebye~~~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/34617.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数据中心 250KW 水冷负载组概述

该负载专为数据中心冷水机组调试和测试应用而设计, 是一款紧凑的便携式产品,具有无限功率和水流控制功能,可实现精确的温升设置与施加的功率。鹦鹉螺是完全可联网的,可以从远程站控制单个或多个单元。 使用带有触摸屏 HMI 的 PLC,…

豆包大语言模型API调用错误码一览表

本文介绍了您可能从 API 和官方 SDK 中看到的错误代码。 http code说明 400 原因:错误的请求,例如缺少必要参数,或者参数不符合规范等 解决方法:检查请求后重试 401 原因:认证错误,代表服务无法对请求进…

FFmpeg开发笔记(四十)Nginx集成rtmp模块实现RTMP推拉流

《FFmpeg开发实战:从零基础到短视频上线》一书的“10.2.2 FFmpeg向网络推流”介绍了轻量级流媒体服务器MediaMTX,虽然MediaMTX使用很简单,可是不能满足复杂的业务需求,故而实际应用中需要引入专业的流媒体服务器。 nginx-rtmp是开…

Navicat连接Oracle出现Oracle library is not loaded的解决方法

目录 1. 问题所示2. 原理分析3. 解决方法1. 问题所示 使用Navicat链接Oracle的时候,出现如下提示:Oracle library is not loaded. 截图如下所示: 2. 原理分析 通常是由于缺少必需的 Oracle 客户端库或环境变量未正确配置所致 还有一种情况是 32位与64位的不匹配:Navica…

基于Langchain-chatchat搭建本地智能知识问答系统

基于Langchain-chatchat搭建本地智能 搭建本地智能知识问答系统:基于Langchain-chatchat的实践指南引言项目概述环境安装Anacondapip 项目安装步骤大语言模型(LLM)的重要性结语 搭建本地智能知识问答系统:基于Langchain-chatchat的…

记错医院预约的日期,选择加号还是回去?

记错了去医院的日期,起了个大早,用了 90 分钟才到医院,取号时提示没有预约的号,才发现记错时间了。这个时候是选择找医生加号还是直接回去呢?如果是你怎么选择? 如果选择找医生加号,号会排到最后…

STM32 ---- F1系列内核和芯片系统架构 || 存储器映像 || 寄存器映射

一,存储器映像 STM32 寻址范围:2^32 4 * 2^10 *2^10 K 4 * 2^10 M 4G 地址所访问的存储单元是按字节编址的。 0x0000 0000 ~ 0xFFFF FFFF 什么是存储器映射? 存储器本身不具有地址信息,给存储器分配地址的…

STM32单片机WDG看门狗详解

文章目录 1. WDG简介 2. IWDG框图 3. IWDG键寄存器 4. IWDG超时时间 5. WWDG框图 6. WWDG工作特性 7. WWDG超时时间 8. IWDG和WWDG对比 9. 代码示例 1. WDG简介 WDG(Watchdog)看门狗 看门狗可以监控程序的运行状态,当程序因为设计…

2024年6月24日 语法纠正

修改前的 So happy to see you again in our English Corner. Today, we have our old friend Fannie come with us and Ms. Liang is also here. Because today we use this new meeting material at first time, I arbitrarily assgin the roles according to everyone’s r…

Docker Compose--安装Nginx--方法/实例

原文网址:Docker Compose--安装Nginx--方法/实例_IT利刃出鞘的博客-CSDN博客 简介 说明 本文介绍Docker Compose如何安装Nginx。 目录结构 ├── config │ ├── cert │ │ ├── xxx_bundle.pem │ │ └── xxx.key │ ├── conf.d │ …

【ONLYOFFICE震撼8.1】ONLYOFFICE8.1版本桌面编辑器测评

随着远程工作的普及和数字化办公的发展,越来越多的人开始寻找一款具有强大功能和便捷使用的办公软件。在这个时候,ONLYOFFICE 8.1应运而生,成为了许多用户的新选择。ONLYOFFICE 8.1是一种办公套件软件,它提供了文档处理、电子表格…

jupyter中如何看plt.plot的局部细节

在Jupyter中使用matplotlib时,如果你想要放大图表的某一部分,可以使用matplotlib的交互式方式查看局部细节。 %matplotlib notebook # 在Jupyter中使用交互式后端 import matplotlib.pyplot as plt import numpy as np# 生成数据 x np.linspace(0, 10…

TiDB 资源管控的对撞测试以及最佳实践架构

作者: GreenGuan 原文来源: https://tidb.net/blog/bc405c21 引言 TiDB 是一个存算分离的架构,资源管控对这种分离的架构来说实现确实有非常大的难度,TiDB 从 7.1 版本开始引入资源管控的概念,在社区也有不少伙伴测…

STM32实现独立看门狗和窗口看门狗

文章目录 1. WDG 2. IWDG独立看门狗 2.1 main.c 3. WWDG窗口看门狗 3.1 main.c 1. WDG 对于WDG看门狗的详细解析可以看下面这篇文章: STM32单片机WDG看门狗详解-CSDN博客 看门狗可以监控程序的运行状态,当程序因为设计漏洞、硬件故障、电磁干扰等原…

论文速递 | Management Science 4月文章合集(下)

编者按 在本系列文章中,我们梳理了运筹学顶刊Management Science在2024年4月份发布有关OR/OM以及相关应用的13篇文章的基本信息,旨在帮助读者快速洞察领域新动态。本文为第二部分(2/2)。 推荐文章1 ● 题目:Social Le…

模拟面试之外卖点单系统(高频面试题目mark)

今天跟大家分享一个大家简历中常见的项目-《外卖点单系统》,这是一个很经典的项目,有很多可以考察的知识点和技能点,但大多数同学都是学期项目,没有实际落地,对面试问题准备不充分,回答时抓不到重点&#x…

SpringBoot中使用MQTT实现消息的订阅和发布

SpringBoot中使用MQTT实现消息的订阅和发布 背景 java框架SpringBoot通过mQTT通信 控制物联网设备 还是直接上代码 第一步依赖&#xff1a; <!--mqtt相关依赖--><dependency><groupId>org.springframework.integration</groupId><artifactId>s…

百度百科词条创建的前提条件

随着互联网的发展&#xff0c;人们获取信息越来越依赖于搜索引擎&#xff0c;而百度百科作为百度搜索的核心产品在百度中一般能够稳居首位&#xff0c;而且百科词条具有权威性&#xff0c;可信度比较高&#xff0c;非常适用于企业和人物的形象宣传。 最近&#xff0c;小马识途营…

JS-数组扁平化方法合集(递归,while循环,flat)

前言 数组扁平化也是面试常考题之一&#xff0c;今天就和大家简单分享一下常见的数组扁平方法。这题其实主要考察的是递归思想&#xff0c;因为当数组里面嵌套非常多层数组的时候只能通过循环递归来进行扁平。本次分享主要也是分享本题的递归思想。话不多说&#xff0c;开始分…

基于Spring Boot构建淘客返利平台

基于Spring Boot构建淘客返利平台 大家好&#xff0c;我是免费搭建查券返利机器人省钱赚佣金就用微赚淘客系统3.0的小编&#xff0c;也是冬天不穿秋裤&#xff0c;天冷也要风度的程序猿&#xff01;今天我们将讨论如何基于Spring Boot构建一个淘客返利平台。 淘客返利平台通过…