2023.08.09 NNの重み行列確認(構成別)【torch】
code:python
import torch
import torch.nn as nn
import torch.nn.functional as F
def show_params_shape(model):
for param in model.parameters():
print(param.shape)
class Net1(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Linear(3, 5)
class Net2(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Conv2d(2, 3, (4, 4))
class Net3(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.RNN(2, 5, 3, batch_first=True)
print('# Lienar')
show_params_shape(Net1())
print('# Conv2d')
show_params_shape(Net2())
print('# RNN')
show_params_shape(Net3())
結果
code:python
# Lienar
torch.Size(3, 4) # 重み(出力 × 入力) # Conv2d
torch.Size(3, 4, 4, 4) # 重み(出力 × 入力 × カーネルh × カーネルw) # RNN