深度学习 Day24——J3-1DenseNet算法实战与解析

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制
  • 🚀 文章来源:K同学的学习圈子

文章目录

  • 前言
  • 1 我的环境
  • 2 pytorch实现DenseNet算法
    • 2.1 前期准备
      • 2.1.1 引入库
      • 2.1.2 设置GPU(如果设备上支持GPU就使用GPU,否则使用CPU)
      • 2.1.3 导入数据
      • 2.1.4 可视化数据
      • 2.1.4 图像数据变换
      • 2.1.4 划分数据集
      • 2.1.4 加载数据
      • 2.1.4 查看数据
    • 2.2 搭建densenet121模型
    • 2.3 训练模型
      • 2.3.1 设置超参数
      • 2.3.2 编写训练函数
      • 2.3.3 编写测试函数
      • 2.3.4 正式训练
    • 2.4 结果可视化
    • 2.4 指定图片进行预测
    • 2.6 模型评估
  • 3 知识点详解
    • 3.1 nn.Sequential和nn.Module区别与选择
      • 3.1.1 nn.Sequential
      • 3.1.2 nn.Module
      • 3.1.3 对比
      • 3.1.4 总结
    • 3.2 python中OrderedDict的使用
  • 总结


前言

关键字: pytorch实现DenseNet算法,nn.Sequential和nn.Module区别与选择,python中OrderedDict的使用

1 我的环境

  • 电脑系统:Windows 11
  • 语言环境:python 3.8.6
  • 编译器:pycharm2020.2.3
  • 深度学习环境:
    torch == 1.9.1+cu111
    torchvision == 0.10.1+cu111
    TensorFlow 2.10.1
  • 显卡:NVIDIA GeForce RTX 4070

2 pytorch实现DenseNet算法

2.1 前期准备

2.1.1 引入库


import torch
import torch.nn as nn
import time
import copy
from torchvision import transforms, datasets
from pathlib import Path
from PIL import Image
import torchsummary as summary
import torch.nn.functional as F
from collections import OrderedDict
import re
import torch.utils.model_zoo as model_zoo
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
plt.rcParams['figure.dpi'] = 100  # 分辨率
import warningswarnings.filterwarnings('ignore')  # 忽略一些warning内容,无需打印

2.1.2 设置GPU(如果设备上支持GPU就使用GPU,否则使用CPU)

"""前期准备-设置GPU"""
# 如果设备上支持GPU就使用GPU,否则使用CPUdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")print("Using {} device".format(device))

输出

Using cuda device

2.1.3 导入数据

'''前期工作-导入数据'''
data_dir = r"D:\DeepLearning\data\BreastCancer"
data_dir = Path(data_dir)data_paths = list(data_dir.glob('*'))
classeNames = [str(path).split("\\")[-1] for path in data_paths]
print(classeNames)

输出

['.DS_Store', '0', '1']

2.1.4 可视化数据

'''前期工作-可视化数据'''
subfolder = Path(data_dir) / "1"
image_files = list(p.resolve() for p in subfolder.glob('*') if p.suffix in [".jpg", ".png", ".jpeg"])
plt.figure(figsize=(10, 6))
for i in range(len(image_files[:12])):image_file = image_files[i]ax = plt.subplot(3, 4, i + 1)img = Image.open(str(image_file))plt.imshow(img)plt.axis("off")
# 显示图片
plt.tight_layout()
plt.show()

在这里插入图片描述

2.1.4 图像数据变换

'''前期工作-图像数据变换'''
total_datadir = data_dir# 关于transforms.Compose的更多介绍可以参考:https://blog.csdn.net/qq_38251616/article/details/124878863
train_transforms = transforms.Compose([transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸transforms.ToTensor(),  # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间transforms.Normalize(  # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])
total_data = datasets.ImageFolder(total_datadir, transform=train_transforms)
print(total_data)
print(total_data.class_to_idx)

输出

Dataset ImageFolderNumber of datapoints: 13403Root location: D:\DeepLearning\data\BreastCancerStandardTransform
Transform: Compose(Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=None)ToTensor()Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
{'0': 0, '1': 1}

2.1.4 划分数据集

'''前期工作-划分数据集'''
train_size = int(0.8 * len(total_data))  # train_size表示训练集大小,通过将总体数据长度的80%转换为整数得到;
test_size = len(total_data) - train_size  # test_size表示测试集大小,是总体数据长度减去训练集大小。
# 使用torch.utils.data.random_split()方法进行数据集划分。该方法将总体数据total_data按照指定的大小比例([train_size, test_size])随机划分为训练集和测试集,
# 并将划分结果分别赋值给train_dataset和test_dataset两个变量。
train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size])
print("train_dataset={}\ntest_dataset={}".format(train_dataset, test_dataset))
print("train_size={}\ntest_size={}".format(train_size, test_size))

输出

train_dataset=<torch.utils.data.dataset.Subset object at 0x000001AB3AD06BE0>
test_dataset=<torch.utils.data.dataset.Subset object at 0x000001AB3AD06B20>
train_size=10722
test_size=2681

2.1.4 加载数据

'''前期工作-加载数据'''
batch_size = 32train_dl = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=4)
test_dl = torch.utils.data.DataLoader(test_dataset,batch_size=batch_size,shuffle=True,num_workers=4)

2.1.4 查看数据

'''前期工作-查看数据'''
for X, y in test_dl:print("Shape of X [N, C, H, W]: ", X.shape)print("Shape of y: ", y.shape, y.dtype)break

输出

Shape of X [N, C, H, W]:  torch.Size([32, 3, 224, 224])
Shape of y:  torch.Size([32]) torch.int64

2.2 搭建densenet121模型

"""构建DenseNet网络"""
# 这里我们采用了Pytorch的框架来实现DenseNet,
# 首先实现DenseBlock中的内部结构,这里是BN+ReLU+1×1Conv+BN+ReLU+3×3Conv结构,最后也加入dropout层用于训练过程。
class _DenseLayer(nn.Sequential):"""Basic unit of DenseBlock (using bottleneck layer) """def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):super(_DenseLayer, self).__init__()self.add_module('norm1', nn.BatchNorm2d(num_input_features)),self.add_module('relu1', nn.ReLU(inplace=True)),self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * growth_rate,kernel_size=1, stride=1, bias=False)),self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),self.add_module('relu2', nn.ReLU(inplace=True)),self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,kernel_size=3, stride=1, padding=1, bias=False)),self.drop_rate = drop_ratedef forward(self, x):new_features = super(_DenseLayer, self).forward(x)if self.drop_rate > 0:new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)return torch.cat([x, new_features], 1)# 实现DenseBlock模块,内部是密集连接方式(输入特征数线性增长):
class _DenseBlock(nn.Sequential):"""DenseBlock """def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):super(_DenseBlock, self).__init__()for i in range(num_layers):layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate)self.add_module('denselayer%d' % (i + 1), layer)# 实现Transition层,它主要是一个卷积层和一个池化层:
class _Transition(nn.Sequential):def __init__(self, num_input_features, num_output_features):super(_Transition, self).__init__()self.add_module('norm', nn.BatchNorm2d(num_input_features))self.add_module('relu', nn.ReLU(inplace=True))self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,kernel_size=1, stride=1, bias=False))self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))# 最后我们实现DenseNet网络:
class DenseNet(nn.Module):r"""Densenet-BC model class, based on`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`Args:growth_rate (int) - how many filters to add each layer (`k` in paper)block_config (list of 3 or 4 ints) - how many layers in each pooling blocknum_init_features (int) - the number of filters to learn in the first convolution layerbn_size (int) - multiplicative factor for number of bottle neck layers(i.e. bn_size * k features in the bottleneck layer)drop_rate (float) - dropout rate after each dense layernum_classes (int) - number of classification classes"""def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),num_init_features=24, bn_size=4, compression=0.5, drop_rate=0,num_classes=1000):super(DenseNet, self).__init__()# First Conv2dself.features = nn.Sequential(OrderedDict([('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),('norm0', nn.BatchNorm2d(num_init_features)),('relu0', nn.ReLU(inplace=True)),('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1))]))# Each denseblocknum_features = num_init_featuresfor i, num_layers in enumerate(block_config):block = _DenseBlock(num_layers, num_features, bn_size, growth_rate, drop_rate)self.features.add_module('denseblock%d' % (i + 1), block)num_features += num_layers * growth_rateif i != len(block_config) - 1:transition = _Transition(num_input_features=num_features,num_output_features=int(num_features * compression))self.features.add_module('transition%d' % (i + 1), transition)num_features = int(num_features * compression)# Final bn+reluself.features.add_module('norm5', nn.BatchNorm2d(num_features))self.features.add_module('relu5', nn.ReLU(inplace=True))# classification layerself.classifier = nn.Linear(num_features, num_classes)# params initializationfor m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight)elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.bias, 0)nn.init.constant_(m.weight, 1)elif isinstance(m, nn.Linear):nn.init.constant_(m.bias, 0)def forward(self, x):features = self.features(x)out = F.avg_pool2d(features, 7, stride=1).view(features.size(0), -1)out = self.classifier(out)return outmodel_urls = {'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth','densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth','densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth','densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth'}def densenet121(pretrained=False, **kwargs):"""DenseNet121"""model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16),	**kwargs)if pretrained:# '.'s are no longer allowed in module names, but pervious _DenseLayer# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.# They are also in the checkpoints in model_urls. This pattern is used# to find such keys.pattern = re.compile(r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')state_dict = model_zoo.load_url(model_urls['densenet121'])for key in list(state_dict.keys()):res = pattern.match(key)if res:new_key = res.group(1) + res.group(2)state_dict[new_key] = state_dict[key]del state_dict[key]model.load_state_dict(state_dict)return model"""搭建densenet121模型"""
# model = densenet121().to(device)  
model = densenet121(True).to(device)  # 使用预训练模型
print(model)
print(summary.summary(model, (3, 224, 224)))  # 查看模型的参数量以及相关指标

输出

----------------------------------------------------------------Layer (type)               Output Shape         Param #
================================================================Conv2d-1         [-1, 64, 112, 112]           9,408BatchNorm2d-2         [-1, 64, 112, 112]             128ReLU-3         [-1, 64, 112, 112]               0MaxPool2d-4           [-1, 64, 56, 56]               0BatchNorm2d-5           [-1, 64, 56, 56]             128ReLU-6           [-1, 64, 56, 56]               0Conv2d-7          [-1, 128, 56, 56]           8,192BatchNorm2d-8          [-1, 128, 56, 56]             256ReLU-9          [-1, 128, 56, 56]               0Conv2d-10           [-1, 32, 56, 56]          36,864BatchNorm2d-11           [-1, 96, 56, 56]             192ReLU-12           [-1, 96, 56, 56]               0Conv2d-13          [-1, 128, 56, 56]          12,288BatchNorm2d-14          [-1, 128, 56, 56]             256ReLU-15          [-1, 128, 56, 56]               0Conv2d-16           [-1, 32, 56, 56]          36,864BatchNorm2d-17          [-1, 128, 56, 56]             256ReLU-18          [-1, 128, 56, 56]               0Conv2d-19          [-1, 128, 56, 56]          16,384BatchNorm2d-20          [-1, 128, 56, 56]             256ReLU-21          [-1, 128, 56, 56]               0Conv2d-22           [-1, 32, 56, 56]          36,864BatchNorm2d-23          [-1, 160, 56, 56]             320ReLU-24          [-1, 160, 56, 56]               0Conv2d-25          [-1, 128, 56, 56]          20,480BatchNorm2d-26          [-1, 128, 56, 56]             256ReLU-27          [-1, 128, 56, 56]               0Conv2d-28           [-1, 32, 56, 56]          36,864BatchNorm2d-29          [-1, 192, 56, 56]             384ReLU-30          [-1, 192, 56, 56]               0Conv2d-31          [-1, 128, 56, 56]          24,576BatchNorm2d-32          [-1, 128, 56, 56]             256ReLU-33          [-1, 128, 56, 56]               0Conv2d-34           [-1, 32, 56, 56]          36,864BatchNorm2d-35          [-1, 224, 56, 56]             448ReLU-36          [-1, 224, 56, 56]               0Conv2d-37          [-1, 128, 56, 56]          28,672BatchNorm2d-38          [-1, 128, 56, 56]             256ReLU-39          [-1, 128, 56, 56]               0Conv2d-40           [-1, 32, 56, 56]          36,864BatchNorm2d-41          [-1, 256, 56, 56]             512ReLU-42          [-1, 256, 56, 56]               0Conv2d-43          [-1, 128, 56, 56]          32,768AvgPool2d-44          [-1, 128, 28, 28]               0BatchNorm2d-45          [-1, 128, 28, 28]             256ReLU-46          [-1, 128, 28, 28]               0Conv2d-47          [-1, 128, 28, 28]          16,384BatchNorm2d-48          [-1, 128, 28, 28]             256ReLU-49          [-1, 128, 28, 28]               0Conv2d-50           [-1, 32, 28, 28]          36,864BatchNorm2d-51          [-1, 160, 28, 28]             320ReLU-52          [-1, 160, 28, 28]               0Conv2d-53          [-1, 128, 28, 28]          20,480BatchNorm2d-54          [-1, 128, 28, 28]             256ReLU-55          [-1, 128, 28, 28]               0Conv2d-56           [-1, 32, 28, 28]          36,864BatchNorm2d-57          [-1, 192, 28, 28]             384ReLU-58          [-1, 192, 28, 28]               0Conv2d-59          [-1, 128, 28, 28]          24,576BatchNorm2d-60          [-1, 128, 28, 28]             256ReLU-61          [-1, 128, 28, 28]               0Conv2d-62           [-1, 32, 28, 28]          36,864BatchNorm2d-63          [-1, 224, 28, 28]             448ReLU-64          [-1, 224, 28, 28]               0Conv2d-65          [-1, 128, 28, 28]          28,672BatchNorm2d-66          [-1, 128, 28, 28]             256ReLU-67          [-1, 128, 28, 28]               0Conv2d-68           [-1, 32, 28, 28]          36,864BatchNorm2d-69          [-1, 256, 28, 28]             512ReLU-70          [-1, 256, 28, 28]               0Conv2d-71          [-1, 128, 28, 28]          32,768BatchNorm2d-72          [-1, 128, 28, 28]             256ReLU-73          [-1, 128, 28, 28]               0Conv2d-74           [-1, 32, 28, 28]          36,864BatchNorm2d-75          [-1, 288, 28, 28]             576ReLU-76          [-1, 288, 28, 28]               0Conv2d-77          [-1, 128, 28, 28]          36,864BatchNorm2d-78          [-1, 128, 28, 28]             256ReLU-79          [-1, 128, 28, 28]               0Conv2d-80           [-1, 32, 28, 28]          36,864BatchNorm2d-81          [-1, 320, 28, 28]             640ReLU-82          [-1, 320, 28, 28]               0Conv2d-83          [-1, 128, 28, 28]          40,960BatchNorm2d-84          [-1, 128, 28, 28]             256ReLU-85          [-1, 128, 28, 28]               0Conv2d-86           [-1, 32, 28, 28]          36,864BatchNorm2d-87          [-1, 352, 28, 28]             704ReLU-88          [-1, 352, 28, 28]               0Conv2d-89          [-1, 128, 28, 28]          45,056BatchNorm2d-90          [-1, 128, 28, 28]             256ReLU-91          [-1, 128, 28, 28]               0Conv2d-92           [-1, 32, 28, 28]          36,864BatchNorm2d-93          [-1, 384, 28, 28]             768ReLU-94          [-1, 384, 28, 28]               0Conv2d-95          [-1, 128, 28, 28]          49,152BatchNorm2d-96          [-1, 128, 28, 28]             256ReLU-97          [-1, 128, 28, 28]               0Conv2d-98           [-1, 32, 28, 28]          36,864BatchNorm2d-99          [-1, 416, 28, 28]             832ReLU-100          [-1, 416, 28, 28]               0Conv2d-101          [-1, 128, 28, 28]          53,248BatchNorm2d-102          [-1, 128, 28, 28]             256ReLU-103          [-1, 128, 28, 28]               0Conv2d-104           [-1, 32, 28, 28]          36,864BatchNorm2d-105          [-1, 448, 28, 28]             896ReLU-106          [-1, 448, 28, 28]               0Conv2d-107          [-1, 128, 28, 28]          57,344BatchNorm2d-108          [-1, 128, 28, 28]             256ReLU-109          [-1, 128, 28, 28]               0Conv2d-110           [-1, 32, 28, 28]          36,864BatchNorm2d-111          [-1, 480, 28, 28]             960ReLU-112          [-1, 480, 28, 28]               0Conv2d-113          [-1, 128, 28, 28]          61,440BatchNorm2d-114          [-1, 128, 28, 28]             256ReLU-115          [-1, 128, 28, 28]               0Conv2d-116           [-1, 32, 28, 28]          36,864BatchNorm2d-117          [-1, 512, 28, 28]           1,024ReLU-118          [-1, 512, 28, 28]               0Conv2d-119          [-1, 256, 28, 28]         131,072AvgPool2d-120          [-1, 256, 14, 14]               0BatchNorm2d-121          [-1, 256, 14, 14]             512ReLU-122          [-1, 256, 14, 14]               0Conv2d-123          [-1, 128, 14, 14]          32,768BatchNorm2d-124          [-1, 128, 14, 14]             256ReLU-125          [-1, 128, 14, 14]               0Conv2d-126           [-1, 32, 14, 14]          36,864BatchNorm2d-127          [-1, 288, 14, 14]             576ReLU-128          [-1, 288, 14, 14]               0Conv2d-129          [-1, 128, 14, 14]          36,864BatchNorm2d-130          [-1, 128, 14, 14]             256ReLU-131          [-1, 128, 14, 14]               0Conv2d-132           [-1, 32, 14, 14]          36,864BatchNorm2d-133          [-1, 320, 14, 14]             640ReLU-134          [-1, 320, 14, 14]               0Conv2d-135          [-1, 128, 14, 14]          40,960BatchNorm2d-136          [-1, 128, 14, 14]             256ReLU-137          [-1, 128, 14, 14]               0Conv2d-138           [-1, 32, 14, 14]          36,864BatchNorm2d-139          [-1, 352, 14, 14]             704ReLU-140          [-1, 352, 14, 14]               0Conv2d-141          [-1, 128, 14, 14]          45,056BatchNorm2d-142          [-1, 128, 14, 14]             256ReLU-143          [-1, 128, 14, 14]               0Conv2d-144           [-1, 32, 14, 14]          36,864BatchNorm2d-145          [-1, 384, 14, 14]             768ReLU-146          [-1, 384, 14, 14]               0Conv2d-147          [-1, 128, 14, 14]          49,152BatchNorm2d-148          [-1, 128, 14, 14]             256ReLU-149          [-1, 128, 14, 14]               0Conv2d-150           [-1, 32, 14, 14]          36,864BatchNorm2d-151          [-1, 416, 14, 14]             832ReLU-152          [-1, 416, 14, 14]               0Conv2d-153          [-1, 128, 14, 14]          53,248BatchNorm2d-154          [-1, 128, 14, 14]             256ReLU-155          [-1, 128, 14, 14]               0Conv2d-156           [-1, 32, 14, 14]          36,864BatchNorm2d-157          [-1, 448, 14, 14]             896ReLU-158          [-1, 448, 14, 14]               0Conv2d-159          [-1, 128, 14, 14]          57,344BatchNorm2d-160          [-1, 128, 14, 14]             256ReLU-161          [-1, 128, 14, 14]               0Conv2d-162           [-1, 32, 14, 14]          36,864BatchNorm2d-163          [-1, 480, 14, 14]             960ReLU-164          [-1, 480, 14, 14]               0Conv2d-165          [-1, 128, 14, 14]          61,440BatchNorm2d-166          [-1, 128, 14, 14]             256ReLU-167          [-1, 128, 14, 14]               0Conv2d-168           [-1, 32, 14, 14]          36,864BatchNorm2d-169          [-1, 512, 14, 14]           1,024ReLU-170          [-1, 512, 14, 14]               0Conv2d-171          [-1, 128, 14, 14]          65,536BatchNorm2d-172          [-1, 128, 14, 14]             256ReLU-173          [-1, 128, 14, 14]               0Conv2d-174           [-1, 32, 14, 14]          36,864BatchNorm2d-175          [-1, 544, 14, 14]           1,088ReLU-176          [-1, 544, 14, 14]               0Conv2d-177          [-1, 128, 14, 14]          69,632BatchNorm2d-178          [-1, 128, 14, 14]             256ReLU-179          [-1, 128, 14, 14]               0Conv2d-180           [-1, 32, 14, 14]          36,864BatchNorm2d-181          [-1, 576, 14, 14]           1,152ReLU-182          [-1, 576, 14, 14]               0Conv2d-183          [-1, 128, 14, 14]          73,728BatchNorm2d-184          [-1, 128, 14, 14]             256ReLU-185          [-1, 128, 14, 14]               0Conv2d-186           [-1, 32, 14, 14]          36,864BatchNorm2d-187          [-1, 608, 14, 14]           1,216ReLU-188          [-1, 608, 14, 14]               0Conv2d-189          [-1, 128, 14, 14]          77,824BatchNorm2d-190          [-1, 128, 14, 14]             256ReLU-191          [-1, 128, 14, 14]               0Conv2d-192           [-1, 32, 14, 14]          36,864BatchNorm2d-193          [-1, 640, 14, 14]           1,280ReLU-194          [-1, 640, 14, 14]               0Conv2d-195          [-1, 128, 14, 14]          81,920BatchNorm2d-196          [-1, 128, 14, 14]             256ReLU-197          [-1, 128, 14, 14]               0Conv2d-198           [-1, 32, 14, 14]          36,864BatchNorm2d-199          [-1, 672, 14, 14]           1,344ReLU-200          [-1, 672, 14, 14]               0Conv2d-201          [-1, 128, 14, 14]          86,016BatchNorm2d-202          [-1, 128, 14, 14]             256ReLU-203          [-1, 128, 14, 14]               0Conv2d-204           [-1, 32, 14, 14]          36,864BatchNorm2d-205          [-1, 704, 14, 14]           1,408ReLU-206          [-1, 704, 14, 14]               0Conv2d-207          [-1, 128, 14, 14]          90,112BatchNorm2d-208          [-1, 128, 14, 14]             256ReLU-209          [-1, 128, 14, 14]               0Conv2d-210           [-1, 32, 14, 14]          36,864BatchNorm2d-211          [-1, 736, 14, 14]           1,472ReLU-212          [-1, 736, 14, 14]               0Conv2d-213          [-1, 128, 14, 14]          94,208BatchNorm2d-214          [-1, 128, 14, 14]             256ReLU-215          [-1, 128, 14, 14]               0Conv2d-216           [-1, 32, 14, 14]          36,864BatchNorm2d-217          [-1, 768, 14, 14]           1,536ReLU-218          [-1, 768, 14, 14]               0Conv2d-219          [-1, 128, 14, 14]          98,304BatchNorm2d-220          [-1, 128, 14, 14]             256ReLU-221          [-1, 128, 14, 14]               0Conv2d-222           [-1, 32, 14, 14]          36,864BatchNorm2d-223          [-1, 800, 14, 14]           1,600ReLU-224          [-1, 800, 14, 14]               0Conv2d-225          [-1, 128, 14, 14]         102,400BatchNorm2d-226          [-1, 128, 14, 14]             256ReLU-227          [-1, 128, 14, 14]               0Conv2d-228           [-1, 32, 14, 14]          36,864BatchNorm2d-229          [-1, 832, 14, 14]           1,664ReLU-230          [-1, 832, 14, 14]               0Conv2d-231          [-1, 128, 14, 14]         106,496BatchNorm2d-232          [-1, 128, 14, 14]             256ReLU-233          [-1, 128, 14, 14]               0Conv2d-234           [-1, 32, 14, 14]          36,864BatchNorm2d-235          [-1, 864, 14, 14]           1,728ReLU-236          [-1, 864, 14, 14]               0Conv2d-237          [-1, 128, 14, 14]         110,592BatchNorm2d-238          [-1, 128, 14, 14]             256ReLU-239          [-1, 128, 14, 14]               0Conv2d-240           [-1, 32, 14, 14]          36,864BatchNorm2d-241          [-1, 896, 14, 14]           1,792ReLU-242          [-1, 896, 14, 14]               0Conv2d-243          [-1, 128, 14, 14]         114,688BatchNorm2d-244          [-1, 128, 14, 14]             256ReLU-245          [-1, 128, 14, 14]               0Conv2d-246           [-1, 32, 14, 14]          36,864BatchNorm2d-247          [-1, 928, 14, 14]           1,856ReLU-248          [-1, 928, 14, 14]               0Conv2d-249          [-1, 128, 14, 14]         118,784BatchNorm2d-250          [-1, 128, 14, 14]             256ReLU-251          [-1, 128, 14, 14]               0Conv2d-252           [-1, 32, 14, 14]          36,864BatchNorm2d-253          [-1, 960, 14, 14]           1,920ReLU-254          [-1, 960, 14, 14]               0Conv2d-255          [-1, 128, 14, 14]         122,880BatchNorm2d-256          [-1, 128, 14, 14]             256ReLU-257          [-1, 128, 14, 14]               0Conv2d-258           [-1, 32, 14, 14]          36,864BatchNorm2d-259          [-1, 992, 14, 14]           1,984ReLU-260          [-1, 992, 14, 14]               0Conv2d-261          [-1, 128, 14, 14]         126,976BatchNorm2d-262          [-1, 128, 14, 14]             256ReLU-263          [-1, 128, 14, 14]               0Conv2d-264           [-1, 32, 14, 14]          36,864BatchNorm2d-265         [-1, 1024, 14, 14]           2,048ReLU-266         [-1, 1024, 14, 14]               0Conv2d-267          [-1, 512, 14, 14]         524,288AvgPool2d-268            [-1, 512, 7, 7]               0BatchNorm2d-269            [-1, 512, 7, 7]           1,024ReLU-270            [-1, 512, 7, 7]               0Conv2d-271            [-1, 128, 7, 7]          65,536BatchNorm2d-272            [-1, 128, 7, 7]             256ReLU-273            [-1, 128, 7, 7]               0Conv2d-274             [-1, 32, 7, 7]          36,864BatchNorm2d-275            [-1, 544, 7, 7]           1,088ReLU-276            [-1, 544, 7, 7]               0Conv2d-277            [-1, 128, 7, 7]          69,632BatchNorm2d-278            [-1, 128, 7, 7]             256ReLU-279            [-1, 128, 7, 7]               0Conv2d-280             [-1, 32, 7, 7]          36,864BatchNorm2d-281            [-1, 576, 7, 7]           1,152ReLU-282            [-1, 576, 7, 7]               0Conv2d-283            [-1, 128, 7, 7]          73,728BatchNorm2d-284            [-1, 128, 7, 7]             256ReLU-285            [-1, 128, 7, 7]               0Conv2d-286             [-1, 32, 7, 7]          36,864BatchNorm2d-287            [-1, 608, 7, 7]           1,216ReLU-288            [-1, 608, 7, 7]               0Conv2d-289            [-1, 128, 7, 7]          77,824BatchNorm2d-290            [-1, 128, 7, 7]             256ReLU-291            [-1, 128, 7, 7]               0Conv2d-292             [-1, 32, 7, 7]          36,864BatchNorm2d-293            [-1, 640, 7, 7]           1,280ReLU-294            [-1, 640, 7, 7]               0Conv2d-295            [-1, 128, 7, 7]          81,920BatchNorm2d-296            [-1, 128, 7, 7]             256ReLU-297            [-1, 128, 7, 7]               0Conv2d-298             [-1, 32, 7, 7]          36,864BatchNorm2d-299            [-1, 672, 7, 7]           1,344ReLU-300            [-1, 672, 7, 7]               0Conv2d-301            [-1, 128, 7, 7]          86,016BatchNorm2d-302            [-1, 128, 7, 7]             256ReLU-303            [-1, 128, 7, 7]               0Conv2d-304             [-1, 32, 7, 7]          36,864BatchNorm2d-305            [-1, 704, 7, 7]           1,408ReLU-306            [-1, 704, 7, 7]               0Conv2d-307            [-1, 128, 7, 7]          90,112BatchNorm2d-308            [-1, 128, 7, 7]             256ReLU-309            [-1, 128, 7, 7]               0Conv2d-310             [-1, 32, 7, 7]          36,864BatchNorm2d-311            [-1, 736, 7, 7]           1,472ReLU-312            [-1, 736, 7, 7]               0Conv2d-313            [-1, 128, 7, 7]          94,208BatchNorm2d-314            [-1, 128, 7, 7]             256ReLU-315            [-1, 128, 7, 7]               0Conv2d-316             [-1, 32, 7, 7]          36,864BatchNorm2d-317            [-1, 768, 7, 7]           1,536ReLU-318            [-1, 768, 7, 7]               0Conv2d-319            [-1, 128, 7, 7]          98,304BatchNorm2d-320            [-1, 128, 7, 7]             256ReLU-321            [-1, 128, 7, 7]               0Conv2d-322             [-1, 32, 7, 7]          36,864BatchNorm2d-323            [-1, 800, 7, 7]           1,600ReLU-324            [-1, 800, 7, 7]               0Conv2d-325            [-1, 128, 7, 7]         102,400BatchNorm2d-326            [-1, 128, 7, 7]             256ReLU-327            [-1, 128, 7, 7]               0Conv2d-328             [-1, 32, 7, 7]          36,864BatchNorm2d-329            [-1, 832, 7, 7]           1,664ReLU-330            [-1, 832, 7, 7]               0Conv2d-331            [-1, 128, 7, 7]         106,496BatchNorm2d-332            [-1, 128, 7, 7]             256ReLU-333            [-1, 128, 7, 7]               0Conv2d-334             [-1, 32, 7, 7]          36,864BatchNorm2d-335            [-1, 864, 7, 7]           1,728ReLU-336            [-1, 864, 7, 7]               0Conv2d-337            [-1, 128, 7, 7]         110,592BatchNorm2d-338            [-1, 128, 7, 7]             256ReLU-339            [-1, 128, 7, 7]               0Conv2d-340             [-1, 32, 7, 7]          36,864BatchNorm2d-341            [-1, 896, 7, 7]           1,792ReLU-342            [-1, 896, 7, 7]               0Conv2d-343            [-1, 128, 7, 7]         114,688BatchNorm2d-344            [-1, 128, 7, 7]             256ReLU-345            [-1, 128, 7, 7]               0Conv2d-346             [-1, 32, 7, 7]          36,864BatchNorm2d-347            [-1, 928, 7, 7]           1,856ReLU-348            [-1, 928, 7, 7]               0Conv2d-349            [-1, 128, 7, 7]         118,784BatchNorm2d-350            [-1, 128, 7, 7]             256ReLU-351            [-1, 128, 7, 7]               0Conv2d-352             [-1, 32, 7, 7]          36,864BatchNorm2d-353            [-1, 960, 7, 7]           1,920ReLU-354            [-1, 960, 7, 7]               0Conv2d-355            [-1, 128, 7, 7]         122,880BatchNorm2d-356            [-1, 128, 7, 7]             256ReLU-357            [-1, 128, 7, 7]               0Conv2d-358             [-1, 32, 7, 7]          36,864BatchNorm2d-359            [-1, 992, 7, 7]           1,984ReLU-360            [-1, 992, 7, 7]               0Conv2d-361            [-1, 128, 7, 7]         126,976BatchNorm2d-362            [-1, 128, 7, 7]             256ReLU-363            [-1, 128, 7, 7]               0Conv2d-364             [-1, 32, 7, 7]          36,864BatchNorm2d-365           [-1, 1024, 7, 7]           2,048ReLU-366           [-1, 1024, 7, 7]               0Linear-367                 [-1, 1000]       1,025,000
================================================================
Total params: 7,978,856
Trainable params: 7,978,856
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 294.58
Params size (MB): 30.44
Estimated Total Size (MB): 325.59
----------------------------------------------------------------

2.3 训练模型

2.3.1 设置超参数

"""训练模型--设置超参数"""
loss_fn = nn.CrossEntropyLoss()  # 创建损失函数,计算实际输出和真实相差多少,交叉熵损失函数,事实上,它就是做图片分类任务时常用的损失函数
learn_rate = 1e-4  # 学习率
optimizer1 = torch.optim.SGD(model.parameters(), lr=learn_rate)# 作用是定义优化器,用来训练时候优化模型参数;其中,SGD表示随机梯度下降,用于控制实际输出y与真实y之间的相差有多大
optimizer2 = torch.optim.Adam(model.parameters(), lr=learn_rate)  
lr_opt = optimizer2
model_opt = optimizer2
# 调用官方动态学习率接口时使用2
lambda1 = lambda epoch : 0.92 ** (epoch // 4)
# optimizer = torch.optim.SGD(model.parameters(), lr=learn_rate)
scheduler = torch.optim.lr_scheduler.LambdaLR(lr_opt, lr_lambda=lambda1) #选定调整方法

2.3.2 编写训练函数

"""训练模型--编写训练函数"""
# 训练循环
def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)  # 训练集的大小,一共60000张图片num_batches = len(dataloader)  # 批次数目,1875(60000/32)train_loss, train_acc = 0, 0  # 初始化训练损失和正确率for X, y in dataloader:  # 加载数据加载器,得到里面的 X(图片数据)和 y(真实标签)X, y = X.to(device), y.to(device) # 用于将数据存到显卡# 计算预测误差pred = model(X)  # 网络输出loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失# 反向传播optimizer.zero_grad()  # 清空过往梯度loss.backward()  # 反向传播,计算当前梯度optimizer.step()  # 根据梯度更新网络参数# 记录acc与losstrain_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc /= sizetrain_loss /= num_batchesreturn train_acc, train_loss

2.3.3 编写测试函数

"""训练模型--编写测试函数"""
# 测试函数和训练函数大致相同,但是由于不进行梯度下降对网络权重进行更新,所以不需要传入优化器
def test(dataloader, model, loss_fn):size = len(dataloader.dataset)  # 测试集的大小,一共10000张图片num_batches = len(dataloader)  # 批次数目,313(10000/32=312.5,向上取整)test_loss, test_acc = 0, 0# 当不进行训练时,停止梯度更新,节省计算内存消耗with torch.no_grad(): # 测试时模型参数不用更新,所以 no_grad,整个模型参数正向推就ok,不反向更新参数for imgs, target in dataloader:imgs, target = imgs.to(device), target.to(device)# 计算losstarget_pred = model(imgs)loss = loss_fn(target_pred, target)test_loss += loss.item()test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()#统计预测正确的个数test_acc /= sizetest_loss /= num_batchesreturn test_acc, test_loss

2.3.4 正式训练

"""训练模型--正式训练"""
epochs = 10
train_loss = []
train_acc = []
test_loss = []
test_acc = []
best_test_acc=0for epoch in range(epochs):milliseconds_t1 = int(time.time() * 1000)# 更新学习率(使用自定义学习率时使用)# adjust_learning_rate(lr_opt, epoch, learn_rate)model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, model_opt)scheduler.step() # 更新学习率(调用官方动态学习率接口时使用)model.eval()epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)# 获取当前的学习率lr = lr_opt.state_dict()['param_groups'][0]['lr']milliseconds_t2 = int(time.time() * 1000)template = ('Epoch:{:2d}, duration:{}ms, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}, Lr:{:.2E}')if best_test_acc < epoch_test_acc:best_test_acc = epoch_test_acc#备份最好的模型best_model = copy.deepcopy(model)template = ('Epoch:{:2d}, duration:{}ms, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}, Lr:{:.2E},Update the best model')print(template.format(epoch + 1, milliseconds_t2-milliseconds_t1, epoch_train_acc * 100, epoch_train_loss, epoch_test_acc * 100, epoch_test_loss, lr))
# 保存最佳模型到文件中
PATH = './best_model.pth'  # 保存的参数文件名
torch.save(model.state_dict(), PATH)
print('Done')

输出最高精度为Test_acc:100%

Epoch: 1, duration:74420ms, Train_acc:83.7%, Train_loss:0.902, Test_acc:85.8%,Test_loss:0.345, Lr:1.00E-04,Update the best model
Epoch: 2, duration:72587ms, Train_acc:86.4%, Train_loss:0.329, Test_acc:85.5%,Test_loss:0.343, Lr:1.00E-04
Epoch: 3, duration:72941ms, Train_acc:87.9%, Train_loss:0.292, Test_acc:89.2%,Test_loss:0.262, Lr:1.00E-04,Update the best model
Epoch: 4, duration:74155ms, Train_acc:88.8%, Train_loss:0.279, Test_acc:89.7%,Test_loss:0.248, Lr:1.00E-04,Update the best model
Epoch: 5, duration:75123ms, Train_acc:89.1%, Train_loss:0.265, Test_acc:89.0%,Test_loss:0.277, Lr:1.00E-04
Epoch: 6, duration:74381ms, Train_acc:89.6%, Train_loss:0.255, Test_acc:90.5%,Test_loss:0.249, Lr:1.00E-04,Update the best model
Epoch: 7, duration:73710ms, Train_acc:90.2%, Train_loss:0.243, Test_acc:84.1%,Test_loss:0.369, Lr:1.00E-04
Epoch: 8, duration:73995ms, Train_acc:90.7%, Train_loss:0.230, Test_acc:89.5%,Test_loss:0.250, Lr:1.00E-04
Epoch: 9, duration:73017ms, Train_acc:90.7%, Train_loss:0.223, Test_acc:89.3%,Test_loss:0.263, Lr:1.00E-04
Epoch:10, duration:73960ms, Train_acc:91.2%, Train_loss:0.224, Test_acc:91.6%,Test_loss:0.209, Lr:1.00E-04,Update the best model
Epoch:11, duration:74113ms, Train_acc:91.2%, Train_loss:0.219, Test_acc:90.5%,Test_loss:0.225, Lr:1.00E-04
Epoch:12, duration:73573ms, Train_acc:91.5%, Train_loss:0.213, Test_acc:88.5%,Test_loss:0.273, Lr:1.00E-04
Epoch:13, duration:73206ms, Train_acc:92.2%, Train_loss:0.202, Test_acc:85.1%,Test_loss:0.377, Lr:1.00E-04
Epoch:14, duration:73540ms, Train_acc:92.1%, Train_loss:0.195, Test_acc:91.2%,Test_loss:0.225, Lr:1.00E-04
Epoch:15, duration:73378ms, Train_acc:92.3%, Train_loss:0.192, Test_acc:87.6%,Test_loss:0.796, Lr:1.00E-04
Epoch:16, duration:73195ms, Train_acc:92.5%, Train_loss:0.187, Test_acc:92.5%,Test_loss:0.197, Lr:1.00E-04,Update the best model
Epoch:17, duration:73737ms, Train_acc:93.1%, Train_loss:0.174, Test_acc:92.7%,Test_loss:0.186, Lr:1.00E-04,Update the best model
Epoch:18, duration:73884ms, Train_acc:93.4%, Train_loss:0.171, Test_acc:80.6%,Test_loss:0.463, Lr:1.00E-04
Epoch:19, duration:73239ms, Train_acc:93.2%, Train_loss:0.168, Test_acc:91.2%,Test_loss:0.221, Lr:1.00E-04
Epoch:20, duration:73386ms, Train_acc:93.7%, Train_loss:0.159, Test_acc:92.5%,Test_loss:0.196, Lr:1.00E-04

2.4 结果可视化

"""训练模型--结果可视化"""
epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

在这里插入图片描述

2.4 指定图片进行预测

def predict_one_image(image_path, model, transform, classes):test_img = Image.open(image_path).convert('RGB')plt.imshow(test_img)  # 展示预测的图片plt.show()test_img = transform(test_img)img = test_img.to(device).unsqueeze(0)model.eval()output = model(img)_, pred = torch.max(output, 1)pred_class = classes[pred]print(f'预测结果是:{pred_class}')# 将参数加载到model当中
model.load_state_dict(torch.load(PATH, map_location=device))"""指定图片进行预测"""
classes = list(total_data.class_to_idx)
# 预测训练集中的某张照片
predict_one_image(image_path=str(Path(data_dir) / "Cockatoo/001.jpg"),model=model,transform=train_transforms,classes=classes)

在这里插入图片描述

输出

预测结果是:0

2.6 模型评估

"""模型评估"""
best_model.eval()
epoch_test_acc, epoch_test_loss = test(test_dl, best_model, loss_fn)
# 查看是否与我们记录的最高准确率一致
print(epoch_test_acc, epoch_test_loss)

输出

预测结果是:0
0.9268929503916449 0.185508520431107

3 知识点详解

3.1 nn.Sequential和nn.Module区别与选择

3.1.1 nn.Sequential

torch.nn.Sequential是一个Sequential容器,模块将按照构造函数中传递的顺序添加到模块中。另外,也可以传入一个有序模块。 为了更容易理解,官方给出了一些案例:

# Sequential使用实例model = nn.Sequential(nn.Conv2d(1,20,5),nn.ReLU(),nn.Conv2d(20,64,5),nn.ReLU())# Sequential with OrderedDict使用实例
model = nn.Sequential(OrderedDict([('conv1', nn.Conv2d(1,20,5)),('relu1', nn.ReLU()),('conv2', nn.Conv2d(20,64,5)),('relu2', nn.ReLU())]))

3.1.2 nn.Module

下面我们再用 Module 定义这个模型,下面是使用 Module 的模板

class 网络名字(nn.Module):def __init__(self, 一些定义的参数):super(网络名字, self).__init__()self.layer1 = nn.Linear(num_input, num_hidden)self.layer2 = nn.Sequential(...)...定义需要用的网络层def forward(self, x): # 定义前向传播x1 = self.layer1(x)x2 = self.layer2(x)x = x1 + x2...return x

注意的是,Module 里面也可以使用 Sequential,同时 Module 非常灵活,具体体现在 forward 中,如何复杂的操作都能直观的在 forward 里面执行

3.1.3 对比

为了方便比较,我们先用普通方法搭建一个神经网络。

class Net(torch.nn.Module):def __init__(self, n_feature, n_hidden, n_output):super(Net, self).__init__()self.hidden = torch.nn.Linear(n_feature, n_hidden)self.predict = torch.nn.Linear(n_hidden, n_output)def forward(self, x):x = F.relu(self.hidden(x))x = self.predict(x)return x
net1 = Net(1, 10, 1)net2 = torch.nn.Sequential(torch.nn.Linear(1, 10),torch.nn.ReLU(),torch.nn.Linear(10, 1)
)

打印这两个net

print(net1)
"""
Net ((hidden): Linear (1 -> 10)(predict): Linear (10 -> 1)
)
"""
print(net2)
"""
Sequential ((0): Linear (1 -> 10)(1): ReLU ()(2): Linear (10 -> 1)
)
"""

我们可以发现,打印torch.nn.Sequential会自动加入激励函数,
在 net1 中, 激励函数实际上是在 forward() 功能中被调用的,没有在init中定义,所以在打印网络结构时不会有激励函数的信息.

解析源码,在torch.nn.Sequential中:

def forward(self, input):for module in self:input = module(input)return input

可以看到,torch.nn.Sequential的forward只是简单的顺序传播,操作性有限.

3.1.4 总结

使用torch.nn.Module,我们可以根据自己的需求改变传播过程,如RNN等
如果你需要快速构建或者不需要过多的过程,直接使用torch.nn.Sequential即可。

参考链接:nn.Sequential和nn.Module区别与选择

3.2 python中OrderedDict的使用

很多人认为python中的字典是无序的,因为它是按照hash来存储的,但是python中有个模块collections(英文,收集、集合),里面自带了一个子类

OrderedDict,实现了对字典对象中元素的排序。请看下面的实例:

import collections
print "Regular dictionary"
d={}
d['a']='A'
d['b']='B'
d['c']='C'
for k,v in d.items():print k,vprint "\nOrder dictionary"
d1 = collections.OrderedDict()
d1['a'] = 'A'
d1['b'] = 'B'
d1['c'] = 'C'
d1['1'] = '1'
d1['2'] = '2'
for k,v in d1.items():print k,v

输出:

Regular dictionary
a A
c C
b BOrder dictionary
a A
b B
c C
1 1
2 2

可以看到,同样是保存了ABC等几个元素,但是使用OrderedDict会根据放入元素的先后顺序进行排序。所以输出的值是排好序的。

OrderedDict对象的字典对象,如果其顺序不同那么Python也会把他们当做是两个不同的对象,请看事例:

print 'Regular dictionary:'
d2={}
d2['a']='A'
d2['b']='B'
d2['c']='C'd3={}
d3['c']='C'
d3['a']='A'
d3['b']='B'print d2 == d3print '\nOrderedDict:'
d4=collections.OrderedDict()
d4['a']='A'
d4['b']='B'
d4['c']='C'd5=collections.OrderedDict()
d5['c']='C'
d5['a']='A'
d5['b']='B'print  d1==d2输出:
Regular dictionary:
TrueOrderedDict:
False

再看几个例子:

dd = {'banana': 3, 'apple':4, 'pear': 1, 'orange': 2}
#按key排序
kd = collections.OrderedDict(sorted(dd.items(), key=lambda t: t[0]))
print kd
#按照value排序
vd = collections.OrderedDict(sorted(dd.items(),key=lambda t:t[1]))
print vd#输出
OrderedDict([('apple', 4), ('banana', 3), ('orange', 2), ('pear', 1)])
OrderedDict([('pear', 1), ('orange', 2), ('banana', 3), ('apple', 4)])

总结

  数据量越大,训练时间越长,在DataLoader中增加num_workers,即增加线程数量,可能会导致内存不足出现,Couldn‘t open shared file mapping或者Out of memery的错误,可尝试减小num_corkers。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/601568.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

「Verilog学习笔记」编写乘法器求解算法表达式

专栏前言 本专栏的内容主要是记录本人学习Verilog过程中的一些知识点&#xff0c;刷题网站用的是牛客网 timescale 1ns/1nsmodule calculation(input clk,input rst_n,input [3:0] a,input [3:0] b,output [8:0] c);reg [8:0] data1, data2 ; assign c data2 ; always (posed…

猫咪主食罐头巅峰、希喂、K9哪款好?猫咪主食罐头真实对比测评

在当前科学喂养观念广泛传播的背景下&#xff0c;铲屎官们对猫咪主食的营养价值和健康性有了更高的要求。作为猫咪日常成长和活动的主要能量来源&#xff0c;主食的营养价值对猫咪的健康状况有着直接的影响。特别是对于处于成长期的猫咪来说&#xff0c;选择一款优质的主食对其…

关键词优化完整 “操作 “指南

关键词优化的定义 在内容中突出相关关键词的行为&#xff0c;有助于将谷歌流量引向您的网站。关键词优化要求内容创建者做到以下几点&#xff1a; 研究并发现最佳关键词找到自然的方式在内容中突出相关词语 看看&#xff0c;你已经创建了一些很棒的内容。你做了研究&#xf…

K8S部署GitLab

