padding和卷积的区别_TensorFlow笔记1——20.CNN卷积神经网络padding两种模式SAME和VALID...

第1种解说:(核心最后一张图,两种填充方式输出的形状尺寸计算公式)

在用tensorflow写CNN的时候,调用卷积核api的时候,会有填padding方式的参数,找到源码中的函数定义如下(max pooling也是一样):

def conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None, data_format=None, name=None)

源码中对于padding参数的说明如下:

padding: A string from: "SAME", "VALID". The type of padding algorithm to use.

说了padding可以用“SAME”和“VALID”两种方式,但是对于这两种方式具体是什么并没有多加说明。 这里用Stack Overflow中的一份代码来简单说明一下,代码如下:

x = tf.constant([[1., 2., 3.],[4., 5., 6.]])
x = tf.reshape(x, [1, 2, 3, 1])  # give a shape accepted by tf.nn.max_pool
valid_pad = tf.nn.max_pool(x, [1, 2, 2, 1], [1, 2, 2, 1], padding='VALID')
same_pad = tf.nn.max_pool(x, [1, 2, 2, 1], [1, 2, 2, 1], padding='SAME')print(valid_pad.get_shape())
print(same_pad.get_shape())
# 最后输出的结果为:
(1, 1, 1, 1) 
(1, 1, 2, 1)

可以看出“SAME”的填充方式是比“VALID”的填充方式多了一列。 让我们来看看变量x是一个2x3的矩阵,max pooling窗口为2x2,两个维度的strides=2。 第一次由于窗口可以覆盖(橙色区域做max pool操作),没什么问题,如下:

149af4b7ff7e6cd14ce1ae6cfff48a9e.png

接下来就是“SAME”和“VALID”的区别所在,由于步长为2,当向右滑动两步之后“VALID”发现余下的窗口不到2x2所以就把第三列直接去了,而“SAME”并不会把多出的一列丢弃,但是只有一列了不够2x2怎么办?填充!

2c425d3b660785dfb1de0d5e9dc813d7.png

如上图所示,“SAME”会增加第四列以保证可以达到2x2,但为了不影响原来的图像像素信息,一般以0来填充。(这里使用表格的形式展示,markdown不太好控制格式,明白意思就行),这就不难理解不同的padding方式输出的形状会有所不同了。

在CNN用在文本中时,一般卷积层设置卷积核的大小为n×k,其中k为输入向量的维度(即[n,k,input_channel_num,output_channel_num]),这时候我们就需要选择“VALID”填充方式,这时候窗口仅仅是沿着一个维度扫描而不是两个维度。可以理解为统计语言模型当中的N-gram。

我们设计网络结构时需要设置输入输出的shape,源码nn_ops.py中的convolution函数和pool函数给出的计算公式如下:

If padding == "SAME":output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides[i])If padding == "VALID":output_spatial_shape[i] =ceil((input_spatial_shape[i] -(spatial_filter_shape[i]-1) * dilation_rate[i])/ strides[i]).

dilation_rate为一个可选的参数,默认为1,这里我们可以先不管它。 整理一下,对于“VALID”,输出的形状计算如下:

e55375d35b236b2f200820b21ca09e87.png

参考<https://cloud.tencent.com/developer/article/1012365>

第2种解说:利用tf.nn.conv2d示例来理解 strides, padding效果

这里先再简单重复一下tf.nn.conv2d使用,其基本参数的使用规范同样也适用于其他CNN语句

tf.nn.conv2d (input, filter, strides, padding, use_cudnn_on_gpu=None, data_format=None, name=None)
  • input : 输入的要做卷积的图片,要求为一个张量,shape为 [ batch, in_height, in_weight, in_channel ],其中batch为图片的数量,in_height 为图片高度,in_weight 为图片宽度,in_channel 为图片的通道数,灰度图该值为1,彩色图为3。(也可以用其它值,但是具体含义不是很理解)
  • filter: 卷积核,要求也是一个张量,shape为 [ filter_height, filter_weight, in_channel, out_channels ],其中 filter_height 为卷积核高度,filter_weight 为卷积核宽度,in_channel 是图像通道数 ,和 input 的 in_channel 要保持一致,out_channel 是卷积核数量。
  • strides: 卷积时在图像每一维的步长,这是一个一维的向量,[ 1, strides, strides, 1],第一位和最后一位固定必须是1
  • padding: string类型,值为“SAME” 和 “VALID”,表示的是卷积的形式,是否考虑边界。"SAME"是考虑边界,不足的时候用0去填充周围,"VALID"则不考虑
  • use_cudnn_on_gpu: bool类型,是否使用cudnn加速,默认为true
