yukicoder contest 363 F
解き方
2023年3月3日 追記
いつのまにかリジャッジでTLEとなっていた。
そのため、このページに書かれた解き方は、現時点では誤っていると思われる。
2023年3月6日 追記
畳み込み計算の中の剰余演算を減らしたり、掛け算割り算をビット演算に変えたりしたら少し速くなった。
そうして無理矢理に実行時間制限内に収めて、ギリギリではあるが再びACになった。
しかし、解説や想定解などをみるかぎり、畳み込みの繰り返し2乗法(あるいはダブリング)でやはり問題なさそうだが
私の提出がここまでギリギリになるのは、私の畳み込み計算が遅いということなのだろうか。
これでも、リジャッジされる前のACになるまでに、以前に書いたことのあった畳み込み計算だと遅かったので
実行時間制限に収まるようになんとか高速化したつもりだったのだけれど
追記ここまで
※ ここに記載する解き方は解説と実質は同じものだと思われるが、多項式を使わずに考えた ※ 私の提出は、何度か改良してようやく実行時間制限に収まったので、どこか不適当な箇所がある可能性がある
二項定理を利用するため、この記事においては$ 0^{0} = 1 とする。
また、数列$ \bm{s} の$ i 番目の要素を$ \bm{s}(i) と表記する。ただし、数列は$ 0 番目からはじまるとする。
つまり、長さが$ n の数列$ \bm{s} について、要素を並べる形式で記載すると$ (\bm{s}(0), \bm{s}(1), \ldots, \bm{s}(n-1)) となる。
※ 解き方を記載する際の都合により、問題文と異なる表記になっている
問題文にある通り、正整数$ N 、$ M 、$ L が与えられたとする。
任意の正整数$ m について、非負整数からなる長さ$ L+1 の数列$ \bm{s}_{m} を考える。
$ 0 以上$ L 以下の整数$ k それぞれについて、この数列の$ k 番目の要素$ \bm{s}_{m} (k) を次のように定義する。
各要素が$ 0 以上$ 2^{m} 未満の整数であり、長さが$ N の数列$ \bm{A} は$ 2^{Nm} 個あるが、
その全てに対する$ (\bm{A}(0)\ \mathrm{AND}\ \bm{A}(1)\ \mathrm{AND}\ \cdots\ \mathrm{AND}\ \bm{A}(N-1))^{k} 総和を$ \bm{s}_{m} (k) とする。
定義により、この問題で求めるものは、数列$ \bm{s}_{M} から$ 0 番目の要素を除いたものである。
まず$ \bm{s}_{1} について考える。
各要素が$ 0 または$ 1 であり、長さが$ N の数列$ \bm{A} を考えると、
$ \bm{A}(0)\ \mathrm{AND}\ \bm{A}(1)\ \mathrm{AND}\ \cdots\ \mathrm{AND}\ \bm{A}(N-1) は数列$ \bm{A} の要素が全て$ 1 のときのみ$ 1 になり、その他の場合は$ 0 である。
つまり、$ k が$ 0 の場合を除き$ \bm{s}_{1} (k) = 1 である。また、$ 0^{0} = 1 という仮定から、$ \bm{s}_{1} (0) = 2^{N} である。
よって、$ 2^{N} を繰り返し2乗法で計算すれば、$ \bm{s}_{1} は$ O(L+\log N) の計算量で計算することが可能である。
次に、2つの正整数$ m_{1} と$ m_{2} について、$ \bm{s}_{m_{1}} と$ \bm{s}_{m_{2}} が与えられたときに$ \bm{s}_{m_{1}+m_{2}} を計算する方法を考える。
各要素が$ 0 以上$ 2^{m_{1}+m_{2}} 未満の整数であり、長さが$ N の数列$ \bm{A} に対し、
$ \bm{A}(0)\ \mathrm{AND}\ \bm{A}(1)\ \mathrm{AND}\ \cdots\ \mathrm{AND}\ \bm{A}(N-1) を2進数で表したときの$ i 桁目を$ b_{i} とすると、
任意の非負整数$ k について次の式が成り立つ。
$ (\bm{A}(0)\ \mathrm{AND}\ \bm{A}(1)\ \mathrm{AND}\ \cdots\ \mathrm{AND}\ \bm{A}(N-1))^{k} = (b_{0}+\cdots2^{m_{1}-1}b_{m_{1}-1}+2^{m_{1}}b_{m_{1}}+\cdots2^{m_{1}+m_{2}-1}b_{m_{1}+m_{2}-1})^{k}
二項定理により、右辺は更に次のように変形できる。
$ \sum_{i = 0}^{k} {}_{k} \mathrm{C}_{i} \times (b_{0}+\cdots2^{m_{1}-1}b_{m_{1}-1})^{k-i} \times (2^{m_{1}})^{i}(b_{m_{1}}+\cdots2^{m_{2}-1}b_{m_{1}+m_{2}-1})^{i}
ここで、$ b_{0},\ldots,b_{m_{1}-1} の値と$ b_{m_{1}},\ldots,b_{m_{1}+m_{2}-1} の値は独立して決まるため、
全ての$ \bm{A} について総和をとると、$ k が$ L 以下の場合には次のようになる。
(独立な確率変数の積の期待値が期待値の積になることと同じ理屈で、積の総和は総和の積となる)
$ \bm{s}_{m_{1}+m_{2}} (k) = \sum_{i=0}^{k} {}_{k} \mathrm{C}_{i} \times \bm{s}_{m_{1}} (k-i) \times (2^{m_{1}})^i \bm{s}_{m_{2}} (i)
以上より、繰り返し2乗法あるいはダブリングの要領で計算すれば
$ \bm{s}_{m_{1}} と$ \bm{s}_{m_{2}} から$ \bm{s}_{m_{1}+m_{2}} を求めるような計算を$ O(\log M) 回行うことで$ \bm{s}_{M} が求められる。
ただし、$ \bm{s}_{m_{1}} と$ \bm{s}_{m_{2}} から$ \bm{s}_{m_{1}+m_{2}} を求めるような計算は、愚直に行うと$ L^{2} のオーダーの計算量となるため、効率化が必要である。
ここで、任意の正整数$ m と、$ 0 以上$ L 以下の任意の整数$ k について、$ \bm{t}_{m} (k) = \frac{1}{k!} \bm{s}_{m} (k) となるように数列$ \bm{t}_{m} を定義する。
このとき、$ {}_{k} \mathrm{C}_{i} = \frac{k!}{(k-i)!i!} より、上の式の両辺を$ k! で割ることで、次の式が得られる。
$ \bm{t}_{m_{1}+m_{2}} (k) = \sum_{i=0}^{k} \bm{t}_{m_{1}} (k-i) \times (2^{m_{1}})^i \bm{t}_{m_{2}} (i)
これは数列の畳み込み計算であるため、$ \bm{t}_{m_{1}} と$ \bm{t}_{m_{2}} から$ \bm{t}_{m_{1}+m_{2}} を求める計算は$ O(L \log L) の計算量で実行可能である。
以上より、この問題は全体で$ O(L \log L \log M+\log N) の計算量で計算可能である。
具体的には、高速フーリエ変換により離散フーリエ変換を効率的に行うことで、畳み込みを$ O(L \log L) の計算量で実行できる。
検索するとたくさんの情報が見つかるが、例えば下の記事にも離散フーリエ変換と畳み込みの関係について記載されている。
また、離散フーリエ変換は複素数体上での$ 1 の$ L 乗根を通常の場合に利用するのが、浮動小数点数として扱うと誤差が生じる。
下のWikipedia にも書かれている通り、複素数体以外の環についても、$ L 乗根に相当するものがあれば離散フーリエ変換を考えられる。
下のWikipedia にも数論変換(NTT)について書かれているが、適切な素数などがあれば、整数の範囲で考えることもできる。
解答例
下は上記の方法で解いたときの提出結果である。また、上記の提出の際に提出したソースコードをその下に転記する。
code:C
long long mod_num = 998244353LL;
long long root = 3LL;
int length = 998244352;
long long inverse_root = 0LL;
long long inverse_l = 0LL;
int log_l = 0;
long long pow_root16 = {}; long long pow_root_inv16 = {}; long long power_mod (long long a, long long b, long long p) {
long long ans = 0LL;
a %= p;
if (b <= 0LL) {
return 1LL;
}
ans = power_mod(a, b/2LL, p);
ans = (ans * ans) % p;
if (b%2LL == 1LL) {
ans = (ans * a) % p;
}
return ans;
}
void setup_ntt (int l) {
int tmp_length = 4;
log_l = 1;
while(tmp_length < 2*l) {
tmp_length *= 4;
log_l++;
}
root = power_mod(root, length / tmp_length, mod_num);
inverse_root = power_mod(root, mod_num-2LL, mod_num);
inverse_l = power_mod((long long) tmp_length, mod_num-2LL, mod_num);
length = tmp_length;
for (int i = log_l-1; i > 0; i--) {
pow_rooti-1 *= pow_rooti; }
pow_root_invlog_l-1 = inverse_root; for (int i = log_l-1; i > 0; i--) {
pow_root_invi-1 = pow_root_invi; pow_root_invi-1 *= pow_root_invi; pow_root_invi-1 %= mod_num; pow_root_invi-1 *= pow_root_invi-1; pow_root_invi-1 %= mod_num; }
return;
}
void ntt_4n (long long *a, long long *pow_root) {
long long root_1_4 = pow_root0; for (int i = 0; i < length; i++) {
int idx = 0;
int tmp = i;
for (int j = 0; j < log_l; j++) {
idx <<= 2;
idx |= (tmp&3);
tmp >>= 2;
}
if (i < idx) {
}
}
for (int i = 0; i < log_l; i++) {
int ix2 = 2*i;
int ix2_2 = ix2+2;
int step = (1<<ix2);
int cnt = length>>ix2_2;
long long tmp_root = 1LL;
for (int j = 0; j < step; j++) {
long long w1 = tmp_root;
long long w2 = (w1*w1)%mod_num;
long long w3 = (w2*w1)%mod_num;
for (int k = 0; k < cnt; k++) {
int idx1 = ((k<<ix2_2)|j);
int idx2 = (idx1|step);
int idx3 = (idx1|(2<<ix2));
int idx4 = (idx2|idx3);
long long a1 = aidx1%mod_num; long long a2 = (aidx2*w1)%mod_num; long long a3 = (aidx3*w2)%mod_num; long long a4 = (aidx4*w3)%mod_num; long long wa2 = (a2*root_1_4)%mod_num;
long long wa4 = (a4*root_1_4)%mod_num;
long long pad = (mod_num<<1LL);
aidx2 = a1+wa2-a3-wa4+pad; aidx4 = a1-wa2-a3+wa4+pad; }
tmp_root = (tmp_root*pow_rooti)%mod_num; }
}
for (int i = 0; i < length; i++) {
}
return;
}
int main () {
int n = 0;
int m = 0;
int l = 0;
int res = 0;
long long pow2 = 2LL;
long long inv_cnt = 0LL;
long long pow_cnt = 0LL;
long long pow_inv_l = 0LL;
res = scanf("%d", &n);
res = scanf("%d", &m);
res = scanf("%d", &l);
for (int i = 0; i < l; i++) {
facti+1 *= (long long) (i+1); }
invfl = power_mod(factl, mod_num-2LL, mod_num); for (int i = l; i > 0; i--) {
}
setup_ntt(l+1);
work00 = power_mod(2LL, (long long)n, mod_num); for (int i = 1; i <= l; i++) {
}
for (int i = 0; i <= l; i++) {
}
if (m%2 == 1) {
for (int i = 0; i <= l; i++) {
}
} else {
}
m /= 2;
while (m > 0) {
long long p = 1LL;
for (int j = 0; j <= l; j++) {
p = (p*pow2)%mod_num;
}
for (int j = 0; j < length; j++) {
}
ntt_4n(work0, pow_root_inv); for (int j = 0; j <= l; j++) {
}
for (int j = l+1; j < length; j++) {
}
pow_cnt = 2LL*pow_cnt+1LL;
pow2 = (pow2*pow2)%mod_num;
if (m%2 == 1) {
long long p = 1LL;
for (int i = 0; i <= l; i++) {
p = (p*pow2)%mod_num;
}
for (int i = l+1; i < length; i++) {
}
ntt_4n(ans, pow_root);
for (int i = 0; i < length; i++) {
}
ntt_4n(ans, pow_root_inv);
inv_cnt += pow_cnt+1LL;
}
m /= 2;
}
pow_inv_l = power_mod(inverse_l, inv_cnt, mod_num);
for (int i = 0; i < l; i++) {
printf("%lld\n", (((ansi+1*pow_inv_l)%mod_num)*facti+1)%mod_num); }
return 0;
}
私の提出一覧
table: submissions_yukicoder_contest_363_F
提出のURL 提出時刻 結果 備考
感想
似たように、何かの$ k 乗の総和を求めるような問題をいろいろ考えていたことがあるので、$ O(L^{2} \log M) で解けることはわかった。
また、畳み込みに近い式の形をしていたので、高速フーリエ変換を使えば$ O(L \log L \log M) で解けるかもしれないと思ったが、
高速フーリエ変換による畳み込み計算の高速化を即興で実装するのは難しかったため、コンテスト中に解くのは断念した。
この問題を解いたときに考えたことをもとに、和のk乗の総和について少し考察したりしたいたので、この問題も上の方法を思いついた。
コンテスト後に、高速フーリエ変換と畳み込みに関する記事などを調べて、今回の問題もそれで高速化できそうなことがわかったので
ただ、手元で最悪のケースについて試してみると、一応は現実的な時間と言える範囲の時間で答えが返ってきたので、
できる限り無駄をなくして効率化すれば正解になる可能性があると思い、その後はなんとか実行時間を縮められないかと改善を行った。
使い回せるものは使いまわすなど、途中の計算で無駄な箇所をできる限り削る。
高速フーリエ変換についてはこの記事を参考に、再帰ではなくループで計算するように変更した。 そして、もう何も手が思いつかないというところで、なんとかギリギリAC となった。
いろいろ改善してギリギリだったので、もっと効率の良い解き方が想定されているのかと思ったが、
解説をみると、やはり$ O(L \log L \log M) の計算量が想定されているらしかった。
それならば、なぜ私の解答は実行制限時間ギリギリになってしまうのだろうか、私は何か勘違いしているのか。
ところで、解き方の中では下の記事へのリンクを貼ったが、この問題を解く際に私が見た記事は違うものだった気がする。
記憶と合致する記事が見つからなかったため、下の記事を貼った。内容に不足はないため、特に問題はないが、私は何を見たのか。