timm 视觉库中的 create_model 函数详解

timm 视觉库中的 create_model 函数详解

最近一年 Vision Transformer 及其相关改进的工作层出不穷,在他们开源的代码中,大部分都用到了这样一个库:timm。各位炼丹师应该已经想必已经对其无比熟悉了,本文将介绍其中最关键的函数之一:create_model 函数。

timm简介

PyTorchImageModels,简称timm,是一个巨大的PyTorch代码集合,包括了一系列:

  • image models
  • layers
  • utilities
  • optimizers
  • schedulers
  • data-loaders / augmentations
  • training / validation scripts

旨在将各种 SOTA 模型、图像实用工具、常用的优化器、训练策略等视觉相关常用函数的整合在一起,并具有复现ImageNet训练结果的能力。

源码:https://github.com/rwightman/pytorch-image-models

文档:https://fastai.github.io/timmdocs/

create_model 函数的使用及常用参数

本小节先介绍 create_model 函数,及常用的参数 **kwargs

顾名思义,create_model 函数是用来创建一个网络模型(如 ResNet、ViT 等),timm 库本身可供直接调用的模型已有接近400个,用户也可以自己实现一些模型并注册进 timm (这一部分内容将在下一小节着重介绍),供自己调用。

model_name

我们首先来看最简单地用法:直接传入模型名称 model_name

import timm 
# 创建 resnet-34 
model = timm.create_model('resnet34')
# 创建 efficientnet-b0
model = timm.create_model('efficientnet_b0')

我们可以通过 list_models 函数来查看已经可以直接创建、有预训练参数的模型列表:

all_pretrained_models_available = timm.list_models(pretrained=True)
print(all_pretrained_models_available)
print(len(all_pretrained_models_available))

输出:

[..., 'vit_large_patch16_384', 'vit_large_patch32_224_in21k', 'vit_large_patch32_384', 'vit_small_patch16_224', 'wide_resnet50_2', 'wide_resnet101_2', 'xception', 'xception41', 'xception65', 'xception71']
452

如果没有设置 pretrained=True 的话有将会输出612,即有预训练权重参数的模型有452个,没有预训练参数,只有模型结构的共有612个。

pretrained

如果我们传入 pretrained=True,那么 timm 会从对应的 URL 下载模型权重参数并载入模型,只有当第一次(即本地还没有对应模型参数时)会去下载,之后会直接从本地加载模型权重参数。

model = timm.create_model('resnet34', pretrained=True)

输出:

Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth" to /home/song/.cache/torch/hub/checkpoints/resnet34-43635321.pth

features_only、out_indices

create_mode 函数还支持 features_only=True 参数,此时函数将返回部分网络,该网络提取每一步最深一层的特征图。还可以使用 out_indices=[…] 参数指定层的索引,以提取中间层特征。

# 创建一个 (1, 3, 224, 224) 形状的张量
x = torch.randn(1, 3, 224, 224)
model = timm.create_model('resnet34')
preds = model(x)
print('preds shape: {}'.format(preds.shape))all_feature_extractor = timm.create_model('resnet34', features_only=True)
all_features = all_feature_extractor(x)
print('All {} Features: '.format(len(all_features)))
for i in range(len(all_features)):print('feature {} shape: {}'.format(i, all_features[i].shape))out_indices = [2, 3, 4]
selected_feature_extractor = timm.create_model('resnet34', features_only=True, out_indices=out_indices)
selected_features = selected_feature_extractor(x)
print('Selected Features: ')
for i in range(len(out_indices)):print('feature {} shape: {}'.format(out_indices[i], selected_features[i].shape))

我们以一个 (1, 3, 224, 224) 形状的张量为输入,在视觉任务中,图像输入张量总是类似的形状。上面例程展示了,创建完整模型 model,创建完整特征提取器 all_feature_extractor,和创建某几层特征提取器 selected_feature_extractor 的具体输出。

可以结合下面 ResNet34 的结构图来理解(图中不同的颜色表示不同的 layer),根据下图分析各层的卷积操作,计算各层最后一个卷积的输入,并与上面例程的输出(附在图后)验证是否一致。

在这里插入图片描述

输出:

preds shape: torch.Size([1, 1000])
All 5 Features:
feature 0 shape: torch.Size([1, 64, 112, 112])
feature 1 shape: torch.Size([1, 64, 56, 56])
feature 2 shape: torch.Size([1, 128, 28, 28])
feature 3 shape: torch.Size([1, 256, 14, 14])
feature 4 shape: torch.Size([1, 512, 7, 7])
Selected Features:
feature 2 shape: torch.Size([1, 128, 28, 28])
feature 3 shape: torch.Size([1, 256, 14, 14])
feature 4 shape: torch.Size([1, 512, 7, 7])

