ABC331 G. Collect Them All
Difficulty:2668
問題文
#ABC331 #ABC
#FFT #分割統治FFT #多項式 #畳み込み
#マルコフ連鎖 DP #期待値の線形性
問題
M種類のカードが合計N枚ある。種類$ i(1\leq i \leq M)のカードが $ C_i枚ある。
N枚からランダムに選び、元に戻すという操作を繰り返して、M種類全てを選んだことがある状態になるまでの操作回数の期待値を求めよ。
制約
$ M\leq N \leq 2×10^5
$ \sum_{i=1}^{M}C_i=N
解法
以下は公式解説の「解法2:マルコフ連鎖」を自分なりにかみ砕いたもの
一連の操作は「吸収マルコフ連鎖」とみなせる。どういうことかというと、「N枚のカードのうちどのカードを今までに選んだことがあるか」を状態$ Sとおくと、そこからの遷移は状態が変化するかしないかであり、状態が変化した後は二度とその状態に戻ってくることがなく、また最終的にはM種類のカードを全て選んだことがある状態にたどり着き、そこからは遷移せず「吸収状態」となる、ということである。
これでもまだ何を言っているかわかりづらいが、状態を頂点、遷移を有向辺に見立てた有向グラフを考えたとき、(自己ループのことは一旦置いておくと)DAGになっているということである。
つまり、トポロジカルソートしてDPをすることでこの問題を(理論上は)解くことができる。しかしこのDPの状態数はおよそ$ 2^N、遷移は$ O(N)なので計算量はおよそ$ O(2^NN)であり、到底実行時間制限に間に合わない。
さて、どうするか?困ってしまった...
ここで、「M種類すべてを選ぶまでにかかる手数の最小値」は、「すべての状態について、その状態から他の状態へ遷移するまでの手数の期待値の和」と等しいことがわかる。ちょっとここで止まってじっくり考えてみると分かると思う。
そして、各状態から他の状態に遷移する手数の期待値は、状態Sに含まれるカードの数(つまり、これまでに選んだことがあるカードの数)を$ |S|とすると、遷移できる確率がN枚のうちから$ N-|S|枚のどれかを引ける確率$ \frac{N-|S|}{N}、同じ状態に戻る確率が$ \frac{|S|}{N}なので、遷移するまでの期待値は$ \frac {1}{「新たな状態に遷移できる確率」}=$ \frac{1}{1- \frac{|S|}{N}}となる!ここにこの状態になるまでの確率を$ f(S)とおくと、最終的な答えに加算するのはさきほどの式にに$ f(S)をかけた$ \frac{f(S)}{1- \frac{|S|}{N}}となる。
ここまでを式に表すと、状態Sの集合Pに含まれる各状態Sについて、上記の期待値の和なので
$ \sum_{S\in P}\frac{f(S)}{1- \frac{|S|}{N}}となる。
|S|ごとに分けてみる。$ 0\leq |S|\leq N-1のそれぞれに分けてみると、
$ \sum_{k=0}^{N-1}\sum_{S\in P かつ|S|=k}\frac{f(S)}{1- \frac{|S|}{N}}
|S|=kなので、
$ \sum_{k=0}^{N-1}\sum_{S\in P かつ|S|=k}\frac{f(S)}{1- \frac{k}{N}}
$ \frac{1}{1- \frac{k}{N}}がシグマの内部で共通しているので外に出して
$ \sum_{k=0}^{N-1} \frac{1}{1-\frac{k}{N}} \sum_{S\in P かつ|S|=k}f(S)
$ \frac{1}{1- \frac{k}{N}}の分母と分子にNをかけると$ \frac{N}{N-k}になるので、
$ \sum_{k=0}^{N-1} \frac{N}{N-K} \sum_{S\in P かつ|S|=k}f(S)
になった!(公式解説で3行で行っている式変形を、細かく分けながら行った。)
さて、左側の$ \sum_{k=0}^{N-1}\frac{N}{N-K} はfor文で簡単に求められそうだ。右の部分をどうしようか...?
Sの定義に戻ってみると、「N枚のカードのうちからk枚選んで、M種類すべてを選んではいない状態」なので、|S|が同じ部分でまとめて考えると、これらの和(先ほどの式の右の部分)は「N枚のカードのうちからk枚選んで、M種類すべてを選んではいない確率」である。
ここで余事象を考える。すると、「N枚のカードのうちからk枚選んで、M種類すべてを選んではいない確率」の余事象は「N枚のカードのうちからk枚選んで、M種類すべてを選んではいる確率」となる。
これは、dp[i][j]=$ i(1\leq i\leq M)種類目のカードまでから$ j枚選び、$ i種類すべてを選ぶような方法の数 というDPで$ O(NM)で求めることができる。($ 1\leq i \leq M,1\leq j\leq Nより)
だがこれでは遅い。困ってしまった...2
ここで多項式のことを考えてみる。「いくつかの選び方があって、それを繰り返して何枚選ぶときの〇〇」というのは多項式で表せそうだ!
最終的な式は解説にも載っているように$ \prod_{i=1}^{M}\sum_{j=1}^{C_i} \binom{C_i}{j}X^jとなる。これは...何...となると思う。(分かる人は説明は飛ばしてもらってOK)
まずこれも分解して考えよう。右のシグマは何?というと、わかりやすく横に並べて書いてみると
$ \binom{C_i}{1}X^1+\binom{C_i}{2}X^2+\binom{C_i}{3}X^3...\binom{C_i}{C_i}X^{C_i}
となる。つまり、これがi種類めのカードが$ C_i枚あるときにその中から$ j枚選ぶ方法に$ X^jをかけているものとなる。
$ \binom{n}{r}というのは$ nCrという意味で、n個のものからr個選ぶ通り数のこと。
全ての種類のカードに対するこの式の総積を取ることで、すべての(全種類から1枚ずつ選ぶ)選び方を多項式で表現できている!というわけだ。多項式であらわすという概念がよくわからない場合、FPS(Formal Power Series,形式的冪級数)に他の方の記事を載せているのでそれを読んでほしい。
やった!多項式で表せたぞ!といって、並んでいる式を愚直にかけていく(畳み込み。多項式同士の掛け算はFFTで畳み込みをすることで高速にできる)。
だがこれでは遅い。困ってしまった...3
なぜこれでは間に合わないのか?を考えてみる。今何がしたいかというと、
$ (\binom{C_1}{1}X^1+\binom{C_1}{2}X^2...\binom{C_1}{C_1}X^{C_1})(\binom{C_2}{1}X^1+\binom{C_2}{2}X^2...\binom{C_2}{C_2}X^{C_2})...(\binom{C_M}{1}X^1+\binom{C_M}{2}X^2...\binom{C_M}{C_M}X^{C_M})
という式があって(これは、上の総和記号Σと総積記号Πを展開したもの。総和/積の記号のありがたさがわかる...)
この式を全部掛け算して(畳み込みをして)、$ Xの$ N次多項式を計算したい!
いま最終的に何がしたいかを振り返ると、この多項式を計算することで、「N枚のうちからk枚選ぶ方法のうち、M種類全てを選ぶような選び方が何通りか」がわかり、それを1から引くことで「N枚のカードのうちからk枚選んで、M種類すべてを選んではいない確率」がわかり、冒頭の式$ \sum_{k=0}^{N-1} \frac{N}{N-K} \sum_{S\in P かつ|S|=k}f(S)の$ \sum_{S\in P かつ|S|=k}f(S)がわかり、答えが計算できる。
さて、多項式の総積を求めたい。
今やりたいのは「長さの総和がNであるようなM個の多項式の総積を求める」ことである。
愚直に行うと$ O(NM)かかってしまう。O(NM)のコード
かける順番をうまく工夫することで、計算量を抑えたい。
直感的に思いつくこととしては、長さが短いものから順に畳み込んでいくことである。畳み込みには少なくとも両方の多項式の長さをかけたものぐらいの時間がかかるので、長いものに何度も畳み込みをしたくない。
これをそのまま実装してみる!
priority_queueにpair<int,FPS>を乗せ、firstが小さい順に取り出されるようにする。
ただしFPSに大小関係は定義されていないので、vectorにFPSを追加していき、その添え字をpairに入れておく。
これで多項式の総積を求めることができた!この多項式のk項目には「N枚のうちからk枚選ぶ方法のうち、M種類全てを選ぶような選び方が何通りか」が入っている!!!
あとはこの記事もしくは公式解説を上に進んでいきながら、実装をがんばればよい。
なお、多項式の総積を求める部分は、わざわざprioriry_queueに入れずとも、queueに入れるだけでもできる。これは二分木をイメージするとわかりやすいかもしれない。セグ木のように、下に多項式が並んでいて、隣り合う多項式の積がその一つ上に並んでいて、最終的に一番上にすべての多項式の総積がある、という感じ。
実装
提出コード(queueバージョン)
code:cpp
#include<atcoder/all>
using namespace atcoder;
#include<bits/stdc++.h>
using namespace std;
#define rep2(i, m, n) for (int i = (m); i < (n); ++i)
#define rep(i, n) rep2(i, 0, n)
#define drep2(i, m, n) for (int i = (m)-1; i >= (n); --i)
#define drep(i, n) drep2(i, n, 0)
typedef long long ll;
const int MAX = 5000010;
const int MOD = 998244353;
ll facMAX, finvMAX, invvMAX;
void COMinit(){fac0=fac1=finv0=finv1=invv1=1;for(int i=2;i<MAX;i++){faci=faci-1*i%MOD;invvi=MOD-invvMOD%i*(MOD/i)%MOD;finvi=finvi-1*invvi%MOD;}}
ll COM(int n,int k){if(n<k)return 0;if(n<0||k<0)return 0;if(k==0)return 1;return facn*(finvk*finvn-k%MOD)%MOD;}
using mint = modint998244353;
using FPS = vector<mint>;
int main(){
COMinit();
int n,m;cin >> n >> m;
vector<int>c(m);
for(auto&i:c)cin >> i;
queue<FPS>q;
for(int i = 0;i<m;i++){
FPS g(ci+1);
for(int j = 1;j<=ci;j++){
gj=COM(ci,j);
}
q.push(g);
}
while(q.size()>1){
auto f = q.front();q.pop();
auto g = q.front();q.pop();
q.push(convolution(f,g));
}
auto f = q.front();
vector<mint>p(n+1);
for(int i = 0;i<n;i++){
pi = fi;
pi /= COM(n,i);
pi = mint(1)-pi;
}
mint ans{};
for(int i = 0;i<n;i++){
ans += mint(pi*n)/(n-i);
}
cout<<ans.val()<<endl;
}
提出コード(priority_queueバージョン)
code:cpp
#include<atcoder/all>
using namespace atcoder;
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
//二項係数
const int MAX = 5000010;
const int MOD = 998244353;
ll facMAX, finvMAX, invvMAX;
void COMinit(){fac0=fac1=finv0=finv1=invv1=1;for(int i=2;i<MAX;i++){faci=faci-1*i%MOD;invvi=MOD-invvMOD%i*(MOD/i)%MOD;finvi=finvi-1*invvi%MOD;}}
ll COM(int n,int k){if(n<k)return 0;if(n<0||k<0)return 0;if(k==0)return 1;return facn*(finvk*finvn-k%MOD)%MOD;}
using mint = modint998244353;
using FPS = vector<mint>;
int main(){
COMinit();
int n,m;cin >> n >> m;
vector<int>c(m);
for(auto&i:c)cin >> i;
vector<FPS>fpsv;
priority_queue<pair<ll,ll>,vector<pair<ll,ll>>,greater<pair<ll,ll>>>pq;
for(int i = 0;i<m;i++){
FPS g(ci+1);
for(int j = 1;j<=ci;j++){
gj=COM(ci,j);
}
pq.push({g.size(),fpsv.size()});
fpsv.push_back(g);
}
while(pq.size()>1){
auto fs,fpos = pq.top();pq.pop();
auto gs,gpos = pq.top();pq.pop();
auto ret = convolution(fpsvfpos,fpsvgpos);
pq.push({ret.size(),fpsv.size()});
fpsv.push_back(ret);
}
auto fs,fpos = pq.top();
auto f = fpsvfpos;
vector<mint>p(n+1);
for(int i = 0;i<n;i++){
pi = fi;
pi /= COM(n,i);
pi = mint(1)-pi;
}
mint ans{};
for(int i = 0;i<n;i++){
ans += mint(pi*n)/(n-i);
}
cout<<ans.val()<<endl;
}