《昇思25天学习打卡营第6天 | mindspore 函数式自动微分常见用法》

1. 背景:

使用 mindspore 学习神经网络,打卡第6天;

2. 训练的内容:

使用 mindspore 的函数式自动微分常见用法;

3. 常见的用法小节:

函数式自动微分支持一系列常用的函数

3.1 损失函数:

binary_cross_entropy_with_logits 是一个损失函数,计算预测值和目标值之间的二值交叉熵损失

# 使用函数式自动微分的设计理念, 
# 自动微分接口 grad 和 value_and_grad
import numpy as np
import mindspore
from mindspore import nn
from mindspore import ops
from mindspore import Tensor, Parameter# 构建输入参数
x = ops.ones(5, mindspore.float32)
y = ops.zeros(3, mindspore.float32)
w = Parameter(Tensor(np.random.randn(5, 3), mindspore.float32), name='w')
b = Parameter(Tensor(np.random.randn(3,), mindspore.float32), name='b')# binary_cross_entropy_with_logits 是一个损失函数,计算预测值和目标值之间的二值交叉熵损失。
def function_x_y(x, y, w, b):z = ops.matmul(x, w) + bloss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))return lossloss = function_x_y(x, y, w, b)
print(loss)

3.2 微分函数与梯度计算

为了优化模型参数,需要求参数对loss的导数, 可使用 mindspore 的 grad 方法;grad函数的两个入参

  • fn: 求导的函数;
  • grad_position: 求导输入位置索引
grad_fn = mindspore.grad(function_x_y, (2, 3))
grads = grad_fn(x, y, w, b)
print(grads) # stop gradient
# 将function改为同时输出loss和z的function_with_logits
def fucntion_with_logits(x, y, w, b):z = ops.matmul(x, w) + bloss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))return loss, zgrad_fn = mindspore.grad(fucntion_with_logits, (2, 3))
grads = grad_fn(x, y, w, b)
print(grads)

3.3 Stop Gradient

当我们希望函数输出多项时,微分函数会求所有输出项对参数的导数。此时如果想实现对某个输出项的梯度截断,或消除某个Tensor对梯度的影响,需要用到Stop Gradient操作

# 想实现对某个输出项的梯度截断,或消除某个Tensor对梯度的影响,需要用到Stop Gradient操作
def function_stop_gradient(x, y, w, b):z = ops.matmul(x, w) + bloss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))return loss, ops.stop_gradient(z)grad_fn = mindspore.grad(function_stop_gradient, (2, 3))
grads = grad_fn(x, y, w, b)
print(grads)

3.4 神经网络梯度计算

通过Cell构造同样的神经网络,利用函数式自动微分来实现反向传播

# 神经网络梯度计算
# Define model
class Network(nn.Cell):def __init__(self):super().__init__()self.w = wself.b = bdef construct(self, x):z = ops.matmul(x, self.w) + self.breturn z# Instantiate model
model = Network()
# Instantiate loss function
loss_fn = nn.BCEWithLogitsLoss()# Define forward function
def forward_fn(x, y):z = model(x)loss = loss_fn(z, y)return loss# 使用 trainable_params()方法取出求导的参数
grad_fn = mindspore.value_and_grad(forward_fn, None, weights=model.trainable_params())loss, grads = grad_fn(x, y)
print(grads)

相关链接:

  • https://xihe.mindspore.cn/events/mindspore-training-camp
  • https://gitee.com/mindspore/docs/blob/r2.3.0rc2/tutorials/source_zh_cn/beginner/autograd.ipynb

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/41847.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

系统重装

待更新 重置win11 双系统删除其中一个,并将格式化后的空间并入

HTTP请求详解及其在嵌入式系统中的应用

前言 HTTP(HyperText Transfer Protocol,超文本传输协议)是互联网中最广泛使用的应用层协议,用于客户端和服务器之间的数据传输。了解HTTP请求的工作原理对于开发网络应用和嵌入式系统中的网络通信至关重要。本文将详细介绍HTTP请…

跟着峰哥学java 第四天 商品分类 前后端显示

