Keep what you need extracting efficient subnetworks from large audio representation models
https://scrapbox.io/files/67fd1496833a2bd64042afdb.png
論文リンク
文献情報
Keep what you need : extracting efficient subnetworks from large audio representation models
ICASSP2025
D. Genova et al. (IRCAM, Sqiarp instruments)
要は何?
音の基盤モデルを利用したdownstream taskにおいて、枝刈りをすることで対象のdownstream taskに特化したサブネットワークを作る方法
問題意識と解決策
近年はCV, NLP,さらに音分類タスクにおいて基盤モデルを用いるのはよく用いられる一つのやり方
膨大なデータで事前学習→目的タスク(スケールははるかに小さい)でファインチューンする
HuBERT, CLAP, MERTなど
しかし、基盤モデルはパラメタも多く計算量も要する→特にエッジデバイス等、限られた計算資源で作用させるには不向き
downstream taskに適用するときも、基盤モデルにヘッドをつけるという方法→さらにパラメタを増やすことになる
→基盤モデルの一部を切り出す;サブネットワークを作ればよいのでは?
yamamoto.icon 蛇足:本文、だってオーバースペックじゃん?みたいなことをいちいち挟むの草
→特に、構造化枝刈りによって計算資源利用を最小限にしたサブネットワークの作成手法を提案
関連研究
モデル圧縮
枝刈り(prunning)→タスクを解くのに重要でない重みを特定し、取り除く
高速化や省メモリ化が期待できる
オーバーパラメタライズなDNNのほとんどの重みは取り除けことが示されている(-95%くらい)
しかし、多くの枝刈り手法は繰り返しの追加訓練を要する
yamamoto.icon 追加訓練が必要な枝刈り手法に、宝くじ仮説に基づく方法を挙げていた
https://scrapbox.io/files/67fd1a23aa935d44070b68f4.png
そこで、より効率的な、構造化枝刈りに基づく手法を考える
yamamoto.icon 構造化枝刈りとは:層やフィルタなど粗い単位で枝刈りをすること
https://scrapbox.io/files/67fd1b373c091c2b10c957ca.png
ただし、構造化枝刈りは繰り返し追加訓練を要する枝刈り手法に比べて圧縮率が低くなりがち
基盤モデルをdownstream taskに活用する手法
linear probing: 出力層の後にヘッドをつけてヘッドだけ学習
finetuning: 基盤モデルのパラメータも学習
parameter-efficient transfer learnin (PETL)
adapter (中間層の出力に小さな学習可能層をつける)など
Scaling and Shifting your Features (SSF): 中間層のあとにパラメタが学習可能なアフィン変換をつける
$ \tilde{f}_{(l,\theta)} = \gamma_l \odot f_{(l,\theta)}(x) + \beta_l
$ \gamma,\betaは中間層の次元数を持つ学習可能パラメタ
https://scrapbox.io/files/67fd2059df400bdca553a978.png
手法
https://scrapbox.io/files/67fd237ca0e76b94832948ec.png
本手法では行う構造化枝刈りは、学習可能バイナリマスクを基盤モデルの各中間層に挟むことで実現
ヘッドとバイナリマスクモジュールのみを学習
基盤モデルの重み自体は固定のまま
バイナリマスクで0にあたった部分は学習後に取り除かれる
yamamoto.icon上記のSSFの、$ \betaが0固定、$ \gammaが0か1のバイナリ列バージョンだという解釈でいいよう。
これは、downstream taskを解くにあたっては、基盤モデルのユニットの多くは取り除けるという仮説に基づく
事前訓練の必要性はあくまで否定せず、downstream taskであればtoo muchで、サブネットワークの抽出ができるはず、という立場。
具体的なマスクの実現は以下の通り、
https://scrapbox.io/files/67fd267827809aea45f5aef2.png
xを入力したi番目の層の出力lにマスクmを通した後の出力
$ l(x,\psi_i, m_i) = m_i \odot l(x,\psi_i)
さらに、バイナリマスクとして扱うためにsigmoidを通して丸める
https://scrapbox.io/files/67fd276b0d5043b76a5dcd9c.png
yamamoto.iconマスクの値はsigmoidによって(0,1)になり、その丸め処理で0か1になるという感じ
さらに、タスクに特化したコンパクトな表現を得るため、sparsity inducing lossを提案
downstream taskのロスに追加で、係数λで和をとる
$ L_s = \frac{ \sum^{N}_{i=1}{||sigmoid(m_i - t)||_2 }}{N}
N:ユニット数 、t: ペナルティの強さ(ハイパラ)-> 低くすればより疎になる
実験とその結果
用いた基盤モデル
CLAP;general audioとtextでcontrastive learningしたモデル
MusicFM;音楽信号特化。Conformerで構成、BEST-RQに沿った方法で学習
Wav2Vec2.0 ;音声信号用。couv+transformer, マスク部のcontrastive lossで学習
downstream task
音声:wav2vec2.0を利用
1 librispeechのASR
2 Fluent Speech commandの Intent classification
3 Sppech commandsのKeyword spotting
環境音:CLAPを利用
4 ESC50のaudio classification
5 Urbansound8kのsound event classification
6 FSD50kのAudio tagging
音楽:musicFMを利用
7 GTZANのgenre classification
8 MagnaTagATuneのmusic tagging
9 NSynthのpitch tracking
headの実装
1以外→2層のMLP(1024次元、ReLU)でクロスエントロピー(tagging系(マルチラベル)はBCE)
1→2層のBLSTM、CTCloss
比較条件
トリムした割合 25%, 50%, 75%の枝刈りモデル、
それと同等のサイズでスクラッチ学習したモデル
SSF
通常のlinear probing
結果;精度
https://scrapbox.io/files/67fd29d8a7ddd93024eaf854.png
多くのタスクで25%枝刈り > linear probe を達成
50%枝刈りしても精度はそこそこ保たれている
同程度のパラメータでスクラッチ学習するよりも精度は大体いい
SSFと比べると精度は劣る
SSFはパラメータを削減はしていない
枝刈り量のベスト:ASR以外は20%以上枝刈りしたものがベスト
ASRはタスクが複雑なためだと考えられる
FSDは提案法が劣る:音やラベルの多様性
結果;速度やパラメタ数
精度の下がり幅が5%以内となる中で一番速度が出たモデルを報告
https://scrapbox.io/files/67fd2c3b046bf29650e34f92.png
CLAPやWav2Vec2.0では、分類タスクであれば2倍以上高速化できている
musicfmではそこまで速くならず。実装のボトルネックか、モデルのアーキテクチャによるものと考える。
yamamoto.icon conformerだと速くならなかったりするのだろうか
今後は生成などでも枝刈りの有効性を示したいとのこと
また、今回得たマスクの解析や、音基盤モデルが何を学習したかの解析を取り組みたいとのこと。
コメント
pruningというワードを論文タイトルに入れてほしかった感。
枝刈りを音基盤タスクでやってくれたのは良い仕事だと感じた
結局のところタスクによって最適なマスクsparsity具合(t)が異なるので、それを見つける方法は繰り返しトライアルするしかないのでは、という気も。
ラベルの対応に多様性があるタスクでは枝刈りが向かない、という結論に帰着する?そこの理論的な解析?ができるのであれば続報やってほしい感がある。