2023.6.22 計算グラフの可視化【torchvision】
torchvizモジュールのmake_dot関数を利用することで、pytorchの計算で構成された計算グラフを可視化することができる。
WSL2環境では、Graphvizが要求されるので、aptでインストールしておこう。
$ sudo apt install graphviz
$ pipt3 install torchviz
code:python
import torch
from torchviz import make_dot
def f(x):
return x**2
x1 = torch.tensor(1, dtype=torch.float, requires_grad=True)
y = f(x1)
y.backward()
params={'Input': x1, 'Output': y}
image = make_dot(y, params=params)
image.format = 'png'
#image.format = 'x11'
image.render()
https://scrapbox.io/files/646f117e99cca4001c8c2349.png
もうちょっと複雑な例
code:python
import torch
from torchviz import make_dot
x1 = torch.tensor(1, dtype=torch.float, requires_grad=True)
x2 = torch.tensor(1, dtype=torch.float)
x3 = torch.tensor(1, dtype=torch.float, requires_grad=True)
a = 2*x1 + x2
b = torch.sin(x2)
c = a*b + torch.log(x3)
y = c + x1
y.backward()
params={'Input1': x1, 'Input2':x2, 'Input3':x3, 'Output': y, 'Hoge':c}
image = make_dot(y, params=params)
image.format = 'png'
#image.format = 'x11'
image.render()
結果から
出力はbackward関数を適用した従属変数1つのみ、これをルートノードとする(一般の木構造と比べると上下逆だが)
出力ノードからリーフノードに到るまでに、計算グラフが構成されている。
requires_grad=Trueされたテンソルがリーフノードとなる。
make_dot関数の
第1引数にルードノード(出力となる変数)
名前付き引数paramsに、ルートノードとリーフノードに関する情報を辞書によって与える。
キーとして与えた文字列がノードの名称となる
データに対応する変数を与える。
ループ処理から構成される計算にも適用できる。
code:python
import torch
from torchviz import make_dot
x1 = torch.tensor(1, dtype=torch.float, requires_grad=True)
x = x1
for i in range(5):
x = 2*x
y = x
y.backward()
params = {'Input': x1, 'Output':y}
image = make_dot(y, params=params)
image.format = 'x11'
image.render()
分岐処理を含む場合にも適用できる。
code:python
import torch
from torchviz import make_dot
x1 = torch.tensor(1, dtype=torch.float, requires_grad=True)
x2 = torch.tensor(1, dtype=torch.float, requires_grad=True)
x = x1
for i in range(5):
if i == 2:
x = x + torch.sin(x) + x2
else:
x = 2*x
y = x
y.backward()
params = {'Input': x1, 'Input2': x2, 'Output':y}
image = make_dot(y, params=params)
image.format = 'png'
#image.format = 'x11'
image.render()
https://scrapbox.io/files/646f162e305483001bbd3154.png
PyTorchを用いて構成したニューラルネットの可視化にも利用できる。
code:python
import torch
import torch.nn.functional as F
import torch.nn as nn
from torchviz import make_dot
class Net(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(3,6) # Input Layer
self.fc2 = nn.Linear(6,2) # Hidden Layer
def forward(self, x):
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.softmax(x, dim=0)
return x
model = Net()
data = torch.tensor(1.,3.,2., dtype=torch.float32, requires_grad=True)
y = model.forward(data)
params = dict(model.named_parameters())
params'Input' = data
params'Output' = y
image = make_dot(y, params=params)
image.format = 'png'
#image.format = 'x11'
image.render()
https://scrapbox.io/files/646f1a99a1b7f2001c099e68.png
この場合、ノードで扱われる値の形状が丸括弧の中に表示されている。