PyTorch学习笔记之基础函数篇(十五)

文章目录

  • 数值比较运算
    • 8.1 torch.equal()函数
    • 8.2 torch.ge()函数
    • 8.3 torch.gt()函数
    • 8.4 torch.le()函数
    • 8.5 torch.lt()函数
    • 8.6 torch.ne()函数
    • 8.7 torch.sort()函数
    • 8.8 torch.topk()函数

数值比较运算

8.1 torch.equal()函数

torch.equal(tensor1, tensor2) -> bool

这个函数接受两个张量(tensor1 和 tensor2)作为参数,并返回一个布尔值(bool),表示这两个张量是否相等。

  • tensor1 (Tensor): 第一个要比较的张量。
  • tensor2 (Tensor): 第二个要比较的张量。

torch.equal 是 PyTorch 中的一个函数,用于比较两个张量(tensors)是否具有相同的形状和元素值。如果两个张量满足以下条件,则 torch.equal 返回 True:

  • 两个张量具有相同的形状(shape)。
  • 两个张量中的每个元素都相等。

这个函数在比较两个张量是否完全相同时非常有用。

示例:

import torch# 创建两个形状和值都相同的张量
tensor1 = torch.tensor([1.0, 2.0, 3.0])
tensor2 = torch.tensor([1.0, 2.0, 3.0])# 使用 torch.equal 检查它们是否相等
result = torch.equal(tensor1, tensor2)
print(result)  # 输出: True# 创建两个形状相同但值不同的张量
tensor3 = torch.tensor([1.0, 2.0, 4.0])# 再次使用 torch.equal 检查
result = torch.equal(tensor1, tensor3)
print(result)  # 输出: False

请注意,torch.equal 不仅比较张量的形状,还比较它们的元素值。这与 torch.eq 不同,后者仅比较两个张量中对应位置的元素值是否相等,并返回一个布尔值张量。

另外,如果两个张量都是稀疏的(即它们使用稀疏格式存储),则 torch.equal 还会比较它们的稀疏索引和值是否相同。

使用 torch.equal 时,请确保比较的张量在设备(CPU 或 GPU)和数据类型(如 float32、float64 等)上都是一致的,否则比较可能会失败或产生意外的结果。

8.2 torch.ge()函数

在PyTorch中,torch.ge() 函数用于逐元素地比较两个张量(tensors),并返回一个布尔值张量(boolean tensor),其中每个元素表示对应位置上的元素是否满足第一个张量中的元素大于或等于第二个张量中的元素。

函数签名如下:

torch.ge(input, other, *, out=None) -> Tensor
  • input (Tensor): 要比较的第一个张量。
  • other (Tensor or scalar): 要比较的第二个张量或标量。
  • out (Tensor, optional): 输出张量。

如果 other 是一个标量,那么它会与 input 张量中的每个元素进行比较。如果 other 是一个张量,那么它必须具有与 input 相同的形状,以便进行逐元素的比较。

示例:

import torch# 创建两个张量
tensor1 = torch.tensor([1, 2, 3])
tensor2 = torch.tensor([1, 2, 4])# 使用 torch.ge 进行逐元素比较
result = torch.ge(tensor1, tensor2)print(result)
# 输出: tensor([ True,  True, False])
# 因为 1 >= 1, 2 >= 2, 3 < 4# 使用标量进行比较
scalar = 2
result_scalar = torch.ge(tensor1, scalar)print(result_scalar)
# 输出: tensor([False,  True,  True])
# 因为 1 < 2, 2 >= 2, 3 >= 2

请注意,torch.ge() 返回的是一个布尔值张量,其形状与输入张量相同。如果比较操作涉及到不同类型的张量(例如,一个张量是浮点数类型,另一个是整数类型),则可能需要先将它们转换为相同的类型,以避免可能的类型不匹配错误。

此外,还可以使用张量的比较运算符 >= 来进行相同的操作,这在语法上可能更简洁:

result = tensor1 >= tensor2

这行代码与 torch.ge(tensor1, tensor2) 等效。

