基于 ResNet18 架构使用 deformable convolution的车道线检测

下面是一个基于关键点的车道线检测网络的 PyTorch 代码示例,其中使用了 deformable convolution。该代码示例基于 ResNet18 架构,可以根据实际情况进行修改。

首先,需要导入必要的库和模块:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.utils import _pair
from torch.nn.parameter import Parameterfrom torchvision.models.resnet import resnet18

然后,定义一个基于 ResNet18 架构的车道线检测网络模型:

class LaneDetectionNet(nn.Module):def __init__(self, num_classes=1, deformable_groups=2):super(LaneDetectionNet, self).__init__()# load ResNet18self.resnet = resnet18(pretrained=True)# replace the first conv layerself.resnet.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)# add deformable convolutionsself.resnet.layer1[0].conv1 = DeformConv2d(64, 64, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1), bias=False, deformable_groups=deformable_groups)self.resnet.layer1[0].conv2 = DeformConv2d(64, 64, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1), bias=False, deformable_groups=deformable_groups)self.resnet.layer2[0].conv1 = DeformConv2d(128, 128, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1), bias=False, deformable_groups=deformable_groups)self.resnet.layer2[0].conv2 = DeformConv2d(128, 128, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1), bias=False, deformable_groups=deformable_groups)self.resnet.layer3[0].conv1 = DeformConv2d(256, 256, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1), bias=False, deformable_groups=deformable_groups)self.resnet.layer3[0].conv2 = DeformConv2d(256, 256, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1), bias=False, deformable_groups=deformable_groups)self.resnet.layer4[0].conv1 = DeformConv2d(512, 512, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1), bias=False, deformable_groups=deformable_groups)self.resnet.layer4[0].conv2 = DeformConv2d(512, 512, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1), bias=False, deformable_groups=deformable_groups)# add the output layersself.fc1 = nn.Linear(512, 512)self.fc2 = nn.Linear(512, num_classes)def forward(self, x):x = self.resnet(x)x = F.relu(self.fc1(x))x = self.fc2(x)return x

其中,DeformConv2d 是一个 deformable convolution 的实现类。其代码如下:

class DeformConv2d(nn.Module):def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, deformable_groups=1):super(DeformConv2d, self).__init__()self.offset_conv = nn.Conv2d(in_channels, 2 * kernel_size[0] * kernel_size[1] * deformable_groups, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=True)self.weight = Parameter(torch.Tensor(out_channels, in_channels, kernel_size[0], kernel_size[1]))if bias:self.bias = Parameter(torch.Tensor(out_channels))else:self.register_parameter('bias', None)self.reset_parameters()self.stride = _pair(stride)self.padding = _pair(padding)self.dilation = _pair(dilation)self.groups = groupsself.deformable_groups = deformable_groupsdef reset_parameters(self):nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))if self.bias is not None:fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)bound = 1 / math.sqrt(fan_in)nn.init.uniform_(self.bias, -bound, bound)def forward(self, x):offset = self.offset_conv(x)output = deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups, self.deformable_groups)return output

最后,定义一个 deformable convolution 的实现函数 deform_conv2d,代码如下:

def deform_conv2d(input, offset, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, deformable_groups=1):# get shapes and parametersbatch_size, in_channels, in_h, in_w = input.size()out_channels, _, kernel_h, kernel_w = weight.size()stride_h, stride_w = _pair(stride)pad_h, pad_w = _pair(padding)dilation_h, dilation_w = _pair(dilation)input_padded = F.pad(input, (pad_w, pad_w, pad_h, pad_h))# calculate output shapeout_h = (in_h + 2*pad_h - dilation_h*(kernel_h-1) - 1) // stride_h + 1out_w = (in_w + 2*pad_w - dilation_w*(kernel_w-1) - 1) // stride_w + 1# unfold input and offsetoffset = offset.view(batch_size, deformable_groups, 2 * kernel_h * kernel_w, out_h, out_w)input_unfolded = F.unfold(input_padded, (kernel_h, kernel_w), dilation=dilation, stride=stride)# calculate outputoutput = torch.zeros(batch_size, out_channels, out_h, out_w).to(input.device)weight = weight.view(1, out_channels, in_channels // groups, kernel_h, kernel_w).repeat(batch_size, 1, 1, 1, 1)for h in range(out_h):for w in range(out_w):input_region = input_unfolded[:, :, h, w].view(batch_size, -1, 1, 1)offset_region = offset[:, :, :, h, w]weight_region = weightoutput_region = F.conv2d(input_region, weight_region, bias=None, stride=1, padding=0, dilation=1, groups=deformable_groups)output_region = deformable_conv2d_compute(output_region, offset_region)output[:, :, h, w] = output_region.squeeze()if bias is not None:output += bias.view(1, -1, 1, 1)return output

其中,deformable_conv2d_compute 函数是 deformable convolution 的计算函数。它的代码如下:

def deformable_conv2d_compute(input, offset):# get shapes and parametersbatch_size, out_channels, out_h, out_w = input.size()in_channels = offset.size(1) // 2# sample input according to offsetgrid_h = torch.linspace(-1, 1, out_h).view(1, 1, out_h, 1).to(input.device)grid_w = torch.linspace(-1, 1, out_w).view(1, 1, 1, out_w).to(input.device)offset_h = offset[:, :in_channels, :, :]offset_w = offset[:, in_channels:, :, :]sample_h = torch.add(grid_h, offset_h)sample_w = torch.add(grid_w, offset_w)sample_h = sample_h.clamp(-1, 1)sample_w = sample_w.clamp(-1, 1)sample_h = ((sample_h + 1) / 2) * (out_h - 1)sample_w = ((sample_w + 1) / 2) * (out_w - 1)sample_h_floor = sample_h.floor().long()sample_w_floor = sample_w.floor().long()sample_h_ceil = sample_h.ceil().long()sample_w_ceil = sample_w.ceil().long()sample_h_floor = sample_h_floor.clamp(0, out_h - 1)sample_w_floor = sample_w_floor.clamp(0, out_w - 1)sample_h_ceil = sample_h_ceil.clamp(0, out_h - 1)sample_w_ceil = sample_w_ceil.clamp(0, out_w - 1)# gather input values according to sampled indicesinput_flat = input.view(batch_size, in_channels, out_h * out_w)index_base = torch.arange(0, batch_size, device=input.device).view(batch_size, 1, 1) * out_h * out_windex_base = index_base.expand(batch_size, in_channels, out_h * out_w)index_offset = torch.arange(0, out_h * out_w, device=input.device).view(1, 1, -1)index_offset = index_offset.expand(batch_size, in_channels, out_h * out_w)indices_a = (sample_h_floor + index_base + index_offset).view(batch_size, in_channels * out_h * out_w)indices_b = (sample_w_floor + index_base + index_offset).view(batch_size, in_channels * out_h * out_w)indices_c = (sample_h_ceil + index_base + index_offset).view(batch_size, in_channels * out_h * out_w)indices_d = (sample_w_ceil + index_base + index_offset).view(batch_size, in_channels * out_h * out_w)value_a = input_flat.gather(2, indices_a.unsqueeze(1).repeat(1, out_channels, 1))value_b = input_flat.gather(2, indices_b.unsqueeze(1).repeat(1, out_channels, 1))value_c = input_flat.gather(2, indices_c.unsqueeze(1).repeat(1, out_channels, 1))value_d = input_flat.gather(2, indices_d.unsqueeze(1).repeat(1, out_channels, 1))# calculate interpolation weights and outputw_a = ((sample_w_ceil - sample_w) * (sample_h_ceil - sample_h)).view(batch_size, 1, out_h, out_w)w_b = ((sample_w - sample_w_floor) * (sample_h_ceil - sample_h)).view(batch_size, 1, out_h, out_w)w_c = ((sample_w_ceil - sample_w) * (sample_h - sample_h_floor)).view(batch_size, 1, out_h, out_w)w_d = ((sample_w - sample_w_floor) * (sample_h - sample_h_floor)).view(batch_size, 1, out_h, out_w)output = w_a * value_a + w_b * value_b + w_c * value_c + w_d * value_dreturn output

最后,可以使用以下代码进行网络的测试:

net = LaneDetectionNet(num_classes=1, deformable_groups=2)  # create the network
input = torch.randn(1, 3, 100, 100)  # create a random input tensor
output = net(input)  # feed it through the network
print(output.shape)  # print the output shape

输出的结果应该为 (1, 1, 1, 1)。这说明网络已经成功地将 100*100 的像素图压缩成了一个标量。可以根据实际情况进行调整和优化,来达到更好的性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/121422.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

云端代码编辑器Atheos

什么是 Atheos ? Atheos是一个基于 Web 的 IDE 框架,占用空间小且要求最低,构建于 Codiad 之上,不过 Atheos 已从原始 Codiad 项目完全重写,以利用更现代的工具、更简洁的代码和更广泛的功能。 注意事项 群晖内核版本太…

React hooks的闭包陷阱

react hooks 陷阱 hooks必须放在函数顶层&#xff0c; 不能在条件分支和方法内 1、useState陷阱 异步陷阱 function Index() {const [count, setCount] useState(0)function add(){setCount( count 1 )console.log(count); // 0}return (<div><span>{count}…

浅谈中国汽车充电桩行业市场状况及充电桩选型的介绍

安科瑞虞佳豪 车桩比降低是完善新能源汽车行业配套的一大重要趋势&#xff0c;目前各国政府都在努力推进政策&#xff0c;通过税收减免、建设补贴等措施提升充电桩建设速度&#xff0c;以满足新能源汽车需求。 近年来&#xff0c;在需求和技术的驱动下&#xff0c;充电桩的平…

Mac怎么删除文件和软件?苹果电脑删除第三方软件方法

Mac删除程序这个话题为什么一直重复说或者太多人讨论呢&#xff1f;因为如果操作不当&#xff0c;可能会导致某些不好的影响。因为Mac电脑如果有太多无用的应用程序&#xff0c;很有可能会拖垮Mac系统的运行速度。或者如果因为删除不干净&#xff0c;导致残留文件积累在Mac电脑…

Jmeter的接口自动化测试

在去年实施了一年的三端&#xff08;PC、无线M站、无线APP【Android、IOS】&#xff09;后&#xff0c;今年7月份开始&#xff0c;我们开始进行接口自动化的实施&#xff0c;目前已完成了整个框架的搭建以及接口的持续测试集成。今天做个简单的分享。 在开始自动化投入前&#…

第14期 | GPTSecurity周报

GPTSecurity是一个涵盖了前沿学术研究和实践经验分享的社区&#xff0c;集成了生成预训练 Transformer&#xff08;GPT&#xff09;、人工智能生成内容&#xff08;AIGC&#xff09;以及大型语言模型&#xff08;LLM&#xff09;等安全领域应用的知识。在这里&#xff0c;您可以…

vscode 保存 “index.tsx“失败: 权限不足。选择 “以超级用户身份重试“ 以超级用户身份重试。

vscode 保存 "index.tsx"失败: 权限不足。选择 “以超级用户身份重试” 以超级用户身份重试。 操作&#xff1a;mac在文件夹中创建文件&#xff0c;sudo 创建umiJs项目 解决&#xff1a;修改文件夹权限 右键文件夹

STM32F103单片机内部RTC实时时钟驱动程序

一、STM32f103系列RTC功能 RTC实时时钟功能是嵌入式软件开发中比较常用的功能&#xff0c;一般MCU的RTC功能都带有年月日时间寄存器&#xff0c;比如STM32F4xx系列&#xff0c;RTC描述如下&#xff1a; 可见F4系列的RTC功能比较强大&#xff0c;设置好初始时间后&#xff0c;读…

NodeJS爬取墨刀上的设计图片

背景 设计人员分享了一个墨刀的原型图&#xff0c;但是给的是只读权限&#xff0c;无法下载其中的素材&#xff1b;开发时想下载里面的一张动图&#xff0c;通过浏览器的F12工具在页面结构找到了图片地址。 但是浏览器直接访问后发现没权限&#xff1a; Nginx 的 403 页面。。…

批量编辑 Outlook 联系人

现状 Outlook 自带的联系人编辑功能无法快速、批量编辑联系人字段使用 Excel 等外部编辑器&#xff0c;可批量编辑联系人 导出联系人到文件 在【联系人】界面&#xff0c;点击【文件】在【文件】界面&#xff0c;点击【打开和导出】–>【导入/导出】在弹出的向导窗口中点…

语法复习之C语言与指针

内存是如何存储数据的&#xff1f; 在C语言中定义一个变量后&#xff0c;系统就会为其分配内存空间。这个内存空间包括了地址和长度。将变量赋值后&#xff0c;该值就被写入到了指定的内存空间中。内存空间的大小一般以字节作为基本单位。   普通变量存放的是数据&#xff0c…

学习paddle-detection(paddlex的使用)

首先下载paddlex&#xff08;网页&#xff09;的本地软件&#xff0c;下载链接如下&#xff1a; paddlex 下载完成后进行安装 打开后选择开发者模式&#xff0c;开发者模式主要是和VScode进行集成 本章节主要介绍在开发者模式下可以查看和编辑的文件及其作用&#xff0c;关于…

【优选算法系列】第一节.双指针(283. 移动零和1089. 复写零)

作者简介&#xff1a;大家好&#xff0c;我是未央&#xff1b; 博客首页&#xff1a;未央.303 系列专栏&#xff1a;优选算法系列 每日一句&#xff1a;人的一生&#xff0c;可以有所作为的时机只有一次&#xff0c;那就是现在&#xff01;&#xff01;&#xff01;&#xff01…

LabVIEW开发基于图像处理的车牌检测系统

LabVIEW开发基于图像处理的车牌检测系统 自动车牌识别的一般步骤是图像采集、去除噪声的预处理、车牌定位、字符分割和字符识别。结果主要取决于所采集图像的质量。在不同照明条件下获得的图像具有不同的结果。在要使用的预处理技术中&#xff0c;必须将彩色图像转换为灰度&am…

LaTeX:在标题section中添加脚注footnote

命令讲解 先导包&#xff1a; \usepackage{footmisc} 设原标题为&#xff1a; \section{标题内容} 更改为&#xff1a; \section[标题内容]{标题内容\protect\footnote{脚注内容}} 语法讲解&#xff1a; \section[]{} []内为短标题&#xff0c;作为目录和页眉中的标题。…

在Golang中理解错误处理

处理Golang中临时错误和最终错误的策略和示例 作为一名精通Golang的开发人员&#xff0c;您了解有效的错误处理是编写健壮可靠软件的关键因素。在复杂系统中&#xff0c;错误可能采取各种形式&#xff0c;包括临时故障和最终失败。在本文中&#xff0c;我们将探讨处理Golang中…

【java】建筑施工一体化智慧工地信息管理系统源码

智慧工地系统是一种利用人工智能和物联网技术来监测和管理建筑工地的系统。它可以通过感知设备、数据处理和分析、智能控制等技术手段&#xff0c;实现对工地施工、设备状态、人员安全等方面的实时监控和管理。 一、智慧工地让工程施工智能化 1、内容全面&#xff0c;多维度数…

8086汇编环境的使用

先打开emu8086&#xff0c;写入代码 ;给11003H的地址赋1234H的值;不能直接给DS赋值需要寄存器中转 mov dx, 1100H mov ds, dx mov ax, 1234H ;不能直接给内存地址赋值&#xff0c;需要DS:[偏移地址]指向内存 mov [3H], ax 点击emulate开始模拟 出现调试框&#xff0c;调试框的…

IDEA部署SSM项目mysql数据库MAVEN项目部署教程

如果 SSM 项目是基于 Maven 构建的&#xff0c;则需要配置 maven 环境&#xff0c;否则跳过这一步 步骤一&#xff1a;配置 Maven 第一步&#xff1a;用 IDEA 打开项目&#xff0c;准备配置 maven 环境 &#xff0c;当然如果本地没有提前配置好 maven&#xff0c;就用 IDEA 默…

25-什么是事件循环

一、是什么 &#x1f37f;&#x1f37f;&#x1f37f;JavaScript是一门单线程的语言、 意味着同一时间内只能做一件事&#xff0c;但是这并不意味着单线程就是阻塞&#xff0c;而实现单线程非阻塞的方法就是事件循环 在JavaScript中&#xff0c;所有的任务都可以分为 同步任…