Keras深度学习框架基础第二讲:层接口(layers API)第二部分“基本层类”

1、layer 类

典型的layer类如下

keras.layers.Layer(activity_regularizer=None,trainable=True,dtype=None,autocast=True,name=None,**kwargs
)

这是一个所有层都继承的基类。

一个层是一个可调用的对象,它接受一个或多个张量作为输入,并输出一个或多个张量。它涉及计算,这些计算在call()方法中定义,并且有一个状态(权重变量)。状态可以在以下两种方式中创建:

  • __init__()方法中,例如通过self.add_weight()
  • 在可选的build()方法中,这个方法会在第一次调用该层的__call__()时被调用,并提供输入的形状,这些形状可能在初始化时未知。

层是递归可组合的:如果你将一个层实例作为另一个层的属性,外部层将开始跟踪内部层创建的权重。嵌套层应该在__init__()方法或build()方法中实例化。

用户只需实例化一个层,然后将其当作可调用的对象来使用。

参数

  • trainable: 布尔值,表示该层的变量是否应该是可训练的。
  • name: 字符串,表示层的名称。
  • dtype:层的计算和权重的数据类型。也可以是一个keras.DTypePolicy,它允许计算和权重的数据类型不同。默认为None。如果为None,则使用keras.config.dtype_policy(),这通常是一个float32策略,除非通过keras.config.set_dtype_policy()设置为不同的值。

属性

  • name: 层的名称(字符串)。
  • dtype: 层权重的数据类型。是layer.variable_dtype的别名。
  • variable_dtype: 层权重的数据类型。
  • compute_dtype:层计算的数据类型。层会自动将输入转换为这个数据类型,从而使得计算和输出也在这个数据类型下。当使用混合精度与keras.DTypePolicy时,这可能与variable_dtype不同。
  • trainable_weights: 应在反向传播中包括的变量列表。
  • non_trainable_weights:不应在反向传播中包括的变量列表。
  • weights:trainable_weightsnon_trainable_weights列表的合并(按此顺序)。
  • trainable:该层是否应该被训练(布尔值),即其潜在的可训练权重是否应作为layer.trainable_weights的一部分返回。
  • input_spec: 可选的(一组)InputSpec对象,指定层可以接受的输入的约束。

推荐Layer的子类实现以下方法

  • __init__(self): 定义自定义层属性,并使用add_weight()或其他状态创建不依赖于输入形状的层权重。
  • build(self, input_shape):
    此方法可用于创建依赖于输入形状(s)的权重,使用add_weight()或其他状态。当__call__()被调用时(如果层尚未被构建),它将自动调用build()来构建层。
  • call(self, *args, **kwargs):
    在确保build()已被调用后,在__call__()中被调用。call()方法执行将层应用于输入参数的逻辑。在call()中,你可以选择性地使用两个保留的关键字参数:1. training(布尔值,表示调用是否处于推理模式或训练模式)。2. mask(布尔张量,编码输入中屏蔽的时间步,例如在RNN层中使用)。该方法的一个典型签名是call(self, inputs),如果用户需要,还可以添加trainingmask
  • get_config(self):返回一个字典,包含用于初始化此层的配置。如果字典的键与__init__()中的参数不同,则还需要重写from_config(self)方法。此方法在保存层或包含此层的模型时使用。

示例
以下是一个基础示例,演示了一个包含两个变量w和b的层,它实现了y = w * x + b的计算。这个示例展示了如何实现build()和call()方法,以及如何将变量设置为层的属性以跟踪为层的权重(在layer.weights中)。

class SimpleDense(Layer):def __init__(self, units=32):super().__init__()self.units = units# Create the state of the layer (weights)def build(self, input_shape):self.kernel = self.add_weight(shape=(input_shape[-1], self.units),initializer="glorot_uniform",trainable=True,name="kernel",)self.bias = self.add_weight(shape=(self.units,),initializer="zeros",trainable=True,name="bias",)# Defines the computationdef call(self, inputs):return ops.matmul(inputs, self.kernel) + self.bias# Instantiates the layer.
linear_layer = SimpleDense(4)# This will also call `build(input_shape)` and create the weights.
y = linear_layer(ops.ones((2, 2)))
assert len(linear_layer.weights) == 2# These weights are trainable, so they're listed in `trainable_weights`:
assert len(linear_layer.trainable_weights) == 2

