【LSTM实战】跨越千年,赋诗成文:用LSTM重现唐诗的韵律与情感

本文将介绍如何使用LSTM训练一个能够创作诗歌的模型。为了训练出效果优秀的模型,我整理了来自网络的4万首诗歌数据集。我们的模型可以直接使用预先训练好的参数,这意味着您无需从头开始训练,即可在自己的电脑上体验AI作诗的乐趣。我已经为您准备好了这些训练好的参数,让您能够轻松地在自己的设备上开始创作。本文将详细讲解如何在个人电脑上运行该模型,即使您没有机器学习方面的背景知识,也能轻松驾驭,让您的AI模型在自己的电脑上运行起来,体验AI创作诗歌的乐趣.所有的代码和资料都在仓库:https://gitee.com/yw18791995155/generate_poetry.git

秋风吹拂,窗外的树叶似灵动的舞者翩翩而舞,落日余晖将天际晕染成一片醉人的橘红。
与此同时,AI 于知识的瀚海中遨游,遍览数千篇文章后,开启了它的首次创作之旅。

在对近 4 万首唐诗深度学习之后,赋诗如下:

在这里插入图片描述

此诗颇具韵味,实乃勤勉研习之硕果。汲取全唐诗之精华,方成就这般非凡之能,常人岂易企及?

本博客将简要分析其中的技术细节,若有阐释未尽之处,在此诚挚欢迎诸君于评论区畅所欲言,各抒己见。先呈上仓库链接https://gitee.com/yw18791995155/generate_poetry.git

若诸位无暇详阅,不妨为该项目点亮 star 或进行 fork,诸君的每一份支持都将如熠熠星光,化作我砥砺前行之强劲动力源泉。言归正传,让我们一同开启打造AI诗人的旅程吧

01 环境配置

在开始之前,确保你的电脑已经安装了必要的依赖库:PyTorch 和 NumPy。安装命令如下:

   pip install torch torchvision torchaudio numpy

一切就绪,我们可以开始了!

02 初识LSTM

长短期记忆网络 LSTM是一种特殊类型的循环神经网络(RNN),它被设计用来解决传统RNN在处理长序列数据时遇到的长期依赖性问题(梯度消失和梯度爆炸问题)。
在这里插入图片描述

LSTM的核心优势在于其能够学习并记住长期的信息依赖关系。这种能力使得LSTM在处理长文本内容时比普通RNN更为出色。LSTM网络中包含了四个主要的组件,它们通过门控机制来控制信息的流动:

  1. 遗忘门(Forget Gate):决定哪些信息应该被遗忘,不再保留在单元状态中。
  2. 输入门(Input Gate):决定哪些新信息将被存储在单元状态中。
  3. 单元状态(Cell State):携带数据穿越时间的信息带,可以看作是LSTM的“记忆”。
  4. 输出门(Output Gate):决定哪些信息将从单元状态输出到下一个隐藏状态。

这些门控机制使得LSTM能够有选择性地保留或遗忘信息,从而有效地捕捉和利用长期依赖性。这种设计灵感来源于对传统RNN在处理长序列时遗忘信息的挑战的回应,LSTM通过这些门控结构,使得网络能够更加灵活地处理时间序列数据。

03处理数据

接下来,首先要做的就是读取准备好的诗歌数据。然后对数据进行清洗,剔除那些包含特殊字符或长度不符合要求的诗歌。清洗完数据后,我们会为每首诗加上开始和结束的标志,确保生成的诗歌有明确的起止符号。

然后,我们会构建词典,为每个词分配一个唯一的索引,同时建立词汇到索引、索引到词汇的映射关系。最后,把每首诗转换成数字序列,这样就能让模型进行处理了。

