一、ResNet简介
- ResNet是一次CNN网络架构,核心思想是引入"残差学习"来解决深层网络难以训练的问题。
- 在传统的网络中,每一层都直接尝试学习目标映射。相反,ResNet通过跨层连接,允许某一层学习输入与输出之间的残差(或者说是差异),使得这些网络层只需要学习与输入的微小差异,从而简化了学习目标和过程。
二、FashionMNIST数据集简介
- 之前的博客已经较为细致的介绍了FashionMNIST数据集:插眼传送
注意:了解数据集是机器学习的所有环节中最重要的一步,没有之一。
三、用代码实现FashionMNIST预测
1.导包
from torchvision.datasets import FashionMNIST
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torch
import numpy as np
import random
2.加载数据
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = 'cpu'
generator = torch.Generator()# 设置随机种子,确保实验可重复性
seed_value = 420
torch.manual_seed(seed_value)
random.seed(seed_value)
np.random.seed(seed_value)
# 如果你使用CUDA并希望进一步确定性,可以添加下面两行代码
torch.cuda.manual_seed(seed_value)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
generator.manual_seed(seed_value)transform = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.RandomRotation([-8,8]),transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,)) # 归一化处理
])transform2 = transforms.Compose([#transforms.RandomRotation([-5,5]),transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,)) # 归一化处理
])# 从"./dataset/"目录加载FashionMNIST数据集,如果没有则会自动下载。
train_data = FashionMNIST(root='./dataset/', train=True, download=True,transform=transform)
test_data = FashionMNIST(root='./dataset/', train=False, download=True,transform=transform2)
train_batch = DataLoader(dataset=train_data, batch_size=128, shuffle=True, num_workers=0, drop_last=False, generator=generator)
test_batch = DataLoader(dataset=test_data, batch_size=128, shuffle=False, num_workers=0, drop_last=False, generator=generator)
3.定义模型
class Model(torch.nn.Module):def __init__(self,in_features=1,out_features=10):super().__init__()self.relu = torch.nn.ReLU()self.conv1 = torch.nn.Conv2d(in_channels=in_features, out_channels=64, kernel_size=3, bias=False) self.adavgpool = torch.nn.AdaptiveAvgPool2d((1, 1))self.block1 = torch.nn.Sequential(self.conv1, torch.nn.BatchNorm2d(64), self.relu)self.output = torch.nn.Linear(512, out_features, bias=True)self.maxpool = torch.nn.AvgPool2d(2,ceil_mode=True)self.downsample = torch.nn.Sequential(torch.nn.Conv2d(64, 128, kernel_size=1,stride=2,bias=False),torch.nn.BatchNorm2d(128))self.downsample2 = torch.nn.Sequential(torch.nn.Conv2d(128, 256, kernel_size=1,stride=2, bias=False), torch.nn.BatchNorm2d(256))self.downsample3 = torch.nn.Sequential(torch.nn.Conv2d(256, 512, kernel_size=1,stride=2, bias=False), torch.nn.BatchNorm2d(512))self.conv_res = torch.nn.Sequential(torch.nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1, stride=2, bias=False),torch.nn.BatchNorm2d(128),self.relu,torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1, bias=False),torch.nn.BatchNorm2d(128),)self.conv_res2 = torch.nn.Sequential(torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1, stride=1, bias=False),torch.nn.BatchNorm2d(128),self.relu,torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1, bias=False),torch.nn.BatchNorm2d(128),)self.conv_res3 = torch.nn.Sequential(torch.nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1, stride=2, bias=False),torch.nn.BatchNorm2d(256),self.relu,torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1, bias=False),torch.nn.BatchNorm2d(256),)self.conv_res4 = torch.nn.Sequential(torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1, stride=1, bias=False),torch.nn.BatchNorm2d(256),self.relu,torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1, bias=False),torch.nn.BatchNorm2d(256),)self.conv_res5 = torch.nn.Sequential(torch.nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1, stride=2, bias=False),torch.nn.BatchNorm2d(512),self.relu,torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1, bias=False),torch.nn.BatchNorm2d(512),)self.conv_res6 = torch.nn.Sequential(torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1, stride=1, bias=False),torch.nn.BatchNorm2d(512),self.relu,torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1, bias=False),torch.nn.BatchNorm2d(512),)def forward(self,x):x = self.block1(x)identity = self.downsample(x)x = self.conv_res(x)x += identityx = self.relu(x)x = self.conv_res2(x)x += identityx = self.relu(x)identity = self.downsample2(x)x = self.conv_res3(x)x += identityx = self.relu(x)x = self.conv_res4(x)x += identityx = self.relu(x)identity = self.downsample3(x)x = self.conv_res5(x)x += identityx = self.relu(x)x = self.conv_res6(x)x += identityx = self.relu(x)x = self.adavgpool(x)x = x.view(len(x), -1)x = self.output(x)return x
注意:此模型不是完整的ResNet网络,这里做了部分修改,以适应当前图片尺寸。
4.定义损失函数、优化器
from torch.optim import Adam
from torch.nn import functional as F# 初始化一个模型,输入图片通道数为1,输出特征为10
model = Model().to(device)
# 使用负对数似然损失函数
criterion = torch.nn.CrossEntropyLoss()
# 初始化Adam优化器,设定学习率为0.005
opt = Adam(model.parameters(), lr=0.001)
5.开始训练
# 进行9次迭代
for _ in range(49):# 遍历数据批次for n_, batch in enumerate(train_batch):# 将输入数据X调整形状并输入到模型X = batch[0].to(device)# y为真实标签y = batch[1].to(device)# 前向传播,获取模型输出sigma = model.forward(X)# 计算损失loss = criterion(sigma, y)# 计算预测的标签y_hat = torch.max(sigma, dim=1)[1]# 计算预测正确的数量correct_count = torch.sum(y_hat == y)# 计算准确率accuracy = correct_count / len(y) * 100# 反向传播,计算梯度loss.backward()# 更新模型参数opt.step()# 清除之前的梯度model.zero_grad()# 打印当前批次的损失和准确率print(n_, 'loss:', loss.item(), 'accuracy:', accuracy.item())
输出:
468 loss: 0.21156974136829376 accuracy: 91.66667175292969
468 loss: 0.24343211948871613 accuracy: 89.58333587646484
468 loss: 0.3186508119106293 accuracy: 87.5
468 loss: 0.16633149981498718 accuracy: 93.75
468 loss: 0.13033141195774078 accuracy: 93.75
468 loss: 0.09412961453199387 accuracy: 97.91667175292969
468 loss: 0.044871985912323 accuracy: 98.95833587646484
468 loss: 0.023767223581671715 accuracy: 100.0
468 loss: 0.09273606538772583 accuracy: 97.91667175292969
...
6.验证测试集
correct_count = 0
for batch in test_batch:test_X = batch[0].to(device)test_y = batch[1].to(device)sigma = model.forward(torch.tensor(test_X, dtype=torch.float32))y_hat = torch.max(sigma, dim=1)[1]correct_count += torch.sum(y_hat == test_y)accuracy = correct_count / 10000 * 100
print('accuracy:', accuracy.item())
输出:
accuracy: 93.73999786376953
- 可以看出:ResNet相对于googleNet,处理时间减少,正确率提高。在FashionMNIST数据集上有较优的表现。