【CenterFusion】run_epoch()函数-训练一轮epoch-CenterFusion/src/lib/trainer.py

文件位置:CenterFusion/src/lib/trainer.py
run_epoch作用:CenterFusion 项目训练一轮epoch过程

  • 在 main.py 函数中,生成了训练器,然后再使用训练器训练一个 epoch
  • run_epoch()函数的定义在src\lib\trainer.py150行左右,它的主要过程如下所示:
  def run_epoch(self, phase, epoch, data_loader):model_with_loss = self.model_with_loss'''self.model_with_loss 是 ModelWithLoss 类,这个类又继承 torch.nn.Module 类'''if phase == 'train':model_with_loss.train()'''启用 Batch Normalization 和 Dropout如果模型中有 BN 层 (Batch Normalization) 和 Dropout需要在训练时添加 model.train()model.train()是保证 BN 层能够用到每一批数据的均值和方差对于 Dropout,model.train() 是随机取一部分网络连接来训练更新参数'''else:if len(self.opt.gpus) > 1:model_with_loss = self.model_with_loss.modulemodel_with_loss.eval()'''不启用 Batch Normalization 和 Dropout如果模型中有 BN 层 (Batch Normalization) 和Dropout在测试时添加 model.eval()model.eval() 是保证 BN 层能够用全部训练数据的均值和方差即测试过程中要保证 BN 层的均值和方差不变对于 Dropout,model.eval() 是利用到了所有网络连接,即不进行随机舍弃神经元。'''torch.cuda.empty_cache()'''释放空间'''opt = self.optresults = {}data_time, batch_time = AverageMeter(), AverageMeter()'''新建两个 AverageMeter 对象'''avg_loss_stats = {l: AverageMeter() for l in self.loss_stats \if l == 'tot' or opt.weights[l] > 0}'''为 loss 列表的每个属性赋值一个 AverageMeter 对象'''num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters'''获取数据长度'''bar = Bar('{}/{}'.format(opt.task, opt.exp_id), max=num_iters)end = time.time()'''设置进度条'''for iter_id, batch in enumerate(data_loader):if iter_id >= num_iters:break'''遍历完'''data_time.update(time.time() - end)'''更新 data_time 的值'''for k in batch:if k != 'meta':batch[k] = batch[k].to(device=opt.device, non_blocking=True)'''这里的 batch 是一个 Tensor 对象将其配置到 gpu 上'''output, loss, loss_stats = model_with_loss(batch, phase)'''运行第一阶段(模型训练)'''# backpropagate and step optimizer 反向传播和步进优化器loss = loss.mean()'''求每一层损失值的平均值'''if phase == 'train':self.optimizer.zero_grad()'''将模型的参数梯度初始化为0'''loss.backward()'''反向传播计算梯度'''self.optimizer.step()'''更新所有参数''''''根据 pytorch 中 backward() 函数的计算当网络参量进行反馈时,梯度是累积计算而不是被替换但在处理每一个 batch 时并不需要与其他 batch的梯度混合起来累积计算因此需要对每个 batch 调用一遍 zero_grad() 将参数梯度置 0.'''batch_time.update(time.time() - end)'''更新 batch_time 的值'''end = time.time()Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format(epoch, iter_id, num_iters, phase=phase,total=bar.elapsed_td, eta=bar.eta_td)'''bar.elapsed_td : 经过的时间增量eta=bar.eta_td : 时间间隔'''for l in avg_loss_stats:avg_loss_stats[l].update(loss_stats[l].mean().item(), batch['image'].size(0))Bar.suffix = Bar.suffix + '|{} {:.4f} '.format(l, avg_loss_stats[l].avg)'''更新平均损失'''Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \'|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time)if opt.print_iter > 0:if iter_id % opt.print_iter == 0:print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix)) else:bar.next()'''opt.print_iter = 0 执行 else 语句,显示进度条'''if opt.debug > 0:self.debug(batch, output, iter_id, dataset=data_loader.dataset)'''debug 默认为 0,没有执行 if 语句'''if (phase == 'val' and (opt.run_dataset_eval or opt.eval)):meta = batch['meta']dets = fusion_decode(output, K=opt.K, opt=opt)'''解码器和雷达点云融合调用的这个函数位于 CenterFusion\src\lib\model\decode.py 中这个函数具体实现的功能就是将前面模型训练得到的结果,也就是一些特征图,这些特征图为多维矩阵将特征图与毫米波雷达点云进行映射,映射过程就是将特征图进行维度转换、升维等操作,然后再点乘旋转矩阵'''for k in dets:dets[k] = dets[k].detach().cpu().numpy()'''detach() 阻断反向传播,返回值仍为 tensorcpu() 将变量放在 cpu 上,仍为 tensornumpy() 将 tensor 转换为 numpy'''calib = meta['calib'].detach().numpy() if 'calib' in meta else Nonedets = generic_post_process(opt, dets, meta['c'].cpu().numpy(), meta['s'].cpu().numpy(),output['hm'].shape[2], output['hm'].shape[3], self.opt.num_classes,calib)result = []for i in range(len(dets[0])):if dets[0][i]['score'] > self.opt.out_thresh and all(dets[0][i]['dim'] > 0):result.append(dets[0][i])'''筛选结果'''img_id = batch['meta']['img_id'].numpy().astype(np.int32)[0]'''强制类型转换图片 id'''results[img_id] = resultdel output, loss, loss_statsbar.finish()ret = {k: v.avg for k, v in avg_loss_stats.items()}'''平均损失结果'''ret['time'] = bar.elapsed_td.total_seconds() / 60.return ret, results

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/751638.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

