还记得这篇文章吗?迁移学习|代码实现
在这篇文章中,我们知道了在构建模型时,可以借助一些非常有名的模型,这些模型在ImageNet数据集上早已经得到了检验。
同时torchvision模块也提供了预训练好的模型。我们只需稍作修改,便可运用到自己的实际任务中!
我们仍然按照这个步骤开始我们的模型的训练
-
准备一个可迭代的数据集
-
定义一个神经网络
-
将数据集输入到神经网络进行处理
-
计算损失
-
通过梯度下降算法更新参数
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision import models
数据集准备
cifar10_train = torchvision.datasets.CIFAR10(
root = 'cifar10/',
train = True,
download = True
)
cifar10_test=torchvision.datasets.CIFAR10(
root = 'cifar10/',
train = False,
download = True
)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((224,224))
])
cifar2_train=[(transform(img),[3,5].index(label)) for img,label in cifar10_train if label in [3,5]]
cifar2_test=[(transform(img),[3,5].index(label)) for img,label in cifar10_test if label in [3,5]]
train_loader = torch.utils.data.DataLoader(cifar2_train, batch_size=64,shuffle=True)
test_loader = torch.utils.data.DataLoader(cifar2_test, batch_size=64,shuffle=True)
数据集使用CIFAR-10数据集中的猫和狗。
CIFAR-10数据集类别
种类 标签
-
plane 0
-
car 1
-
bird 2
-
cat 3
-
deer 4
-
dog 5
-
frog 6
-
horse 7
-
ship 8
-
truck 9
可以看到其中cat和dog的标签分别为3和5
借助:
[3,5].index(label)
我们可以将cat标签变为0,dog标签变为1,从而回到二分类问题。
举个例子:
>>> [3,5].index(3)
0
>>> [3,5].index(5)
1
定义模型
参考这篇文章:迁移学习|代码实现
#网络搭建
network=models.resnet18(pretrained=True)
for param in network.parameters():
param.requires_grad=False
network.fc=nn.Linear(512,2)
#损失函数
criterion=nn.CrossEntropyLoss()
#优化器
optimizer=optim.SGD(network.fc.parameters(),lr=0.01,momentum=0.9)
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
network=network.to(device)
训练模型:
for epoch in range(10):
total_loss = 0
total_correct = 0
for batch in train_loader: # Get batch
images, labels =batch
images=images.to(device)
labels=labels.to(device)
optimizer.zero_grad() #告诉优化器把梯度属性中权重的梯度归零,否则pytorch会累积梯度
preds = network(images)
loss = criterion(preds, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
_,prelabels=torch.max(preds,dim=1)
total_correct += int((prelabels==labels).sum())
accuracy = total_correct/len(cifar2_train)
print("Epoch:%d , Loss:%f , Accuracy:%f "%(epoch,total_loss,accuracy))
-
Epoch:0 , Loss:78.549439 , Accuracy:0.788900
-
Epoch:1 , Loss:77.828066 , Accuracy:0.801500
-
Epoch:2 , Loss:66.151785 , Accuracy:0.828100
-
Epoch:3 , Loss:76.204446 , Accuracy:0.816800
-
Epoch:4 , Loss:68.886606 , Accuracy:0.828100
-
Epoch:5 , Loss:71.129405 , Accuracy:0.821200
-
Epoch:6 , Loss:66.096364 , Accuracy:0.829900
-
Epoch:7 , Loss:65.504227 , Accuracy:0.827700
-
Epoch:8 , Loss:76.303878 , Accuracy:0.817100
-
Epoch:9 , Loss:70.546953 , Accuracy:0.820700
测试模型:
correct=0
total=0
network.eval()
with torch.no_grad():
for batch in test_loader:
imgs,labels=batch
imgs=imgs.cuda()
labels=labels.cuda()
preds=network(imgs)
_,prelabels=torch.max(preds,dim=1)
#print(prelabels.size())
total=total+labels.size(0)
correct=correct+int((prelabels==labels).sum())
#print(total)
accuracy=correct/total
print("Accuracy: ",accuracy)
Accuracy: 0.8025
这里使用的预训练模型是resnet18,我们也可以使用VGG16模型,同时记得改变最后一个全连接层的输出参数,使得其满足我们自己的任务。
除了预训练模型之外,我们还可以对一些超参数进行调整,使最后的效果变得更好!