PyTorch实现InceptionResNetV2:预训练模型适应多类别任务代码解析

系列文章目录

9种经典图片分类卷积模型系列合集(推荐程度依次递减):

  1. Se_resnet50
  2. Resnet50
  3. Xception
  4. inceptionresnetv2
  5. resnext
  6. bninception
  7. shufflenetv2
  8. polynet
  9. vggm

Imagenet的预训练inceptionresnetv2是1000个类别,根据笔者添加了一个bottleneck层和一个head层使得可以进行自定义类别训练。

源码

from __future__ import print_function, division, absolute_import
import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
import os
import sys__all__ = ['InceptionResNetV2', 'inceptionresnetv2']pretrained_settings = {'inceptionresnetv2': {'imagenet': {'url': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth','input_space': 'RGB','input_size': [3, 299, 299],'input_range': [0, 1],'mean': [0.5, 0.5, 0.5],'std': [0.5, 0.5, 0.5],'num_classes': 1000},'imagenet+background': {'url': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth','input_space': 'RGB','input_size': [3, 299, 299],'input_range': [0, 1],'mean': [0.5, 0.5, 0.5],'std': [0.5, 0.5, 0.5],'num_classes': 1001}}
}def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):"""3x3 convolution with padding"""return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,padding=dilation, groups=groups, bias=False, dilation=dilation)def conv1x1(in_planes, out_planes, stride=1):"""1x1 convolution"""return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)class BasicBlock(nn.Module):expansion = 1__constants__ = ['downsample']def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,base_width=64, dilation=1, norm_layer=None):super(BasicBlock, self).__init__()if norm_layer is None:norm_layer = nn.BatchNorm2dif groups != 1 or base_width != 64:raise ValueError('BasicBlock only supports groups=1 and base_width=64')if dilation > 1:raise NotImplementedError("Dilation > 1 not supported in BasicBlock")# Both self.conv1 and self.downsample layers downsample the input when stride != 1self.conv1 = conv3x3(inplanes, planes, stride)self.bn1 = norm_layer(planes)self.relu = nn.ReLU(inplace=True)self.conv2 = conv3x3(planes, planes)self.bn2 = norm_layer(planes)self.downsample = downsampleself.stride = stridedef forward(self, x):identity = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)if self.downsample is not None:identity = self.downsample(x)out += identityout = self.relu(out)return outclass Bottleneck(nn.Module):expansion = 4__constants__ = ['downsample']def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,base_width=64, dilation=1, norm_layer=None):super(Bottleneck, self).__init__()if norm_layer is None:norm_layer = nn.BatchNorm2dwidth = int(planes * (base_width / 64.)) * groupsself.conv1 = conv1x1(inplanes, width)self.bn1 = norm_layer(width)self.conv2 = conv3x3(width, width, stride, groups, dilation)self.bn2 = norm_layer(width)self.conv3 = conv1x1(width, planes * self.expansion)self.bn3 = norm_layer(planes * self.expansion)self.relu = nn.ReLU(inplace=True)self.downsample = downsampleself.stride = stridedef forward(self, x):identity = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)if self.downsample is not None:identity = self.downsample(x)out += identityout = self.relu(out)return outclass BasicConv2d(nn.Module):def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):super(BasicConv2d, self).__init__()self.conv = nn.Conv2d(in_planes, out_planes,kernel_size=kernel_size, stride=stride,padding=padding, bias=False) # verify bias falseself.bn = nn.BatchNorm2d(out_planes,eps=0.001, # value found in tensorflowmomentum=0.1, # default pytorch valueaffine=True)self.relu = nn.ReLU(inplace=False)def forward(self, x):x = self.conv(x)x = self.bn(x)x = self.relu(x)return xclass Mixed_5b(nn.Module):def __init__(self):super(Mixed_5b, self).__init__()self.branch0 = BasicConv2d(192, 96, kernel_size=1, stride=1)self.branch1 = nn.Sequential(BasicConv2d(192, 48, kernel_size=1, stride=1),BasicConv2d(48, 64, kernel_size=5, stride=1, padding=2))self.branch2 = nn.Sequential(BasicConv2d(192, 64, kernel_size=1, stride=1),BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1),BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1))self.branch3 = nn.Sequential(nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),BasicConv2d(192, 64, kernel_size=1, stride=1))def forward(self, x):x0 = self.branch0(x)x1 = self.branch1(x)x2 = self.branch2(x)x3 = self.branch3(x)out = torch.cat((x0, x1, x2, x3), 1)return outclass Block35(nn.Module):def __init__(self, scale=1.0):super(Block35, self).__init__()self.scale = scaleself.branch0 = BasicConv2d(320, 32, kernel_size=1, stride=1)self.branch1 = nn.Sequential(BasicConv2d(320, 32, kernel_size=1, stride=1),BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1))self.branch2 = nn.Sequential(BasicConv2d(320, 32, kernel_size=1, stride=1),BasicConv2d(32, 48, kernel_size=3, stride=1, padding=1),BasicConv2d(48, 64, kernel_size=3, stride=1, padding=1))self.conv2d = nn.Conv2d(128, 320, kernel_size=1, stride=1)self.relu = nn.ReLU(inplace=False)def forward(self, x):x0 = self.branch0(x)x1 = self.branch1(x)x2 = self.branch2(x)out = torch.cat((x0, x1, x2), 1)out = self.conv2d(out)out = out * self.scale + xout = self.relu(out)return outclass Mixed_6a(nn.Module):def __init__(self):super(Mixed_6a, self).__init__()self.branch0 = BasicConv2d(320, 384, kernel_size=3, stride=2)self.branch1 = nn.Sequential(BasicConv2d(320, 256, kernel_size=1, stride=1),BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1),BasicConv2d(256, 384, kernel_size=3, stride=2))self.branch2 = nn.MaxPool2d(3, stride=2)def forward(self, x):x0 = self.branch0(x)x1 = self.branch1(x)x2 = self.branch2(x)out = torch.cat((x0, x1, x2), 1)return outclass Block17(nn.Module):def __init__(self, scale=1.0):super(Block17, self).__init__()self.scale = scaleself.branch0 = BasicConv2d(1088, 192, kernel_size=1, stride=1)self.branch1 = nn.Sequential(BasicConv2d(1088, 128, kernel_size=1, stride=1),BasicConv2d(128, 160, kernel_size=(1,7), stride=1, padding=(0,3)),BasicConv2d(160, 192, kernel_size=(7,1), stride=1, padding=(3,0)))self.conv2d = nn.Conv2d(384, 1088, kernel_size=1, stride=1)self.relu = nn.ReLU(inplace=False)def forward(self, x):x0 = self.branch0(x)x1 = self.branch1(x)out = torch.cat((x0, x1), 1)out = self.conv2d(out)out = out * self.scale + xout = self.relu(out)return outclass Mixed_7a(nn.Module):def __init__(self):super(Mixed_7a, self).__init__()self.branch0 = nn.Sequential(BasicConv2d(1088, 256, kernel_size=1, stride=1),BasicConv2d(256, 384, kernel_size=3, stride=2))self.branch1 = nn.Sequential(BasicConv2d(1088, 256, kernel_size=1, stride=1),BasicConv2d(256, 288, kernel_size=3, stride=2))self.branch2 = nn.Sequential(BasicConv2d(1088, 256, kernel_size=1, stride=1),BasicConv2d(256, 288, kernel_size=3, stride=1, padding=1),BasicConv2d(288, 320, kernel_size=3, stride=2))self.branch3 = nn.MaxPool2d(3, stride=2)def forward(self, x):x0 = self.branch0(x)x1 = self.branch1(x)x2 = self.branch2(x)x3 = self.branch3(x)out = torch.cat((x0, x1, x2, x3), 1)return outclass Block8(nn.Module):def __init__(self, scale=1.0, noReLU=False):super(Block8, self).__init__()self.scale = scaleself.noReLU = noReLUself.branch0 = BasicConv2d(2080, 192, kernel_size=1, stride=1)self.branch1 = nn.Sequential(BasicConv2d(2080, 192, kernel_size=1, stride=1),BasicConv2d(192, 224, kernel_size=(1,3), stride=1, padding=(0,1)),BasicConv2d(224, 256, kernel_size=(3,1), stride=1, padding=(1,0)))self.conv2d = nn.Conv2d(448, 2080, kernel_size=1, stride=1)if not self.noReLU:self.relu = nn.ReLU(inplace=False)def forward(self, x):x0 = self.branch0(x)x1 = self.branch1(x)out = torch.cat((x0, x1), 1)out = self.conv2d(out)out = out * self.scale + xif not self.noReLU:out = self.relu(out)return outclass InceptionResNetV2(nn.Module):def __init__(self, num_classes=1001, zero_init_residual=False):super(InceptionResNetV2, self).__init__()# Special attributsself.input_space = Noneself.input_size = (299, 299, 3)self.mean = Noneself.std = None# Modulesself.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2)self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1)self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1)self.maxpool_3a = nn.MaxPool2d(3, stride=2)self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1)self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1)self.maxpool_5a = nn.MaxPool2d(3, stride=2)self.mixed_5b = Mixed_5b()self.repeat = nn.Sequential(Block35(scale=0.17),Block35(scale=0.17),Block35(scale=0.17),Block35(scale=0.17),Block35(scale=0.17),Block35(scale=0.17),Block35(scale=0.17),Block35(scale=0.17),Block35(scale=0.17),Block35(scale=0.17))self.mixed_6a = Mixed_6a()self.repeat_1 = nn.Sequential(Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10))self.mixed_7a = Mixed_7a()self.repeat_2 = nn.Sequential(Block8(scale=0.20),Block8(scale=0.20),Block8(scale=0.20),Block8(scale=0.20),Block8(scale=0.20),Block8(scale=0.20),Block8(scale=0.20),Block8(scale=0.20),Block8(scale=0.20))self.block8 = Block8(noReLU=True)self.conv2d_7b = BasicConv2d(2080, 1536, kernel_size=1, stride=1)self.avgpool_1a = nn.AvgPool2d(5, count_include_pad=False)self.bottleneck = nn.Sequential(nn.Linear(1536, 512),nn.BatchNorm1d(512),nn.ReLU(),nn.Dropout(0.5))self.bottleneck[0].weight.data.normal_(0, 0.005)self.bottleneck[0].bias.data.fill_(0.1)self.head = nn.Sequential(nn.Linear(512, 512),nn.ReLU(),nn.Dropout(0.5),nn.Linear(512, num_classes))# self.fc = nn.Linear(512, num_classes)for dep in range(2):self.head[dep * 3].weight.data.normal_(0, 0.01)self.head[dep * 3].bias.data.fill_(0.0)for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)if zero_init_residual:for m in self.modules():if isinstance(m, Bottleneck):nn.init.constant_(m.bn3.weight, 0)elif isinstance(m, BasicBlock):nn.init.constant_(m.bn2.weight, 0)def features(self, input):x = self.conv2d_1a(input)x = self.conv2d_2a(x)x = self.conv2d_2b(x)x = self.maxpool_3a(x)x = self.conv2d_3b(x)x = self.conv2d_4a(x)x = self.maxpool_5a(x)x = self.mixed_5b(x)x = self.repeat(x)x = self.mixed_6a(x)x = self.repeat_1(x)x = self.mixed_7a(x)x = self.repeat_2(x)x = self.block8(x)x = self.conv2d_7b(x)return xdef logits(self, features):x = self.avgpool_1a(features)# print("x1.size={}".format(x.shape))x = x.view(x.size(0), -1)# print("x2.size={}".format(x.shape))x = self.bottleneck(x)x = self.head(x)# x = self.last_linear(x)return xdef forward(self, input):x = self.features(input)# print("x0.size={}".format(x.shape))x = self.logits(x)return xdef inceptionresnetv2(num_classes=1000, pretrained='imagenet'):r"""InceptionResNetV2 model architecture from the`"InceptionV4, Inception-ResNet..." <https://arxiv.org/abs/1602.07261>`_ paper."""if pretrained:pretrained = 'imagenet+background'num_classes_hat = 1001settings = pretrained_settings['inceptionresnetv2'][pretrained]# print(settings)# print('num=%d\n',num_classes)# assert num_classes == settings['num_classes'], \#     "num_classes should be {}, but is {}".format(settings['num_classes'], num_classes)# both 'imagenet'&'imagenet+background' are loaded from same parametersmodel = InceptionResNetV2(num_classes=num_classes_hat)model.load_state_dict(model_zoo.load_url(settings['url']), strict=False)# if pretrained == 'imagenet+background':#     # print("yes")#     # model.last_linear = nn.Linear(1536, num_classes).cuda()#     new_last_linear = nn.Linear(1536, num_classes).cuda()#     new_last_linear.weight.data = model.last_linear.weight.data[1:]#     new_last_linear.bias.data = model.last_linear.bias.data[1:]#     model.last_linear = new_last_linearmodel.input_space = settings['input_space']model.input_size = settings['input_size']model.input_range = settings['input_range']model.mean = settings['mean']model.std = settings['std']else:model = InceptionResNetV2(num_classes=num_classes)return model'''
TEST
Run this code with:

cd $HOME/pretrained-models.pytorch

python -m pretrainedmodels.inceptionresnetv2

'''
if __name__ == '__main__':assert inceptionresnetv2(num_classes=10, pretrained=None)print('success')assert inceptionresnetv2(num_classes=1000, pretrained='imagenet')print('success')assert inceptionresnetv2(num_classes=1001, pretrained='imagenet+background')print('success')# failassert inceptionresnetv2(num_classes=1001, pretrained='imagenet')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/867681.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Go 中的类型推断

&#x1f49d;&#x1f49d;&#x1f49d;欢迎莅临我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:「stormsha的主页」…

【三级等保】等保整体建设方案(Word原件)

建设要点目录&#xff1a; 1、系统定级与安全域 2、实施方案设计 3、安全防护体系建设规划 软件全文档&#xff0c;全方案获取方式&#xff1a;本文末个人名片直接获取。

【Python】基于KMeans的航空公司客户数据聚类分析

&#x1f490;大家好&#xff01;我是码银~&#xff0c;欢迎关注&#x1f490;&#xff1a; CSDN&#xff1a;码银 公众号&#xff1a;码银学编程 实验目的和要求 会用Python创建Kmeans聚类分析模型使用KMeans模型对航空公司客户价值进行聚类分析会对聚类结果进行分析评价 实…

Python酷库之旅-第三方库Pandas(008)

目录 一、用法精讲 16、pandas.DataFrame.to_json函数 16-1、语法 16-2、参数 16-3、功能 16-4、返回值 16-5、说明 16-6、用法 16-6-1、数据准备 16-6-2、代码示例 16-6-3、结果输出 17、pandas.read_html函数 17-1、语法 17-2、参数 17-3、功能 17-4、返回值…

CountDownLatch简介

引言 在多线程编程中&#xff0c;线程之间的协调和同步是一个常见的需求。Java 提供了多种工具来实现这一目标&#xff0c;其中 CountDownLatch 是一种简单而强大的同步机制。本文将详细介绍 CountDownLatch 的概念、使用方法和实际应用场景。 1. CountDownLatch 概述 Count…

Redis新手教程

Redis新手教程 目录 什么是RedisRedis的安装 安装前准备安装步骤 Redis的基本数据类型 字符串哈希列表集合有序集合 Redis的持久化 快照AOF Redis的高可用性 主从复制Redis SentinelRedis Cluster Redis的使用场景Redis的优缺点总结 1. 什么是Redis Redis&#xff08;Remot…

IPython 调试秘籍:精通 %xmode 命令的错误显示模式设置

IPython 调试秘籍&#xff1a;精通 %xmode 命令的错误显示模式设置 在使用 IPython 进行交互式编程时&#xff0c;错误信息的显示模式对于调试代码至关重要。%xmode 命令是 IPython 中专门用于控制错误信息展示方式的魔术命令。本文将详细解释 %xmode 命令的使用方法&#xff…

介绍东芝TB62262FTAG芯片:高性能两相双极步进电机驱动器

在当今快速发展的科技领域&#xff0c;高性能的电机驱动器对于许多工程项目来说至关重要。东芝的TB62262FTAG这款两相双极步进电机驱动器采用PWM斩波技术&#xff0c;集成了多个先进功能&#xff0c;适用于各种工业和消费类应用。本文将详细介绍TB62262FTAG的参数、性能、优势及…

ubuntu22 设置开机直接登录桌面

专栏总目录 一、打开设置文件 sudo vi /etc/gdm3/custom.conf 二、修改设置 在[daemon] 找到AutomaticLoginEnable和AutomaticLogin选项&#xff0c;取消注释并修改为&#xff1a; [daemon] # 自动登录用户名 AutomaticLoginEnableTrue AutomaticLoginusername 其中usernam…

《向量数据库指南》——Milvus Cloud检索器增强的深度探讨:句子窗口检索与元数据过滤

检索器增强的深度探讨&#xff1a;句子窗口检索与元数据过滤 在信息爆炸的时代&#xff0c;高效的检索系统成为了连接用户与海量数据的关键桥梁。为了进一步提升检索的准确性和用户满意度&#xff0c;检索器增强技术应运而生&#xff0c;其中句子窗口检索与元数据过滤作为两大…

【Qt】day3 自定义控件、框架、定时器、QPainter、QFile

文章目录 自定义控件封装自定义框架定时器第一种方式第二种方式 &#xff08;推荐&#xff09; 事件分发器QPainter基本操作高级设置抗锯齿移动坐标原点 画家画资源图片&#xff0c;并实现手动移动 作业QPaintDevice绘图设备QPixmapQimageQPicture QFile文件读写操作QFileInfo文…

移动校园(3):处理全校课程数据excel文档,实现空闲教室查询与课程表查询

首先打开教学平台 然后导出为excel文档 import mathimport pandas as pd import pymssql serverName 127.0.0.1 userName sa passWord 123456 databaseuniSchool conn pymssql.connect(serverserverName,useruserName,passwordpassWord,databasedatabase) cursor conn.cur…

Python面试题:在 Python 中,如何实现上下文管理器(context manager)?

在 Python 中&#xff0c;实现上下文管理器&#xff08;context manager&#xff09;有两种常见的方法&#xff1a;使用类和使用装饰器&#xff08;contextlib 模块中的 contextmanager 装饰器&#xff09;。上下文管理器用于管理资源&#xff0c;例如文件、网络连接等&#xf…

昇思11天

基于 MindSpore 实现 BERT 对话情绪识别 BERT模型概述 BERT&#xff08;Bidirectional Encoder Representations from Transformers&#xff09;是由Google于2018年开发并发布的一种新型语言模型。BERT在许多自然语言处理&#xff08;NLP&#xff09;任务中发挥着重要作用&am…

【C++】map和set详解

目录 1. 关联式容器 2. 键值对pair 3. 树形结构的关联式容器 4. set 4.1 set的介绍 4.2 set的构造 4.3 set的迭代器 4.4 set的容量 4.5 set的常用函数 5. multiset 6. map 6.1 map的介绍 6.2 map的构造 6.3 map的迭代器 6.4 map的容量 6.5 map的operator[] 6.6…

BiLSTM模型实现

# 本段代码构建类BiLSTM, 完成初始化和网络结构的搭建 # 总共3层: 词嵌入层, 双向LSTM层, 全连接线性层 # 本段代码构建类BiLSTM, 完成初始化和网络结构的搭建 # 总共3层: 词嵌入层, 双向LSTM层, 全连接线性层 import torch import torch.nn as nn# 本函数实现将中文文本映射为…

Android使用HttpURLConnection实现文件上传(包括图片)

1.文件上传完整代码 下面是完整的文件上传代码&#xff0c;复制即可使用。注意要开启子线程运行 public class UploadFileTask {/*** 上传文件到服务器&#xff0c;并返回服务器相应结果* param requestURL 服务器的地址* param imageUri 文件的Uri* param context* return 服…

【虚幻引擎】UE4初学者系列教程开发进阶实战篇——生存游戏案例

一、课程体系 1 学前必读 2 Character类相关基础 -人物移动控制 -动画蓝图 3 常见游戏机制基础 -碰撞器、触发器 -物体使用接口 -视角切换 4其他相关设计 -背包系统 -锻造系统 -物体破碎效果 -简易种植系统 -互动物体动画 5课程结语 二、UI部分 思维导图部分 实操部分 …

只需4500字,带你学习Python中7种基础数据类型!

Python 语言以其简洁、高效和强大的功能&#xff0c;成为了无数开发者和编程爱好者的首选。无论是数据分析、人工智能、网络开发还是自动化脚本&#xff0c;Python 都能以其优雅的语法和丰富的库支持&#xff0c;让编程变得更加简单而有趣。 但正如建造一座大厦需要坚实的地基…

如何借助AI在20分钟内写一个springboot单表的增删改查

目录 1. AI工具介绍2. 写代码的正确顺序2.1 编写 Entity 类&#xff1a;2.2 编写 Mapper 接口&#xff1a;2.3 编写 Mapper XML 文件&#xff08;如果使用 MyBatis&#xff09;&#xff1a;2.4 编写 Service 接口&#xff1a;2.5 编写 Service 实现类&#xff08;ServiceImpl&a…