pytorch 利用Tensorboar记录训练过程loss变化

文章目录

    • 1. LossHistory日志类定义
    • 2. LossHistory类的使用
      • 2.1 实例化LossHistory
      • 2.2 记录每个epoch的loss
      • 2.3 训练结束close掉SummaryWriter
    • 3. 利用Tensorboard 可视化
      • 3.1 显示可视化效果
    • 参考

利用Tensorboard记录训练过程中每个epoch的训练loss以及验证loss,便于及时了解网络的训练进展。

代码参考自 B导github仓库: https://github.com/bubbliiiing/deeplabv3-plus-pytorch

1. LossHistory日志类定义

import os
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
import scipy.signalimport torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
#from tensorboardX import SummaryWriter
class LossHistory():def __init__(self, log_dir, model, input_shape):self.log_dir    = log_dirself.losses     = []self.val_loss   = []os.makedirs(self.log_dir)self.writer     = SummaryWriter(self.log_dir)try:dummy_input     = torch.randn(2, 3, input_shape[0], input_shape[1])self.writer.add_graph(model, dummy_input)except:passdef append_loss(self, epoch, loss, val_loss):if not os.path.exists(self.log_dir):os.makedirs(self.log_dir)self.losses.append(loss)self.val_loss.append(val_loss)with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:f.write(str(loss))f.write("\n")with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f:f.write(str(val_loss))f.write("\n")self.writer.add_scalar('loss', loss, epoch)self.writer.add_scalar('val_loss', val_loss, epoch)self.loss_plot()def loss_plot(self):iters = range(len(self.losses))plt.figure()plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss')plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss')try:if len(self.losses) < 25:num = 5else:num = 15plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss')plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss')except:passplt.grid(True)plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend(loc="upper right")plt.savefig(os.path.join(self.log_dir, "epoch_loss.png"))plt.cla()plt.close("all")
  • (1) 首先利用LossHistory类的构造函数__init__, 实例化TensorboardSummaryWriter对象self.writer,并将网络结构图添加到self.writer中。其中__init__方法接收的参数包括,保存log的路径log_dir以及模型model和输入的shape
def __init__(self, log_dir, model, input_shape):self.log_dir    = log_dirself.losses     = []self.val_loss   = []os.makedirs(self.log_dir)self.writer     = SummaryWriter(self.log_dir)try:dummy_input     = torch.randn(2, 3, input_shape[0], input_shape[1])self.writer.add_graph(model, dummy_input)except:pass
  • (2) 记录每个epoch的训练损失loss以及验证val_loss,并保存到tensorboar中显示
self.writer.add_scalar('loss', loss, epoch)
self.writer.add_scalar('val_loss', val_loss, epoch)

同时将训练的loss以及验证val_loss逐行保存到.txt文件中

 with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:f.write(str(loss))f.write("\n")

并且在每个epoch时,调用loss_plot绘制历史的loss曲线,并保存为epoch_loss.png, 由于每个epoch保存的图片都是重名的,因此在训练结束时,会保存最新的所有epoch绘制的loss曲线

2. LossHistory类的使用

2.1 实例化LossHistory

在训练开始前,实例化LossHistory类,调用__init__实例化时,会创建SummaryWriter对象,用于记录训练的过程中的数据,比如loss, graph以及图片信息等

local_rank  = int(os.environ["LOCAL_RANK"]) 
model   = DeepLab(num_classes=num_classes, backbone=backbone, downsample_factor=downsample_factor, pretrained=pretrained)
input_shape     = [512, 512]if local_rank == 0:time_str        = datetime.datetime.strftime(datetime.datetime.now(),'%Y_%m_%d_%H_%M_%S')log_dir         = os.path.join(save_dir, "loss_" + str(time_str))loss_history    = LossHistory(log_dir, model, input_shape=input_shape)else:loss_history    = None
  • 对于多GPU训练时,只在主进程(local_rank == 0)记录训练的日志信息
  • log 保存的路径log_dir,利用loss_ + 当前时间的形式记录
log_dir         = os.path.join(save_dir, "loss_" + str(time_str))

2.2 记录每个epoch的loss

在每个epoch中,利用loss_history的append_loss方法,利用SummaryWriter对象保存loss:

for epoch in range(start_epoch, total_epoch):...loss_history.append_loss(epoch + 1, total_loss / epoch_step, val_loss / epoch_step_val)
  • 记录了每个epoch的训练loss以及验证val_loss
  • 同时将最新的loss曲线,保存到本地epoch_loss.png
  • 并将历史的训练loss和val_loss保存为txt文件,方便查看

2.3 训练结束close掉SummaryWriter

loss_history.writer.close()

