Pytorch torchvision完成Faster-rcnn目标检测demo及源码详解

Torchvision更新到0.3.0后支持了更多的功能,其中新增模块detection中实现了整个faster-rcnn的功能。本博客主要讲述如何通过torchvision和pytorch使用faster-rcnn,并提供一个demo和对应代码及解析注释。

目录

如果你不想深入了解原理和训练,只想用Faster-rcnn做目标检测,请看这里

torchvision中Faster-rcnn接口

一个demo

使用方法

如果你想深入了解原理,并训练自己的模型

环境搭建

准备训练数据

模型训练

单张图片检测

效果

如果你不想深入了解原理和训练,只想用Faster-rcnn做目标检测,请看这里
torchvision中Faster-rcnn接口
torchvision内部集成了Faster-rcnn的模型,其接口和调用方式野非常简洁,目前官方提供resnet50+rpn在coco上训练的模型,调用该模型只需要几行代码:

>>> import torch
>>> import torchvision
 
// 创建模型,pretrained=True将下载官方提供的coco2017模型
>>> model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
>>> model.eval()
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
>>> predictions = model(x)
 
 
 
注意网络的输入x是一个Tensor构成的list,而输出prediction则是一个由dict构成list。prediction的长度和网络输入的list中Tensor个数相同。prediction中的每个dict包含输出的结果:

其中boxes是检测框坐标,labels是类别,scores则是置信度。

>>> predictions[0]
 
{'boxes': tensor([], size=(0, 4), grad_fn=<StackBackward>), 'labels': tensor([], dtype=torch.int64), 'scores': tensor([], grad_fn=<IndexBackward>)}
一个demo
如果你不想自己写读取图片/预处理/后处理,我这里有个写好的demo.py,可以跑在任何安装了pytorch1.1+和torchvision0.3+的环境下,不需要其他依赖,可以用来完成目标检测的任务。

为了能够显示类别标签,我们将coco的所有类别写入coco_names.py

names = {'0': 'background', '1': 'person', '2': 'bicycle', '3': 'car', '4': 'motorcycle', '5': 'airplane', '6': 'bus', '7': 'train', '8': 'truck', '9': 'boat', '10': 'traffic light', '11': 'fire hydrant', '13': 'stop sign', '14': 'parking meter', '15': 'bench', '16': 'bird', '17': 'cat', '18': 'dog', '19': 'horse', '20': 'sheep', '21': 'cow', '22': 'elephant', '23': 'bear', '24': 'zebra', '25': 'giraffe', '27': 'backpack', '28': 'umbrella', '31': 'handbag', '32': 'tie', '33': 'suitcase', '34': 'frisbee', '35': 'skis', '36': 'snowboard', '37': 'sports ball', '38': 'kite', '39': 'baseball bat', '40': 'baseball glove', '41': 'skateboard', '42': 'surfboard', '43': 'tennis racket', '44': 'bottle', '46': 'wine glass', '47': 'cup', '48': 'fork', '49': 'knife', '50': 'spoon', '51': 'bowl', '52': 'banana', '53': 'apple', '54': 'sandwich', '55': 'orange', '56': 'broccoli', '57': 'carrot', '58': 'hot dog', '59': 'pizza', '60': 'donut', '61': 'cake', '62': 'chair', '63': 'couch', '64': 'potted plant', '65': 'bed', '67': 'dining table', '70': 'toilet', '72': 'tv', '73': 'laptop', '74': 'mouse', '75': 'remote', '76': 'keyboard', '77': 'cell phone', '78': 'microwave', '79': 'oven', '80': 'toaster', '81': 'sink', '82': 'refrigerator', '84': 'book', '85': 'clock', '86': 'vase', '87': 'scissors', '88': 'teddybear', '89': 'hair drier', '90': 'toothbrush'}
然后构建一个可以读取图片并检测的demo.py

import torch
import torchvision
import argparse
import cv2
import numpy as np
import sys
sys.path.append('./')
import coco_names
import random
 
def get_args():
    parser = argparse.ArgumentParser(description='Pytorch Faster-rcnn Detection')
 
    parser.add_argument('image_path', type=str, help='image path')
    parser.add_argument('--model', default='fasterrcnn_resnet50_fpn', help='model')
    parser.add_argument('--dataset', default='coco', help='model')
    parser.add_argument('--score', type=float, default=0.8, help='objectness score threshold')
    args = parser.parse_args()
 
    return args
 
