使用GRU单元的RNN模型生成唐诗

文章目录

    • 1. 读取数据
    • 2. 字符索引
    • 3. 创建文本序列
    • 4. 创建文本编码序列
    • 5. 使用GRU单元建立RNN模型
    • 6. 文本生成

参考 基于深度学习的自然语言处理

本文使用 GRU 单元建立 RNN 网络,使用唐诗三百首进行训练,使用模型生成唐诗。

GRU RNN 网络能够克服简单RNN网络的一些问题,如梯度消失,梯度很难从深层传递到浅层,导致浅层的参数更新非常缓慢,学习速度很慢,还导致深层浅层学习不均衡。

GRU,LSTM 使用更新门,遗忘门,来解决长距离的依赖关系,GRU相比LSTM参数更少。

RNN 网络的还有缺点就是无法采用并行计算,必须在上一个时间步的基础上计算下一个时间步。

1. 读取数据

# 读取文本
file = "tangshi300.txt"
with open(file,'r',encoding='utf-8') as f:text = f.read()
print(len(text))
print(text[:180])

输出:

29405  # 文本长度唐诗3001-50010杜甫:佳人绝代有佳人,幽居在空谷。
自云良家子,零落依草木。
关中昔丧乱,兄弟遭杀戮。
官高何足论,不得收骨肉。
世情恶衰歇,万事随转烛。
夫婿轻薄儿,新人美如玉。
合昏尚知时,鸳鸯不独宿。
但见新人笑,那闻旧人哭!
在山泉水清,出山泉水浊。
侍婢卖珠回,牵萝补茅屋。
摘花不插发,采柏动盈掬。
天寒翠袖薄,日暮倚修竹。

2. 字符索引

# 创建字符序号索引
words = sorted(list(set(text)))
print("字和符号数量:{}".format(len(words)))word_idx = {w : i for (i, w) in enumerate(words)}
idx_word = {i : w for (i, w) in enumerate(words)}

输出:

字和符号数量:2590

3. 创建文本序列

# 根据文本,创建序列
sample_maxlen = 40 # 样本句子长度
sentences = []
next_word = []
for i in range(len(text)-sample_maxlen):sentences.append(text[i : i+sample_maxlen]) # 滑窗取出样本句子next_word.append(text[i+sample_maxlen]) # 句子末尾的下一个字
print("样本数量:{}".format(len(sentences)))
样本数量:29365

4. 创建文本编码序列

# 将文本序列转化成数字序列(矩阵), 实际上就是一个one_hot 编码
import numpy as np
X = np.zeros((len(sentences), sample_maxlen, len(words)), dtype=np.bool)
#            样本数             1个样本字个数   一个字的OH编码长度
y = np.zeros((len(sentences), len(words)), dtype=np.bool)
#            样本数          一个字的OH编码长度
for i in range(len(sentences)):for t, w in enumerate(sentences[i]):X[i, t, word_idx[w]] = 1 # 把 i 样本 t 字符 对应的 OH 编码位置 写为 1y[i, word_idx[next_word[i]]] = 1

5. 使用GRU单元建立RNN模型

  • 建模
# 建模
from keras.models import Sequential
from keras.layers import GRU, Dense
from keras.optimizers import Adam
model = Sequential()
model.add(GRU(units=128,input_shape=(sample_maxlen, len(words)))) # GRU 层
model.add(Dense(units=len(words), activation='softmax')) # FC 层 多分类 softmax
  • 训练
optimizer = Adam(learning_rate=0.001) # adam 优化器
model.compile(optimizer=optimizer,loss='categorical_crossentropy',metrics=['accuracy']) # 配置模型
history = model.fit(X, y, batch_size=128, epochs=500) # 训练
model.save("tangshi_generator_model.h5") # 保存模型
  • 绘制训练曲线
import pandas as pd
import matplotlib.pyplot as plt
pd.DataFrame(history.history).plot(figsize=(8, 5)) # 绘制训练曲线
plt.grid(True)
plt.gca().set_ylim(0, 1) # set the vertical range to [0-1]
plt.show()


模型在 100 个 epochs 时已基本上完全拟合了训练数据

6. 文本生成

  • 采样函数
def sampling(preds, temperature=1.0):preds = np.asarray(preds).astype('float64')preds = np.log(preds)/temperature # 我的理解是概率平滑exp_preds = np.exp(preds)preds = exp_preds/np.sum(exp_preds)probs = np.random.multinomial(1, preds, 1)
#   多项式分布,做n次试验,按照preds的概率分布(和=1),取出size组结果,如下
#     >>> np.random.multinomial(20, [1/6.]*6, size=1)
#             array([[4, 1, 7, 5, 2, 1]]) # randomreturn np.argmax(probs) # 返回概率最大的idx
  • 随机选取训练文本里的一段开始生成后序文本
