昇思MindSpore学习入门-高阶自动微分

mindspore.ops模块提供的grad和value_and_grad接口可以生成网络模型的梯度。grad计算网络梯度,value_and_grad同时计算网络的正向输出和梯度。本文主要介绍如何使用grad接口的主要功能,包括一阶、二阶求导,单独对输入或网络权重求导,返回辅助变量,以及如何停止计算梯度。

一阶求导

计算一阶导数方法:mindspore.grad,其中参数使用方式为:

  • fn:待求导的函数或网络。
  • grad_position:指定求导输入位置的索引。若为int类型,表示对单个输入求导;若为tuple类型,表示对tuple内索引的位置求导,其中索引从0开始;若是None,表示不对输入求导,这种场景下,weights非None。默认值:0。
  • weights:训练网络中需要返回梯度的网络变量。一般可通过weights = net.trainable_params()获取。默认值:None。
  • has_aux:是否返回辅助参数的标志。若为True,fn输出数量必须超过一个,其中只有fn第一个输出参与求导,其他输出值将直接返回。默认值:False。

下面先构建自定义网络模型Net,再对其进行一阶求导,通过这样一个例子对grad接口的使用方式做简单介绍,即公式:

𝑓(𝑥,𝑦)=𝑥∗𝑥∗𝑦∗𝑧

首先定义网络模型Net、输入x和输入y:

import numpy as np

from mindspore import ops, Tensor

import mindspore.nn as nn

import mindspore as ms

# 定义输入x和y

x = Tensor([3.0], dtype=ms.float32)

y = Tensor([5.0], dtype=ms.float32)

class Net(nn.Cell):

    def __init__(self):

        super(Net, self).__init__()

        self.z = ms.Parameter(ms.Tensor(np.array([1.0], np.float32)), name='z')

    def construct(self, x, y):

        out = x * x * y * self.z

        return out

对输入求一阶导

对输入x, y进行求导,需要将grad_position设置成(0, 1):

对权重进行求导

对权重z进行求导,这里不需要对输入求导,将grad_position设置成None:

返回辅助变量

同时对输入和权重求导,其中只有第一个输出参与求导,示例代码如下:

停止计算梯度

可以使用stop_gradient来停止计算指定算子的梯度,从而消除该算子对梯度的影响。

在上面一阶求导使用的矩阵相乘网络模型的基础上,再增加一个算子out2并禁止计算其梯度,得到自定义网络Net2,然后看一下对输入的求导结果情况。

示例代码如下:

从上面的打印可以看出,由于对out2设置了stop_gradient,所以out2没有对梯度计算有任何的贡献,其输出结果与未加out2算子时一致。

下面删除out2 = stop_gradient(out2),再来看一下输出结果。示例代码为:

打印结果可以看出,把out2算子的梯度也计算进去之后,由于out2和out1算子完全相同,因此它们产生的梯度也完全相同,所以可以看到,结果中每一项的值都变为了原来的两倍(存在精度误差)。

高阶求导

高阶微分在AI支持科学计算、二阶优化等领域均有应用。如分子动力学模拟中,利用神经网络训练势能时,损失函数中需计算神经网络输出对输入的导数,则反向传播便存在损失函数对输入、权重的二阶交叉导数。

此外,AI求解微分方程(如PINNs方法)还会存在输出对输入的二阶导数。又如二阶优化中,为了能够让神经网络快速收敛,牛顿法等需计算损失函数对权重的二阶导数。

MindSpore可通过多次求导的方式支持高阶导数,下面通过几类例子展开阐述。

单输入单输出高阶导数

例如Sin算子,其公式为:

𝑓(𝑥)=𝑠𝑖𝑛(𝑥)

其一阶导数、二阶导数为:

其二阶导数(-Sin)实现如下:

从上面的打印结果可以看出,−𝑠𝑖𝑛(3.1415926)的值接近于0。

单输入多输出高阶导数

对如下公式求导:

(1)𝑓(𝑥)=(𝑓1(𝑥),𝑓2(𝑥))

其中:

(2)𝑓1(𝑥)=𝑠𝑖𝑛(𝑥)

(3)𝑓2(𝑥)=𝑐𝑜𝑠(𝑥)

梯度计算时由于MindSpore采用的是反向自动微分机制,会对输出结果求和后再对输入求导。因此其一阶导数是:

(4)𝑓′(𝑥)=𝑐𝑜𝑠(𝑥)−𝑠𝑖𝑛(𝑥)

其二阶导数为:

(5)𝑓″(𝑥)=−𝑠𝑖𝑛(𝑥)−𝑐𝑜𝑠(𝑥)

从上面的打印结果可以看出,−𝑠𝑖𝑛(3.1415926)−𝑐𝑜𝑠(3.1415926)的值接近于1。

