昇思25天学习打卡营第18天|文本解码原理--以MindNLP为例

文章目录

      • 昇思MindSpore应用实践
          • 1、自回归语言模型
            • RNN网络
          • 2、文本解码原理--以MindNLP为例
            • Greedy search
            • Beam search
            • Repeat problem
            • TopK sample
      • Refernence

昇思MindSpore应用实践

本系列文章主要用于记录昇思25天学习打卡营的学习心得。

1、自回归语言模型

自回归语言模型(Autoregressive Language Model)是一种用于生成文本的统计模型。它基于序列数据的概率分布,通过建模当前词语与前面已生成词语的条件概率来预测下一个词语,即 根据前文预测下一个单词

一个文本序列的概率分布可以分解为每个词基于其上文的条件概率的乘积

  • 𝑊_0:初始上下文单词序列
  • 𝑇: 时间步
  • 当生成EOS标签时,停止生成。

自回归语言模型可以使用不同的方法来建模条件概率分布。其中,在基于Transformer架构的大语言模型出现之前,一种常见的方法是使用循环神经网络(Recurrent Neural Network,RNN),
RNN 可以通过在每个时间步骤上接收输入并保留隐状态信息,来捕捉序列中的上下文关系。通过训练RNN模型,可以学习到词语之间的概率分布,并用于生成新的文本。

RNN网络

RNN 网络的基本结构包括一个输入层 x t x_t xt、隐藏层 h t h_t ht(含激活函数Activation Function)、延迟器(循环单元)、输出层 h t h_t ht
在这里插入图片描述

网络中的神经元通过时间步骤连接形成循环:允许信息从一个时间步骤的输出 h t − 1 h_{t-1} ht1通过与输入 X t X_t Xt经过tanh函数激活后,传递至下一个时间步骤输入的一部分
RNN具体计算公式:

h t = t a n h ( W i h x t + b i h + W h h x t − 1 + b h h ) h_t=tanh(W_{ih}x_t+b_{ih}+W_{hh}x_{t-1}+b_{hh}) ht=tanh(Wihxt+bih+Whhxt1+bhh)
在这里插入图片描述

单个展开的RNN结构

在这里插入图片描述

整体展开的RNN结构

对于某 t 时刻的步骤,RNN隐藏状态大致的计算方法为:

2、文本解码原理–以MindNLP为例

MindNLP/huggingface Transformers提供的文本生成方法

Greedy search

在每个时间步𝑡都简单地选择概率最高的词作为当前输出词:

𝑤_𝑡=𝑎𝑟𝑔𝑚𝑎𝑥_𝑤 𝑃(𝑤|𝑤_(1:𝑡−1))

按照贪心搜索输出序列(“The”,“nice”,“woman”) 的条件概率为:0.5 x 0.4 = 0.2

缺点: 错过了隐藏在低概率词后面的高概率词,如:dog=0.5, has=0.9

#greedy_searchfrom mindnlp.transformers import GPT2Tokenizer, GPT2LMHeadModeltokenizer = GPT2Tokenizer.from_pretrained("iiBcai/gpt2", mirror='modelscope')  # 导入预训练GPT2# add the EOS token as PAD token to avoid warnings
model = GPT2LMHeadModel.from_pretrained("iiBcai/gpt2", pad_token_id=tokenizer.eos_token_id, mirror='modelscope')# encode context the generation is conditioned on
input_ids = tokenizer.encode('I enjoy walking with my cute dog', return_tensors='ms')# generate text until the output length (which includes the context length) reaches 50
greedy_output = model.generate(input_ids, max_length=50)print("Output:\n" + 100 * '-')
print(tokenizer.decode(greedy_output[0], skip_special_tokens=True))

在这里插入图片描述

Beam search

Beam search通过在每个时间步保留最可能的 num_beams 个词,并从中最终选择出概率最高的序列来降低丢失潜在的高概率序列的风险。如图以 num_beams=2 为例:

(“The”,“dog”,“has”) : 0.4 * 0.9 = 0.36

(“The”,“nice”,“woman”) : 0.5 * 0.4 = 0.20

优点:一定程度保留最优路径

缺点:1. 无法解决重复问题;2. 开放域生成效果差

from mindnlp.transformers import GPT2Tokenizer, GPT2LMHeadModeltokenizer = GPT2Tokenizer.from_pretrained("iiBcai/gpt2", mirror='modelscope')# add the EOS token as PAD token to avoid warnings
model = GPT2LMHeadModel.from_pretrained("iiBcai/gpt2", pad_token_id=tokenizer.eos_token_id, mirror='modelscope')# encode context the generation is conditioned on
input_ids = tokenizer.encode('I enjoy walking with my cute dog', return_tensors='ms')# activate beam search and early_stopping
beam_output = model.generate(input_ids, max_length=50, num_beams=5, early_stopping=True
)print("Output:\n" + 100 * '-')
print(tokenizer.decode(beam_output[0], skip_special_tokens=True))
print(100 * '-')# set no_repeat_ngram_size to 2
beam_output = model.generate(input_ids, max_length=50, num_beams=5, no_repeat_ngram_size=2, early_stopping=True
)print("Beam search with ngram, Output:\n" + 100 * '-')
print(tokenizer.decode(beam_output[0], skip_special_tokens=True))
print(100 * '-')# set return_num_sequences > 1
beam_outputs = model.generate(input_ids, max_length=50, num_beams=5, no_repeat_ngram_size=2, num_return_sequences=5, early_stopping=True
)# now we have 3 output sequences
print("return_num_sequences, Output:\n" + 100 * '-')
for i, beam_output in enumerate(beam_outputs):print("{}: {}".format(i, tokenizer.decode(beam_output, skip_special_tokens=True)))
print(100 * '-')
Output:
----------------------------------------------------------------------------------------------------
I enjoy walking with my cute dog, but I don't think I'll ever be able to walk with her again.""I don't think I'll ever be able to walk with her again.""I don't think I
----------------------------------------------------------------------------------------------------
Beam search with ngram, Output:
----------------------------------------------------------------------------------------------------
I enjoy walking with my cute dog, but I don't think I'll ever be able to walk with her again.""I'm not sure what to say to that," she said. "I mean, it's not like I'm
----------------------------------------------------------------------------------------------------
return_num_sequences, Output:
----------------------------------------------------------------------------------------------------
0: I enjoy walking with my cute dog, but I don't think I'll ever be able to walk with her again.""I'm not sure what to say to that," she said. "I mean, it's not like I'm
1: I enjoy walking with my cute dog, but I don't think I'll ever be able to walk with her again.""I'm not sure what to say to that," she said. "I mean, it's not like she's
2: I enjoy walking with my cute dog, but I don't think I'll ever be able to walk with her again.""I'm not sure what to say to that," she said. "I mean, it's not like we're
3: I enjoy walking with my cute dog, but I don't think I'll ever be able to walk with her again.""I'm not sure what to say to that," she said. "I mean, it's not like I've
4: I enjoy walking with my cute dog, but I don't think I'll ever be able to walk with her again.""I'm not sure what to say to that," she said. "I mean, it's not like I can
----------------------------------------------------------------------------------------------------
Repeat problem

n-gram 惩罚:

将出现过的候选词的概率设置为 0
设置no_repeat_ngram_size=2 ,任意 2-gram 不会出现两次
实际文本生成需要重复出现

import mindspore
from mindnlp.transformers import GPT2Tokenizer, GPT2LMHeadModeltokenizer = GPT2Tokenizer.from_pretrained("iiBcai/gpt2", mirror='modelscope')# add the EOS token as PAD token to avoid warnings
model = GPT2LMHeadModel.from_pretrained("iiBcai/gpt2", pad_token_id=tokenizer.eos_token_id, mirror='modelscope')# encode context the generation is conditioned on
input_ids = tokenizer.encode('I enjoy walking with my cute dog', return_tensors='ms')mindspore.set_seed(0)
# activate sampling and deactivate top_k by setting top_k sampling to 0
sample_output = model.generate(input_ids, do_sample=True, max_length=50, top_k=0
)print("Output:\n" + 100 * '-')
print(tokenizer.decode(sample_output[0], skip_special_tokens=True))
Output:
----------------------------------------------------------------------------------------------------
I enjoy walking with my cute dog Neddy as much as I'd like. Keep up the good work Neddy!"I realized what Neddy meant when he first launched the website. "Thank you so much for joining."
TopK sample

选出概率最大的 K 个词,重新归一化,最后在归一化后的 K 个词中采样;
将采样池限制为固定大小 K :

  • 在分布比较尖锐的时候产生胡言乱语
  • 在分布比较平坦的时候限制模型的创造力
