YOLOv3改进方法增加特征尺度和训练层数

YOLOv3改进方法

YOLOv3的改进方法有很多,本文讲述的是增加一个特征尺度。
以YOLOv3-darknet53(ALexeyAB版本)为基础,增加了第4个特征尺度:104*104。原版YOLOv3网络结构:
原版YOLOv3网络结构
YOLOv3-4l网络结构:
YOLOv3-4l网络结构

即,在经过2倍上采样后,输出的特征尺度由52x52提升至104x104,再通过route层将第109层与特征提取网络的第11层特征进行特征融合,以充分利用深层特征和浅层特征。其余的特征融合分别为:2倍上采样后输出的第85层和第97层。通过route层分别将第85层与第61层,第97层与第36层的特征图进行特征融合。四个特征尺度分别为:104x104,52x52,26x26和13x13。

具体的步骤为:

(1)修改配置文件cfg
再增加一个检测尺度(在原yolov3的最后一层yolo层的后面,再增加一个检测层:在下方链接里的cfg文件最后的“#######”注释行之后的部分,便是增加的检测层结构)。
yolo-4l的cfg下载地址
链接:https://pan.baidu.com/s/1b92jmcAPTgzxua4Pat7p4A
提取码:xji2
注意:网盘里提供的cfg配置文件,需要进行相应参数修改(修改示例见链接:yolov3的cfg配置文件注释及修改示例)。

(2)重新计算anchors
由于原先是3个检测尺度共9个anchors,此时是4层共12个anchors。且不同数据库的anchors值不一样(比如自行构建的数据库),所以必须重新计算anchors,并更新到cfg文件中。如果选取的先验框维度比较合适,那么模型就会更容易学习,更易收敛,从而做出更好的预测,预测框与标注真实框的IOU就会更好。

计算数据库的anchors的命令为:

./darknet detector calc_anchors /usr/cx/darknetalexeyAB/darknet-master/names_data/voc.data -num_of_clusters 12 -width 416 -height 416 -show 1 

注意:/usr/cx/darknetalexeyAB/darknet-master/names_data/voc.data是我们自己的voc.data的路径,根据自己的项目自行进行修改。
结果如下图所示:
在这里插入图片描述
计算多次,每次的anchors值会不一样,但基本相差无几。其实这些anchors值,就是先验框,就是样本库里最经常出现的几类边界框。通过选取专属于实际数据库的anchors,将会加速收敛,更容易学习,提高IOU值。

这里还有另外1种计算anchors的方法, 通过脚本文件来计算anchors锚点值。脚本文件如下所示:

# coding=utf-8
# 通过k-means ++ 算法获取anchors的尺寸
import numpy as np# 定义Box类,描述bounding box的坐标
class Box():def __init__(self, x, y, w, h):self.x = xself.y = yself.w = wself.h = h# 计算两个box在某个轴上的重叠部分
# x1是box1的中心在该轴上的坐标
# len1是box1在该轴上的长度
# x2是box2的中心在该轴上的坐标
# len2是box2在该轴上的长度
# 返回值是该轴上重叠的长度
def overlap(x1, len1, x2, len2):len1_half = len1 / 2len2_half = len2 / 2left = max(x1 - len1_half, x2 - len2_half)right = min(x1 + len1_half, x2 + len2_half)return right - left# 计算box a 和box b 的交集面积
# a和b都是Box类型实例
# 返回值area是box a 和box b 的交集面积
def box_intersection(a, b):w = overlap(a.x, a.w, b.x, b.w)h = overlap(a.y, a.h, b.y, b.h)if w < 0 or h < 0:return 0area = w * hreturn area# 计算 box a 和 box b 的并集面积
# a和b都是Box类型实例
# 返回值u是box a 和box b 的并集面积
def box_union(a, b):i = box_intersection(a, b)u = a.w * a.h + b.w * b.h - ireturn u# 计算 box a 和 box b 的 iou
# a和b都是Box类型实例
# 返回值是box a 和box b 的iou
def box_iou(a, b):return box_intersection(a, b) / box_union(a, b)# 使用k-means ++ 初始化 centroids,减少随机初始化的centroids对最终结果的影响
# boxes是所有bounding boxes的Box对象列表
# n_anchors是k-means的k值
# 返回值centroids 是初始化的n_anchors个centroid
def init_centroids(boxes,n_anchors):centroids = []boxes_num = len(boxes)centroid_index = np.random.choice(boxes_num, 1)centroids.append(boxes[centroid_index])print(centroids[0].w,centroids[0].h)for centroid_index in range(0,n_anchors-1):sum_distance = 0distance_thresh = 0distance_list = []cur_sum = 0for box in boxes:min_distance = 1for centroid_i, centroid in enumerate(centroids):distance = (1 - box_iou(box, centroid))if distance < min_distance:min_distance = distancesum_distance += min_distancedistance_list.append(min_distance)distance_thresh = sum_distance*np.random.random()for i in range(0,boxes_num):cur_sum += distance_list[i]if cur_sum > distance_thresh:centroids.append(boxes[i])print(boxes[i].w, boxes[i].h)breakreturn centroids# 进行 k-means 计算新的centroids
# boxes是所有bounding boxes的Box对象列表
# n_anchors是k-means的k值
# centroids是所有簇的中心
# 返回值new_centroids 是计算出的新簇中心
# 返回值groups是n_anchors个簇包含的boxes的列表
# 返回值loss是所有box距离所属的最近的centroid的距离的和
def do_kmeans(n_anchors, boxes, centroids):loss = 0groups = []new_centroids = []for i in range(n_anchors):groups.append([])new_centroids.append(Box(0, 0, 0, 0))for box in boxes:min_distance = 1group_index = 0for centroid_index, centroid in enumerate(centroids):distance = (1 - box_iou(box, centroid))if distance < min_distance:min_distance = distancegroup_index = centroid_indexgroups[group_index].append(box)loss += min_distancenew_centroids[group_index].w += box.wnew_centroids[group_index].h += box.hfor i in range(n_anchors):new_centroids[i].w /= len(groups[i])new_centroids[i].h /= len(groups[i])return new_centroids, groups, loss# 计算给定bounding boxes的n_anchors数量的centroids
# label_path是训练集列表文件地址
# n_anchors 是anchors的数量
# loss_convergence是允许的loss的最小变化值
# grid_size * grid_size 是栅格数量
# iterations_num是最大迭代次数
# plus = 1时启用k means ++ 初始化centroids
def compute_centroids(label_path,n_anchors,loss_convergence,grid_size,iterations_num,plus):boxes = []label_files = []f = open(label_path)for line in f:label_path = line.rstrip().replace('images', 'labels')label_path = label_path.replace('JPEGImages', 'labels')label_path = label_path.replace('.jpg', '.txt')label_path = label_path.replace('.JPEG', '.txt')label_files.append(label_path)f.close()for label_file in label_files:f = open(label_file)for line in f:temp = line.strip().split(" ")if len(temp) > 1:boxes.append(Box(0, 0, float(temp[3]), float(temp[4])))if plus:centroids = init_centroids(boxes, n_anchors)else:centroid_indices = np.random.choice(len(boxes), n_anchors)centroids = []for centroid_index in centroid_indices:centroids.append(boxes[centroid_index])# iterate k-meanscentroids, groups, old_loss = do_kmeans(n_anchors, boxes, centroids)iterations = 1while (True):centroids, groups, loss = do_kmeans(n_anchors, boxes, centroids)iterations = iterations + 1print("loss = %f" % loss)if abs(old_loss - loss) < loss_convergence or iterations > iterations_num:breakold_loss = lossfor centroid in centroids:print(centroid.w * grid_size, centroid.h * grid_size)# print resultfor centroid in centroids:print("k-means result:\n")print(centroid.w * grid_size, centroid.h * grid_size)
#只需修改这里的参数n_anchors和grid_size;得到的9个预选框的参数复制到cfg即可 
#要修改的路径--训练集train.txt的路径
#label_path = "/home/chris/darknet/scripts/2007_train.txt"
label_path = "/usr/cx/darknetalexeyAB/names_data/2007_train.txt"
n_anchors = 9          #预选框anchors的个数,6,9,12,15,根据自己的实际项目进行设置;
loss_convergence = 1e-6
grid_size = 416        #栅格的尺寸
iterations_num = 100    #迭代的步数
plus = 0                 #开关;=1时,使用k-means++算法,一般=0。
compute_centroids(label_path,n_anchors,loss_convergence,grid_size,iterations_num,plus)

脚本文件(命名为k-means.py)
运行python k-means.py即可。
注意修改路径。代码注释中已经标出。

(3)anchors值替换
在cfg文件的每个yolo层,进行如下修改:
1)mask取值变为0~11,3个为一组,最前面一层yolo层的mask赋值为9,10,11
2)将第二行的anchors值更新替换成步骤(2)中计算得到的anchors值;
3)classes是类别数,此项目仅有1个类别,根据自己的项目修改classes的值
3)将num=9改成num=12

