文章目录
- 前言
- 一、数据预处理
- 二、辅助训练工具函数
- 三、绘图工具函数
- 四、模型定义
- 五、模型训练与预测
- 六、实例化模型并训练
- 训练结果可视化
- 总结
前言
循环神经网络(RNN)是深度学习中处理序列数据的重要模型,尤其在自然语言处理和时间序列分析中有着广泛应用。本篇博客将通过一个基于 PyTorch 的 RNN 实现,结合《The Time Machine》数据集,带你从零开始理解 RNN 的构建、训练和预测过程。我们将逐步剖析代码,展示如何加载数据、定义工具函数、构建模型、绘制训练过程图表,并最终训练一个字符级别的 RNN 模型。代码中包含了数据预处理、模型定义、梯度裁剪、困惑度计算等关键步骤,适合希望深入理解 RNN 的初学者和进阶者。
本文基于 PyTorch 实现,所有代码均来自附件,并辅以详细注释和图表说明。让我们开始吧!
一、数据预处理
首先,我们需要加载和预处理《The Time Machine》数据集,将其转化为适合 RNN 输入的格式。以下是数据预处理的完整代码:
import random
import re
import torch
from collections import Counterdef read_time_machine():"""将时间机器数据集加载到文本行的列表中"""with open('timemachine.txt', 'r') as f:lines = f.readlines()# 去除非字母字符并将每行转换为小写return [re.sub('[^A-Za-z]+', ' ', line).strip().lower() for line in lines]def tokenize(lines, token='word'):"""将文本行拆分为单词或字符词元"""if token == 'word':return [line.split() for line in lines]elif token == 'char':return [list(line) for line in lines]else:print(f'错误:未知词元类型:{token}')def count_corpus(tokens):"""统计词元的频率"""if not tokens:return Counter()if isinstance(tokens[0], list):flattened_tokens = [token for sublist in tokens for token in sublist]else:flattened_tokens = tokensreturn Counter(flattened_tokens)class Vocab:"""文本词表类,用于管理词元及其索引的映射关系"""def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):self.tokens = tokens if tokens is not None else []self.reserved_tokens = reserved_tokens if reserved_tokens is not None else []counter = self._count_corpus(self.tokens)self._token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)self.idx_to_token = ['<unk>'] + self.reserved_tokensself.token_to_idx = {token: idx for idx, token in enumerate(self.idx_to_token)}for token, freq in self._token_freqs:if freq < min_freq:breakif token not in self.token_to_idx:self.idx_to_token.append(token)self.token_to_idx[token] = len(self.idx_to_token) - 1@staticmethoddef _count_corpus(tokens):if not tokens:return Counter()if isinstance(tokens[0], list):tokens = [token for sublist in tokens for token in sublist]return Counter(tokens)def __len__(self):return len(self.idx_to_token)def __getitem__(self, tokens):if not isinstance(tokens, (list, tuple)):return self.token_to_idx.get(tokens, self.unk)return [self[token] for token in tokens]def to_tokens(self, indices):if not isinstance(indices, (list, tuple)):return self.idx_to_token[indices]return [self.idx_to_token[index] for index in indices]@propertydef unk(self):return 0@propertydef token_freqs(self):return self._token_freqsdef load_corpus_time_machine(max_tokens=-1):"""返回时光机器数据集的词元索引列表和词表"""lines = read_time_machine()tokens = tokenize(lines, 'char')vocab = Vocab(tokens)corpus = [vocab[token] for line in tokens for token in line]if max_tokens > 0:corpus = corpus