多输入多输出高阶导数

对如下公式求导:

(1)𝑓(𝑥,𝑦)=(𝑓1(𝑥,𝑦),𝑓2(𝑥,𝑦))

其中:

(2)𝑓1(𝑥,𝑦)=𝑠𝑖𝑛(𝑥)−𝑐𝑜𝑠(𝑦)

(3)𝑓2(𝑥,𝑦)=𝑐𝑜𝑠(𝑥)−𝑠𝑖𝑛(𝑦)

梯度计算时由于MindSpore采用的是反向自动微分机制, 会对输出结果求和后再对输入求导。

求和:

(4)∑𝑜𝑢𝑡𝑝𝑢𝑡=𝑠𝑖𝑛(𝑥)+𝑐𝑜𝑠(𝑥)−𝑠𝑖𝑛(𝑦)−𝑐𝑜𝑠(𝑦)

输出和关于输入𝑥的一阶导数为:

(5)d∑𝑜𝑢𝑡𝑝𝑢𝑡d𝑥=𝑐𝑜𝑠(𝑥)−𝑠𝑖𝑛(𝑥)

输出和关于输入𝑥的二阶导数为:

(6)d∑𝑜𝑢𝑡𝑝𝑢𝑡2d2𝑥=−𝑠𝑖𝑛(𝑥)−𝑐𝑜𝑠(𝑥)

输出和关于输入𝑦的一阶导数为:

(7)d∑𝑜𝑢𝑡𝑝𝑢𝑡d𝑦=−𝑐𝑜𝑠(𝑦)+𝑠𝑖𝑛(𝑦)

输出和关于输入𝑦的二阶导数为:

(8)d∑𝑜𝑢𝑡𝑝𝑢𝑡2d2𝑦=𝑠𝑖𝑛(𝑦)+𝑐𝑜𝑠(𝑦)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/875502.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

7.24 模拟赛总结 [dp 专场] + tarjan

复盘 7:40 开题 看 T1 ,妈呀,一上来就数数?盯了几分钟后发现会了,不就是 LCS 计数嘛 继续看,T2 看上去很恶心,线段覆盖,感觉可能是贪心什么的 再看 T3,先想了个 n 2 n^2 n2 的式…

Vue 3 + Vite 项目中安装 Tailwind CSS

官网:安装 - TailwindCSS中文文档 | TailwindCSS中文网 tips:只按照官网的配置可能会导致样式不加载/加载不生效的问题 1、正确安装指令 npm install -D tailwindcss postcss autoprefixer npx tailwindcss init -p 自动生成 ​tailwind.config.js​…

【C++】string类(上)

个人主页~ string 一、标准库中的string类1、什么是string类2、string类的常用接口讲解(1)string类的常见构造(2)string类的容量操作(3)string类对象的访问及遍历(4)string类对象的修…

Java语言程序设计——篇七(2)

🌿🌿🌿跟随博主脚步,从这里开始→博主主页🌿🌿🌿 封装性与多态 封装性与访问修饰符类的访问权限类成员的访问权限 🌠防止类扩展和方法覆盖实战演练 抽象类实战演练 对象转换实战演练…

lambda表达式,真题示例

