【Pytorch神经网络实战案例】10 搭建深度卷积神经网络

 
识别黑白图中的服装图案(Fashion-MNIST)https://blog.csdn.net/qq_39237205/article/details/123379997基于上述代码修改模型的组成

1 修改myConNet模型

1.1.1 修改阐述

将模型中的两个全连接层,变为全局平均池化层。

1.1.2 修改结果

### 1.5 定义模型类
class myConNet(torch.nn.Module):def __init__(self):super(myConNet, self).__init__()# 定义卷积层self.conv1 = torch.nn.Conv2d(in_channels = 1 ,out_channels = 6,kernel_size = 3)self.conv2 = torch.nn.Conv2d(in_channels = 6,out_channels = 12,kernel_size = 3)self.conv3 = torch.nn.Conv2d(in_channels = 12, out_channels=10, kernel_size = 3) # 分为10个类def forward(self,t):# 第一层卷积和池化处理t = self.conv1(t)t = F.relu(t)t = F.max_pool2d(t, kernel_size=2, stride=2)# 第二层卷积和池化处理t = self.conv2(t)t = F.relu(t)t = F.max_pool2d(t, kernel_size=2, stride=2)# 第三层卷积和池化处理t = self.conv3(t)t = F.avg_pool2d(t,kernel_size = t.shape[-2:],stride = t.shape[-2:]) # 设置池化区域为输入数据的大小(最后两个维度),完成全局平均化的处理。return t.reshape(t.shape[:2])

2 代码

