【图像分割】【深度学习】PFNet官方Pytorch代码-PFNet网络损失函数模块解析

【图像分割】【深度学习】PFNet官方Pytorch代码-PFNet网络损失函数模块解析

文章目录

  • 【图像分割】【深度学习】PFNet官方Pytorch代码-PFNet网络损失函数模块解析
  • 前言
  • PM定位模块损失函数
  • FM聚焦模块损失函数
  • 总结


前言

在详细解析PFNet代码之前,首要任务是成功运行PFNet代码【win10下参考教程】,后续学习才有意义。本博客讲解PFNet神经网络模块的损失函数模块代码,不涉及其他功能模块代码。

PFNet中有四个输出预测,一个来自定位模块(PM),三个来自聚焦模块(FM),整体的损失函数为:
ℓ o v e r a l l = ℓ p m + ∑ i = 1 3 2 ( 3 − i ) ℓ f m i {\ell _{overall}}{\rm{ }} = {\rm{ }}{\ell _{pm}} + \sum\limits_{i = 1}^3 {{2^{(3 - i)}}} \ell _{fm}^i overall=pm+i=132(3i)fmi
其中 ℓ f m i \ell _{fm}^i fmi表示在PFNet网络中至上往下第 i i i个的聚焦模块的预测的损失。

博主将各功能模块的代码在不同的博文中进行了详细的解析,点击【win10下参考教程】,博文的目录链接放在前言部分。


PM定位模块损失函数

对于PM模块,使用二值交叉熵损失(Binary CrossEntropy Loss,BCE)损失 ℓ b c e \ell _{{\rm{bce}}} bce和IoU损失 ℓ i o u \ell _{{\rm{iou}}} iou的输出,即 ℓ p m = ℓ b c e + ℓ i o u {\ell _{{\rm{pm}}}} = {\ell _{{\rm{bce}}}} + {\ell _{{\rm{iou}}}} pm=bce+iou,以引导PM探索目标对象的初始位置。
二值交叉熵损失 ℓ i o u \ell _{{\rm{iou}}} iou是常见用法,因此不再具体讲解,本小节主要介绍 ℓ i o u \ell _{{\rm{iou}}} iou,因为它不同于目标检测中用于衡量预测边界框与真实边界框之间的重叠程度,而在论文中对此并没有详细解释,因此博主根据论文源码绘制以下示意图具体讲解 ℓ i o u \ell _{{\rm{iou}}} iou的作用:

ℓ i o u = 1 − i o u {\ell _{{\rm{iou}}}} = 1 - iou iou=1iou i o u iou iou重合度越高, ℓ i o u \ell _{{\rm{iou}}} iou损失越小, i o u = i n t e r u n i o n − i n t e r iou = \frac{{{\rm{inter}}}}{{{\rm{union - inter}}}} iou=unioninterinter。那么 i n t e r inter inter u n i o n − i n t e r union - inter unioninter分别表示什么含义呢?博主将根据所绘制的示意图详细说明其中的含义,如上图所示, m a s k mask mask只有前景为1背景为0俩种值, p r e d pred pred的取值范围则在(0~1)之间,为了方便理解博主也是暴力的拆解成前景为0.8背景为0.2俩种值。

  1. i n t e r inter inter表示真实标签 m a s k mask mask和预测标签 p r e d pred pred对应像素相乘后再对像素值求和的值,如上图的inter所示(只表示到对应元素相乘), i n t e r inter inter的含义可以理解成真实标签的前景部分在预测标签上的预测结果,简单来说就是只考虑预测标签针对真实前景的预测效果,默认背景部分完全预测正确,屏蔽了背景不作考虑,因此 i n t e r = T b + P f inter=T_b+P_f inter=Tb+Pf
  2. u n i o n union union表示真实标签 m a s k mask mask和预测标签 p r e d pred pred对应像素相加后再对像素值求和的值,如上图的union所示(只表示到对应元素相加),那么 u n i o n − i n t e r union-inter unioninter的含义可以理解成真实标签的背景部分在预测标签上的预测结果,如上图的union-inter所示,简单来说就是只考虑预测标签针对真实背景的预测效果,默认前景部分完全预测正确,屏蔽了前景不作考虑,因此 u n i o n − i n t e r = T f + P b union-inter=T_f+P_b unioninter=Tf+Pb

