AutoFIS: Automatic Feature Interaction Selection in Factorization Models for Click-Through Rate Prediction
Author: Bin Liu, Chenxu Zhu, Guilin Li, Weinan Zhang, Jincai Lai, Ruiming Tang, Xiuqiang He, Zhenguo Li, Yong Yu
Organization: Huawei Noah’s Ark Lab , Shanghai Jiao Tong University
論文を選ぶ理由:
導入予定のDeepFMを改善できそう
大量な特徴の計算コストになんとかしたい
Huawei Noah`s Ark出品ので信用できそう
TL;DR
FM系モデルにTwo-stageの自動化特徴インタラクション選択アルゴリズムを提案した
Searchステージでアーキテクチャパラメータを導入し、正規化Optimizerを生かしてそれぞれのパラメータを学習する
Re-trainステージで学習したパラメータをFixしてAttentionとして使う
Huawei App Storeでオンライン実験した結果、DeepFMのCTRとCVRを20%以上に向上した
2、背景
CTR予測モデルExplicitに特徴インタラクション設計は精度を向上する
FM/MFなどは2階インタラクションを抽出する
しかし全てのインタラクションは有効ではない
Treeモデルは有効なインタラクションを自動的に学習できるが、一部のカテゴリ特徴だけ限られている
DNNモデルの表現力が強いが、シンプルなネットワークは望ましい形に自動的に収束する保証がない
そのためにネットワークのアーキテクチャに工夫して、explicit的にインタラクションを学習するモデルはたくさん提案された。特にFactorization Models(FM, DeepFM, PNN, AFM, NFM etc)
しかし、これらのモデルは簡単に特徴インタラクションをenumerateするか、特徴の人工的な選択に頼れるか
膨大な時間と計算コストかかる
ゆえに自動的に有効なインタラクションだけ絞る仕組みがほしい
3、提案手法
3.1 Modelの定式化
まずベースとなるFMモデルを定式化する
FM系のモデル:いくつかの特徴間のインタラクションを内積 or Neural Networkの形で一つの実数にマッピングするモデル
https://gyazo.com/97caa3bddeec9d1ea74838264ba19826
FM、DeepFMとIPNNの構造はFigure1のようである。それぞれのレイヤを定義する
Embedding layer
CTRタスクにデータは特徴をone-hot / multi-hotエンコーディングして入力する
$ {\bm x = \lbrack \bm x_1, \bm x_2, ..., \bm x_m \rbrack}
$ \bm x はmulti-field入力であり、$ \bm x_iはi番目の特徴fieldのone-hot / multi-hotのEmbeddingベクトルである
特徴Embeddingレイヤは上の入力を低次元のベクトルに変換するレイヤ
$ \bm e_i = V_i \bm x_i
ここ$ V_i \in R^{d \times n_i}はマトリックスである。$ n_iがi番目fieldの特徴値の数であり、$ dは低次元ベクトルの次元数である
もし$ \bm x_iがone-hot($ \bm x_i\lbrack j\rbrack = 1)の場合、$ \bm x_iの表現は$ V_i^jとなる
$ \bm x_i がmulti-hot($ \bm x_i\lbrack j\rbrack = 1for $ j = i_1, i_2, ..., i_k)の場合、$ \bm x_iの表現が$ \{V_i^{i1}, V_i^{i2}, ..., V_i^{ik}\}の平均 or Sumである
Embeddingレイヤの出力は複数embeddingベクトルのconcatenationである
$ \bm E = \lbrack \bm e_1, \bm e_2, ..., \bm e_m \rbrack
Feature Interaction layer
Embeddingレイヤの後ろはそれぞれのEmbeddingのインタラクションをモデリングするレイヤである。それぞれのEmbeddingペアで内積を計算する
$ \lbrack \langle \bm e_1, \bm e_2 \rangle, \langle \bm e_1, \bm e_3 \rangle, ..., \langle \bm e_{m-1}, \bm e_{m} \rangle \rbrack
このレイヤのインタラクション数はEmbedding数の平方になる
FMとDeepFMモデルでは、インタラクションレイヤの出力が以下になる
$ l_{fm} = \langle \bm w, \bm x \rangle + \sum_{i=1}^{m} \sum_{j>i}^{m} \langle \bm e_i, \bm e_j \rangle
ここ全てのインタラクションは平等に次のレイヤに貢献する。後のSection4の検証によると、全てのインタラクションは予測に役に立つわけではなく、一部のインタラクションが逆に損をもたらす
提案手法AutoFISの有効性を検証するために、三階のインタラクションも下のように定義した
$ l_{fm}^{3rd} = \langle \bm w, \bm x \rangle + \sum_{i=1}^{m} \sum_{j>i}^{m} \langle \bm e_i, \bm e_j \rangle + \sum_{i=1}^{m} \sum_{j>i}^{m} \sum_{t>j}^{m} \langle \bm e_i, \bm e_j, \bm e_t \rangle
MLP layer
NNの全結合層は以下のように計算する
$ \bm a^{l+1} = relu(W^{(l)} \bm a^{(l)}+\bm b^{(l)})
Output layer
FMモデルはMLPレイヤがないため、出力は以下である
$ \hat{y}_{FM} = sigmoid(l_{fm}) = \frac{1}{1+exp(-l_{fm})}
DeepFMモデルはインタラクションレイヤとMLPレイヤを並列に計算している
$ \hat{y}_{DeepFM} = sigmoid(l_{fm} + MLP(E))
IPNNモデルはMLPレイヤがインタラクションレイヤの後ろにある
$ \hat{y}_{IPNN} = sigmoid(MLP(\lbrack E, l_{fm} \rbrack))
IPNNモデルのMLPレイヤはそれぞれインタラクションのre-weightingと等しい。しかし、無効なインタラクションを識別と取り除くことができないため、さらにモデルにノイズと計算コストを増やす。Section4でAutoFISがこれを改善することを示した
Objective Function
cross entropyである
$ L(y, \hat{y_M}) = -ylog \hat{y_M} - (1-y)log(1-\hat{y_M})
3.2 AutoFIS
AutoFISは二つのステージ(search & re-train)を分けて、有効なインタラクションを識別してから再度学習する
Search Stage
gateのオペレーションを導入する。gateがオープン状態なら、特徴のインタラクションが選択され、クローズ状態なら選択されない
gateの数は2階インタラクションの数と同じ、$ C_m^2である。gateすべての組み合わせは$ 2^{C_m^2}種類があるので、brute-forceの計算は現実ではない
https://gyazo.com/76202810ecbd884dc8901ecc4767abfd
自動的にgateを学習するために、制限条件を緩和し、gateが連続値を取れるようにアーキテクチャパラメータ$ \bm \alphaを導入する
FMレイヤは以下のように変形する
$ l_{AutoFIS} = \langle w,x \rangle + \sum_{i=1}^{m} \sum_{j>i}^{m} \alpha_{(i,j)} \langle \bm e_i, \bm e_j \rangle
勾配降下法でアーキテクチャパラメータ$ \alphaを学習させ、重要ではないアーキテクチャパラメータを0にしてgateをクローズする
Batch Normalization
$ \langle \bm e_i, \bm e_j \rangleの値は$ \alpha_{(i,j)}と連結に学習しているので、両方Scaleのカップリングしている。そのために、$ \alpha_{(i,j)}は不安定になり、インタラクションの重要度を反映できなくなる問題がある。この問題を解決するために、$ \langle \bm e_i, \bm e_j \rangleにBatchNormalizationをApplyする
もともとのBNは:
$ \hat{\bm z} = \frac{\bm z_{in} - \bm \mu_B}{\sqrt{\sigma_B^2 + \epsilon}} and $ \bm z_{out} = \theta \cdot \hat{\bm z} + \bm \beta
$ \bm z_{in}, \hat{\bm z}, \bm z_{out}はBNの入力、正規化した値と出力 である
$ \bm \mu_B, \sigma_Bは入力があるミニバッチでの平均値と標準差
$ \theta, \bm \betaは学習できるscaleとshiftパラメータ
$ \epsilonは安定化のための定数
安定した$ \alpha_{(i,j)}を学習するために、scaleとshiftをそれぞれ1と0に設定し、$ \langle \bm e_i, \bm e_j \rangleのインタラクションを下のように計算した
$ \langle \bm e_i, \bm e_j \rangle_{BN} = \frac{\bm \langle \bm e_i, \bm e_j \rangle - \mu_B(\langle \bm e_i, \bm e_j \rangle)}{\sqrt{\sigma_B^2(\langle \bm e_i, \bm e_j \rangle) + \epsilon}}
$ \mu_B, \sigma_Bは$ \langle \bm e_i, \bm e_j \rangleがミニバッチ$ Bにおいての平均値と標準差
GRDA Optimizer
Generalized regularized dual averaging(GRDA) optimizerはスパースなDNNをトレーニングできる。アーキテクチャパラメータ $ \alphaを下の公式で学習させる
$ \alpha_{t+1} = \argmin_{\alpha} \lbrace \alpha^T(-\alpha_0+\gamma\sum_{i=0}^t \nabla L(\alpha_t;Z_{i+1})) + g(t, \gamma) \|\alpha\|_1+1/2\|\alpha\|_2^2 \rbrace
その中$ g(t, \gamma) = c\gamma^{1/2}(t\gamma)^u、$ \gammaは学習率である。$ c, uはAccuracyとスパース性を調整するハイパーパラメータである
このOptimizerを用いてスパースの$ \alphaを求め、重要ではないインタラクションを捨てる。他のパラメータは通常通りAdamで学習する
One-level Optimization
本研究ではアーキテクチャパラメータ$ \alphaを他のモデルパラメータ$ \bm vと一緒に学習した。関連研究のDARTSでは勾配降下法のステップで交代に$ \alphaと$ \bm vを学習したが、それが精度を低下の原因となる。そのために本研究は勾配降下法で$ \alphaと $ \bm vをそれぞれ下の偏微分で下降した
$ \partial_v L_{train}(v_{t-1}, \alpha_{t-1}) and $ \partial_{\alpha}L_{train}(v_{t-1},\alpha_{t-1})
Section 4でこのone-level optimizationはtwo-levelより良いことを示した
Re-train stage
search stageの後、学習したアーキテクチャパラメータ$ \bm \alpha^*に基づいて、gateの値$ G_{(i,j)}を決める:もしインタラクション$ \langle \bm e_i, \bm e_j \rangleの$ \alpha_{(i,j)}^* = 0なら$ G_{(i,j)} = 0を設定する。そのほか$ G_{(i,j)} = 1
そしてインタラクションレイヤの更新式は以下に変える
$ l_{fm}^{re} = \langle \bm w, \bm x \rangle + \sum_{i=1}^{m} \sum_{j>i}^{m} \alpha_{(i,j)} G_{(i,j)} \langle \bm e_i, \bm e_j \rangle
ここで$ \alpha_{(i,j)}はアーキテクチャパラメータではなく、Attentionの重みとなる
すべてのパラメータはAdam optimizerで学習する
4、Experiments
二つの公開データセットと一つのprivateデータセットでオフライン実験した。オンラインのABテストも実験した。
以下のResearch Questionを研究した
RQ1: AutoFISはFMモデルの精度を向上するか
RQ2: シンプルなモデルで選択された特徴は、SOTAモデルに遷移して推論時間を抑えて精度改善できるか
RQ3: AutoFISで選択した特徴は本当にusefulなのか
RQ4: AutoFISは既存モデルのオンラインレコメンドシステムを改善できるか
RQ5: AutoFISそれぞれのコンポーネントの貢献はどれぐらいあるか
4.1 Dataset
データセット:Avazu(KaggleのCTR予測データセット)、
Criteo:1ヶ月間数十億サンプルのクリックログデータセット。data 6-12をトレーニングとして、day 13をEvaluationとして使われていた。Negative down-samplingして50%のポジティブ率にした。13個の数値特徴はBucketingでone-hotエンコーディングした。20回以下現れた特徴はダミー特徴「other」にした
Private:Huawei App Storeのゲームレコメンドシステムから収集したデータセット
https://gyazo.com/f9dff20a4eea9bd423f5b7174ed0931b
4.2 実装詳細
2階特徴インタラクションを選択する時、search stageでモデルの$ \bm \alphaと$ \bm vを全トレーニングデータで学習する。そして無効なインタラクションをリムーブしてre-trainする
3階インタラクションにおいて、2階の時学習した選択をリユースして、3階の特徴をenumerateして重要度を学習した。最終的に2階と3階の学習した有効なインタラクションでre-train
search stageの$ \bm \alphaはGRAD Optimzerで学習し、他のパラメータ$ \bm vがAdamで学習した。re-trainステージで全てはAdamで学習する
4.3 AutoFISの特徴選択(RQ1)
Table 1でRQ1に対する実験結果をまとめた
https://gyazo.com/74812a021c249037e494412bc081c62a
(1)Avazuで71%のFM、76%のDeepFMインタラクションをリムーブできる。推論時間を減らすだけではなく、精度も上がる
(2)3階インタラクションに2%-10%だけ有効である。AutoDeepFM(3rd)とAutoFM(3rd)はFMとDeepFMにも同じレベルのスピードが出て、精度がより高い
(3)このような改善は比較的に少ない時間でできた。(AutoDeepFM(3rd)は24分・128分で2階・3階インタラクションのsearch実行できた。シングルGPUで)
3階のインタラクションを全て推論すると、FMとDeepFMの推論時間はOut
4.4 選択したインタラクションの遷移(RQ2)
https://gyazo.com/eaf178769bf4ff330b1d56617d64b7a5
AutoFMで選択したインタラクションは、IPNNに遷移できるかについて実験した。AutoFMで2階のインタラクションを選択した結果(AutoIPNN(2nd))は元のIPNNと同じぐらいの精度が出た。AvazuとCriteoで30%と50%のインタラクションが残された。さらに2階と3階選択したインタラクションの遷移(AutoIPNN(3rd))が精度を改善した
4.5 AutoFIS選択した特徴インタラクションの有効性(RQ3)
4.5.1 Real dataに選択した特徴インタラクションの有効性
選択したインタラクションの有効性を検証するために、statistics_AUCを定義する:あるインタラクションに対して、テストサンプルへの予測値は統計CTR(downloads/impressions)の場合だけそのインタラクションを有効にする予測モデルを作る。
このモデルのAUCは該当インタラクションのstatistics_AUCである。高いstatistics_AUCはインタラクションの重要性を示している
Figure3で本研究のモデル選択したインタラクションが比較的に高いstatistics_AUCがあることを示した
https://gyazo.com/2cbc0359bb1dedcd9fe706ecf880046e
一部の高いstatistics_AUCのインタラクションは選択されなかった理由は、それらの情報がすでに他のインタラクションに含まれているから
Statistics_AUCで選ばれたインタラクションとAutoFISのモデルを比較した。提案モデルの方がパフォーマンスが良い
https://gyazo.com/b92c2bfa075c986a9c0a91101ab6a7f0
4.5.1 Synthetic dataに選択した特徴インタラクションの有効性
合成データに選択した特徴インタラクションの有効性を検証した
不完全なploy-2関数で(bi-linear項はカテゴリのインタラクション)合成データセットを作り、以下のことを検証した
提案モデルは重要なインタラクションを見つけるか
提案モデルと他のモデルのパフォーマンス比較
入力xはmカテゴリのN個をサンプリングした。ラベルyは以下のように合成した
$ y= \delta(\sum_{i=1}^mw_ix_i + \sum_{i,j \in C}v_{i,j}x_ix_j+b+\epsilon)
$ \delta(z) =\begin{cases} 1, & if\ z\ \geq\ threshold \\ 0, & otherwise \end{cases}
ここのbi-linear項のセット$ Cと$ \bm w, \bm v,bはランダムに選択され、固定される
データペアはiidである。$ m = 6, N = 60, C = \lbrace(x_0, x_1), (x_2, x_5),(x_3,x_4)\rbraceでの実験結果はFigure 5でまとめて、提案モデルは有効なインタラクションを選択したことがわかった
https://gyazo.com/81c4b8947da737ad459e5fd4dc07211d
https://gyazo.com/d6fa68381f1ee2493dc6c5da4dd247a7
4.6 Deployment & Online Experiments(RQ4)
Huawei App Storeでオンライン実験を行った。数億のDAUがいて、数千億のログが毎日残される。AutoDeepFMをオンライン環境でデプロイした。3ノードのクラスタで、1ノードあたりは48コアのIntel Xeon CPU E5-2670(2.30GHZ) + 400GB Mem + 2 NVIDIA TESLA V100 GPUのスペックである
A/Bテスト結果はFigure 6, 7で示した
https://gyazo.com/322a558b04ff57e1242ee2cd005decc1
前後期間A/Aテストはベースラインモデル(DeepFM)自身が8%の上下があることがわかった。真ん中の十日間でのA/BテストでCTRとCVR両方大幅に改善した(20.3%と20.1%)
4.7 Ablation Study(RQ5)
4.7.1 異なるシードで$ \bm \alpha学習のStability
Search stageでAutoFMに異なるシードを設定し、複数回$ \bm \alphaを学習した。これらの$ \bm \alphaのPearson相関係数は0.86であり、Stablilityを示した。BNレイヤを除いたら、Pearson係数は0.65に落ちた。
4.7.2 AutoFISのそれぞれのコンポーネントの有効性
Table6でいくつかのAutoFMコンポーネントを取り除いたモデルをまとめた。Table 7でそれぞれのモデルのパフォーマンスを示した
search stageを除いたモデルはRandomでインタラクションを選択した。ランダムで選択したインタラクションの数はAutoFISと同じく設定し、10回を回して結果を平均を取った
https://gyazo.com/63ed191ba32eb0b783c225ccf0b71e24
https://gyazo.com/932e5a9a83dd271ac8d09d383d8dcfd4
AutoFM-BN-$ \alphaとRandom+FMを比較してみると、search stageの有効性がわかった
CriteoでRandom+FMはFMよりいい制度得られたので、一定の条件でランダムにインタラクションを選択したモデルが全てのインタラクションを使うモデルより優れることがわかった
AutoFMとAutoFM-BNの比較はBNの有効性を示した
AutoFM-BNとAutoFM-BN-$ \alphaのギャップは$ \alphaの有効性を示した。re-trainステージでAttentionの仕組みになる
4.7.3 one-level vs bi-level optimization
https://gyazo.com/cdb28b23480e48825a6ba2577c97223a
Table8でone-levelとbi-level optimizationの性能を検証した。one-levelの方が精度よかった。
5、結論
AutoFISのモデルを提案し、自動的に2階と3階の特徴インタラクションを選択できるようになった
提案手法は基本全てのFMモデルに展開でき、CTR予測に貢献した
実装はEasyであり、二つのOpenデータセットと一つのPrivateデータセット、オンラインABテストにおいて全てより良い精度得られた
Appendix
AvazuとCriteoでの実験のパラメータ詳細
https://gyazo.com/785c588de5290afd619f0c11acda07b9