机器学习入门--循环神经网络原理与实践

循环神经网络

循环神经网络(RNN)是一种在序列数据上表现出色的人工神经网络。相比于传统前馈神经网络,RNN更加适合处理时间序列数据,如音频信号、自然语言和股票价格等。本文将介绍RNN的基本数学原理、使用PyTorch和Scikit-Learn数据集实现的代码。

数学原理

RNN是一种带有循环结构的神经网络,其在处理序列数据时将前一次的输出作为当前输入的一部分。这使得RNN能够记住先前的状态和信息,并且在处理长期依赖关系时表现出色。

RNN的基本公式可以表示为:

h t = f ( W h h h t − 1 + W x h x t ) h_t = f(W_{hh}h_{t-1} + W_{xh}x_t) ht=f(Whhht1+Wxhxt)

其中 h t h_t ht是RNN在时间步 t t t的隐藏状态, f f f是激活函数, W h h W_{hh} Whh是隐藏状态的权重矩阵, h t − 1 h_{t-1} ht1是上一次的隐藏状态, W x h W_{xh} Wxh是输入 x t x_t xt和隐藏状态 h t h_t ht之间的权重矩阵, x t x_t xt是时间步 t t t的输入。

在RNN的训练过程中,我们需要使用反向传播算法计算梯度并更新权重。由于RNN具有时间上的依赖关系,每一步的梯度都取决于前一步的梯度,这意味着我们需要使用反向传播算法的变体——反向传播通过时间(BPTT)算法来计算梯度。

代码实现

