深度学习处理文本(10)

保存自定义层

在编写自定义层时,一定要实现get_config()方法:这样我们可以利用config字典将该层重新实例化,这对保存和加载模型很有用。该方法返回一个Python字典,其中包含用于创建该层的构造函数的参数值。所有Keras层都可以被序列化(serialize)和反序列化(deserialize)​,如下所示。

config = layer.get_config()
new_layer = layer.__class__.from_config(config)---- config不包含权重值,因此该层的所有权重都是从头初始化的

来看下面这个例子。

layer = PositionalEmbedding(sequence_length, input_dim, output_dim)
config = layer.get_config()
new_layer = PositionalEmbedding.from_config(config)

在保存包含自定义层的模型时,保存文件中会包含这些config字典。从文件中加载模型时,你应该在加载过程中提供自定义层的类,以便其理解config对象,如下所示。

model = keras.models.load_model(filename, custom_objects={"PositionalEmbedding": PositionalEmbedding})

你会注意到,这里使用的规范化层并不是之前在图像模型中使用的BatchNormalization层。这是因为BatchNormalization层处理序列数据的效果并不好。相反,我们使用的是LayerNormalization层,它对每个序列分别进行规范化,与批量中的其他序列无关。它类似NumPy的伪代码如下

def layer_normalization(batch_of_sequences):----输入形状:(batch_size, sequence_length, embedding_dim)mean = np.mean(batch_of_sequences, keepdims=True, axis=-1)---- (本行及以下1)计算均值和方差,仅在最后一个轴(−1轴)上汇聚数据variance = np.var(batch_of_sequences, keepdims=True, axis=-1)return (batch_of_sequences - mean) / variance

下面是训练过程中的BatchNormalization的伪代码,你可以将二者对比一下。

def batch_normalization(batch_of_images):----输入形状:(batch_size, height, width, channels)mean = np.mean(batch_of_images, keepdims=True, axis=(0, 1, 2))---- (本行及以下1)在批量轴(0轴)上汇聚数据,这会在一个批量的样本之间形成相互作用variance = np.var(batch_of_images, keepdims=True, axis=(0, 1, 2))return (batch_of_images - mean) / variance

BatchNormalization层从多个样本中收集信息,以获得特征均值和方差的准确统计信息,而LayerNormalization层则分别汇聚每个序列中的数据,更适用于序列数据。我们已经实现了TransformerEncoder,下面可以用它来构建一个文本分类模型,如代码清单11-22所示,它与前面的基于GRU的模型类似。代码清单11-22 将Transformer编码器用于文本分类

vocab_size = 20000
embed_dim = 256
num_heads = 2
dense_dim = 32inputs = keras.Input(shape=(None,), dtype="int64")
x = layers.Embedding(vocab_size, embed_dim)(inputs)
x = TransformerEncoder(embed_dim, dense_dim, num_heads)(x)
x = layers.GlobalMaxPooling1D()(x)---- TransformerEncoder返回的是完整序列,所以我们需要用全局汇聚层将每个序列转换为单个向量,以便进行分类
x = layers.Dropout(0.5)(x)
outputs = layers.Dense(1, activation="sigmoid")(x)
model = keras.Model(inputs, outputs)
model.compile(optimizer="rmsprop",loss="binary_crossentropy",metrics=["accuracy"])
model.summary()

我们来训练这个模型,如代码清单11-23所示。模型的测试精度为87.5%,比GRU模型略低。代码清单11-23 训练并评估基于Transformer编码器的模型

callbacks = [keras.callbacks.ModelCheckpoint("transformer_encoder.keras",save_best_only=True)
]
model.fit(int_train_ds, validation_data=int_val_ds, epochs=20,callbacks=callbacks)
model = keras.models.load_model("transformer_encoder.keras",custom_objects={"TransformerEncoder": TransformerEncoder})----在模型加载过程中提供自定义的TransformerEncoder类
print(f"Test acc: {model.evaluate(int_test_ds)[1]:.3f}")

现在你应该已经开始感到有些不对劲了。你能看出是哪里不对劲吗?本节的主题是“序列模型”​。我一开始就强调了词序的重要性。我说过,Transformer是一种序列处理架构,最初是为机器翻译而开发的。然而……你刚刚见到的Transformer编码器根本就不是一个序列模型。你注意到了吗?它由密集层和注意力层组成,前者独立处理序列中的词元,后者则将词元视为一个集合。你可以改变序列中的词元顺序,并得到完全相同的成对注意力分数和完全相同的上下文感知表示。如果将每篇影评中的单词完全打乱,模型也不会注意到,得到的精度也完全相同。自注意力是一种集合处理机制,它关注的是序列元素对之间的关系,如图11-10所示,它并不知道这些元素出现在序列的开头、结尾还是中间。既然是这样,为什么说Transformer是序列模型呢?如果它不查看词序,又怎么能很好地进行机器翻译呢?

在这里插入图片描述

Transformer是一种混合方法,它在技术上是不考虑顺序的,但将顺序信息手动注入数据表示中。这就是缺失的那部分,它叫作位置编码(positional encoding)​。我们来看一下。

使用位置编码重新注入顺序信息

位置编码背后的想法非常简单:为了让模型获取词序信息,我们将每个单词在句子中的位置添加到词嵌入中。这样一来,输入词嵌入将包含两部分:普通的词向量,它表示与上下文无关的单词;位置向量,它表示该单词在当前句子中的位置。我们希望模型能够充分利用这一额外信息。你能想到的最简单的方法就是将单词位置与它的嵌入向量拼接在一起。你可以向这个向量添加一个“位置”轴。在该轴上,序列中的第一个单词对应的元素为0,第二个单词为1,以此类推。然而,这种做法可能并不理想,因为位置可能是非常大的整数,这会破坏嵌入向量的取值范围。如你所知,神经网络不喜欢非常大的输入值或离散的输入分布。

在“Attention Is All You Need”这篇原始论文中,作者使用了一个有趣的技巧来编码单词位置:将词嵌入加上一个向量,这个向量的取值范围是[-1, 1],取值根据位置的不同而周期性变化(利用余弦函数来实现)​。这个技巧提供了一种思路,通过一个小数值向量来唯一地描述较大范围内的任意整数。这种做法很聪明,但并不是本例中要用的。我们的方法更加简单,也更加有效:我们将学习位置嵌入向量,其学习方式与学习嵌入词索引相同。然后,我们将位置嵌入与相应的词嵌入相加,得到位置感知的词嵌入。这种方法叫作位置嵌入(positional embedding)​。我们来实现这种方法,如代码清单11-24所示。代码清单11-24 将位置嵌入实现为Layer子类

class PositionalEmbedding(layers.Layer):def __init__(self, sequence_length, input_dim, output_dim, **kwargs):----位置嵌入的一个缺点是,需要事先知道序列长度super().__init__(**kwargs)self.token_embeddings = layers.Embedding(----准备一个Embedding层,用于保存词元索引input_dim=input_dim, output_dim=output_dim)self.position_embeddings = layers.Embedding(input_dim=sequence_length, output_dim=output_dim)----另准备一个Embedding层,用于保存词元位置self.sequence_length = sequence_lengthself.input_dim = input_dimself.output_dim = output_dimdef call(self, inputs):length = tf.shape(inputs)[-1]positions = tf.range(start=0, limit=length, delta=1)embedded_tokens = self.token_embeddings(inputs)embedded_positions = self.position_embeddings(positions)return embedded_tokens + embedded_positions  ←----将两个嵌入向量相加def compute_mask(self, inputs, mask=None):---- (本行及以下1)与Embedding层一样,该层应该能够生成掩码,从而可以忽略输入中填充的0。框架会自动调用compute_mask方法,并将掩码传递给下一层return tf.math.not_equal(inputs, 0)def get_config(self):----实现序列化,以便保存模型config = super().get_config()config.update({"output_dim": self.output_dim,"sequence_length": self.sequence_length,"input_dim": self.input_dim,})return config

你可以像使用普通Embedding层一样使用这个PositionEmbedding层。我们来看一下它的实际效果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/75800.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

机器视觉3D中激光偏镜的优点

机器视觉的3D应用中,激光偏镜(如偏振片、波片、偏振分束器等)通过其独特的偏振控制能力,显著提升了系统的测量精度、抗干扰能力和适应性。以下是其核心优点: 1. 提升3D成像精度 抑制环境光干扰:偏振片可滤除非偏振的环境杂光(如日光、室内照明),仅保留激光偏振信号,大…

线程同步的学习与应用

1.多线程并发 1).多线程并发引例 #include <stdio.h> #include <stdlib.h> #include <unistd.h> #include <assert.h> #include <pthread.h>int wg0; void *fun(void *arg) {for(int i0;i<1000;i){wg;printf("wg%d\n",wg);} } i…

写.NET可以指定运行SUB MAIN吗?调用任意一个里面的类时,如何先执行某段初始化代码?

VB.NET 写.NET可以指定运行SUB MAIN吗?调用任意一个里面的类时,如何先执行某段初始化代码? 分享 1. 在 VB.NET 中指定运行 Sub Main 在 VB.NET 里&#xff0c;你能够指定 Sub Main 作为程序的入口点。下面为你介绍两种实现方式&#xff1a; 方式一&#xff1a;在项目属性…

【AI插件开发】Notepad++ AI插件开发实践(代码篇):从Dock窗口集成到功能菜单实现

一、引言 上篇文章已经在Notepad的插件开发中集成了选中即问AI的功能&#xff0c;这一篇文章将在此基础上进一步集成&#xff0c;支持AI对话窗口以及常见的代码功能菜单&#xff1a; 显示AI的Dock窗口&#xff0c;可以用自然语言向 AI 提问或要求执行任务选中代码后使用&…

关联容器-模板类pair数对

关联容器 关联容器和顺序容器有着根本的不同:关联容器中的元素是按关键字来保存和访问的,而顺序容器中的元素是按它们在容器中的位置来顺序保存和访问的。 关联容器支持高效的关键字查找和访问。 两个主要的关联容器(associative-container),set和map。 set 中每个元素只包…

京东运维面试题及参考答案

目录 OSPF 实现原理是什么? 请描述 TCP 三次握手的过程。 LVS 的原理是什么? 阐述 Nginx 七层负载均衡的原理。 Nginx 与 Apache 有什么区别? 如何查看监听在 8080 端口的是哪个进程(可举例:netstat -tnlp | grep 8080)? OSI 七层模型是什么,请写出各层的协议。 …

输入框输入数字且保持精度

在项目中如果涉及到金额等需要数字输入且保持精度的情况下&#xff0c;由于输入框是可以随意输入文本的&#xff0c;所以一般情况下可能需要监听输入框的change事件&#xff0c;然后通过正则表达式去替换掉不匹配的文本部分。 由于每次文本改变都会被监听&#xff0c;包括替换…

使用 requests 和 BeautifulSoup 解析淘宝商品

以下将详细解释如何通过这两个库来实现按关键字搜索并解析淘宝商品信息。 一、准备工作 1. 安装必要的库 在开始之前&#xff0c;确保已经安装了 requests 和 BeautifulSoup 库。如果尚未安装&#xff0c;可以通过以下命令进行安装&#xff1a; bash pip install requests…

C#调用ACCESS数据库,解决“Microsoft.ACE.OLEDB.12.0”未注册问题

C#调用ACCESS数据库&#xff0c;解决“Microsoft.ACE.OLEDB.12.0”未注册问题 解决方法&#xff1a; 1.将C#采用的平台从AnyCpu改成X64 2.将官网下载的“Microsoft Access 2010 数据库引擎可再发行程序包AccessDatabaseEngine_X64”文件解压 3.安装解压后的文件 点击下载安…

【文献阅读】Vision-Language Models for Vision Tasks: A Survey

发表于2024年2月 TPAMI 摘要 大多数视觉识别研究在深度神经网络&#xff08;DNN&#xff09;训练中严重依赖标注数据&#xff0c;并且通常为每个单一视觉识别任务训练一个DNN&#xff0c;这导致了一种费力且耗时的视觉识别范式。为应对这两个挑战&#xff0c;视觉语言模型&am…

