基于tensorflow2的目标检测完整实现过程

序言

虽然tf1仍然在维护,但tf2毕竟是主流,如果不是项目有明确要求,建议直接选择tf2。本文以tf2为例展开,总结从环境准备到使用自己的数据和tensorflow预训练模型进行快速训练和调用。对tensorflow和目标检测算法有深入了解的,可以在此基础上,进行完全独立的模型搭建。tf1和tf2在目标检测API的使用上差异不算太大,文中会尽量单独列出。


一、数据准备

1、目标检测公开数据集

图像数据的收集及标注都是十分耗时的工作,因此,若只是单纯研究算法或存在场景高度重合的公开数据集,首选一定是使用公开数据集。在此列举几个常见的:

(1)coco数据集

全称:Microsoft Common Objects in Context,包含90种对象(80种大类),30万张图片,总共包含150万个对象。支持物体检测、分割、文字识别。

下载地址:

训练集:http://images.cocodataset.org/zips/train2017.zip

验证集:http://images.cocodataset.org/zips/val2017.zip

测试集:http://images.cocodataset.org/zips/test2017.zip

训练集和验证集的标注:http://images.cocodataset.org/annotations/annotations_trainval2017.zip

(不要尝试直接访问官网,网页加载很慢)

(2)wider face

用于人脸检测的数据集,包含32203个图像和393703个人脸图像。

下载地址:

WIDER FACE: A Face Detection Benchmark

(3)kitti

自动驾驶相关数据集,包含车辆、行人,交通标识的识别。

下载地址:

The KITTI Vision Benchmark Suite

图像相关数据集介绍,可以参考这位大神的文章,总结得很全面

图像数据库-CSDN博客

2、自行标注

(1)labelimg工具安装

如果以上数据无法覆盖实际业务场景,那就需要自行标注图片。这里介绍基于python开发的工具labelimg,使用pip install labelimg安装,然后执行labelimg即可打开。

 (2)labelimg工具使用

 (3)标注文件格式转换

如图,工具支持多种标注类型。有时候标注的类型不是我们想要的,需要进行转换(不可能重新标注)。常见的有YOLO(txt) ,VOC(xml)相互转换。

import os
import cv2
import xml.etree.ElementTree as ETdef txt_to_xml(input_txt_dir, output_xml_dir, image_dir, class_txt):# 获取txt文件的目录列表txt_files = os.listdir(input_txt_dir)# 获取图像的目录列表image_files = os.listdir(image_dir)image_infos = []for txt_file in txt_files:file_name, file_ext = os.path.splitext(txt_file)for image_file in image_files:images = []image_name, image_ext = os.path.splitext(image_file)if image_ext in ['.jpg', '.png', '.jpeg']:# 判断图像名是否与txt文件名相同if image_name == file_name:images.append(image_file)# 读取txt文件中的标注信息with open(os.path.join(input_txt_dir, txt_file), 'r') as f:bboxes = []for line in f.readlines():bbox_id, x_center, y_center, width, height = line.strip().split()x_center = float(x_center)  # 相对坐标y_center = float(y_center)  # 相对坐标width = float(width)  # 相对坐标height = float(height)  # 相对坐标bbox = (bbox_id, x_center, y_center, width, height)bboxes.append(bbox)images.append(bboxes)image_infos.append(images)# 获取标注框的类别列表class_names = []with open(class_txt, 'r') as classes:for class_name in classes.readlines():class_names.append(class_name.strip())# 遍历每个图像文件,获取图像的高度和宽度,并将标注信息写入XML文件for image_info in image_infos:image_file = image_info[0]image_name, image_ext = os.path.splitext(image_file)image_path = os.path.join(image_dir, image_file)img = cv2.imread(image_path)image_height, image_width, num_channels = img.shape[:3]  # 获取图片的高度、宽度和通道数# 创建XML文件并写入标注信息with open(os.path.join(output_xml_dir, image_name + '.xml'), mode='a') as f:f.write('<annotation>\n')# 图像位置信息f.write('\t<filename>{}</filename>\n'.format(image_file))f.write('\t<path>{}</path>\n'.format(image_path))# 图像尺寸信息f.write('\t<size>\n')f.write('\t\t<width>{}</width>\n\t\t<height>{}</height>\n\t\t<depth>{}</depth>\n'.format(image_width,image_height,num_channels))f.write('\t</size>\n')# 图像类别、坐标信息bboxes = image_info[1]for bbox in bboxes:bbox_id, x_center, y_center, width, height = bboxxmin = (x_center * image_width) - (width * image_width) / 2  # 计算标注框左上角x坐标值ymin = (y_center * image_height) - (height * image_height) / 2  # 计算标注框左上角y坐标值xmax = (x_center * image_width) + (width * image_width) / 2  # 计算标注框右下角x坐标值ymax = (y_center * image_height) + (height * image_height) / 2  # 计算标注框右下角y坐标值f.write('\t<object>\n')f.write('\t\t<name>{}</name>\n'.format(class_names[int(bbox_id)].strip()))f.write('\t\t<pose>Unspecified</pose>\n')f.write('\t\t<truncated>0</truncated>\n')f.write('\t\t<difficult>0</difficult>\n')f.write('\t\t<bndbox>\n')f.write('\t\t\t<xmin>{}</xmin>\n\t\t\t<ymin>{}</ymin>\n\t\t\t<xmax>{}</xmax>\n\t\t\t<ymax>{}</ymax>\n'.format(int(xmin), int(ymin), int(xmax), int(ymax)))f.write('\t\t</bndbox>\n')f.write('\t</object>\n')f.write('</annotation>')
def xml_to_txt(input_dir, output_dir, class_txt):# 获取所有XML文件列表xml_files = os.listdir(input_dir)# 获取标注框的类别列表class_names = []with open(class_txt, 'r') as classes:for class_name in classes.readlines():class_names.append(class_name.replace('\n', ''))# 遍历每个XML文件for xml_file in xml_files:# 获取文件名和扩展名file_name, file_ext = os.path.splitext(xml_file)# 确保是XML文件if file_ext == '.xml':# 解析XML文件并获取标注信息tree = ET.parse(os.path.join(input_dir, xml_file))root = tree.getroot()# 获取图像的最大宽度和高度max_width = float(root.find('size').find('width').text)max_height = float(root.find('size').find('height').text)# 获取标注框的坐标信息bndbox_coords = []for obj in root.findall('object'):bbox_type = obj.find('name').texttype_id = class_names.index(bbox_type)bndbox = obj.find('bndbox')xmin = float(bndbox.find('xmin').text)ymin = float(bndbox.find('ymin').text)xmax = float(bndbox.find('xmax').text)ymax = float(bndbox.find('ymax').text)bndbox_coords.append((type_id, xmin, ymin, xmax, ymax))# 计算YOLO所需的格式并写入输出文件with open(os.path.join(output_dir, file_name + '.txt'), 'w') as f:for coords in bndbox_coords:type_id, xmin, ymin, xmax, ymax = coordsx_center = (xmin + xmax) / 2 / max_width  # x_center字段计算,相对坐标y_center = (ymin + ymax) / 2 / max_height  # y_center字段计算,相对坐标width = (xmax - xmin) / max_width  # width字段(相对宽)计算height = (ymax - ymin) / max_height  # height字段(相对高)计算f.write('{} {:.6f} {:.6f} {:.6f} {:.6f}\n'.format(type_id, x_center, y_center, width, height))

                              

二、环境准备及目标检测API安装

1、目标检测API安装

直接从github上拉或者去网页下载(压缩包解压后重命名为models)

git clone https://github.com/tensorflow/models.git

 将models下面的2个文件夹路径添加到python的环境变量,否则无法正常导入相关模块,因为这个模块并非通过pip安装的。

export PYTHONPATH=$PYTHONPATH:/your_path/models/research:/your_path/models/research/slim

2、基础环境准备

tf1要求python≥3.6、tensorflow≥1.15、protobuf compiler>3.0

tf2要求python≥3.6、tensorflow≥2.2、protobuf compiler>3.0

Protocol Buffers(通常简称为protobuf)的目的是为了序列化和反序列化数据结构,以便于数据的存储、传输和跨平台交换。这里重点介绍一下protobuf和其编译器protoc的安装。

(1)安装protobuf
pip install protobuf==3.20.0

 这将保证python能正确处理proto文件。

(2)下载protoc编译器

根据自己的环境选择相应版本,地址:

Releases · protocolbuffers/protobuf · GitHub

下载解压后将其bin路径添加到环境变量,如 vi ~/.bashrc

export PATH="/your_path/protoc/bin:$PATH"
(3)编译python接口

在……\models\research 目录开执行:

protoc object_detection/protos/*.proto --python_out=.

 表示使用protoc编译.proto文件,它将.proto文件转换为python代码。

(4) 安装其他依赖

 将……/models/research/object_detection/packages/tf2/setup.py复制到……/models/research

 目录下,然后执行

python -m pip install .
(5)验证
python object_detection/builders/model_builder_tf2_test.py

 如果正确显示了耗时,则安装成功。报错一般都是依赖缺失或版本冲突。

三、模型训练及调用

1、数据格式转换

(1)公开数据集格式转换

不同的框架都会对输入数据格式进行一定要求,tensorflow目标检测API要求数据格式为tf record,在models/research/object_detection/dataset_tools文件下下,提供了一些针对主要公开数据集的格式转换的代码。

 

①调用方法1

如代码提示,在终端中通过指定数据集和标注文件以及输出路径来运行程序

 ②调用方法2

也可以直接修改代码,在代码中指定相关路径,然后双击运行。

 原始数据(下载解压后)

转换后

(2)转换自己的数据集

①基于labelimg标注的xml文件,先将其转换为csv文件,顺便做数据集划分。指定标注文件所在文件夹和输出的2个csv文件即可。

import glob
import pandas as pd
import xml.etree.ElementTree as ETdef xml_to_csv(path):xml_list = []# 读取注释文件for xml_file in glob.glob(path + '/*.xml'):tree = ET.parse(xml_file)root = tree.getroot()for member in root.findall('object'):value = (root.find('filename').text,int(root.find('size')[0].text),int(root.find('size')[1].text),member[0].text,int(member[4][0].text),int(member[4][1].text),int(member[4][2].text),int(member[4][3].text))xml_list.append(value)column_name = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax']# 将所有数据分为样本集和验证集,一般按照3:1的比例train_list = xml_list[0: int(len(xml_list) * 0.67)]eval_list = xml_list[int(len(xml_list) * 0.67) + 1:]# 保存为CSV格式train_df = pd.DataFrame(train_list, columns=column_name)eval_df = pd.DataFrame(eval_list, columns=column_name)train_df.to_csv('/image_identify/data/ad/csv/train.csv', index=False)eval_df.to_csv('/image_identify/data/ad/csv/eval.csv', index=False)path = '/image_identify/data/ad/anotations_xml'
xml_to_csv(path)

 ②再将csv转tf_record,需要指定csv文件、图片路径以及输出路径。

from __future__ import division
from __future__ import print_function
from __future__ import absolute_importimport os
import io
import pandas as pd
# 如果本身就是tf1,直接导入即可
import tensorflow.compat.v1 as tf
from PIL import Image
from object_detection.utils import dataset_util
from collections import namedtupleflags = tf.app.flags
flags.DEFINE_string('csv_input', '', 'Path to the CSV input')
flags.DEFINE_string('output_path', '', 'Path to output TFRecord')
FLAGS = flags.FLAGS# 将分类名称转成ID号,根据自己实际打的标签进行修改
def class_text_to_int(row_label):if row_label == 'car':return 1elif row_label == 'people':return 2else:# 其他都归为一类return 3def split(df, group):data = namedtuple('data', ['filename', 'object'])gb = df.groupby(group)return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)]def create_tf_example(group, path):print(os.path.join(path, '{}'.format(group.filename)))with tf.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid:encoded_jpg = fid.read()encoded_jpg_io = io.BytesIO(encoded_jpg)image = Image.open(encoded_jpg_io)width, height = image.sizefilename = (group.filename + '.jpg').encode('utf8')image_format = b'jpg'xmins = []xmaxs = []ymins = []ymaxs = []classes_text = []classes = []for index, row in group.object.iterrows():xmins.append(row['xmin'] / width)xmaxs.append(row['xmax'] / width)ymins.append(row['ymin'] / height)ymaxs.append(row['ymax'] / height)classes_text.append(row['class'].encode('utf8'))classes.append(class_text_to_int(row['class']))tf_example = tf.train.Example(features=tf.train.Features(feature={'image/height': dataset_util.int64_feature(height),'image/width': dataset_util.int64_feature(width),'image/filename': dataset_util.bytes_feature(filename),'image/source_id': dataset_util.bytes_feature(filename),'image/encoded': dataset_util.bytes_feature(encoded_jpg),'image/format': dataset_util.bytes_feature(image_format),'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),'image/object/class/text': dataset_util.bytes_list_feature(classes_text),'image/object/class/label': dataset_util.int64_list_feature(classes),}))return tf_exampledef main(csv_input, output_path, imgPath):writer = tf.python_io.TFRecordWriter(output_path)path = imgPathexamples = pd.read_csv(csv_input)grouped = split(examples, 'filename')for group in grouped:tf_example = create_tf_example(group, path)writer.write(tf_example.SerializeToString())writer.close()if __name__ == '__main__':imgPath = '/image_identify/data/ad/images'# 生成训练集output_path = '/image_identify/data/ad/tf_record/train.record'csv_input = '/image_identify/data/ad/csv/train.csv'main(csv_input, output_path, imgPath)# 生成验证集# output_path = '/image_identify/data/ad/tf_record/eval.record'# csv_input = '/image_identify/data/ad/csv/eval.csv'# main(csv_input, output_path, imgPath)

2、模型训练 

(1)下载预训练模型

tf1的地址:

https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf1_detection_zoo.md

tf2的地址:

https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2_detection_zoo.md

 这个网页只提供了模型下载地址,不支持直接下载,可以使用tf自带的文件下载方法

 自己注意url拼接正确。

import tensorflow as tf# 下载模型,并解压缩
def download_model(model_name):base_url = 'http://download.tensorflow.org/models/object_detection/tf2/20200713/'# 返回的是下载文件的路径model_dir = tf.keras.utils.get_file(fname=model_name,origin=base_url + model_name,# 是否需要解压untar=True)                                       return str(model_dir)MODEL_NAME = 'centernet_hg104_512x512_coco17_tpu-8.tar.gz'
PATH_TO_MODEL_DIR = download_model(MODEL_NAME)
print(PATH_TO_MODEL_DIR)

当然,也可以将url复制后,新开一个页面,输入链接后回车,也可以下载。

补充:标注文件的下载

# 下载标签
def download_labels(filename):base_url = 'https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/data/'label_dir = tf.keras.utils.get_file(fname=filename,origin=base_url + filename,untar=False)return str(label_dir)LABEL_FILENAME = 'mscoco_label_map.pbtxt'
# 下载 Label 对照表文件
PATH_TO_LABELS = download_labels(LABEL_FILENAME)
(2)使用预训练模型进行训练

图像领域,一则数据量较大,二来大量的卷积也导致计算量较大。因此,使用预训练模型进行二次训练(结合自己的数据)是一个靠谱的路子(快速落地)。而这个地方tf1和tf2就有区别了

①针对tf2

编辑下载的模型文件中的pipline.config文件,修改其中关于数据集以及模型的路径相关部分。

 检测类型要改成“detection”,ckpt-0表示临时文件的第0个,如果有多个,也可以指定其他。

 在object_detection文件夹里找到model_main_tf2.py文件,同样的,可以按照说明去命令行调用,也可以跟我一样,修改代码,然后直接运行。

②针对tf1

在research/object_detection/samples/configs文件夹下找到和下载的模型对应的配置文件,修改其中关于模型路径以及训练、验证集的路径。当然,其他参数也可以调整。

 再执行research/object_detection/legacy/train.py

python train.py \
--logtostderr \
--train_dir=your_path/output \
--pipeline_config_path=your_path/models/research/object_detection/samples/configs/faster_rcnn_inception_resnet_v2_atrous_coco.config

 注意,训练图像数据,最好有GPU支持。如果只是学习和测试,建议选择mobilenet等轻量网络,batch_size和num_steps可以设置小一点,否则,要么内存溢出要么等到天荒地老。

(3)训练完成后

 训练完成后,在你指定的输出文件夹就生成了新的checkpoint文件

 checkpoint转pb文件

在object_detection文件夹下找到exporter_main_v2.py文件,按照使用说明指定参数。

 

3、模型调用

 上述训练生成的模型(也可以直接使用下载的模型)就可以用于预测验证了。不过,一般checkpoint只作为临时文件,正式部署都需要保存成pb文件。

(1)加载pb文件
# 图片 目标检测
import os
import tensorflow as tf
import warnings
import time
from object_detection.utils import label_map_util, config_util
from object_detection.utils import visualization_utils as viz_utils
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from object_detection.builders import model_builderwarnings.filterwarnings('ignore')# GPU 设定为 记忆体动态调整 (dynamic memory allocation)
# 通过 tf.config.experimental.list_physical_devices,我们可以获得当前主机上某种特定运算设备类型(如 GPU 或 CPU )的列表
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:# 可以通过 tf.config.experimental.set_memory_growth 将 GPU 的显存使用策略设置为 “仅在需要时申请显存空间”tf.config.experimental.set_memory_growth(gpu, True)#  GPU 设定为固定为 2GB
if gpus:tf.config.experimental.set_virtual_device_configuration(gpus[0],[tf.config.experimental.VirtualDeviceConfiguration(memory_limit=1024*2)])MODEL_DATE = '20200711'
MODEL_NAME = 'centernet_hg104_1024x1024_coco17_tpu-32'
# 一般用keras下载下来的文件存储在"C:\Users\用户名\.keras" 下,linux在"/root/.keras"
PATH_TO_MODEL_DIR = '/root/.keras/datasets/centernet_hg104_1024x1024_coco17_tpu-32'PATH_TO_LABELS = '/root/.keras/datasets/mscoco_label_map.pbtxt'
# 建立 Label 的对照表 (代码与名称) 标签映射将索引号对应于类别名称,如5对应于飞机。
category_index = label_map_util.create_category_index_from_labelmap(PATH_TO_LABELS, use_display_name=True)# 测试图片
image_np = np.array(Image.open('../../data/images_Object_Detection/test.jpeg'))
# 转为 TensorFlow tensor
input_tensor = tf.convert_to_tensor(image_np)
# 加一维,变为 (批处理数, 宽, 高, 3通道颜色)
input_tensor = input_tensor[tf.newaxis, ...]"""加载模型方法1:Saved Model(包含完整的网络结构和权重等参数)"""# 从下载的目录载入模型
PATH_TO_SAVED_MODEL = PATH_TO_MODEL_DIR + "/saved_model"
print('载入模型...', end='')
start_time = time.time()
# 载入模型
detect_fn = tf.saved_model.load(PATH_TO_SAVED_MODEL)
elapsed_time = time.time() - start_time
print(f'共花费 0.0611 秒.')# 进行检测,信息包含:候选框, 类别, 机率
detections = detect_fn(input_tensor)
# 得到检测到的目标数
num_detections = int(detections.pop('num_detections'))
print(f'检测到的物件个数:{num_detections}')
# 转换数据类型
detections = {key: value[0, :num_detections].numpy() for key, value in detections.items()}
detections['num_detections'] = num_detections
detections['detection_classes'] = detections['detection_classes'].astype(np.int64)# 打印检测到的结果
print(f'物件资讯 (候选框, 类别, 机率):')
for detection_boxes, detection_classes, detection_scores in \zip(detections['detection_boxes'], detections['detection_classes'], detections['detection_scores']):print(np.around(detection_boxes, 4), detection_classes, round(detection_scores*100, 2))image_np_with_detections = image_np.copy()
# 加框
viz_utils.visualize_boxes_and_labels_on_image_array(image_np_with_detections,detections['detection_boxes'],detections['detection_classes'],detections['detection_scores'],category_index,                   # 分类的映射use_normalized_coordinates=True,  # 框的坐标是否为归一化的值max_boxes_to_draw=200,            # 最多绘制多少框,不指定将会全部绘制min_score_thresh=.30,             # 过滤概率过低的agnostic_mode=True)              # 默认False,表示每个类别的边界框独立(不同的颜色、标注)plt.figure(figsize=(12, 8))
plt.imshow(image_np_with_detections, cmap='viridis')
saved_file = '../../data/images_Object_Detection/test._detection2.png'
# 删除旧文件
if os.path.isfile(saved_file):os.remove(saved_file)
plt.savefig(saved_file)
(2)加载checkpoint
# 图片 目标检测
import os
import tensorflow as tf
import warnings
import time
from object_detection.utils import label_map_util, config_util
from object_detection.utils import visualization_utils as viz_utils
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from object_detection.builders import model_builderwarnings.filterwarnings('ignore')# GPU 设定为 记忆体动态调整 (dynamic memory allocation)
# 通过 tf.config.experimental.list_physical_devices,我们可以获得当前主机上某种特定运算设备类型(如 GPU 或 CPU )的列表
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:# 可以通过 tf.config.experimental.set_memory_growth 将 GPU 的显存使用策略设置为 “仅在需要时申请显存空间”tf.config.experimental.set_memory_growth(gpu, True)#  GPU 设定为固定为 2GB
if gpus:tf.config.experimental.set_virtual_device_configuration(gpus[0],[tf.config.experimental.VirtualDeviceConfiguration(memory_limit=1024*2)])# 一般用keras下载下来的文件存储在"C:\Users\用户名\.keras" 下,linux在"/root/.keras"
PATH_TO_MODEL_DIR = '/root/.keras/datasets/centernet_hg104_1024x1024_coco17_tpu-32'PATH_TO_LABELS = '/root/.keras/datasets/mscoco_label_map.pbtxt'
# 建立 Label 的对照表 (代码与名称) 标签映射将索引号对应于类别名称,如5对应于飞机。
category_index = label_map_util.create_category_index_from_labelmap(PATH_TO_LABELS, use_display_name=True)# 测试图片
image_np = np.array(Image.open('../../data/images_Object_Detection/test.jpeg'))
# 转为 TensorFlow tensor
input_tensor = tf.convert_to_tensor(image_np)
# 加一维,变为 (批处理数, 宽, 高, 3通道颜色)
input_tensor = input_tensor[tf.newaxis, ...]"""加载模型方法2:Checkpoint(只包含权重信息,需要重新构建模型。一般是训练过程中临时保存,加载速度相对更快)"""
PATH_TO_CFG = PATH_TO_MODEL_DIR + "/pipeline.config"
PATH_TO_CKPT = PATH_TO_MODEL_DIR + "/checkpoint"# 计时开始
print('Loading model... ', end='')
start_time = time.time()
configs = config_util.get_configs_from_pipeline_file(PATH_TO_CFG)
model_config = configs['model']
detection_model = model_builder.build(model_config=model_config, is_training=False)
# 还原模型
ckpt = tf.compat.v2.train.Checkpoint(model=detection_model)
# 恢复变量当在其他地方需要为模型重新载入之前保存的参数时,需要再次实例化一个 checkpoint,同时保持键名的一致。再调用 checkpoint 的 restore 方法。
ckpt.restore(os.path.join(PATH_TO_CKPT, 'ckpt-0')).expect_partial()
elapsed_time = time.time() - start_time
print(f'共花费 0.0611 秒.')# # 任选一张图片进行物件侦测
# # 虽然默认的即时执行模式(Eager Execution)为我们带来了灵活及易调试的特性,但在特定的场合,例如追求高性能或部署模型时,我们依然希望
# # 使用 TensorFlow 1.X 中默认的图执行模式(Graph Execution),将模型转换为高效的 TensorFlow 图模型。此时,TensorFlow 2 为我们提供
# # 了 tf.function 模块,结合 AutoGraph 机制,使得我们仅需加入一个简单的 @tf.function 修饰符,就能轻松将模型以图执行模式运行。
@tf.function
def detect_fn(image):image, shapes = detection_model.preprocess(image)prediction_dict = detection_model.predict(image, shapes)detections = detection_model.postprocess(prediction_dict, shapes)return detections# 进行检测,信息包含:候选框, 类别, 机率
detections = detect_fn(input_tensor)
# 得到检测到的目标数
num_detections = int(detections.pop('num_detections'))
print(f'检测到的物件个数:{num_detections}')
# 转换数据类型
detections = {key: value[0, :num_detections].numpy() for key, value in detections.items()}
detections['num_detections'] = num_detections
detections['detection_classes'] = detections['detection_classes'].astype(np.int64)# 打印检测到的结果
print(f'物件资讯 (候选框, 类别, 机率):')
for detection_boxes, detection_classes, detection_scores in \zip(detections['detection_boxes'], detections['detection_classes'], detections['detection_scores']):print(np.around(detection_boxes, 4), detection_classes, round(detection_scores*100, 2))image_np_with_detections = image_np.copy()
# 加框
viz_utils.visualize_boxes_and_labels_on_image_array(image_np_with_detections,detections['detection_boxes'],detections['detection_classes'],detections['detection_scores'],category_index,                   # 分类的映射use_normalized_coordinates=True,  # 框的坐标是否为归一化的值max_boxes_to_draw=200,            # 最多绘制多少框,不指定将会全部绘制min_score_thresh=.30,             # 过滤概率过低的agnostic_mode=True)              # 默认False,表示每个类别的边界框独立(不同的颜色、标注)plt.figure(figsize=(12, 8))
plt.imshow(image_np_with_detections, cmap='viridis')
saved_file = '../../data/images_Object_Detection/test._detection2.png'
# 删除旧文件
if os.path.isfile(saved_file):os.remove(saved_file)
plt.savefig(saved_file)
(3)检测效果示例

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/40263.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Vue+Xterm.js+WebSocket+JSch实现Web Shell终端

一、需求 在系统中使用Web Shell连接集群的登录节点 二、实现 前端使用Vue&#xff0c;WebSocket实现前后端通信&#xff0c;后端使用JSch ssh通讯包。 1. 前端核心代码 <template><div class"shell-container"><div id"shell"/>&l…

【FPGA】STA静态时序分析

文章目录 一.定义二.分类1. 静态时序分析2. 静态时序分析 三. 概念四. 时间余量1.场景2.建立时间余量3.保持时间余量 一.定义 时序分析:检查电路是否满足时序要求&#xff1b; 二.分类 1. 静态时序分析 STA,遍历所有的时序路径&#xff0c;根据时序库&#xff08;.lib文件&…

新手教学系列——使用uWSGI对Flask应用提速

在构建和部署Flask应用时,性能和稳定性是两个关键的因素。为了提升Flask应用的性能,我们可以借助uWSGI这个强大的工具。本文将详细介绍为什么要使用uWSGI、uWSGI的底层原理,并提供一个实例配置,帮助你更好地理解和应用这个工具。 为什么要使用uWSGI uWSGI 是一个应用服务…

探索企业知识边界,鸿翼ECM AI助手开启智慧问答新时代

在信息化迅速发展的当下&#xff0c;企业积累的数字文档数量巨大&#xff0c;这些文档中蕴含的深层信息对业务发展至关重要。然而&#xff0c;传统的搜索技术常常因只能进行关键字查询而无法满足对文档深层次理解的需求。 据Gartner调查&#xff0c;高达47%的员工在寻找有效工…

Webpack: 三种Chunk产物的打包逻辑

概述 在前文 Webpack: Dependency Graph 管理模块间依赖 中&#xff0c;我们已经详细讲解了「构建」阶段如何从 Entry 开始逐步递归读入、解析模块内容&#xff0c;并最终构建出模块依赖关系图 —— ModuleGraph 对象。本文我们继续往下&#xff0c;讲解在接下来的「封装」阶段…

【大数据】—美国交通事故分析(2016 年 2 月至 2020 年 12 月)

引言 在当今快速发展的数字时代&#xff0c;大数据已成为我们理解世界、做出决策的重要工具。特别是在交通安全领域&#xff0c;大数据分析能够揭示事故模式、识别风险因素&#xff0c;并帮助制定预防措施&#xff0c;从而挽救生命。本文将深入探讨2016年2月至2020年12月期间&…

24年河南特岗教师招聘流程+报名流程

河南特岗教师报名流程如下 1.登录河南省特岗招聘网 登录河南省特岗招聘网注册账号和密码&#xff0c;账号可以是手机号或者身份证号&#xff0c;密码自己设置 2.注册登录账号 注册完账号重新登录账号&#xff0c;输入身份证号、手机号、密码、验证码 3.浏览考试须知 填写个人信…

Python 编程快速上手——让繁琐工作自动化(第2版)读书笔记01 Python基础快速过关

Python 编程快速上手——让繁琐工作自动化&#xff08;第2版&#xff09;读书笔记01 Python基础快速过关 1 python基础概念 Python提供了高效的高级数据结构&#xff0c;还能简单有效地面向对象编程。 python运算符顺序 **——%——//——/——*——-——python中常见的数据…

Real-Time 3D Graphics with WebGL2

WebGL渲染管线 下图是WebGL渲染管线的示意图: Vertex Buffer Objects (VBOs) VBOS中包含了用于描述几何体的信息。如&#xff0c;几何体的顶点坐标&#xff0c;法线坐标&#xff0c;颜色&#xff0c;纹理坐标等。 Index Buffer Objects (IBOs) IBOs中包含了描述顶点关系的信…

C#的多线程UI窗体控件显示方案 - 开源研究系列文章

上次编写了《LUAgent服务器端工具》这个应用&#xff0c;然后里面需要新启动一个线程去对文件进行上传到FTP服务器&#xff0c;但是新线程里无法对应用主线程UI的内容进行更改&#xff0c;所以就需要在线程里设置主UI线程里控件信息的方法&#xff0c;于是就有了此博文。此文记…

Rocky Linux 9 快速安装docker 教程

前述 CentOS 7系统将于2024年06月30日停止维护服务。CentOS官方不再提供CentOS 及后续版本&#xff0c;不再支持新的软件和补丁更新。CentOS用户现有业务随时面临宕机和安全风险&#xff0c;并无法确保及时恢复。由于 CentOS Stream 相对不稳定&#xff0c;刚好在寻找平替系统…

idm 支持断点续传吗 idm 断点续传如何使用 idm断点续传怎么解决 idm下载中断后无法继续下载

断点续传功能&#xff0c;让我再也不会惧怕下载大型文件。在断点续传的帮助下&#xff0c;用户可以随时暂停下载任务&#xff0c;并在空闲时继续之前的下载进程。下载文件不惧网络波动&#xff0c;断点续传让下载过程更稳定。有关 idm 支持断点续传吗&#xff0c;idm 断点续传如…

JavaScript:if-else类型

目录 任务描述 相关知识 if语句 if-else语句 匹配问题 编程要求 任务描述 本关任务&#xff1a;根据成绩判断考试结果。 相关知识 在编程中&#xff0c;我们常常根据变量是否满足某个条件来执行不同的语句。 JavaScript中利用以if关键字开头的条件语句达到以上目的&am…

|从零搭建网络| VisionTransformer网络详解及搭建

&#x1f31c;|从零搭建网络| VisionTransformer系列网络详解及搭建&#x1f31b; 文章目录 &#x1f31c;|从零搭建网络| VisionTransformer系列网络详解及搭建&#x1f31b;&#x1f31c; 前言 &#x1f31b;&#x1f31c; VIT模型详解 &#x1f31b;&#x1f31c; VIT模型架…

mybatis、mybatis-plus插件开发,实现数据脱敏功能

首先说一下mybatis中四大组件的作用&#xff0c;下面开发的插件拦截器会使用 四大组件Executor、StatementHandler、ParameterHandler、ResultSetHandler Executor&#xff1a; Executor 是 MyBatis 中的执行器&#xff0c;负责 SQL 语句的执行工作。它通过调度 StatementHan…

python基础语法 004-3流程控制- while

1 while while 主要用的场景没有 for 循环多。 while循环&#xff1a;主要运行场景 我不知道什么时候结束。。。不知道运行多少次 1.1 基本用法 # while 4 > 3: #一直执行 # print("hell0")while 4 < 3: #不会打印&#xff0c;什么都没有print("…

IT之旅启航:高考后IT专业预习全攻略

✨作者主页&#xff1a; Mr.Zwq✔️个人简介&#xff1a;一个正在努力学技术的Python领域创作者&#xff0c;擅长爬虫&#xff0c;逆向&#xff0c;全栈方向&#xff0c;专注基础和实战分享&#xff0c;欢迎咨询&#xff01; 您的点赞、关注、收藏、评论&#xff0c;是对我最大…

Java知识点大纲

文章目录 第一阶段&#xff1a;JavaSE1、面向对象编程(基础)1)面向过程和面向对象区别2)类和对象的概述3)类的属性和方法4)创建对象内存分析5)构造方法(Construtor)及其重载6)对象类型的参数传递7)this关键字详解8)static关键字详解9)局部代码块、构造代码块和静态代码块10)pac…

2-24 基于图像处理的细胞计数方法

基于图像处理的细胞计数方法。经过初次二值化、中值滤波后二值化、优化后二值化图像、填充背景色的二进制图像、开运算后的图像一系列运算后&#xff0c;进行标签设置&#xff0c;最终得到细胞总数。程序已调通&#xff0c;可直接运行。 2-24 细胞计数方法 中值滤波后二值化 - …