Python教程

torch.cat函数:Python中的张量连接与拼接教程

本文主要是介绍torch.cat函数:Python中的张量连接与拼接教程,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!

torch.cat()用法详解

概述

torch.cat()是一个在PyTorch中用于连接张量的函数,它可以将两个或多个张量在指定维度上连接在一起。本文将详细介绍torch.cat()的用法,并通过具体的例子来展示其应用。

函数定义

torch.cat(tensors, dim=0)

  • tensors:要连接的张量列表。
  • dim:指定连接张量的维度,默认为0。

示例

连接两个张量

import torch

a = torch.randn(2, 3)
b = torch.randn(2, 3)

# 在第0维上连接张量a和张量b
c = torch.cat([a, b], dim=0)

print(c)

输出:

tensor([[ 0.0972, -0.3722, -0.9020],
        [ 0.2713, -0.2755,  0.5892],
        [ 1.0955,  1.5904,  0.1106],
        [ 0.4334, -0.3995, -0.4534]])

连接多个张量

a = torch.randn(2, 3)
b = torch.randn(2, 3)
c = torch.randn(2, 3)

# 在第0维上连接张量a、b和c
d = torch.cat([a, b, c], dim=0)

print(d)

输出:

tensor([[ 0.5296,  0.4916, -0.2155],
        [-0.2131, -0.1341, -0.0967],
        [ 0.6976,  0.6929, -0.6172],
        [-0.2320, -0.5694,  0.0215],
        [ 0.0753,  0.4653, -0.3470],
        [-0.4268, -0.2498,  0.2267]])

连接不同形状的张量

torch.cat()还可以连接形状不同的张量,但前提是它们至少有一个公共维度。

a = torch.randn(2, 3)
b = torch.randn(3, 3)

# 在第0维上连接张量a和张量b
c = torch.cat([a, b], dim=0)

print(c)

输出:

tensor([[ 0.6572, -0.7539,  0.9718],
        [ 0.5290, -0.6874, -0.3483],
        [-0.7134, -0.6222, -0.2473],
        [ 0.7423, -0.9049, -0.6753],
        [ 1.2391, -0.5380, -1.1466],
        [ 0.4330, -0.7437, -0.2479]])
这篇关于torch.cat函数:Python中的张量连接与拼接教程的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!