PyTorch: テンソル演算メモ
ペアワイズ結合
2つのベクトル集合をペアワイズに結合する.
たとえば以下のようなデータで考える.
code: ipython
In 3: x = torch.Tensor(np.arange(2*3*4).reshape(2,3,4)) # Shape: (batch_size, num_x, embedding_size) In 4: y = torch.Tensor(-np.arange(2*2*4).reshape(2,2,4)) # Shape: (batch_size, num_y, embedding_size) ペアワイズに結合して (batch_size, num_x, num_y, 2*embedding_size)を得るとすると.
code: ipython
In 7: shape = (2, 3, 2 4) In 8: expandex_x = x.unsqueeze(2).expand(shape) In 9: expandex_y = y.unsqueeze(1).expand(shape) In 10: torch.cat((expandex_x, expandex_y), -1)