天行健&#xff0c;君子以自强不息&#xff1b;地势坤&#xff0c;君子以厚德载物。 每个人都有惰性&#xff0c;但不断学习是好好生活的根本&#xff0c;共勉&#xff01; 文章均为学习整理笔记&#xff0c;分享记录为主&#xff0c;如有错误请指正&#xff0c;共同学习进步。…

SpringBoot 如何 返回页面

背景 RestController ResponseBody Controller Controller中的方法无法返回jsp页面&#xff0c;或者html&#xff0c;配置的视图解析器 InternalResourceViewResolver不起作用&#xff0c;返回的内容就是Return 里的内容。 Mapping ResponseBody 也会出现同样的问题。 解…

明明白白安装Python解释器(多版本共存切换)、Python IDE:PyCharm(专业版永久)、透彻!

Python解释器安装 ———————— 解释器&#xff08;英语&#xff1a;Interpreter&#xff09;。用户可以到Python的官网上直接下载Python解释器安装程序。 在浏览器地址栏中输入&#xff1a; http://www.python.org 需要最新专业版PyCharm永久使用权限的扫码免费获取&a…

基于FFmpeg的短视频编辑工具Cut

前言 最近在学习FFmpeg和音视频的相关知识&#xff0c;为了加强对FFmpeg的认识和了解&#xff0c;于是撸了一个短视频编辑软件Cut。 效果图先行&#xff1a; 技术点 启动页优化 但启动app的时候会有一个短暂的黑屏或者白屏。为什么呢&#xff1f; 是因为在App启动时&#x…

智能分析网关V4在工业园区周界防范场景中的应用

一、背景需求分析 在工业产业园、化工园或生产制造园区中&#xff0c;周界防范意义重大&#xff0c;对园区的安全起到重要的作用。常规的安防方式是采用人员巡查&#xff0c;人力投入成本大而且效率低。周界一旦被破坏或入侵&#xff0c;会影响园区人员和资产安全&#xff0c;对…

分布式系统——共识问题

1. 分布式系统 1.1 分布式系统的概念 分布式系统是由多台计算机组成的网络&#xff0c;这些计算机共同协作以实现一个共同的目标。在这种环境中&#xff0c;每台计算机作为一个独立的进程运行。但对最终用户来说&#xff0c;它们似乎是作为一个单一系统在操作。这个概念对于创…

大学生搜题软件,未来可期吗?

作为一家专注于软件开发的公司《智创有术》&#xff0c;我们致力于为客户提供创新、高效和可靠的解决方案。通过多年的经验和专业知识&#xff0c;我们已经在行业内建立了良好的声誉&#xff0c;并赢得了客户的信任和支持。 支持各种源码&#xff0c;网站搭建&#xff0c;APP&a…

数字孪生在增强现实(AR)中的应用

数字孪生在增强现实&#xff08;Augmented Reality&#xff0c;AR&#xff09;中的应用可以提供更丰富、交互性更强的现实世界增强体验。以下是数字孪生在AR中的一些应用&#xff0c;希望对大家有所帮助。北京木奇移动技术有限公司&#xff0c;专业的软件外包开发公司&#xff…

视频剪辑实战:如何批量嵌套合并视频,提高剪辑效率必备技巧

在视频剪辑工作中&#xff0c;经常要处理大量的视频片段。要提高工作效率&#xff0c;批量嵌套合并视频成为了一项必备技巧。现在一起看看云炫AI智剪如何使用一些实用的技巧&#xff0c;快速、准确地完成批量嵌套合并视频的任务。 合并后的视频截图&#xff0c;由两段不同片段组…

【STM32】STM32学习笔记-DMA直接存储器存储(23)

00. 目录 文章目录 00. 目录01. DMA简介02. DMA主要特性03. 存储器映像04. DMA框图05. DMA基本结构06. DMA请求07. 数据宽度与对齐08. 数据转运DMA09. ADC扫描模式DMA10. 附录 01. DMA简介 小容量产品是指闪存存储器容量在16K至32K字节之间的STM32F101xx、STM32F102xx和STM32F…

解决Gitlab Prometheus导致的磁盘空间不足问题

解决Gitlab Prometheus导致的磁盘空间不足问题 用docker搭建了一个gitlab服务&#xff0c;已经建立了多个项目上传&#xff0c;但是突然有一天就503了。 df -TH查看系统盘&#xff0c;发现已经Used 100%爆满了。。。 &#x1f4a1;Tips&#xff1a;/dev/vda1目录是系统盘目录。…

AntV L7 实现地图功能(高德)

一、 使用前的准备 首先&#xff0c;注册开发者账号&#xff0c;成为高德开放平台开发者 登陆之后&#xff0c;在进入「应用管理」 页面「创建新应用」 为应用添加 Key&#xff0c;「服务平台」一项请选择「 Web 端 ( JSAPI ) 」 二、安装依赖 // 安装L7 依赖 npm install…

2024年【危险化学品生产单位主要负责人】复审模拟考试及危险化学品生产单位主要负责人作业模拟考试

题库来源&#xff1a;安全生产模拟考试一点通公众号小程序 2024年危险化学品生产单位主要负责人复审模拟考试为正在备考危险化学品生产单位主要负责人操作证的学员准备的理论考试专题&#xff0c;每个月更新的危险化学品生产单位主要负责人作业模拟考试祝您顺利通过危险化学品…

深度学习 Day23——J3DenseNet算法实战与解析

&#x1f368; 本文为&#x1f517;365天深度学习训练营 中的学习记录博客&#x1f356; 原作者&#xff1a;K同学啊 | 接辅导、项目定制&#x1f680; 文章来源&#xff1a;K同学的学习圈子 文章目录 前言1 我的环境2 pytorch实现DenseNet算法2.1 前期准备2.1.1 引入库2.1.2 设…

flutter 使用adb 同时连接 多个模拟器

MUMU模拟器 MuMu模拟器官网_安卓12模拟器_网易手游模拟器 传统只需要 连接一个 默认命令是 默认端口是7555 adb connect 127.0.0.1:7555 但是需要同时连接调试多个模拟器的时候 就需要连接多个 这里可以使用自带的多开 多开后 使用 1 是对应多开的序号 这样就可以查看对…

我是谁 whoami

文章目录 我是谁 whoami更多信息 我是谁 whoami 我知道你是谁&#xff0c;但我不知道我是谁&#xff0c;此时whoami可以帮助你&#xff0c;哈哈。 whoami将打印当前用户的名字。与id -un类似。 官方定义为&#xff1a; whoami - print effective userid 用法为&#xff1a; …

Redis基础学习一

1. Redis 入门 1.1. Redis 诞生历程 1.1.1.从一个故事开始 08 年的时候有一个意大利西西里岛的小伙子&#xff0c;笔名 antirez&#xff08;http://invece.org/&#xff09;&#xff0c;创建了一个访客信息网站 LLOOGG.COM。有的时候我们需要知道网站的访问情况&#xff0c;…