import tensorflow as tf
# case 1
# 输入是1张 3*3 大小的图片,图像通道数是5,卷积核是 1*1 大小,数量是1
# 步长是[1,1,1,1]最后得到一个 3*3 的feature map
# 1张图最后输出就是一个 shape为[1,3,3,1] 的张量
input = tf.Variable(tf.random_normal([1,3,3,5]))
filter = tf.Variable(tf.random_normal([1,1,5,1]))
op1 = tf.nn.conv2d(input, filter, strides=[1,1,1,1], padding='SAME')# case 2
# 输入是1张 3*3 大小的图片,图像通道数是5,卷积核是 2*2 大小,数量是1
# 步长是[1,1,1,1]最后得到一个 3*3 的feature map
# 1张图最后输出就是一个 shape为[1,3,3,1] 的张量 
input = tf.Variable(tf.random_normal([1,3,3,5]))
filter = tf.Variable(tf.random_normal([2,2,5,1]))
op2 = tf.nn.conv2d(input, filter, strides=[1,1,1,1], padding='SAME')# case 3  
# 输入是1张 3*3 大小的图片,图像通道数是5,卷积核是 3*3 大小,数量是1
# 步长是[1,1,1,1]最后得到一个 1*1 的feature map (不考虑边界)
# 1张图最后输出就是一个 shape为[1,1,1,1] 的张量
input = tf.Variable(tf.random_normal([1,3,3,5]))  
filter = tf.Variable(tf.random_normal([3,3,5,1]))  
op3 = tf.nn.conv2d(input, filter, strides=[1, 1, 1, 1], padding='VALID') # case 4
# 输入是1张 5*5 大小的图片,图像通道数是5,卷积核是 3*3 大小,数量是1
# 步长是[1,1,1,1]最后得到一个 3*3 的feature map (不考虑边界)
# 1张图最后输出就是一个 shape为[1,3,3,1] 的张量
input = tf.Variable(tf.random_normal([1,5,5,5]))  
filter = tf.Variable(tf.random_normal([3,3,5,1]))  
op4 = tf.nn.conv2d(input, filter, strides=[1, 1, 1, 1], padding='VALID')  # case 5  
# 输入是1张 5*5 大小的图片,图像通道数是5,卷积核是 3*3 大小,数量是1
# 步长是[1,1,1,1]最后得到一个 5*5 的feature map (考虑边界)
# 1张图最后输出就是一个 shape为[1,5,5,1] 的张量
input = tf.Variable(tf.random_normal([1,5,5,5]))  
filter = tf.Variable(tf.random_normal([3,3,5,1]))  
op5 = tf.nn.conv2d(input, filter, strides=[1, 1, 1, 1], padding='SAME')  # case 6 
# 输入是1张 5*5 大小的图片,图像通道数是5,卷积核是 3*3 大小,数量是7
# 步长是[1,1,1,1]最后得到一个 5*5 的feature map (考虑边界)
# 1张图最后输出就是一个 shape为[1,5,5,7] 的张量
input = tf.Variable(tf.random_normal([1,5,5,5]))  
filter = tf.Variable(tf.random_normal([3,3,5,7]))  
op6 = tf.nn.conv2d(input, filter, strides=[1, 1, 1, 1], padding='SAME')  # case 7  
# 输入是1张 5*5 大小的图片,图像通道数是5,卷积核是 3*3 大小,数量是7
# 步长是[1,2,2,1]最后得到7个 3*3 的feature map (考虑边界)
# 1张图最后输出就是一个 shape为[1,3,3,7] 的张量
input = tf.Variable(tf.random_normal([1,5,5,5]))  
filter = tf.Variable(tf.random_normal([3,3,5,7]))  
op7 = tf.nn.conv2d(input, filter, strides=[1, 2, 2, 1], padding='SAME')  # case 8  
# 输入是10 张 5*5 大小的图片,图像通道数是5,卷积核是 3*3 大小,数量是7
# 步长是[1,2,2,1]最后每张图得到7个 3*3 的feature map (考虑边界)
# 10张图最后输出就是一个 shape为[10,3,3,7] 的张量
input = tf.Variable(tf.random_normal([10,5,5,5]))  
filter = tf.Variable(tf.random_normal([3,3,5,7]))  
op8 = tf.nn.conv2d(input, filter, strides=[1, 2, 2, 1], padding='SAME')  init = tf.initialize_all_variables() 
with tf.Session() as sess:sess.run(init)print('*' * 20 + ' op1 ' + '*' * 20)print(sess.run(op1))print('*' * 20 + ' op2 ' + '*' * 20)print(sess.run(op2))print('*' * 20 + ' op3 ' + '*' * 20)print(sess.run(op3))print('*' * 20 + ' op4 ' + '*' * 20)print(sess.run(op4))print('*' * 20 + ' op5 ' + '*' * 20)print(sess.run(op5))print('*' * 20 + ' op6 ' + '*' * 20)print(sess.run(op6))print('*' * 20 + ' op7 ' + '*' * 20)print(sess.run(op7))print('*' * 20 + ' op8 ' + '*' * 20)print(sess.run(op8))

