如何使用CNN进行物体识别和分类_RCNN物体识别

                R-CNN,图片识别

38f6f756c5ddf43b82a5e0390e7967d3.png

目标检测(Object Detection)是图像分类的延伸,除了分类任务,还要给定多个检测目标的坐标位置。R-CNN是最早基于CNN的目标检测方法,然后基于这条路线依次演进出了SPPnet,Fast R-CNN和Faster R-CNN,然后到2017年的Mask R-CNN。

R-CNN 模型由候选区域(Region Proposal)、特征提取(Feature Extractor)和分类器(Classifier)三个模块组成。候选区域生成并提取独立类别的候选区域。特征提取从每个候选区域中提取特征,通常使用深度卷积神经网络。分类器使用线性 SVM 分类器模型将提取出的特征目标分类为已知类别之一。R-CNN 结构如下图所示。

1.候选区域

    R-CNN 生成候选区域时使用了选择性搜索(Selective Search)算法,用来提出候选区域或图像中潜在对象的边界框。选择性搜索算法把图像分割成 1000-2000 个小区域,遍历分割的小区域并合并可能性最高的相邻区域,知道整张图像合并成一个区域位置,合并后输出所有存在过的区域,即候选区域。

通过选择性搜索算法选出的候选边框为矩形,不同物体矩形边框大小不同。但CNN模型的输入层图片必须固定分辨率,如果选择性搜索算法选出的矩形边框不进行预处理,则不能作为CNN的输入提取图像特征。因此每个输入图像的矩形候选框均要进行大小格式化处理。

2.特征提取

对特定目标的识别检测的难点之一是已标记物体分类标签的训练数据不多,CNN通常进行随机初始化神经网络参数,对训练数据量的要求非常高,因此 R-CNN 采用有监督的预训练,网络优化求解时采用随机梯度下降法,学习率大小为 0.001。

特征提取网络预训练后,采用选择性搜索算法搜索出来的候选框继续对经过了预训练 CNN模型进行训练。其原理是,假设模型需要检测的目标类别有N个,则要对前述经预训练的CNN模型最后一层进行替换,输出成N+1个神经元,其中多出一个背景神经元。该层的训练过程使用随机初始化参数的方法,其它网络层则参数不变。输入一张图片,可以得到 2000个左右候选框(Bounding  Box)。数据集中的图片是提前进行人工标注的数据,每张图片都标注了涵盖目标物体的正确边框,因此在CNN阶段需要用重叠度(Intersection over Union,Io U)为 2000 个 Bounding Box 打分。若通过选择性搜索算法选出的 Bounding Box 与人工标注目标物体框的 Io U 大于 0.5,则将被 Bounding Box选中的物体标注成目标物体类别,该类物体成为正样本,若 Io U 小于 0.5 则该 Bounding Box 所框为背景类别,也成为负样本。

3.分类器

CNN的输出是一个4096个元素向量,用于描述图像的内容,并将其输入线性SVM进行分类,对每个已知类别的目标物体都训练一个支持向量机(Support  Vector Machine,SVM),因此这是一个二分类问题。在特征提取过程中 R-CNN 模型通过选择性搜索算法选取了 2000 个左右 Bounding Box,即一个 2000×4096 特征向量矩阵,之后将矩阵与 SVM 权值矩阵 4096×N 点乘,可得到 2000×N 的结果矩阵,该矩阵表示了 2000个 Bounding Box 的分类结果。

R-CNN有如下缺点:

(1)需固定每一张子图片的大小,改变了原有图片的尺寸,影响CNN分类器的效果。

(2)将每一候选图片放入分类器训练,速度很慢并且有重复计算。

(3)其训练是分阶段的,对于目标检测而言,R-CNN首先需要对预训练模型进行特定类别物体的微调训练,然后再训练SVM对提取到的特征进行分类,最后还需要训练候选框回归器(Bounding-box Regressor)对候选子图中的目标进行精确的提取。

主程序:

import time

start = time.time()

import numpy as np

import os

import six.moves.urllib as urllib

import sys

import tarfile

import tensorflow.compat.v1 as tf

tf.disable_v2_behavior()

import zipfile

import cv2

from collections import defaultdict

from io import StringIO

from matplotlib import pyplot as plt

from PIL import Image

os.chdir('C://Users//fxlir//Desktop//my_detect')#文件夹路径

#Env setup

# This is needed to display the images.

#%matplotlib inline

# This is needed since the notebook is stored in the object_detection folder.

sys.path.append("..")

#Object detection imports

from object_detection.utils import label_map_util

from object_detection.utils import visualization_utils as vis_util

#Model preparation

# What model to download.

#这是我们刚才训练的模型

MODEL_NAME = 'C://Users//fxlir//Desktop//my_detect//models'#训练好的模型文件夹

#对应的Frozen model位置

# Path to frozen detection graph. This is the actual model that is used for the object detection.

PATH_TO_CKPT = MODEL_NAME + '//frozen_inference_graph.pb'#训练好的模型

# List of the strings that is used to add correct label for each box.

PATH_TO_LABELS = os.path.join('training', 'object-detection.pbtxt') #类别标签

#改成自己例子中的类别数,4

NUM_CLASSES = 1

#Load a (frozen) Tensorflow model into memory.   

detection_graph = tf.Graph()

with detection_graph.as_default():

  od_graph_def = tf.GraphDef()

  with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:

    serialized_graph = fid.read()

    od_graph_def.ParseFromString(serialized_graph)

    tf.import_graph_def(od_graph_def, name='')   

#Loading label map

label_map = label_map_util.load_labelmap(PATH_TO_LABELS)

categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)

category_index = label_map_util.create_category_index(categories)

#Helper code

def load_image_into_numpy_array(image):

  (im_width, im_height) = image.size

  return np.array(image.getdata()).reshape(

      (im_height, im_width, 3)).astype(np.uint8)

#Detection

# If you want to test the code with your images, just add path to the images to the TEST_IMAGE_PATHS.

#测试图片位置

PATH_TO_TEST_IMAGES_DIR = 'C://Users//fxlir//Desktop//my_detect//raccoon'

os.chdir(PATH_TO_TEST_IMAGES_DIR)

TEST_IMAGE_PATHS = os.listdir(PATH_TO_TEST_IMAGES_DIR)

# Size, in inches, of the output images.

IMAGE_SIZE = (12, 8)

output_path = ('C://Users//fxlir//Desktop//my_detect//test_out//')

with detection_graph.as_default():

  with tf.Session(graph=detection_graph) as sess:

    # Definite input and output Tensors for detection_graph

    image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')

    # Each box represents a part of the image where a particular object was detected.

    detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')

    # Each score represent how level of confidence for each of the objects.

    # Score is shown on the result image, together with the class label.

    detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')

    detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')

    num_detections = detection_graph.get_tensor_by_name('num_detections:0')

    for image_path in TEST_IMAGE_PATHS:

      image = cv2.imread(image_path, 0)

      image_BGR = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)#

      #image_RGB = cv2.cvtColor(image_BGR, cv2.COLOR_BGR2RGB)#

      image_np = image_BGR

      image_np_expanded = np.expand_dims(image_np, axis=0)

      (boxes, scores, classes, num) = sess.run(

        [detection_boxes, detection_scores, detection_classes, num_detections],

        feed_dict={image_tensor: image_np_expanded})

      vis_util.visualize_boxes_and_labels_on_image_array(

            image_np,

            np.squeeze(boxes),

            np.squeeze(classes).astype(np.int32),

            np.squeeze(scores),

            category_index,

            use_normalized_coordinates=True,

            line_thickness=8)

      cv2.imwrite(output_path+image_path.split('\\')[-1],image_np)

      cv2.imshow('object detection', image_np)

      cv2.waitKey(0)

      cv2.destroyAllWindows()

end =  time.time()

print("Execution Time: ", end - start)

#欢迎订阅,一起学习,一起交流

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/504161.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

nfs服务器_Kubernetes集群下部署NFS持久存储

NFS是网络文件系统Network File System的缩写,NFS服务器可以让PC将网络中的NFS服务器共享的目录挂载到本地的文件系统中,而在本地的系统中来看,那个远程主机的目录就好像是自己的一个磁盘分区一样。kubernetes使用NFS共享存储有两种方式&…

c语言 指针_C 语言指针详解

(给CPP开发者加星标,提升C/C技能)作者:C语言与CPP编程 / 自成一派123(本文来自作者投稿)1为什么使用指针假如我们定义了 char a’A’ ,当需要使用 ‘A’ 时,除了直接调用变量 a ,还可以定义 char *p&a &#xff0c…

idea修改代码后不重启项目_使用DevTool实现SpringBoot项目热部署

前言最近在开发的时候,每次改动代码都需要启动项目,因为有的时候改动的服务比较多,所以重启的次数也就比较多了,想着每次重启等待也挺麻烦的,就打算使用DevTools工具实现项目的热部署热部署是什么大家都知道在项目开发…

c++ 单例模式_Redis单例、主从模式、sentinel以及集群的配置方式及优缺点对比

redis作为一种高效的缓存框架,使用是非常广泛的,在数据存储上,在运行时其将数据存储在内存中,以实现数据的高效读写,并且根据定制的持久化规则不同,其会不定期的将数据持久化到硬盘中。另外相较于其他的NoS…

jenkins 插件目录_10 个 Jenkins 实战经验,助你轻松上手持续集成

众所周知,持续构建与发布是我们日常工作中要面对的的一个重要环节,目前很多公司都采用 Jenkins 来搭建符合需求的 CI/CD 流程,作为一个持续集成的开源工具,它以安装启动方便,配置简单,上手容易的特点&#…

jdbc 批量insert_JDBC相关知识解答

1. JDBC_PreparedStatement插入大量数据_批处理插入_效率比较(1) jdbc新增大量数据时, 如何处理能提高效率?答:使用批处理提高效率(2) 什么是批处理? JDBC如何进行批处理?答:批处理:在与数据库的一次连接中,批量的执行条 SQL 语…

