【Pytorch】RNN for Image Classification

在这里插入图片描述

文章目录

  • 1 RNN 的定义
  • 2 RNN 输入 input, h_0
  • 3 RNN 输出 output, h_n
  • 4 多层
  • 5 小试牛刀

学习参考来自

  • pytorch中nn.RNN()总结
  • RNN for Image Classification(RNN图片分类–MNIST数据集)
  • pytorch使用-nn.RNN
  • Building RNNs is Fun with PyTorch and Google Colab

1 RNN 的定义

在这里插入图片描述

nn.RNN(input_size, hidden_size, num_layers=1, nonlinearity=tanh, bias=True, batch_first=False, dropout=0, bidirectional=False)

参数说明

  • input_size输入特征的维度, 一般 rnn 中输入的是词向量,那么 input_size 就等于一个词向量的维度
  • hidden_size隐藏层神经元个数,或者也叫输出的维度(因为rnn输出为各个时间步上的隐藏状态)
  • num_layers网络的层数
  • nonlinearity激活函数
  • bias是否使用偏置
  • batch_first输入数据的形式,默认是 False,就是这样形式,(seq(num_step), batch, input_dim),也就是将序列长度放在第一位,batch 放在第二位
  • dropout是否应用dropout, 默认不使用,如若使用将其设置成一个0-1的数字即可
  • birdirectional是否使用双向的 rnn,默认是 False

2 RNN 输入 input, h_0

input 形状: 当设置 batch_first = False 时, ( L , N , H i n ) (L , N , H_{ i n}) (L,N,Hin) —— [时间步数, 批量大小, 特征维度]

当设置 batch_first = True时, ( N , L , H i n ) (N , L , H_{ i n}) (N,L,Hin)

当输入只有两个维度且 batch_size 为 1 时 : ( L , H i n ) (L, H_{in}) (L,Hin) 时,需要调用 torch.unsqueeze() 增加维度。

h_0 形状: ( D ∗ n u m _ l a y e r s , N , H o u t ) ( D ∗ n u m \_ l a y e r s , N , H _{o u t} ) (Dnum_layers,N,Hout), D 代表单向 RNN 还是双向 RNN。

在这里插入图片描述

3 RNN 输出 output, h_n

output 形状:当设置 batch_first = False 时, ( L , N , D ∗ H o u t ) (L, N, D * H_{out}) (L,N,DHout)—— [时间步数, 批量大小, 隐藏单元个数];

当设置 batch_first = True 时, ( N , L , D ∗ H o u t ) (N, L, D * H_{out}) (N,L,DHout)

h_n 形状 ( D ∗ num_layers , N , H o u t ) (D * \text{num\_layers}, N, H_{out}) (Dnum_layers,N,Hout)

4 多层

在这里插入图片描述

5 小试牛刀

在这里插入图片描述
如MNIST中28行看成28个序列, 每个序列有28个特征

