CV(10)--目标检测

前言

仅记录学习过程,有问题欢迎讨论

目标检测

object detection,就是在给定的图片中精确找到物体所在位置,并标注出物体的类别;输出的是分类类别label+物体的外框(x, y, width, height)。

目标检测算法:
1、候选区域/框 + 深度学习分类:通过提取候选区域,并对相应区域进行以深度学习方法为主的
分类的方案,如:
• R-CNN(Selective Search + CNN + SVM)
• SPP-net(ROI Pooling)
• Fast R-CNN(Selective Search + CNN + ROI)
• Faster R-CNN(RPN + CNN + ROI)
2、 基于深度学习的回归方法:YOLO/SSD 等方法

IOU:
一个简单的测量标准,测量两个框之间的相似度,目标和实际的交集/并集,值越大越接近。
TP:预测为正样本,实际也为正样本
TN:预测为负样本,实际也为负样本
FP:预测为正样本,实际为负样本
FN:预测为负样本,实际为正样本

precision(精确度)和recall(召回率):精度是找得对,召回是找的全
precision = TP / (TP + FP)
recall = TP / (TP + FN)
F1 = 2 * (precision * recall) / (precision + recall)

边框回归:
目标是寻找一种关系使得原始窗口经过映射接近真实窗口
Input:
P=(Px,Py,Pw,Ph)
(注:训练阶段输入还包括 Ground Truth)
Output:
需要进行的平移变换和尺度缩放 dx,dy,dw,dh ,或者说是Δx,Δy,Sw,Sh 。
有了这四个变换我们就可以直接得到 Ground Truth。

TWO Stage:
Faster R-CNN(RPN + CNN + ROI)

Faster-RCNN

  1. Conv layers:Faster RCNN首先使用一组基础的conv+relu+pooling层提取 image的feature maps。该feature maps被共享用于后续 RPN层和全连接层。

    • 在Faster RCNN Conv layers中对所有的卷积都做了pad处理( pad=1,即填充一圈0),导致原图 变为 (M+2)x(N+2)大小,再做3x3卷积后输出MxN 。正是这种设置,导致Conv layers中的conv层 不改变输入和输出矩阵大小
    • Conv layers中的pooling层kernel_size=2,stride=2。 这样每个经过pooling层的MxN矩阵,都会变为(M/2)x(N/2)大小
    • 一个MxN大小的矩阵经过Conv layers固定变为(M/16)x(N/16)。 这样Conv layers生成的feature map都可以和原图对应起来
  2. Region Proposal Networks(RPN):RPN网络用于生成region proposals。通过softmax判断anchors属于 positive或者negative,再利用bounding box regression 修正anchors获得精确的proposals。

    • 直接使用RPN生成检测框,是Faster R-CNN的巨 大优势,能极大提升检测框的生成速度。

在这里插入图片描述

- 上面一条通过softmax分类anchors(按特征中心点穷举9个框),获得positive和negative分类-二分类;- 下面一条用于计算对于anchors的bounding box regression偏移量,以获得精确的proposal。- 最后的Proposal层则负责综合positive anchors和对应bounding box regression偏移量获取 proposals,同时剔除太小和超出边界的proposals
  1. Roi Pooling:该层收集输入的feature maps和proposals, 综合这些信息后提取proposal feature maps,送入后续全连接层判定目标类别。

    • proposal是对应MN尺度的,所以先使用spatial_scale参数将其映射回(M/16)(N/16)大小 feature map尺度;
    • 再将每个proposal对应的feature map区域水平分为pooled_w * pooled_h的网格;
    • 对网格的每一份都进行max pooling处理。
  2. Classification:利用proposal feature maps计算 proposal的类别,同时再次bounding box regression获得检测框最终的精确位置。

实现Faster-RCNN网络结构

"""
实现Faster-RCNN
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
import cv2# 定义骨干网络,这里使用 ResNet
class ResNetBackbone(nn.Module):def __init__(self):super(ResNetBackbone, self).__init__()resnet = torchvision.models.resnet50(pretrained=True)self.features = nn.Sequential(*list(resnet.children())[:-2])def forward(self, x):x = self.features(x)return x# 区域生成网络 (RPN)
class RPN(nn.Module):def __init__(self, in_channels, num_anchors):super(RPN, self).__init__()self.conv = nn.Conv2d(in_channels, 512, kernel_size=3, stride=1, padding=1)#  2表示每个锚点有两种可能的类别:正负样本。通过这层卷积,网络将对每个锚点预测其概率得分。self.cls_layer = nn.Conv2d(512, num_anchors * 2, kernel_size=1, stride=1)# 4 表示对于每个锚点,要预测其边界框的 4 个参数self.reg_layer = nn.Conv2d(512, num_anchors * 4, kernel_size=1, stride=1)def forward(self, x):x = F.relu(self.conv(x))cls_scores = self.cls_layer(x)bbox_preds = self.reg_layer(x)cls_scores = cls_scores.permute(0, 2, 3, 1).contiguous().view(x.size(0), -1, 2)bbox_preds = bbox_preds.permute(0, 2, 3, 1).contiguous().view(x.size(0), -1, 4)return cls_scores, bbox_preds# RoI 池化层
class RoIPooling(nn.Module):def __init__(self, output_size):super(RoIPooling, self).__init__()self.output_size = output_sizedef forward(self, features, rois):roi_features = []for i in range(features.size(0)):# 包含了感兴趣区域的信息roi = rois[i]# features = (batch_size, channels, height, width),池化到output_size的统一size大小roi_feature = torchvision.ops.roi_pool(features[i].unsqueeze(0), [roi], self.output_size)roi_features.append(roi_feature)roi_features = torch.cat(roi_features, dim=0)return roi_features# Faster R-CNN 模型
class FasterRCNN(nn.Module):def __init__(self, num_classes):super(FasterRCNN, self).__init__()self.backbone = ResNetBackbone()self.rpn = RPN(2048, 9)  # 假设使用 9 个锚点# 池化到 7*7self.roi_pooling = RoIPooling((7, 7))self.fc1 = nn.Linear(2048 * 7 * 7, 1024)self.fc2 = nn.Linear(1024, 1024)self.cls_layer = nn.Linear(1024, num_classes)self.reg_layer = nn.Linear(1024, num_classes * 4)def forward(self, x, rois=None):features = self.backbone(x)cls_scores, bbox_preds = self.rpn(features)if rois is not None:roi_features = self.roi_pooling(features, rois)roi_features = roi_features.view(roi_features.size(0), -1)fc1 = F.relu(self.fc1(roi_features))fc2 = F.relu(self.fc2(fc1))cls_preds = self.cls_layer(fc2)reg_preds = self.reg_layer(fc2)return cls_preds, reg_preds, cls_scores, bbox_predselse:return cls_scores, bbox_preds# 自定义数据集类
class CustomDataset(Dataset):def __init__(self, image_paths, target_paths, transform=None):self.image_paths = image_pathsself.target_paths = target_pathsself.transform = transformdef __len__(self):return len(self.image_paths)def __getitem__(self, idx):image = cv2.imread(self.image_paths[idx])image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)target = np.load(self.target_paths[idx], allow_pickle=True)if self.transform:image = self.transform(image)return image, target# 数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 训练函数
def train(model, dataloader, optimizer, criterion_cls, criterion_reg):model.train()total_loss = 0for images, targets in dataloader:images = images.to(device)targets = [t.to(device) for t in targets]optimizer.zero_grad()cls_preds, reg_preds, cls_scores, bbox_preds = model(images, targets)# 计算分类和回归损失,这里假设 targets 包含真实类别和边界框信息cls_loss = criterion_cls(cls_preds, targets)reg_loss = criterion_reg(reg_preds, targets)loss = cls_loss + reg_lossloss.backward()optimizer.step()total_loss += loss.item()return total_loss / len(dataloader)# 评估函数
def evaluate(model, dataloader):model.eval()correct = 0total = 0with torch.no_grad():for images, targets in dataloader:images = images.to(device)targets = [t.to(device) for t in targets]cls_preds, reg_preds, _, _ = model(images)# 计算评估指标,这里可根据具体需求实现# 例如计算 mAP 等return correct / totalif __name__ == "__main__":# 假设的图像和标注文件路径image_paths = ['img/street.jpg', 'img/street.jpg']target_paths = ['target1.npy', 'target2.npy']dataset = CustomDataset(image_paths, target_paths, transform)dataloader = DataLoader(dataset, batch_size=2, shuffle=True)device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')num_classes = 2  # 包括背景类model = FasterRCNN(num_classes).to(device)optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)criterion_cls = nn.CrossEntropyLoss()criterion_reg = nn.SmoothL1Loss()num_epochs = 10for epoch in range(num_epochs):loss = train(model, dataloader, optimizer, criterion_cls, criterion_reg)print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {loss}')# 评估accuracy = evaluate(model, dataloader)print(f'Accuracy: {accuracy}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/67355.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Qt】01-了解QT

踏入QT的殿堂之路 前言一、创建工程文件1.1 步骤介绍1.2 编译介绍方法1、方法2、编译成功 二、了解框架2.1 main.cpp2.2 .Pro文件2.2.1 注释需要打井号。2.2.2 F1带你进入帮助模式2.2.3 build文件 2.3 构造函数 三、编写工程3.1 main代码3.2 结果展示 四、指定父对象4.1 main代…

【深度学习】关键技术-激活函数(Activation Functions)

激活函数(Activation Functions) 激活函数是神经网络的重要组成部分,它的作用是将神经元的输入信号映射到输出信号,同时引入非线性特性,使神经网络能够处理复杂问题。以下是常见激活函数的种类、公式、图形特点及其应…

3.flask蓝图使用

构建一个目录结构 user_oper.py from flask import Blueprint, request, session, redirect, render_template import functools # 创建蓝图 user Blueprint(xkj, __name__)DATA_DICT {1: {"name": "张三", "age": 22, "gender": …

React第二十二章(useDebugValue)

useDebugValue useDebugValue 是一个专为开发者调试自定义 Hook 而设计的 React Hook。它允许你在 React 开发者工具中为自定义 Hook 添加自定义的调试值。 用法 const debugValue useDebugValue(value)参数说明 入参 value: 要在 React DevTools 中显示的值formatter?:…

【漏洞分析】DDOS攻防分析

0x00 UDP攻击实例 2013年12月30日,网游界发生了一起“追杀”事件。事件的主角是PhantmL0rd(这名字一看就是个玩家)和黑客组织DERP Trolling。 PhantomL0rd,人称“鬼王”,本名James Varga,某专业游戏小组的…

【 PID 算法 】PID 算法基础

一、简介 PID即:Proportional(比例)、Integral(积分)、Differential(微分)的缩写。也就是说,PID算法是结合这三种环节在一起的。粘一下百度百科中的东西吧。 顾名思义,…

PyTorch使用教程(1)—PyTorch简介

PyTorch是一个开源的深度学习框架,由Facebook人工智能研究院(FAIR)于2016年开发并发布,其主要特点包括自动微分功能和动态计算图的支持,使得模型建立更加灵活‌。官网网址:https://pytorch.org。以下是关于…

PyTorch框架——基于深度学习YOLOv5神经网络水果蔬菜检测识别系统

基于深度学习YOLOv5神经网络水果蔬菜检测识别系统,其能识别的水果蔬菜有15种,# 水果的种类 names: [黑葡萄, 绿葡萄, 樱桃, 西瓜, 龙眼, 香蕉, 芒果, 菠萝, 柚子, 草莓, 苹果, 柑橘, 火龙果, 梨子, 花生, 黄瓜, 土豆, 大蒜, 茄子, 白萝卜, 辣椒, 胡萝卜,…

Mac玩Steam游戏秘籍!

Mac玩Steam游戏秘籍! 大家好!最近有不少朋友在用MacBook玩Steam游戏时遇到不支持mac的问题。别担心,我来教你如何用第三方工具Crossover来畅玩这些不支持的游戏,简单又实用! 第一步:下载Crossover 首先&…

【网络篇】IP知识

IPv4首部与IPv6首部 IPv4相对于IPv6的好处: 1.IPv6可自动配置,即使没有DHCP服务器也可以实现自动分配IP地址,实现即插即用。 2.IPv6包首部长度采用固定40字节,删除了选项字段,以及首部校验和,简化了首部…

我的年度总结

这一年的人生起伏:从曙光到低谷再到新的曙光 其实本来没打算做年度总结的,无聊打开了帅帅的视频,结合自己最近经历的,打算简单聊下。因为原本打算做的内容会是一篇比较丧、低能量者的呻吟。 实习生与创业公司的零到一 第一段工…

Vue脚手架开发 Vue2基础 VueRouter的基本使用 vue-router路由案例

vue-router路由 Vue脚手架开发,创建项目:https://blog.csdn.net/c_s_d_n_2009/article/details/144973766 Vue Router,Vue Router | Vue.js 的官方路由,Vue.js 的官方路由,为 Vue.js 提供富有表现力、可配置的、方便…

Windows远程桌面网关出现重大漏洞

微软披露了其Windows远程桌面网关(RD Gateway)中的一个重大漏洞,该漏洞可能允许攻击者利用竞争条件,导致拒绝服务(DoS)攻击。该漏洞被标识为CVE-2025-21225,已在2025年1月的补丁星期二更新中得到…

c语言----------内存管理

内存管理 目录 一。作用域1.1 局部变量1.2 静态(static)局部变量1.3 全局变量1.4 静态(static)全局变量1.5 extern全局变量声明1.6 全局函数和静态函数1.7 总结 二。内存布局2.1 内存分区2.2 存储类型总结2.3内存操作函数1) memset()2) memcpy()3) memmove()4) memcmp() 2.4 堆…

【2024年华为OD机试】 (C卷,100分)- 堆栈中的剩余数字(Java JS PythonC/C++)

一、问题描述 题目描述 向一个空栈中依次存入正整数&#xff0c;假设入栈元素 n(1<n<2^31-1)按顺序依次为 nx…n4、 n3、n2、 n1, 每当元素入栈时&#xff0c;如果 n1n2…ny(y 的范围[2,x]&#xff0c; 1<x<1000)&#xff0c;则 n1~ny 全部元素出栈&#xff0c;重…

Java安全—SPEL表达式XXESSTI模板注入JDBCMyBatis注入

前言 之前我们讲过SpringBoot中的MyBatis注入和模板注入的原理&#xff0c;那么今天我们就讲一下利用以及发现。 这里推荐两个专门研究java漏洞的靶场&#xff0c;本次也是根据这两个靶场来分析代码&#xff0c;两个靶场都是差不多的。 https://github.com/bewhale/JavaSec …

51单片机入门基础

目录 一、基础知识储备 &#xff08;一&#xff09;了解51单片机的基本概念 &#xff08;二&#xff09;掌握数字电路基础 &#xff08;三&#xff09;学习C语言编程基础 二、开发环境搭建 &#xff08;一&#xff09;硬件准备 &#xff08;二&#xff09;软件准备 三、…

基于Java的百度AOI数据解析与转换的实现方法

目录 前言 一、AOI数据结构简介 1、官网的实例接口 2、响应参数介绍 二、Java对AOI数据的解析 1、数据解析流程图 2、数据解析实现 3、AOI数据解析成果 三、总结 前言 在当今信息化社会&#xff0c;地理信息数据在城市规划、交通管理、商业选址等领域扮演着越来越重要的…

【WEB】网络传输中的信息安全 - 加密、签名、数字证书与HTTPS

文章目录 1. 概述2. 网络传输安全2.1.什么是中间人攻击2.2. 加密和签名2.2.1.加密算法2.2.2.摘要2.2.3.签名 2.3.数字证书2.3.1.证书的使用2.3.2.根证书2.3.3.证书链 2.4.HTTPS 1. 概述 本篇主要是讲解讲一些安全相关的基本知识&#xff08;如加密、签名、证书等&#xff09;&…

shell练习2

需求&#xff1a;判断192.168.1.0/24网络中&#xff0c;当前在线的ip有哪些&#xff0c;并编写脚本打印出来。 #!/bin/bashnmap -sn 192.168.1.0/24 | grep Nmap scan report for | awk {print $5} 注意&#xff1a;当运行 bash ip.sh 时出现 nmap: command not found 的错误…