def random_color():
    b = random.randint(0,255)
    g = random.randint(0,255)
    r = random.randint(0,255)
 
    return (b,g,r)
 
def main():
    args = get_args()
    input = []
    num_classes = 91
    names = coco_names.names
        
    # Model creating
    print("Creating model")
    model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes, pretrained=True)  
    model = model.cuda()
 
    model.eval()
 
    src_img = cv2.imread(args.image_path)
    img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
    img_tensor = torch.from_numpy(img/255.).permute(2,0,1).float().cuda()
    input.append(img_tensor)
    out = model(input)
    boxes = out[0]['boxes']
    labels = out[0]['labels']
    scores = out[0]['scores']
 
    for idx in range(boxes.shape[0]):
        if scores[idx] >= args.score:
            x1, y1, x2, y2 = boxes[idx][0], boxes[idx][1], boxes[idx][2], boxes[idx][3]
            name = names.get(str(labels[idx].item()))
            cv2.rectangle(src_img,(x1,y1),(x2,y2),random_color(),thickness=2)
            cv2.putText(src_img, text=name, org=(x1, y1+10), fontFace=cv2.FONT_HERSHEY_SIMPLEX, 
                fontScale=0.5, thickness=1, lineType=cv2.LINE_AA, color=(0, 0, 255))
 
    cv2.imshow('result',src_img)
    cv2.waitKey()
    cv2.destroyAllWindows()
 
    
 
if __name__ == "__main__":
    main()
运行命令

$ python demo.py [image path]
就能完成检测,并且不需要任何其他依赖,只需要Pytorch1.1+和torchvision0.3+。看下效果:

使用方法
我发现好像很多人对上面这个demo怎么用不太清楚,照着下面的流程做就好了:

下载代码:https://github.com/supernotman/Faster-RCNN-with-torchvision
下载模型:Baidu Cloud
运行命令:
$ python detect.py --model_path [模型路径] --image_path [图片路径]
其实非常简单。

如果你想深入了解原理,并训练自己的模型
这里提供一份我重构过的代码,把torchvision中的faster-rcnn部分提取出来,可以训练自己的模型(目前只支持coco),并有对应博客讲解。

代码地址:https://github.com/supernotman/Faster-RCNN-with-torchvision
代码解析博客:
Pytorch torchvision构建Faster-rcnn(一)----coco数据读取

Pytorch torchvision构建Faster-rcnn(二)----基础网络

Pytorch torchvision构建Faster-rcnn(三)----RPN

Pytorch torchvision构建Faster-rcnn(四)----ROIHead

训练模型:Baidu Cloud
环境搭建
下载代码:

$ git clone https://github.com/supernotman/Faster-RCNN-with-torchvision.git
安装依赖:

$ pip install -r requirements.txt
注意:

代码要求Pytorch版本大于1.1.0,torchvision版本大于0.3.0。

如果某个依赖项通过pip安装过慢,推荐替换清华源:

$ pip install -i https://pypi.tuna.tsinghua.edu.cn/simple some-package
如果pytorch安装过慢,可参考conda安装Pytorch下载过慢解决办法(7月23日更新ubuntu下pytorch1.1安装方法)

准备训练数据
下载coco2017数据集,下载地址:

http://images.cocodataset.org/zips/train2017.zip
http://images.cocodataset.org/annotations/annotations_trainval2017.zip

http://images.cocodataset.org/zips/val2017.zip
http://images.cocodataset.org/annotations/stuff_annotations_trainval2017.zip

http://images.cocodataset.org/zips/test2017.zip
http://images.cocodataset.org/annotations/image_info_test2017.zip 

如果下载速度过慢,可参考博客COCO2017数据集国内下载地址

数据下载后按照如下结构放置:

  coco/
    2017/
      annotations/
      test2017/
      train2017/
      val2017/
模型训练
$ python -m torch.distributed.launch --nproc_per_node=$gpus --use_env train.py --world-size $gpus --b 4
训练采用了Pytorch的distributedparallel方式,支持多gpu。

注意其中$gpus为指定使用的gpu数量,b为每个gpu上的batch_size,因此实际batch_size大小为$gpus × b。

