如何使用CNN进行物体识别和分类_RCNN物体识别

                R-CNN,图片识别

38f6f756c5ddf43b82a5e0390e7967d3.png

目标检测(Object Detection)是图像分类的延伸,除了分类任务,还要给定多个检测目标的坐标位置。R-CNN是最早基于CNN的目标检测方法,然后基于这条路线依次演进出了SPPnet,Fast R-CNN和Faster R-CNN,然后到2017年的Mask R-CNN。

R-CNN 模型由候选区域(Region Proposal)、特征提取(Feature Extractor)和分类器(Classifier)三个模块组成。候选区域生成并提取独立类别的候选区域。特征提取从每个候选区域中提取特征,通常使用深度卷积神经网络。分类器使用线性 SVM 分类器模型将提取出的特征目标分类为已知类别之一。R-CNN 结构如下图所示。

1.候选区域

    R-CNN 生成候选区域时使用了选择性搜索(Selective Search)算法,用来提出候选区域或图像中潜在对象的边界框。选择性搜索算法把图像分割成 1000-2000 个小区域,遍历分割的小区域并合并可能性最高的相邻区域,知道整张图像合并成一个区域位置,合并后输出所有存在过的区域,即候选区域。

通过选择性搜索算法选出的候选边框为矩形,不同物体矩形边框大小不同。但CNN模型的输入层图片必须固定分辨率,如果选择性搜索算法选出的矩形边框不进行预处理,则不能作为CNN的输入提取图像特征。因此每个输入图像的矩形候选框均要进行大小格式化处理。

2.特征提取

对特定目标的识别检测的难点之一是已标记物体分类标签的训练数据不多,CNN通常进行随机初始化神经网络参数,对训练数据量的要求非常高,因此 R-CNN 采用有监督的预训练,网络优化求解时采用随机梯度下降法,学习率大小为 0.001。

特征提取网络预训练后,采用选择性搜索算法搜索出来的候选框继续对经过了预训练 CNN模型进行训练。其原理是,假设模型需要检测的目标类别有N个,则要对前述经预训练的CNN模型最后一层进行替换,输出成N+1个神经元,其中多出一个背景神经元。该层的训练过程使用随机初始化参数的方法,其它网络层则参数不变。输入一张图片,可以得到 2000个左右候选框(Bounding  Box)。数据集中的图片是提前进行人工标注的数据,每张图片都标注了涵盖目标物体的正确边框,因此在CNN阶段需要用重叠度(Intersection over Union,Io U)为 2000 个 Bounding Box 打分。若通过选择性搜索算法选出的 Bounding Box 与人工标注目标物体框的 Io U 大于 0.5,则将被 Bounding Box选中的物体标注成目标物体类别,该类物体成为正样本,若 Io U 小于 0.5 则该 Bounding Box 所框为背景类别,也成为负样本。

3.分类器

CNN的输出是一个4096个元素向量,用于描述图像的内容,并将其输入线性SVM进行分类,对每个已知类别的目标物体都训练一个支持向量机(Support  Vector Machine,SVM),因此这是一个二分类问题。在特征提取过程中 R-CNN 模型通过选择性搜索算法选取了 2000 个左右 Bounding Box,即一个 2000×4096 特征向量矩阵,之后将矩阵与 SVM 权值矩阵 4096×N 点乘,可得到 2000×N 的结果矩阵,该矩阵表示了 2000个 Bounding Box 的分类结果。

R-CNN有如下缺点:

(1)需固定每一张子图片的大小,改变了原有图片的尺寸,影响CNN分类器的效果。

(2)将每一候选图片放入分类器训练,速度很慢并且有重复计算。

(3)其训练是分阶段的,对于目标检测而言,R-CNN首先需要对预训练模型进行特定类别物体的微调训练,然后再训练SVM对提取到的特征进行分类,最后还需要训练候选框回归器(Bounding-box Regressor)对候选子图中的目标进行精确的提取。

主程序:

import time

start = time.time()

import numpy as np

import os

import six.moves.urllib as urllib

import sys

import tarfile

import tensorflow.compat.v1 as tf

tf.disable_v2_behavior()

import zipfile

import cv2

from collections import defaultdict

from io import StringIO

from matplotlib import pyplot as plt

from PIL import Image

os.chdir('C://Users//fxlir//Desktop//my_detect')#文件夹路径

#Env setup

# This is needed to display the images.

#%matplotlib inline

# This is needed since the notebook is stored in the object_detection folder.

sys.path.append("..")

#Object detection imports

from object_detection.utils import label_map_util

from object_detection.utils import visualization_utils as vis_util

#Model preparation

# What model to download.

#这是我们刚才训练的模型

MODEL_NAME = 'C://Users//fxlir//Desktop//my_detect//models'#训练好的模型文件夹

#对应的Frozen model位置

# Path to frozen detection graph. This is the actual model that is used for the object detection.

PATH_TO_CKPT = MODEL_NAME + '//frozen_inference_graph.pb'#训练好的模型

# List of the strings that is used to add correct label for each box.

PATH_TO_LABELS = os.path.join('training', 'object-detection.pbtxt') #类别标签

