Attention Is All You Need
2022.2.21 Asuka
year : #2017
tag : #Attention #Transformer
Paper
どんなものか ( 要約、システムのI/Oなど)
Transformerの原論文
元々はNLPの翻訳タスクを解くモデル
RNNやCNNを一切使わず、Attentionのみを利用したEncoder-Decoder modelを構築した。
各レイヤー内における逐次処理の必要がなくなったため学習が大幅に高速化した。一方、出力の品質は非常に優れている。
TransformerのアーキテクチャはNLPだとBERTやGPT-2, CVだとVQ-GANやDALL-Eといったモデルに使われている。
先行研究と比べてすごいところ、貢献
既存の問題点:LSTMやGRUといった再帰的推論を行うモデルは、逐次処理を必要としたため学習の並列化が困難であり、学習コストが高く、Token間の距離が大きくなると依存関係の表現が困難となる。Attentionは入力または出力内のTokenの距離に関係なく各Token間の依存関係を表現することができるが、(当時は)RNNと組み合わせて利用されるために学習のコストは依然として高いままである。
貢献:RNNやCNNを一切使わず、Attentionのみを利用したEncoder-Decoder modelを構築したことにより、学習の大部分を並列化することに成功した。
既存の問題点:逐次処理を減らす目的で開発されたExtended Neural GPU, ByteNet, ConvS2SといったCNNベースモデルは学習の並列化が可能であるものの、I/Oシーケンスが大きくなるにつれ計算量がO(N), またはO(logN)で大きくなるため離れた位置間の依存関係を学習することが困難である。
貢献:Transformer内のSelf-Attention層では、離れた距離間の依存関係をシーケンスの長さによらずO(1)で計算することができる。
手法
原論文では、翻訳タスクを解いている。
アーキテクチャ
1: Summary
https://gyazo.com/caac4fed59e12bd9e80628a54ed96d99
TransformerはInput, Encoder, Decoder, Linear&Softmax, の4段階に分けられる。
以下ではわかりやすさのためにInputの詳細はEncoderの節で解説を行う。
Encoder-Decoderモデルとは、InputをEncoderで一旦特徴量を抽出した潜在表現zに落とし、それをDecoderでターゲットに変換するモデル。
https://gyazo.com/d232db5d769104b22fe93179b3d60a65
Encoder Input: 原文
Encoder Outuput: 原文の潜在表現
Decoder Input: なし or Decoder Outputで得られたToken列
各Decoderの途中でEncoder OutputがInputされる
Decoder Output: 翻訳後のToken(列)
Encoder, Decoderがそれぞれ6層スタッキングされている。Encoderの最終層のOutputで得られた表現を、Decoderの各層に入力している。
Transformer内では3か所(正確には3*Enc, Decの6層=18)で機能の異なるAttention層が働いている
Encoder Self Attention Layer
入力文内のAttentionを計算
Encoder-Decoder Attention Layer
入力文で計算したAttentionを出力文の生成に反映する
Decoder Masked Attention Layer
Decoderの再帰的出力のAttentionを計算
2: Encoder
Summary
https://gyazo.com/4e36c4b0132f02cce44274ed42b61710
Encoderは、Self-Attention層とFeed Forward層のサブレイヤーに分かれる。さらにそれぞれのサブレイヤーの後には、サブレイヤー入力前にキャッシュした残差のAddとNormalizationを行うAdd&Norm層が続いている。
このEncoderが原論文では6層(Encoder#0~Encoder#5)スタッキングされている。
Self-Attention
そもそもAttentionとは?
入力されたデータのどこに注目すべきか、動的に特定する仕組み。
Self-Attentionはあるデータを処理する際に注意の先をデータ自身にしている。
これにより、現在処理中の単語に対して他の単語の表現を組み込んだ計算が可能となる
Attentionを行う手法はさまざまだが、TransformerではQuery, Key, Valueという3つのベクトルで計算されるScaled-Dot-Product-Attentionという手法をとっている。
Q, K, Vについて
(かなり大雑把な説明だが)Q=ある単語, K=ある単語の注意されるベクトル, V=単語の潜在表現となる本体
Self-Attentionの仕組み
https://gyazo.com/404e23a5b2e6e3a022c261e84e7df7c5
1: Inputの文章Xをx1...xnとしてToken列化する(xiは512次元のベクトル)
これ以降は説明のためx1("Thinking")の処理についてを取り扱う。実際にはxは行列なので同時処理されている。
2: x1に対してそれぞれq1, k1, v1を計算する
qi = xi * Wq, ki = xi * Wk, vi = xi * Wv
q, k, v は64次元に固定される(dq, dk, dv = 64)
Wは学習可能な重み
3: s1 ={q1・k1, q1・k2, ...q1・kn}を計算する
処理対象のQueryと、入力されたすべてのTokenのKeyの内積を算出し、処理対象がどのTokenにAttentionを向ければいいのかをScore化する
4: s1 = s1 / 8 を計算する
ScoreをKeyベクトルの次元の平方根(sqrt(dk))で割り算することで正規化
5: 2~4を経たs1をSoftmax
6: z1 = v1*s1を計算する
s1(Attentionの情報)とv1(単語の潜在表現となる本体)を乗算することで、注目したい単語の値をそのままにしておき無関係な単語の情報をそぎ落とすことができる
以上の操作を行列形式で行う。(1, 4, 5は同じ操作なので省略)
2: X(shape=n, 512) * Wq,k,v = Q, K, VとしてQuery, Key, Valueを計算
https://gyazo.com/bcebcd12f39c40151b837d37739166a1
3: QueryとKeyの転置行列を乗算しS={s1, s2...sn}を算出
6: Z = Softmax(S/sqrt(dk))*V
まとめ
単語のEmbeddingにWQ, WK, WV(学習可能な重み)を乗算することでQuery, Key, Valueを作る
ある単語のQueryに対してすべての単語のKeyの内積をとってあげた値ScoreがSelf-Attentionとなる
ScoreにたいしてValueを乗算することで、Attentionを向ける必要のない情報をそぎ落とす
つまりScaled Dot-Product Attentionとは
https://gyazo.com/9c673bbbc75b5e24097ac49d7ee6a82e
Multihead-Attention
Multiheadってなに?
上述したSelf-Attention層を8つ並列処理させることで、「どこに注意を向ければいいのか」という情報も8種類獲得できるようにした。
入力Xが8つのAttention Headに流されるためMultihead-Attentionというネーミング
https://gyazo.com/af64aaa4751dfefcc8f2c80cee1c301f
各Attention Headはそれぞれ独立した重みを持っているため、計算されるQ, K, Vも異なる
Multihead-Attentionについて
https://gyazo.com/2406bc1f7473e015c0cbc71ea3f17925
XがAttention-Head#0~#7を通過して算出されたZ0...Z7をconcatし、学習可能な重みWOと乗算することでshapeを(n, 512)に戻す
https://gyazo.com/cbaf8f7c44e979e584f8aa9590698071
Positional Encoding
Positional Encodingって?
上述した入力X={x1, x2...xn}は、単にTokenが行方向に並んでいるだけなので、このままではTransformer内で各Tokenの位置情報が扱われないままになってしまう。
そこで、入力されたxiにたいして位置情報を表すPosition Encoding Vector(shape=, 512(Embeddingの次元数と同じ))を加算することで入力系列の単語の順序を表現するのが、Positional Encoding
これによって、入力系列内で近い単語ほどAttentionが向けられやすかったり(逆も然り)、主語が存在する確率の高い文頭にAttentionが向けられやすいといった、「単語間の距離が持つ意味」を学習することができる。
さらに、この方法を使うと長さどこまで続くか不明な系列に対しても位置情報を付与することが可能
Position Encoding Vectorの計算方法
https://gyazo.com/d59c7b0dfb5bebaf0cee5a6a7f3d57e4
各Tokenに対してsin波とcos波から値をサンプリングし、sinからサンプリングした値を偶数次元、cosからサンプリングした値を奇数次元に代入している。
https://gyazo.com/073232402dce866307817b417226ddf0
行方向にx0...x9が並んでいる図。図では64次元までしか表示されていないが、実際は512次元。
Residual
Encoder内の各サブレイヤーの前では出力がキャッシュされ、続くレイヤーで加算される。さらにその値がLayerNormされる。
https://gyazo.com/dd22bf2c04f5acc93f76cd4947b8a2a1
Feed-Forward
各Multihead-Attention層のあとには全結合なFeed-forward層が続く。
この層はPosition-wise Networkと呼ばれ、要は入力Z={z1, z2...zn}の各Tokenが独立してNetworkに入力されており、Network内における各単語の依存性は存在しないということ。(ただし、あくまでも一つのNetworkなので重みは同じものを利用している)
https://gyazo.com/bdb900e7ca3a6284ccdc9be210b1c774
2層のpointwise-convolutionが続いているようなもの。
3: Decoder
Summary
https://gyazo.com/41f1d38679516170b6d0d6008d457322
Encoder層で見てきたMutihead-Attention, Feed Forward, Add&Normのを使うのはまったく同じ
違うのは、Multihead-Attentionにマスク処理が加わったMasked Multihead Attentionと、各Multihead Attentionの入力
出力はZであり、出力されたZを次層のLinear&Softmaxでに通して1個のTokenを生成する。そのTokenを再帰的にDecoderにInputすることで単語列を生成する。
https://gyazo.com/089f83473785204a6aa786be16977f07
Masked Multihead Attention
カンニングマシンとならないように、Decoderが推論したTokenより後の位置にある入力ベクトルにAttentionがかからないようにマスクしている。
Input: 正解ラベルのEmbedding+Positional Embedding (実際は最終層で出力したTokenよりも後のTokenはマスクされている)
Output: Z
Encoder-Decoder Multihead Attention
Decoder2層目のMultihead Attentionを指す。
Encoderの最終層におけるKeyとValueがInputされ、QueryはDecoder内前層の出力X*Wqとなる。
これにより、デコーダ内で処理されている情報のどこにAttentionを向ければいいのかをEncoderが教えてくれる。
Input Masked Multihead Attentionの出力
Output: Z
4: Linear&Softmax
https://gyazo.com/ca86d4b3a887f89cf38074f2b539a50c
1: Decoderの出力を全結合層に通してvocab_sizeの1次元ベクトル化
2: Softmaxで合計を1にした後にargmaxで確率が最大のTokenを特定する
Input: Decoder Output
Output: 翻訳語のToken(あくまでここは再帰的なDecodeを行うため、i番目のTokenのみを生成していることに注意)
その他
Optimizer
Adamを利用
Dropout
Multihaed AttentionとAdd&Normの間、Positional Encodingsの加算前、EncoderとDecoderへの入力直前にp=0.1でDropoutしている
検証方法
Dataset
英独, 英仏の翻訳タスクで検証(WMT2014)
Hardware
P100*8台でbase model は12時間, Large modelは3.5日
Large modelは入力のEmbeddings, Feed Forward層, Attention Headの数が大きくなっている。
Token Select
beam search( beam_size=4, length_penalty=0.6)
Result
https://gyazo.com/a0a5b0102d732a64a9ff5b5c58e0b82b
いずれのタスクでもアンサンブルモデルを超えてSoTAを出した。
Training Costも大幅に削減した。
hparamsの変更検証
https://gyazo.com/bffe23c9e96a932f7b2c98b4453879c8
A: Attention Headは1つより増やした法がいいが、多すぎてもだめ。
A, B: KeyとValueの次元数を減らすと性能がさがる
C, D:学習可能パラメータ数が大きくなればなるほど性能は高い
C, D: DropoutやLabel Smoothingは有効
E: Positional Encodingsの変わりに位置を考慮したEmbeddingを使っても性能に変化はなかった
議論、課題
RNNやCNNの逐次性を排除した結果、品質を上げつつも大幅に計算量を減らすことに成功した。
NLPだけでなく画像や音声処理にも有効
Decoderは逐次的に単語を予測していくので、その処理を高速化する予知はある。
個人的には、「学習可能パラメータ数を大きくすればするほど性能が上がる」ということを現在のところ証明できている点が今までにない新規性だと思う。それを証明するためには学習の安定化と高速が必要不可欠であったが、Transformerはそれを可能にした。
次に読むべき論文
ViT原論文
コメント
参考文献
https://tips-memo.com/translation-jayalmmar-transformer
http://jalammar.github.io/illustrated-transformer/
https://gangango.com/2019/06/16/post-573/
https://qiita.com/omiita/items/07e69aef6c156d23c538
https://axa.biopapyrus.jp/machine-learning/model-evaluation/k-fold-cross-validation.html
https://colab.research.google.com/github/tensorflow/tensor2tensor/blob/master/tensor2tensor/notebooks/hello_t2t.ipynb#scrollTo=OJKU36QAfqOC