基于CNN+RNNs(LSTM, GRU)的红点位置检测(pytorch)

1 项目背景

需要在图片精确识别三跟红线所在的位置,并输出这三个像素的位置。

在这里插入图片描述
其中,每跟红线占据不止一个像素,并且像素颜色也并不是饱和度和亮度极高的红黑配色,每个红线放大后可能是这样的。

在这里插入图片描述

而我们的目标是精确输出每个红点的位置,需要精确到像素。也就是说,对于每根红线,模型需要输出橙色箭头所指的像素而不是蓝色箭头所指的像素的位置。

在之前尝试过纯 RNNs 检测红点,但是准确率感人,在噪声极低的情况下并不能精准识别位置。但是有次尝试transformer位置编码之后发现效果不错:

实验loss完全准确的点
GRU129.66411762.0/9000 (20%)
LSTM249.20531267.0/9000 (14%)
Position embedding + GRU16.34035025.0/9000 (56%)
Position embedding + LSTM204.15511603.0/9000 (18%)

这说明模型的难点在于学习位置信息而不是寻找颜色有问题的点。联想到CNN也能提供位置信息,我决定尝试卷积一下的效果。

2 数据集

还是之前那个代码合成的数据集数据集,每个数据集规模在15000张图片左右,在没有加入噪音的情况下,每个样本预览如图所示:
在这里插入图片描述
加入噪音后,每个样本的预览如下图所示:

在这里插入图片描述

图中黑色部分包含比较弱的噪声,并非完全为黑色。

数据集包含两个文件,一个是文件夹,里面包含了jpg压缩的图像数据:
在这里插入图片描述
另一个是csv文件,里面包含了每个图像的名字以及3根红线所在的像素的位置。

在这里插入图片描述

3 思路

其实思路特别朴素。就是在RNNs要读序列化数据之前先用CNN把数据跑一遍,让原始的输入序列变成具有局部特征表示的嵌入表示,卷积后提取的特征输入到 RNN层,RNN 保持了序列中的长时依赖信息。接下来先用 fc1 把 RNN 的输出映射成分数,然后用 fc2 预测三个具体位置,经过 Sigmoid 输出 [0, 1] 的相对位置,再与宽度相乘得到真实位置。具体的流程如下图所示:

在这里插入图片描述

4 结果

在图片长度为1080、低噪声环境时,对比实验的结果如下:

实验loss完全准确的点
GRU129.66411762.0/9000 (20%)
LSTM249.20531267.0/9000 (14%)
CNN+GRU1419.5781601.0/9000 (7%)
CNN+LSTM1166.4599762.0/9000 (8%)

1080长度下图片抽样预测的效果如下:

在这里插入图片描述

在简单图片中的效果跟其他方法差距不大——基本都能准确定位红线,但是还是没办法做到像素级别的精确

在这里插入图片描述

可能是我的打开方式不对,但是CNN+RNN的效果并不如意。

从训练过程来看存在过拟合:

在这里插入图片描述

5 代码

CNN+GRU结构:


class CNN_GRU(nn.Module):def __init__(self, config):super(CNN_GRU, self).__init__()self.input_size = config.input_sizeself.hidden_size = config.hidden_sizeself.num_layers = config.num_layersself.device = config.device# CNNself.conv1 = nn.Conv1d(in_channels=self.input_size, out_channels=64, kernel_size=3, padding=1)self.conv2 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3, padding=1)self.conv3 = nn.Conv1d(in_channels=128, out_channels=self.input_size, kernel_size=3, padding=1)self.gru = nn.GRU(input_size=self.input_size, hidden_size=self.hidden_size, num_layers=self.num_layers,batch_first=True, bidirectional=True, dropout=0.6)self.fc1 = nn.Sequential(nn.Linear(self.hidden_size * 2, 1))self.fc2 = nn.Sequential(nn.Linear(config.width, 3),  # predict 3 pointsnn.Sigmoid(),)self.scale = config.widthself.device = config.devicedef forward(self, x):x = x.squeeze(2)x = F.relu(self.conv1(x))  # (batch_size, 64, width)x = F.relu(self.conv2(x))  # (batch_size, 128, width)x = F.relu(self.conv3(x))  # (batch_size, input_size, width)x = x.permute(0, 2, 1)h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)output, _ = self.gru(x0, h0)scores = self.fc1(output).squeeze(-1)  # shape: (batch_size, 1080)predicted_positions = self.fc2(scores)scaled_predicted_positions = predicted_positions * self.scalefinal_predicted_positions = torch.clamp(scaled_predicted_positions, min=0, max=self.scale - 1)return final_predicted_positions

CNN+LSTM结构:

