深度学习-09-让函数更简单

深度学习-09-让函数更简单


本文是《深度学习入门2-自製框架》 的学习笔记,记录自己学习心得,以及对重点知识的理解。如果内容对你有帮助,请支持正版,去购买正版书籍,支持正版书籍不仅是尊重作者的辛勤劳动,也是鼓励更多优秀作品问世。

当前笔记内容主要为:步骤9 让函数更简单 章节的相关理解。

书籍总共分为5个阶段,每个阶段分很多步骤,最终是一步一步实现一个深度学习框架。例如前两个阶段为:

第 1 阶段共包括 10 个步骤 。 在这个阶段,将创建自动微分的机制
第 2 阶段,从步骤11-24,该阶段的主要目标是扩展当前的 DeZero ,使它能够执行更复杂的计算 ,使它能 够处理接收多个输入的函数和返回多个输出的函数


1.作为python 使用 


思考下我们之前的使用模式,在使用函数对象的时候,我们先要定义个出来,然后再使用,这里可以优化吗?

 x = Variable(np.array(0.5))A = Square()a = A(x)

是不是有点冗余,啰嗦,能否直接调用 a = f(x) 这种方式?

其实是可以的。我们对代码进行修改:

def square(x):f = Square()return f(x)def exp(x):f = Exp()return f(x)

上面代码,更进一步优化

def square(x):return Square()(x)def exp(x):return Exp()(x)

编写测试案例,看结果是否一致

    x = Variable(np.array(0.5))a= square(x)b = exp(a)y = square(b)y.grad = np.array(1.0)y.backward()print(x.grad)


输出结果还是  3.297442541400256  说明函数是等价的


2.简化backward方法

这里优化的目的是,简化用户在反向传播方面的工作。省略前面代码中的 y . grad = np.array(1 . 0),每次反向传播时,我们都要定义这个代码。代码改进:

class Variable:def __init__(self, data):self.data = dataself.grad = Noneself.creator = Nonedef set_creator(self,func):self.creator = funcdef backward(self):if self.grad is None:self.grad = np.ones_like(self.data)  # 初始化类型与data 一样funcs = [self.creator]while funcs:f = funcs.pop()x, y =f.input, f.outputx.grad = f.backward(y.grad)if x.creator is not None:funcs.append(x.creator)


代码验证:

    # 优化ones_like 初始化后# 不需要定义 y.grad = np.array(1.0) 这个了x = Variable(np.array(0.5))y = square(exp(square(x)))y.backward()print(x.grad)

输出结果
3.297442541400256


3.仅支持ndarray 

为了减少用户误用,增加参数校验。我们的框架从开始就是只支持Variable ndarray 的示例。为了避免有些用户很可能会不小心使用 float 或 int 等数据类型。例如:Variable(1. 0) 和 Variable(3) 等 这些错误的使用,我们增加参数校验。

class Variable:def __init__(self, data):if data is not None:if not isinstance(data, np.ndarray):raise TypeError('{} is not supported'.format(type(data)))  #参数校验self.data = dataself.grad = Noneself.creator = None

此外,由于NumPy 的特点,带来新的问题

考虑以下代码

    x = np.array([1.0])y = x ** 2print(type(x), x.ndim)print(type(y))


输出:

C:\Python\Python39-32\python.exe D:/pyworkspace/dezero-01/step09.py
<class 'numpy.ndarray'> 1
<class 'numpy.ndarray'>

这里执行时先注释掉其他case 案例


if __name__ == '__main__':# x = Variable(np.array(0.5))# a = square(x)# b = exp(a)# y = square(b)## y.grad = np.array(1.0)# y.backward()# print(x.grad)### # 优化ones_like 初始化后# # 不需要定义 y.grad = np.array(1.0) 这个了# x = Variable(np.array(0.5))# y = square(exp(square(x)))# y.backward()# print(x.grad)### # 错误使用## x = Variable(np.array(1.0))# x = Variable(None)# #x = Variable(1.0)  # 错误使用# Numpy 特性问题x = np.array([1.0])y = x ** 2print(type(x), x.ndim)print(type(y))

如果要上面的案例也可以执行,就要改造。考虑一种case ,由numpy 特性导致的问题。

    x= np.array(1.0)y = x ** 2print(type(x), x.ndim)print(type(y))

输出:

<class 'numpy.ndarray'> 0
<class 'numpy.float64'>

发现 0维的变成了 numpy.float64 、 numpy.float32

这意 味着 DeZero 函数的输山 Variable 可能是 numpy. float64 或 numpy. float32 类哑 的数据。我们在数据输出,输入过程中需要检查,强制转换下。


首先引入辅助函数 

def as_array(x):if np.isscalar(x):  # 使用 np.isscalar 函数来检查 numpy.float64 等属于标量return np.array(x)return x

修改 Function 类的输出,增加强制转换

class Function:def __call__(self, input):x = input.datay = self.forward(x)output = Variable(as_array(y))   # 转成 ndarray 类型output.set_creator(self)  # 输出者保存创造者对象self.input = inputself.output = output  # 保存输出者。我是创造者的信息,这是动态建立 "连接"这 一 机制的核心return outputdef forward(self, x):raise NotImplementedError()  # 使用Function  这个方法forward 方法的人 , 这个方法应该通过继承采实现def backward(self, gy):raise NotImplementedError()

最后执行,发现所有的案例都可以跑通了:

if __name__ == '__main__':x = Variable(np.array(0.5))a = square(x)b = exp(a)y = square(b)y.grad = np.array(1.0)y.backward()print(x.grad)# 优化ones_like 初始化后# 不需要定义 y.grad = np.array(1.0) 这个了x = Variable(np.array(0.5))y = square(exp(square(x)))y.backward()print(x.grad)# 错误使用x = Variable(np.array(1.0))x = Variable(None)#x = Variable(1.0)  # 错误使用# Numpy 特性问题x = np.array([1.0])y = x ** 2print(type(x), x.ndim)print(type(y))x= np.array(1.0)y = x ** 2print(type(x), x.ndim)print(type(y))


输出结果:

C:\Python\Python39-32\python.exe D:/pyworkspace/dezero-01/step09.py
3.297442541400256
3.297442541400256
<class 'numpy.ndarray'> 1
<class 'numpy.ndarray'>
<class 'numpy.ndarray'> 0
<class 'numpy.float64'>

进程已结束,退出代码0

4.代码总结

到此相关优化以及完成,本节所有代码如下:

'''
step09.py
优化-函数更易于使用'''import numpy as npclass Variable:def __init__(self, data):if data is not None:   # 新增if not isinstance(data, np.ndarray):raise TypeError('{} is not supported'.format(type(data)))self.data = dataself.grad = Noneself.creator = Nonedef set_creator(self,func):self.creator = funcdef backward(self):if self.grad is None:self.grad = np.ones_like(self.data)funcs = [self.creator]while funcs:f = funcs.pop()x, y =f.input, f.outputx.grad = f.backward(y.grad)if x.creator is not None:funcs.append(x.creator)class Function:def __call__(self, input):x = input.datay = self.forward(x)                # 新增output = Variable(as_array(y))   # 转成 ndarray 类型  output.set_creator(self)  # 输出者保存创造者对象self.input = inputself.output = output  # 保存输出者。我是创造者的信息,这是动态建立 "连接"这 一 机制的核心return outputdef forward(self, x):raise NotImplementedError()  # 使用Function  这个方法forward 方法的人 , 这个方法应该通过继承采实现def backward(self, gy):raise NotImplementedError()class Square(Function):def forward(self, x):y = x ** 2return ydef backward(self, gy):x= self.input.datagx = 2 * x * gy     #方法的参数 gy 是 一个 ndarray 实例 , 它是从输出传播而来的导数 。return gxclass Exp(Function):def forward(self, x):y = np.exp(x)return ydef backward(self, gy):x = self.input.datagx = np.exp(x) * gyreturn gxdef square(x):f = Square()return f(x)def exp(x):f = Exp()return f(x)def as_array(x):        # 新增if np.isscalar(x):  # 使用 np.isscalar 函数来检查 numpy.float64 等属于标量return np.array(x)return xif __name__ == '__main__':x = Variable(np.array(0.5))a = square(x)b = exp(a)y = square(b)y.grad = np.array(1.0)y.backward()print(x.grad)# 优化ones_like 初始化后# 不需要定义 y.grad = np.array(1.0) 这个了x = Variable(np.array(0.5))y = square(exp(square(x)))y.backward()print(x.grad)# 错误使用x = Variable(np.array(1.0))x = Variable(None)#x = Variable(1.0)  # 错误使用# Numpy 特性问题x = np.array([1.0])y = x ** 2print(type(x), x.ndim)print(type(y))x= np.array(1.0)y = x ** 2print(type(x), x.ndim)print(type(y))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/22983.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

