Scatter函数
scatter_(dim, index, src, reduce=None) → Tensor
Writes all values from the tensor
src
intoself
at the indices specified in theindex
tensor. For each value insrc
, its output index is specified by its index insrc
fordimension != dim
and by the corresponding value inindex
fordimension = dim
.
简单理解就是将 src 张量中的元素散落到 self 张量中,具体选择哪个元素,选择的元素散落到哪个位置由index张量决定,具体的映射规则为:
1 | # 其中 i,j,k 为index张量中元素坐标。 |
参数
- dim(int) 指index数组元素替代的坐标(dim = 0 替代src中的横坐标)
- index (LongTensor) 可以为空,最大与src张量形状相同
- src(Tensor or float) 源张量
- reduce 聚集函数(src替换元素与self中被替换元素执行的操作,默认是替代,可以进行add,multiply等操作)
具体例子:
1 | 1, 11).reshape((2, 5)) src = torch.arange( |
以上例子中scatter函数执行的操作等价于:
1 | # 遍历index,在dim = 0 时,替换i |
gather() 函数
Gathers values along an axis specified by dim.
沿着某条轴对元素进行聚集
scatter的逆操作,挑选源张量的某些元素,放置到新张量中,具体映射规则与scatter类似:
1 | out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 |
参数
- input 输入张量
- dim 聚集的轴
- index 聚集的index张量
具体例子:
1 | 1, 2], [3, 4]]) t = torch.tensor([[ |
以上例子中gather()函数等价于scatter()函数的操作:
1 | # 遍历index,在dim = 1 时,替换j |
scatter vs gather
相同点:
- scatter 和 gather 函数都是根据index数组从 src/input 源张量中选取指定元素
不同点:
- scatter选取元素放置到目标张量中,放置位置由
index[i][j]
决定 - gather选取元素放置组成新的张量,选取位置由
index[i][j]
决定
应用
1. 使用scatter函数将向量转化为one-hot形式
以转化为n x 10
的n个10维行one-hot向量为例子:
1 | 9, 8, 9, 5, 6, 7]) t = torch.tensor([ |
使用argmax转化回去:
1 | >>> one_hot.argmax(dim = 1,keepdim = True) |