机器学习入门--循环神经网络原理与实践

循环神经网络

循环神经网络(RNN)是一种在序列数据上表现出色的人工神经网络。相比于传统前馈神经网络,RNN更加适合处理时间序列数据,如音频信号、自然语言和股票价格等。本文将介绍RNN的基本数学原理、使用PyTorch和Scikit-Learn数据集实现的代码。

数学原理

RNN是一种带有循环结构的神经网络,其在处理序列数据时将前一次的输出作为当前输入的一部分。这使得RNN能够记住先前的状态和信息,并且在处理长期依赖关系时表现出色。

RNN的基本公式可以表示为:

h t = f ( W h h h t − 1 + W x h x t ) h_t = f(W_{hh}h_{t-1} + W_{xh}x_t) ht=f(Whhht1+Wxhxt)

其中 h t h_t ht是RNN在时间步 t t t的隐藏状态, f f f是激活函数, W h h W_{hh} Whh是隐藏状态的权重矩阵, h t − 1 h_{t-1} ht1是上一次的隐藏状态, W x h W_{xh} Wxh是输入 x t x_t xt和隐藏状态 h t h_t ht之间的权重矩阵, x t x_t xt是时间步 t t t的输入。

在RNN的训练过程中,我们需要使用反向传播算法计算梯度并更新权重。由于RNN具有时间上的依赖关系,每一步的梯度都取决于前一步的梯度,这意味着我们需要使用反向传播算法的变体——反向传播通过时间(BPTT)算法来计算梯度。

代码实现

