Pytorch从零开始实战14

Pytorch从零开始实战——DenseNet + SENet算法实战

本系列来源于365天深度学习训练营

原作者K同学

文章目录

  • Pytorch从零开始实战——DenseNet + SENet算法实战
    • 环境准备
    • 数据集
    • 模型选择
    • 开始训练
    • 可视化
    • 总结

环境准备

本文基于Jupyter notebook,使用Python3.8,Pytorch2.0.1+cu118,torchvision0.15.2,需读者自行配置好环境且有一些深度学习理论基础。本次实验的目的是使用DenseNet+SENet模型。
第一步,导入常用包

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn.functional as F
import random
from time import time
import numpy as np
import pandas as pd
import datetime
import gc
import os
import copy
import warnings
os.environ['KMP_DUPLICATE_LIB_OK']='True'  # 用于避免jupyter环境突然关闭
torch.backends.cudnn.benchmark=True  # 用于加速GPU运算的代码

设置随机数种子

torch.manual_seed(428)
torch.cuda.manual_seed(428)
torch.cuda.manual_seed_all(428)
random.seed(428)
np.random.seed(428)

检查设备对象

torch.manual_seed(428)
torch.cuda.manual_seed(428)
torch.cuda.manual_seed_all(428)
random.seed(428)
np.random.seed(428)

检查设备对象

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device, torch.cuda.device_count() # # (device(type='cuda'), 2)

数据集

本次实验继续使用猴痘病数据集,使用pathlib查看类别,本次类别只有0,1两种类别分别代表患病和不患病。

import pathlib
data_dir = './data/ill/'
data_dir = pathlib.Path(data_dir) # 转成pathlib.Path对象
data_paths = list(data_dir.glob('*')) 
classNames = [str(path).split("/")[2] for path in data_paths]
classNames # ['Monkeypox', 'Others']

使用transforms对数据集进行统一处理,并且根据文件夹名映射对应标签

all_transforms = transforms.Compose([transforms.Resize([224, 224]),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 标准化
])total_data = datasets.ImageFolder("./data/ill/", transform=all_transforms)
total_data.class_to_idx # {'Monkeypox': 0, 'Others': 1}

随机查看5张图片

def plotsample(data):fig, axs = plt.subplots(1, 5, figsize=(10, 10)) #建立子图for i in range(5):num = random.randint(0, len(data) - 1) #首先选取随机数,随机选取五次#抽取数据中对应的图像对象,make_grid函数可将任意格式的图像的通道数升为3,而不改变图像原始的数据#而展示图像用的imshow函数最常见的输入格式也是3通道npimg = torchvision.utils.make_grid(data[num][0]).numpy()nplabel = data[num][1] #提取标签 #将图像由(3, weight, height)转化为(weight, height, 3),并放入imshow函数中读取axs[i].imshow(np.transpose(npimg, (1, 2, 0))) axs[i].set_title(nplabel) #给每个子图加上标签axs[i].axis("off") #消除每个子图的坐标轴plotsample(total_data)

在这里插入图片描述
根据8比2划分数据集和测试集,并且利用DataLoader划分批次和随机打乱

train_size = int(0.8 * len(total_data))
test_size  = len(total_data) - train_size
train_ds, test_ds = torch.utils.data.random_split(total_data, [train_size, test_size])batch_size = 32
train_dl = torch.utils.data.DataLoader(train_ds,batch_size=batch_size,shuffle=True,)
test_dl = torch.utils.data.DataLoader(test_ds,batch_size=batch_size,shuffle=True,)len(train_dl.dataset), len(test_dl.dataset) # (1713, 429)

模型选择

本次实验使用DenseNet + SENet模型,DenseNet的设计核心思想是通过密集连接来增强神经网络的信息流动,促进梯度的传播,以及提高参数的共享和重复使用。采用跨通道concat的形式来连接,会连接前面所有层作为输入。
核心公式为:
在这里插入图片描述
DenseNet中的基本组成单元是DenseBlock,它由多个密集连接的DenseLayer组成。每个DenseLayer都接收所有前面的DenseLayer特征作为输入,将其连接到自己的输出上,并传递给后续的层。如图所示,这是一个基本的DenseBlock模块。
在这里插入图片描述
整体网络架构图如下所示,借用K同学的图片
在这里插入图片描述

为了控制模型的复杂度并减少特征图的大小,DenseNet引入了Transition Block。过渡块包括批归一化、ReLU激活和 1x1 卷积,以减小特征图的通道数,并通过池化操作降低空间维度。
在这里插入图片描述
首先对DenseLayer类定义,本次实验使用add_module函数,默认是用于向类中添加一个子模块,第一个参数为模块名,第二个参数为模块实例,其实相当于加到父类的nn.Sequential里面,所以调用的时候使用super().forward(x),这段的核心是将输入 x 与新特征 t 进行通道维度上的连接,完成密集连接。

class DenseLayer(nn.Sequential):def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):super().__init__()self.add_module("norm1", nn.BatchNorm2d(num_input_features))self.add_module("relu1", nn.ReLU(inplace=True))self.add_module("conv1", nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False))self.add_module("norm2", nn.BatchNorm2d(bn_size * growth_rate))self.add_module("relu2", nn.ReLU(inplace=True))self.add_module("conv2", nn.Conv2d(bn_size*growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False))self.drop_rate = drop_ratedef forward(self, x):t = super().forward(x)if self.drop_rate > 0:t = F.dropout(t, p=self.drop_rate, training=self.training)return torch.cat([x, t], 1)

下面是DenseBlock的实现,通过循环创建了多个DenseLayer。其中的 num_input_features + i * growth_rate 用于指定输入通道的数量,确保每个DenseLayer的输入通道数逐渐增加。将新创建的DenseLayer添加为 DenseBlock 的子模块。循环结束后,DenseBlock 就包含了多个DenseLayer,每个DenseLayer都具有逐渐增加的输入通道数量。

class DenseBlock(nn.Sequential):def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):super().__init__()for i in range(num_layers):layer = DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate)self.add_module("denselayer%d" % (i + 1), layer)

下面是Transition,实现过渡的功能,是在块之间降低通道数量和空间维度。

class Transition(nn.Sequential):def __init__(self, num_input_feature, num_output_features):super().__init__()self.add_module("norm", nn.BatchNorm2d(num_input_feature))self.add_module("relu", nn.ReLU(inplace=True))self.add_module("conv", nn.Conv2d(num_input_feature, num_output_features, kernel_size=1, stride=1, bias=False))self.add_module("pool", nn.AvgPool2d(2, stride=2))

SENet是一种深度神经网络结构,它的核心思想是允许网络在训练期间对每个通道进行自适应的加权,以使网络能够更加关注对任务有用的通道,并抑制对任务无关的通道。这有助于提高网络对输入数据的敏感性,并提升网络性能。SENet的结构包括两个主要组件:Squeeze 操作和 Excitation 操作。

Squeeze 操作(Global Average Pooling):通过全局平均池化,将每个通道的空间维度降为1。这样,对于每个通道,都得到一个单一的数值,反映了该通道对整个特征图的重要性。

Excitation 操作(通道注意力):在 Squeeze 操作后,通过一个小型的多层感知机(MLP)来学习通道之间的关系。这个小型MLP包含一个压缩操作和一个激励操作)。最后,利用学到的权重对每个通道的特征图进行加权,得到加权后的特征表示。
在这里插入图片描述
下面是SENet的实现,首先,通过全局平均池化层对输入特征图进行平均池化,将每个通道的空间维度降为1。然后,通过全连接层序列 fc 对降维后的特征进行处理,得到每个通道的注意力权重。最后,将得到的注意力权重通过 view 操作还原为与输入特征图相同的形状,并将其与输入特征图相乘,得到应用了注意力机制的特征图。

from torch.nn import init
class SEAttention(nn.Module):def __init__(self, channel=512, reduction=16):super().__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Sequential(nn.Linear(channel, channel // reduction, bias=False),nn.ReLU(inplace=True),nn.Linear(channel // reduction, channel, bias=False),nn.Sigmoid())def init_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):init.kaiming_normal_(m.weight, mode='fan_out')if m.bias is not None:init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):init.constant_(m.weight, 1)init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):init.normal_(m.weight, std=0.001)if m.bias is not None:init.constant_(m.bias, 0)def forward(self, x):b, c, _, _ = x.size()y = self.avg_pool(x).view(b, c)y = self.fc(y).view(b, c, 1, 1)return x * y.expand_as(x)

整体模型实现,self.features 是一个包含多个层的序列,包括初始卷积块、多个DenseBlock和Transition,以及最后的全局平均池化和分类器。遍历 block_config 中的配置,创建DenseBlock和Transition。参数初始化部分使用了 Kaiming 初始化和常数初始化。

其中,OrderedDict是Python中的一种有序字典数据结构,它保留了元素添加的顺序。在神经网络中,我们可以使用OrderedDict来指定模型的层次结构。

在进行平均池化之前,进入到SENet进行学习通道注意力权重从而提高网络的表征能力。

from collections import OrderedDictclass DenseNet(nn.Module):def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64,bn_size=4, compression_rate=0.5, drop_rate=0, num_classes=1000):super().__init__()self.features = nn.Sequential(OrderedDict([("conv0", nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),("norm0", nn.BatchNorm2d(num_init_features)),("relu0", nn.ReLU(inplace=True)),("pool0", nn.MaxPool2d(3, stride=2, padding=1))]))num_features = num_init_featuresfor i, num_layers in enumerate(block_config):block = DenseBlock(num_layers, num_features, bn_size, growth_rate, drop_rate)self.features.add_module("denseblock%d" % (i + 1), block)num_features += num_layers * growth_rateif i != len(block_config) - 1:transition = Transition(num_features, int(num_features * compression_rate))self.features.add_module("transition%d" % (i + 1), transition)num_features = int(num_features * compression_rate)self.features.add_module("norm5", nn.BatchNorm2d(num_features))self.features.add_module("relu5", nn.ReLU(inplace=True))self.se = SEAttention(channel=1024, reduction=8)self.classifier = nn.Linear(num_features, num_classes)for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight)elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.bias, 0)nn.init.constant_(m.weight, 1)elif isinstance(m, nn.Linear):if m.bias is not None:nn.init.constant_(m.bias, 0)def forward(self, x):features = self.features(x)out = self.se(features)out = F.avg_pool2d(features, 7, stride=1)out = out.view(features.size(0), -1)out = self.classifier(out)return out

使用summary查看网络

from torchsummary import summary
model = DenseNet().to(device)
summary(model, input_size=(3, 224, 224))

在这里插入图片描述

开始训练

定义训练函数

def train(dataloader, model, loss_fn, opt):size = len(dataloader.dataset)num_batches = len(dataloader)train_acc, train_loss = 0, 0for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)opt.zero_grad()loss.backward()opt.step()train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc /= sizetrain_loss /= num_batchesreturn train_acc, train_loss

定义测试函数

def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)test_acc, test_loss = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)test_acc += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss += loss.item()test_acc /= sizetest_loss /= num_batchesreturn test_acc, test_loss

定义学习率、损失函数、优化算法

loss_fn = nn.CrossEntropyLoss()
learn_rate = 0.0001
opt = torch.optim.Adam(model.parameters(), lr=learn_rate)

开始训练,epoch设置为20

import time
epochs = 20
train_loss = []
train_acc = []
test_loss = []
test_acc = []T1 = time.time()best_acc = 0
best_model = 0for epoch in range(epochs):model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)model.eval() # 确保模型不会进行训练操作epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)if epoch_test_acc > best_acc:best_acc = epoch_test_accbest_model = copy.deepcopy(model)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)print("epoch:%d, train_acc:%.1f%%, train_loss:%.3f, test_acc:%.1f%%, test_loss:%.3f"% (epoch + 1, epoch_train_acc * 100, epoch_train_loss, epoch_test_acc * 100, epoch_test_loss))T2 = time.time()
print('程序运行时间:%s秒' % (T2 - T1))PATH = './best_model.pth'  # 保存的参数文件名
if best_model is not None:torch.save(best_model.state_dict(), PATH)print('保存最佳模型')
print("Done")