import mindspore
from mindnlp.transformers import GPT2Tokenizer, GPT2LMHeadModeltokenizer = GPT2Tokenizer.from_pretrained("iiBcai/gpt2", mirror='modelscope')# add the EOS token as PAD token to avoid warnings
model = GPT2LMHeadModel.from_pretrained("iiBcai/gpt2", pad_token_id=tokenizer.eos_token_id, mirror='modelscope')# encode context the generation is conditioned on
input_ids = tokenizer.encode('I enjoy walking with my cute dog', return_tensors='ms')mindspore.set_seed(0)
# activate sampling and deactivate top_k by setting top_k sampling to 0
sample_output = model.generate(input_ids, do_sample=True, max_length=50, top_k=50
)print("Output:\n" + 100 * '-')
print(tokenizer.decode(sample_output[0], skip_special_tokens=True))
Output:
----------------------------------------------------------------------------------------------------
I enjoy walking with my cute dog.She's always up for some action, so I have seen her do some stuff with it.Then there's the two of us.The two of us I'm talking about were

Refernence

[1] 自回归语言模型简介
[2]昇思大模型平台
[3] MindSpore官方文档-文本解码原理–以MindNLP为例

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/49148.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【JVM基础04】——组成-什么是虚拟机栈?

目录 1- 引言:虚拟机栈1-1 虚拟机栈是什么?(What)1-2 为什么用虚拟机栈?虚拟机栈的作用 (Why) 2- ⭐核心:栈的常见问题(How)2-1 方法内的局部变量是否线程安全?线程不安全的局部变量 2-2 什么情况会导致栈内存溢出&…

深入Mysql-03-MySQL 表的约束与数据库设计

文章目录 数据库约束的概述约束种类主键约束唯一约束非空约束默认值外键约束 表与表之间的关系数据库设计 数据库约束的概述 对表中的数据进行限制,保证数据的正确性、有效性和完整性。一个表如果添加了约束,不正确的数据将无法插入到表中。 约束种类 …

go-kratos 学习笔记(3) google buf 管理proto

google buf 管理proto,以及从新归档文件的目录结构 什么是 BSR? BSR 将 Protobuf 文件作为版本化模块进行存储和管理,以便个人和组织可以轻松使用和发布他们的 API。 BSR 带有可浏览的 UI、依赖项管理、API 验证、版本控制、生成的文档以及…

稳居中科院2区的SCIEI双检索期刊,听说一投就中!

IEEE TRANSACTIONS ON ELECTRON DEVICES,中科院2区,JCR Q2, SCI&EI双检索期刊,年发文量在1000篇左右,且大有继续扩刊的走向。有投稿经验的作者反馈,比较容易被录用。 期刊信息 IEEE TRANSACTIONS ON ELECTRON DE…

Python 机器学习求解 PDE 学习项目——PINN 求解一维 Poisson 方程

本文使用 TensorFlow 1.15 环境搭建深度神经网络(PINN)求解一维 Poisson 方程: − Δ u f in Ω , u 0 on Γ : ∂ Ω . \begin{align} -\Delta u & f \quad & \text{in } \Omega,\\ u & 0 \quad & \text{on } \Gamma:\partial \Om…

非对称加密算法RSA的OpenSSL代码实现Demo

目录 1 RSA简介 1.1 RSA算法介绍 1.2 RSA算法的速度与安全性 1.3 RSA存储格式 1.3.1 PKCS#1 标准主要用于 RSA密钥,其RSA公钥和RSA私钥PEM格式 1.3.2 PKCS#8 标准定义了一个密钥格式的通用方案,其公钥和私钥PEM格式 2 OpenSSL代码实现 2.1 生…

初学51单片机之指针基础与串口通信应用

开始之前推荐一个电路学习软件,这个软件笔者也刚接触。名字是Circuit有在线版本和不在线版本,这是笔者在B站看视频翻到的。 Paul Falstadhttps://www.falstad.com/这是地址。 离线版本在网站内点这个进去 根据你的系统下载你需要的版本红线的是windows…

第九讲:POU与变量基础

POU(Program Organization Unit)的分类 一、定义及分类 POU即程序组成单元 二、三种POU的作用 1、功能/功能快:看作算法 功能块的POU是比较复杂的指令 三、功能块POU和功能POU的区别 1、理解功能POU(对比) 不添加实例名,就不需要去建立变量,所以就不会占到内存。 因…