这样,我们就可以通过 timm_model 函数及其 features_onlyout_indices 参数将预训练模型方便地转换为自己想要的特征提取器。

接下来我们来看一下这些特征提取器究竟是什么类型:

import timm
feature_extractor = timm.create_model('resnet34', features_only=True, out_indices=[3])print('type:', type(feature_extractor))
print('len: ', len(feature_extractor))
for item in feature_extractor:print(item)

输出:

type: <class 'timm.models.features.FeatureListNet'>
len:  7
conv1
bn1
act1
maxpool
layer1
layer2
layer3

可以看到,feature_extractor 其实也是一个神经网络,在 timm 中称为 FeatureListNet,而我们通过 out_indices 参数来指定截取到哪一层特征。

需要注意的是,ViT 模型并不支持 features_only 选项(0.4.12版本)。

extractor = timm.create_model('vit_base_patch16_224', features_only=True)

输出:

RuntimeError: features_only not implemented for Vision Transformer models.

create_model 函数究竟做了什么

registry

在了解了 create_model 函数的基本使用之后,我们来深入探索一下 create_model 函数的源码,看一下究竟是怎样实现从模型到特征提取器的转换的。

create_model 主体只有 50 行左右的代码,因此所有这些神奇的事情是在其他地方完成的。我们知道 timm.list_models() 函数中的每一个模型名字(str)实际上都是一个函数。以下代码可以测试这一点:

import timm
import random 
from timm.models import registrym = timm.list_models()[-1]
print(m)
registry.is_model(m)

输出:

xception71
True

实际上,在 timm 内部,有一个字典称为 _model_entrypoints 包含了所有的模型名称和他们各自的函数。比如说,我们可以通过 model_entrypoint 函数从 _model_entrypoints 内部得到 xception71 模型的构造函数。

constuctor_fn = registry.model_entrypoint(m)
print(constuctor_fn)

输出:

<function timm.models.xception_aligned.xception71(pretrained=False, **kwargs)>

也有可能输出:

<function xception71 at 0x7fc0cba0eca0>

一样的。

如我们所见,在 timm.models.xception_aligned 模块中有一个函数称为 xception71 。类似的,timm 中的每一个模型都有着一个这样的构造函数。事实上,内部的 _model_entrypoints 字典大概长这个样子:

_model_entrypoints
> > 
{
'cspresnet50':<function timm.models.cspnet.cspresnet50(pretrained=False, **kwargs)>,'cspresnet50d': <function timm.models.cspnet.cspresnet50d(pretrained=False, **kwargs)>,
'cspresnet50w': <function timm.models.cspnet.cspresnet50w(pretrained=False, **kwargs)>,
'cspresnext50': <function timm.models.cspnet.cspresnext50(pretrained=False, **kwargs)>,
'cspresnext50_iabn': <function timm.models.cspnet.cspresnext50_iabn(pretrained=False, **kwargs)>,
'cspdarknet53': <function timm.models.cspnet.cspdarknet53(pretrained=False, **kwargs)>,
'cspdarknet53_iabn': <function timm.models.cspnet.cspdarknet53_iabn(pretrained=False, **kwargs)>,
'darknet53': <function timm.models.cspnet.darknet53(pretrained=False, **kwargs)>,
'densenet121': <function timm.models.densenet.densenet121(pretrained=False, **kwargs)>,
'densenetblur121d': <function timm.models.densenet.densenetblur121d(pretrained=False, **kwargs)>,
'densenet121d': <function timm.models.densenet.densenet121d(pretrained=False, **kwargs)>,
'densenet169': <function timm.models.densenet.densenet169(pretrained=False, **kwargs)>,
'densenet201': <function timm.models.densenet.densenet201(pretrained=False, **kwargs)>,
'densenet161': <function timm.models.densenet.densenet161(pretrained=False, **kwargs)>,
'densenet264': <function timm.models.densenet.densenet264(pretrained=False, **kwargs)>,}

所以说,在 timm 对应的模块中,每个模型都有一个构造器。比如说 ResNets 系列模型被定义在 timm.models.resnet 模块中。因此,实际上我们有两种方式来创建一个 resnet34 模型:

import timm
from timm.models.resnet import resnet34# 使用 create_model
m = timm.create_model('resnet34')# 直接调用构造函数
m = resnet34()

