PyTorch实现InceptionResNetV2:预训练模型适应多类别任务代码解析

系列文章目录

9种经典图片分类卷积模型系列合集(推荐程度依次递减):

  1. Se_resnet50
  2. Resnet50
  3. Xception
  4. inceptionresnetv2
  5. resnext
  6. bninception
  7. shufflenetv2
  8. polynet
  9. vggm

Imagenet的预训练inceptionresnetv2是1000个类别,根据笔者添加了一个bottleneck层和一个head层使得可以进行自定义类别训练。

源码

from __future__ import print_function, division, absolute_import
import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
import os
import sys__all__ = ['InceptionResNetV2', 'inceptionresnetv2']pretrained_settings = {'inceptionresnetv2': {'imagenet': {'url': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth','input_space': 'RGB','input_size': [3, 299, 299],'input_range': [0, 1],'mean': [0.5, 0.5, 0.5],'std': [0.5, 0.5, 0.5],'num_classes': 1000},'imagenet+background': {'url': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth','input_space': 'RGB','input_size': [3, 299, 299],'input_range': [0, 1],'mean': [0.5, 0.5, 0.5],'std': [0.5, 0.5, 0.5],'num_classes': 1001}}
}def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):"""3x3 convolution with padding"""return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,padding=dilation, groups=groups, bias=False, dilation=dilation)def conv1x1(in_planes, out_planes, stride=1):"""1x1 convolution"""return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)class BasicBlock(nn.Module):expansion = 1__constants__ = ['downsample']def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,base_width=64, dilation=1, norm_layer=None):super(BasicBlock, self).__init__()if norm_layer is None:norm_layer = nn.BatchNorm2dif groups != 1 or base_width != 64:raise ValueError('BasicBlock only supports groups=1 and base_width=64')if dilation > 1:raise NotImplementedError("Dilation > 1 not supported in BasicBlock")# Both self.conv1 and self.downsample layers downsample the input when stride != 1self.conv1 = conv3x3(inplanes, planes, stride)self.bn1 = norm_layer(planes)self.relu = nn.ReLU(inplace=True)self.conv2 = conv3x3(planes, planes)self.bn2 = norm_layer(planes)self.downsample = downsampleself.stride = stridedef forward(self, x):identity = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)if self.downsample is not None:identity = self.downsample(x)out += identityout = self.relu(out)return outclass Bottleneck(nn.Module):expansion = 4__constants__ = ['downsample']def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,base_width=64, dilation=1, norm_layer=None):super(Bottleneck, self).__init__()if norm_layer is None:norm_layer = nn.BatchNorm2dwidth = int(planes * (base_width / 64.)) * groupsself.conv1 = conv1x1(inplanes, width)self.bn1 = norm_layer(width)self.conv2 = conv3x3(width, width, stride, groups, dilation)self.bn2 = norm_layer(width)self.conv3 = conv1x1(width, planes * self.expansion)self.bn3 = norm_layer(planes * self.expansion)self.relu = nn.ReLU(inplace=True)self.downsample = downsampleself.stride = stridedef forward(self, x):identity = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)if self.downsample is not None:identity = self.downsample(x)out += identityout = self.relu(out)return outclass BasicConv2d(nn.Module):def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):super(BasicConv2d, self).__init__()self.conv = nn.Conv2d(in_planes, out_planes,kernel_size=kernel_size, stride=stride,padding=padding, bias=False) # verify bias falseself.bn = nn.BatchNorm2d(out_planes,eps=0.001, # value found in tensorflowmomentum=0.1, # default pytorch valueaffine=True)self.relu = nn.ReLU(inplace=False)def forward(self, x):x = self.conv(x)x = self.bn(x)x = self.relu(x)return xclass Mixed_5b(nn.Module):def __init__(self):super(Mixed_5b, self).__init__()self.branch0 = BasicConv2d(192, 96, kernel_size=1, stride=1)self.branch1 = nn.Sequential(BasicConv2d(192, 48, kernel_size=1, stride=1),BasicConv2d(48, 64, kernel_size=5, stride=1, padding=2))self.branch2 = nn.Sequential(BasicConv2d(192, 64, kernel_size=1, stride=1),BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1),BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1))self.branch3 = nn.Sequential(nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),BasicConv2d(192, 64, kernel_size=1, stride=1))def forward(self, x):x0 = self.branch0(x)x1 = self.branch1(x)x2 = self.branch2(x)x3 = self.branch3(x)out = torch.cat((x0, x1, x2, x3), 1)return outclass Block35(nn.Module):def __init__(self, scale=1.0):super(Block35, self).__init__()self.scale = scaleself.branch0 = BasicConv2d(320, 32, kernel_size=1, stride=1)self.branch1 = nn.Sequential(BasicConv2d(320, 32, kernel_size=1, stride=1),BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1))self.branch2 = nn.Sequential(BasicConv2d(320, 32, kernel_size=1, stride=1),BasicConv2d(32, 48, kernel_size=3, stride=1, padding=1),BasicConv2d(48, 64, kernel_size=3, stride=1, padding=1))self.conv2d = nn.Conv2d(128, 320, kernel_size=1, stride=1)self.relu = nn.ReLU(inplace=False)def forward(self, x):x0 = self.branch0(x)x1 = self.branch1(x)x2 = self.branch2(x)out = torch.cat((x0, x1, x2), 1)out = self.conv2d(out)out = out * self.scale + xout = self.relu(out)return outclass Mixed_6a(nn.Module):def __init__(self):super(Mixed_6a, self).__init__()self.branch0 = BasicConv2d(320, 384, kernel_size=3, stride=2)self.branch1 = nn.Sequential(BasicConv2d(320, 256, kernel_size=1, stride=1),BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1),BasicConv2d(256, 384, kernel_size=3, stride=2))self.branch2 = nn.MaxPool2d(3, stride=2)def forward(self, x):x0 = self.branch0(x)x1 = self.branch1(x)x2 = self.branch2(x)out = torch.cat((x0, x1, x2), 1)return outclass Block17(nn.Module):def __init__(self, scale=1.0):super(Block17, self).__init__()self.scale = scaleself.branch0 = BasicConv2d(1088, 192, kernel_size=1, stride=1)self.branch1 = nn.Sequential(BasicConv2d(1088, 128, kernel_size=1, stride=1),BasicConv2d(128, 160, kernel_size=(1,7), stride=1, padding=(0,3)),BasicConv2d(160, 192, kernel_size=(7,1), stride=1, padding=(3,0)))self.conv2d = nn.Conv2d(384, 1088, kernel_size=1, stride=1)self.relu = nn.ReLU(inplace=False)def forward(self, x):x0 = self.branch0(x)x1 = self.branch1(x)out = torch.cat((x0, x1), 1)out = self.conv2d(out)out = out * self.scale + xout = self.relu(out)return outclass Mixed_7a(nn.Module):def __init__(self):super(Mixed_7a, self).__init__()self.branch0 = nn.Sequential(BasicConv2d(1088, 256, kernel_size=1, stride=1),BasicConv2d(256, 384, kernel_size=3, stride=2))self.branch1 = nn.Sequential(BasicConv2d(1088, 256, kernel_size=1, stride=1),BasicConv2d(256, 288, kernel_size=3, stride=2))self.branch2 = nn.Sequential(BasicConv2d(1088, 256, kernel_size=1, stride=1),BasicConv2d(256, 288, kernel_size=3, stride=1, padding=1),BasicConv2d(288, 320, kernel_size=3, stride=2))self.branch3 = nn.MaxPool2d(3, stride=2)def forward(self, x):x0 = self.branch0(x)x1 = self.branch1(x)x2 = self.branch2(x)x3 = self.branch3(x)out = torch.cat((x0, x1, x2, x3), 1)return outclass Block8(nn.Module):def __init__(self, scale=1.0, noReLU=False):super(Block8, self).__init__()self.scale = scaleself.noReLU = noReLUself.branch0 = BasicConv2d(2080, 192, kernel_size=1, stride=1)self.branch1 = nn.Sequential(BasicConv2d(2080, 192, kernel_size=1, stride=1),BasicConv2d(192, 224, kernel_size=(1,3), stride=1, padding=(0,1)),BasicConv2d(224, 256, kernel_size=(3,1), stride=1, padding=(1,0)))self.conv2d = nn.Conv2d(448, 2080, kernel_size=1, stride=1)if not self.noReLU:self.relu = nn.ReLU(inplace=False)def forward(self, x):x0 = self.branch0(x)x1 = self.branch1(x)out = torch.cat((x0, x1), 1)out = self.conv2d(out)out = out * self.scale + xif not self.noReLU:out = self.relu(out)return outclass InceptionResNetV2(nn.Module):def __init__(self, num_classes=1001, zero_init_residual=False):super(InceptionResNetV2, self).__init__()# Special attributsself.input_space = Noneself.input_size = (299, 299, 3)self.mean = Noneself.std = None# Modulesself.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2)self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1)self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1)self.maxpool_3a = nn.MaxPool2d(3, stride=2)self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1)self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1)self.maxpool_5a = nn.MaxPool2d(3, stride=2)self.mixed_5b = Mixed_5b()self.repeat = nn.Sequential(Block35(scale=0.17),Block35(scale=0.17),Block35(scale=0.17),Block35(scale=0.17),Block35(scale=0.17),Block35(scale=0.17),Block35(scale=0.17),Block35(scale=0.17),Block35(scale=0.17),Block35(scale=0.17))self.mixed_6a = Mixed_6a()self.repeat_1 = nn.Sequential(Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10))self.mixed_7a = Mixed_7a()self.repeat_2 = nn.Sequential(Block8(scale=0.20),Block8(scale=0.20),Block8(scale=0.20),Block8(scale=0.20),Block8(scale=0.20),Block8(scale=0.20),Block8(scale=0.20),Block8(scale=0.20),Block8(scale=0.20))self.block8 = Block8(noReLU=True)self.conv2d_7b = BasicConv2d(2080, 1536, kernel_size=1, stride=1)self.avgpool_1a = nn.AvgPool2d(5, count_include_pad=False)self.bottleneck = nn.Sequential(nn.Linear(1536, 512),nn.BatchNorm1d(512),nn.ReLU(),nn.Dropout(0.5))self.bottleneck[0].weight.data.normal_(0, 0.005)self.bottleneck[0].bias.data.fill_(0.1)self.head = nn.Sequential(nn.Linear(512, 512),nn.ReLU(),nn.Dropout(0.5),nn.Linear(512, num_classes))# self.fc = nn.Linear(512, num_classes)for dep in range(2):self.head[dep * 3].weight.data.normal_(0, 0.01)self.head[dep * 3].bias.data.fill_(0.0)for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)if zero_init_residual:for m in self.modules():if isinstance(m, Bottleneck):nn.init.constant_(m.bn3.weight, 0)elif isinstance(m, BasicBlock):nn.init.constant_(m.bn2.weight, 0)def features(self, input):x = self.conv2d_1a(input)x = self.conv2d_2a(x)x = self.conv2d_2b(x)x = self.maxpool_3a(x)x = self.conv2d_3b(x)x = self.conv2d_4a(x)x = self.maxpool_5a(x)x = self.mixed_5b(x)x = self.repeat(x)x = self.mixed_6a(x)x = self.repeat_1(x)x = self.mixed_7a(x)x = self.repeat_2(x)x = self.block8(x)x = self.conv2d_7b(x)return xdef logits(self, features):x = self.avgpool_1a(features)# print("x1.size={}".format(x.shape))x = x.view(x.size(0), -1)# print("x2.size={}".format(x.shape))x = self.bottleneck(x)x = self.head(x)# x = self.last_linear(x)return xdef forward(self, input):x = self.features(input)# print("x0.size={}".format(x.shape))x = self.logits(x)return xdef inceptionresnetv2(num_classes=1000, pretrained='imagenet'):r"""InceptionResNetV2 model architecture from the`"InceptionV4, Inception-ResNet..." <https://arxiv.org/abs/1602.07261>`_ paper."""if pretrained:pretrained = 'imagenet+background'num_classes_hat = 1001settings = pretrained_settings['inceptionresnetv2'][pretrained]# print(settings)# print('num=%d\n',num_classes)# assert num_classes == settings['num_classes'], \#     "num_classes should be {}, but is {}".format(settings['num_classes'], num_classes)# both 'imagenet'&'imagenet+background' are loaded from same parametersmodel = InceptionResNetV2(num_classes=num_classes_hat)model.load_state_dict(model_zoo.load_url(settings['url']), strict=False)# if pretrained == 'imagenet+background':#     # print("yes")#     # model.last_linear = nn.Linear(1536, num_classes).cuda()#     new_last_linear = nn.Linear(1536, num_classes).cuda()#     new_last_linear.weight.data = model.last_linear.weight.data[1:]#     new_last_linear.bias.data = model.last_linear.bias.data[1:]#     model.last_linear = new_last_linearmodel.input_space = settings['input_space']model.input_size = settings['input_size']model.input_range = settings['input_range']model.mean = settings['mean']model.std = settings['std']else:model = InceptionResNetV2(num_classes=num_classes)return model'''
TEST
Run this code with:

cd $HOME/pretrained-models.pytorch

python -m pretrainedmodels.inceptionresnetv2

'''
if __name__ == '__main__':assert inceptionresnetv2(num_classes=10, pretrained=None)print('success')assert inceptionresnetv2(num_classes=1000, pretrained='imagenet')print('success')assert inceptionresnetv2(num_classes=1001, pretrained='imagenet+background')print('success')# failassert inceptionresnetv2(num_classes=1001, pretrained='imagenet')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/867681.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Go 中的类型推断

&#x1f49d;&#x1f49d;&#x1f49d;欢迎莅临我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:「stormsha的主页」…

【三级等保】等保整体建设方案(Word原件)

建设要点目录&#xff1a; 1、系统定级与安全域 2、实施方案设计 3、安全防护体系建设规划 软件全文档&#xff0c;全方案获取方式&#xff1a;本文末个人名片直接获取。

【Python】基于KMeans的航空公司客户数据聚类分析

&#x1f490;大家好&#xff01;我是码银~&#xff0c;欢迎关注&#x1f490;&#xff1a; CSDN&#xff1a;码银 公众号&#xff1a;码银学编程 实验目的和要求 会用Python创建Kmeans聚类分析模型使用KMeans模型对航空公司客户价值进行聚类分析会对聚类结果进行分析评价 实…

Python酷库之旅-第三方库Pandas(008)

目录 一、用法精讲 16、pandas.DataFrame.to_json函数 16-1、语法 16-2、参数 16-3、功能 16-4、返回值 16-5、说明 16-6、用法 16-6-1、数据准备 16-6-2、代码示例 16-6-3、结果输出 17、pandas.read_html函数 17-1、语法 17-2、参数 17-3、功能 17-4、返回值…

介绍东芝TB62262FTAG芯片:高性能两相双极步进电机驱动器

在当今快速发展的科技领域&#xff0c;高性能的电机驱动器对于许多工程项目来说至关重要。东芝的TB62262FTAG这款两相双极步进电机驱动器采用PWM斩波技术&#xff0c;集成了多个先进功能&#xff0c;适用于各种工业和消费类应用。本文将详细介绍TB62262FTAG的参数、性能、优势及…

《向量数据库指南》——Milvus Cloud检索器增强的深度探讨:句子窗口检索与元数据过滤

检索器增强的深度探讨&#xff1a;句子窗口检索与元数据过滤 在信息爆炸的时代&#xff0c;高效的检索系统成为了连接用户与海量数据的关键桥梁。为了进一步提升检索的准确性和用户满意度&#xff0c;检索器增强技术应运而生&#xff0c;其中句子窗口检索与元数据过滤作为两大…

【Qt】day3 自定义控件、框架、定时器、QPainter、QFile

文章目录 自定义控件封装自定义框架定时器第一种方式第二种方式 &#xff08;推荐&#xff09; 事件分发器QPainter基本操作高级设置抗锯齿移动坐标原点 画家画资源图片&#xff0c;并实现手动移动 作业QPaintDevice绘图设备QPixmapQimageQPicture QFile文件读写操作QFileInfo文…

移动校园(3):处理全校课程数据excel文档,实现空闲教室查询与课程表查询

首先打开教学平台 然后导出为excel文档 import mathimport pandas as pd import pymssql serverName 127.0.0.1 userName sa passWord 123456 databaseuniSchool conn pymssql.connect(serverserverName,useruserName,passwordpassWord,databasedatabase) cursor conn.cur…

昇思11天

基于 MindSpore 实现 BERT 对话情绪识别 BERT模型概述 BERT&#xff08;Bidirectional Encoder Representations from Transformers&#xff09;是由Google于2018年开发并发布的一种新型语言模型。BERT在许多自然语言处理&#xff08;NLP&#xff09;任务中发挥着重要作用&am…

【C++】map和set详解

目录 1. 关联式容器 2. 键值对pair 3. 树形结构的关联式容器 4. set 4.1 set的介绍 4.2 set的构造 4.3 set的迭代器 4.4 set的容量 4.5 set的常用函数 5. multiset 6. map 6.1 map的介绍 6.2 map的构造 6.3 map的迭代器 6.4 map的容量 6.5 map的operator[] 6.6…

【虚幻引擎】UE4初学者系列教程开发进阶实战篇——生存游戏案例

一、课程体系 1 学前必读 2 Character类相关基础 -人物移动控制 -动画蓝图 3 常见游戏机制基础 -碰撞器、触发器 -物体使用接口 -视角切换 4其他相关设计 -背包系统 -锻造系统 -物体破碎效果 -简易种植系统 -互动物体动画 5课程结语 二、UI部分 思维导图部分 实操部分 …

如何借助AI在20分钟内写一个springboot单表的增删改查

目录 1. AI工具介绍2. 写代码的正确顺序2.1 编写 Entity 类&#xff1a;2.2 编写 Mapper 接口&#xff1a;2.3 编写 Mapper XML 文件&#xff08;如果使用 MyBatis&#xff09;&#xff1a;2.4 编写 Service 接口&#xff1a;2.5 编写 Service 实现类&#xff08;ServiceImpl&a…

【pyhton学习】深度理解类和对象

&#x1f3ac; 鸽芷咕&#xff1a;个人主页 &#x1f525; 个人专栏: 《C干货基地》《粉丝福利》 ⛺️生活的理想&#xff0c;就是为了理想的生活! 文章目录 一、一切皆对象1.1 对象的概念1.2 如何创建类对象1.3 类型检测 二、属性与方法2.1 如何查看属性与方法2.2 属性和方法…

C语言 | Leetcode C语言题解之第220题存在重复元素III

题目&#xff1a; 题解&#xff1a; struct HashTable {int key;int val;UT_hash_handle hh; };int getID(int x, long long w) {return x < 0 ? (x 1ll) / w - 1 : x / w; }struct HashTable* query(struct HashTable* hashTable, int x) {struct HashTable* tmp;HASH_F…

leetcode每日一题-3101 交替子数组计数

暴力遍历&#xff1a;看起来像是回溯,实际上就是递归 class Solution { private:long long _res 0; public:long long countAlternatingSubarrays(vector<int>& nums) {backtrack(nums, 0);return _res;}void backtrack(vector<int>& nums, long long st…

查询某个县区数据,没有的数据用0补充。

加油&#xff0c;新时代打工人&#xff01; 思路&#xff1a; 先查出有数据的县区&#xff0c;用县区编码判断&#xff0c;不存在县区里的数据。然后&#xff0c;用union all进行两个SQL拼接起来。 SELECTt.regionCode,t.regionName,t.testNum,t.sampleNum,t.squareNum,t.crop…

普中51单片机:数码管显示原理与实现详解(四)

文章目录 引言数码管的结构数码管的工作原理静态数码管电路图开发板IO连接图代码演示 动态数码管实现步骤数码管驱动方式电路图开发板IO连接图真值表代码演示1代码演示2代码演示3 引言 数码管&#xff08;Seven-Segment Display&#xff09;是一种常见的显示设备&#xff0c;广…

Visual studio 2023下使用 installer projects 打包C#程序并创建 CustomAction 类

Visual studio 2023下使用 installer projects 打包C#程序并创建 CustomAction 类 1 安装Visual studio 20203,并安装插件1.1 下载并安装 Visual Studio1.2 步骤二:安装 installer projects 扩展插件2 创建安装项目2.1 创建Windows安装项目2.2 新建应用程序安装文件夹2.3 添加…

A Threat Actors 出售 18 万名 Shopify 用户信息

BreachForums 论坛成员最近发布了涉及 Shopify 的重大数据泄露事件。 据报道&#xff0c;属于近 180,000 名用户的敏感数据遭到泄露。 Shopify Inc. 是一家总部位于安大略省渥太华的加拿大公司。 开发和营销同名电子商务平台、Shopify POS 销售点系统以及专用于企业的营销工…

SQL脚本初始化数据

创建或选择某个数据库&#xff0c;运行窗口输入&#xff1a;source,再拖入文件&#xff0c;回车即可&#xff1b; 虽然也可以使用图形化工具初始化数据&#xff0c;但是他会有内存限制&#xff0c;也就是较大的sql文件不可以初始化&#xff0c;而运行窗口没有sql文件大小限制&…