from keras.models import load_model
import random
model = load_model("tangshi_generator_model.h5")def generate_tangshi(model, generate_len=200):start_idx = random.randint(0, len(text)-sample_maxlen-1) # 随机开始位置generated = ""sentence = text[start_idx : start_idx + sample_maxlen] # 开始的句子generated += sentenceprint("随机选取的开始句子为:{}".format(generated))for i in range(generate_len): # 后续要生成的句子长度x_pred = np.zeros((1, sample_maxlen, len(words)))for t, w in enumerate(sentence):x_pred[0, t, word_idx[w]] = 1 # 当前句子的 OH 编码preds = model.predict(x_pred)[0] # predict 返回 (1,2590)一个样本 2590个类预测值next_idx = sampling(preds, 1) # 采样出来下一个最有可能的词的idxnext_w = idx_word[next_idx] # 取出这个词generated += next_w # 加到句子中sentence = sentence[1:] + next_w # 句子窗口后移一个位置,作为下次预测的输入return generatedgenerate_tangshi(model, 100)

输出:


模型完全记住了后续的诗句。

  • 自己随意编写训练集里没有的诗句作为开始,如下(不可有训练集中未出现的字)


with open('test.txt','r',encoding='utf-8') as f:test_text = f.read()def generate_tangshi_test(model, generate_len=60):generated = ""sentence = test_text[0 : sample_maxlen]generated += sentenceprint("测试文本开始句子为:{}".format(generated))for i in range(generate_len):x_pred = np.zeros((1, sample_maxlen, len(words)))for t, w in enumerate(sentence):x_pred[0, t, word_idx[w]] = 1preds = model.predict(x_pred)[0]next_idx = sampling(preds, 1)next_w = idx_word[next_idx]generated += next_wsentence = sentence[1:] + next_wreturn generatedgenerate_tangshi_test(model, 100)

输出:

输出的诗句有部分一整句都是训练集里的,有的则不是,反正没有多少诗的美感,哈哈!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/473398.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

循环与分支

1. 循环 for循环for arg in [list] 这是一个基本的循环结构. 它与C语言中的for循环结构有很大的不同.for arg in [list]docommand(s)...done for arg in "$var1" "$var2" "$var3" ... "$varN" 在[list]中的参数加上双引号是为了阻止单…

Python数据结构常见的八大排序算法(详细整理)

前言 八大排序,三大查找是《数据结构》当中非常基础的知识点,在这里为了复习顺带总结了一下常见的八种排序算法。 常见的八大排序算法,他们之间关系如下: 排序算法.png 他们的性能比较: 下面,利用Python分别…

牛客 牛牛选物(01背包)

文章目录1. 题目2. 解题1. 题目 链接:https://ac.nowcoder.com/acm/contest/9887/A 来源:牛客网 牛牛有现在有n个物品,每个物品有一个体积v[i]和重量g[i],他想选择其中总体积恰好为V的若干个物品,想使这若干个物品的总重量最大&…

asp.net学习之再论sqlDataSource

asp.net学习之再论sqlDataSource 原文:asp.net学习之再论sqlDataSource本节从上一节没有阐述的几个方面,再讨论一下SqlDataSource的用法及注意的事项。 上一节的链接地址如下:http://www.cnblogs.com/shipfi/archive/2009/10/15/1584093.html 1. S…

微信小程序最常用的布局——Flex布局

