Pytorch迁移学习训练病变分类模型

划分数据集

1.创建训练集文件夹和测试集文件夹

# 创建 train 文件夹
os.mkdir(os.path.join(dataset_path, 'train'))# 创建 test 文件夹
os.mkdir(os.path.join(dataset_path, 'val'))# 在 train 和 test 文件夹中创建各类别子文件夹
for Retinopathy in classes:os.mkdir(os.path.join(dataset_path, 'train', Retinopathy))os.mkdir(os.path.join(dataset_path, 'val', Retinopathy))

2.划分训练集、测试集,移动文件

test_frac = 0.2  # 测试集比例
random.seed(123) # 随机数种子,用来打乱数据集df = pd.DataFrame()print('{:^18} {:^18} {:^18}'.format('类别', '训练集数据个数', '测试集数据个数'))for Retinopathy in classes: # 遍历每个类别# 读取该类别的所有图像文件名old_dir = os.path.join(dataset_path, Retinopathy)images_filename = os.listdir(old_dir)random.shuffle(images_filename) # 随机打乱# 划分训练集和测试集testset_numer = int(len(images_filename) * test_frac) # 测试集图像个数testset_images = images_filename[:testset_numer]      # 获取拟移动至 test 目录的测试集图像文件名trainset_images = images_filename[testset_numer:]     # 获取拟移动至 train 目录的训练集图像文件名# 移动图像至 test 目录for image in testset_images:old_img_path = os.path.join(dataset_path, Retinopathy, image)         # 获取原始文件路径new_test_path = os.path.join(dataset_path, 'val', Retinopathy, image) # 获取 test 目录的新文件路径shutil.move(old_img_path, new_test_path) # 移动文件# 移动图像至 train 目录for image in trainset_images:old_img_path = os.path.join(dataset_path, Retinopathy, image)           # 获取原始文件路径new_train_path = os.path.join(dataset_path, 'train', Retinopathy, image) # 获取 train 目录的新文件路径shutil.move(old_img_path, new_train_path) # 移动文件# 删除旧文件夹assert len(os.listdir(old_dir)) == 0 # 确保旧文件夹中的所有图像都被移动走shutil.rmtree(old_dir) # 删除文件夹# 工整地输出每一类别的数据个数print('{:^18} {:^18} {:^18}'.format(Retinopathy, len(trainset_images), len(testset_images)))# 保存到表格中df = df.append({'class':Retinopathy, 'trainset':len(trainset_images), 'testset':len(testset_images)}, ignore_index=True)# 重命名数据集文件夹
shutil.move(dataset_path, dataset_name+'_split')# 数据集各类别数量统计表格,导出为 csv 文件
df['total'] = df['trainset'] + df['testset']
df.to_csv('数据量统计.csv', index=False)

结果如下:

统计各类别数据个数柱状图

1.导入工具包

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

2.设置matplotlib的中文字体,因为它默认无法写中文字体

plt.rcParams['font.sans-serif']=['SimHei']  # 用来正常显示中文标签 
plt.rcParams['axes.unicode_minus']=False  # 用来正常显示负号

3.指定可视化的特征

feature = 'total'
# feature = 'trainset'
# feature = 'testset'df = df.sort_values(by=feature, ascending=False)

4.通过柱状图展示出来

plt.figure(figsize=(22, 7))x = df['class']
y = df[feature]plt.bar(x, y, facecolor='#1f77b4', edgecolor='k')plt.xticks(rotation=90)
plt.tick_params(labelsize=15)
plt.xlabel('类别', fontsize=20)
plt.ylabel('图像数量', fontsize=20)# plt.savefig('各类别图片数量.pdf', dpi=120, bbox_inches='tight')plt.show()

结果如下:

由此可见,数据集是比较均衡的。

5.将训练集与测试集的比例展示出来

plt.figure(figsize=(22, 7))
x = df['class']
y1 = df['testset']
y2 = df['trainset']width = 0.55 # 柱状图宽度plt.xticks(rotation=90) # 横轴文字旋转plt.bar(x, y1, width, label='测试集')
plt.bar(x, y2, width, label='训练集', bottom=y1)plt.xlabel('类别', fontsize=20)
plt.ylabel('图像数量', fontsize=20)
plt.tick_params(labelsize=13) # 设置坐标文字大小plt.legend(fontsize=16) # 图例# 保存为高清的 pdf 文件
plt.savefig('各类别图像数量.pdf', dpi=120, bbox_inches='tight')plt.show()

