深度学习项目--基于RNN的阿尔茨海默病诊断研究(pytorch实现)

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

前言

  • 其实这个项目比较适合机器学习做,用XGBoost会更好,这个项目更适合RNN学习案例,测试集准确率达到百分之84.2,效果还是算过得去,但是用其他模型会更好,机器学习的方法后面会更新
  • RNN讲解: 深度学习基础–一文搞懂RNN
  • 欢迎收藏 + 关注,本人将会持续更新

文章目录

    • 1、导入数据
    • 2、数据处理
      • 1、患病占比
      • 2、相关性分析
      • 3、年龄与患病探究
    • 3、特征选择
    • 4、构建数据集
      • 1、数据集划分与标准化
      • 2、构建加载
    • 5、构建模型
    • 6、模型训练
      • 1、构建训练集
      • 2、构建训练集
      • 3、设置超参数
    • 7、模型训练
    • 8、结果评估
      • 1、结果图
      • 2、混淆矩阵

1、导入数据

import pandas as pd  
import numpy as np 
import matplotlib.pyplot as plt  
import seaborn as sns 
import torch  
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDatasetplt.rcParams["font.sans-serif"] = ["Microsoft YaHei"]  # 显示中文
plt.rcParams['axes.unicode_minus'] = False		# 显示负号data_df = pd.read_csv("alzheimers_disease_data.csv")data_df.head()
PatientIDAgeGenderEthnicityEducationLevelBMISmokingAlcoholConsumptionPhysicalActivityDietQuality...MemoryComplaintsBehavioralProblemsADLConfusionDisorientationPersonalityChangesDifficultyCompletingTasksForgetfulnessDiagnosisDoctorInCharge
047517300222.927749013.2972186.3271121.347214...001.725883000100XXXConfid
147528900026.82768104.5425247.6198850.518767...002.592424000010XXXConfid
247537303117.795882019.5550857.8449881.826335...007.119548010100XXXConfid
347547410133.800817112.2092668.4280017.435604...016.481226000000XXXConfid
447558900020.716974018.4543566.3104610.795498...000.014691001100XXXConfid

5 rows × 35 columns

该数据集是2149名被诊断患有阿尔茨海默病或有阿尔茨海默病风险的患者的健康记录的综合集合。数据集中的每个患者都有一个唯一的ID号,范围从4751到6900。该数据集涵盖了广泛的信息,这些信息对于理解与阿尔茨海默病相关的各种因素至关重要。它包括人口统计细节、生活习惯、病史、临床测量、认知和功能评估、症状和诊断信息。

在这里插入图片描述

在这里插入图片描述

# 标签中文化
data_df.rename(columns={ "Age": "年龄", "Gender": "性别", "Ethnicity": "种族", "EducationLevel": "教育水平", "BMI": "身体质量指数(BMI)", "Smoking": "吸烟状况", "AlcoholConsumption": "酒精摄入量", "PhysicalActivity": "体育活动时间", "DietQuality": "饮食质量评分", "SleepQuality": "睡眠质量评分", "FamilyHistoryAlzheimers": "家族阿尔茨海默病史", "CardiovascularDisease": "心血管疾病", "Diabetes": "糖尿病", "Depression": "抑郁症史", "HeadInjury": "头部受伤", "Hypertension": "高血压", "SystolicBP": "收缩压", "DiastolicBP": "舒张压", "CholesterolTotal": "胆固醇总量", "CholesterolLDL": "低密度脂蛋白胆固醇(LDL)", "CholesterolHDL": "高密度脂蛋白胆固醇(HDL)", "CholesterolTriglycerides": "甘油三酯", "MMSE": "简易精神状态检查(MMSE)得分", "FunctionalAssessment": "功能评估得分", "MemoryComplaints": "记忆抱怨", "BehavioralProblems": "行为问题", "ADL": "日常生活活动(ADL)得分", "Confusion": "混乱与定向障碍", "Disorientation": "迷失方向", "PersonalityChanges": "人格变化", "DifficultyCompletingTasks": "完成任务困难", "Forgetfulness": "健忘", "Diagnosis": "诊断状态", "DoctorInCharge": "主诊医生" },inplace=True)data_df.columns
Index(['PatientID', '年龄', '性别', '种族', '教育水平', '身体质量指数(BMI)', '吸烟状况', '酒精摄入量','体育活动时间', '饮食质量评分', '睡眠质量评分', '家族阿尔茨海默病史', '心血管疾病', '糖尿病', '抑郁症史','头部受伤', '高血压', '收缩压', '舒张压', '胆固醇总量', '低密度脂蛋白胆固醇(LDL)','高密度脂蛋白胆固醇(HDL)', '甘油三酯', '简易精神状态检查(MMSE)得分', '功能评估得分', '记忆抱怨', '行为问题','日常生活活动(ADL)得分', '混乱与定向障碍', '迷失方向', '人格变化', '完成任务困难', '健忘', '诊断状态','主诊医生'],dtype='object')

