深度学习:Matplotlib篇

一、简介

1.1 什么是 Matplotlib?

Matplotlib 是一个广泛使用的 2D 绘图库,它可以用来在 Python 中创建各种静态、动态和交互式的图表。无论是科学计算、数据可视化,还是深度学习模型的训练与评估,Matplotlib 都能提供强大的图形展示功能。在深度学习领域,Matplotlib 通常用于可视化训练过程中的损失函数、准确率曲线以及各种训练结果

1.2 为什么会在深度学习中使用 Matplotlib?

在深度学习中,使用 Matplotlib 可以帮助开发者和研究人员更直观地理解模型的性能。常见的应用包括:

  • 绘制训练过程中的损失和精度曲线,以监控模型是否过拟合或欠拟合
  • 展示分类模型的混淆矩阵,以可视化分类错误类型
  • 可视化图像特征或特征图,以深入理解卷积神经网络(CNN)内部的工作机制

二、基本使用方法

2.1 导入

在使用之前,需要先导入其主要模块。一般来说我们会将 Matplotlib.pyplot 模块导入并简写为 plt,为了代码看起来简洁

import matplotlib.pyplot as plt

2.2 基本绘图操作

这里以折线图为例

# 创建一些数据
x = [1, 2, 3, 4, 5]
y = [1, 4, 9, 16, 25]# 绘制折线图
plt.plot(x, y)# 显示图像
plt.show()

2.3 标题、标签与图例

添加标题、轴标签以及图例让图表更具信息性

plt.plot(x, y, label='y = x^2')
plt.title('Example of a Line Plot')
plt.xlabel('X Axis')
plt.ylabel('Y Axis')
plt.legend()
plt.show()

2.4 调整图像样式

Matplotlib 支持多种样式的图表,如折线图、柱状图、散点图等。在绘图时,我们也可以调整线条颜色、样式以及图形的其他参数

plt.plot(x, y, color='green', linestyle='--', marker='o', label='y = x^2')
plt.title('Customized Line Plot')
plt.xlabel('X Axis')
plt.ylabel('Y Axis')
plt.legend()
plt.show()

三、在深度学习中的应用

3.1 可视化训练过程

在深度学习模型训练过程中,最常见的做法是通过 Matplotlib 绘制训练和验证集上的损失曲线及准确率曲线。这些图表能够帮助我们判断模型的表现,分析是否存在过拟合或欠拟合的现象

