cat拼接
使用条件:合并的dim的size可以不同,但是其它的dim的size必须相同。
语法:cat([tensor1,tensor2],dim = n) # 将tensor1和tensor2的第n个维度合并
代码演示:
# 拼接与拆分
a = torch.rand(4,32,8)
b = torch.rand(5,32,8)
print(torch.cat([a,b],dim=0).shape) # torch.Size([9, 32, 8])
stack拼接
为什么要使用stack?下面会举个例子阐述一下原因:
A [32, 8] # 一个班,一共有32个同学,每个同学有8门成绩
B [32, 8] # 一个班,一共有32个同学,每个同学有8门成绩
cat:[64, 8] # 一个班,一共有64个同学,每个同学有8门成绩,不符合实际
stack: [2, 32, 8] # 2个班,每个班有32个同学,每个同学有8门成绩,符合实际
使用条件:A.shape = B.shape
代码演示:
a = torch.rand(32,8)
b = torch.rand(32,8)
print(torch.cat([a,b],dim=0).shape) # torch.Size([64, 8])
print(torch.stack([a,b],dim=0).shape) # torch.Size([2, 32, 8])
split——根据长度拆分
语法:split(len, dim = n) # 在第n个维度拆分,每个size=len
代码演示:
# c.shape = torch.Size([2, 32, 8])
aa, bb = c.split(1,dim=0)
print(aa.shape,bb.shape) # torch.Size([1, 32, 8]) torch.Size([1, 32, 8])
注意:不要超过第0维的总体长度2,等于也不行,别忘了split进行的是拆分。
chunk——根据数量拆分
语法:chunk(num, dim = n) # 在第n维进行拆分,拆分为num份
代码演示:
# c.shape = torch.Size([2, 32, 8])
aa, bb = c.chunk(2,dim = 0)
print(aa.shape,bb.shape) # torch.Size([1, 32, 8]) torch.Size([1, 32, 8])