API和用法:
torch.where(condition, x, y) -> Tensor
condition:判断条件,一个布尔类型的张量,表示条件。
若当前index满足条件,则取x中index对应的元素
若当前index不满足条件,则取y中index对应的元素
形状:
d 是一个shape和 b,c 相同的tensor,也就是 b,c 的shape 也必须相同。
注意:a不需要和b,c一样的shape
示例:
import torch# 创建一个布尔类型的张量,表示条件
condition = torch.tensor([True, False, True, False])# 创建两个与 condition 形状相同的张量
x = torch.tensor([1, 2, 3, 4])
y = torch.tensor([5, 6, 7, 8])# 使用 torch.where() 函数获取满足条件的元素索引
result = torch.where(condition, x, y)print(result)
输出:
tensor([1, 6, 3, 8])
torch.where()函数-CSDN博客
torch.where()详解-CSDN博客
torch.where()函数解读-CSDN博客