【自然语言处理八-transformer实现翻译任务-一(输入)】

自然语言处理八-transformer实现翻译任务-一(输入)

  • transformer架构
  • 数据处理部分
    • 模型的输入数据(图中inputs outputs outputs_probilities对应的label)
      • 以处理英中翻译数据集为例的代码
    • positional encoding 位置嵌入
      • 代码

鉴于transfomer的重要性,在两篇介绍过transfomer模型理论的基础上,我们将分几篇文章,用pytorch代码实现一个完整的transfomer模型。
下面是之前介绍模型的文章:
自然语言处理六-最重要的模型-transformer-上
自然语言处理六-最重要的模型-transformer-下

transformer架构

在这里重新给出transfomer架构图以及用中文翻译后的对照图:
在这里插入图片描述

从架构图可以看出实现transformer架构,encoder和decoder大部分相同。因此也规划几篇内容,介绍以下几大块:

  1. 输入输出部分
    处理数据,源和目标以及输出,以及位置编码

  2. 注意力部分
    多头注意力和掩蔽多头注意力

  3. 前馈网络等
    加和归一化,以及逐位前馈网路

  4. 训练和测试

本篇作为开始,先介绍处理数据部分

数据处理部分

模型的输入数据(图中inputs outputs outputs_probilities对应的label)

这部分用来处理transformer模型需要输入的数据,这部分其实和seq2seq架构是相同的,以自然语言的翻译为例:
比如transformer需要将 ich mochte ein bier 翻译成 i want a beer,那如需要输入的数据格式是这样的(假设句子最大长度5):
[ich mochte ein bier , i want a beer, i want a beer ]

那么上面那部分输入的用途都是什么呢?
encoder输入需要翻译的句子 ich mochte ein bier
decoder输入是 i want a beer
decoder的label是i want a beer
分别对应于图中inputs outputs outputs_probilities相对应的标签

其中是填充字符,用来填充到一个我们超参数中我们设定的sequence的长度
代表句子的开始, 代表句子结束
当然模型真正要处理还是需要根据词汇表转成数字格式的形式,才能被模型处理

以处理英中翻译数据集为例的代码

