tf.reduce_sum()方法深度解析

首先看一下reduce_sum及其参数的注释 :

def tf.reduce_sum(input_tensor, axis=None, keepdims=False, name=None)

Computes the sum of elements across dimensions of a tensor.

Reduces input_tensor along the dimensions given in axis. Unless keepdims is true, the rank of the tensor is reduced by 1 for each entry in axis. If keepdims is true, the reduced dimensions are retained with length 1.

If axis is None, all dimensions are reduced, and a tensor with a
single element is returned.

1 参数只有一个矩阵 input_tensor,其余参数为默认

下面我们用三个张量为例(下统一称为矩阵),对一维向量,二维矩阵,三维矩阵分别执行其余参数为默认的reduce_sum()运算。

x_dim_1 = tf.constant([1,2,3])
z = tf.reduce_sum(x_dim_1)
print("1D sum :")
print(z)x_dim_2 = tf.constant([[1,2,3],[4,5,6]])
z = tf.reduce_sum(x_dim_2)
print("2D sum :")
print(z)x_dim_3 = tf.constant([[[1,2],[2,3],[3,4]],[[4,5],[5,6],[6,7]]])z = tf.reduce_sum(x_dim_3)
print("3D sum :")
print(z)

输出结果如下:
在这里插入图片描述
不难看出,无论是对几维矩阵而言,tf.reduce_sum()方法均对矩阵的所有元素进行求和,并将结果降维至一维,返回一个数值(0维矩阵)。 在这种情况下,tf.reduce_sum()方法与我们熟悉的numpy.sum()方法可以说在矩阵运算方面别无二致,同时也符合注释中所说 :

If axis is None, all dimensions are reduced, and a tensor with a
single element is returned.

2 保持keepdims参数为false,改变axis参数的值进行运算

这个部分就是reduce_sum的灵魂所在,其关键是要理解所谓axis参数对于高维矩阵的意义。
不妨先看看,如果我们打印出上述代码中的三个矩阵常量的shape,会得到什么:

在这里插入图片描述
需要理解的是,对于高维矩阵来说,其每一个元素都是一个低维矩阵。 例如,对于一维矩阵(数组) x_dim_1 而言,每一个元素都是一个0维的数组(数值) : 1,2,3。 对于高维矩阵,以三维矩阵x_dim_3 为例,其每一个元素都是一个二维的矩阵,这些二维矩阵中的每一个元素都是一个一维数组,这些一维数组中的每个元素都是一个0维矩阵(数值)。

听起来好像像绕口令,那么我们就以x_dim_3为例,具体解释axis与矩阵的shape之间的关系。
x_dim_3 = [ [ [1,2],[2,3],[3,4] ],
[ [4,5],[5,6],[6,7] ] ]

首先我们不难看出,这是一个三维矩阵。那么该矩阵包含了几个二维矩阵的元素呢,我们不难看出包含了两个,分别是 [ [1,2],[2,3],[3,4] ] 以及 [ [4,5],[5,6],[6,7] ] 两个二维矩阵。那么对于每一个作为元素的二维矩阵,又包含了多少个一维数组呢,显然,每一个二维矩阵都包含了三个一维数组。以第一个二维数组元素为例,包含了[1,2],[2,3],[3,4]三个一维数组元素。而每一个一维数组元素,又分别包含了两个0维数组(数值)。

这就是张量 x_dim_3 的shape的构成(2,3,2),其代表的数学意义为,这个三维矩阵包含了两个二维矩阵元素,每一个二维矩阵又包含了三个一维矩阵元素,每一个一维矩阵又包含了两个数值。同时,三个数值分别对应了axis=0 axis=1 axis=2 三个轴分别包含的元素个数。

问题实际变成了,沿着axis=0 : 对包含的两个二维矩阵元素进行求和, 沿着axis=1 : 对该维度包含的三个一维矩阵进行求和,沿着axis = 2 : 对该维度包含的两个数值进行求和。 下面放出对x_dim_3的测试及输出结果。

