霹雳吧啦Wz《pytorch图像分类》-p6MobileNet网络

《pytorch图像分类》p6MobileNet网络结构详解

  • 1 MobileNet v1网络
    • 1.1 Depthwise convolution(DW卷积)
      • 1.1.1Depthwise separable convolution(深度可分的卷积操作)
    • 1.2 增加超参数α和β
  • 2 MobileNet v2网络
    • 2.1 Inverted Residuals(倒残差结构)
    • 2.2 Linear Bottlenecks
    • 2.3 MobileNet v3
  • 3 课程代码
    • 3.1 modle_v2.py
    • 3.2 train.py
    • 3.3 predict.py

1 MobileNet v1网络

论文链接:MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications

1.1 Depthwise convolution(DW卷积)

DW卷积大大减少了运算量和参数数量
卷积核的深度channel=1
in_channels=卷积核个数=out_channels
在这里插入图片描述

1.1.1Depthwise separable convolution(深度可分的卷积操作)

深度可分的卷积操作包括DW(Depthwise)和PW(Pointwise)
在这里插入图片描述
传统卷积:
在这里插入图片描述
DW+PW卷积:
在这里插入图片描述
MobileNet Comparison to Popular Models
在这里插入图片描述

1.2 增加超参数α和β

在这里插入图片描述

2 MobileNet v2网络

相比MobileNet v1 网络准确率更高 参数更小
论文链接:MobileNetV2: Inverted Residuals and Linear Bottlenecks

2.1 Inverted Residuals(倒残差结构)

在这里插入图片描述

2.2 Linear Bottlenecks

针对倒残差结构最后一个1×1的卷积层,它使用了线性的激活函数而不是relu激活函数
只有stride=1in_channels=out_channels时,才有捷径分支
在这里插入图片描述
relu6激活函数公式:
f ( x ) = m i n ( m a x ( x , 0 ) , 6 ) f(x) = min(max(x, 0), 6) f(x)=min(max(x,0),6)
在这里插入图片描述
MobileNet v2网络模型结构参数
在这里插入图片描述

2.3 MobileNet v3

论文链接:Searching for MobileNetV3

3 课程代码

3.1 modle_v2.py

