第54步 深度学习图像识别:MLP-Mixer建模(Pytorch)

基于WIN10的64位系统演示

一、写在前面

(1)MLP-Mixer

MLP-Mixer(Multilayer Perceptron Mixer)是Google在2021年提出的一种新型的视觉模型结构。它的主要特点是完全使用多层感知机(MLP)来处理图像,而不是使用常见的卷积(Convolution)或者自注意力(Self-Attention)机制。

MLP-Mixer的结构主要包括两种类型的层:Token Mixing层和Channel Mixing层。在Token Mixing层中,模型会将图像分割成若干个patch(类似于像素块),然后对这些patch进行处理。在Channel Mixing层中,模型会对每个patch的通道进行处理。这两种类型的层交替堆叠,形成了最终的模型结构。

MLP-Mixer的设计目标是探索除卷积和自注意力之外的其他可能的模型结构,以期在保持性能的同时,降低模型的复杂性和计算成本。实验结果显示,MLP-Mixer在一些图像分类任务上的性能可以与ResNet和Transformer等主流模型相媲美。

然而,需要注意的是,虽然MLP-Mixer在某些方面展现出了很好的性能,但它并不意味着会替代卷积或者自注意力模型。实际上,每种模型都有其适用的场景和优势,MLP-Mixer提供了一个新的视角和工具,供我们处理视觉任务。

(2)MLP-Mixer的码源

本文使用 mlp-mixer-pytorch 库来实现MLP-Mixer。

当然,得先安装这个库:

(a)首先,打开Anaconda Prompt。在开始菜单中找到它,或者直接在搜索栏中输入"Anaconda Prompt"。在打开的Anaconda Prompt中,如果你想在一个特定的环境中安装mlp_mixer_pytorch,你需要先激活这个环境。假设你的环境名为myenv,你可以使用以下命令来激活这个环境:

conda activate myenv

(b)接下来,使用pip来安装mlp_mixer_pytorch库。在Anaconda Prompt中输入以下命令并按回车键:

pip install mlp-mixer-pytorch

二、MLP-Mixer迁移学习代码实战

我们继续胸片的数据集:肺结核病人和健康人的胸片的识别。其中,肺结核病人700张,健康人900张,分别存入单独的文件夹中。

(a)导入包

import copy
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision import models
from torch.utils.data import DataLoader
from torch import optim, nn
from torch.optim import lr_scheduler
import os
import matplotlib.pyplot as plt
import warnings
import numpy as npwarnings.filterwarnings("ignore")
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False# 设置GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

(b)导入数据集