我们将使用PyTorch和Scikit-Learn数据集实现一个简单的RNN模型,用于预测时间序列数据。以下是代码:

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_boston
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt# 加载数据集
data = load_boston()
X = data.data
y = data.target# 数据标准化
scaler = StandardScaler()
X = scaler.fit_transform(X)
y = y.reshape(-1, 1)# 转换为PyTorch张量,并增加时间步维度
X = torch.tensor(X, dtype=torch.float32).unsqueeze(1)
y = torch.tensor(y, dtype=torch.float32)# 定义RNN模型
class SimpleRNN(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(SimpleRNN, self).__init__()self.hidden_size = hidden_sizeself.rnn = nn.RNN(input_size, hidden_size, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):out, _ = self.rnn(x)out = self.fc(out[:, -1, :])return out# 创建模型实例
input_size = X.shape[2]  # 更新input_size的值
hidden_size = 32
output_size = 1
model = SimpleRNN(input_size, hidden_size, output_size)# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 启用异常检测
torch.autograd.set_detect_anomaly(True)# 训练模型
num_epochs = 10000
# 记录损失
loss_list = []for epoch in range(num_epochs):optimizer.zero_grad()outputs = model(X)loss = criterion(outputs, y)loss.backward()optimizer.step()if (epoch+1) % 100 == 0:loss_list.append(loss.item())print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')# 关闭异常检测
torch.autograd.set_detect_anomaly(False)# 可视化损失曲线
plt.plot(range(100), loss_list)
plt.xlabel('num_epochs')
plt.ylabel('loss of RNN Training')
plt.show()
plt.savefig('Loss_of_RNN_Training.png')# 预测新数据
new_data_point = X[0].reshape(1, 1, -1)  # 假设使用第一个数据点进行预测
prediction = model(new_data_point)
print(f'Predicted value: {prediction.item()}')

上述代码实现了一个简单的循环神经网络(RNN)模型来预测波士顿房价,并可视化训练过程中损失的变化。代码首先加载并标准化了波士顿房价数据集,然后定义了一个包含RNN层和全连接层的SimpleRNN模型,并使用均方误差作为损失函数和Adam优化器进行训练。训练完成后,使用matplotlib库绘制训练过程中损失的变化曲线(如下图所示)。最后,使用训练好的模型对新的数据点进行预测,并输出预测值。这段代码可以为初学者提供一个实现RNN模型的参考,并通过可视化训练过程中的损失曲线来帮助理解模型的性能。
RNN 损失曲线

总结

本文介绍了RNN的基本数学原理、使用PyTorch和Scikit-Learn数据集实现的代码,以及如何解读代码并总结。RNN是一种在序列数据上表现出色的神经网络,常用于处理时间序列数据,如音频信号、自然语言和股票价格等。我们可以使用PyTorch和Scikit-Learn数据集来实现一个简单的RNN模型,并用它来预测未知的时间序列数据。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/683139.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【网络攻防实验】【北京航空航天大学】【实验四、防火墙配置(Firewall Configuration)实验】

实验四、防火墙配置(Firewall Configuration)实验 一、 实验环境搭建 1. Kali Linux网络配置 将Kali Linux虚拟机网卡1设置为NAT网络模式,ip地址为10.0.2.5,如下图所示: 配置NAT网络端口转发: 将Kali Linux网卡2设置为内部网络模式: 配置Kali Linux网卡1: 类似地,配…

Spring Cloud 路由和消息传递 (HTTP 路由)

Spring Cloud 路由 Spring Cloud 路由是指将请求路由到特定服务的机制。Spring Cloud 提供了多种路由机制,包括: Ribbon: 一个基于 HTTP 和 TCP 的客户端负载均衡工具,提供软负载均衡、故障转移等功能。Feign: 一个声明式的 HTTP 客户端&am…

在中国做 DePIN?你需要明白风险与机遇

撰文:肖飒团队 来源Techub News专栏作者 随着科技的发展,我们正在日益进入一个资源相对过剩的时代,这使我们在日常生活中虽然支付了该部分资源的使用费,但却时常不能将其「物尽其用」,难免出现资源浪费。例如&#x…

【More Effective C++】条款19:了解临时对象的来源

临时对象:没有命名,不会出现在源代码中 帮助隐式类型转换成功而创建的对象 编译器创建一个类型为string的临时对象,以buffer作为参数,调用string的构造函数;str绑定到了这个临时对象上函数返回时,这个临时…

PHP+vue+mysql校园学生社团管理系统574cc

运行环境:phpstudy/wamp/xammp等 开发语言:php 后端框架:Thinkphp 前端框架:vue.js 服务器:apache 数据库:mysql 数据库工具:Navicat/phpmyadmin 前台功能: 首页:展示社团信息和活动…

C#,二项式系数(Binomial Coefficient)的七种算法与源代码

1 二项式系数(binomial coefficient) 二项式系数(binomial coefficient),或组合数,在数学里表达为:(1 x)ⁿ展开后x的系数(其中n为自然数)。从定义可看出二项式系数的值…

立体库库存数量统计(SCL代码)

立体库库存物体检测由光电开关完成,每个储物格都有一个检测光电。5*6的仓库需要30个光电检测开关组成检测矩阵。找出矩阵中的最大元素并返回其所在的行号和列号和我们今天介绍的算法有很多相似的地方,大家可以对比学习。具体链接地址如下: h…

Vue中@change、@input和@blur的区别及@keyup介绍

Vue中change、input和blur、focus的区别及keyup介绍 1. change、input、blur、focus事件2. keyup事件3. 补充:el-input的change事件自定义传参 1. change、input、blur、focus事件 change在输入框发生变化且失去焦点后触发; input在输入框内容发生变化后…

conda env退回到之前的版本

默认显示的是 base 环境的历史记录 conda list --revisions 回到第 N 个版本 conda install --revision N 显示指定环境的修改记录 conda list -n env_name -r

机器学习3----决策树

这是前期准备 import numpy as np import pandas as pd import matplotlib.pyplot as plt #ID3算法 #每个特征的信息熵 # target : 账号是否真实,共2种情况 # yes 7个 p0.7 # no 3个 p0.3 info_D-(0.7*np.log2(0.7)0.3*np.log2(0.3)) info_D #日志密度…

Rust 原生类型

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、标量类型(scalar type)二、 复合类型(compound type)总结 前言 Rust 学习系列 ,rust中的原生类…

算法学习——LeetCode力扣回溯篇4

算法学习——LeetCode力扣回溯篇4 332. 重新安排行程 332. 重新安排行程 - 力扣(LeetCode) 描述 给你一份航线列表 tickets ,其中 tickets[i] [fromi, toi] 表示飞机出发和降落的机场地点。请你对该行程进行重新规划排序。 所有这些机票…

django中的中间件

在Django中,中间件(Middleware)是一个轻量级的、底层的“插件”系统,用于全局地修改Django的输入或输出。每个中间件组件都负责执行一些特定的任务,比如检查用户是否登录、处理日志、GZIP压缩等。Django的中间件提供了…

Vulnhub靶机:DC4

一、介绍 运行环境:Virtualbox 攻击机:kali(10.0.2.15) 靶机:DC4(10.0.2.57) 目标:获取靶机root权限和flag 靶机下载地址:https://www.vulnhub.com/entry/dc-4,313/…

Midjourney绘图欣赏系列(一)

Midjourney介绍 Midjourney 是生成式人工智能的一个很好的例子,它根据文本提示创建图像。它与 Dall-E 和 Stable Diffusion 一起成为最流行的 AI 艺术创作工具之一。与竞争对手不同,Midjourney 是自筹资金且闭源的,因此确切了解其幕后内容尚不…

【从Python基础到深度学习】7. 使用scp命令实现主机间通讯

一、生成 SSH 密钥对 ssh-keygen 是一个用于生成 SSH 密钥对的命令行工具,用于身份验证和加密通信 ssh-keygen 二、将本地主机上的 SSH 公钥添加到远程主机 ssh-copy-id 命令用于将本地主机上的 SSH 公钥添加到远程主机上的 authorized_keys 文件中,…

【初学者必看】迈入Midjourney的艺术世界:轻松掌握Midjourney的注册与订阅!

文章目录 前言一、Midjourney是什么二、Midjourney注册三、新建自己的服务器四、开通订阅 前言 AI绘画即指人工智能绘画,是一种计算机生成绘画的方式。是AIGC应用领域内的一大分支。 AI绘画主要分为两个部分,一个是对图像的分析与判断,即…

MySQL定时备份及清理脚本(一劳永逸)-改良版本

一 创建备份路径 cd /mysql-backup mkdir back cd back 二 创建日志文件 vi mysql-backlog.log 内容为空,保存 三 创建备份脚本 vi save-all-data.sh#!/bin/bash #source /etc/profile user"root" password"LXYlxy2:024.#8u}" host"127…

QlikSense财务聚合函数:IRR/NPV/XIRR/XNPV

IRR - 脚本函数 IRR() 函数用于返回聚合内部回报率,以揭示迭代于 group by 子句定义的大量记录上的表达式的数值表示的现金流系列。 这些现金流不必是均值,因为它们可用于年金。但是,现金流必须定期出现,例如每月或每年。内部收…

LeetCode879. Profitable Schemes——动态规划

文章目录 一、题目二、题解 一、题目 There is a group of n members, and a list of various crimes they could commit. The ith crime generates a profit[i] and requires group[i] members to participate in it. If a member participates in one crime, that member ca…