2023.6.25 gather関数【torch】
dimの向き(axisによる指定と同じ)に、indexで指定した要素から構成されたテンソルを生成する。
結果のテンソルはindexで与えたテンソルと同じサイズとなる。
2次元行列に対して2次元のindexを適用
インスタンスメソッド(result1)と、クラスメソッド(result2)が用意されている。xが処理の対象となるテンソル、aがインデックスを意味するテンソルである。
code:python
import torch
x = torch.tensor(1, 2], [3, 4)
a = torch.tensor(0, 1], [1, 0)
result1 = torch.gather(input=x, dim=0, index=a)
result2 = x.gather(dim=1, index=a)
print('# dim=0\n', result1)
print('# dim=1\n', result2)
出力はは2x2となる。
code:text
# dim=0
tensor([1, 4,
3, 2])
# dim=1
tensor([1, 4,
3, 2])
2x2テンソルに対して2x1テンソルのindexを適用
code:python
import torch
x = torch.tensor(1, 2], [3, 4)
a = torch.tensor(0, 1)
result1 = torch.gather(input=x, dim=0, index=a)
print('# dim=0\n', result1)
result2 = x.gather(dim=1, index=a)
print('# dim=1\n', result2)
出力は2x1となる。
code:text
# dim=0
tensor(1, 4)
# dim=1
tensor(1, 2)
注:対象とインデックスのテンソルの次元(形状ではない)は揃っている必要がある。上の例(2次元)に下のインデックス(1次元)を適用できない。
code:python
a = torch.tensor(0, 1)
2x2テンソルに1x2テンソルのインデックスを適用した例
code:python
import torch
x = torch.tensor(1, 2], [3, 4)
a = torch.tensor(0], [1)
result1 = torch.gather(input=x, dim=0, index=a)
print('# dim=0\n', result1)
result2 = x.gather(dim=1, index=a)
print('# dim=1\n', result2)
出力は1x2となる。
code:text
# dim=0
tensor([1,
3])
# dim=1
tensor([1,
4])