pytorch 利用Tensorboar记录训练过程loss变化

文章目录

    • 1. LossHistory日志类定义
    • 2. LossHistory类的使用
      • 2.1 实例化LossHistory
      • 2.2 记录每个epoch的loss
      • 2.3 训练结束close掉SummaryWriter
    • 3. 利用Tensorboard 可视化
      • 3.1 显示可视化效果
    • 参考

利用Tensorboard记录训练过程中每个epoch的训练loss以及验证loss,便于及时了解网络的训练进展。

代码参考自 B导github仓库: https://github.com/bubbliiiing/deeplabv3-plus-pytorch

1. LossHistory日志类定义

import os
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
import scipy.signalimport torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
#from tensorboardX import SummaryWriter
class LossHistory():def __init__(self, log_dir, model, input_shape):self.log_dir    = log_dirself.losses     = []self.val_loss   = []os.makedirs(self.log_dir)self.writer     = SummaryWriter(self.log_dir)try:dummy_input     = torch.randn(2, 3, input_shape[0], input_shape[1])self.writer.add_graph(model, dummy_input)except:passdef append_loss(self, epoch, loss, val_loss):if not os.path.exists(self.log_dir):os.makedirs(self.log_dir)self.losses.append(loss)self.val_loss.append(val_loss)with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:f.write(str(loss))f.write("\n")with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f:f.write(str(val_loss))f.write("\n")self.writer.add_scalar('loss', loss, epoch)self.writer.add_scalar('val_loss', val_loss, epoch)self.loss_plot()def loss_plot(self):iters = range(len(self.losses))plt.figure()plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss')plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss')try:if len(self.losses) < 25:num = 5else:num = 15plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss')plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss')except:passplt.grid(True)plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend(loc="upper right")plt.savefig(os.path.join(self.log_dir, "epoch_loss.png"))plt.cla()plt.close("all")
  • (1) 首先利用LossHistory类的构造函数__init__, 实例化TensorboardSummaryWriter对象self.writer,并将网络结构图添加到self.writer中。其中__init__方法接收的参数包括,保存log的路径log_dir以及模型model和输入的shape
def __init__(self, log_dir, model, input_shape):self.log_dir    = log_dirself.losses     = []self.val_loss   = []os.makedirs(self.log_dir)self.writer     = SummaryWriter(self.log_dir)try:dummy_input     = torch.randn(2, 3, input_shape[0], input_shape[1])self.writer.add_graph(model, dummy_input)except:pass
  • (2) 记录每个epoch的训练损失loss以及验证val_loss,并保存到tensorboar中显示
self.writer.add_scalar('loss', loss, epoch)
self.writer.add_scalar('val_loss', val_loss, epoch)

同时将训练的loss以及验证val_loss逐行保存到.txt文件中

 with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:f.write(str(loss))f.write("\n")

并且在每个epoch时,调用loss_plot绘制历史的loss曲线,并保存为epoch_loss.png, 由于每个epoch保存的图片都是重名的,因此在训练结束时,会保存最新的所有epoch绘制的loss曲线

2. LossHistory类的使用

2.1 实例化LossHistory

在训练开始前,实例化LossHistory类,调用__init__实例化时,会创建SummaryWriter对象,用于记录训练的过程中的数据,比如loss, graph以及图片信息等

local_rank  = int(os.environ["LOCAL_RANK"]) 
model   = DeepLab(num_classes=num_classes, backbone=backbone, downsample_factor=downsample_factor, pretrained=pretrained)
input_shape     = [512, 512]if local_rank == 0:time_str        = datetime.datetime.strftime(datetime.datetime.now(),'%Y_%m_%d_%H_%M_%S')log_dir         = os.path.join(save_dir, "loss_" + str(time_str))loss_history    = LossHistory(log_dir, model, input_shape=input_shape)else:loss_history    = None
  • 对于多GPU训练时,只在主进程(local_rank == 0)记录训练的日志信息
  • log 保存的路径log_dir,利用loss_ + 当前时间的形式记录
log_dir         = os.path.join(save_dir, "loss_" + str(time_str))

2.2 记录每个epoch的loss

在每个epoch中,利用loss_history的append_loss方法,利用SummaryWriter对象保存loss:

for epoch in range(start_epoch, total_epoch):...loss_history.append_loss(epoch + 1, total_loss / epoch_step, val_loss / epoch_step_val)
  • 记录了每个epoch的训练loss以及验证val_loss
  • 同时将最新的loss曲线,保存到本地epoch_loss.png
  • 并将历史的训练loss和val_loss保存为txt文件,方便查看

