009.ResNet-FashionMNIST-正确率93.739

一、ResNet简介

  • ResNet是一次CNN网络架构,核心思想是引入"残差学习"来解决深层网络难以训练的问题。
  • 在传统的网络中,每一层都直接尝试学习目标映射。相反,ResNet通过跨层连接,允许某一层学习输入与输出之间的残差(或者说是差异),使得这些网络层只需要学习与输入的微小差异,从而简化了学习目标和过程。

二、FashionMNIST数据集简介

  • 之前的博客已经较为细致的介绍了FashionMNIST数据集:插眼传送

注意:了解数据集是机器学习的所有环节中最重要的一步,没有之一。

三、用代码实现FashionMNIST预测

1.导包
from torchvision.datasets import FashionMNIST
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torch
import numpy as np
import random
2.加载数据
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = 'cpu'
generator = torch.Generator()# 设置随机种子,确保实验可重复性
seed_value = 420
torch.manual_seed(seed_value)
random.seed(seed_value)
np.random.seed(seed_value)
# 如果你使用CUDA并希望进一步确定性,可以添加下面两行代码
torch.cuda.manual_seed(seed_value)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
generator.manual_seed(seed_value)transform = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.RandomRotation([-8,8]),transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))  # 归一化处理
])transform2 = transforms.Compose([#transforms.RandomRotation([-5,5]),transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))  # 归一化处理
])# 从"./dataset/"目录加载FashionMNIST数据集,如果没有则会自动下载。
train_data = FashionMNIST(root='./dataset/', train=True,  download=True,transform=transform)
test_data = FashionMNIST(root='./dataset/', train=False,  download=True,transform=transform2)
train_batch = DataLoader(dataset=train_data, batch_size=128,  shuffle=True, num_workers=0, drop_last=False, generator=generator)
test_batch = DataLoader(dataset=test_data, batch_size=128,  shuffle=False, num_workers=0, drop_last=False, generator=generator)
3.定义模型
class Model(torch.nn.Module):def __init__(self,in_features=1,out_features=10):super().__init__()self.relu = torch.nn.ReLU()self.conv1 = torch.nn.Conv2d(in_channels=in_features, out_channels=64, kernel_size=3, bias=False) self.adavgpool = torch.nn.AdaptiveAvgPool2d((1, 1))self.block1 = torch.nn.Sequential(self.conv1, torch.nn.BatchNorm2d(64), self.relu)self.output = torch.nn.Linear(512, out_features, bias=True)self.maxpool = torch.nn.AvgPool2d(2,ceil_mode=True)self.downsample = torch.nn.Sequential(torch.nn.Conv2d(64, 128, kernel_size=1,stride=2,bias=False),torch.nn.BatchNorm2d(128))self.downsample2 = torch.nn.Sequential(torch.nn.Conv2d(128, 256, kernel_size=1,stride=2, bias=False), torch.nn.BatchNorm2d(256))self.downsample3 = torch.nn.Sequential(torch.nn.Conv2d(256, 512, kernel_size=1,stride=2, bias=False), torch.nn.BatchNorm2d(512))self.conv_res = torch.nn.Sequential(torch.nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1, stride=2, bias=False),torch.nn.BatchNorm2d(128),self.relu,torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1, bias=False),torch.nn.BatchNorm2d(128),)self.conv_res2 = torch.nn.Sequential(torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1, stride=1, bias=False),torch.nn.BatchNorm2d(128),self.relu,torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1, bias=False),torch.nn.BatchNorm2d(128),)self.conv_res3 = torch.nn.Sequential(torch.nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1, stride=2, bias=False),torch.nn.BatchNorm2d(256),self.relu,torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1, bias=False),torch.nn.BatchNorm2d(256),)self.conv_res4 = torch.nn.Sequential(torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1, stride=1, bias=False),torch.nn.BatchNorm2d(256),self.relu,torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1, bias=False),torch.nn.BatchNorm2d(256),)self.conv_res5 = torch.nn.Sequential(torch.nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1, stride=2, bias=False),torch.nn.BatchNorm2d(512),self.relu,torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1, bias=False),torch.nn.BatchNorm2d(512),)self.conv_res6 = torch.nn.Sequential(torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1, stride=1, bias=False),torch.nn.BatchNorm2d(512),self.relu,torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1, bias=False),torch.nn.BatchNorm2d(512),)def forward(self,x):x = self.block1(x)identity = self.downsample(x)x = self.conv_res(x)x += identityx = self.relu(x)x = self.conv_res2(x)x += identityx = self.relu(x)identity = self.downsample2(x)x = self.conv_res3(x)x += identityx = self.relu(x)x = self.conv_res4(x)x += identityx = self.relu(x)identity = self.downsample3(x)x = self.conv_res5(x)x += identityx = self.relu(x)x = self.conv_res6(x)x += identityx = self.relu(x)x = self.adavgpool(x)x = x.view(len(x), -1)x = self.output(x)return x

注意:此模型不是完整的ResNet网络,这里做了部分修改,以适应当前图片尺寸。

