RetinaNet+focal loss

one stage 精度不高,一个主要原因是正负样本的不平衡,以YOLO为例,每个grid cell有5个预测,本来正负样本的数量就有差距,再相当于进行5倍放大后,这种数量上的差异更会被放大。

文中提出新的分类损失函数Focal loss,该损失函数通过抑制那些容易分类样本的权重,将注意力集中在那些难以区分的样本上,有效控制正负样本比例,防止失衡现象。也就是focal loss用于解决正负样本不平衡与难易样本不平衡的问题.

其中用于控制正负样本的数量失衡,用于控制简单/难区分样本数量失衡。一般=0.25,=2.也就是正样本loss相对增加,负样本loss相对减少,负样本相比正样本loss减少的倍数为3,同时困难样本loss相对增加,简单样本loss相对减少.

模型采用FPN,P3到P7,其中P7能够增加对大物体的检测。

在FPN的P3-P7中分别设置32x32-512x512尺寸不等的anchor,比例设置为{1:2, 1:1, 2:1}。每一层一共有9个anchor,不同层能覆盖的size范围为32-813。对每一个anchor,都对应一个K维的one-hot向量(K是类别数)和4维的位置回归向量。

同时分类子网对A个anchor,每个anchor中的K个类别,都预测一个存在概率。如下图所示,对于FPN的每一层输出,对分类子网来说,加上四层3x3x256卷积的FCN网络,最后一层的卷积稍有不同,用3x3xKA,最后一层维度变为KA表示,对于每个anchor,都是一个K维向量,表示每一类的概率,然后因为one-hot属性,选取概率得分最高的设为1,其余k-1为归0。传统的RPN在分类子网用的是1x1x18,只有一层,而在RetinaNet中,用的是更深的卷积,总共有5层,实验证明,这种卷积层的加深,对结果有帮助。与分类子网并行,对每一层FPN输出接上一个位置回归子网,该子网本质也是FCN网络,预测的是anchor和它对应的一个GT位置的偏移量。首先也是4层256维卷积,最后一层是4A维度,即对每一个anchor,回归一个(x,y,w,h)四维向量。注意,此时的位置回归是类别无关的。分类和回归子网虽然是相似的结构,但是参数是不共享的

代码:

正负样本计算loss的两种方式


import torch
import torch.nn.functional as Fdef focal_loss_one(alpha, beta, cls_preds, gts):print('======第一种实现方式=======')num_pos = gts.sum()print('==num_pos:', num_pos)alpha_tensor = torch.ones_like(cls_preds) * alphaalpha_tensor = torch.where(torch.eq(gts, 1.), alpha_tensor, 1. - alpha_tensor)print('===alpha_tensor===', alpha_tensor)preds = torch.where(torch.eq(gts, 1.), cls_preds, 1. - cls_preds)print('===1. - preds===', 1. - preds)focal_weight = alpha_tensor * torch.pow((1. - preds), beta)print('==focal_weight:', focal_weight)batch_bce_loss = -(gts * torch.log(cls_preds) + (1. - gts) * torch.log(1. - cls_preds))batch_focal_loss = focal_weight * batch_bce_lossprint('==batch_focal_loss:', batch_focal_loss)batch_focal_loss = batch_focal_loss.sum()print('== batch_focal_loss:', batch_focal_loss)print('==batch_focal_loss.item():', batch_focal_loss.item())if num_pos != 0:mean_batch_focal_loss = batch_focal_loss / num_poselse:mean_batch_focal_loss = batch_focal_lossprint('==mean_batch_focal_loss:', mean_batch_focal_loss)def focal_loss_two(alpha, beta, cls_preds, gts):print('======第二种实现方式=======')pos_inds = (gts == 1.0).float()print('==pos_inds:', pos_inds)neg_inds = (gts != 1.0).float()print('===neg_inds:', neg_inds)pos_loss = -pos_inds * alpha * (1.0 - cls_preds) ** beta * torch.log(cls_preds)neg_loss = -neg_inds * (1 - alpha) * ((cls_preds) ** beta) * torch.log(1.0 - cls_preds)num_pos = pos_inds.float().sum()print('==num_pos:', num_pos)pos_loss = pos_loss.sum()neg_loss = neg_loss.sum()if num_pos == 0:mean_batch_focal_loss = neg_losselse:mean_batch_focal_loss = (pos_loss + neg_loss) / num_posprint('==mean_batch_focal_loss:', mean_batch_focal_loss)def focal_loss_three(alpha, beta, cls_preds, gts):print('======第三种实现方式=======')num_pos = gts.sum()pred_sigmoid = cls_predstarget = gts.type_as(pred_sigmoid)pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)focal_weight = (alpha * target + (1 - alpha) *(1 - target)) * pt.pow(beta)batch_focal_loss = F.binary_cross_entropy(pred_sigmoid, target, reduction='none') * focal_weightbatch_focal_loss = batch_focal_loss.sum()if num_pos != 0:mean_batch_focal_loss = batch_focal_loss / num_poselse:mean_batch_focal_loss = batch_focal_lossprint('==mean_batch_focal_loss:', mean_batch_focal_loss)
bs = 2
num_class = 3
alpha = 0.25
beta = 2
# (B, cls)
cls_preds = torch.rand([bs, num_class], dtype=torch.float)
print('==cls_preds:', cls_preds)
gts = torch.tensor([0, 2])
# (B, cls)
gts = F.one_hot(gts, num_classes=num_class).type_as(cls_preds)
print('===gts===', gts)
focal_loss_one(alpha, beta, cls_preds, gts)
focal_loss_two(alpha, beta, cls_preds, gts)
focal_loss_three(alpha, beta, cls_preds, gts)

