pytorch实现常用的一些即插即用模块(长期更新)

1.可分离卷积

#coding:utf-8
import torch.nn as nnclass DWConv(nn.Module):def __init__(self, in_plane, out_plane):super(DWConv, self).__init__()self.depth_conv = nn.Conv2d(in_channels=in_plane,out_channels=in_plane,kernel_size=3,stride=1,padding=1,groups=in_plane)self.point_conv = nn.Conv2d(in_channels=in_plane,out_channels=out_plane,kernel_size=1,stride=1,padding=0,groups=1)def forward(self, x):x = self.depth_conv(x)x = self.point_conv(x)return xdef deubg_dw():import torchDW_model = DWConv(3, 32)x = torch.rand((32, 3, 320, 320))out = DW_model(x)print(out.shape)
if __name__ == '__main__':deubg_dw()

2.DBnet论文中的DBhead

#coding:utf-8
import torch
from torch import nnclass DBHead(nn.Module):def __init__(self, in_channels, out_channels, k=50):super().__init__()self.k = kself.binarize = nn.Sequential(nn.Conv2d(in_channels, in_channels // 4, 3, padding=1),nn.BatchNorm2d(in_channels // 4),nn.ReLU(inplace=True),nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 2, 2),nn.BatchNorm2d(in_channels // 4),nn.ReLU(inplace=True),nn.ConvTranspose2d(in_channels // 4, 1, 2, 2),nn.Sigmoid())self.binarize.apply(self.weights_init)self.thresh = self._init_thresh(in_channels)self.thresh.apply(self.weights_init)def forward(self, x):shrink_maps = self.binarize(x)threshold_maps = self.thresh(x)if self.training:#从父类继承的变量, train的时候默认是true, eval的时候会变为falsebinary_maps = self.step_function(shrink_maps, threshold_maps)y = torch.cat((shrink_maps, threshold_maps, binary_maps), dim=1)else:y = torch.cat((shrink_maps, threshold_maps), dim=1)return ydef weights_init(self, m):classname = m.__class__.__name__if classname.find('Conv') != -1:nn.init.kaiming_normal_(m.weight.data)elif classname.find('BatchNorm') != -1:m.weight.data.fill_(1.)m.bias.data.fill_(1e-4)def _init_thresh(self, inner_channels, serial=False, smooth=False, bias=False):in_channels = inner_channelsif serial:in_channels += 1self.thresh = nn.Sequential(nn.Conv2d(in_channels, inner_channels // 4, 3, padding=1, bias=bias),nn.BatchNorm2d(inner_channels // 4),nn.ReLU(inplace=True),self._init_upsample(inner_channels // 4, inner_channels // 4, smooth=smooth, bias=bias),nn.BatchNorm2d(inner_channels // 4),nn.ReLU(inplace=True),self._init_upsample(inner_channels // 4, 1, smooth=smooth, bias=bias),nn.Sigmoid())return self.threshdef _init_upsample(self, in_channels, out_channels, smooth=False, bias=False):if smooth:inter_out_channels = out_channelsif out_channels == 1:inter_out_channels = in_channelsmodule_list = [nn.Upsample(scale_factor=2, mode='nearest'),nn.Conv2d(in_channels, inter_out_channels, 3, 1, 1, bias=bias)]if out_channels == 1:module_list.append(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=1, bias=True))return nn.Sequential(module_list)else:return nn.ConvTranspose2d(in_channels, out_channels, 2, 2)def step_function(self, x, y):return torch.reciprocal(1 + torch.exp(-self.k * (x - y)))def debug_main():x = torch.rand((8, 256, 160, 160))head_model = DBHead(in_channels=256, out_channels=2)head_model.train()y = head_model(x)print('==y.shape:', y.shape)head_model.eval()y = head_model(x)print('==y.shape:', y.shape)if __name__ == '__main__':debug_main()

3.sENet中的attention

目的对于不同通道进行加权,先squeeze将h*w*c global averge pooling成1*1*c特征,在经过两层线性层,通过sigmoid输出加权在不同通道。


import torch
import torch.nn as nn
import torch.nn.functional as F
class SELayer(nn.Module):def __init__(self, channel, reduction=16):super(SELayer, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1) # 压缩空间self.fc = nn.Sequential(nn.Linear(channel, channel // reduction, bias=False),nn.ReLU(inplace=True),nn.Linear(channel // reduction, channel, bias=False),nn.Sigmoid())def forward(self, x):b, c, _, _ = x.size()y = self.avg_pool(x).view(b, c)y = self.fc(y).view(b, c, 1, 1)return x * ydef debug_attention():attention_module = SELayer(channel=128, reduction=16)# B,C,H,Wx = torch.rand((2, 128, 100, 100))out = attention_module(x)print('==out.shape:', out.shape)if __name__ == '__main__':debug_attention()

4.cv中的self-attention

(1).feature map通过1*1卷积获得,q,k,v三个向量,q与v转置相乘得到attention矩阵,进行softmax归一化到0到1,在作用于V,得到每个像素的加权.

(2).softmax

(3).加权求和


import torch
import torch.nn as nn
import torch.nn.functional as Fclass Self_Attn(nn.Module):""" Self attention Layer"""def __init__(self, in_dim):super(Self_Attn, self).__init__()self.chanel_in = in_dimself.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)self.gamma = nn.Parameter(torch.zeros(1))self.softmax = nn.Softmax(dim=-1)def forward(self, x):"""inputs :x : input feature maps( B * C * W * H)returns :out : self attention value + input featureattention: B * N * N (N is Width*Height)"""m_batchsize, C, width, height = x.size()proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1)  # B*N*Cproj_key = self.key_conv(x).view(m_batchsize, -1, width * height)  # B*C*Nenergy = torch.bmm(proj_query, proj_key)  # batch的matmul B*N*Nattention = self.softmax(energy)  # B * (N) * (N)proj_value = self.value_conv(x).view(m_batchsize, -1, width * height)  # B * C * Nout = torch.bmm(proj_value, attention.permute(0, 2, 1))  # B*C*Nout = out.view(m_batchsize, C, width, height)  # B*C*H*Wout = self.gamma * out + xreturn out, attentiondef debug_attention():attention_module = Self_Attn(in_dim=128)#B,C,H,Wx = torch.rand((2, 128, 100, 100))attention_module(x)if __name__ == '__main__':debug_attention()