结果如下:

处理完数据集后,就可以开始通过迁移学习训练病变分类模型。

安装配置环境

1.numpy、pandas、matplotlib、seaborn、plotly、requests、tqdm、opencv-python、pillow、wandb和pytorch均已安装完成

2.创建三个文件夹

import os# 存放结果文件
os.mkdir('output')# 存放训练得到的模型权重
os.mkdir('checkpoint')# 存放生成的图表
os.mkdir('图表')

迁移学习训练过程与前处理

1.导入包

import time
import osimport numpy as np
from tqdm import tqdmimport torch
import torchvision
import torch.nn as nn
import torch.nn.functional as Fimport matplotlib.pyplot as plt
%matplotlib inlineimport warnings
warnings.filterwarnings("ignore")

2.获取计算机的硬件,使用CPU还是GPU

# 有 GPU 就用 GPU,没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)

3.图像预处理

from torchvision import transforms# 训练集图像预处理:缩放裁剪、图像增强、转 Tensor、归一化
train_transform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# 测试集图像预处理:缩放、裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

对训练集和测试集分别进行预处理。

训练集的预处理中,RandomResizedCrop(224)表示随机选择一个面积比例,并在该比例下随机裁剪图像,然后将裁剪后的图像缩放到指定的尺寸,参数 224 指定了裁剪并缩放后的图像尺寸应该是 224x224 像素。RandomHorizontalFlip()是进行随机的水平翻转,目的是图像增强。最后转成pytorch的tensor格式进行归一化。归一化的6个参数约定俗成。

4.载入图像分类数据集

from torchvision import datasets# 数据集文件夹路径
dataset_dir = 'E:\科研实验\Train_Custom_Dataset-main\图像分类\dataset_split'train_path = os.path.join(dataset_dir, 'train')
test_path = os.path.join(dataset_dir, 'val')# 载入训练集
train_dataset = datasets.ImageFolder(train_path, train_transform)# 载入测试集
test_dataset = datasets.ImageFolder(test_path, test_transform)

结果如下:

5.类别和索引号一一对应,方便后续的查询

# 映射关系:索引号 到 类别
idx_to_labels = {y:x for x,y in train_dataset.class_to_idx.items()}# 保存为本地的 npy 文件
np.save('idx_to_labels.npy', idx_to_labels)
np.save('labels_to_idx.npy', train_dataset.class_to_idx)

6.定义数据加载器DataLoader

from torch.utils.data import DataLoaderBATCH_SIZE = 32# 训练集的数据加载器
train_loader = DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=4)# 测试集的数据加载器
test_loader = DataLoader(test_dataset,batch_size=BATCH_SIZE,shuffle=False,num_workers=4)

7.可视化一个batch的图像和标注

# 将数据集中的Tensor张量转为numpy的array数据类型
images = images.numpy()

举个例子,images[5].shape展示的是一个批次中第五张图片的信息,结果如下:

images[5]的像素分布如下所示:

显示上图所用代码为:

plt.hist(images[5].flatten(), bins=50)
plt.show()

之前通过预处理归一化,已经将每一个像素都减去它所在通道的均值,再除以它所在通道的标准差了,所以现在的像素不再分布在0~255的整数范围内,而是一个以0为均值的,有正有负的分布。这样的分布更容易被神经网络处理,正如上图所示。

归一化后的图像如下所示:

显示上图所用代码为:

# batch 中经过预处理的图像
idx = 5
plt.imshow(images[idx].transpose((1,2,0))) # 转为(224, 224, 3)
plt.title('label:'+str(labels[idx].item()))

此图的原图像为:

显示上图所用代码为:

# 原始图像
idx = 5
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
plt.imshow(np.clip(images[idx].transpose((1,2,0)) * std + mean, 0, 1))
plt.title('label:'+ pred_classname)
plt.show()

8.选择迁移学习训练的方式

视网膜图像和ImageNet的分布不是很一致,所以这里采用“微调训练所有层”的方式

①调整训练所有层

model = model.to(device)# 交叉熵损失函数
criterion = nn.CrossEntropyLoss() # 训练轮次 Epoch
EPOCHS = 30# 学习率降低策略
lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

②函数:在训练集上训练

from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score
from sklearn.metrics import roc_auc_scoredef train_one_batch(images, labels):'''运行一个 batch 的训练,返回当前 batch 的训练日志'''# 获得一个 batch 的数据和标注images = images.to(device)labels = labels.to(device)outputs = model(images) # 输入模型,执行前向预测loss = criterion(outputs, labels) # 计算当前 batch 中,每个样本的平均交叉熵损失函数值# 优化更新权重optimizer.zero_grad()loss.backward()optimizer.step()# 获取当前 batch 的标签类别和预测类别_, preds = torch.max(outputs, 1) # 获得当前 batch 所有图像的预测类别preds = preds.cpu().numpy()loss = loss.detach().cpu().numpy()outputs = outputs.detach().cpu().numpy()labels = labels.detach().cpu().numpy()log_train = {}log_train['epoch'] = epochlog_train['batch'] = batch_idx# 计算分类评估指标log_train['train_loss'] = losslog_train['train_accuracy'] = accuracy_score(labels, preds)# log_train['train_precision'] = precision_score(labels, preds, average='macro')# log_train['train_recall'] = recall_score(labels, preds, average='macro')# log_train['train_f1-score'] = f1_score(labels, preds, average='macro')return log_train

返回的log_train是训练日志

③函数:在整个测试集上评估

def evaluate_testset():'''在整个测试集上评估,返回分类评估指标日志'''loss_list = []labels_list = []preds_list = []with torch.no_grad():for images, labels in test_loader: # 生成一个 batch 的数据和标注images = images.to(device)labels = labels.to(device)outputs = model(images) # 输入模型,执行前向预测# 获取整个测试集的标签类别和预测类别_, preds = torch.max(outputs, 1) # 获得当前 batch 所有图像的预测类别preds = preds.cpu().numpy()loss = criterion(outputs, labels) # 由 logit,计算当前 batch 中,每个样本的平均交叉熵损失函数值loss = loss.detach().cpu().numpy()outputs = outputs.detach().cpu().numpy()labels = labels.detach().cpu().numpy()loss_list.append(loss)labels_list.extend(labels)preds_list.extend(preds)log_test = {}log_test['epoch'] = epoch# 计算分类评估指标log_test['test_loss'] = np.mean(loss_list)log_test['test_accuracy'] = accuracy_score(labels_list, preds_list)log_test['test_precision'] = precision_score(labels_list, preds_list, average='macro')log_test['test_recall'] = recall_score(labels_list, preds_list, average='macro')log_test['test_f1-score'] = f1_score(labels_list, preds_list, average='macro')return log_test

返回的log_test是测试日志

④登录wandb(可在网页、手机、iPad上实时监控日志)

安装 wandb:pip install wandb

登录 wandb:在命令行中运行wandb login

按提示复制粘贴API Key至命令行中

⑤创建wandb可视化项目

import wandbwandb.init(project='视网膜病变', name=time.strftime('%m%d%H%M%S'))

⑥运行训练

for epoch in range(1, EPOCHS+1):print(f'Epoch {epoch}/{EPOCHS}')## 训练阶段model.train()for images, labels in tqdm(train_loader): # 获得一个 batch 的数据和标注batch_idx += 1log_train = train_one_batch(images, labels)df_train_log = df_train_log.append(log_train, ignore_index=True)wandb.log(log_train)lr_scheduler.step()## 测试阶段model.eval()log_test = evaluate_testset()df_test_log = df_test_log.append(log_test, ignore_index=True)wandb.log(log_test)# 保存最新的最佳模型文件if log_test['test_accuracy'] > best_test_accuracy: # 删除旧的最佳模型文件(如有)old_best_checkpoint_path = 'checkpoint/best-{:.3f}.pth'.format(best_test_accuracy)if os.path.exists(old_best_checkpoint_path):os.remove(old_best_checkpoint_path)# 保存新的最佳模型文件best_test_accuracy = log_test['test_accuracy']new_best_checkpoint_path = 'checkpoint/best-{:.3f}.pth'.format(log_test['test_accuracy'])torch.save(model, new_best_checkpoint_path)print('保存新的最佳模型', 'checkpoint/best-{:.3f}.pth'.format(best_test_accuracy))# best_test_accuracy = log_test['test_accuracy']df_train_log.to_csv('训练日志-训练集.csv', index=False)
df_test_log.to_csv('训练日志-测试集.csv', index=False)