只有正样本计算loss:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variableclass FocalLoss(nn.Module):"""This criterion is a implemenation of Focal Loss, which is proposed inFocal Loss for Dense Object Detection.Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])The losses are averaged across observations for each minibatch.Args:alpha(1D Tensor, Variable) : the scalar factor for this criteriongamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5),putting more focus on hard, misclassified examplessize_average(bool): By default, the losses are averaged over observations for each minibatch.However, if the field size_average is set to False, the losses areinstead summed for each minibatch."""def __init__(self, class_num, alpha=None, gamma=2, size_average=True):super(FocalLoss, self).__init__()if alpha is None:self.alpha = Variable(torch.ones(class_num, 1))else:if isinstance(alpha, Variable):self.alpha = alphaelse:self.alpha = Variable(alpha)self.gamma = gammaself.class_num = class_numself.size_average = size_averagedef forward(self, inputs, targets):N = inputs.size(0)C = inputs.size(1)P = F.softmax(inputs, dim=-1)print('===P:', P)#.data 获取variable的tensorclass_mask = inputs.data.new(N, C).fill_(0)class_mask = Variable(class_mask)ids = targets.view(-1, 1)class_mask.scatter_(1, ids.data, 1.)#得到onehotprint('==class_mask:', class_mask)if inputs.is_cuda and not self.alpha.is_cuda:self.alpha = self.alpha.cuda()alpha = self.alpha[ids.data.view(-1)]print('==alpha:', alpha)probs = (P*class_mask).sum(1).view(-1, 1)print('==probs:', probs)log_p = probs.log()batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_pif self.size_average:loss = batch_loss.mean()else:loss = batch_loss.sum()return lossdef debug_focal():import numpy as np#只对困难样本计算lossloss = FocalLoss(class_num=8)#, alpha=torch.tensor([0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25]).reshape(-1, 1))inputs = torch.rand(2, 8)print('==inputs:', inputs)# print('==inputs.data:', inputs.data)# targets = torch.from_numpy(np.array([[1,0,0,0,0,0,0,0],#                                      [0,1,0,0,0,0,0,0]]))targets = torch.from_numpy(np.array([0, 1]))cost = loss(inputs, targets)print('===cost===:', cost)if __name__ == '__main__':debug_focal()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/493259.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

真实用户首次披露Waymo无人车服务体验: 为避开左转, 故意绕路

来源 :Ars Technica编译 :机器之能 高璇外国网友炸了:「就像看了一部大导演导的烂片一样。」在过去的 18 个月里,Waymo 的汽车一直在凤凰城的东南角运送乘客。该公司在合同中明确规定禁止乘客讨论用户体验,对项目信息进…

“横平竖直”进行连线+将相邻框进行合并

