Pytorch实现图像分类-水果数据集分类--深度学习大作业

目录

1.概述 

2.设计

3.实现

4.实验 

5.总结


1.概述 

本次深度学习大作业,我使用AlexNet模型对"Fruits-360"数据集中的两部分水果和蔬菜图片进行分类

2.设计

模型设计:Alexnet网络

  • 卷积层部分:构建了一系列卷积层、激活函数、最大池化层以及Dropout层,这一系列操作旨在从原始图像中提取丰富的特征。
  • 全连接层部分:通过计算得到的特征图尺寸动态设置全连接层的输入大小,设计了多层全连接网络,包含ReLU激活、Dropout正则化,最后输出层针对数据集的类别数量(本例中为2)进行调整。

因为输入图像数据为RGB图像,在模型的设计时调整,并在设计全连接层时引入了动态尺寸计算方法,保证了模型的通用性和适应性。

3.实现

 代码如下:


import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch import nn, optim# 数据预处理
image_size = (224, 224)
data_transforms = transforms.Compose([transforms.RandomHorizontalFlip(),transforms.Resize(image_size),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])#导入数据集
import torchvision.datasets as datasets
train_data=datasets.ImageFolder (root='fruits-360-original-size/fruits-360-original-size/Training',transform=data_transforms)
test_data=datasets.ImageFolder (root='fruits-360-original-size/fruits-360-original-size/Test',transform=data_transforms)
# print(train_data.classes)
# print('..............')
# print(test_data.classes)#DataLoaderbatchsize=10#每个批次(batch)中包含的样本数量
train_loader = DataLoader(train_data, batch_size=batchsize, shuffle=True, num_workers=1)
test_loader = DataLoader(test_data, batch_size=batchsize, shuffle=False, num_workers=1)  # 测试时不需打乱数据#创建模型class AlexNet(nn.Module):def __init__(self):super(AlexNet, self).__init__()self.conv = nn.Sequential(nn.Conv2d(3, 96, 5, 1, 2),#输入通道数,输出通道数,卷积核大小,步长,填充(!!!rgb图像所以是三个通道,开始没注意以为灰度图像)nn.ReLU(),nn.MaxPool2d(3, 2), nn.Conv2d(96, 256, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(3, 2),nn.Conv2d(256, 384, 3, 1, 1),nn.ReLU(),nn.Conv2d(384, 384, 3, 1, 1),nn.ReLU(),nn.Conv2d(384, 256, 3, 1, 1),nn.ReLU(),nn.MaxPool2d(3, 2))# 计算全连接层输入大小self.fc_input_size = self._get_fc_input_size()self.fc = nn.Sequential(nn.Linear(self.fc_input_size, 4096),nn.ReLU(),nn.Dropout(0.5),#随机丢弃nn.Linear(4096, 4096),nn.ReLU(),nn.Dropout(0.5),nn.Linear(4096, 2)#修改为2因为输出只有两个类)def forward(self, img):# 通过卷积层前向传播,img是输入图像张量feature = self.conv(img)feature = feature.view(img.size(0), -1)#展平# 通过全连接层(fc)进行前向传播,得到最终的输出output = self.fc(feature)return output#动态计算全连接层(FC层)所需要的输入尺寸def _get_fc_input_size(self):# 创建一个与训练/测试时相同尺寸和通道数的随机张量,用于通过卷积层x = torch.randn(1, 3, image_size[0], image_size[1])# 其中3对应RGB图像的通道数,image_size是从外部传入的图像预处理后的尺寸,默认为(224, 224)x = self.conv(x)return x.view(-1).size(0)#展平后的向量长度# 实例化模型、损失函数和优化器
model = AlexNet().to(device="cuda" if torch.cuda.is_available() else "cpu")
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)# 训练函数
def train(model, device, train_loader, optimizer, criterion):model.train()# 遍历训练数据加载器中的每个批次for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)#这样更快# 梯度清零,防止梯度累积optimizer.zero_grad()output = model(data)#预测输出loss = criterion(output, target)loss.backward()optimizer.step()# 每10个batch打印一次训练信息if (batch_idx + 1) % 10 == 0 or (batch_idx + 1) == len(train_loader):print(f'训练轮次: {epoch + 1}/{num_epochs} 损失: {loss.item():.6f}')# 测试函数
def test(model, device, test_loader):model.eval()test_loss = 0correct = 0#避免在测试过程中计算和存储梯度,节省内存并加速计算with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += criterion(output, target).item()  # 累加批次损失pred = output.argmax(dim=1, keepdim=True)  # 获取预测概率最大的类别索引correct += pred.eq(target.view_as(pred)).sum().item()#累加预测正确的数量test_loss /= len(test_loader.dataset)  # 平均损失# 打印测试结果,包括平均损失、正确预测的总数、总样本数以及准确率print(f'\n测试集: 平均损失: {test_loss:.6f}, 正确: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.2f}%)\n')# 主训练循环
num_epochs = 10  # 设置训练轮数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)for epoch in range(num_epochs):print(f'第{epoch + 1}轮训练开始')train(model, device, train_loader, optimizer, criterion)test(model, device, test_loader)