import torch
from torchvision import datasets, transforms
import os# 数据集路径
data_dir = "./MTB"# 图像的大小
img_height = 256
img_width = 256# 数据预处理
data_transforms = {'train': transforms.Compose([transforms.RandomResizedCrop(img_height),transforms.RandomHorizontalFlip(),transforms.RandomVerticalFlip(),transforms.RandomRotation(0.2),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'val': transforms.Compose([transforms.Resize((img_height, img_width)),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}# 加载数据集
full_dataset = datasets.ImageFolder(data_dir)# 获取数据集的大小
full_size = len(full_dataset)
train_size = int(0.7 * full_size)  # 假设训练集占80%
val_size = full_size - train_size  # 验证集的大小# 随机分割数据集
torch.manual_seed(0)  # 设置随机种子以确保结果可重复
train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])# 将数据增强应用到训练集
train_dataset.dataset.transform = data_transforms['train']# 创建数据加载器
batch_size = 32
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=4)dataloaders = {'train': train_dataloader, 'val': val_dataloader}
dataset_sizes = {'train': len(train_dataset), 'val': len(val_dataset)}
class_names = full_dataset.classes

(c)导入MLPMixer

from mlp_mixer_pytorch import MLPMixernum_classes = len(class_names)  # 根据数据集的类别数量来设置模型的输出类别数量# 构建MLP-Mixer模型
model = MLPMixer(image_size = img_height,  # 图像的高和宽channels = 3,  # 图像的通道数patch_size = 16,  # MLP-Mixer的patch大小dim = 512,  # MLP-Mixer的维度depth = 12,  # MLP-Mixer的深度num_classes = num_classes  # 输出类别数量
)# 将模型移动到GPU
model = model.to(device)# 打印模型摘要
print(model)

说明:mlp-mixer-pytorch库的主要功能就是提供了一个MLP-Mixer的类,可以通过实例化这个类来创建一个MLP-Mixer模型。在创建模型时,可以通过参数来设置图像的大小、通道数、patch的大小、模型的维度、深度以及输出类别的数量等。

需要注意的是,mlp-mixer-pytorch库提供的MLP-Mixer模型默认是随机初始化的,也就是说并没有加载预训练权重。如果你有MLP-Mixer的预训练权重,可以在创建模型后加载。

(d)编译模型

# 定义损失函数
criterion = nn.CrossEntropyLoss()# 定义优化器
optimizer = optim.Adam(model.parameters())# 定义学习率调度器
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)# 开始训练模型
num_epochs = 20
best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0# 初始化记录器
train_loss_history = []
train_acc_history = []
val_loss_history = []
val_acc_history = []for epoch in range(num_epochs):print('Epoch {}/{}'.format(epoch, num_epochs - 1))print('-' * 10)# 每个epoch都有一个训练和验证阶段for phase in ['train', 'val']:if phase == 'train':model.train()  # Set model to training modeelse:model.eval()   # Set model to evaluate moderunning_loss = 0.0running_corrects = 0# 遍历数据for inputs, labels in dataloaders[phase]:inputs = inputs.to(device)labels = labels.to(device)# 零参数梯度optimizer.zero_grad()# 前向with torch.set_grad_enabled(phase == 'train'):outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)# 只在训练模式下进行反向和优化if phase == 'train':loss.backward()optimizer.step()# 统计running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / dataset_sizes[phase]epoch_acc = (running_corrects.double() / dataset_sizes[phase]).item()# 记录每个epoch的loss和accuracyif phase == 'train':train_loss_history.append(epoch_loss)train_acc_history.append(epoch_acc)else:val_loss_history.append(epoch_loss)val_acc_history.append(epoch_acc)print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))# 深拷贝模型if phase == 'val' and epoch_acc > best_acc:best_acc = epoch_accbest_model_wts = copy.deepcopy(model.state_dict())print()print('Best val Acc: {:4f}'.format(best_acc))# 加载最佳模型权重
#model.load_state_dict(best_model_wts)
#torch.save(model, 'shufflenet_best_model.pth')
#print("The trained model has been saved.")

(e)Accuracy和Loss可视化

epoch = range(1, len(train_loss_history)+1)fig, ax = plt.subplots(1, 2, figsize=(10,4))
ax[0].plot(epoch, train_loss_history, label='Train loss')
ax[0].plot(epoch, val_loss_history, label='Validation loss')
ax[0].set_xlabel('Epochs')
ax[0].set_ylabel('Loss')
ax[0].legend()ax[1].plot(epoch, train_acc_history, label='Train acc')
ax[1].plot(epoch, val_acc_history, label='Validation acc')
ax[1].set_xlabel('Epochs')
ax[1].set_ylabel('Accuracy')
ax[1].legend()#plt.savefig("loss-acc.pdf", dpi=300,format="pdf")

观察模型训练情况:

 蓝色为训练集,橙色为验证集。

(f)混淆矩阵可视化以及模型参数

