基于Pytorch框架的深度学习Swin-Transformer神经网络食物分类系统源码

 第一步:准备数据

5种鸟类数据:self.class_indict = ["苹果派", "猪小排", "果仁蜜饼", "生牛肉薄片", "鞑靼牛肉"]

,总共有5000张图片,每个文件夹单独放一种数据

第二步:搭建模型

本文选择一个Swin-Transformer网络,其原理介绍如下:

Swin-Transformer是2021年微软研究院发表在ICCV上的一篇文章,并且已经获得ICCV 2021 best paper的荣誉称号。虽然Vision Transformer (ViT)在图像分类方面的结果令人鼓舞,但是由于其低分辨率特性映射和复杂度随图像大小的二次增长,其结构不适合作为密集视觉任务高分辨率输入图像的通过骨干网路。为了最佳的精度和速度的权衡,提出了Swin-Transformer结构。

Swin-Transformer的基础流程。

  1. 输入一张图片 [ H ∗ W ∗ 3 ] [H*W*3] [H∗W∗3]
  2. 图片经过Patch Partition层进行图片分割
  3. 分割后的数据经过Linear Embedding层进行特征映射
  4. 将特征映射后的数据输入具有改进的自关注计算的Transformer块(Swin Transformer块),并与Linear Embedding一起被称为第1阶段
  5. 与阶段1不同,阶段2-4在输入模型前需要进行Patch Merging进行下采样,产生分层表示。
  6. 最终将经过阶段4的数据经过输出模块(包括一个LayerNorm层、一个AdaptiveAvgPool1d层和一个全连接层)进行分类。
Swin-Transformer结构

简单看下原论文中给出的关于Swin Transformer(Swin-T)网络的架构图。其中,图(a)表示Swin Transformer的网络结构流程,图(b)表示两阶段的Swin Transformer Block结构。注意:在Swin Transformer中,每个阶段的Swin Transformer Block结构都是2的倍数,因为里面使用的都是两阶段的Swin Transformer Block结构,如下图所示:

第三步:训练代码

1)损失函数为:交叉熵损失函数

2)训练代码:

import os
import argparseimport torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transformsfrom my_dataset import MyDataSet
from model import swin_tiny_patch4_window7_224 as create_model
from utils import read_split_data, train_one_epoch, evaluatedef main(args):device = torch.device(args.device if torch.cuda.is_available() else "cpu")if os.path.exists("./weights") is False:os.makedirs("./weights")tb_writer = SummaryWriter()train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)img_size = 224data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(img_size),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),"val": transforms.Compose([transforms.Resize(int(img_size * 1.143)),transforms.CenterCrop(img_size),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}# 实例化训练数据集train_dataset = MyDataSet(images_path=train_images_path,images_class=train_images_label,transform=data_transform["train"])# 实例化验证数据集val_dataset = MyDataSet(images_path=val_images_path,images_class=val_images_label,transform=data_transform["val"])batch_size = args.batch_sizenw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workersprint('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,pin_memory=True,num_workers=nw,collate_fn=train_dataset.collate_fn)val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=batch_size,shuffle=False,pin_memory=True,num_workers=nw,collate_fn=val_dataset.collate_fn)model = create_model(num_classes=args.num_classes).to(device)if args.weights != "":assert os.path.exists(args.weights), "weights file: '{}' not exist.".format(args.weights)weights_dict = torch.load(args.weights, map_location=device)["model"]# 删除有关分类类别的权重for k in list(weights_dict.keys()):if "head" in k:del weights_dict[k]print(model.load_state_dict(weights_dict, strict=False))if args.freeze_layers:for name, para in model.named_parameters():# 除head外,其他权重全部冻结if "head" not in name:para.requires_grad_(False)else:print("training {}".format(name))pg = [p for p in model.parameters() if p.requires_grad]optimizer = optim.AdamW(pg, lr=args.lr, weight_decay=5E-2)for epoch in range(args.epochs):# traintrain_loss, train_acc = train_one_epoch(model=model,optimizer=optimizer,data_loader=train_loader,device=device,epoch=epoch)# validateval_loss, val_acc = evaluate(model=model,data_loader=val_loader,device=device,epoch=epoch)tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]tb_writer.add_scalar(tags[0], train_loss, epoch)tb_writer.add_scalar(tags[1], train_acc, epoch)tb_writer.add_scalar(tags[2], val_loss, epoch)tb_writer.add_scalar(tags[3], val_acc, epoch)tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch))if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--num_classes', type=int, default=5)parser.add_argument('--epochs', type=int, default=100)parser.add_argument('--batch-size', type=int, default=4)parser.add_argument('--lr', type=float, default=0.0001)# 数据集所在根目录# https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgzparser.add_argument('--data-path', type=str,default=r"G:\demo\data\foods")# 预训练权重路径,如果不想载入就设置为空字符parser.add_argument('--weights', type=str, default='swin_tiny_patch4_window7_224.pth',help='initial weights path')# 是否冻结权重parser.add_argument('--freeze-layers', type=bool, default=False)parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')opt = parser.parse_args()main(opt)

第四步:统计正确率

第五步:搭建GUI界面

第六步:整个工程的内容

有训练代码和训练好的模型以及训练过程,提供数据,提供GUI界面代码

代码的下载路径(新窗口打开链接):基于Pytorch框架的深度学习Swin-Transformer神经网络食物分类系统源码

有问题可以私信或者留言,有问必答

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/30156.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

swift使用swift-protobuf协议通讯,使用指北

什么是Protobuf Protobuf(Protocol Buffers)协议😉 Protobuf 是一种由 Google 开发的二进制序列化格式和相关的技术,它用于高效地序列化和反序列化结构化数据,通常用于网络通信、数据存储等场景。 为什么要使用Proto…

Java面试八股之myBatis与myBatis plus的对比

myBatis与myBatis plus的对比 基础与增强: MyBatis 是一个成熟的Java持久层框架,它允许开发者通过XML文件或注解来配置SQL语句和数据库映射,提供了一个灵活的方式来操作数据库,但需要手动编写所有的SQL语句和结果集映射。 MyBa…

Day55 代码随想录打卡|二叉树篇---二叉搜索树中的插入操作

题目(leecode T701): 给定二叉搜索树(BST)的根节点 root 和要插入树中的值 value ,将值插入二叉搜索树。 返回插入后二叉搜索树的根节点。 输入数据 保证 ,新值和原始二叉搜索树中的任意节点值…

【雷丰阳-谷粒商城 】【分布式高级篇-微服务架构篇】【11】ElasticSearch

持续学习&持续更新中… 守破离 【雷丰阳-谷粒商城 】【分布式高级篇-微服务架构篇】【11】ElasticSearch 简介基本概念ElasticSearch概念-倒排索引安装基本命令ik 分词器SpringBoot整合测试存储数据:测试复杂检索同步与异步调用 参考 简介 Elasticsearch 是一…

【AIGC】MetaGPT原理以及应用

目录 MetaGPT原理 MetaGPT应用 MetaGPT和传统编程语言相比有什么优势和劣势 视频中的PPT 参考资料 MetaGPT原理 MetaGPT是一种多智能体框架,它结合了元编程技术,通过标准化操作程序(SOPs)来协调基于大语言模型的多智能体系统…

嵌入式实验---实验一 通用GPIO实验

一、实验目的 1、掌握STM32F103 GPIO程序设计流程; 2、熟悉STM32固件库的基本使用。 二、实验原理 1、通过按键实现:按键按下,LED点亮;按键释放,LED熄灭。 三、实验设备和器材 电脑、Keil uVision5软件、Proteus…

Hierarchical Integration Diffusion Model for Realistic Image Deblurring

neurips23 上交&ETH&字节&清华&上海ai lab&悉尼大学&西湖大学https://github.com/zhengchen1999/HI-Diff 问题引入 现在的diffusion的方法在sample的时候需要的iteration过多,所以本文提出在高度压缩的空间进行DM,且deblur模型…

【python】PyCharm如何设置字体大小和背景

目录 效果展示 字体大小 背景设置 效果展示 字体大小 再左上角找到四条杠的图标 找到File 一般字体大小为22最合适,行间距为默认 背景设置 还是再字体设置的页面搜索 background 小编的其他文章详见,欢迎来支持 东洛的克莱斯韦克-CSDN博客 【机器…

程序员失业了,你可以做这些事情

这篇文章,我们讲,你先别带入自己哈,如果失业了,放心吧,你那么有上进心,不会失业的。咱就是说,如果万一失业了,你可以做这些事情。 1 体力好的铁人三项 👩‍&#x1f3e…

Python基础-引用参数、斐波那契数列、无极分类

1.引用参数的问题 (1)列表(list) 引用参数,传地址的参数,即list1会因list2修改而改变。 list1 [1,2,3,4] list2 list1 print(list1) list2[2] 1 print(list2) print(list1)非引用参数,不传…

解锁TikTok内容趋势——高效获取TikTok标签信息接口

一、引言 在TikTok这个全球热门的短视频平台上,标签(Hashtags)是用户和内容创作者连接、发现新内容的重要工具。为了帮助品牌、市场分析师、内容创作者等更好地理解和利用TikTok上的内容趋势,我们推出了一款全新的接口服务&#…

PD19 Parallels Desktop 虚拟机 安装Windows10系统 操作步骤(保姆级教程,轻松上手)

Mac分享吧 文章目录 效果一、准备工作**下载软件** 二、开始安装1、打开pd 19 虚拟机,点击右上角文件,新建2、通过下载好的镜像安装Windows10系统。找到镜像文件位置,安装,配置2、显示安装完成,打开Windows10系统 三、…

UI设计速成课:理解模态窗口与非模态窗口的区别

我们日常所说的弹性框架是非常笼统的概念。我们习惯性地称之为对话框架、浮动层和提示条。弹性框架可以分为两种:模态弹性框架和非模态弹性框架。产品需要弹性框架来传递信息,用户需要弹性框架来接受反馈,但是没有经过推敲的弹出窗口设计很容易让用户感到…

最新版首发 | 手把手教你安装 Vivado2024.1(附安装包)

Q:Vivado出2024版了!不知迪普微有没有对应的安装包呢? A:有的!回复“Vivado2024.1”即可获得相应安装包哦~ Q:好哒~但是我不会安装,可否安排一期安装教程? A:立马安排&…

Gin 详解

Gin 介绍 gin框架是一个基于go语言的轻量级web框架,它具有高效性、灵活性、易扩展性路由 gin框架使用的是定制版的httprouter 其路由原理是大量使用公共前缀的树结构,注册路由的过程就是构造前缀树的过程。 具有公共前缀的节点也共享一个公共父节点。…

怎么移除pdf文件编辑限制,有哪些方法?

PDF是我们在学习或工作中常常应用到的一种文件格式,因为它的跨平台性和文档保真度而备受欢迎。但是,有时我们会遇到PDF编辑权限被限制了,那么pdf解除编辑限制可以用什么方法呢?别急,接下来,本文将深入探讨如…

关于Panabit在资产平台中类型划分问题

现场同事问了一个问题:Panabit能不能当做CentOS接入? 我第一反应是:Panabit是个什么鬼?为啥要混编接入?后期维护都是事啊。所以,我就想回答:不能! 但是,最好要给出一个…

通过sql语句直接导出excel文件

SELECT column1 as 名字 FROM your_table INTO OUTFILE /path/to/your_file.csv FIELDS TERMINATED BY , ENCLOSED BY " LINES TERMINATED BY \n 这里的注意事项是,INTO OUTFILE 这后面的路径需要通过下面的SQL查出来 show variables like %secure%; 操作步骤…

构建多模态模型,生成主机观测指标,欢迎来战丨2024天池云原生编程挑战赛

在当前云计算和微服务架构日益普及的背景下,企业和开发者对云资源的依赖日益加深。Elastic Compute Service(ECS)作为提供计算能力的核心服务,承担着众多的业务。随着微服务架构的广泛应用,任务的部署和执行变得更为灵…