Alleviating Cold-start Problem in CTR Prediction with A Variational Embedding Learning Framework
Author: Xiaoxiao Xu, Chen Yang, Qian Yu, Zhiwei Fang, Jiaxing Wang, Chaosheng Fan, Yang He, Changping Peng, Zhangang Lin, Jingping Shao
Organization: JD.com
論文を選ぶ理由:
Cold-start 課題に対象
New userとnew itemへの改善が顕著
様々のモデルと組み合わせる
1、Introduction
new userとnew adsがオンライン広告によくあり、既存のCTRモデルのEmbeddingに対応しづらい
Taobao Display Ad Click データセットでは毎日12%の広告と16.9%のユーザが新規である
Cold-start問題に主に二つの方向がある
コンテンツベースの方法:広告とユーザの他のほかのアトリビュートを利用する
Meta-learning:学習プロセスを工夫して、既存のユーザと広告からknowledge transfer
しかし、上の方法では広告とユーザのEmbeddingが一つのベクトルとして学習している
新規ユーザと広告がデータが足りないので、従来の手法では対応難しい
ポイントのEmbeddingも過学習しやすい
本研究は変分推論を用いて、Embeddingがポイントではなく分布のようにモデリングして学習させる
Neural Networkで学習できる事前分布をモデリングした
正則化して過学習を防ぐ
2、Method
2.1 定式化
2.1.1 前提
https://pic4.zhimg.com/v2-b36f8433239998c70ef1e79fa4e6103b_r.jpg
あるデータセット$ Dに対して、入力とラベルのペア $ (x, y)がある。
入力の $ x は $ \bold x = \lbrack u, c(u),i,c(i), contexts \rbrackである
Embeddingは x から得られる:$ z = g_{\phi}(x)
EmbeddingをMLPネットワーク層に入力し、予測値を得られる $ \hat{y} = \sigma(f_{\theta}(z))
二値分類のcross-entropy目的関数で学習する:
$ L(\phi,\theta) = l(\phi,\theta) = -ylog(\hat{y})-(1-y)log(1-\hat{y})
2.1.2 変分推論
変分推論は潜在変数 $ z(ここはEmbedding)と観測された入力 $ xの条件事後分布 $ p(z|x)に近似する手法である。
ベイズのルールに従って: $ p(z|x) = p(x,z)/p(x)。しかし $ p(x)= \int p(x, z) dzは効率に求める方法がない
$ ELBO(\phi_q) = E(logp(x|z)) - D_{KL}(q_{\phi_q(z|x)}||p(z))\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ (6)
2.1.3 Embeddingのポイント推測の問題
Embeddingはベクトルの形なら、新規IDにデータ不足により、孤立なEmbeddingが作られてしまう
ポイントとして推測されると、Overfittingの可能性が高くなる
https://pic4.zhimg.com/v2-042b750129e066454e410379e38a3ee7_r.jpg
2.2 分布推測
2.2.1 Variational Embedding Learning Framework提案
Embeddingを入力の潜在変数zとして、変分推論で計算する。VELFで全ての分布は正規分布と仮定する
変分推論は、元分布の事後確率を選んだ分布の事後確率で近似する$ q_{\phi_q}(z|x) \approx p(z|x)
式(6)により、Optimizationの目標は二つある:
前項のLikelihoodを最大化する
$ E(logp(x|z))はlog-lossで計算する
後項のKL Divergenceを最小化する
KL Divergenceは正則化として取り扱う。ハイパーパラメータ $ \alphaを導入し、バランスをコントロールできる
式(6)の $ p(z)を下のようにモデリングする
固定の事前分布はモデルの汎用性に影響があるため、neural networkで事前分布をパラメータ化して学習する
$ p(z) \equiv p_{\phi_p}(z|c)
cはuser ID と ad ID と関わる特徴である。$ \phi_qはnerual networkのパラメータである
目標関数は以下である
$ L(\phi, \theta) = l(\phi, \theta) - \alpha D_{KL}(q_{\phi_q}(z|x)||p_{\phi_p}(z))
ここ$ \phi = \lbrack \phi_q,\phi_p \rbrack
2.2.2 Mean-field Variational Embedding Framework
本研究では、新規広告と新規ユーザ両方のcold-startに対象しているため、Mean-field Theoryを導入した
ユーザのEmbeddingと広告のEmbeddingを分けて、独立している仮説
ユーザと広告それぞれの潜在変数は:$ z^u, z^iと定義する
目標関数は以下のように修正する
$ L(\phi, \theta) = l(\phi, \theta) - \alpha (D_{KL}(q_{\phi_q^u}(z^u|u)||p_{\phi_p^u}(z^u)) + D_{KL}(q_{\phi_q^i}(z^i|i)||p_{\phi_p^i}(z^i))
ここ$ \phi = \lbrack \phi_q^u,\phi_p^u,\phi_q^i,\phi_p^i \rbrack
2.2.3 Regularization Priors
過学習を防ぐために、ユーザとアイテムそれぞれの標準正規分布を事前分布として導入する
$ p(z^u)=N(0, I^u) $ p(z^i)=N(0,I^i)
$ L(\phi^u, \phi^i, \theta) = l(\phi^u, \phi^i, \theta) - \alpha (D_{KL}(q_{\phi_q^u}(z^u|u)||p_{\phi_p^u}(z^u)) + D_{KL}(q_{\phi_q^i}(z^i|i)||p_{\phi_p^i}(z^i)))
$ - \alpha(D_{KL}(p_{\phi_p^u}(z^u)||p(z^u))+D_{KL}(p_{\phi_p^i}(z^i)||p(z^i)))
2.2 トレーニング
ユーザと広告のEmbeddingは同じ流れで計算できる。ここはユーザのEmbedding
$ q_{\phi_q^u} = N(\mu_q^u(u), {\sigma_q^u}^2(u))
$ p_{\phi_p^u}(z^u)=p_{\phi_p^u}(z^u|c(u))=N(\mu_p^u(c(u)), {\sigma_p^u}^2(c(u)))
事後確率の$ \mu_q^u, \sigma_q^uはユーザIDをDNNに入力して計算する
事前確率の$ \mu_p^u, \sigma_p^uはユーザのアトリビュートをDNNに入力して計算する
VELFで実際のユーザEmbeddingは、事後確率からサンプリングして得られる。
$ z^u = \mu_q(u)+\sigma_q(u)\odot\epsilon^u
$ \epsilon^u \thicksim N(0, I)
すべてのEmbeddingをconcatして、従来のモデルに入力してラベルを予測する
$ \hat{y} = \sigma(f_\theta(concat(z^u, z^i, z^{c(u)}, z^{c(i)}, z^{context})))
Log-lossは同じように計算する
$ l(\phi,\theta)=\frac{1}{L}\sum_{k=1}^L(-ylog\hat{y}_{(k)})-(1-y)log(1-\hat{y}_{(k)})
LはMonte Carlo Samplingの数であり、本研究では1に固定する
KL-divergenceの項は正規分布の性質によって計算と微分できる
$ D_{KL}(q||p)=log\frac{\sigma_p}{\sigma_q}+\frac{\sigma_q^2+(\mu_q-\mu_p)^2}{2\sigma_p^2}
以上の式をまとめて、ELBOを最大化してEnd-to-Endの学習できる
2.3 推論
ユーザのEmbeddingの例として、推論の計算は以下である。広告も同じように計算できる
$ z^u = g(u)\mu_q(u)+(1-g(u))\mu_p(c(u))
$ g(u) = \frac{1}{1+e^{-F(u)+\epsilon}}
パラメータ化した事前確率が非頻繁なIDに対する補正の役割である。$ g(u)はsigmoidに似ている比率を調整する関数である。$ F(u)はユーザがトレーニングデータの中の頻度である
3、実験と結果
RQ1 VELFはほかのcold-start対策と比べてどれぐらい効くか
RQ2 様々のbackboneにVELFのパフォーマンスは
RQ3 VELFに分布の推測、パラメータ化と正規化事前分布の効果は
3.1 データセット
MovieLens-1M 時間順にソートして最初80%のサンプルをトレーニングし、30レビュー未満のユーザをテストセットに入れる
Taobao Display Ad Click 最初の7日間のデータでトレーニングし、最後の1日でテストする
CIKM2019 EComm AI デフォルト設定でトレーニングとテストを行う
詳細は下のように設置する
https://gyazo.com/b991deddd090aabd6f11096c399e591c
3.2 RQ1 Comparison with State-of-the-Art
https://pic3.zhimg.com/v2-9de3ec6d5b759342065e6e35670a6abe_r.jpg
本研究はSOTAのDropoutNetとMWUFより優れるAUCを得られている
VELFはアイテムの改善がユーザより大きい
Dropout Netを見たら、MovieLensのパフォーマンスがTaobaoより弱い。それはTaobaoのデータが豊富なユーザとアイテムアトリビュートがあるからである。コンテンツベースのアプローチはアトリビュートにセンシティブである
CIKMでNew userとAllのAUCは同じである。それはデフォルトの分割でテストユーザがトレーニングセットのユーザと重なっていないからである
3.3 RQ2 Generalization Experiments
DeepFM以外ほかのbackboneにもテストした
https://gyazo.com/fc3a4972ff51eee92e4225cb5ca70120
ほかのモデルでもVELFが有効であり、改善が見られる
3.4 RQ3 Ablation Study
https://gyazo.com/2123adf821ffd98dc89fa0637314d1fe
VELF(Point): 学習した分布の平均値をユーザとアイテムのEmbeddingにする
VELF(Fixed): パラメータ化したPriorを固定した
VELF(No-R): パラメータ化したPriorの正則化を取り除いた
結論:
VELF(Point)はDropoutNetとほぼ同じAUCのため、pointより分布の優れさがわかる
VELF(No-R)はVELF(Fixed)より優れるので、パラメータ化したPriorが改善につながる
VELFとVELF(No-R)を比べると、正規化も改善につながることがわかる
4、結論と個人コメント
変分推論を用いて、Embeddingの分布を推測してCold-startに効果が得られた
EmbeddingはPointではなく、分布を学習するアイディアが素晴らしいと思う
様々なbackboneモデルと組み合わせできる
複雑な処理が入っているので、本番向けの実装コスト大きい
オフラインのAUCが改善できたが、オンラインのA/Bテストはどうなるか気になる
が、一部のコメントから同じ性能再現できなかったという報告もある