当提到除了通过反向传播在训练过程中更新的可训练权重之外,层还可以具有非可训练权重。这些权重意味着在call()方法调用期间需要手动更新。以下是一个示例层,它计算其输入的累积和(running sum):

class ComputeSum(Layer):def __init__(self, input_dim):super(ComputeSum, self).__init__()# Create a non-trainable weight.self.total = self.add_weight(shape=(),initializer="zeros",trainable=False,name="total",)def call(self, inputs):self.total.assign(self.total + ops.sum(inputs))return self.totalmy_sum = ComputeSum(2)
x = ops.ones((2, 2))
y = my_sum(x)assert my_sum.weights == [my_sum.total]
assert my_sum.non_trainable_weights == [my_sum.total]
assert my_sum.trainable_weights == []

weights属性

keras.layers.Layer.weights

层的所有权重变量的列表。

与 layer.variables 不同,这排除了度量状态和随机种子。

在 TensorFlow 的 Keras API 中,layer.weights 是一个常用的属性,它返回构成层权重的所有变量的列表。这些权重变量是在训练过程中通过反向传播进行更新的。而 layer.variables 属性则包括了层中的所有变量,不仅限于权重,还包括度量状态(例如用于计算损失或准确率的变量)和可能用于初始化层的随机种子等。

因此,当您想要获取并操作层的权重时,通常使用 layer.weights 而不是 layer.variables

trainable_weights属性

keras.layers.Layer.trainable_weights

层的所有可训练权重变量的列表。

这些是在训练过程中由优化器更新的权重。

在TensorFlow的Keras框架中,当你创建一个神经网络层时,该层可能包含多个权重变量。这些权重变量中的一部分是可训练的,意味着在训练模型(即通过反向传播更新权重以最小化损失函数)时,它们会被优化器(如Adam、SGD等)更新。layer.trainable_weights属性返回的就是这些可训练权重变量的列表。
non_trainable_weights属性

keras.layers.Layer.non_trainable_weights

层的所有非可训练权重变量的列表。

这些是在训练过程中不应由优化器更新的权重。与 layer.non_trainable_variables 不同,这排除了度量状态和随机种子。

在TensorFlow的Keras框架中,一个层可能包含一些权重变量,这些变量在训练过程中不应被优化器更新。这些权重变量通常用于存储一些固定的参数或状态,如批量归一化层中的运行均值和方差。layer.non_trainable_weights属性返回的就是这些非可训练权重变量的列表。注意,与layer.non_trainable_variables不同,这个列表仅包含权重变量,而不包括度量状态或随机种子等其他非权重变量。

2、add_weight方法

Layer.add_weight(shape=None,initializer=None,dtype=None,trainable=True,autocast=True,regularizer=None,constraint=None,aggregation="mean",name=None,
)

参数说明

shape:变量的形状元组。必须完全定义(没有None条目)。如果未指定,则默认为()(即标量)。

initializer:用于填充初始变量值的初始化器对象,或者是内置初始化器的字符串名称(例如"random_normal")。如果未指定,对于浮点变量默认为"glorot_uniform",对于其他所有类型(例如int, bool)则默认为"zeros"。

dtype:要创建的变量的数据类型,例如"float32"。如果未指定,则默认为层的变量数据类型(如果层也未指定,则默认为"float32")。

trainable:布尔值,指示该变量是否应通过反向传播进行训练,或者其更新是否由人工管理。默认为True。

autocast:布尔值,指示在访问变量时是否自动进行类型转换。默认为True。

regularizer:正则化器对象,用于在权重上应用惩罚项。这些惩罚项在优化过程中被添加到损失函数中。默认为None。

constraint:约束对象,在优化器更新后应用于变量,或者是内置约束的字符串名称。默认为None。

aggregation:字符串,可选值为’mean’、‘sum’、‘only_first_replica’。为变量添加注解,表示在编写自定义数据并行训练循环时,应使用哪种多副本聚合类型。

name:变量的字符串名称。对于调试很有用。

trainable属性

keras.layers.Layer.trainable

可设置的布尔值,表示此层是否应该可训练。

3、get_weights方法

Layer.get_weights()

返回层的权重值存入NumPy数组的列表。

4 、set_weights方法

Layer.set_weights(weights)

通过NumPy数组的列表设置层的权重值。

5、get_config方法

Model.get_config()

返回对象的配置。

对象的配置是一个Python字典(可序列化),包含了重新实例化该对象所需的信息。

6、add_loss方法

Layer.add_loss(loss)

可以在call()方法内部调用以添加一个标量损失。