import collections
import numpy as np
import torch# 定义起始和结束标记
start_token = 'B'
end_token = 'E'def process_poems(file_name):"""处理诗歌文件,将诗歌转换为数字序列,并构建词汇表。:param file_name: 诗歌文件的路径:return:- poems_vector: 诗歌的数字序列列表- word_to_idx: 词汇到索引的映射字典- idx_to_word: 索引到词汇的映射列表"""# 初始化诗歌列表poems = []# 读取文件并处理每一行with open(file_name, "r", encoding='utf-8') as f:for line in f.readlines():try:# 分割标题和内容title, content = line.strip().split(':')content = content.replace(' ', '')# 过滤掉包含特殊字符的诗歌if '_' in content or '(' in content or '(' in content or '《' in content or '[' in content or \start_token in content or end_token in content:continue# 过滤掉长度不符合要求的诗歌if len(content) < 5 or len(content) > 79:continue# 添加起始和结束标记content = start_token + content + end_tokenpoems.append(content)except ValueError as e:pass# 统计所有单词的频率all_words = [word for poem in poems for word in poem]counter = collections.Counter(all_words)words = sorted(counter.keys(), key=lambda x: counter[x], reverse=True)# 添加空格作为填充符words.append(' ')words_length = len(words)# 构建词汇到索引和索引到词汇的映射word_to_idx = {word: i for i, word in enumerate(words)}idx_to_word = [word for word in words]# 将诗歌转换为数字序列poems_vector = [[word_to_idx[word] for word in poem] for poem in poems]return poems_vector, word_to_idx, idx_to_worddef generate_batch(batch_size, poems_vec, word_to_int):"""生成批量训练数据。:param batch_size: 批量大小:param poems_vec: 诗歌的数字序列列表:param word_to_int: 词汇到索引的映射字典:return:- x_batches: 输入数据批次- y_batches: 目标数据批次"""# 计算可以生成的批次数num_example = len(poems_vec) // batch_sizex_batches = []y_batches = []for i in range(num_example):start_index = i * batch_sizeend_index = start_index + batch_size# 获取当前批次的诗歌batches = poems_vec[start_index:end_index]# 找到当前批次中最长的诗歌长度length = max(map(len, batches))# 初始化输入数据,使用空格进行填充x_data = np.full((batch_size, length), word_to_int[' '], np.int32)# 填充输入数据for row, batch in enumerate(batches):x_data[row, :len(batch)] = batch# 创建目标数据,目标数据是输入数据向右移一位y_data = np.copy(x_data)y_data[:, :-1] = x_data[:, 1:]"""x_data             y_data[6,2,4,6,9]       [2,4,6,9,9][1,4,2,8,5]       [4,2,8,5,5]"""# 将当前批次的数据添加到列表中yield torch.tensor(x_data), torch.tensor(y_data)

04创建模型

现在是时候搭建我们的 LSTM 模型了!我们将创建一个双层 LSTM 网络。双层 LSTM 比单层的更有能力捕捉复杂的模式和结构,能够更好地处理诗歌这种带有丰富语言特征的任务。

