Pytorch中张量矩阵乘法函数(mm, bmm, matmul)使用说明,含高维张量实例及运行结果

Pytorch中张量矩阵乘法函数使用说明

  • 1 torch.mm() 函数
    • 1.1 torch.mm() 函数定义及参数
    • 1.2 torch.bmm() 官方示例
  • 2 torch.bmm() 函数
    • 2.1 torch.bmm() 函数定义及参数
    • 2.2 torch.bmm() 官方示例
  • 3 torch.matmul() 函数
    • 3.1 torch.matmul() 函数定义及参数
    • 3.2 torch.matmul() 规则约定
    • 3.3 torch.matmul() 官方示例
    • 3.4 高维数据实例解释
  • 参考博文及感谢

1 torch.mm() 函数

全称为matrix-matrix product,对输入的张量做矩阵乘法运算,输入输出维度一定是2维

1.1 torch.mm() 函数定义及参数

torch.bmm(input, mat2, , out=None) → Tensor
input (Tensor) – – 第一个要相乘的矩阵
** mat2
* (Tensor) – – 第二个要相乘的矩阵
不支持广播到通用形状、类型推广以及整数、浮点和复杂输入。

1.2 torch.bmm() 官方示例

mat1 = torch.randn(2, 3)
mat2 = torch.randn(3, 3)
torch.mm(mat1, mat2)tensor([[ 0.4851,  0.5037, -0.3633],[-0.0760, -3.6705,  2.4784]])

2 torch.bmm() 函数

全称为batch matrix-matrix product,对输入的张量做矩阵乘法运算,输入输出维度一定是3维;

2.1 torch.bmm() 函数定义及参数

torch.bmm(input, mat2, , out=None) → Tensor
input (Tensor) – – 第一批要相乘的矩阵
** mat2
* (Tensor) – – 第二批要相乘的矩阵
不支持广播到通用形状、类型推广以及整数、浮点和复杂输入。

2.2 torch.bmm() 官方示例

input = torch.randn(10, 3, 4)
mat2 = torch.randn(10, 4, 5)
res = torch.bmm(input, mat2)
res.size()torch.Size([10, 3, 5])

3 torch.matmul() 函数

可进行多维矩阵运算,根据不同输入维度进行广播机制然后运算,和点积类似,广播机制可参考之前博文torch.mul()函数。

3.1 torch.matmul() 函数定义及参数

torch.matmul(input, mat2, , out=None) → Tensor
input (Tensor) – – 第一个要相乘的张量
** mat2
* (Tensor) – – 第二个要相乘的张量
支持广播到通用形状、类型推广以及整数、浮点和复杂输入。

3.2 torch.matmul() 规则约定

(1)若两个都是1D(向量)的,则返回两个向量的点积;

(2)若两个都是2D(矩阵)的,则按照(矩阵相乘)规则返回2D;

(3)若input维度1D,other维度2D,则先将1D的维度扩充到2D(1D的维数前面+1),然后得到结果后再将此维度去掉,得到的与input的维度相同。即使作扩充(广播)处理,input的维度也要和other维度做对应关系;

(4)若input是2D,other是1D,则返回两者的点积结果;

(5)如果一个维度至少是1D,另外一个大于2D,则返回的是一个批矩阵乘法( a batched matrix multiply)

  • (a)若input是1D,other是大于2D的,则类似于规则(3);
  • (b)若other是1D,input是大于2D的,则类似于规则(4);
  • (c)若input和other都是3D的,则与torch.bmm()函数功能一样;
  • (d)如果input中某一维度满足可以广播(扩充),那么也是可以进行相乘操作的。例如 input(j,1,n,m)* other (k,m,p) = output(j,k,n,p)

matmul() 根据输入矩阵自动决定如何相乘。低维根据高维需求,合理广播。

3.3 torch.matmul() 官方示例

# vector x vector
tensor1 = torch.randn(3)
tensor2 = torch.randn(3)
torch.matmul(tensor1, tensor2).size()torch.Size([])
# matrix x vector
tensor1 = torch.randn(3, 4)
tensor2 = torch.randn(4)
torch.matmul(tensor1, tensor2).size()torch.Size([3])
# batched matrix x broadcasted vector
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4)
torch.matmul(tensor1, tensor2).size()torch.Size([10, 3])
# batched matrix x batched matrix
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(10, 4, 5)
torch.matmul(tensor1, tensor2).size()torch.Size([10, 3, 5])
# batched matrix x broadcasted matrix
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4, 5)
torch.matmul(tensor1, tensor2).size()torch.Size([10, 3, 5])

3.4 高维数据实例解释

直接看一个4维的二值例子,先看图(红虚线和实线是为了便于区分维度而添加),不懂再结合代码和结果分析,先做广播,然后对应矩阵进行乘积运算
在这里插入图片描述

