如何优雅使用torch的view? 如何理解RowMajor/ColumnMajor?

作者: rainlin 分类: 大模型与GPU编程 发布时间: 2024-07-28 13:02

背景

最近看一些torch代码,有很多view操作,之前没深入了解,存在一些问题:

  1. view的参数该如何填写呢?
  2. view后得到的tensor结果是怎样的呢?
  3. view是如何将tensor进行映射的呢?

本质上涉及到tensor元素在内存中的布局问题,所以就查了些资料,写下自己的理解。

什么是RowMajor/ColumnMajor?

首先需要明确,无论tensor是多少维度,在内存中都只有一维,所以需要根据一定的布局方式,将高维tensor排布到内存中

RowMajor和ColumnMajor是指tensor元素在内存中的布局方式。Row是按行排放,Col则是按列存放,torch默认是是RawMajor。

看维基百科的一张图:

图中展示了矩阵(二维tensor)在内存中布局的方式,简单来说就是,

按照RowMajor方式,以上矩阵在内存中布局为:a_11,a_12,a_13,a_21,a_22,a_23,a_31,a_32,a_33
按照ColumnMajor方式,以上矩阵内存中的布局为:a_11,a_21,a_31,a_12,a_22,a_32,a_13,a_23,a_33

如果是多维tensor呢?

对于多维tensor而言,我们把以上规则进行泛化,本质上RowMajor布局是最右侧的维度变化最快,ColumnMajor反之。

以[2,3,4]大小的tensor为例,其不同major方式的存放顺序如下:

RowMajor: [0][0][0],[0][0][1],[0][0][2],[0][0][3],[0][1][0],[0][1][1].....

ColumnMajor: [0][0][0],[1][0][0],[0][1][0],[1][1][0],[0][2][0],[1][2][0].....

了解布局后,如何理解view?

view不会改变元素的内存布局,只是换一个维度去看待tensor。所以我们先根据原始tensor,明确元素在内存中的布局后,再按照同样思路去反向映射到设置的view维度,便可以获得view后的结果。
以下为一些示例,可以帮助理解view的过程:

import torch

# a:[2,3,4]
a = torch.tensor(
    [
        [
            [1, 2, 3, 4], 
            [5, 6, 7, 8], 
            [9, 10, 11, 12]
         ],
        [
            [13, 14, 15, 16], 
            [17, 18, 19, 20], 
            [21, 22, 23, 24]
         ],
    ]
)
print(a.size())
'''
torch.Size([2, 3, 4])
'''


print(a.view(-1))
'''
[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]
'''


print(a.view(2,-1))
'''
tensor([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12],
        [13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]])
'''


print(a.view(6,-1))
'''
tensor([[ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 9, 10, 11, 12],
        [13, 14, 15, 16],
        [17, 18, 19, 20],
        [21, 22, 23, 24]])
'''


print(a.view(24,-1))
'''
tensor([[ 1],
        [ 2],
        [ 3],
        [ 4],
        [ 5],
        [ 6],
        [ 7],
        [ 8],
        [ 9],
        [10],
        [11],
        [12],
        [13],
        [14],
        [15],
        [16],
        [17],
        [18],
        [19],
        [20],
        [21],
        [22],
        [23],
        [24]])
'''


print(a.view(4,3,2))
'''
tensor([[[ 1,  2],
         [ 3,  4],
         [ 5,  6]],

        [[ 7,  8],
         [ 9, 10],
         [11, 12]],

        [[13, 14],
         [15, 16],
         [17, 18]],

        [[19, 20],
         [21, 22],
         [23, 24]]])
'''
  1. view的tensor元素总数应该与之前一致,即所有维度的乘积相等。
  2. view中的-1表示由torch自动计算,其实就是根据1中的规则,进行除法得到。

后记

  1. 对于RowMajor/ColumnMajor的理解不止在view中有用,在cuTLASS等库中也是常用概念。

 

 

本文链接: http://rainlin.top/archives/230
转载请注明转载自: Rainlin Home

发表回复

您的电子邮箱地址不会被公开。 必填项已用 * 标注