2023.6.7 torch.nn.Conv2dの入出力次元の確認【torch】
こんな二値画像を入力する。
https://scrapbox.io/files/647fd6002fb355001b1fe73d.png
code:python
# https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
layer1 = nn.Conv2d(1, 2, 1)
image = 0,1,1,1,1,0],0,0,1,0,0,0,0,0,1,1,0,0,0,1,0,1,0,0,1,0,0,1,0,0,[1,0,0,0,1,1
data = torch.tensor(image, dtype=torch.float)
#plt.imshow(data)
#plt.show()
print('# Before')
print(data.shape)
x = layer1(data)
print('# After')
print(x.shape)
print(x)
#plt.imshow()
#plt.show()
実行結果は
code:text
# Before
torch.Size(1, 6, 6)
# After
torch.Size(2, 6, 6)
tensor([[0.4147, 0.7390, 0.7390, 0.7390, 0.7390, 0.4147,
0.4147, 0.4147, 0.7390, 0.4147, 0.4147, 0.4147,
0.4147, 0.4147, 0.7390, 0.7390, 0.4147, 0.4147,
0.4147, 0.7390, 0.4147, 0.7390, 0.4147, 0.4147,
0.7390, 0.4147, 0.4147, 0.7390, 0.4147, 0.4147,
0.7390, 0.4147, 0.4147, 0.4147, 0.7390, 0.7390],
[0.9589, 0.0119, 0.0119, 0.0119, 0.0119, 0.9589,
0.9589, 0.9589, 0.0119, 0.9589, 0.9589, 0.9589,
0.9589, 0.9589, 0.0119, 0.0119, 0.9589, 0.9589,
0.9589, 0.0119, 0.9589, 0.0119, 0.9589, 0.9589,
0.0119, 0.9589, 0.9589, 0.0119, 0.9589, 0.9589,
0.0119, 0.9589, 0.9589, 0.9589, 0.0119, 0.0119]],
grad_fn=<SqueezeBackward1>)
バッチ学習を念頭に、2件の画像をリストに並べて入力する。
code:python
# https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
layer1 = nn.Conv2d(2, 3, 1)
image = 0,1,1,1,1,0],0,0,1,0,0,0,0,0,1,1,0,0,0,1,0,1,0,0,1,0,0,1,0,0,[1,0,0,0,1,1
data = torch.tensor(image, image, dtype=torch.float)
print('# Before')
print(data.shape)
x = layer1(data)
print('# After')
print(x.shape)
print(x)
#plt.imshow()
#plt.show()
実行結果
code:text
# Before
torch.Size(2, 6, 6)
# After
torch.Size(3, 6, 6)
tensor([[-0.5847, -0.7914, -0.7914, -0.7914, -0.7914, -0.5847,
-0.5847, -0.5847, -0.7914, -0.5847, -0.5847, -0.5847,
-0.5847, -0.5847, -0.7914, -0.7914, -0.5847, -0.5847,
-0.5847, -0.7914, -0.5847, -0.7914, -0.5847, -0.5847,
-0.7914, -0.5847, -0.5847, -0.7914, -0.5847, -0.5847,
-0.7914, -0.5847, -0.5847, -0.5847, -0.7914, -0.7914],
[-0.3923, -0.2137, -0.2137, -0.2137, -0.2137, -0.3923,
-0.3923, -0.3923, -0.2137, -0.3923, -0.3923, -0.3923,
-0.3923, -0.3923, -0.2137, -0.2137, -0.3923, -0.3923,
-0.3923, -0.2137, -0.3923, -0.2137, -0.3923, -0.3923,
-0.2137, -0.3923, -0.3923, -0.2137, -0.3923, -0.3923,
-0.2137, -0.3923, -0.3923, -0.3923, -0.2137, -0.2137],
[-0.2455, -0.3891, -0.3891, -0.3891, -0.3891, -0.2455,
-0.2455, -0.2455, -0.3891, -0.2455, -0.2455, -0.2455,
-0.2455, -0.2455, -0.3891, -0.3891, -0.2455, -0.2455,
-0.2455, -0.3891, -0.2455, -0.3891, -0.2455, -0.2455,
-0.3891, -0.2455, -0.2455, -0.3891, -0.2455, -0.2455,
-0.3891, -0.2455, -0.2455, -0.2455, -0.3891, -0.3891]],
grad_fn=<SqueezeBackward1>)
nn.Conv2dの第1引数は、入力される画像の数と一致させる必要があった。
第2引数で、出力チャネルの数を指定するが、入力画像との関係はもうちょっと勉強が必要。