使用PyTorch实现LSTM生成ai诗

最近学习torch的一个小demo。

什么是LSTM?

长短时记忆网络(Long Short-Term Memory,LSTM)是一种循环神经网络(RNN)的变体,旨在解决传统RNN在处理长序列时的梯度消失和梯度爆炸问题。LSTM引入了一种特殊的存储单元和门控机制,以更有效地捕捉和处理序列数据中的长期依赖关系。

通俗点说就是:LSTM是一种改进版的递归神经网络(RNN)。它的主要特点是可以记住更长时间的信息,这使得它在处理序列数据(如文本、时间序列、语音等)时非常有效。

步骤如下

数据准备

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import string
import os# 数据加载和预处理
def load_data(filepath):with open(filepath, 'r', encoding='utf-8') as file:text = file.read()return textdef preprocess_text(text):text = text.lower()text = text.translate(str.maketrans('', '', string.punctuation))return textdata_path = 'poetry.txt'  # 替换为实际的诗歌数据文件路径
text = load_data(data_path)
text = preprocess_text(text)
chars = sorted(list(set(text)))
char_to_idx = {char: idx for idx, char in enumerate(chars)}
idx_to_char = {idx: char for char, idx in char_to_idx.items()}
vocab_size = len(chars)print(f"Total characters: {len(text)}")
print(f"Vocabulary size: {vocab_size}")

模型构建

定义LSTM模型:

class LSTMModel(nn.Module):def __init__(self, input_size, hidden_size, output_size, num_layers=2):super(LSTMModel, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)self.softmax = nn.LogSoftmax(dim=1)def forward(self, x, hidden):lstm_out, hidden = self.lstm(x, hidden)output = self.fc(lstm_out[:, -1, :])output = self.softmax(output)return output, hiddendef init_hidden(self, batch_size):weight = next(self.parameters()).datahidden = (weight.new(self.num_layers, batch_size, self.hidden_size).zero_(),weight.new(self.num_layers, batch_size, self.hidden_size).zero_())return hidden

训练模型

将数据转换成LSTM需要的格式:

def prepare_data(text, seq_length):inputs = []targets = []for i in range(0, len(text) - seq_length, 1):seq_in = text[i:i + seq_length]seq_out = text[i + seq_length]inputs.append([char_to_idx[char] for char in seq_in])targets.append(char_to_idx[seq_out])return inputs, targetsseq_length = 100
inputs, targets = prepare_data(text, seq_length)# Convert to tensors
inputs = torch.tensor(inputs, dtype=torch.long)
targets = torch.tensor(targets, dtype=torch.long)batch_size = 64
input_size = vocab_size
hidden_size = 256
output_size = vocab_size
num_epochs = 20
learning_rate = 0.001model = LSTMModel(input_size, hidden_size, output_size)
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)# Training loop
for epoch in range(num_epochs):h = model.init_hidden(batch_size)total_loss = 0for i in range(0, len(inputs), batch_size):x = inputs[i:i + batch_size]y = targets[i:i + batch_size]x = nn.functional.one_hot(x, num_classes=vocab_size).float()output, h = model(x, h)loss = criterion(output, y)optimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item()print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(inputs):.4f}")

生成

def generate_text(model, start_str, length=100):model.eval()with torch.no_grad():input_eval = torch.tensor([char_to_idx[char] for char in start_str], dtype=torch.long).unsqueeze(0)input_eval = nn.functional.one_hot(input_eval, num_classes=vocab_size).float()h = model.init_hidden(1)predicted_text = start_strfor _ in range(length):output, h = model(input_eval, h)prob = torch.softmax(output, dim=1).datapredicted_idx = torch.multinomial(prob, num_samples=1).item()predicted_char = idx_to_char[predicted_idx]predicted_text += predicted_charinput_eval = torch.tensor([[predicted_idx]], dtype=torch.long)input_eval = nn.functional.one_hot(input_eval, num_classes=vocab_size).float()return predicted_textstart_string = "春眠不觉晓"
generated_text = generate_text(model, start_string)
print(generated_text)

运行结果如下:

运行的肯定不好,但至少出结果了。诗歌我这边只放了几句,可以自己通过外部文件放入更多素材。