# -*- coding: utf-8 -*-"""
加载源数据,并处理成data set和data loader
"""
import os
import zipfile
import torch
import torch.utils.data as Data
from src.configs import config
from src.utility.utils import Utils
from src.vocabulay.vocabulary import Vocabularyclass DataSetLoader:"""数据封装成dateset和datelaoder"""def __init__(self, is_train=True, numbers=config.num_examples):"""init parameters"""self.raw_file_path = config.raw_file_pathself.raw_zip_path = config.raw_zip_pathself.batch_size = config.batch_sizeself.num_steps = config.num_stepsself.num_examples = numbersself.is_train = is_traindef load_array(self, data_arrays):"""构造一个数据迭代器:param data_arrays:  输入数据的列表:param is_train:     训练/测试数据集:return:             dataloader"""dataset = Data.TensorDataset(*data_arrays)return Data.DataLoader(dataset, self.batch_size, shuffle=self.is_train)def load_data_nmt(self):"""返回翻译数据集的迭代器和词表:return: 迭代器和词表"""source, src_vocab, target, tgt_vocab = build_vocabu()src_array, src_valid_len = self.build_array_nmt(source, src_vocab)tgt_array, tgt_valid_len = self.build_array_nmt(target, tgt_vocab)data_arrays = (src_array, src_valid_len, tgt_array, tgt_valid_len)data_iter = self.load_array(data_arrays)return data_iter, src_vocab, tgt_vocabdef build_array_nmt(self, lines, vocab):"""将机器翻译的文本序列转换成小批量source[['hello', 'world'],..]  target[['你'], [好]]source[[]]"""lines = [vocab[l] for l in lines]lines = [l + [vocab['<eos>']] for l in lines]array = torch.tensor([Utils.truncate_pad(l, self.num_steps, vocab['<pad>']) for l in lines])valid_len = Utils.reduce_sum(Utils.astype(array != vocab['<pad>'], torch.int32), 1)return array, valid_lendef extract_content():"""提取raw text中内容:return: raw text"""if not os.path.exists(config.raw_file_path):with zipfile.ZipFile(config.raw_zip_path, 'r') as zip_ref:zip_ref.extractall(os.path.dirname(config.raw_file_path))print('语料解压缩完成')with open(config.raw_file_path, 'r', encoding='UTF-8') as f:content = f.read()return contentdef build_vocabu():"""创建词表:return: 词表"""text = Utils.preprocess_nmt(extract_content())source, target = Utils.tokenize_nmt(text, config.num_examples)src_vocab = Vocabulary(source, min_freq=2)tgt_vocab = Vocabulary(target, min_freq=2)return source, src_vocab, target, tgt_vocab

positional encoding 位置嵌入

为了添加位置信息,transformer的位置嵌入,每个位置的512维的数据用sin/cos做了处理
在这里插入图片描述
其中pos是在句子中位置,i是维度信息

代码

def get_sinusoid_encoding_table(n_position, d_model):def cal_angle(position, hid_idx):return position / np.power(10000, 2 * (hid_idx // 2) / d_model)def get_posi_angle_vec(position):return [cal_angle(position, hid_j) for hid_j in range(d_model)]

其他处理会在后续章节继续实现

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/813043.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

1.汉诺塔问题

C力扣 汉诺塔 class Solution { public:void hanota(vector<int>& a, vector<int>& b, vector<int>& c) {dfs(a,b,c,a.size());}void dfs(vector<int>& a, vector<int>& b, vector<int>& c,int n){if(n1){c.push…

阻塞队列和生产消费模型

阻塞队列 阻塞队列的概念 队列相信我们已经不陌生了 之前也学过很多队列 比如: 普通队列 和 优先级队列 两种 这两种队列都是线程不安全的 而我们讲的阻塞队列 刚好可以解决线程安全问题 也是先进先出 并且带有阻塞功能. 阻塞功能是怎么回事呢 就是如果入队的时候阻塞队列为…

rabbitmq安装rabbitmq-delayed-message-exchange插件

下载地址&#xff1a;Community Plugins | RabbitMQ 上传到rabbitmq安装目录的/plugins目录下 我的是/usr/lcoal/rabbitmq/plugins/ 直接安装 [rootk8s-node1 rabbitmq]# rabbitmq-plugins enable rabbitmq_delayed_message_exchange [rootk8s-node1 rabbitmq]# rabbitmq-pl…

pringboot2集成swagger2出现guava的FluentIterable方法不存在

错误信息 Description: An attempt was made to call a method that does not exist. The attempt was made from the following location: springfox.documentation.spring.web.scanners.ApiListingScanner.scan(ApiListingScanner.java:117) The following method did not ex…

c语言例题,求数组中最大值,99乘法口诀表

例题1&#xff1a;求出数组中最大的值 根据题意&#xff0c;我们知道的是需要从一个数组中找到一个最大的元素并且输出。那首先我们先建立一个数组&#xff0c;然后将一些不有序的整型元素放到数组中&#xff0c;然后再建立一个变量来存放数组中的第一个元素&#xff0c;通过一…

算法设计与分析实验报告c++实现(八皇后问题、连续邮资问题、卫兵布置问题、圆排列问题)

一、实验目的 1&#xff0e;加深学生对回溯法算法设计方法的基本思想、基本步骤、基本方法的理解与掌握&#xff1b; 2&#xff0e;提高学生利用课堂所学知识解决实际问题的能力&#xff1b; 3&#xff0e;提高学生综合应用所学知识解决实际问题的能力。 二、实验任务 用回溯…

vue简单使用五(组件的使用)

目录 如何定义组件&#xff1a; 组件的命名&#xff1a; 父组件向子组件传值&#xff1a; 子组件向父组件传值&#xff1a; 如何定义组件&#xff1a; 全局组件定义&#xff1a; 局部组件定义&#xff1a; 组件的基本使用&#xff1a; 打印结果&#xff1a; 组件的命名&#xf…

分类预测 | Matlab实现OOA-BP鱼鹰算法优化BP神经网络数据分类预测

分类预测 | Matlab实现OOA-BP鱼鹰算法优化BP神经网络数据分类预测 目录 分类预测 | Matlab实现OOA-BP鱼鹰算法优化BP神经网络数据分类预测分类效果基本介绍程序设计参考资料 分类效果 基本介绍 1.Matlab实现OOA-BP鱼鹰算法优化BP神经网络多特征分类预测&#xff08;完整源码和数…

MATLAB Simulink仿真搭建及代码生成技术—01自定义新建模型模板

MATLAB Simulink仿真搭建及代码生成技术 目录 01-自定义新建模型模板点击运行&#xff1a;显示效果&#xff1a;查看模型设置&#xff1a; 01-自定义新建模型模板 新建模型代码如下&#xff1a; function new_model(modelname) %建立一个名为SmartAss的新的模型并打开 open_…

【微服务】------常见模型的分析与比较

DDD 分层架构 整洁架构 整洁架构又名“洋葱架构”。为什么叫它洋葱架构&#xff1f;看看下面这张图你就明白了。整洁架构的层就像洋葱片一样&#xff0c;它体现了分层的设计思想。 整洁架构最主要的原则是依赖原则&#xff0c;它定义了各层的依赖关系&#xff0c;越往里依赖越…

[大模型]Langchain-Chatchat安装和使用

项目地址&#xff1a; https://github.com/chatchat-space/Langchain-Chatchat 快速上手 1. 环境配置 首先&#xff0c;确保你的机器安装了 Python 3.8 - 3.11 (我们强烈推荐使用 Python3.11)。 $ python --version Python 3.11.7接着&#xff0c;创建一个虚拟环境&#xff…

Web前端-Ajax

Ajax 概念:Asynchronous JavaScript And XML,异步的JavaScript和XML。 作用: 1.数据交换:通过Ajax可以给服务器发送请求,并获取服务器响应的数据。 2.异步交互:可以在不重新加载整个页面的情况下,与服务器交换数据并更新部分网页的技术,如:搜索联想、用户名是否可用的校验等等…

Ubuntu 22上安装Anaconda3。下载、安装、验证详细教程

在Ubuntu 22上安装Anaconda3&#xff0c;你可以遵循以下步骤&#xff1a; 更新系统存储库&#xff1a; 打开终端并运行以下命令来更新系统存储库&#xff1a; sudo apt update安装curl包&#xff1a; 下载Anaconda安装脚本通常需要使用curl工具。如果系统中没有安装curl&#x…

基于springboot的医院挂号取药缴费管理系统

一、基于springboot的医院挂号取药缴费管理系统 简介系统和功能 二、技术框架 这是一款基于SpringbootLAYUIMysql的管理系统 开发语言&#xff1a;Java JDK1.8 数据库&#xff1a;mysql5.7 前端&#xff1a;LAYUI框架 后端&#xff1a;Springboot框架、Spring框架、持久层My…

WinForms零基础进阶控件教程(超实用详细版)

文章目录 树型控件TreeView常用属性常用事件添加、删除树节点通过代码添加树节点通过代码删除树节点&#xff1a;管理节点图标响应事件&#xff0c;获取选中节点 列表视图ListView常用属性常用事件添加、删除项使用ListViewItem集合编辑器添加&#xff0c;删除项&#xff1a;通…

LED显示IC-点阵数码管显示驱动/抗干扰数码管驱动VK1650 SOP16/DIP16

产品品牌&#xff1a;永嘉微电/VINKA 产品型号&#xff1a;VK1650 封装形式&#xff1a;SOP16/DIP16 概述 VK1650是一种带键盘扫描电路接口的 LED 驱动控制专用芯片&#xff0c;内部集成有数据锁存器、LED 驱动、键盘扫描等电路。SEG脚接LED阳极&#xff0c;GRID脚接LED阴极&…

WebGL 2.0相较于1.0有什么不同?

作者&#xff1a;STANCH 1.概述 WebGL 1.0自推出以来&#xff0c;已成为广泛支持的Web标准&#xff0c;既能跨平台&#xff0c;还免版税。它通过插件为Web浏览器带来高质量的3D图形&#xff0c;这是迄今为止市场上使用最广泛的Web图形&#xff0c;并得到Apple&#xff0c;Goog…

ControllerAdvice用法

ControllerAdvice用法 ControllerAdvice是一个组件注解&#xff0c;它允许你在一个地方处理整个应用程序控制器的异常、绑定数据和预处理请求。这意味着你不需要在每个控制器中重复相同的异常处理代码&#xff0c;从而使得代码更加简洁、易于管理。 主要特性 全局异常处理&a…

【vue】v-model.lazy等(非实时渲染)

v-model&#xff1a;实时渲染v-model.lazy&#xff1a;失去焦点/按回车后&#xff0c;才渲染v-model.number&#xff1a;值转换为数字v-model.trim&#xff1a;去除首尾空格 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8&qu…

MDK平台 - Code, RO-data , RW-data, ZI-data详解

文章目录 1 . 前言2 . Code, RO-data , RW-data, ZI-data解析3 . RAM上电复位4 . 细节扩展5 . 总结 【全文大纲】 : https://blog.csdn.net/Engineer_LU/article/details/135149485 1 . 前言 MDK编译后&#xff0c;会列出Code, RO-data , RW-data, ZI-data&#xff0c;以下解析…