2023.08.09 NNの重み行列確認(基本)【torch】
PyTorchで構成したネットワークが保持するレイヤーの重み行列を確認する方法を示す。
まず、雛形を示す。
code:python
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net1(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Linear(2, 5)
print('# 全結合層')
model1 = Net1()
print('# レイヤの保持する重み行列はparameter()メソッドで取り出すことができる。')
print(model1.parameters())
parameters()メソッドで取り出されたオブジェクトはイタレブルなので、画面に表示するとこのようにオブジェクトの属性だけが出力される。 code:python
# 全結合レイヤ
# レイヤの保持する重み行列はparameter()メソッドで取り出すことができる。
<generator object Module.parameters at 0x7f2031d692a0>
そこで、レイヤーの重みパラメータを取り出すための方法を2通り示す。
1つ目は、rangeオブジェクトのように for 文で順に取り出す方法である。
code:python
for param in model1.parameters():
print(param)
結果は、次のようにレイヤごとに重み行列が取り出すことができる。
code:python
Parameter containing:
Parameter containing:
2つ目は、リストオブジェクトにキャストする方法である。
code:python
params = list(model1.parameters())
print(params)
これで各レイヤの重み行列を要素に持つリストを得ることができる。
code:python
[Parameter containing: