昇思25天学习打卡营第6天 | 函数式自动微分

神经网络的训练主要使用反向传播算法

模型预测值(logits)正确标签(label)送入损失函数(loss function)获得loss

然后进行反向传播计算,求得梯度(gradients),最终更新至模型参数(parameters)

自动微分能够计算可导函数在某点处的导数值,是反向传播算法的一般化。

自动微分主要解决的问题是将一个复杂的数学运算分解为一系列简单的基本运算,该功能对用户屏蔽了大量的求导细节和过程,大大降低了框架的使用门槛。

MindSpore使用函数式自动微分的设计理念,提供更接近于数学语义的自动微分接口gradvalue_and_grad。下面我们使用一个简单的单层线性变换模型进行介绍。

1.函数与计算图

计算图是用图论语言表示数学函数的一种方式,也是深度学习框架表达神经网络模型的统一方法。我们将根据下面的计算图构造计算函数和神经网络。

在这个模型中,𝑥𝑥为输入,𝑦𝑦为正确值,𝑤𝑤和𝑏𝑏是我们需要优化的参数。

我们根据计算图描述的计算过程,构造计算函数。 其中,binary_cross_entropy_with_logits 是一个损失函数,计算预测值和目标值之间的二值交叉熵损失

执行计算函数,可以获得计算的loss值。

x = ops.ones(5, mindspore.float32)  # input tensor
y = ops.zeros(3, mindspore.float32)  # expected output
w = Parameter(Tensor(np.random.randn(5, 3), mindspore.float32), name='w') # weight
b = Parameter(Tensor(np.random.randn(3,), mindspore.float32), name='b') # biasdef function(x, y, w, b):z = ops.matmul(x, w) + bloss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))return lossloss = function(x, y, w, b)
print(loss)

2.微分函数与梯度计算

为了优化模型参数,需要求参数对loss的导数:\frac{\partial loss }{\partial w}\frac{\partial loss}{\partial b},此时我们调用mindspore.grad函数,来获得function的微分函数。

这里使用了grad函数的两个入参,分别为:

  • fn:待求导的函数。
  • grad_position:指定求导输入位置的索引。

由于我们对𝑤𝑤和𝑏𝑏求导,因此配置其在function入参对应的位置(2, 3)

使用grad获得微分函数是一种函数变换,即输入为函数,输出也为函数。

执行微分函数,即可获得𝑤𝑤、𝑏𝑏对应的梯度。

grad_fn = mindspore.grad(function, (2, 3))grads = grad_fn(x, y, w, b)
print(grads)

3.Stop Gradient

通常情况下,求导时会求loss对参数的导数,因此函数的输出只有loss一项。当我们希望函数输出多项时,微分函数会求所有输出项对参数的导数。

此时如果想实现对某个输出项的梯度截断,或消除某个Tensor对梯度的影响,需要用到Stop Gradient操作。

这里我们将function改为同时输出loss和z的function_with_logits,获得微分函数并执行。

1

可以看到,求得𝑤𝑤、𝑏𝑏对应的梯度值与初始function求得的梯度值一致。

4.Auxiliary data

Auxiliary data意为辅助数据,是函数除第一个输出项外的其他输出。通常我们会将函数的loss设置为函数的第一个输出,其他的输出即为辅助数据。

gradvalue_and_grad提供has_aux参数,当其设置为True时,可以自动实现前文手动添加stop_gradient的功能,满足返回辅助数据的同时不影响梯度计算的效果。

下面仍使用function_with_logits,配置has_aux=True,并执行。

grad_fn = mindspore.grad(function_with_logits, (2, 3), has_aux=True)grads, (z,) = grad_fn(x, y, w, b)
print(grads, z)

可以看到,求得𝑤𝑤、𝑏𝑏对应的梯度值与初始function求得的梯度值一致,同时z能够作为微分函数的输出返回。

5.神经网络梯度计算

我们的神经网络构造是继承自面向对象编程范式的nn.Cell。接下来我们通过Cell构造同样的神经网络,利用函数式自动微分来实现反向传播。

首先我们继承nn.Cell构造单层线性变换神经网络。这里我们直接使用前文的𝑤𝑤、𝑏𝑏作为模型参数,使用mindspore.Parameter进行包装后,作为内部属性,并在construct内实现相同的Tensor操作。

