昇思25天学习打卡营第07天 | 函数式自动微分

昇思25天学习打卡营第07天 | 函数式自动微分

文章目录

  • 昇思25天学习打卡营第07天 | 函数式自动微分
    • 函数与计算图
      • 微分函数与梯度
      • Stop Gradient
      • Auxiliary data
    • 神经网络梯度计算
    • 总结
    • 打卡

神经网络的训练主要使用反向传播算法,首先计算模型预测值(logits)与正确标签(label)之间的loss,然后进行反向传播,通过梯度来更新模型参数从而完成网路的训练。

MindSpore使用函数式自动微分的设计理念,提供更接近于数学语义的自动微分接口gradvalue_and_grad

函数与计算图

计算图是图论语言表示数学函数的一种方式,也是深度学习框架表达神经网络模型的统一方法。
compute-graph
在这个模型中, x x x为输入, z z z为输出, y y y为正确值, w w w b b b是需要优化的参数。

x = ops.ones(5, mindspore.float32)  # input tensor
y = ops.zeros(3, mindspore.float32)  # expected output
w = Parameter(Tensor(np.random.randn(5, 3), mindspore.float32), name='w') # weight
b = Parameter(Tensor(np.random.randn(3,), mindspore.float32), name='b') # biasdef function(x, y, w, b):z = ops.matmul(x, w) + bloss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))return loss

通过执行function获得loss值:

loss = function(x,y,w,b)

微分函数与梯度

为了优化参数 w w w b b b,需要求参数对loss的导数 ∂ l o s s ∂ w \frac{\partial loss}{\partial w} wloss ∂ l o s s ∂ b \frac{\partial loss}{\partial b} bloss

可以通过mindspore.grad函数来获得function的微分函数:

grad_fn = mindspore.grad(function, (2, 3))grads = grad_fn(x, y, w, b)

此处使用了grad的两个入参:

  • fn:待求导的函数;
  • grad_position:指定求导输入位置的索引。

Stop Gradient

通常情况下,求导时会求loss对参数的导数,因此函数只输出loss一项。
如果函数输出多项时,微分函数会求所有输出对参数的导数。

def function_with_logits(x, y, w, b):z = ops.matmul(x, w) + bloss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))return loss, zgrad_fn = mindspore.grad(function_with_logits, (2, 3))
grads = grad_fn(x, y, w, b)
print(grads)

此处function_with_logits输出的z会影响梯度。

如果想要实现对某个输出项的梯度截断,或消除某个Tensor对梯度的影响,需要用到Stop Gradient操作。

def function_stop_gradient(x, y, w, b):z = ops.matmul(x, w) + bloss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))return loss, ops.stop_gradient(z)grad_fn = mindspore.grad(function_stop_gradient, (2, 3))
grads = grad_fn(x, y, w, b)
print(grads)

Auxiliary data

Auxiliary data为辅助数据,是函数除第一个输出项外的其他输出。通常loss值为函数的第一个输出,而其它输出即为辅助数据。

gradvalue_and_grad提供has_aux参数,设置为True时,可以自动实现前文中stop_gradient的功能。

grad_fn = mindspore.grad(function_with_logits, (2, 3), has_aux=True)
grads, (z,) = grad_fn(x, y, w, b)
print(grads, z)

神经网络梯度计算

# Define model
class Network(nn.Cell):def __init__(self):super().__init__()self.w = wself.b = bdef construct(self, x):z = ops.matmul(x, self.w) + self.breturn z# Instantiate model
model = Network()
# Instantiate loss function
loss_fn = nn.BCEWithLogitsLoss()

实例化网络和损失函数后,将其封装为一个前向计算函数,用于自动微分:

# Define forward function
def forward_fn(x, y):z = model(x)loss = loss_fn(z, y)return loss

由于使用nn.Cell封装网络模型,其参数为Cell的内部属性,因此不需要指定grad_position参数,直接设置为None

对模型参数求导时,使用weights参数,指定为通过model.trainable_params()方法从Cell中取出的可以求导的参数:

grad_fn = mindspore.value_and_grad(forward_fn, None, weights=model.trainable_params())loss, grads = grad_fn(x, y)
print(grads)

总结

这一节从一个简单的线性函数 w x + b wx+b wx+b出发,介绍了网络模型中数学函数的统一表示方法(即计算图),与loss的计算过程。使用gradvalue_and_grad方法可以通过自动微分获取目标函数的微分函数,从而得到参数对loss的梯度,进而优化参数。对于需要输出辅助数据的函数来说,可以通过ops.stop_gradient进行梯度截断,或设置has_aux=True来自动完成。
在通过Cell封装的网络模型中,需要将模型和loss的调用封装为一个前向计算函数,从而进行自动微分。

打卡

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/41456.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

科普文:微服务之服务网格Service Mesh

一、ServiceMesh概念 背景 随着业务的发展,传统单体应用的问题越来越严重: 单体应用代码库庞大,不易于理解和修改持续部署困难,由于单体应用各组件间依赖性强,只要其中任何一个组件发生更改,将重新部署整…

MUNIK解读ISO26262--什么是DFA

我们在学习功能安全过程中,经常会听到很多安全分析方法,有我们熟知的FMEA(Failure Modes Effects Analysis)和FTA(Fault Tree Analysis)还有功能安全产品设计中几乎绕不开的FMEDA(Failure Modes Effects and Diagnostic Analysis),相比于它们…

【OceanBase】OBProxy 无状态的理解

SueWakeup 个人主页:SueWakeup 系列专栏:为祖国的科技进步添砖Java 个性签名:保留赤子之心也许是种幸运吧 本文封面由 凯楠📸友情提供 目录 前言 OBProxy 无状态的概述 OBProxy 无状态特性带来的优点 1. 高可用 2. 负载均衡…

2024最新版Redis常见面试题包含详细讲解

Redis适用于哪些场景? 缓存分布式锁降级限流消息队列延迟消息队 说一说缓存穿透 缓存穿透的概念 用户频繁的发起恶意请求查询缓存中和数据库中都不存在的数据,查询积累到一定量级导致数据库压力过大甚至宕机。 缓存穿透的原因 比如正常情况下用户发…

C++基础22 字符串与字符数组及其相关操作

这是《C算法宝典》C基础篇的第22节文章啦~ 如果你之前没有太多C基础,请点击👉C基础,如果你C语法基础已经炉火纯青,则可以进阶算法👉专栏:算法知识和数据结构👉专栏:数据结构啦 ​ 目…

蓝牙传输技术的演进与发展

蓝牙模块技术,作为无线通信领域的重要一员,自其诞生之初便受到了广泛的关注和应用。随着技术的不断发展和演进,蓝牙模块技术已经从最初的单一功能、有限传输速度发展到现在的多功能、高速率、低功耗,为人们的生活和工作带来了极大…

信创-系统架构师认证

随着国家对信息技术自主创新的战略重视程度不断提升,信创产业迎来前所未有的发展机遇。未来几年内,信创产业将呈现市场规模扩大、技术创新加速、产业链完善和国产化替代加速的趋势。信创人才培养对于推动产业发展具有重要意义。应加强高校教育、建立人才…

NXP i.MX8系列平台开发讲解 - 3.18 Linux tty子系统介绍(一)

专栏文章目录传送门:返回专栏目录 Hi, 我是你们的老朋友,主要专注于嵌入式软件开发,有兴趣不要忘记点击关注【码思途远】 目录 1. TTY 起源 2. Linux 系统中的TTY 2.1 Linux TTY 设备形式 2.2 Linux TTY framework 2.3 驱动核心相关文件…

零基础入门怎么学习老挝语字母表?《老挝语翻译通》App真人发音教学,学习老挝语字母发音和词汇句子!

这段老挝文字翻译成中文是什么意思?有什么好用的老挝语翻译工具推荐吗? 快速翻译:中老语言无缝转换,实时翻译,让沟通更流畅。 学习工具:零基础入门到流利对话,老挝语真人发音,让你的…