1.后端 1.1mybatis-plus分页查询配置 在商品热卖数据中,只让其显示八条数据 将要使用分页 也就是service.page方法 此时需要配置 mp拦截器 Configuration public class MybatisPlusConfig {Beanpublic PaginationInterceptor paginationInterceptor() {return …

模型训练之数据集

我们知道人工智能的四大要素:数据、算法、算力、场景。我们训练模型离不开数据 目标 一、数据集划分 定义 数据集:训练集是一组训练数据。 样本:一组数据中一个数据 特征:反映样本在某方面的表现、属性或性质事项 训练集&#…

星辰宇宙动态页面vue版,超好看的前端页面。附源码与应用教程(若依)

本代码的html版本,来源自“山羊の前端小窝”作者,我对此进行了vue版本转换以及相关应用。特此与大家一起分享~ 1、直接上效果图: 带文字版:文字呼吸式缩放。 纯净版: 默认展示效果: 缩放与旋转后&#xf…

mysql5.6的安装步骤

1.下载mysql 下载地址:https://downloads.mysql.com/archives/community/ 在这里我们下载zip的包 2.解压mysql包到指定目录 3. 添加my.ini文件 # For advice on how to change settings please see # http://dev.mysql.com/doc/refman/5.6/en/server-configurat…

tongweb+ths6011测试websocket(by lqw)

本次使用的tongweb版本7049m4,测试包ws_example.war(在tongweb安装目录的samples/websocket下),ths版本6011 首先在tongweb控制台部署一下ws_example.war,部署后测试是否能访问: 然後ths上的httpserver.conf的參考配…

腾讯centos mysql安装

腾讯centos mysql安装 腾讯云提供了一系列的云计算服务,包括操作系统、数据库、服务器等。在腾讯云上安装CentOS操作系统和MySQL数据库可以按照以下步骤进行: 登录腾讯云控制台(登录 - 腾讯云)。在控制台页面上方的搜索框中输入…

vue数组变化的侦测***

数组变化的侦测 变更方法 vue能够侦听响应式数组的变更方法,并在他们被调用时触发相关更新。这些变更方法包括: push()pop()shift()unshift()splice()sort()reverse() 替换一个数组 变更方法,顾名思义,就是会对调用他们的原数组进…

Vue 路由传参 query方法 bug 记录

问题描述 vue 路由传参 踩坑 this.$router.push({path: "xxxxxxx",query: {opportunity_id:row.opportunity_id,constructor:row.constructor,},});解决方案: 上述方法传入新页面时,访问的 this.$route.query 会有bug 每一次刷新都会在最后一…

DNS服务器

DNS服务器 一、DNS简介: DNS(Domain Name System)是一种用于将域名解析为IP地址的系统。 在DNS中,正向解析将域名转换为IP地址,而反向解析将IP地址转换为域名。正向解析是DNS系统最常用的解析方式,它允许…

本地部署到服务器上的资源路径问题

本地部署到服务器上的资源路径问题 服务器端的源代码的静态资源目录层级 当使用Thymeleaf时,在templates的目录下为返回的html页面,下面以两个例子解释当将代码部署到tomcat时访问资源的路径配置问题 例子一 index.html(在templates的根目录…

VBA初学:零件成本统计之三(获取材料外协的金额)

第三步,从K3的数据库中获取金额 我这里是使用循环,通过任务单号将金额汇总出来,如果使用数组的话,还要按任务单写GROUP,还要去对应,不如循环直接一点 获取材料和外协金额的表格Sub getje()Dim rowcount A…

leetcode-每日一题

3101. 交替子数组计数https://leetcode.cn/problems/count-alternating-subarrays/ 给你一个 二进制数组 nums 。 如果一个 子数组 中 不存在 两个 相邻 元素的值 相同 的情况,我们称这样的子数组为 交替子数组 。 返回数组 nums 中交替子数组的数量。 示例 …

算法力扣刷题 三十四【71.简化路径】

前言 栈和队列篇。 记录 三十四【71.简化路径】 一、题目阅读 给你一个字符串 path ,表示指向某一文件或目录的 Unix 风格 绝对路径 (以 ‘/’ 开头),请你将其转化为更加简洁的规范路径。 在 Unix 风格的文件系统中&#xff0c…

3-2 梯度与反向传播

3-2 梯度与反向传播 主目录点这里 梯度的含义 可以看到红色区域的变化率较大,梯度较大;绿色区域的变化率较小,梯度较小。 在二维情况下,梯度向量的方向指向函数增长最快的方向,而其大小表示增长的速率。 梯度的计算 …

使用Python实现深度学习模型:模型解释与可解释人工智能

在深度学习领域,模型解释和可解释性人工智能(XAI)正变得越来越重要。理解深度学习模型的决策过程对于提高模型的透明度和可信度至关重要。本文将详细介绍如何使用Python实现模型解释和可解释性人工智能,包括基本概念、常用方法、代码实现和示例应用。 目录 模型解释与可解…

docker使用镜像jms_all部署jumpserver

创建容器需要挂载出来的服务器对应目录 mkdir -p /data/redis/data mkdir -p /opt/mysql/{data,conf,logs}docker安装redis docker run -d -it --name redis -p 6379:6379 -v /data/redis/data:/data --restart=always

如何第一次从零上传项目到GitLab

嗨,我是兰若,今天想给大家说下,如何上传一个完整的项目到与LDAP集成的GitLab,也就是说这个项目之前是不在git上面的,这是第一次上传,这样上传上去之后,其他小伙伴就可以根据你这个项目的git地址…

3. train_encoder_decoder.py

train_encoder_decoder.py #__future__ 模块提供了一种方式,允许开发者在当前版本的 Python 中使用即将在将来版本中成为标准的功能和语法特性。此处为了确保代码同时兼容Python 2和Python 3版本中的print函数 from __future__ import print_function # 导入标准库…