8.3 torch.gt()函数

torch.gt() 是 PyTorch 中的一个函数,用于逐元素地比较两个张量(tensors),并返回一个布尔值张量(boolean tensor),其中每个元素表示对应位置上的第一个张量中的元素是否大于第二个张量中的元素。

函数签名如下:

torch.gt(input, other, *, out=None) -> Tensor
  • input (Tensor): 要比较的第一个张量。
  • other (Tensor or scalar): 要比较的第二个张量或标量。
  • out (Tensor, optional): 输出张量。

如果 other 是一个标量,那么它会与 input 张量中的每个元素进行比较。如果 other 是一个张量,那么它必须具有与 input 相同的形状,以便进行逐元素的比较。

示例:

import torch# 创建两个张量
tensor1 = torch.tensor([1, 2, 3])
tensor2 = torch.tensor([1, 2, 4])# 使用 torch.gt 进行逐元素比较
result = torch.gt(tensor1, tensor2)print(result)
# 输出: tensor([False, False,  True])
# 因为 1 不大于 1, 2 不大于 2, 3 大于 4(这里的比较是错误的,应该是 3 < 4)# 修正比较操作
result_corrected = torch.gt(tensor1, 2)print(result_corrected)
# 输出: tensor([False, False,  True])
# 因为 1 不大于 2, 2 不大于 2, 3 大于 2# 使用标量进行比较
scalar = 2
result_scalar = torch.gt(tensor1, scalar)print(result_scalar)
# 输出: tensor([False, False,  True])
# 因为 1 不大于 2, 2 不大于 2, 3 大于 2

请注意,torch.gt() 返回的是一个布尔值张量,其形状与输入张量相同。在这个例子中,tensor1 和 tensor2 的比较结果是一个包含三个元素的布尔值张量,表示 tensor1 中每个元素是否大于 tensor2 中对应位置的元素。

另外,与 torch.ge() 类似,你也可以使用张量的比较运算符 > 来进行相同的操作:

result = tensor1 > tensor2

这行代码与 torch.gt(tensor1, tensor2) 等效。

8.4 torch.le()函数

在PyTorch中,torch.le() 函数用于逐元素地比较两个张量(tensors),并返回一个布尔值张量(boolean tensor),其中每个元素表示对应位置上的第一个张量中的元素是否小于或等于第二个张量中的元素。

函数签名如下:

torch.le(input, other, *, out=None) -> Tensor
  • input (Tensor): 要比较的第一个张量。
  • other (Tensor or scalar): 要比较的第二个张量或标量。
  • out (Tensor, optional): 输出张量。

如果 other 是一个标量,那么它会与 input 张量中的每个元素进行比较。如果 other 是一个张量,那么它必须具有与 input 相同的形状,以便进行逐元素的比较。

示例:

import torch# 创建两个张量
tensor1 = torch.tensor([1, 2, 3])
tensor2 = torch.tensor([1, 2, 4])# 使用 torch.le 进行逐元素比较
result = torch.le(tensor1, tensor2)print(result)
# 输出: tensor([ True,  True,  True])
# 因为 1 <= 1, 2 <= 2, 3 <= 4# 使用标量进行比较
scalar = 2
result_scalar = torch.le(tensor1, scalar)print(result_scalar)
# 输出: tensor([ True,  True, False])
# 因为 1 <= 2, 2 <= 2, 3 > 2

在这个例子中,torch.le(tensor1, tensor2) 返回一个布尔值张量,其中每个元素对应于 tensor1 和 tensor2 中相应位置的元素比较结果。同样,torch.le(tensor1, scalar) 将 tensor1 中的每个元素与标量 scalar 进行比较。

与 torch.ge() 类似,你也可以使用张量的比较运算符 <= 来进行相同的操作:

result = tensor1 <= tensor2

这行代码与 torch.le(tensor1, tensor2) 等效。

8.5 torch.lt()函数

在PyTorch中,torch.lt() 函数用于逐元素地比较两个张量(tensors),并返回一个布尔值张量(boolean tensor),其中每个元素表示对应位置上的第一个张量中的元素是否小于第二个张量中的元素。

