如何理解 PyTorch 函数的 dim 参数
之前很长的一段时间内,我都不太清楚如何感性地理解 PyTorch 中的 dim
参数。最近琢磨到了一个还算比较好理解的方法,故简单记录在这里。
dim
在 PyTorch 的很多函数中都可以指定,例如 sum / mode / unsqueeze / topk
等等,主要是告诉函数应该针对张量的哪个特定维度操作。
这在输入张量维度很高的时候就不那么直观了。虽说不理解问题不大,最多手写循环就能达到目的。但如果我们想尽量避免使用 python 的显式循环,或者还想要利用广播机制来更快的完成计算任务,就不得不总结一下了。
聚合类函数(减小维度数的运算,reduction operations),例如
sum / mean / max / min / mode / topk
等等;dim
通常的语义是 “沿这个维度进行消除”,如果有指定keepdim=True
,则这个维度 size 压缩为 1;dim
的值就对应张量 shape 的索引;被操作的每个元素的 shape 就是 原张量的 shape 在
dim
索引之后组成的新的 shape,即shape[dim+1:]
;
例如对于
a = torch.tensor([ [ [[1],[2],[3]], [[2],[3],[4]] ], [ [[3],[4],[5]], [[4],[5],[6]] ] ])
,它的 shape 是(2, 2, 3, 1)
;tips. 对于高维张量的形状,请从外向里读,这样更清楚一点;
那么
sum(a, dim=2)
的含义就是沿着 size 为 3 的维度(shape 索引是 2)相加,被求和的元素的 shape 就是原 shape 中索引为 2 向后的组成的。不难发现 size 为 3 的维度的元素是 shape 是(1,)
的子元素,把子元素加起来就行,答案应该是[ [ [[6]], [[9]] ], [ [[12]], [[15]] ] ]
;再来个简单的:
b = torch.tensor([[1,2,3], [4,5,6]])
,这个(2,3)
的张量是不是就非常清楚了?我们计算sum(b, dim=0)
就是对 size 为 2 的维度求和,也就是元素是 shape(3,)
的张量求和,答案显然是[[5,7,9]]
;拼接类函数(不改变维度数的运算),例如
cat
等,dim
通常的语义是 “拼接的方向”;- “拼接的方向” 是指,拼接后 size 有变化的维度;例如
a = torch.zeros(2, 3); b = torch.zeros(2, 4)
,torch.cat([a,b], dim=1)
就是让行对齐、列增大的拼接方式; - 拼接得到的张量的 shape 和原张量的 shape 的维度数相同,只是某个维度上的 size 有所不同;
- “拼接的方向” 是指,拼接后 size 有变化的维度;例如
扩展类函数(增加维度数的运算,expansion operations),例如
unsqueeze / stack
等,dim
通常的语义是在原张量的指定维度下添加 size 为 N ($N\ge1$) 的维度;对于
unsqueeze
,没有向张量引入其他信息,只是在原张量 shape 索引为dim
的位置插入 1 来扩展;比较难以理解的是
stack
,很多人会把stack
和cat
的作用搞混。但我们只需要搞清楚本质上stack
是维度扩展类函数,而cat
则是拼接类函数,就行了!cat
不改变张量的维度,只是将两个或以上张量在已有维度上拼接;而stack
则是通过新增一个维度来连接两个张量。像
a = torch.zeros(2,3); b = torch.ones(2,3)
,如果执行torch.stack([a,b], dim=2)
,则是在原张量 shape 为(2,3)
的情况下构造一个 shape 为(2,3,2)
的新张量(在 shape 索引为dim=2
的位置插入一个 size=2 的维度,分别装a
和b
)。至于最终如何表示输出的张量,很简单,就是在第三个维度把两个张量排在一起,这么表示:[[[0, 1], [0, 1], [0, 1]], [[0, 1], [0, 1], [0, 1]]]
,还好只有 3 维,我们可以感性地画出来:
综上,在维度很高、没法感性理解的时候,可以尝试列出输入的 shape(从外向内读),然后在你要执行的函数中指定 dim
,按规则写出输出的 shape,你就能清楚这个操作究竟在做什么了。