class CNN_GRU(nn.Module):def __init__(self, config):super(CNN_GRU, self).__init__()self.input_size = config.input_sizeself.hidden_size = config.hidden_sizeself.num_layers = config.num_layersself.device = config.device# CNNself.conv1 = nn.Conv1d(in_channels=self.input_size, out_channels=64, kernel_size=3, padding=1)self.conv2 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3, padding=1)self.conv3 = nn.Conv1d(in_channels=128, out_channels=self.input_size, kernel_size=3, padding=1)self.lstm = nn.LSTM(input_size=self.input_size, hidden_size=self.hidden_size, num_layers=self.num_layers,batch_first=True, bidirectional=True, dropout=0.6)self.fc1 = nn.Sequential(nn.Linear(self.hidden_size * 2, 1))self.fc2 = nn.Sequential(nn.Linear(config.width, 3),  # predict 3 pointsnn.Sigmoid(),)self.scale = config.widthself.device = config.devicedef forward(self, x):x = x.squeeze(2)x = F.relu(self.conv1(x))  # (batch_size, 64, width)x = F.relu(self.conv2(x))  # (batch_size, 128, width)x = F.relu(self.conv3(x))  # (batch_size, input_size, width)x = x.permute(0, 2, 1)h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)output, _ = self.lstm(x, (h0, c0))scores = self.fc1(output).squeeze(-1)  # shape: (batch_size, 1080)predicted_positions = self.fc2(scores)scaled_predicted_positions = predicted_positions * self.scalefinal_predicted_positions = torch.clamp(scaled_predicted_positions, min=0, max=self.scale - 1)return final_predicted_positions

路过的大佬有什么建议 ball ball 在评论区打出来,我会去尝试~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/61706.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【java】常用命令记录

1.java 2.jar 2.1 介绍 JAR包是Java中所特有一种压缩文档,其实大家就可以把它理解为.zip包。(也可以用war包. jar cvf aa.war)当然也是有区别的,JAR包中有一个META-INF\MANIFEST.MF文件,当你找成JAR包时,它会自动生成。JAR包是由JDK安装目录\bin\jar.exe命令生成的&#xff0…

树莓派搭建NextCloud:给数据一个安全的家

前言 NAS有很多方案,常见的有 Nextcloud、Seafile、iStoreOS、Synology、ownCloud 和 OpenMediaVault ,以下是他们的特点: 1. Nextcloud 优势: 功能全面:支持文件同步、共享、在线文档编辑、视频会议、日历、联系人…

数据集-目标检测系列- 花卉 鸡蛋花 检测数据集 frangipani >> DataBall

数据集-目标检测系列- 花卉 鸡蛋花 检测数据集 frangipani >> DataBall DataBall 助力快速掌握数据集的信息和使用方式,会员享有 百种数据集,持续增加中。 贵在坚持! 数据样例项目地址: * 相关项目 1)数据集…

初次体验加猜测信息安全管理与评估国赛阶段训练习

[第一部分] 网络安全事件响应 window操作系统服务器应急响应流程_windows 服务器应急响应靶场_云无迹的博客-CSDN博客 0、请提交攻击者攻击成功的第一时间,格式:YY:MM:DD hh:mm:ss1、请提交攻击者的浏览器版本2、请提交攻击者目录扫描所使用的工具名称…

Python Matplotlib 安装指南:使用 Miniconda 实现跨 Linux、macOS 和 Windows 平台安装

Python Matplotlib 安装指南:使用 Miniconda 实现跨 Linux、macOS 和 Windows 平台安装 Matplotlib是Python最常用的数据可视化工具之一,结合Miniconda可以轻松管理安装和依赖项。在这篇文章中,我们将详细介绍如何使用Miniconda在Linux、mac…

opencv-python 分离边缘粘连的物体(距离变换)

import cv2 import numpy as np# 读取图像,这里添加了判断图像是否读取成功的逻辑 img cv2.imread("./640.png") # 灰度图 gray cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # 高斯模糊 gray cv2.GaussianBlur(gray, (5, 5), 0) # 二值化 ret, binary cv2…

KubeSphere内网环境实践GO项目流水线

KubeSphere内网环境实践GO项目流水线 kubesphere官方给出的流水线都是在公网环境下,并对接github、dockerhub等环境。本文在内网实践部署,代码库使用内网部署的gitlab,镜像仓库使用harbor。 1. 环境准备 1.1 部署kubesphere环境 参考官方…

MINES

MINES (m)6A (I)dentification Using (N)anopor(E) (S)equencing Tombo(v1.4) 命令在 MINES 之前执行: (仅在 fast5 文件中尚未包含 fastq 时需要) tombo preprocess annotate_raw_with_fastqs --fast5-basedir /fast5_dir/ --fastq-file…

