Pytorch从零开始实战14

Pytorch从零开始实战——DenseNet + SENet算法实战

本系列来源于365天深度学习训练营

原作者K同学

文章目录

  • Pytorch从零开始实战——DenseNet + SENet算法实战
    • 环境准备
    • 数据集
    • 模型选择
    • 开始训练
    • 可视化
    • 总结

环境准备

本文基于Jupyter notebook,使用Python3.8,Pytorch2.0.1+cu118,torchvision0.15.2,需读者自行配置好环境且有一些深度学习理论基础。本次实验的目的是使用DenseNet+SENet模型。
第一步,导入常用包

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn.functional as F
import random
from time import time
import numpy as np
import pandas as pd
import datetime
import gc
import os
import copy
import warnings
os.environ['KMP_DUPLICATE_LIB_OK']='True'  # 用于避免jupyter环境突然关闭
torch.backends.cudnn.benchmark=True  # 用于加速GPU运算的代码

设置随机数种子

torch.manual_seed(428)
torch.cuda.manual_seed(428)
torch.cuda.manual_seed_all(428)
random.seed(428)
np.random.seed(428)

检查设备对象

torch.manual_seed(428)
torch.cuda.manual_seed(428)
torch.cuda.manual_seed_all(428)
random.seed(428)
np.random.seed(428)

检查设备对象

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device, torch.cuda.device_count() # # (device(type='cuda'), 2)

数据集

本次实验继续使用猴痘病数据集,使用pathlib查看类别,本次类别只有0,1两种类别分别代表患病和不患病。

import pathlib
data_dir = './data/ill/'
data_dir = pathlib.Path(data_dir) # 转成pathlib.Path对象
data_paths = list(data_dir.glob('*')) 
classNames = [str(path).split("/")[2] for path in data_paths]
classNames # ['Monkeypox', 'Others']

使用transforms对数据集进行统一处理,并且根据文件夹名映射对应标签

all_transforms = transforms.Compose([transforms.Resize([224, 224]),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 标准化
])total_data = datasets.ImageFolder("./data/ill/", transform=all_transforms)
total_data.class_to_idx # {'Monkeypox': 0, 'Others': 1}

随机查看5张图片

def plotsample(data):fig, axs = plt.subplots(1, 5, figsize=(10, 10)) #建立子图for i in range(5):num = random.randint(0, len(data) - 1) #首先选取随机数,随机选取五次#抽取数据中对应的图像对象,make_grid函数可将任意格式的图像的通道数升为3,而不改变图像原始的数据#而展示图像用的imshow函数最常见的输入格式也是3通道npimg = torchvision.utils.make_grid(data[num][0]).numpy()nplabel = data[num][1] #提取标签 #将图像由(3, weight, height)转化为(weight, height, 3),并放入imshow函数中读取axs[i].imshow(np.transpose(npimg, (1, 2, 0))) axs[i].set_title(nplabel) #给每个子图加上标签axs[i].axis("off") #消除每个子图的坐标轴plotsample(total_data)

在这里插入图片描述
根据8比2划分数据集和测试集,并且利用DataLoader划分批次和随机打乱

train_size = int(0.8 * len(total_data))
test_size  = len(total_data) - train_size
train_ds, test_ds = torch.utils.data.random_split(total_data, [train_size, test_size])batch_size = 32
train_dl = torch.utils.data.DataLoader(train_ds,batch_size=batch_size,shuffle=True,)
test_dl = torch.utils.data.DataLoader(test_ds,batch_size=batch_size,shuffle=True,)len(train_dl.dataset), len(test_dl.dataset) # (1713, 429)

模型选择

本次实验使用DenseNet + SENet模型,DenseNet的设计核心思想是通过密集连接来增强神经网络的信息流动,促进梯度的传播,以及提高参数的共享和重复使用。采用跨通道concat的形式来连接,会连接前面所有层作为输入。
核心公式为:
在这里插入图片描述
DenseNet中的基本组成单元是DenseBlock,它由多个密集连接的DenseLayer组成。每个DenseLayer都接收所有前面的DenseLayer特征作为输入,将其连接到自己的输出上,并传递给后续的层。如图所示,这是一个基本的DenseBlock模块。
在这里插入图片描述
整体网络架构图如下所示,借用K同学的图片
在这里插入图片描述