但使用上,我们无须调用构造函数。所用模型都可以通过 create_model 函数来将创建。

Register model

resnet34 构造函数的源码如下:

@register_model
def resnet34(pretrained=False, **kwargs):"""Constructs a ResNet-34 model."""model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], **kwargs)return _create_resnet('resnet34', pretrained, **model_args)

我们会发现 timm 中的每个模型都有一个 register_model 装饰器。最开始, _model_entrypoints 是一个空字典。我们是通过 register_model 装饰器来不断地像其中添加模型名称和它对应的构造函数。该装饰器的定义如下:

def register_model(fn):# lookup containing modulemod = sys.modules[fn.__module__]module_name_split = fn.__module__.split('.')module_name = module_name_split[-1] if len(module_name_split) else ''# add model to __all__ in modulemodel_name = fn.__name__if hasattr(mod, '__all__'):mod.__all__.append(model_name)else:mod.__all__ = [model_name]# add entries to registry dict/sets_model_entrypoints[model_name] = fn_model_to_module[model_name] = module_name_module_to_models[module_name].add(model_name)has_pretrained = False  # check if model has a pretrained url to allow filtering on thisif hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs:# this will catch all models that have entrypoint matching cfg key, but miss any aliasing# entrypoints or non-matching comboshas_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url']if has_pretrained:_model_has_pretrained.add(model_name)return fn

我们可以看到, register_model 函数完成了一些比较基础的步骤,但这里需要指出的是这一句:

_model_entrypoints[model_name] = fn

它将给定的 fn 添加到 _model_entrypoints 其键名为 fn.__name__。所以说 resnet34 函数上的装饰器 @register_model_model_entrypoints 中创建一个新的条目,像这样:

{&#8217;resnet34&#8217;: <function timm.models.resnet.resnet34(pretrained=False, **kwargs)>}

我们同样可以看到在 resnet34 构造函数的源码中,在设置完一些 model_args 之后,它会随后调用 _create_resnet 函数。让我们再来看一下该函数的源码:

def _create_resnet(variant, pretrained=False, **kwargs):return build_model_with_cfg(ResNet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained, **kwargs)

所以在 _create_resnet 函数之中,会再调用 build_model_with_cfg 函数并将一个构造器类 ResNet 、变量名 resnet34、一个 default_cfg 和一些 **kwargs 传入其中。

default config

timm 中所有的模型都有一个默认的配置,包括指向它的预训练权重参数的URL、类别数、输入图像尺寸、池化尺寸等。

resnet34 的默认配置如下:

{'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth',
'num_classes': 1000,
'input_size': (3, 224, 224),
'pool_size': (7, 7),
'crop_pct': 0.875,
'interpolation': 'bilinear',
'mean': (0.485, 0.456, 0.406),
'std': (0.229, 0.224, 0.225),
'first_conv': 'conv1',
'classifier': 'fc'}

此默认配置与其他参数(如构造函数类和一些模型参数)一起传递给 build_model_with_cfg 函数。

build model with config

这个 build_model_with_cfg 函数负责:

  1. 真正地实例化一个模型类来创建一个模型
  2. pruned=True,对模型进行剪枝
  3. pretrained=True,加载预训练模型参数
  4. features_only=True,将模型转换为特征提取器

看一下该函数的源码:

def build_model_with_cfg(model_cls: Callable,variant: str,pretrained: bool,default_cfg: dict,model_cfg: dict = None,feature_cfg: dict = None,pretrained_strict: bool = True,pretrained_filter_fn: Callable = None,pretrained_custom_load: bool = False,**kwargs):pruned = kwargs.pop('pruned', False)features = Falsefeature_cfg = feature_cfg or {}if kwargs.pop('features_only', False):features = Truefeature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))if 'out_indices' in kwargs:feature_cfg['out_indices'] = kwargs.pop('out_indices')model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs)model.default_cfg = deepcopy(default_cfg)if pruned:model = adapt_model_from_file(model, variant)# for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for featsnum_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))if pretrained:if pretrained_custom_load:load_custom_pretrained(model)else:load_pretrained(model,num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3),filter_fn=pretrained_filter_fn, strict=pretrained_strict)if features:feature_cls = FeatureListNetif 'feature_cls' in feature_cfg:feature_cls = feature_cfg.pop('feature_cls')if isinstance(feature_cls, str):feature_cls = feature_cls.lower()if 'hook' in feature_cls:feature_cls = FeatureHookNetelse:assert False, f'Unknown feature class {feature_cls}'model = feature_cls(model, **feature_cfg)model.default_cfg = default_cfg_for_features(default_cfg)  # add back default_cfgreturn model

