深度学习Day-21:ResNet与DenseNet结合

 🍨 本文为:[🔗365天深度学习训练营] 中的学习记录博客
 🍖 原作者:[K同学啊 | 接辅导、项目定制]

要求:

  1. 探索ResNet与DenseNet结合的可能性
  2. 根据模型特性构建新的模型框架
  3. 验证改进后模型的效果

一、 基础配置

  • 语言环境:Python3.8
  • 编译器选择:Pycharm
  • 深度学习环境:
    • torch==1.12.1+cu113
    • torchvision==0.13.1+cu113

二、 前期准备 

1.设置GPU

import pathlib
import torch
import torch.nn as nn
from torchvision import transforms, datasetsdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(device)

2. 导入数据

本项目所采用的数据集未收录于公开数据中,故需要自己在文件目录中导入相应数据集合,并设置对应文件目录,以供后续学习过程中使用。

运行下述代码:

data_dir = './data/bird_photos'
data_dir = pathlib.Path(data_dir)data_paths = list(data_dir.glob('*'))
classeNames = [str(path).split("\\")[2] for path in data_paths]
print(classeNames)image_count = len(list(data_dir.glob('*/*')))
print("图片总数为:", image_count)

得到如下输出:

['Bananaquit', 'Black Skimmer', 'Black Throated Bushtiti', 'Cockatoo']
图片总数为: 565

接下来,我们通过transforms.Compose对整个数据集进行预处理:

train_transforms = transforms.Compose([transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸# transforms.RandomHorizontalFlip(), # 随机水平翻转transforms.ToTensor(),  # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间transforms.Normalize(  # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])test_transform = transforms.Compose([transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸transforms.ToTensor(),  # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间transforms.Normalize(  # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])total_data = datasets.ImageFolder("./data/bird_photos/", transform=train_transforms)
print(total_data.class_to_idx)

得到如下输出:

{'Bananaquit': 0, 'Black Skimmer': 1, 'Black Throated Bushtiti': 2, 'Cockatoo': 3}

3. 划分数据集

 此处数据集需要做按比例划分的操作:

train_size = int(0.8 * len(total_data))
test_size = len(total_data) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size])

接下来,根据划分得到的训练集和验证集对数据集进行包装:

batch_size = 32
train_dl = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=0)
test_dl = torch.utils.data.DataLoader(test_dataset,batch_size=batch_size,shuffle=True,num_workers=0)

并通过:

for X, y in test_dl:print("Shape of X [N, C, H, W]: ", X.shape)print("Shape of y: ", y.shape, y.dtype)break

输出测试数据集的数据分布情况:

Shape of X [N, C, H, W]:  torch.Size([32, 3, 224, 224])
Shape of y:  torch.Size([32]) torch.int64

4.搭建模型

DPN网络通过High Order RNN(HORNN)将ResNet和DenseNet进行了融合,实现了ResNet特征复用及DenseNet特征生成,在保持了二者复用特征+挖掘特征能力的同时,避免了像原始DenseNet那样臃肿的结构。

1.模型搭建


