クラスと関数【torch】
PyTorchは提供するニューラルネットの構成要素には、関数のまま利用するものとインスタンス化して利用するものがある。
平均二乗誤差
関数:torch.nn.functional.mse_loss
メソッド:torch.nn.MSELoss
ReLU関数(活性化関数)
関数:torch.nn.functional.relu
メソッド:torch.nn.ReLU
線形層
関数:torch.nn.functional.linear
メソッド:torch.nn.Linear
メソッドと書かれた方が「インスタンス化」して利用するものである。
code:class_or_function1.py
import torch as pt
import torch.nn as nn
import torch.nn.functional as F
data1 = pt.tensor(1, 2, 1, dtype=pt.float)
data2 = pt.tensor(1, 2, 0, dtype=pt.float)
e1 = F.mse_loss(data1, data2) # 関数のまま
print('mse_loss関数:', e1)
loss = nn.MSELoss() # インスタンス化
e2 = loss(data1, data2) # インスタンスにデータを与えて実行
print('MSELossインスタンス:', e2)
例えば、線形層は処理を行うために層数や重み行列のように、ネットワークの構造や状態に関する情報が要求される。
線形層の処理を関数を用いて処理するには、これら全ての情報を引数として与える必要があるのに対して、メソッドはそれらの情報をインスタンス変数として保持できるため情報の管理が容易であるという違いがある。
この点に対して、上で挙げたMSELossやReLU関数は内部状態をもたない。これらのクラスを用意することにどのような利点があるのだろうか?以下で考察する。
~~~~~~~~~~~~~~~~~~
あるネットワークのクラスを設計を考える。このネットワークが有する機能は、Pytorchが提供するインスタンスをインスタンス変数として持たせることにより実装する。
このクラスをnn.Moduleクラスのサブクラスとすることで、クラスが有する機能の一覧を簡単に確認することが可能となる。
code:class_or_function2.py
import torch as pt
import torch.nn as nn
import torch.nn.functional as F
class Mynet(nn.Module): # Mynetはnn.Moduleのサブクラス
def __init__(self):
super().__init__()
self.loss = nn.MSELoss() # 構成要素
def forward(self, data1, data2):
x = self.loss(data1, data2)
return x
data1 = pt.tensor(1, 2, 1, dtype=pt.float)
data2 = pt.tensor(1, 2, 0, dtype=pt.float)
net = Mynet()
e = net.forward(data1, data2)
print('MSE', e)
print('# netインスタンスの構成要素:')
print(net)
code:(結果).py
MSE tensor(0.3333)
# netインスタンスの構成要素:
Mynet(
(loss): MSELoss()
)
このように、lossという変数がMSELoss型インスタンスに割り当てられていることが簡単に把握できる。
クラスが有する構成要素を増やしてみよう。
code:class_or_function3.py
import torch as pt
import torch.nn as nn
class Mynet(nn.Module):
def __init__(self):
super().__init__()
self.loss1 = nn.MSELoss()
self.loss2 = nn.MSELoss()
self.relu = nn.ReLU()
self.lin1 = nn.Linear(4, 10)
self.lin2 = nn.Linear(10, 3)
model = Mynet()
print('modelインスタンスの構成要素:')
print(model)
code:(結果).py
modelインスタンスの構成要素:
Mynet(
(loss1): MSELoss()
(loss2): MSELoss()
(relu): ReLU()
(lin1): Linear(in_features=4, out_features=10, bias=True)
(lin2): Linear(in_features=10, out_features=3, bias=True)
)
上がmodelインスタンスの構成要素の一覧である。線形層(Linear)についてはこれららが有する構造についての情報まで表示されている。
最後に、クラス定義文の先頭を次のように修正した時の実行結果を確認せよ。
code:(抜粋).py
class Mynet(): # 何も継承しない