深度学习Day-21:ResNet与DenseNet结合

 🍨 本文为:[🔗365天深度学习训练营] 中的学习记录博客
 🍖 原作者:[K同学啊 | 接辅导、项目定制]

要求:

  1. 探索ResNet与DenseNet结合的可能性
  2. 根据模型特性构建新的模型框架
  3. 验证改进后模型的效果

一、 基础配置

  • 语言环境:Python3.8
  • 编译器选择:Pycharm
  • 深度学习环境:
    • torch==1.12.1+cu113
    • torchvision==0.13.1+cu113

二、 前期准备 

1.设置GPU

import pathlib
import torch
import torch.nn as nn
from torchvision import transforms, datasetsdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(device)

2. 导入数据

本项目所采用的数据集未收录于公开数据中,故需要自己在文件目录中导入相应数据集合,并设置对应文件目录,以供后续学习过程中使用。

运行下述代码:

data_dir = './data/bird_photos'
data_dir = pathlib.Path(data_dir)data_paths = list(data_dir.glob('*'))
classeNames = [str(path).split("\\")[2] for path in data_paths]
print(classeNames)image_count = len(list(data_dir.glob('*/*')))
print("图片总数为:", image_count)

得到如下输出:

['Bananaquit', 'Black Skimmer', 'Black Throated Bushtiti', 'Cockatoo']
图片总数为: 565

接下来,我们通过transforms.Compose对整个数据集进行预处理:

train_transforms = transforms.Compose([transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸# transforms.RandomHorizontalFlip(), # 随机水平翻转transforms.ToTensor(),  # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间transforms.Normalize(  # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])test_transform = transforms.Compose([transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸transforms.ToTensor(),  # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间transforms.Normalize(  # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])total_data = datasets.ImageFolder("./data/bird_photos/", transform=train_transforms)
print(total_data.class_to_idx)

得到如下输出:

{'Bananaquit': 0, 'Black Skimmer': 1, 'Black Throated Bushtiti': 2, 'Cockatoo': 3}

3. 划分数据集

 此处数据集需要做按比例划分的操作:

train_size = int(0.8 * len(total_data))
test_size = len(total_data) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size])

接下来,根据划分得到的训练集和验证集对数据集进行包装:

batch_size = 32
train_dl = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=0)
test_dl = torch.utils.data.DataLoader(test_dataset,batch_size=batch_size,shuffle=True,num_workers=0)

并通过:

for X, y in test_dl:print("Shape of X [N, C, H, W]: ", X.shape)print("Shape of y: ", y.shape, y.dtype)break

输出测试数据集的数据分布情况:

Shape of X [N, C, H, W]:  torch.Size([32, 3, 224, 224])
Shape of y:  torch.Size([32]) torch.int64

4.搭建模型

DPN网络通过High Order RNN(HORNN)将ResNet和DenseNet进行了融合,实现了ResNet特征复用及DenseNet特征生成,在保持了二者复用特征+挖掘特征能力的同时,避免了像原始DenseNet那样臃肿的结构。

1.模型搭建


class Block(nn.Module):def __init__(self, in_channel, mid_channel, out_channel, dense_channel, stride, groups, is_shortcut=False):# in_channel,是输入通道数,mid_channel是中间经历的通道数,out_channels是经过一次板块之后的输出通道数。# dense_channels设置这个参数的原因就是一边进行着resnet方式的卷积运算,另一边也同时进行着dense的卷积计算,之后特征图融合形成新的特征图super().__init__()self.is_shortcut = is_shortcutself.out_channel = out_channelself.conv1 = nn.Sequential(nn.Conv2d(in_channel, mid_channel, kernel_size=1, bias=False),nn.BatchNorm2d(mid_channel),nn.ReLU())self.conv2 = nn.Sequential(nn.Conv2d(mid_channel, mid_channel, kernel_size=3, stride=stride, padding=1, groups=groups, bias=False),nn.BatchNorm2d(mid_channel),nn.ReLU())self.conv3 = nn.Sequential(nn.Conv2d(mid_channel, out_channel + dense_channel, kernel_size=1, bias=False),nn.BatchNorm2d(out_channel + dense_channel))if self.is_shortcut:self.shortcut = nn.Sequential(nn.Conv2d(in_channel, out_channel + dense_channel, kernel_size=3, padding=1, stride=stride, bias=False),nn.BatchNorm2d(out_channel + dense_channel))self.relu = nn.ReLU(inplace=True)def forward(self, x):a = xx = self.conv1(x)x = self.conv2(x)x = self.conv3(x)if self.is_shortcut:a = self.shortcut(a)d = self.out_channelx = torch.cat([a[:, :d, :, :] + x[:, :d, :, :], a[:, d:, :, :], x[:, d:, :, :]], dim=1)x = self.relu(x)return xclass DPN(nn.Module):def __init__(self, cfg):super(DPN, self).__init__()self.group = cfg['group']self.in_channel = cfg['in_channel']mid_channels = cfg['mid_channels']out_channels = cfg['out_channels']dense_channels = cfg['dense_channels']num = cfg['num']self.conv1 = nn.Sequential(nn.Conv2d(3, self.in_channel, 7, stride=2, padding=3, bias=False, padding_mode='zeros'),nn.BatchNorm2d(self.in_channel),nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=0))self.conv2 = self._make_layers(mid_channels[0], out_channels[0], dense_channels[0], num[0], stride=1)self.conv3 = self._make_layers(mid_channels[1], out_channels[1], dense_channels[1], num[1], stride=2)self.conv4 = self._make_layers(mid_channels[2], out_channels[2], dense_channels[2], num[2], stride=2)self.conv5 = self._make_layers(mid_channels[3], out_channels[3], dense_channels[3], num[3], stride=2)self.pool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(cfg['out_channels'][3] + (num[3] + 1) * cfg['dense_channels'][3], cfg['classes'])  # fc层需要计算def _make_layers(self, mid_channel, out_channel, dense_channel, num, stride=2):layers = []layers.append(Block(self.in_channel, mid_channel, out_channel, dense_channel, stride=stride, groups=self.group,is_shortcut=True))# block_1里面is_shortcut=True就是resnet中的shortcut连接,将浅层的特征进行一次卷积之后与进行三次卷积的特征图相加# 后面几次相同的板块is_shortcut=False简单的理解就是一个多次重复的板块,第一次利用就可以满足浅层特征的利用,后面重复的不在需要self.in_channel = out_channel + dense_channel * 2# 由于里面包含dense这种一直在叠加的特征图计算,# 所以第一次是2倍的dense_channel,后面每次一都会多出1倍,所以有(i+2)*dense_channelfor i in range(1, num):layers.append(Block(self.in_channel, mid_channel, out_channel, dense_channel, stride=1, groups=self.group))self.in_channel = self.in_channel + dense_channel# self.in_channel = out_channel + (i+2)*dense_channelreturn nn.Sequential(*layers)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = self.conv4(x)x = self.conv5(x)x = self.pool(x)x = torch.flatten(x, start_dim=1)x = self.fc(x)return xdef DPN92(n_class=10):cfg = {'group': 32,'in_channel': 64,'mid_channels': (96, 192, 384, 768),'out_channels': (256, 512, 1024, 2048),'dense_channels': (16, 32, 24, 128),'num': (3, 4, 20, 3),'classes': (n_class)}return DPN(cfg)def DPN98(n_class=10):cfg = {'group': 40,'in_channel': 96,'mid_channels': (160, 320, 640, 1280),'out_channels': (256, 512, 1024, 2048),'dense_channels': (16, 32, 32, 128),'num': (3, 6, 20, 3),'classes': (n_class)}return DPN(cfg)

2.查看模型信息

x = torch.randn(2, 3, 224, 224)
model = DPN98(4)
model.to(device)
import torchsummary as summary
summary.summary(model, (3, 224, 224))

得到如下输出:

----------------------------------------------------------------Layer (type)               Output Shape         Param #
================================================================Conv2d-1         [-1, 96, 112, 112]          14,112BatchNorm2d-2         [-1, 96, 112, 112]             192ReLU-3         [-1, 96, 112, 112]               0MaxPool2d-4           [-1, 96, 55, 55]               0Conv2d-5          [-1, 160, 55, 55]          15,360BatchNorm2d-6          [-1, 160, 55, 55]             320ReLU-7          [-1, 160, 55, 55]               0Conv2d-8          [-1, 160, 55, 55]           5,760BatchNorm2d-9          [-1, 160, 55, 55]             320ReLU-10          [-1, 160, 55, 55]               0Conv2d-11          [-1, 272, 55, 55]          43,520BatchNorm2d-12          [-1, 272, 55, 55]             544Conv2d-13          [-1, 272, 55, 55]         235,008BatchNorm2d-14          [-1, 272, 55, 55]             544ReLU-15          [-1, 288, 55, 55]               0Block-16          [-1, 288, 55, 55]               0Conv2d-17          [-1, 160, 55, 55]          46,080BatchNorm2d-18          [-1, 160, 55, 55]             320ReLU-19          [-1, 160, 55, 55]               0Conv2d-20          [-1, 160, 55, 55]           5,760BatchNorm2d-21          [-1, 160, 55, 55]             320ReLU-22          [-1, 160, 55, 55]               0Conv2d-23          [-1, 272, 55, 55]          43,520BatchNorm2d-24          [-1, 272, 55, 55]             544ReLU-25          [-1, 304, 55, 55]               0Block-26          [-1, 304, 55, 55]               0Conv2d-27          [-1, 160, 55, 55]          48,640BatchNorm2d-28          [-1, 160, 55, 55]             320ReLU-29          [-1, 160, 55, 55]               0Conv2d-30          [-1, 160, 55, 55]           5,760BatchNorm2d-31          [-1, 160, 55, 55]             320ReLU-32          [-1, 160, 55, 55]               0Conv2d-33          [-1, 272, 55, 55]          43,520BatchNorm2d-34          [-1, 272, 55, 55]             544ReLU-35          [-1, 320, 55, 55]               0Block-36          [-1, 320, 55, 55]               0Conv2d-37          [-1, 320, 55, 55]         102,400BatchNorm2d-38          [-1, 320, 55, 55]             640ReLU-39          [-1, 320, 55, 55]               0Conv2d-40          [-1, 320, 28, 28]          23,040BatchNorm2d-41          [-1, 320, 28, 28]             640ReLU-42          [-1, 320, 28, 28]               0Conv2d-43          [-1, 544, 28, 28]         174,080BatchNorm2d-44          [-1, 544, 28, 28]           1,088Conv2d-45          [-1, 544, 28, 28]       1,566,720BatchNorm2d-46          [-1, 544, 28, 28]           1,088ReLU-47          [-1, 576, 28, 28]               0Block-48          [-1, 576, 28, 28]               0Conv2d-49          [-1, 320, 28, 28]         184,320BatchNorm2d-50          [-1, 320, 28, 28]             640ReLU-51          [-1, 320, 28, 28]               0Conv2d-52          [-1, 320, 28, 28]          23,040BatchNorm2d-53          [-1, 320, 28, 28]             640ReLU-54          [-1, 320, 28, 28]               0Conv2d-55          [-1, 544, 28, 28]         174,080BatchNorm2d-56          [-1, 544, 28, 28]           1,088ReLU-57          [-1, 608, 28, 28]               0Block-58          [-1, 608, 28, 28]               0Conv2d-59          [-1, 320, 28, 28]         194,560BatchNorm2d-60          [-1, 320, 28, 28]             640ReLU-61          [-1, 320, 28, 28]               0Conv2d-62          [-1, 320, 28, 28]          23,040BatchNorm2d-63          [-1, 320, 28, 28]             640ReLU-64          [-1, 320, 28, 28]               0Conv2d-65          [-1, 544, 28, 28]         174,080BatchNorm2d-66          [-1, 544, 28, 28]           1,088ReLU-67          [-1, 640, 28, 28]               0Block-68          [-1, 640, 28, 28]               0Conv2d-69          [-1, 320, 28, 28]         204,800BatchNorm2d-70          [-1, 320, 28, 28]             640ReLU-71          [-1, 320, 28, 28]               0Conv2d-72          [-1, 320, 28, 28]          23,040BatchNorm2d-73          [-1, 320, 28, 28]             640ReLU-74          [-1, 320, 28, 28]               0Conv2d-75          [-1, 544, 28, 28]         174,080BatchNorm2d-76          [-1, 544, 28, 28]           1,088ReLU-77          [-1, 672, 28, 28]               0Block-78          [-1, 672, 28, 28]               0Conv2d-79          [-1, 320, 28, 28]         215,040BatchNorm2d-80          [-1, 320, 28, 28]             640ReLU-81          [-1, 320, 28, 28]               0Conv2d-82          [-1, 320, 28, 28]          23,040BatchNorm2d-83          [-1, 320, 28, 28]             640ReLU-84          [-1, 320, 28, 28]               0Conv2d-85          [-1, 544, 28, 28]         174,080BatchNorm2d-86          [-1, 544, 28, 28]           1,088ReLU-87          [-1, 704, 28, 28]               0Block-88          [-1, 704, 28, 28]               0Conv2d-89          [-1, 320, 28, 28]         225,280BatchNorm2d-90          [-1, 320, 28, 28]             640ReLU-91          [-1, 320, 28, 28]               0Conv2d-92          [-1, 320, 28, 28]          23,040BatchNorm2d-93          [-1, 320, 28, 28]             640ReLU-94          [-1, 320, 28, 28]               0Conv2d-95          [-1, 544, 28, 28]         174,080BatchNorm2d-96          [-1, 544, 28, 28]           1,088ReLU-97          [-1, 736, 28, 28]               0Block-98          [-1, 736, 28, 28]               0Conv2d-99          [-1, 640, 28, 28]         471,040BatchNorm2d-100          [-1, 640, 28, 28]           1,280ReLU-101          [-1, 640, 28, 28]               0Conv2d-102          [-1, 640, 14, 14]          92,160BatchNorm2d-103          [-1, 640, 14, 14]           1,280ReLU-104          [-1, 640, 14, 14]               0Conv2d-105         [-1, 1056, 14, 14]         675,840BatchNorm2d-106         [-1, 1056, 14, 14]           2,112Conv2d-107         [-1, 1056, 14, 14]       6,994,944BatchNorm2d-108         [-1, 1056, 14, 14]           2,112ReLU-109         [-1, 1088, 14, 14]               0Block-110         [-1, 1088, 14, 14]               0Conv2d-111          [-1, 640, 14, 14]         696,320BatchNorm2d-112          [-1, 640, 14, 14]           1,280ReLU-113          [-1, 640, 14, 14]               0Conv2d-114          [-1, 640, 14, 14]          92,160BatchNorm2d-115          [-1, 640, 14, 14]           1,280ReLU-116          [-1, 640, 14, 14]               0Conv2d-117         [-1, 1056, 14, 14]         675,840BatchNorm2d-118         [-1, 1056, 14, 14]           2,112ReLU-119         [-1, 1120, 14, 14]               0Block-120         [-1, 1120, 14, 14]               0Conv2d-121          [-1, 640, 14, 14]         716,800BatchNorm2d-122          [-1, 640, 14, 14]           1,280ReLU-123          [-1, 640, 14, 14]               0Conv2d-124          [-1, 640, 14, 14]          92,160BatchNorm2d-125          [-1, 640, 14, 14]           1,280ReLU-126          [-1, 640, 14, 14]               0Conv2d-127         [-1, 1056, 14, 14]         675,840BatchNorm2d-128         [-1, 1056, 14, 14]           2,112ReLU-129         [-1, 1152, 14, 14]               0Block-130         [-1, 1152, 14, 14]               0Conv2d-131          [-1, 640, 14, 14]         737,280BatchNorm2d-132          [-1, 640, 14, 14]           1,280ReLU-133          [-1, 640, 14, 14]               0Conv2d-134          [-1, 640, 14, 14]          92,160BatchNorm2d-135          [-1, 640, 14, 14]           1,280ReLU-136          [-1, 640, 14, 14]               0Conv2d-137         [-1, 1056, 14, 14]         675,840BatchNorm2d-138         [-1, 1056, 14, 14]           2,112ReLU-139         [-1, 1184, 14, 14]               0Block-140         [-1, 1184, 14, 14]               0Conv2d-141          [-1, 640, 14, 14]         757,760BatchNorm2d-142          [-1, 640, 14, 14]           1,280ReLU-143          [-1, 640, 14, 14]               0Conv2d-144          [-1, 640, 14, 14]          92,160BatchNorm2d-145          [-1, 640, 14, 14]           1,280ReLU-146          [-1, 640, 14, 14]               0Conv2d-147         [-1, 1056, 14, 14]         675,840BatchNorm2d-148         [-1, 1056, 14, 14]           2,112ReLU-149         [-1, 1216, 14, 14]               0Block-150         [-1, 1216, 14, 14]               0Conv2d-151          [-1, 640, 14, 14]         778,240BatchNorm2d-152          [-1, 640, 14, 14]           1,280ReLU-153          [-1, 640, 14, 14]               0Conv2d-154          [-1, 640, 14, 14]          92,160BatchNorm2d-155          [-1, 640, 14, 14]           1,280ReLU-156          [-1, 640, 14, 14]               0Conv2d-157         [-1, 1056, 14, 14]         675,840BatchNorm2d-158         [-1, 1056, 14, 14]           2,112ReLU-159         [-1, 1248, 14, 14]               0Block-160         [-1, 1248, 14, 14]               0Conv2d-161          [-1, 640, 14, 14]         798,720BatchNorm2d-162          [-1, 640, 14, 14]           1,280ReLU-163          [-1, 640, 14, 14]               0Conv2d-164          [-1, 640, 14, 14]          92,160BatchNorm2d-165          [-1, 640, 14, 14]           1,280ReLU-166          [-1, 640, 14, 14]               0Conv2d-167         [-1, 1056, 14, 14]         675,840BatchNorm2d-168         [-1, 1056, 14, 14]           2,112ReLU-169         [-1, 1280, 14, 14]               0Block-170         [-1, 1280, 14, 14]               0Conv2d-171          [-1, 640, 14, 14]         819,200BatchNorm2d-172          [-1, 640, 14, 14]           1,280ReLU-173          [-1, 640, 14, 14]               0Conv2d-174          [-1, 640, 14, 14]          92,160BatchNorm2d-175          [-1, 640, 14, 14]           1,280ReLU-176          [-1, 640, 14, 14]               0Conv2d-177         [-1, 1056, 14, 14]         675,840BatchNorm2d-178         [-1, 1056, 14, 14]           2,112ReLU-179         [-1, 1312, 14, 14]               0Block-180         [-1, 1312, 14, 14]               0Conv2d-181          [-1, 640, 14, 14]         839,680BatchNorm2d-182          [-1, 640, 14, 14]           1,280ReLU-183          [-1, 640, 14, 14]               0Conv2d-184          [-1, 640, 14, 14]          92,160BatchNorm2d-185          [-1, 640, 14, 14]           1,280ReLU-186          [-1, 640, 14, 14]               0Conv2d-187         [-1, 1056, 14, 14]         675,840BatchNorm2d-188         [-1, 1056, 14, 14]           2,112ReLU-189         [-1, 1344, 14, 14]               0Block-190         [-1, 1344, 14, 14]               0Conv2d-191          [-1, 640, 14, 14]         860,160BatchNorm2d-192          [-1, 640, 14, 14]           1,280ReLU-193          [-1, 640, 14, 14]               0Conv2d-194          [-1, 640, 14, 14]          92,160BatchNorm2d-195          [-1, 640, 14, 14]           1,280ReLU-196          [-1, 640, 14, 14]               0Conv2d-197         [-1, 1056, 14, 14]         675,840BatchNorm2d-198         [-1, 1056, 14, 14]           2,112ReLU-199         [-1, 1376, 14, 14]               0Block-200         [-1, 1376, 14, 14]               0Conv2d-201          [-1, 640, 14, 14]         880,640BatchNorm2d-202          [-1, 640, 14, 14]           1,280ReLU-203          [-1, 640, 14, 14]               0Conv2d-204          [-1, 640, 14, 14]          92,160BatchNorm2d-205          [-1, 640, 14, 14]           1,280ReLU-206          [-1, 640, 14, 14]               0Conv2d-207         [-1, 1056, 14, 14]         675,840BatchNorm2d-208         [-1, 1056, 14, 14]           2,112ReLU-209         [-1, 1408, 14, 14]               0Block-210         [-1, 1408, 14, 14]               0Conv2d-211          [-1, 640, 14, 14]         901,120BatchNorm2d-212          [-1, 640, 14, 14]           1,280ReLU-213          [-1, 640, 14, 14]               0Conv2d-214          [-1, 640, 14, 14]          92,160BatchNorm2d-215          [-1, 640, 14, 14]           1,280ReLU-216          [-1, 640, 14, 14]               0Conv2d-217         [-1, 1056, 14, 14]         675,840BatchNorm2d-218         [-1, 1056, 14, 14]           2,112ReLU-219         [-1, 1440, 14, 14]               0Block-220         [-1, 1440, 14, 14]               0Conv2d-221          [-1, 640, 14, 14]         921,600BatchNorm2d-222          [-1, 640, 14, 14]           1,280ReLU-223          [-1, 640, 14, 14]               0Conv2d-224          [-1, 640, 14, 14]          92,160BatchNorm2d-225          [-1, 640, 14, 14]           1,280ReLU-226          [-1, 640, 14, 14]               0Conv2d-227         [-1, 1056, 14, 14]         675,840BatchNorm2d-228         [-1, 1056, 14, 14]           2,112ReLU-229         [-1, 1472, 14, 14]               0Block-230         [-1, 1472, 14, 14]               0Conv2d-231          [-1, 640, 14, 14]         942,080BatchNorm2d-232          [-1, 640, 14, 14]           1,280ReLU-233          [-1, 640, 14, 14]               0Conv2d-234          [-1, 640, 14, 14]          92,160BatchNorm2d-235          [-1, 640, 14, 14]           1,280ReLU-236          [-1, 640, 14, 14]               0Conv2d-237         [-1, 1056, 14, 14]         675,840BatchNorm2d-238         [-1, 1056, 14, 14]           2,112ReLU-239         [-1, 1504, 14, 14]               0Block-240         [-1, 1504, 14, 14]               0Conv2d-241          [-1, 640, 14, 14]         962,560BatchNorm2d-242          [-1, 640, 14, 14]           1,280ReLU-243          [-1, 640, 14, 14]               0Conv2d-244          [-1, 640, 14, 14]          92,160BatchNorm2d-245          [-1, 640, 14, 14]           1,280ReLU-246          [-1, 640, 14, 14]               0Conv2d-247         [-1, 1056, 14, 14]         675,840BatchNorm2d-248         [-1, 1056, 14, 14]           2,112ReLU-249         [-1, 1536, 14, 14]               0Block-250         [-1, 1536, 14, 14]               0Conv2d-251          [-1, 640, 14, 14]         983,040BatchNorm2d-252          [-1, 640, 14, 14]           1,280ReLU-253          [-1, 640, 14, 14]               0Conv2d-254          [-1, 640, 14, 14]          92,160BatchNorm2d-255          [-1, 640, 14, 14]           1,280ReLU-256          [-1, 640, 14, 14]               0Conv2d-257         [-1, 1056, 14, 14]         675,840BatchNorm2d-258         [-1, 1056, 14, 14]           2,112ReLU-259         [-1, 1568, 14, 14]               0Block-260         [-1, 1568, 14, 14]               0Conv2d-261          [-1, 640, 14, 14]       1,003,520BatchNorm2d-262          [-1, 640, 14, 14]           1,280ReLU-263          [-1, 640, 14, 14]               0Conv2d-264          [-1, 640, 14, 14]          92,160BatchNorm2d-265          [-1, 640, 14, 14]           1,280ReLU-266          [-1, 640, 14, 14]               0Conv2d-267         [-1, 1056, 14, 14]         675,840BatchNorm2d-268         [-1, 1056, 14, 14]           2,112ReLU-269         [-1, 1600, 14, 14]               0Block-270         [-1, 1600, 14, 14]               0Conv2d-271          [-1, 640, 14, 14]       1,024,000BatchNorm2d-272          [-1, 640, 14, 14]           1,280ReLU-273          [-1, 640, 14, 14]               0Conv2d-274          [-1, 640, 14, 14]          92,160BatchNorm2d-275          [-1, 640, 14, 14]           1,280ReLU-276          [-1, 640, 14, 14]               0Conv2d-277         [-1, 1056, 14, 14]         675,840BatchNorm2d-278         [-1, 1056, 14, 14]           2,112ReLU-279         [-1, 1632, 14, 14]               0Block-280         [-1, 1632, 14, 14]               0Conv2d-281          [-1, 640, 14, 14]       1,044,480BatchNorm2d-282          [-1, 640, 14, 14]           1,280ReLU-283          [-1, 640, 14, 14]               0Conv2d-284          [-1, 640, 14, 14]          92,160BatchNorm2d-285          [-1, 640, 14, 14]           1,280ReLU-286          [-1, 640, 14, 14]               0Conv2d-287         [-1, 1056, 14, 14]         675,840BatchNorm2d-288         [-1, 1056, 14, 14]           2,112ReLU-289         [-1, 1664, 14, 14]               0Block-290         [-1, 1664, 14, 14]               0Conv2d-291          [-1, 640, 14, 14]       1,064,960BatchNorm2d-292          [-1, 640, 14, 14]           1,280ReLU-293          [-1, 640, 14, 14]               0Conv2d-294          [-1, 640, 14, 14]          92,160BatchNorm2d-295          [-1, 640, 14, 14]           1,280ReLU-296          [-1, 640, 14, 14]               0Conv2d-297         [-1, 1056, 14, 14]         675,840BatchNorm2d-298         [-1, 1056, 14, 14]           2,112ReLU-299         [-1, 1696, 14, 14]               0Block-300         [-1, 1696, 14, 14]               0Conv2d-301         [-1, 1280, 14, 14]       2,170,880BatchNorm2d-302         [-1, 1280, 14, 14]           2,560ReLU-303         [-1, 1280, 14, 14]               0Conv2d-304           [-1, 1280, 7, 7]         368,640BatchNorm2d-305           [-1, 1280, 7, 7]           2,560ReLU-306           [-1, 1280, 7, 7]               0Conv2d-307           [-1, 2176, 7, 7]       2,785,280BatchNorm2d-308           [-1, 2176, 7, 7]           4,352Conv2d-309           [-1, 2176, 7, 7]      33,214,464BatchNorm2d-310           [-1, 2176, 7, 7]           4,352ReLU-311           [-1, 2304, 7, 7]               0Block-312           [-1, 2304, 7, 7]               0Conv2d-313           [-1, 1280, 7, 7]       2,949,120BatchNorm2d-314           [-1, 1280, 7, 7]           2,560ReLU-315           [-1, 1280, 7, 7]               0Conv2d-316           [-1, 1280, 7, 7]         368,640BatchNorm2d-317           [-1, 1280, 7, 7]           2,560ReLU-318           [-1, 1280, 7, 7]               0Conv2d-319           [-1, 2176, 7, 7]       2,785,280BatchNorm2d-320           [-1, 2176, 7, 7]           4,352ReLU-321           [-1, 2432, 7, 7]               0Block-322           [-1, 2432, 7, 7]               0Conv2d-323           [-1, 1280, 7, 7]       3,112,960BatchNorm2d-324           [-1, 1280, 7, 7]           2,560ReLU-325           [-1, 1280, 7, 7]               0Conv2d-326           [-1, 1280, 7, 7]         368,640BatchNorm2d-327           [-1, 1280, 7, 7]           2,560ReLU-328           [-1, 1280, 7, 7]               0Conv2d-329           [-1, 2176, 7, 7]       2,785,280BatchNorm2d-330           [-1, 2176, 7, 7]           4,352ReLU-331           [-1, 2560, 7, 7]               0Block-332           [-1, 2560, 7, 7]               0
AdaptiveAvgPool2d-333           [-1, 2560, 1, 1]               0Linear-334                    [-1, 4]          10,244
================================================================
Total params: 95,008,356
Trainable params: 95,008,356
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 664.46
Params size (MB): 362.43
Estimated Total Size (MB): 1027.47
----------------------------------------------------------------

三、 训练模型 

1. 编写训练函数

def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)  # 训练集的大小num_batches = len(dataloader)  # 批次数目, (size/batch_size,向上取整)train_loss, train_acc = 0, 0  # 初始化训练损失和正确率for X, y in dataloader:  # 获取图片及其标签X, y = X.to(device), y.to(device)# 计算预测误差pred = model(X)  # 网络输出loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失# 反向传播optimizer.zero_grad()  # grad属性归零loss.backward()  # 反向传播optimizer.step()  # 每一步自动更新# 记录acc与losstrain_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc /= sizetrain_loss /= num_batchesreturn train_acc, train_loss

2. 编写测试函数

测试函数和训练函数大致相同,但是由于不进行梯度下降对网络权重进行更新,所以不需要传入优化器

def test(dataloader, model, loss_fn):size = len(dataloader.dataset)  # 测试集的大小num_batches = len(dataloader)  # 批次数目test_loss, test_acc = 0, 0# 当不进行训练时,停止梯度更新,节省计算内存消耗with torch.no_grad():for imgs, target in dataloader:imgs, target = imgs.to(device), target.to(device)# 计算losstarget_pred = model(imgs)loss = loss_fn(target_pred, target)test_loss += loss.item()test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()test_acc /= sizetest_loss /= num_batchesreturn test_acc, test_loss

3.正式训练

import copyoptimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss()  # 创建损失函数epochs = 10train_loss = []
train_acc = []
test_loss = []
test_acc = []best_acc = 0  # 设置一个最佳准确率,作为最佳模型的判别指标for epoch in range(epochs):# 更新学习率(使用自定义学习率时使用)# adjust_learning_rate(optimizer, epoch, learn_rate)model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, optimizer)# scheduler.step() # 更新学习率(调用官方动态学习率接口时使用)model.eval()epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)# 保存最佳模型到 best_modelif epoch_test_acc > best_acc:best_acc = epoch_test_accbest_model = copy.deepcopy(model)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)# 获取当前的学习率lr = optimizer.state_dict()['param_groups'][0]['lr']template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')print(template.format(epoch + 1, epoch_train_acc * 100, epoch_train_loss,epoch_test_acc * 100, epoch_test_loss, lr))# 保存最佳模型到文件中
PATH = './best_model.pth'  # 保存的参数文件名
torch.save(model.state_dict(), PATH)print('Done')

得到如下输出:

Epoch: 1, Train_acc:35.0%, Train_loss:1.512, Test_acc:15.0%, Test_loss:2.101, Lr:1.00E-04
Epoch: 2, Train_acc:55.1%, Train_loss:1.088, Test_acc:15.9%, Test_loss:5.737, Lr:1.00E-04
Epoch: 3, Train_acc:71.0%, Train_loss:0.773, Test_acc:39.8%, Test_loss:2.180, Lr:1.00E-04
Epoch: 4, Train_acc:76.3%, Train_loss:0.616, Test_acc:62.8%, Test_loss:1.222, Lr:1.00E-04
Epoch: 5, Train_acc:79.4%, Train_loss:0.565, Test_acc:61.1%, Test_loss:2.034, Lr:1.00E-04
Epoch: 6, Train_acc:79.6%, Train_loss:0.492, Test_acc:61.9%, Test_loss:1.497, Lr:1.00E-04
Epoch: 7, Train_acc:83.2%, Train_loss:0.480, Test_acc:69.0%, Test_loss:1.305, Lr:1.00E-04
Epoch: 8, Train_acc:84.1%, Train_loss:0.403, Test_acc:56.6%, Test_loss:2.690, Lr:1.00E-04
Epoch: 9, Train_acc:90.0%, Train_loss:0.304, Test_acc:71.7%, Test_loss:1.104, Lr:1.00E-04
Epoch:10, Train_acc:93.8%, Train_loss:0.190, Test_acc:55.8%, Test_loss:2.481, Lr:1.00E-04
Done
预测结果是:CockatooProcess finished with exit code 0

四、 结果可视化

1. Loss&Accuracy

import matplotlib.pyplot as plt
# 隐藏警告
import warningswarnings.filterwarnings("ignore")  # 忽略警告信息
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
plt.rcParams['figure.dpi'] = 100  # 分辨率epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

得到的可视化结果:

 2. 指定图片进行预测

首先,先定义出一个用于预测的函数:

 
from PIL import Imageclasses = list(total_data.class_to_idx)from PIL import Imageclasses = list(total_data.class_to_idx)def predict_one_image(image_path, model, transform, classes):test_img = Image.open(image_path).convert('RGB')plt.imshow(test_img)  # 展示预测的图片test_img = transform(test_img)img = test_img.to(device).unsqueeze(0)model.eval()output = model(img)_, pred = torch.max(output, 1)pred_class = classes[pred]print(f'预测结果是:{pred_class}')

接着调用函数对指定图片进行预测:

# 预测训练集中的某张照片
predict_one_image(image_path='./data/bird_photos/Cockatoo/011.jpg',model=model,transform=train_transforms,classes=classes)

得到如下结果:

预测结果是:Cockatoo

五、网络架构及参数

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/31743.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【linux】dup文件描述符复制函数和管道详解

目录 一、文件描述符复制 1、dup函数(复制文件描述符) ​编辑 2、dup2函数(复制文件描述符) ​编辑 二、无名管道pipe 1、概述 2、无名管道的创建 3、无名管道读写的特点 4、无名管道ps -A | grep bash实现 三、有名管道FI…

