【PyTorch】基于YOLO的多目标检测项目(一)

【PyTorch】基于YOLO的多目标检测项目(一)

【PyTorch】基于YOLO的多目标检测项目(二)

目标检测是对图像中的现有目标进行定位和分类的过程。识别的对象在图像中显示有边界框。一般的目标检测方法有两种:基于区域提议的和基于回归/分类的。这里使用一种基于回归/分类的方法,称为YOLO。

目录

准备COCO数据集

创建自定义数据集

转换数据

定义数据加载器


准备COCO数据集

COCO是一个大规模的对象检测,分割和字幕数据集。它包含80个对象类别用于对象检测。

下载以下GitHub存储库

https://github.com/pjreddie/darkneticon-default.png?t=N7T8https://github.com/pjreddie/darknet

创建一个名为config的文件夹,将darknet/cfg/coco.data、darknet/cfg/yolov3.cfg文件复制到config文件夹中。

创建一个名为data的文件夹,从以下链接获取coco.names文件,并将其放入data文件夹,coco.names文件包含COCO数据集中80个对象类别的列表。

darknet/data/coco.names at master · pjreddie/darknet · GitHubConvolutional Neural Networks. Contribute to pjreddie/darknet development by creating an account on GitHub.icon-default.png?t=N7T8https://github.com/pjreddie/darknet/blob/master/data/coco.names将darknet/scripts/get_coco_dataset.sh文件复制到data文件夹中,并复制get_coco_cocoet.sh到data文件夹。接下来,打开一个终端并执行get_coco_cocoet.sh,该脚本将把完整的COCO数据集下载到名为coco的子文件夹中。也可通过以下链接下载coco数据集。

COCO2014_数据集-飞桨AI Studio星河社区 (baidu.com)icon-default.png?t=N7T8https://aistudio.baidu.com/datasetdetail/165195

在images文件夹中,有两个名为train 2014和val 2014的文件夹,分别包含82783和40504个图像。在labels文件夹中,有两个名为train 2014和val 2014的标签,分别包含82081和40137文本文件。这些文本文件包含图像中对象的边界框坐标。此外,trainvalno5k.txt文件是一个包含117264张图像的列表,这些图像将用于训练模型。此列表是train2014和val2014中图像的组合,5000个图像除外。5k.txt文件包含将用于验证的5000个图像的列表。

创建自定义数据集

完成数据集下载后,使用PyTorch的Dataset和Dataloader类创建训练和验证数据集和数据加载器。

from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms.functional as TF
import os
import numpy as npimport torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(torch.__version__)
#定义CocoDataset类,并展示来自训练和验证数据集的一些示例图像
class CocoDataset(Dataset):def __init__(self, path2listFile, transform=None, trans_params=None):with open(path2listFile, "r") as file:self.path2imgs = file.readlines()self.path2labels = [path.replace("images", "labels").replace(".png", ".txt").replace(".jpg", ".txt")for path in self.path2imgs]self.trans_params = trans_paramsself.transform = transformdef __len__(self):return len(self.path2imgs)def __getitem__(self, index):path2img = self.path2imgs[index % len(self.path2imgs)].rstrip()img = Image.open(path2img).convert('RGB')path2label = self.path2labels[index % len(self.path2imgs)].rstrip()labels= Noneif os.path.exists(path2label):labels = np.loadtxt(path2label).reshape(-1, 5)if self.transform:img, labels = self.transform(img, labels, self.trans_params)return img, labels, path2img    
root_data="./data/coco"
path2trainList=os.path.join(root_data, "trainvalno5k.txt")coco_train = CocoDataset(path2trainList)
print(len(coco_train))

 

# 从coco_train中获取图像、标签和图像路径
img, labels, path2img = coco_train[1] 
print("image size:", img.size, type(img))
print("labels shape:", labels.shape, type(labels))
print("labels \n", labels)

path2valList=os.path.join(root_data, "5k.txt")
coco_val = CocoDataset(path2valList, transform=None, trans_params=None)
print(len(coco_val))

img, labels, path2img = coco_val[7] 
print("image size:", img.size, type(img))
print("labels shape:", labels.shape, type(labels))
print("labels \n", labels)

import matplotlib.pylab as plt
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from torchvision.transforms.functional import to_pil_image
import random
%matplotlib inline
path2cocoNames="./data/coco.names"
fp = open(path2cocoNames, "r")
coco_names = fp.read().split("\n")[:-1]
print("number of classese:", len(coco_names))
print(coco_names)

def rescale_bbox(bb,W,H):x,y,w,h=bbreturn [x*W, y*H, w*W, h*H]
COLORS = np.random.randint(0, 255, size=(80, 3),dtype="uint8")
# fnt = ImageFont.truetype('Pillow/Tests/fonts/FreeMono.ttf', 16)
fnt = ImageFont.truetype('arial.ttf', 16)
def show_img_bbox(img,targets):if torch.is_tensor(img):img=to_pil_image(img)if torch.is_tensor(targets):targets=targets.numpy()[:,1:]W, H=img.sizedraw = ImageDraw.Draw(img)for tg in targets:id_=int(tg[0])bbox=tg[1:]bbox=rescale_bbox(bbox,W,H)xc,yc,w,h=bboxcolor = [int(c) for c in COLORS[id_]]name=coco_names[id_]draw.rectangle(((xc-w/2, yc-h/2), (xc+w/2, yc+h/2)),outline=tuple(color),width=3)draw.text((xc-w/2,yc-h/2),name, font=fnt, fill=(255,255,255,0))plt.imshow(np.array(img))        
np.random.seed(1)
rnd_ind=np.random.randint(len(coco_train))
img, labels, path2img = coco_train[rnd_ind] 
print(img.size, labels.shape)plt.rcParams['figure.figsize'] = (20, 10)
show_img_bbox(img,labels)

np.random.seed(1)
rnd_ind=np.random.randint(len(coco_val))
img, labels, path2img = coco_val[rnd_ind] 
print(img.size, labels.shape)plt.rcParams['figure.figsize'] = (20, 10)
show_img_bbox(img,labels)

转换数据

定义一个转换函数和传递给CocoDataset类的参数

def pad_to_square(img, boxes, pad_value=0, normalized_labels=True):w, h = img.sizew_factor, h_factor = (w,h) if normalized_labels else (1, 1)dim_diff = np.abs(h - w)pad1= dim_diff // 2pad2= dim_diff - pad1if h<=w:left, top, right, bottom= 0, pad1, 0, pad2else:left, top, right, bottom= pad1, 0, pad2, 0padding= (left, top, right, bottom)img_padded = TF.pad(img, padding=padding, fill=pad_value)w_padded, h_padded = img_padded.sizex1 = w_factor * (boxes[:, 1] - boxes[:, 3] / 2)y1 = h_factor * (boxes[:, 2] - boxes[:, 4] / 2)x2 = w_factor * (boxes[:, 1] + boxes[:, 3] / 2)y2 = h_factor * (boxes[:, 2] + boxes[:, 4] / 2)    x1 += padding[0] # 左y1 += padding[1] # 上x2 += padding[2] # 右y2 += padding[3] # 下boxes[:, 1] = ((x1 + x2) / 2) / w_paddedboxes[:, 2] = ((y1 + y2) / 2) / h_paddedboxes[:, 3] *= w_factor / w_paddedboxes[:, 4] *= h_factor / h_paddedreturn img_padded, boxes    
def hflip(image, labels):image = TF.hflip(image)labels[:, 1] = 1.0 - labels[:, 1]return image, labelsdef transformer(image, labels, params):if params["pad2square"] is True:image,labels= pad_to_square(image, labels)image = TF.resize(image,params["target_size"])if random.random() < params["p_hflip"]:image,labels=hflip(image,labels)image=TF.to_tensor(image)targets = torch.zeros((len(labels), 6))targets[:, 1:] = torch.from_numpy(labels)return image, targets
trans_params_train={"target_size" : (416, 416),"pad2square": True,"p_hflip" : 1.0,"normalized_labels": True,
}
coco_train=CocoDataset(path2trainList,transform=transformer,trans_params=trans_params_train)np.random.seed(100)
rnd_ind=np.random.randint(len(coco_train))
img, targets, path2img = coco_train[rnd_ind] 
print("image shape:", img.shape)
print("labels shape:", targets.shape) plt.rcParams['figure.figsize'] = (20, 10)
COLORS = np.random.randint(0, 255, size=(80, 3),dtype="uint8")
show_img_bbox(img,targets)

通过传递 transformer 函数来定义 CocoDataset 的一个对象来验证数据 

trans_params_val={"target_size" : (416, 416),"pad2square": True,"p_hflip" : 0.0,"normalized_labels": True,
}
coco_val= CocoDataset(path2valList,transform=transformer,trans_params=trans_params_val)np.random.seed(55)
rnd_ind=np.random.randint(len(coco_val))
img, targets, path2img = coco_val[rnd_ind] 
print("image shape:", img.shape)
print("labels shape:", targets.shape) plt.rcParams['figure.figsize'] = (20, 10)
COLORS = np.random.randint(0, 255, size=(80, 3),dtype="uint8")
show_img_bbox(img,targets)

 

定义数据加载器

定义两个用于训练和验证数据集的数据加载器,从coco_train和coco_val中获取小批量数据。

from torch.utils.data import DataLoaderbatch_size=8
def collate_fn(batch):imgs, targets, paths = list(zip(*batch))targets = [boxes for boxes in targets if boxes is not None]for b_i, boxes in enumerate(targets):boxes[:, 0] = b_itargets = torch.cat(targets, 0)imgs = torch.stack([img for img in imgs])return imgs, targets, pathstrain_dl = DataLoader(coco_train,batch_size=batch_size,shuffle=True,num_workers=0,pin_memory=True,collate_fn=collate_fn,)torch.manual_seed(0)
for imgs_batch,tg_batch,path_batch in train_dl:break
print(imgs_batch.shape)
print(tg_batch.shape,tg_batch.dtype)

 

val_dl = DataLoader(coco_val,batch_size=batch_size,shuffle=False,num_workers=0,pin_memory=True,collate_fn=collate_fn,)torch.manual_seed(0)
for imgs_batch,tg_batch,path_batch in val_dl:break
print(imgs_batch.shape)
print(tg_batch.shape,tg_batch.dtype)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/50253.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何找到最快解析速度的DNS

如何找到最快解析速度的DNS DNS&#xff0c;即域名系统&#xff08;Domain Name System&#xff09;&#xff0c;是互联网的一项服务。它作为将域名和IP地址相互映射的一个分布式数据库&#xff0c;能够使用户更方便地访问互联网&#xff0c;而不用记住能够被机器直接读取的IP数…

6.乳腺癌良性恶性预测(二分类、逻辑回归、PCA降维、SVD奇异值分解)

乳腺癌良性恶性预测 1. 特征工程1.1 特征筛选1.2 特征降维 PCA1.3 SVD奇异值分解 2. 代码2.1 逻辑回归、二分类问题2.2 特征降维 PCA2.3 SVD奇异值分解 1. 特征工程 专业上&#xff1a;30个人特征来自于临床一线专家&#xff0c;每个特征和都有医学内涵&#xff1b;数据上&…

7月25日JavaSE学习笔记

线程的生命周期中&#xff0c;等待是主动的&#xff0c;阻塞是被动的 锁对象 创建锁对象&#xff0c;锁对象同一时间只允许一个线程进入 //创建锁对象Lock locknew ReentrantLock(true);//创建可重入锁 可重入锁&#xff1a;在嵌套代码块中&#xff0c;锁对象一样就可以直接…

进销存系统开发,进销存源码解析,添加商品选择商品

点击添加商品信息&#xff08;可以&#xff09; (关键字范围&#xff1a;商品名称&#xff0c;简拼&#xff0c;条形码&#xff0c;SKU,规格&#xff0c;参数&#xff0c;尺寸&#xff0c;接口&#xff0c;CPU,品牌) function cwpd_selSaleGoodsNewMore_Vtax2024(domid,width…

sed利用脚本处理文件

一、sed是什么 sed 命令是利用脚本来处理文本文件。它可以依照脚本的指令来处理、编辑文本文件。主要用来自动编 辑一个或多个文件、简化对文件的反复操作、编写转换程序等。 二、sed的原理 读入新的一行内容到缓存空间&#xff1b; 从指定的操作指令中取出第一条指令&…

【时时三省】(C语言基础)分支语句2

山不在高&#xff0c;有仙则名。水不在深&#xff0c;有龙则灵。 ——csdn时时三省 多分支语句 if&#xff08;表达式1&#xff09; 语句1; else if&#xff08;表达式2&#xff09; 语句2; else 语句3; 如果表达式1成立语句1会执行 如果不成立表达式2执行 如果表达式2成…

【运维笔记】数据库无法启动,数据库炸后备份恢复数据

事情起因 在做docker作业的时候&#xff0c;把卷映射到了宿主机原来的mysql数据库目录上&#xff0c;宿主机原来的mysql版本为8.0&#xff0c;docker容器版本为5.6&#xff0c;导致翻车。 具体操作 备份目录 将/var/lib/mysql备份到~/mysql_backup&#xff1a;cp /var/lib/…

Multiview LM-ICP 配准算法

Multiview LM-ICP 配准算法针对一些大型的物体&#xff08;比如建筑物&#xff09;或者需要精细化建模的物体&#xff08;比如某个文物&#xff09;&#xff0c;仅仅进行成对的配准难以还原物体的全貌和细节。所以&#xff0c;多个视角的配准十分关键。 多视角的配准存在以下两…

[STM32]FlyMcu同时烧写BootLoader和APP文件-HEX文件组成

目录 一、前言 二、HEX文件的格式 三、组合HEX文件 四、使用FlyMcu烧录 一、前言 如题&#xff0c;BootLoader每次烧写都是全部擦除&#xff0c;当我们烧写APP程序的时候&#xff0c;BootLoader程序将不复存在&#xff0c;很多开发者或许只有USB转TTL模块&#xff0c;没有其…

grep命令搜索部分命令

首先 然后可以输入&#xff5c;以及grep命令 比如 bjobs| grep "3075*"bjobs| grep "3075"这个结果是这样的&#xff0c;

MYSQL 第四次作业

任务要求&#xff1a; 具体操作&#xff1a; 新建数据库&#xff1a; mysql> CREATE DATABASE mydb15_indexstu; Query OK, 1 row affected (0.01 sec) mysql> USE mydb15_indexstu; Database changed 新建表&#xff1a; mysql> CREATE TABLE student( ->…

遇到总条数count(*)返回不了数据

文章目录 前提1.准备数据1.1 建表语句1.2 插入数据 2.程序代码3.返回结果与分析4.验证 前提 获取h_user表中count(*)字段的值打印出来&#xff0c;打印出来是0&#xff0c;数据库中执行sql返回不是0。端点调试找到原因。下面先把数据库表数据及程序贴出来。 1.准备数据 1.1 …

CSS技巧专栏:一日一例 12 -纯CSS实现边框上下交错的按钮特效

CSS技巧专栏&#xff1a;一日一例 12 -纯CSS实现边框上下交错的按钮特效 大家好&#xff0c;今天我们来做一个上下边框交错闪动的按钮特效。 本例图片 案例分析 虽说这按钮给人的感觉就是上下两个边框交错变换了位置&#xff0c;但我们都知道border是没法移动的。那么这个按…

【无标KaiwuDB CTO 魏可伟:差异化创新,面向行业的多模架构题】

2024年7月16日&#xff0c;KaiwuDB CTO 魏可伟受邀于 2024 可信数据库发展大会主论坛发表演讲《多模一库 —— KaiwuDB 的现代数据库架构探索》&#xff0c;以下是演讲精华实录。 多模数据库 是顺应时代发展与融合趋势的产物 数据模型最早始于网状模型和层次模型&#xff0c;…

Spark实时(五):InputSource数据源案例演示

文章目录 InputSource数据源案例演示 一、​​​​​​​File Source 1、读取text文件 2、读取csv文件 3、读取json文件 二、Socket Source 三、Rate Source InputSource数据源案例演示 在Spark2.0版本之后&#xff0c;DataFrame和Dataset可以表示静态有边界的数据&am…

移动式气象站:便携科技的天气守望者

在科技日新月异的今天&#xff0c;我们身边的许多设备都在向着更加智能化、便携化的方向发展。而在气象观测领域&#xff0c;移动式气象站的出现&#xff0c;不仅改变了传统气象观测的固有模式&#xff0c;更以其灵活性和实时性&#xff0c;在气象监测、灾害预警等领域发挥着越…

MySQL练习05

题目 步骤 触发器 use mydb16_trigger; #使用数据库create table goods( gid char(8) primary key, name varchar(10), price decimal(8,2), num int);create table orders( oid int primary key auto_increment, gid char(10) not null, name varchar(10), price decima…

基于python的BP神经网络红酒品质分类预测模型

1 导入必要的库 import pandas as pd import numpy as np import matplotlib.pyplot as plt from sklearn.model_selection import train_test_split from sklearn.preprocessing import LabelEncoder from tensorflow.keras.models import Sequential from tenso…

NET8部署Kestrel服务HTTPS深入解读TLS协议之Certificate证书

Certificate证书 Certificate称为数字证书。数字证书是一种证明身份的电子凭证&#xff0c;它包含一个公钥和一些身份信息&#xff0c;用于验证数字签名和加密通信。数字证书在网络通信、电子签名、认证授权等场景中都有广泛应用。其特征如下&#xff1a; 由权威机构颁发&…

跟李沐学AI:池化层

目录 二维最大池化 填充、步幅和多个通道 平均池化层 池化层总结 二维最大池化 返回滑动窗口中的最大值。 图为池化窗口形状为 22 的最大池化层。着色部分是第一个输出元素&#xff0c;以及用于计算这个输出的输入元素: max(0,1,3,4)4。池化层与卷积层类似&#xff0c;不断…