昇思25天学习打卡营第07天 | 函数式自动微分

昇思25天学习打卡营第07天 | 函数式自动微分

文章目录

  • 昇思25天学习打卡营第07天 | 函数式自动微分
    • 函数与计算图
      • 微分函数与梯度
      • Stop Gradient
      • Auxiliary data
    • 神经网络梯度计算
    • 总结
    • 打卡

神经网络的训练主要使用反向传播算法,首先计算模型预测值(logits)与正确标签(label)之间的loss,然后进行反向传播,通过梯度来更新模型参数从而完成网路的训练。

MindSpore使用函数式自动微分的设计理念,提供更接近于数学语义的自动微分接口gradvalue_and_grad

函数与计算图

计算图是图论语言表示数学函数的一种方式,也是深度学习框架表达神经网络模型的统一方法。
compute-graph
在这个模型中, x x x为输入, z z z为输出, y y y为正确值, w w w b b b是需要优化的参数。

x = ops.ones(5, mindspore.float32)  # input tensor
y = ops.zeros(3, mindspore.float32)  # expected output
w = Parameter(Tensor(np.random.randn(5, 3), mindspore.float32), name='w') # weight
b = Parameter(Tensor(np.random.randn(3,), mindspore.float32), name='b') # biasdef function(x, y, w, b):z = ops.matmul(x, w) + bloss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))return loss

通过执行function获得loss值:

loss = function(x,y,w,b)

微分函数与梯度

为了优化参数 w w w b b b,需要求参数对loss的导数 ∂ l o s s ∂ w \frac{\partial loss}{\partial w} wloss ∂ l o s s ∂ b \frac{\partial loss}{\partial b} bloss

可以通过mindspore.grad函数来获得function的微分函数:

grad_fn = mindspore.grad(function, (2, 3))grads = grad_fn(x, y, w, b)

此处使用了grad的两个入参:

  • fn:待求导的函数;
  • grad_position:指定求导输入位置的索引。

Stop Gradient

通常情况下,求导时会求loss对参数的导数,因此函数只输出loss一项。
如果函数输出多项时,微分函数会求所有输出对参数的导数。

def function_with_logits(x, y, w, b):z = ops.matmul(x, w) + bloss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))return loss, zgrad_fn = mindspore.grad(function_with_logits, (2, 3))
grads = grad_fn(x, y, w, b)
print(grads)

此处function_with_logits输出的z会影响梯度。

如果想要实现对某个输出项的梯度截断,或消除某个Tensor对梯度的影响,需要用到Stop Gradient操作。

def function_stop_gradient(x, y, w, b):z = ops.matmul(x, w) + bloss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))return loss, ops.stop_gradient(z)grad_fn = mindspore.grad(function_stop_gradient, (2, 3))
grads = grad_fn(x, y, w, b)
print(grads)

Auxiliary data

Auxiliary data为辅助数据,是函数除第一个输出项外的其他输出。通常loss值为函数的第一个输出,而其它输出即为辅助数据。

gradvalue_and_grad提供has_aux参数,设置为True时,可以自动实现前文中stop_gradient的功能。

grad_fn = mindspore.grad(function_with_logits, (2, 3), has_aux=True)
grads, (z,) = grad_fn(x, y, w, b)
print(grads, z)

神经网络梯度计算

# Define model
class Network(nn.Cell):def __init__(self):super().__init__()self.w = wself.b = bdef construct(self, x):z = ops.matmul(x, self.w) + self.breturn z# Instantiate model
model = Network()
# Instantiate loss function
loss_fn = nn.BCEWithLogitsLoss()

实例化网络和损失函数后,将其封装为一个前向计算函数,用于自动微分:

# Define forward function
def forward_fn(x, y):z = model(x)loss = loss_fn(z, y)return loss

由于使用nn.Cell封装网络模型,其参数为Cell的内部属性,因此不需要指定grad_position参数,直接设置为None

对模型参数求导时,使用weights参数,指定为通过model.trainable_params()方法从Cell中取出的可以求导的参数:

