深度学习:匿名函数lambda函数的使用

背景:
假设我们有一个简单的线性回归模型,其损失函数是均方误差(MSE):

class LinearModel:def __init__(self):self.W = np.random.randn(1, 1)  # 初始化权重def predict(self, x):return np.dot(x, self.W)  # 线性预测def loss(self, x, t):y_pred = self.predict(x)return np.mean((y_pred - t) ** 2)  # 均方误差# 创建模型实例
model = LinearModel()# 定义输入数据和目标标签
x = np.array([[1], [2], [3]])
t = np.array([[2], [4], [6]])# 定义损失函数
loss_W = lambda W : model.loss(x, t)# 计算损失
current_loss = loss_W(model.W)
print(current_loss)

loss_W = lambda W : self.loss(x, t)
是一个 lambda 函数,它接受一个参数 W 并返回损失函数的值。具体来说:

lambda W : self.loss(x, t) 定义了一个匿名函数(lambda 函数),它接受一个参数 W。self.loss(x, t) 是损失函数的调用,其中 x 和 t 是损失函数的输入参数。

因此,loss_W 的作用是将参数 W 传递给损失函数 self.loss,并返回损失函数的计算结果。这里的 W 通常是模型中的权重参数,而 x 和 t 分别是输入数据和目标标签。

总结一下,loss_W 是一个接受权重参数 W 并返回损失值的函数。
当然可以将 loss_W 写成一个普通的函数,而不是使用 lambda 表达式。以下是如何将 loss_W 写成一个普通函数的示例:

class LinearModel:def __init__(self):self.W = np.random.randn(1, 1)  # 初始化权重def predict(self, x):return np.dot(x, self.W)  # 线性预测def loss(self, x, t):y_pred = self.predict(x)return np.mean((y_pred - t) ** 2)  # 均方误差# 创建模型实例
model = LinearModel()# 定义输入数据和目标标签
x = np.array([[1], [2], [3]])
t = np.array([[2], [4], [6]])# 定义损失函数
def loss_W(W):model.W = W  # 更新模型的权重return model.loss(x, t)# 计算损失
current_loss = loss_W(model.W)
print(current_loss)

在这个例子中:

loss_W 是一个普通的函数,它接受 W 作为参数。在函数内部,首先更新模型的权重 model.W 为传入的 W。然后调用 model.loss(x, t) 计算损失,并返回计算结果。

这种方式与使用 lambda 函数的效果是相同的,但代码的可读性可能会有所提高,特别是对于复杂的逻辑。

以下是梯度的代码,通过下述代码更加深入了解lambda:

def numerical_gradient(f, x):h = 1e-4grad = np.zeros_like(x)for idx in range(x.size):tmp_val = x[idx]x[idx] = tmp_val + hfxh1 = f(x)x[idx] = tmp_val - hfxh2 = f(x)grad[idx] = (fxh1 - fxh2) / (2*h)x[idx] = tmp_valreturn grad

下面的numerical_gradient函数是调用上面函数的

  def numerical_gradient(self, x, t):loss_W = lambda W: self.loss(x, t)grads = {}grads['W1'] = numerical_gradient(loss_W, self.params['W1'])grads['b1'] = numerical_gradient(loss_W, self.params['b1'])grads['W2'] = numerical_gradient(loss_W, self.params['W2'])grads['b2'] = numerical_gradient(loss_W, self.params['b2'])return grads

loss函数为:

   def predict(self, x):W1, b1 = self.params['W1'], self.params['b1']W2, b2 = self.params['W2'], self.params['b2']a1 = np.dot(x, W1) + b1z1 = sigmoid(a1)a2 = np.dot(z1, W2) + b2y = a2return ydef loss(self, x, t):y = self.predict(x)return self.lastLayer.forward(y, t)

所以在下述代码中

 loss_W = lambda W: self.loss(x, t)grads = {}grads['W1'] = numerical_gradient(loss_W, self.params['W1'])grads['b1'] = numerical_gradient(loss_W, self.params['b1'])grads['W2'] = numerical_gradient(loss_W, self.params['W2'])grads['b2'] = numerical_gradient(loss_W, self.params['b2'])

