基于Pytorch框架的深度学习Swin-Transformer神经网络食物分类系统源码

 第一步:准备数据

5种鸟类数据:self.class_indict = ["苹果派", "猪小排", "果仁蜜饼", "生牛肉薄片", "鞑靼牛肉"]

,总共有5000张图片,每个文件夹单独放一种数据

第二步:搭建模型

本文选择一个Swin-Transformer网络,其原理介绍如下:

Swin-Transformer是2021年微软研究院发表在ICCV上的一篇文章,并且已经获得ICCV 2021 best paper的荣誉称号。虽然Vision Transformer (ViT)在图像分类方面的结果令人鼓舞,但是由于其低分辨率特性映射和复杂度随图像大小的二次增长,其结构不适合作为密集视觉任务高分辨率输入图像的通过骨干网路。为了最佳的精度和速度的权衡,提出了Swin-Transformer结构。

Swin-Transformer的基础流程。

  1. 输入一张图片 [ H ∗ W ∗ 3 ] [H*W*3] [H∗W∗3]
  2. 图片经过Patch Partition层进行图片分割
  3. 分割后的数据经过Linear Embedding层进行特征映射
  4. 将特征映射后的数据输入具有改进的自关注计算的Transformer块(Swin Transformer块),并与Linear Embedding一起被称为第1阶段
  5. 与阶段1不同,阶段2-4在输入模型前需要进行Patch Merging进行下采样,产生分层表示。
  6. 最终将经过阶段4的数据经过输出模块(包括一个LayerNorm层、一个AdaptiveAvgPool1d层和一个全连接层)进行分类。
Swin-Transformer结构

简单看下原论文中给出的关于Swin Transformer(Swin-T)网络的架构图。其中,图(a)表示Swin Transformer的网络结构流程,图(b)表示两阶段的Swin Transformer Block结构。注意:在Swin Transformer中,每个阶段的Swin Transformer Block结构都是2的倍数,因为里面使用的都是两阶段的Swin Transformer Block结构,如下图所示:

第三步:训练代码

1)损失函数为:交叉熵损失函数

2)训练代码:

