pytorch小记(一):pytorch矩阵乘法:torch.matmul(x, y)

pytorch小记(一):pytorch矩阵乘法:torch.matmul(x, y)/ x @ y

      • 代码
      • 代码 1:`torch.matmul(x, y)`
        • 输入张量:
        • 计算逻辑:
        • 输出结果:
      • 代码 2:`y = y.view(4,1)` 再 `torch.matmul(x, y)`
        • 输入张量:
        • 计算逻辑:
        • 输出结果:
      • 总结:两种情况的区别


代码

x = torch.tensor([[1,2,3,4], [5,6,7,8]])
y = torch.tensor([2, 3, 1, 0]) # y.shape == (4)
print(torch.matmul(x, y))
print(x @ y)
>>>
tensor([11, 35])
tensor([11, 35])
x = torch.tensor([[1,2,3,4], [5,6,7,8]])
y = torch.tensor([2, 3, 1, 0]) # y.shape == (4)
y = y.view(4,1)                # y.shape == (4, 1)
'''
tensor([[2],[3],[1],[0]])
'''
print(torch.matmul(x, y))
print(x @ y)
>>>
tensor([[11],[35]])
tensor([[11],[35]])

在这段代码中,torch.matmul(x, y) 或者x @ y计算的是矩阵乘法或张量乘法。我们分两种情况详细分析:


代码 1:torch.matmul(x, y)

输入张量:
  • x 是一个 2D 张量,形状为 (2, 4)
    tensor([[1, 2, 3, 4],[5, 6, 7, 8]])
    
  • y 是一个 1D 张量,形状为 (4,)
    tensor([2, 3, 1, 0])
    
计算逻辑:

在 PyTorch 中,如果 matmul 的一个输入是 2D 张量,另一个是 1D 张量,计算规则是:

  • 将 1D 张量 y 当作列向量 (4, 1),与矩阵 x 进行矩阵乘法。
  • 结果是一个 1D 张量,形状为 (2,)

矩阵乘法公式:
result [ i ] = ∑ j x [ i , j ] ⋅ y [ j ] \text{result}[i] = \sum_j x[i, j] \cdot y[j] result[i]=jx[i,j]y[j]

具体计算步骤:

  1. 对第一行 [1, 2, 3, 4]
    ( 1 ⋅ 2 ) + ( 2 ⋅ 3 ) + ( 3 ⋅ 1 ) + ( 4 ⋅ 0 ) = 2 + 6 + 3 + 0 = 11 (1 \cdot 2) + (2 \cdot 3) + (3 \cdot 1) + (4 \cdot 0) = 2 + 6 + 3 + 0 = 11 (12)+(23)+(31)+(40)=2+6+3+0=11
  2. 对第二行 [5, 6, 7, 8]
    ( 5 ⋅ 2 ) + ( 6 ⋅ 3 ) + ( 7 ⋅ 1 ) + ( 8 ⋅ 0 ) = 10 + 18 + 7 + 0 = 35 (5 \cdot 2) + (6 \cdot 3) + (7 \cdot 1) + (8 \cdot 0) = 10 + 18 + 7 + 0 = 35 (52)+(63)+(71)+(80)=10+18+7+0=35
输出结果:
torch.matmul(x, y)
# tensor([11, 35])

代码 2:y = y.view(4,1)torch.matmul(x, y)

输入张量:
  • x 是同一个 2D 张量,形状为 (2, 4)
  • y 被重塑为 2D 张量,形状为 (4, 1)
    tensor([[2],[3],[1],[0]])
    
计算逻辑:

在这种情况下,matmul 执行的是 矩阵乘法,两个输入的形状为 (2, 4)(4, 1)

  • 矩阵乘法的规则是:前一个矩阵的列数必须等于后一个矩阵的行数
  • 结果张量的形状是 (2, 1)

矩阵乘法公式:
result [ i , k ] = ∑ j x [ i , j ] ⋅ y [ j , k ] \text{result}[i, k] = \sum_j x[i, j] \cdot y[j, k] result[i,k]=jx[i,j]y[j,k]