import matplotlib.pyplot as plt# 假设有训练和验证的损失和准确率数据 随便取的
epochs = range(1, 11)
train_loss = [0.8, 0.6, 0.4, 0.3, 0.2, 0.15, 0.1, 0.08, 0.06, 0.04]
val_loss = [0.9, 0.7, 0.5, 0.4, 0.3, 0.25, 0.22, 0.21, 0.2, 0.19]# 绘制损失曲线
plt.plot(epochs, train_loss, 'bo-', label='Training loss')
plt.plot(epochs, val_loss, 'ro-', label='Validation loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

通过这种方法,我们就可以很直观地观察模型在每个 epoch 训练后的损失变化趋势,尤其是训练集和验证集的差异可以帮助判断是否发生过拟合

3.2 可视化卷积神经网络的特征图

在深度学习中,卷积神经网络(CNN)可以学习到图像的层次化特征。在某些场景下,我们希望可视化这些特征图,以更好地理解网络的工作机制。这时就可以通过 Matplotlib 将卷积层的输出特征图绘制出来

import torch
import matplotlib.pyplot as plt# 假设我们有一个 CNN 模型和一张输入图像
model = ...
image = ...# 获取卷积层的输出
features = model.conv1(image)# 可视化第一层卷积后的特征图
fig, axarr = plt.subplots(1, 4)
for idx in range(4):axarr[idx].imshow(features[0, idx].detach().numpy(), cmap='gray')
plt.show()

给你们补全一下

import torch  
import torch.nn as nn  
import torchvision.transforms as transforms  
from PIL import Image  
import matplotlib.pyplot as plt  # 定义一个简单的 CNN 模型  
class SimpleCNN(nn.Module):  def __init__(self):  super(SimpleCNN, self).__init__()  self.conv1 = nn.Conv2d(in_channels=3, out_channels=4, kernel_size=3, stride=1, padding=1)  # 添加其他层如果需要的话,这里仅作为示例  def forward(self, x):  x = self.conv1(x)  # x = ... 其他操作  return x  # 实例化模型  
model = SimpleCNN()  # 将模型移动到 GPU(如果可用),或者保持在 CPU 上  
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  
model.to(device)  # 加载并预处理一张图像  
# 假设我们使用一张 RGB 图像,尺寸为 224x224(根据模型需求调整)  
transform = transforms.Compose([  transforms.Resize((224, 224)),  transforms.ToTensor(),  # 将图像转换为 PyTorch 张量,并归一化到 [0, 1]  
])  # 这里我们加载一个示例图像,你需要提供实际的图像路径  
image_path = 'path_to_your_image.jpg'  # 替换为你的图像路径  
image = Image.open(image_path).convert('RGB')  
image = transform(image).unsqueeze(0)  # 增加一个 batch 维度  
image = image.to(device)  # 确保图像在正确的设备上  # 获取卷积层的输出  
with torch.no_grad():  # 我们不需要计算梯度  features = model.conv1(image)  # 可视化第一层卷积后的特征图  
features = features.squeeze(0).cpu()  # 移除 batch 维度,并移动到 CPU  
fig, axarr = plt.subplots(1, 4, figsize=(12, 3))  
for idx in range(4):  axarr[idx].imshow(features[idx, :, :].numpy(), cmap='gray')  # 注意这里的索引可能需要根据实际的输出形状调整  axarr[idx].axis('off')  # 关闭坐标轴  
plt.show()

 看看效果:

原图:

3.3 混淆矩阵的可视化

混淆矩阵用于衡量分类模型性能,它、可以直观展示模型在不同类别上的分类正确率和错误率

import seaborn as sns
from sklearn.metrics import confusion_matrix# 假设我们有预测值和真实值
y_true = [0, 1, 2, 2, 0, 1, 1, 2]
y_pred = [0, 0, 2, 2, 0, 2, 1, 2]# 计算混淆矩阵
cm = confusion_matrix(y_true, y_pred)# 可视化混淆矩阵
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title('Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.show()

这种可视化方法可以帮助我们深入分析模型在哪些类别上分类错误较多,进而做出相应的改进措施

四、进阶技巧

4.1 子图与多图显示

在深度学习中,往往需要同时观察多个图表(如损失和准确率),这时可以通过 subplot 函数来实现多个子图的排列显示

import matplotlib.pyplot as plt  
import numpy as np  # 假设我们已经有了一些训练数据  
# 这些数据通常是在训练循环中收集的  # 示例数据(您应该使用您的实际数据替换这些)  
epochs = np.arange(1, 21)  # 假设我们训练了20个epoch  
train_loss = np.linspace(0.5, 0.1, 20)  # 假设训练损失从0.5线性降低到0.1  
val_loss = np.linspace(0.55, 0.15, 20)  # 假设验证损失从0.55线性降低到0.15  
train_acc = np.linspace(0.5, 0.9, 20)   # 假设训练准确率从0.5线性增加到0.9  
val_acc = np.linspace(0.45, 0.85, 20)   # 假设验证准确率从0.45线性增加到0.85  # 创建子图  
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))  # 绘制第一个图(损失曲线)  
ax1.plot(epochs, train_loss, 'b-', label='Training loss')  
ax1.plot(epochs, val_loss, 'r-', label='Validation loss')  
ax1.set_title('Loss Over Epochs')  
ax1.set_xlabel('Epochs')  
ax1.set_ylabel('Loss')  
ax1.legend()  
ax1.grid(True)  # 可选:添加网格线  # 绘制第二个图(准确率曲线)  
ax2.plot(epochs, train_acc, 'b-', label='Training accuracy')  
ax2.plot(epochs, val_acc, 'r-', label='Validation accuracy')  
ax2.set_title('Accuracy Over Epochs')  
ax2.set_xlabel('Epochs')  
ax2.set_ylabel('Accuracy')  
ax2.legend()  
ax2.grid(True)  # 可选:添加网格线  # 显示图形  
plt.tight_layout()  # 可选:调整子图之间的间距  
plt.show()

4.2 动态绘图

在深度学习训练过程中,有时我们希望能实时查看训练的进展。这时可以利用 Matplotlib 的动态绘图功能,通过 plt.ion() 实现图表的实时更新

