深度学习处理文本(13)

我们使用基于GRU的编码器和解码器来在Keras中实现这一方法。选择GRU而不是LSTM,会让事情变得简单一些,因为GRU只有一个状态向量,而LSTM有多个状态向量。首先是编码器,如代码清单11-28所示。

代码清单11-28 基于GRU的编码器

from tensorflow import keras
from tensorflow.keras import layersembed_dim = 256
latent_dim = 1024source = keras.Input(shape=(None,), dtype="int64", name="english")----不要忘记掩码,它对这种方法来说很重要
x = layers.Embedding(vocab_size, embed_dim, mask_zero=True)(source)----这是英语源句子。指定输入名称,我们就可以用输入组成的字典来拟合模型
encoded_source = layers.Bidirectional(layers.GRU(latent_dim), merge_mode="sum")(x)----编码后的源句子即为双向GRU的最后一个输出

接下来,我们来添加解码器——一个简单的GRU层,其初始状态为编码后的源句子。我们再添加一个Dense层,为每个输出时间步生成一个在西班牙语词表上的概率分布,如代码清单11-29所示。

代码清单11-29 基于GRU的解码器与端到端模型

past_target = keras.Input(shape=(None,), dtype="int64", name="spanish")----这是西班牙语目标句子
x = layers.Embedding(vocab_size, embed_dim, mask_zero=True)(past_target)----不要忘记使用掩码
decoder_gru = layers.GRU(latent_dim, return_sequences=True)
x = decoder_gru(x, initial_state=encoded_source)----编码后的源句子作为解码器GRU的初始状态
x = layers.Dropout(0.5)(x)
target_next_step = layers.Dense(vocab_size, activation="softmax")(x)----预测下一个词元
seq2seq_rnn = keras.Model([source, past_target], target_next_step)----端到端模型:将源句子和目标句子映射为偏移一个时间步的目标句子

训练过程中,解码器接收整个目标序列作为输入,但由于RNN逐步处理的性质,它将仅通过查看输入中第0~N个词元来预测输出的第N个词元(对应于句子的下一个词元,因为输出需要偏移一个时间步)​。这意味着我们只能使用过去的信息来预测未来——我们也应该这样做,否则就是在作弊,这样生成模型在推断过程中将不会生效。下面开始训练模型,如代码清单11-30所示。

代码清单11-30 训练序列到序列循环模型

seq2seq_rnn.compile(optimizer="rmsprop",loss="sparse_categorical_crossentropy",metrics=["accuracy"])
seq2seq_rnn.fit(train_ds, epochs=15, validation_data=val_ds)

我们选择精度来粗略监控训练过程中的验证集性能。模型精度为64%,也就是说,平均而言,该模型在64%的时间里正确预测了西班牙语句子的下一个单词。然而在实践中,对于机器翻译模型而言,下一个词元精度并不是一个很好的指标,因为它会假设:在预测第N+1个词元时,已经知道了从0到N的正确的目标词元。实际上,在推断过程中,你需要从头开始生成目标句子,不能认为前面生成的词元都是100%正确的。现实世界中的机器翻译系统可能会使用“BLEU分数”来评估模型。这个指标会评估整个生成序列,并且看起来与人类对翻译质量的评估密切相关。最后,我们使用模型进行推断,如代码清单11-31所示。我们从测试集中挑选几个句子,并观察模型如何翻译它们。我们首先将种子词元"[start]“与编码后的英文源句子一起输入解码器模型。我们得到下一个词元的预测结果,并不断将其重新输入解码器,每次迭代都采样一个新的目标词元,直到遇到”[end]"或达到句子的最大长度。

代码清单11-31 利用RNN编码器和RNN解码器来翻译新句子

import numpy as np
spa_vocab = target_vectorization.get_vocabulary()---- (本行及以下1)准备一个字典,将词元索引预测值映射为字符串词元
spa_index_lookup = dict(zip(range(len(spa_vocab)), spa_vocab))
max_decoded_sentence_length = 20def decode_sequence(input_sentence):tokenized_input_sentence = source_vectorization([input_sentence])decoded_sentence = "[start]"----种子词元for i in range(max_decoded_sentence_length):tokenized_target_sentence = target_vectorization([decoded_sentence])next_token_predictions = seq2seq_rnn.predict(---- (本行及以下2)对下一个词元进行采样[tokenized_input_sentence, tokenized_target_sentence])sampled_token_index = np.argmax(next_token_predictions[0, i, :])sampled_token = spa_index_lookup[sampled_token_index]---- (本行及以下1)将下一个词元预测值转换为字符串,并添加到生成的句子中decoded_sentence += " " + sampled_tokenif sampled_token == "[end]":----退出条件:达到最大长度或遇到停止词元breakreturn decoded_sentencetest_eng_texts = [pair[0] for pair in test_pairs]
for _ in range(20):input_sentence = random.choice(test_eng_texts)print("-")print(input_sentence)print(decode_sequence(input_sentence))