import  torchvision
import torchvision.transforms as transforms
import pylab
import torch
from matplotlib import pyplot as plt
import torch.utils.data
import torch.nn.functional as F
import numpy as np
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"# 定义显示图像的函数
def imshow(img):print("图片形状",np.shape(img))img = img/2 +0.5npimg = img.numpy()plt.axis('off')plt.imshow(np.transpose(npimg,(1,2,0)))### 1.1 自动下载FashionMNIST数据集
data_dir = './fashion_mnist' # 设置存放位置
transform = transforms.Compose([transforms.ToTensor()]) # 可以自动将图片转化为Pytorch支持的形状[通道,高,宽],同时也将图片的数值归一化
train_dataset = torchvision.datasets.FashionMNIST(data_dir,train=True,transform=transform,download=True)
print("训练集的条数",len(train_dataset))### 1.2 读取及显示FashionMNIST数据集中的数据
val_dataset = torchvision.datasets.FashionMNIST(root=data_dir,train=False,transform=transform)
print("测试集的条数",len(val_dataset))
##1.2.1 显示数据集中的数据
im = train_dataset[0][0].numpy()
im = im.reshape(-1,28)
pylab.imshow(im)
pylab.show()
print("当前图片的标签为",train_dataset[0][1])### 1.3 按批次封装FashionMNIST数据集
batch_size = 10 #设置批次大小
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)### 1.4 读取批次数据集
## 定义类别名称
classes = ('T-shirt', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle_Boot')
sample = iter(train_loader) # 将数据集转化成迭代器
images,labels = sample.next() # 从迭代器中取得一批数据
print("样本形状",np.shape(images)) # 打印样本形状
# 输出 样本形状 torch.Size([10, 1, 28, 28])
print("样本标签",labels)
# 输出 图片形状 torch.Size([3, 32, 302])
imshow(torchvision.utils.make_grid(images,nrow = batch_size)) # 数据可视化:make_grid()将该批次的图片内容组合为一个图片,用于显示,nrow用于设置生成图片中每行的样本数量
print(','.join('%5s' % classes[labels[j]] for j in range(len(images))))
# 输出 Trouser,Trouser,Dress,  Bag,Shirt,Sandal,Shirt,Dress,  Bag,  Bag### 1.5 定义模型类
class myConNet(torch.nn.Module):def __init__(self):super(myConNet, self).__init__()# 定义卷积层self.conv1 = torch.nn.Conv2d(in_channels = 1 ,out_channels = 6,kernel_size = 3)self.conv2 = torch.nn.Conv2d(in_channels = 6,out_channels = 12,kernel_size = 3)self.conv3 = torch.nn.Conv2d(in_channels = 12, out_channels=10, kernel_size = 3) # 分为10个类def forward(self,t):# 第一层卷积和池化处理t = self.conv1(t)t = F.relu(t)t = F.max_pool2d(t, kernel_size=2, stride=2)# 第二层卷积和池化处理t = self.conv2(t)t = F.relu(t)t = F.max_pool2d(t, kernel_size=2, stride=2)# 第三层卷积和池化处理t = self.conv3(t)t = F.avg_pool2d(t,kernel_size = t.shape[-2:],stride = t.shape[-2:]) # 设置池化区域为输入数据的大小(最后两个维度),完成全局平均化的处理。return t.reshape(t.shape[:2])if __name__ == '__main__':network = myConNet() # 生成自定义模块的实例化对象#指定设备device = torch.device("cuda:0"if torch.cuda.is_available() else "cpu")print(device)network.to(device)print(network) # 打印myConNet网络
### 1.6 损失函数与优化器criterion = torch.nn.CrossEntropyLoss()  #实例化损失函数类optimizer = torch.optim.Adam(network.parameters(), lr=.01)
### 1.7 训练模型for epoch in range(2):  # 数据集迭代2次running_loss = 0.0for i, data in enumerate(train_loader, 0):  # 循环取出批次数据 使用enumerate()函数对循环计数,第二个参数为0,表示从0开始inputs, labels = datainputs, labels = inputs.to(device), labels.to(device)  #optimizer.zero_grad()  # 清空之前的梯度outputs = network(inputs)loss = criterion(outputs, labels)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新参数running_loss += loss.item()### 训练过程的显示if i % 1000 == 999:print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 2000))running_loss = 0.0print('Finished Training')
### 1.8 保存模型torch.save(network.state_dict(),'./models/CNNFashionMNist.PTH')### 1.9 加载模型,并且使用该模型进行预测network.load_state_dict(torch.load('./models/CNNFashionMNist.PTH')) # 加载模型# 使用模型dataiter = iter(test_loader) # 获取测试数据images, labels = dataiter.next()inputs, labels = images.to(device), labels.to(device)imshow(torchvision.utils.make_grid(images, nrow=batch_size)) # 取出一批数据进行展示print('真实标签: ', ' '.join('%5s' % classes[labels[j]] for j in range(len(images))))# 输出:真实标签:  Ankle_Boot Pullover Trouser Trouser Shirt Trouser  Coat Shirt Sandal Sneakeroutputs = network(inputs) # 调用network对输入样本进行预测,得到测试结果outputs_, predicted = torch.max(outputs, 1) # 对于预测结果outputs沿着第1维度找出最大值及其索引值,该索引值即为预测的分类结果print('预测结果: ', ' '.join('%5s' % classes[predicted[j]] for j in range(len(images))))# 输出:预测结果:  Ankle_Boot Pullover Trouser Trouser Pullover Trouser Shirt Shirt Sandal Sneaker### 1.10 评估模型# 测试模型class_correct = list(0. for i in range(10)) # 定义列表,收集每个类的正确个数class_total = list(0. for i in range(10)) # 定义列表,收集每个类的总个数with torch.no_grad():for data in test_loader: # 遍历测试数据集images, labels = datainputs, labels = images.to(device), labels.to(device)outputs = network(inputs) # 将每个批次的数据输入模型_, predicted = torch.max(outputs, 1) # 计算预测结果predicted = predicted.to(device)c = (predicted == labels).squeeze() # 统计正确的个数for i in range(10): # 遍历所有类别label = labels[i]class_correct[label] = class_correct[label] + c[i].item() # 若该类别正确则+1class_total[label] = class_total[label] + 1 # 根据标签中的类别,计算类的总数sumacc = 0for i in range(10): # 输出每个类的预测结果Accuracy = 100 * class_correct[i] / class_total[i]print('Accuracy of %5s : %2d %%' % (classes[i], Accuracy))sumacc = sumacc + Accuracyprint('Accuracy of all : %2d %%' % (sumacc / 10.)) # 输出最终的准确率

输出:

Accuracy of T-shirt : 72 %
Accuracy of Trouser : 96 %
Accuracy of Pullover : 75 %
Accuracy of Dress : 72 %
Accuracy of  Coat : 75 %
Accuracy of Sandal : 90 %
Accuracy of Shirt : 35 %
Accuracy of Sneaker : 93 %
Accuracy of   Bag : 92 %
Accuracy of Ankle_Boot : 92 %
Accuracy of all : 79 %

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/469369.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Service Worker,Web Worker,WebSocket的对比

Service Worker 处理网络请求的后台服务。适用于离线和后台同步数据或推送信息。不能直接和dom交互。通过postMessage方法交互。 Web Worker 模拟多线程,允许复杂计算功能的脚本在后台运行而不会阻碍到其他脚本的运行。适用于处理器占用量大而又不阻碍的情形。不能直…

MTK 升级USB问题

问题:我们的开发环境是ubuntu里面安装xp ,经常是xp下没有正常识别preload模式下的usb.这样肯定不能升级不了。 设置:MTK preload下的USB vid:0e8d pid:2000 revion 0100 知道这几个值了,在usb配置里面增加这个筛选项就可以了。

JAVA 8 StreamAPI 和 lambda表达式 总结(一)--lambda表达式

这些天看见另一本好书《给大忙人看的Java SE 8》,其中的新特性 StreamAPI 和 lambda表达式 是之前jdk没有提供的新特性,也是jdk8 重要的更新内容,我会总结一下它们的用法,更详细的参见书本。 lambda表达式的概念 人对一个概念的理…

【Pytorch神经网络理论篇】 14 过拟合问题的优化技巧(一):基本概念+正则化+数据增大

同学你好!本文章于2021年末编写,获得广泛的好评! 故在2022年末对本系列进行填充与更新,欢迎大家订阅最新的专栏,获取基于Pytorch1.10版本的理论代码(2023版)实现, Pytorch深度学习理论篇(2023版)目录地址…

MTK 8127平台使用busybox

一、什么是BusyBox ? BusyBox 是标准 Linux 工具的一个单个可执行实现。BusyBox 包含了一些简单的工具,例如 cat 和 echo,还包含了一些更大、更复杂的工具,例如 grep、find、mount 以及 telnet。有些人将 BusyBox 称为 Linux 工具…

MediaPlayer 播放视频的方法

MediaPlayer mediaPlayer new MediaPlayer(); mediaPlayer.reset();//重置为初始状态 mediaPlayer.setAudioStreamType(AudioManager.STREAM_MUSIC); mediaPlayer.setDisplay(surfaceView.getHolder());//设置画面显示为surfaceView mediaPlayer.setDataSource("/mnt/sdc…

Android bootchart分析

1.首先确保编译的init被烧录到板子里面去了,源码的位置在system/core/init/ 2.第一次修改后,编译了system/core/init/然后又编译了./mkkernel 生成boot.img 但是烧录进去还是不成功 3.然后 发现有一个宏没有设置 在bootchart.h里面,BOOTCHART 修改后重新编译,烧了所有…

【Pytorch神经网络理论篇】 15 过拟合问题的优化技巧(二):Dropout()方法

同学你好!本文章于2021年末编写,获得广泛的好评! 故在2022年末对本系列进行填充与更新,欢迎大家订阅最新的专栏,获取基于Pytorch1.10版本的理论代码(2023版)实现, Pytorch深度学习理论篇(2023版)目录地址…

bzoj2435: [Noi2011]道路修建 树上dp

点击打开链接 RE了一辈子... 思路&#xff1a;树上dp&#xff0c;直接dfs找到每个点v的子节点有多少&#xff0c; 那么对答案的贡献是 w*abs((n-size[v])-size[v]); RE代码&#xff1a; 1 #include <bits/stdc.h>2 using namespace std;3 typedef long long ll;4 const i…

android之APP模块编译

原文地址&#xff1a;http://blog.csdn.net/yaphet__s/article/details/45640627 一&#xff0c;如何把app编进系统 a.源码编译&#xff0c;在packages/apps目录下有安卓原生的app&#xff0c;以Bluetooth为例&#xff0c;源码根目录下有Android.mk文件&#xff1a; packages\a…

【Pytorch神经网络理论篇】 16 过拟合问题的优化技巧(三):批量归一化

同学你好&#xff01;本文章于2021年末编写&#xff0c;获得广泛的好评&#xff01; 故在2022年末对本系列进行填充与更新&#xff0c;欢迎大家订阅最新的专栏&#xff0c;获取基于Pytorch1.10版本的理论代码(2023版)实现&#xff0c; Pytorch深度学习理论篇(2023版)目录地址…

【Pytorch神经网络理论篇】 17 循环神经网络结构:概述+BP算法+BPTT算法

同学你好&#xff01;本文章于2021年末编写&#xff0c;获得广泛的好评&#xff01; 故在2022年末对本系列进行填充与更新&#xff0c;欢迎大家订阅最新的专栏&#xff0c;获取基于Pytorch1.10版本的理论代码(2023版)实现&#xff0c; Pytorch深度学习理论篇(2023版)目录地址…

Android 闹钟

需求&#xff1a;新的平台要实现关机启动&#xff0c;所以要了解一下闹钟的机制 这个链接写得比较详细&#xff08;我只是动手试了试&#xff0c;毕竟应用不是专长&#xff09;&#xff1a;http://www.cnblogs.com/mengdd/p/3819806.html &#xff11;、AlarmManager 这个是…

【Pytorch神经网络理论篇】 18 循环神经网络结构:LSTM结构+双向RNN结构

同学你好&#xff01;本文章于2021年末编写&#xff0c;获得广泛的好评&#xff01; 故在2022年末对本系列进行填充与更新&#xff0c;欢迎大家订阅最新的专栏&#xff0c;获取基于Pytorch1.10版本的理论代码(2023版)实现&#xff0c; Pytorch深度学习理论篇(2023版)目录地址…

DWR之初尝

---恢复内容开始--- 准备公工作 1.去官网下载jar和war 开发工具 eclipse 开始开发:gogogo 1:建立一个可以跑起来的javaweb项目,最基本的就可以了. 2:导入commons-logging-1.0.4.jar,dwr.jar 3:在web.xml里配置一下 <?xml version"1.0" encoding"UTF-8"…

jni调试

&#xff11;&#xff12;年的时候写过&#xff2a;&#xff2e;&#xff29;但是又忘记得差不多了&#xff0c;现在重新写了一次&#xff0c;发现碰到了几个问题&#xff0c;写下来记录一下 第一步 应用程序java代码 package com.example.helloworld;import java.util.Calend…

【Pytorch神经网络理论篇】 19 循环神经网络训练语言模型:语言模型概述+NLP多项式概述

同学你好&#xff01;本文章于2021年末编写&#xff0c;获得广泛的好评&#xff01; 故在2022年末对本系列进行填充与更新&#xff0c;欢迎大家订阅最新的专栏&#xff0c;获取基于Pytorch1.10版本的理论代码(2023版)实现&#xff0c; Pytorch深度学习理论篇(2023版)目录地址…

B. Code For 1 一个类似于线段树的东西

http://codeforces.com/contest/768/problem/B 我的做法是&#xff0c;观察到&#xff0c;只有是x % 2的情况下&#xff0c;才有可能出现0 其他的&#xff0c;都是1来的&#xff0c;所以开始的ans应该是R - L 1 那么现在就是要看那些是x % 2的&#xff0c;然后放在的位置是属于…

解码错误。‘gb2312‘ codec can‘t decode byte 0xf3 in position 307307: illegal multibyte sequence

一般在decode加errors"ignore"就可以了。例如&#xff1a; decode(gb2312,errors ignore)