import matplotlib.pyplot as plt  
import random  # 用于生成模拟数据  # 初始化列表  
epochs = []  
train_loss = []  
val_loss = []  # 开启交互模式  
plt.ion()  # 模拟10个epoch的训练过程  
for epoch in range(10):  # 更新epochs列表(虽然在这种情况下,我们可以直接使用range(11)来绘制,但为了与您的代码一致,我们还是更新这个列表)  epochs.append(epoch + 1)  # 模拟新的训练损失和验证损失(在实际应用中,这些值将来自您的训练循环)  new_train_loss = random.uniform(0.1, 1.0)  # 生成一个0.1到1.0之间的随机浮点数  new_val_loss = random.uniform(0.1, 1.0)    # 生成另一个0.1到1.0之间的随机浮点数  # 将新的损失值添加到列表中  train_loss.append(new_train_loss)  val_loss.append(new_val_loss)  # 清除上一帧图像  plt.clf()  # 绘制新的曲线  plt.plot(epochs, train_loss, 'b-', label='Training loss')  plt.plot(epochs, val_loss, 'r-', label='Validation loss')  plt.xlabel('Epoch')  plt.ylabel('Loss')  plt.title('Loss Over Epochs')  plt.legend()  plt.grid(True)  # 可选:添加网格线  # 暂停一段时间以更新图形(0.1秒)  plt.pause(0.1)  # 关闭交互模式(在显示最终图形之前通常不需要这样做,因为plt.show()会处理它)  
# 但为了与您的代码一致,我们还是包含了这个调用  
plt.ioff()  # 显示最终图形(在交互模式下,这通常不是必需的,因为图形已经在循环中更新了)  
# 但由于我们包含了plt.ioff(),所以我们需要调用plt.show()来确保图形显示出来  
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/56921.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Python】if选择判断结构详解:逻辑分支与条件判断

目录 🍔 if选择判断结构作用 1.1 if选择判断结构的基本语法 1.2 if选择结构案例 1.3 if...else...结构 1.4 if...elif...else多条件判断结构 1.5 if嵌套结构 🍔 综合案例:石头剪刀布 2.1 需求分析 2.2 代码实现 2.3 随机出拳 &…

【数据结构】快速排序(三种实现方式)

目录 一、基本思想 二、动图演示(hoare版) 三、思路分析(图文) 四、代码实现(hoare版) 五、易错提醒 六、相遇场景分析 6.1 ❥ 相遇位置一定比key要小的原因 6.2 ❥ 右边为key,左边先走 …

redis详细教程(2.List教程)

List是一种可以存储多个有序字符串的数据类型,其中的元素按照顺序排列(可以重复出现),可以通过数字索引来访问列表中的元素,索引可以从左到右或者从右到左。 Redis 列表可以通过两种方式实现:压缩列表&…

PVE 一键安装WIKI.js

Wiki.js 一个轻量的知识库管理工具 在PVE 的shell 下执行如下代码(国内访问需自行调整),一键安装,默认使用了sqlLite 作为数据库: bash -c "$(wget -qLO - https://github.com/tteck/Proxmox/raw/main/ct/wikijs…

电脑维修指南

1.输入法切换 1.右键悬浮窗 2.选择全拼 2.换壁纸 壁纸给你准备好了 https://wwyz.lanzoul.com/b00g2g2vyd 密码:da72浏览器下载解压, 然后就有了 右键, 挑选 3.清理垃圾 浏览器输入这个地址 https://wwyz.lanzoul.com/ijMin2di41ih普通下载 找一个喜欢的地方 右键, 解压 …

[SWPUCTF 2022 新生赛]py1的write up

开启靶场,下载附件,解压后得到: 双击exe文件,出现弹窗: 问的是异或,写个python文件来计算结果: # 获取用户输入的两个整数 num1 int(input("Enter the first number: ")) num2 int…

排序算法(冒泡,插入),希尔排序(插入升级),希尔排序和插入排序时间比较!

🎁个人主页:我们的五年 🔍系列专栏:排序算法 🎉欢迎大家点赞👍评论📝收藏⭐文章 一.冒泡排序: 时间复杂度:O(N^2)。 🏄‍♂️思路…

【Nas】X-DOC:搞机之PVE部署All In One(黑群晖NAS 软路由OpenWrt Docker Win10远程桌面)

【Nas】X-DOC:搞机之PVE部署All In One(黑群晖NAS & 软路由OpenWrt & Docker & Win10远程桌面) 1、原硬件配置清单:2、改AIO后增加配置清单:3、虚拟化平台PVE:4、搭建的关键服务: 1…