最近在学习微信小程序,在设计首页布局的时候,新认识了一种布局方式display:flex 1 .container { 2 display: flex; 3 flex-direction: column; 4 align-items: center; 5 background-color: #b3d4db; 6 } 编译之后的效果很明显,界面…

LeetCode 649. Dota2 参议院(循环队列)

文章目录1. 题目2. 解题1. 题目 Dota2 的世界里有两个阵营:Radiant(天辉)和 Dire(夜魇) Dota2 参议院由来自两派的参议员组成。现在参议院希望对一个 Dota2 游戏里的改变作出决定。他们以一个基于轮为过程的投票进行。在每一轮中,每一位参议员都可以行…

'[linux下tomcat 配置

tomcat目录结构bin ——Tomcat执行脚本目录 conf ——Tomcat配置文件 lib ——Tomcat运行需要的库文件(JARS) logs ——Tomcat执行时的LOG文件 temp ——Tomcat临时文件存放目录 webapps ——Tomcat的主要Web发布目录(存放我们自己的JSP,SER…

微信小程序基础(一)

一.注册小程序账号,下载IDE 1.官网注册https://mp.weixin.qq.com/,并下载IDE。 2.官方文档一向都是最好的学习资料。 注意: (1)注册账号之后会有一个appid,新建项目的时候需要填上,不然很多…

[Kaggle] Spam/Ham Email Classification 垃圾邮件分类(RNN/GRU/LSTM)

文章目录1. 读入数据2. 文本处理3. 建模4. 训练5. 测试练习地址:https://www.kaggle.com/c/ds100fa19 相关博文 [Kaggle] Spam/Ham Email Classification 垃圾邮件分类(spacy) [Kaggle] Spam/Ham Email Classification 垃圾邮件分类&#xff…

禁止网页复制

<SCRIPT LANGUAGEjavascript>function click() {alert(禁止你的左键复制&#xff01;) }function click1() {if (event.button2) {alert(禁止右键点击~&#xff01;) }}function CtrlKeyDown(){if (event.ctrlKey) {alert(不当的拷贝将损害您的系统&#xff01;) }}docum…

微信小程序中实现瀑布流布局和无限加载

瀑布流布局是一种比较流行的页面布局方式&#xff0c;最典型的就是Pinterest.com&#xff0c;每个卡片的高度不都一样&#xff0c;形成一种参差不齐的美感。 在HTML5中&#xff0c;我们可以找到很多基于jQuery之类实现的瀑布流布局插件&#xff0c;轻松做出这样的布局形式。在…

LeetCode 1684. 统计一致字符串的数目(哈希)

文章目录1. 题目2. 解题1. 题目 给你一个由不同字符组成的字符串 allowed 和一个字符串数组 words 。 如果一个字符串的每一个字符都在 allowed 中&#xff0c;就称这个字符串是 一致 字符串。 请你返回 words 数组中 一致 字符串的数目。 示例 1&#xff1a; 输入&#xff…

Android下常见的内存泄露 经典

转自&#xff1a;http://www.linuxidc.com/Linux/2011-10/44785.htm 因为Android使用Java作为开发语言&#xff0c;很多人在使用会不注意内存的问题。 于是有时遇到程序运行时不断消耗内存&#xff0c;最终导致OutOfMemery&#xff0c;程序异常退出&#xff0c;这就是内存泄露导…

微信小程序:页面跳转时传递数据到另一个页面

一、功能描述 页面跳转时&#xff0c;同时把当前页面的数据传递给跳转的目标页面&#xff0c;并在跳转后的目标页面进行展示 二、功能实现 1. 代码实现 test1页面 // pages/test1/test1.js Page({/*** 页面的初始数据*/data: {name:Tom,age:12},buttonListener:function(){…

LeetCode 1685. 有序数组中差绝对值之和(前缀和)

文章目录1. 题目2. 解题1. 题目 给你一个 非递减 有序整数数组 nums 。 请你建立并返回一个整数数组 result&#xff0c;它跟 nums 长度相同&#xff0c;且result[i] 等于 nums[i] 与数组中所有其他元素差的绝对值之和。 换句话说&#xff0c; result[i] 等于 sum(|nums[i]-…

对一个 复杂的json结果进行取值的例子

1 JSON结果集 1 [2 {3 "J_LP_OPERATE_MAIN": {4 "ID": "1900036295",5 "FILL_MAN": "周兴福",6 "FILL_DEPT": "运维一班",7 "STATE…

微信小程序正则判断姓名和手机号

一、页面效果 二、json文件 //获取应用实例 const app getApp() Page({/*** 页面的初始数据*/data: {array: [速美, 现代, 淮安],mode: scaleToFill,src: ../../images/1.png,userInfo: {},hasUserInfo: false,canIUse: wx.canIUse(button.open-type.getUserInfo),userName: …

LeetCode 1686. 石子游戏 VI(贪心)

文章目录1. 题目2. 解题283 / 1660&#xff0c;前17%681 / 6572&#xff0c;前10.4%1. 题目 Alice 和 Bob 轮流玩一个游戏&#xff0c;Alice 先手。 一堆石子里总共有 n 个石子&#xff0c;轮到某个玩家时&#xff0c;他可以 移出 一个石子并得到这个石子的价值。 Alice 和 B…

T4生成实体和简单的CRUD操作

主要跟大家交流下T4,我这里针对的是mysql,我本人比较喜欢用mysql,所以语法针对mysql,所以你要准备mysql的DLL了,同理sqlserver差不多,有兴趣可以自己写写,首先网上找了一个T4的帮助类,得到一些数据库属性,命名为 DbHelper.ttinclude <# template debug"false" hos…

微信小程序的不同函数调用的几种方法

一、调取参数 直接调取当前js中的方法, 调取参数that.bindViewTap(); 二、跳转页面 navigateTo: function () { wx.navigateTo({ url: ../page4/page4 }); },全局变量使用方法 a.js var app getApp() Page({ data: { hex1: [], })} //设置全局变量 if (hex1 ! null) { app.…