2、数据处理

data_df.isnull().sum()
PatientID           0
年龄                  0
性别                  0
种族                  0
教育水平                0
身体质量指数(BMI)         0
吸烟状况                0
酒精摄入量               0
体育活动时间              0
饮食质量评分              0
睡眠质量评分              0
家族阿尔茨海默病史           0
心血管疾病               0
糖尿病                 0
抑郁症史                0
头部受伤                0
高血压                 0
收缩压                 0
舒张压                 0
胆固醇总量               0
低密度脂蛋白胆固醇(LDL)      0
高密度脂蛋白胆固醇(HDL)      0
甘油三酯                0
简易精神状态检查(MMSE)得分    0
功能评估得分              0
记忆抱怨                0
行为问题                0
日常生活活动(ADL)得分       0
混乱与定向障碍             0
迷失方向                0
人格变化                0
完成任务困难              0
健忘                  0
诊断状态                0
主诊医生                0
dtype: int64
from sklearn.preprocessing import LabelEncoder# 创建 LabelEncoder 实例
label_encoder = LabelEncoder()# 对非数值型列进行标签编码
data_df['主诊医生'] = label_encoder.fit_transform(data_df['主诊医生'])data_df.head()
PatientID年龄性别种族教育水平身体质量指数(BMI)吸烟状况酒精摄入量体育活动时间饮食质量评分...记忆抱怨行为问题日常生活活动(ADL)得分混乱与定向障碍迷失方向人格变化完成任务困难健忘诊断状态主诊医生
047517300222.927749013.2972186.3271121.347214...001.7258830001000
147528900026.82768104.5425247.6198850.518767...002.5924240000100
247537303117.795882019.5550857.8449881.826335...007.1195480101000
347547410133.800817112.2092668.4280017.435604...016.4812260000000
447558900020.716974018.4543566.3104610.795498...000.0146910011000

5 rows × 35 columns

1、患病占比

# 计算是否患病, 人数
counts = data_df["诊断状态"].value_counts()# 计算百分比
sizes = counts / counts.sum() * 100# 绘制环形图
fig, ax = plt.subplots()
wedges, texts, autotexts = ax.pie(sizes, labels=sizes.index, autopct='%1.2ff%%', startangle=90, wedgeprops=dict(width=0.3))plt.title("患病占比(1患病,0没有患病)")plt.show()


在这里插入图片描述

患病人数居多

2、相关性分析

plt.figure(figsize=(40, 35))
sns.heatmap(data_df.corr(), annot=True, fmt=".2f")
plt.show()


在这里插入图片描述

其中,与患病相关性比较强的有:MMSE得分、功能评估得分、记忆抱怨、行为问题等相关性比较强,其中,MMSE得分、功能评估得分为负相关,记忆抱怨、行为问题为正相关。

3、年龄与患病探究

