torch.cat()
是一个在PyTorch中用于连接张量的函数,它可以将两个或多个张量在指定维度上连接在一起。本文将详细介绍torch.cat()
的用法,并通过具体的例子来展示其应用。
torch.cat(tensors, dim=0)
tensors
:要连接的张量列表。dim
:指定连接张量的维度,默认为0。import torch a = torch.randn(2, 3) b = torch.randn(2, 3) # 在第0维上连接张量a和张量b c = torch.cat([a, b], dim=0) print(c)
输出:
tensor([[ 0.0972, -0.3722, -0.9020], [ 0.2713, -0.2755, 0.5892], [ 1.0955, 1.5904, 0.1106], [ 0.4334, -0.3995, -0.4534]])
a = torch.randn(2, 3) b = torch.randn(2, 3) c = torch.randn(2, 3) # 在第0维上连接张量a、b和c d = torch.cat([a, b, c], dim=0) print(d)
输出:
tensor([[ 0.5296, 0.4916, -0.2155], [-0.2131, -0.1341, -0.0967], [ 0.6976, 0.6929, -0.6172], [-0.2320, -0.5694, 0.0215], [ 0.0753, 0.4653, -0.3470], [-0.4268, -0.2498, 0.2267]])
torch.cat()
还可以连接形状不同的张量,但前提是它们至少有一个公共维度。
a = torch.randn(2, 3) b = torch.randn(3, 3) # 在第0维上连接张量a和张量b c = torch.cat([a, b], dim=0) print(c)
输出:
tensor([[ 0.6572, -0.7539, 0.9718], [ 0.5290, -0.6874, -0.3483], [-0.7134, -0.6222, -0.2473], [ 0.7423, -0.9049, -0.6753], [ 1.2391, -0.5380, -1.1466], [ 0.4330, -0.7437, -0.2479]])