#改成自己例子中的类别数,4

NUM_CLASSES = 1

#Load a (frozen) Tensorflow model into memory.   

detection_graph = tf.Graph()

with detection_graph.as_default():

  od_graph_def = tf.GraphDef()

  with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:

    serialized_graph = fid.read()

    od_graph_def.ParseFromString(serialized_graph)

    tf.import_graph_def(od_graph_def, name='')   

#Loading label map

label_map = label_map_util.load_labelmap(PATH_TO_LABELS)

categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)

category_index = label_map_util.create_category_index(categories)

#Helper code

def load_image_into_numpy_array(image):

  (im_width, im_height) = image.size

  return np.array(image.getdata()).reshape(

      (im_height, im_width, 3)).astype(np.uint8)

#Detection

# If you want to test the code with your images, just add path to the images to the TEST_IMAGE_PATHS.

#测试图片位置

PATH_TO_TEST_IMAGES_DIR = 'C://Users//fxlir//Desktop//my_detect//raccoon'

os.chdir(PATH_TO_TEST_IMAGES_DIR)

TEST_IMAGE_PATHS = os.listdir(PATH_TO_TEST_IMAGES_DIR)

# Size, in inches, of the output images.

IMAGE_SIZE = (12, 8)

output_path = ('C://Users//fxlir//Desktop//my_detect//test_out//')

with detection_graph.as_default():

  with tf.Session(graph=detection_graph) as sess:

    # Definite input and output Tensors for detection_graph

    image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')

    # Each box represents a part of the image where a particular object was detected.

    detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')

    # Each score represent how level of confidence for each of the objects.

    # Score is shown on the result image, together with the class label.

    detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')

    detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')

    num_detections = detection_graph.get_tensor_by_name('num_detections:0')

    for image_path in TEST_IMAGE_PATHS:

      image = cv2.imread(image_path, 0)

      image_BGR = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)#

      #image_RGB = cv2.cvtColor(image_BGR, cv2.COLOR_BGR2RGB)#

      image_np = image_BGR

      image_np_expanded = np.expand_dims(image_np, axis=0)

      (boxes, scores, classes, num) = sess.run(

        [detection_boxes, detection_scores, detection_classes, num_detections],

        feed_dict={image_tensor: image_np_expanded})

      vis_util.visualize_boxes_and_labels_on_image_array(

            image_np,

            np.squeeze(boxes),

            np.squeeze(classes).astype(np.int32),

            np.squeeze(scores),

            category_index,

            use_normalized_coordinates=True,

            line_thickness=8)

      cv2.imwrite(output_path+image_path.split('\\')[-1],image_np)

      cv2.imshow('object detection', image_np)

      cv2.waitKey(0)

      cv2.destroyAllWindows()

end =  time.time()

print("Execution Time: ", end - start)

#欢迎订阅,一起学习,一起交流

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/504161.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python如何不跳行打印_python怎么不换行打印

Python2.7中,执行完print后,会自动换行,如下代码会打印:abc\n123\n(其中\n代表换行)print (abc)print (123)如何实现不换行打印字符呢,下面介绍Python2.7中 实现不换行打印字符的3种简单方法:1.在print函数…

【LeetCode 总结】Leetcode 题型分类总结、索引与常用接口函数

文章目录零. Java 常用接口函数一. 动态规划二. 链表三. 哈希表四. 滑动窗口五. 字符串六. DFS、BFS七. 二分法八. 二叉树九. 偏数学、过目不忘 and 原地算法等十. 每日一题前言: 是时候开一个对于我的 LeetCode 专栏的总结索引了 虽然说大概只刷了150道左右&#…

nfs服务器_Kubernetes集群下部署NFS持久存储

NFS是网络文件系统Network File System的缩写,NFS服务器可以让PC将网络中的NFS服务器共享的目录挂载到本地的文件系统中,而在本地的系统中来看,那个远程主机的目录就好像是自己的一个磁盘分区一样。kubernetes使用NFS共享存储有两种方式&…

c语言 指针_C 语言指针详解

(给CPP开发者加星标,提升C/C技能)作者:C语言与CPP编程 / 自成一派123(本文来自作者投稿)1为什么使用指针假如我们定义了 char a’A’ ,当需要使用 ‘A’ 时,除了直接调用变量 a ,还可以定义 char *p&a &#xff0c…

kettle 插入更新 数据增量_使用Kettle工具进行增量数据同步

增量同步的方式有很多种,我使用的是: 快照表 触发器需求:当主库库表发生增删改时,从库库表与主库库表数据保持一致。环境:1、Mysql2、kettle 7.1思路:1、在主库中,将需要同步的库表新建快照表,…

idea修改代码后不重启项目_使用DevTool实现SpringBoot项目热部署

前言最近在开发的时候,每次改动代码都需要启动项目,因为有的时候改动的服务比较多,所以重启的次数也就比较多了,想着每次重启等待也挺麻烦的,就打算使用DevTools工具实现项目的热部署热部署是什么大家都知道在项目开发…

vue 计算文件hash值_vue的hash值原理,也是table切换。

.pages>div{display: none;}aaabbbcccc首页关于我的页面用户中心//hash 和页面一一对应起来//router 配置var router [{path:"/",component:document.getElementById("home")},{path:"/about",component:document.getElementById("abou…

c++ 单例模式_Redis单例、主从模式、sentinel以及集群的配置方式及优缺点对比

redis作为一种高效的缓存框架,使用是非常广泛的,在数据存储上,在运行时其将数据存储在内存中,以实现数据的高效读写,并且根据定制的持久化规则不同,其会不定期的将数据持久化到硬盘中。另外相较于其他的NoS…

jenkins 插件目录_10 个 Jenkins 实战经验,助你轻松上手持续集成

众所周知,持续构建与发布是我们日常工作中要面对的的一个重要环节,目前很多公司都采用 Jenkins 来搭建符合需求的 CI/CD 流程,作为一个持续集成的开源工具,它以安装启动方便,配置简单,上手容易的特点&#…

手机python3l运行_Python3 os.lchflags() 方法

Python3 os.lchflags() 方法概述os.lchflags() 方法用于设置路径的标记为数字标记,类似 chflags(),但是没有软链接。只支持在 Unix 下使用。语法lchflags()方法语法格式如下:os.lchflags(path, flags)参数path -- 设置标记的文件路径flags --…

jdbc 批量insert_JDBC相关知识解答

1. JDBC_PreparedStatement插入大量数据_批处理插入_效率比较(1) jdbc新增大量数据时, 如何处理能提高效率?答:使用批处理提高效率(2) 什么是批处理? JDBC如何进行批处理?答:批处理:在与数据库的一次连接中,批量的执行条 SQL 语…

lin通信ldf文件解析_lin ldf

Baby-LIN 采用闪存来保存固件, 因此更新和升级非常简便。 Baby-LIN 的软件套装是 LINWorks。这个软件包包括几个不同的应用程序。 LINWorks LDF-Editor 可以检查、......并且在未加载 LDF/SDF 文件的情况下,也可以用来监测与记录总线数据。 Baby-LIN-DLL 库文件可让用户编写应用…

vue项目使用大华摄像头怎样初始化_Vue接入监控视频技术总结

最近一直在搞监控视频接入方面的事情,积累了不少的经验,这里总结一下。提前说一句,本文提到的视频接入均是以RTSP为基础转码而来的,至于用海康大华等插件播放的咱们就闭口不提了可以看这个文章,在vue中接入ocx控件播放…

python 实现显著性检测_强!汽车车道视频检测:python+OpenCV为主实现

1 说明:1.1 完整版:汽车车道动态视频检测讲解和注释版代码,小白秒懂。1.2 pythonOpenCVmoviepynumpy为主的技术要点。1.3 代码来源:https://github.com/linghugoogle/CarND-Advanced-Lane-Lines #虽然感觉也是fork别人的&#xff…

var和function谁先优先执行_变量var声明和函数function声明优先级

变量声明优先级使用var关键字和function关键字声明的变量,会被JS的解释器优先解析执行,具有优先级使用var关键字声明变量1. 看代码说话// 在script中直接打印输出变量aconsole.log(a); // Uncaught ReferenceError: a is not defined2. 看代码说话consol…

python的变量名有哪些_【python字符串做变量名的方法有哪些?这些方法对python应用很重要】- 环球网校...

【摘要】python的功能都是建立在代码之上的,不过你知道python字符串做变量名的方法有哪些?这些方法对python应用很重要,如果你想学好python,那么本文内容一定要自己试试,毕竟实践出真知,那么python字符串做变量名的方…

如何学习c语言 零基础20天学会C语言

C语言开发 学习C语言不是一朝一夕的事情,但也不需要花费十年时间才能精通。如何以最小的代价学习并精通C语言是本文的主题。请注意,即使是“最小的代价”,也绝不是什么捷径,而是以最短的时间取得最多的收获,同时也意味…

钟平逻辑英语语法_逻辑英语-钟平笔记.pdf

英语主干定位:(状1 )主 (定1)谓 (状2 )( 宾 )(定 2 、状 1 )中文主干定位:(状1 、定 1)主 (状2 )谓 (定2 )( 宾 )(状 1 )主语:句首的独立名词性结构谓语:排除过程首先排除从句中和介词短语中动词宾语:谓语后的独立名词…

python爬虫分析_Python爬虫解析网页的4种方式

文章目录 爬虫的价值 正则表达式 requests-html BeautifulSoup lxml的XPath 爬虫的价值 常见的数据获取方式就三种:自有数据、购买数据、爬取数据。用Python写爬虫工具在现在是一种司空见惯的事情,每个人都希望能够写一段程序去互联网上扒一点资料下来&a…

学习C/C++的简单方法

如何学习C呢。C和C是很多专业的必修课,尤其对计算机专业来说,更是重中之重。C语言是早期发展的高级语言,具备执行速度快,语法优美等特点。是底层高效率系统的首选开发语言。今天就和大家分享一下怎么学好C/C语言吧 _ 怎么学好C、…