5.spp多窗口pooling

import torch
import torch.nn as nn
import torch.nn.functional as F
class SPP(nn.Module):"""Spatial Pyramid Pooling"""def __init__(self):super(SPP, self).__init__()def forward(self, x):x_1 = F.max_pool2d(x, kernel_size=5, stride=1, padding=2)x_2 = F.max_pool2d(x, kernel_size=9, stride=1, padding=4)x_3 = F.max_pool2d(x, kernel_size=13, stride=1, padding=6)x = torch.cat([x, x_1, x_2, x_3], dim=1)return xdef debug_spp():x = torch.rand((8,3,256,256))spp = SPP()x = spp(x)print('==x.shape:', x.shape)if __name__ == '__main__':debug_spp()

6.RetinaFPN

# coding: utf-8
import torch
import torch.nn as nn
import torch.nn.functional as Fclass RetinaFPN(nn.Module):def __init__(self,C3_inplanes,C4_inplanes,C5_inplanes,planes,use_p5=False):super(RetinaFPN, self).__init__()self.use_p5 = use_p5self.P3_1 = nn.Conv2d(C3_inplanes,planes,kernel_size=1,stride=1,padding=0)self.P3_2 = nn.Conv2d(planes,planes,kernel_size=3,stride=1,padding=1)self.P4_1 = nn.Conv2d(C4_inplanes,planes,kernel_size=1,stride=1,padding=0)self.P4_2 = nn.Conv2d(planes,planes,kernel_size=3,stride=1,padding=1)self.P5_1 = nn.Conv2d(C5_inplanes,planes,kernel_size=1,stride=1,padding=0)self.P5_2 = nn.Conv2d(planes,planes,kernel_size=3,stride=1,padding=1)if self.use_p5:self.P6 = nn.Conv2d(planes,planes,kernel_size=3,stride=2,padding=1)else:self.P6 = nn.Conv2d(C5_inplanes,planes,kernel_size=3,stride=2,padding=1)self.P7 = nn.Sequential(nn.ReLU(),nn.Conv2d(planes, planes, kernel_size=3, stride=2, padding=1))def forward(self, inputs):[C3, C4, C5] = inputsP5 = self.P5_1(C5)P4 = self.P4_1(C4)P4 = F.interpolate(P5, size=(P4.shape[2], P4.shape[3]),mode='nearest') + P4P3 = self.P3_1(C3)P3 = F.interpolate(P4, size=(P3.shape[2], P3.shape[3]),mode='nearest') + P3P5 = self.P5_2(P5)P4 = self.P4_2(P4)P3 = self.P3_2(P3)if self.use_p5:P6 = self.P6(P5)else:P6 = self.P6(C5)del C3, C4, C5P7 = self.P7(P6)return [P3, P4, P5, P6, P7]if __name__ == '__main__':image_h, image_w = 640, 640fpn = RetinaFPN(512, 1024, 2048, 256)C3, C4, C5 = torch.randn(3, 512, 80, 80), torch.randn(3, 1024, 40, 40), torch.randn(3, 2048, 20, 20)[P3, P4, P5, P6, P7] = fpn([C3, C4, C5])print("P3", P3.shape)print("P4", P4.shape)print("P5", P5.shape)print("P6", P6.shape)print("P7", P7.shape)