T b T_b Tb表示背景位置真实像素求和值(也就是0), P f P_f Pf表示前景位置预测像素求和值, T f T_f Tf表示前景位置真实像素求和值, P b P_b Pb表示背景位置预测像素求和值。
注意!!!!区分背景位置预测像素和预测背景像素俩个概念!!!前者是真实背景像素位置可能真确预测为背景,也可能错误预测成前景;后者则是对预测一个像素位置为背景。

解释了 i n t e r inter inter u n i o n − i n t e r union - inter unioninter的含义, i o u iou iou也可以表示成 i o u = T b + P f T f + P p iou = \frac{{{T_b} + {P_{\rm{f}}}}}{{{T_f} + {P_p}}} iou=Tf+PpTb+Pf T b T_b Tb T f T_f Tf是固定不变的,那么 ℓ i o u \ell _{{\rm{iou}}} iou的优化目标就是 P f P_f Pf越来越大且 P b P_b Pb越来越小。
代码位置:train.py

# PM loss function
bce_loss = nn.BCEWithLogitsLoss().cuda(device_ids[0])
iou_loss = loss.IOU().cuda(device_ids[0])
def bce_iou_loss(pred, target):bce_out = bce_loss(pred, target)iou_out = iou_loss(pred, target)loss = bce_out + iou_outreturn loss

代码位置:loss.py

博主为了方便大家理解,小改了下源码,但是没有丝毫影响源码的原始目的。

class IOU(torch.nn.Module):def __init__(self):super(IOU, self).__init__()def _iou(self, pred, target):pred = torch.sigmoid(pred)# 交集区域inter = (pred * target).sum(dim=(2, 3))# 并集区域union = (pred + target).sum(dim=(2, 3))# iou损失iou = 1 - (inter / (union- inter))return iou.mean()def forward(self, pred, target):return self._iou(pred, target)

FM聚焦模块损失函数

对于FM模块,希望更多地关注对象的边界、细长区域或孔处等分散注意力区域。因此,使用加权二值交叉熵损失(Binary CrossEntropy Loss,BCE)损失 ℓ w b c e \ell _{{\rm{wbce}}} wbce和加权IoU损失 ℓ w i o u \ell _{{\rm{wiou}}} wiou的输出,即 ℓ f m = ℓ w b c e + ℓ w i o u {\ell _{{\rm{fm}}}} = {\ell _{{\rm{wbce}}}} + {\ell _{{\rm{wiou}}}} fm=wbce+wiou,以迫使FM更加关注可能的分散注意力区域。
ℓ i o u \ell _{{\rm{iou}}} iou在上个章节就进行了说明, ℓ w i o u \ell _{{\rm{wiou}}} wiou大同小异,因此不再具体讲解,本小节主要介绍 ℓ w b c e \ell _{{\rm{wbce}}} wbce ℓ w i o u \ell _{{\rm{wiou}}} wiou中的 w w w权重的产生,在论文中对此并没有详细解释,因此博主根据论文源码绘制以下示意图具体讲解 w w w的作用:

w w w权重是通过对标签 m a s k mask mask进行平均池化操作,再减去 m a s k mask mask,最后取绝对值:
w = 1 + 5 × ∣ A v g P o o l ( m a s k ) − m a s k ∣ w = 1 + 5 \times \left| {\left. {AvgPool(mask) - mask} \right|} \right. w=1+5×AvgPool(mask)mask
为什么这么简单的操作就能让 w w w更加关注可能的分散注意力区域?博主分以下几种情况讨论:

  • 第一种情况:如上图1所示位置,该前景像素位于前景目标的内部,因此不是对象的边界、细长区域或孔处等分散注意力区域,其 w w w权重计算为1,不需要对其做额外加强;
  • 第二种情况:如上图2所示位置,该前景像素是对象的边界,属于分散注意力区域,其 w w w权重计算为4.9,可谓是剧烈加强;
  • 第三种情况:如上图3所示位置,该背景像素是模糊边界,也属于分散注意力区域,其 w w w权重计算为4.3,也是剧烈加强;
  • 第四种情况:如上图4所示位置,该像素是背景,其 w w w权重计算为1,不需要对其做额外加强;

博主绘制的示意图只是为了方便理解,真实的池化核大小不可能只有3×3那么小,源码中使用的池化核大小是31×31。
代码位置:train.py

# FM loss function
structure_loss = loss.structure_loss().cuda(device_ids[0])

代码位置:loss.py