data_df['年龄'].min(), data_df['年龄'].max()
(60, 90)
# 计算每一个年龄段患病人数 
age_bins = range(60, 91)
grouped = data_df.groupby('年龄').agg({'诊断状态': ['sum', 'size']})  # 分组、聚合函数: sum求和,size总大小
grouped.columns = ['患病', '总人数']
grouped['不患病'] = grouped['总人数'] - grouped['患病']  # 计算不患病的人数# 设置绘图风格
sns.set(style="whitegrid")plt.figure(figsize=(12, 5))# 获取x轴标签(即年龄)
x = grouped.index.astype(str)  # 将年龄转换为字符串格式便于显示# 画图
plt.bar(x, grouped["不患病"], 0.35, label="不患病", color='skyblue')
plt.bar(x, grouped["患病"], 0.35, label="患病", color='salmon')# 设置标题
plt.title("患病年龄分布", fontproperties='Microsoft YaHei')
plt.xlabel("年龄", fontproperties='Microsoft YaHei')
plt.ylabel("人数", fontproperties='Microsoft YaHei')# 如果需要对图例也应用相同的字体
plt.legend(prop={'family': 'Microsoft YaHei'})# 展示
plt.tight_layout()
plt.show()


在这里插入图片描述

通过发现,由于原本数据中不患病的多,所以不患病的在图像中显示多,通过观察发现患病与年龄有关,尤其是年龄大,80岁的,患病与不患病比例高

提示:这里写代码的时候,不知道为什么,不指定字体,就显示不了字体。

3、特征选择

模型采用:决策树特征训练,可以很好的对特征重要性进行排序。

特征选择:采用REF,特征选择方法:

RFE(Recursive Feature Elimination,递归特征消除)和 SelectFromModel 都是 Scikit-learn 中用于特征选择的方法,但它们的工作机制和使用场景有所不同。

SelectFromModel

  • 工作原理SelectFromModel 是一种基于模型的特征选择方法。它通过一个基础评估器来判断每个特征的重要性,并根据给定的阈值选择那些重要性得分超过该阈值的特征。默认情况下,它会使用基础评估器提供的 feature_importances_ 或者 coef_ 属性来衡量特征的重要性。
  • 适用场景:当您希望基于某个预训练模型的特征重要性来进行特征选择时特别有用。它允许您设置一个全局阈值来控制特征选择的标准,但不直接支持指定想要选择的特征数量。
  • 优点:简单易用,适合快速进行特征筛选。
  • 缺点:不如 RFE 精细,不能直接控制最终选择的特征数量。

RFE (Recursive Feature Elimination)

  • 工作原理:RFE 采用了一种递归的方式进行特征选择。首先,它会训练一个模型,并根据模型对每个特征的重要性评分进行排序。然后,它会移除最不重要的特征,并重复这个过程,直到留下指定数量的特征为止。
  • 适用场景:当您确切知道想要选择多少个特征时非常有用。它提供了比 SelectFromModel 更细致的控制,因为您可以直接指定要保留的特征数量。
  • 优点:可以精确控制最终选择的特征数量,并且在每一轮迭代中都能考虑到所有剩余特征的整体贡献。
  • 缺点:计算成本相对较高,因为它需要多次训练模型,特别是当数据集很大或模型复杂度很高时。

总结

  • 如果您的目标是基于某个预定义的重要性阈值来简化模型,那么 SelectFromModel 可能是更合适的选择。
  • 如果您希望直接控制最终选择的特征数量,并愿意接受更高的计算成本以获得更精细的控制,那么 RFE 可能更适合您的需求。

两种方法都有其独特的优势和适用场景,选择哪一种取决于您的具体应用需求、数据特性以及性能考虑。

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_reportdata = data_df.copy()X = data_df.iloc[:, 1:-2]
y = data_df.iloc[:, -2]X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 标准化
sc = StandardScaler()
X_train = sc.fit_transform(X_train)
X_test = sc.transform(X_test)# 模型创建
tree = DecisionTreeClassifier()
tree.fit(X_train, y_train)
pred = tree.predict(X_test)reporter = classification_report(y_test, pred)
print(reporter)
              precision    recall  f1-score   support0       0.91      0.92      0.91       2771       0.85      0.83      0.84       153accuracy                           0.89       430macro avg       0.88      0.88      0.88       430