4.实验 

实验过程:首先对数据进行预处理,然后导入数据集合和数据加载,然后对模型进行构造,然后对模型进行训练和测试

实验结果如下:

 

5.总结

      在实验中由于使用AlexNet网络对RGB图像进行图像分类,所以不是灰度图像的输入通道为1,而是改成3,一开始没想到这点,然后对于模型的输出来说,由于我是在电脑上跑的,内存不太够,我对于Fruits-360数据集进行删减,最后剩下两个类别,所以模型的输出应该改成2,然后还有一些训练过程中的错误,实现了深度学习的图像分类,锻炼了实践能力以及综合能力

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/874205.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【等保测评】服务器——Windows server 2012 R2

文章目录 **身份鉴别****访问控制****安全审计****入侵防范****恶意代码防范****可信验证****测评常用命令** Windows服务器安全计算环境测评 测评对象:Windows server 2012 R2 身份鉴别 (高风险)应对登录的用户进行身份标识和鉴别&#x…

【爱上C++】list用法详解、模拟实现

文章目录 一:list介绍以及使用1.list介绍2.基本用法①list构造方式②list迭代器的使用③容量④元素访问⑤插入和删除⑥其他操作image.png 3.list与vector对比 二:list模拟实现1.基本框架2.节点结构体模板3.__list_iterator 结构体模板①模板参数说明②构…

【无人机】低空经济中5G RedCap芯片的技术分析报告