代码如下:

import torch
import numpy as npnp.random.seed(2022)
a = np.random.randint(low=0, high=2, size=(2, 2, 3, 4))
a = torch.tensor(a)
b = np.random.randint(low=0, high=2, size=(2, 1, 4, 3))
b = torch.tensor(b)
c = torch.matmul(a, b)
# or
# c = a @ b
print(a)
print("=============================================")
print(b)
print("=============================================")
print(c.size())
print("=============================================")
print(c)

运行结果为:

tensor([[[[1, 0, 1, 0],[1, 1, 0, 1],[0, 0, 0, 0]],[[1, 1, 1, 1],[1, 1, 0, 0],[0, 1, 0, 1]]],[[[0, 0, 0, 1],[0, 0, 0, 1],[0, 1, 0, 0]],[[1, 1, 1, 1],[1, 1, 1, 1],[0, 0, 0, 0]]]], dtype=torch.int32)
=============================================
tensor([[[[0, 1, 0],[1, 1, 0],[0, 0, 0],[1, 1, 0]]],[[[0, 1, 0],[1, 1, 1],[1, 1, 1],[1, 0, 1]]]], dtype=torch.int32)
=============================================
torch.Size([2, 2, 3, 3])
=============================================
tensor([[[[0, 1, 0],[2, 3, 0],[0, 0, 0]],[[2, 3, 0],[1, 2, 0],[2, 2, 0]]],[[[1, 0, 1],[1, 0, 1],[1, 1, 1]],[[3, 3, 3],[3, 3, 3],[0, 0, 0]]]], dtype=torch.int32)

参考博文及感谢

部分内容参考以下链接,这里表示感谢 Thanks♪(・ω・)ノ
参考博文1 官方文档查询地址
https://pytorch.org/docs/stable/index.html
参考博文2 Pytorch矩阵乘法之torch.mul() 、 torch.mm() 及torch.matmul()的区别
https://blog.csdn.net/irober/article/details/113686080

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/81469.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

并查集与LRUCache

一)并查集 在一些应用问题中,需要将N个不同的元素划分成一些互不相交的集合,开始的时候,每一个元素自成一个单元素集合,然后按照一定的规律将归于同一组元素的集合进行合并,并且在此过程中需要反复使用到查询某一个元素…

使用grubby更改RHEL7/8/9的默认内核

使用grubby更改RHEL7/8/9的默认内核 验证默认内核版本获取当前默认内核的索引号检查所有内核的详细信息检查已安装的内核 更改默认内核引导条目使用索引号更改默认内核引导条目 验证默认内核版本 参考:https://linux.cn/article-16147-1.html # 验证默认内核版本 …

炫云云渲染3ds max效果图渲染教程

