trl
https://gyazo.com/889b2577018c9101c7cb3181faa251eb
かっこいい
TRL - Transformer Reinforcement Learning
completion only の学習
チャットで user の入力部分は loss を計算せず、assistant 部分の loss を計算したい
よくあるのが messages 形式
code:messages形式.py
{
"messages": [
{"role": "user", "content": "東京の天気は?"},
{"role": "assistant", "content": "晴れです"},
]
}
Tokenizer に chat template がついていて
code:chat_template.json
{% for message in messages %}
<|im_start|>{{ message'role' }}
{{ message'content' }}<|im_end|>
{% endfor %}
適用すると
code:chat_template_applied.txt
<|im_start|>user
東京の天気は?<|im_end|>
<|im_start|>assistant
晴れです<|im_end|>
これを数値トークン化したものがモデルに入力される
これの 晴 から学習(loss を計算したい)
今どきの SFTTrainer は messages 形式のデータセットを train_dataset に渡せばよしなにやってくれる?
DataCollatorForCompletionOnlyLM は使わなくてもいい
messages 形式ではなく独自フォーマット / chat テンプレート使わない / 細かく制御する / 古いテンプレート
assistant_only_loss=True にすると全ての assistant ターンが学習対象になる
(余談) add_generation_prompt
code:add_generation_prompt.py
tokenizer.apply_chat_template(
messages={"role":"user", "content": "Hello!",
add_generation_prompt=True
)
すると、assistan の開始トークンがたされる
code:applied.txt
<|im_start|>user
Hello!<|im_end|>
<|im_start|>assistant
推論時はつける / 学習時はつけない(入力には普通 user, assitatnt の会話の組が渡ってくるので)
#LLM