7.Focus


import torch
import torch.nn as nndef autopad(k, p=None):  # kernel, padding# Pad to 'same'if p is None:p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-pad# print('==p:', p)return pclass Conv(nn.Module):# Standard convolutiondef __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):  # ch_in, ch_out, kernel, stride, padding, groupssuper(Conv, self).__init__()self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)self.bn = nn.BatchNorm2d(c2)self.act = nn.Hardswish() if act else nn.Identity()def forward(self, x):return self.act(self.bn(self.conv(x)))def fuseforward(self, x):return self.act(self.conv(x))class Focus(nn.Module):# Focus wh information into c-spacedef __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):  # ch_in, ch_out, kernel, stride, padding, groupssuper(Focus, self).__init__()self.conv = Conv(c1 * 4, c2, k, s, p, g, act)def forward(self, x):  # x(b,c,w,h) -> y(b,4c,w/2,h/2)return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))def debug_focus():model = Focus(c1=3, c2=24)img = torch.rand((8, 3, 124, 124))print('==img.shape', img.shape)out = model(img)print('===out.shape', out.shape)debug_focus()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/493164.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

硅片行业:过剩背景下的寡头市场

来源:乐晴智库精选▌竞争格局:过剩背景下的寡头市场,规模壁垒初步形成光伏产业总体处于产能过剩的状态,硅片环节的过剩尤为突出。根据PVInfolink的统计数据,截至2018年2季度末,全球硅片总产能超过160GW,年化…

从attention到Transformer+CV中的self-attention

一.总体结构 由于rnn等循环神经网络有时序依赖,导致无法并行计算,而Transformer主体框架是一个encoder-decoder结构,去掉了RNN序列结构,完全基于attention和全连接。同时为了弥补词与词之间时序信息,将词位置embedding…

12年后,人工智能和人类会是什么样?这是900位专家的看法|报告

来源:机器之能摘要:有分析师预计,到2030年,在复杂的数字系统中,人们将更加依赖于网络人工智能。 有人说,随着对这些网络工具的广泛使用,我们将继续沿着历史的轨迹生活地更好。也有一些人说&…

水印去除(基于nosie2noise优化 代码+模型)

github链接 1.感受野计算: :本层感受野; :上层感受野; :第i层卷积或池化的步长 k:本层卷积核大小 2.空洞卷积卷积核计算:Kk(k-1)(r-1),k为原始卷积核大小,r为空洞卷积参数空洞率,带入上式即可计算空洞卷积感受野; 3.针对noi…

广度深度都要,亚马逊是如何推动 Alexa 内生成长的?

来源:雷锋网摘要:发展到今天,Alexa 已经成为亚马逊旗下最重要的几个业务支柱之一,尤其是在人工智能语音助手层面,它和 Google Assistant、Apple Siri、Microsoft Cortana 并驾齐驱,甚至在应用场景上有领先之…

剖析云平台中的“共享型数据库”

剖析云计 算中的“共享型数据库” 摘要: 随着云计算的出现,出现了很多新的名词,像云数据库、云存储、弹性扩容,资源隔离等词汇。下面就大家炒的比较热的“共享型数据库”做一下解释,给大家剖析什么叫“共享型数据库”。…

FCOS: A Simple and Strong Anchor-free Object Detector

论文链接 一.背景 1.anchor-base缺点          (1).anchor的设置对结果影响很大,不同项目这些超参都需要根据经验来确定,难度较大. (2).anchor太过密集,其中很多是负样本&#xff…