例如grads[‘W1’] = numerical_gradient(loss_W, self.params[‘W1’]) 会调用第一个 numerical_gradient函数用
(f(x+h) - f(x-h))/2*h计算梯度,而由于匿名函数有更新参数的作用,所以当x=self.params[‘W1’]时,计算f(x+h)本例即匿名函数loss_W时会自动将模型中的self.params[‘W1’]=self.params[‘W1’]+h,作用就是匿名函数返回的self.loss(x, t)调用的predict函数里的对应参数会相应更新,这样即可获得在更新后的W1条件下对应的predict输出值从而计算loss。
同理以下
grads[‘b1’] = numerical_gradient(loss_W, self.params[‘b1’])
grads[‘W2’] = numerical_gradient(loss_W, self.params[‘W2’])
grads[‘b2’] = numerical_gradient(loss_W, self.params[‘b2’])
也是一样的原理,使用匿名函数可以在改变后的参数下,返回需要的函数值,很方便。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/57799.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

总结OpenGL和pyrender安装和使用过程中的坑

目录 报错一:AttributeError: NoneType object has no attribute glGetError 报错二:ImportError: (Unable to load OpenGL library, OSMesa: cannot open shared object file: No such file or directory, OSMesa, None) 报错三:raise ImportError("Unable to load…

计算机网络:数据链路层 —— 以太网(Ethernet)

文章目录 局域网局域网的主要特征 以太网以太网的发展100BASE-T 以太网物理层标准 吉比特以太网载波延伸物理层标准 10吉比特以太网汇聚层交换机物理层标准 40/100吉比特以太网传输媒体 局域网 局域网(Local Area Network, LAN)是一种计算机网络&#x…

协议 HTTP

目录 1. 基本概念 2. HTTP 方法 3. 状态码 4. 请求和响应结构 5. HTTPS 6. 其他特性 7. 常见应用 HTTP(超文本传输协议)是一种用于在网络上传输超文本的协议,是万维网(WWW)上的基础协议。以下是关于HTTP协议的一…

Newstar_week1_week2_wp

week1 wp crypto 一眼秒了 n费马分解再rsa flag: import libnum import gmpy2 from Crypto.Util.number import * p 9648423029010515676590551740010426534945737639235739800643989352039852507298491399561035009163427050370107570733633350911691280297…

PostgreSQL的学习心得和知识总结(一百五十六)|auto_explain — log execution plans of slow queries

目录结构 注:提前言明 本文借鉴了以下博主、书籍或网站的内容,其列表如下: 1、参考书籍:《PostgreSQL数据库内核分析》 2、参考书籍:《数据库事务处理的艺术:事务管理与并发控制》 3、PostgreSQL数据库仓库…

python-PyQt项目实战案例:制作一个视频播放器

文章目录 1. 关键问题描述2. 通过OpenCV读取视频/打开摄像头抓取视频3. 通过PyQt 中的 QTimer定时器实现视频播放4. PyQt 视频播放器实现代码参考文献 1. 关键问题描述 在前面的文章中已经分享了pyqt制作图像处理工具的文章,也知道pyqt通过使用label控件显示图像的…

庆祝程序员节:聊一聊编程语言的演变

人不走空 🌈个人主页:人不走空 💖系列专栏:算法专题 ⏰诗词歌赋:斯是陋室,惟吾德馨 目录 🌈个人主页:人不走空 💖系列专栏:算法专题 ⏰诗词歌…

[HBase]二 HBase原生Shell命令大全

HBase原生Shell命令汇总 1. General组 5 1.1. 查看集群状态:status 5 1.2. 查看表的操作方法:table_help 5 1.3. 查看HBase的版本信息:version 5 1.4. 查看当前用户:whoami 5 2. Namespace组 5 2.1. 创建命名空间:create_namespace 5 2.2. 显示命名空…

qt配置https请求

qt应用版本 windows 32位 先说下心理路程,你能遇到的我都遇到了,你能想到的我都想到了,怎么解决看这一篇就够了,从上午12点到晚上12点几乎没离开电脑(除了吃饭),对于openssl这种用的时候无感&am…

idea 2023 创建 springboot 项目 LTS

idea 2023 创建 springboot 项目 LTS idea 版本 2023.3.8 参考 idea 阿里 建立 springboot 工程 方法 LTS https://blog.csdn.net/wowocpp/article/details/124692532 File ---- New ---- Project https://start.spring.io/ http://start.aliyun.com http://127.0.0.1:8080…

AI智能电销机器人有什么功能?语音机器人系统搭建

AI智能电销机器人是一种基于人工智能技术进行自动化电话销售的系统,广泛应用于市场营销、客户关系管理和销售流程中。它通过模拟人工销售人员的行为,能够完成多项重要功能,从而提高销售效率、降低人力成本。以下是AI智能电销机器人的主要功能…

旺店通对接金蝶云星空销售出库接口细节

数据集成是确保各系统高效协同运作的关键环节。本案例将重点介绍如何通过轻易云数据集成平台,实现旺店通旗舰奇门与金蝶云星空之间的销售出库数据对接,具体方案为“销售出库对接,供应商发货-new”。 在本次集成过程中,我们利用了…

Angular 保姆级别教程高阶应用 - RxJs

RxJS 13.1.1 什么是 RxJS ? RxJS 是一个用于处理异步编程的 JavaScript 库,目标是使编写异步和基于回调的代码更容易。 13.1.2 为什么要学习 RxJS ? 就像 Angular 深度集成 TypeScript 一样,Angular 也深度集成了 RxJS。 服务、表单、事件、全局状…

Qt 文本文件读写与保存

Qt 文本文件读写与保存 开发工具&#xff1a;VS2013 QT5.8 设计UI界面&#xff0c;如下图所示 sample7_1QFile.h 头文件&#xff1a; #pragma once#include <QtWidgets/QMainWindow> #include "ui_sample7_1QFile.h"class sample7_1QFile : public QMainWin…

1024玩码神挑战赛,太太太上头了!!!

闯关链接&#xff1a;编程导航-码神挑战 第1关 提示&#xff1a; 直接转ASKII码 第2关 提示&#xff1a; 最常用的快捷键&#xff08;cv&#xff09; 第3关 提示&#xff1a; 答案在网址栏 第4关 提示&#xff1a; 输入表示蓝色区域的这种颜色的16进制代码&#xff0c;在网页代…

【openAI】机器学习算法

文章目录 CSDN 前言 &#x1f4ac; 欢迎讨论&#xff1a;如果你在学习过程中有任何问题或想法&#xff0c;欢迎在评论区留言&#xff0c;我们一起交流学习。你的支持是我继续创作的动力&#xff01; &#x1f44d; 点赞、收藏与分享&#xff1a;觉得这篇文章对你有帮助吗&…

SQL实战测试

SQL实战测试 &#xff08;请写下 SQL 查询语句&#xff0c;不需要展示结果&#xff09; 表 a DateSalesCustomerRevenue2019/1/1张三A102019/1/5张三A18 1. **用一条 ** SQL 语句写出每个月&#xff0c;每个销售有多少个客户收入多少 输出结果表头为“月”&#xff0c;“销…

i春秋web题库——题目名称:SQLi

WEB——SQLi 写在之前&#xff1a;题目简介&#xff1a;题目分析&#xff1a; 写在之前&#xff1a; 本题在CSDN上或是其它博客上有过解答&#xff0c;只不过不知是什么原因&#xff0c;我没有找到解题过程比较完整的文章。于是我决定在CTF初学阶段写一篇这样的博客&#xff0…

【lca,树上差分】P3128 [USACO15DEC] Max Flow P

题意 给定大小为 n ( 2 ≤ n ≤ 5 1 0 4 ) n(2 \leq n \leq 5 \times 10^4) n(2≤n≤5104) 的树&#xff0c;并给定 m ( 1 ≤ m ≤ 1 0 5 ) m(1 \leq m \leq 10^5) m(1≤m≤105) 条树上的路径&#xff08;给定两个端点&#xff0c;容易证明两个点树上路径唯一&#xff09;&…

迭代器失效和支持随机访问的容器总结

创作活动 迭代器失效&#xff1a; 顺序容器&#xff08;如vector、deque、list&#xff09; vector 插入操作&#xff1a; 当在vector中间或头部插入元素时&#xff0c;所有位于插入点之后的迭代器都会失效。这是因为vector的元素在内存中是连续存储的&#xff0c;插入元素可能…