第97步 深度学习图像目标检测:RetinaNet建模

基于WIN10的64位系统演示

一、写在前面

本期开始,我们继续学习深度学习图像目标检测系列,RetinaNet模型。

二、RetinaNet简介

RetinaNet 是由 Facebook AI Research (FAIR) 的研究人员在 2017 年提出的一种目标检测模型。它是一种单阶段(one-stage)的目标检测方法,但通过引入一个名为 Focal Loss 的创新损失函数,RetinaNet 解决了单阶段检测器常面临的正负样本不平衡问题。以下是 RetinaNet 的主要特点:

(1)Focal Loss:

传统的交叉熵损失往往由于背景类(负样本)的数量远大于目标类(正样本)而导致训练不稳定。为了解决这一不平衡问题,RetinaNet 引入了 Focal Loss。Focal Loss 被设计为更重视那些难以分类的负样本,而减少对容易分类的背景类的关注。这有助于提高模型对目标的检测精度。

(2)特征金字塔网络 (FPN):

RetinaNet 使用了特征金字塔网络 (FPN) 作为其骨干网络,这是一个为多尺度目标检测设计的卷积网络结构。FPN 可以从单张图像中提取多尺度的特征,使得模型能够有效地检测不同大小的物体。

(3)预定义锚框:

与其他一阶段检测器相似,RetinaNet 在其特征图上使用预定义的锚框来预测目标的位置和类别。

三、数据源

来源于公共数据,文件设置如下:

大概的任务就是:用一个框框标记出MTB的位置。

四、RetinaNet实战

直接上代码:

import os
import random
import torch
import torchvision
from torchvision.models.detection import retinanet_resnet50_fpn 
from torchvision.transforms import functional as F
from PIL import Image
from torch.utils.data import DataLoader
import xml.etree.ElementTree as ET
import matplotlib.pyplot as plt
from torchvision import transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np# Function to parse XML annotations
def parse_xml(xml_path):tree = ET.parse(xml_path)root = tree.getroot()boxes = []for obj in root.findall("object"):bndbox = obj.find("bndbox")xmin = int(bndbox.find("xmin").text)ymin = int(bndbox.find("ymin").text)xmax = int(bndbox.find("xmax").text)ymax = int(bndbox.find("ymax").text)# Check if the bounding box is validif xmin < xmax and ymin < ymax:boxes.append((xmin, ymin, xmax, ymax))else:print(f"Warning: Ignored invalid box in {xml_path} - ({xmin}, {ymin}, {xmax}, {ymax})")return boxes# Function to split data into training and validation sets
def split_data(image_dir, split_ratio=0.8):all_images = [f for f in os.listdir(image_dir) if f.endswith(".jpg")]random.shuffle(all_images)split_idx = int(len(all_images) * split_ratio)train_images = all_images[:split_idx]val_images = all_images[split_idx:]return train_images, val_images# Dataset class for the Tuberculosis dataset
class TuberculosisDataset(torch.utils.data.Dataset):def __init__(self, image_dir, annotation_dir, image_list, transform=None):self.image_dir = image_dirself.annotation_dir = annotation_dirself.image_list = image_listself.transform = transformdef __len__(self):return len(self.image_list)def __getitem__(self, idx):image_path = os.path.join(self.image_dir, self.image_list[idx])image = Image.open(image_path).convert("RGB")xml_path = os.path.join(self.annotation_dir, self.image_list[idx].replace(".jpg", ".xml"))boxes = parse_xml(xml_path)# Check for empty bounding boxes and return Noneif len(boxes) == 0:return Noneboxes = torch.as_tensor(boxes, dtype=torch.float32)labels = torch.ones((len(boxes),), dtype=torch.int64)iscrowd = torch.zeros((len(boxes),), dtype=torch.int64)target = {}target["boxes"] = boxestarget["labels"] = labelstarget["image_id"] = torch.tensor([idx])target["iscrowd"] = iscrowd# Apply transformationsif self.transform:image = self.transform(image)return image, target# Define the transformations using torchvision
data_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),  # Convert PIL image to tensortorchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize the images
])# Adjusting the DataLoader collate function to handle None values
def collate_fn(batch):batch = list(filter(lambda x: x is not None, batch))return tuple(zip(*batch))def get_retinanet_model_for_finetuning(num_classes):# Load a RetinaNet model with a ResNet-50-FPN backbone without pre-trained weightsmodel = retinanet_resnet50_fpn(pretrained=False, num_classes=num_classes)return model# Function to save the model
def save_model(model, path="RetinaNet_mtb.pth", save_full_model=False):if save_full_model:torch.save(model, path)else:torch.save(model.state_dict(), path)print(f"Model saved to {path}")# Function to compute Intersection over Union
def compute_iou(boxA, boxB):xA = max(boxA[0], boxB[0])yA = max(boxA[1], boxB[1])xB = min(boxA[2], boxB[2])yB = min(boxA[3], boxB[3])interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)iou = interArea / float(boxAArea + boxBArea - interArea)return iou# Adjusting the DataLoader collate function to handle None values and entirely empty batches
def collate_fn(batch):batch = list(filter(lambda x: x is not None, batch))if len(batch) == 0:# Return placeholder batch if entirely emptyreturn [torch.zeros(1, 3, 224, 224)], [{}]return tuple(zip(*batch))#Training function with modifications for collecting IoU and loss
def train_model(model, train_loader, optimizer, device, num_epochs=10):model.train()model.to(device)loss_values = []iou_values = []for epoch in range(num_epochs):epoch_loss = 0.0total_ious = 0num_boxes = 0for images, targets in train_loader:# Skip batches with placeholder dataif len(targets) == 1 and not targets[0]:continue# Skip batches with empty targetsif any(len(target["boxes"]) == 0 for target in targets):continueimages = [image.to(device) for image in images]targets = [{k: v.to(device) for k, v in t.items()} for t in targets]loss_dict = model(images, targets)losses = sum(loss for loss in loss_dict.values())optimizer.zero_grad()losses.backward()optimizer.step()epoch_loss += losses.item()# Compute IoU for evaluationwith torch.no_grad():model.eval()predictions = model(images)for i, prediction in enumerate(predictions):pred_boxes = prediction["boxes"].cpu().numpy()true_boxes = targets[i]["boxes"].cpu().numpy()for pred_box in pred_boxes:for true_box in true_boxes:iou = compute_iou(pred_box, true_box)total_ious += iounum_boxes += 1model.train()avg_loss = epoch_loss / len(train_loader)avg_iou = total_ious / num_boxes if num_boxes != 0 else 0loss_values.append(avg_loss)iou_values.append(avg_iou)print(f"Epoch {epoch+1}/{num_epochs} Loss: {avg_loss} Avg IoU: {avg_iou}")# Plotting loss and IoU valuesplt.figure(figsize=(12, 5))plt.subplot(1, 2, 1)plt.plot(loss_values, label="Training Loss")plt.title("Training Loss across Epochs")plt.xlabel("Epochs")plt.ylabel("Loss")plt.subplot(1, 2, 2)plt.plot(iou_values, label="IoU")plt.title("IoU across Epochs")plt.xlabel("Epochs")plt.ylabel("IoU")plt.show()# Save model after trainingsave_model(model)# Validation function
def validate_model(model, val_loader, device):model.eval()model.to(device)with torch.no_grad():for images, targets in val_loader:images = [image.to(device) for image in images]targets = [{k: v.to(device) for k, v in t.items()} for t in targets]model(images)# Paths to your data
image_dir = "tuberculosis-phonecamera"
annotation_dir = "tuberculosis-phonecamera"# Split data
train_images, val_images = split_data(image_dir)# Create datasets and dataloaders
train_dataset = TuberculosisDataset(image_dir, annotation_dir, train_images, transform=data_transform)
val_dataset = TuberculosisDataset(image_dir, annotation_dir, val_images, transform=data_transform)# Updated DataLoader with new collate function
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)# Model and optimizer
model = get_retinanet_model_for_finetuning(2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# Train and validate
train_model(model, train_loader, optimizer, device="cuda", num_epochs=10)
validate_model(model, val_loader, device="cuda")#######################################Print Metrics######################################
def calculate_metrics(predictions, ground_truths, iou_threshold=0.5):TP = 0  # True PositivesFP = 0  # False PositivesFN = 0  # False Negativestotal_iou = 0  # to calculate mean IoUfor pred, gt in zip(predictions, ground_truths):pred_boxes = pred["boxes"].cpu().numpy()gt_boxes = gt["boxes"].cpu().numpy()# Match predicted boxes to ground truth boxesfor pred_box in pred_boxes:max_iou = 0matched = Falsefor gt_box in gt_boxes:iou = compute_iou(pred_box, gt_box)if iou > max_iou:max_iou = iouif iou > iou_threshold:matched = Truetotal_iou += max_iouif matched:TP += 1else:FP += 1FN += len(gt_boxes) - TPprecision = TP / (TP + FP) if (TP + FP) != 0 else 0recall = TP / (TP + FN) if (TP + FN) != 0 else 0f1_score = (2 * precision * recall) / (precision + recall) if (precision + recall) != 0 else 0mean_iou = total_iou / (TP + FP) if (TP + FP) != 0 else 0return precision, recall, f1_score, mean_ioudef evaluate_model(model, dataloader, device):model.eval()model.to(device)all_predictions = []all_ground_truths = []with torch.no_grad():for images, targets in dataloader:images = [image.to(device) for image in images]predictions = model(images)all_predictions.extend(predictions)all_ground_truths.extend(targets)precision, recall, f1_score, mean_iou = calculate_metrics(all_predictions, all_ground_truths)return precision, recall, f1_score, mean_ioutrain_precision, train_recall, train_f1, train_iou = evaluate_model(model, train_loader, "cuda")
val_precision, val_recall, val_f1, val_iou = evaluate_model(model, val_loader, "cuda")print("Training Set Metrics:")
print(f"Precision: {train_precision:.4f}, Recall: {train_recall:.4f}, F1 Score: {train_f1:.4f}, Mean IoU: {train_iou:.4f}")print("\nValidation Set Metrics:")
print(f"Precision: {val_precision:.4f}, Recall: {val_recall:.4f}, F1 Score: {val_f1:.4f}, Mean IoU: {val_iou:.4f}")#sheet
header = "| Metric    | Training Set | Validation Set |"
divider = "+----------+--------------+----------------+"train_metrics = f"| Precision | {train_precision:.4f}      | {val_precision:.4f}          |"
recall_metrics = f"| Recall    | {train_recall:.4f}      | {val_recall:.4f}          |"
f1_metrics = f"| F1 Score  | {train_f1:.4f}      | {val_f1:.4f}          |"
iou_metrics = f"| Mean IoU  | {train_iou:.4f}      | {val_iou:.4f}          |"print(header)
print(divider)
print(train_metrics)
print(recall_metrics)
print(f1_metrics)
print(iou_metrics)
print(divider)#######################################Train Set######################################
import numpy as np
import matplotlib.pyplot as pltdef plot_predictions_on_image(model, dataset, device, title):# Select a random image from the datasetidx = np.random.randint(50, len(dataset))image, target = dataset[idx]img_tensor = image.clone().detach().to(device).unsqueeze(0)# Use the model to make predictionsmodel.eval()with torch.no_grad():prediction = model(img_tensor)# Inverse normalization for visualizationinv_normalize = transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],std=[1/0.229, 1/0.224, 1/0.225])image = inv_normalize(image)image = torch.clamp(image, 0, 1)image = F.to_pil_image(image)# Plot the image with ground truth boxesplt.figure(figsize=(10, 6))plt.title(title + " with Ground Truth Boxes")plt.imshow(image)ax = plt.gca()# Draw the ground truth boxes in bluefor box in target["boxes"]:rect = plt.Rectangle((box[0], box[1]), box[2]-box[0], box[3]-box[1],fill=False, color='blue', linewidth=2)ax.add_patch(rect)plt.show()# Plot the image with predicted boxesplt.figure(figsize=(10, 6))plt.title(title + " with Predicted Boxes")plt.imshow(image)ax = plt.gca()# Draw the predicted boxes in redfor box in prediction[0]["boxes"].cpu():rect = plt.Rectangle((box[0], box[1]), box[2]-box[0], box[3]-box[1],fill=False, color='red', linewidth=2)ax.add_patch(rect)plt.show()# Call the function for a random image from the train dataset
plot_predictions_on_image(model, train_dataset, "cuda", "Selected from Training Set")#######################################Val Set####################################### Call the function for a random image from the validation dataset
plot_predictions_on_image(model, val_dataset, "cuda", "Selected from Validation Set")

这回也是需要从头训练的,就不跑了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/168286.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

RabbitMQ 安装(在docker容器中安装)

为什么要用&#xff1f; RabbitMQ是一个开源的消息代理和队列服务器&#xff0c;主要用于在不同的应用程序之间传递消息。它实现了高级消息队列协议&#xff08;AMQP&#xff09;&#xff0c;并提供了一种异步协作机制&#xff0c;以帮助提高系统的性能和扩展性。 RabbitMQ的作…

​LeetCode解法汇总2304. 网格中的最小路径代价

目录链接&#xff1a; 力扣编程题-解法汇总_分享记录-CSDN博客 GitHub同步刷题项目&#xff1a; https://github.com/September26/java-algorithms 原题链接&#xff1a;力扣&#xff08;LeetCode&#xff09;官网 - 全球极客挚爱的技术成长平台 描述&#xff1a; 给你一个下…

Flink实战(11)-Exactly-Once语义之两阶段提交

0 大纲 [Apache Flink]2017年12月发布的1.4.0版本开始&#xff0c;为流计算引入里程碑特性&#xff1a;TwoPhaseCommitSinkFunction。它提取了两阶段提交协议的通用逻辑&#xff0c;使得通过Flink来构建端到端的Exactly-Once程序成为可能。同时支持&#xff1a; 数据源&#…

【Redis】前言--介绍redis的全局系统观

一.前言 学习是要形成自己的网状知识以及知识架构图&#xff0c;要不最终都还是碎片化的知识&#xff0c;不能达到提升的目的&#xff0c;只有掌握了全貌的知识才是全解&#xff0c;要不只是一知半解。这章会介绍redis的系统架构图&#xff0c;帮助认识redis的设计是什么样的&a…

解决几乎任何机器学习问题 -- 学习笔记(组织机器学习项目)

书籍名&#xff1a;Approaching (Almost) Any Machine Learning Problem-解决几乎任何机器学习问题 此专栏记录学习过程&#xff0c;内容包含对这本书的翻译和理解过程 我们首先来看看文件的结构。对于你正在做的任何项目,都要创建一个新文件夹。在本例中,我 将项目命名为 “p…

笔记:内网渗透流程之信息收集

信息收集 首先&#xff0c;收集目标内网的信息&#xff0c;包括子网结构、域名信息、IP地址范围、开放的端口和服务等。这包括通过主动扫描和渗透测试工具收集信息&#xff0c;以及利用公开的信息源进行信息搜集。 本机信息收集 查看系统配置信息 查看系统详细信息&#xf…

电子桌牌如何赋能数字化会务?以深圳程序员节为例。

10月24日&#xff0c;由深圳市人民政府指导&#xff0c;深圳市工业和信息化局、龙华区人民政府、国家工业信息安全发展研究中心、中国软件行业协会联合主办的2023深圳中国1024程序员节开幕式暨主论坛活动在深圳龙华区启幕。以“领航鹏城发展&#xff0c;码动程序世界”为主题&a…

模拟退火算法应用——求解函数的最小值

仅作自己学习使用 一、问题 需求&#xff1a; 计算函数 的极小值&#xff0c;其中个体x的维数n10&#xff0c;即x(x1,x2,…,x10)&#xff0c;其中每一个分量xi均需在[-20,20]内。因此可以知道&#xff0c;这个函数只有一个极小值点x (0,0,…,0)&#xff0c;且其极小值是0&…

医保线上购药系统:引领医疗新潮流

在科技的驱动下&#xff0c;医疗健康服务正经历一场数字化的革新。医保线上购药系统&#xff0c;不仅是一种医疗服务的新选择&#xff0c;更是技术代码为我们的健康管理带来的全新可能。本文将通过一些简单的技术代码示例&#xff0c;深入解析医保线上购药系统的工作原理和优势…

MySQL数据库主从集群搭建

快捷查看指令 ctrlf 进行搜索会直接定位到需要的知识点和命令讲解&#xff08;如有不正确的地方欢迎各位小伙伴在评论区提意见&#xff0c;博主会及时修改&#xff09; MySQL数据库主从集群搭建 主从复制&#xff0c;是用来建立一个和主数据库完全一样的数据库环境&#xff0c…

短视频获客系统成功分享,与其开发流程与涉及到的技术

先来看实操成果&#xff0c;↑↑需要的同学可看我名字↖↖↖↖↖&#xff0c;或评论888无偿分享 一、短视频获客系统的开发流程 1. 需求分析&#xff1a;首先需要对目标用户进行深入了解&#xff0c;明确系统的功能和目标&#xff0c;制定详细的需求文档。 2. 系统设计&#…

关于vs code Debug调试时候出现“找不到任务C/C++: g++.exe build active file” 解决方法

vs code Debug调试时候出现“找不到任务C/C: g.exe build active file” &#xff0c;出现报错&#xff0c;Debug失败 后来经过摸索和上网查找资料解决问题 方法如下 在Vs code的操作页面左侧有几个配置文件 红框里的是需要将要修改的文件 查看tasks.json和launch.json框选&…

Android Frameworks 开发总结之七

1.修改android 系统/system/下面文件时权限不够问题 下面提到的方式目前在Bobcat的userdebug image上测试可行&#xff0c;还没有在user上测试过. 修改前: leifleif:~$ adb root restarting adbd as root leifleif:~$ adb disable-verity verity is already disabled using …

Find My鼠标|苹果Find My技术与鼠标结合,智能防丢,全球定位

随着折叠屏、多屏幕、OLED 等新兴技术在个人计算机上的应用&#xff0c;产品更新换代大大加速&#xff0c;进一步推动了个人计算机需求的增长。根据 IDC 统计&#xff0c;2021 年全球 PC 市场出货量达到 3.49 亿台&#xff0c;同比增长 14.80%&#xff0c;随着个人计算机市场发…

亚马逊云科技re:Invent大会:云计算与生成式AI共筑科技新局面,携手构建未来

随着科技的飞速发展&#xff0c;云计算和生成式 AI 已经成为了推动科技进步的重要力量。这两者相互结合&#xff0c;正在为我们创造一个全新的科技局面。 亚马逊云科技的re:Invent大会再次证明了云计算和生成式AI的强大结合正在塑造科技的新未来。这次大会聚焦了云计算的前沿技…

为什么要隐藏id地址?使用IP代理技术可以实现吗?

随着网络技术的不断发展&#xff0c;越来越多的人开始意识到保护个人隐私的重要性。其中&#xff0c;隐藏自己的IP地址已经成为了一种常见的保护措施。那么&#xff0c;为什么要隐藏IP地址&#xff1f;使用IP代理技术可以实现吗&#xff1f;下面就一起来探讨这些问题。 首先&am…

Qt 软件调试(二)使用dump捕获崩溃信息

Qt应用程序异常崩溃该怎么办&#xff0c;生成dump文件再回溯分析&#xff0c;可以快速且准确的帮助我们定位到崩溃的点。那么&#xff0c;本章我们分享下如何在Qt中生成dump文件。 一、使用minudump捕获崩溃信息 #include <QCoreApplication> #include <QDir> #i…

【洛谷 P1636】Einstein学画画 题解(图论+欧拉通路)

Einstein学画画 题目描述 Einstein 学起了画画。 此人比较懒~~&#xff0c;他希望用最少的笔画画出一张画…… 给定一个无向图&#xff0c;包含 n n n 个顶点&#xff08;编号 1 ∼ n 1 \sim n 1∼n&#xff09;&#xff0c; m m m 条边&#xff0c;求最少用多少笔可以画…

nodejs微信小程序+python+PHP-书吧租阅管理系统的设计与实现-安卓-计算机毕业设计

目 录 摘 要 I ABSTRACT II 目 录 II 第1章 绪论 1 1.1背景及意义 1 1.2 国内外研究概况 1 1.3 研究的内容 1 第2章 相关技术 3 2.1 nodejs简介 4 2.2 express框架介绍 6 2.4 MySQL数据库 4 第3章 系统分析 5 3.1 需求分析 5 3.2 系统可行性分析 5 3.2.1技术可行性&#xff1a;…

深度学习+不良身体姿势检测+警报系统+代码+部署(姿态识别矫正系统)

正确的身体姿势是一个人整体健康的关键。然而&#xff0c;保持正确的身体姿势可能很困难&#xff0c;因为我们经常忘记这一点。这篇博文将引导您完成为此构建解决方案所需的步骤。最近&#xff0c;我们在使用 POSE 进行身体姿势检测方面玩得很开心。它就像一个魅力&#xff01;…