1. 引言 图一. 新基建:低空经济 低空经济作为一种新兴的经济形态,涵盖了无人机、电动垂直起降飞行器(eVTOL)、低空物流、空中交通管理等多个领域。随着5G网络的普及和演进,5G RedCap(Reduced Capability&a…

Typora 1.5.8 版本安装下载教程 (轻量级 Markdown 编辑器),图文步骤详解,免费领取(软件可激活使用)

文章目录 软件介绍软件下载安装步骤激活步骤 软件介绍 Typora是一款基于Markdown语法的轻量级文本编辑器,它的主要目标是为用户提供一个简洁、高效的写作环境。以下是Typora的一些主要特点和功能: 实时预览:Typora支持实时预览功能&#xff0…

腾讯云简单部署MYSQL 8.0

1.安装MySQL8.0资源库 yum localinstall https://repo.mysql.com//mysql80-community-release-el7-1.noarch.rpm2.安装MySQL8.0 yum -y install mysql-community-server --nogpgcheck . yum -y install mysql-community-server --nogpgcheck 3.启动MySQL并配置开机自启 sys…

【效率提升】程序员常用Shell脚本

文章目录 常用Shell脚本一. 定期更新分区数据二、获取系统资源的使用情况 常用Shell脚本 一. 定期更新分区数据 在某些场景下,我们需要对N年前某一分区的数据进行删除,并添加今年该对应分区的数据,实现数据的流动式存储。 #!/bin/bash dt$…

【devops】ttyd 一个web版本的shell工具 | web版本shell工具 | web shell

一、什么是 TTYD ttyd是在web端一个简单的服务器命令行工具 类似我们在云厂商上直接ssh链接我们的服务器输入指令一样 二、安装ttyd 1、macOS Install with Homebrew: brew install ttydInstall with MacPorts: sudo port install ttyd 2、linux Binary version (recommend…

神经网络中如何优化模型和超参数调优(案例为tensor的预测)

总结: 初级:简单修改一下超参数,效果一般般但是够用,有时候甚至直接不够用 中级:optuna得出最好的超参数之后,再多一些epoch让train和testloss整体下降,然后结果就很不错。 高级:…

Redis集群部署Windows版本

Redis集群 之前因为数据量的原因,并没有进行Redis集群的配置需要,现在由于数据量大,需要进行集群部署。 最初在windows系统部署,需要Redis的windows版本,但官方没有windows版本,所以需要去gitHub上找由民…

【STM32】MPU内存保护单元

注:仅在F7和M7系列上使用介绍 功能: 设置不同存储区域的存储器访问权限(管理员、用户) 设置存储器(内存和外设)属性(可缓冲、可缓存、可共享) 优点:提高嵌入式系统的健壮…

Bash 学习摘录

文章目录 1、变量和参数的介绍(1)变量替换$(...) (2)特殊的变量类型export位置参数shift 2、引用(1)引用变量(2)转义 3、条件判断(1)条件测试结构&#xff08…

Qt+OpenCascade开发笔记(一):occ的windows开发环境搭建(一):OpenCascade介绍、下载和安装过程

若该文为原创文章,转载请注明原文出处 本文章博客地址:https://hpzwl.blog.csdn.net/article/details/140604141 长沙红胖子Qt(长沙创微智科)博文大全:开发技术集合(包含Qt实用技术、树莓派、三维、OpenCV…

[C++进阶]模板进阶

此篇是学完stl后对于模板的补充 建议先看看这个[C初阶]模板初阶-CSDN博客 一、类模板 此处是对初阶讲过的 1. 类模板的定义格式 template<class T1, class T2, …, class Tn> class 类模板名 {}; 例如我们之前学习过的vector类&#xff1a; template<class T>…

C++中的多路转接技术之epoll

epoll 是干什么的&#xff1f;举个简单的例子 epoll的相关系统调用**epoll_create**和epoll_create1区别 epoll_ctl参数解释 **epoll_wait**参数说明返回值 epoll的使用 **epoll**工作原理epoll的优点(和 **select** 的缺点对应)epoll工作方式**水平触发**Level Triggered 工作…

Springboot 启动时Bean的创建与注入(一)-面试热点-springboot源码解读-xunznux

Springboot 启动时Bean的创建与注入&#xff0c;以及对应的源码解读 文章目录 Springboot 启动时Bean的创建与注入&#xff0c;以及对应的源码解读构建Web项目流程图&#xff1a;堆栈信息&#xff1a;堆栈信息简介堆栈信息源码详解1、main:10, DemoApplication (com.xun.demo)2…

HashMap与ConcurrentHashMap

文章目录 HashMap1.1 HashMap 的数据结构&#xff1f;1.2 HashMap 的动态扩容1.3 Hash实现方法1.4 如何解决Hash冲突 ConcurrentHashMap HashMap 1.1 HashMap 的数据结构&#xff1f; 哈希表结构&#xff08;链表散列&#xff1a;数组链表&#xff09;实现&#xff0c;结合数…

详细分析Springboot自定义启动界面(附Demo)

目录 前言1. banner.text1.1 配置文件关闭1.2 启动类关闭1.3 命令行关闭 2. 自定义Banner类3. 自动配置类4. 总结 前言 实现自定义启动动画是一项有趣的任务&#xff0c;虽然Spring Boot本身不提供内置的动画功能&#xff0c;但可以通过一些技巧来实现 以下主要以Demo的形式展…

三字棋游戏(C语言详细解释)

hello&#xff0c;小伙伴们大家好&#xff0c;算是失踪人口回归了哈&#xff0c;主要原因是期末考试完学校组织实训&#xff0c;做了俄罗斯方块&#xff0c;后续也会更新&#xff0c;不过今天先从简单的三字棋说起 话不多说&#xff0c;开始今天的内容 一、大体思路 我们都知…

MongoDB教程(十三):MongoDB覆盖索引

&#x1f49d;&#x1f49d;&#x1f49d;首先&#xff0c;欢迎各位来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里不仅可以有所收获&#xff0c;同时也能感受到一份轻松欢乐的氛围&#xff0c;祝你生活愉快&#xff01; 文章目录 引言什么是覆盖…

数据结构(栈及其实现)

栈 概念与结构 栈&#xff1a;⼀种特殊的线性表&#xff0c;其只允许在固定的⼀端进⾏插⼊和删除元素操作。 进⾏数据插⼊和删除操作的⼀端称为栈顶&#xff0c;另⼀端称为栈底。栈中的数据元素遵守后进先出 LIFO&#xff08;Last In First Out&#xff09;的原则。 压栈&…