之前很长的一段时间内,我都不太清楚如何感性地理解 PyTorch 中的 dim 参数。最近琢磨到了一个还算比较好理解的方法,故简单记录在这里。

dim 在 PyTorch 的很多函数中都可以指定,例如 sum / mode / unsqueeze / topk 等等,主要是告诉函数应该针对张量的哪个特定维度操作。

这在输入张量维度很高的时候就不那么直观了。虽说不理解问题不大,最多手写循环就能达到目的。但如果我们想尽量避免使用 python 的显式循环,或者还想要利用广播机制来更快的完成计算任务,就不得不总结一下了。

  • 聚合类函数减小维度数的运算,reduction operations),例如 sum / mean / max / min / mode / topk 等等;

    • dim 通常的语义是 “沿这个维度进行消除”,如果有指定 keepdim=True,则这个维度 size 压缩为 1;

    • dim 的值就对应张量 shape 的索引;

    • 被操作的每个元素的 shape 就是 原张量的 shape 在 dim 索引之后组成的新的 shape,即 shape[dim+1:]

    例如对于 a = torch.tensor([ [ [[1],[2],[3]], [[2],[3],[4]] ], [ [[3],[4],[5]], [[4],[5],[6]] ] ]),它的 shape 是 (2, 2, 3, 1)

    tips. 对于高维张量的形状,请从外向里读,这样更清楚一点

    那么 sum(a, dim=2) 的含义就是沿着 size 为 3 的维度(shape 索引是 2)相加,被求和的元素的 shape 就是原 shape 中索引为 2 向后的组成的。不难发现 size 为 3 的维度的元素是 shape 是 (1,) 的子元素,把子元素加起来就行,答案应该是 [ [ [[6]], [[9]] ], [ [[12]], [[15]] ] ]

    再来个简单的:b = torch.tensor([[1,2,3], [4,5,6]]),这个 (2,3) 的张量是不是就非常清楚了?我们计算 sum(b, dim=0) 就是对 size 为 2 的维度求和,也就是元素是 shape (3,) 的张量求和,答案显然是 [[5,7,9]]

  • 拼接类函数不改变维度数的运算),例如 cat 等,dim 通常的语义是 “拼接的方向”;

    • “拼接的方向” 是指,拼接后 size 有变化的维度;例如 a = torch.zeros(2, 3); b = torch.zeros(2, 4)torch.cat([a,b], dim=1) 就是让行对齐、列增大的拼接方式;
    • 拼接得到的张量的 shape 和原张量的 shape 的维度数相同,只是某个维度上的 size 有所不同;
  • 扩展类函数增加维度数的运算,expansion operations),例如 unsqueeze / stack 等,dim 通常的语义是在原张量的指定维度下添加 size 为 N ($N\ge1$) 的维度;

    • 对于 unsqueeze,没有向张量引入其他信息,只是在原张量 shape 索引为 dim 的位置插入 1 来扩展;

    • 比较难以理解的是 stack,很多人会把 stackcat 的作用搞混。但我们只需要搞清楚本质上 stack 是维度扩展类函数,而 cat 则是拼接类函数,就行了!cat 不改变张量的维度,只是将两个或以上张量在已有维度上拼接;而 stack 则是通过新增一个维度来连接两个张量。

      a = torch.zeros(2,3); b = torch.ones(2,3),如果执行 torch.stack([a,b], dim=2),则是在原张量 shape 为 (2,3) 的情况下构造一个 shape 为 (2,3,2) 的新张量(在 shape 索引为 dim=2 的位置插入一个 size=2 的维度,分别装 ab)。至于最终如何表示输出的张量,很简单,就是在第三个维度把两个张量排在一起,这么表示:[[[0, 1], [0, 1], [0, 1]], [[0, 1], [0, 1], [0, 1]]],还好只有 3 维,我们可以感性地画出来:

综上,在维度很高、没法感性理解的时候,可以尝试列出输入的 shape(从外向内读),然后在你要执行的函数中指定 dim,按规则写出输出的 shape,你就能清楚这个操作究竟在做什么了。