Trie
code: python
from typing import Dict, List, Tuple
class TrieNode:
def __init__(self):
self.is_end = False
self.freq = 0
class Trie:
def __init__(self, k: int = 5):
self.root = TrieNode()
self.k = k
# N: len(word_freq)
# O(N*L), O(N*L)
def bulk_insert(self, word_freq: Dictstr, int): for word, freq in word_freq.items():
self._insert_with_freq(word, freq)
# L: len(word)
# O(L), O(L)
def _insert_with_freq(self, word: str, freq: int):
# 単語ごとに始める
node = self.root
# 1 文字ずつ木にする
for char in word:
# なかったら枝を生やす
if char not in node.children:
node.childrenchar = TrieNode() # 位置を下にずらす
node.is_end = True
node.freq = freq
# P: len(prefix), M: Total number of words matching the prefix
# O(P + M + MlogM)
def find_top_k_prefixes(self, prefix: str) -> Liststr: node = self.root
# マッチする最深まで辿る
for char in prefix:
if char not in node.children:
return []
self._collect_words(node, prefix, words)
return sorted(words, key=lambda x: (-x1, x0)):self.k # M: Total number of words matching the prefix
# O(M), O(M)
def _collect_words(self, node: TrieNode, prefix: str, words: Listtuple): if node.is_end:
words.append((prefix, node.freq))
for char, child_node in node.children.items():
self._collect_words(child_node, prefix + char, words)
# root
# / \
# t w
# / \ / \
# r o i i
# / | \ \ | \
# e u y y n s
# | | |
# e e h
import pytest
@pytest.fixture
def sample_trie():
word_freq = {
"tree": 10,
"true": 35,
"try": 29,
"toy": 14,
"wish": 25,
"win": 50,
}
trie = Trie()
trie.bulk_insert(word_freq)
return trie
def test_insert_and_find(sample_trie):
assert sample_trie.find_top_k_prefixes("t") == [
("true", 35),
("try", 29),
("toy", 14),
("tree", 10)
]