Transformerを実装して少し理解した気になる会(その3)
#jam #Transformerを実装して少し理解した気になる会
3. TransformerのDecoder部分を利用し、SimpleStoriesデータセットで学習を行う
y-chan.icon 実装するしかねぇええぜええええ
実装に当たって、datasetsライブラリと、tiktokenライブラリを使う
datasets: SimpleStoriesデータセットを自動でセッティングしてくれるめっちゃ嬉しいやつ
Hugging Faceのデータセットをいい感じに持ってきてくれる
tiktoken: OpenAIが提供する、文字列をトークンにエンコード/デコードしてくれる嬉しいやつ
いろんなエンコード/デコード手法があって、語彙を学習したような結果らしい
GPT-4oとかで使われているo200k_baseというモデルや、GPT-4で使われているcl100k_baseというモデルがある。
200kを使いたいけど、語彙20万って思ったよりデカい(GPUメモリをいっぱい使う)ので一旦100kの方を使います。
古いgpt2モデルとかは日本語の語彙が微妙らしい
gpt-ossで使われているharmonyというトークナイザーもあるけど、今回は使用しない
(2025/12/24追記)
トークナイザーは自分でBPEを学習させて作ったほうが良かったかも。
GPTのものだと、日本語・英語以外もあるから無駄に大きい。100kの語彙ですらデカい可能性があって、出力の完成度を高めるには使わない語彙は排除したほうがいいかも(未検証)
学習スクリプトに必要な要素
設定(Config)のロード: モデルのハイパーパラメータ(学習前に人間が設定する値)や、モデルパラメータの保存先を指定するファイルを読み込む
データセット準備
データローダー(DataLoader): データセットをミニバッチ化したり、モデルに流し込めるような形に整形したりする。
ミニバッチ化: いくつかの学習データをひとまとめにすること、複数のデータを一気にGPUに流すことで、学習を安定させたり、GPU資源を有効活用して高速化させることができる。一つのデータに対して最適化させるより、複数のデータに対して最適化させるほうがモデルを汎化させやすく、学習効率も良いため、ミニバッチ化はよく用いられる。ミニバッチ自体は様々なサイズを指定できるが、通常は2の累乗値を設定する。
モデルに流し込めるような整形: 複数のデータをまとめてミニバッチ化すると、文字列であれば長さが違うことがある。torch.Tensorやnp.ndarrayというのは、全次元において同じ長さでなければならないので、文字列をカットしたり、あるいは長いものに合わせる形で0埋め(パディング)したりしなければならない。そういった処理を行う。
オプティマイザー(Optimizer): モデルのパラメータを更新するマン、いろんなアルゴリズムがある。今回はAdamWというオプティマイザーを使う。
スケジューラー(Scheduler): モデルのパラメータを更新するに当たって、その更新割合を調節するパラメータがある。それがラーニングレート(Learning Rate)。ただ、ずっと同じラーニングレートを使用し続けると、学習結果が発散することがあるので、学習が進むにつれて何らかの方法で減衰させる必要がある。それを制御するのがスケジューラーである。
モデルパラメータのロード/セーブ: 学習が途中から再開できるように、あるいは推論に使えるように、学習結果としてモデルやオプティマイザーのパラメータを保存・読み込むことが多々ある。PyTorchに限らず、深層学習系ライブラリはそういった機能が標準であるので活用する。
あとは、データローダーからデータを取り出し、モデルに流し込み(forward)、予測結果を得る。その後、予測結果と正解データの差(Loss)を何らかの手法、今回であればクロスエントロピー(CrossEntropy)で取り、モデルに勾配を溜める(backward)。溜まった勾配を元に、モデルのパラメータを更新する、という流れを繰り返す感じ。
Lossを計算する際は、パディング部分に関してはLossを取らないようにする。maskをうまく活用する。
勾配をすぐに流してモデルパラメータを構築する場合もあるが、勾配蓄積(Gradient Accumulation)というテクニックを使う場合もある。これはミニバッチを更に分割し、GPUメモリが少ない場合でも大きいミニバッチを学習した際と同じふるまいをさせることができ、学習安定化の効果が期待できる。
ほかにも、データを記録し、閲覧できるツール(TensorBoardやWeight&Biasなど)のために、学習進捗や、途中時点の生成結果を保存するのがよく見られる。
あとはAMP(自動混合精度)とかあるけどもうめんどいので省略
というわけで、これを実装した結果がこちら
もうこっち読んで、コードとコメントのセットのほうがわかりやすい多分
https://github.com/y-chan/learn-and-make-slm/pull/15
トピック4は次のページへ
Transformerを実装して少し理解した気になる会(その4)