PATHWAYS: ASYNCHRONOUS DISTRIBUTED DATAFLOW FOR ML
2022/05/06
著者: Google Research の方々
Proceedings of the 5th MLSys Conference, Santa Clara, CA, USA, 2022.
https://gyazo.com/f00f88460ec960009e082c461d4f9510
選んだ理由
バズ論文なので読んでおこうと
参考
どんなもの?
ブログを訳すと
既存モデルの問題点
1 つのタスクの 1 つのモデルを作っていた
既存モデルが学習したこと使わずに、フルスクラッチに学習するのは無駄が多い
1 つのデータは 1 つの意味しか持たない
leopard という言葉一つをとっても人間は複数の情景を想像する
dense なモデルは無駄が多い
人間は一つの問題を解くときに脳のすべてを活用しているわけではない
Pathways
多数のタスクに 1 つのモデルで取り組む
データを抽象的に扱い異なった意味で活用できる
データに応じてそのネットワークを利用するかを動的に決める(スパースに動かす)
と書かれているが、
新しいスケーリングと並列化の手法を構築したので、
これまで以上に巨大なパラメータで動かせるようになった
が技術的なキモのように思う
ブログの文章は微妙に学習の話と PaLM のモデル話が混じっている
某 raddit のコメント
https://gyazo.com/7370e808a5b38ebb3c39384bde29f770
この Pathways を Transformer で使ってみたよ!が PaLm
Palm 自体は Transformer の Decoder 部分のみを使ったモデル
先行研究と比べてどこがすごい?
既存の複数 GPU を使う方法として、NN のコンポーネントごとに分離して実行する方法がある
Pipeline Parallelism
https://gyazo.com/ca4e5d2381f4e19a64be255e4296fd74
1 段目の方法は愚直な方法だが、この場合 GPU をほぼ使えていないため、micro batch という概念を導入する
これは batch を更に分割して、GPU の使用率を上げる方法である
Pipeline 化の問題点
図中に Bubble と書かれているように、使用率が低い箇所が存在する
頻繁にメモリのリロードが発生してボトルネックになる
そこで Pipeline 化を必要としない手法を Pathways で提案し、PaLM で活用した
技術や手法のキモはどこ?
ざっくりいう並列処理しながら分散処理できる仕組みを用意したよ
Resource Manager
Pathways のシステムには Resource Manager がグローバルに単一で存在する
使用側は Client を用意しており、Client は下記図左のような中間表現としての計算グラフを作成する
https://gyazo.com/a43c831c024f64e62a04858aff9689e0
Gang-Scheduled dynamic dispatch
single program, multiple data (SPMD) 構造を実現し並列化を目指す
同一プログラムを異なるデータに対して実行する並列処理
このために Gang Scheduling をサポートすることが必要
実行に必要なリソースを最初にすべて確保しておき、スレッド間の通信オーバーヘッドをなくす
PATHWAYS では
コンパイルされた関数をそれぞれの処理装置に乗せておく(入力値分の buffer は用意しておく
(この辺先にモデル構造がが決まっていて allocate すべき量が事前にわかるからできるのかなぁ)
ネットワークに future output を出力するようにキューイングしておく
どの順番に関数を実行するかをスケジューラーと決定する
Parallel asynchronous dispatch
Host 側の allocation コストが既存手法の (a) だとボトルネックになりやすい
先に paralell で host 側の処理だけを終えておき、ボトルネックを解消する
https://gyazo.com/d7a06a2facc73ad2c322607ba43d0308
どうやって有効だと検証した?
Sequential と Parallel Dispatch の比較
Parallel の場合は Host でのオーバーヘッドを先に解消しているため、pipeline stage が増えるほど Sequential よりもパフォーマンスを出すことができている
https://gyazo.com/6a9ac5424a8b70ba3fed6abc3282dfbc
他にも Pathways の論文では JAX と比較をしていたが、正直図表の嬉しさがわからなかったので割愛・・:bow:
FLOPS utilization で比較した (PaLM の結果)
https://gyazo.com/39369e0aa2884a9582f918f736373791
既存モデルと比較して計算リソースをより効率的に活用できている
議論はある?
PATHWAYS はデザインか単なる実装か?
PATHWAYS は TPU を前提とした設計になっており、TPU カーネルありきの内容もある
が、PATHWAYS を作る上で行った意思決定は大規模 GPU 環境でも転用できるはずだ
リソース管理周り
今回は TPU 計算を効率的に多重化することを目的としていた
現実のより複雑な multi tenant で行う場合には、多様なリソースを管理できるようにする必要がある
データに基づくベクトル化された control flow
現在はモデル内の全重みを一度にすべて更新している
より効率的な学習のためには、データに基づいてノード間のデータをやり取りする必要がある
所感
TPU を基準としたアーキテクチャすごいという気持ち
分散処理基盤として面白そうで NN 以外にも活用の道はないのかな