grad_fn = mindspore.value_and_grad(forward_fn, None, weights=model.trainable_params())loss, grads = grad_fn(x, y)
print(grads)

总结

这一节从一个简单的线性函数 w x + b wx+b wx+b出发,介绍了网络模型中数学函数的统一表示方法(即计算图),与loss的计算过程。使用gradvalue_and_grad方法可以通过自动微分获取目标函数的微分函数,从而得到参数对loss的梯度,进而优化参数。对于需要输出辅助数据的函数来说,可以通过ops.stop_gradient进行梯度截断,或设置has_aux=True来自动完成。
在通过Cell封装的网络模型中,需要将模型和loss的调用封装为一个前向计算函数,从而进行自动微分。

打卡

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/41456.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

科普文:微服务之服务网格Service Mesh

一、ServiceMesh概念 背景 随着业务的发展,传统单体应用的问题越来越严重: 单体应用代码库庞大,不易于理解和修改持续部署困难,由于单体应用各组件间依赖性强,只要其中任何一个组件发生更改,将重新部署整…

Android SurfaceFlinger——创建Layer(二十)

上一篇文章介绍到,SurfaceComposerClient 中的 createSurface() 方法最终创建的是一个 Layer,这里我们接着看 Layer 的创建。 一、Layer创建 1、SurfaceFlinger.cpp 源码位置:/frameworks/native/services/surfaceflinger/SurfaceFlinger.cpp status_t SurfaceFlinger:…

MUNIK解读ISO26262--什么是DFA

我们在学习功能安全过程中,经常会听到很多安全分析方法,有我们熟知的FMEA(Failure Modes Effects Analysis)和FTA(Fault Tree Analysis)还有功能安全产品设计中几乎绕不开的FMEDA(Failure Modes Effects and Diagnostic Analysis),相比于它们…

VBA中打开、保存关闭Excel工作簿的方法

前言 本节会介绍使用VBA方法打开Excel工作簿、保存关闭Excel工作簿的方法,分别会用到Open、Save、Close方法的使用。 1.使用Open方法打开工作簿 Workbooks.Open(FileName,UpdateLinks,ReadOnly,Format,Password,WriteResPassword,Ignore-ReadOnlyResommended,Orig…

【OceanBase】OBProxy 无状态的理解

SueWakeup 个人主页:SueWakeup 系列专栏:为祖国的科技进步添砖Java 个性签名:保留赤子之心也许是种幸运吧 本文封面由 凯楠📸友情提供 目录 前言 OBProxy 无状态的概述 OBProxy 无状态特性带来的优点 1. 高可用 2. 负载均衡…

centos7.9安装mysql5.7

由于个人配置的服务器性能比较差,容量也不够,没有使用docker或宝塔安装mysql 参考(有细节差异): https://blog.csdn.net/weixin_44304847/article/details/124349013?ops_request_misc%257B%2522request%255Fid%2522…

2024最新版Redis常见面试题包含详细讲解

Redis适用于哪些场景? 缓存分布式锁降级限流消息队列延迟消息队 说一说缓存穿透 缓存穿透的概念 用户频繁的发起恶意请求查询缓存中和数据库中都不存在的数据,查询积累到一定量级导致数据库压力过大甚至宕机。 缓存穿透的原因 比如正常情况下用户发…

C++基础22 字符串与字符数组及其相关操作

这是《C算法宝典》C基础篇的第22节文章啦~ 如果你之前没有太多C基础,请点击👉C基础,如果你C语法基础已经炉火纯青,则可以进阶算法👉专栏:算法知识和数据结构👉专栏:数据结构啦 ​ 目…

蓝牙传输技术的演进与发展

蓝牙模块技术,作为无线通信领域的重要一员,自其诞生之初便受到了广泛的关注和应用。随着技术的不断发展和演进,蓝牙模块技术已经从最初的单一功能、有限传输速度发展到现在的多功能、高速率、低功耗,为人们的生活和工作带来了极大…

MySQL 一些用来做比较的函数

目录 IF:根据不同条件返回不同的值 CASE:多条件判断,类似于Switch函数 IFNULL:用于检查一个值是否为NULL,如果是,则用指定值代替 NULLIF:比较两个值,如果相等则返回NULL&#xff…

信创-系统架构师认证

随着国家对信息技术自主创新的战略重视程度不断提升,信创产业迎来前所未有的发展机遇。未来几年内,信创产业将呈现市场规模扩大、技术创新加速、产业链完善和国产化替代加速的趋势。信创人才培养对于推动产业发展具有重要意义。应加强高校教育、建立人才…

Android Gradle 开发与应用 (五): 构建变体与自定义任务

目录 1. 概述 2. 构建变体 2.1 构建变体的概念 2.2 构建类型 2.3 产品风味 2.4 构建变体的使用 3. 自定义任务 3.1 自定义任务的概念 3.2 创建自定义任务 3.3 配置任务依赖 3.4 任务类型 3.5 动态任务 3.6 自定义任务执行顺序 4. 案例 4.1 多渠道打包 4.2 自动…

Linux CMakeLists编写之可执行程序

目录 1 概述2 文件命名3 实例4 代码分析 1 概述 编译工具有很多(make/cmake/BJam)。cmake是跨平台,使用cmake编译需要编写CMakeLists.txt。本文编写CMakeLists.txt来生成C可执行程序。 2 文件命名 文件命名为CMakeLists.txt,是一个文本文件,可以使用任何编辑器编辑…

iOS项目怎样进行二进制重排

什么是二进制重排 ? 在iOS项目中,二进制重排(Binary Reordering 或者 Binary Rearrangement)是一种优化技术,主要目的是通过重新组织应用程序的二进制文件中的代码和数据段,来提高应用程序的性能&#xff…

【Kubernetes】如何将应用服务,部署到Kubernetes中???

第一步:准备Docker镜像 首先,将服务打包为Dokcer镜像。确保镜像构建正确,并包含服务运行所需的所有依赖项和配置。 (1)创建一个文件夹(目录) mkdir ./newpath(2)在文件来(目录)中创建Dockerf…

代码随想录训练营Day56

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、搜索插入位置二、在排序数组中查找元素的第一个和最后一个位置 前言 提示:这里可以添加本文要记录的大概内容: 今天是跟着代码随想…

Mqtt Client客户端重连时,清空订阅的主题

最近开发时,碰到需要修改mqttClient客户端的订阅主题。实际测试时发现一个问题,修改订阅了新的主题,重连后,之前订阅的主题还是存在,还能继续收到之前订阅主题的消息。 解决办法,配置 mOptions.setCleanSes…

NXP i.MX8系列平台开发讲解 - 3.18 Linux tty子系统介绍(一)

专栏文章目录传送门:返回专栏目录 Hi, 我是你们的老朋友,主要专注于嵌入式软件开发,有兴趣不要忘记点击关注【码思途远】 目录 1. TTY 起源 2. Linux 系统中的TTY 2.1 Linux TTY 设备形式 2.2 Linux TTY framework 2.3 驱动核心相关文件…

零基础入门怎么学习老挝语字母表?《老挝语翻译通》App真人发音教学,学习老挝语字母发音和词汇句子!

这段老挝文字翻译成中文是什么意思?有什么好用的老挝语翻译工具推荐吗? 快速翻译:中老语言无缝转换,实时翻译,让沟通更流畅。 学习工具:零基础入门到流利对话,老挝语真人发音,让你的…

MaxKB开源知识库问答系统发布v1.3.0版本,新增强大的工作流引擎

2024年4月12日,1Panel开源项目组正式发布官方开源子项目——MaxKB开源知识库问答系统(github.com/1Panel-dev/MaxKB)。MaxKB开源项目发布后迅速获得了社区用户的认可,成功登顶GitHub Trending趋势榜主榜。 截至2024年7月4日&…