在这里插入图片描述
x_0 到 x_27, 相当于依次输入图像的28行

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt# -------------
# MNIST dataset
# -------------
batch_size = 128
train_dataset = torchvision.datasets.MNIST(root='./',train=True,transform=transforms.ToTensor(),download=True)
test_dataset = torchvision.datasets.MNIST(root='./',train=False,transform=transforms.ToTensor())
# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=False)# ---------------------
# Exploring the dataset
# ---------------------
# function to show an image
def imshow(img):npimg = img.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))# get some random training images
dataiter = iter(train_loader)
images, labels = dataiter.next()if 1:# show imageimshow(torchvision.utils.make_grid(images, nrow=15))plt.show()# ----------
# parameters
# ----------
N_STEPS = 28
N_INPUTS = 28  # 输入数据的维度
N_NEURONS = 150  # RNN中间的特征的大小
N_OUTPUT = 10  # 输出数据的维度(分类的个数)
N_EPHOCS = 10  # epoch的大小
N_LAYERS = 3# ------
# models
# ------
class ImageRNN(nn.Module):def __init__(self, batch_size, n_inputs, n_neurons, n_outputs, n_layers):super(ImageRNN, self).__init__()self.batch_size = batch_size  # 输入的时候batch_size, 128self.n_inputs = n_inputs  # 输入的维度, 28self.n_outputs = n_outputs  # 分类的大小 10self.n_neurons = n_neurons  # RNN中输出的维度 150self.n_layers = n_layers  # RNN中的层数 3self.basic_rnn = nn.RNN(self.n_inputs, self.n_neurons, num_layers=self.n_layers)self.FC = nn.Linear(self.n_neurons, self.n_outputs)def init_hidden(self):# (num_layers, batch_size, n_neurons)# initialize hidden weights with zero values# 这个是net的memory, 初始化memory为0return (torch.zeros(self.n_layers, self.batch_size, self.n_neurons).to(device))def forward(self, x):  # torch.Size([128, 28, 28])# transforms x to dimensions : n_step × batch_size × n_inputsx = x.permute(1, 0, 2)  # 需要把n_step放在第一个, torch.Size([28, 128, 28])self.batch_size = x.size(1)  # 每次需要重新计算batch_size, 因为可能会出现不能完整方下一个batch的情况 128self.hidden = self.init_hidden()  # 初始化hidden state  torch.Size([3, 128, 150])rnn_out, self.hidden = self.basic_rnn(x, self.hidden)  # 前向传播  torch.Size([28, 128, 150]), torch.Size([3, 128, 150])out = self.FC(rnn_out[-1])  # 求出每一类的概率 torch.Size([128, 150])->torch.Size([128, 10])return out.view(-1, self.n_outputs)  # 最终输出大小 : batch_size X n_output  torch.Size([128, 10])# --------------------
# Device configuration
# --------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# ------------------------------------
# Test the model(输入一张图片查看输出)
# ------------------------------------
# 定义模型
model = ImageRNN(batch_size, N_INPUTS, N_NEURONS, N_OUTPUT, N_LAYERS).to(device)
print(model)
"""
ImageRNN((basic_rnn): RNN(28, 150, num_layers=3)(FC): Linear(in_features=150, out_features=10, bias=True)
)
"""# 初始化模型的weight
model.basic_rnn.weight_hh_l0.data = torch.eye(n=N_NEURONS, m=N_NEURONS, out=None).to(device)
model.basic_rnn.weight_hh_l1.data = torch.eye(n=N_NEURONS, m=N_NEURONS, out=None).to(device)
model.basic_rnn.weight_hh_l2.data = torch.eye(n=N_NEURONS, m=N_NEURONS, out=None).to(device)# 定义数据
dataiter = iter(train_loader)
images, labels = dataiter.next()
model.hidden = model.init_hidden()
logits = model(images.view(-1, 28, 28).to(device))
print(logits[0:2])
"""
tensor([[-0.2846, -0.1503, -0.1593,  0.5478,  0.6827,  0.3489, -0.2989,  0.4575,-0.2426, -0.0464],[-0.6708, -0.3025, -0.0205,  0.2242,  0.8470,  0.2654, -0.0381,  0.6646,-0.4479,  0.2523]], device='cuda:0', grad_fn=<SliceBackward>)
"""# 产生对角线是1的矩阵
torch.eye(n=5, m=5, out=None)
"""
tensor([[1., 0., 0., 0., 0.],[0., 1., 0., 0., 0.],[0., 0., 1., 0., 0.],[0., 0., 0., 1., 0.],[0., 0., 0., 0., 1.]])
"""# --------
# Training
# --------
model = ImageRNN(batch_size, N_INPUTS, N_NEURONS, N_OUTPUT, N_LAYERS).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 初始化模型的weight
model.basic_rnn.weight_hh_l0.data = torch.eye(n=N_NEURONS, m=N_NEURONS, out=None).to(device)
model.basic_rnn.weight_hh_l1.data = torch.eye(n=N_NEURONS, m=N_NEURONS, out=None).to(device)
model.basic_rnn.weight_hh_l2.data = torch.eye(n=N_NEURONS, m=N_NEURONS, out=None).to(device)def get_accuracy(logit, target, batch_size):"""最后用来计算模型的准确率"""corrects = (torch.max(logit, 1)[1].view(target.size()).data == target.data).sum()accuracy = 100.0 * corrects/batch_sizereturn accuracy.item()# ---------
# 开始训练
# ---------
for epoch in range(N_EPHOCS):train_running_loss = 0.0train_acc = 0.0model.train()# trainging roundfor i, data in enumerate(train_loader):optimizer.zero_grad()# reset hidden statesmodel.hidden = model.init_hidden()# get inputsinputs, labels = datainputs = inputs.view(-1, 28, 28).to(device)labels = labels.to(device)# forward+backward+optimizeoutputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()train_running_loss = train_running_loss + loss.detach().item()train_acc = train_acc + get_accuracy(outputs, labels, batch_size)model.eval()print('Epoch : {:0>2d} | Loss : {:<6.4f} | Train Accuracy : {:<6.2f}%'.format(epoch, train_running_loss/i, train_acc/i))# ----------------------------------------
# Computer accuracy on the testing dataset
# ----------------------------------------
test_acc = 0.0
for i,data in enumerate(test_loader,0):inputs, labels = datalabels = labels.to(device)inputs = inputs.view(-1,28,28).to(device)outputs = model(inputs)thisBatchAcc = get_accuracy(outputs, labels, batch_size)print("Batch:{:0>2d}, Accuracy : {:<6.4f}%".format(i,thisBatchAcc))test_acc = test_acc + thisBatchAcc
print('============平均准确率===========')
print('Test Accuracy : {:<6.4f}%'.format(test_acc/i))
"""
Epoch : 00 | Loss : 0.6336 | Train Accuracy : 79.32 %
Epoch : 01 | Loss : 0.2363 | Train Accuracy : 93.00 %
Epoch : 02 | Loss : 0.1852 | Train Accuracy : 94.63 %
Epoch : 03 | Loss : 0.1516 | Train Accuracy : 95.69 %
Epoch : 04 | Loss : 0.1338 | Train Accuracy : 96.13 %
Epoch : 05 | Loss : 0.1198 | Train Accuracy : 96.67 %
Epoch : 06 | Loss : 0.1254 | Train Accuracy : 96.46 %
Epoch : 07 | Loss : 0.1128 | Train Accuracy : 96.88 %
Epoch : 08 | Loss : 0.1059 | Train Accuracy : 97.09 %
Epoch : 09 | Loss : 0.1048 | Train Accuracy : 97.10 %
Batch:00, Accuracy : 98.4375%
Batch:01, Accuracy : 98.4375%
Batch:02, Accuracy : 95.3125%
Batch:03, Accuracy : 98.4375%
Batch:04, Accuracy : 96.8750%
Batch:05, Accuracy : 93.7500%
Batch:06, Accuracy : 97.6562%
Batch:07, Accuracy : 95.3125%
Batch:08, Accuracy : 94.5312%
Batch:09, Accuracy : 92.9688%
Batch:10, Accuracy : 96.0938%
Batch:11, Accuracy : 96.0938%
Batch:12, Accuracy : 97.6562%
Batch:13, Accuracy : 96.8750%
Batch:14, Accuracy : 96.0938%
Batch:15, Accuracy : 95.3125%
Batch:16, Accuracy : 95.3125%
Batch:17, Accuracy : 96.0938%
Batch:18, Accuracy : 96.0938%
Batch:19, Accuracy : 97.6562%
Batch:20, Accuracy : 97.6562%
Batch:21, Accuracy : 98.4375%
Batch:22, Accuracy : 96.0938%
Batch:23, Accuracy : 96.8750%
Batch:24, Accuracy : 97.6562%
Batch:25, Accuracy : 99.2188%
Batch:26, Accuracy : 96.0938%
Batch:27, Accuracy : 94.5312%
Batch:28, Accuracy : 98.4375%
Batch:29, Accuracy : 94.5312%
Batch:30, Accuracy : 96.0938%
Batch:31, Accuracy : 93.7500%
Batch:32, Accuracy : 96.8750%
Batch:33, Accuracy : 96.0938%
Batch:34, Accuracy : 95.3125%
Batch:35, Accuracy : 96.8750%
Batch:36, Accuracy : 97.6562%
Batch:37, Accuracy : 93.7500%
Batch:38, Accuracy : 94.5312%
Batch:39, Accuracy : 100.0000%
Batch:40, Accuracy : 99.2188%
Batch:41, Accuracy : 100.0000%
Batch:42, Accuracy : 98.4375%
Batch:43, Accuracy : 98.4375%
Batch:44, Accuracy : 96.8750%
Batch:45, Accuracy : 99.2188%
Batch:46, Accuracy : 96.0938%
Batch:47, Accuracy : 98.4375%
Batch:48, Accuracy : 97.6562%
Batch:49, Accuracy : 100.0000%
Batch:50, Accuracy : 99.2188%
Batch:51, Accuracy : 91.4062%
Batch:52, Accuracy : 96.8750%
Batch:53, Accuracy : 99.2188%
Batch:54, Accuracy : 99.2188%
Batch:55, Accuracy : 100.0000%
Batch:56, Accuracy : 98.4375%
Batch:57, Accuracy : 98.4375%
Batch:58, Accuracy : 97.6562%
Batch:59, Accuracy : 100.0000%
Batch:60, Accuracy : 99.2188%
Batch:61, Accuracy : 96.0938%
Batch:62, Accuracy : 100.0000%
Batch:63, Accuracy : 97.6562%
Batch:64, Accuracy : 97.6562%
Batch:65, Accuracy : 96.8750%
Batch:66, Accuracy : 98.4375%
Batch:67, Accuracy : 100.0000%
Batch:68, Accuracy : 100.0000%
Batch:69, Accuracy : 100.0000%
Batch:70, Accuracy : 96.8750%
Batch:71, Accuracy : 98.4375%
Batch:72, Accuracy : 100.0000%
Batch:73, Accuracy : 99.2188%
Batch:74, Accuracy : 100.0000%
Batch:75, Accuracy : 96.0938%
Batch:76, Accuracy : 95.3125%
Batch:77, Accuracy : 96.8750%
Batch:78, Accuracy : 12.5000%
============平均准确率===========
Test Accuracy : 97.4559%
# """# 定义hook
class SaveFeatures():"""注册hook和移除hook"""def __init__(self, module):self.hook = module.register_forward_hook(self.hook_fn)def hook_fn(self, module, input, output):self.features = outputdef close(self):self.hook.remove()# 绑定到model上
activations = SaveFeatures(model.basic_rnn)# 定义数据
dataiter = iter(train_loader)
images, labels = dataiter.next()# 前向传播
model.hidden = model.init_hidden()
logits = model(images.view(-1, 28, 28).to(device))
activations.close()  # 移除hook# 这个是 28(step)*128(batch_size)*150(hidden_size)
print(activations.features[0].shape)
# torch.Size([28, 128, 150])
print(activations.features[0][-1].shape)
# torch.Size([128, 150])

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/44717.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【ECCV 2024】首个跨模态步态识别框架:Camera-LiDAR Cross-modality Gait Recognition

