16、PyTorch中进行卷积残差模块算子融合

文章目录

  • 1. 1x1卷积核-> 3x3卷积核
  • 2. 输入x --> 3x3卷积核,无变化
  • 3. 代码

1. 1x1卷积核-> 3x3卷积核

假设我们有一个1x1的卷积核,需要通过填充变为一个3x3的卷积核,实现的是像素之间无关联
[ 4 ] → [ 0 0 0 0 4 0 0 0 0 ] \begin{equation} \begin{bmatrix}4\end{bmatrix}\to \begin{bmatrix} 0&0&0\\\\ 0&4&0\\\\ 0&0&0\end{bmatrix} \end{equation} [4] 000040000

2. 输入x --> 3x3卷积核,无变化

我们希望有一个x,用3x3的卷积核表示后依然不变,那么首先是3x3的卷积核本身移动过程中不会改变像素值,像素之间不融合,其次是空间中不融合,假设我们有一个卷积定义如下
c o n 2 d ( 2 , 2 , 3 , p a d d i n g = " s a m e " ) \begin{equation} con2d(2,2,3,padding="same") \end{equation} con2d(2,2,3,padding="same")
可得: 输出通道为2,输入通道为2,卷积核大小为3,padding=“same”表示卷积核图像不变
卷积权重大小为(2,2,3,3)

  • 可以把(2,2,3,3)简单拆分成两个部分,第一个为(2,3,3)的卷积核矩阵,实现的是卷积滑动操作,(2,2)表示的将输入通道数2转换成输出通道数2,那么一个2x2的矩阵,怎样才能够实现通道分离呢?一般就是对角矩阵,那么可以简单看做如下:
    [ a b b a ] ; a → [ 0 0 0 0 1 0 0 0 0 ] ; b → [ 0 0 0 0 0 0 0 0 0 ] \begin{equation} \begin{bmatrix} a&b\\\\ b&a \end{bmatrix};a\to \begin{bmatrix} 0&0&0\\\\ 0&1&0\\\\ 0&0&0 \end{bmatrix};b\to \begin{bmatrix} 0&0&0\\\\ 0&0&0\\\\ 0&0&0 \end{bmatrix} \end{equation} abba ;a 000010000 ;b 000000000
  • 这样就实现了通道上的分离和像素上的分离。

3. 代码

#!/usr/bin/env python
# -*- coding:utf-8 -*-
# @FileName  :confusion_conv2d.py
# @Time      :2024/12/9 8:23
# @Author    :Jason Zhang
import torch
from torch import nn
import torch.nn.functional as Fif __name__ == "__main__":run_code = 0in_channels = 2out_channels = 2kernel_size = 3w = 9h = 9x = torch.ones(1, in_channels, w, h)  # input image size# 1. pytorch method 1.1 x --> image = 1,2,9,9 1.2 kernel --> conv2d -->2,2,3,3 --> 2,2,3,3 --> kernel_size =3x3# 1.3 compute . if we have out_chanel for the 2 ,and# each channel has 2 kernel with 3x3, total 2x2--> 4 counts for 3x3conv_2d = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding="same")print(f"conv_2d.weight.shape={conv_2d.weight.shape}")print(f"conv_2d.weight={conv_2d.weight}")conv_2d_pointwise = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)result1 = conv_2d(x) + conv_2d_pointwise(x) + xprint(f"result1.shape=\n{result1.shape}")print(f"result1=\n{result1}")zeros = torch.unsqueeze(torch.zeros(kernel_size, kernel_size), 0)starts = torch.unsqueeze(F.pad(torch.ones(1, 1), (1, 1, 1, 1)), 0)print(f"zeros=\n{zeros}")print(f"zeros.shape={zeros.shape}")print(f"starts=\n{starts}")print(f"starts.shape=\n{starts.shape}")starts_zeros = torch.unsqueeze(torch.cat((starts, zeros), 0), 0)zeros_starts = torch.unsqueeze(torch.cat((zeros, starts), 0), 0)print(f"start_zeros=\n{starts_zeros}")print(f"start_zeros.shape=\n{starts_zeros.shape}")print(f"zeros_starts=\n{zeros_starts}")print(f"zeros_starts.shape=\n{zeros_starts.shape}")identity_weight = torch.cat((starts_zeros, zeros_starts), 0)identity_bias = torch.zeros(out_channels)print(f"identity=\n{identity_weight}")print(f"identity.shape=\n{identity_weight.shape}")test_conv2d = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding="same")test_conv2d_weight = test_conv2d.weighttest_conv2d_bias = test_conv2d.biasprint(test_conv2d_weight.shape)print(test_conv2d_bias.shape)test_conv2d.weight = nn.Parameter(identity_weight)test_conv2d.bias = nn.Parameter(identity_bias)input_x = torch.randint(1, 10, (1, 2, 9, 9), dtype=torch.float)out_y = test_conv2d(input_x)print(f"input_x=\n{input_x}")print(f"out_y=\n{out_y}")check_out = torch.allclose(input_x, out_y)print(f"input_x is {check_out} same for out_y")point_wise = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, padding="same")point_wise_weight = F.pad(point_wise.weight, (1, 1, 1, 1, 0, 0, 0, 0))point_wise_bias = point_wise.biasprint(f"point_wise=\n{point_wise_weight}")print(f"point_wise.shape=\n{point_wise_weight.shape}")point_wise_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,padding="same")point_wise_conv.weight = nn.Parameter(point_wise_weight)point_wise_conv.bias = nn.Parameter(point_wise_bias)point_wise_out = point_wise(input_x)print(f"point_wise_out=\n{point_wise_out}")point_3_wise_out = point_wise_conv(input_x)check_3 = torch.allclose(point_wise_out, point_3_wise_out)print(f"check_3 is {check_3} same for point_3_wise_out")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/63728.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深入理解代理模式(Proxy):静态代理、动态代理与AOP