Lambda表达式 它使代码更加简洁、易读,函数式编程增强了代码的表达力。常用于对集合的操作,如遍历、过滤、转换等。 Lambda表达式的形式: 参数, 箭头(->) 以及一个表达式: (String first, String sec…

Android P Input设备变化监听 Storage设备变化监听

InputManager.java中实现了InputDeviceListener接口,只需要新建一个类 implements InputDeviceListener ,并且将类实例化注册给InputManager.getInstance().registerInputDeviceListener即可。 StorageManager同理 StorageManager中会调用StorageEventL…

还手动抄字幕?学会这3个视频转文字方法,轻松提取视频中的字幕!

大家有尝试过考试前极限抱佛脚吗? 在下不才,曾经试过一次,轻松在及格线低空飘过【大家不要学不要学不要学,重要的事情说三遍!!!】 至于我当时究竟是怎么做到的呢?其实这里面有点小…

网络原理_初识

目录 一、局域网LAN 二、广域网WAN 三、网络通信基础 3.1 IP地址 3.2 端口号 3.3 协议 3.4 五元组 3.5 OSI七层模型 3.6 TCP/IP五层模型 3.7 网络设备所在分层 3.8 封装和分用 总结 一、局域网LAN 局域网,即 Local Area Network,Local 即标…

“微软蓝屏”全球宕机,敲响基础软件自主可控警钟

上周五,“微软蓝屏”“感谢微软 喜提假期”等词条冲上热搜,全球百万打工人受此影响,共同见证这一历史性事件。据微软方面发布消息称,旗下Microsoft 365系列服务出现访问中断。随后在全球范围内,包括企业、政府、个人在…

【定积分】

框架 概念,性质定积分计算基本特色变限积分及其导数反常积分(广义积分)定积分应用面积体积 讲解 1.概念,性质: 定积分就是求出曲线的面积;性质中要注意几个不等式的比较 2.定积分计算: 基本&…

物理机 gogs+jenkins+sonarqube 实现CI/CD

一、部署gogs_0.11.91_linux_amd64.tar.gz gogs官网下载&#xff1a;https://dl.gogs.io/ yum -y install mariadb-serversystemctl start mariadbsystemctl enable mariadbuseradd gittar zxvf gogs_0.11.91_linux_amd64.tar.gzcd gogsmysql -u root -p < scripts/mysql.…

vue3前端开发-小兔鲜项目-登录和非登录状态下的模板适配

vue3前端开发-小兔鲜项目-登录和非登录状态下的模板适配&#xff01;有了上次的内容铺垫&#xff0c;我们可以根据用户的token来判定&#xff0c;到底是显示什么内容了。 1&#xff1a;我们在对应的导航组件内修改完善一下内容即可。 <script setup> import { useUserSt…

svn软件总成全内容

SVN软件总成 概述&#xff1a;本文为经验型文档 目录 D:\安装包\svn软件总成 的目录D:\安装包\svn软件总成\svn-base添加 的目录D:\安装包\svn软件总成\tools 的目录D:\安装包\svn软件总成\tools\sqlite-tools-win32-x86-3360000 的目录D:\安装包\svn软件总成\安装包-----bt lo…

C#调用OpenCvSharp实现图像的角点检测

角点检测用于获取图像特征&#xff0c;以支撑运动检测、目标识别、图像匹配等方面的应用。常用的角点检测算法包括Kitchen-Rosenfeld算法、Harris算法、KLT算法、SUSAN算法等&#xff0c;本文学习并测试Harris角点检测算法。   关于Harris算法的数学原理请见参考文献1的第18、…

C++内存管理和模板/stl初识

前言 c兼容C语言&#xff0c;但它因为有类和对象的概念&#xff0c;C语言原生的那套内存管理函数在特定场景下还是有些捉襟见肘的&#xff0c;为此c在C语言的基础上引入新的内存管理方案&#xff0c;今天我们就来简单的认识一下c的内存管理。除此之外&#xff0c;模板也是c引入…

Jetpack Compose 通过 OkHttp 发送 HTTP 请求的示例

下面是一个使用 Kotlin 和 Jetpack Compose 来演示通过 OkHttp 发送 HTTP 请求的示例。这个示例包括在 Jetpack Compose 中发送一个 GET 请求和一个 POST 请求&#xff0c;并显示结果。 添加okhttp依赖 首先&#xff0c;在你的 build.gradle.kts 文件中添加必要的依赖&#xf…

父子组件生命周期的执行顺序

在Vue中&#xff0c;父子组件的生命周期执行顺序是一个重要的概念&#xff0c;它帮助开发者理解组件之间的加载、更新和销毁过程。以下是对父子组件生命周期执行顺序的详细解释&#xff1a; 一、加载渲染过程 当Vue实例开始创建时&#xff0c;会按照以下顺序执行生命周期钩子…

PACS医学影像临床信息系统,C#影像归档和通信系统源码,PACS源码,支持图像的获取、传输、浏览、打印、测量、重建、对比、存储、处理,电子胶片影像管理等

医学影像临床信息系统具有图像采集、显示、存储、传输和管理等功能&#xff0c;支持DICOM影像设备和非DICOM影像设备&#xff0c;可以识别CT、MR、CR/DR、X光、DSA、B超、NM、SC等设备的图像类型&#xff0c;可对数字影像进行无损压缩和有损压缩处理。C/S体系结构的多媒体数据库…

STM32智能照明控制系统教程

目录 引言环境准备智能照明控制系统基础代码实现&#xff1a;实现智能照明控制系统 4.1 数据采集模块 4.2 数据处理与控制模块 4.3 通信与网络系统实现 4.4 用户界面与数据可视化应用场景&#xff1a;照明管理与优化问题解决方案与优化收尾与总结 1. 引言 智能照明控制系统通…

独立游戏《星尘异变》UE5 C++程序开发日志8——实现敏感词过滤功能(AC自动机)

在游戏中经常会有需要玩家输入一些内容的功能&#xff0c;例如聊天&#xff0c;命名等&#xff0c;这款游戏只有在存档时辉用到命名功能&#xff0c;所以这个过滤也只是一个实验性的功能&#xff0c;我们将使用AC自动机来实现&#xff0c;这是在我们把“csdn”这个词设置为屏蔽…