Pytorch实现List Tensor转Tensor,reshape拼接等操作,


目录
  • 一、List Tensor转Tensor (torch.cat)
    • 高维tensor
  • 二、List Tensor转Tensor (torch.stack)

    持续更新一些常用的Tensor操作,比如List,Numpy,Tensor之间的转换,Tensor的拼接,维度的变换等操作。

    其它Tensor操作如 einsum等见:待更新。

    用到两个函数:

    • torch.cat
    • torch.stack

    一、List Tensor转Tensor (torch.cat)

    // An highlighted block
    >>> t1 = torch.FloatTensor([[1,2],[5,6]])
    >>> t2 = torch.FloatTensor([[3,4],[7,8]])
    >>> l = []
    >>> l.append(t1)
    >>> l.append(t2)
    >>> ta = torch.cat(l,dim=0)
    >>> ta = torch.cat(l,dim=0).reshape(2,2,2)
    >>> tb = torch.cat(l,dim=1).reshape(2,2,2)
    >>> ta
    tensor([[[1., 2.],
             [5., 6.]],
    
            [[3., 4.],
             [7., 8.]]])
    >>> tb
    tensor([[[1., 2.],
             [3., 4.]],
    
            [[5., 6.],
             [7., 8.]]])

    高维tensor

    ** 如果理解了2D to 3DTensor,以此类推,不难理解3D to 4D,看下面代码即可明白:**

    >>> t1 = torch.range(1,8).reshape(2,2,2)
    >>> t2 = torch.range(11,18).reshape(2,2,2)
    >>> l = []
    >>> l.append(t1)
    >>> l.append(t2)
    >>> torch.cat(l,dim=2).reshape(2,2,2,2)
    tensor([[[[ 1.,  2.],
              [11., 12.]],
    
             [[ 3.,  4.],
              [13., 14.]]],
    
    
            [[[ 5.,  6.],
              [15., 16.]],
    
             [[ 7.,  8.],
              [17., 18.]]]])
    >>> torch.cat(l,dim=1).reshape(2,2,2,2)
    tensor([[[[ 1.,  2.],
              [ 3.,  4.]],
    
             [[11., 12.],
              [13., 14.]]],
    
    
            [[[ 5.,  6.],
              [ 7.,  8.]],
    
             [[15., 16.],
              [17., 18.]]]])
    >>> torch.cat(l,dim=0).reshape(2,2,2,2)
    tensor([[[[ 1.,  2.],
              [ 3.,  4.]],
    
             [[ 5.,  6.],
              [ 7.,  8.]]],
    
    
            [[[11., 12.],
              [13., 14.]],
    
             [[15., 16.],
              [17., 18.]]]])

    二、List Tensor转Tensor (torch.stack)

    代码:

    import torch
    
    t1 = torch.FloatTensor([[1,2],[5,6]])
    t2 = torch.FloatTensor([[3,4],[7,8]])
    l = [t1, t2]
    
    t3 = torch.stack(l, dim=2)
    print(t3.shape)
    print(t3)
    
    ## output:
    ## torch.Size([2, 2, 2])
    ## tensor([[[1., 3.],
    ##          [2., 4.]],
    ##        [[5., 7.],
    ##         [6., 8.]]])

    以上为个人经验,希望能给大家一个参考,也希望大家多多支持3672js教程。

    您可能感兴趣的文章:
    • pytorch Variable与Tensor合并后 requires_grad()默认与修改方式
    • Pytorch之扩充tensor的操作
    • 详解Pytorch中的tensor数据结构

    评论关闭