目录 1. 代理模式简介2. 静态代理3. 动态代理 3.1 JDK动态代理3.2 CGLIB动态代理 4. 面向切面编程(AOP)5. 实战示例6. 总结与最佳实践 1. 代理模式简介 代理模式是一种结构型设计模式,它允许我们提供一个代理来控制对其他对象的访问。代理模式在不改变原始类代码…

java+springboot+mysql私人会所管理系统

项目介绍: 使用javaspringbootmysql开发的私人会所管理系统,系统包含管理员、技师、用户角色,功能如下: 管理员:用户管理;服务项目;技师管理;房间管理;预约管理&#x…

Formality:set_svf命令

相关阅读 Formalityhttps://blog.csdn.net/weixin_45791458/category_12841971.html?spm1001.2014.3001.5482 svf文件的全称是Setup Verification for Formality,即Design Compiler提供给Formality的设置验证文件,它的作用是为Formality的指导模式(Gui…

Hive 数据操作语言全面解析

Hive 数据操作语言全面解析 在 Hive 大数据处理框架中,数据操作语言(DML)提供了多种方式来操作和修改数据,包括数据的加载、插入、更新、删除以及合并等操作。本文将详细介绍 Hive 中各类数据操作语句的语法、用法、注意事项以及…

JS API日期对象

目标:掌握日期对象,可以让网页显示日期 日期对象:用来表示时间的对象 作用:可以得到当前系统时间 实例化 目标:能够实现实例化日期对象 在代码中发现了new关键字时,一般将这个操作称为实例化 创建一个时…

【前端】JavaScript中的闭包与垃圾回收机制详解

博客主页: [小ᶻ☡꙳ᵃⁱᵍᶜ꙳] 本文专栏: 前端 文章目录 💯前言💯垃圾回收机制(Garbage Collection, GC)垃圾回收的核心原理核心过程 函数作用域与垃圾回收运行分析输出结果 垃圾回收的局限性与挑战 &#x1f4a…

单臂路由配置

知识点 单臂路由指在路由器上的一个接口配置子接口(逻辑接口)来实现不同vlan间通信 路由器上的每个物理接口都可以配置多个子接口(逻辑接口) 公司的财务部、技术部和业务部有多台计算机,它们使用一台二层交换机进行互…

verilog编程规范

verilog编程规范 文章目录 verilog编程规范前言一、代码划分二、verilog编码ABCDEFG 前言 高内聚,低耦合,干净清爽的代码 一、代码划分 高内聚: 一个功能一个模块干净的接口提取公共的代码 低耦合: 模块之间低耦合尽量用少量…

WEB安全基础知识

WAF全称为Web Application Firewall(网页应用防火墙)是一种专门设计用来保护web应用免受各种网络攻击的安全防护措施。它位于客户端与服务器之间,监控和过滤HTTP流量,从而拦截恶意请求、识别并防御常见的web攻击。 WAF的主要功能…

qemu安装arm64架构银河麒麟

qemu虚拟化软件,可以在一个平台上模拟另一个硬件平台,可以支持多种处理器架构。 一、安装 安装教程:https://blog.csdn.net/qq_36035382/article/details/125308044 下载链接:https://qemu.weilnetz.de/w64/2024/ 我下载的是 …

前端怎么用 EventSource?EventSource 怎么配置请求头及加参数?EventSourcePolyfill 使用方法

前言 在前端开发中,特别是实时数据更新的场景下,EventSource 是一个非常实用的 API。它允许浏览器与服务器建立单向连接,服务器可以持续地发送数据给客户端,而无需客户端不断轮询。本文将详细介绍 EventSource 的使用方法、如何配…

188-下翻便携式6U CPCI工控机箱

一、板卡概述 下翻式CPCI便携工控机,系统采用6u cpci背板结构,1个系统槽,7个扩展槽, 满足对携带的需求,可装标准6U8槽CPCI主板,8个扩展槽, 满足客户对空间扩展的需求.可宽温服务的工作产品,15高亮度液晶显示屏,超薄88键笔记本键盘,触摸式鼠标,加固型机箱结构,使它能够适应各种复…

网页核心页面设计(第9章)

一、多个边框阴影 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible" content"IEedge"><meta name"viewport" content"widthdevice-…

SpringBoot中Selenium详解

文章目录 SpringBoot中Selenium详解一、引言二、集成Selenium1、环境准备1.1、添加依赖 2、编写测试代码2.1、测试主类2.2、页面对象2.3、搜索组件 三、使用示例四、总结 SpringBoot中Selenium详解 一、引言 在现代软件开发中&#xff0c;自动化测试是提高软件质量、减少重复…

Edge SCDN的独特优势有哪些?

强大的边缘计算能力 Edge SCDN&#xff08;边缘安全加速&#xff09;是酷盾安全推出的边缘集分布式 DDoS 防护、CC 防护、WAF 防护、BOT 行为分析为一体的安全加速解决方案。通过边缘缓存技术&#xff0c;智能调度使用户就近获取所需内容&#xff0c;为用户提供稳定快速的访问…

Fastapi教程:使用aioredis异步访问redis

本文将介绍如何使用 FastAPI 异步访问 Redis&#xff0c;包括环境配置、连接创建、数据库初始化、增删查改操作、键过期、管道&#xff08;pipeline&#xff09;操作以及事务管理等内容。 环境准备 首先&#xff0c;我们需要安装必要的依赖包。Redis 是一个键值存储系统&…

duxapp 2024-12-09 更新 PullView可以弹出到中间,优化CLI使用体验

UI库 修复 Button 禁用状态失效的问题Modal 组件即将停用&#xff0c;请使用 PullView 基础库 PullView side 新增 center 指定弹出到屏幕中间PullView 新增 duration 属性&#xff0c;指定动画时长新增 useBackHandler hook 用来阻止安卓端点击返回键 RN端 修复 windows …

多线程与线程互斥

目录 引言 一、多线程设计 多线程模拟抢票 二、互斥锁 互斥量的接口 修改抢票代码 锁的原理 锁的封装&#xff1a;RAII 引言 随着信息技术的飞速发展&#xff0c;计算机软件正变得越来越复杂&#xff0c;对性能和响应速度的要求也日益提高。在这样的背景下&#xff0c;…

Vue导出报表功能【动态表头+动态列】

安装依赖包 npm install -S file-saver npm install -S xlsx npm install -D script-loader创建export-excel.vue组件 代码内容如下&#xff08;以element-ui样式代码示例&#xff09;&#xff1a; <template><el-button type"primary" click"Expor…

ZUC256 Go Go Go!!!

文章目录 背景运行效果代码 背景 因业务需要使用ZUC算法&#xff0c;GitHub上又没有对ZUC256相对应的Go语言的实现。 吃水不忘挖井人&#xff0c;在这里感谢GmSSL及BouncyCastle两个强大的密码学库&#xff01; 本ZUC256的编写&#xff0c;参考了这两个库及中科院软件院发布的…