weighted avg       0.89      0.89      0.89       430

效果不错,进行特征选择

# 不知道为啥,这样也需要在设置
plt.rcParams["font.sans-serif"] = ["Microsoft YaHei"]  # 显示中文
plt.rcParams['axes.unicode_minus'] = False		# 显示负号# 特征展示
feature_importances = tree.feature_importances_
features_rf = pd.DataFrame({'特征': X.columns, '重要度': feature_importances})
features_rf.sort_values(by='重要度', ascending=False, inplace=True)
plt.figure(figsize=(20, 10))
sns.barplot(x='重要度', y='特征', data=features_rf)
plt.xlabel('重要度')
plt.ylabel('特征')
plt.title('随机森林特征图')
plt.show()


在这里插入图片描述

从这个可以看出,有些特征没有效果,如性别,高血压等。

下面进行特征选择,选取20个特征。

from sklearn.feature_selection import RFE# 使用 RFE 来选择特征
rfe_selector = RFE(estimator=tree, n_features_to_select=20)  # 选择前20个特征
rfe_selector.fit(X, y)  
X_new = rfe_selector.transform(X)
feature_names = np.array(X.columns) 
selected_feature_names = feature_names[rfe_selector.support_]
print(selected_feature_names)
['年龄' '种族' '教育水平' '身体质量指数(BMI)' '酒精摄入量' '体育活动时间' '饮食质量评分' '睡眠质量评分' '心血管疾病''收缩压' '舒张压' '胆固醇总量' '低密度脂蛋白胆固醇(LDL)' '高密度脂蛋白胆固醇(HDL)' '甘油三酯''简易精神状态检查(MMSE)得分' '功能评估得分' '记忆抱怨' '行为问题' '日常生活活动(ADL)得分']

4、构建数据集

1、数据集划分与标准化

feature_selection = ['年龄', '种族','教育水平','身体质量指数(BMI)', '酒精摄入量', '体育活动时间', '饮食质量评分', '睡眠质量评分', '心血管疾病','收缩压', '舒张压', '胆固醇总量', '低密度脂蛋白胆固醇(LDL)', '高密度脂蛋白胆固醇(HDL)', '甘油三酯','简易精神状态检查(MMSE)得分', '功能评估得分', '记忆抱怨', '行为问题', '日常生活活动(ADL)得分']X = data_df[feature_selection]# 标准化, 标准化其实对应连续性数据,分类数据不适合,由于特征中只有种族是分类数据,这里我偷个“小懒”
sc = StandardScaler()
X = sc.fit_transform(X)X = torch.tensor(np.array(X), dtype=torch.float32)
y = torch.tensor(np.array(y), dtype=torch.long)# 再次进行特征选择
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)X_train.shape, y_train.shape
(torch.Size([1719, 20]), torch.Size([1719]))

2、构建加载

batch_size = 32train_dl = DataLoader(TensorDataset(X_train, y_train),batch_size=batch_size,shuffle=True
)test_dl = DataLoader(TensorDataset(X_test, y_test),batch_size=batch_size,shuffle=False
)

5、构建模型

class Rnn_Model(nn.Module):def __init__(self):super().__init__()# 调用rnnself.rnn = nn.RNN(input_size=20, hidden_size=200, num_layers=1, batch_first=True)self.fc1 = nn.Linear(200, 50)self.fc2 = nn.Linear(50, 2)def forward(self, x):x, hidden1 = self.rnn(x)x = self.fc1(x)x = self.fc2(x)return x# 数据不大,cpu即可
device = "cpu"model = Rnn_Model().to(device)
model
Rnn_Model((rnn): RNN(20, 200, batch_first=True)(fc1): Linear(in_features=200, out_features=50, bias=True)(fc2): Linear(in_features=50, out_features=2, bias=True)
)
model(torch.randn(32, 20)).shape
torch.Size([32, 2])

6、模型训练

1、构建训练集