[yolo]
mask = 9,10,11
anchors = 34, 57,  77,110, 145,155, 174,220, 177,324, 212,273, 281,212, 356,206, 241,316, 329,265, 399,265, 346,33
classes=1
num=12
jitter=.3
ignore_thresh = .7
truth_thresh = 1
random=1

(4)模型训练
经过增加cfg配置文件的检测层,计算anchors,并将其更新到cfg配置文件中之后,接下来就可以进行模型的训练了。
注意:由于我们没有对backbone基础网络进行修改,所以,可以使用darknet53.conv.74预训练权重进行训练。
darknet53.conv.74下载链接如下:
darknet53.conv.74权重文件
链接:https://pan.baidu.com/s/14Hwqqsp_ua28Xu27gaQk6g
提取码:dnai

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/547398.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

uva 610(tarjan的应用)

题目链接&#xff1a;http://acm.hust.edu.cn/vjudge/problem/viewProblem.action?id23727 思路&#xff1a;首先是Tarjan找桥&#xff0c;对于桥&#xff0c;只能是双向边&#xff0c;而对于同一个连通分量而言&#xff0c;只要重新定向为同一个方向即可。 1 #include<ios…

Win7搭建NodeJs开发环境以及HelloWorld展示—图解

Windows 7系统下搭建NodeJs开发环境&#xff08;NodeJsWebStrom&#xff09;以及Hello World&#xff01;展示&#xff0c;大体思路如下&#xff1a;第一步&#xff1a;安装NodeJs运行环境。第二步&#xff1a;安装WebStrom开发工具。第三步&#xff1a;创建并运行NodeJs项目展…

