Evolutionary Optimization of Model Merging Recipes
会議
Motivation 選んだ理由
今までに無いタイプのLLMの論文なので中身が気になる
計算量が小さいことがウリなので、実用性が高そう
Summary どんなもの?
既存のモデルをマージ・結合して新しいモデルを作る
どうマージ・結合するかの探索に進化的アプローチを取っている
Contribution 先行研究と比べてどこがすごい?
モデルのマージ自体は実は古くから行われてきた
1990年代から画像処理モデルや画像分類モデルではモデルのマージが使われてきた
画像生成モデルで特にマージしたモデルが一般的に使われるように
LLMのマージ自体も先行研究がある
Open LLM Leaderboard の上位モデルのほとんどは、言語モデル愛好家のコミュニティによって作られたマージモデルで徐々に占められてきている
貢献は
モデルマージの自動化
KKD(勘・経験・度胸)でやってきたモデルのマージを進化探索で置き換え
クロスドメインのマージ
計算量を抑えつつ、7Bのモデルで70B相当のパフォーマンス
日本文化特有のコンテンツを扱う能力を実証
実数の世界にある画像に比べると離散的な言語モデルはマージが難しく、マージの手法が研究されている
Task Arithmetic: FTされたモデルの重みと事前学習モデルの重みの差分をタスクベクトルとし、算術演算によってマージ
TIES-Merging: モデル間でマージするパラメタを工夫(FTで変化が小さいパラメタを除く、パラメタ間の符号を一致させる)
DARE: "Language Models are Super Mario" というおもしろタイトル論文で提案された手法。FTで変化が小さいパラメタを除き、パラメタをリスケールしなおしたモデルをマージする。マージには他の手法との組み合わせが使われる
Method 技術や手法のキモはどこ?
この手法で扱うマージの手法は2つ
https://gyazo.com/4fdde887f405d167c749a3fdeaf2a834
PS: パラメータ重みを混合させる
進化的アルゴリズムを使って探索(CMA-ESを利用)
多変量正規分布に従って個体を生成、上位n個体を抽出し、多変量正規分布のパラメタを更新、を繰り返す
(おそらく多変量正規分布にしたがって重みを混合させる割合を決定している?)
目的関数は各タスクのスコア(MGSMの精度、VQAのROUGEスコアなど)
DFS: レイヤー自体をつなぎ合わせる
層の並び替え
合計M層のベースとなる2モデルからT層のモデルを作るとすると、$ (M+1)^T の組み合わせを探索しなければならない
同じ層の繰り返しや層間の順序を入れ替えるとパフォーマンスに悪影響であることがわかったため、そのような組み合わせを除いて、探索空間を $ 2^T に制限できる。(2はモデルの数)
層間の結合
ただ異なるモデルから層を持ってきても、普通はうまく動かない
レイヤ$ iと$ jをつなぐスケーリングを$ W_{ij}、ただし$ W \in {\cal R}^{M \times M} としてこれも進化的探索で最適化する。
しかしMの2乗は大きすぎるので、FFNで近似$ W_{ij} = \pi_\theta (i,j,t) 、パラメタ $ \theta を探索する ($ t \leq T)
これも進化的アルゴリズムを使って探索(これもCMA-ESを利用)
PSとDFSの組み合わせ
PSとDFSを同時に行うとおそらく探索空間が広くなりすぎる上、性質が異なるため扱いにくい
まずPSマージ(探索)を適用して、複数のマージモデルを生成
次にDFSマージ(探索)を適用して最終的なモデルを生成する
Experiments どうやって有効だと検証した?
日本語数学LLM
ベースとなったモデル(すべてMistral-7B-v0.1 のFTモデル)
日本語能力の高いモデル
数学に特化したモデル
数学の能力が高いモデル
データセット
MGSMデータセットの日本語テストセットで評価(250事例)
Input 問題:ロジャーは5個のテニスボールがあります。テニスボールの缶を2つ追加で買います。それぞれの缶には3つのテニスボールが入っています。彼は今いくつのテニスボールがありますか?
Output ステップごとの答え:ロジャーは最初5個のボールがありました。テニスボール3個入りの缶が2つあれば、テニスボールは6個あります。5+6=11。答えは11です。
進化的探索には MGSMには含まれていない GSM8kの他言語のデータを日本語に翻訳して利用
元々のモデルがGSM8kで学習されているため、予備調査ではうまくいかず?
評価方法
評価基準:回答の数値が正しい+理由が日本語で書かれている
出力の最後にでてくる数値を回答とした
複数のモデルを統合すると、出力形式を修正するのが難しい
ほぼすべてのケースで正しく答えは抽出できている
生成はGreedy サンプリング、zero-shot pass@1 で評価
最適化
PS の最適化には Optuna のCMA-ESを利用。マージの手法にはTIES-MergingとDAREを採用している。
DFSの最適化にはEvoJAXのCMA-ESを利用。
同じCMA-ESだが実装が異なる理由はわからず
結果
https://gyazo.com/11f1f61b174f9af134794e775f8f93f4
MGSM-JA は日本語での数学能力、JP-LMEHは日本語の一般的な能力をしらべている
https://gyazo.com/c5e76227edc3e6b02bffea84a7e42bf1
解けた問題の図示。色付きのバーで正しく回答できた問題をしめしている
PSが特に性能が高いものの、DFSやPS+DFSもそれぞれ元のモデルより性能が向上している
日本語能力と数学能力をそれぞれ別のモデルから引き継いで合わせて使うことができている
元モデルではどれも解けていない問題が解けるようになっているので、日本語能力と数学能力を融合して使えていることが示唆される
https://gyazo.com/f9cea32567ac17c4773092a0d35408c5
一般的な日本語能力に関するベンチマークの詳細
JComQA: 常識問題 「スープを飲む時に使う道具は?」
JNLI: 含意関係認識 「文1: 野球選手がバットをスイングしています。文2:野球選手がキャッチボールをしています。」
MARC: レビュー文からスコアを推定
JSQuAD: コンテクストが与えられた上で、一般的なQA 「2014年のドバイの国内総生産は?」
JAQKET: クイズ 「童謡『たなばたさま』の歌詞で、「さらさら」と歌われる植物は何の葉?」
XLSum: ウェブ記事の要約
XWino: 文章内の穴埋め問題
MGSM: 数学問題
JCoLA: 自然な文の選択「あなたは寒いです vs 私は寒いです」
PSではJSQuADやJAQKETのスコアも向上している
JComQAやJNLIではやや下がるが、探索で数学能力のみを考慮している事を考えると本来は下がって当然で下がり幅が小さいと言っても良さそう
Shisa Gamma は日本語モデルなのに他よりJSQuADが低いが、マージしたモデルでは元のどのモデルよりも高いスコアがでている。
言語独立の文章理解のような抽象的な能力も転移できている?
モデルの分析
PS、DFSともに日本語モデルがベースになって混合されている
DFSでは日本語モデルで低層〜中層までが占められ、その後両モデルが混合して現れる
日本語VLM
shisa-gamma-7b-v1 と LLaVA-1.6-Mistral-7B をマージする実験
探索には Japanese Visual Genome VQA を使用
https://gyazo.com/d36340d20ebfcbbc9491a30bd98b8edf
LLaVAとJapanese Stable VLM はベースライン
LLaVAに日本語能力を追加することで、性能向上できている
Discussion 議論はある?
今の手法では、材料となるモデルのセットをユーザーが選択する必要がある
・既存の膨大なモデルの母集団から候補となるソースモデルを探索することも可能だと考えている。
・自己改良が可能なモデルの群れからなる集合知の出現を可能にする可能性を秘めている。
モデルのマージがあくまでソースモデルの能力の組み合わせにとどまるのか、それ以上まで進めるのかは気になる所
例えば全く新しい能力の獲得はできるのか?(新しいタスクに対して新規モデルを学習するよりも効率的にできる可能性があるか?)
モデルの訓練に対するコストは上がり続けているので、1つの巨大モデルを作るよりは複数の特化モデルをマージして利用するほうがコスト的に優位かもしれない?
小規模な社内データなどでドメイン特化のLLMを作り、そこに個別タスクの能力を後付けできるとLLM用途が広がりそう
マージされたモデルは論理的な一貫性を欠いた応答を生成する場合があった
FT されていないので、事実と異なる出力を含む場合がある
とくにDFSのモデルで、FTとの組み合わせはDFS由来の不整合を修正できる可能性があって効果が高そう
モデル出力の論理的一貫性は特定のレイヤではなく、個々のレイヤでそれぞれ学習している?
論理的一貫性に特化したモデルをマージするようなことはできない?