ResNet(残差网络)是深度学习中的经典模型,通过引入残差连接解决了深层网络训练中的梯度消失问题。本文将从残差块的定义开始,逐步实现一个ResNet模型,并在Fashion MNIST数据集上进行训练和测试。
1. 残差块(Residual Block)实现
残差块通过跳跃连接(Shortcut Connection)将输入直接传递到输出,缓解了深层网络的训练难题。以下是残差块的PyTorch实现:
import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2lclass Residual(nn.Module):def __init__(self, input_channels, num_channels, use_1x1conv=False, strides=1):super().__init__()self.conv1 = nn.Conv2d(input_channels, num_channels, kernel_size=3, padding=1, stride=strides)self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1)self.bn1 = nn.BatchNorm2d(num_channels)self.bn2 = nn.BatchNorm2d(num_channels)if use_1x1conv:self.conv3 = nn.Conv2d(input_channels, num_channels, kernel_size=1, stride=strides)else:self.conv3 = Noneself.relu = nn.ReLU(inplace=True)def forward(self, x):y = F.relu(self.bn1(self.conv1(x)))y = self.bn2(self.conv2(y))if self.conv3:x = self.conv3(x)y += xreturn F.relu(y)
代码解析:
-
use_1x1conv
:当输入和输出通道数不一致时,使用1x1卷积调整通道数。 -
strides
:控制特征图下采样的步长。 -
残差相加后再次使用ReLU激活,增强非线性表达能力。
2. 构建ResNet模型
ResNet由多个残差块堆叠而成,以下代码构建了一个简化版ResNet-18:
# 初始卷积层
b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),nn.BatchNorm2d(64),nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)def resnet_block(input_channels, num_channels, num_residuals, first_block=False):blk = []for i in range(num_residuals):if i == 0 and not first_block: # 第一个块需下采样blk.append(Residual(input_channels, num_channels, use_1x1conv=True, strides=2))else:blk.append(Residual(num_channels, num_channels))return blk# 堆叠残差块
b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))
b3 = nn.Sequential(*resnet_block(64, 128, 2))
b4 = nn.Sequential(*resnet_block(128, 256, 2))
b5 = nn.Sequential(*resnet_block(256, 512, 2))# 完整网络结构
net = nn.Sequential(b1, b2, b3, b4, b5,nn.AdaptiveAvgPool2d((1, 1)),nn.Flatten(),nn.Linear(512, 10)
)
模型结构说明:
-
AdaptiveAvgPool2d
:自适应平均池化,将特征图尺寸统一为1x1。 -
Flatten
:展平特征用于全连接层分类。
3. 数据加载与预处理
使用Fashion MNIST数据集,批量大小为256:
train_data, test_data = d2l.load_data_fashion_mnist(batch_size=256)
4. 模型训练与测试
设置训练参数:10个epoch,学习率0.05,并使用GPU加速:
d2l.train_ch6(net, train_data, test_data, num_epochs=10, lr=0.05, device=d2l.try_gpu())
训练结果:
loss 0.124, train acc 0.952, test acc 0.860
4921.4 examples/sec on cuda:0
5. 结果可视化
训练过程中损失和准确率变化如下图所示:
分析:
-
训练准确率(紫色虚线)迅速上升并稳定在95%以上。
-
测试准确率(绿色点线)达到86%,表明模型具有良好的泛化能力。
-
损失值(蓝色实线)持续下降,未出现过拟合。
6. 完整代码
整合所有代码片段(需安装d2l
库):
# 残差块定义、模型构建、训练代码见上文
7. 总结
本文实现了ResNet的核心组件——残差块,并构建了一个简化版ResNet模型。通过实验验证,模型在Fashion MNIST数据集上表现良好。读者可尝试调整网络深度或超参数以进一步提升性能。
改进方向:
-
增加残差块数量构建更深的ResNet(如ResNet-34/50)。
-
使用数据增强策略提升泛化能力。
-
尝试不同的优化器和学习率调度策略。
注意事项:
-
确保已安装PyTorch和
d2l
库。 -
GPU环境可显著加速训练,若使用CPU需调整批量大小。
希望本文能帮助您理解ResNet的实现细节!如有疑问,欢迎在评论区留言讨论。