# Define model
class Network(nn.Cell):def __init__(self):super().__init__()self.w = wself.b = bdef construct(self, x):z = ops.matmul(x, self.w) + self.breturn z

接下来我们实例化模型和损失函数:

# Instantiate model
model = Network()
# Instantiate loss function
loss_fn = nn.BCEWithLogitsLoss()

完成后,由于需要使用函数式自动微分,需要将神经网络和损失函数的调用封装为一个前向计算函数:

# Define forward function
def forward_fn(x, y):z = model(x)loss = loss_fn(z, y)return loss

完成后,我们使用value_and_grad接口获得微分函数,用于计算梯度。

由于使用Cell封装神经网络模型,模型参数为Cell的内部属性,此时我们不需要使用grad_position指定对函数输入求导,因此将其配置为None。对模型参数求导时,我们使用weights参数,使用model.trainable_params()方法从Cell中取出可以求导的参数。

grad_fn = mindspore.value_and_grad(forward_fn, None, weights=model.trainable_params())loss, grads = grad_fn(x, y)
print(grads)

代码实现:

总结一下神经网络的梯度计算:

构建模型,基于nn.cell基类,_init_,定义类,construct类。

实例化:model=Network()

实例化损失函数:loss_fn = nn.BCEWithLogitsLoss()

函数式自动微分,需要把神经网络和损失函数调用封装为一个前向的计算函数def forward_fn()

使用接口grad_position获取微分函数

用weight函数求导,用model.trainable_params()取出求导的参数

byebye~~~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/34617.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数据中心 250KW 水冷负载组概述

该负载专为数据中心冷水机组调试和测试应用而设计, 是一款紧凑的便携式产品,具有无限功率和水流控制功能,可实现精确的温升设置与施加的功率。鹦鹉螺是完全可联网的,可以从远程站控制单个或多个单元。 使用带有触摸屏 HMI 的 PLC,…

Navicat连接Oracle出现Oracle library is not loaded的解决方法

目录 1. 问题所示2. 原理分析3. 解决方法1. 问题所示 使用Navicat链接Oracle的时候,出现如下提示:Oracle library is not loaded. 截图如下所示: 2. 原理分析 通常是由于缺少必需的 Oracle 客户端库或环境变量未正确配置所致 还有一种情况是 32位与64位的不匹配:Navica…

基于Langchain-chatchat搭建本地智能知识问答系统

基于Langchain-chatchat搭建本地智能 搭建本地智能知识问答系统:基于Langchain-chatchat的实践指南引言项目概述环境安装Anacondapip 项目安装步骤大语言模型(LLM)的重要性结语 搭建本地智能知识问答系统:基于Langchain-chatchat的…

记错医院预约的日期,选择加号还是回去?

记错了去医院的日期,起了个大早,用了 90 分钟才到医院,取号时提示没有预约的号,才发现记错时间了。这个时候是选择找医生加号还是直接回去呢?如果是你怎么选择? 如果选择找医生加号,号会排到最后…

STM32 ---- F1系列内核和芯片系统架构 || 存储器映像 || 寄存器映射

一,存储器映像 STM32 寻址范围:2^32 4 * 2^10 *2^10 K 4 * 2^10 M 4G 地址所访问的存储单元是按字节编址的。 0x0000 0000 ~ 0xFFFF FFFF 什么是存储器映射? 存储器本身不具有地址信息,给存储器分配地址的…

STM32单片机WDG看门狗详解

文章目录 1. WDG简介 2. IWDG框图 3. IWDG键寄存器 4. IWDG超时时间 5. WWDG框图 6. WWDG工作特性 7. WWDG超时时间 8. IWDG和WWDG对比 9. 代码示例 1. WDG简介 WDG(Watchdog)看门狗 看门狗可以监控程序的运行状态,当程序因为设计…

Docker Compose--安装Nginx--方法/实例

原文网址:Docker Compose--安装Nginx--方法/实例_IT利刃出鞘的博客-CSDN博客 简介 说明 本文介绍Docker Compose如何安装Nginx。 目录结构 ├── config │ ├── cert │ │ ├── xxx_bundle.pem │ │ └── xxx.key │ ├── conf.d │ …

【ONLYOFFICE震撼8.1】ONLYOFFICE8.1版本桌面编辑器测评