4.定义损失函数、优化器
from torch.optim import Adam
from torch.nn import functional as F# 初始化一个模型,输入图片通道数为1,输出特征为10
model = Model().to(device)
# 使用负对数似然损失函数
criterion = torch.nn.CrossEntropyLoss()
# 初始化Adam优化器,设定学习率为0.005
opt = Adam(model.parameters(), lr=0.001)
5.开始训练
# 进行9次迭代
for _ in range(49):# 遍历数据批次for n_, batch in enumerate(train_batch):# 将输入数据X调整形状并输入到模型X = batch[0].to(device)# y为真实标签y = batch[1].to(device)# 前向传播,获取模型输出sigma = model.forward(X)# 计算损失loss = criterion(sigma, y)# 计算预测的标签y_hat = torch.max(sigma, dim=1)[1]# 计算预测正确的数量correct_count = torch.sum(y_hat == y)# 计算准确率accuracy = correct_count / len(y) * 100# 反向传播,计算梯度loss.backward()# 更新模型参数opt.step()# 清除之前的梯度model.zero_grad()# 打印当前批次的损失和准确率print(n_, 'loss:', loss.item(), 'accuracy:', accuracy.item())

输出:

468 loss: 0.21156974136829376 accuracy: 91.66667175292969
468 loss: 0.24343211948871613 accuracy: 89.58333587646484
468 loss: 0.3186508119106293 accuracy: 87.5
468 loss: 0.16633149981498718 accuracy: 93.75
468 loss: 0.13033141195774078 accuracy: 93.75
468 loss: 0.09412961453199387 accuracy: 97.91667175292969
468 loss: 0.044871985912323 accuracy: 98.95833587646484
468 loss: 0.023767223581671715 accuracy: 100.0
468 loss: 0.09273606538772583 accuracy: 97.91667175292969
...
6.验证测试集
correct_count = 0
for batch in test_batch:test_X = batch[0].to(device)test_y = batch[1].to(device)sigma = model.forward(torch.tensor(test_X, dtype=torch.float32))y_hat = torch.max(sigma, dim=1)[1]correct_count += torch.sum(y_hat == test_y)accuracy = correct_count / 10000 * 100
print('accuracy:', accuracy.item())

输出:

accuracy: 93.73999786376953
  • 可以看出:ResNet相对于googleNet,处理时间减少,正确率提高。在FashionMNIST数据集上有较优的表现。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/28386.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

高级人工智能复习 题目整理 中科大

题目整理 填空 1.准确性,复杂性,验证集 2. 3 2 n 3^{2^n} 32n 3 C 2 n m 3^{C^m_{2n}} 3C2nm​ 3 m 3^m 3m n 1 n1 n1 3. 状态 从状态s采取行动a后继续采用策略 π \pi π的收益 环境 4. 语法 语义 推理规则 5. 参与者,策略集&#xff…

Elasticsearch 8.1官网文档梳理 - 十一、Ingest pipelines(管道)