def train(data, model, loss_fn, opt):size = len(data.dataset)batch_num = len(data)train_loss, train_acc = 0.0, 0.0for X, y in data:X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)# 反向传播opt.zero_grad()  # 梯度清零loss.backward()  # 求导opt.step()       # 设置梯度train_loss += loss.item()train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss /= batch_numtrain_acc /= size return train_acc, train_loss 

2、构建训练集

def test(data, model, loss_fn):size = len(data.dataset)batch_num = len(data)test_loss, test_acc = 0.0, 0.0 with torch.no_grad():for X, y in data: X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)test_loss += loss.item()test_acc += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= batch_numtest_acc /= sizereturn test_acc, test_loss 

3、设置超参数

超参数,这里第一步设置了:

  • 1e-3,但是不稳定;
  • 1e-4,效果不错.
loss_fn = nn.CrossEntropyLoss()  # 损失函数     
learn_lr = 1e-4            # 超参数
optimizer = torch.optim.Adam(model.parameters(), lr=learn_lr)   # 优化器

7、模型训练

train_acc = []
train_loss = []
test_acc = []
test_loss = []epoches = 50for i in range(epoches):model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, optimizer)model.eval()epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)# 输出template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}')print(template.format(i + 1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))print("Done")
Epoch: 1, Train_acc:64.9%, Train_loss:0.658, Test_acc:66.0%, Test_loss:0.617
Epoch: 2, Train_acc:66.9%, Train_loss:0.585, Test_acc:70.9%, Test_loss:0.564
Epoch: 3, Train_acc:75.1%, Train_loss:0.531, Test_acc:75.1%, Test_loss:0.511
Epoch: 4, Train_acc:79.8%, Train_loss:0.476, Test_acc:80.9%, Test_loss:0.463
Epoch: 5, Train_acc:82.7%, Train_loss:0.432, Test_acc:81.9%, Test_loss:0.429
Epoch: 6, Train_acc:83.9%, Train_loss:0.399, Test_acc:82.6%, Test_loss:0.413
Epoch: 7, Train_acc:84.4%, Train_loss:0.388, Test_acc:83.3%, Test_loss:0.405
Epoch: 8, Train_acc:85.0%, Train_loss:0.380, Test_acc:82.8%, Test_loss:0.401
Epoch: 9, Train_acc:84.7%, Train_loss:0.381, Test_acc:83.0%, Test_loss:0.398
Epoch:10, Train_acc:84.3%, Train_loss:0.374, Test_acc:84.0%, Test_loss:0.398
Epoch:11, Train_acc:84.9%, Train_loss:0.373, Test_acc:83.5%, Test_loss:0.395
Epoch:12, Train_acc:84.3%, Train_loss:0.374, Test_acc:83.7%, Test_loss:0.400
Epoch:13, Train_acc:84.4%, Train_loss:0.375, Test_acc:83.7%, Test_loss:0.398
Epoch:14, Train_acc:84.6%, Train_loss:0.370, Test_acc:83.5%, Test_loss:0.399
Epoch:15, Train_acc:85.0%, Train_loss:0.370, Test_acc:83.3%, Test_loss:0.400
Epoch:16, Train_acc:84.9%, Train_loss:0.371, Test_acc:83.5%, Test_loss:0.402
Epoch:17, Train_acc:84.8%, Train_loss:0.373, Test_acc:83.3%, Test_loss:0.396
Epoch:18, Train_acc:85.0%, Train_loss:0.369, Test_acc:83.5%, Test_loss:0.397
Epoch:19, Train_acc:84.9%, Train_loss:0.372, Test_acc:83.7%, Test_loss:0.397
Epoch:20, Train_acc:85.3%, Train_loss:0.371, Test_acc:83.3%, Test_loss:0.394
Epoch:21, Train_acc:84.8%, Train_loss:0.372, Test_acc:83.5%, Test_loss:0.396
Epoch:22, Train_acc:84.6%, Train_loss:0.373, Test_acc:83.7%, Test_loss:0.396
Epoch:23, Train_acc:84.8%, Train_loss:0.370, Test_acc:84.0%, Test_loss:0.397
Epoch:24, Train_acc:84.3%, Train_loss:0.373, Test_acc:84.0%, Test_loss:0.401
Epoch:25, Train_acc:84.8%, Train_loss:0.370, Test_acc:84.0%, Test_loss:0.398
Epoch:26, Train_acc:84.9%, Train_loss:0.370, Test_acc:83.5%, Test_loss:0.398
Epoch:27, Train_acc:84.2%, Train_loss:0.373, Test_acc:82.8%, Test_loss:0.398
Epoch:28, Train_acc:85.6%, Train_loss:0.367, Test_acc:82.8%, Test_loss:0.399
Epoch:29, Train_acc:84.6%, Train_loss:0.370, Test_acc:83.7%, Test_loss:0.400
Epoch:30, Train_acc:84.4%, Train_loss:0.374, Test_acc:84.0%, Test_loss:0.399
Epoch:31, Train_acc:84.6%, Train_loss:0.370, Test_acc:83.0%, Test_loss:0.399
Epoch:32, Train_acc:85.2%, Train_loss:0.370, Test_acc:83.7%, Test_loss:0.396
Epoch:33, Train_acc:84.8%, Train_loss:0.372, Test_acc:84.0%, Test_loss:0.395
Epoch:34, Train_acc:84.9%, Train_loss:0.371, Test_acc:83.0%, Test_loss:0.396
Epoch:35, Train_acc:84.5%, Train_loss:0.371, Test_acc:83.3%, Test_loss:0.395
Epoch:36, Train_acc:85.0%, Train_loss:0.371, Test_acc:83.5%, Test_loss:0.396
Epoch:37, Train_acc:85.2%, Train_loss:0.369, Test_acc:84.2%, Test_loss:0.396
Epoch:38, Train_acc:84.6%, Train_loss:0.376, Test_acc:84.0%, Test_loss:0.395
Epoch:39, Train_acc:85.2%, Train_loss:0.370, Test_acc:84.2%, Test_loss:0.396
Epoch:40, Train_acc:84.9%, Train_loss:0.371, Test_acc:84.2%, Test_loss:0.396
Epoch:41, Train_acc:84.4%, Train_loss:0.372, Test_acc:84.0%, Test_loss:0.394
Epoch:42, Train_acc:84.9%, Train_loss:0.370, Test_acc:84.2%, Test_loss:0.393
Epoch:43, Train_acc:84.8%, Train_loss:0.370, Test_acc:84.4%, Test_loss:0.395
Epoch:44, Train_acc:84.4%, Train_loss:0.372, Test_acc:84.0%, Test_loss:0.394
Epoch:45, Train_acc:85.3%, Train_loss:0.371, Test_acc:85.3%, Test_loss:0.396
Epoch:46, Train_acc:84.5%, Train_loss:0.371, Test_acc:83.5%, Test_loss:0.395
Epoch:47, Train_acc:84.5%, Train_loss:0.369, Test_acc:83.5%, Test_loss:0.396
Epoch:48, Train_acc:84.9%, Train_loss:0.371, Test_acc:83.3%, Test_loss:0.397
Epoch:49, Train_acc:85.1%, Train_loss:0.371, Test_acc:83.3%, Test_loss:0.396
Epoch:50, Train_acc:85.0%, Train_loss:0.369, Test_acc:82.6%, Test_loss:0.398
Done

