手把手教你用深度学习做物体检测(一): 快速感受物体检测的酷炫

我们先来看看什么是物体检测,见下图:

如上图所示, 物体检测就是需要检测出图像中有哪些目标物体,并且框出其在图像中的位置。

本篇文章,我将会介绍如何利用训练好的物体检测模型来快速实现上图的效果,这里我们将会用到基于coco数据集训练的yolov3模型,该模型能识别80类物品,具体如下:

人 自行车 汽车 摩托车 飞机 公共汽车 火车 卡车 船 红绿灯 消防栓 停车标志 停车收费码表 长凳 鸟 猫 狗 马 羊 牛 大象 熊 斑马 长颈鹿 
双肩包 雨伞 手提包 领带 手提箱 飞盘 双架滑雪板 滑雪板 球 风筝 棒球棍 棒球手套 滑板 冲浪板 网球拍 瓶子 酒杯 杯子 叉子 刀 勺子 碗 
香蕉 苹果 三明治 橙子 西兰花 胡萝卜 热狗 披萨 炸面圈 蛋糕 椅子 沙发 盆栽 床 餐桌 厕所 显示器 笔记本电脑 鼠标 遥控器 键盘 手机 
微波炉 电烤箱 烤面包器 水槽 冰箱 书 钟 花瓶 剪刀 泰迪熊 吹风机 牙刷

下面,我们来看具体如何实现。

第一步:从github上下载项目: https://github.com/qqwweee/keras-yolo3

该项目是基于keras的yolov3实现,keras是一个深度学习高层框架,提供了更友好的接口,其底层可以兼容很多深度学习框架,比如tensorflow等。yolo是目前很流行的物体检测算法,yolov3是第三个版本,也是最新的版本。

第二步:安装keras。

通过pip安装即可,如果后续有遇到本地环境没有的包,也通过pip安装就好了(这里假设你已经装好了python的相关环境,并且知道如何使用pip,如果你还不清楚,可以自行网上搜索,过程也不复杂)。

第三步:下载yolov3.weights,这个文件是darknet预训练好的yolov3模型,可以检测coco数据集中涵盖的80类物体。地址:https://pjreddie.com/media/files/yolov3.weights

第四步:执行以下命令,将下载下来的文件转换为keras可以使用的.h5模型文件

python convert.py yolov3.cfg yolov3.weights model_data/yolo.h5

第五步:将项目中的yolo.py,用下面代码替换,注意检查_defaults中的配置路径是否正确