MaxKB开源知识库问答系统发布v1.3.0版本,新增强大的工作流引擎

2024年4月12日,1Panel开源项目组正式发布官方开源子项目——MaxKB开源知识库问答系统(github.com/1Panel-dev/MaxKB)。MaxKB开源项目发布后迅速获得了社区用户的认可,成功登顶GitHub Trending趋势榜主榜。 截至2024年7月4日&…

docker仓库--centos7.9部署harbor详细过程与使用以及常见问题

文章目录 前言1.docker-compose是什么2.harbor是什么 centos7部署harbor详细过程与使用环境一、部署docker二、部署harbor1.下载docker-compose工具2.harbor安装3.拷贝样本文件,并修改文件4.安装harbor,安装完成自行启动5.查看 三、harbor的使用1.创建项…

Https网站如何申请免费的SSL证书及操作使用指南

前言 在当今互联网环境下,HTTPS已成为网站安全的标配,它通过SSL/TLS协议为网站数据传输提供加密,保障用户信息的安全。申请并部署免费SSL证书,不仅能够提升网站的专业形象,还能增强用户信任。本文将详细介绍如何在知名…

StreamSets: 数据采集工具详解

欢迎来到我的博客,很高兴能够在这里和您见面!欢迎订阅相关专栏: 欢迎关注微信公众号:野老杂谈 ⭐️ 全网最全IT互联网公司面试宝典:收集整理全网各大IT互联网公司技术、项目、HR面试真题. ⭐️ AIGC时代的创新与未来&a…

Golang语法规范和风格指南(一)——简单指南

1. 前引 一个语言的规范的学习是重要的,直接关系到你的代码是否易于维护和理解,同时学习好对应的语言规范可以在前期学习阶段有效规避该语言语法和未知编程风格的冲突。 这里是 Google 提供的规范,有助于大家在开始学习阶段对 Golang 进行一…

Tensorflow入门实战 T07-Vgg16网络进行咖啡豆识别

本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 | 接辅导、项目定制 1、 前言 这周学习的主要内容是,使用tensorflow编写代码,使用vgg-16网络模型,完成咖啡豆的识别。 2、完整代码 imp…

【密码学基础】对随机不经意传输(Random Oblivious Transfer)的理解

ROT在offline阶段生成大量的OT对,在online阶段通过one-pad方式高效加密,并且只需要简单的异或运算就能实现OT过程(去随机化)。 在ROT中,有一个关键点是:需要考虑offline阶段的选择比特和online阶段的选择比…

C++ 视觉开发 六.特征值匹配

以图片识别匹配的案例来分析特征值检测与匹配方法。 目录 一.感知哈希算法(Perceptual Hash Algorithm) 二.特征值检测步骤 1.减小尺寸 2.简化色彩 3.计算像素点均值 4.构造感知哈希位信息 5.构造一维感知哈希值 三.实现程序 1.感知哈希值计算函数 2.计算距离函数 3…

vscode 生成项目目录结构 directory-tree 实用教程

1. 安装插件 directory-tree 有中文介绍,极其友好! 2. 用 vscode 打开目标项目 3. 快捷键 Ctrl Shift p,输入 Directory Tree 后回车 会在 README.md 文件的底部生成项目目录(若项目中没有 README.md 文件,则会自动创…

用NanoID换掉 UUID,好处是?【送源码】

当我们在分布式环境中存储一些数据的时候,不得不面对的一个选择,就是ID生成器。 使用一个唯一的字符串,来标识一条完整的记录。 这时候,不能使用md5或者sha1来对整个记录做摘要,因为我们后续还要改动这个记录。也不能…

【C++】日期类

鼠鼠实现了一个日期类,用来练习印证前几篇博客介绍的内容!! 目录 1.日期类的定义 2.得到某年某月的天数 3.检查日期是否合法 4.(全缺省)构造函数 5.拷贝构造函数 6.析构函数 7.赋值运算符重载 8.>运算符重…