GSAM - Surrogate Gap Minimization Improves Sharpness-Aware Training
https://gyazo.com/4cd6d3940c35b980ad4c184f20793cc2
問題提起
SAMの計算式では, 本当にフラットな損失点を見つけているとは言えない
$ L_\mathcal{S}^\text{SAM}(\mathbf{w}) \triangleq \max_{\|\mathbf{\epsilon}\|_p\leq\rho} L_\mathcal{S}(\mathbf{w}+\mathbf{\epsilon})
例えば下の図では, 近傍 $ f_pについて最適化すると, SAMの場合, 青に収束してしまう危険がある
https://gyazo.com/82da06baa80880eb9e91d3db1383ba64
本当に見るべきは以下に定義するsurrogate gap $ h(x)
$ h(x) := f_p(x) - f(x)
surrogate gap $ h(x)については, Hessianの最大固有値との間で以下の関係が成り立つことが証明できる $ \sigma_{\mathrm{max}} ≈ \frac{2h(w_∗)}{ρ^2}
しかも, $ O(ρ^3)程度の誤差らしい
なので, surrogate gapがフラットな損失点へと収束することが理論的に証明されている
最適化の注意点
最適化したいのは, $ f(x), f_p(x), h(x)の三つ
ただし, $ min_w f_p(x) + \lambda h(x)を最適化するのは少し注意が必要
例えば, $ \nabla h = \nabla f_p - \nabla fは$ \nabla fと$ \nabla f_pとで内積が負の値になることがある
すなわち, 最適化のConflictが起きる可能性がある (下図参照)
conflict = 片方を最適化すると片方が最適解から遠ざかる可能性がある
なので, 実際のアルゴリズムは, $ \nabla hの直交成分を使って, 下図赤線の方向に解を更新する
https://gyazo.com/4cd6d3940c35b980ad4c184f20793cc2
SAMとの比較 (toy-setting)
(GIFアニメなので自動でループ再生されてます)
https://gyazo.com/d5f006c6aba4ae1a8a0645c137cf0b20
結果
https://gyazo.com/3e1e08fb1e0c9ba282c1f45211ff83a9