深度学习之生成唐诗案例(Pytorch版)

主要思路:

对于唐诗生成来说,我们定义一个"S" 和 "E"作为开始和结束。

 示例的唐诗大概有40000多首,

首先数据预处理,将唐诗加载到内存,生成对应的word2idx、idx2word、以及唐诗按顺序的字序列。

Dataset_Dataloader.py
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoaderdef deal_tangshi():with open("poems.txt", "r", encoding="utf-8") as fr:lines = fr.read().strip().split("\n")tangshis = []for line in lines:splits = line.split(":")if len(splits) != 2:continuetangshis.append("S" + splits[1] + "E")word2idx = {"S": 0, "E": 1}word2idx_count = 2tangshi_ids = []for tangshi in tangshis:for word in tangshi:if word not in word2idx:word2idx[word] = word2idx_countword2idx_count += 1idx2word = {idx: w for w, idx in word2idx.items()}for tangshi in tangshis:tangshi_ids.extend([word2idx[w] for w in tangshi])return word2idx, idx2word, tangshis, word2idx_count, tangshi_idsword2idx, idx2word, tangshis, word2idx_count, tangshi_ids = deal_tangshi()class TangShiDataset(Dataset):def __init__(self, tangshi_ids, num_chars):# 语料数据self.tangshi_ids = tangshi_ids# 语料长度self.num_chars = num_chars# 词的数量self.word_count = len(self.tangshi_ids)# 句子数量self.number = self.word_count // self.num_charsdef __len__(self):return self.numberdef __getitem__(self, idx):# 修正索引值到: [0, self.word_count - 1]start = min(max(idx, 0), self.word_count - self.num_chars - 2)x = self.tangshi_ids[start: start + self.num_chars]y = self.tangshi_ids[start + 1: start + 1 + self.num_chars]return torch.tensor(x), torch.tensor(y)def __test_Dataset():dataset = TangShiDataset(tangshi_ids, 8)x, y = dataset[0]print(x, y)if __name__ == '__main__':# deal_tangshi()__test_Dataset()
TangShiModel.py:唐诗的模型
import torch
import torch.nn as nn
from Dataset_Dataloader import *
import torch.nn.functional as Fclass TangShiRNN(nn.Module):def __init__(self, vocab_size):super().__init__()# 初始化词嵌入层self.ebd = nn.Embedding(vocab_size, 128)# 循环网络层self.rnn = nn.RNN(128, 128, 1)# 输出层self.out = nn.Linear(128, vocab_size)def forward(self, inputs, hidden):embed = self.ebd(inputs)# 正则化层embed = F.dropout(embed, p=0.2)output, hidden = self.rnn(embed.transpose(0, 1), hidden)# 正则化层embed = F.dropout(output, p=0.2)output = self.out(output.squeeze())return output, hiddendef init_hidden(self):return torch.zeros(1, 64, 128)

 main.py:

import timeimport torchfrom Dataset_Dataloader import *
from TangShiModel import *
import torch.optim as optim
from tqdm import tqdmdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")def train():dataset = TangShiDataset(tangshi_ids, 128)epochs = 100model = TangShiRNN(word2idx_count).to(device)criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=1e-3)for idx in range(epochs):dataloader = DataLoader(dataset, batch_size=64, shuffle=True, drop_last=True)start_time = time.time()total_loss = 0total_num = 0total_correct = 0total_correct_num = 0hidden = model.init_hidden()for x, y in tqdm(dataloader):x = x.to(device)y = y.to(device)# 隐藏状态hidden = model.init_hidden()hidden = hidden.to(device)# 模型计算output, hidden = model(x, hidden)# print(output.shape)# print(y.shape)# 计算损失loss = criterion(output.permute(1, 2, 0), y)# 梯度清零optimizer.zero_grad()# 反向传播loss.backward()# 参数更新optimizer.step()total_loss += loss.sum().item()total_num += len(y)total_correct_num += y.shape[0] * y.shape[1]# print(output.shape)total_correct += (torch.argmax(output.permute(1, 0, 2), dim=-1) == y).sum().item()print("epoch : %d average_loss : %.3f average_correct : %.3f use_time : %ds" %(idx + 1, total_loss / total_num, total_correct / total_correct_num, time.time() - start_time))torch.save(model.state_dict(), f"./modules/tangshi_module_{idx + 1}.bin")if __name__ == '__main__':train()