在Keras的自定义层或模型中,有时我们可能需要在前向传播(即call()方法)中直接计算某些损失。例如,在正则化层中,我们可能想要根据层的权重或输出计算一个损失项。为了在训练过程中包含这个损失,我们通常会使用add_loss()方法。

add_loss()方法允许你在call()方法内部添加一个标量损失,这个损失将在反向传播时被考虑进去。这通常用于实现自定义的正则化、约束或其他需要在前向传播中计算的损失项。

class MyLayer(Layer):...def call(self, x):self.add_loss(ops.sum(x))return x

losses属性

keras.layers.Layer.losses

add_loss、正则化器和子层中获取的标量损失列表。

在Keras中,当您使用add_loss方法在层或模型中添加损失时,这些损失会被收集起来并在训练过程中被考虑。同样,如果层或模型有正则化器(如权重衰减),那么这些正则化器产生的损失也会被添加到损失列表中。此外,如果层有子层(即嵌套在其他层中的层),那么这些子层的损失也会被包含在内。

这些标量损失在训练过程中会被累加,并用于计算总损失,然后用于反向传播以更新模型的权重。

注意:这些损失通常是在call方法或其他层/模型的方法中通过add_loss方法添加的,并且是在模型编译后、训练开始前计算的。在模型编译之前,losses列表可能为空或只包含由正则化器产生的损失。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/840892.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【spring】@ControllerAdvice注解学习

ControllerAdvice介绍 ControllerAdvice 是 Spring 框架提供的一个注解,用于定义一个全局的异常处理类或者说是控制器增强类(controller advice class)。这个特性特别适用于那些你想应用于整个应用程序中多个控制器的共有行为,比…

ctfhub中的SSRF的相关例题(下)

目录 URL Bypass 知识点 相关例题 数字IP Bypass 相关例题 方法一:使用数字IP 方法二:转16进制 方法三:用localhost代替 方法四:特殊地址 302跳转 Bypass ​编辑 关于localhost原理: DNS重绑定 Bypass 知识点&…

ant design pro 6.0搭建教程

一、搭建 环境: Node.js 18.16.1 ant design pro 6.0 注意:选择umi3时,使用node.js 18版本的会报错,可以实践一下,这里就不再进行实践了。 umi3需要版本是低于node.js 18的 node下载地址: https://nodejs.…

可重构柔性装配产线,为智能制造领域带来了新的革命性变革

随着科技的飞速发展,个性化需求逐渐成为市场的主导。在这个充满变革的时代,制造业正面临着前所未有的挑战和机遇。如何快速响应市场需求、提高生产效率、保证产品质量,成为每一家制造企业必须思考的问题。 在这样的背景下,富唯智…

免费插件集-illustrator插件-Ai插件-文本对象和文本段落互转

文章目录 1.介绍2.安装3.通过窗口>扩展>知了插件4.功能解释5.总结 1.介绍 本文介绍一款免费插件,加强illustrator使用人员工作效率,进行文本对象和文本段落互转。首先从下载网址下载这款插件 https://download.csdn.net/download/m0_67316550/878…

00.OpenLayers快速开始

00OpenLayers快速开始 官方文档: 快速开始:https://openlayers.org/doc/quickstart.html 需要node环境 一、设置新项目 npm create ol-app my-app cd my-app npm start第一个命令将创建一个名为 my-app​ 的目录(如果您愿意,…

赞扬老师的词汇积累

1 词汇 启明星 象征意义:在古代文化中,启明星(即金星,特别是在黎明前出现在东方天空的那颗亮星)常被视为新一天开始的象征,预示着光明和希望的到来。因此,将老师们比喻为“启明星”&#xff0…

Java——简易图书管理系统

本文使用 Java 实现一个简易图书管理系统 一、思路 简易图书管理系统说白了其实就是 用户 与 图书 这两个对象之间的交互 书的属性有 书名 作者 类型 价格 借阅状态 而用户可以分为 普通用户 管理员 使用数组将书统一管理起来 用户对这个数组进行操作 普通用户可以进…

有趣的css - 圆形背景动效多选框

大家好,我是 Just,这里是「设计师工作日常」,今天分享的是用 css 实现一个圆形背景动效多选框,适用提醒用户勾选场景,突出多选框选项,可以有效增加用户识别度。 最新文章通过公众号「设计师工作日常」发布…

js画思维导图代码2