2.3 训练结束close掉SummaryWriter

loss_history.writer.close()

3. 利用Tensorboard 可视化

  • Tensorboard最早是在Tensorflow中开发和应用的,pytorch 中也同样支持Tensorboard的使用,pytorch中的Tensorboard工具叫TensorboardX, 它需要依赖于tensorflow库中的一些组件支持。因此在安装Tensorboardx之前,需要先安装TensorFlow, 否则直接安装Tensorboardx运行会报错。
pip install tensorflow
pip install tensorboardX

3.1 显示可视化效果

训练结束后,cd到SummaryWriter中定义好日志保存目录log_dir下,执行如下指令

cd log_dir # log_dir为定义的日志保存目录
tensorboard  --logdir=./     --port 6006 

然后会显示出访问的链接地址,点击链接就可以查看Tensorboard可视化效果

  • Scalar模块展示训练过程中,每个epoch的train_loss、Accuracy、Learn_Rating的数值变化
    在这里插入图片描述
  • GRAPH模块展示的是模型的网络结构
    在这里插入图片描述
  • HISTOGRAMS模块展示添加到tensorboard中各层的权重分布情况
    在这里插入图片描述

参考

  • 1 https://github.com/bubbliiiing/deeplabv3-plus-pytorch
  • 2 pytorch中使用tensorboard实现训练过程可视化

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/670154.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

前端工程化之:webpack3-5(css module)

目录 一、css module 1.思路 2.实现原理 3.如何应用样式 4.其他操作 &#xff08;1&#xff09;全局类名 &#xff08;2&#xff09;如何控制最终的类名 5.其他注意事项 一、css module 通过命名规范来限制类名太过死板&#xff0c;而 css in js 虽然足够灵活&…

Qt应用软件【协议篇】TCP示例

文章目录 TCP协议简介Qt中的TCP编程完整代码示例实际使用中的技巧实际使用中的注意事项TCP协议简介 TCP(传输控制协议)是一种面向连接的、可靠的、基于字节流的传输层通信协议。与UDP不同,TCP提供了数据包排序、重传机制、流量控制和拥塞控制,确保了数据传输的可靠性和顺序…

京东首页移动端-web实战

