pytorch中gather函數的理解。

函數torch.gather(input, dim, index, out=None) → Tensor
沿給定軸 dim ,將輸入索引張量 index 指定位置的值進行聚合.
對一個 3 維張量,輸出可以定義為:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

Parameters:

  • input (Tensor) – 源張量
  • dim (int) – 索引的軸
  • index (LongTensor) – 聚合元素的下標(index需要是torch.longTensor類型)
  • out (Tensor, optional) – 目標張量

使用說明舉例:

  1. dim = 1
import torch
a = torch.randint(0, 30, (2, 3, 5))
print(a)
'''
tensor([[[ 18.,   5.,   7.,   1.,   1.],
         [  3.,  26.,   9.,   7.,   9.],
         [ 10.,  28.,  22.,  27.,   0.]],

        [[ 26.,  10.,  20.,  29.,  18.],
         [  5.,  24.,  26.,  21.,   3.],
         [ 10.,  29.,  10.,   0.,  22.]]])
'''
index = torch.LongTensor([[[0,1,2,0,2],
                          [0,0,0,0,0],
                          [1,1,1,1,1]],
                        [[1,2,2,2,2],
                         [0,0,0,0,0],
                         [2,2,2,2,2]]])
print(a.size()==index.size())
b = torch.gather(a, 1,index)
print(b)
'''
True
tensor([[[ 18.,  26.,  22.,   1.,   0.],
         [ 18.,   5.,   7.,   1.,   1.],
         [  3.,  26.,   9.,   7.,   9.]],

        [[  5.,  29.,  10.,   0.,  22.],
         [ 26.,  10.,  20.,  29.,  18.],
         [ 10.,  29.,  10.,   0.,  22.]]])
可以看到沿著dim=1,也就是列的時候。輸出tensor第一頁內容,
第一行分別是 按照index指定的,
input tensor的第一頁 
第一列的下標為0的元素 第二列的下標為1元素 第三列的下標為2的元素,第四列下標為0元素,第五列下標為2元素
index-->0,1,2,0,2    output--> 18.,  26.,  22.,   1.,   0.
'''
  1. dim =2
c = torch.gather(a, 2,index)
print(c)
'''
tensor([[[ 18.,   5.,   7.,  18.,   7.],
         [  3.,   3.,   3.,   3.,   3.],
         [ 28.,  28.,  28.,  28.,  28.]],

        [[ 10.,  20.,  20.,  20.,  20.],
         [  5.,   5.,   5.,   5.,   5.],
         [ 10.,  10.,  10.,  10.,  10.]]])
dim = 2的時候就安裝 行 聚合了。參照上面的舉一反三。
'''
  1. dim = 0
index2 = torch.LongTensor([[[0,1,1,0,1],
                          [0,1,1,1,1],
                          [1,1,1,1,1]],
                        [[1,0,0,0,0],
                         [0,0,0,0,0],
                         [1,1,0,0,0]]])
d = torch.gather(a, 0,index2)
print(d)
'''
tensor([[[ 18.,  10.,  20.,   1.,  18.],
         [  3.,  24.,  26.,  21.,   3.],
         [ 10.,  29.,  10.,   0.,  22.]],

        [[ 26.,   5.,   7.,   1.,   1.],
         [  3.,  26.,   9.,   7.,   9.],
         [ 10.,  29.,  22.,  27.,   0.]]])
這個有點特殊,dim = 0的時候(三維情況下),是從不同的頁收集元素的。
這里舉的例子只有兩頁。所有index在0,1兩個之間選擇。
輸出的矩陣元素也是按照index的指定。分別在第一頁和第二頁之間跳著選的。
index [0,1,1,0,1]的意思就是。
在第一頁選這個位置的元素,在第二頁選這個位置的元素,在第二頁選,第一頁選,第二頁選。

'''

最后編輯于
?著作權歸作者所有,轉載或內容合作請聯系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。

推薦閱讀更多精彩內容