Java8使用Stream流实现List列表查询、统计、排序、分组、合并

Java8使用Stream流实现List列表查询、统计、排序以及分组 目录 一、查询方法1.1 forEach1.2 filter(T -> boolean)1.3 filterAny() 和 filterFirst()1.4 map(T -> R) 和 flatMap(T -> Stream)1.5 distinct()1.6 limit(long n) 和 skip(long n) 二、判断方法2.1 anyMa…

容器之按钮盒构件演示

代码; #include <gtk-2.0/gtk/gtk.h> #include <glib-2.0/glib.h> #include <gtk-2.0/gdk/gdkkeysyms.h> #include <stdio.h>int main(int argc, char *argv[]) {gtk_init(&argc, &argv);GtkWidget *window;window gtk_window_new(GTK_WINDO…

xargs 传参

xargs的默认命令是 echo&#xff0c;空格是默认定界符。这意味着通过管道传递给 xargs的输入将会包含换行和空白&#xff0c;不过通过 xargs 的处理&#xff0c;换行和空白将被空格取代。xargs是构建单行命令的重要组件之一。 xargs -n1 // 一次输出一个参数到一行&#xf…

Python学习笔记16:进阶篇(五)异常处理

异常 在编程中&#xff0c;异常是指程序运行过程中发生的意外事件&#xff0c;这些事件通常中断了正常的指令流程。它们可能是由于错误的输入数据、资源不足、非法操作或其他未预料到的情况引起的。Python中&#xff0c;当遇到这类情况时&#xff0c;会抛出一个异常对象&#…

最详细的Selenium+Pytest自动化测试框架实战

前言 selenium自动化 pytest测试框架 本章你需要 一定的python基础——至少明白类与对象&#xff0c;封装继承 一定的selenium基础——本篇不讲selenium&#xff0c; 测试框架简介 测试框架有什么优点呢&#xff1a; 代码复用率高&#xff0c;如果不使用框架的话&#xff…

2004年-2022年 全国31省市场分割指数数据

市场分割指数在经济学领域是一个关键的概念&#xff0c;特别是在评估不同区域市场一体化水平时。陆铭等学者深入研究了市场分割问题&#xff0c;并对市场分割指数给出了定义&#xff1a;它是一个衡量在相同时间点不同区域或同一区域在不同时间点的某类商品相对价格差异的指标。…

几种常见的滤波器样式

IIR Peaking Filter IIR LowShelf Filter IIR HighShelf Filter 4. IIR LowPassFilter 5. IIR HighPass Filter FIR PeakingFilter FIR LowShelf Filter 8. FIR HighShelf Filter 8. FIR LowPass Filter 10. FIR HighPass Filter

OpenHarmony-HDF驱动框架介绍及加载过程分析

前言 HarmonyOS面向万物互联时代&#xff0c;而万物互联涉及到了大量的硬件设备&#xff0c;这些硬件的离散度很高&#xff0c;它们的性能差异与配置差异都很大&#xff0c;所以这要求使用一个更灵活、功能更强大、能耗更低的驱动框架。OpenHarmony系统HDF驱动框架采用C语言面…

【Kafka】Kafka Broker工作流程、节点服役与退役、副本、文件存储、高效读写数据-08

【Kafka】Kafka Broker工作流程、节点服役与退役、副本、文件存储、高效读写数据 1. Kafka Broker 工作流程1.1 Zookeeper 存储的 Kafka 信息1.2 Kafka Broker总体工作流程1.2.1 Controller介绍 1.3 Broker 重要参数 2. 节点服役与退役3. Kafka副本 1. Kafka Broker 工作流程 …

GUI Guider(V1.7.2) 设计UI在嵌入式系统上的应用(N32G45XVL-STB)

目录 概述 1 使用GUI Guider 设计UI 1.1 创建页面 1.2 页面切换事件实现 1.3 生成代码和仿真 1.3.1 生成和编译代码 1.3.2 仿真UI 2 GUI Guider生成的代码结构 2.1 代码结构介绍 2.2 Project目录下的文件 3 板卡上移植UI 3.1 加载代码至工程目录 3.2 主函数中调…

【环境变量问题:计算机删除环境变量的恢复方法;此环境变量太大。此对话框允许将值设置为最长2047个字符】

不小心误删了win10系统环境变量可以试试下文方法恢复。 本方法针对修改环境变量未重启的用户可以使用&#xff0c;如果修改环境变量&#xff0c;然后还重启了&#xff0c;只能说重新来。 方法一&#xff1a;使用命令提示符恢复 被修改的系统Path只是同步到了注册表中&#x…

2024软考系规考前复习20问!看看你能答上来多少

今天给大家整理了——2024系统规划与管理师考前20问&#xff0c;这是一份很重要的软考备考必看干货&#xff0c;包含很多核心知识点。有PDF版&#xff0c;可打印下来&#xff0c;过完一遍教材后&#xff0c;来刷一刷、背一背&#xff0c;说不定可以帮你拿下不少分。 第1问- 信息…

2024.6.23周报

目录 摘要 ABSTRACT 一、文献阅读 一、题目 二、摘要 三、网络架构 四、创新点 五、文章解读 1、Introduction 2、Method 3、实验 4、结论 二、代码实验 总结 摘要 本周阅读了一篇题目为NAS-PINN: NEURAL ARCHITECTURE SEARCH-GUIDED PHYSICS-INFORMED NEURAL N…

解决电脑关机难题:电脑关不了机的原因以及方法

在使用电脑的日常生活中&#xff0c;有时会遇到一些烦人的问题&#xff0c;其中之一就是电脑关不了机。当您尝试关闭电脑时&#xff0c;它可能会停留在某个界面&#xff0c;或者根本不响应关机指令。这种情况不仅令人困惑&#xff0c;还可能导致数据丢失或系统损坏。 在本文中…

DS:堆的应用——两种算法和TOP-K问题

欢迎来到Harper.Lee的学习世界&#xff01;博主主页传送门&#xff1a;Harper.Lee的博客主页想要一起进步的uu可以来后台找我哦&#xff01; 一、堆的排序 1.1 向上调整——建小堆 1.1.1 代码实现 //时间复杂度&#xff1a;O(N*logN) //空间复杂度&#xff1a;O(logN) for (…

计算机网络知识点汇总

计算机网络知识点汇总 第1章计算机网络体系结构 1.1 计算机网络概述 1.1.1 计算机网络的概念 ​ 计算机网络是由若干个结点(node)和连接这些结点的链路(link)组成。网络中的结点可以是就三级、集线器、交换机、或者路由器等&#xff0c;网络之间通过路由器进行互联&#xf…

Nodejs 第七十九章(Kafka进阶)

kafka前置知识在上一章讲过了 不再复述 kafka进阶 1. server.properties配置文件 server.properties是Kafka服务器的配置文件&#xff0c;它用于配置Kafka服务的各个方面&#xff0c;包括网络设置、日志存储、消息保留策略、安全认证 #broker的全局唯一编号&#xff0c;不能…

MySQL数据库初体验+数据库管理(其一)

【1】 操作系统介绍&#xff1a; Linux操作系统有 RedHat CentOS Debian Ubuntu OpenSUSE 信创标准 国产系统 &#xff1a; 华为&#xff08;欧拉&#xff09; 阿里&#xff08;龙蜥&#xff09; 腾讯 &#xff08;tencentOS&#xff09; 麒麟&#xf…

【日记】梦到兄长要给鳄鱼换牙齿……(421 字)

正文 今天中午睡了一个小时多一点&#xff0c;做了一个很奇怪的梦。梦见兄长要给一条鳄鱼换牙齿&#xff0c;还说早上不好操作&#xff0c;要三天之后的中午或晚上&#xff0c;颇有一种翻黄历寻个良辰吉日之感。但我没那样大的耐性&#xff0c;便捏住鳄鱼的嘴&#xff0c;左摔右…