请注意,这种推断方法虽然非常简单,但效率很低,因为每次采样新词时,都需要重新处理整个源句子和生成的整个目标句子。在实际应用中,你会将编码器和解码器分成两个独立的模型,在每次采样词元时,解码器只运行一步,并重新使用之前的内部状态。翻译结果如代码清单11-32所示。对于一个玩具模型而言,这个模型的效果相当好,尽管它仍然会犯许多低级错误。

代码清单11-32 循环翻译模型的一些结果示例

Who is in this room?
[start] quién está en esta habitación [end]
-
That doesn't sound too dangerous.
[start] eso no es muy difícil [end]
-
No one will stop me.
[start] nadie me va a hacer [end]
-
Tom is friendly.
[start] tom es un buen [UNK] [end]

有很多方法可以改进这个玩具模型。编码器和解码器可以使用多个循环层堆叠(请注意,对于解码器来说,这会使状态管理变得更加复杂)​,我们还可以使用LSTM代替GRU,诸如此类。然而,除了这些调整,RNN序列到序列学习方法还受到一些根本性的限制。源序列表示必须完整保存在编码器状态向量中,这极大地限制了待翻译句子的长度和复杂度。这有点像一个人完全凭记忆翻译一句话,并且在翻译时只能看一次源句子。RNN很难处理非常长的序列,因为它会逐渐忘记过去。等到处理序列中的第100个词元时,模型关于序列开始的信息已经几乎没有了。这意味着基于RNN的模型无法保存长期上下文,而这对于翻译长文档而言至关重要。正是由于这些限制,机器学习领域才采用Transformer架构来解决序列到序列问题。我们来看一下。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/76789.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

HashMap 底层原理详解

1. 核心数据结构 JDK 1.7 及之前&#xff1a;数组 链表 JDK 1.8 及之后&#xff1a;数组 链表/红黑树&#xff08;链表长度 ≥8 时转红黑树&#xff0c;≤6 时退化为链表&#xff09; // JDK 1.8 的 Node 定义&#xff08;链表节点&#xff09; static class Node<K,V&g…

使用MySQL时出现 Ignoring query to other database 错误

Ignoring query to other database 错误 当在远程连接软件中输入MySQL命令出现该错误 导致错误原因是&#xff1a;登录mysql时账户名没有加上u 如果出现该错误&#xff0c;退出mysql&#xff0c;重新输入正确格式进入即可&#xff01;

哈尔滨工业大学:大模型时代的具身智能

大家好&#xff0c;我是樱木。 机器人在工业领域&#xff0c;已经逐渐成熟。具身容易&#xff0c;智能难。 机器人-》智能机器人&#xff0c;需要自主能力&#xff0c;加上通用能力。 智能机器人-》人类&#xff0c;这个阶段就太有想象空间了。而最受关注的-类人机器人。 如何…

Javascript代码压缩混淆工具terser详解

原始的JavaScript代码在正式的服务器上,如果没有进行压缩,混淆,不仅加载速度比较慢,而且还存在安全和性能问题. 因此现在需要进行压缩,混淆处理. 处理方案简单描述一下: 1. 使用 terser 工具进行 安装 terser工具: # npm 安装 npm install terser --save-dev# 或使用 yarn 安…

Java String 常用方法详解

目录 一、获取字符串信息(一)获取字符串长度(二)获取指定索引处的字符(三)获取子字符串二、字符串比较(一)比较字符串内容(二)忽略大小写比较三、字符串转换(一)转换为大写(二)转换为小写四、字符串查找(一)查找子字符串的位置(二)从指定位置开始查找五、字符…

Linux驱动开发练习案例

1 开发目标 1.1 架构图 操作系统&#xff1a;基于Linux5.10.10源码和STM32MP157开发板&#xff0c;完成tf-a(FSBL)、u-boot(SSBL)、uImage、dtbs的裁剪&#xff1b; 驱动层&#xff1a;为每个外设配置DTS并且单独封装外设驱动模块。其中电压ADC测试&#xff0c;采用linux内核…

leetcode-代码随想录-哈希表-赎金信

题目 题目链接&#xff1a;383. 赎金信 - 力扣&#xff08;LeetCode&#xff09; 给你两个字符串&#xff1a;ransomNote 和 magazine &#xff0c;判断 ransomNote 能不能由 magazine 里面的字符构成。 如果可以&#xff0c;返回 true &#xff1b;否则返回 false 。 maga…

精品可编辑PPT | “新基建”在数字化智慧高速公路中的支撑应用方案智慧建筑智慧交通解决方案施工行业解决方案

本文详细阐述了“新基建”在数字化智慧高速公路中的支撑应用方案&#xff0c;从政策背景出发&#xff0c;指出国家在交通领域的一系列发展规划和指导意见&#xff0c;强调了智慧交通建设的重要性。分析了当前高速公路存在的问题&#xff0c;如基础感知设施不足、协同水平低、服…