实测当b=4,1080ti下大概每张卡会占用11G显存,请根据情况自行设定。

训练过程中每个epoch会给出一次评估结果,形式如下:

 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.352
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.573
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.375
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.207
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.387
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.448
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.296
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.474
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.498
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.312
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.538
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.631

其中AP为准确率,AR为召回率,第一行为训练结果的mAP,第四、五、六行分别为小/中/大物体对应的mAP

单张图片检测
$ python detect.py --model_path result/model_13.pth --image_path imgs/1.jpg
model_path为模型路径,image_path为测试图片路径。

代码文件夹中assets给出了从coco2017测试集中挑选的11张图片测试结果。

效果

任何程序错误,以及技术疑问或需要解答的,请添加

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/547028.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Hadoop安装配置

1、集群部署介绍 1.1 Hadoop简介 Hadoop是Apache软件基金会旗下的一个开源分布式计算平台。以Hadoop分布式文件系统&#xff08;HDFS&#xff0c;Hadoop Distributed Filesystem&#xff09;和MapReduce&#xff08;Google MapReduce的开源实现&#xff09;为核心的Hadoop为用户…

iOS设置拍照retake和use按钮为中文简体

iOS设置拍照retake和use按钮为中文简体&#xff0c;设置有两种方式一个是代码直接控制&#xff0c;第二就是xcode配置本机国际化为“china”&#xff08;简体中文&#xff09;。 本文重点要说的是第二种&#xff0c;这样配置有两个好处&#xff0c;一是操作比较简单&#xff0…

QT5 QSqlQuery的SELECT INSERT UPDATE DELETE命令用法

1.QSqlQuery的SELECT查询记录用法&#xff1a; QSqlQuery q("SELECT * FROM departments");QSqlRecord rec q.record();int idCol rec.indexOf("departID"); // index of the field "departID"int nameColrec.indexOf("department")…

实时手势识别 【手部跟踪】Mediapipe中的hand

参考链接&#xff1a; 1&#xff09;github代码链接&#xff1a;https://github.com/google/mediapipe 2&#xff09;说明文档&#xff1a;https://google.github.io/mediapipe 3&#xff09;python环境配置文档&#xff1a;https://google.github.io/mediapipe/getting_sta…

react native仿微信性别选择-自定义弹出框

简述 要实现微信性别选择需要使用两部分的技术&#xff1a; 第一、是自定义弹出框&#xff1b; 第二、单选框控件使用&#xff1b; 效果 实现 一、配置弹出框 弹出框用的是&#xff1a;react-native-popup-dialog&#xff08;Git地址&#xff1a;https://github.com/jacklam…

斯蒂芬斯蒂芬但是当时发生的s

2019独角兽企业重金招聘Python工程师标准>>> 什么是啊啊啊啊啊啊啊 "> 转载于:https://my.oschina.net/ivanfjz/blog/190114

Error processing line 1 of vision-1.0.0-py3.6-nspkg.pth AttributeError: ‘NoneType‘ object has no

最近调试代码不知道安装什么包导致代码运行的时候出现报错 AttributeError: NoneType object has no attribute loader &#xff0c;虽然代码也能运行通过&#xff0c;但是报错还是很不舒服。 Remainder of file ignored Error processing line 1 of D:\Anaconda3\envs\fastrc…

华为交换机S3700清空配置方法

1、用户视图下输入&#xff1a;reset saved-configuration&#xff1b;输入&#xff1a;Y&#xff0c;确认清除 2、输入&#xff1a;reboot&#xff1b;重启系统&#xff08;第1次提示输入&#xff1a;N 不保存配置&#xff1b;第2次提示输入&#xff1a;Y 确认重启&#xff0…

Udp通讯(零基础)

前面学习了Tcp通讯之后听老师同学们讲到Udp也可以通讯&#xff0c;实现还要跟简单&#xff0c;没有繁琐的连接&#xff0c;所以最近学习了一下&#xff0c;记录下来以便忘记&#xff0c;同时也发表出来与大家相互学习&#xff0c;下面是我自己写的一个聊天例子&#xff0c;实现…

VOC数据集格式转化成COCO数据集格式

VOC数据集格式转化成COCO数据集格式 一、唠叨 之前写过一篇关于coco数据集转化成VOC格式的博客COCO2VOC&#xff0c;最近读到CenterNet的官方代码&#xff0c;实现上则是将voc转化成coco数据格式&#xff0c;这样的操作我个人感觉很不习惯&#xff0c;也觉得有些奇葩&…