我们可以看到,模型在这一步被创建出来:model = model_cls(**kwargs)。本文将不再深入到 prunedadapt_model_from_file 内部查看。

总结

通过本文,我们已经完全了解了 create_model 函数,我们了解到:

  • 每个模型有不同的构造函数,可以传入不同的参数, _model_entrypoints 字典包括了所有的模型名称及其对应的构造函数
  • build_with_model_cfg 函数接收模型构造器类和其中的一些具体参数,真正地实例化一个模型
  • load_pretrained 会加载预训练参数
  • FeatureListNet 类可以将模型转换为特征提取器

Ref:

https://github.com/rwightman/pytorch-image-models

https://fastai.github.io/timmdocs/

https://fastai.github.io/timmdocs/create_model#Turn-any-model-into-a-feature-extractor

https://fastai.github.io/timmdocs/tutorial_feature_extractor

https://zhuanlan.zhihu.com/p/404107277

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/532666.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C--数据结构--树的学习

6.2.1二叉树的性质 1.二叉树 性质&#xff1a; 1.若二叉树的层次从1开始&#xff0c;则在二叉树的第i层最多有2^(i-1)个结点 2.深度为k的二叉树最多有2^k -1个结点 &#xff08;k>1&#xff09; 3.对任何一颗二叉树&#xff0c;如果其叶结点个数为n0,度为2的非叶结点个数…

C语言—sort函数比较大小的快捷使用--algorithm头文件下

sort函数 一般情况下要将一组数从的大到小排序或从小到大排序&#xff0c;要定义一个新的函数排序。 而我们也可以直接使用在函数下的sort函数&#xff0c;只需加上头文件&#xff1a; #include<algorithm> using namespace std;sort格式&#xff1a;sort(首元素地址&…

AI编译器与传统编译器的联系与区别

AI编译器与传统编译器的区别与联系 总结整理自知乎问题 针对神经网络的编译器和传统编译器的区别和联系是什么&#xff1f;。 文中提到的答主的知乎主页&#xff1a;金雪锋、杨军、蓝色、SunnyCase、贝壳与知了、工藤福尔摩 笔者本人理解 为了不用直接手写机器码&#xff0…

python学习1:注释\变量类型\转换函数\转义字符\运算符

python基础学习 与大多数语言不同&#xff0c;python最具特色的就是使用缩进来表示代码块&#xff0c;不需要使用大括号 {} 。缩进的空格数是可变的&#xff0c;但是同一个代码块的语句必须包含相同的缩进空格数。 &#xff08;一个tab4个空格&#xff09; Python语言中常见的…

python 学习2 /输入/ 输出 /列表 /字典

python基础学习第二天 输入输出 xinput("输入内容") print(x)input输出&#xff1a; eval :去掉字符串外围的引号&#xff0c;按照python的语法执行内容 aeval(12) print(a)eval输出样式&#xff1a; 列表 建立&#xff0c;添加&#xff0c;插入&#xff0c;删去…

快速排序 C++

快速排序 C 本文图示借鉴自清华大学邓俊辉老师数据结构课程。 快速排序的思想 快速排序是分治思想的典型应用。该排序算法可以原地实现&#xff0c;即空间复杂度为 O(1)O(1)O(1)&#xff0c;而时间复杂度为 O(nlogn)O(nlogn)O(nlogn) 。 算法将待排序的序列 SSS 分为两个子…

llvm与gcc

llvm与gcc llvm 是一个编译器&#xff0c;也是一个编译器架构&#xff0c;是一系列编译工具&#xff0c;也是一个编译器工具链&#xff0c;开源 C11 实现。 gcc 相对于 clang 的优势&#xff1a; gcc 支持更过语言前端&#xff0c;如 Java, Ada, FORTRAN, Go等gcc 支持更多地 …

攻防世界web新手区解题 view_source / robots / backup

1**. view_source** 题目描述&#xff1a;X老师让小宁同学查看一个网页的源代码&#xff0c;但小宁同学发现鼠标右键好像不管用了。 f12查看源码即可发现flag 2. robots 题目描述&#xff1a;X老师上课讲了Robots协议&#xff0c;小宁同学却上课打了瞌睡&#xff0c;赶紧来教教…

听GPT 讲Rust源代码--src/tools(25)

File: rust/src/tools/clippy/clippy_lints/src/methods/suspicious_command_arg_space.rs 在Rust源代码中&#xff0c;suspicious_command_arg_space.rs文件位于clippy_lints工具包的methods目录下&#xff0c;用于实现Clippy lint SUSPICIOUS_COMMAND_ARG_SPACE。 Clippy是Ru…

