使用PyTorch实现LSTM生成ai诗

最近学习torch的一个小demo。

什么是LSTM?

长短时记忆网络(Long Short-Term Memory,LSTM)是一种循环神经网络(RNN)的变体,旨在解决传统RNN在处理长序列时的梯度消失和梯度爆炸问题。LSTM引入了一种特殊的存储单元和门控机制,以更有效地捕捉和处理序列数据中的长期依赖关系。

通俗点说就是:LSTM是一种改进版的递归神经网络(RNN)。它的主要特点是可以记住更长时间的信息,这使得它在处理序列数据(如文本、时间序列、语音等)时非常有效。

步骤如下

数据准备

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import string
import os# 数据加载和预处理
def load_data(filepath):with open(filepath, 'r', encoding='utf-8') as file:text = file.read()return textdef preprocess_text(text):text = text.lower()text = text.translate(str.maketrans('', '', string.punctuation))return textdata_path = 'poetry.txt'  # 替换为实际的诗歌数据文件路径
text = load_data(data_path)
text = preprocess_text(text)
chars = sorted(list(set(text)))
char_to_idx = {char: idx for idx, char in enumerate(chars)}
idx_to_char = {idx: char for char, idx in char_to_idx.items()}
vocab_size = len(chars)print(f"Total characters: {len(text)}")
print(f"Vocabulary size: {vocab_size}")

模型构建

定义LSTM模型:

class LSTMModel(nn.Module):def __init__(self, input_size, hidden_size, output_size, num_layers=2):super(LSTMModel, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)self.softmax = nn.LogSoftmax(dim=1)def forward(self, x, hidden):lstm_out, hidden = self.lstm(x, hidden)output = self.fc(lstm_out[:, -1, :])output = self.softmax(output)return output, hiddendef init_hidden(self, batch_size):weight = next(self.parameters()).datahidden = (weight.new(self.num_layers, batch_size, self.hidden_size).zero_(),weight.new(self.num_layers, batch_size, self.hidden_size).zero_())return hidden

训练模型

将数据转换成LSTM需要的格式:

def prepare_data(text, seq_length):inputs = []targets = []for i in range(0, len(text) - seq_length, 1):seq_in = text[i:i + seq_length]seq_out = text[i + seq_length]inputs.append([char_to_idx[char] for char in seq_in])targets.append(char_to_idx[seq_out])return inputs, targetsseq_length = 100
inputs, targets = prepare_data(text, seq_length)# Convert to tensors
inputs = torch.tensor(inputs, dtype=torch.long)
targets = torch.tensor(targets, dtype=torch.long)batch_size = 64
input_size = vocab_size
hidden_size = 256
output_size = vocab_size
num_epochs = 20
learning_rate = 0.001model = LSTMModel(input_size, hidden_size, output_size)
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)# Training loop
for epoch in range(num_epochs):h = model.init_hidden(batch_size)total_loss = 0for i in range(0, len(inputs), batch_size):x = inputs[i:i + batch_size]y = targets[i:i + batch_size]x = nn.functional.one_hot(x, num_classes=vocab_size).float()output, h = model(x, h)loss = criterion(output, y)optimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item()print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(inputs):.4f}")

生成

def generate_text(model, start_str, length=100):model.eval()with torch.no_grad():input_eval = torch.tensor([char_to_idx[char] for char in start_str], dtype=torch.long).unsqueeze(0)input_eval = nn.functional.one_hot(input_eval, num_classes=vocab_size).float()h = model.init_hidden(1)predicted_text = start_strfor _ in range(length):output, h = model(input_eval, h)prob = torch.softmax(output, dim=1).datapredicted_idx = torch.multinomial(prob, num_samples=1).item()predicted_char = idx_to_char[predicted_idx]predicted_text += predicted_charinput_eval = torch.tensor([[predicted_idx]], dtype=torch.long)input_eval = nn.functional.one_hot(input_eval, num_classes=vocab_size).float()return predicted_textstart_string = "春眠不觉晓"
generated_text = generate_text(model, start_string)
print(generated_text)

运行结果如下:

运行的肯定不好,但至少出结果了。诗歌我这边只放了几句,可以自己通过外部文件放入更多素材。

整体代码直接运行即可:

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import string# 预定义一些中文诗歌数据
text = """
春眠不觉晓,处处闻啼鸟。
夜来风雨声,花落知多少。
床前明月光,疑是地上霜。
举头望明月,低头思故乡。
红豆生南国,春来发几枝。
愿君多采撷,此物最相思。
"""# 数据预处理
def preprocess_text(text):text = text.replace('\n', '')return texttext = preprocess_text(text)
chars = sorted(list(set(text)))
char_to_idx = {char: idx for idx, char in enumerate(chars)}
idx_to_char = {idx: char for char, idx in char_to_idx.items()}
vocab_size = len(chars)print(f"Total characters: {len(text)}")
print(f"Vocabulary size: {vocab_size}")class LSTMModel(nn.Module):def __init__(self, input_size, hidden_size, output_size, num_layers=2):super(LSTMModel, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)self.softmax = nn.LogSoftmax(dim=1)def forward(self, x, hidden):lstm_out, hidden = self.lstm(x, hidden)output = self.fc(lstm_out[:, -1, :])output = self.softmax(output)return output, hiddendef init_hidden(self, batch_size):weight = next(self.parameters()).datahidden = (weight.new(self.num_layers, batch_size, self.hidden_size).zero_(),weight.new(self.num_layers, batch_size, self.hidden_size).zero_())return hiddendef prepare_data(text, seq_length):inputs = []targets = []for i in range(0, len(text) - seq_length, 1):seq_in = text[i:i + seq_length]seq_out = text[i + seq_length]inputs.append([char_to_idx[char] for char in seq_in])targets.append(char_to_idx[seq_out])return inputs, targetsseq_length = 10
inputs, targets = prepare_data(text, seq_length)# Convert to tensors
inputs = torch.tensor(inputs, dtype=torch.long)
targets = torch.tensor(targets, dtype=torch.long)batch_size = 64
input_size = vocab_size
hidden_size = 256
output_size = vocab_size
num_epochs = 50
learning_rate = 0.003model = LSTMModel(input_size, hidden_size, output_size)
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)# Training loop
for epoch in range(num_epochs):h = model.init_hidden(batch_size)total_loss = 0for i in range(0, len(inputs), batch_size):x = inputs[i:i + batch_size]y = targets[i:i + batch_size]if x.size(0) != batch_size:continuex = nn.functional.one_hot(x, num_classes=vocab_size).float()output, h = model(x, h)loss = criterion(output, y)optimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item()print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss / len(inputs):.4f}")def generate_text(model, start_str, length=100):model.eval()with torch.no_grad():input_eval = torch.tensor([char_to_idx[char] for char in start_str], dtype=torch.long).unsqueeze(0)input_eval = nn.functional.one_hot(input_eval, num_classes=vocab_size).float()h = model.init_hidden(1)predicted_text = start_strfor _ in range(length):output, h = model(input_eval, h)prob = torch.softmax(output, dim=1).datapredicted_idx = torch.multinomial(prob, num_samples=1).item()predicted_char = idx_to_char[predicted_idx]predicted_text += predicted_charinput_eval = torch.tensor([[predicted_idx]], dtype=torch.long)input_eval = nn.functional.one_hot(input_eval, num_classes=vocab_size).float()return predicted_textstart_string = "春眠不觉晓"
generated_text = generate_text(model, start_string, length=100)
print(generated_text)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/28279.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vue格网图

先看效果 再看代码 <n-gridv-elsex-gap"20":y-gap"20"cols"2 s:2 m:3 l:3 xl:3 2xl:4"responsive"screen" ><n-grid-itemv-for"(item,index) in newSongList":key"item.id"class"cursor-pointer …