class structure_loss(torch.nn.Module):def __init__(self):super(structure_loss, self).__init__()def _structure_loss(self, pred, mask):print(pred.shape)# 根据mask标签生成关于mask的权重# 根据公式可以知道,越是靠近前景目标边缘的像素,权重可能就越高,而越靠近前景目标的中心的像素权重越低,最低为1weit = 1 + 5 * torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask)# 因为预测标签还要进行加权,暂时需要保留结构,所以损失在每个元素上计算,reduce选择nonewbce = F.binary_cross_entropy_with_logits(pred, mask, reduce='none')# 加权的bcewbce = (weit * wbce).sum(dim=(2, 3)) / weit.sum(dim=(2, 3))pred = torch.sigmoid(pred)# 交集区域inter = ((pred * mask) * weit).sum(dim=(2, 3))# 并集区域union = ((pred + mask) * weit).sum(dim=(2, 3))# 加权的iou损失wiou = 1 - (inter) / (union - inter)return (wbce + wiou).mean()def forward(self, pred, mask):return self._structure_loss(pred, mask)

总结

尽可能简单、详细的介绍PFNet网络中的损失函数模块的结构和代码。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/178449.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

压缩包文件丢失?4个正确找回方法分享!

“我有一个很重要的压缩包保存在电脑上,但是不知道为什么里面有些文件丢失了。有什么方法可以快速找回压缩文件?请大家给我支支招吧!” 如果我们的文件太多,将它们放在压缩包中不仅能让文件更有序,还能更合理的节省电脑…

LD_PRELOAD劫持

LD_PRELOAD劫持 <1> LD_PRELOAD简介 LD_PRELOAD 是linux下的一个环境变量。用于动态链接库的加载&#xff0c;在动态链接库的过程中他的优先级是最高的。类似于 .user.ini 中的 auto_prepend_file&#xff0c;那么我们就可以在自己定义的动态链接库中装入恶意函数。 也…

tp8 使用rabbitMQ(3)发布/订阅

发布/订阅 当我们想把一个消息&#xff0c;发送给 多个消费者的时候&#xff0c;我们把这种模式叫做发布/订阅模式&#xff0c;比如我们做两个消费者&#xff0c;其中一个消费者把消息写入磁盘中&#xff0c;别一个消费者把消息结果输出到屏幕上&#xff0c;就要用到发布订阅模…

产品化和商品化

我们经常会在IT产业听过以下岗位&#xff1a; 1、产品序列&#xff1a;产品行销经理 2、产品序列&#xff1a;产品经理、需求分析师、产品详细设计工程师、UIUE设计师 3、产品序列&#xff1a;业务架构师、应用架构师、数据架构师、技术架构师 4、研发序列&#xff1a;创新原型…

Java 图片验证码需求分析

&#x1f497;wei_shuo的个人主页 &#x1f4ab;wei_shuo的学习社区 &#x1f310;Hello World &#xff01; 图片验证码 需求分析 连续因输错密码而登录失败时&#xff0c;记录其连续输错密码的累加次数&#xff1b;若在次数小于5时&#xff0c;用户输入正确的密码并成功登录…

前K个高频单词(Java详解)

一、题目描述 给定一个单词列表 words 和一个整数 k &#xff0c;返回前 k 个出现次数最多的单词。 返回的答案应该按单词出现频率由高到低排序。如果不同的单词有相同出现频率&#xff0c; 按字典顺序 排序。 示例1&#xff1a; 输入: words ["i", "love&…

浅谈硬件连通性测试几大优势

硬件连通性测试是确保硬件系统正常运行、提高系统可靠性和降低生产成本的关键步骤。在现代工程和制造中&#xff0c;将连通性测试纳入生产流程是一个明智的选择&#xff0c;有助于确保硬件产品的质量和性能达到最优水平。本文将介绍硬件连通性测试的主要优势有哪些! 一、提高系…

游戏测试和软件测试有什么区别

针对手游而言&#xff0c;游戏测试的本质是APP&#xff0c;所以不少手游的测试方式与APP测试异曲同工&#xff0c;然而也有所不同。APP更多的是具有一种工具&#xff0c;一款APP好不好用不重要&#xff0c;关键点在于实用。而游戏则具有一种玩具属性&#xff0c;它并不见得实用…

基于Python+requests编写的自动化测试项目-实现流程化的接口串联

框架产生目的&#xff1a;公司走的是敏捷开发模式&#xff0c;编写这种框架是为了能够满足当前这种发展模式&#xff0c;用于前后端联调之前&#xff08;后端开发完接口&#xff0c;前端还没有将业务处理完毕的时候&#xff09;以及日后回归阶段&#xff0c;方便为自己腾出学(m…