随着远程工作的普及和数字化办公的发展,越来越多的人开始寻找一款具有强大功能和便捷使用的办公软件。在这个时候,ONLYOFFICE 8.1应运而生,成为了许多用户的新选择。ONLYOFFICE 8.1是一种办公套件软件,它提供了文档处理、电子表格…

jupyter中如何看plt.plot的局部细节

在Jupyter中使用matplotlib时,如果你想要放大图表的某一部分,可以使用matplotlib的交互式方式查看局部细节。 %matplotlib notebook # 在Jupyter中使用交互式后端 import matplotlib.pyplot as plt import numpy as np# 生成数据 x np.linspace(0, 10…

TiDB 资源管控的对撞测试以及最佳实践架构

作者: GreenGuan 原文来源: https://tidb.net/blog/bc405c21 引言 TiDB 是一个存算分离的架构,资源管控对这种分离的架构来说实现确实有非常大的难度,TiDB 从 7.1 版本开始引入资源管控的概念,在社区也有不少伙伴测…

论文速递 | Management Science 4月文章合集(下)

编者按 在本系列文章中,我们梳理了运筹学顶刊Management Science在2024年4月份发布有关OR/OM以及相关应用的13篇文章的基本信息,旨在帮助读者快速洞察领域新动态。本文为第二部分(2/2)。 推荐文章1 ● 题目:Social Le…

模拟面试之外卖点单系统(高频面试题目mark)

今天跟大家分享一个大家简历中常见的项目-《外卖点单系统》,这是一个很经典的项目,有很多可以考察的知识点和技能点,但大多数同学都是学期项目,没有实际落地,对面试问题准备不充分,回答时抓不到重点&#x…

Handling `nil` Values in `NSDictionary` in Objective-C

Handling nil Values in NSDictionary in Objective-C When working with Objective-C, particularly when dealing with data returned from a server, it’s crucial (至关重要的) to handle nil values appropriately (适当地) to prevent unexpected crashes. Here, we ex…

VBA递归过程快速组合数据

实例需求:数据表包含的列数不固定,有的列(数量和位置不固定)包含组合数据,例如C2单元格为D,P,说明Unit Config有两种分别为D和P,如下图所示。 现在需要将所有的组合罗列出来,如下所示…

git上传本地项目及更新项目

1、注册GitHub账号和下载git 2、在GitHub上新建一个仓库,点击号——>New repository,给仓库起一个名字,点击Create repository 3、进入要上传的项目中,右键点击git back here,命令行输入git init初始化&#xff0c…

全球电力电子测试方案专业提供商「艾诺仪器」×企企通召开项目启动会,推进企业采购数智化升级

导读 供应链管理已成为企业的核心竞争力之一,为应对快速变化的市场环境,艾诺仪器亟需强化采购管理和供应链协同的竞争力。SRM涉及到各事业部、各所属企业等多个层面,希望通过双方优势资源的整合,打造高效协同、科学智能的数字化采…

数据挖掘概览

数据挖掘(Data Mining)就是从大量的,不完全的,有噪声的,模糊的,随机的实际应用数据中,提取隐含在其中的,人们事先不知道的,但又是潜在有用的信息和知识的过程. 预测性数据挖掘 分类 定义:分类就是把一些新的数据项映射到给定类别中的某一个类别 分类流程&#x…

Python | Leetcode Python题解之第189题轮转数组

题目&#xff1a; 题解&#xff1a; def reverse(nums: List[int], left, right) -> None:i, j left, rightwhile i < j:nums[i], nums[j] nums[j], nums[i]i1j-1 class Solution:def rotate(self, nums: List[int], k: int) -> None:n len(nums)k % nreverse(num…

Midway + TypeORM项目部署到BT后启动失败,MySQL报错

Midway TypeORM项目部署到BT后启动失败&#xff0c;MySQL报错 前沿 您需要先了解这篇文章&#xff1a;https://blog.csdn.net/weixin_45687201/article/details/139336111 错误日志 服务状态开启后就失败项目日志&#xff0c;输出 \> my-midway-project1.0.0 start \&…

【Python新手入门指南】Linux-conda环境安装与使用参考

文章目录 前言一、conda是什么&#xff1f;二、安装步骤三、使用Conda来管理Python环境1. 创建环境2. 激活环境3. 安装软件包4. 查看环境5. 删除环境&#xff1a;如果您不再需要某个环境&#xff0c;可以使用以下命令将其删除&#xff1a; 前言 如果你是一位经验丰富的Python开…