自動微分(スカラの場合)
関数$ f(x) = x^2 - 2xの$ x=2における導関数の値を求める問題を考える。導関数は
$ \frac{d}{dx} f(x) = 2x - 2
より、
$ \left. \frac{d}{dx} f(x) \right|_{x=2} = 2x - 2 |_{x=2} = 2
である。これを計算するプログラムは次のように与えられる。
code:tensor_grad11.py
def fd(x):
return 2*x - 2 # 導関数の式を用いて計算する
x = 2.0
print('Derivative:', fd(x))
code:結果.txt
$ python3 torch_grad12.py
Derivative: 2. <--- x=2 に対する y の勾配(傾き)の値
自動微分を利用すると、次のように作ることができる。
code:tensor_grad12.py
import torch as pt
x = pt.tensor(2., dtype=pt.float32, requires_grad=True) # (1)
y = x**2 - 2*x # (2):元の関数の式
y.backward() # (3):自動微分
print('Derivative:', x.grad)
code:結果.txt
$ python3 torch_grad13.py
Derivative: tensor(2.) <--- x=2 に対する y の勾配(傾き)の値
導関数式を利用することなく導関数値が得られていることがお分かりだろうか?
計算グラフ
上記のプログラムを実行した結果、次のような計算グラフが生成される。各ノードには計算に利用された変数と実行された計算過程が保持されている。
https://scrapbox.io/files/653343e26cdcec001c2b52e8.svg
この中で、計算に用いられた変数xがリーフ、結果である変数yがルートである。
リーフは勾配を扱うように設定されたテンソルである。(requires_grad=True したもの)
ルート変数yのbackwardメソッドを呼び出すと、グラフをリーフまで辿り、この間で行われた計算の内容を元に勾配を求め、リーフのgrad変数に格納する。
計算式とグラフを構成するノード(要素)の対応を考える。
code:(計算式).py
y = x**2 - 2*x
AccumulateGrad: 変数から値が出力された箇所に置かれるノード(と思われる)
PowerBackward0:x**2の「べき算」を意味するノード
MulBackward0: 2*x の「乗算」を意味するノード
SubBackward0: x**2 - 2*x の「減算」を意味するノード
ノードx, y内の()は、これらの変数の形状を示している。
リーフが2つある例
計算に2つの変数x1, x2を利用した例を考える。いずれもリーフとなるようrequires_grad=Trueを設定している。
code:tensor_grad13.py
import torch as pt
x1 = pt.tensor(2, dtype=pt.float, requires_grad=True)
x2 = pt.tensor(1, dtype=pt.float, requires_grad=True)
a = pt.sin(x1**2) + 1
b = pt.cos(2*x2)
y = a + b
y.backward() # 自動微分
print('# 実験1----------')
print('x1:', x1)
print('x1.grad:', x1.grad)
print('# 実験2----------')
print('x2:', x2)
print('x2.grad:', x2.grad) # gradを持たない場合はNoneと表示される
https://scrapbox.io/files/653347b8769b21001c7b4398.svg
上の例からx2から勾配の設定を外すよう修正する。結果としてこの変数はリーフではなくなり、計算グラフも生成されなくなる。
code:tensor_grad14.py
import torch as pt
x1 = pt.tensor(2, dtype=pt.float, requires_grad=True)
x2 = pt.tensor(1, dtype=pt.float) # x2はgradをもたない
a = pt.sin(x1**2) + 1
b = pt.cos(2*x2)
y = a + b
y.backward() # 自動微分
print('# 実験1----------')
print('x1:', x1)
print('x1.grad:', x1.grad)
print('# 実験2----------')
print('x2:', x2)
print('x2.grad:', x2.grad) # gradを持たない場合はNoneと表示される
https://scrapbox.io/files/6533470408589b001bf409cc.svg
/icons/hr.icon
※ ブラウザのバックボタンで戻る