0%

pytorch Scatter和gather函数使用

Scatter函数

scatter_(dim, index, src, reduce=None) → Tensor

Writes all values from the tensor src into self at the indices specified in the index tensor. For each value in src, its output index is specified by its index in src for dimension != dim and by the corresponding value in index for dimension = dim.

简单理解就是将 src 张量中的元素散落到 self 张量中,具体选择哪个元素,选择的元素散落到哪个位置由index张量决定,具体的映射规则为:

三维张量为例
1
2
3
4
# 其中 i,j,k 为index张量中元素坐标。
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2

参数

  • dim(int) 指index数组元素替代的坐标(dim = 0 替代src中的横坐标)
  • index (LongTensor) 可以为空,最大与src张量形状相同
  • src(Tensor or float) 源张量
  • reduce 聚集函数(src替换元素与self中被替换元素执行的操作,默认是替代,可以进行add,multiply等操作)

具体例子:

1
2
3
4
5
6
7
8
9
>>> src = torch.arange(1, 11).reshape((2, 5))
>>> src
tensor([[ 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10]])
>>> index = torch.tensor([[0, 1, 2, 0]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
tensor([[1, 0, 0, 4, 0],
[0, 2, 0, 0, 0],
[0, 0, 3, 0, 0]])

以上例子中scatter函数执行的操作等价于:

1
2
3
4
# 遍历index,在dim = 0 时,替换i
for i in range(index.shape[0]):
for j in range(index.shape[1]):
self[index[i][j]][j] = src[i][j]

gather() 函数

Gathers values along an axis specified by dim.

沿着某条轴对元素进行聚集

scatter的逆操作,挑选源张量的某些元素,放置到新张量中,具体映射规则与scatter类似:

1
2
3
out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2

参数

  • input 输入张量
  • dim 聚集的轴
  • index 聚集的index张量

具体例子:

1
2
3
>>> t = torch.tensor([[1, 2], [3, 4]])
>>> torch.gather(t, 1, torch.tensor([[0, 1]]))
tensor([[1, 2]])

以上例子中gather()函数等价于scatter()函数的操作:

1
2
3
4
# 遍历index,在dim = 1 时,替换j
for i in range(index.shape[0]):
for j in range(index.shape[1]):
result[i][j] = input[i][index[i][j]]

scatter vs gather

相同点:

  • scatter 和 gather 函数都是根据index数组从 src/input 源张量中选取指定元素

不同点:

  • scatter选取元素放置到目标张量中,放置位置由index[i][j]决定
  • gather选取元素放置组成新的张量,选取位置由index[i][j]决定

应用

1. 使用scatter函数将向量转化为one-hot形式

以转化为n x 10的n个10维行one-hot向量为例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
>>> t = torch.tensor([9, 8, 9, 5, 6, 7])
>>> t.view(-1,1)
tensor([[9],
[8],
[9],
[5],
[6],
[7]])
>>> one_hot = torch.zeros(t.shape[0],10).scatter(1,t.view(-1,1),1)
>>> ont_hot
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
[0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]])

使用argmax转化回去:

1
2
3
4
5
6
7
>>> one_hot.argmax(dim = 1,keepdim = True)
tensor([[9],
[8],
[9],
[5],
[6],
[7]])

待补充