之前的demo都是模仿和简化了已有的模型,也可以直接调用orhvision的标准模型,代码将更加简单。
新建resnet18.py
import torch.nn as nn
from torchvision import modelsclass ResNet18(nn.Module):def __init__(self, num_classes=10):super(ResNet18, self).__init__()self.model = models.resnet18(pretrained=True) # 调用torchvision.models中的resnet18self.num_ftrs = self.model.fc.in_features # 获取全连接层的输入特征数self.model.fc = nn.Linear(self.num_ftrs, num_classes) # 修改全连接层def forward(self, x):out = self.model(x)return outdef resnet18():return ResNet18()
在之前的train.py脚本导入模型,并修改脚本中的net定义,改为:
net = resnet18().to(device),
即可运行开始训练,首次运行,会自动下载模型:
下载完之后就开始训练: