JAXを使ってみよう
https://gyazo.com/096761a3a7d0bcf8d4b96e9d7148466b
JAXについて
JAX はGoogleによって開発されNumPyの代替品です。 JAXフレームワークには、NumPyおよびSciPyが提供する多くの関数のGPU対応バージョン、JITコンパイラー、およびその他のディープラーニング固有の機能が含まれています。 JAXには、CuPyとは対照的に独自のJITコンパイラがあります。 このネイティブ機能は優れていますが、JITコンパイラに必要な構文は多少異質な場合があります。 たとえば、従来のNumPyのインデックススライスはサポートされていません。 スライス抽出するために必要な構文は一般的ではありません。 これらの小さな変更により、既存のコードをJAXのJITコンパイラに適したコードに変換することが困難になる可能性があります。
JAXはGoogleのXLAコンパイラを使用しています。 現在、XLAは多くのCPU、Google TPU、NVIDIA GPU、そして最近ではAMD ROCmをサポートしています。 ただし、JAXはまだAMD GPUをサポートしていません。
JAXのインストール
CPUバージョンのJAXのインストールは次のように行います。
code: bash
$ pip install --upgrade pip
$ pip install --upgrade jax jaxlib # CPU-only version
CPUバージョンは、ラップトップなどやローカルホストでの開発を行う場合に役立つでしょう。
CPUとGPUの両方をサポートするJAXをインストールする場合、はじめにCUDA
をインストールされている必要があります。
他の一般的なディープラーニングシステムとは異なり、JAXはCUDAまたはCuDNNをpipパッケージの一部としてバンドルしていません。 CUDA-10 JAXのパッケージにはCuDNN7が必要ですが、CUDA-11 JAXパッケージにはCuDNN8が必要です。CUDAとCuDNNの他の組み合わせも可能ですが、ソースから構築する必要があります。
code: bash
pip install --upgrade pip
pip install --upgrade jax jaxlib==0.1.59+cuda110 -f $JAXURL
jaxlibのバージョンは、使用するインストールされているCUDAのバージョンと対応している必要があります。
CUDA-11.1の場合は cuda111
CUDA-11.0の場合は cuda110
CUDA-10.2の場合は cuda102
CUDA-10.1の場合は cuda101
次のコマンドでCUDAバージョンを見つけることができます:
code: bash
$ nvcc --version
NVidia CUDA Toolkit をデフォルトの設定でインストールした場合は、インストールパスは /usr/local/cuda-X.Xになります。ここで、X.XはCUDAバージョン番号です。例:/usr/local/cuda-10.2
JAXの一部のGPU機能では、CUDAのインストールがこのパスにインストールされていると想定していることに注意してください。 CUDAがシステムの他の場所にインストールされている場合は、シンボリックリンクを作成できます。
code: bash
$ sudo -n -s /path/to/installed/cuda /usr/local/cuda-X.X
または、JAXをインポートする前に、次の環境変数を設定することでJAXに知らせることができます。
code: bash
$ export XLA_FLAGS="--xla_gpu_cuda_data_dir=/path/to/installed/cuda"
JAXの使用方法
Google Cloud や Google Colab に接続された、ブラウザでノートブックを使用してすぐに利用することができます。
スターターノートブックは次のとおりです。
JAXはGoogle TPUで実行されるようになりました。 プレビューを試すには次を参照してください。
次も参照するとよいでしょう。
数値演算関数の置き換え
JAXの基本は、数値演算関数を置き換える拡張可能なシステムです。主なものは、grad、jit、vmap、pmapの4つです。
gradによる自動微分
JAXにはAutogradとほぼ同じAPIがあります。 最も一般的な関数は、逆モード勾配(reverse-mode gradients)です。
code: python
from jax import grad
import jax.numpy as jnp
def tanh(x): # 関数を定義
y = jnp.exp(-2.0 * x)
return (1.0 - y) / (1.0 + y)
grad_tanh = grad(tanh) # 勾配関数を取得
print(grad_tanh(1.0)) # x = 1.0 を計算
# 出力は 0.4199743
gradを使用すると、任意の順序に微分(differentiate)できます。
code: python
print(grad(grad(grad(tanh)))(1.0))
# 出力は 0.62162673
より高度なautodiffの場合、逆モードのヤコビアン関数にはjax.vjpを使用し、順モードのヤコビアンベクトル関数にはjax.jvpを使用できます。 この2つは、相互に、および他のJAX変換を使用して任意に構成できます。 これらを構成して、完全なヘッセ行列を効率的に計算する関数を作成したものは次の例です。
code: python
from jax import jit, jacfwd, jacrev
def hessian(fun):
return jit(jacfwd(jacrev(fun)))
Autogradを使うことで、Python制御構造で自由に微分することができます。
code: python
def abs_val(x):
if x > 0:
return x
else:
return -x
abs_val_grad = grad(abs_val)
print(abs_val_grad(1.0)) # 出力は 1.0
print(abs_val_grad(-1.0)) # 出ry区は -1.0 (abs_valを再評価)
JITによるコンパイル
XLAを使用して、@jitデコレータまたは高階関数として使用されるjitを使用して関数を全体をコンパイルすることができます。
code: python
import jax.numpy as jnp
from jax import jit
def slow_f(x):
# 要素ごとの操作は、コンパイルで大きな性能向上が得られる
return x * x + x * 2.0
x = jnp.ones((5000, 5000))
fast_f = jit(slow_f)
%timeit -n10 -r3 fast_f(x) # ~ 4.5 ms / loop on Titan X
%timeit -n10 -r3 slow_f(x) # ~ 14.5 ms / loop (also on GPU via JAX)
jitとgrad、およびその他のJAX変換を好きなように組み合わせることができます。
jitを使用する場合、対象の関数で使用できるPython制御フローの種類には制約があります。詳細は GotchasNotebook を参照してください。 vmapによる自動ベクトル化
vmapはベクトル化を行うマッピング関数です。 配列軸に沿って関数をマッピングするというおなじみのセマンティクスがありますが、ループを外側に保持する代わりに、パフォーマンスを向上させるためにループを関数のプリミティブ操作にプッシュダウンします。
vmapを使用すると、コードでバッチディメンションに移植する必要がなくなります。 たとえば、次の単純なバッチ処理されていないニューラルネットワーク予測関数について考えてみます。
code: python
def predict(params, input_vec):
assert input_vec.ndim == 1
activations = inputs
for W, b in params:
outputs = jnp.dot(W, activations) + b # input_vec は右側
activations = jnp.tanh(outputs)
return outputs
代わりに、入力の左側にバッチディメンションを許可するために、代わりにjnp.dot(inputs,W)を記述することがよくありますが、この特定の予測関数は、単一の入力ベクトルにのみ適用されるように記述されています。 この関数を入力のバッチに一度に適用したい場合は、意味的には次のように書くことができます。
code: python
from functools import partial
predictions = jnp.stack(list(map(partial(predict, params),
input_batch)))
ただし、一度に1つの例をネットワーク経由でプッシュするのは時間がかかります。 計算をベクトル化することをお勧めします。これにより、すべてのレイヤーで、行列-ベクトルの乗算ではなく、行列-行列の乗算を実行します。
vmap関数は私たちのためにその変換を行います。 つまり、次のように記述できます。
code: python
from jax import vmap
predictions = vmap(partial(predict, params))(input_batch)
# または、代わりに
predictions = vmap(predict, in_axes=(None, 0))(params, input_batch)
次に、vmap関数は関数内の外側のループをプッシュし、マシンは、手動でバッチ処理を行ったかのように、行列と行列の乗算を実行することになります。
vmapを使用せずに単純なニューラルネットワークを手動でバッチ処理するのは簡単ですが、それ以外の場合、手動のベクトル化は非現実的または不可能な場合があります。 例ごとの勾配を効率的に計算するという問題を取り上げます。つまり、パラメーターの固定セットについて、バッチ内の各例で個別に評価された損失関数の勾配を計算する必要があります。 vmapを使用すると、簡単です。
code: python
per_example_gradients = vmap(partial(grad(loss), params))(inputs, targets)
もちろん、vmapは、jit、grad、およびその他のJAX変換を使用して任意に構成できます。 jax.jacfwd、jax.jacrev、およびjax.hessianでの高速ヤコビ行列およびヘッセ行列計算のために、順方向モードと逆方向モードの両方の自動微分でvmapを使用します。
pmapを使用したSPMDプログラミング
複数のGPUなど、複数のアクセラレータの並列プログラミングには、pmapを使用します。 pmapを使用すると、高速並列集合通信操作を含む、SPMD(Single-Program Multiple-Data)プログラムを作成できます。 pmapを適用すると、作成した関数がXLAによってコンパイルされ(jitと同様)、デバイス間で並行して複製および実行されます。
これは、8GPUマシンでの例です。
code: python
from jax import random, pmap
import jax.numpy as jnp
# GPUごとに1つずつ、8つのランダムな5000 x6000マトリックスを作成
keys = random.split(random.PRNGKey(0), 8)
mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys)
# 各デバイスでローカルmatmulを並行して実行(データ転送なし)
result = pmap(lambda x: jnp.dot(x, x.T))(mats) # result.shape (8, 5000, 5000)
# 各デバイスの平均を並行して計算し、結果を出力
print(pmap(jnp.mean)(result))
純粋なマップを表現することに加えて、デバイス間で高速な集合通信操作を使用できます。
code: python
from functools import partial
from jax import lax
@partial(pmap, axis_name='i')
def normalize(x):
return x / lax.psum(x, 'i')
print(normalize(jnp.arange(4.)))
より洗練された通信パターンのためにpmap関数をネストすることもできます。
それはすべて構成されているので、並列計算によって自由に微分できます。
code: python
from jax import grad
@pmap
def f(x):
y = jnp.sin(x)
@pmap
def g(z):
return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum()
return grad(lambda w: jnp.sum(g(w)))(x)
print(f(x))
print(grad(lambda x: jnp.sum(f(x)))(x))
pmap関数を逆モードで微分する場合(たとえば、gradを使用)、計算の逆方向パスは順方向パスと同じように並列化されます。
現時点での既知の制約
代表的なものを以下にしめします。
JAX変換は、副作用がなく参照透過性を尊重する純粋関数でのみ機能します。
つまり、isを使用したオブジェクトIDテストは保持されません。 純粋関数でないPython関数でJAX変換を使用すると、次のようなエラーが表示される場合があります。
Exception: Can't lift Traced... or Exception: Different traces at same level
x[i] += y のような配列の更新はサポートされていませんが、機能的な代替手段があります。 jitでは、これらの機能的な代替手段は、バッファをインプレースで自動的に再利用します。
乱数はNumPyとは異なりますが、これには正当な理由があります。
JAXは、分割可能な最新のThreefryカウンターベースのPRNGを使用します。 つまり、その設計により、PRNG状態を新しいPRNGにフォークして、並列確率生成で使用することができます。
畳み込み演算子を探している場合は、jax.laxパッケージに含まれています。
JAXはデフォルトで単精度(32ビットなどfloat32)を適用しています。倍精度(64ビットなどfloat64)を有効にするには、起動時にjax_enable_x64変数を設定する(または環境変数JAX_ENABLE_X64=Trueを設定する)必要があります。
PythonスカラーとNumPyタイプの組み合わせを含むNumPyのdtypeでの指示の一部は保持されません。つまり、numpy.add(1,np.array([2], np.float32))。このときのdtypeはfloat32ではなくfloat64です。
jitなどの一部の変換は、Python制御フローの使用方法を制限します。 何か問題が発生すると、常に大きなエラーが発生します。 jitのstatic_argnums引数や、lax.scanなどの構造化された制御フロープリミティブを使用するか、小さなサブ関数でjitを使用する必要がある場合があります。
まとめ
CUDAを利用してGPU操作を行う他のライブラリと比較して、インストールや学習コストは高くなります。しかし、簡単にマルチGPUでの計算処理ができることは特筆するべき利点だと言えます。
また、Google ColabやGoogle Cloud で TPU を利用した計算ができることは大きな差別化要因です。
参考