快速开始
安装 timm
pip install timm
timm.create_model
(model_name: str,pretrained: bool = False,pretrained_cfg:Union = None,pretrained_cfg_overlay: Optional = None,checkpoint_path: str = '',scriptable: Optional = None,exportable: Optional = None,no_jit: Optional = None, **kwargs)
timm.create_model
详细解读
create_model
函数用于创建一个模型。它的参数如下:
model_name
: 模型名称 (字符串)。pretrained
: 是否加载预训练权重 (布尔值,默认值为 False)。pretrained_cfg
: 预训练配置 (可选)。pretrained_cfg_overlay
: 预训练配置覆盖 (可选)。checkpoint_path
: 检查点路径 (字符串,默认值为空)。scriptable
: 是否可脚本化 (可选)。exportable
: 是否可导出 (可选)。no_jit
: 是否禁用 JIT 编译 (可选)。**kwargs
: 其他关键字参数。
关键字参数
drop_rate
: 分类器训练时的 dropout 率 (浮点数)。drop_path_rate
: 训练时随机深度 drop 路径率 (浮点数)。global_pool
: 分类器的全局池化类型 (字符串)。
示例
from timm import create_model# 创建一个没有预训练权重的 MobileNetV3-Large 模型。
model = create_model('mobilenetv3_large_100')# 创建一个带有预训练权重的 MobileNetV3-Large 模型。
model = create_model('mobilenetv3_large_100', pretrained=True)
model.num_classes # 1000# 创建一个带有预训练权重和新分类头的 MobileNetV3-Large 模型 (10 类)。
model = create_model('mobilenetv3_large_100', pretrained=True, num_classes=10)
model.num_classes # 10
这个函数会通过入口函数将相关参数传递给 timm.models.build_model_with_cfg
,然后调用模型类的 __init__
方法。如果 kwargs
的值为 None
,则在传递前会被剔除。
加载预训练模型
import timm
model = timm.create_model('mobilenetv3_large_100', pretrained=True)
model.eval()
注意:返回的 PyTorch 模型默认设置为训练模式,因此如果你计划使用它进行推理,则必须在其上调用 .eval()。
列出预训练模型
import timm
from pprint import pprint
model_names = timm.list_models(pretrained=True)
pprint(model_names)
微调预训练模型
model = timm.create_model('mobilenetv3_large_100', pretrained=True, num_classes=NUM_FINETUNE_CLASSES)
特征提取
x = torch.randn(1, 3, 224, 224)
features = model.forward_features(x)
print(features.shape)
图像增强
transform = timm.data.create_transform((3, 224, 224))
预处理数据
data_cfg = timm.data.resolve_data_config(model.pretrained_cfg)
transform = timm.data.create_transform(**data_cfg)
使用预训练模型进行推理
image = Image.open(requests.get(url, stream=True).raw)
image_tensor = transform(image).unsqueeze(0)
output = model(image_tensor)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
values, indices = torch.topk(probabilities, 5)
特征提取
特征提取
倒数第二层特征 (分类器前特征)
你可以通过多种方式获取模型的倒数第二层特征,无需修改模型:
-
未池化特征:
- 使用
model.forward_features(input)
来获取未池化特征。 - 创建模型时不包含分类器和池化层。
- 使用
reset_classifier(0, '')
移除分类器和池化层。
- 使用
-
池化特征:
- 使用
model.forward_features()
并手动池化结果。 - 创建模型时只移除分类器。
- 使用
多尺度特征图 (特征金字塔)
可以创建一个只输出特征图的模型,使用 features_only=True
参数,并可通过 out_indices
指定输出哪些层的特征。