【ECCV 2024】首个跨模态步态识别框架&#xff1a;Camera-LiDAR Cross-modality Gait Recognition 简介&#xff1a;主要方法&#xff1a;实验结果&#xff1a; 论文&#xff1a;https://arxiv.org/abs/2407.02038 简介&#xff1a; 步态识别是一种重要的生物特征识别技术。基…

算法力扣刷题记录 四十一【N叉树遍历】

前言 依然是遍历问题。由二叉树扩展到N叉树遍历。 记录 四十一【N叉树遍历】 一、【589. N叉树的前序遍历】 题目 给定一个 n 叉树的根节点 root &#xff0c;返回 其节点值的 前序遍历 。 n 叉树 在输入中按层序遍历进行序列化表示&#xff0c;每组子节点由空值 null 分隔…

第十八章 Express multer 文件上传

本章将学习Express multer 文件上传 &#xff0c;因为Nest 的文件上传是基于 Express 的中间件 multer 实现的&#xff0c;所以在学习 Nest 文件上传之前&#xff0c;我们先学习下 multer 包 首先先创建 multer-test 文件夹执行下面代码 创建package.json npm init -y接着安装…

深入浅出 Spring @Async 异步编程的艺术

目录 一、异步编程 二、Async 介绍 2.1 Async 使用 三、Async 原理 一、异步编程 在软件开发中&#xff0c;异步编程是非常关键的&#xff0c;尤其是构建高性能、高响应度的应用时。异步编程的主要优势在于它能够避免阻塞操作&#xff0c;提高程序的效率和用户体验。异步编…

修BUG:程序包javax.servlet.http不存在

貌似昨晚上并没有成功在tomcat上面运行&#xff0c;而是直接运行了网页。 不知道为啥又报错这个。。。 解决方案&#xff1a; https://developer.baidu.com/article/details/2768022 就整了这一步就行了 而且我本地就有这个tomcat就是加进去了。 所以说啊&#xff0c;是不是&a…

eNSP公司管理的对象及策略

拓扑图[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 ) 实验需求 第一步&#xff1a;根据题目搭建拓扑图 其中交换机的型号为&#xff1a;S5700 防火墙设备为&#xff1a;USG6000V 第二步&#xff1a;启动防火墙设备 首先会让你输入密码&#xff0c;…

【MySQL】常见的MySQL日志都有什么用?

MySQL日志的内容非常重要&#xff0c;面试中经常会被问到。同时&#xff0c;掌握日志相关的知识也有利于我们理解MySQL 底层原理&#xff0c;必要时帮助我们排查解决问题。 MySQL中常见的日志类型主要有下面几类(针对的是InnoDB 存储引擎): 错误日志(error log):对 MySQL 的启…

CentOS 6.5配置国内在线yum源和制作openssh 9.8p1 rpm包 —— 筑梦之路

CentOS 6.5比较古老的版本了&#xff0c;而还是有一些古老的项目仍然在使用。 环境说明 1. 更换国内在线yum源 CentOS 6 在线可用yum源配置——筑梦之路_centos6可用yum源-CSDN博客 cat > CentOS-163.repo << EOF [base] nameCentOS-$releasever - Base - 163.com …