图像异常检测研究现状综述

论文标题&#xff1a;图像异常检测研究现状综述 作者&#xff1a;吕承侃 1, 2 沈 飞 1, 2, 3 张正涛 1, 2, 3 张 峰 1, 2, 3 发表日期&#xff1a;2022年6月 阅读日期 &#xff1a;2023年11月28 研究背景&#xff1a; 图像异常检测是计算机视觉领域的一个热门研究课题, 其目…

leetCode 39.组合总和 + 回溯算法 + 剪枝 + 图解 + 笔记

39. 组合总和 - 力扣&#xff08;LeetCode&#xff09; 给你一个 无重复元素 的整数数组 candidates 和一个目标整数 target &#xff0c;找出 candidates 中可以使数字和为目标数 target 的 所有 不同组合 &#xff0c;并以列表形式返回。你可以按 任意顺序 返回这些组合 can…

2015年五一杯数学建模A题不确定性条件下的最优路径问题解题全过程文档及程序

2015年五一杯数学建模 A题 不确定性条件下的最优路径问题 原题再现 目前&#xff0c;交通拥挤和事故正越来越严重的困扰着城市交通。随着我国交通运输事业的迅速发展&#xff0c;交通“拥塞”已经成为很多城市的“痼疾”。在复杂的交通环境下&#xff0c;如何寻找一条可靠、快…

HarmonyOS 数据持久化 Preferences 如何在页面中对数据进行读写

背景介绍 最近在了解并跟着官方文档尝试做一个鸿蒙app 小demo的过程中对在app中保存数据遇到些问题 特此记录下来 这里的数据持久化以 Preferences为例子展开 废话不多说 这里直接上节目(官方提供的文档示例:) 以Stage模型为例 1.明确preferences的类型 import data_prefer…

印刷企业建设数字工厂管理系统的工作内容有哪些

随着科技的不断进步&#xff0c;数字工厂管理系统在印刷企业中的应用越来越广泛。这种系统可以有效地整合企业内外资源&#xff0c;提高生产效率&#xff0c;降低生产成本&#xff0c;并为印刷企业提供更好的业务运营与管理模式。本文将从以下几个方面探讨印刷企业建设数字工厂…

如何用postman实现接口自动化测试

postman使用 开发中经常用postman来测试接口&#xff0c;一个简单的注册接口用postman测试&#xff1a; 接口正常工作只是最基本的要求&#xff0c;经常要评估接口性能&#xff0c;进行压力测试。 postman进行简单压力测试 下面是压测数据源&#xff0c;支持json和csv两个格…

Kibana部署

服务器 安装软件主机名IP地址系统版本配置KibanaElk10.3.145.14centos7.5.18042核4G软件版本&#xff1a;nginx-1.14.2、kibana-7.13.2-linux-x86_64.tar.gz 1. 安装配置Kibana &#xff08;1&#xff09;安装 [rootelk ~]# tar zxf kibana-7.13.2-linux-x86_64.tar.gz -C…

easyExcel 注解开发 快速以及简单上手 以及包含工具类

easyExcel 简单快速使用 1. mevan 这里版本我这里选的是 poi 4.1.2和 ali的easyexcel 的 3.3.1。 因为阿里easy是根据poi的依赖开发的有关系&#xff0c;两者需要对应要不然就会有很多bug和错误在运行时发生。需要版本对应&#xff0c;然而就是easy的代码也会有bug这个版本是比…

运动鞋品牌识别

一、前期工作 1. 设置GPU from tensorflow import keras from tensorflow.keras import layers,models import os, PIL, pathlib import matplotlib.pyplot as plt import tensorflow as tfgpus tf.config.list_physical_devices("GPU")if gpus:gpu0 …

Leetcode—18.四数之和【中等】

2023每日刷题&#xff08;四十一&#xff09; Leetcode—18.四数之和 实现代码 class Solution { public:vector<vector<int>> fourSum(vector<int>& nums, int target) {vector<vector<int>> ans;sort(nums.begin(), nums.end());int n …

chatgpt prompt提示词

ChatGPT 最近十分火爆&#xff0c;今天我也来让 ChatGPT 帮我阅读一下 Vue3 的源代码。 都知道 Vue3 组件有一个 setup函数。那么它内部做了什么呢&#xff0c;今天跟随 ChatGPT 来一探究竟。 实战 1.setup setup 函数在什么位置呢&#xff0c;我们不知道他的实现函数名称&…