from torch import nn
import torchdef _make_divisible(ch, divisor=8, min_ch=None):"""This function is taken from the original tf repo.It ensures that all layers have a channel number that is divisible by 8It can be seen here:https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py"""if min_ch is None:min_ch = divisornew_ch = max(min_ch, int(ch + divisor / 2) // divisor * divisor)# Make sure that round down does not go down by more than 10%.if new_ch < 0.9 * ch:new_ch += divisorreturn new_chclass ConvBNReLU(nn.Sequential):def __init__(self, in_channel, out_channel, kernel_size=3, stride=1, groups=1):padding = (kernel_size - 1) // 2super(ConvBNReLU, self).__init__(nn.Conv2d(in_channel, out_channel, kernel_size, stride, padding, groups=groups, bias=False),nn.BatchNorm2d(out_channel),nn.ReLU6(inplace=True))class InvertedResidual(nn.Module):def __init__(self, in_channel, out_channel, stride, expand_ratio):super(InvertedResidual, self).__init__()hidden_channel = in_channel * expand_ratioself.use_shortcut = stride == 1 and in_channel == out_channellayers = []if expand_ratio != 1:# 1x1 pointwise convlayers.append(ConvBNReLU(in_channel, hidden_channel, kernel_size=1))layers.extend([# 3x3 depthwise convConvBNReLU(hidden_channel, hidden_channel, stride=stride, groups=hidden_channel),# 1x1 pointwise conv(linear)nn.Conv2d(hidden_channel, out_channel, kernel_size=1, bias=False),nn.BatchNorm2d(out_channel),])self.conv = nn.Sequential(*layers)def forward(self, x):if self.use_shortcut:return x + self.conv(x)else:return self.conv(x)class MobileNetV2(nn.Module):def __init__(self, num_classes=1000, alpha=1.0, round_nearest=8):super(MobileNetV2, self).__init__()block = InvertedResidualinput_channel = _make_divisible(32 * alpha, round_nearest)last_channel = _make_divisible(1280 * alpha, round_nearest)inverted_residual_setting = [# t, c, n, s[1, 16, 1, 1],[6, 24, 2, 2],[6, 32, 3, 2],[6, 64, 4, 2],[6, 96, 3, 1],[6, 160, 3, 2],[6, 320, 1, 1],]features = []# conv1 layerfeatures.append(ConvBNReLU(3, input_channel, stride=2))# building inverted residual residual blockesfor t, c, n, s in inverted_residual_setting:output_channel = _make_divisible(c * alpha, round_nearest)for i in range(n):stride = s if i == 0 else 1features.append(block(input_channel, output_channel, stride, expand_ratio=t))input_channel = output_channel# building last several layersfeatures.append(ConvBNReLU(input_channel, last_channel, 1))# combine feature layersself.features = nn.Sequential(*features)# building classifierself.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.classifier = nn.Sequential(nn.Dropout(0.2),nn.Linear(last_channel, num_classes))# weight initializationfor m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out')if m.bias is not None:nn.init.zeros_(m.bias)elif isinstance(m, nn.BatchNorm2d):nn.init.ones_(m.weight)nn.init.zeros_(m.bias)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)nn.init.zeros_(m.bias)def forward(self, x):x = self.features(x)x = self.avgpool(x)x = torch.flatten(x, 1)x = self.classifier(x)return x

3.2 train.py

import os
import sys
import jsonimport torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm import tqdmfrom model_v2 import MobileNetV2def main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("using {} device.".format(device))batch_size = 16epochs = 5data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),"val": transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root pathimage_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set pathassert os.path.exists(image_path), "{} path does not exist.".format(image_path)train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),transform=data_transform["train"])train_num = len(train_dataset)# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}flower_list = train_dataset.class_to_idxcla_dict = dict((val, key) for key, val in flower_list.items())# write dict into json filejson_str = json.dumps(cla_dict, indent=4)with open('class_indices.json', 'w') as json_file:json_file.write(json_str)nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workersprint('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=nw)validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),transform=data_transform["val"])val_num = len(validate_dataset)validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size, shuffle=False,num_workers=nw)print("using {} images for training, {} images for validation.".format(train_num,val_num))# create modelnet = MobileNetV2(num_classes=5)# load pretrain weights# download url: https://download.pytorch.org/models/mobilenet_v2-b0353104.pthmodel_weight_path = "mobilenet_v2.pth"assert os.path.exists(model_weight_path), "file {} dose not exist.".format(model_weight_path)pre_weights = torch.load(model_weight_path, map_location='cpu')# delete classifier weightspre_dict = {k: v for k, v in pre_weights.items() if net.state_dict()[k].numel() == v.numel()}missing_keys, unexpected_keys = net.load_state_dict(pre_dict, strict=False)# freeze features weightsfor param in net.features.parameters():param.requires_grad = Falsenet.to(device)# define loss functionloss_function = nn.CrossEntropyLoss()# construct an optimizerparams = [p for p in net.parameters() if p.requires_grad]optimizer = optim.Adam(params, lr=0.0001)best_acc = 0.0save_path = './MobileNetV2.pth'train_steps = len(train_loader)for epoch in range(epochs):# trainnet.train()running_loss = 0.0train_bar = tqdm(train_loader, file=sys.stdout)for step, data in enumerate(train_bar):images, labels = dataoptimizer.zero_grad()logits = net(images.to(device))loss = loss_function(logits, labels.to(device))loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,epochs,loss)# validatenet.eval()acc = 0.0  # accumulate accurate number / epochwith torch.no_grad():val_bar = tqdm(validate_loader, file=sys.stdout)for val_data in val_bar:val_images, val_labels = val_dataoutputs = net(val_images.to(device))# loss = loss_function(outputs, test_labels)predict_y = torch.max(outputs, dim=1)[1]acc += torch.eq(predict_y, val_labels.to(device)).sum().item()val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,epochs)val_accurate = acc / val_numprint('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %(epoch + 1, running_loss / train_steps, val_accurate))if val_accurate > best_acc:best_acc = val_accuratetorch.save(net.state_dict(), save_path)print('Finished Training')if __name__ == '__main__':main()

在这里插入图片描述

3.3 predict.py