新兴市场游戏产业爆发 传音以技术抢抓机遇 ​

随着年轻人口的增加以及互联网的普及,非洲、中东等新兴市场正迎来游戏产业的大爆发,吸引着全球游戏企业玩家在此开疆辟土。中国出海企业代表传音以新兴市场需求为中心,秉持本地化创新理念不断加强游戏等关键领域技术攻关凭借移动终端设备为全球玩家带来极致游戏体验,收获了消费…

就业平台小程序的设计

管理员账户功能包括&#xff1a;系统首页&#xff0c;个人中心&#xff0c;学生管理&#xff0c;企业管理&#xff0c;企业类型管理&#xff0c;留言板管理&#xff0c;系统管理 微信端账号功能包括&#xff1a;系统首页&#xff0c;招聘信息&#xff0c;简历&#xff0c;我的…

MapReduce底层原理详解:大案例解析(第32天)

系列文章目录 一、MapReduce概述 二、MapReduce工作机制 三、Map&#xff0c;Shuffle&#xff0c;reduce阶段详解 四、大案例解析 文章目录 系列文章目录前言一、MapReduce概述二、MapReduce工作机制1. 角色与组件2. 作业提交与执行流程1. 作业提交&#xff1a;2. Map阶段&…

MATLAB中c2d函数用法