IDEA中Spring Boot项目报错:There was an unexpected error (type=Not Found, status=404)

报这个错的原因是SpringBoot主配置类(SpringBootApplication注解标注的类)的所在包和你的Controller类不在同一个包下

计算机一级windows7操作,计算机等级一级:Windows7应用之小技巧

为了帮助广大考生更好的复习&#xff0c;帮考网综合整理提供了计算机等级考试一级微机知识:Windows7应用之小技巧&#xff0c;以供各位考生复习参考&#xff0c;希望对考生复习有所帮助。从年初开始&#xff0c;断断续续的&#xff0c;windows7使用也有一段时间&#xff0c;碰到…

漂亮的jQuery tab选项卡插件

清远大学城网&#xff08;http://www.qydxc.com&#xff09; tab选项卡在实际应用中几乎到处可见&#xff0c;像现在大型网站163&#xff0c;QQ&#xff0c;新浪&#xff0c;淘宝都使用了tab选项卡效果&#xff0c;下面我来介绍一款jQuery tab选项卡插件. jQuery tab插件 结构…

清空文件下的SVN控制文件

代码如下&#xff0c;复制代码为txt文件&#xff0c;更改后缀为“.bat”&#xff0c;把文件放到&#xff0c;需要删除的文件的顶端文件夹内&#xff0c;点击执行。 echo on color 2f mode con: cols80 lines25 REM echo 正在清理SVN文件&#xff0c;请稍候...... rem 循环删除当…

我国高性能计算机发展,中国高性能计算机发展水平与趋势

“目前&#xff0c;我国的高性能计算已达到世界先进水平&#xff0c;成为继美、日之后&#xff0c;被誉为世界高性能计算的‘第三股力量’。但是&#xff0c;从发展总体上&#xff0c;我们也清醒的看到&#xff0c;我国高性能计算在更广泛的应用领域上与西方国家存在很大差距。…

pandas无法打开.xlsx文件,xlrd.biffh.XLRDError: Excel xlsx file; not supported

pandas无法打开.xlsx文件&#xff0c;xlrd.biffh.XLRDError: Excel xlsx file&#xff1b; not supported 新版xlrd报 Excel xlsx file&#xff1b; not supported 原因是最近xlrd更新到了2.0.1版本&#xff0c;只支持.xls文件。所以pandas.read_excel(‘xxx.xlsx’)会报错。…

MySQL5.6忘记root密码(win平台)

1、首先net stop mysql服务&#xff0c;并且切换到任务管理器&#xff0c;有与mysql有关的&#xff0c;最好关闭进程。 2、运行CMD命令切换到MySql安装bin目录&#xff0c;下面是我的mysql安装目录 cd C:\Program Files\MySQL\MySQL Server 5.6\bin 接着执行mysqld --skip-gra…

