使用timm创建模型会出现网络连接等错误,比如LocalEntryNotFoundError: Connection error, and we cannot find the requested files in the disk cache. Please try again or make sure your Internet connection is on.
这是因为timm下载权重默认是从huggingfaceHub,国内一般访问不了。所以需要手动下载。
本人的timm版本是 0.9.11的。
解决办法:
pretrained_cfg = timm.models.create_model(backbone_name).default_cfg
print(pretrained_cfg)
其中backbone_name 是你要创建的模型名比如resnet50,自己修改下。
打印pretrained_cfg信息我们就可以得到以下这条信息:
{'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/wide_resnet50_racm-8234f177.pth', 'hf_hub_id': 'timm/wide_resnet50_2.racm_in1k', 'architecture': 'wide_resnet50_2', 'tag': 'racm_in1k', 'custom_load': False, 'input_size': (3, 224, 224), 'test_input_size': (3, 288, 288), 'fixed_input_size': False, 'interpolation': 'bicubic', 'crop_pct': 0.875, 'test_crop_pct': 0.95, 'crop_mode': 'center', 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), 'num_classes': 1000, 'pool_size': (7, 7), 'first_conv': 'conv1', 'classifier': 'fc', 'origin_url': 'https://github.com/huggingface/pytorch-image-models'}
上面的url就是权重下载链接,我这里下载的是wide_resnet50_2的权重。
通过这个网址把权重下载下来,放到对应的目录地址checkpoint_path,比如checkpoint_path = ‘/home/xxx/wide_resnet50_racm-8234f177.pth’。
然后把创建模型的代码改成以下就解决了。
self.backbone = timm.create_model(model_name=backbone_name, pretrained=True, pretrained_cfg_overlay=dict(file=checkpoint_path),** kwargs)