import os
import jsonimport torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as pltfrom model_v2 import MobileNetV2def main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")data_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# load imageimg_path = "1.jpg"assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)img = Image.open(img_path)plt.imshow(img)# [N, C, H, W]img = data_transform(img)# expand batch dimensionimg = torch.unsqueeze(img, dim=0)# read class_indictjson_path = './class_indices.json'assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)with open(json_path, "r") as f:class_indict = json.load(f)# create modelmodel = MobileNetV2(num_classes=5).to(device)# load model weightsmodel_weight_path = "./MobileNetV2.pth"model.load_state_dict(torch.load(model_weight_path, map_location=device))model.eval()with torch.no_grad():# predict classoutput = torch.squeeze(model(img.to(device))).cpu()predict = torch.softmax(output, dim=0)predict_cla = torch.argmax(predict).numpy()print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],predict[predict_cla].numpy())plt.title(print_res)for i in range(len(predict)):print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],predict[i].numpy()))plt.show()if __name__ == '__main__':main()

1.jpg预测的是向日葵
在这里插入图片描述
2.jpg预测的是蒲公英
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/611161.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

传统 VC 机构,是否还能在 Fair launch 的散户牛市中胜出?

LaunchPad 是代币面向市场的重要一环&#xff0c;将代币推向市场&#xff0c;加密项目将能够通过代币的销售从市场上募集资金&#xff0c;同时生态也开始进入全新的发展阶段。而对于投资者来说&#xff0c;早期打新市场同样充满着机会&#xff0c;参与 LaunchPad 对于每一个投资…

ubuntu 18.04网络问题

ubuntu 18.04网络问题汇总 准备工作一、有线网卡不可用二、无法访问外网 准备工作 安装好系统之后&#xff0c;检查gcc和make是否已经安装 $ which gcc /usr/bin/gcc $ which make /usr/bin/make如果未安装&#xff0c;则安装gcc和make $ apt install gcc $ apt install mak…

基于ssm个性化旅游攻略定制系统设计与实现+jsp论文

摘 要 在如今社会上&#xff0c;关于信息上面的处理&#xff0c;没有任何一个企业或者个人会忽视&#xff0c;如何让信息急速传递&#xff0c;并且归档储存查询&#xff0c;采用之前的纸张记录模式已经不符合当前使用要求了。所以&#xff0c;对个性化旅游攻略信息管理的提升&…

GO语言笔记2-变量与基本数据类型

变量使用步骤 声明赋值使用 package main import "fmt" func main(){var age int //声明一个 int类型的变量叫ageage 18 //给变量用 赋值fmt.Println(age) //使用变量 输出变量的值 } 编译运行输出变量值 变量的四种使用方式 package main import "fmt&q…

InnoDB引擎

一、逻辑存储结构 ① 表空间&#xff08;ibd文件&#xff09;&#xff0c;一个mysql实例可以对应多个表空间&#xff0c;用于存储记录、索引等数据。 ② 段&#xff0c;分为数据段&#xff08;Leaf node segment&#xff09;、索引段&#xff08;Non-leaf node segment&#x…

特征工程(二)

特征工程&#xff08;二&#xff09; 特征理解 理解手上的数据&#xff0c;就可以更好的明确下一步的方向。从繁杂的切入点中&#xff0c;主要着眼于一下几个方面&#xff1a; 结构化数据与非结构化数据&#xff1b;数据的4个等级&#xff1b;识别数据中存在的缺失值&#xf…

古典吉他教师阿木:来自新疆的音乐才子

阿木,全名木合汤夏甫依克,于 1990 年 10 月 8 日出生在新疆这片美丽的土地上,是一位哈萨克族人。他是英皇认证古典吉他教师、中国社会艺术吉他考级考官、中国智慧工程研究会艺术教育委员会执行委员、新疆吉他艺术节发起人之一兼评审组组长。 阿木自幼受到哥哥的影响,对吉他产生…

强化学习第1天:马尔可夫过程

☁️主页 Nowl &#x1f525;专栏 《强化学习》 &#x1f4d1;君子坐而论道&#xff0c;少年起而行之 ​​ 一、介绍 什么是马尔可夫过程&#xff1f;马尔可夫过程是马尔可夫决策过程的基础&#xff0c;而马尔可夫决策过程便是大部分强化学习任务的抽象过程&#xff0c;本文…

206. 反转链表(Java)

题目描述&#xff1a; 给你单链表的头节点 head &#xff0c;请你反转链表&#xff0c;并返回反转后的链表。 输入&#xff1a; head [1,2,3,4,5] 输出&#xff1a; [5,4,3,2,1] 代码实现&#xff1a; 1.根据题意创建一个结点类&#xff1a; public class ListNode {int val…

Python pip 常用指令

前言 Python的pip是一个强大的包管理工具&#xff0c;它可以帮助我们安装、升级和管理Python的第三方库。以下是一些常用的pip指令。 1. 安装第三方库 使用pip安装Python库非常简单&#xff0c;只需要使用pip install命令&#xff0c;后面跟上库的名字即可。 # 安装virtuale…

苍穹外卖Day01——解决总结1中存在的问题

前序章节&#xff1a; 苍穹外卖Day01——总结1 解决总结1中存在的问题 1. 什么是JWT2. POJO、Entity、VO、DTO3. Nginx&#xff08;反向代理&#xff09;4. Data注解 1. 什么是JWT JWT&#xff08;JSON Web Token&#xff09;是一种用于在网络应用间传递信息的开放标准&#…

MongoDB多文档事务详解

事务简介 事务&#xff08;transaction&#xff09;是传统数据库所具备的一项基本能力&#xff0c;其根本目的是为数据的可靠性与一致性提供保障。而在通常的实现中&#xff0c;事务包含了一个系列的数据库读写操作&#xff0c;这些操作要么全部完成&#xff0c;要么全部撤销。…

Mysql是怎么运行的(上)

文章目录 Mysql是怎么运行的Mysql处理一条语句的流程连接管理解析与优化存储引擎 基本配置配置文件系统变量状态变量字符集四种重要的字符集MySQL中的utf8和utf8mb4各级别的字符集和比较规则MySQL中字符集的转换排序规则产生的不同的排序结果 InnoDB存储引擎介绍COMPACT行格式介…

LLM大模型和数据标注

对于那些不精通机器学习的人来说&#xff0c;像 ChatGPT 所基于的 GPT-3.5 这样的大型语言模型似乎是自给自足的。这些模型通过无监督或自我监督学习进行训练。简而言之&#xff0c;只需极少的人工干预&#xff0c;就能生成一个能像人类一样对话的模型。 这就引出了一个问题--…

性能分析与调优: Linux 文件系统观测工具

目录 一、实验 1.环境 2.mount 3.free 4.top 5.vmstat 6.sar 7.slabtop 8.strace 9.opensnoop 10.filetop 11.cachestat 二、问题 1.Ftrace实例如何实现 2.Function trace 如何跟踪实例 3.function_graph Trace 如何跟踪实例 4.trace event 如何跟踪实例 5.未…

ESP32-S3 使用内置USB下载程序、调试、LOG相关问题总结

目录 Preface&#xff1a; &#xff08;一&#xff09;为电脑安装USB驱动 &#xff08;二&#xff09;Platformio工程 &#xff08;三&#xff09;相关文章 &#xff08;四&#xff09;总结 Preface&#xff1a; esp32-s3有一个built-in的usb-jtag&#xff0c;可以用来下载…

Linux环境vscode clang-format格式化:vscode clang format command is not available亲测有效!

问题现象 vscode安装了clang-format插件&#xff0c;但是使用就报错 问题原因 设置中配置的clang-format插件工具路径不正确。 解决方案-亲测有效&#xff01; 确认本地安装了clang-format工具&#xff1a;终端输入clang-format&#xff08;也可能是clang-format-13等版本…

软件测试|MySQL CROSS JOIN:交叉连接的详细解析

简介 在 MySQL 数据库中&#xff0c;CROSS JOIN 是一种用于生成两个或多个表的笛卡尔积的连接方法。CROSS JOIN 不需要任何连接条件&#xff0c;它将左表的每一行与右表的每一行进行组合&#xff0c;从而生成一个包含所有可能组合的结果集。本文将详细介绍 MySQL 中的 CROSS J…

故事生成动漫解说视频,用Artflow AI做英语口语故事

大家好我是在看&#xff0c;记录普通人学习探索AI之路。 今天&#xff0c;我将再次为大家精心策划一个使用Artflow AI制作动漫解说视频的详尽教程&#xff0c;这个教程专为初学者设计。通过这款强大的Artflow AI工具&#xff0c;用户能够一键自动化完成从图像生成、视频剪辑到配…

性能测试LoadRunner解决动态验证码问题

对于这个问题&#xff0c;通常我们可以采取以下三个途径来解决该问题&#xff1a; 1、第一种方法&#xff0c;也是最容易想到的&#xff0c;在被测系统中暂时屏蔽验证功能&#xff0c;也就是说&#xff0c;临时修改应用&#xff0c;无论用户输入的是什么验证码&#xff0c;都…