Transformerを実装して少し理解した気になる会(その2)
#jam #Transformerを実装して少し理解した気になる会
2. Transformerの重要なモジュールについて理解して実装する
Positional Encoding
https://scrapbox.io/files/690d77a562c837a7f30f44d2.png
(引用: Attention Is All You Need)
Transformerは後の層でSelf-Attentionといって、入力された単語・トークンの、お互いの関係性を確率で表したような行列を作り出す
ただし、この行列は位置関係を持てない
行列演算(特に内積・外積)は順序に依存しないため、入力の並びが入れ替わっても同じ結果になる可能性がある
位置情報を認識させるにはわざわざ埋め込まなければならない
そこででてくるのがPositional Encoding
sin/cosの波で埋め込むことの利点
波形を使うことで「距離」や「周期的な関係」をモデルが学びやすくなる
任意の長さの文に拡張しやすい
例えば、インデックスを渡していくと、文長がアホほど長くなったときにアホみたいにデカい数値を埋め込むことになるので、学習が発散してしまう
sin/cosであれば値域が[-1, 1]になることは自明
sushichan044.icon いい話だ
helkun.iconいい話だ2
sin/cosの組み合わせであれば2次元空間的な位置を表せて、sin/cosの片方だけでは表せない方向ベクトルになる
周期性だけだと位置を学ぶのは難しい
Scaled Dot-Product Attention
https://scrapbox.io/files/691035ea2154031cd868fa95.png
https://scrapbox.io/files/69103625d9f7ff67ac11b862.png
式的にはこうなる
Q/K/Vの意味(by ChatGPT)
Q: いま「知りたいこと」
sushichan044.icon 検索クエリ
K: 各単語の「持っている情報のラベル」
helkun.iconkey
helkun.icon知りたいことが含まれた情報の集合...?
V:その単語が「実際に持つ内容」
helkun.iconvalue
Q(Query)とK(Key)の内積(Dot-Product)を計算し、類似度を計算する
sqrt(d_k)で割る(Scale)
Q/Kの次元数が多いとsoftmax後に勾配が小さくなるので、その前にScaleすることで勾配を伝播しやすくする
y-chan.icon なんか直感に反する?掛けるならまだわかる気がするけど
sushichan044.icon 割ると計算したあとの各要素の絶対値が小さくなる
y-chan.icon 絶対値が大きいのが混じるとほぼすべての値が0/1に振り切ってしまい、勾配が小さくなるから、絶対値を下げることが、勾配を残すために有効な手段となる。なるほど
helkun.icon絶対値が大きい → 勾配が緩やかになる(小さくなる) → 絶対値を下げたい → sqrt(d_k)みたいに次元数みたいなので割れば嬉しくなる...?
sushichan044.icon 適当な 0 以上の定数で各要素を割るとたしかに softmax 後の値の絶対値が小さい
https://scrapbox.io/files/691047ac4c9fd173666a0b55.jpeg
softmaxをかけて、類似度を確率に落とし込む
最後、その類似度を重みとしてVと内積を取り、最終的なアテンションの結果を得る
(2025/12/24追記)Maskについて
Mak (opt.)って書いてあるやつ
学習データ内のpaddingなど、学習に不要なものはmaskをかけて、Attentionで計算しないようにする
それに加えて、系列の途中であっても、系列の最後を予測するのと同等の効果を得られるように、系列の途中で未来の系列を見ないようにする「上三角行列によるmask」がある
因果マスク(Causal Mask)と呼ばれることもある
y-chan.icon KV Cacheの実装の際に、これを正しく理解できていない事に気づいた、マジでごめん、下のは頑張って理解した図
https://scrapbox.io/files/694ae7211b0f0343f46f24c3.png
Multi-Head Attention
https://scrapbox.io/files/69103cd70cbee435275a022d.png
複数のheadをつくることで、人間が物事を複数の視点から見るのと近い感じにできる
内部でやってる処理はScaled Dot-Product Attentionなので、特に説明することなし
なんでV K Qの順序になっているのか?
y-chan.icon 知らん、なんでだこれ
(Position-wise) Feed-forward Network
https://scrapbox.io/files/690d77594dc3e23e483f6167.png
max(0, x)はReLUとも呼ばれている
非線形関数と呼ばれているもの、ReLUもSoftmaxと同じく活性化関数である
x W_1はなど内積ではなく外積
ただし、ReLUはもう古代のものなのでこれと違うやつを使います
SwiGLU
GPT OSSとかで使われているもの、SwishとGLU(Gated Linear Unit)の合体系
なんかここがわかりやすい解説してそう
https://jcarlosroldan.com/post/348
ReLUについてグラフ化(以降のグラフは上記サイトから引用)
https://scrapbox.io/files/691046af3ed850226e27154a.png
欠点: 入力値が0未満の場合に、勾配が消失する
Swishについてグラフ化
https://scrapbox.io/files/691047207cd288ac1681af2e.png
式も出すと
Swish(x) = x * sigmoid(beta * x)
betaは学習可能なパラメータだが、別に1で固定している場合もある
入力値が0未満でも、勾配消失しない
GLUについてグラフ化
https://scrapbox.io/files/691047f0e3c285547efef8d9.png
式: GLU(x) = (Wx+b)*sigmoid(Vx+c)
W, V, b, cは学習可能なパラメータ
ちなみに、Wx+b + Vx+cはLinearと同じ処理
SwiGLU
https://scrapbox.io/files/691048f8f38aa2a97864f62d.png
式: SwiGLU(x) = (Wx+b)*Swish(Vx+c)
更に自由度が高い活性化関数となった
なぜ動くかは説明できないらしい(from SwiGLU paper)
y-chan.icon 神の慈悲とか言ってて終わった
helkun.iconラマヌジャン
sushichan044.icon 学習 dev tool が求められている
というわけで、以降はFFN = SwiGLU(x)とする
以上でモジュールは全網羅した(一応)
まとめると、
Positional Encodingで位置情報を与えて
Multi-Head Attentionで多角的に情報に捉える
それぞれの情報をScaled Dot-Product Attentionで捉える
Feed-forward Networkで、Attentionの情報を更に加工して最終的な出力につなげる
Extra stage: 昨今のTransformerの拡張について知って実装する
2つぐらいある
Rotary Position Embedding(RoPE)
Grouped Query Attention
(Flash Attention)
他にもMoE(Mixture of Experts)とかあるんだけど今回は無視
多分そこまでの性能いらない
Rotary Position Embedding(RoPE)
Roformerというペーパーで提案された位置埋め込み方式
https://alphaxiv.org/pdf/2104.09864
sushichan044.icon これをロングコンテキスト対応にする YaRN を含め、この記事が最高にわかりやすい
https://tech-blog.abeja.asia/entry/advent-2025-day10
これまでの位置埋込み
絶対位置埋め込み(Positional Encoding)
Transformerの原論文で提案されている手法
絶対的な位置しか見ないので、相対的な位置や情報が重視されない
ds(繰り返しとか)はデータドリブンで学習して得るしかない
相対位置埋め込み(Relative Positional Encoding)
Transformer-XLなどの論文で採用されている方式
絶対位置埋め込みの欠点を補う
相対的な位置関係の獲得においてデータドリブンを解消
実装が複雑になり、後述する計算・メモリ最適化(Flash Attentionなど)が利用できなくなる
Attentionまで改変しなければならない
RoPEが叶えたこと
位置情報を「回転行列」を乗算することで付与する。
PEは加算での付与だった
系列長的な制約を減らせる上に、相対的な位置が離れるごとに内積が減衰していくので、自然言語的な特性を表現可能に
相対位置埋め込みと比べてとても簡単な実装になった
やっていること
今までの位置埋め込みは文字埋め込みのあとに加算を行っていた
これをやめて、Attention手前のLinearを通った文字埋め込みに乗算する
https://scrapbox.io/files/69275497beea9deb02faf7f2.png
RoFormerでは回転位置埋め込みはKeyとQueryにだけ適用する
https://scrapbox.io/files/692755b4b0b754f018182076.png
なぜ?
回転行列による位置埋め込み自体は、絶対位置を埋め込んでいる事になっている
が、内積を取るとこれが相対位置としての意味合いを持つ
「相対的な位置が離れるごとに内積が減衰していく」点と関連しているっぽい
y-chan.icon 実際計算してみるといいかも、めんどくさくてまだやってない
逆に言えば一度内積を取ってしまえばそれ以上は情報をややこしくいじるだけになってしまって無意味なので、Valueには回転位置埋め込みを適用しないほうが理論的できれいである
y-chan.icon と考えるのが良さそう
本当に実装すべきものはこれ
正方行列を計算するのはO(n^2)かかってしまうので、等価な計算を用意
https://scrapbox.io/files/692759ba6ca218944037ace4.png
sushichan044.icon 素朴に行列積を計算するとほとんど zero であることが明らかなので、結果が 0 ではない計算に絞って実施するのは妥当そう
右辺第一項は入力xそのままとcos行列でクロネッカー積?を取る
y-chan.icon 記号的にはクロネッカー積らしいんだよなぁ、わからん
いろんな再現実装を見る限りアダマール積っぽいんだけど
helkun.icon$ 1\times d^2の行列がが爆誕..?
クロネッカー積は嘘らしい
sushichan044.icon クロネッカー積ではありそうだがラマヌジャンなんだよなぁ
$ 1 \times d^2 次元の行列ができるはず...
違うわ、次元が変わっているのでクロネッカー積なわけないわ
趣旨に反する
アダマール積(要素積)説が濃厚
https://scrapbox.io/files/6927f3bfcc08bee56199d754.jpeg
右辺第二項は入力xを1要素ごとに反転させ、符号を逆転させたものとsin行列とのクロネッカー積
ちなみに要素の反転と符号逆転をする行列はこうやって作れる
code:python
def rotate_half(x: Tensor) -> Tensor:
# x: dim
x = x.reshape(*x.shape:-1, -1, 2) # dim/2, 2
x1, x2 = x0, x1
rotated = torch.stack((-x2, x1), dim=-1) # dim/2, 2
return rotated.reshape(*x.shape:-2, -1) # dim
a = torch.arange(10) # tensor(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
b = rotate_half(a) # tensor(-1, 0, -3, 2, -5, 4, -7, 6, -9, 8)
mに関して
時間を表すもの、PEにおけるpos
thetaに関して
こうなるらしい、10000って数についてはPEを参考にしているが根拠はないっぽい?
https://scrapbox.io/files/69275d1377c6c6852961c01e.png
PyTorchで書くと
code:python
base = 10000
thetas = 1. / (base ** (torch.arange(0, dim, 2):(dim // 2).float() / dim))
あとは時間軸に伸ばす必要がある
時間軸方向の横ベクトルと、thetasの縦ベクトルで外積を取るといい感じになる
torch.outerといういい感じの関数があります
こいつもラマヌジャン(?)
なぜうまく動くかは説明できないらしい
helkun.icon機械学習あるある:ラマヌジャン
Grouped Query Attention(GQA)
https://alphaxiv.org/pdf/2305.13245
sushichan044.icon 便利サービスこと alphaxiv
y-chan.icon deepwikiと似たような感じ、なぜこれが無料で提供されているのかわからないでおなじみ
https://scrapbox.io/files/6926bb60817c43538a22920c.png
既存手法の解説から
Multi Head Attention(画像左)
一番最初の原型、メモリをめちゃめちゃ食う
昨今のLLMは推論に当たってKVをキャッシュする挙動がよく見られるが、メモリ帯域が問題に。
Multi Query Attention(画像右)
メモリ帯域が問題なら情報を減らせばいいじゃん、という発想
なので、複数あるHeadのKVを1つにまとめてしまって、全Headで同じKVを参照するようにすればメモリも計算量も減らせていいんじゃね?っていう提案
性能がそこそこ落ちる(それはそう)
Grouped Query Attention(画像真ん中)
MQAは流石にやり過ぎなので、一定数のHeadだけで同じKVを参照する形にした
MHAよりは効率的
性能低下をある程度避けて、MQAと同等の推論速度が出たとのこと
y-chan.icon MQAはメモリ帯域を削減しまくったけど、GQAは帯域がギリギリ詰まらない程度にメモリを使うって感じなのかな
sushichan044.icon real-world だと N head で共有するの N はどのくらいなのか
y-chan.icon 2か4かなぁ
実装は驚くほどシンプル
1. KVのLinearのout channelを減らす(2つのHeadをまとめるとすれば1/2)
2. Linearのout channelが減った分、expandしてreshapeすればいい
ちょっとだけ頭を捻ってshapeパズルをしないといけない
使うのはunsqueeze、expand、reshape
y-chan.icon こんにちは、Shape合わせ芸人と申します
Flash Attention
Flash Attention-2
https://alphaxiv.org/pdf/2307.08691
これは今回は実装しません(CUDAを自力で書く必要があるため)
本来、AttentionはO(n^2)の計算量・メモリ量が必要
これは学習に置いても推論に置いてもかなり欠点
これをやめたい
なんとメモリをO(n)で計算する方法が存在するらしい?
もともと、Metaがメモリ効率を上げてO(n)もしくはO(n log n)にするという手法を作っていた
それに影響されているのかは定かではないが、Flash Attentionがその後に登場
MetaのライブラリもFlash Attentionを統合したという流れっぽい
仕組みはわからん
簡単にまとめてもらった(by Claude Sonnet 4.5)
Flash Attention(初代)の貢献
ブロック分割(タイリング)とGPUメモリ階層の活用
通常: 全体のQK^T行列(n×n)を一気に計算 → HBM(High Bindwidth Memory)に保存
FlashAttention: 小さなブロックずつ計算 → SRAM内で完結
SRAMは高速らしい(小容量だが19TB/sぐらい?HBMは大容量だが2TB/sぐらいらしい)
オンラインSoftmax + 再計算
Softmaxを全体に対して計算せず、ブロックごとに計算して統計量だけを保持
逆伝播の際に必要な情報(全体のSoftmax?)を再計算する
Flash Attention-2の貢献
Matmul演算の有効活用(GPU演算器の活用)
高速化として貢献
系列長方向への並列化
ワープ間の作業分割の最適化
ワープ: GPUのスレッドをグループ化したものらしい
helkun.iconkernelをかいた!?!?!?
CUDAを描かないといけない理由はこれらしい
y-chan.icon Flash Attention 3 は H100 とかじゃないと効果がないらしい
sushichan044.icon 草
今回学習に使うGPUがRTX3090とコンシューマ向けのちょっと古いGPUなので、Flash Attentionを導入してあげると学習効率を結構ちゃんと上げられるはず
xformersというMetaのライブラリが良さげっぽいので使います
PyTorchにも含まれているが、maskが使えないっぽいので却下
実装するぞ!!
helkun.iconうおお!!
最終的に実装する物の形を書いておきます。
Scaled Dot-Product Attentionはそのまま
https://scrapbox.io/files/693e82ddcdc8d6fc62927e9c.pnghttps://scrapbox.io/files/693e58c246ae6dcd34f7008b.png
Softmaxは入れない
これはLossを計算するときに使うCrossEntropy側にSoftmaxと同等の処理を行うため
推論時は必要になる
https://qiita.com/ground0state/items/8933f9ef54d6cd005a69
LayerNormはPostではなく、Preでやる
詳しい話はここにかいてある
https://zenn.dev/shot4410/articles/5354ce65907e15#構造(post-ln-%2F-pre-ln)
Post-LNでは層の数に従って指数的に勾配が減少する、いわゆる勾配消失の問題が存在する
https://proceedings.mlr.press/v119/xiong20b.html
トピック3は次のページへ
Transformerを実装して少し理解した気になる会(その3)