leetcode455.分发饼干、376. 摆动序列、53. 最大子序和

455.分发饼干 为了满足更多的小孩&#xff0c;就不要造成饼干尺寸的浪费 大尺寸的饼干既可以满足胃口大的孩子也可以满足胃口小的孩子&#xff0c;那么就应该优先满足胃口大的 这里的局部最优就是大饼干喂给胃口大的&#xff0c;充分利用饼干尺寸喂饱一个&#xff0c;全局最…

前端面试题日常练-day54 【面试题】

题目 希望这些选择题能够帮助您进行前端面试的准备&#xff0c;答案在文末 1. 在PHP中&#xff0c;以下哪个函数用于将一个字符串转换为日期时间对象&#xff1f; a) date() b) strtotime() c) datetime() d) time() 2. PHP中的超全局变量$_COOKIE用于存储什么类型的数据&a…

一致性hash算法的应用与go语言实现

一致性hash算法的应用与实现 设计目标&#xff1a;一致性hash算法的主要设计目标是在分布式系统中实现节点增减时数据映射关系的最小变动&#xff0c;从而保证数据的一致性和系统的稳定性。 一致性hash算法的应用场景 分布式负载均衡 一致性hash算法在分布式系统中得到广泛应…

揭秘AI 原生应用技术栈

一次性把“AI 原生应用技术栈”说明白 AI热潮持续&#xff0c;厂商努力推动有价值的应用涌现&#xff0c;并打造服务AI原始应用的平台产品。本文精简介绍业界最新的AI原生应用技术栈&#xff0c;让您迅速把握前沿科技脉搏。 整体架构 AI技术栈逻辑图精简呈现&#xff0c;多层…

图形学初识--透视修正

文章目录 前言正文为什么需要透视矫正&#xff1f;1、视图坐标空间--->NDC坐标空间&#xff08;透视投影&#xff09;&#xff08;1&#xff09;直线&#xff1a;&#xff08;2&#xff09;三角形&#xff1a;总结&#xff1a; 2、NDC坐标空间--->屏幕坐标空间&#xff0…

PID控制算法介绍及使用举例

PID 控制算法是一种常用的反馈控制算法&#xff0c;用于控制系统的稳定性和精度。PID 分别代表比例&#xff08;Proportional&#xff09;、积分&#xff08;Integral&#xff09;和微分&#xff08;Derivative&#xff09;&#xff0c;通过组合这三个部分来调节控制输出&#…

因子区间[牛客周赛44]

思路分析: 我们可以发现125是因子个数的极限了,所以我们可以用二维数组来维护第几个数有几个因子,然后用前缀和算出来每个区间合法个数,通过一个排列和从num里面选2个 ,c num 2 来计算即可 #include<iostream> #include<cstring> #include<string> #include…

用户反馈解决方案 —— 兔小巢构建反馈功能

目录 01: 前言 02: 用户反馈整体实现方案分析 03: 兔小巢全解析 04: 基于兔小巢实现用户反馈 05: 总结 01: 前言 在前台系统中&#xff0c;用户反馈 功能也是一个非常常见的需求。 通过反馈功能&#xff0c;我们可以知道当前的应用存在的一些不足和用户相应的一些诉求。…

【Linux系统】进程信号

本篇博客整理了进程信号从产生到处理的过程细节&#xff0c;通过不同过程中的系统调用和其背后的原理&#xff0c;旨在让读者更深入地理解操作系统的设计与软硬件管理手段。 目录 一、信号是什么 1.以生活为鉴 2.默认动作与自定义动作 3.信号的分类、保存、产生 二、产生…