大数据有十大应用领域,看看你用到了哪个?

来源:网络大数据摘要:如果提到“大数据”时,你会想到什么?也许大部分人会联想到庞大的服务器集群;或者联想到销售商提供的一些个性化的推荐和建议。如今大数据的深度和广度远不止这些,大数据已经在人类社会实践中发挥着巨大的优势…

2018年《环球科学》十大科学新闻出炉:霍金逝世、贺建奎事件位列前二

来源:量子位如果要用两个词来定义2018年的话,我们可能会选择“进步”与“反思”。中国科学在持续进步,克隆猴“中中”与“华华”、单条染色体的酵母,都是世界级的研究成果。“火星快车”在火星上发现大面积的液态湖泊,…

CornerNet: Detecting Objects as Paired Keypoints

CornerNet论文链接 Hourglass Network论文链接 一.背景 1.anchor-base缺点          (1).anchor的设置对结果影响很大,不同项目这些超参都需要根据经验来确定,难度较大. (2).anchor太过密集&…

详细解读什么是自适应巡航?

来源:智车科技摘要:自适应巡航设计初衷是减轻驾驶员长途驾驶的疲劳,极为复杂的城市路况并不是它发挥作用的地方。虽然现在的自适应巡航系统具备了根据前车情况、根据路况减速,甚至是刹停的功能,不过其开发之初便是为了…

CenterNet:Objects as Points

CenterNet论文链接 一.背景 1.anchor-base缺点          (1).anchor的设置对结果影响很大,不同项目这些超参都需要根据经验来确定,难度较大. (2).anchor太过密集,其中很多是负样本…

美国正在衰落的24个行业:“猝不及防”还是“温水煮青蛙”?

来源:资本实验室摘要:技术发展一日千里,外部环境日新月异。在这个变化无处不在的世界,许多行业都在不可避免地经历着或是猝不及防,或是“温水煮青蛙”般的冲击。近期,美国财经网站24/7 Wallst根据过去十年的…

距离与相似度计算

一.余弦相似度 加速计算参考这篇文章 from math import *def square_rooted(x):return round(sqrt(sum([a*a for a in x])), 3)def cosine_similarity(x,y):numerator sum(a*b for a, b in zip(x,y))denominator square_rooted(x)*square_rooted(y)return round(numerator/f…

5G 产业链重要细分投资领域

来源:乐晴智库精选▌2019年全球电子产业将保持增长ICInsights预计2018年全球电子产品销售额16220亿美元,同比增长5.1%,2019年将达到16800亿美元,同比增长3.5%,2017~2021年CAGR4.6%。预计2019年通信市场销售额5350亿美元…

CPNDet:Corner Proposal Network for Anchor-free, Two-stage Object Detection

CPNDet论文链接 一.背景 anchor-based方法将大量框密集分布在feature map上,在推理时,由于预设的anchor与目标差异大,召回率会偏低。而anchor-free不受anchor大小限制,在任意形状上会更加灵活,但是像CornerNet这种,先…

Unix/Linux环境C编程入门教程(3) Oracle Linux 环境搭建

Unix/Linux版本众多,我们推荐Unix/Linux初学者选用几款典型的Unix/Linux操作系统进行学习。2010年9月,Oracle Enterprise Linux发布新版内核——Unbreakable Enterprise Kernel,专门针对Oracle软件与硬件进行优化,最重要的是Oracl…

最权威北美放射学会年会回顾:AI的进化与下一个前沿

翻译 : 高璇摘要:人工智能在成像领域的前景必须为终端用户带来时间节省、资源优化、精度增益和感知增益(接近精准健康方法)。前两个是指生产力方面,而后两个是指质量方面。人工智能在成像领域的脚步不会停留在这里——它已经帮助重…

Registry注册机制

前言:不管是Detectron还是mmdetection,都有用到这个register机制,特意去弄明白,记录一下。 首先看Registry代码: # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reservedfrom typing import Dict, Optional, Iterable, T…

Android Volley 库通过网络获取 JSON 数据

本文内容 什么是 Volley 库 Volley 能做什么 Volley 架构 环境 演示 Volley 库通过网络获取 JSON 数据 参考资料 Android 关于网络操作一般都会介绍 HttpClient 以及 HttpConnection 这两个包。前者是 Apache 开源库,后者是 Android 自带 API。企业级应用&#xff0…