class Block(nn.Module):def __init__(self, in_channel, mid_channel, out_channel, dense_channel, stride, groups, is_shortcut=False):# in_channel,是输入通道数,mid_channel是中间经历的通道数,out_channels是经过一次板块之后的输出通道数。# dense_channels设置这个参数的原因就是一边进行着resnet方式的卷积运算,另一边也同时进行着dense的卷积计算,之后特征图融合形成新的特征图super().__init__()self.is_shortcut = is_shortcutself.out_channel = out_channelself.conv1 = nn.Sequential(nn.Conv2d(in_channel, mid_channel, kernel_size=1, bias=False),nn.BatchNorm2d(mid_channel),nn.ReLU())self.conv2 = nn.Sequential(nn.Conv2d(mid_channel, mid_channel, kernel_size=3, stride=stride, padding=1, groups=groups, bias=False),nn.BatchNorm2d(mid_channel),nn.ReLU())self.conv3 = nn.Sequential(nn.Conv2d(mid_channel, out_channel + dense_channel, kernel_size=1, bias=False),nn.BatchNorm2d(out_channel + dense_channel))if self.is_shortcut:self.shortcut = nn.Sequential(nn.Conv2d(in_channel, out_channel + dense_channel, kernel_size=3, padding=1, stride=stride, bias=False),nn.BatchNorm2d(out_channel + dense_channel))self.relu = nn.ReLU(inplace=True)def forward(self, x):a = xx = self.conv1(x)x = self.conv2(x)x = self.conv3(x)if self.is_shortcut:a = self.shortcut(a)d = self.out_channelx = torch.cat([a[:, :d, :, :] + x[:, :d, :, :], a[:, d:, :, :], x[:, d:, :, :]], dim=1)x = self.relu(x)return xclass DPN(nn.Module):def __init__(self, cfg):super(DPN, self).__init__()self.group = cfg['group']self.in_channel = cfg['in_channel']mid_channels = cfg['mid_channels']out_channels = cfg['out_channels']dense_channels = cfg['dense_channels']num = cfg['num']self.conv1 = nn.Sequential(nn.Conv2d(3, self.in_channel, 7, stride=2, padding=3, bias=False, padding_mode='zeros'),nn.BatchNorm2d(self.in_channel),nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=0))self.conv2 = self._make_layers(mid_channels[0], out_channels[0], dense_channels[0], num[0], stride=1)self.conv3 = self._make_layers(mid_channels[1], out_channels[1], dense_channels[1], num[1], stride=2)self.conv4 = self._make_layers(mid_channels[2], out_channels[2], dense_channels[2], num[2], stride=2)self.conv5 = self._make_layers(mid_channels[3], out_channels[3], dense_channels[3], num[3], stride=2)self.pool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(cfg['out_channels'][3] + (num[3] + 1) * cfg['dense_channels'][3], cfg['classes'])  # fc层需要计算def _make_layers(self, mid_channel, out_channel, dense_channel, num, stride=2):layers = []layers.append(Block(self.in_channel, mid_channel, out_channel, dense_channel, stride=stride, groups=self.group,is_shortcut=True))# block_1里面is_shortcut=True就是resnet中的shortcut连接,将浅层的特征进行一次卷积之后与进行三次卷积的特征图相加# 后面几次相同的板块is_shortcut=False简单的理解就是一个多次重复的板块,第一次利用就可以满足浅层特征的利用,后面重复的不在需要self.in_channel = out_channel + dense_channel * 2# 由于里面包含dense这种一直在叠加的特征图计算,# 所以第一次是2倍的dense_channel,后面每次一都会多出1倍,所以有(i+2)*dense_channelfor i in range(1, num):layers.append(Block(self.in_channel, mid_channel, out_channel, dense_channel, stride=1, groups=self.group))self.in_channel = self.in_channel + dense_channel# self.in_channel = out_channel + (i+2)*dense_channelreturn nn.Sequential(*layers)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = self.conv4(x)x = self.conv5(x)x = self.pool(x)x = torch.flatten(x, start_dim=1)x = self.fc(x)return xdef DPN92(n_class=10):cfg = {'group': 32,'in_channel': 64,'mid_channels': (96, 192, 384, 768),'out_channels': (256, 512, 1024, 2048),'dense_channels': (16, 32, 24, 128),'num': (3, 4, 20, 3),'classes': (n_class)}return DPN(cfg)def DPN98(n_class=10):cfg = {'group': 40,'in_channel': 96,'mid_channels': (160, 320, 640, 1280),'out_channels': (256, 512, 1024, 2048),'dense_channels': (16, 32, 32, 128),'num': (3, 6, 20, 3),'classes': (n_class)}return DPN(cfg)

2.查看模型信息

x = torch.randn(2, 3, 224, 224)
model = DPN98(4)
model.to(device)
import torchsummary as summary
summary.summary(model, (3, 224, 224))

得到如下输出:

----------------------------------------------------------------Layer (type)               Output Shape         Param #
================================================================Conv2d-1         [-1, 96, 112, 112]          14,112BatchNorm2d-2         [-1, 96, 112, 112]             192ReLU-3         [-1, 96, 112, 112]               0MaxPool2d-4           [-1, 96, 55, 55]               0Conv2d-5          [-1, 160, 55, 55]          15,360BatchNorm2d-6          [-1, 160, 55, 55]             320ReLU-7          [-1, 160, 55, 55]               0Conv2d-8          [-1, 160, 55, 55]           5,760BatchNorm2d-9          [-1, 160, 55, 55]             320ReLU-10          [-1, 160, 55, 55]               0Conv2d-11          [-1, 272, 55, 55]          43,520BatchNorm2d-12          [-1, 272, 55, 55]             544Conv2d-13          [-1, 272, 55, 55]         235,008BatchNorm2d-14          [-1, 272, 55, 55]             544ReLU-15          [-1, 288, 55, 55]               0Block-16          [-1, 288, 55, 55]               0Conv2d-17          [-1, 160, 55, 55]          46,080BatchNorm2d-18          [-1, 160, 55, 55]             320ReLU-19          [-1, 160, 55, 55]               0Conv2d-20          [-1, 160, 55, 55]           5,760BatchNorm2d-21          [-1, 160, 55, 55]             320ReLU-22          [-1, 160, 55, 55]               0Conv2d-23          [-1, 272, 55, 55]          43,520BatchNorm2d-24          [-1, 272, 55, 55]             544ReLU-25          [-1, 304, 55, 55]               0Block-26          [-1, 304, 55, 55]               0Conv2d-27          [-1, 160, 55, 55]          48,640BatchNorm2d-28          [-1, 160, 55, 55]             320ReLU-29          [-1, 160, 55, 55]               0Conv2d-30          [-1, 160, 55, 55]           5,760BatchNorm2d-31          [-1, 160, 55, 55]             320ReLU-32          [-1, 160, 55, 55]               0Conv2d-33          [-1, 272, 55, 55]          43,520BatchNorm2d-34          [-1, 272, 55, 55]             544ReLU-35          [-1, 320, 55, 55]               0Block-36          [-1, 320, 55, 55]               0Conv2d-37          [-1, 320, 55, 55]         102,400BatchNorm2d-38          [-1, 320, 55, 55]             640ReLU-39          [-1, 320, 55, 55]               0Conv2d-40          [-1, 320, 28, 28]          23,040BatchNorm2d-41          [-1, 320, 28, 28]             640ReLU-42          [-1, 320, 28, 28]               0Conv2d-43          [-1, 544, 28, 28]         174,080BatchNorm2d-44          [-1, 544, 28, 28]           1,088Conv2d-45          [-1, 544, 28, 28]       1,566,720BatchNorm2d-46          [-1, 544, 28, 28]           1,088ReLU-47          [-1, 576, 28, 28]               0Block-48          [-1, 576, 28, 28]               0Conv2d-49          [-1, 320, 28, 28]         184,320BatchNorm2d-50          [-1, 320, 28, 28]             640ReLU-51          [-1, 320, 28, 28]               0Conv2d-52          [-1, 320, 28, 28]          23,040BatchNorm2d-53          [-1, 320, 28, 28]             640ReLU-54          [-1, 320, 28, 28]               0Conv2d-55          [-1, 544, 28, 28]         174,080BatchNorm2d-56          [-1, 544, 28, 28]           1,088ReLU-57          [-1, 608, 28, 28]               0Block-58          [-1, 608, 28, 28]               0Conv2d-59          [-1, 320, 28, 28]         194,560BatchNorm2d-60          [-1, 320, 28, 28]             640ReLU-61          [-1, 320, 28, 28]               0Conv2d-62          [-1, 320, 28, 28]          23,040BatchNorm2d-63          [-1, 320, 28, 28]             640ReLU-64          [-1, 320, 28, 28]               0Conv2d-65          [-1, 544, 28, 28]         174,080BatchNorm2d-66          [-1, 544, 28, 28]           1,088ReLU-67          [-1, 640, 28, 28]               0Block-68          [-1, 640, 28, 28]               0Conv2d-69          [-1, 320, 28, 28]         204,800BatchNorm2d-70          [-1, 320, 28, 28]             640ReLU-71          [-1, 320, 28, 28]               0Conv2d-72          [-1, 320, 28, 28]          23,040BatchNorm2d-73          [-1, 320, 28, 28]             640ReLU-74          [-1, 320, 28, 28]               0Conv2d-75          [-1, 544, 28, 28]         174,080BatchNorm2d-76          [-1, 544, 28, 28]           1,088ReLU-77          [-1, 672, 28, 28]               0Block-78          [-1, 672, 28, 28]               0Conv2d-79          [-1, 320, 28, 28]         215,040BatchNorm2d-80          [-1, 320, 28, 28]             640ReLU-81          [-1, 320, 28, 28]               0Conv2d-82          [-1, 320, 28, 28]          23,040BatchNorm2d-83          [-1, 320, 28, 28]             640ReLU-84          [-1, 320, 28, 28]               0Conv2d-85          [-1, 544, 28, 28]         174,080BatchNorm2d-86          [-1, 544, 28, 28]           1,088ReLU-87          [-1, 704, 28, 28]               0Block-88          [-1, 704, 28, 28]               0Conv2d-89          [-1, 320, 28, 28]         225,280BatchNorm2d-90          [-1, 320, 28, 28]             640ReLU-91          [-1, 320, 28, 28]               0Conv2d-92          [-1, 320, 28, 28]          23,040BatchNorm2d-93          [-1, 320, 28, 28]             640ReLU-94          [-1, 320, 28, 28]               0Conv2d-95          [-1, 544, 28, 28]         174,080BatchNorm2d-96          [-1, 544, 28, 28]           1,088ReLU-97          [-1, 736, 28, 28]               0Block-98          [-1, 736, 28, 28]               0Conv2d-99          [-1, 640, 28, 28]         471,040BatchNorm2d-100          [-1, 640, 28, 28]           1,280ReLU-101          [-1, 640, 28, 28]               0Conv2d-102          [-1, 640, 14, 14]          92,160BatchNorm2d-103          [-1, 640, 14, 14]           1,280ReLU-104          [-1, 640, 14, 14]               0Conv2d-105         [-1, 1056, 14, 14]         675,840BatchNorm2d-106         [-1, 1056, 14, 14]           2,112Conv2d-107         [-1, 1056, 14, 14]       6,994,944BatchNorm2d-108         [-1, 1056, 14, 14]           2,112ReLU-109         [-1, 1088, 14, 14]               0Block-110         [-1, 1088, 14, 14]               0Conv2d-111          [-1, 640, 14, 14]         696,320BatchNorm2d-112          [-1, 640, 14, 14]           1,280ReLU-113          [-1, 640, 14, 14]               0Conv2d-114          [-1, 640, 14, 14]          92,160BatchNorm2d-115          [-1, 640, 14, 14]           1,280ReLU-116          [-1, 640, 14, 14]               0Conv2d-117         [-1, 1056, 14, 14]         675,840BatchNorm2d-118         [-1, 1056, 14, 14]           2,112ReLU-119         [-1, 1120, 14, 14]               0Block-120         [-1, 1120, 14, 14]               0Conv2d-121          [-1, 640, 14, 14]         716,800BatchNorm2d-122          [-1, 640, 14, 14]           1,280ReLU-123          [-1, 640, 14, 14]               0Conv2d-124          [-1, 640, 14, 14]          92,160BatchNorm2d-125          [-1, 640, 14, 14]           1,280ReLU-126          [-1, 640, 14, 14]               0Conv2d-127         [-1, 1056, 14, 14]         675,840BatchNorm2d-128         [-1, 1056, 14, 14]           2,112ReLU-129         [-1, 1152, 14, 14]               0Block-130         [-1, 1152, 14, 14]               0Conv2d-131          [-1, 640, 14, 14]         737,280BatchNorm2d-132          [-1, 640, 14, 14]           1,280ReLU-133          [-1, 640, 14, 14]               0Conv2d-134          [-1, 640, 14, 14]          92,160BatchNorm2d-135          [-1, 640, 14, 14]           1,280ReLU-136          [-1, 640, 14, 14]               0Conv2d-137         [-1, 1056, 14, 14]         675,840BatchNorm2d-138         [-1, 1056, 14, 14]           2,112ReLU-139         [-1, 1184, 14, 14]               0Block-140         [-1, 1184, 14, 14]               0Conv2d-141          [-1, 640, 14, 14]         757,760BatchNorm2d-142          [-1, 640, 14, 14]           1,280ReLU-143          [-1, 640, 14, 14]               0Conv2d-144          [-1, 640, 14, 14]          92,160BatchNorm2d-145          [-1, 640, 14, 14]           1,280ReLU-146          [-1, 640, 14, 14]               0Conv2d-147         [-1, 1056, 14, 14]         675,840BatchNorm2d-148         [-1, 1056, 14, 14]           2,112ReLU-149         [-1, 1216, 14, 14]               0Block-150         [-1, 1216, 14, 14]               0Conv2d-151          [-1, 640, 14, 14]         778,240BatchNorm2d-152          [-1, 640, 14, 14]           1,280ReLU-153          [-1, 640, 14, 14]               0Conv2d-154          [-1, 640, 14, 14]          92,160BatchNorm2d-155          [-1, 640, 14, 14]           1,280ReLU-156          [-1, 640, 14, 14]               0Conv2d-157         [-1, 1056, 14, 14]         675,840BatchNorm2d-158         [-1, 1056, 14, 14]           2,112ReLU-159         [-1, 1248, 14, 14]               0Block-160         [-1, 1248, 14, 14]               0Conv2d-161          [-1, 640, 14, 14]         798,720BatchNorm2d-162          [-1, 640, 14, 14]           1,280ReLU-163          [-1, 640, 14, 14]               0Conv2d-164          [-1, 640, 14, 14]          92,160BatchNorm2d-165          [-1, 640, 14, 14]           1,280ReLU-166          [-1, 640, 14, 14]               0Conv2d-167         [-1, 1056, 14, 14]         675,840BatchNorm2d-168         [-1, 1056, 14, 14]           2,112ReLU-169         [-1, 1280, 14, 14]               0Block-170         [-1, 1280, 14, 14]               0Conv2d-171          [-1, 640, 14, 14]         819,200BatchNorm2d-172          [-1, 640, 14, 14]           1,280ReLU-173          [-1, 640, 14, 14]               0Conv2d-174          [-1, 640, 14, 14]          92,160BatchNorm2d-175          [-1, 640, 14, 14]           1,280ReLU-176          [-1, 640, 14, 14]               0Conv2d-177         [-1, 1056, 14, 14]         675,840BatchNorm2d-178         [-1, 1056, 14, 14]           2,112ReLU-179         [-1, 1312, 14, 14]               0Block-180         [-1, 1312, 14, 14]               0Conv2d-181          [-1, 640, 14, 14]         839,680BatchNorm2d-182          [-1, 640, 14, 14]           1,280ReLU-183          [-1, 640, 14, 14]               0Conv2d-184          [-1, 640, 14, 14]          92,160BatchNorm2d-185          [-1, 640, 14, 14]           1,280ReLU-186          [-1, 640, 14, 14]               0Conv2d-187         [-1, 1056, 14, 14]         675,840BatchNorm2d-188         [-1, 1056, 14, 14]           2,112ReLU-189         [-1, 1344, 14, 14]               0Block-190         [-1, 1344, 14, 14]               0Conv2d-191          [-1, 640, 14, 14]         860,160BatchNorm2d-192          [-1, 640, 14, 14]           1,280ReLU-193          [-1, 640, 14, 14]               0Conv2d-194          [-1, 640, 14, 14]          92,160BatchNorm2d-195          [-1, 640, 14, 14]           1,280ReLU-196          [-1, 640, 14, 14]               0Conv2d-197         [-1, 1056, 14, 14]         675,840BatchNorm2d-198         [-1, 1056, 14, 14]           2,112ReLU-199         [-1, 1376, 14, 14]               0Block-200         [-1, 1376, 14, 14]               0Conv2d-201          [-1, 640, 14, 14]         880,640BatchNorm2d-202          [-1, 640, 14, 14]           1,280ReLU-203          [-1, 640, 14, 14]               0Conv2d-204          [-1, 640, 14, 14]          92,160BatchNorm2d-205          [-1, 640, 14, 14]           1,280ReLU-206          [-1, 640, 14, 14]               0Conv2d-207         [-1, 1056, 14, 14]         675,840BatchNorm2d-208         [-1, 1056, 14, 14]           2,112ReLU-209         [-1, 1408, 14, 14]               0Block-210         [-1, 1408, 14, 14]               0Conv2d-211          [-1, 640, 14, 14]         901,120BatchNorm2d-212          [-1, 640, 14, 14]           1,280ReLU-213          [-1, 640, 14, 14]               0Conv2d-214          [-1, 640, 14, 14]          92,160BatchNorm2d-215          [-1, 640, 14, 14]           1,280ReLU-216          [-1, 640, 14, 14]               0Conv2d-217         [-1, 1056, 14, 14]         675,840BatchNorm2d-218         [-1, 1056, 14, 14]           2,112ReLU-219         [-1, 1440, 14, 14]               0Block-220         [-1, 1440, 14, 14]               0Conv2d-221          [-1, 640, 14, 14]         921,600BatchNorm2d-222          [-1, 640, 14, 14]           1,280ReLU-223          [-1, 640, 14, 14]               0Conv2d-224          [-1, 640, 14, 14]          92,160BatchNorm2d-225          [-1, 640, 14, 14]           1,280ReLU-226          [-1, 640, 14, 14]               0Conv2d-227         [-1, 1056, 14, 14]         675,840BatchNorm2d-228         [-1, 1056, 14, 14]           2,112ReLU-229         [-1, 1472, 14, 14]               0Block-230         [-1, 1472, 14, 14]               0Conv2d-231          [-1, 640, 14, 14]         942,080BatchNorm2d-232          [-1, 640, 14, 14]           1,280ReLU-233          [-1, 640, 14, 14]               0Conv2d-234          [-1, 640, 14, 14]          92,160BatchNorm2d-235          [-1, 640, 14, 14]           1,280ReLU-236          [-1, 640, 14, 14]               0Conv2d-237         [-1, 1056, 14, 14]         675,840BatchNorm2d-238         [-1, 1056, 14, 14]           2,112ReLU-239         [-1, 1504, 14, 14]               0Block-240         [-1, 1504, 14, 14]               0Conv2d-241          [-1, 640, 14, 14]         962,560BatchNorm2d-242          [-1, 640, 14, 14]           1,280ReLU-243          [-1, 640, 14, 14]               0Conv2d-244          [-1, 640, 14, 14]          92,160BatchNorm2d-245          [-1, 640, 14, 14]           1,280ReLU-246          [-1, 640, 14, 14]               0Conv2d-247         [-1, 1056, 14, 14]         675,840BatchNorm2d-248         [-1, 1056, 14, 14]           2,112ReLU-249         [-1, 1536, 14, 14]               0Block-250         [-1, 1536, 14, 14]               0Conv2d-251          [-1, 640, 14, 14]         983,040BatchNorm2d-252          [-1, 640, 14, 14]           1,280ReLU-253          [-1, 640, 14, 14]               0Conv2d-254          [-1, 640, 14, 14]          92,160BatchNorm2d-255          [-1, 640, 14, 14]           1,280ReLU-256          [-1, 640, 14, 14]               0Conv2d-257         [-1, 1056, 14, 14]         675,840BatchNorm2d-258         [-1, 1056, 14, 14]           2,112ReLU-259         [-1, 1568, 14, 14]               0Block-260         [-1, 1568, 14, 14]               0Conv2d-261          [-1, 640, 14, 14]       1,003,520BatchNorm2d-262          [-1, 640, 14, 14]           1,280ReLU-263          [-1, 640, 14, 14]               0Conv2d-264          [-1, 640, 14, 14]          92,160BatchNorm2d-265          [-1, 640, 14, 14]           1,280ReLU-266          [-1, 640, 14, 14]               0Conv2d-267         [-1, 1056, 14, 14]         675,840BatchNorm2d-268         [-1, 1056, 14, 14]           2,112ReLU-269         [-1, 1600, 14, 14]               0Block-270         [-1, 1600, 14, 14]               0Conv2d-271          [-1, 640, 14, 14]       1,024,000BatchNorm2d-272          [-1, 640, 14, 14]           1,280ReLU-273          [-1, 640, 14, 14]               0Conv2d-274          [-1, 640, 14, 14]          92,160BatchNorm2d-275          [-1, 640, 14, 14]           1,280ReLU-276          [-1, 640, 14, 14]               0Conv2d-277         [-1, 1056, 14, 14]         675,840BatchNorm2d-278         [-1, 1056, 14, 14]           2,112ReLU-279         [-1, 1632, 14, 14]               0Block-280         [-1, 1632, 14, 14]               0Conv2d-281          [-1, 640, 14, 14]       1,044,480BatchNorm2d-282          [-1, 640, 14, 14]           1,280ReLU-283          [-1, 640, 14, 14]               0Conv2d-284          [-1, 640, 14, 14]          92,160BatchNorm2d-285          [-1, 640, 14, 14]           1,280ReLU-286          [-1, 640, 14, 14]               0Conv2d-287         [-1, 1056, 14, 14]         675,840BatchNorm2d-288         [-1, 1056, 14, 14]           2,112ReLU-289         [-1, 1664, 14, 14]               0Block-290         [-1, 1664, 14, 14]               0Conv2d-291          [-1, 640, 14, 14]       1,064,960BatchNorm2d-292          [-1, 640, 14, 14]           1,280ReLU-293          [-1, 640, 14, 14]               0Conv2d-294          [-1, 640, 14, 14]          92,160BatchNorm2d-295          [-1, 640, 14, 14]           1,280ReLU-296          [-1, 640, 14, 14]               0Conv2d-297         [-1, 1056, 14, 14]         675,840BatchNorm2d-298         [-1, 1056, 14, 14]           2,112ReLU-299         [-1, 1696, 14, 14]               0Block-300         [-1, 1696, 14, 14]               0Conv2d-301         [-1, 1280, 14, 14]       2,170,880BatchNorm2d-302         [-1, 1280, 14, 14]           2,560ReLU-303         [-1, 1280, 14, 14]               0Conv2d-304           [-1, 1280, 7, 7]         368,640BatchNorm2d-305           [-1, 1280, 7, 7]           2,560ReLU-306           [-1, 1280, 7, 7]               0Conv2d-307           [-1, 2176, 7, 7]       2,785,280BatchNorm2d-308           [-1, 2176, 7, 7]           4,352Conv2d-309           [-1, 2176, 7, 7]      33,214,464BatchNorm2d-310           [-1, 2176, 7, 7]           4,352ReLU-311           [-1, 2304, 7, 7]               0Block-312           [-1, 2304, 7, 7]               0Conv2d-313           [-1, 1280, 7, 7]       2,949,120BatchNorm2d-314           [-1, 1280, 7, 7]           2,560ReLU-315           [-1, 1280, 7, 7]               0Conv2d-316           [-1, 1280, 7, 7]         368,640BatchNorm2d-317           [-1, 1280, 7, 7]           2,560ReLU-318           [-1, 1280, 7, 7]               0Conv2d-319           [-1, 2176, 7, 7]       2,785,280BatchNorm2d-320           [-1, 2176, 7, 7]           4,352ReLU-321           [-1, 2432, 7, 7]               0Block-322           [-1, 2432, 7, 7]               0Conv2d-323           [-1, 1280, 7, 7]       3,112,960BatchNorm2d-324           [-1, 1280, 7, 7]           2,560ReLU-325           [-1, 1280, 7, 7]               0Conv2d-326           [-1, 1280, 7, 7]         368,640BatchNorm2d-327           [-1, 1280, 7, 7]           2,560ReLU-328           [-1, 1280, 7, 7]               0Conv2d-329           [-1, 2176, 7, 7]       2,785,280BatchNorm2d-330           [-1, 2176, 7, 7]           4,352ReLU-331           [-1, 2560, 7, 7]               0Block-332           [-1, 2560, 7, 7]               0
AdaptiveAvgPool2d-333           [-1, 2560, 1, 1]               0Linear-334                    [-1, 4]          10,244
================================================================
Total params: 95,008,356
Trainable params: 95,008,356
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 664.46
Params size (MB): 362.43
Estimated Total Size (MB): 1027.47
----------------------------------------------------------------

三、 训练模型 

1. 编写训练函数

def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)  # 训练集的大小num_batches = len(dataloader)  # 批次数目, (size/batch_size,向上取整)train_loss, train_acc = 0, 0  # 初始化训练损失和正确率for X, y in dataloader:  # 获取图片及其标签X, y = X.to(device), y.to(device)# 计算预测误差pred = model(X)  # 网络输出loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失# 反向传播optimizer.zero_grad()  # grad属性归零loss.backward()  # 反向传播optimizer.step()  # 每一步自动更新# 记录acc与losstrain_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc /= sizetrain_loss /= num_batchesreturn train_acc, train_loss

2. 编写测试函数

测试函数和训练函数大致相同,但是由于不进行梯度下降对网络权重进行更新,所以不需要传入优化器

def test(dataloader, model, loss_fn):size = len(dataloader.dataset)  # 测试集的大小num_batches = len(dataloader)  # 批次数目test_loss, test_acc = 0, 0# 当不进行训练时,停止梯度更新,节省计算内存消耗with torch.no_grad():for imgs, target in dataloader:imgs, target = imgs.to(device), target.to(device)# 计算losstarget_pred = model(imgs)loss = loss_fn(target_pred, target)test_loss += loss.item()test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()test_acc /= sizetest_loss /= num_batchesreturn test_acc, test_loss

3.正式训练

import copyoptimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss()  # 创建损失函数epochs = 10train_loss = []
train_acc = []
test_loss = []
test_acc = []best_acc = 0  # 设置一个最佳准确率,作为最佳模型的判别指标for epoch in range(epochs):# 更新学习率(使用自定义学习率时使用)# adjust_learning_rate(optimizer, epoch, learn_rate)model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, optimizer)# scheduler.step() # 更新学习率(调用官方动态学习率接口时使用)model.eval()epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)# 保存最佳模型到 best_modelif epoch_test_acc > best_acc:best_acc = epoch_test_accbest_model = copy.deepcopy(model)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)# 获取当前的学习率lr = optimizer.state_dict()['param_groups'][0]['lr']template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')print(template.format(epoch + 1, epoch_train_acc * 100, epoch_train_loss,epoch_test_acc * 100, epoch_test_loss, lr))# 保存最佳模型到文件中
PATH = './best_model.pth'  # 保存的参数文件名
torch.save(model.state_dict(), PATH)print('Done')

得到如下输出:

Epoch: 1, Train_acc:35.0%, Train_loss:1.512, Test_acc:15.0%, Test_loss:2.101, Lr:1.00E-04
Epoch: 2, Train_acc:55.1%, Train_loss:1.088, Test_acc:15.9%, Test_loss:5.737, Lr:1.00E-04
Epoch: 3, Train_acc:71.0%, Train_loss:0.773, Test_acc:39.8%, Test_loss:2.180, Lr:1.00E-04
Epoch: 4, Train_acc:76.3%, Train_loss:0.616, Test_acc:62.8%, Test_loss:1.222, Lr:1.00E-04
Epoch: 5, Train_acc:79.4%, Train_loss:0.565, Test_acc:61.1%, Test_loss:2.034, Lr:1.00E-04
Epoch: 6, Train_acc:79.6%, Train_loss:0.492, Test_acc:61.9%, Test_loss:1.497, Lr:1.00E-04
Epoch: 7, Train_acc:83.2%, Train_loss:0.480, Test_acc:69.0%, Test_loss:1.305, Lr:1.00E-04
Epoch: 8, Train_acc:84.1%, Train_loss:0.403, Test_acc:56.6%, Test_loss:2.690, Lr:1.00E-04
Epoch: 9, Train_acc:90.0%, Train_loss:0.304, Test_acc:71.7%, Test_loss:1.104, Lr:1.00E-04
Epoch:10, Train_acc:93.8%, Train_loss:0.190, Test_acc:55.8%, Test_loss:2.481, Lr:1.00E-04
Done
预测结果是:CockatooProcess finished with exit code 0

四、 结果可视化

1. Loss&Accuracy

import matplotlib.pyplot as plt
# 隐藏警告
import warningswarnings.filterwarnings("ignore")  # 忽略警告信息
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
plt.rcParams['figure.dpi'] = 100  # 分辨率epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

得到的可视化结果:

 2. 指定图片进行预测

首先,先定义出一个用于预测的函数:

 
from PIL import Imageclasses = list(total_data.class_to_idx)from PIL import Imageclasses = list(total_data.class_to_idx)def predict_one_image(image_path, model, transform, classes):test_img = Image.open(image_path).convert('RGB')plt.imshow(test_img)  # 展示预测的图片test_img = transform(test_img)img = test_img.to(device).unsqueeze(0)model.eval()output = model(img)_, pred = torch.max(output, 1)pred_class = classes[pred]print(f'预测结果是:{pred_class}')

接着调用函数对指定图片进行预测:

# 预测训练集中的某张照片
predict_one_image(image_path='./data/bird_photos/Cockatoo/011.jpg',model=model,transform=train_transforms,classes=classes)

得到如下结果:

预测结果是:Cockatoo

五、网络架构及参数

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/31743.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【linux】dup文件描述符复制函数和管道详解

目录 一、文件描述符复制 1、dup函数(复制文件描述符) ​编辑 2、dup2函数(复制文件描述符) ​编辑 二、无名管道pipe 1、概述 2、无名管道的创建 3、无名管道读写的特点 4、无名管道ps -A | grep bash实现 三、有名管道FI…