为了控制模型的复杂度并减少特征图的大小,DenseNet引入了Transition Block。过渡块包括批归一化、ReLU激活和 1x1 卷积,以减小特征图的通道数,并通过池化操作降低空间维度。
在这里插入图片描述
首先对DenseLayer类定义,本次实验使用add_module函数,默认是用于向类中添加一个子模块,第一个参数为模块名,第二个参数为模块实例,其实相当于加到父类的nn.Sequential里面,所以调用的时候使用super().forward(x),这段的核心是将输入 x 与新特征 t 进行通道维度上的连接,完成密集连接。

class DenseLayer(nn.Sequential):def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):super().__init__()self.add_module("norm1", nn.BatchNorm2d(num_input_features))self.add_module("relu1", nn.ReLU(inplace=True))self.add_module("conv1", nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False))self.add_module("norm2", nn.BatchNorm2d(bn_size * growth_rate))self.add_module("relu2", nn.ReLU(inplace=True))self.add_module("conv2", nn.Conv2d(bn_size*growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False))self.drop_rate = drop_ratedef forward(self, x):t = super().forward(x)if self.drop_rate > 0:t = F.dropout(t, p=self.drop_rate, training=self.training)return torch.cat([x, t], 1)

下面是DenseBlock的实现,通过循环创建了多个DenseLayer。其中的 num_input_features + i * growth_rate 用于指定输入通道的数量,确保每个DenseLayer的输入通道数逐渐增加。将新创建的DenseLayer添加为 DenseBlock 的子模块。循环结束后,DenseBlock 就包含了多个DenseLayer,每个DenseLayer都具有逐渐增加的输入通道数量。

class DenseBlock(nn.Sequential):def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):super().__init__()for i in range(num_layers):layer = DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate)self.add_module("denselayer%d" % (i + 1), layer)

下面是Transition,实现过渡的功能,是在块之间降低通道数量和空间维度。

class Transition(nn.Sequential):def __init__(self, num_input_feature, num_output_features):super().__init__()self.add_module("norm", nn.BatchNorm2d(num_input_feature))self.add_module("relu", nn.ReLU(inplace=True))self.add_module("conv", nn.Conv2d(num_input_feature, num_output_features, kernel_size=1, stride=1, bias=False))self.add_module("pool", nn.AvgPool2d(2, stride=2))

SENet是一种深度神经网络结构,它的核心思想是允许网络在训练期间对每个通道进行自适应的加权,以使网络能够更加关注对任务有用的通道,并抑制对任务无关的通道。这有助于提高网络对输入数据的敏感性,并提升网络性能。SENet的结构包括两个主要组件:Squeeze 操作和 Excitation 操作。

Squeeze 操作(Global Average Pooling):通过全局平均池化,将每个通道的空间维度降为1。这样,对于每个通道,都得到一个单一的数值,反映了该通道对整个特征图的重要性。

Excitation 操作(通道注意力):在 Squeeze 操作后,通过一个小型的多层感知机(MLP)来学习通道之间的关系。这个小型MLP包含一个压缩操作和一个激励操作)。最后,利用学到的权重对每个通道的特征图进行加权,得到加权后的特征表示。
在这里插入图片描述
下面是SENet的实现,首先,通过全局平均池化层对输入特征图进行平均池化,将每个通道的空间维度降为1。然后,通过全连接层序列 fc 对降维后的特征进行处理,得到每个通道的注意力权重。最后,将得到的注意力权重通过 view 操作还原为与输入特征图相同的形状,并将其与输入特征图相乘,得到应用了注意力机制的特征图。

