【Pytorch神经网络实战案例】12 利用注意力机制的神经网络实现对FashionMNIST数据集图片的分类

1、掩码模式:是相对于变长的循环序列而言的,如果输入的样本序列长度不同,那么会先对其进行对齐处理(对短序列补0,对长序列截断),再输入模型。这样,模型中的部分样本中就会有大量的零值。为了提升运算性能,需要以掩码的方式将不需要的零值去掉,并保留非零值进行计算,这就是掩码的作用
2、均值模式:正常模式对每个维度的所有序列计算注意力分数,而均值模式对每个维度注意力分数计算平均值。均值模式会平滑处理同一序列不同维度之间的差异,认为所有维度都是平等的,将注意力用在序列之间。这种方式更能体现出序列的重要性。


 代码 Attention_cclassification.py

import torchvision
import torchvision.transforms as tranforms
import pylab
import torch
from matplotlib import pyplot as plt
import numpy as np
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'  # 可能是由于是MacOS系统的原因data_dir = './fashion_mnist'
tranform = tranforms.Compose([tranforms.ToTensor()])
train_dataset = torchvision.datasets.FashionMNIST(root=data_dir,train=True,transform=tranform,download=True)
print("训练数据集条数",len(train_dataset))
val_dataset = torchvision.datasets.FashionMNIST(root=data_dir, train=False, transform=tranform)
print("测试数据集条数",len(val_dataset))
im = train_dataset[0][0]
im = im.reshape(-1,28)
pylab.imshow(im)
pylab.show()
print("该图片的标签为:",train_dataset[0][1])## 数据集的制造
batch_size = 10
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)def imshow(img):print("图片形状:",np.shape(img))npimg = img.numpy()plt.axis('off')plt.imshow(np.transpose(npimg,(1,2,0)))classes = ('T-shirt', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle_Boot')
sample = iter(train_loader)
images,labels = sample.next()
print("样本形状:",np.shape(images))
print("样本标签",labels)
imshow(torchvision.utils.make_grid(images,nrow=batch_size))
print(','.join('%5s' % classes[labels[j]] for j in range(len(images))))class myLSTMNet(torch.nn.Module): #定义myLSTMNet模型类,该模型包括 2个RNN层和1个全连接层def __init__(self,in_dim, hidden_dim, n_layer, n_class):super(myLSTMNet, self).__init__()# 定义循环神经网络层self.lstm = torch.nn.LSTM(in_dim, hidden_dim, n_layer, batch_first=True)self.Linear = torch.nn.Linear(hidden_dim * 28, n_class)  # 定义全连接层self.attention = AttentionSeq(hidden_dim, hard=0.03) # 定义注意力层,使用硬模式的注意力机制def forward(self, t):  # 搭建正向结构t, _ = self.lstm(t)  # 使用LSTM对象进行RNN数据处理t = self.attention(t)   # 对循环神经网络结果进行注意力机制的处理,将处理后的结果变形为二维数据,传入全连接输出层。1t = t.reshape(t.shape[0], -1) # 对循环神经网络结果进行注意力机制的处理,将处理后的结果变形为二维数据,传入全连接输出层。2out = self.Linear(t)  # 进行全连接处理return outclass AttentionSeq(torch.nn.Module):def __init__(self, hidden_dim, hard=0.0): # 初始化super(AttentionSeq, self).__init__()self.hidden_dim = hidden_dimself.dense = torch.nn.Linear(hidden_dim, hidden_dim)self.hard = harddef forward(self, features, mean=False): # 类的处理方法# [batch,seq,dim]batch_size, time_step, hidden_dim = features.size()weight = torch.nn.Tanh()(self.dense(features)) # 全连接计算# 计算掩码,mask给负无穷使得权重为0mask_idx = torch.sign(torch.abs(features).sum(dim=-1))# mask_idx = mask_idx.unsqueeze(-1).expand(batch_size, time_step, hidden_dim)mask_idx = mask_idx.unsqueeze(-1).repeat(1, 1, hidden_dim)# 将掩码作用在注意力结果上# torch.where函数的意思是按照第一参数的条件对每个元素进行检查,如果满足,那么使用第二个参数里对应元素的值进行填充,如果不满足,那么使用第三个参数里对应元素的值进行填充。# torch.ful_likeO函数是按照张量的形状进行指定值的填充,其第一个参数是参考形状的张量,第二个参数是填充值。weight = torch.where(mask_idx == 1, weight,torch.full_like(mask_idx, (-2 ** 32 + 1))) # 利用掩码对注意力结果补0序列填充一个极小数,会在Softmax中被忽略为0weight = weight.transpose(2, 1)# 必须对注意力结果补0序列填充一个极小数,千万不能填充0,因为注意力结果是经过激活函数tanh()计算出来的,其值域是 - 1~1, 在这个区间内,零值是一个有效值。如果填充0,那么会对后面的Softmax结果产生影响。填充的值只有远离这个有效区间才可以保证被Softmax的结果忽略。weight = torch.nn.Softmax(dim=2)(weight) # 计算注意力分数if self.hard != 0:  # hard modeweight = torch.where(weight > self.hard, weight, torch.full_like(weight, 0))if mean: # 支持注意力分数平均值模式weight = weight.mean(dim=1)weight = weight.unsqueeze(1)weight = weight.repeat(1, hidden_dim, 1)weight = weight.transpose(2, 1)features_attention = weight * features # 将注意力分数作用于特征向量上return features_attention # 返回结果#实例化模型对象
network = myLSTMNet(28, 128, 2, 10)  # 图片大小是28x28,28:输入数据的序列长度为28。128:每层放置128个LSTM Cell。2:构建两层由LSTM Cell所组成的网络。10:最终结果分为10类。
#指定设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
network.to(device)
print(network)  #打印网络criterion = torch.nn.CrossEntropyLoss()  # 实例化损失函数类
optimizer = torch.optim.Adam(network.parameters(), lr=0.01)for epoch in range(2): # 数据集迭代2次running_loss = 0.0for i, data in enumerate(train_loader, 0): # 循环取出批次数据inputs, labels = datainputs = inputs.squeeze(1) # 由于输入数据是序列形式,不再是图片,因此将通道设为1inputs, labels = inputs.to(device), labels.to(device) # 指定设备optimizer.zero_grad() # 清空之前的梯度outputs = network(inputs)loss = criterion(outputs, labels) # 计算损失loss.backward()  #反向传播optimizer.step() #更新参数running_loss += loss.item()if i % 1000 == 999:print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 2000))running_loss = 0.0
print('Finished Training')#使用模型
dataiter = iter(test_loader)
images, labels = dataiter.next()
inputs, labels = images.to(device), labels.to(device)imshow(torchvision.utils.make_grid(images,nrow=batch_size))
print('真实标签: ', ' '.join('%5s' % classes[labels[j]] for j in range(len(images))))
inputs = inputs.squeeze(1)
outputs = network(inputs)
_, predicted = torch.max(outputs, 1)print('预测结果: ', ' '.join('%5s' % classes[predicted[j]]for j in range(len(images))))#测试模型
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():for data in test_loader:images, labels = dataimages = images.squeeze(1)inputs, labels = images.to(device), labels.to(device)outputs = network(inputs)_, predicted = torch.max(outputs, 1)predicted = predicted.to(device)c = (predicted == labels).squeeze()for i in range(10):label = labels[i]class_correct[label] += c[i].item()class_total[label] += 1sumacc = 0
for i in range(10):Accuracy = 100 * class_correct[i] / class_total[i]print('Accuracy of %5s : %2d %%' % (classes[i], Accuracy ))sumacc =sumacc+Accuracy
print('Accuracy of all : %2d %%' % ( sumacc/10. ))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/469342.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

爬虫实战学习笔记_4 网络请求urllib3模块:发送GET/POST请求实例+上传文件+IP代理+json+二进制+超时

1 urllib3模块简介 urllib3是一个第三方的网络请求模块(单独安装该模块),在功能上比Python自带的urllib强大。 1.1了解urllib3 urllib3库功能强大,条理清晰的用于HTTP客户端的python库,提供了很多Python标准库里所没…

C. Jon Snow and his Favourite Number DP + 注意数值大小

http://codeforces.com/contest/768/problem/C 这题的数值大小只有1000&#xff0c;那么可以联想到&#xff0c;用数值做数组的下标&#xff0c;就是类似于计数排序那样子。。 这样就可以枚举k次操作&#xff0c;然后for (int i 0; i < 1025; i)&#xff0c;也就是O(1000 *…

【Pytorch神经网络理论篇】 21 信息熵与互信息:联合熵+条件熵+交叉熵+相对熵/KL散度/信息散度+JS散度

同学你好&#xff01;本文章于2021年末编写&#xff0c;获得广泛的好评&#xff01; 故在2022年末对本系列进行填充与更新&#xff0c;欢迎大家订阅最新的专栏&#xff0c;获取基于Pytorch1.10版本的理论代码(2023版)实现&#xff0c; Pytorch深度学习理论篇(2023版)目录地址…

【Pytorch神经网络理论篇】 22 自编码神经网络:概述+变分+条件变分自编码神经网络

同学你好&#xff01;本文章于2021年末编写&#xff0c;获得广泛的好评&#xff01; 故在2022年末对本系列进行填充与更新&#xff0c;欢迎大家订阅最新的专栏&#xff0c;获取基于Pytorch1.10版本的理论代码(2023版)实现&#xff0c; Pytorch深度学习理论篇(2023版)目录地址…

【Pytorch神经网络实战案例】13 构建变分自编码神经网络模型生成Fashon-MNST模拟数据

1 变分自编码神经网络生成模拟数据案例说明 变分自编码里面真正的公式只有一个KL散度。 1.1 变分自编码神经网络模型介绍 主要由以下三个部分构成&#xff1a; 1.1.1 编码器 由两层全连接神经网络组成&#xff0c;第一层有784个维度的输入和256个维度的输出&#xff1b;第…

【Pytorch神经网络实战案例】14 构建条件变分自编码神经网络模型生成可控Fashon-MNST模拟数据

1 条件变分自编码神经网络生成模拟数据案例说明 在实际应用中&#xff0c;条件变分自编码神经网络的应用会更为广泛一些&#xff0c;因为它使得模型输出的模拟数据可控&#xff0c;即可以指定模型输出鞋子或者上衣。 1.1 案例描述 在变分自编码神经网络模型的技术上构建条件…

hibernate持久化对象

转载于:https://www.cnblogs.com/jianxin-lilang/p/6440101.html

【Pytorch神经网络理论篇】 23 对抗神经网络:概述流程 + WGAN模型 + WGAN-gp模型 + 条件GAN + WGAN-div + W散度

同学你好&#xff01;本文章于2021年末编写&#xff0c;获得广泛的好评&#xff01; 故在2022年末对本系列进行填充与更新&#xff0c;欢迎大家订阅最新的专栏&#xff0c;获取基于Pytorch1.10版本的理论代码(2023版)实现&#xff0c; Pytorch深度学习理论篇(2023版)目录地址…

【Pytorch神经网络实战案例】15 WGAN-gp模型生成Fashon-MNST模拟数据

1 WGAN-gp模型生成模拟数据案例说明 使用WGAN-gp模型模拟Fashion-MNIST数据的生成&#xff0c;会使用到WGAN-gp模型、深度卷积GAN(DeepConvolutional GAN&#xff0c;DCGAN)模型、实例归一化技术。 1.1 DCGAN中的全卷积 WGAN-gp模型侧重于GAN模型的训练部分&#xff0c;而DCG…

Android启动过程深入解析

转载自&#xff1a;http://blog.jobbole.com/67931/ 当按下Android设备电源键时究竟发生了什么&#xff1f;Android的启动过程是怎么样的&#xff1f;什么是Linux内核&#xff1f;桌面系统linux内核与Android系统linux内核有什么区别&#xff1f;什么是引导装载程序&#xff1…

【Pytorch神经网络实战案例】16 条件WGAN模型生成可控Fashon-MNST模拟数据

1 条件GAN前置知识 条件GAN也可以使GAN所生成的数据可控&#xff0c;使模型变得实用&#xff0c; 1.1 实验描述 搭建条件GAN模型&#xff0c;实现向模型中输入标签&#xff0c;并使其生成与标签类别对应的模拟数据的功能&#xff0c;基于WGAN-gp模型改造实现带有条件的wGAN-…

Android bootchart(二)

这篇文章讲一下MTK8127开机启动的时间 MTK8127发布版本开机时间大约在&#xff12;&#xff10;秒左右&#xff0c;如果发现开机时间变长&#xff0c;大部分是因为加上了客户订制的东西&#xff0c;代码累赘太多了。 &#xff11;、下面看一下&#xff2d;&#xff34;&#…

Android Camera框架

总体介绍 Android Camera 框架从整体上看是一个 client/service 的架构, 有两个进程: client 进程,可以看成是 AP 端,主要包括 JAVA 代码与一些 native c/c++代码; service 进 程,属于服务端,是 native c/c++代码,主要负责和 linux kernel 中的 camera driver 交互,搜集 li…

【Pytorch神经网络实战案例】17 带W散度的WGAN-div模型生成Fashon-MNST模拟数据

1 WGAN-div 简介 W散度的损失函数GAN-dv模型使用了W散度来替换W距离的计算方式&#xff0c;将原有的真假样本采样操作换为基于分布层面的计算。 2 代码实现 在WGAN-gp的基础上稍加改动来实现&#xff0c;重写损失函数的实现。 2.1 代码实战&#xff1a;引入模块并载入样本-…

【Pytorch神经网络理论篇】 24 神经网络中散度的应用:F散度+f-GAN的实现+互信息神经估计+GAN模型训练技巧

同学你好&#xff01;本文章于2021年末编写&#xff0c;获得广泛的好评&#xff01; 故在2022年末对本系列进行填充与更新&#xff0c;欢迎大家订阅最新的专栏&#xff0c;获取基于Pytorch1.10版本的理论代码(2023版)实现&#xff0c; Pytorch深度学习理论篇(2023版)目录地址…

【Pytorch神经网络实战案例】18 最大化深度互信信息模型DIM实现搜索最相关与最不相关的图片

图片搜索器分为图片的特征提取和匹配两部分&#xff0c;其中图片的特征提取是关键。将使用一种基于无监督模型的提取特征的方法实现特征提取&#xff0c;即最大化深度互信息&#xff08;DeepInfoMax&#xff0c;DIM&#xff09;方法。 1 最大深度互信信息模型DIM简介 在DIM模型…

【Pytorch神经网络实战案例】19 神经网络实现估计互信息的功能

1 案例说明&#xff08;实现MINE正方法的功能&#xff09; 定义两组具有不同分布的模拟数据&#xff0c;使用神经网络的MINE的方法计算两个数据分布之间的互信息 2 代码编写 2.1 代码实战&#xff1a;准备样本数据 import torch import torch.nn as nn import torch.nn.fun…

爬虫实战学习笔记_6 网络请求request模块:基本请求方式+设置请求头+获取cookies+模拟登陆+会话请求+验证请求+上传文件+超时异常

1 requests requests是Python中实现HTTP请求的一种方式&#xff0c;requests是第三方模块&#xff0c;该模块在实现HTTP请求时要比urlib、urllib3模块简化很多&#xff0c;操作更加人性化。 2 基本请求方式 由于requests模块为第三方模块&#xff0c;所以在使用requests模块时…

201521123044 《Java程序设计》第01周学习总结

1.本章学习总结 你对于本章知识的学习总结 1.了解了Java的发展史。 2.学习了什么是JVM,区分JRE与JDK,下载JDK。 3.从C语言的.c 到C的 .cpp再到Java的.java&#xff0c;每种语言编译程序各有不同&#xff0c;却有相似之处。 2. 书面作业 **Q1.为什么java程序可以跨平台运行&…

将一个java工程导入到myeclipse应该注意的地方

[原文]http://www.cnblogs.com/ht2411/articles/5471130.html 1. 最好新建一个myeclipse工程&#xff0c;然后从从文件系统导入该工程文件。 原因&#xff1a;很多项目可能是eclipse创建的&#xff0c;或者myeclipse的版本不一致&#xff0c;这样可能导致很多奇怪的现象&#x…