# 沿着axis = 0 对三维矩阵进行降维求和 实际就是对两个二维矩阵元素进行加法运算
x_dim_3 = tf.constant([[[1,2],[2,3],[3,4]],[[4,5],[5,6],[6,7]]])
z = tf.reduce_sum(x_dim_3,axis=0)
print(z)

输出结果如下:
在这里插入图片描述
注释 :
在这里插入图片描述

# 沿着axis=1 对三维矩阵进行降维求和 
x_dim_3 = tf.constant([[[1,2],[2,3],[3,4]],[[4,5],[5,6],[6,7]]])
z = tf.reduce_sum(x_dim_3,axis=1)
print(z)

在这里插入图片描述

# 沿着axis=2对三维矩阵进行降维求和 
x_dim_3 = tf.constant([[[1,2],[2,3],[3,4]],[[4,5],[5,6],[6,7]]])
z = tf.reduce_sum(x_dim_3,axis=2)
print(z)

在这里插入图片描述

3 总结

reduce_sum方法人如其名,减少维度的同时对矩阵内部沿着某一轴进行求和。与之类似的方法还有reduce_max, reduce_min, reduce_mean,他们关于axis参数的使用与本篇文章中分析的是别无二致的。当然,axis可以不是一个数值,而是一个矩阵,有兴趣的可以进一步探索与发现。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/386866.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

主成分分析(PCA)原理详解_转载

一、PCA简介 1. 相关背景 在许多领域的研究与应用中,往往需要对反映事物的多个变量进行大量的观测,收集大量数据以便进行分析寻找规律。多变量大样本无疑会为研究和应用提供了丰富的信息,但也在一定程度上增加了数据采集的工作量,…

Mac cnpm装包时提示Error: EACCES: permission denied解决办法

Cnpm装包时提示Error: EACCES: permission denied解决办法 2018年03月04日 09:31:51 miniminixu 阅读数:1598 版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/miniminixu/article/details/79434609 只需在cnpm …

特征点检测 FAST算法及代码详解

本文着重介绍了用于图像特征点检测的算法,FAST算法,以及使用matlab的实现。 FAST算法是一种拐点检测算法,其主要应用于提取图像中的特征点,在动态成像的一系列图像中追踪定位对象。众所周知,我们生活的世界是动态化的…

一文看懂计算机神经网络与梯度下降