# 运行结果

运行结果这里就省略了,太长了,所以不写这里了。复制语句到Jupyter中运行一下就懂了

参考<理解tf.nn.conv2d方法>

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/453875.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MORMOT数据库连接池

MORMOT数据库连接池 MORMOT封装了一堆的PROPS控件&#xff0c;用于连接各种数据库。 MORMOT的封装是武装到了牙齿的&#xff0c;这堆PROPS控件居然数据库连接池也封装好了。这就为我们省了不少事&#xff0c;笔者非常喜欢&#xff01; 下面拿TOleDBMSSQLConnectionProperties控…

循环神经网络变形之 (Long Short Term Memory,LSTM)

1、长短期记忆网络LSTM简介 在RNN 计算中&#xff0c;讲到对于传统RNN水平方向进行长时刻序列依赖时可能会出现梯度消失或者梯度爆炸的问题。LSTM 特别适合解决这种需要长时间依赖的问题。 LSTM&#xff08;Long Short Term Memory&#xff0c;长短期记忆网络&#xff09;是R…

Windows 系统下使用 MinGW + MSYS + GCC 编译 FFMPEG

一定要按照顺序操作&#xff0c;否则你很可能持续遇到很多奇怪的问题&#xff08;ffmpeg对编译系统版本要求比较高&#xff09;。 1. www.mingw.org: 下载并安装 MinGW 5.1.4 (http://jaist.dl.sourceforge.net/sourceforge/mingw/MinGW-5.1.4.exe)&#xff0c;安装时选中 g, m…

eclipse怎样改编码格式_Eclipse中各种编码格式及设置

操作系统&#xff1a;Windows 10(家庭中文版)Eclipse版本&#xff1a;Version: Oxygen.1a Release (4.7.1a)刚看到一篇文章&#xff0c;里面介绍说Ascii、Unicode是编码&#xff0c;而GBK、UTD-8等是编码格式。Java中的编码问题(by 迷失之路)&#xff1a;https://www.cnblogs.c…

UE4 ShooterGame Demo的开火的代码

之前一直没搞懂按下鼠标左键开火之后&#xff0c;代码的逻辑是怎么走的&#xff0c;今天看懂了之前没看懂的部分&#xff0c;进了一步 ShooterCharacter.cpp void AShooterCharacter::OnStartFire() {AShooterPlayerController* MyPC Cast<AShooterPlayerController>(Co…

kafka 异常:return ‘<SimpleProducer batch=%s>‘ % self.async ^ SyntaxError: invalid syntax

Python3.X 执行Python编写的生产者和消费者报错&#xff0c;报错信息如下&#xff1a; Traceback (most recent call last): File "mykit_kafka_producer.py", line 9, in <module> from kafka import KafkaProducer File "/usr/local/lib/python3.7/sit…

python 分布式计算框架_漫谈分布式计算框架

如果问 mapreduce 和 spark 什么关系&#xff0c;或者说有什么共同属性&#xff0c;你可能会回答他们都是大数据处理引擎。如果问 spark 与 tensorflow 呢&#xff0c;就可能有点迷糊&#xff0c;这俩关注的领域不太一样啊。但是再问 spark 与 MPI 呢&#xff1f;这个就更远了。…

Codeforces 899D Shovel Sale

题目大意 给定正整数 $n$&#xff08;$2\le n\le 10^9$&#xff09;。 考虑无序整数对 $(x, y)$&#xff08;$1\le x,y\le n, x\ne y$&#xff09;。 求满足 「$xy$ 结尾连续的 9 最多」的数对 $(x,y)$ 的个数。 例子&#xff1a; $n50$&#xff0c;$(49,50)$ 是一个满足条件的…

Windows系统使用minGW+msys 编译ffmpeg 0.5的全过程详述

