yolov3 anchors用kmeans聚类出先验框+anchor宽高比分析

一.yolov v3聚类出框 

# -*- coding: utf-8 -*-
import numpy as np
import random
import argparse
import os# # 参数名称
# parser = argparse.ArgumentParser(description='使用该脚本生成YOLO-V3的anchor boxes\n')
# parser.add_argument('--input_annotation_txt_dir', required=True, type=str, help='输入存储图片的标注txt文件(注意不要有中文)')
# parser.add_argument('--output_anchors_txt', required=True, type=str, help='输出的存储Anchor boxes的文本文件')
# parser.add_argument('--input_num_anchors', required=True, default=6, type=int, help='输入要计算的聚类(Anchor boxes的个数)')
# parser.add_argument('--input_cfg_width', required=True, type=int, help="配置文件中width")
# parser.add_argument('--input_cfg_height', required=True, type=int, help="配置文件中height")
# args = parser.parse_args()
# print('args:', args)
'''
centroids 聚类点 尺寸是 numx2,类型是ndarray
annotation_array 其中之一的标注框
'''def IOU(annotation_array, centroids):#similarities = []# 其中一个标注框w, h = annotation_arrayfor centroid in centroids:c_w, c_h = centroidif c_w >= w and c_h >= h:  # 第1中情况similarity = w * h / (c_w * c_h)elif c_w >= w and c_h <= h:  # 第2中情况similarity = w * c_h / (w * h + (c_w - w) * c_h)elif c_w <= w and c_h >= h:  # 第3种情况similarity = c_w * h / (w * h + (c_h - h) * c_w)else:  # 第3种情况similarity = (c_w * c_h) / (w * h)similarities.append(similarity)# 将列表转换为ndarrayreturn np.array(similarities, np.float32)  # 返回的是一维数组,尺寸为(num,)'''
k_means:k均值聚类
annotations_array 所有的标注框的宽高,N个标注框,尺寸是Nx2,类型是ndarray
centroids 聚类点 尺寸是 numx2,类型是ndarray
'''def k_means(annotations_array, centroids, eps=0.00005, iterations=200000):#N = annotations_array.shape[0]  # C=2num = centroids.shape[0]# 损失函数distance_sum_pre = -1assignments_pre = -1 * np.ones(N, dtype=np.int64)#iteration = 0# 循环处理while (True):#iteration += 1#distances = []# 循环计算每一个标注框与所有的聚类点的距离(IOU)for i in range(N):distance = 1 - IOU(annotations_array[i], centroids)distances.append(distance)# 列表转换成ndarraydistances_array = np.array(distances, np.float32)  # 该ndarray的尺寸为 Nxnum# 找出每一个标注框到当前聚类点最近的点assignments = np.argmin(distances_array, axis=1)  # 计算每一行的最小值的位置索引# 计算距离的总和,相当于k均值聚类的损失函数distances_sum = np.sum(distances_array)# 计算新的聚类点centroid_sums = np.zeros(centroids.shape, np.float32)for i in range(N):centroid_sums[assignments[i]] += annotations_array[i]  # 计算属于每一聚类类别的和for j in range(num):centroids[j] = centroid_sums[j] / (np.sum(assignments == j))# 前后两次的距离变化diff = abs(distances_sum - distance_sum_pre)# 打印结果print("iteration: {},distance: {}, diff: {}, avg_IOU: {}\n".format(iteration, distances_sum, diff,np.sum(1 - distances_array) / (N * num)))# 三种情况跳出while循环:1:循环20000次,2:eps计算平均的距离很小 3:以上的情况if (assignments == assignments_pre).all():print("按照前后两次的得到的聚类结果是否相同结束循环\n")breakif diff < eps:print("按照eps结束循环\n")breakif iteration > iterations:print("按照迭代次数结束循环\n")break# 记录上一次迭代distance_sum_pre = distances_sumassignments_pre = assignments.copy()if __name__ == '__main__':# 聚类点的个数,anchor boxes的个数num_clusters = 9#args.input_num_anchors# 索引出文件夹中的每一个标注文件的名字(.txt)names = [i for i in os.listdir('train_images_tif_txt') if 'txt' in i]#args.input_annotation_txt_dir)print('names:',names)# # 标注的框的宽和高annotations_w_h = []for name in names:txt_path = os.path.join('train_images_tif_txt', name)# 读取txt文件中的每一行f = open(txt_path, 'r')for line in f.readlines():line = line.rstrip('\n')w, h = line.split(' ')[3:]  # 这时读到的w,h是字符串类型# eval()函数用来将字符串转换为数值型annotations_w_h.append((eval(w), eval(h)))f.close()# 将列表annotations_w_h转换为numpy中的array,尺寸是(N,2),N代表多少框annotations_array = np.array(annotations_w_h, dtype=np.float32)N = annotations_array.shape[0]# 对于k-means聚类,随机初始化聚类点random_indices = [random.randrange(N) for i in range(num_clusters)]  # 产生随机数centroids = annotations_array[random_indices]# k-means聚类k_means(annotations_array, centroids, 0.00005, 200000)# 对centroids按照宽排序,并写入文件widths = centroids[:, 0]sorted_indices = np.argsort(widths)anchors = centroids[sorted_indices]print('anchors:',anchors)# # 将anchor写入文件并保存f_anchors = open('./anchors_txt.txt', 'w')# #for anchor in anchors:           #cfg_w train的时候用的宽度             #cfg_h train的时候用的高度f_anchors.write('%d,%d,' % (int(anchor[0] * 200), int(anchor[1] * 1800)))# f_anchors.write('\n')

train_images_tif_txt下存放的是如下所示的标注txt文件.

二.宽高比分析

1.kmeans.py代码

import numpy as npdef iou(box, clusters):"""Calculates the Intersection over Union (IoU) between a box and k clusters.:param box: tuple or array, shifted to the origin (i. e. width and height):param clusters: numpy array of shape (k, 2) where k is the number of clusters:return: numpy array of shape (k, 0) where k is the number of clusters"""x = np.minimum(clusters[:, 0], box[0])y = np.minimum(clusters[:, 1], box[1])if np.count_nonzero(x == 0) > 0 or np.count_nonzero(y == 0) > 0:raise ValueError("Box has no area")intersection = x * ybox_area = box[0] * box[1]cluster_area = clusters[:, 0] * clusters[:, 1]iou_ = intersection / (box_area + cluster_area - intersection)return iou_def avg_iou(boxes, clusters):"""Calculates the average Intersection over Union (IoU) between a numpy array of boxes and k clusters.:param boxes: numpy array of shape (r, 2), where r is the number of rows:param clusters: numpy array of shape (k, 2) where k is the number of clusters:return: average IoU as a single float"""return np.mean([np.max(iou(boxes[i], clusters)) for i in range(boxes.shape[0])])def translate_boxes(boxes):"""Translates all the boxes to the origin.:param boxes: numpy array of shape (r, 4):return: numpy array of shape (r, 2)"""new_boxes = boxes.copy()for row in range(new_boxes.shape[0]):new_boxes[row][2] = np.abs(new_boxes[row][2] - new_boxes[row][0])new_boxes[row][3] = np.abs(new_boxes[row][3] - new_boxes[row][1])return np.delete(new_boxes, [0, 1], axis=1)def kmeans(boxes, k, dist=np.median):"""Calculates k-means clustering with the Intersection over Union (IoU) metric.:param boxes: numpy array of shape (r, 2), where r is the number of rows:param k: number of clusters:param dist: distance function:return: numpy array of shape (k, 2)"""rows = boxes.shape[0]distances = np.empty((rows, k))last_clusters = np.zeros((rows,))np.random.seed()print('np.random.choice(rows, k, replace=False):',np.random.choice(rows, k))# the Forgy method will fail if the whole array contains the same rowsclusters = boxes[np.random.choice(rows, k, replace=False)]while True:for row in range(rows):distances[row] = 1 - iou(boxes[row], clusters)nearest_clusters = np.argmin(distances, axis=1)if (last_clusters == nearest_clusters).all():breakfor cluster in range(k):clusters[cluster] = dist(boxes[nearest_clusters == cluster], axis=0)last_clusters = nearest_clustersreturn clusters

2.example.py代码

import glob
import xml.etree.ElementTree as ET
import cv2
import os
import numpy as np
import matplotlib.pyplot as plt
from kmeans import kmeans, avg_iou# ANNOTATIONS_PATH = "./data/pascalvoc07-annotations"
ANNOTATIONS_PATH = "./data/widerface-annotations"
CLUSTERS = 9
# 相对原图是否归一化
BBOX_NORMALIZE = Truedef show_cluster(data, cluster, max_points=2000):'''Display bouding box's size distribution and anchor generated in scatter.'''if len(data) > max_points:idx = np.random.choice(len(data), max_points)data = data[idx]plt.scatter(data[:, 0], data[:, 1], s=5, c='lavender')plt.scatter(cluster[:, 0], cluster[:, 1], c='red', s=100, marker="^")plt.xlabel("Width")plt.ylabel("Height")plt.title("Bounding and anchor distribution")plt.savefig("cluster.png")plt.show()def show_width_height(data, cluster, bins=50):'''Display bouding box distribution with histgram.'''if data.dtype != np.float32:data = data.astype(np.float32)width = data[:, 0]height = data[:, 1]ratio = height / widthplt.figure(1, figsize=(20, 6))plt.subplot(131)plt.hist(width, bins=bins, color='green')plt.xlabel('width')plt.ylabel('number')plt.title('Distribution of Width')plt.subplot(132)plt.hist(height, bins=bins, color='blue')plt.xlabel('Height')plt.ylabel('Number')plt.title('Distribution of Height')plt.subplot(133)plt.hist(ratio, bins=bins, color='magenta')plt.xlabel('Height / Width')plt.ylabel('number')plt.title('Distribution of aspect ratio(Height / Width)')plt.savefig("shape-distribution.png")plt.show()def sort_cluster(cluster):'''Sort the cluster to with area small to big.'''if cluster.dtype != np.float32:cluster = cluster.astype(np.float32)area = cluster[:, 0] * cluster[:, 1]cluster = cluster[area.argsort()]ratio = cluster[:, 1:2] / cluster[:, 0:1]return np.concatenate([cluster, ratio], axis=-1)# def load_dataset(path, normalized=True):
#     '''
#     load dataset from pasvoc formatl xml files
#     return [[w,h],[w,h]]
#     '''
#     dataset = []
#     for xml_file in glob.glob("{}/*xml".format(path)):
#         tree = ET.parse(xml_file)
#
#         height = int(tree.findtext("./size/height"))
#         width = int(tree.findtext("./size/width"))
#
#         for obj in tree.iter("object"):
#             if normalized:
#                 xmin = int(obj.findtext("bndbox/xmin")) / float(width)
#                 ymin = int(obj.findtext("bndbox/ymin")) / float(height)
#                 xmax = int(obj.findtext("bndbox/xmax")) / float(width)
#                 ymax = int(obj.findtext("bndbox/ymax")) / float(height)
#             else:
#                 xmin = int(obj.findtext("bndbox/xmin"))
#                 ymin = int(obj.findtext("bndbox/ymin"))
#                 xmax = int(obj.findtext("bndbox/xmax"))
#                 ymax = int(obj.findtext("bndbox/ymax"))
#             if (xmax - xmin) == 0 or (ymax - ymin) == 0:
#                 continue  # to avoid divded by zero error.
#             dataset.append([xmax - xmin, ymax - ymin])
#
#     return np.array(dataset)def load_dataset(path, normalized=True):'''load dataset from pasvoc formatl xml filesreturn [[w,h],[w,h]]'''dataset = []names = [i for i in os.listdir(path) if 'txt' in i]  # args.input_annotation_txt_dir)# print('names:', names)# # 标注的框的宽和高# annotations_w_h = []for name in names:txt_path = os.path.join(path, name)img_path = txt_path.replace('.txt', '.jpg')img = cv2.imread(img_path)img_h, img_w, _ = img.shape# 读取txt文件中的每一行f = open(txt_path, 'r')for line in f.readlines():line = line.rstrip('\n')w, h = line.split(' ')[3:]  # 这时读到的w,h是字符串类型# eval()函数用来将字符串转换为数值型if normalized:dataset.append((eval(w), eval(h)))else:dataset.append((eval(w) * 200, eval(h) * 1800))f.close()return np.array(dataset)# print("Start to load data annotations on: %s" % ANNOTATIONS_PATH)
# [[w, h], [w, h]]
data = load_dataset(path='./train_img', normalized=BBOX_NORMALIZE)
print(data[:3])
print("Start to do kmeans, please wait for a moment.")
out = kmeans(data, k=CLUSTERS)
print('==out', out)
out_sorted = sort_cluster(out)
print("Accuracy: {:.2f}%".format(avg_iou(data, out) * 100))
#
show_cluster(data, out, max_points=2000)if out.dtype != np.float32:out = out.astype(np.float32)print("Recommanded aspect ratios(width/height)")
print("Width    Height   Height/Width")
for i in range(len(out_sorted)):print("%.3f      %.3f     %.1f" % (out_sorted[i, 0], out_sorted[i, 1], out_sorted[i, 2]))
show_width_height(data, out, bins=50)