PyTorch学习笔记之激活函数篇(五)

5、PReLU函数 对应的论文链接&#xff1a;https://arxiv.org/abs/1502.01852v1 5.1 公式 PReLU函数的公式&#xff1a; f ( x ) { x , x > 0 α x , x < 0 ( α 是可训练参数 ) f(x) \begin{cases} x&,x>0 \\ \alpha x&,x<0 (\alpha 是可训练参数) …

【遍历方法】浅析Java中字符串、数组、集合的遍历

目录 前言 字符串篇 1.1 使用 for 循环和 charAt 方法 1.2 使用增强 for 循环&#xff08;forEach 循环&#xff09; 1.3 使用 Java 8 的 Stream API 最终效果 数组篇 2.1 使用普通 for 循环 2.2 使用增强型 for 循环( forEach 循环) 2.3 使用 Arrays.asList 和 forE…

解决:springboot项目访问hdfs文件提示guava版本不兼容

1、问题描述 版本说明&#xff1a;我用的hadoop版本&#xff1a;3.1.3 项目可以正常启动&#xff0c;但是调用访问hdfs的服务时候报错,报错消息如下&#xff1a;com.google.common.base.preconditions.checkArgument(ZL java/lang/String;Ljava/lang/Object:)V 原因分析&#x…

Flutter开发进阶之使用工具效率开发

Flutter开发进阶之使用工具效率开发 软件开发团队使用Flutter开发的原因通常是因为Flutter开发性能高、效率高、兼容性好、可拓展性高&#xff0c;作为软件PM来说主要考虑的是范围管理、进度管理、成本管理、资源管理、质量管理、风险管理和沟通管理等&#xff0c;可以看到Flu…

企业内部培训考试系统培训计划功能说明

培训计划是预设好的一套课程系列&#xff0c;包含课程和考试&#xff0c;分多个阶段&#xff0c;每完成一个阶段就会在学习地图上留下标记&#xff0c;让用户看到自己的努力成果&#xff0c;增强成就感&#xff0c;从而坚持完成课程。 企业内部培训考试系统中如何设置培训计划…

基于springboot的购物商城管理系统

1.项目简介 1.1 用户简介 用户主要分为管理员和用户端&#xff1a; 管理员&#xff1a; 管理员可以对后台数据进行管理、拥有最高权限、具体权限有登录后进行首页轮播图的配置管理、商品的配置、新品家具商城的配置管理、、家具商城分类管理配置、家具商城详情商品管理、用户…

Git 下载时需要使用代理?

食用方法 在命令行中&#xff0c;你可以使用以下命令来设置Git的HTTP和HTTPS代理&#xff1a; git config --global http.proxy http://127.0.0.1:6890 git config --global https.proxy https://127.0.0.1:6890 注意是根据自己的实际情况修改IP和端口号 注意如果不想全局配置…

react-面试题