import torch
import torch.nn as nn
import torch.optim as optimclass RNNModel(nn.Module):def __init__(self, vocab_size, rnn_size=128, num_layers=2):"""构建RNN序列到序列模型。:param vocab_size: 词汇表大小:param rnn_size: RNN隐藏层大小:param num_layers: RNN层数"""super(RNNModel, self).__init__()# 选择LSTM单元# 参数说明:输入大小、隐藏层大小、层数、batch_first=True表示输入数据的第一维是批次大小self.cell = nn.LSTM(rnn_size, rnn_size, num_layers, batch_first=True)# 嵌入层,将词汇表中的词转换为向量# vocab_size + 1 是因为在词嵌入中需要有一个特殊标记,用于表示填充位置,所以词嵌入时会加一个词。self.embedding = nn.Embedding(vocab_size + 1, rnn_size)# RNN隐藏层大小self.rnn_size = rnn_size# 全连接层,用于输出预测# 输入大小为RNN隐藏层大小,输出大小为词汇表大小加1self.fc = nn.Linear(rnn_size, vocab_size + 1)def forward(self, input_data, hidden):"""前向传播:param input_data: 输入数据,形状为 (batch_size, sequence_length):param output_data: 输出数据(训练时提供),形状为 (batch_size, sequence_length):return: 输出结果或损失"""# 获取批次大小batch_size = input_data.size(0)# 嵌入层,将输入数据转换为向量# 输入数据形状为 (batch_size, sequence_length),嵌入后形状为 (batch_size, sequence_length, rnn_size)embedded = self.embedding(input_data)# 通过RNN层# 输入形状为 (batch_size, sequence_length, rnn_size),输出形状为 (batch_size, sequence_length, rnn_size)outputs, hidden = self.cell(embedded, hidden)# 将输出展平# 展平后的形状为 (batch_size * sequence_length, rnn_size)outputs = outputs.contiguous().view(-1, self.rnn_size)# 通过全连接层# 输入形状为 (batch_size * sequence_length, rnn_size),输出形状为 (batch_size * sequence_length, vocab_size + 1)logits = self.fc(outputs)return logits, hidden

05训练模型

接下来,就是我们最考验耐性的部分——训练模型了。训练过程中,你可能需要一些时间,所以建议使用 GPU 加速。经过实测,使用 GPU 训练速度大约是 CPU 的四倍左右。所以,如果你有条件,最好让 GPU 出马,省时省力。

import torch
from model import RNNModel
from torch import nn
from poem_data_processing import *
import os
import time# 检查是否有可用的GPU,如果没有则使用CPU
# windows用户使用torch.cuda.is_available()来检查是否有可用的GPU。
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")print(f"Using device: {device}")def train(poems_path, num_epochs, batch_size, lr):"""训练RNN模型并进行预测。参数:poems_path (str): 诗歌数据文件路径。num_epochs (int): 训练的轮数。batch_size (int): 批次大小。lr (float): 学习率。"""# 确保模型保存目录存在if not os.path.exists('./model'):os.makedirs('./model')# 处理诗歌数据,生成向量化表示和映射字典poems_vector, word_to_idx, idx_to_word = process_poems(poems_path)# 初始化RNN模型并将其移动到指定设备model = RNNModel(len(idx_to_word), 128, num_layers=2).to(device)# 使用Adam优化器初始化训练器trainer = torch.optim.Adam(model.parameters(), lr=lr)# 使用交叉熵损失函数loss_fn = nn.CrossEntropyLoss()# 开始训练过程for epoch in range(num_epochs):loss_sum = 0start = time.time()# 生成并迭代训练批次for X, Y in generate_batch(batch_size, poems_vector, word_to_idx):# 将输入和目标数据移动到指定设备X = X.to(device)Y = Y.to(device)state = None# 前向传播outputs, state = model(X, state)Y = Y.view(-1)# 计算损失l = loss_fn(outputs, Y.long())# 反向传播和优化trainer.zero_grad()l.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), 0.01)trainer.step()loss_sum += l.item() * Y.shape[0]end = time.time()print(f"Time cost: {end - start}s")print(f"epoch: {epoch}, loss: {loss_sum / len(poems_vector)}")# 保存模型和优化器的状态try:torch.save({'model_state_dict': model.state_dict(),'optimizer_state_dict': trainer.state_dict(),}, os.path.join('./model', 'torch-latest.pth'))except Exception as e:print(f"Error saving model: {e}")if __name__ == "__main__":file_path = "./data/poems.txt"train(file_path, num_epochs=100, batch_size=64, lr=0.002)

经过了约 20 分钟的训练,终于,模型训练完成!训练结束后,模型的参数会自动保存到文件中,这样下次就可以直接加载预训练的模型,省去重新训练的麻烦

06测试模型

终于,我们来到了最激动人心的环节——AI作诗。经过几个小时的努力,我们的AI诗人已经准备好创作一首藏头诗,以此来弥补我因编程而失去的头发。
鸡枝蝉及九层峰,内邸曾随佛统衣。

你写明时何处寻,大江蕃戴帝来儿。

太古能弗岂何如,惟无百物恣蹉跎。

美人迟意识王机,马首辞来六堕愁。

测试代码

# 导入必要的库
import torch
from model import RNNModel
from poem_data_processing import process_poems
import numpy as np# 定义开始和结束标记
start_token = 'B'
end_token = 'E'
# 模型保存的目录
model_dir = './model/'
# 诗歌数据文件路径
poems_file = './data/poems.txt'# 学习率
lr = 0.0002def to_word(predict, vocabs):"""将预测结果转换为词汇表中的字。参数:predict: 模型的预测结果,一个概率分布。vocabs: 词汇表,包含所有可能的字。返回:从预测结果中随机选择的一个字。"""predict = predict.numpy()[0]predict /= np.sum(predict)sample = np.random.choice(np.arange(len(predict)), p=predict)if sample > len(vocabs):return vocabs[-1]else:return vocabs[sample]def gen_poem(begin_word):"""生成诗歌。参数:begin_word: 诗歌的第一个字。返回:生成的诗歌,以字符串形式返回。"""batch_size = 1# 处理诗歌数据,得到诗歌向量、字到索引的映射和索引到字的映射poems_vector, word_to_idx, idx_to_word = process_poems(poems_file)# 初始化模型model = RNNModel(len(idx_to_word), 128, num_layers=2)# 加载模型参数checkpoint = torch.load(f'{model_dir}/torch-latest.pth')model.load_state_dict(checkpoint['model_state_dict'], strict=False)model.eval()# 初始化输入序列x = torch.tensor([word_to_idx[start_token]], dtype=torch.long).view(1, 1)hidden = None# 生成诗歌with torch.no_grad():output, hidden = model(x, hidden)predict = torch.softmax(output, dim=1)word = begin_word or to_word(predict, idx_to_word)poem_ = ''i = 0while word != end_token:poem_ += wordi += 1if i > 24:breakx = torch.tensor([word_to_idx[word]], dtype=torch.long).view(1, 1)output, hidden = model(x, hidden)predict = torch.softmax(output, dim=1)word = to_word(predict, idx_to_word)return poem_def pretty_print_poem(poem_):"""格式化打印诗歌。参数:poem_: 生成的诗歌,以字符串形式输入。"""poem_sentences = poem_.split('。')for s in poem_sentences:if s != '' and len(s) > 10:print(s + '。')if __name__ == '__main__':# 用户输入第一个字begin_char = input('请输入第一个字 please input the first character: \n')print('AI作诗 generating poem...')# 生成诗歌poem = gen_poem(begin_char)# 打印诗歌pretty_print_poem(poem_=poem)

效果出乎意料地好,所有的努力都值了。是不是觉得很有趣?快来下载代码,亲自体验AI作诗的乐趣吧。

项目目录

在这里插入图片描述

  • 训练模型,运行train.py文件。
  • 想直接体验AI作诗,运行test.py文件。

如果不想从头训练,可以直接使用预训练好的模型参数,这些参数已经保存在文件中,只需下载仓库的所有代码和文件即可
仓库地址:https://gitee.com/yw18791995155/generate_poetry.git

读到这里,如果你觉得这篇文章有点意思,不妨转发点赞。如果你对AI小项目感兴趣,欢迎关注我,我会持续分享更多有趣的项目。

感谢你的阅读,愿你的代码永远没有bug,头发永远浓密!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/61655.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Swift】运算符

文章目录 术语赋值运算符算数运算符基本四则算术运算符求余运算符一元负号运算符一元正号运算符 比较运算符三元运算符空合运算符区间运算符闭区间运算符半开区间运算符单侧区间运算符 逻辑运算符逻辑非运算符逻辑与运算符逻辑或运算符逻辑运算符组合计算 位运算符运算符优先级…

微信小程序技术架构图

一、视图层1.WXML&#xff08;WeiXin Markup Language&#xff09; 这是微信小程序的标记语言&#xff0c;类似于 HTML。它用于构建小程序的页面结构。例如&#xff0c;通过标签来定义各种视图元素&#xff0c;如<view>&#xff08;类似于 HTML 中的<div>&#xff…

【AI最前线】DP双像素sensor相关的AI算法全集:深度估计、图像去模糊去雨去雾恢复、图像重建、自动对焦

Dual Pixel 简介 双像素是成像系统的感光元器件中单帧同时生成的图像&#xff1a;通过双像素可以实现&#xff1a;深度估计、图像去模糊去雨去雾恢复、图像重建 成像原理来源如上&#xff0c;也有遮罩等方式的pd生成&#xff0c;如图双像素视图可以看到光圈的不同一半&#x…

#Verilog HDL# Verilog中的UDP原语

目录 一 UDP符号 1.1 组合UDP 1.2 时序UDP 1.2.1 电平UDP 1.2.2 边沿UDP 标准的Verilog原语,如nand和not,有时可能不足以或不便用于表示复杂逻辑。为了建模组合逻辑或时序逻辑,可以定义称为用户定义原语(UDP)的新原语元素。 所有UDP都有且仅有一个输出,该输出可以是…

Python 版本的 2024详细代码

2048游戏的Python实现 概述&#xff1a; 2048是一款流行的单人益智游戏&#xff0c;玩家通过滑动数字瓷砖来合并相同的数字&#xff0c;目标是合成2048这个数字。本文将介绍如何使用Python和Pygame库实现2048游戏的基本功能&#xff0c;包括游戏逻辑、界面绘制和用户交互。 主…

socket连接封装

效果&#xff1a; class websocketMessage {constructor(params) {this.params params; // 传入的参数this.socket null;this.lockReconnect false; // 重连的锁this.socketTimer null; // 心跳this.lockTimer null; // 重连this.timeout 3000; // 发送消息this.callbac…

docker compose的安装和使用

1. Docker Compose 简介 Docker Compose 是一个工具&#xff0c;用于定义和运行多容器的 Docker 应用。通过编写一个 docker-compose.yml 文件&#xff0c;可以一次性启动所有容器&#xff0c;并且方便管理容器之间的依赖。 2. 安装 Docker Compose 前提条件 确保已安装 Do…

银河麒麟v10 x86架构二进制方式kubeadm+docker+cri-docker搭建k8s集群(证书有效期100年) —— 筑梦之路

环境说明 master&#xff1a;192.168.100.100 node: 192.168.100.101 kubeadm 1.31.2 &#xff08;自编译二进制文件&#xff0c;证书有效期100年&#xff09; 银河麒麟v10 sp2 x86架构 内核版本&#xff1a;5.4.x 编译安装 cgroup v2启用 docker版本&#xff1a;27.x …

大语言模型---RewardBench 介绍;RewardBench 的主要功能;适用场景

文章目录 1. RewardBench 介绍2. RewardBench 的主要功能3. 适用场景 1. RewardBench 介绍 RewardBench: Evaluating Reward Models是一个专门用于评估 Reward Models&#xff08;奖励模型&#xff09; 的公开平台&#xff0c;旨在衡量模型在多种任务上的性能&#xff0c;包括…

基于Redis实现的手机短信登入功能

目录 开发准备 注册阿里短信服务 依赖坐标 阿里短信 依赖 mybatis-plus 依赖 redis 依赖 配置文件 导入数据库表 短信发送工具类 生成随机验证码的工具类 校验合法手机号的工具类 ThreadLocal 线程工具类 消息工具类 基于 session 的短信登录的问题 开发教程 Redis 结构设计 …

Java语言程序设计 选填题知识点总结

第一章 javac.exe是JDK提供的编译器public static void main (String args[])是Java应用程序主类中正确的main方法Java源文件是由若干个书写形式互相独立的类组成的Java语言的名字是印度尼西亚一个盛产咖啡的岛名Java源文件中可以有一个或多个类Java源文件的扩展名是.java如果…

免费好用的静态网页托管平台全面对比介绍

5个免费好用的静态网页托管平台全面对比 前言 作为一名前端开发者&#xff0c;经常会遇到需要部署静态网页的场景。无论是个人项目展示、简单的游戏demo还是作品集网站&#xff0c;选择一个合适的托管平台都很重要。本文将详细介绍5个免费的静态网页托管平台&#xff0c;帮助…

python正则表达式基本字符字符

字符 描述 text 匹配text字符串 . 匹配除换行符之外的任意一个单个字符 ^ 匹配一个字符串的开头 $ 匹配一个字符串的末尾 在正则表达式中,我们还可用匹配限定符来约束匹配的次数 2. 匹配限定符 最大匹配 最小匹配 描述 * *? 重复匹配前表达式零次或多次 &a…

k8s篇之控制器类型以及各自的适用场景

1. k8s中控制器介绍 在 Kubernetes 中,控制器(Controller)是集群中用于管理资源的关键组件。 它们的核心作用是确保集群中的资源状态符合用户的期望,并在需要时自动进行调整。 Kubernetes 提供了多种不同类型的控制器,每种控制器都有其独特的功能和应用场景。 2. 常见的…

python程序的编写以及发布(形象类比)

最近重新接触python&#xff0c;本人之前对于python的虚拟环境&#xff0c;安装包比较比较迷惑&#xff0c;这里给出一个具象的理解。可以将 Python 程序运行的过程类比成一次 做菜的过程&#xff0c;从准备食材到最后出锅。以下是具体的类比步骤&#xff1a; 1. 安装 Python 环…

shell基础知识3 --- 流程控制之条件判断

条件判断语句是一种最简单的流程控制语句。该语句使得程序根据不同的条件来执行不同的程序分支。 一、if语句语法 1.单分支结构 法1&#xff1a; 法2&#xff1a; if <条件表达式> if…

功耗中蓝牙扫描事件插桩埋点

手机功耗中蓝牙扫描事件插桩埋点 功耗主要监控蓝牙扫描的时间和次数&#xff0c;进而换算为频次监控。其中不同的蓝牙扫描模式带来的功耗影响也是不一样的。 即功耗影响度低延迟扫描>平衡模式扫描>低功耗模式。例如某款机型分别为&#xff1a;低延迟扫描 14.64mA,平衡模…

电容测试流程

一、外观检测 1. 目的&#xff1a;检验电容样品外观是否与规格书一致&#xff0c;制程工艺是否良好&#xff0c;确保部品的品质。 2. 仪器&#xff1a;放大镜 3. 测试说明&#xff1a; &#xff08;1&#xff09;样品上丝印与规格书中相符&#xff0c;丝印信息&#xff08;…

探索 .NET 9 控制台应用中的 LiteDB 异步 CRUD 操作

本文主要是使用异步方式&#xff0c;体验 litedb 基本的 crud 操作。 LiteDB 是一款轻量级、快速且免费的 .NET NoSQL 嵌入式数据库&#xff0c;专为小型本地应用程序设计。它以单一数据文件的形式提供服务&#xff0c;支持文档存储和查询功能&#xff0c;适用于桌面应用、移动…

leetcode刷题记录(四十二)——101. 对称二叉树

&#xff08;一&#xff09;问题描述 . - 力扣&#xff08;LeetCode&#xff09;. - 备战技术面试&#xff1f;力扣提供海量技术面试资源&#xff0c;帮助你高效提升编程技能,轻松拿下世界 IT 名企 Dream Offer。https://leetcode.cn/problems/symmetric-tree/description/给你…