from sklearn.metrics import classification_report, confusion_matrix
import math
import pandas as pd
import numpy as np
import seaborn as sns
from matplotlib.pyplot import imshow# 定义一个绘制混淆矩阵图的函数
def plot_cm(labels, predictions):# 生成混淆矩阵conf_numpy = confusion_matrix(labels, predictions)# 将矩阵转化为 DataFrameconf_df = pd.DataFrame(conf_numpy, index=class_names ,columns=class_names)  plt.figure(figsize=(8,7))sns.heatmap(conf_df, annot=True, fmt="d", cmap="BuPu")plt.title('Confusion matrix',fontsize=15)plt.ylabel('Actual value',fontsize=14)plt.xlabel('Predictive value',fontsize=14)def evaluate_model(model, dataloader, device):model.eval()   # 设置模型为评估模式true_labels = []pred_labels = []# 遍历数据for inputs, labels in dataloader:inputs = inputs.to(device)labels = labels.to(device)# 前向with torch.no_grad():outputs = model(inputs)_, preds = torch.max(outputs, 1)true_labels.extend(labels.cpu().numpy())pred_labels.extend(preds.cpu().numpy())return true_labels, pred_labels# 获取预测和真实标签
true_labels, pred_labels = evaluate_model(model, dataloaders['val'], device)# 计算混淆矩阵
cm_val = confusion_matrix(true_labels, pred_labels)
a_val = cm_val[0,0]
b_val = cm_val[0,1]
c_val = cm_val[1,0]
d_val = cm_val[1,1]# 计算各种性能指标
acc_val = (a_val+d_val)/(a_val+b_val+c_val+d_val)  # 准确率
error_rate_val = 1 - acc_val  # 错误率
sen_val = d_val/(d_val+c_val)  # 灵敏度
sep_val = a_val/(a_val+b_val)  # 特异度
precision_val = d_val/(b_val+d_val)  # 精确度
F1_val = (2*precision_val*sen_val)/(precision_val+sen_val)  # F1值
MCC_val = (d_val*a_val-b_val*c_val) / (np.sqrt((d_val+b_val)*(d_val+c_val)*(a_val+b_val)*(a_val+c_val)))  # 马修斯相关系数# 打印出性能指标
print("验证集的灵敏度为:", sen_val, "验证集的特异度为:", sep_val,"验证集的准确率为:", acc_val, "验证集的错误率为:", error_rate_val,"验证集的精确度为:", precision_val, "验证集的F1为:", F1_val,"验证集的MCC为:", MCC_val)# 绘制混淆矩阵
plot_cm(true_labels, pred_labels)# 获取预测和真实标签
train_true_labels, train_pred_labels = evaluate_model(model, dataloaders['train'], device)
# 计算混淆矩阵
cm_train = confusion_matrix(train_true_labels, train_pred_labels)  
a_train = cm_train[0,0]
b_train = cm_train[0,1]
c_train = cm_train[1,0]
d_train = cm_train[1,1]
acc_train = (a_train+d_train)/(a_train+b_train+c_train+d_train)
error_rate_train = 1 - acc_train
sen_train = d_train/(d_train+c_train)
sep_train = a_train/(a_train+b_train)
precision_train = d_train/(b_train+d_train)
F1_train = (2*precision_train*sen_train)/(precision_train+sen_train)
MCC_train = (d_train*a_train-b_train*c_train) / (math.sqrt((d_train+b_train)*(d_train+c_train)*(a_train+b_train)*(a_train+c_train))) 
print("训练集的灵敏度为:",sen_train, "训练集的特异度为:",sep_train,"训练集的准确率为:",acc_train, "训练集的错误率为:",error_rate_train,"训练集的精确度为:",precision_train, "训练集的F1为:",F1_train,"训练集的MCC为:",MCC_train)# 绘制混淆矩阵
plot_cm(train_true_labels, train_pred_labels)

效果不错:

 (g)AUC曲线绘制

from sklearn import metrics
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import pandas as pd
import mathdef plot_roc(name, labels, predictions, **kwargs):fp, tp, _ = metrics.roc_curve(labels, predictions)plt.plot(fp, tp, label=name, linewidth=2, **kwargs)plt.plot([0, 1], [0, 1], color='orange', linestyle='--')plt.xlabel('False positives rate')plt.ylabel('True positives rate')ax = plt.gca()ax.set_aspect('equal')# 确保模型处于评估模式
model.eval()train_ds = dataloaders['train']
val_ds = dataloaders['val']val_pre_auc   = []
val_label_auc = []for images, labels in val_ds:for image, label in zip(images, labels):      img_array = image.unsqueeze(0).to(device)  # 在第0维增加一个维度并将图像转移到适当的设备上prediction_auc = model(img_array)  # 使用模型进行预测val_pre_auc.append(prediction_auc.detach().cpu().numpy()[:,1])val_label_auc.append(label.item())  # 使用Tensor.item()获取Tensor的值
auc_score_val = metrics.roc_auc_score(val_label_auc, val_pre_auc)train_pre_auc   = []
train_label_auc = []for images, labels in train_ds:for image, label in zip(images, labels):img_array_train = image.unsqueeze(0).to(device) prediction_auc = model(img_array_train)train_pre_auc.append(prediction_auc.detach().cpu().numpy()[:,1])  # 输出概率而不是标签!train_label_auc.append(label.item())
auc_score_train = metrics.roc_auc_score(train_label_auc, train_pre_auc)plot_roc('validation AUC: {0:.4f}'.format(auc_score_val), val_label_auc , val_pre_auc , color="red", linestyle='--')
plot_roc('training AUC: {0:.4f}'.format(auc_score_train), train_label_auc, train_pre_auc, color="blue", linestyle='--')
plt.legend(loc='lower right')
#plt.savefig("roc.pdf", dpi=300,format="pdf")print("训练集的AUC值为:",auc_score_train, "验证集的AUC值为:",auc_score_val)

