from collections import OrderedDictimport torch
from torch import nnclass IntermediateLayerGetter(nn.ModuleDict):"""Module wrapper that returns intermediate layers from a modelIt has a strong assumption that the modules have been registeredinto the model in the same order as they are used.This means that one should **not** reuse the same nn.Moduletwice in the forward if you want this to work.Additionally, it is only able to query submodules that are directlyassigned to the model. So if `model` is passed, `model.feature1` canbe returned, but not `model.feature1.layer2`.Arguments:model (nn.Module): model on which we will extract the featuresreturn_layers (Dict[name, new_name]): a dict containing the namesof the modules for which the activations will be returned asthe key of the dict, and the value of the dict is the nameof the returned activation (which the user can specify).def __init__(self, model, return_layers):if not set(return_layers).issubset([name for name, _ in model.named_children()]):raise ValueError("return_layers are not present in model")orig_return_layers = return_layersreturn_layers = {k: v for k, v in return_layers.items()}layers = OrderedDict()for name, module in model.named_children():layers[name] = moduleif name in return_layers:del return_layers[name]if not return_layers:breaksuper(IntermediateLayerGetter, self).__init__(layers)self.return_layers = orig_return_layersdef forward(self, x):out = OrderedDict()for name, module in self.named_children():x = module(x)if name in self.return_layers:out_name = self.return_layers[name]out[out_name] = xreturn out
import torchvision
m = torchvision.models.resnet18(pretrained=False)
m = torchvision.models.resnet18(pretrained=True)# extract layer1 and layer3, giving as names `feat1` and feat2`
new_m = torchvision.models._utils.IntermediateLayerGetter(m,{'layer1': 'feat1', 'layer3': 'feat2'})
out = new_m(torch.rand(1, 3, 224, 224))
print([(k, v.shape) for k, v in out.items()])