# -*- coding: utf-8 -*-
"""
Class definition of YOLO_v3 style detection model on image and video
"""import colorsys
import os
from timeit import default_timer as timer
import numpy as np
from PIL import Image, ImageFont, ImageDraw
from keras import backend as K
from keras.layers import Input
from keras.models import load_model
from keras.utils import multi_gpu_model
from yolo3.model import yolo_eval, yolo_body, tiny_yolo_body
from yolo3.utils import letterbox_imageclass YOLO(object):_defaults = {"model_path": 'model_data/yolo_weights.h5',"anchors_path": 'model_data/yolo_anchors.txt',"classes_path": 'model_data/coco_classes.txt',"score" : 0.3,"iou" : 0.45,"model_image_size" : (416, 416),"gpu_num" : 0,}@classmethoddef get_defaults(cls, n):if n in cls._defaults:return cls._defaults[n]else:return "Unrecognized attribute name '" + n + "'"def __init__(self, **kwargs):self.__dict__.update(self._defaults)  # set up default valuesself.__dict__.update(kwargs)  # and update with user overridesself.class_names = self._get_class()self.anchors = self._get_anchors()self.sess = K.get_session()self.boxes, self.scores, self.classes = self.generate()def _get_class(self):classes_path = os.path.expanduser(self.classes_path)with open(classes_path, encoding="utf-8") as f:class_names = f.readlines()class_names = [c.strip() for c in class_names]return class_namesdef _get_anchors(self):anchors_path = os.path.expanduser(self.anchors_path)with open(anchors_path, encoding="utf-8") as f:anchors = f.readline()anchors = [float(x) for x in anchors.split(',')]return np.array(anchors).reshape(-1, 2)def generate(self):model_path = os.path.expanduser(self.model_path)assert model_path.endswith('.h5'), 'Keras model or weights must be a .h5 file.'# Load model, or construct model and load weights.num_anchors = len(self.anchors)num_classes = len(self.class_names)is_tiny_version = num_anchors == 6  # default settingtry:self.yolo_model = load_model(model_path, compile=False)except:self.yolo_model = tiny_yolo_body(Input(shape=(None, None, 3)), num_anchors // 2, num_classes) \if is_tiny_version else yolo_body(Input(shape=(None, None, 3)), num_anchors // 3, num_classes)self.yolo_model.load_weights(self.model_path)  # make sure model, anchors and classes matchelse:assert self.yolo_model.layers[-1].output_shape[-1] == \num_anchors / len(self.yolo_model.output) * (num_classes + 5), \'Mismatch between model and given anchor and class sizes'print('{} model, anchors, and classes loaded.'.format(model_path))# Generate colors for drawing bounding boxes.hsv_tuples = [(x / len(self.class_names), 1., 1.)for x in range(len(self.class_names))]self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))self.colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)),self.colors))np.random.seed(10101)  # Fixed seed for consistent colors across runs.np.random.shuffle(self.colors)  # Shuffle colors to decorrelate adjacent classes.np.random.seed(None)  # Reset seed to default.# Generate output tensor targets for filtered bounding boxes.self.input_image_shape = K.placeholder(shape=(2,))if self.gpu_num >= 2:self.yolo_model = multi_gpu_model(self.yolo_model, gpus=self.gpu_num)boxes, scores, classes = yolo_eval(self.yolo_model.output, self.anchors,len(self.class_names), self.input_image_shape,score_threshold=self.score, iou_threshold=self.iou)return boxes, scores, classesdef detect_image(self, image):start = timer()if self.model_image_size != (None, None):assert self.model_image_size[0] % 32 == 0, 'Multiples of 32 required'assert self.model_image_size[1] % 32 == 0, 'Multiples of 32 required'boxed_image = letterbox_image(image, tuple(reversed(self.model_image_size)))else:new_image_size = (image.width - (image.width % 32),image.height - (image.height % 32))boxed_image = letterbox_image(image, new_image_size)image_data = np.array(boxed_image, dtype='float32')print(image_data.shape)image_data /= 255.image_data = np.expand_dims(image_data, 0)  # Add batch dimension.out_boxes, out_scores, out_classes = self.sess.run([self.boxes, self.scores, self.classes],feed_dict={self.yolo_model.input: image_data,self.input_image_shape: [image.size[1], image.size[0]],K.learning_phase(): 0})print('Found {} boxes for {}'.format(len(out_boxes), 'img'))# font = ImageFont.truetype(font='font/FiraMono-Medium.otf',#                           size=np.floor(3e-2 * image.size[1] + 0.5).astype('int32'))# 使用中文字体font = ImageFont.truetype(font='font/simfang.ttf',size=np.floor(3e-2 * image.size[1] + 0.5).astype('int32'))thickness = (image.size[0] + image.size[1]) // 300for i, c in reversed(list(enumerate(out_classes))):predicted_class = self.class_names[c]box = out_boxes[i]score = out_scores[i]label = '{} {:.2f}'.format(predicted_class, score)draw = ImageDraw.Draw(image)label_size = draw.textsize(label, font)top, left, bottom, right = boxtop = max(0, np.floor(top + 0.5).astype('int32'))left = max(0, np.floor(left + 0.5).astype('int32'))bottom = min(image.size[1], np.floor(bottom + 0.5).astype('int32'))right = min(image.size[0], np.floor(right + 0.5).astype('int32'))print(label, (left, top), (right, bottom))if top - label_size[1] >= 0:text_origin = np.array([left, top - label_size[1]])else:text_origin = np.array([left, top + 1])# My kingdom for a good redistributable image drawing library.for i in range(thickness):draw.rectangle([left + i, top + i, right - i, bottom - i],outline=self.colors[c])draw.rectangle([tuple(text_origin), tuple(text_origin + label_size)],fill=self.colors[c])draw.text(text_origin, label, fill=(0, 0, 0), font=font)del drawend = timer()print(end - start)return imagedef close_session(self):self.sess.close()def detect_video(yolo, video_path, output_path=""):import cv2vid = cv2.VideoCapture(video_path)if not vid.isOpened():raise IOError("Couldn't open webcam or video")# video_FourCC = int(vid.get(cv2.CAP_PROP_FOURCC))video_FourCC = cv2.VideoWriter_fourcc(*"mp4v")video_fps = vid.get(cv2.CAP_PROP_FPS)video_size = (int(vid.get(cv2.CAP_PROP_FRAME_WIDTH)),int(vid.get(cv2.CAP_PROP_FRAME_HEIGHT)))isOutput = True if output_path != "" else Falseif isOutput:print("!!! TYPE:", type(output_path), type(video_FourCC), type(video_fps), type(video_size))out = cv2.VideoWriter(output_path, video_FourCC, video_fps, video_size)accum_time = 0curr_fps = 0fps = "FPS: ??"prev_time = timer()while True:return_value, frame = vid.read()if return_value:image = Image.fromarray(frame)image = yolo.detect_image(image)result = np.asarray(image)curr_time = timer()exec_time = curr_time - prev_timeprev_time = curr_timeaccum_time = accum_time + exec_timecurr_fps = curr_fps + 1if accum_time > 1:accum_time = accum_time - 1fps = "FPS: " + str(curr_fps)curr_fps = 0cv2.putText(result, text=fps, org=(3, 15), fontFace=cv2.FONT_HERSHEY_SIMPLEX,fontScale=0.50, color=(255, 0, 0), thickness=2)cv2.namedWindow("Object Detect", cv2.WINDOW_NORMAL)cv2.resizeWindow("Object Detect", 640, 480);cv2.imshow("Object Detect", result)if isOutput:print("start write...==========================================================================")out.write(result)if cv2.waitKey(1) & 0xFF == ord('q'):breakelse:breakout.release()vid.release()cv2.destroyAllWindows()yolo.close_session()def for_img(yolo):path = 'images/IMG_0728.JPG'try:image = Image.open(path)except:print('Open Error! Try again!')else:r_image = yolo.detect_image(image)r_image.show()yolo.close_session()def for_video(yolo):detect_video(yolo, "videos/xx.mp4", "videos/xx_detect.mp4")if __name__ == '__main__':_yolo = YOLO()for_img(_yolo)# for_video(_yolo)

本项目是可以在CPU上运行的,但是GPU上运行的更快。关于如何搭建GPU的运行环境,感兴趣的读者可以参考《如何在阿里云租一台GPU服务器做深度学习?》,然后将上面的代码的gpu_num改为你的GPU号(可以使用nvidia-smi命令查看),并注意加入对GPU显存的使用控制即可,这里为了快速体验物体检测效果,就不再对GPU下运行程序做过多的介绍,虽然在CPU下运行会慢很多,但用于体验足够了。

做完上面的步骤后,执行yolo.py,将会看到你想检测的图像的物体检测效果,左边是原图,该图项目中是没有的,可以自行下载,或者用你喜欢的其它图片来尝试检测:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

除了图片的检测,还可以对视频进行检测,修改yolo.py中的最后一行,将图片检测注释掉,放开视频检测的注释,然后执行yolo.py即可,下面是检测效果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/644953.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Pyside6中QTableWidget使用

目录 一:介绍: 二:演示 一:介绍: 在 PySide6 中,QTableWidget 是一个用于展示和编辑表格数据的控件。它提供了在窗口中创建和显示表格的功能,并允许用户通过单元格来编辑数据。 要使用 QTabl…

Windows 下 TFTP 服务搭建及 U-Boot 中使用 tftp 命令实现文件下载

目录 Tftpd32/64文件下载更多内容 TFTP(Trivial File Transfer Protocol,简单文件传输协议)是 TCP/IP 协议族中的一个用来在客户机与服务器之间进行简单文件传输的协议,提供不复杂、开销不大的文件传输服务,端口号为 6…

免费SSL申请和自动更新

当前是在mac下操作 安装certbot # mac下brew安装即可 brew install certbotcentos 安装 centos安装文档 申请泛解析证书 sudo certbot certonly --manual --preferred-challengesdns -d *.yourdomain.com## 输出 Saving debug log to /var/log/letsencrypt/letsencrypt.lo…

[Android] Android文件系统中存储的内容有哪些?

文章目录 前言root 文件系统/system 分区稳定性:安全性: /system/bin用来提供服务的二进制可执行文件:调试工具:UNIX 命令:调用 Dalvik 的脚本(upall script):/system/bin中封装的app_process脚本 厂商定制的二进制可执行文件: /system/xbin/system/lib[64]/system/…

6.php开发-个人博客项目Tp框架路由访问安全写法历史漏洞

目录 知识点 php框架——TP URL访问 Index.php-放在控制器目录下 ​编辑 Test.php--要继承一下 带参数的—————— 加入数据库代码 --不过滤 --自己写过滤 --手册(官方)的过滤 用TP框架找漏洞: 如何判断网站是thinkphp&#x…

nvm安装与使用教程

目录 nvm是什么 nvm安装 配置环境变量 更换淘宝镜像 安装node.js版本 nvm list available 显示可下载版本的部分列表 nvm install 版本号 ​编辑 nvm ls 查看已经安装的版本 ​编辑 nvm use 版本号(切换想使用的版本号) nvm是什么 nvm是node.js version management的…

mfc110.dll丢失是什么意思?全面解析mfc110.dll丢失的解决方法

在使用计算机的过程中,用户可能会遭遇一个常见的困扰,即系统提示无法找到mfc110.dll文件。这个动态链接库文件(DLL)是Microsoft Foundation Classes(MFC)库的重要组成部分,对于许多基于Windows的…

代码随想录刷题笔记 DAY12 | 二叉树的理论基础 | 二叉树的三种递归遍历 | 二叉树的非递归遍历 | 二叉树的广度优先搜索

Day 12 01. 二叉树的理论基础 1.1 二叉树的种类 满二叉树:除了叶子节点以外,每个节点都有两个子节点,整个树是被完全填满的完全二叉树:除了底层以外,其他部分是满的,底部可以不是满的但是必须是从左到右连…

数据结构之受限线性表

受限线性表 对于一般线性表,虽然必须通过遍历逐一查找再对目标位置进行增、删和查操作,但至少一般线性表对于可操作元素并没有限制。说到这里,大家应该明白了,所谓的受限线性表,就是可操作元素受到了限制。 受限线性表…

【Web前端开发基础】CSS3之Web字体、字体图标、平面转换、渐变

CSS3之Web字体、字体图标、平面转换、渐变 目录 CSS3之Web字体、字体图标、平面转换、渐变一、Web字体1.1 Web字体概述1.2 字体文件1.3 font-face 规则 二、字体图标2.1 字体图标2.2 字体图标的优点2.3 图标库2.4 下载字体包2.5 字体图标的使用步骤2.6 字体图标使用注意点2.7 上…

「 典型安全漏洞系列 」06.路径遍历(Path Traversal)详解

引言:什么是路径遍历?如何进行路径遍历攻击并规避常见防御?如何防止路径遍历漏洞。 1. 简介 路径遍历(Path Traversal)是一种安全漏洞,也被称为目录遍历或目录穿越、文件路径遍历。它发生在应用程序未正确…

mysql生成最近24小时整点/最近30天/最近12个月时间临时表

文章目录 生成最近24小时整点生成最近30天生成最近12个月 生成最近24小时整点 SELECT-- 每向下推1行, i比上次减去1b.*, i.*,DATE_FORMAT( DATE_SUB( NOW(), INTERVAL ( -( i : i - 1 ) ) HOUR ), %Y-%m-%d %H:00 ) AS time FROM-- 目的是生成12行数据( SELECTa FROM( SELECT…

搭建《幻兽帕鲁》服务器需要怎样配置的云服务器?

随着《幻兽帕鲁》这款游戏的日益流行,越来越多的玩家希望能够在自己的服务器上体验这款游戏。然而,搭建一个稳定、高效的游戏服务器需要仔细的规划和配置。本文将分享搭建《幻兽帕鲁》服务器所需的配置及搭建步骤,助力大家获得更加畅快的游戏…

搭建k8s集群实战(一)系统设置

1、架构及服务 Kubernetes作为容器集群系统,通过健康检查重启策略实现了Pod故障自我修复能力,通过调度算法实现将Pod分布式部署,并保持预期副本数,根据Node失效状态自动在其他Node拉起Pod,实现了应用层的高可用性。 …

树的学习day01

树的理解 树是一种递归形式的调用 树是由于多个结点组成的有限集合T 树中有且仅有一个结点称为根 当结点大于1的时候,往往其余的结点为m个互不相交的有限个集合T1,…,Tm,每个互不相交的有限集合本身右是一棵树,称为这个根的子树 空树也是树 关…

选现货白银投资划不划算?

可以肯定的是选择现货白银投资是划算的,但投资者需要有足够的知识和经验,以及对市场的敏锐观察力。只有这样,投资者才能在现货白银投资中获取收益。在投资市场上,白银作为一种特殊的投资品种,一直以来都备受投资者们的…

JUC-CAS

1. CAS概述 CAS(Compare ans swap/set) 比较并交换,实现并发的一种底层技术。它将预期的值和内存中的值比较,如果相同,就更新内存中的值。如果不匹配,一直重试(自旋)。Java.util.concurrent.atomic包下的原…

Redis - redis.windows.conf配置文件及RDB和AOF数据持久化方案

Redis - redis.windows.conf配置文件及RDB和AOF数据持久化方案 Redis的高性能是由于其将所有数据都存储在了内存中,为了使Redis在重启之后仍能保证数据不丢失,需要将数据从内存中同步到硬盘中,这一过程就是持久化。 Redis支持两种方式的持久化…

【51单片机】点亮第一个LED灯

目录 点亮第一个LED灯单片机 GPIO 介绍GPIO 概念GPIO 结构 LED简介软件设计点亮D1指示灯LED流水灯 橙色 点亮第一个LED灯 单片机 GPIO 介绍 GPIO 概念 GPIO(general purpose intput output) 是通用输入输出端口的简称, 可以通过软件来控制…