函数签名如下:

torch.lt(input, other, *, out=None) -> Tensor
  • input (Tensor): 要比较的第一个张量。
  • other (Tensor or scalar): 要比较的第二个张量或标量。
  • out (Tensor, optional): 输出张量。

如果 other 是一个标量,那么它会与 input 张量中的每个元素进行比较。如果 other 是一个张量,那么它必须具有与 input 相同的形状,以便进行逐元素的比较。

示例:

import torch# 创建两个张量
tensor1 = torch.tensor([1, 2, 3])
tensor2 = torch.tensor([1, 2, 4])# 使用 torch.lt 进行逐元素比较
result = torch.lt(tensor1, tensor2)print(result)
# 输出: tensor([False, False,  True])
# 因为 1 不小于 1, 2 不小于 2, 3 小于 4# 使用标量进行比较
scalar = 2
result_scalar = torch.lt(tensor1, scalar)print(result_scalar)
# 输出: tensor([ True,  False,  False])
# 因为 1 小于 2, 2 不小于 2, 3 不小于 2

在这个例子中,torch.lt(tensor1, tensor2) 返回一个布尔值张量,表示 tensor1 中每个元素是否小于 tensor2 中对应位置的元素。同样地,torch.lt(tensor1, scalar) 将 tensor1 中的每个元素与标量 scalar 进行比较。

你也可以使用张量的比较运算符 < 来执行相同的操作:

result = tensor1 < tensor2

这行代码与 torch.lt(tensor1, tensor2) 等效。

8.6 torch.ne()函数

torch.ne() 是 PyTorch 中的一个函数,用于逐元素地比较两个张量(tensors),并返回一个布尔值张量(boolean tensor),其中每个元素表示对应位置上的第一个张量中的元素是否不等于第二个张量中的元素。

函数签名如下:

torch.ne(input, other, *, out=None) -> Tensor
  • input (Tensor): 要比较的第一个张量。
  • other (Tensor or scalar): 要比较的第二个张量或标量。
  • out (Tensor, optional): 输出张量。

如果 other 是一个标量,那么它会与 input 张量中的每个元素进行比较。如果 other 是一个张量,那么它必须具有与 input 相同的形状,以便进行逐元素的比较。

示例:

import torch# 创建两个张量
tensor1 = torch.tensor([1, 2, 3])
tensor2 = torch.tensor([1, 2, 4])# 使用 torch.ne 进行逐元素比较
result = torch.ne(tensor1, tensor2)print(result)
# 输出: tensor([False, False,  True])
# 因为 1 等于 1, 2 等于 2, 3 不等于 4# 使用标量进行比较
scalar = 2
result_scalar = torch.ne(tensor1, scalar)print(result_scalar)
# 输出: tensor([ True,  True, False])
# 因为 1 不等于 2, 2 不等于 2, 3 不等于 2

在这个例子中,torch.ne(tensor1, tensor2) 返回一个布尔值张量,其中每个元素对应于 tensor1 和 tensor2 中相应位置的元素比较结果。同样,torch.ne(tensor1, scalar) 将 tensor1 中的每个元素与标量 scalar 进行比较。

与 torch.ne() 类似,你也可以使用张量的比较运算符 != 来进行相同的操作:

result = tensor1 != tensor2

这行代码与 torch.ne(tensor1, tensor2) 等效。

8.7 torch.sort()函数

torch.sort() 是 PyTorch 中的一个函数,用于对张量(tensor)进行排序。它返回排序后的张量以及原始张量中元素的索引。

函数签名如下:

torch.sort(input, dim=None, descending=False, out=None) -> (Tensor, Tensor)
  • input (Tensor): 要排序的张量。
  • dim (int, optional): 沿着哪个维度进行排序。默认为 None,表示对整个张量进行排序。
  • descending (bool, optional): 是否按降序排序。默认为 False,即按升序排序。
  • out (tuple[Tensor, Tensor], optional): 可选的输出张量元组,用于存放排序结果和索引。