3. 利用Tensorboard 可视化

  • Tensorboard最早是在Tensorflow中开发和应用的,pytorch 中也同样支持Tensorboard的使用,pytorch中的Tensorboard工具叫TensorboardX, 它需要依赖于tensorflow库中的一些组件支持。因此在安装Tensorboardx之前,需要先安装TensorFlow, 否则直接安装Tensorboardx运行会报错。
pip install tensorflow
pip install tensorboardX

3.1 显示可视化效果

训练结束后,cd到SummaryWriter中定义好日志保存目录log_dir下,执行如下指令

cd log_dir # log_dir为定义的日志保存目录
tensorboard  --logdir=./     --port 6006 

然后会显示出访问的链接地址,点击链接就可以查看Tensorboard可视化效果

  • Scalar模块展示训练过程中,每个epoch的train_loss、Accuracy、Learn_Rating的数值变化
    在这里插入图片描述
  • GRAPH模块展示的是模型的网络结构
    在这里插入图片描述
  • HISTOGRAMS模块展示添加到tensorboard中各层的权重分布情况
    在这里插入图片描述

参考

  • 1 https://github.com/bubbliiiing/deeplabv3-plus-pytorch
  • 2 pytorch中使用tensorboard实现训练过程可视化

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/670154.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

前端工程化之:webpack3-5(css module)

目录 一、css module 1.思路 2.实现原理 3.如何应用样式 4.其他操作 &#xff08;1&#xff09;全局类名 &#xff08;2&#xff09;如何控制最终的类名 5.其他注意事项 一、css module 通过命名规范来限制类名太过死板&#xff0c;而 css in js 虽然足够灵活&…

京东首页移动端-web实战

