mlx_parallm
If you’re on a CUDA machine, you’d use something like vLLM, which is a more “production-grade” solution for achieving high tok/s throughput with parallel requests, but it doesn’t work on a Mac.
implemented batched KV caching in MLX for fast parallel LLM inference on Apple devices.
batched KV cachingによって高速化を実現したらしい。
元々KVキャッシュはあって、それをバッチで処理するようになったのがアツいポイント。
https://scrapbox.io/files/668752deb15053001c512fff.png
普通LLMではトークン(t1,t2,t3)毎に入力して出力(t4)を得る
次にt1,t2,t3,t4を順に入力してt5を得る...といった形で出力(t4...)を得る
でも次のトークン得るために毎回計算するのめんどくない?
そのためiter1を計算したときに次のiterのために保存しておきます。transformersライブラリなどでは保存していた計算結果を利用して実際にはiter2*のように計算することで計算コストを削減しています。ここでは保存した計算結果を KV cache と呼びます。
ここからは最初のiter1を prefill フェーズ、iter2*, iter3を decode フェーズと呼びます。
https://scrapbox.io/files/668751c3ce9ed4001cfeda77.png
クライアントからリクエストが来たら prefill フェーズのみを計算しそれぞれの KV cache を保存してキューに追加します。キューにある KV cache をバッチにまとめて decode フェーズを計算して1トークンだけ生成して KV cache を更新しキューに戻します。生成が終了した KV cache はキューに戻さずにクライアントに生成結果を送信します。これを繰り返してテキストを生成します。
Text Generation InferenceやvLLMはContinuous batchingをサポートしています。 元々batching algorithmはあり、vLLM等でもKV Cache使う手法が実装されていて、そういうのをmlxにも作りましたって感じかな。 Gotcha. So if you use different prompt, the speed would be a lot lower since most of the gain comes from batching kv cache?
Yes. With completely different prompts there is no gain.
なるほど?mlx_parallmの場合同一プロンプトをスケールさせる用途が良い?
---
追記。
プレフィル フェーズでは、LLM は入力トークンを処理して中間状態 (キーと値) を計算します。これは、「最初の」新しいトークンを生成するために使用されます。
デコード フェーズでは、LLM は停止基準が満たされるまで、出力トークンを 1 つずつ自己回帰的に生成します。各順次出力トークンは、以前のすべてのイテレーションの出力状態 (キーと値) を知っている必要があります。
つまり、プレフィルフェーズとは入力(ユーザからのプロンプト、例えば「こんにちは、今日の運勢を教えてください」)をまず入れると、諸々内部計算されて最初のトークン(例えば「今日」)が出る。この出るまでがプレフィルフェーズ。
その後、出力した最初のトークンを自己回帰的に入力して次の出力、次の出力...を出す。これがデコードフェーズ。
図的にはこれがわかりやすそう。
出力の長さは、計算順序依存性が強い
これは自己回帰的だから。
入力の長さは、計算順序依存性が少ない
プレフィルフェーズ。このフェーズ以降をキャッシュとして持つのがKVキャッシュ。
この図でいうと横の緑がひとりってイメージか?
バッチで複数人(緑の帯)を投入して、それらを同時的に処理する。
https://scrapbox.io/files/66aa59674be692001c0852d7.png
モデルの重み: メモリはモデル パラメーターによって占有されます。例として、70 億個のパラメーター (たとえば、Llama 2 7B)を、16 ビット精度 (FP16 または BF16) でロードすると、メモリにおよそ 7B * sizeof(FP16) ~= 14 GB が必要になります。
KV キャッシュ: メモリは、冗長な計算を避けるためにセルフアテンション テンソルのキャッシュによって占有されます。
なるほど?MLP層は固定。だけどKVキャッシュは入力されるたびに値が更新されていく。
なのでここの計算を持っておくのはアツい。
大規模言語モデル(LLM)は、2ステップのプロセスでテキストを生成します:入力プロンプトのトークンが並列で処理される「prefill」と、自己回帰的な方法で一度に一つのトークンが生成され、テキストが生成される「decoding」です。生成されるそれぞれのトークンは入力に追加され、モデルが次のトークンを生成できるようにフィードバックされます。LLMが特殊なストップトークンを出力するか、ユーザーが定義した条件(最大トークン数が生成された場合など)を満たした場合に生成が停止します。
Time To First Token (TTFT)
最初にユーザがトークン出力を得る速さ。
Time Per Output Token (TPOT)
次のトークン出力を得る速さ。
レーテンシー
全ての出力を得るのに要する時間。
スループット
全ユーザ及びリクエストに対するtokens/s。
バッチではこれが上がるから結果としてtpsが上がるイメージか。
トランスフォーマー固有の数多くの重要な最適化があります。この主要な例はKV(キーバリュー)キャッシングです。デコーダーのみのトランスフォーマーベースのモデルにおけるアテンションのメカニズムは計算処理的に非効率です。それぞれのトークンは以前に出現したすべてのトークンを処理するため、新たなトークンが生成される都度、数多くの同じ値を再計算します。例えば、N番目のトークンを生成する際には、(N-1)番目のトークンは(N-2)番目、(N-3)番目…最初のトークンを処理します。同様に、(N+1)番目のトークンを生成する際には、N番目のトークンに対するアテンションは、 (N-1)番目、(N-2)番目、(N-3)番目…最初のトークンを再度参照する必要があります。KVキャッシング、すなわち、アテンションレイヤーに対する中間的なキー/バリューの保存は、繰り返しの計算を避け、あとで再利用できるようにこれらの結果を保持するために活用されます。
KVキャッシュ。
バッチサイズと入出力ニューロン数は2^Nのサイズで使用することが推奨されます。しばしば8の倍数ですが、使用するハードウェアやモデルのdtypeに応じて、より高い数になることもあります
つまり、LLMのランタイムに普通にディープラーニングみたいにバッチ機能がある。
バッチ処理は、状況を改善する 1 つの方法です。入力シーケンスがあるたびに新しいモデル パラメータをロードする代わりに、モデル パラメータを 1 回ロードして、それを使用して多くの入力シーケンスを処理できます。これにより、チップのメモリ帯域幅がより効率的に使用されるため、コンピューティングの使用率とスループットが向上し、LLM 推論のコストが削減されます。
In-flight Batching
効率的な推論には大きなBatchで動作させることが欠かせないが、LLMではBatch毎に動作するシーケンス長が異なってしまう。そのため全てのBatchが動作完了するまで待っているとGPUを有効に活用することができない。
そこでIn-flight batchingでは計算終了したBatchに次のBatchを与える。これによってGPU使用率を常に高く維持することができ、時間当たりのLLM処理トークン数を最大化できる重要な技術となる。
このような形でバッチを与える前提で、効率的に捌けるバッチは強い。
(シーケンス長でメチャ長いプロンプトが一個入ってるとそれに引きずられてしまう)
ですです!!
50回分のプロンプトを一気に投げて、一気に受け取るようなバッチ処理ができます! プロンプトの数が増えても、生成にかかる時間が線形的に増えていかないのがすげえポイントです
PagedAttention登場以前の従来の並列生成はKVキャッシュとよばれる”リクエストごとに発生する大きなGPUメモリ消費”との戦いでした。
(KVキャッシュは transfomerのmodelを生で叩くときに past_key_values として登場します)
つまりモデルのパラメータとは別に発生する推論時のメモリ消費です。
これが同時に捌けるリクエスト数の限界を決めており、リクエスト数の限界=推論時消費メモリがGPUの搭載メモリを超えないギリギリのラインで、この限界に達する前に、生成リクエストを別のGPUノードに負荷を転嫁する必要があります。このリクエスト数限界を押し上げることが1リクエスト当たり(もしくは1トークンあたり)の推論コストを下げることにつながります。
なるほど!
ここまでまとめると、LLMは入力トークン全部一気に処理するプレフィルとシーケンシャルなデコーディングがある。
プレフィル以降は毎回KVキャッシュする。これは計算のたびにメモリにのる。
vLLMとか色々バッチで一気に複数入力受け入れることができるものあるけど、バッチやりすぎるとメモリ逼迫するので工夫が必要。
後、バッチの中に長い入力あったりするとそれにひきずられたりするので、そこにも工夫が必要。
なので上記リンクにあったStatic batchingはこれバッチでうけいれられるの前提で、プレフィルも全部終わるの待って、そのあと生成も全部待つ。
素朴なバッチ処理。
DynamicとかContinuousはそれを効率化した工夫。
A toy visualization of how the "infinite KV cache" in MLX LM works.
Some notes:
- You always keep the first n tokens (n=1 in this toy example and n=4 in MLX LM)
- The maximum cache size is 4 in this toy example (specify it with --max-kv-size in MLX LM)
- If you follow the arrows from t=1 to t=7, you see that the last token still depends causally on all the initial tokens. But the dependence is mediated through other tokens instead of direct. That's why there is some accuracy loss.
- It uses a circular buffer which is really nice because it's super efficient. You don't have to copy the whole cache, just overwrite one spot. The invariance of self-attention to the order of inputs is a feature here.