C语言求3到100之间的素数

一、代码展示 二、运行结果 三、感悟思考 注意: 这个题思路他是一个试除法的一个思路 先进入一个for循环 遍历3到100之间的数字 第二个for循环则是 判断他不是素数 那么就直接退出 这里用break 是素数就打印出来 在第一个for循环内 第二个for循环外

英语—四级CET4考试—蒙猜篇—匹配题

蒙猜方法一 匹配题的做题&#xff1a; 方法一&#xff1a; 首先&#xff0c;什么都不想&#xff0c;把问题中ing形式的&#xff0c;大写字母的&#xff0c;人名&#xff0c;地名&#xff0c;最后几个依次框起来。 然后&#xff0c;比如46题&#xff0c;口里默念meaningful lif…

股票日数据使用_未复权日数据生成前复权日周月季年数据

目录 前置&#xff1a; 准备 代码&#xff1a;数据库交互部分 代码&#xff1a;生成前复权 日、周、月、季、年数据 前置&#xff1a; 1 未复权日数据获取&#xff0c;请查看 https://blog.csdn.net/m0_37967652/article/details/146435589 数据库使用PostgreSQL。更新日…

系统与网络安全------Windows系统安全(6)

资料整理于网络资料、书本资料、AI&#xff0c;仅供个人学习参考。 共享文件夹 发布共享文件夹 Windows共享概述 微软公司推出的网络文件/打印机服务系统 可以将一台主机的资源发布给其他主机共有 共享访问的优点 方便、快捷相比光盘 U盘不易受文件大小限制 可以实现访问…

BN 层的作用, 为什么有这个作用?

BN 层&#xff08;Batch Normalization&#xff09;——这是深度神经网络中非常重要的一环&#xff0c;它大大改善了网络的训练速度、稳定性和收敛效果。 &#x1f9e0; 一句话理解 BN 层的作用&#xff1a; Batch Normalization&#xff08;批归一化&#xff09;通过标准化每一…

判断HiveQL语句为ALTER TABLE语句的识别函数

写一个C#字符串解析程序代码&#xff0c;逻辑是从前到后一个一个读取字符&#xff0c;遇到匹配空格、Tab和换行符就继续读取下一个字符&#xff0c;遇到大写或小写的字符a&#xff0c;就读取后一个字符并匹配是否为大写或小写的字符l&#xff0c;以此类推&#xff0c;匹配任意字…

基于编程的运输设备管理系统设计(vue+springboot+ssm+mysql8.x)

基于编程的运输设备管理系统设计&#xff08;vuespringbootssmmysql8.x&#xff09; 运输设备信息管理系统是一个全面的设备管理平台&#xff0c;旨在优化设备管理流程&#xff0c;提高运输效率。系统提供登录入口&#xff0c;确保只有授权用户可以访问。个人中心让用户可以查…

6.1 python加载win32或者C#的dll的方法

python很方便的可以加载win32的方法以及C#编写的dll中的方法或者变量&#xff0c;大致过程如下。 一.python加载win32的方法&#xff0c;使用win32api 1.安装库win32api pip install win32api 2.加载所需的win32函数并且调用 import win32api win32api.MessageBox(0,"…

前端精度计算:Decimal.js 基本用法与详解

一、Decimal.js 简介 decimal.js 是一个用于任意精度算术运算的 JavaScript 库&#xff0c;它可以完美解决浮点数计算中的精度丢失问题。 官方API文档&#xff1a;Decimal.js 特性&#xff1a; 任意精度计算&#xff1a;支持大数、小数的高精度运算。 链式调用&#xff1a;…

SQL Server 数据库实验报告

​​​​​​​ 1.1 实验题目&#xff1a;索引和数据完整性的使用 1.2 实验目的&#xff1a; &#xff08;1&#xff09;掌握SQL Server的资源管理器界面应用&#xff1b; &#xff08;2&#xff09;掌握索引的使用&#xff1b; &#xff08;3&#xff09;掌握数据完整性的…

AI绘画中的LoRa是什么?

Lora是一个多义词&#xff0c;根据不同的上下文可以指代多种事物。以下将详细介绍几种主要的含义&#xff1a; LoRa技术 LoRa&#xff08;Long Range Radio&#xff09;是一种低功耗广域网&#xff08;LPWAN&#xff09;无线通信技术&#xff0c;以其远距离、低功耗和低成本的特…

哈希表(Hashtable)核心知识点详解

1. 基本概念 定义&#xff1a;通过键&#xff08;Key&#xff09;直接访问值&#xff08;Value&#xff09;的数据结构&#xff0c;基于哈希函数将键映射到存储位置。 核心操作&#xff1a; put(key, value)&#xff1a;插入键值对 get(key)&#xff1a;获取键对应的值 remo…