import os
import argparseimport torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transformsfrom my_dataset import MyDataSet
from model import swin_tiny_patch4_window7_224 as create_model
from utils import read_split_data, train_one_epoch, evaluatedef main(args):device = torch.device(args.device if torch.cuda.is_available() else "cpu")if os.path.exists("./weights") is False:os.makedirs("./weights")tb_writer = SummaryWriter()train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)img_size = 224data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(img_size),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),"val": transforms.Compose([transforms.Resize(int(img_size * 1.143)),transforms.CenterCrop(img_size),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}# 实例化训练数据集train_dataset = MyDataSet(images_path=train_images_path,images_class=train_images_label,transform=data_transform["train"])# 实例化验证数据集val_dataset = MyDataSet(images_path=val_images_path,images_class=val_images_label,transform=data_transform["val"])batch_size = args.batch_sizenw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workersprint('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,pin_memory=True,num_workers=nw,collate_fn=train_dataset.collate_fn)val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=batch_size,shuffle=False,pin_memory=True,num_workers=nw,collate_fn=val_dataset.collate_fn)model = create_model(num_classes=args.num_classes).to(device)if args.weights != "":assert os.path.exists(args.weights), "weights file: '{}' not exist.".format(args.weights)weights_dict = torch.load(args.weights, map_location=device)["model"]# 删除有关分类类别的权重for k in list(weights_dict.keys()):if "head" in k:del weights_dict[k]print(model.load_state_dict(weights_dict, strict=False))if args.freeze_layers:for name, para in model.named_parameters():# 除head外,其他权重全部冻结if "head" not in name:para.requires_grad_(False)else:print("training {}".format(name))pg = [p for p in model.parameters() if p.requires_grad]optimizer = optim.AdamW(pg, lr=args.lr, weight_decay=5E-2)for epoch in range(args.epochs):# traintrain_loss, train_acc = train_one_epoch(model=model,optimizer=optimizer,data_loader=train_loader,device=device,epoch=epoch)# validateval_loss, val_acc = evaluate(model=model,data_loader=val_loader,device=device,epoch=epoch)tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]tb_writer.add_scalar(tags[0], train_loss, epoch)tb_writer.add_scalar(tags[1], train_acc, epoch)tb_writer.add_scalar(tags[2], val_loss, epoch)tb_writer.add_scalar(tags[3], val_acc, epoch)tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch))if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--num_classes', type=int, default=5)parser.add_argument('--epochs', type=int, default=100)parser.add_argument('--batch-size', type=int, default=4)parser.add_argument('--lr', type=float, default=0.0001)# 数据集所在根目录# https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgzparser.add_argument('--data-path', type=str,default=r"G:\demo\data\foods")# 预训练权重路径,如果不想载入就设置为空字符parser.add_argument('--weights', type=str, default='swin_tiny_patch4_window7_224.pth',help='initial weights path')# 是否冻结权重parser.add_argument('--freeze-layers', type=bool, default=False)parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')opt = parser.parse_args()main(opt)

第四步:统计正确率

第五步:搭建GUI界面

第六步:整个工程的内容

有训练代码和训练好的模型以及训练过程,提供数据,提供GUI界面代码

代码的下载路径(新窗口打开链接):基于Pytorch框架的深度学习Swin-Transformer神经网络食物分类系统源码

有问题可以私信或者留言,有问必答

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/30156.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

分布式锁(4):jedis基于Redis setnx、get、getset的分布式锁

1 实现原理 setnx(lockkey, 当前时间+过期超时时间) ,如果返回1,则获取锁成功;如果返回0则没有获取到锁,转向步骤(2)get(lockkey)获取值oldExpireTime ,并将这个value值与当前的系统时间进行比较,如果小于当前系统时间,则认为这个锁已经超时,可以允许别的请求重新获取,…

《多线程》

每一个任务就是一个进程,每个进程内部至少有一个线程在运行中。线程是程序执行的一个路径,每一个线程都有自己的局部变量表,程序技术器,以及各自的生命周期。 1.创建一个线程,并且重写它的run方法,将行为方…

swift使用swift-protobuf协议通讯,使用指北

什么是Protobuf Protobuf(Protocol Buffers)协议😉 Protobuf 是一种由 Google 开发的二进制序列化格式和相关的技术,它用于高效地序列化和反序列化结构化数据,通常用于网络通信、数据存储等场景。 为什么要使用Proto…

UnrealEngine打开Setup.bat,提示Failed to download的解决方法

Failed to download when I run Setup.bat - #3 by Milisours - Getting Started & Setup - Epic Developer Community Forums https://forums.unrealengine.com/uploads/short-url/oGTskBcZI8ACTyCw7jIK2dTmkC7.xml 下载这个文件 然后替换掉Engine/Build/下面的Commit.g…

c++ map set底层模拟实现

关于这两个数据结构的insert接口实现 请看这篇文章 https://blog.csdn.net/l23456789mmmmm/article/details/139500413?spm1001.2014.3001.5501 map::operator[]底层实现请看这篇文章 cmap类operator[]详解_c map operator-CSDN博客 红黑树模拟实现 #pragma once #include &…

Java面试八股之myBatis与myBatis plus的对比

myBatis与myBatis plus的对比 基础与增强: MyBatis 是一个成熟的Java持久层框架,它允许开发者通过XML文件或注解来配置SQL语句和数据库映射,提供了一个灵活的方式来操作数据库,但需要手动编写所有的SQL语句和结果集映射。 MyBa…

Day55 代码随想录打卡|二叉树篇---二叉搜索树中的插入操作

题目(leecode T701): 给定二叉搜索树(BST)的根节点 root 和要插入树中的值 value ,将值插入二叉搜索树。 返回插入后二叉搜索树的根节点。 输入数据 保证 ,新值和原始二叉搜索树中的任意节点值…

【雷丰阳-谷粒商城 】【分布式高级篇-微服务架构篇】【11】ElasticSearch

持续学习&持续更新中… 守破离 【雷丰阳-谷粒商城 】【分布式高级篇-微服务架构篇】【11】ElasticSearch 简介基本概念ElasticSearch概念-倒排索引安装基本命令ik 分词器SpringBoot整合测试存储数据:测试复杂检索同步与异步调用 参考 简介 Elasticsearch 是一…

【AIGC】MetaGPT原理以及应用

目录 MetaGPT原理 MetaGPT应用 MetaGPT和传统编程语言相比有什么优势和劣势 视频中的PPT 参考资料 MetaGPT原理 MetaGPT是一种多智能体框架,它结合了元编程技术,通过标准化操作程序(SOPs)来协调基于大语言模型的多智能体系统…

Zookeeper 集群节点选举原理实现(三)

Zookeeper 集群节点选举原理实现(三) 刚部署三个节点或者多个节点启动时,此时还未选择出领导节点,不同节点的初始化zxid 是如何保证不重复不冲突有序呢? 在 Zookeeper 集群的初始启动阶段,所有节点会在选举领导节点之前先初始化自己的状态和 ZXID。为了确保不同节点的初始…

jieba中文分词器的使用

Jieba 是一个中文分词的第三方库,主要用于对中文文本进行分词。分词是将文本分割成一个个词语的过程,这在中文文本处理中尤为重要,因为中文不像英文那样有明显的空格来分隔词语。Jieba 的分词算法可以实现精确分词、全模式分词和搜索引擎模式…

嵌入式实验---实验一 通用GPIO实验

一、实验目的 1、掌握STM32F103 GPIO程序设计流程; 2、熟悉STM32固件库的基本使用。 二、实验原理 1、通过按键实现:按键按下,LED点亮;按键释放,LED熄灭。 三、实验设备和器材 电脑、Keil uVision5软件、Proteus…

Hierarchical Integration Diffusion Model for Realistic Image Deblurring

neurips23 上交&ETH&字节&清华&上海ai lab&悉尼大学&西湖大学https://github.com/zhengchen1999/HI-Diff 问题引入 现在的diffusion的方法在sample的时候需要的iteration过多,所以本文提出在高度压缩的空间进行DM,且deblur模型…

力扣第209题“长度最小的子数组”

关注微信公众号 数据分析螺丝钉 免费领取价值万元的python/java/商业分析/数据结构与算法学习资料 在本篇文章中,我们将详细解读力扣第209题“长度最小的子数组”。通过学习本篇文章,读者将掌握如何使用滑动窗口和双指针的方法来解决这一问题&#xff0…

甲辰年五月十四风雨思

甲辰年五月十四风雨思 夜雨消暑气,远光归家心。 ​只待万窗明,朝夕千家勤。 ​苦乐言行得,酸甜日常品。 宫商角徵羽,​仁义礼智信。

【python】PyCharm如何设置字体大小和背景

目录 效果展示 字体大小 背景设置 效果展示 字体大小 再左上角找到四条杠的图标 找到File 一般字体大小为22最合适,行间距为默认 背景设置 还是再字体设置的页面搜索 background 小编的其他文章详见,欢迎来支持 东洛的克莱斯韦克-CSDN博客 【机器…

如何优雅的一键下载OpenHarmony活跃分支代码?请关注【itopen: ohos_download】

itopen组织:1、提供OpenHarmony优雅实用的小工具2、手把手适配riscv qemu linux的三方库移植3、未来计划riscv qemu ohos的三方库移植 小程序开发4、一切拥抱开源,拥抱国产化 一、概述 为方便大家每次下载OpenHarmony不同分支/tag代码&#xff0c…

【文末附gpt升级秘笈】“登月游戏”对人类的意义

“登月游戏”对人类的意义是多方面的,不仅体现在科技、教育和娱乐层面,还对人类探索未知的精神产生了深远影响。 一、科技意义 “登月游戏”作为早期计算机游戏的代表之一,展示了计算机技术在模拟现实世界方面的能力。通过模拟登月器的着陆…

数据库-单表查询-排序和分组

对查询结果排序: SELECT 字段名 FROM 表名 ORDER BY 字段名 [ASC[DESC]]; ASC 升序关键字DESC 降序关键字 分组查询的单独使用: SELECT 字段名 FROM 表名 GROUP BY 字段名; 使用 LIMIT 限制查询结果的数量: SELECT 字段名 FROM 表名 LIMIT [OFFSET,] 记录数; 第一个…