from torch.nn import init
class SEAttention(nn.Module):def __init__(self, channel=512, reduction=16):super().__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Sequential(nn.Linear(channel, channel // reduction, bias=False),nn.ReLU(inplace=True),nn.Linear(channel // reduction, channel, bias=False),nn.Sigmoid())def init_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):init.kaiming_normal_(m.weight, mode='fan_out')if m.bias is not None:init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):init.constant_(m.weight, 1)init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):init.normal_(m.weight, std=0.001)if m.bias is not None:init.constant_(m.bias, 0)def forward(self, x):b, c, _, _ = x.size()y = self.avg_pool(x).view(b, c)y = self.fc(y).view(b, c, 1, 1)return x * y.expand_as(x)

整体模型实现,self.features 是一个包含多个层的序列,包括初始卷积块、多个DenseBlock和Transition,以及最后的全局平均池化和分类器。遍历 block_config 中的配置,创建DenseBlock和Transition。参数初始化部分使用了 Kaiming 初始化和常数初始化。

其中,OrderedDict是Python中的一种有序字典数据结构,它保留了元素添加的顺序。在神经网络中,我们可以使用OrderedDict来指定模型的层次结构。

在进行平均池化之前,进入到SENet进行学习通道注意力权重从而提高网络的表征能力。

from collections import OrderedDictclass DenseNet(nn.Module):def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64,bn_size=4, compression_rate=0.5, drop_rate=0, num_classes=1000):super().__init__()self.features = nn.Sequential(OrderedDict([("conv0", nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),("norm0", nn.BatchNorm2d(num_init_features)),("relu0", nn.ReLU(inplace=True)),("pool0", nn.MaxPool2d(3, stride=2, padding=1))]))num_features = num_init_featuresfor i, num_layers in enumerate(block_config):block = DenseBlock(num_layers, num_features, bn_size, growth_rate, drop_rate)self.features.add_module("denseblock%d" % (i + 1), block)num_features += num_layers * growth_rateif i != len(block_config) - 1:transition = Transition(num_features, int(num_features * compression_rate))self.features.add_module("transition%d" % (i + 1), transition)num_features = int(num_features * compression_rate)self.features.add_module("norm5", nn.BatchNorm2d(num_features))self.features.add_module("relu5", nn.ReLU(inplace=True))self.se = SEAttention(channel=1024, reduction=8)self.classifier = nn.Linear(num_features, num_classes)for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight)elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.bias, 0)nn.init.constant_(m.weight, 1)elif isinstance(m, nn.Linear):if m.bias is not None:nn.init.constant_(m.bias, 0)def forward(self, x):features = self.features(x)out = self.se(features)out = F.avg_pool2d(features, 7, stride=1)out = out.view(features.size(0), -1)out = self.classifier(out)return out

使用summary查看网络

from torchsummary import summary
model = DenseNet().to(device)
summary(model, input_size=(3, 224, 224))

在这里插入图片描述

开始训练

定义训练函数

def train(dataloader, model, loss_fn, opt):size = len(dataloader.dataset)num_batches = len(dataloader)train_acc, train_loss = 0, 0for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)opt.zero_grad()loss.backward()opt.step()train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc /= sizetrain_loss /= num_batchesreturn train_acc, train_loss

定义测试函数

def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)test_acc, test_loss = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)test_acc += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss += loss.item()test_acc /= sizetest_loss /= num_batchesreturn test_acc, test_loss

定义学习率、损失函数、优化算法

loss_fn = nn.CrossEntropyLoss()
learn_rate = 0.0001
opt = torch.optim.Adam(model.parameters(), lr=learn_rate)

开始训练,epoch设置为20

import time
epochs = 20
train_loss = []
train_acc = []
test_loss = []
test_acc = []T1 = time.time()best_acc = 0
best_model = 0for epoch in range(epochs):model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)model.eval() # 确保模型不会进行训练操作epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)if epoch_test_acc > best_acc:best_acc = epoch_test_accbest_model = copy.deepcopy(model)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)print("epoch:%d, train_acc:%.1f%%, train_loss:%.3f, test_acc:%.1f%%, test_loss:%.3f"% (epoch + 1, epoch_train_acc * 100, epoch_train_loss, epoch_test_acc * 100, epoch_test_loss))T2 = time.time()
print('程序运行时间:%s秒' % (T2 - T1))PATH = './best_model.pth'  # 保存的参数文件名
if best_model is not None:torch.save(best_model.state_dict(), PATH)print('保存最佳模型')
print("Done")