具体计算步骤:

  1. 对第一行 [1, 2, 3, 4] 和列向量 [[2], [3], [1], [0]]
    ( 1 ⋅ 2 ) + ( 2 ⋅ 3 ) + ( 3 ⋅ 1 ) + ( 4 ⋅ 0 ) = 2 + 6 + 3 + 0 = 11 (1 \cdot 2) + (2 \cdot 3) + (3 \cdot 1) + (4 \cdot 0) = 2 + 6 + 3 + 0 = 11 (12)+(23)+(31)+(40)=2+6+3+0=11
  2. 对第二行 [5, 6, 7, 8] 和列向量 [[2], [3], [1], [0]]
    ( 5 ⋅ 2 ) + ( 6 ⋅ 3 ) + ( 7 ⋅ 1 ) + ( 8 ⋅ 0 ) = 10 + 18 + 7 + 0 = 35 (5 \cdot 2) + (6 \cdot 3) + (7 \cdot 1) + (8 \cdot 0) = 10 + 18 + 7 + 0 = 35 (52)+(63)+(71)+(80)=10+18+7+0=35
输出结果:
torch.matmul(x, y)
# tensor([[11],
#         [35]])

总结:两种情况的区别

  1. y 是 1D 张量

    • torch.matmul(x, y) 返回一个 1D 张量,形状为 (2,)
    • 相当于将 y 当作列向量,与矩阵 x 做矩阵乘法。
  2. y 是 2D 张量

    • torch.matmul(x, y) 返回一个 2D 张量,形状为 (2, 1)
    • 矩阵乘法严格遵守二维矩阵的维度规则。

两者的结果数值相同,但形状不同,主要是因为输入张量的维度不同,导致输出的维度也发生了变化。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/67209.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Sentaurus TCAD学习笔记:transform指令

目录 一、transform指令简介二、transform指令的实现1.cut指令2.flip指令3.rotate指令4.stretch指令5.translate指令6.reflect指令 三、transform指令示例 一、transform指令简介 在Sentaurus中,如果需要对器件进行翻转、平移等操作,可以通过transform指…

kafka消费堆积问题探索

背景 我们的商城项目用PHP写的,原本写日志方案用的是PHP的方案,但是,这个方案导致资源消耗一直降不下来,使用了20个CPU。后面考虑使用通过kafka的方案写日志,商城中把产生的日志丢到kafka中,在以go写的项目…

【opencv】第7章 图像变换

7.1 基 于OpenCV 的 边 缘 检 测 本节中,我们将一起学习OpenCV 中边缘检测的各种算子和滤波器——Canny 算子、Sobel 算 子 、Laplacian 算子以及Scharr 滤波器。 7.1.1 边缘检测的一般步骤 在具体介绍之前,先来一起看看边缘检测的一般步骤。 1.【第…

[Qt]常用控件介绍-多元素控件-QListWidget、QTableWidget、QQTreeWidget

目录 1.多元素控件介绍 2.ListWidget控件 属性 核心方法 核心信号 细节 Demo:编辑日程 3.TableWidget控件 核心方法 QTableWidgetItem核心信号 QTableWidgetItem核心方法 细节 Demo:编辑学生信息 4.TreeWidget控件 核心方法 核心信号…

[Linux]从零开始的STM32MP157交叉编译环境配置

一、前言 最近该忙的事情也是都忙完了,也是可以开始好好的学习一下Linux了。之前九月份的时候就想入手一块Linux的开发板用来学习Linux底层开发。之前在NXP和STM32MP系列之间犹豫,思来想去还是入手了一块STM32MP157。当然不是单纯因为MP157的性能在NXP之…

小程序如何引入腾讯位置服务

小程序如何引入腾讯位置服务 1.添加服务 登录 微信公众平台 注意:小程序要企业版的 第三方服务 -> 服务 -> 开发者资源 -> 开通腾讯位置服务 在设置 -> 第三方设置 中可以看到开通的服务,如果没有就在插件管理中添加插件 2.腾讯位置服务…

添加计算机到AD域中

添加计算机到AD域中 一、确定计算机的DNS指向域中的DNS二、打开系统设置三、加域成功后 一、确定计算机的DNS指向域中的DNS 二、打开系统设置 输入域管理员的账密 三、加域成功后 这里有显示,就成功了。

从epoll事件的视角探讨TCP:三次握手、四次挥手、应用层与传输层之间的联系

目录 一、应用层与TCP之间的联系 二、 当通信双方中的一方如客户端主动断开连接时,仅是在客户端的视角下连接已经断开,在服务端的眼中,连接依然存在,为什么?——触发EPOLLRDHUP事件:对端关闭连接或停止写…

使用RSyslog将Nginx Access Log写入Kafka