python 实现显著性检测_强!汽车车道视频检测:python+OpenCV为主实现

1 说明:1.1 完整版:汽车车道动态视频检测讲解和注释版代码,小白秒懂。1.2 pythonOpenCVmoviepynumpy为主的技术要点。1.3 代码来源:https://github.com/linghugoogle/CarND-Advanced-Lane-Lines #虽然感觉也是fork别人的&#xff…

如何学习c语言 零基础20天学会C语言

C语言开发 学习C语言不是一朝一夕的事情,但也不需要花费十年时间才能精通。如何以最小的代价学习并精通C语言是本文的主题。请注意,即使是“最小的代价”,也绝不是什么捷径,而是以最短的时间取得最多的收获,同时也意味…

学习C/C++的简单方法

如何学习C呢。C和C是很多专业的必修课,尤其对计算机专业来说,更是重中之重。C语言是早期发展的高级语言,具备执行速度快,语法优美等特点。是底层高效率系统的首选开发语言。今天就和大家分享一下怎么学好C/C语言吧 _ 怎么学好C、…

python数据预处理代码_Python中数据预处理(代码)

本篇文章给大家带来的内容是关于Python中数据预处理(代码),有一定的参考价值,有需要的朋友可以参考一下,希望对你有所帮助。1、导入标准库import numpy as np import matplotlib.pyplot as plt import pandas as pd 2、…

零基础想学好C语言编程,首先要掌握的是正确的学习思路!

如果新手要学习编程,一些前辈都会建议从Python、PHP、Java开始学。 不过,有些程序员是直接从C语言强势入门编程的。 那么,如何学习C语言呢?下面提供4种入门C语言的方法: 0、刷题 绝大多数的程序员学编程的时候,还…

C/C++初学者快速提升?

如今,软件开发行业继续向前大步迈进。信息技术越来越吃香,越来越多人学习学习c语言,那么如何系统有效的学习C语言?下面分享给大家的有效学习语言的方法,希望可以帮到你! 一、了解大纲,通览教材 想学好C语言最重要的一…

pytorch 矩阵相乘_深度学习 — — PyTorch入门(三)

点击关注我哦autograd和动态计算图可以说是pytorch中非常核心的部分,我们在之前的文章中提到:autograd其实就是反向求偏导的过程,而在求偏导的过程中,链式求导法则和雅克比矩阵是其实现的数学基础;Tensor构成的动态计算…

codeblocks如何让输出结果 空格_简单讲讲如何实现两个正整数相加,然后输出这个结果...

首先吧,两个整数123 456,相加得到579,我们就得输出579,这个很容易操作,但是如果是:1212161596156198115645646886148461554 2671232162176217624372497590415915915029125 呢?long ? long lo…

C语言和其他高级语言的最大的区别是什么?

提到C语言,我们知道C语言和其他高级语言的最大的区别就是C语言是要操作内存的! 我们需要知道——变量,其实是内存地址的一个抽像名字罢了。在静态编译的程序中,所有的变量名都会在编译时被转成内存地址。机器是不知道我们取的名字…

零基础学C语言必备书籍,抖音编程达人推荐(进群交流学习互动)

C语言从入门到进阶的书籍推荐。 【基础】 这本谭浩强写的【C语言程序设计】可谓是广大人事的入门书籍。我曾经用的教材就是这本,里面大概涵盖了 C语言 语法的 80% 。一个很适合自学的入门书。 【c prime puls】 是 C语言 最经典的入门书籍,极力推荐。每…

网站如何进行渠道跟踪_网站如何进行搜索引擎优化?

这是一个很一般的平台标题,没有任何吸引力,但是它真的可以被一个很好的基层站长估计的很少,我问一个做了多年基层站长的朋友,我说如何做好搜索引擎优化的SEO,他给我的答案很难,答案太大了,所以我…

什么是编程语言,大神教你为什么要学C语言?

首先来说说编程语言这个概念。 编程,其实就是让计算机听懂自己的话,让计算机帮自己想干的事情。编程语言,就是让你能够和计算机进行交流的一种语言。说白了就是让你的软件按你的命令干活。 打比方说,我们经常在僵尸片里面看到&a…

xshell vim 不能粘贴_linux基础知识:vim(vi)的知识

### vim三种模式命令行模式:在该模式下不能对文本进行- 直接编辑,可以输入一些操作(删除行,复制行,移动光标,粘贴)【打开之后默认进入的模式】编辑模式:在该模式下可以对文件内容进行编辑末行模式&#xff…

新手如何学习C语言/C++,教你一年时间是拿到年薪50万

最近会有一些初中高中大学的同学问,C语言C不知道怎么学习不会写代码怎么办?大致上都是一些类似的问题吧,回想一下自己走过的路,反复的了很久思考然后写了这篇文章,希望可以对一些迷惘新手小白程序员同学一丝帮助&#…