《昇思25天学习打卡营第07天|函数式自动微分》

函数式自动微分

环境配置

# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
import numpy as np
import mindspore
from mindspore import nn
from mindspore import ops
from mindspore import Tensor, Parameter

函数与计算图

  • w x + b = z wx + b = z wx+b=z
    -> A c t i v a t i o n − F u n c t i o n ( z ) Activation - Function(z) ActivationFunction(z)
    -> y p r e d y_{pred} ypred
    -> C r o s s − E n t r o p y ( y , y p r e d ) Cross - Entropy(y , y_{pred}) CrossEntropy(y,ypred)

  • w , b 为需要优化的参数 w,b为需要优化的参数 w,b为需要优化的参数

    x = ops.ones(5, mindspore.float32) # input tensor
    y = ops.zones(3, mindspore.float32) # expected output
    w = Parameter(Tensor(np.random.randn(5, 3), mindspore.float32), name = 'w')
    b = Parameter(Tensor(np.random.randn(3,), mindspore.float32), name='b') # biasdef function(x, y, w, b):z = ops.matmul(x, w) + bloss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))return lossloss = function(x, y, w, b)
    print(loss)
    #output Tensor(shape=[], dtype=Float32, value= 0.914285)
    

微分函数与梯度计算

  • 为优化模型需要求参数对loss的导数 ∂ l o s s ∂ w \frac{\partial loss}{\partial w} wloss, ∂ l o s s ∂ b \frac{\partial loss}{\partial b} bloss
  • 调用mindspore.grad函数获取function的微分函数
  • fn: 待求导函数
  • grad_position: 指定求导输入位置索引
  • 使用grad获得微分函数是一种函数变换,即输入为函数,输出也为函数
grad_fn = mindspore.grad(function, (2, 3))
grads = grad_fn(x, y, w, b)
print(grads)
#Output (Tensor(shape=[5, 3], dtype=Float32, value= [[ 6.56869709e-02,  5.37334494e-02,  3.01467031e-01], [ 6.56869709e-02,  5.37334494e-02,  3.01467031e-01], [ 6.56869709e-02,  5.37334494e-02,  3.01467031e-01], [ 6.56869709e-02,  5.37334494e-02,  3.01467031e-01], [ 6.56869709e-02,  5.37334494e-02,  3.01467031e-01]]), Tensor(shape=[3], dtype=Float32, value= [ 6.56869709e-02,  5.37334494e-02,  3.01467031e-01]))

Stop Gradient

  • 实现对某个输出项的梯度截断,或消除某个Tensor对梯度的影响
def function_with_logits(x, y, w, b):z = ops.matmul(x, w) + bloss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))return loss, z
grad_fn = mindspore.grad(function_with_logits, (2, 3))
grads = grad_fn(x, y, w, b)
# 若想屏蔽掉z对梯度的影响,使用ops.stop_gradient接口, 将梯度在此截断def function_stop_gradient(x, y, w, b):z = ops.matmul(x, w) + bloss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))return loss, ops.stop_gradient(z)grad_fn = mindspore.grad(function_stop_gradient, (2, 3))
grads = grad_fn(x, y, w, b)

Auxiliary data

  • Auxiliary data为辅助数据,是函数除第一个输出项外的其他输出。
  • gradvalue_and_grad提供has_aux参数,当其设置为True时,可以自动实现前文手动添加stop_gradient的功能。
grad_fn = mindspore.grad(function_with_logits, (2, 3), has_aux=True)
grads, (z,) = grad_fn(x, y, w, b)

神经网络梯度计算

#定义模型
class Network(nn.Cell):def __init__(self):super().__init__()self.w = wself.b = bdef construct(self, x):z = ops.matmul(x, self.w) + self.breturn z
# 实例化模型
model = Network()
# 实例化损失函数
loss_fn = nn.BCEWithLogitsLoss()
# 定义正向传播
def forward_fn(x, y):z = model(x)loss = loss_fn(z, y)return loss
grad_fn = mindspore.value_and_grad(forward_fn, None, weights=model.trainable_params())
loss, grads = grad_fn(x, y)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/38443.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Windows10录屏,教你3个方法,简单快速录屏

“我的电脑系统是Windows10的系统,今晚要进行线上开会,但我实在有事没办法参加会议,想把会议的内容录制下来方便我后续观看。但却找不到电脑录屏功能在哪里打开?求助一下,谁能帮帮我?” 在数字化时代&…

mysql 命令 —— 查看表信息(show table status)

查询表信息,如整个表的数据量大小、表的索引占用空间大小等 1、查询某个库下面的所有表信息: SHOW TABLE STATUS FROM your_database_name;2、查询指定的表信息: SHOW TABLE STATUS LIKE your_table_name;如:Data_length 显示表…

闲聊 .NET Standard

前言 有时候,我们从 Nuget 下载第三方包时,会看到这些包的依赖除了要求 .NET FrameWork、.NET Core 等的版本之外,还会要求 .NET Standard 的版本,比如这样: 这个神秘的 .NET Standard 是什么呢? .NET St…

【算法】字母异位词分组

题目:字母异位词分组 给你一个字符串数组,请你将 字母异位词 组合在一起。可以按任意顺序返回结果列表。 字母异位词 是由重新排列源单词的所有字母得到的一个新单词。 示例 1: 输入: strs [“eat”, “tea”, “tan”, “ate”, “nat”, “bat”] …

从零开始搭建spring boot多模块项目

一、搭建父级模块 1、打开idea,选择file–new–project 2、选择Spring Initializr,选择相关java版本,点击“Next” 3、填写父级模块信息 选择/填写group、artifact、type、language、packaging(后面需要修改)、java version(后面需要修改成和第2步中版本一致)。点击“…

【0300】Postgres内核动态哈希表实现机制(1)

相关文章: 【0299】Postgres内核之哈希表(Hash Tables) 0 概述 在【0299】Postgres内核之哈希表(Hash Tables)一文中,讲解了哈希表的作用、实现、优缺点等特性。本文开始,将详细分析Postgres内…

MySQL之应用层优化(三)

应用层优化 应用层缓存 2.本地共享内存缓存 这种缓存一般是中等大小(几个GB),快速,难以在多台机器间同步。它们对小型的半静态位数据比较合适。例如每个州的城市列表,分片数据存储的分区函数(映射表),或者使用存活时间(TTL)策略…

记录一次Chrome浏览器自动排序ajax请求的JSON数据问题

文章目录 1.前言2. 为什么会这样?3.如何解决? 1.前言 作者作为新人入职的第一天,mentor给了一个维护公司运营平台的小需求,具体需求是根据运营平台的某个管理模块所展示记录的某些字段对展示记录做排序。 第一步: myb…

工业触摸一体机优化MES应用开发流程

工业触摸一体机在现代工业生产中扮演着至关重要的角色,它集成了智能触摸屏和工业计算机的功能,广泛应用于各种生产场景中。而制造执行系统(MES)作为工业生产管理的重要工具,对于提高生产效率、降低成本、优化资源利用具…

力扣hot100-普通数组

文章目录 题目:最大子数组和方法1 动态规划方法2 题目:合并区间题解 题目:最大子数组和 原题链接:最大子数组和 方法1 动态规划 public class T53 {//动态规划public static int maxSubArray(int[] nums) {if (nums.length 0…

C++基础知识-编译相关

记录C语言相关的基础知识 1 C源码到可执行文件的四个阶段 预处理(.i)、编译(.s)、汇编(.obj)、链接。 1.1 预处理 预处理阶段,主要完成宏替换、文件展开、注释删除、条件编译展开、添加行号和文件名标识,输出.i/.ii预处理文件。 宏替换,…

【UML用户指南】-26-对高级行为建模-状态图

目录 1、概念 2、组成结构 3、一般用法 4、常用建模技术 4.1、对反应型对象建模 一个状态图显示了一个状态机。在为对象的生命期建模中 活动图展示的是跨过不同的对象从活动到活动的控制流 状态图展示的是单个对象内从状态到状态的控制流。 在UML中,用状态图…

tcpdump命令详解及使用实例

1、抓所有网卡数据包,保存到指定路径 tcpdump -i any -w /oemdata/123.pcap&一、tcpdump简介 tcpdump可以将网络中传送的数据包完全截获下来提供分析。它支持针对网络层、协议、主机、网络或端口的过滤,并提供and、or、not等逻辑语句来去掉无用的信…

【Python】已解决:SyntaxError: positional argument follows keyword argument

文章目录 一、分析问题背景二、可能出错的原因三、错误代码示例四、正确代码示例五、注意事项 已解决:SyntaxError: positional argument follows keyword argument 一、分析问题背景 在Python编程中,当我们在调用函数时混合使用位置参数(p…

RabbitMQ进阶篇

文章目录 发送者的可靠性生产者重试机制实现生产者确认 MQ的可靠性数据持久化交换机持久化队列持久化消息持久化 Lazy Queue(可配置~)控制台配置Lazy模式代码配置Lazy模式更新已有队列为lazy模式 消费者的可靠性消费者确认机制失败重试机制失败处理策略 业务幂等性唯一消息ID业…

西部智慧健身小程序+华为运动健康服务

1、 应用介绍 西部智慧健身小程序为用户提供一站式全流程科学健身综合服务。用户通过登录微信小程序,可享用健康筛查、运动风险评估、体质检测评估、运动处方推送、个人运动数据监控与评估等公益服务。 2、 体验介绍西部智慧健身小程序华为运动健康服务核心体验如…

idea xml ctrl+/ 注释格式不对齐

处理前 处理后 解决办法 取消这两个勾选

核方法总结(三)———核主成分(kernel PCA)学习笔记

一、核主成分 1.1 和PCA的区别 PCA (主成分分析)对应一个线性高斯模型(参考书的第二章),其基本假设是数据由一个符合正态分布的隐变量通过一个线性映射得到,因此可很好描述符合高斯分布的数据。然而在很多实…

ViewBinding的使用(因为kotlin-android-extensions插件的淘汰)

书籍: 《第一行代码 Android》第三版 开发环境: Android Studio Jellyfish | 2023.3.1 问题: 3.2.4在Activity中使用Toast章节中使用到了kotlin-android-extensions插件,但是该插件已经淘汰,根据网上了解,目前使用了新的技术VewBinding替…

UE4_材质_材质节点_DepthFade

一、DepthFade参数 DepthFade(深度消退)表达式用来隐藏半透明对象与不透明对象相交时出现的不美观接缝。 项目说明属性消退距离(Fade Distance)这是应该发生消退的全局空间距离。未连接 FadeDistance(FadeDistance&a…