txt是类别, cx,cy,w和h是归一化后的比例),下图是其分布,也就是输入如果是方形,anchor ratio比就用这个

 下图是乘以实际尺寸后的分布,也就是输入如果是图片等比例 anchor ratio比就用这个

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/493237.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Geoff Hinton:全新的想法将比微小的改进更有影响力

来源&#xff1a;AI科技评论摘要&#xff1a;日前&#xff0c;WIRED 对 Hinton 进行了一次专访&#xff0c;在访谈中&#xff0c;WIRED 针对人工智能带来的道德挑战和面临的挑战等问题进行了提问&#xff0c;以下为谈话内容。“作为一名谷歌高管&#xff0c;我认为在公开场合抱…

修改TOMCAT服务器图标为应用LOGO

在tomcat下部署应用程序&#xff0c;运行后&#xff0c;发现在地址栏中会显示tomcat的小猫咪图标。有时候&#xff0c;我们自己不想显示这个图标&#xff0c;想换成自己定义的的图标&#xff0c;那么按如下方法操作即可&#xff1a; 参考网上的解决方案&#xff1a;1、将$TOMCA…

python连接mysql的一些基础知识+安装Navicat可视化数据库+flask_sqlalchemy写数据库

一&#xff0e;mysql基础知识 &#xff11;&#xff0e;connect连接数据库 import pymysqldef get_conn():conn pymysql.connect(hostxxx.xxx.xxx.xxx, port3306, userroot, passwd, dbnewspaper_rest) # db:表示数据库名称return conn &#xff12;&#xff0e;创建表 im…

工业互联网平台创新发展白皮书(2018)

来源&#xff1a;走向智能论坛摘要&#xff1a;近日&#xff0c;在“2018年产业互联网与数据经济大会——首届工业互联网平台创新发展暨两化融合推进会”上&#xff0c;国家工业信息安全发展研究中心尹丽波主任发布并解读了《工业互联网平台创新发展白皮书&#xff08;2018&…

迭代器模式和组合模式混用

迭代器模式和组合模式混用 前言 园子里说设计模式的文章算得上是海量了&#xff0c;所以本篇文章所用到的迭代器设计模式和组合模式不提供原理解析&#xff0c;有兴趣的朋友可以到一些前辈的设计模式文章上学学&#xff0c;很多很有意思的。在Head First 设计模式这本书中&…

python实现可扩容队列

#coding:utf-8 """ fzh created on 2019/10/15 构建一个队列 """ import datetimeclass LoopQueue(object):def __init__(self, n10):self.arr [None] * (n1) # 由于特意浪费了一个空间&#xff0c;所以arr的实际大小应该是用户传入的容量1sel…

5G 产业链重要投资节点

来源&#xff1a;兴业证券 ▌5G:大通信容量及超低延时&#xff0c;未来多项应用的基础5G:高工作频率以及频谱带宽带来高通信容量5G(5thgeneration)是指第五代移动电话通信标准。3GPP(第三代合作伙伴计划&#xff0c;电信标准化机构)将5G标准分为了NSA(非独立组网)和SA(独立组网…

Kneser猜想与相关推广

本文本来是想放在Borsuk-Ulam定理的应用这篇文章当中。但是这个文章实在是太长&#xff0c;导致有喧宾夺主之嫌&#xff0c;从而独立出为一篇文章&#xff0c;仅供参考。$\newcommand{\di}{\mathrm{dist}}$ &#xff08;图1&#xff1a;Kneser叙述他的猜想原文手稿&#xff09;…

python .py文件变为.so文件进行加密

&#xff11;.mytest.py 需要加密的内容 #coding:utf-8 import datetimeclass Today():def get_time(self):print(datetime.datetime.now())def say(self):print("hello word!")today Today() today.say() today.get_time() 2.执行setup.py 也就是加密脚本 from…

从技术上解读大数据的应用现状和开源未来

来源&#xff1a;网络大数据作者 | 韩锐、 Lizy Kurian John、詹剑锋摘要&#xff1a;近年来&#xff0c;随着大数据系统的快速发展&#xff0c;各式各样的开源基准测试集被开发出来&#xff0c;以评测和分析大数据系统并促进其技术改进。然而&#xff0c;迄今为止&#xff0c;…

十八岁华裔天才携手「量子计算先驱」再次颠覆量子计算

来源&#xff1a;机器之心编译参与&#xff1a;刘晓坤、李泽南摘要&#xff1a;量子计算再一次「被打败了」。今年 8 月&#xff0c;刚刚年满 18 岁的 Ewin Tang 证明了经典算法能以和量子计算机相近的速度解决推荐问题&#xff0c;这位天才少女&#xff08;更正&#xff1a;不…