算法题目整合4

文章目录 122. 大数减法123. 滑动窗口最大值117. 软件构建124. 小红的数组构造125. 精华帖子126. 连续子数组最大和 122. 大数减法 题目描述 以字符串的形式读入两个数字,编写一个函数计算它们的差,以字符串形式返回。输入描述 输入两个数字&#xff…

物联网专业创新人才培养体系的探索与实践

一、引言 随着物联网(IoT)技术的迅猛发展,物联网领域的人才需求日益增加。物联网技术作为新一轮信息技术革命的核心,已经渗透到社会生活的各个领域,对推动经济转型升级、提升国家竞争力具有重要意义。因此&#xff0c…

VUE之---slot插槽

什么是插槽 slot 【插槽】, 是 Vue 的内容分发机制, 组件内部的模板引擎使用slot 元素作为承载分发内容的出口。slot 是子组件的一个模板标签元素, 而这一个标签元素是否显示, 以及怎么显示是由父组件决定的。 VUE中slot【插槽】…

自己开发软件实现网站抓取m3u8链接

几天前一个同学说想下载一个网站的视频找不到连接,问我有没有什么办法,网站抓取m3u8链接 网页抓取m3u8链接。当时一听觉得应该简单,于是说我抽空看看。然后就分析目标网页,试图从网页源码里找出连接,有的源代码直接有,但是有的没有…

Java二十三种设计模式-代理模式模式(8/23)

代理模式:为对象访问提供灵活的控制 引言 代理模式(Proxy Pattern)是一种结构型设计模式,它为其他对象提供一个代替或占位符,以控制对它的访问。 基础知识,java设计模式总体来说设计模式分为三大类&#…

Varjo XR-4系列现已获得达索3DEXPERIENCE平台官方支持

近日,全球领先的工业虚拟和混合现实解决方案提供商Varjo宣布,Varjo XR-4系列现已获得达索3DEXPERIENCE平台的本地支持。这种集成为工程师和设计师带来了先进的虚拟和混合现实功能,他们可以通过沉浸式技术创新并简化他们的3D工作流程。 在达索…

【iOS】Tagged Pointer

目录 前言什么是Tagged Pointer?引入Tagged Pointer技术之前引入Tagged Pointer之后总结 Tagged Pointer原理(TagData分析)关闭数据混淆MacOS分析NSNumberNSString iOS分析 判断Tagged PointerTagged Pointer应用Tagged Pointer 注意点 Tagge…

Qt绘制指南针(仪表盘绘制封装使用)

指南针是一种用来确定方向的工具。它由一个磁针制成,其一端被磁化,可以自由旋转。当放置在水平面上时,磁针会指向地球的磁北极。通过观察磁针的指向,我们可以确定地理北方的方向。本示例是在Qt中绘制一个指南针,通过继…

Android WebViewClient 的 `shouldOverrideUrlLoading` 方法

简介 在Android开发中,WebView是一个强大的工具,可以在你的应用中显示网页内容。了解 WebViewClient 中的 shouldOverrideUrlLoading 方法是至关重要的,因为这个方法允许你控制 URL 在 WebView 中的处理方式。 在本文中,我们将详…

S71200 - 笔记

1 S71200 0 ProfiNet - 2 PLC编程 01.如何零基础快速上手S7-1200_哔哩哔哩_bilibili 西门子S7-1200PLC编程设计学习视频,从入门开始讲解_哔哩哔哩_bilibili

Linux:进程信号(一.认识信号、信号的产生及深层理解、Term与Core)

上次结束了进程间通信的知识介绍:Linux:进程间通信(二.共享内存详细讲解以及小项目使用和相关指令、消息队列、信号量 文章目录 1.认识信号进程看待信号方式 2.信号的产生2.1信号的处理的方式 --- signal()函数2.2kill指令产生信号2.3键盘产生…

最新快乐二级域名分发系统重置版v1.7源码-最新美化版+源码+可对接支付

源码简介: 最新快乐二级域名分发系统重置版v1.7源码,它是最新美化版源码可对接支付。 快乐二级域名分发系统重置版v1.7源码,简单快捷、功能强大的控制面板。系统稳定长久,控制面板没任何广告,让网站更实用方便。 最…