在这里插入图片描述

可视化

可视化训练过程与测试过程

import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        #分辨率epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

在这里插入图片描述

总结

SE模块引入了通道注意力机制,使得网络在学习过程中能够更加自适应地关注对任务有用的通道,抑制对任务无关的通道。这有助于提高网络的特征表达能力。当前也可以与各种其他的深度神经网络结构集成。因此,可以在不改变整体网络架构的情况下,通过引入通道注意力机制来增强网络性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/580892.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

亿赛通电子文档安全管理系统 linkfilterservice 未授权漏洞

产品简介 亿赛通电子文档安全管理系统,(简称:CDG)是一款电子文档安全加密软件,该系统利用驱动层透明加密技术,通过对电子文档的加密保护,防止内部员工泄密和外部人员非法窃取企业核心重要数据资…

TS常用类型

原始类型使用 // 原始类型使用 let age: number 18let myName: string 前端let isLoding: boolean falselet a: null nulllet b: undefined undefinedlet s:symbol Symbol()数组类型使用 // 数组类型的两种写法// 写法一 let numbers: number[] [1, 2, 3] // 数值类型…

Spring企业开发核心框架

文章目录 Spring企业开发核心框架一、框架前言1. 总体技术体系2. 框架概念和理解 二、Spring Framework简介1. Spring 和 SpringFramework2. SpringFramework主要功能模块3. SpringFramework 主要优势 三、Spring IoC 容器概念1. 组件和组件管理概念2. Spring IoC容器和容器实现…

ALS-运动系统解构

角色握持 角色蓝图:将物体绑在手上 动作蓝图: 将握持动画截取一帧(explicit time时间写好) 角色替换 在原人物模型下面加一个骨骼体(先不用添加模型),重命名为bodymesh AI使用流程 新建一…

品牌如何在线上打造“社交货币”?媒介盒子揭秘

品牌的社交货币,是品牌与消费者的共识身份铸造器。竹筒奶茶、Keep奖牌这类的实体产品作为社交货币,每每能够引爆社交平台,那么品牌能否通过线上平台打造“社交货币”呢?接下来就让媒介盒子和大家聊聊。 一、社交货币是什么 社交货…

6.Nacos

1.单机部署 1.1 官网 https://nacos.io/zh-cn/index.html https://github.com/alibaba/Nacos 1.2.版本说明 https://github.com/alibaba/spring-cloud-alibaba/wiki/%E7%89%88%E6%9C%AC%E8%AF%B4%E6%98%8E 1.3.下载地址 https://github.com/alibaba/nacos/releases/tag/2.2.…

小区跑腿服务

社区跑腿服务是指在社区范围内为居民提供各种便利的服务,包括购物代劳、快递代取、家政服务等。 这种服务的出现,满足了居民生活中诸多需求,受到了广泛的欢迎和认可。 首先,社区跑腿服务方便了居民的日常生活。 居民无需亲自前…

Unity 数据存储PlayerPrefs管理类

Unity 数据存储PlayerPrefs管理类 Unity 数据存储PlayerPrefs管理类实现存取实体类对象存储格式为Json格式Singleton.csInventoryEntity.csDataManager.cs用法如下 Unity 数据存储PlayerPrefs管理类 实现存取实体类对象 存储格式为Json格式 源码如下: Singleton…

克魔助手工具下载、注册和登录指南

下载安装克魔助手 摘要 本文介绍了如何下载安装克魔助手工具,以及注册和登录流程。通过简单的步骤,用户可以轻松获取并使用该工具,为后续的手机应用管理操作做好准备。 引言 克魔助手是一款免费的手机管理工具,通过该工具用户…