个人博客地址:使用RSyslog将Nginx Access Log写入Kafka | 一张假钞的真实世界 环境说明 CentOS Linux release 7.3.1611kafka_2.12-0.10.2.2nginx/1.12.2rsyslog-8.24.0-34.el7.x86_64.rpm 创建测试Topic $ ./kafka-topics.sh --zookeeper 192.168.72.25:2181/k…

使用 Docker 部署 Java 项目(通俗易懂)

目录 1、下载与配置 Docker 1.1 docker下载(这里使用的是Ubuntu,Centos命令可能有不同) 1.2 配置 Docker 代理对象 2、打包当前 Java 项目 3、进行编写 DockerFile,并将对应文件传输到 Linux 中 3.1 编写 dockerfile 文件 …

《研发管理 APQP 软件系统》——汽车电子行业的应用收益分析

全星研发管理 APQP 软件系统在汽车电子行业的应用收益分析 在汽车电子行业,技术革新迅猛,市场竞争激烈。《全星研发管理 APQP 软件系统》的应用,为企业带来了革命性的变化,诸多收益使其成为行业发展的关键驱动力。 《全星研发管理…

22、PyTorch nn.Conv2d卷积网络使用教程

文章目录 1. 卷积2. python 代码3. notes 1. 卷积 输入A张量为: A [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 ] \begin{equation} A\begin{bmatrix} 0&1&2&3\\\\ 4&5&6&7\\\\ 8&9&10&11\\\\ 12&13&14&15 \end{b…

ASP.NET Core - 依赖注入(四)

ASP.NET Core - 依赖注入(四) 4. ASP.NET Core默认服务5. 依赖注入配置变形 4. ASP.NET Core默认服务 之前讲了中间件,实际上一个中间件要正常进行工作,通常需要许多的服务配合进行,而中间件中的服务自然也是通过 Ioc…

UE5游戏性能优化指南

解除帧率限制 启动游戏 按 “~” 键 输入 t.MaxFPS 200 可以解除默认帧率限制达到更高的帧率 UE游戏性能和场景优化思路: 1. 可以把可延展性调低,帧率会大幅提高,但画质会大幅降低 2.调整固定灯光,静态光源&#xff…

深度学习中的卷积和反卷积(四)——卷积和反卷积的梯度

本系列已完结,全部文章地址为: 深度学习中的卷积和反卷积(一)——卷积的介绍 深度学习中的卷积和反卷积(二)——反卷积的介绍 深度学习中的卷积和反卷积(三)——卷积和反卷积的计算 …

【C语言】线程

目录 1. 什么是线程 1.1概念 1.2 进程和线程的区别 1.3 线程资源 2. 函数接口 2.1创建线程: pthread_create 2.2 退出线程: pthread_exit 2.3 回收线程资源 练习 1. 什么是线程 1.1概念 线程是一个轻量级的进程,为了提高系统的性能引入线程。 在同一个进…

【C语言】字符串函数详解

文章目录 Ⅰ. strcpy -- 字符串拷贝1、函数介绍2、模拟实现 Ⅱ. strcat -- 字符串追加1、函数介绍2、模拟实现 Ⅲ. strcmp -- 字符串比较1、函数介绍2、模拟实现 Ⅳ. strncpy、strncat、strncmp -- 可限制操作长度Ⅴ. strlen -- 求字符串长度1、函数介绍2、模拟实现&#xff08…

Windows部署NVM并下载多版本Node.js的方法(含删除原有Node的方法)

本文介绍在Windows电脑中,下载、部署NVM(node.js version management)环境,并基于其安装不同版本的Node.js的方法。 在之前的文章Windows系统下载、部署Node.js与npm环境的方法(https://blog.csdn.net/zhebushibiaoshi…

centos 8 中安装Docker

注:本次样式安装使用的是centos8 操作系统。 1、镜像下载 具体的镜像下载地址各位可以去官网下载,选择适合你们的下载即可! 1、CentOS官方下载地址:https://vault.centos.org/ 2、阿里云开源镜像站下载:centos安装包…

STM32-笔记40-BKP(备份寄存器)

一、什么是BKP(备份寄存器)? 备份寄存器是42个16位的寄存器,可用来存储84个字节的用户应用程序数据。他们处在备份域里,当VDD电源被切断,他们仍然由VBAT维持供电。当系统在待机模式下被唤醒,或…