1. 计算机神经网络与神经元 要理解神经网络中的梯度下降算法,首先我们必须清楚神经元的定义。如下图所示,每一个神经元可以由关系式yf(∑i1nwixib)y f(\sum_{i1}^nw_ix_i b)yf(∑i1n​wi​xi​b)来描述,其中X[x1,x2,...,xn]X [x_1,x_2,..…

vs2015web项目无法加载64位c++的dll,提示试图加载不正确的格式

vs2015无法加载64位c的dll,提示试图加载不正确的格式! 开始用winform引用64位的c的dll,在项目的属性设置生成里面选择any cpu或者x64都可以成功! 但在web项目和接口里面运行就提示试图加载不正确的格式,想办法找了一天也没处理掉&…

使用Rancher搭建K8S测试环境

环境准备(4台主机,Ubuntu16.04Docker1.12.6 SSH): rancher1 192.168.3.160 只做管理节点 node1 192.168.3.161 K8S的节点1 node2 192.168.3.162 K8S的节点2 node3 192.168.3.163 K8S的节点3 此时如…

Anaconda安装tensorflow报错问题解决方法

最近脱离了googlecolab想使用本地的anaconda进行机器学习课题的演练,在安装tensorflow时报错 : UnsatisfiableError: The following specifications were found。下面给出解决方法。 发现实际原因是由于anaconda的python环境,当前版本的tensorflow只能适…

yml的mybatis的sql查看

yml的mybatis的sql查看 控制台输出结果:

unity如何让canvas总是显示在所有层的最上方?

由于unity中的图层都是从上至下渲染的,那么在渲染的过程中,只需要将canvas所在的UI层的渲染优先级order排在其他层之后,就可以保证UI画面总是最后加载出来的了。 在canvas的inspector中修改order in layer 或者 sorting layer都可以实现这一…

关于同时可用git命令clone和TortoiseGit拉取代码不需要密码

工作需要在windows7下使用git分布式版本控制系统,需要同时可以在git命令行模式或TortoiseGit拉取代码而不需要每次输入密码。 这时候需要同时安装git和TortoiseGit。 git使用命令ssh-keygen -C “邮箱地址” -t rsa产生的密钥在TortoiseGit中不能用。TortoiseGit 使…

交叉验证 cross validation 与 K-fold Cross Validation K折叠验证

交叉验证,cross validation是机器学习中非常常见的验证模型鲁棒性的方法。其最主要原理是将数据集的一部分分离出来作为验证集,剩余的用于模型的训练,称为训练集。模型通过训练集来最优化其内部参数权重,再在验证集上检验其表现。…

第十一周总结

这个作业属于那个课程 C语言程序设计II 这个作业要求在哪里 https://edu.cnblogs.com/campus/zswxy/computer-scienceclass4-2018/homework/3203 我在这个课程的目标是 理解与使用递归函数。 参考文献 基础题 2-1 宏定义“#define DIV(a, b) a/b”,经DIV(x …

softmax函数与交叉熵损失函数

本文主要介绍了当前机器学习模型中广泛应用的交叉熵损失函数与softmax激励函数。 这个损失函数主要应用于多分类问题,用于衡量预测值与实际值之间的相似程度。 交叉熵损失函数定义如下: LCE(y^,y∗)−∑i1Nclassesyi∗log(yi^)L_{CE}(\hat{y}, y^*) - \sum_{i1}^…

unity如何让物体与特定物体之间不发生碰撞

unity中我们普遍使用的是碰撞器来实现各个物体的碰撞体积,例如Box collider, Sphere Collider。 在实现游戏的过程中,如果不想要物体与特定物体产生碰撞,或反之,只想让碰撞发生在特定物体之间时,我们就需要配置layer …

jenkins的JAVA简单顺序配置git仓库

后台Java的发布配置 1、从源码管理下载项目内容 2、构建触发器 3 、构建下环境 4、构建后处理

SQLyog连接数据库报错plugin caching_sha2_password could not be loaded

打开cmd:mysql -uroot -p 进入mysql依次执行下面语句 ALTER USER rootlocalhost IDENTIFIED BY password PASSWORD EXPIRE NEVER; #修改加密规则 ALTER USER rootlocalhost IDENTIFIED WITH mysql_native_password BY password; #更新一下用户的密码 FLUSH PRIVI…

unity导入素材时材质丢失素材变成粉红色的解决方法

有很多时候,当我们通过unity asset store或者blender等等外源导入素材时,会出现材质缺失的bug,如下图所示 : 一个很可能的原因,是由于unity本身管线在每个版本的更新过程中,材质的渲染编码发生了改变。由于这种原因引…

Jenkins 部署vue到服务器

链接github名称 2、从源码管理下载 3、更新最新前端模块 4、进行构建和打包

numpy数组提取一定规律的数据

numpy数组的索引也是符合start stop step规律的,因此可以通过索引提取出一系列索引有规律的元素,如下例子: import numpy as np i np.linspace(1,100,100, dtypeint)-1 print(i) i_train i[0:100:10] print(i_train)输出结果如下 : 可以看到通过索引…

CRM、用户管理权限

CRM目录结构 from django.shortcuts import HttpResponse,render,redirect from django.conf.urls import url from django.utils.safestring import mark_safe from django.urls import reverse from django.forms import ModelForm from stark.utils.my_page import Paginat…