ROC曲线如下:

 这个ROC曲线也是不错的!全部大于95%!

三、写在最后

截至目前,图像分类领域基本就是CNN、Transformer和MLP三足鼎立了。孰优孰劣,还不好说,中庸之道那就是各有千秋。他们之间的两两组合或者一起融合的话,效果又会如何?

四、数据

链接:https://pan.baidu.com/s/15vSVhz1rQBtqNkNp2GQyVw?pwd=x3jf

提取码:x3jf

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/7540.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

3dsmax制作一个小人

文章目录 步骤起阶五官手臂短袖添加头发、头饰BodyPaint软件贴图导入到3dsmax 渲染 步骤 起阶 五官 手臂 短袖 添加头发、头饰 BodyPaint软件贴图 寻找网络贴图,用PS切割,用BodyPaint恢复纹理 导入到3dsmax 渲染

【三维点云处理】顶点、面片、邻接矩阵、邻接距离矩阵以及稀疏存储概念

文章目录 vts和faces基础知识vertices-节点(3是点的三维坐标)faces-面片(3是构成三角形面片的3个点) 邻接矩阵邻接距离矩阵(NN500)稀疏矩阵 vts和faces基础知识 vertices-节点(3是点的三维坐标…

设计模式大白话——观察者模式

文章目录 一、概述二、示例三、模式定义四、其他 一、概述 ​ 与其叫他观察者模式,我更愿意叫他叫 订阅-发布模式 ,这种模式在我们生活中非常常见,比如:追番了某个电视剧,当电视剧有更新的时候会第一时间通知你。当你…

Fuzz测试:提升自动驾驶安全性

目录 什么是Fuzz测试? 自动驾驶的潜在风险 Fuzz测试:自动驾驶和车联网 Fuzz测试方法有以下几种: 资料获取方法 纵观近百年来汽车制造业的发展历程,产业跨进的每一步背后都有着技术创新作为支撑。汽车技术创新对世界经济、社会…

数学建模学习(3):综合评价类问题整体解析及分析步骤

一、评价类算法的简介 对物体进行评价,用具体的分值评价它们的优劣 选这两人其中之一当男朋友,你会选谁? 不同维度的权重会产生不同的结果 所以找到每个维度的权重是最核心的问题 0.25 二、评价前的数据处理 供应商ID 可靠性 指标2 指…

基于Android Studio编辑器上开发的一款看点新闻App

完整资料进入【数字空间】查看——baidu搜索"writebug" 1 系统需求分析 1.1 引言 1.1.1 开发目的 看点新闻App的开发是为了实时查看最新消息以了解社会动态,增长知识,增广见闻,顺便娱乐一下内心世界来放松自己。 1.1.2 开发背景 …

【Spring Boot Admin】使用(整合Spring Security服务,添加鉴权)

Spring Boot Admin 监控平台 背景:Spring Boot Admin 监控平台不添加鉴权就直接访问的话,是非常不安全的。所以在生产环境中使用时,需要添加鉴权,只有通过鉴权后才能监控客户端服务。本文整合Spring Security进行实现。 pom依赖 …

Vue第四篇:html和js基础知识查漏补缺

1、a标签 定义超链接,用于从一个页面链接到另一个页面 target属性:打开目标URL的方式,_top为再当前窗口打开,_blank为新窗口打开 2、span标签 对文档中的行内元素进行组合,它提供了一种将文本的一部分或者文档的一部分…

Jmeter(二十三):快速生成测试报告

一、jmeter配置 首先要保证jmeter命令是ok的,如果你在cmd中输入jmeter -v,有出现如下截图所示的信息,那就说明jmeter环境ok; 二、jmeter执行结合命令 生成HTML测试报告 1.完成脚本的调试、参数化、断言等操作。然后在聚合报告中指定日志文件存储路径,路径中最好不要包含有…

通过电商项目,详解抓包到接口测试,附图片验证码 +cookie 问题处理!

通常来说,进行接口测试,开发会提供对应的接口文档给到测试,但也有例外。开发无接口文档,但领导又需要你对刚开发的软件,进行接口测试、接口自动化测试、甚至是性能测试。这个时候作为专业测试应该怎么办? …

[元带你学: eMMC协议 28] eMMC 上电时序 | eMMC 上电指南

依JEDEC eMMC及经验辛苦整理,原创保护,禁止转载。 专栏 《元带你学:eMMC协议》 内容摘要 全文 1500 字, 主要内容 eMMC 上电规范 和 eMMC 上电指南, 这部分内容偏向电气特性,如果不是硬件的同学只要特别浅的了解, 一带而过。 eMMC 上电规范 eMMC 电压 VCCQ指的是接口…

视频文件批量添加字幕内容需要如何快速操作

有时候我们在剪辑视频的过程中,想要给视频素材添加上一些文字说明,需要如何操作呢?为了提高剪辑效率,今天小编来分享教学,教你如何才能批量地给视频素材添加滚动字幕,一起来看看具体的方法介绍吧。 我们先打…

《吐血整理》保姆级系列教程-玩转Fiddler抓包教程(2)-初识Fiddler让你理性认识一下

1.前言 今天的理性认识主要就是讲解和分享Fiddler的一些理论基础知识。其实这部分也没有什么,主要是给小伙伴或者童鞋们讲一些实际工作中的场景,然后隆重推出我们的猪脚(主角)-Fiddler。 1.1工作场景 做app测试,你是…

正则表达式 —— Grep

文本处理三剑客:Grep、Sed、Awk 这三个工具都是基于对文本的内容进行增删改查的操作,此篇着重介绍grep与正则表达式的应用,以及扩展正则表达式。 正则表达式 什么是正则表达式? 它是由一类特殊字符以及文本字符所编写的一种模式…

华为云零代码平台AstroZero新手操作指南-3分钟体验创建培训报名表

华为云Astro轻应用Astro Zero是华为云为行业客户、合作伙伴、开发者量身打造的低代码/零代码应用开发平台,提供全场景可视化开发能力和端到端部署能力,可快速搭建行业和大型企业级应用并沉淀复用行业资产,加速行业数字化。 在AstroZero上&am…

程序员如何向老板提加薪?

今天的问题不仅适用于程序员,对于其他职业同样适用。如果你认为自己所做的工作应该得到更多的报酬,并且想为此做点什么,你有两个选择:找一个新的高薪工作或要求加薪。 这两种选择都会带来新的焦虑,但它们都会带来新的…

gerrit 提交搞了一天的账号密码

搞了一整天的账号密码怎么输入都不对 以为输入了也不对,查找各种文档也不太行 参考也不太行: https://blog.csdn.net/qq_43279637/article/details/103595122 最后发现 是使用了git clone http 脑残方式,正确应该使用 git clone ssh 就可以…

指数函数exp

目录 指数函数及e 指数增长 复数指数 练习 1. expgui 2. 计算e 3 五角星绘制 指数函数及e (1)的比值总是常数 (2)的导数为其自身。(根据比值1推导出e的值) %% Plot a^t and its approximate derivat…

【ribbon】Ribbon的负载均衡和扩展功能

Ribbon的核心接口 参考:org.springframework.cloud.netflix.ribbon.RibbonClientConfiguration IClientConfig:Ribbon的客户端配置,默认采用DefaultClientConfigImpl实现。IRule:Ribbon的负载均衡策略,默认采用ZoneA…

iPortal 注册登录模块扩展开发

作者:yx 文章目录 前言一、示例代码简介二、对接 iPortal REST API 接口2.1、登录模块扩展开发2.2、注册模块扩展开发 三、页面内容及样式实现四、配置启用定制页面 前言 针对注册登录模块,iPortal 允许用户通过 iFrame 方式接入自行开发的页面&#xf…