目录 语法 说明 示例 在MATLAB中&#xff0c;c2d函数用于将连续时间系统&#xff08;Continuous-Time System&#xff09;转换为离散时间系统&#xff08;Discrete-Time System&#xff09;。以下是c2d函数的基本语法、说明以及示例&#xff1a; 语法 sys_d c2d(sys_c, T…

【每天认识一个漏洞】spf邮件伪造漏洞

&#x1f31d;博客主页&#xff1a;泥菩萨 &#x1f496;专栏&#xff1a;Linux探索之旅 | 网络安全的神秘世界 | 专接本 | 每天学会一个渗透测试工具 &#x1f3a3;漏洞危害 允许攻击者伪造发件人身份&#xff0c;从而发送钓鱼邮件或垃圾邮件&#xff0c;获取接收方的信任&am…

[leetcode]partition-list 分隔链表

. - 力扣&#xff08;LeetCode&#xff09; class Solution { public:ListNode* partition(ListNode* head, int x) {ListNode *smlDummy new ListNode(0), *bigDummy new ListNode(0);ListNode *sml smlDummy, *big bigDummy;while (head ! nullptr) {if (head->val &l…

YOLOv10改进 | 添加注意力机制 | 添加ACmix自注意力与卷积混合模型改善模型特征识别效率(包含二次创新PSA机制)

一、本文介绍 本文给大家带来的改进机制是ACmix自注意力机制的改进版本&#xff0c;它的核心思想是&#xff0c;传统卷积操作和自注意力模块的大部分计算都可以通过1x1的卷积来实现。ACmix首先使用1x1卷积对输入特征图进行投影&#xff0c;生成一组中间特征&#xff0c;然后根…

JavaScript中的Symbol类型是什么以及它的作用

聚沙成塔每天进步一点点 本文回顾 ⭐ 专栏简介JavaScript中的Symbol类型是什么以及它的作用1. 符号&#xff08;Symbol&#xff09;的创建2. 符号的特性3. 符号的作用3.1 属性名的唯一性3.2 防止属性被意外访问或修改3.3 使用内置的符号3.4 符号与属性遍历 4. 总结 ⭐ 写在最后…

网络协议(TCP三次握手,四次断开详解)

TCP的详细过程&#xff1a; TCP&#xff08;传输控制协议&#xff09;的三次握手和四次断开是其建立连接和终止连接的重要过程&#xff0c;以下是详细解释&#xff1a; 三次握手&#xff1a; 1. 第一次握手&#xff1a;客户端向服务器发送一个 SYN&#xff08;同步&#x…

Flask 用 Redis 缓存键值对-实例

Flask 使用起 Redis 来简直就是手到擒来&#xff0c;比 MySQL 简单多了&#xff0c;不需要那么多配置&#xff0c;实际代码就这么多&#xff0c;直接复制就能用。除了提供简单实用的实例以外&#xff0c;本文后面还会简单介绍一下 Redis 的安装与使用&#xff0c;初学者也能一看…

Linux笔记之三

Linux笔记之三 一、用户组管理二、磁盘管理三、进程管理总结 一、用户组管理 每个用户都有一个用户组&#xff0c;系统可以对一个用户组中的所有用户进行集中管理&#xff08;开发、测试、运维、root&#xff09;。不同Linux系统对用户组的管理涉及用户组的添加、删除和修改。…

8. Python3 pandas数据分析处理库

11.1 pandas的数据结构 pandas的数据结构如下图所示&#xff1a; pandas的几种数据结构有内在联系&#xff0c;可以吧DataFrame看作Series的容器&#xff0c;把Panel看作DataFrame的容器。可以像操作字典那样在这些数据结构中插入或者移除数据对象。在介绍这些数据结构之前&am…