Java一次编译,到处运行是如何实现的

Java一次编译&#xff0c;到处运行是如何实现的 转自&#xff1a;https://cloud.tencent.com/developer/article/1415194 &#xff08;排版微调&#xff09; JAVA编译运行总览 Java是一种高级语言&#xff0c;要让计算机执行你撰写的Java程序&#xff0c;也得通过编译程序的…

攻防世界web新手区解题 /cookie / disabled_button / weak_auth

cookie 题目描述&#xff1a;X老师告诉小宁他在cookie里放了些东西&#xff0c;小宁疑惑地想&#xff1a;‘这是夹心饼干的意思吗&#xff1f;’ 使用burp suite抓包查看 发现提示&#xff1a; look-herecookie.php 于是在url后加上 cookie.php 得到提示查看返回 就得到了f…

基于GET报错的sql注入,sqli-lab 1~4

根据注入类型可将sql注入分为两类&#xff1a;数字型和字符型 例如&#xff1a; 数字型&#xff1a; sleect * from table where if 用户输入id 字符型&#xff1a;select * from table where id 用户输入id &#xff08;有引号) 通过URL中修改对应的D值&#xff0c;为正常数字…

xss原理和注入类型

XSS漏洞原理 : XSS又叫CSS(cross Site Script), 跨站脚本攻击,指的是恶意攻击者往Web页面里插入恶意JS代码,当用户浏览该页时,嵌入其中的Web里的JS代码就会被执行,从而达到恶意的特殊目的. 比如:拿到cooike XSS漏洞分类: 反射性(非存储型) payload没有经过存储,后端接收后,直接…

存储型xss案例

存储型xss原理: 攻击者在页面插入xss代码,服务端将数据存入数据库,当用户访问存在xss漏洞的页面时,服务端从数据库取出数据展示到页面上,导致xss代码执行,达到攻击效果 案例: 在一个搭建的论坛网站中, 根据存储型xss注入的条件,要找到可以存储到数据库的输入位置,并且这个位置…

反射型XSS案例

**原理:**攻击者将url中插入xss代码,服务端将url中的xss代码输出到页面上,攻击者将带有xss代码的url发送给用户,用户打开后受到xss攻击 需要url中有可以修改的参数 案例: 可能存在反射型xss的功能(点) : 搜索框等&#xff08;所有url会出现参数的地方都可以尝试&#xff09;……

xss-lab靶场通关writeup(1~6.......在更新)

level 2 : 标签被编码&#xff0c;利用属性完成弹窗 输入 发现没有弹窗 查看源代码&#xff1a; 发现&#xff1a; <>符号被编码 说明keybord参数进行了处理&#xff0c;那么只能从属性上进行恶意编码&#xff1a;先将属性的引号和标签闭合&#xff0c;用 // 将后面的…

PyTorch 分布式训练DDP 单机多卡快速上手

PyTorch 分布式训练DDP 单机多卡快速上手 本文旨在帮助新人快速上手最有效的 PyTorch 单机多卡训练&#xff0c;对于 PyTorch 分布式训练的理论介绍、多方案对比&#xff0c;本文不做详细介绍&#xff0c;有兴趣的读者可参考&#xff1a; [分布式训练] 单机多卡的正确打开方式…

Linux free 命令详解

Linux free 命令详解 free 命令用来查看系统中已用的和可用的内存。 命令选项及输出简介 关于各种命令的功能和命令选项&#xff0c;还是推荐英语比较好的同学直接看手册 RTFM&#xff1a;man free。这里简单总结一下一些重点&#xff1a; 功能及输出简介 free 命令显示系…

CTF web题 wp:

1.签到题 火狐F12查看源码&#xff0c;发现注释&#xff1a; 一次base64解码出flag 2.Encode 在这里插入图片描述 和第一题界面一样&#xff1f;&#xff1f; 轻车熟路f12&#xff1a; 发现编码&#xff1a; 格式看上去是base64&#xff0c;连续两次base64后&#xff0c;观…

【深度学习】深入理解Batch Normalization批归一化

【深度学习】深入理解Batch Normalization批归一化 转自&#xff1a;https://www.cnblogs.com/guoyaohua/p/8724433.html 这几天面试经常被问到BN层的原理&#xff0c;虽然回答上来了&#xff0c;但还是感觉答得不是很好&#xff0c;今天仔细研究了一下Batch Normalization的原…