整体代码直接运行即可:

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import string# 预定义一些中文诗歌数据
text = """
春眠不觉晓,处处闻啼鸟。
夜来风雨声,花落知多少。
床前明月光,疑是地上霜。
举头望明月,低头思故乡。
红豆生南国,春来发几枝。
愿君多采撷,此物最相思。
"""# 数据预处理
def preprocess_text(text):text = text.replace('\n', '')return texttext = preprocess_text(text)
chars = sorted(list(set(text)))
char_to_idx = {char: idx for idx, char in enumerate(chars)}
idx_to_char = {idx: char for char, idx in char_to_idx.items()}
vocab_size = len(chars)print(f"Total characters: {len(text)}")
print(f"Vocabulary size: {vocab_size}")class LSTMModel(nn.Module):def __init__(self, input_size, hidden_size, output_size, num_layers=2):super(LSTMModel, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)self.softmax = nn.LogSoftmax(dim=1)def forward(self, x, hidden):lstm_out, hidden = self.lstm(x, hidden)output = self.fc(lstm_out[:, -1, :])output = self.softmax(output)return output, hiddendef init_hidden(self, batch_size):weight = next(self.parameters()).datahidden = (weight.new(self.num_layers, batch_size, self.hidden_size).zero_(),weight.new(self.num_layers, batch_size, self.hidden_size).zero_())return hiddendef prepare_data(text, seq_length):inputs = []targets = []for i in range(0, len(text) - seq_length, 1):seq_in = text[i:i + seq_length]seq_out = text[i + seq_length]inputs.append([char_to_idx[char] for char in seq_in])targets.append(char_to_idx[seq_out])return inputs, targetsseq_length = 10
inputs, targets = prepare_data(text, seq_length)# Convert to tensors
inputs = torch.tensor(inputs, dtype=torch.long)
targets = torch.tensor(targets, dtype=torch.long)batch_size = 64
input_size = vocab_size
hidden_size = 256
output_size = vocab_size
num_epochs = 50
learning_rate = 0.003model = LSTMModel(input_size, hidden_size, output_size)
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)# Training loop
for epoch in range(num_epochs):h = model.init_hidden(batch_size)total_loss = 0for i in range(0, len(inputs), batch_size):x = inputs[i:i + batch_size]y = targets[i:i + batch_size]if x.size(0) != batch_size:continuex = nn.functional.one_hot(x, num_classes=vocab_size).float()output, h = model(x, h)loss = criterion(output, y)optimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item()print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss / len(inputs):.4f}")def generate_text(model, start_str, length=100):model.eval()with torch.no_grad():input_eval = torch.tensor([char_to_idx[char] for char in start_str], dtype=torch.long).unsqueeze(0)input_eval = nn.functional.one_hot(input_eval, num_classes=vocab_size).float()h = model.init_hidden(1)predicted_text = start_strfor _ in range(length):output, h = model(input_eval, h)prob = torch.softmax(output, dim=1).datapredicted_idx = torch.multinomial(prob, num_samples=1).item()predicted_char = idx_to_char[predicted_idx]predicted_text += predicted_charinput_eval = torch.tensor([[predicted_idx]], dtype=torch.long)input_eval = nn.functional.one_hot(input_eval, num_classes=vocab_size).float()return predicted_textstart_string = "春眠不觉晓"
generated_text = generate_text(model, start_string, length=100)
print(generated_text)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/28279.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vue格网图

先看效果 再看代码 <n-gridv-elsex-gap"20":y-gap"20"cols"2 s:2 m:3 l:3 xl:3 2xl:4"responsive"screen" ><n-grid-itemv-for"(item,index) in newSongList":key"item.id"class"cursor-pointer …

Spring学习笔记(九)简单的SSM框架整合

实验目的 掌握SSM框架整合。 实验环境 硬件&#xff1a;PC机 操作系统&#xff1a;Windows 开发工具&#xff1a;idea 实验内容 整合SSM框架。 实验步骤 搭建SSM环境&#xff1a;构建web项目&#xff0c;导入需要的jar包&#xff0c;通过单元测试测试各层框架搭建的正确…

IDEA 设置主题、背景图片、背景颜色

一、设置主题 1、点击菜单 File -> Settings : 点击 Settings 菜单 2、点击 Editor -> Color Scheme -> Scheme, 小哈的 IDEA 版本号为 2022.2.3 , 官方默认提供了 4 种主题&#xff1a; Classic Light &#xff08;经典白&#xff09; ;Darcula &#xff08;暗黑主…

知识普及:什么是边缘计算(Edge Computing)?

边缘计算是一种分布式计算架构&#xff0c;它将数据处理、存储和服务功能移近数据产生的边缘位置&#xff0c;即接近数据源和用户的位置&#xff0c;而不是依赖中心化的数据中心或云计算平台。边缘计算的核心思想是在靠近终端设备的位置进行数据处理&#xff0c;以降低延迟、减…

前端:鼠标点击实现高亮特效

一、实现思路 获取鼠标点击位置 通过鼠标点击位置设置高亮裁剪动画 二、效果展示 三、按钮组件代码 <template><buttonclass"blueBut"click"clickHandler":style"{backgroundColor: clickBut ? rgb(31, 67, 117) : rgb(128, 128, 128),…

16. 第十六章 类和函数

16. 类和函数 现在我们已经知道如何创建新的类型, 下一步是编写接收用户定义的对象作为参数或者将其当作结果用户定义的函数. 本章我会展示函数式编程风格, 以及两个新的程序开发计划.本章的代码示例可以从↓下载. https://github.com/AllenDowney/ThinkPython2/blob/master/c…

java程序在运行过程各个内部结构的作用

一&#xff1a;内部结构 一个进程对应一个jvm实例&#xff0c;一个运行时数据区&#xff0c;又包含多个线程&#xff0c;这些线程共享了方法区和堆&#xff0c;每个线程包含了程序计数器、本地方法栈和虚拟机栈接下来我们通过一个示意图介绍一下这个空间。 如图所示,当一个hell…

