Transformerを実装して少し理解した気になる会(その4)
#jam #Transformerを実装して少し理解した気になる会
4. 学習したTransformerで推論をする
Transformerの推論方法
Transformer自体はSequence-to-Sequenceモデルと呼ばれているもので、入力系列と同じ長さのデータを返す。
正確には、[バッチサイズ, 系列長] の入力に対して[バッチサイズ, 系列長, モデルの語彙数]のデータが返ってくる。
バッチサイズを省いた図
https://scrapbox.io/files/693e86f285613ddc34f05e39.png
ただ、実際に使うのは系列の中で一番最後の出力確率のみ。
モデルの語彙数は何を表しているか?
各語彙が次に来ると考えられる確率を表している。
その確率が高いものを次のトークンとして選び出す感じ。
次のトークンを選び出す方法は2つ
1. 最も確率が高いものを次のトークンとして用いる(Greedy Decoding: 貪欲法)
これでもいいが、同じ文章を繰り返しがちになってしまうらしい
2. 確率を元にサンプリングして次のトークンを選ぶ
今回はこれを実装してみる
確率を出す方法にも、いくつか工夫がある
temperature
温度と呼ばれているが、Transformerの最後の確率を抽出する際に掛けるSoftmaxに加工を施し、全要素をexpを通す前にT(temperature)で割る
$ \rm{Probability} = \frac{ \exp (x_i / T)}{\sum_j \exp (x_j / T)}
わかりやすい解説: https://qiita.com/nkriskeeic/items/db3b4b5e835e63a7f243
なぜ温度というか?: ボルツマン分布に由来するらしい
https://qiita.com/murakamixi/items/9b21328d67d0b1316674#ボルツマン分布とは
top-k
次の単語確率として挙がっている中で、確率が高いもの順にk個を選び、それらからサンプリングする形
top-p
次の単語確率として挙がっている中で、上位p%の確率を占めるものだけを選び、それらからサンプリングする形
今回は、簡単のためtemperatureとtop-kだけを実装してみる
とりあえずinfer自体は実装しておいたので、これを元に改変してみてください
https://github.com/y-chan/learn-and-make-slm/commit/170c8662afbeac53e6637d3233f2e7339587c7bf
ちなみに、確率的にサンプリングするにはどうすればいいか?
torch.multinomialを使ってください
https://chatgpt.com/share/693ea5f5-e074-8007-a808-a76f3204998e
KV Cache
推論時、Cacheを使うと効率が上げられる
本来、AttentionはQ[B H S D] @ K^T[B H D S] = QK[B H S S]の行列をつくり、K及びQの関連度を計算する
最終的にはこのシェイプにV[B H S D]をかけるので、QK[B H S S] @ V[B H S D] = ATTN[B H S D]となる
この、Qに関しては一番最後のものだけでも十分である
なぜか?: 一番最後のトークンと、そこまでのすべてのKとの関連度はQ[B H 1 D] @ K^T[B H D S] = QK[B H 1 S]で計算できているから
VはQKのS次元によって参照される(QK[B H 1 S] @ V[B H S D] = ATTN[B H 1 D])
よって、Qは一番最後のものしか必要がないが、KVは今までのすべての系列が必要。
QもKもVもAttentionに対する入力xから作られるが、これを一番新しいトークンだけを入力にし、KVはキャッシュしておけば、メモリ効率的にも計算効率的にも良い。
一応検算
code:python
>> import numpy as np
>> Q = np.random.randn(1, 1, 5, 8) # 系列長5、隠れ層のサイズ8とする
>> K = np.random.randn(1, 1, 5, 8)
>> V = np.random.randn(1, 1, 5, 8)
>> (Q @ K.mT) @ V # mTはtranspose(-1, -2)と同意
array([[[[-1.48512457e+00, 7.68810083e+00, 1.59719047e-02,
-1.63712342e+00, 2.32358763e+00, 6.32060038e+00,
-6.16095162e+00, -3.10746309e+00],
[-9.58711200e-01, -1.48961774e+01, -7.74099172e+00,
4.18099819e+00, -6.03822469e+00, -1.72024124e+01,
1.59750358e+01, 5.52561616e+00],
[-3.53787693e+00, -9.79928332e+00, -1.26630842e+01,
6.61701216e-01, -6.07633685e+00, -8.09025039e+00,
1.69200101e+01, 4.11580059e+00],
[-3.09115717e+00, -7.96822456e+00, -1.04828669e+01,
2.83218687e+00, -3.42227529e+00, -1.02565361e+01,
1.21468175e+01, 6.15714889e+00],
[ 6.53485610e+00, -3.03395125e+00, 1.41325512e+01,
1.48436159e+00, 7.40336971e-01, -6.77432953e+00,
-6.72438642e+00, -2.00797773e+00]]]])
>> ((Q @ K.mT) @ V):,:,-1:
array([[[[ 6.5348561 , -3.03395125, 14.13255116, 1.48436159,
0.74033697, -6.77432953, -6.72438642, -2.00797773]]]])
>> ((Q:,:,-1:, @ K.mT) @ V)
array([[[[ 6.5348561 , -3.03395125, 14.13255116, 1.48436159,
0.74033697, -6.77432953, -6.72438642, -2.00797773]]]])
>> np.isclose(((Q @ K.mT) @ V):,:,-1:, ((Q:,:,-1:, @ K.mT) @ V))
array([ True, True, True, True, True, True, True, True]) # 同じになった
https://zenn.dev/sre_holdings/articles/f15290860986ad
K と V は使いまわしが効く
(上の説明)
こういう
step 1: [K1, V1]
step 2: [K1, V1], [K2, V2]
step n: [K1, V2], [K2, V2] ... [K_n, V_n]
なので、step 終了時に KV を保存して、次の step で新規に計算した KV を concat していけば計算が減ってお得(計算効率がいい)
sushichan044.icon Hugging Face の KVCache に関する解説と実装
https://huggingface.co/docs/transformers/main/kv_cache
https://huggingface.co/docs/transformers/main/en/internal/generation_utils#transformers.StaticCache
https://github.com/huggingface/transformers/blob/40dc11cd3eb4126652aa41ef8272525affd4a636/src/transformers/cache_utils.py#L1004-L1080
https://github.com/huggingface/transformers/blob/40dc11cd3eb4126652aa41ef8272525affd4a636/src/transformers/cache_utils.py#L671-L892
sushichan044.icon 書いた
KVキャッシュモジュールの実装 https://github.com/y-chan/learn-and-make-slm/pull/17
KVキャッシュの統合 https://github.com/y-chan/learn-and-make-slm/pull/28
振る舞いは基本的に同じであり、データの保存場所や保存方法を工夫して学習シチュエーションごとに適したパフォーマンスの Cache 実装を使い分けられるようになっているだけ
KV Cacheの効果について
y-chan.icon こちらをご覧ください
https://y-chan.dev/blog/learn-and-make-slm/#実行速度
Extra Stage: より効率的かつ実用的な推論
そもそもPyTorchはそのまま使うと遅い
いろんなモジュールがあるからなのか?結構遅い気がする
モデルをJITコンパイルする方法もある
試してないけど結局巨大なPyTorchに依存するから大して変わらない気がする
じゃあPyTorchに依存しない方法はないのか?
ある、公式にOpen Neural Network eXchange(ONNX)という規格に変換する口が存在する
torch.onnx
ONNXに変換(export)する際、トレースと呼ばれる処理が行われ、演算や重みなどが抽出される。その際、if文による分岐は無効化され(トレース時に通った経路のみ使えるようになる)、同様の理由でfor文も使えるが回数が固定になる。
また、様々な演算ごとにOperatorが存在し、内部で効率的に演算を行える。トレースだけでは検出できないOperatorは、その演算を利用するように定義できる。
torch.autograd.Functionの機能、symbolicメソッドを書けば、利用する演算・引数などを指定できる。
例えば、Embeddingモジュールの場合
code:python
class EmbeddingFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, indices: Tensor, weight: Tensor) -> Tensor:
return weightindices
@staticmethod
def symbolic(g, indices: Tensor, weight: Tensor):
# ONNXのGatherオペレータを使用
# axis=0 (行方向でgather)
return g.op("Gather", weight, indices, axis_i=0)
class Embedding(nn.Module):
def __init__(self, num_embeddings: int, embedding_dim: int):
super().__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.w = nn.Parameter(nonzero_randn(num_embeddings, embedding_dim) * (embedding_dim**-0.5))
def forward(self, x: Tensor) -> Tensor:
# ONNX Export時、組み込みオペレータを使用するための処理
if torch.onnx.is_in_onnx_export():
return EmbeddingFunction.apply(x, self.w)
else:
return EmbeddingFunction.forward(None, x, self.w)
本来であれば、Embeddingモジュールの中でEmbeddingFunction.apply(x, self.w)すれば十分である
ただ、backward(勾配計算のための処理)を自前実装しなければならず、実装負荷と学習負荷が高くなる。
そのため、学習時はforward関数を直接呼び出す(本当はダメだがctxにNoneを渡す)ことで、学習時はPyTorch自動微分の恩恵を受けられるが、ONNX Export時はOperatorを指定できるようにしている。
なぜctx=Noneがダメか?
Sigmoidに関しては自前実装したのでそれでわかる
code:python
class SigmoidFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor) -> torch.Tensor:
y = 1.0 / (1.0 + (-x).exp())
ctx.save_for_backward(y)
return y
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
y = ctx.saved_tensors0
return grad_output * y * (1.0 - y)
@staticmethod
def symbolic(g, x: torch.Tensor):
return g.op("Sigmoid", x)
sigmoid = SigmoidFunction.apply
本来であればctxにはbackwardのために計算結果を保存するから、forwardにNoneを渡せばsave_for_backwardなんてないぜ!って言われる
RotaryEmbedding/Attentionに関しても、ONNXのバージョン23から専用のオペレータが用意されるようになったので、それらを活用する。
Attentionに関しては、外挿型のKV Cacheに対応しているので、それらもひっくるめた対応を行った
具体的には、2つある
1. past_key/past_valueを入力(外挿)できるようにし、present_key/present_valueを返せるようにした
2. past_key/past_valueがない場合(初期入力の場合)でも、ONNXモデルは入力が必要である。ただし、いずれかの次元がサイズ0のものを入力することができない。なので、系列長が1の0初期化されたテンソルをkey/valueとして入力し、ONNXモデル内では系列の先頭を無視する(-infのAttention Maskをかける)ようにした
ただし、ONNX Runtime(ONNX公式のONNXモデル実行環境)のCUDA実行プロバイダは同ONNX バージョン23にあるSqueezeオペレータに対応していない(古いバージョンのSqueezeオペレータには対応している)
なので、CUDA向けにはバージョン17程度まで下げてexportし、それに合わせて17では対応していないRotaryEmbedding/Attentionオペレータは素の処理をトレースさせるようにしている。
cudaを使ってexportしたときにそうなるようにしている(テンソルのデバイスタイプを見て判断する荒業)
code:python
if torch.onnx.is_in_onnx_export():
if Q.device.type != "cuda":
output, present_key, present_value = ScaledDotProductAttentionFunction.apply(
Q, K, V, int(K.size(-3)), int(Q.size(-3)), self.scale.item(), past_key, past_value, seq_lens
)
else:
# CUDA EP向けにexportする場合はAttention Opのないversion 17でexportされるのでforwardをトレースさせる
output, present_key, present_value = ScaledDotProductAttentionFunction.forward(
None, Q, K, V, int(K.size(-3)), int(Q.size(-3)), self.scale.item(), past_key, past_value, seq_lens
)
実際ONNXは早いのか?
比較してみよう(GPT-OSS再現のモデル/RTX3090)
Flash AttentionとKVキャッシュをちゃんと使ってるCUDA実行
https://scrapbox.io/files/695a6b82a981db403f067e42.png
外挿KVキャッシュ利用のONNX化(CUDAバージョン)のCUDA実行
https://scrapbox.io/files/695a6ba21006662fb7748318.png
y-chan.icon 速度2.5倍で草