8、结果评估

1、结果图

import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息epochs_range = range(epoches)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training= Loss')
plt.show()


在这里插入图片描述

2、混淆矩阵

混淆矩阵(Confusion Matrix)是机器学习和数据科学中用于评估分类模型性能的一种表格。它通过展示模型预测结果与实际标签之间的对比,帮助我们理解模型的准确度以及其在不同类别上的表现。

对于一个二分类问题,混淆矩阵通常是一个2x2的表格,包含以下四个指标:

  • 真正例 (True Positive, TP):模型正确预测为正类的样本数。
  • 假正例 (False Positive, FP):模型错误地将负类预测为正类的样本数。
  • 假负例 (False Negative, FN):模型错误地将正类预测为负类的样本数。
  • 真负例 (True Negative, TN):模型正确预测为负类的样本数。

而对于多分类问题,混淆矩阵会相应地扩展到NxN的大小(N为类别数量),每一行代表实际类别,每一列代表预测类别。

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay pred = model(X_test.to(device)).argmax(1).cpu().numpy()# 计算混淆矩阵
cm = confusion_matrix(y_test, pred)# 计算
plt.figure(figsize=(6, 5))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")
# 标题
plt.title("混淆矩阵")
plt.xlabel("Predicted Label")
plt.ylabel("True Label")plt.tight_layout()  # 自适应
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/71036.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