11.泛型、trait和生命周期(上)

标题 一、泛型数据的引入二、改写为泛型函数三、结构体/枚举中的泛型定义四、方法定义中的泛型 一、泛型数据的引入 下面是两个函数&#xff0c;分别用来取得整型和符号型vector中的最大值 use std::fs::File;fn get_max_float_value_from_vector(src: &[f64]) -> f64…

代码随想录-Day31

455. 分发饼干 假设你是一位很棒的家长&#xff0c;想要给你的孩子们一些小饼干。但是&#xff0c;每个孩子最多只能给一块饼干。 对每个孩子 i&#xff0c;都有一个胃口值 g[i]&#xff0c;这是能让孩子们满足胃口的饼干的最小尺寸&#xff1b;并且每块饼干 j&#xff0c;都…

vs+qt5.0 使用poppler 操作库

Poppler 是一个用来生成 PDF 的C类库&#xff0c;从xpdf 继承而来。vs编译库如下&#xff1a; vs中只需要添加依赖库即可 头文件&#xff1a;

【UE5|水文章】在UMG上显示帧率

参考视频&#xff1a; https://www.youtube.com/watch?vH_NdvImlI68 蓝图&#xff1a;

数值分析笔记(二)函数插值

函数插值 已知函数 f ( x ) f(x) f(x)在区间[a,b]上n1个互异节点 { x i } i 0 n \{{x_i}\}_{i0}^{n} {xi​}i0n​处的函数值 { y i } i 0 n \{{y_i}\}_{i0}^{n} {yi​}i0n​&#xff0c;若函数集合 Φ \Phi Φ中函数 ϕ ( x ) \phi(x) ϕ(x)满足条件 ϕ ( x i ) y i ( i …

数据结构01 栈及其相关问题讲解【C++实现】

栈是一种线性数据结构&#xff0c;栈的特征是数据的插入和删除只能通过一端来实现&#xff0c;这一端称为“栈顶”&#xff0c;相应的另一端称为“栈底”。 栈及其特点 用一个简单的例子来说&#xff0c;栈就像一个放乒乓球的圆筒&#xff0c;底部是封住的&#xff0c;如果你想…

2024年了,苹果可以通话录音了

人不走空 &#x1f308;个人主页&#xff1a;人不走空 &#x1f496;系列专栏&#xff1a;算法专题 ⏰诗词歌赋&#xff1a;斯是陋室&#xff0c;惟吾德馨 6月11日凌晨&#xff0c;苹果在WWDC24大会上&#xff0c;密集输出了酝酿多时的AI应用更新。苹果对通话、对话、图…

力扣 SQL题目

185.部门工资前三高的所有员工 公司的主管们感兴趣的是公司每个部门中谁赚的钱最多。一个部门的 高收入者 是指一个员工的工资在该部门的 不同 工资中 排名前三 。 编写解决方案&#xff0c;找出每个部门中 收入高的员工 。 以 任意顺序 返回结果表。 返回结果格式如下所示。 …

Android studio如何导入项目

打开解压好的安装包 找到build.gradle文件 打开查看gradle版本 下载对应的gradle版本Index of /gradle/&#xff08;镜像网站&#xff09; 下载all的对应压缩包 配置gradle的环境变量 新建GRADLE_HOME 将GRADLE_HOME加入到path中 将项目在Android studio中打开进行配置 将gr…

LM339模块电路故障查询

最近的电路测试中出现一个问题&#xff0c;如果不接液晶屏&#xff0c;LM339输入端是高电平&#xff0c;如果接了液晶屏&#xff0c;输入端就是低电平&#xff0c;即使在输入端加了上拉电阻&#xff0c;还是如前面的结论&#xff0c;如果越过LM339,直接和后级电路连接&#xff…

Python爬虫JS逆向进阶课程

这门课程是Python爬虫JS逆向进阶课程&#xff0c;将教授学员如何使用Python爬虫技术和JS逆向技术获取网站数据。学习者将学习如何分析网站的JS代码&#xff0c;破解反爬虫机制&#xff0c;以及如何使用Selenium和PhantomJS等工具进行模拟登录和数据抓取。课程结合实例演练和项目…

ThinkPHP邮件发送配置教程?怎么配置群发?

ThinkPHP邮件发送安全性如何保障&#xff1f;ThinkPHP如何实现&#xff1f; 无论是用户注册后的验证邮件&#xff0c;还是订单处理的通知邮件&#xff0c;都需要一个可靠的邮件发送机制。AokSend将详细介绍如何在ThinkPHP框架中配置邮件发送功能&#xff0c;并带您逐步了解其中…

Python武器库开发-武器库篇之Mongodb未授权漏洞扫描器(五十六)

Python武器库开发-武器库篇之Mongodb未授权漏洞扫描器(五十六) MongoDB 未授权访问漏洞简介以及危害 MongoDB是一款非常受欢迎的开源NoSQL数据库&#xff0c;广泛应用于各种Web应用和移动应用中。然而&#xff0c;由于默认配置的不当或者管理员的疏忽&#xff0c;导致不少Mon…