【Kubernetes】StorageClass 的作用是什么?如何实现动态存储供应?

StorageClass 使得用户能够根据不同的存储需求动态地申请和管理存储资源。 StorageClass 定义了如何创建存储资源&#xff0c;并指定了存储供应的配置&#xff0c;例如存储类型、质量、访问模式等。为动态存储供应提供了基础&#xff0c;使得 Kubernetes 可以在用户创建 PVC 时…

Muduo网络库介绍

1.Reactor介绍 1.回调函数 **回调&#xff08;Callback&#xff09;**是一种编程技术&#xff0c;允许将一个函数作为参数传递给另一个函数&#xff0c;并在适当的时候调用该函数 1.工作原理 定义回调函数 注册回调函数 触发回调 2.优点 异步编程 回调函数允许在事件发生时…

Debian编译安装mysql8.0.41源码包 笔记250401

Debian编译安装mysql8.0.41源码包 以下是在Debian系统上通过编译源码安装MySQL 8.0.41的完整步骤&#xff0c;包含依赖管理、编译参数优化和常见问题处理&#xff1a; 准备工作 1. 安装编译依赖 sudo apt update sudo apt install -y \cmake gcc g make libssl-dev …

Git常用问题收集

gitignore 忽略文件夹 不生效 有时候我们接手别人的项目时&#xff0c;发现有的忽略不对想要修改&#xff0c;但发现修改忽略.gitignore后无效。原因是如果某些文件已经被纳入版本管理在.gitignore中忽略路径是不起作用的&#xff0c;这时候需要先清除本地缓存&#xff0c;然后…

编程哲学——TCP可靠传输

TCP TCP可靠传输 TCP的可靠传输表现在 &#xff08;1&#xff09;建立连接时三次握手&#xff0c;四次挥手 有点像是这样对话&#xff1a; ”我们开始对话吧“ ”收到“ ”好的&#xff0c;我收到你收到了“ &#xff08;2&#xff09;数据传输时ACK应答和超时重传 ”我们去吃…

【MediaPlayer】基于libvlc+awtk的媒体播放器

基于libvlcawtk的媒体播放器 libvlc下载地址 awtk下载地址 代码实现libvlc相关逻辑接口UI媒体接口实例化媒体播放器注意事项 libvlc 下载地址 可以到https://download.videolan.org/pub/videolan/vlc/去下载一个vlc版本&#xff0c;下载后其实是vlc的windows客户端&#xff0…

pulsar中的延迟队列使用详解

Apache Pulsar的延迟队列支持任意时间精度的延迟消息投递&#xff0c;适用于金融交易、定时提醒等高时效性场景。其核心设计通过堆外内存索引队列与持久化分片存储实现&#xff0c;兼顾灵活性与可扩展性。以下从实现原理、使用方式、优化策略及挑战展开解析&#xff1a; 一、核…

单链表的实现 | 附学生信息管理系统的实现

目录 1.前言&#xff1a; 2.单链表的相关概念&#xff1a; 2.1定义&#xff1a; 2.2形式&#xff1a; 2.3特点&#xff1a; 3.常见功能及代码 &#xff1a; 3.1创建节点&#xff1a; 3.2头插&#xff1a; 3.3尾插&#xff1a; 3.4头删&#xff1a; 3.5尾删&#xff1a; 3.6插入…

java实用工具类Localstorage

public class LocalStorageUtil {//提供ThreadLocal对象,private static ThreadLocal threadLocalnew ThreadLocal();public static Object get(){return threadLocal.get();}public static void set(Object o){threadLocal.set(o);}public static void remove(){threadLocal.r…

LLM-大语言模型浅谈

目录 核心定义 典型代表 核心原理 用途 优势与局限 未来发展方向 LLM&#xff08;Large Language Model&#xff09;大语言模型&#xff0c;指通过海量文本数据训练 能够理解和生成人类语言的深度学习模型。 核心定义 一种基于深度神经网络&#xff08;如Transformer架…