在这里插入图片描述

可视化

可视化训练过程与测试过程

import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        #分辨率epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

在这里插入图片描述

总结

SE模块引入了通道注意力机制,使得网络在学习过程中能够更加自适应地关注对任务有用的通道,抑制对任务无关的通道。这有助于提高网络的特征表达能力。当前也可以与各种其他的深度神经网络结构集成。因此,可以在不改变整体网络架构的情况下,通过引入通道注意力机制来增强网络性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/580892.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

亿赛通电子文档安全管理系统 linkfilterservice 未授权漏洞

产品简介 亿赛通电子文档安全管理系统,(简称:CDG)是一款电子文档安全加密软件,该系统利用驱动层透明加密技术,通过对电子文档的加密保护,防止内部员工泄密和外部人员非法窃取企业核心重要数据资…

Spring企业开发核心框架

文章目录 Spring企业开发核心框架一、框架前言1. 总体技术体系2. 框架概念和理解 二、Spring Framework简介1. Spring 和 SpringFramework2. SpringFramework主要功能模块3. SpringFramework 主要优势 三、Spring IoC 容器概念1. 组件和组件管理概念2. Spring IoC容器和容器实现…

ALS-运动系统解构

角色握持 角色蓝图:将物体绑在手上 动作蓝图: 将握持动画截取一帧(explicit time时间写好) 角色替换 在原人物模型下面加一个骨骼体(先不用添加模型),重命名为bodymesh AI使用流程 新建一…

品牌如何在线上打造“社交货币”?媒介盒子揭秘

品牌的社交货币,是品牌与消费者的共识身份铸造器。竹筒奶茶、Keep奖牌这类的实体产品作为社交货币,每每能够引爆社交平台,那么品牌能否通过线上平台打造“社交货币”呢?接下来就让媒介盒子和大家聊聊。 一、社交货币是什么 社交货…

6.Nacos

1.单机部署 1.1 官网 https://nacos.io/zh-cn/index.html https://github.com/alibaba/Nacos 1.2.版本说明 https://github.com/alibaba/spring-cloud-alibaba/wiki/%E7%89%88%E6%9C%AC%E8%AF%B4%E6%98%8E 1.3.下载地址 https://github.com/alibaba/nacos/releases/tag/2.2.…

小区跑腿服务

社区跑腿服务是指在社区范围内为居民提供各种便利的服务,包括购物代劳、快递代取、家政服务等。 这种服务的出现,满足了居民生活中诸多需求,受到了广泛的欢迎和认可。 首先,社区跑腿服务方便了居民的日常生活。 居民无需亲自前…

克魔助手工具下载、注册和登录指南

下载安装克魔助手 摘要 本文介绍了如何下载安装克魔助手工具,以及注册和登录流程。通过简单的步骤,用户可以轻松获取并使用该工具,为后续的手机应用管理操作做好准备。 引言 克魔助手是一款免费的手机管理工具,通过该工具用户…

文章解读与完整程序——《考虑“源-荷-储”协同互动的主动配电网优化调度研究》

摘要:伴随智能电网的建设和清洁能源的开发利用,配电网中的负荷类型呈现多元化发展,分布式电源、可控负荷、储能等资源的增加让单向潮流的传统配电网逐渐向双向潮流的主动配电网结构转变。在能源结构转变的同时,清洁能源自身的随机性和波动性给配电网带来了更大的调峰…

2023.12.25 关于 Redis 数据类型 Hash 常用命令、内部编码、应用场景

