torchvision.models._utils.IntermediateLayerGetter()使用
源码如下:
from collections import OrderedDictimport torch
from torch import nnclass IntermediateLayerGetter(nn.ModuleDict):"""Module wrapper that returns intermediate layers from a modelIt has a strong assumption that the modules have been registeredinto the model in the same order as they are used.This means that one should **not** reuse the same nn.Moduletwice in the forward if you want this to work.Additionally, it is only able to query submodules that are directlyassigned to the model. So if `model` is passed, `model.feature1` canbe returned, but not `model.feature1.layer2`.Arguments:model (nn.Module): model on which we will extract the featuresreturn_layers (Dict[name, new_name]): a dict containing the namesof the modules for which the activations will be returned asthe key of the dict, and the value of the dict is the nameof the returned activation (which the user can specify).def __init__(self, model, return_layers):if not set(return_layers).issubset([name for name, _ in model.named_children()]):raise ValueError("return_layers are not present in model")orig_return_layers = return_layersreturn_layers = {k: v for k, v in return_layers.items()}layers = OrderedDict()for name, module in model.named_children():layers[name] = moduleif name in return_layers:del return_layers[name]if not return_layers:breaksuper(IntermediateLayerGetter, self).__init__(layers)self.return_layers = orig_return_layersdef forward(self, x):out = OrderedDict()for name, module in self.named_children():x = module(x)if name in self.return_layers:out_name = self.return_layers[name]out[out_name] = xreturn out
传入参数:
model:输入是我们要提取中间层特征的模型
return_layers:返回层是返回一个字典,是对应的层名与对应的特征
原理举例:
这里引入pytorch中的ResNet18模型打印各层结构
import torchvision
m = torchvision.models.resnet18(pretrained=False)
print(m)
输入如下:
m = torchvision.models.resnet18(pretrained=True)# extract layer1 and layer3, giving as names `feat1` and feat2`
new_m = torchvision.models._utils.IntermediateLayerGetter(m,{'layer1': 'feat1', 'layer3': 'feat2'})
out = new_m(torch.rand(1, 3, 224, 224))
print([(k, v.shape) for k, v in out.items()])
输出结果如下: