メモ化再帰DPでTLEを避けるには
再帰呼び出しで実装するのが自然な問題がある
例えば、たどる順序が自明でない動的計画法をメモ化再帰呼び出しで実装するシチュエーション Pythonは数値演算が重いので演算の多い動的計画法はTLEしやすい
PyPyだと関数呼び出しが遅いので再帰呼び出しを酷使するとTLEしやすい
普段はNumbaでAOTコンパイルして高速化してるのだけど…
Numbaは関数内関数の再帰呼び出しをサポートしてない
関数外に置いてもAOTではできない
JITならできるが実行フェーズにコンパイルするので速度上不利
Cythonだと関数を「Cからしか呼ばれない」と明示してコンパイルできるので有効打になり得る?
→現時点で最速
速度の比較にはAtCoderのサーバ上でのテストケース1_17, 1_01の実行速度を使う
1_01が一番大きいが、タイムアウトすると時間が分からないので一回り小さい1_17も併用した
table:実行時間まとめ
1_17 1_01
653 ms TLE Code1 Python 素朴なメモ化再帰
422 ms TLE Code1 PyPy
735 ms TLE Code2 Python メモ化 dict→list
378 ms 485 ms Code2 PyPy
434 ms 565 ms Code3 Python 関数呼び出し前に計算済みかをチェック
498 ms 352 ms Code3 PyPy
169 ms 223 ms Code4 Cython
142 ms 177 ms Code5 Cython メモ化 array→Cの配列
148 ms 164 ms (best) Code6 Cython 探索終了条件分岐を再帰の外で先に済ませる
1209 ms 1225 ms Code7 Numba JIT
295 ms 368 ms Code8 Python 深さ優先探索の帰りがけ順で処理
253 ms 256 ms Code8 PyPy
Code1: 素朴なメモ化再帰
code:python
from collections import defaultdict
import sys
sys.setrecursionlimit(10**6)
def solve(N, M, edges):
longest = {}
def get_longest(start):
if start in longest:
next_edges = edges.get(start)
if not next_edges:
ret = 0
else:
ret = max(get_longest(v) for v in edgesstart) + 1 return ret
return max(get_longest(v) for v in edges)
def main():
N, M = map(int, input().split())
edges = defaultdict(set)
for i in range(M):
v1, v2 = map(int, input().split())
print(solve(N, M, edges))
main()
Code2: メモ化にdictを使ってることを疑問に思う人がいるかもしれないのでlistに変えたバージョン
この感じだとPyPyがdictへのアクセスが苦手な可能性があるな(要検証)
code:python
def solve(N, M, edges):
for i in range(N + 1):
def get_longest(start):
if ret != -1:
return ret
next_edges = edges.get(start)
if not next_edges:
ret = 0
else:
ret = max(get_longest(v) for v in edgesstart) + 1 return ret
return max(get_longest(v) for v in edges)
Code3 関数呼び出し前に計算済みかどうかをチェックするバージョン
関数呼び出しの回数をラスト1回分減らす
Code4 Cython
code:cython
cdef get_longest(long start, dict edges, long: longest): cdef list next_edges
next_edges = edges.get(start)
if not next_edges:
ret = 0
else:
#ret = max(get_longest(v, edges, longest) for v in edgesstart) + 1 ret = 0
x = get_longest(v, edges, longest) + 1
if x > ret:
ret = x
return ret
def solve(N, M, edges):
cdef array.array longest = pyarray.array('l', -1 * (N + 1)) return max(get_longest(v, edges, longest) for v in edges)
関数内関数のままではcdefに出来なかったので外に出した
リストのアクセスが遅そうだったのでarrayに変えた
ジェネレータ内包が問題を起こしてたのでforループに書き換えた
Code5 longestをCの配列としてグローバルに置く
142 ms / 177 ms
Code6 出て行く辺があるかどうかの条件分岐を再帰の外で先に済ませる
148 ms / 164 ms
Numba
速度はさておき何もしなくてもそのまま動くCythonと比べると、Numbaは動くようにするまでが大変
ret = max(get_longest(v) for v in edges[start]) + 1
The use of yield in a closure is unsupported.
Compilation is falling back to object mode WITH looplifting enabled because Function "get_longest" failed type inference due to: non-precise type pyobject
単純にnumba.jitをつけるアプローチではうまくいかない、型推論ができるオブジェクトを引数にする必要がある
隣接リスト形式のグラフの扱いが難関
Pythonで気楽に作るとdafaultdict(list)だが、dictもlistも適切でない
可変長なので雑にnp.arrayにすると空間がN * M
整数列として渡してNumba世界で連結リストにするのが一番マシかなと思う
今回の問題ではグラフの辺の最大数が決まってるのでそのサイズのnp.array を確保する
Code7: Numba JIT
今までの中で格段に遅い、よく通ったもんだ
しかし二つの問題にかかった時間の差は16msecで、一番速いCython実装と同じぐらい
つまりJITであるせいで実行フェーズでコンパイルしてしまい時間を食っているが、コンパイル済みのものは高速ということ
僕が普段NumbaをAOTで使うのもこれが理由。コンパイルフェーズでAOTコンパイルできるなら実行フェーズでJITコンパイルする必要は皆無
Numba AOTはコンパイル時にエラーになる
Untyped global name 'get_longest': cannot determine Numba type of <class 'function'>
これは再帰呼び出しだと人間が認識しているものが、Pythonにとっては「グローバル変数を取得してそれを呼ぶ」だから。
関数定義の後に同じ名前で別のものに名前が再束縛されるかもしれないので、取得したグローバル変数の型が不明なのである。
とはいえこれは不便なので、例えばコンパイル時に「この関数内での同名変数のコールはこの関数自体のコールである」と人間が宣言することで再帰呼び出しを可能にするとかが、将来のNumbaには入るかもしれない、入るといいなぁ
JITではできてるんだからAOTできてもいいじゃんー
というわけで今のところ末尾再帰でない再帰関数をNumbaでAOTする方法は見つけられていない
Code8: 深さ優先探索の帰りがけ順で処理をする案
そもそも再帰呼び出しを使わなくてもコールツリーを深さ優先で探索して帰りに処理をすれば良い、という発想
code:python
def solve(N, M, edges):
longest = {}
while stack:
v = stack.pop()
if v > 0:
if v in longest:
continue
next_edges = edges.get(v)
stack.append(-v)
if next_edges:
stack.extend(next_edges)
else:
next_edges = edges.get(-v)
if not next_edges:
ret = 0
else:
ret = max(longestx for x in next_edges) + 1 return max(longestv for v in edges)