C# OpenCvSharp Mat操作-创建Mat-ones

ones 函数用于创建一个全为“1”的矩阵&#xff08;Mat&#xff09;&#xff0c;可以用于各种图像处理和计算机视觉任务。下面我将详细解释每个重载版本的 ones 函数&#xff0c;并提供相应的示例代码。&#x1f4f8; 1️⃣ ones(int rows, int cols, int type) 这个重载函数…

VS - regsvr32.exe的官方工程

文章目录 VS - regsvr32.exe的官方工程概述笔记官方原版实现自己封装一个函数来干活(注册/反注册 COM DLL)END VS - regsvr32.exe的官方工程 概述 如果是要使用COM DLL&#xff0c; 必须先注册。 一般手工注册就要调用regsvr32.exe xx.dll 但是控制的不够细&#xff0c;且一般…

Spring学习笔记(九)简单的SSM框架整合

实验目的 掌握SSM框架整合。 实验环境 硬件&#xff1a;PC机 操作系统&#xff1a;Windows 开发工具&#xff1a;idea 实验内容 整合SSM框架。 实验步骤 搭建SSM环境&#xff1a;构建web项目&#xff0c;导入需要的jar包&#xff0c;通过单元测试测试各层框架搭建的正确…

IDEA 设置主题、背景图片、背景颜色

一、设置主题 1、点击菜单 File -> Settings : 点击 Settings 菜单 2、点击 Editor -> Color Scheme -> Scheme, 小哈的 IDEA 版本号为 2022.2.3 , 官方默认提供了 4 种主题&#xff1a; Classic Light &#xff08;经典白&#xff09; ;Darcula &#xff08;暗黑主…

知识普及:什么是边缘计算(Edge Computing)?

边缘计算是一种分布式计算架构&#xff0c;它将数据处理、存储和服务功能移近数据产生的边缘位置&#xff0c;即接近数据源和用户的位置&#xff0c;而不是依赖中心化的数据中心或云计算平台。边缘计算的核心思想是在靠近终端设备的位置进行数据处理&#xff0c;以降低延迟、减…

React组件通信方式总结

文章目录 父组件向子组件传递数据子组件向父组件传递数据兄弟组件传递数据祖先与后代组件之间的传值复杂关系的组件之间的传值使用发布-订阅模式使用 Redux 父组件向子组件传递数据 无论是类组件还是函数式组件&#xff0c;父组件向子组件传递数据的方式都是使用 props 来实现…

vue怎样获取dom元素?

在 Vue.js 中&#xff0c;直接操作 DOM 元素通常不是推荐的做法&#xff0c;因为 Vue 的核心思想是数据驱动视图&#xff0c;我们更倾向于通过改变数据来影响视图&#xff0c;而不是直接操作 DOM。 然而&#xff0c;在某些情况下&#xff0c;你可能确实需要直接获取和操作 DOM…

C++模板之模板成员函数不能偏特化

目录 1.引言 2.类模板成员函数的特化 2.1.没有函数特化的类模板 2.2.增加函数特化 3.“曲线救国”函数“偏特化” 3.1.函数重载实现“偏特化” 3.2.使用类型选择机制实现“偏特化” 4.总结 1.引言 C 泛型编程的资料在介绍类模板的特化和偏特化的时候&#xff0…

【HarmonyOS】HUAWEI DevEco Studio 下载地址汇总

目录 OpenHarmony 4.x Releases 4.1 Release4.0 Release OpenHarmony 3.x Releases 3.2.1 Release3.2 Release3.1.3 Release3.1.2 Release3.1.1 Release3.1 Release 说明 Full SDK&#xff1a;面向OEM厂商提供&#xff0c;包含了需要使用系统权限的系统接口。 Public SDK&am…

Python对Excel表格的操作