predict.py:

import torch
import torch.nn as nn
from Dataset_Dataloader import *
from TangShiModel import *device = torch.device("cuda" if torch.cuda.is_available() else "cpu")def predict():model = TangShiRNN(word2idx_count)model.load_state_dict(torch.load("./modules/tangshi_module_100.bin", map_location=torch.device('cpu')))model.eval()hidden = torch.zeros(1, 1, 128)start_word = input("输入第一个字:")flag = Nonetangshi_strs = []while True:if not flag:outputs, hidden = model(torch.tensor([[word2idx["S"]]], dtype=torch.long), hidden)tangshi_strs.append("S")flag = Trueelse:tangshi_strs.append(start_word)outputs, hidden = model(torch.tensor([[word2idx[start_word]]], dtype=torch.long), hidden)top_i = torch.argmax(outputs, dim=-1)if top_i.item() == word2idx["E"]:breakprint(top_i)start_word = idx2word[top_i.item()]print(tangshi_strs)if __name__ == '__main__':predict()

完整代码如下:

https://github.com/STZZ-1992/tangshi-generator.giticon-default.png?t=N7T8https://github.com/STZZ-1992/tangshi-generator.git

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/153701.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

万字解析设计模式之代理模式

一、代理模式 1.1概述 代理模式是一种结构型设计模式,它允许通过创建代理对象来控制对其他对象的访问。这种模式可以增加一些额外的逻辑来控制对原始对象的访问,同时还可以提供更加灵活的访问方式。 代理模式分为静态代理和动态代理两种。静态代理是在编…

【Android11】AOSPSettings增加蓝牙开关

基于Android11增加一个蓝牙开关按钮然后控制蓝牙开关 首先控制蓝牙开关的逻辑很简单,bluetoothAdapter.disable();就可以关闭。这里需要用到android.bluetooth.BluetoothAdapter; 1.找到蓝牙界面的xml文件加个按钮 --- a/packages/apps/Settings/res/xml/connect…

Day01 嵌入式 -----流水灯

一、简单介绍 嵌入式系统中的流水灯是一种常见的示例项目,通常用于演示嵌入式系统的基本功能和控制能力。流水灯由多个发光二极管(LED)组成,这些LED按照一定的顺序依次点亮和熄灭,形成一种像水流一样的流动效果。 二、…

单/三相dq解耦控制与特定次谐波抑制

1. 单相整流器dq坐标系下建模 单相整流器的拓扑如图所示,可知 u a b u s − L d i s d t − R i s {u_{ab}} {u_{s}} - L\frac{{d{i_s}}}{{dt}} - R{i_s} uab​us​−Ldtdis​​−Ris​。   将电压和电流写成dq的形式。 { u s U s m sin ⁡ ( ω t ) i s I …

选择「程序员」职业的8个理由

软件开发人员是具有创建软件程序的创意和技术技能的专业人员,是一个具有高回报和挑战性的职业选择。如今,软件开发人员几乎在每个行业工作。随着世界变得越来越数字化,越来越需要具有技术背景的人来创建特定的软件应用程序。 如果您考虑做一…

【React】classnames 库(可添加多个 className 类名)

文章目录 前言&#xff1a;在项目中我们有时候需要添加多个className&#xff0c;这时候就需要用到这个库了 例如&#xff1a;我们想得到这样一个效果 <div classclass1 class2></div>但是在react中&#xff0c;我们没办法像上面那样去实现&#xff0c;所以我们得…

纯JS,RSA,AES,公钥,私钥生成及加解密

通过网络找的JS源文件&#xff0c;修改后使用&#xff0c;包含RSA 密匙对生成 及AES 加解密 涉及的JS源文件 下载 GitHub - cgrlancer/RSA-AES: 纯js,RSA,AES前端加解密 前端引用 import {generateRsaKeyWithPKCS8,encryptByRSA,decryptByRSA,encrypt,decrypt,testRsa} fr…

文心一言-情感关怀之旅

如何让LLM更有温度。 应用介绍

【精选】XML技术知识点合计

XML概述 概念 XML&#xff08;Extensible Markup Language&#xff09;&#xff1a;可扩展标记语言 可扩展&#xff1a;标签都是自定义的。 发展历程 HTML和XML都是W3C&#xff08;万维网联盟&#xff09;制定的标准&#xff0c;最开始HTML的语法过于松散&#xff0c;于是W…

使用Java解决快手滑块验证码

分析页面结构&#xff1a; 使用浏览器开发者工具分析快手滑块验证码页面的HTML和JavaScript结构&#xff0c;找到滑块验证的相关元素和事件。 模拟滑块滑动&#xff1a; 使用Java的Selenium库或其他网络爬虫工具&#xff0c;模拟用户在滑块上的操作。你需要模拟鼠标点击、拖动…

企业要满足什么条件才能实施CRM系统?

CRM的作用相信大家也所有了解&#xff0c;但并不是所有的企业都适合实施CRM。或者说&#xff0c;大部分企业实施CRM并不会100%的成功。那么&#xff0c;企业实施CRM的条件是什么&#xff1f;下面我们就来说一说。 1、业务规模 如果您的客户数量较少&#xff0c;没有复杂的客户…

二分查找——34. 在排序数组中查找元素的第一个和最后一个位置

文章目录 1. 题目2. 算法原理2.1 暴力解法2.2 二分查找左端点查找右端点查找 3. 代码实现4. 二分模板 1. 题目 题目链接&#xff1a;34. 在排序数组中查找元素的第一个和最后一个位置 - 力扣&#xff08;LeetCode&#xff09; 给你一个按照非递减顺序排列的整数数组 nums&#…

苹果手机数据迁移,简单方法送给大家!

当我们准备更换新的苹果手机时&#xff0c;最令人头疼的问题就是如何将旧手机的数据迁移到新手机上。无论是什么手机&#xff0c;数据迁移确实是一个比较繁琐的过程。 但是&#xff0c;只要我们掌握了正确的方法&#xff0c;那么这个过程就会变得简单许多。苹果手机数据迁移的…

护眼灯亮度多少合适?亮度适合学生的护眼台灯推荐

护眼灯亮度满足国AA级标准就好了。可以肯定的是&#xff0c;护眼灯一般可以达到护眼的效果。 看书和写字时&#xff0c;光线应适度&#xff0c;不宜过强或过暗&#xff0c;护眼灯光线较柔和&#xff0c;通常并不刺眼&#xff0c;眼球容易适应&#xff0c;可以防止光线过强或过…

go map字典操作

类型断言 断言 在现代化 程序中 有助于 终止代码 , 防止 更大的 错误产生 package mainimport "fmt"func main() {var i interface{} "hello"s : i.(string)fmt.Println(s)s, ok : i.(string)fmt.Println(s, ok)f, ok : i.(float64)fmt.Println(f, ok)f…

老友小明哥-个人简介

b站个人主页&#xff08;可以看看免费视频&#xff09;&#xff1a;老友小明哥的个人空间-老友小明哥个人主页-哔哩哔哩视频

请问DasViewer是否支持与业务系统集成,将业务的动态的数据实时的展示到三维模型上?

答&#xff1a;一般这种是以平台的方式来展示&#xff0c;云端地球实景三维建模云平台是专门做这一块的&#xff0c;可前往云端地球官网免费使用。 DasViewer是由大势智慧自主研发的免费的实景三维模型浏览器,采用多细节层次模型逐步自适应加载技术,让用户在极低的电脑配置下,…

数据质量校验

1.事实表包含昨日数据 2.昨日同比趋势分析 圆通业务量较为平稳 &#xff0c;每日数据量和昨日比差距不足20%&#xff0c;会做数据量的昨日环比差距分析

Camtasia2024免费版mac电脑录屏软件

作为一个互联网人&#xff0c;没少在录屏软件这个坑里摸爬滚打。培训、学习、游戏、影视解说……都得用它。这时候没个拿得出手的私藏软件&#xff0c;还怎么混&#xff1f;说实话&#xff0c;录屏软件这两年也用了不少&#xff0c;基本功能是有但总觉得缺点什么&#xff0c;直…

01-制作人和迈克尔杰克逊-《人月神话》中译本纠错及联想

DDD领域驱动设计批评文集 做强化自测题获得“软件方法建模师”称号 《软件方法》各章合集 2001年&#xff0c;我们翻译《人月神话》的时候&#xff0c;由于水平有限&#xff0c;译文中存在不少错误。 这些年&#xff0c;随着阅历的增长&#xff0c;在重读的时候偶尔也会有“…