react native android6+拍照闪退或重启的解决方案

前言 android 6权限使用的时候需要动态申请&#xff0c;那么在使用rn的时候要怎么处理拍照权限问题呢&#xff1f;本文提供的是一揽子rn操作相册、拍照的解决方案&#xff0c;请看正文的提高班部分。 解决步骤 1、AndroidManifest.xml设置拍照权限&#xff1a; <uses-perm…

学术论文SCI、期刊、毕业设计中的图表专用软件

Origin Origin是由OriginLab公司开发的一个科学绘图、数据分析软件&#xff0c;支持在Microsoft Windows下运行。Origin支持各种各样的2D/3D图形。Origin中的数据分析功能包括统计&#xff0c;信号处理&#xff0c;曲线拟合以及峰值分析。 Origin中的曲线拟合是采用基于Lever…

常用的学术论文图表(折线图、柱状图)matplotlib python代码模板

最终选用了pythonMatplotlib。Matplotlib是著名Python的标配画图包&#xff0c;其绘图函数的名字基本上与 Matlab 的绘图函数差不多。优点是曲线精致&#xff0c;软件开源免费&#xff0c;支持Latex公式插入&#xff0c;且许多时候只需要一行或几行代码就能搞定。 然后小编经过…

史上最详细nodejs版本管理器nvm的安装与使用(附注意事项和优化方案)

使用场景 在Node版本快速更新迭代的今天&#xff0c;新老项目使用的node版本号可能已经不相同了&#xff0c;node版本更新越来越快&#xff0c;项目越做越多&#xff0c;node切换版本号的需求越来越迫切&#xff0c;传统卸载一个版本在安装另一个版本的方式太过于麻烦&#xf…

【TensorFlow】 基于视频时序LSTM的行为动作识别

简介 本文基于LSTM来完成用户行为识别。数据集来源&#xff1a;https://archive.ics.uci.edu/ml/machine-learning-databases/00240/ 此数据集一共有6种行为状态&#xff1a; 行走&#xff1b; 站立&#xff1b; 躺下&#xff1b; 坐下&#xff1b; 上楼&#xff1b; 下楼&am…

利用Asp.net MVC处理文件的上传下载

如果你仅仅只有Asp.net Web Forms背景转而学习Asp.net MVC的&#xff0c;我想你的第一个经历或许是那些曾经让你的编程变得愉悦无比的服务端控件都驾鹤西去了.FileUpload就是其中一个&#xff0c;而这个控件的缺席给我们带来一些小问题。这篇文章主要说如何在Asp.net MVC中上传…

Python遍历文件夹下所有文件及目录

遍历文件夹中的所有子文件夹及子文件使用os.walk()方法非常简单。 语法格式大致如下&#xff1a; os.walk(top[, topdownTrue[, onerrorNone[, followlinksFalse]]]) top – 根目录下的每一个文件夹(包含它自己), 产生3-元组 (dirpath, dirnames, filenames)【文件夹路径, 文…

华为交换机telnet和ftp服务开启/关闭命令

1.telnet开启/关闭 在系统视图下 启用方式如下&#xff1a; telnet server enable //使能telnet服务关闭方式如下&#xff1a; undo telnet server //关闭telnet服务2.FTP开启/关闭 通过display ftp-server查看启用状态 如果已经启用&#xff0c;会在查看的命令中显示&#…

为什么用Spring来管理Hibernate?

为什么用Spring来管理Hibernate&#xff1f;为什么要用Hibernate框架&#xff1f;这个在《Hibernate介绍》博客中已经提到了。既然用Hibernate框架访问管理持久层&#xff0c;那为何又提到用Spring来管理以及整合Hibernate呢&#xff1f;首先我们来看一下Hibernate进行操作的步…

Python-Pandas之两个Dataframe的差异比较

昨天在外网找到一个比较dataframe的好库&#xff0c;叫datacompy&#xff0c;它的优点有&#xff1a; 1、可以把对比后的信息详情打印出来&#xff0c;比如列是否相等&#xff0c;行是否相等&#xff1b; 2、在数据中如果有不相等列&#xff0c;那么就只比较相同的列&#xf…