设置视口标签以及引入初始化样式 <link rel"stylesheet" href"./css/normalize.css"><link rel"stylesheet" href"./css/index.css"> body常用初始化样式 body {width: 100%;min-width: 320px;max-width: 640px;margin:…

mysql 多数据源

依赖 <dependencies><!--mysql连接--><dependency><groupId>mysql</groupId><artifactId>mysql-connector-java</artifactId><scope>runtime</scope></dependency><!--多数据源--><dependency><g…

VLM 系列——Llava1.6——论文解读

一、概述 1、是什么 Llava1.6 是llava1.5 的升级暂时还没有论文等&#xff0c;是一个多模态视觉-文本大语言模型&#xff0c;可以完成&#xff1a;图像描述、视觉问答、根据图片写代码&#xff08;HTML、JS、CSS&#xff09;&#xff0c;潜在可以完成单个目标的视觉定位、名画…

Qt PCL学习(一):环境搭建

参考 (QT配置pcl)PCL1.12.1QT5.15.2vs2019cmake3.22.4vtk9.1.0visual studio2019Qt5.15.2PCL1.12.1vtk9.1.0cmake3.22.2 本博客用到的所有资源 版本一览&#xff1a;Visual Studio 2019 Qt 5.15.2 PCL 1.12.1 VTK 9.1.0https://pan.baidu.com/s/1xW7xCdR5QzgS1_d1NeIZpQ?pw…

计算机设计大赛 深度学习+opencv+python实现车道线检测 - 自动驾驶

文章目录 0 前言1 课题背景2 实现效果3 卷积神经网络3.1卷积层3.2 池化层3.3 激活函数&#xff1a;3.4 全连接层3.5 使用tensorflow中keras模块实现卷积神经网络 4 YOLOV56 数据集处理7 模型训练8 最后 0 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 &am…

React+Antd+实现省、市区级联下拉多选组件

1、效果 是你要的效果&#xff0c;咱们继续往下看&#xff0c;搜索面板实现省市区下拉&#xff0c;原本有antd的Cascader组件&#xff0c;但是级联组件必须选到子节点&#xff0c;不能只选省&#xff0c;满足不了页面的需求 2、环境准备 1、react18 2、antd 4 3、功能实现 …

IntelliScraper 更新 --可自定义最大输出和相似度 支持Html的内容相似度匹配

场景 之前我们在使用IntelliScraper 初代版本的时候&#xff0c;不少人和我反馈一个问题&#xff0c;那就是最大输出结果只有50个&#xff0c;而且还带有html内容&#xff0c;不支持自动化&#xff0c;我声明一下&#xff0c;自动化目前不会支持&#xff0c;以后也不会支持&am…

Java集合为什么不能使用foreach删除元素

文章目录 前言foreach为什么不能使用foreach操作ArrayList迭代器解析 前言 相信各位程序猿在开发的过程中都用过foreach循环&#xff0c;简单快捷的遍历集合或者数组&#xff0c;但是在通过foreach进行集合操作的时候就不可以了&#xff0c;这是为什么&#xff1f;这里先把问题…

正点原子-STM32定时器学习笔记(1)未完待续

1. 通用定时器简介&#xff08;F1为例&#xff09; F1系列通用定时器有4个&#xff0c;TIM2/TIM3/TIM4/TIM5 主要特性&#xff1a; 16位递增、递减、中心对齐计数器&#xff08;计数值&#xff1a;0~65535&#xff09;&#xff1b; 16位预分频器&#xff08;分频系数&#xff…

[晓理紫]AI专属会议截稿时间订阅

AI专属会议截稿时间订阅 关注{晓理紫}&#xff0c;每日更新最新AI专属会议信息&#xff0c;如感兴趣&#xff0c;请转发给有需要的同学&#xff0c;谢谢支持&#xff01;&#xff01; 如果你感觉对你有所帮助&#xff0c;请关注我&#xff0c;每日准时为你推送最新AI专属会议信…

C语言-4

排序算法简介 /*学习内容&#xff1a;冒泡排序&#xff08;最基本的排序方法&#xff09;选择排序&#xff08;冒泡的优化&#xff09;插入排序&#xff08;在合适的位置插入合适的数据&#xff09; *//*排序分类&#xff1a;1.内部排序待需要进行排序的数据全部存放到内存中&…

[职场] C++开发工程师的岗位职责 #学习方法#笔记

C开发工程师的岗位职责 C开发工程师是利用C语言设计完成软件系统底层模块功能&#xff1b;测试软件模块和软集成产品&#xff0c;进行软件故障的诊断、定位、分析和调试&#xff0c;实施产品测试方案&#xff1b;向业务部门提供软件的后期技术支持。C开发工程师是负责使用C编程…

知识融合前沿技术:构建多模态、公平高效的大规模知识表示

目录 前言1 无监督对齐&#xff1a;构建智能实体关联2 多视角嵌入&#xff1a;提高数据利用效率3 嵌入表示增强&#xff1a;挑战节点相似性&#xff0c;对抗训练解决4 大规模实体对齐&#xff1a;克服模糊性和异构性结论 前言 在信息时代&#xff0c;知识融合成为推动人工智能…

全链游戏的未来趋势与Bridge Champ的创新之路

为了充分探索全链游戏的特点和趋势&#xff0c;以及Bridge Champ如何作为一个创新案例融入这一发展脉络&#xff0c;我们需要深入了解这两者之间的互动和相互影响。全链游戏&#xff0c;或完全基于区块链的游戏&#xff0c;代表了游戏行业的一个重要转型&#xff0c;它们利用区…

kafka-splunk数据通路实践

目的&#xff1a; 鉴于目前网络上没有完整的kafka数据投递至splunk教程&#xff0c;通过本文操作步骤&#xff0c;您将实现kafka数据投递至splunk日志系统 实现思路&#xff1a; 创建kafka集群部署splunk&#xff0c;设置HTTP事件收集器部署connector服务创建connector任务&a…

re:从0开始的CSS学习之路 1. CSS语法规则

0. 写在前面 现在大模型卷的飞起&#xff0c;感觉做页面的活可能以后就不需要人来做了&#xff0c;不知道现在还有没有学前端的必要。。。 1. HTML和CSS结合的三种方式 在HTML中&#xff0c;我们强调HTML并不关心显示样式&#xff0c;样式是CSS的工作&#xff0c;现在就轮到C…

6、基于机器学习的预测

应用机器学习的任何预测任务与这四个策略。 文章目录 1、简介1.1定义预测任务1.2准备预测数据1.3多步预测策略1.3.1多输出模型1.3.2直接策略1.3.3递归策略1.3.4DirRec 策略2、流感趋势示例2.1多输出模型2.2直接策略1、简介 在第二课和第三课中,我们将预测视为一个简单的回归问…

EMNLP 2023精选:Text-to-SQL任务的前沿进展(上篇)——正会论文解读

导语 本文记录了今年的自然语言处理国际顶级会议EMNLP 2023中接收的所有与Text-to-SQL相关&#xff08;通过搜索标题关键词查找得到&#xff0c;可能不全&#xff09;的论文&#xff0c;共计12篇&#xff0c;包含5篇正会论文和7篇Findings论文&#xff0c;以下是对这些论文的略…

Redis(三)主从架构、Redis哨兵架构、Redis集群方案对比、Redis高可用集群搭建、Redis高可用集群之水平扩展

转自 极客时间 Redis主从架构 redis主从架构搭建&#xff0c;配置从节点步骤&#xff1a; 1、复制一份redis.conf文件2、将相关配置修改为如下值&#xff1a; port 6380 pidfile /var/run/redis_6380.pid # 把pid进程号写入pidfile配置的文件 logfile "6380.log" …