我们将使用PyTorch和Scikit-Learn数据集实现一个简单的RNN模型,用于预测时间序列数据。以下是代码:

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_boston
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt# 加载数据集
data = load_boston()
X = data.data
y = data.target# 数据标准化
scaler = StandardScaler()
X = scaler.fit_transform(X)
y = y.reshape(-1, 1)# 转换为PyTorch张量,并增加时间步维度
X = torch.tensor(X, dtype=torch.float32).unsqueeze(1)
y = torch.tensor(y, dtype=torch.float32)# 定义RNN模型
class SimpleRNN(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(SimpleRNN, self).__init__()self.hidden_size = hidden_sizeself.rnn = nn.RNN(input_size, hidden_size, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):out, _ = self.rnn(x)out = self.fc(out[:, -1, :])return out# 创建模型实例
input_size = X.shape[2]  # 更新input_size的值
hidden_size = 32
output_size = 1
model = SimpleRNN(input_size, hidden_size, output_size)# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 启用异常检测
torch.autograd.set_detect_anomaly(True)# 训练模型
num_epochs = 10000
# 记录损失
loss_list = []for epoch in range(num_epochs):optimizer.zero_grad()outputs = model(X)loss = criterion(outputs, y)loss.backward()optimizer.step()if (epoch+1) % 100 == 0:loss_list.append(loss.item())print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')# 关闭异常检测
torch.autograd.set_detect_anomaly(False)# 可视化损失曲线
plt.plot(range(100), loss_list)
plt.xlabel('num_epochs')
plt.ylabel('loss of RNN Training')
plt.show()
plt.savefig('Loss_of_RNN_Training.png')# 预测新数据
new_data_point = X[0].reshape(1, 1, -1)  # 假设使用第一个数据点进行预测
prediction = model(new_data_point)
print(f'Predicted value: {prediction.item()}')

上述代码实现了一个简单的循环神经网络(RNN)模型来预测波士顿房价,并可视化训练过程中损失的变化。代码首先加载并标准化了波士顿房价数据集,然后定义了一个包含RNN层和全连接层的SimpleRNN模型,并使用均方误差作为损失函数和Adam优化器进行训练。训练完成后,使用matplotlib库绘制训练过程中损失的变化曲线(如下图所示)。最后,使用训练好的模型对新的数据点进行预测,并输出预测值。这段代码可以为初学者提供一个实现RNN模型的参考,并通过可视化训练过程中的损失曲线来帮助理解模型的性能。
RNN 损失曲线

总结

本文介绍了RNN的基本数学原理、使用PyTorch和Scikit-Learn数据集实现的代码,以及如何解读代码并总结。RNN是一种在序列数据上表现出色的神经网络,常用于处理时间序列数据,如音频信号、自然语言和股票价格等。我们可以使用PyTorch和Scikit-Learn数据集来实现一个简单的RNN模型,并用它来预测未知的时间序列数据。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/683139.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【网络攻防实验】【北京航空航天大学】【实验四、防火墙配置(Firewall Configuration)实验】

实验四、防火墙配置(Firewall Configuration)实验 一、 实验环境搭建 1. Kali Linux网络配置 将Kali Linux虚拟机网卡1设置为NAT网络模式,ip地址为10.0.2.5,如下图所示: 配置NAT网络端口转发: 将Kali Linux网卡2设置为内部网络模式: 配置Kali Linux网卡1: 类似地,配…

在中国做 DePIN?你需要明白风险与机遇

撰文:肖飒团队 来源Techub News专栏作者 随着科技的发展,我们正在日益进入一个资源相对过剩的时代,这使我们在日常生活中虽然支付了该部分资源的使用费,但却时常不能将其「物尽其用」,难免出现资源浪费。例如&#x…

PHP+vue+mysql校园学生社团管理系统574cc

运行环境:phpstudy/wamp/xammp等 开发语言:php 后端框架:Thinkphp 前端框架:vue.js 服务器:apache 数据库:mysql 数据库工具:Navicat/phpmyadmin 前台功能: 首页:展示社团信息和活动…

C#,二项式系数(Binomial Coefficient)的七种算法与源代码

1 二项式系数(binomial coefficient) 二项式系数(binomial coefficient),或组合数,在数学里表达为:(1 x)ⁿ展开后x的系数(其中n为自然数)。从定义可看出二项式系数的值…

立体库库存数量统计(SCL代码)

立体库库存物体检测由光电开关完成,每个储物格都有一个检测光电。5*6的仓库需要30个光电检测开关组成检测矩阵。找出矩阵中的最大元素并返回其所在的行号和列号和我们今天介绍的算法有很多相似的地方,大家可以对比学习。具体链接地址如下: h…

机器学习3----决策树

这是前期准备 import numpy as np import pandas as pd import matplotlib.pyplot as plt #ID3算法 #每个特征的信息熵 # target : 账号是否真实,共2种情况 # yes 7个 p0.7 # no 3个 p0.3 info_D-(0.7*np.log2(0.7)0.3*np.log2(0.3)) info_D #日志密度…

算法学习——LeetCode力扣回溯篇4

算法学习——LeetCode力扣回溯篇4 332. 重新安排行程 332. 重新安排行程 - 力扣(LeetCode) 描述 给你一份航线列表 tickets ,其中 tickets[i] [fromi, toi] 表示飞机出发和降落的机场地点。请你对该行程进行重新规划排序。 所有这些机票…

Vulnhub靶机:DC4

一、介绍 运行环境:Virtualbox 攻击机:kali(10.0.2.15) 靶机:DC4(10.0.2.57) 目标:获取靶机root权限和flag 靶机下载地址:https://www.vulnhub.com/entry/dc-4,313/…

Midjourney绘图欣赏系列(一)

Midjourney介绍 Midjourney 是生成式人工智能的一个很好的例子,它根据文本提示创建图像。它与 Dall-E 和 Stable Diffusion 一起成为最流行的 AI 艺术创作工具之一。与竞争对手不同,Midjourney 是自筹资金且闭源的,因此确切了解其幕后内容尚不…

【从Python基础到深度学习】7. 使用scp命令实现主机间通讯

一、生成 SSH 密钥对 ssh-keygen 是一个用于生成 SSH 密钥对的命令行工具,用于身份验证和加密通信 ssh-keygen 二、将本地主机上的 SSH 公钥添加到远程主机 ssh-copy-id 命令用于将本地主机上的 SSH 公钥添加到远程主机上的 authorized_keys 文件中,…

【初学者必看】迈入Midjourney的艺术世界:轻松掌握Midjourney的注册与订阅!

文章目录 前言一、Midjourney是什么二、Midjourney注册三、新建自己的服务器四、开通订阅 前言 AI绘画即指人工智能绘画,是一种计算机生成绘画的方式。是AIGC应用领域内的一大分支。 AI绘画主要分为两个部分,一个是对图像的分析与判断,即…

QlikSense财务聚合函数:IRR/NPV/XIRR/XNPV

IRR - 脚本函数 IRR() 函数用于返回聚合内部回报率,以揭示迭代于 group by 子句定义的大量记录上的表达式的数值表示的现金流系列。 这些现金流不必是均值,因为它们可用于年金。但是,现金流必须定期出现,例如每月或每年。内部收…

《合成孔径雷达成像算法与实现》Figure6.12

clc clear close all参数设置 距离向参数设置 R_eta_c 20e3; % 景中心斜距 Tr 2.5e-6; % 发射脉冲时宽 Kr 20e12; % 距离向调频率 alpha_os_r 1.7; % 距离过采样率 Nrg 320; % 距离线采样数 距离向…

【头歌·计组·自己动手画CPU】三、存储系统设计(HUST)(理论版) 【计算机硬件系统设计】

🕺作者: 主页 我的专栏C语言从0到1探秘C数据结构从0到1探秘Linux 😘欢迎 ❤️关注 👍点赞 🙌收藏 ✍️留言 文章目录 一、课程设计目的二、课程设计内容三、课程设计步骤四、课程设计总结 一、课程设计目的 理解计算机…

猫头虎分享:2024年值得程序员关注的技术发展动向分析

博主猫头虎的技术世界 🌟 欢迎来到猫头虎的博客 — 探索技术的无限可能! 专栏链接: 🔗 精选专栏: 《面试题大全》 — 面试准备的宝典!《IDEA开发秘籍》 — 提升你的IDEA技能!《100天精通鸿蒙》 …

HarmonyOS鸿蒙学习基础篇 - Column/Row 组件

前言 Row和Column组件是线性布局容器,用于按照垂直或水平方向排列子组件。Row表示沿水平方向布局的容器,而Column表示沿垂直方向布局的容器。这些容器具有许多属性和方法,可以方便地管理子组件的位置、大小、间距和对齐方式。例如&#xff0c…

从C向C++7——继承

一.继承 1.理解继承 C中的继承是类与类之间的关系,是一个很简单很直观的概念,与现实世界中的继承类似,例如儿子继承父亲的财产。 继承可以理解为一个类从另一个类获取成员变量和成员函数的过程。例如类 B 继承于类 A,那么 B 就…

Codeforces Round 924(Div.2) A~E

A.Rectangle Cutting (模拟) 题意: 给出一个长方形,通过平行于原始矩形的一条边进行切割,将该矩形切割成两个边长为整数的矩形。询问是否能通过旋转和移动这两个矩形,得到新的矩形。 分析: 可以发现拼成的新长方形…

Python算法探索:从经典到现代(三)

一、引言 随着信息技术的飞速发展,数据已经成为现代社会不可或缺的资源。Python,作为数据处理和分析的利器,为我们提供了大量强大的库和工具,用于从经典到现代的各种算法探索。本文将带你领略Python在算法领域的魅力,从…

COW AI接入到微信 保姆教程 (部署在服务器,插件安装)

此文章不涉及国外的AI模型,也无需翻墙,跟某AI模型无关,审核大哥别弄错了 最近的AI开始越开越火了,开始介入到我们生活中的方方面面。就有人好奇AI是否能接入到微信吗?我在GitHub上搜索的时候还真有除了对话外还可以通…