2023.6.7 torch.nn.Conv2dの入出力次元の確認【torch】
こんな二値画像を入力する。
https://scrapbox.io/files/647fd6002fb355001b1fe73d.png
code:python
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
layer1 = nn.Conv2d(1, 2, 1)
image = 0,1,1,1,1,0],0,0,1,0,0,0,0,0,1,1,0,0,0,1,0,1,0,0,1,0,0,1,0,0,[1,0,0,0,1,1 data = torch.tensor(image, dtype=torch.float) print('# Before')
print(data.shape)
x = layer1(data)
print('# After')
print(x.shape)
print(x)
実行結果は
code:text
# Before
# After
grad_fn=<SqueezeBackward1>)
バッチ学習を念頭に、2件の画像をリストに並べて入力する。
code:python
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
layer1 = nn.Conv2d(2, 3, 1)
image = 0,1,1,1,1,0],0,0,1,0,0,0,0,0,1,1,0,0,0,1,0,1,0,0,1,0,0,1,0,0,[1,0,0,0,1,1 print('# Before')
print(data.shape)
x = layer1(data)
print('# After')
print(x.shape)
print(x)
実行結果
code:text
# Before
# After
grad_fn=<SqueezeBackward1>)
nn.Conv2dの第1引数は、入力される画像の数と一致させる必要があった。
第2引数で、出力チャネルの数を指定するが、入力画像との関係はもうちょっと勉強が必要。