torch.sort() 返回一个元组,其中包含两个张量:

  • 排序后的张量(Tensor):沿着指定维度对输入张量进行排序后的结果。
  • 索引张量(Tensor):原始张量中元素的索引,按照排序后的顺序排列。

示例:

import torch# 创建一个张量
x = torch.tensor([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5])# 对张量进行排序
sorted_tensor, sorted_indices = torch.sort(x)print("Sorted tensor:", sorted_tensor)
# 输出: Sorted tensor: tensor([1, 1, 2, 3, 3, 4, 5, 5, 5, 6, 9])print("Sorted indices:", sorted_indices)
# 输出: Sorted indices: tensor([1, 3, 7, 0, 10, 2, 4, 8, 9, 5, 6])# 如果想按降序排序
sorted_tensor_desc, sorted_indices_desc = torch.sort(x, descending=True)print("Sorted tensor (descending):", sorted_tensor_desc)
# 输出: Sorted tensor (descending): tensor([9, 6, 5, 5, 5, 4, 3, 3, 2, 1, 1])print("Sorted indices (descending):", sorted_indices_desc)
# 输出: Sorted indices (descending): tensor([6, 5, 8, 9, 4, 2, 0, 10, 7, 3, 1])

在这个例子中,torch.sort(x) 返回了排序后的张量 sorted_tensor 和原始张量中元素的索引 sorted_indices。如果指定 descending=True,则会按降序排序。

8.8 torch.topk()函数

torch.topk() 是 PyTorch 中的一个函数,用于返回张量中每个指定维度上的前 k 个最大值(或最小值)及其索引。这个函数非常有用,特别是在你需要获取张量中的顶部元素或进行排序相关的操作时。

函数签名如下:

torch.topk(input, k, dim=None, largest=True, sorted=True, *, out=None) -> (Tensor, Tensor)
  • input (Tensor): 输入张量。
  • k (int): 要返回的最大(或最小)值的数量。
  • dim (int, optional): 在哪个维度上进行操作。默认是第一个维度(dim=0)。
  • largest (bool, optional): 是否返回最大的 k 个值。如果是 True,则返回最大的 k 个值;如果是 False,则返回最小的 k 个值。默认为 True。
  • sorted (bool, optional): 是否对结果进行排序。如果为 True,则返回的 k 个值会按照降序(对于 largest=True)或升序(对于 largest=False)排列。默认为 True。
  • out (tuple[Tensor, Tensor], optional): 可选的输出张量元组,用于存放结果和索引。

torch.topk() 返回一个元组,其中包含两个张量:

第一个张量包含每个指定维度上的前 k 个最大值(或最小值)。
第二个张量包含这些最大值(或最小值)在原始张量中的索引。

示例:

import torch# 创建一个张量
x = torch.tensor([[ 3, 2, 1],[ 2, 3, 1],[ 1, 2, 3]])# 获取每个行的前 2 个最大值及其索引
values, indices = torch.topk(x, k=2, dim=1, largest=True, sorted=True)print("Top 2 values:", values)
# 输出: Top 2 values: tensor([[3, 2],
#                             [3, 2],
#                             [3, 2]])print("Indices of top 2 values:", indices)
# 输出: Indices of top 2 values: tensor([[0, 1],
#                                        [1, 0],
#                                        [2, 1]])# 获取每个列的最小值及其索引
values, indices = torch.topk(x, k=1, dim=0, largest=False, sorted=True)print("Smallest value in each column:", values)
# 输出: Smallest value in each column: tensor([[1, 1, 1]])print("Indices of smallest values in each column:", indices)
# 输出: Indices of smallest values in each column: tensor([[2, 2, 2]])

在这个例子中,我们首先使用 torch.topk() 获取了每个行(dim=1)的前两个最大值(largest=True)及其索引。然后,我们改变了维度和条件,获取了每个列(dim=0)的最小值(largest=False)及其索引。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/755665.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JavaScript函数声明调用

普通函数 function f1(a, b) {return "普通函数f1&#xff1a;" a * b }console.log(f1(3, 7));var $ function f2() {return 普通函数f2 } console.log($())箭头函数 // 多个参数的箭头函数 const f3 (param1, param2) > {return "箭头函数f3&#xff…

web部署 四 限制

案例: 1/设置其下载速度限制:1000000010m&#xff0c;10000001m&#xff0c;2分别查看下载速率是否发生变化。 2/限制连接数&#xff0c;同时下载5个文件。查看第6个是否能正常访问。使用命令符:netstat-n&#xff0c;查看活动链接&#xff0c; 正常情况下我们的下载速度 我们…

遥感深度学习:CNN-LSTM模型用于NDVI的预测(Pytorch代码深度剖析)

代码上传至Github库&#xff1a;https://github.com/ChaoQiezi/CNN-LSTM-model-is-used-to-predict-NDVI 01 前言 这是一次完整的关于时空遥感影像预测相关的深度学习项目&#xff0c;后续有时间更新后续部分。 通过这次项目&#xff0c;你可以了解&#xff1a; pytroch的模…

石油炼化5G智能制造工厂数字孪生可视化平台,推进行业数字化转型

石油炼化5G智能制造工厂数字孪生可视化平台&#xff0c;推进行业数字化转型。在石油炼化行业&#xff0c;5G智能制造工厂数字孪生可视化平台的出现&#xff0c;为行业的数字化转型注入了新的活力。石油炼化行业作为传统工业的重要领域&#xff0c;面临着资源紧张、环境压力、安…

数组排列组合---M中取出N个元素