一.环境配置 1.下载并安装 MinGW-5.1.4.exe (http://jaist.dl.sourceforge.net/sourcef … -5.1.4.exe)&#xff0c;安装时选中 g, mingw make。建议安装到c:/mingw. 2.下载并安装 MSYS-1.0.11-rc-1.exe (http://jaist.dl.sourceforge.net/sourcef … 1-rc-1.exe)&#xff0c;安…

Liunx安装gogs,mysql,jdk,tomcat等常用软件

Liunx CentOS系统采用yum安装Mysql 一.安装mysql客户端 yum -y install mysql 二.安装mysql服务器端 [注意:由于CentOS7下的不自带mysql-server,所以得先安装资源包,步骤: 1.wget http://repo.mysql.com/mysql-community-release-el7-5.noarch.rpm (采用wget获取必须有wge…

stm32单片机端口映射_STM32单片机的重映射与地址映射的使用方法及步骤

重映射STM32中对于一些端口的外设已经被其他引脚所使用&#xff0c;这是就需要用端口重映射来解决了&#xff0c;很方便。以USART1为例重映射的步骤为&#xff1a;打开重映射时钟和USART重映射后的I/O口引脚时钟&#xff0c;RCC_APB2PeriphClockCmd(RCC_APB2Periph_GPIOB|RCC_A…

python 第三方模块 yaml - 处理 YAML (专门用来写配置文件的语言)

markdown 的配置使用 Yaml —— Yet Another Markup Language &#xff1a;另一种标记语言。 简介 YAML 是专门用来写配置文件的语言&#xff0c;非常简洁和强大&#xff0c;远比 JSON 格式方便。 YAML在python语言中有PyYAML安装包。 YAML 语言&#xff08;发音 /ˈjməl/ &…

程序员 赚钱

业余编程赚钱 程序员的好方法 现在的人生活水平高了&#xff0c;开销也大了&#xff0c;同时对于一些技术性人员来说有很多种&#xff0c;有些程序员自己开公司&#xff0c;开发自己的产品&#xff0c;年赚百万&#xff0c;有些程序员还在给别人打工&#xff0c;每天累死累活的…

java合并单元格的快捷键_java poi合并单元格问题

使用poi导出的execl合并单元格&#xff0c;会出现下图问题整个单元格看似合并了&#xff0c;但是文字没有垂直居中&#xff0c;而且execl中所有的合并都会在第三行开始出现灰色分层样式合并单元格伪代码String upCompareField ""; //上一行的对比值for(int i 0; i …

webpack自动化构建脚本指令npm run dev/build

指令 为不同环境配置可执行指令&#xff0c;我们使用npm scripts方式&#xff0c;在package.json文件中配置执行指令&#xff1a; {"scripts": {"start": "cross-env NODE_ENVdev webpack-dev-server","build": "cross-env NODE_…

前端之 form 详解

认识表单 在一个页面上可以有多个form表单&#xff0c;但是向web服务器提交表单的时候&#xff0c;一次只可以提交一个表单。要声明一个表单&#xff0c;只需要使用 form 标记来标明表单的开始和结束&#xff0c;若需要向服务器提交数据&#xff0c;则在form标签中需要设置act…

代码 优化 指南 实践

C代码优化方案 华中科技大学计算机学院 姓名&#xff1a; 王全明 QQ&#xff1a; 375288012 Email&#xff1a; quanming1119163.com 目录 目录 C代码优化方案 1、选择合适的算法和数据结构 2、使用尽量小的数据类型 3、减少运算的强度 &#xff08;1&…

.12-浅析webpack源码之NodeWatchFileSystem模块总览

剩下一个watch模块&#xff0c;这个模块比较深&#xff0c;先大概过一下整体涉及内容再分部讲解。 流程图如下&#xff1a; NodeWatchFileSystem const Watchpack require("watchpack");class NodeWatchFileSystem {constructor(inputFileSystem) {this.inputFileSy…

Python 第三方模块之 beautifulsoup(bs4)- 解析 HTML

简单来说&#xff0c;Beautiful Soup是python的一个库&#xff0c;最主要的功能是从网页抓取数据。官方解释如下&#xff1a;官网文档 Beautiful Soup提供一些简单的、python式的函数用来处理导航、搜索、修改分析树等功能。 它是一个工具箱&#xff0c;通过解析文档为用户提供…

modal vue 关闭_Vue弹出框的优雅实践

引言页面引用弹出框组件是经常碰见的需求,如果强行将弹出框组件放入到页面中,虽然功能上奏效但没有实现组件与页面间的解耦,非常不利于后期的维护和功能的扩展.下面举个例子来说明一下这种做法的弊端.click"openModal()">点击 :is_open"is_open" close…