2023.6.22 CNNのネットワーク(C×3+F×3)【torch】
3層の畳み込み層をもつCNNのサンプル
code:python
class Net(nn.Module):
def __init__(self, args):
super().__init__()
# conv layer
# pooling, Window & Stride size
# FC layer
OUTPUT = args'OUTPUT' # for output layer #
OH1 = int(( H1 + 2*P1 - FH1)/S1 + 1) # conv1への入力、タテ
OW1 = int(( W1 + 2*P1 - FW1)/S1 + 1) # conv1への入力、ヨコ
OHP1 = int(OH1 // POOL1)
OWP1 = int(OW1 // POOL1)
OH2 = int((OHP1 + 2*P2 - FH2)/S2 + 1) # conv2への入力、タテ
OW2 = int((OWP1 + 2*P2 - FW2)/S2 + 1) # conv2への入力、ヨコ
OHP2 = int(OH2 // POOL2) # conv2からの出力をPooling、タテ
OWP2 = int(OW2 // POOL2) # conv2からの出力をPooling、ヨコ
OH3 = int((OHP2 + 2*P3 - FH3)/S3 + 1) # conv3への入力、タテ
OW3 = int((OWP2 + 2*P3 - FW3)/S3 + 1) # conv3への入力、ヨコ
OHP3 = int(OH3 // POOL3) # conv3からの出力をPooling、タテ
OWP3 = int(OW3 // POOL3) # conv3からの出力をPooling、ヨコ
print(
'output of conv1 :', OH1, OW1,
', outout of pool1 :', OHP1, OWP1,
', output of conv2 :', OH2, OW2,
', outout of pool2 :', OHP2, OWP2
', output of conv3 :', OH3, OW3,
', outout of pool3 :', OHP3, OWP3
)
self.relu = nn.ReLU()
self.pool1 = nn.MaxPool2d(POOL1, stride=POOL1)
self.pool2 = nn.MaxPool2d(POOL2, stride=POOL2)
self.pool3 = nn.MaxPool2d(POOL3, stride=POOL3)
self.softmax = nn.Softmax(dim=1)
self.conv1 = nn.Conv2d(CH1_i, CH1_o, kernel_size=(FH1, FW1), padding=P1, stride=S1)
self.conv2 = nn.Conv2d(CH2_i, CH2_o, kernel_size=(FH2, FW2), padding=P2, stride=S2)
self.conv3 = nn.Conv2d(CH3_i, CH3_o, kernel_size=(FH3, FW3), padding=P3, stride=S3)
self.fc1 = nn.Linear(CH3_o*OHP3*OWP3, HIDDEN_FC1)
self.fc2 = nn.Linear(HIDDEN_FC1, HIDDEN_FC2)
self.fc3 = nn.Linear(HIDDEN_FC2, OUTPUT)
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.pool1(x)
x = self.conv2(x)
x = self.relu(x)
x = self.pool2(x)
x = self.conv3(x)
x = self.relu(x)
x = self.pool3(x)
x = x.view(x.size()0, -1) x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
return x
使い方は以下のようになる。
code:python
# ネットワークの構成、(Conv2d + ReLU + MaxPoolin)×3 + FC×3
args = {}
## conv layer
args'H1', args'W1' = 129, 44 # iput image size ## pooling, Window & Stride size
## FC layer
#
net = mycnn.Net(args)
用途に応じてカスタマイズしてください。