Ingest pipelines 管道(Ingest pipelines)可让让数据在写入前进行常见的转换。例如可以利用管道删除文档(doc)的字段、或从文本中提取数据、丰富文档(doc)的字段等其他操作。 管道(Ingest pip…

Vite支持的React项目使用SASS指南

前言 在现代前端开发中,SASS是一种广受欢迎的CSS扩展语言,它提供了许多实用功能,如变量、嵌套、部分和混合等。 本教程将指导您在一个使用Vite作为构建工具的React项目中如何配置和使用SASS。 使用步骤 1、创建一个Vite React项目 首先确…

VirtualBox、Centos7下安装docker后pull镜像问题、ftp上传文件问题

Docker安装篇(CentOS7安装)_docker 安装 centos7-CSDN博客 首先,安装docker可以根据这篇文章进行安装,安装完之后,我们就需要去通过docker拉取相关的服务镜像,然后安装相应的服务容器,比如我们通过docker来安装mysql,…

vue 使用 ztree 超大量数据,前端树形结构展示

ztree 是一个很经典的基于jquey开发的树结构编辑展示UI组件库。 创建一个文件 ztree.vue&#xff0c;代码如下&#xff1a; <template><div><div class"ztree vue-giant-tree" :id"ztreeId"></div><div class"treeBox&q…

Android 14 蓝牙主从模式切换

切换蓝牙的A2DP&#xff08;高级音频分布配置文件&#xff09;和AVRCP&#xff08;音频/视频远程控制配置文件&#xff09;的源&#xff08;source&#xff09;和汇点&#xff08;sink&#xff09;模式。 这里&#xff0c;SystemProperties.get尝试获取bluetooth.profile.a2dp.…

在WSL2的Ubuntu中安装和使用Docker/Podman

在WSL2的Ubuntu中安装和使用Docker/Podman 0. 目的 当网络环境良好&#xff08;例如在公司&#xff0c;能直接访问Google等&#xff09;时&#xff0c; Docker/Podman 安装和使用不是问题。 当网络环境不佳&#xff08;例如在家里&#xff09;&#xff0c;要把 WSL2 的 Ubun…

Termius安装docker

安装Termius 直接上官网 新建主机 更新一下yum 更新完成 安装docker的包 直接用命令安装 设置一下开机启动&#xff0c;可以查看docker的版本

Ui学习--UITableView

UI学习 UITableView基础UITableView协议UITableView高级协议与单元格总结 UITableView基础 UITableView作为iOS中的一个控件&#xff0c;用于以表格形式展示数据。例如通讯录好友&#xff0c;朋友圈信息等&#xff0c;都是UITableView的实际运用场景。 首先我们先要加入两个协…

Mysql的增、删、查、改

MySQL 是一个流行的关系型数据库管理系统&#xff0c;它支持 SQL&#xff08;结构化查询语言&#xff09;用于管理数据库中的数据。以下是使用 SQL 在 MySQL 中进行增&#xff08;INSERT&#xff09;、删&#xff08;DELETE&#xff09;、查&#xff08;SELECT&#xff09;、改…

K210使用雷龙NAND完成火灾检测

NAND 文章目录 NAND前言一、NAND是什么&#xff1f;二、来看一看NAND三、部署火灾检测 前言 前几天收到了雷龙NAND的芯片&#xff0c;一共两个芯片和一个转接板&#xff0c;我之前也没有使用过这款芯片&#xff0c;比较好奇&#xff0c;体验了一下&#xff0c;个人认为&#x…

嵌入式微处理器重点学习(三)

堆栈操作 R1=0x005 R3=0x004 SP=0x80014 STMFD sp!, {r1, r3} 指令STMFD sp!, {r1, r3}是一条ARM架构中的存储多个寄存器到内存的指令,这里用于将r1和r3寄存器的内容存储到栈上。STMFD(Store Multiple Full Descending)是一种全递减模式的多寄存器存储指令,它会先将栈指针…

外包公司泛滥,这些常识你应该提前知道?

今年大环境确实很不好 很多985,211的应届生都在网上大吐苦水&#xff0c;很多大龄离职大厂的技术人也好&#xff0c;业务人也好&#xff0c;都纷纷转向短视频平台做起了自媒体。而找工作的人普遍发现&#xff0c;某最火的招聘平台几乎都被外包公司刷屏了。大大小小的外包公司如…

车载以太网-TC8测试

文章目录 TC8测试的用例数量TC8测试基本流程TC8测试内容TC8测试的用例数量 TC8测试的用例数量可能会因版本和具体测试内容而有所不同。一般来说,TC8测试用例总数在800条左右。 以OPEN Alliance Automotive Ethernet ECU Test Specification的3.0版本为例,该版本的测试用例总…

three.js 基础01

1.场景创建 Scene() 2.常用形状集几何体「Geometry」[可设置长宽高等内容,如:new THREE.BoxGeometry(...)] 长方体 BoxGeometry圆柱体CylinderGeometry 球体SphereGeometry圆锥体ConeGeometry矩形平面 PlaneGeometry 圆面体CircleGeometry 3.常用材质「Materi…

linux 部署瑞数6实战(维普,药监局)sign第二部分

声明 本文章中所有内容仅供学习交流使用&#xff0c;不用于其他任何目的&#xff0c;抓包内容、敏感网址、数据接口等均已做脱敏处理&#xff0c;严禁用于商业用途和非法用途&#xff0c;否则由此产生的一切后果均与作者无关&#xff01;wx …

C/C++李峋同款跳动的爱心代码

一、写在前面 在编程的世界里&#xff0c;代码不仅仅是冷冰冰的命令&#xff0c;它也可以成为表达情感、传递浪漫的工具。今天&#xff0c;就让小编带着大家用C语言打造出李峋同款跳动的爱心吧&#xff01; 首先&#xff0c;我们需要知道C作为一种高级编程语言&#xff0c;拥…

软件版本库管理工具

0 Preface/Foreword 常用代码版本管理工具包括如下几种&#xff1a; Git&#xff0c;最基本管理工具&#xff0c;由Linux kernel开发者开发Repo&#xff0c;主要用于管理Android SDK&#xff0c;由Google开发Gerrit&#xff0c;代码审查软件 1 Git 最基本的代码版本库管理工…

Linux软连接和硬连接

文章目录 软链接创建软链接查看软连接删除软链接 硬链接创建硬链接 区别小结 软链接 软连接是linux中一个常用命令&#xff0c;它的功能是为某一个文件在另外一个位置建立一个同步的链接。换句话说&#xff0c;也可以理解成Windows中的快捷方式。 创建软链接 ln -s [dir1] […

宠物健康顾问系统的设计

管理员账户功能包括&#xff1a;系统首页&#xff0c;个人中心&#xff0c;顾问管理&#xff0c;用户管理&#xff0c;健康知识管理&#xff0c;管理员管理&#xff0c;论坛管理&#xff0c;公告管理 顾问账户功能包括&#xff1a;系统首页&#xff0c;个人中心&#xff0c;顾…