蓝桥杯物联网竞赛_STM32L071KBU6_解决脉冲输出频率数值不稳定BUG

问题&#xff1a; 在用脉冲进行做题的时候发现脉冲输出的频率随着脉冲数值增大而越来越不稳定 典型的情况是10000HZ的时候会变成0HZ或者infHZ也就是无穷大 代码&#xff1a; int BEIGNNUMBER 0; int ENDNUMBER 0; unsigned char STATETIM 0; void HAL_TIM_IC_CaptureCal…

ChatGPT-4o抢先体验

速度很快&#xff0c;结果很智能&#xff0c;支持多模态输入输出&#xff0c;感兴趣联系作者

WordPress电脑版+手机版自动识别切换主题插件优化版

下载地址&#xff1a;WordPress电脑版手机版自动识别切换主题插件优化版 插件介绍&#xff1a; 电脑访问自动显示电脑版 手机访问自动显示手机版

FreeRTOS学习笔记-基于stm32(10)事件标志组

一、事件位、事件标志 事件位用来表明某个事件是否发生&#xff0c;事件位通常用作事件标志&#xff0c;就类似标志位。 二、事件标志组 一个事件组就是一组的事件位&#xff0c;就类似一个标志位寄存器&#xff0c;寄存器的每一位都是一个事件标志。与队列&#xff0c;信号量…

电子电器架构 --- 智能座舱技术分类

电子电器架构 — 智能座舱技术分类 我是穿拖鞋的汉子,魔都中坚持长期主义的汽车电子工程师。 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 屏蔽力是信息过载时代一个人的特殊竞争力,任何消耗你的人和事,多看一眼都是你的不对。非必要不费力证明自己,…

3. ChaosBlade-Box平台安装

ChaosBlade-Box平台安装 参考自&#xff1a;https://chaosblade.io/docs/about-chaosblade/box-introduce/ 通过ChaosBlade-Box可实现 chaosblade、litmuschaos 等已托管工具自动化部署&#xff0c;按照社区的建立的混沌实验模型统一实验场景&#xff0c;根据主机、Kubernete…

Java基础29(编码算法 哈希算法 MD5 SHA—1 HMac 算法 堆成加密算法)

目录 一、编码算法 1. 常见编码 2. URL编码 3. Base64编码 4. 小结 二、哈希算法 1. 哈希碰撞 2. 常用哈希算法 MD5算法 SHA-1算法 自定义HashTools工具类 3. 哈希算法的用途 校验下载文件 存储用户密码 4. 小结 三、Hmac算法 小结&#xff1a; 四、对称加密…

js获取blob格式的json对象

我们上传文件时可能会携带某些参数&#xff0c;比如 let formData new FormData() formData.append("data",new Blob([JSON.stringify(this.params)],{type: "application/json }))当我们直接取时发现会取到一个file类型的对象&#xff0c;无法取到值 formDa…

鸿蒙轻内核M核源码分析系列七 动态内存Dynamic Memory

内存管理模块管理系统的内存资源&#xff0c;它是操作系统的核心模块之一&#xff0c;主要包括内存的初始化、分配以及释放。 在系统运行过程中&#xff0c;内存管理模块通过对内存的申请/释放来管理用户和OS对内存的使用&#xff0c;使内存的利用率和使用效率达到最优&#x…

【C++小知识】为什么C语言不支持函数重载,而C++支持

为什么C语言不支持函数重载&#xff0c;而C支持 编译链接过程函数名修饰过程总结 在了解C函数重载前&#xff0c;如果对文件的编译与链接不太了解。可以看看我之前的一篇文章&#xff0c;链接: 文件的编译链接 想要清楚为什么C语言不支持函数重载而C支持&#xff0c;有俩个过程…

FFmpeg中 Scaler 使用文档介绍

描述 在 FFmpeg 中,swscale 是一个用于图像缩放和像素格式转换的库,它是 libswscale 的一部分。这个库提供了一系列的功能,允许开发者在视频处理过程中改变视频帧的尺寸和像素格式。以下是 swscale 的一些关键点: 图像重缩放:swscale 允许开发者对视频帧进行尺寸调整,这在…