UE5材质篇5 简易水面

不得不说,UE5里搞一个水面实在是相比要自己写各种反射来说太友好了,就主要是开启一堆开关,lumen相关的,然后稍微连一些蓝图就几乎有了 这里要改一个shading model,要这个 然后要增加一个这个node 并且不需要连接base …

浦语提示词工程实践(LangGPT版,服务器上部署internlm2-chat-1_8b,踩坑很多才完成的详细教程,)

首先,在InternStudio平台上创建开发机。 创建成功后点击进入开发机打开WebIDE。进入后在WebIDE的左上角有三个logo,依次表示JupyterLab、Terminal和Code Server,我们使用Terminal就行。(JupyterLab可以直接看文件夹)…

小白学多线程(持续更新中)

1.JDK中的线程池 JDK中创建线程池有一个最全的构造方法,里面七个参数如上所示。 执行流程分析: 模拟条件:10个核心线程数,200个最大线程数,阻塞队列大小为100。 当有小于十个任务要处理时,因为小于核心线…

40分钟学 Go 语言高并发:Context包与并发控制

Context包与并发控制 学习目标 知识点掌握程度应用场景context原理深入理解实现机制并发控制和请求链路追踪超时控制掌握超时设置和处理API请求超时、任务限时控制取消信号传播理解取消机制和传播链优雅退出、资源释放context最佳实践掌握使用规范和技巧工程实践中的常见场景…

cocos creator 3.8 Node学习 3

//在Ts、js中 this指向当前的这个组件实例 //this下的一个数据成员node,指向组件实例化的这个节点 //同样也可以根据节点找到挂载的所有组件 //this.node 指向当前脚本挂载的节点//子节点与父节点的关系 // Node.parent是一个Node,Node.children是一个Node[] // th…

音频信号采集前端电路分析

音频信号采集前端电路 一、实验要求 要求设计一个声音采集系统 信号幅度:0.1mVpp到1Vpp 信号频率:100Hz到16KHz 搭建一个带通滤波器,滤除高频和低频部分 ADC采用套件中的AD7920,转换率设定为96Ksps ;96*161536 …

SpringBoot中使用Sharding-JDBC实战(实战+版本兼容+Bug解决)

一、实战 1、引入 ShardingSphere-JDBC 的依赖 https://mvnrepository.com/artifact/org.apache.shardingsphere/shardingsphere-jdbc/5.5.0 <!-- https://mvnrepository.com/artifact/org.apache.shardingsphere/shardingsphere-jdbc --> <dependency><grou…

网络编程 day1.2~day2——TCP和UDP的通信基础(TCP)

笔记脑图 作业&#xff1a; 1、将虚拟机调整到桥接模式联网。 2、TCP客户端服务器实现一遍。 服务器 #include <stdio.h> #include <string.h> #include <myhead.h> #define IP "192.168.60.44" #define PORT 6666 #define BACKLOG 20 int mai…

创建可重用React组件的实用指南

尽管React是全球最受欢迎和使用最广泛的前端框架之一&#xff0c;但许多开发者在重构代码以提高可复用性时仍然感到困难。如果你发现自己在React应用中不断重复相同的代码片段&#xff0c;那你来对地方了。 在本教程中&#xff0c;将向你介绍三个最常见的特征&#xff0c;表明是…

PyQT开发与实践:全面掌握跨平台桌面应用开发

目录 引言 PyQT简介 PyQT的主要特点 开发环境搭建 PyQT开发流程 1. 创建项目和主窗口 2. 添加控件和布局 3. 信号与槽 4. 样式和美化 高级特性 数据绑定和模型/视图编程 多线程和并发 国际化和本地化 实践案例&#xff1a;简单的计算器应用 1. 界面设计 2. 逻辑…

微信小程序条件渲染与列表渲染的全面教程

微信小程序条件渲染与列表渲染的全面教程 引言 在微信小程序的开发中,条件渲染和列表渲染是构建动态用户界面的重要技术。通过条件渲染,我们可以根据不同的状态展示不同的内容,而列表渲染则使得我们能够高效地展示一组数据。本文将详细讲解这两种渲染方式的用法,结合实例…

Origin教程003:数据导入(2)-从文件导入和导入矩阵数据

文章目录 3.3 从文件导入3.3.1 导入txt文件3.3.2 导入excel文件3.3.3 合并工作表3.4 导入矩阵数据3.3 从文件导入 所需数据 https://download.csdn.net/download/WwLK123/900267473.3.1 导入txt文件 选择【数据->从文件导入->导入向导】: 选择文件之后,点击完成即可…