JS计算本周一和本周五的日期

代码不长&#xff1a; var todaynew Date();var weekdaytoday.getDay(); var mondaynew Date(1000*60*60*24*(1-weekday) today.getTime()); var fridaynew Date(1000*60*60*24*(5-weekday) today.getTime()); 目前monday和friday都是Date类型的&#xff0c;要得到字符…

单片机四位数加减计算机程序,51单片机简易计算器程序 实现数字的加减乘除运算...

//头文件#define uint unsigned int#define uchar unsigned charsbit lcdenP1^1; //LCD1602控制引脚sbit rsP1^0;sbit rwP1^2;sbit busyP0^7;//LCD忙char i,j,temp,num,num_1;long a,b,c; //a,第一个数 b,第二个数 c,得数float a_c,b_c;uchar flag,fuhao;//flag表示是否有运…

Python函数定义变量报错:local variable ‘a‘ referenced before assignment

Python 全局变量与global关键字 ​ 在Python的变量使用中&#xff0c;经常会遇到这样的错误: local variable a referenced before assignment它的意思是&#xff1a;局部变量“a”在赋值前就被引用了。 ​ 比如运行下面的代码就会出现这样的问题&#xff1a; a 3 def Fuc(…

MySQL数据库工具类之——DataTable批量加入MySQL数据库(Net版)

MySQL数据库工具类之——DataTable批量加入数据库(Net版)&#xff0c;MySqlDbHelper通用类希望能对大家有用&#xff0c;代码如下&#xff1a; using MySql.Data.MySqlClient;using System;using System.Collections.Generic;using System.Configuration;using System.Data;usi…

显示播客信息-bloginfo() 函数

该标签显示用户博客的相关信息&#xff0c;这些信息通常来自用户在WordPress网站后台“我的配置”和“设置>常规”菜单中填写的内容。该标签可以用在页面模板的任何区域内&#xff0c;且该标签总是将结果输出给浏览器。如果用户需要将输出内容用在PHP中&#xff0c;请使用ge…

计算机的访问资料,怎么从一台电脑访问另一台电脑上的资料?

如果两台电脑用路由器上网&#xff0c;可以按下面方法设置&#xff0c;如果没有可以用网线和网卡连接。用一根网线让两台电脑共享文件网卡连接&#xff0c;首先准备好两张10/100m的网卡。然后&#xff0c;准备几米长的网线&#xff0c;具体长度由你决定。按特定的方式接好插头。…

在pandas中遍历DataFrame行

有如下 Pandas DataFrame&#xff1a; import pandas as pd inp [{c1:10, c2:100}, {c1:11,c2:110}, {c1:12,c2:120}] df pd.DataFrame(inp) print df 上面代码输出&#xff1a; c1 c2 0 10 100 1 11 110 2 12 120 现在需要遍历上面DataFrame的行。对于每一行&#x…

pomelo获取客户端IP

代码&#xff1a; Handler.prototype.getClientIp function(msg, session, next) {var ip session.__session__.__socket__.remoteAddress.ipconsole.log(ip);}

linux内核2.6.35编译过程

一、实验目的 学习重新编译Linux内核&#xff0c;理解、掌握Linux内核和发行版本的区别。 二、实验内容 在Linux操作系统环境下重新编译内核。实验主要内容&#xff1a; A. 查找并且下载一份内核源代码&#xff0c;本实验使用最新的Linux内核2.6.36。 B. 配置内核。 C. 编…

MySQL索引的Index method中btree和hash的区别

2019独角兽企业重金招聘Python工程师标准>>> 在MySQL中&#xff0c;大多数索引&#xff08;如 PRIMARY KEY,UNIQUE,INDEX和FULLTEXT&#xff09;都是在BTREE中存储&#xff0c;但使用memory引擎可以选择BTREE索引或者HASH索引&#xff0c;两种不同类型的索引各自有其…

软件研发测试工程师英文怎么说,软件测试工程师面试英文自我介绍

《软件测试工程师面试英文自我介绍》由会员分享&#xff0c;可在线阅读&#xff0c;更多相关《软件测试工程师面试英文自我介绍(4页珍藏版)》请在人人文库网上搜索。1、软件测试工程师面试英文自我介绍篇一、范文Im , Im twenty-six year old, I majored in E-business and wit…