一、组件基础 1. React 事件机制 <div onClick{this.handleClick.bind(this)}>点我</div> React并不是将click事件绑定到了div的真实DOM上&#xff0c;而是在document处监听了所有的事件&#xff0c;当事件发生并且冒泡到document处的时候&#xff0c;React将事…

网络安全JavaSE第二天(持续更新)

3. 基本数据与运算 3.6 运算符 3.6.1 算术运算符 在 Java 中&#xff0c;算术运算符包含&#xff1a;、-、*、/、% public class ArithmeticOperator { public static void main(String[] args) { int a 10; // 定义了一个整型类型的变量 a&#xff0c;它的值是 10 int b …

区块链推广海外市场怎么做,CloudNEO服务商免费为您定制个性化营销方案

随着区块链技术的不断发展和应用场景的扩大&#xff0c;区块链项目希望能够进入海外市场并取得成功已成为越来越多公司的目标之一。然而&#xff0c;要在海外市场推广区块链项目&#xff0c;需要采取有效的营销策略和措施。作为您的区块链项目营销服务商&#xff0c;CloudNEO将…

后端程序员入门react笔记(八)-redux的使用和项目搭建

一个更好用的文档 添加链接描述 箭头函数的简化 //简化前 function countIncreAction(data) {return {type:"INCREMENT",data} } //简化后 const countIncreAction data>({type:"INCREMENT",data })react UI组件库相关资料 组件库连接和推荐 antd组…

Python 多线程大批量处理文件小程序

说明 平时偶尔需要进行重复性的对文件进行重命名、格式转化等。假设以文件复制功能作为目标&#xff0c;设计一个小程序使用多线程对文件进行批量复制。&#xff08;其实以后主要目标是针对Realsense的raw文件进行批量的转化&#xff0c;并借助多线程加速&#xff09; 代码 i…

uv 必备的工具 ps ai 全家桶合集

非常稀有的资源 &#xff0c;必应搜索 易品资源yipinziyuan 可以找到

sqllab第二十关通关笔记

知识点&#xff1a; cookie注入 可以进行url解析错误注入传参位置 get请求post请求cookie传参 输入admin admin进行登录&#xff0c;抓取当前数据包 通过放包发现是一个302跳转的响应包&#xff0c;页面只有一个 I Love Cookies&#xff1b;没什么信息 通过点击页面上方的按钮…

若你有才能,最好能遇上识才之人,高俅发迹的故事很好诠释了千里马与伯乐的关系

若你有才能&#xff0c;最好能遇上识才之人&#xff0c;高俅发迹的故事很好诠释了千里马与伯乐的关系 其实&#xff0c;“千里马”和“伯乐”都是中国古代传说里的角色。伯乐是古代一个善于相马&#xff08;识别马的好坏&#xff09;的人&#xff0c;而“千里马”则是指一匹能跑…

数通-路由技术基础介绍

自治系统——AS&#xff1b;LAN和广播域&#xff1b;CN-2——精品线路 路由器互联网段为同网段&#xff08;不同网段会造成三层不通&#xff0c;非直连则不会产生直连路由&#xff09; 路由选路&#xff0c;一个路由器的各个接口不能配置相同网段。 IP路由表&#xff1a;例&…

前端学习之css样式 背景样式、字体样式、列表样式、边框样式、内外边距元素属性的转换

背景样式 html文件 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>背景样式</title><link rel"stylesheet" href"../css/3.1背景样式.css"> </head> <bo…

华为云APIG跨域资源共享方案

## 浏览器的同源策略 浏览器的同源策略是一种安全机制&#xff0c;旨在保护用户的信息安全和隐私。它限制了一个网页的脚本只能与来自同一源的资源进行交互&#xff0c;即同源策略要求页面中加载的所有资源&#xff08;包括脚本、样式表、图片等&#xff09;必须来自相同的**域…

python之万花尺

1、使用模块 import sys, random, argparse import numpy as np import math import turtle import random from PIL import Image from datetime import datetime from math import gcd 依次使用pip下载即可 2、代码 import sys, random, argparse import numpy as np imp…

通俗易懂的Python循环讲解

循环用于重复执行一些程序块。从上一讲的选择结构&#xff0c;我们已经看到了如何用缩进来表示程序块的隶属关系。循环也会用到类似的写法。 for循环 for循环需要预先设定好循环的次数(n)&#xff0c;然后执行隶属于for的语句n次。 基本构造是 for 元素 in 序列: statemen…