wandb的监控结果如下所示:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/3442.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Windows】达芬奇19安装教程

DaVinci Resolve Studio是一个结合专业的8k编辑、颜色混合、视觉效果和音频后期制作的软件。只需点击一下,你就可以立即在编辑、混音、特效和音频流之间切换。此外,达芬奇是一个多用户协作的解决方案,使编辑、助理、色彩学家、视觉效果设计师…

OS复习笔记ch4

引言 上一章,我们学习了进程的相关概念和知识,不知道小伙伴们的学习进度如何,没看的小伙伴记得去专栏看完哦。 线程从何而来 我们之前说过,进程是对程序运行过程的抽象,它的抽象程度是比较高的。 一个进程往往对应一…

C++:静态成员变量和静态成员方法

静态成员变量 C中的静态成员变量是属于类而不是类的实例的变量。这意味着无论创建了多少个类的实例,静态成员变量都只有一个副本,并且可以被所有类的实例共享。 让我们来看一个示例: class RolePlayer { public://静态成员变量static int …

值得让英伟达CEO黄仁勋亲自给OpenAI配送的AI服务器!一文带你了解算力,GPU,CPU!

大家好,我是木易,一个持续关注AI领域的互联网技术产品经理,国内Top2本科,美国Top10 CS研究生,MBA。我坚信AI是普通人变强的“外挂”,所以创建了“AI信息Gap”这个公众号,专注于分享AI全维度知识…

怎么办,孟德尔随机化连锁不平衡跑不了!这里有本地连锁不平衡分析方法

大家都知道,孟德尔随机化很大程度依赖于国外的服务器。 最近我们发现孟德尔随机化常用的TwoSampleMR包的clump函数经常报错,这是由于服务器访问人群超时造成的现象,当线上版本失效。 很多人做孟德尔随机化,就卡在clump上。 于是我…

OpenStack云计算(十)——OpenStack虚拟机实例管理,增加一个计算节点并进行实例冷迁移,增加一个计算节点的步骤,实例冷迁移的操作方法

项目实训一 本实训任务对实验环境要求较高,而且过程比较复杂,涉及的步骤非常多,有一定难度,可根据需要选做。可以考虑改为直接观看相关的微课视频 【实训题目】 增加一个计算节点并进行实例冷迁移 【实训目的】 熟悉增加一个…

牛客NC199 字符串解码【中等 递归,栈的思想 C++/Java/Go/PHP】

题目 题目链接: https://www.nowcoder.com/practice/4e008fd863bb4681b54fb438bb859b92 相同题目: https://www.lintcode.com/problem/575 思路 解法和基础计算器1,2,3类似,递归参考答案C struct Info {string str;int stopindex;Info(str…

AOC vs. DAC:哪个更适合您的网络需求?

在现代网络通信中,选择合适的连接线缆对于数据传输的稳定性和速度至关重要。两种常见的线缆类型是 AOC(Active Optical Cable) 和 DAC(Direct Attach Cable)。本文将详细介绍这两种线缆的特点、优势和适用场景&#xf…

Aigtek:介电弹性体高压放大器在软体机器人研究中的应用

近年来软体机器人的研究成为目前机器人研究领域的热点,由于软体材料的自由度可以根据需求自由变化,因此软体机器人有着极高的灵活性,而且软体机器人因其材料的柔软性有着很好的人机交互性能和安全性。它的出现成功解决了传统的刚性机器人人机…

JavaScript云LIS系统概述 前端框架JQuery+EasyUI+Bootstrap医院云HIS系统源码 开箱即用

云LIS系统概述JavaScript前端框架JQueryEasyUIBootstrap医院云HIS系统源码 开箱即用 云LIS(云实验室信息管理系统)是一种结合了计算机网络化信息系统的技术,它无缝嵌入到云HIS(医院信息系统)中,用于连…

《异常检测——从经典算法到深度学习》27 可执行且可解释的在线服务系统中重复故障定位方法

《异常检测——从经典算法到深度学习》 0 概论1 基于隔离森林的异常检测算法 2 基于LOF的异常检测算法3 基于One-Class SVM的异常检测算法4 基于高斯概率密度异常检测算法5 Opprentice——异常检测经典算法最终篇6 基于重构概率的 VAE 异常检测7 基于条件VAE异常检测8 Donut: …

Oracle 监控 SQL 精选 (一)

Oracle数据库的监控通常涉及性能、空间、会话、对象、备份、安全等多个层面。 有效的监控可以帮助 DBA 及时发现和解决问题,提高数据库的稳定性和性能,保障企业的数据安全和业务连续性。 常用的监控指标有: 性能指标: 查询响应时间…

抽象工厂模式(Redis 集群升级)

目录 定义 Redis 集群升级 模拟单机服务 RedisUtils 模拟集群 EGM 模拟集群 IIR 定义使⽤接⼝ 实现调⽤代码 代码实现 定义适配接⼝ 实现集群使⽤服务 EGMCacheAdapter IIRCacheAdapter 定义抽象⼯程代理类和实现 JDKProxy JDKInvocationHandler 测试验证 定义 …

Mockaroo - 在线生成测试用例利器

简介:Mockaroo 是一个无需安装的在线工具,用于生成大量的自定义测试数据。它支持多种数据格式,如JSON、CSV、SQL和Excel,并能模拟复杂的数据结构。 历史攻略: 测试用例:多条件下编写,懒人妙用…

ChatGPT付费创作系统V2.8.4独立版 WEB+H5+小程序端 (新增Pika视频+短信宝+DALL-E-3+Midjourney接口)

小狐狸GPT付费体验系统最新版系统是一款基于ThinkPHP框架开发的AI问答小程序,是基于国外很火的ChatGPT进行开发的Ai智能问答小程序。当前全民热议ChatGPT,流量超级大,引流不要太简单!一键下单即可拥有自己的GPT!无限多…

网盘——文件重命名

文件重命名具体步骤如下: 目录 1、具体步骤 2、代码实现 2.1、添加重命名文件的槽函数 2.2、关联重命名文件夹信号槽 2.3、添加重命名文件的协议 2.4、添加槽函数定义 2.5、服务器 2.6、添加重命名文件的case 2.7、客户端接收回复 3、测试 3.1、点击重命…

debian配置四叶草输入法

效果展示 一、前言 在linux下体验比较好的输入法只有两款:搜狗输入法、四叶草输入法。 ubuntu下可以成功配置搜狗输入法,但debian下从来没有成功过。 今天在用fcitx5 四叶草时发现VNC远程输入法会失灵,于是改用了ibus 四叶草&#xff0c…

Qt : 禁用控件默认的鼠标滚轮事件

最近在写一个模拟器,在item中添加了很多的控件,这些控件默认是支持鼠标滚动事件的。在数据量特别大的时候,及容易不小心就把数据给修改了而不自知。所有,我们这里需要禁用掉这些控件的鼠标滚轮事件。 实现的思想很简单&#xff0c…

原生微信小程序中案例--仿boss区域树选择列多选功能

1. 需求描述: 区域三级列表, 有添加,编辑,删除功能。 选择父级分类,其下子类全部选中,当前分类后加标志显示全字样取消选中子类,其父类分类后标志显示选中数量若子类全部选中,除当…

对2023年图灵奖揭晓看法

2023年图灵奖揭晓,你怎么看? 2023年图灵奖,最近刚刚颁给普林斯顿数学教授 Avi Wigderson!作为理论计算机科学领域的领军人物,他对于理解计算中的随机性和伪随机性的作用,作出了开创性贡献。这些贡献不仅推…