设置视口标签以及引入初始化样式 <link rel"stylesheet" href"./css/normalize.css"><link rel"stylesheet" href"./css/index.css"> body常用初始化样式 body {width: 100%;min-width: 320px;max-width: 640px;margin:…

作业2.5

第四章 堆与拷贝构造函数 一 、程序阅读题 1、给出下面程序输出结果。 #include <iostream.h> class example {int a; public: example(int b5){ab;} void print(){aa1;cout <<a<<"";} void print()const {cout<<a<<endl;} …

mysql 多数据源

依赖 <dependencies><!--mysql连接--><dependency><groupId>mysql</groupId><artifactId>mysql-connector-java</artifactId><scope>runtime</scope></dependency><!--多数据源--><dependency><g…

PostgreSQL开启wal日志归档模式

1.检查归档模式是否开启 postgres# show archive_mode;archive_mode --------------off (1 row)2.开启归档模式 ## 创建归档目录 mkdir -p /pgsql15.4/pg_arch## 配置归档相关参数 postgres# alter system set archive_modeon; ALTER SYSTEM postgres# alter system set arc…

P9240 [蓝桥杯 2023 省 B] 冶炼金属--2024蓝桥杯冲刺省一

点击跳转例题 思路&#xff1a;最开始读完题&#xff0c;我们知道求最小值最大&#xff0c;和最大值最小。是符合二分的性质的&#xff0c;但是我们再一思考可以发现这是简单的数学。 求每条记录的最小值&#xff1a;a/&#xff08;b1&#xff09;1。可以发现 a%b的情况下&…

【手写数据库toadb】虚拟文件描述符,连接表对象与物理文件的纽带,通过逻辑表找到物理文件的密码

22 存储管理抽象接口层 ​专栏内容: 手写数据库toadb 本专栏主要介绍如何从零开发,开发的步骤,以及开发过程中的涉及的原理,遇到的问题等,让大家能跟上并且可以一起开发,让每个需要的人成为参与者。 本专栏会定期更新,对应的代码也会定期更新,每个阶段的代码会打上tag,…

VLM 系列——Llava1.6——论文解读

一、概述 1、是什么 Llava1.6 是llava1.5 的升级暂时还没有论文等&#xff0c;是一个多模态视觉-文本大语言模型&#xff0c;可以完成&#xff1a;图像描述、视觉问答、根据图片写代码&#xff08;HTML、JS、CSS&#xff09;&#xff0c;潜在可以完成单个目标的视觉定位、名画…

【go】gorm\xorm\ent事务处理

文章目录 1 gorm1.1 开启事务1.2 执行操作1.3 提交或回滚 2 xorm2.1 开启事务2.2 执行操作2.3 提交或回滚 3 ent3.1 开启事务3.2 执行操作3.3 提交或回滚 前言&#xff1a;本文介绍golang三种orm框架对数据库事务的操作 1 gorm Begin开启事务 tx *gorm.DB 1.1 开启事务 tx :…

Qt PCL学习(一):环境搭建

参考 (QT配置pcl)PCL1.12.1QT5.15.2vs2019cmake3.22.4vtk9.1.0visual studio2019Qt5.15.2PCL1.12.1vtk9.1.0cmake3.22.2 本博客用到的所有资源 版本一览&#xff1a;Visual Studio 2019 Qt 5.15.2 PCL 1.12.1 VTK 9.1.0https://pan.baidu.com/s/1xW7xCdR5QzgS1_d1NeIZpQ?pw…

计算机设计大赛 深度学习+opencv+python实现车道线检测 - 自动驾驶

文章目录 0 前言1 课题背景2 实现效果3 卷积神经网络3.1卷积层3.2 池化层3.3 激活函数&#xff1a;3.4 全连接层3.5 使用tensorflow中keras模块实现卷积神经网络 4 YOLOV56 数据集处理7 模型训练8 最后 0 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 &am…

React+Antd+实现省、市区级联下拉多选组件

1、效果 是你要的效果&#xff0c;咱们继续往下看&#xff0c;搜索面板实现省市区下拉&#xff0c;原本有antd的Cascader组件&#xff0c;但是级联组件必须选到子节点&#xff0c;不能只选省&#xff0c;满足不了页面的需求 2、环境准备 1、react18 2、antd 4 3、功能实现 …

C++泛型编程:类模板(下)

类模板与继承&#xff1a; 需要指定模板参数的类型 template <class T> class Base { public:T m; }; class Son :public Base<int> { }; template <typename T1,typename T2> class Son2 :public Base<T2> { public:Son2(){cout << "T1的…

IntelliScraper 更新 --可自定义最大输出和相似度 支持Html的内容相似度匹配

场景 之前我们在使用IntelliScraper 初代版本的时候&#xff0c;不少人和我反馈一个问题&#xff0c;那就是最大输出结果只有50个&#xff0c;而且还带有html内容&#xff0c;不支持自动化&#xff0c;我声明一下&#xff0c;自动化目前不会支持&#xff0c;以后也不会支持&am…

按时间维度统计次数案例

按时间维度统计次数案例 文章目录 按时间维度统计次数案例1.按天维度统计个数2.按月维度统计个数3.按小时维度统计个数4.按分钟维度统计个数5.按秒维度统计个数6.每个5分钟的维度统计个数 1.按天维度统计个数 要按天维度统计某个字段的个数&#xff0c;可以使用MySQL的日期函数…

Java集合为什么不能使用foreach删除元素

文章目录 前言foreach为什么不能使用foreach操作ArrayList迭代器解析 前言 相信各位程序猿在开发的过程中都用过foreach循环&#xff0c;简单快捷的遍历集合或者数组&#xff0c;但是在通过foreach进行集合操作的时候就不可以了&#xff0c;这是为什么&#xff1f;这里先把问题…

正点原子-STM32定时器学习笔记(1)未完待续

1. 通用定时器简介&#xff08;F1为例&#xff09; F1系列通用定时器有4个&#xff0c;TIM2/TIM3/TIM4/TIM5 主要特性&#xff1a; 16位递增、递减、中心对齐计数器&#xff08;计数值&#xff1a;0~65535&#xff09;&#xff1b; 16位预分频器&#xff08;分频系数&#xff…

[晓理紫]AI专属会议截稿时间订阅

AI专属会议截稿时间订阅 关注{晓理紫}&#xff0c;每日更新最新AI专属会议信息&#xff0c;如感兴趣&#xff0c;请转发给有需要的同学&#xff0c;谢谢支持&#xff01;&#xff01; 如果你感觉对你有所帮助&#xff0c;请关注我&#xff0c;每日准时为你推送最新AI专属会议信…

洛谷:P2957 [USACO09OCT] Barn Echoes G

题目描述 The cows enjoy mooing at the barn because their moos echo back, although sometimes not completely. Bessie, ever the excellent secretary, has been recording the exact wording of the moo as it goes out and returns. She is curious as to just how mu…