很多人在第一次使用炫云云渲染渲染效果图的时候不知道怎么使用,其实现在使用炫云渲染效果图真的很简单,我们一起来看看。 一客户端安装 1、打开炫云云渲染官网,点击右上角的客户端下载,选择炫云客户端(NEXT版&#xf…

【JavaEE】多线程(三)

多线程(三) 续上文,多线程(二),我们已经讲了 创建线程Thread的一些重要的属性和方法 那么接下来,我们继续来体会了解多线程吧~ 文章目录 多线程(三)线程启动 startsta…

华为云云耀云服务器L实例评测|cento7.9在线使用cloudShell下载rpm解压包安装mysql并开启远程访问

文章目录 ⭐前言⭐使用华为cloudShell连接远程服务器💖 进入华为云耀服务器控制台💖 选择cloudShell ⭐安装mysql压缩包💖 wget下载💖 tar解压💖 安装步骤💖 初始化数据库💖 修改密码&#x1f4…

实验4 交换机端口隔离(access模式)

交换机端口隔离(access模式) 实验目的实验拓扑实验步骤(1)在未划分vlan前,配置pc1、pc2的地址,如图所示(2)测试两台pc机的连通性(3)创建vlan,并验…

Day66|图part5:130. 被围绕的区域、827.最大人工岛

130. 被围绕的区域 leetcode链接:题目链接 这题看起来很复杂,其实跟之前找飞地和找边缘地区的是差不多的,主要分三步: 使用dfs将边缘的岛都找出来,然后用A代替防止混淆;再用dfs找中间不与任何岛相连的飞地…

【码银送书第七期】七本考研书籍

八九月的朋友圈刮起了一股晒通知书潮,频频有大佬晒出“研究生入学通知书”,看着让人既羡慕又焦虑。果然应了那句老话——比你优秀的人,还比你努力。 心里痒痒,想考研的技术人儿~别再犹豫了。小编咨询了一大波上岸的大佬&#xff…

UDP与TCP报头介绍,三次握手与四次挥手详谈

先介绍我们UDP/TCP协议缓冲区 在UDP和TCP在数据传输和介绍时有有缓冲区概念的。 UDP缓冲区 UDP没有真正意义上的 发送缓冲区. 调用sendto会直接交给内核, 由内核将数据传给网络层协议进行后 续的传输动作; UDP具有接收缓冲区. 但是这个接收缓冲区不能保证收到的UDP报的顺序…

C语言天花板——指针(初阶)

🌠🌠🌠 大家在刚刚接触C语言的时候就肯定听说过,指针的重要性以及难度等级,以至于经常“谈虎色变”,但是今天我来带大家走进指针的奇妙世界。🎇🎇🎇 一、什么是指针&…

旋转角度对迭代次数的影响

( A, B )---3*30*2---( 1, 0 )( 0, 1 ) 让网络的输入只有3个节点,AB训练集各由5张二值化的图片组成,让A中有3个1,B中全是0,统计迭代次数并排序。 在3*5的空间内分布3个点有19种可能,但不同的分布只有6种 差值就诶够 …

七天学会C语言-第二天(数据结构)

1. If 语句&#xff1a; If 语句是一种条件语句&#xff0c;用于根据条件的真假执行不同的代码块。它的基本形式如下&#xff1a; if (条件) {// 条件为真时执行的代码 } else {// 条件为假时执行的代码 }写一个基础的If语句 #include<stdio.h> int main(){int x 10;…

硬件故障诊断:快速定位问题

&#x1f337;&#x1f341; 博主猫头虎&#xff08;&#x1f405;&#x1f43e;&#xff09;带您 Go to New World✨&#x1f341; &#x1f984; 博客首页——&#x1f405;&#x1f43e;猫头虎的博客&#x1f390; &#x1f433; 《面试题大全专栏》 &#x1f995; 文章图文…

Linux基础开发工具使用快速上手

软件包管理器 概念理解 在Linux下安装软件的话&#xff0c;一个比较原始的办法是下载程序的源代码&#xff0c;然后进行编译&#xff0c;进而得到可执行程序&#xff0c;然后就可以运行这个软件了。但是这种做法太麻烦了&#xff0c;于是就有些人把一些常用的软件提前编译好&…

笔记1.5:计算机网络体系结构

从功能上描述计算机网络结构 分层结构 每层遵循某个网络协议完成本层功能 基本概念 实体&#xff1a;表示任何可发送或接收信息的硬件或软件进程。 协议是控制两个对等实体进行通信的规则的集合&#xff0c;协议是水平的。 任一层实体需要使用下层服务&#xff0c;遵循本层…

uniapp 小程序 父组件调用子组件方法

答案&#xff1a;配合小程序API > this.selectComponent("")&#xff0c;来选择组件&#xff0c;再使用$vm选择组件实例&#xff0c;再调用方法&#xff0c;或者data 1 设置组件的id,如果你的多端&#xff0c;请跟据情况设置ref,class,id&#xff0c;以便通过小…

一阶低通滤波器滞后补偿算法

一阶低通滤波器的推导过程和双线性变换算法请查看下面文章链接: PLC算法系列之数字低通滤波器(离散化方法:双线性变换)_双线性离散化_RXXW_Dor的博客-CSDN博客PLC信号处理系列之一阶低通(RC)滤波器算法_RXXW_Dor的博客-CSDN博客_rc滤波电路的优缺点1、先看看RC滤波的优缺点…

Redis 篇

1、为什么要用缓存&#xff1f; 使用缓存的目的就是提升读写性能。而实际业务场景下&#xff0c;更多的是为了提升读性能&#xff0c;带来更好的性能&#xff0c;带来更高的并发量。 Redis 的读写性能比 Mysql 好的多&#xff0c;我们就可以把 Mysql 中的热点数据缓存到 Redis…

Linux学习第14天:Linux设备树(一):枝繁叶茂见晴天

本节笔记主要学习了Linux设备树相关知识点&#xff0c;由于内容较多&#xff0c;打算分两天进行总结。今天着重学习Linux设备树&#xff0c;主要包括前三节内容&#xff0c;分别是概念、格式和语法。 本节思维导图内容如下&#xff1a; 一、什么是设备树 设备树可以用一个图来进…

Vivado XADC IP核 使用详解

本文介绍Vivado中XADC Wizard V3.3的使用方法。 XADC简介 XADC Wizard Basic Interface Options&#xff1a; 一共三种&#xff0c;分别是AXI4Lite、DRP、None。勾选后可在界面左侧看到相应通信接口情况。Startup Channel Selection Simultaneous Selection&#xff1a;同时监…