一、数组M中取出N个元素的所有组合 const getCombinationsWithRepetition (array: Array<string>, n: number) > {const result [] as anyconst currentCombination [] as anyfunction backtrack(start: number, count: number) {if (count 0) {result.push(curr…

蓝桥杯练习题——健身大调查

在浏览器中预览 index.html 页面效果如下&#xff1a; 目标 完成 js/index.js 中的 formSubmit 函数&#xff0c;用户填写表单信息后&#xff0c;点击蓝色提交按钮&#xff0c;表单项隐藏&#xff0c;页面显示用户提交的表单信息&#xff08;在 id 为 result 的元素显示&#…

Redis高性能IO模型剖析

Redis的高性能IO模型主要归功于其内部精心设计的机制&#xff0c;包括单线程模型、IO多路复用技术、高效的数据结构以及内存操作等。下面我们将逐一剖析这些关键因素。 首先&#xff0c;Redis采用单线程模型来处理网络IO和键值对读写操作。这种设计避免了多线程间的竞争和同步…

OxyPlot 导出图片

在 OxyPlot 官方文档 https://oxyplot.readthedocs.io/en/latest/export/index.html 中查看 这里用到的是导出到 PNG 文件的方法&#xff0c;不过用的 NuGet 包最新版&#xff08;2.1.0&#xff09;中&#xff0c;PngExporter 中并没有 Background 属性&#xff1a; 所以如果图…

【C语言】C语言内存函数

&#x1f451;个人主页&#xff1a;啊Q闻 &#x1f387;收录专栏&#xff1a;《C语言》 &#x1f389;道阻且长&#xff0c;行则将至 前言 这篇博客是关于C语言内存函数(memcpy,memmove,memset,memcmp)的使用以及部分的模拟实现 memcpy,memmove,memset,memc…

一文搞懂“ReentrantReadWriteLock——读写锁”

文章目录 初识读写锁ReentrantReadWriteLock类结构注意事项 ReentrantReadWriteLock源码分析读写状态的设计HoldCounter 计数器读锁的获取读锁的释放写锁的获取写锁的释放 锁降级 初识读写锁 Java中的锁——ReentrantLock和synchronized都是排它锁&#xff0c;意味着这些锁在同…

Python高级语法

Python高级语 1 列表推导式1.1 什么是列表推导式1.2 列表推导式的使用 2 字典推导式2.1 什么是字典推导式2.2 字典推导式的使用 3 元组推导式4 集合推导式5 三元表达式5.1 什么是三元表达式5.2 三元表达式的使用 1 列表推导式 1.1 什么是列表推导式 列表推导式的英文&#xf…

docker安装配置dnsmasq

docker下载安装 参考&#xff1a;docker安装、卸载、配置、镜像 如果是低版本的额ubuntu&#xff0c;比如ubuntu16.04.7 LTS&#xff0c;为了加快下载速度&#xff0c;参考&#xff1a;Ubuntu16.04LTS安装Docker。 docker安装dnsmasq 下载dnsmasq镜像 首先镜像我们可以选择…

代码随想录 动态规划-完全背包问题

52. 携带研究材料 时间限制&#xff1a;1.000S 空间限制&#xff1a;128MB 题目描述 小明是一位科学家&#xff0c;他需要参加一场重要的国际科学大会&#xff0c;以展示自己的最新研究成果。他需要带一些研究材料&#xff0c;但是他的行李箱空间有限。这些研究材料包括实验…

Could not locate zlibwapi.dll. Please make sure it is in your library path!

背景 运行PaddleOCR时&#xff0c;用的CUDA11.6配的是cuDNN8.4。但是运行后却报错如下。 解决手段 去网上找到这两个文件&#xff0c;现在英伟达好像不能下载了&#xff0c;但是可以去网盘下载。然后把dll文件放入CUDA11.6文件下的bin目录&#xff0c;而lib文件放入CUDA11.6文…

5.1.4.2、【AI技术新纪元:Spring AI解码】Llama2 Chat

Llama2 Chat Meta 的 Llama 2 Chat 是 Llama 2 系列大型语言模型的一部分。它在基于对话的应用程序中表现出色,参数规模范围从 70 亿到 700 亿不等。利用公共数据集和超过 100 万次人类注释,Llama Chat 提供了上下文感知的对话。 通过从公共数据源获取的 2 万亿标记进行训练…

基于 RisingWave 和 Kafka 构建实时网络安全解决方案

实时威胁检测可实时监控和分析数据&#xff0c;并及时对潜在的安全威胁作出识别和响应。与依赖定期扫描或回顾性分析的安全措施不同&#xff0c;实时威胁检测系统可提供即时警报&#xff0c;并启动自动响应来降低风险&#xff0c;而不会出现高延迟。 实时威胁检测有许多不同的…

英特尔生态的深度学习科研环境配置-A770为例

之前发过在Intel A770 GPU安装oneAPI的教程&#xff0c;但那个方法是用于WSL上。总所周知&#xff0c;在WSL使用显卡会有性能损失的。而当初买这台机器的时候我不在场&#xff0c;所以我这几天刚好有空把机器给重装成Ubuntu了。本篇不限于安装oneAPI&#xff0c;因为在英特尔的…

【01】htmlcssgit网络基础知识

一、html&css 防脱发神器 一图胜千言 使用border-box控制尺寸更加直观,因此,很多网站都会加入下面的代码 * {margin: 0;padding: 0;box-sizing: border-box; }颜色的 alpha 通道 颜色的 alpha 通道标识了色彩的透明度,它是一个 0~1 之间的取值,0 标识完全透明,1…

探索什么便签软件好用,可以和手机同步的便签软件

在信息技术日新月异的今天&#xff0c;各类数字工具已经成为我们生活与工作的重要助手。便签软件作为一种简单却高效的辅助工具&#xff0c;悄然改变着人们的记录习惯与时间管理方式。而在诸多便签软件中&#xff0c;能够实现手机与电脑同步功能的产品尤显其独特的价值。那么&a…

数据结构 之 哈希表习题 力扣oj(附加思路版)

哈希表用法 哈希表&#xff1a;键 值对 键&#xff1a;可以看成数组下标&#xff0c;但是哈希表中的建可以是任意类型的&#xff0c;建不能重复,可以不是连续的 值&#xff1a;可以看成数组中的元素&#xff0c;值可以重复&#xff0c;也可以是任意类型的数据 #include<iost…