这段代码是一个使用Vue.js和D3.js构建的树形图组件。它是一个Vue组件&#xff0c;用于创建和显示一个交互式的树形结构图。下面是对这段代码的简要分析&#xff1a; 模板部分 (<template>): 定义了组件的HTML结构&#xff0c;包括一个隐藏的提示框(#tooltip)和一个用于显…

VBA批量合并带有图片、表格与文本框的Word

本文介绍基于VBA语言&#xff0c;对大量含有图片、文本框与表格的Word文档加以批量自动合并&#xff0c;并在每一次合并时添加分页符的方法。 在我们之前的文章基于Python中docx与docxcompose批量合并多个Word文档文件并逐一添加分页符&#xff08;https://blog.csdn.net/zhebu…

helloworld 可执行程序得到的过程

// -E 预处理 开发过程中可以确定某个宏 // -c 把预处理 编译 汇编 都做了,但是不链接 // -o 指定输出文件 // -I 指定头文件目录 // -L 指定链接库文件目录 // -l 指定链接哪一个库文件 #include <stdio.h> #include <stdlib.h> #include <string.h>int mai…

【postgresql初级使用】在表的多个频繁使用列上创建一个索引,多条件查询优化,多场景案例揭示索引失效

多列索引 ​专栏内容&#xff1a; postgresql使用入门基础手写数据库toadb并发编程 个人主页&#xff1a;我的主页 管理社区&#xff1a;开源数据库 座右铭&#xff1a;天行健&#xff0c;君子以自强不息&#xff1b;地势坤&#xff0c;君子以厚德载物. 文章目录 多列索引概述 …

【微积分】CH16 integrals and vector fields听课笔记

【托马斯微积分学习日记】13.1-线积分_哔哩哔哩_bilibili 概述 16.1line integrals of scalar functions [中英双语]可视化多元微积分 - 线积分介绍_哔哩哔哩_bilibili 16.2vector fields and line integrals&#xff1a; work circulation and flux 向量场差不多也是描述某种…

gpt-4o继续迭代考场安排程序 一键生成考场清单

接上两篇gpt-4o考场安排-CSDN博客&#xff0c;考场分层次安排&#xff0c;最终exe版-CSDN博客 当然你也可以只看这一篇。 今天又添加了以下功能&#xff0c;程序见后。 1、自动分页&#xff0c;每个考场打印一页 2、添加了打印试场单页眉 3、添加了页脚 第X页&#xff0c;…

Leetcode刷题笔记1:数组基础1

导语 leetcode刷题笔记记录&#xff0c;本篇博客记录数组基础1部分的题目&#xff0c;主要题目包括&#xff1a; Leetcode 704 二分查找Leetcode 27 移除元素 知识点 二分查找 原理 二分查找的适用对象为有序数组且数组中无重复元素&#xff0c;其主要原理是每次都从有序…

2024年5月软考架构题目回忆分享

十年架构两茫茫 &#xff0c;Redis , UML 夜来幽梦忽还乡 &#xff0c; 大数据&#xff0c; Lambda 选择题 1.需求分析和架构设计面临这两个不同对象&#xff0c;一个是问题空间&#xff0c;一个是解空间 这是英文题&#xff0c;总共五个题目&#xff0c;只记得这么多 2. …

AI视频教程下载:全面掌握ChatGPT和LangChain开发AI应用(附源代码)

这是一门深入的课程&#xff0c;涉及ChatGPT、LangChain和Python。打造专注于现实世界AI集成的AI应用&#xff0c;课件附有每一节涉及到的源代码。 **你将学到什么&#xff1a;** - 将ChatGPT集成到LangChain的生产风格应用中 - 使用LangChain组件构建复杂的文本生成管道 - …

order by 优化

1. 排序方式 MySQL支持两种方式的排序&#xff0c;FileSort和Index&#xff1a; Index的效率高&#xff0c;它指MySQL根据索引本身完成排序。FileSort方式效率较低&#xff0c;是指MySQL自己扫描数据之后进行排序&#xff0c;没有使用到index 因此&#xff0c;我们要让order…

推荐五个线上兼职,在家也能轻松日入百元,适合上班族和全职宝妈

在这个瞬息万变的时代&#xff0c;你是否也曾考虑过在繁忙的工作之外&#xff0c;寻找一份兼职副业来补贴家用&#xff0c;同时保持生活的多样性&#xff1f;别急&#xff0c;现在就让我为你揭秘五个可靠的日结线上兼职岗位&#xff0c;助你轻松迈向财务自由之路&#xff01; 一…