2023年第十六届山东省职业院校技能大赛高职组“应用软件系统开发”赛项样题

第十六届山东省职业院校技能大赛 高职组“应用软件系统开发”赛项样题 目录 一.竞赛须知 二.竞赛任务 模块一:系统需求分析(25分) 模块三:系统部署测试(20分) 需要竞赛源码或资…

Linux常用压缩和解压缩命令

在Linux系统中,有多种压缩和解压缩命令可供使用。以下是一些常用的压缩和解压缩命令的详细解释: 压缩命令 1. gzip 压缩文件: gzip file 这将压缩file并生成一个名为file.gz的压缩文件。 保留原始文件: gzip -c file > fil…

微服务的调用使用

在微服务架构中,不同的微服务之间通常通过网络进行调用和通信。常见的方式包括: 1. **HTTP/HTTPS调用:** 微服务可以通过HTTP或HTTPS协议进行调用。使用HTTP请求方法(如GET、POST、PUT、DELETE)来执行操作&#xff0c…

【AUTOSAR OS】了解AUTOSAR操作系统基本概念(1)--任务

目录 前言 一、任务Task 什么是“基础任务”和“扩展任务”?以及他们适用于什么场景?

文章解读与完整程序——《考虑“源-荷-储”协同互动的主动配电网优化调度研究》

摘要:伴随智能电网的建设和清洁能源的开发利用,配电网中的负荷类型呈现多元化发展,分布式电源、可控负荷、储能等资源的增加让单向潮流的传统配电网逐渐向双向潮流的主动配电网结构转变。在能源结构转变的同时,清洁能源自身的随机性和波动性给配电网带来了更大的调峰…

2023.12.25 关于 Redis 数据类型 Hash 常用命令、内部编码、应用场景

目录 Hash 数据类型 Hash 操作命令 HSET HGET HEXISTS HDEL HKEYS HVALS HGETALL HMGET HLEN HSETNX HINCRBY HINCRBYFLOAT HSTRLEN Hash 编码方式 理解什么是压缩 Hash 实际应用 Cache 缓存 Hash 数据类型 整体上来说 Redis 是键值对结构,其中 …

使用docker创建自己的Android编译容器

文章目录 背景步骤1.创建Dockerfile2.编写Dockerfile指令3.编译4.使用 背景 每次拿到新机器或者系统重装,最麻烦的就是各种环境配置,最近学习了一下docker的知识,用dockerfile创建一个Android编译容器,这样就不用每次都吭哧吭哧的…

python 通过(三维坐标)生成(三维曲面地形图)和(圆柱曲面地形图)

有需要源代码CSDN私信我 注意 python项目移植前要进行以下操作 1.python项目备份 2.生成requirements.txt的库文件 以pycharm为例,生成Python项目所需要的依赖库/包文档:requirements.txt_如何将pycharm项目安装的库文件导出为requirement.txt-CSDN博…

揭秘千巡翼X4卫星通讯无人机

揭秘千巡翼X4卫星通讯无人机 在无人机作业的时候,经常遇到这些异常场景,例如通信网络中断,强干扰,导致无人机无法与飞手通信等。而这些给无人机作业带来三大难题: 难题1,山区作业时通信中断,飞…

【前端框架】NPM概述及使用简介

什么是 NPM npm之于Node,就像pip之于Python,gem之于Ruby,composer之于PHP。 npm是Node官方提供的包管理工具,他已经成了Node包的标准发布平台,用于Node包的发布、传播、依赖控制。npm提供了命令行工具,使你可以方便地下载、安装、升级、删除包,也可以让你作为开发者发布…

教你如何为自己的个人网站选择SSL证书?

在互联网飞速发展的今天,各类互联网技术和工具日新月异,越来越多的人都可以低技术门槛来开办自己的独立博客、自媒体、个人站点等通过这些平台来发布自己想要公开的资讯,或者以此来提供相关的网络服务以及展示、销售自己的作品、商品。殊不知…