今天, 实现了一个对excel表格操作的技术方案. 操作的要求是: (1)在一个目标表格(表格2)中的第2列已经有唯一标识码.第1列为凭证号, 但是是空的. (2)在数据表格中(表格1)中有资产的信息, 其中第2列是资产的唯一标识码, 第1列是凭证号. (3)表格2内只有部分资产. 要求: 从表格1中…

前端:鼠标点击实现高亮特效

一、实现思路 获取鼠标点击位置 通过鼠标点击位置设置高亮裁剪动画 二、效果展示 三、按钮组件代码 <template><buttonclass"blueBut"click"clickHandler":style"{backgroundColor: clickBut ? rgb(31, 67, 117) : rgb(128, 128, 128),…

C# OpenCvSharp 图像处理函数-图像拼接-hconcat、vconcat、Stitcher

在图像处理和计算机视觉领域,图像拼接是一个常见的操作。OpenCvSharp是一个用于.NET平台的OpenCV封装库,可以方便地进行图像处理。本文将详细介绍如何使用OpenCvSharp中的hconcat、vconcat函数以及Stitcher类进行图像拼接,并通过具体示例帮助读者理解和掌握这些知识点。 函…

Java生成NetCDF文件

因为需要再Cesium中实现风场粒子效果&#xff0c;网上找了许多项目&#xff0c;大多是通过加载NC文件来进行渲染的&#xff0c;因此了解NC文件又成了一件重要的事。特此记录用java成果生成可在前端渲染&#xff0c;QGIS中正常渲染的NetCDF文件的相关代码&#xff08;有没详细整…

16. 第十六章 类和函数

16. 类和函数 现在我们已经知道如何创建新的类型, 下一步是编写接收用户定义的对象作为参数或者将其当作结果用户定义的函数. 本章我会展示函数式编程风格, 以及两个新的程序开发计划.本章的代码示例可以从↓下载. https://github.com/AllenDowney/ThinkPython2/blob/master/c…

java程序在运行过程各个内部结构的作用

一&#xff1a;内部结构 一个进程对应一个jvm实例&#xff0c;一个运行时数据区&#xff0c;又包含多个线程&#xff0c;这些线程共享了方法区和堆&#xff0c;每个线程包含了程序计数器、本地方法栈和虚拟机栈接下来我们通过一个示意图介绍一下这个空间。 如图所示,当一个hell…

内窥镜系统设计简介

内窥镜系统设计简介 1. 源由2. 系统组成2.1 光学系统2.2 机械结构2.3 电子系统2.4 软件系统2.5 安全性和合规性2.6 研发与测试2.7 用户培训与支持 3. 研发过程3.1 光学系统Step 1&#xff1a;镜头设计Step 2&#xff1a;光源Step 3&#xff1a;成像传感器 3.2 机械结构Step 1&a…

11.泛型、trait和生命周期(上)

标题 一、泛型数据的引入二、改写为泛型函数三、结构体/枚举中的泛型定义四、方法定义中的泛型 一、泛型数据的引入 下面是两个函数&#xff0c;分别用来取得整型和符号型vector中的最大值 use std::fs::File;fn get_max_float_value_from_vector(src: &[f64]) -> f64…

代码随想录-Day31

455. 分发饼干 假设你是一位很棒的家长&#xff0c;想要给你的孩子们一些小饼干。但是&#xff0c;每个孩子最多只能给一块饼干。 对每个孩子 i&#xff0c;都有一个胃口值 g[i]&#xff0c;这是能让孩子们满足胃口的饼干的最小尺寸&#xff1b;并且每块饼干 j&#xff0c;都…

Python中的命名空间和作用域:解密变量的可见性和生命周期

在 Python 中&#xff0c;命名空间&#xff08;Namespace&#xff09;和作用域&#xff08;Scope&#xff09;是重要的概念&#xff0c;它们决定了变量和函数的可见性和生命周期。理解命名空间和作用域是编写高效、可维护代码的关键。 基本语法 命名空间 命名空间是一个存储…