Web高级开发实验:EL基本运算符与数据访问

一、实验目的 掌握EL的定义,即Expression Language,用于提高编程效率。学习和掌握在开发环境中创建Java文件,并在jsp文件中使用EL表达式去调用其中的方法与属性等。 二、实验所用方法 上机实操 三、实验步骤及截图 1、创建javaweb项目&a…

Springboot项目中常用注解

文章目录 Springboot相关注解EnableAspectJAutoProxy(exposeProxy true)内部实现机制 EnableTransactionManagementServletComponentScanMapperScan(basePackages {"com.xxx.mapper"})ComponentScan(basePackages{"*"})lombok Data注解Controller中的相关…

jvm虚拟机介绍

Java虚拟机(JVM)是Java语言的运行环境,它基于栈式架构,通过加载、验证、准备、解析、初始化等类加载过程,将Java类文件转换成平台无关的字节码,并在运行时动态地将其翻译成特定平台的机器码执行。 JVM的核心…

基于SSM农业信息管理系统的设计

管理员账户功能包括:系统首页,个人中心,用户管理,农业技术管理,种植户管理,农产品类型管理,农资订单管理,系统管理 种植户账号功能包括:系统首页,个人中心&a…

01C++书写hello world、注释、变量、常量

#include <iostream> using namespace std; int main()//main为一个程序的入口&#xff0c;每个程序都必须仅有一个 { cout<<"hello world"<<endl; } //#输出结果为 //单行注释的符号 /* 多行注释的符号 */ //变量创建的语法&#xff1a;数据类…

OpenAI GPT-o1实现方案记录与梳理

本篇文章用于记录从各处收集到的o1复现方案的推测以及介绍 目录 Journey Learning - 上海交通大学NYUMBZUAIGAIRCore IdeaKey QuestionsKey TechnologiesTrainingInference A Tutorial on LLM Reasoning: Relevant methods behind ChatGPT o1 - UCL汪军教授Core Idea先导自回归…

shodan2---清风

注&#xff1a;本文章源于泷羽SEC&#xff0c;如有侵权请联系我&#xff0c;违规必删 学习请认准泷羽SEC学习视频:https://space.bilibili.com/350329294 实验一&#xff1a;search 存在CVE-2019-0708的网络设备 CVE - 2019 - 0708**漏洞&#xff1a;** 该漏洞存在于远程桌面…

offset Explorer连接云服务上的kafka连接不上

以上配置后报连接错误时&#xff0c;可能是因为kafka的server.properties配置文件没配置好&#xff1a; 加上面两条配置&#xff0c;再次测试连接&#xff0c;成功 listeners和advertised.listeners

C++的相关习题(2)

初阶模板 下面有关C中为什么用模板类的原因&#xff0c;描述错误的是? ( &#xff09; A.可用来创建动态增长和减小的数据结构 B.它是类型无关的&#xff0c;因此具有很高的可复用性 C.它运行时检查数据类型&#xff0c;保证了类型安全 D.它是平台无关的&#xff0c;可移植…

Vue.js 组件开发教程:从基础到进阶

Vue.js 组件开发教程:从基础到进阶 引言 在现代前端开发中,Vue.js 作为一款流行的 JavaScript 框架,以其简单易用和灵活性赢得了开发者的青睐。Vue 组件是 Vue.js 的核心概念之一,理解组件的开发和使用对构建复杂的用户界面至关重要。本篇文章将详细介绍 Vue.js 组件的开…

NFS练习

一、实验目的 1、开放/nfs/shared目录&#xff0c;供所有用户查询资料 2、开放/nfs/upload目录&#xff0c;为192.168.xxx.0/24网段主机可以上传目录&#xff0c; 并将所有用户及所属的组映射为nfs-upload,其UID和GID均为210 3、将/home/tom目录仅共享给192.168.xxx.xxx这台…

MySQL全文索引检索中文

MySQL全文索引检索中文 5.7.6版本不支持中文检索&#xff0c;需要手动修改配置 ft_min_word_len 1 &#xff0c;因为默认配置 4 SHOW VARIABLES LIKE ft%; show VARIABLES like ngram_token_size;配置 修改 MySQL 配置文件 vim /etc/my.cnf在配置的 [mysqld] 下面添加**ft_…