2023.6.7 squeeze関数【torch】
squeeze関数【numpy】、PyTorchのsqueeze関数が存在するが、ここではPyTorch版について解説する。
サイズが「1」の次元は無くてもテンソルの構造を保持できるので、不要と考えることができる。これはsize()関数で確認すると分かり易い。
引数無しの場合
引数無しで利用すると、不要と判断される次元を全て削除する。
code:python
import torch
x = torch.tensor(1)
print(x, ':', x.size())
x0 = x.squeeze()
print(x0, ':', x0.size())
結果
code:python
tensor(1) : torch.Size(1)
tensor(1) : torch.Size([])
xは1次元のテンソルであるが、の要素数は1なので、次元数をもたなくても表現できる。
2つの要素をもつ1次元テンソルに対して同じ操作を施す。
code:python
x = torch.tensor(1, 2)
print(x, ':', x.size())
x0 = x.squeeze()
print(x0, ':', x0.size())
結果
code:python
tensor(1, 2) : torch.Size(2)
tensor(1, 2) : torch.Size(2)
構造を維持するために、次元は削除されない。
2つの要素を持つ2次元テンソルだとどうなるだろうか。
code:python
x = torch.tensor(1, 2)
print(x, ':', x.size())
x0 = x.squeeze()
print(x0, ':', x0.size())
結果
code:python
tensor(1, 2) : torch.Size(1, 2)
tensor(1, 2) : torch.Size(2)
不要な次元(=サイズが1の次元)をカットしていることが確認できる。
複雑な例として、1×2×1×3×1のサイズで構成されるテンソルの場合を見てみよう。サイズが1の次元は不要なので、
code:python
x = torch.tensor([[1,2,3,4,5,6]])
print(x, ':', x.size())
x0 = x.squeeze()
print(x0, ':', x0.size())
結果
code:python
tensor([[[[1,
2,
3]],
[[4,
5,
6]]]]) : torch.Size(1, 2, 1, 3, 1)
tensor([1, 2, 3,
4, 5, 6]) : torch.Size(2, 3)
こんな感じに、1×2×1×3×1 は2×3のテンソルに低次元化される。
引数で削減する次元を指定する
引数で指定した次元が不要であれば削減してくれる。
code:python
x = torch.tensor([[1,2,3,4,5,6]])
x0 = x.squeeze(2)
print(x0, ':', x0.size())
1×2×1×3×1の真ん中の「1」を削除してみる。
code:python
tensor([[[1,
2,
3],
[4,
5,
6]]]) : torch.Size(1, 2, 3, 1)
できた。
指定した次元が必要であれば、何もしない。
code:python
x = torch.tensor([[1,2,3,4,5,6]])
print(x, ':', x.size())
x0 = x.squeeze(1)
print(x0, ':', x0.size())
結果
code:python
tensor([[[[1,
2,
3]],
[[4,
5,
6]]]]) : torch.Size(1, 2, 1, 3, 1)
以上