Java8使用Stream流实现List列表查询、统计、排序、分组、合并

Java8使用Stream流实现List列表查询、统计、排序以及分组 目录 一、查询方法1.1 forEach1.2 filter(T -> boolean)1.3 filterAny() 和 filterFirst()1.4 map(T -> R) 和 flatMap(T -> Stream)1.5 distinct()1.6 limit(long n) 和 skip(long n) 二、判断方法2.1 anyMa…

threejs-- add()和attach()的区别(不受父对象影响)

add和attach的区别 add()方法:attach()方法:总结区别: 在Three.js中,add()和attach()方法都涉及将一个物体(object)添加到另一个物体(Object3D)上,但它们有不同的作用和用法: add()方法: add(…

容器之按钮盒构件演示

代码; #include <gtk-2.0/gtk/gtk.h> #include <glib-2.0/glib.h> #include <gtk-2.0/gdk/gdkkeysyms.h> #include <stdio.h>int main(int argc, char *argv[]) {gtk_init(&argc, &argv);GtkWidget *window;window gtk_window_new(GTK_WINDO…

Eclipse使用SpringXml to Java没有反应或者报错

Eclipse使用SpringXml to Java没有反应或者报错 定位错误方法&#xff1a; 通过Window -> Show View -> Error Log打开错误日志视图。 错误日志会记录Eclipse运行时发生的各种错误和警告&#xff0c;包括插件和工具的问题。 在错误日志中查找与你执行的Spring XML to J…

Python学习系列之三目运算

Python学习系列之三目运算 前言C#的三目运算Python的三目运算总结 前言 在项目常有一些运算比较&#xff0c;之前使用的C#常用三目运算&#xff0c;减少使用switch或者if else来减少语句。 当C#转化为python时&#xff0c;三目运算使用不同了。 C#的三目运算 这里举个例子&am…

xargs 传参

xargs的默认命令是 echo&#xff0c;空格是默认定界符。这意味着通过管道传递给 xargs的输入将会包含换行和空白&#xff0c;不过通过 xargs 的处理&#xff0c;换行和空白将被空格取代。xargs是构建单行命令的重要组件之一。 xargs -n1 // 一次输出一个参数到一行&#xf…

使用NestJS构建安全密码重置功能的完整指南:实现短信链接跳转验证功能

引言 实现忘记密码的短信链接验证功能&#xff0c;可以按照以下步骤进行&#xff1a; 用户请求重置密码&#xff1a;用户提供注册手机号码&#xff0c;系统生成一个唯一的重置令牌&#xff08;token&#xff09;&#xff0c;将令牌和用户信息存储在数据库中&#xff0c;并将包…

Linux系统异常进程管理

Linux系统异常进程管理 1、异常关闭服务和进程 1&#xff09;【杀】进程 kill 进程【号】 ##温和、优雅 pkill 进程【名】 ##一下爆头 killall 进程【名】 ##优雅&#xff0c;可能需要多次反复 2&#xff09;杀不掉处理&#xff08;慎用&#xff09; 强制&#xff0c;一…

Python学习笔记16:进阶篇(五)异常处理

异常 在编程中&#xff0c;异常是指程序运行过程中发生的意外事件&#xff0c;这些事件通常中断了正常的指令流程。它们可能是由于错误的输入数据、资源不足、非法操作或其他未预料到的情况引起的。Python中&#xff0c;当遇到这类情况时&#xff0c;会抛出一个异常对象&#…

最详细的Selenium+Pytest自动化测试框架实战

前言 selenium自动化 pytest测试框架 本章你需要 一定的python基础——至少明白类与对象&#xff0c;封装继承 一定的selenium基础——本篇不讲selenium&#xff0c; 测试框架简介 测试框架有什么优点呢&#xff1a; 代码复用率高&#xff0c;如果不使用框架的话&#xff…

一个虚拟空间可以放多个不同类型的网站吗

通常一些个人站长或者公司可能同时拥有几个网站&#xff0c;由于其他几个网站流量不高&#xff0c;而每个网站都租用一个虚拟主机空间的话&#xff0c;感觉有点浪费。大家可能会想虚拟主机能不能也像独立服务器那样放置多个网站呢&#xff1f;答案是肯定的&#xff0c;确定主机…

2004年-2022年 全国31省市场分割指数数据

市场分割指数在经济学领域是一个关键的概念&#xff0c;特别是在评估不同区域市场一体化水平时。陆铭等学者深入研究了市场分割问题&#xff0c;并对市场分割指数给出了定义&#xff1a;它是一个衡量在相同时间点不同区域或同一区域在不同时间点的某类商品相对价格差异的指标。…

组合优于继承

设计模式中的组合与继承 使用组合的模式 装饰者模式&#xff08;decorator pattern&#xff09; 策略模式&#xff08;strategy pattern&#xff09; 组合模式&#xff08;composite pattern&#xff09; 使用了组合关系 使用继承的模式 模板模式&#xff08;template p…

【OpenGauss源码学习 —— (ALTER TABLE(列存修改列类型))】

ALTER TABLE&#xff08;列存修改列类型&#xff09; ATExecAlterColumnType 函数1. 检查和处理列存储表的字符集&#xff1a;2. 处理自动递增列的数据类型检查&#xff1a;3. 处理生成列的类型转换检查&#xff1a;4. 处理生成列的数据类型转换&#xff1a; build_column_defa…

复杂风控场景(反洗钱)下,一些sql解决方案

前言&#xff1a; 在工作中遇到的一些比较复杂的场景&#xff0c;一直觉得很有记录的价值&#xff0c;但是就是嫌麻烦懒得写&#xff0c;拖延症比较厉害&#xff0c;主要是怕以后忘了&#xff0c;这些问题如果做面试题的话&#xff0c;也很考验人&#xff0c;算是给自己留个备忘…

几种常见的滤波器样式

IIR Peaking Filter IIR LowShelf Filter IIR HighShelf Filter 4. IIR LowPassFilter 5. IIR HighPass Filter FIR PeakingFilter FIR LowShelf Filter 8. FIR HighShelf Filter 8. FIR LowPass Filter 10. FIR HighPass Filter

命令行中,常见mysql命令

打开终端。 连接到 MySQL 服务器&#xff08;假设你的mysql叫root&#xff09;&#xff1a; mysql -u root -p 输入密码&#xff1a; Enter password: ******** 查看所有数据库&#xff1a; mysql> SHOW DATABASES; ERROR 4031 (HY000): The client was disconnected …

操作系统入门 -- 进程的通信方式

操作系统入门 – 进程的通信方式 1.什么是进程通信 1.1 定义 进程通信就是在不同进程之间交换信息。在之前文章中可以了解到&#xff0c;进程之间相互独立&#xff0c;一般不可能互相访问。因此进程之间若需要通信&#xff0c;则需要一个所有进程都认可的共享空间&#xff0…

OpenHarmony-HDF驱动框架介绍及加载过程分析

前言 HarmonyOS面向万物互联时代&#xff0c;而万物互联涉及到了大量的硬件设备&#xff0c;这些硬件的离散度很高&#xff0c;它们的性能差异与配置差异都很大&#xff0c;所以这要求使用一个更灵活、功能更强大、能耗更低的驱动框架。OpenHarmony系统HDF驱动框架采用C语言面…