目录 Hash 数据类型 Hash 操作命令 HSET HGET HEXISTS HDEL HKEYS HVALS HGETALL HMGET HLEN HSETNX HINCRBY HINCRBYFLOAT HSTRLEN Hash 编码方式 理解什么是压缩 Hash 实际应用 Cache 缓存 Hash 数据类型 整体上来说 Redis 是键值对结构,其中 …

使用docker创建自己的Android编译容器

文章目录 背景步骤1.创建Dockerfile2.编写Dockerfile指令3.编译4.使用 背景 每次拿到新机器或者系统重装,最麻烦的就是各种环境配置,最近学习了一下docker的知识,用dockerfile创建一个Android编译容器,这样就不用每次都吭哧吭哧的…

python 通过(三维坐标)生成(三维曲面地形图)和(圆柱曲面地形图)

有需要源代码CSDN私信我 注意 python项目移植前要进行以下操作 1.python项目备份 2.生成requirements.txt的库文件 以pycharm为例,生成Python项目所需要的依赖库/包文档:requirements.txt_如何将pycharm项目安装的库文件导出为requirement.txt-CSDN博…

揭秘千巡翼X4卫星通讯无人机

揭秘千巡翼X4卫星通讯无人机 在无人机作业的时候,经常遇到这些异常场景,例如通信网络中断,强干扰,导致无人机无法与飞手通信等。而这些给无人机作业带来三大难题: 难题1,山区作业时通信中断,飞…

Cookie的详解使用(创建,获取,销毁)

文章目录 Cookie的详解使用(创建,获取,销毁)1、Cookie是什么2、cookie的常用方法3、cookie的构造和获取代码演示SetCookieServlet.javaGetCookieServlet.javaweb.xml运行结果如下 4、Cookie的销毁DestoryCookieServletweb.xml运行…

Swift 周报 第四十二期

文章目录 前言新闻和社区苹果 CEO 库克透露接班计划,希望继任者来自公司内部消息称苹果自研 5G 调制解调器开发再“难产”,将推迟到 2026 年 提案正在审查的提案 Swift论坛推荐博文话题讨论关于我们 前言 本期是 Swift 编辑组整理周报的第四十二期&…

Android10(SDK29)以后存储问题

如果你的targetSDK>29 1:使用系统给app分配的内部存储不需要存储权限? 例如:context.getExternalFilesDir(null),随意使用 2:不能随意在外部存储创建文件/文件夹; 例如:Environment.getE…

maven工具的搭建以及使用

文章目录 🐒个人主页🏅JavaEE系列专栏📖前言:🎀首先进行maven工具的搭建🦓1.[打开下载 maven 服务器官网](http://maven.apache.org)🪅2.解压之后,配置环境变量🏨3.打开设…

[前端已死论]——“Java 已死、前端已凉”

一、为什么会出现“前端已死”的言论 信息溯源:“前端已死”的论调是如何传播的? - 知乎 前端已死的真相! - 知乎 好几次看到有其他程序员说:“前端已死!”,这句话虽然太极端了,但是我是比较…

Python入门-组合数据类型(元组,字典,集合)

1.元组 元组 是Python中内置的 不可变序列 在Python中使用 ( ) 定义元组,元素与元素之间使用 英文的逗号分隔 元组总 只有一个 元素的是否,逗号不能省略 元组的创建与删除 # 使用小括号创建元组 t(hello,[10,20,30],python,world) print(t)#使用内置函…

JVM基础原理篇-带你深入拆解G1垃圾回收原理

一、一统天下的G1垃圾回收器概述 大白话: 1.整个堆空间,新生代和老年代比例大概为2:8; 2.正常情况下,新生代回收是高频的,混合回收是频率是适中的,完全回收则是基本不会发生、频率低代价高的,一…

Unity之DOTweenPath轨迹移动

Unity之DOTweenPath轨迹移动 一、介绍 DOTweenPath二、操作说明1、Scene View Commands2、INfo3、Tween Options4、Path Tween Options5、Path Editor Options:轨迹编辑参数,就不介绍了6、ResetPath:重置轨迹7、Events:8、WayPoin…