一.横平竖直”进行连线 解法1.将一些坐标点按照x相等,y相等连起来 解法1.根据 x或y总有一个相等的,用np.sum来找出和为1的点,然后在连起来,存在重复连线的问题. import numpy as npcoord np.array([[10, 60],[10, 20],[20, 20],[40, 40],[40, 60],[20, 40]])img np.zeros(…

一文看透汽车芯片!巨头布局技术路线全解密【附下载】| 智东西内参

来源:智东西摘要:一文看透汽车芯片!巨头布局技术路线全解密智能驾驶涉及人机交互、视觉处理、智能决策等,核心是 AI 算法和芯片。伴随汽车电子化提速,汽车半导体加速成长,2017 年全球市场规模 288 亿美元&a…

详细介绍软件架构设计的三个维度

如果你对项目管理、系统架构有兴趣,请加微信订阅号“softjg”,加入这个PM、架构师的大家庭 架构设计是一个非常大的话题,不管写几篇文章,接触到的始终只是冰山一角,更多的是实践中去体会。这篇文章主要介绍面向对象OO、…

中国智能语音行业研究

报告来源:中信证券作者:刘雯蜀 杨泽原 张若海智能语音作为人机交互的新型方式,有望大规模推广,中国市场是更适合语音交互的市场。2017年中国人工智能市场规模达约220亿元,智能语音占中国人工智能市场份额的22%&#…

SQL2012 附加数据库提示5120错误解决方法

在win8.1 x64系统上使用sql2012进行附加数据库(包括在x86系统正在使用的数据库文件,直接拷贝附加在X64系统中)时,提示无法打开文件,5120错误。 这个错误是因为没有操作权限,所以附加的时候出错,…

pytorch利用rnn通过sin预测cos 利用lstm预测手写数字

一.利用rnn通过sin预测cos 1.首先可视化一下数据 import numpy as np from matplotlib import pyplot as plt def show(sin_np,cos_np):plt.figure()plt.title(Sin and Cos, fontsize18)plt.plot(steps, sin_np, r-, labelsin)plt.plot(steps, cos_np, b-, labelcos)plt.lege…

高德纳咨询公司(Gartner)预测:2019年七大人工智能科技趋势

来源:创新研究摘要:人工智能技术对我们的工作环境、工作种类等等正在产生日益深刻的影响,其结果或好或坏都有可能。为应对这种改变,特别是负面的变化,高德纳咨询公司(Gartner)于2018年12月13日发…

美爆!《自然》公布2018年19张最震撼的科学图片

来源:前瞻网 摘要:2018年注定将载入科学史册:气候上,从加利福尼亚烧到开普敦的致命野火和极端干旱、历史罕见;医学上,克隆和成像技术的进步既带来希望,也产生了争议;生物上,一系列事件让人们意识…

python实现Trie 树+朴素匹配字符串+RK算法匹配字符串+kmp算法匹配字符串

一.trie树应用: 相应leetcode 常用于搜索提示,如当输入一个网址,可以自动搜索出可能的选择。当没有完全匹配的搜索结果,可以返回前缀最相似的可能。 例如三个单词app, apple, add,我们按照以下规则创建了一颗Trie树.对于从树的根…

天才也勤奋!DeepMind哈萨比斯自述:领导400名博士向前,每天工作至凌晨4点

来源:量子位你见过凌晨4点的伦敦吗?哈萨比斯天天见。这位DeepMind创始人、AlphaGo之父,一直是全球赞颂的当世天才,但每天要到凌晨4点,才能睡下。这是哈萨比斯最新采访中透露的作息时间,他告诉《星期日泰晤士…

RNN知识+LSTM知识+encoder-decoder+ctc+基于pytorch的crnn网络结构

一.基础知识: 下图是一个循环神经网络实现语言模型的示例,可以看出其是基于当前的输入与过去的输入序列,预测序列的下一个字符. 序列特点就是某一步的输出不仅依赖于这一步的输入,还依赖于其他步的输入或输…

利用flask写的接口(base64, 二进制, 上传视频流)+异步+gunicorn部署Flask服务+多gpu卡部署

一.flask写的接口 1.1 manage.py启动服务(发送图片base64版) 这里要注意的是用docker的话,记得端口映射 #coding:utf-8 import base64 import io import logging import picklefrom flask import Flask, jsonify, request from PIL import Image from sklearn import metric…

2018中国自动驾驶市场专题分析

来源:智车科技未来智能实验室是人工智能学家与科学院相关机构联合成立的人工智能,互联网和脑科学交叉研究机构。未来智能实验室的主要工作包括:建立AI智能系统智商评测体系,开展世界人工智能智商评测;开展互联网&#…

python写日志

需要再加入按照日期生成日志 #coding:utf-8 import logging import logging.handlers class Logger:logFile def __init__(self, logFile):self.logFile logFileself.logger logging.getLogger(mylogger)self.logger.setLevel(logging.INFO)rf_handler logging.handlers.…

MIT科学家Dimitri P. Bertsekas最新2019出版《强化学习与最优控制》(附书稿PDF讲义)...

来源:专知摘要:MIT科学家Dimitri P. Bertsekas今日发布了一份2019即将出版的《强化学习与最优控制》书稿及讲义,该专著目的在于探索这人工智能与最优控制的共同边界,形成一个可以在任一领域具有背景的人员都可以访问的桥梁。REINF…

yolov3 anchors用kmeans聚类出先验框+anchor宽高比分析

一.yolov v3聚类出框 # -*- coding: utf-8 -*- import numpy as np import random import argparse import os# # 参数名称 # parser argparse.ArgumentParser(description使用该脚本生成YOLO-V3的anchor boxes\n) # parser.add_argument(--input_annotation_txt…

Geoff Hinton:全新的想法将比微小的改进更有影响力

来源:AI科技评论摘要:日前,WIRED 对 Hinton 进行了一次专访,在访谈中,WIRED 针对人工智能带来的道德挑战和面临的挑战等问题进行了提问,以下为谈话内容。“作为一名谷歌高管,我认为在公开场合抱…

修改TOMCAT服务器图标为应用LOGO

在tomcat下部署应用程序,运行后,发现在地址栏中会显示tomcat的小猫咪图标。有时候,我们自己不想显示这个图标,想换成自己定义的的图标,那么按如下方法操作即可: 参考网上的解决方案:1、将$TOMCA…

python连接mysql的一些基础知识+安装Navicat可视化数据库+flask_sqlalchemy写数据库

一.mysql基础知识 1.connect连接数据库 import pymysqldef get_conn():conn pymysql.connect(hostxxx.xxx.xxx.xxx, port3306, userroot, passwd, dbnewspaper_rest) # db:表示数据库名称return conn 2.创建表 im…