华宇TAS应用中间件与因朵科技多款产品完成兼容互认证

在数字化浪潮澎湃向前的当下,信息技术的深度融合与协同发展成为推动各行业创新变革的关键力量。近日,华宇TAS应用中间件携手河北因朵科技有限公司,完成了多项核心产品的兼容互认证。 此次兼容性测试的良好表现,为双方的进一步深入…

麒麟操作系统-MySQL5.7.36二进制安装

1、创建MySQL虚拟用户 groupadd mysql useradd -g mysql -s /sbin/nologin -M mysql 2、创建目录 mkdir -p /data/file #创建文件目录 mkdir -p /opt/mysql #创建MySQL安装目录 mkdir -p /data/mysql/mysql3306/{data,logs} #创建MySQL数据及日志目录 3、安装MySQL5.7.36 …

算法学习笔记之贪心算法

导引(硕鼠的交易) 硕鼠准备了M磅猫粮与看守仓库的猫交易奶酪。 仓库有N个房间,第i个房间有 J[i] 磅奶酪并需要 F[i] 磅猫粮交换,硕鼠可以按比例来交换,不必交换所有的奶酪 计算硕鼠最多能得到多少磅奶酪。 输入M和…

Xcode证书密钥导入

证书干嘛用 渠道定期会给xcode证书,用来给ios打包用,证书里面有记录哪些设备可以打包进去。 怎么换证书 先更新密钥 在钥匙串访问中,选择系统。(选登录也行,反正两个都要导入就是了)。 mac中双击所有 .p12 后缀的密钥&#xff…

使用 Elastic APM 监控你的 C++ 应用程序

作者:来自 Elastic Haidar Braimaanie 在本文中,我们将使用 Opentelemetry CPP 客户端来监控 Elastic APM 中的 C 应用程序。 介绍 开发人员、SRE 和 DevOps 专业人员面临的主要挑战之一是缺乏能够为他们提供应用程序堆栈可见性的综合工具。市场上的许多…

前端骨架怎样实现

前端骨架屏(Skeleton Screen)是一种优化页面加载体验的技术,通常在内容加载时展示一个简易的占位符,避免用户看到空白页面。骨架屏通过展示页面结构的骨架样式,让用户有页面正在加载的感觉,而不是等待内容加…

团结引擎 Shader Graph:解锁图形创作新高度

Shader Graph 始终致力于为开发者提供直观且高效的着色器构建工具,持续推动图形渲染创作的创新与便捷。在团结引擎1.4.0中,Shader Graph 迎来了重大更新,新增多项强大功能并优化操作体验,助力开发者更轻松地实现高质量的渲染效果与…

微信小程序地图标记点,安卓手机一次性渲染不出来的问题

问题描述: 如果微信小程序端,渲染的标记物太多,安卓手机存在标记物不显示的问题,原因初步判断是地图还没有渲染完,标记物数据已经加载完了,导致没有在地图上显示。 解决办法: 使用map组件的b…

AI前端开发的崛起与ScriptEcho的助力

近年来,人工智能(AI)技术飞速发展,深刻地改变着软件开发的格局。尤其是在前端开发领域,AI的应用越来越广泛,催生了对AI写代码工具的需求激增,也显著提升了相关人才的市场价值。然而,…

安装并配置 MySQL

MySQL 是世界上最流行的开源关系型数据库管理系统之一,因其高性能、可靠性和易用性而被广泛应用于各种规模的企业级应用中。本文将详细介绍如何在不同的操作系统上安装和配置 MySQL,帮助你快速搭建起一个功能完善的数据库环境。 选择适合你的安装方式 …

《探秘Windows 10驱动开发:从入门到实战》

《探秘Windows 10驱动开发:从入门到实战》 为什么要在 Windows 10 编写驱动程序 在当今数字化时代,计算机已成为人们生活和工作中不可或缺的工具 ,而 Windows 10 作为一款广泛使用的操作系统,其生态系统的丰富性和复杂性不言而喻。在这个庞大的体系中,驱动程序扮演着举足…

【prompt示例】智能客服+智能质检业务模版

本文原创作者:姚瑞南 AI-agent 大模型运营专家,先后任职于美团、猎聘等中大厂AI训练专家和智能运营专家岗;多年人工智能行业智能产品运营及大模型落地经验,拥有AI外呼方向国家专利与PMP项目管理证书。(转载需经授权&am…

算法17(力扣217)存在重复元素

1、问题 给你一个整数数组 nums 。如果任一值在数组中出现 至少两次 ,返回 true ;如果数组中每个元素互不相同,返回 false 。 2、示例 (1) 示例 1: 输入:nums [1,2,3,1] 输出:…

使用 ffmpeg 给视频批量加图片水印

背景 事情是这样的……前两天突然接到 leader 给的一个任务:给视频加上图片 logo 水印。我这种剪映老司机当然迷之一笑了哈哈哈哈哈,沉浸在简单的任务中还没反应过来巴掌就如洪水般涌来,因为 leader 给了几十个视频……作为一个计算机人&…

CSS 属性选择器详解与实战示例

CSS 属性选择器是 CSS 中非常强大且灵活的一类选择器,它能够根据 HTML 元素的属性和值来进行精准选中。在实际开发过程中,属性选择器不仅可以提高代码的可维护性,而且能够大大优化页面的样式控制。本文将结合菜鸟教程的示例,从基础…

基于SpringBoot和PostGIS的省域“地理难抵点(最纵深处)”检索及可视化实践

目录 前言 1、研究背景 2、研究意义 一、研究目标 1、“地理难抵点”的概念 二、“难抵点”空间检索实现 1、数据获取与处理 2、计算流程 3、难抵点计算 4、WebGIS可视化 三、成果展示 1、华东地区 2、华南地区 3、华中地区 4、华北地区 5、西北地区 6、西南地…

计算机毕业设计——Springboot的校园新闻网站

📘 博主小档案: 花花,一名来自世界500强的资深程序猿,毕业于国内知名985高校。 🔧 技术专长: 花花在深度学习任务中展现出卓越的能力,包括但不限于java、python等技术。近年来,花花更…

PyCharm 批量替换

选择替换的内容 1. 打开全局替换窗口 有两种方式可以打开全局替换窗口: 快捷键方式: 在 Windows 或 Linux 系统下,按下 Ctrl Shift R。在 Mac 系统下,按下 Command Shift R。菜单操作方式:点击菜单栏中的 Edit&…

深度剖析责任链模式

一、责任链模式的本质:灵活可扩展的流水线处理 责任链模式(Chain of Responsibility Pattern)是行为型设计模式的代表,其核心思想是将请求的发送者与接收者解耦,允许多个对象都有机会处理请求。这种模式完美解决了以下…

服务器使用centos7.9操作系统前需要做的准备工作

文章目录 前言1.操作记录 总结 前言 记录一下centos7.9操作系统的服务器在部署业务服务之前需要做的准备工作。 大家可以复制到自己的编辑器里面,有需求的注释一些步骤。 备注:有条件的项目推荐使用有长期支持的操作系统版本。 1.操作记录 # 更换阿里云…