resnet系列+mobilenet v2+pytorch代码实现

一.resnet系列backbone import torch.nn as nn import math import torch.utils.model_zoo as model_zooBatchNorm2d nn.BatchNorm2d__all__ [ResNet, resnet18, resnet34, resnet50, resnet101, deformable_resnet18, deformable_resnet50,resnet152]model_urls {resnet18:…

广度优先搜索(BFS)与深度优先搜索(DFS)

一.广度优先搜索&#xff08;BFS&#xff09; 1.二叉树代码 # 实现一个二叉树 class TreeNode:def __init__(self, x):self.val xself.left Noneself.right Noneself.nexts []root_node TreeNode(1) node_2 TreeNode(2) node_3 TreeNode(3) node_4 TreeNode(4) node_…

骁龙855在AI性能上真的秒杀麒麟980?噱头而已

来源&#xff1a;网易智能摘要&#xff1a;前段时间的高通发布会上&#xff0c;有关骁龙855 AI性能达到友商竞品两倍的言论可谓是赚足了眼球。高通指出&#xff0c;骁龙855针对CPU、GPU、DSP都进行了AI计算优化&#xff0c;结合第四代AI引擎可以实现每秒超过7万亿次运算&#x…

MySQL主从复制(Master-Slave)与读写分离(MySQL-Proxy)实践 转载

http://heylinux.com/archives/1004.html MySQL主从复制&#xff08;Master-Slave&#xff09;与读写分离&#xff08;MySQL-Proxy&#xff09;实践 Mysql作为目前世界上使用最广泛的免费数据库&#xff0c;相信所有从事系统运维的工程师都一定接触过。但在实际的生产环境中&am…

深度解析AIoT背后的发展逻辑

来源&#xff1a;iotworld摘要&#xff1a;AI与IoT融合领域近年来一片火热&#xff0c;不论是资本市场&#xff0c;还是大众创业&#xff0c;无不对其表现出极大的热情。AIoT领域中人机交互的市场机会自2017年开始&#xff0c;“AIoT”一词便开始频频刷屏&#xff0c;成为物联网…

ubuntu安装Redis+安装mysql(配置远程登录)+安装jdk+安转nginx+安转teamviewer+安装terminator+安装sublime

一&#xff0e;Ubuntu 安装 Redis sudo apt-get update sudo apt-get install redis-server redis-server 启动 修改redis配置 远程访问: sudo vim /etc/redis/redis.conf 注释掉本机ip: 有坑的地方 #bind 127.0.0.1  service redis-server restart redis-cli ping …

深入理解SQL注入绕过WAF与过滤机制

知己知彼&#xff0c;百战不殆 --孙子兵法 [目录] 0x0 前言 0x1 WAF的常见特征 0x2 绕过WAF的方法 0x3 SQLi Filter的实现及Evasion 0x4 延伸及测试向量示例 0x5 本文小结 0x6 参考资料 0x0 前言 促使本文产生最初的动机是前些天在做测试时一些攻击向量被WAF挡掉了&#xff0c;…

预测|麦肯锡预测2030年:1亿中国人面临职业转换,全球8亿人被机器人取代

来源&#xff1a;先进制造业摘要&#xff1a;纵观人类技术的发展历程&#xff0c;往往遵循一个固定的规律&#xff0c;即先是概念萌芽&#xff0c;然后经历市场炒作&#xff0c;资本蜂拥&#xff0c;结果潮水退去&#xff0c;泡沫破灭。而繁华落尽后&#xff0c;才会经历技术成…

计算polygon面积和判断顺逆时针方向的方法

一&#xff0e;利用shapely求polygon面积 import shapelyfrom shapely.geometry import Polygon, MultiPoint # 多边形# box1 [2, 0, 4, 2, 2, 4, 0, 2, 0, 0]box1 [2, 0, 4, 2, 2, 4, 0, 2, 2, 2]poly_box1 Polygon(np.array(box1).reshape(-1,2))print(poly_box1)print(p…