セグメント木
1. セグメント木
(1) 概要
数列(リスト)に対して,指定区間における最小値/最大値等の演算を求めるクエリにO(log N)で解答するためのデータ構造
データ構造の作成には,O(NlogN)必要
途中で数が変化しても対応可能(変化しない場合は Sparse Tableの方が速い.クエリに対してO(1))
可能な演算例:和,積,最小値,最大値,最大公約数,最小公倍数,論理積,論理和,排他的論理和
以下Yuto.iconのメモ(勉強中なので間違ってるかも)
table:セグメント木の分類
主セグメント木 一点更新・区間取得
双対セグメント木 区間更新・一点取得
遅延評価セグメント木 区間更新・区間取得
セグメント木を使える条件
モノイド
結合律:$ (a \bullet b)\bullet c = a\bullet (b\bullet c)
単位元:$ a \bullet e = e \bullet a = a
ノードiから見た家族の表現
基本的にbit演算の方が高速なイメージがあるけど可読性は掛け算した方が高い
子供: (2*i, 2*i+1) or (i<<1|0, i<<1|1)
兄弟: i^1
親 : i>>1 or i//2
個人的に分かりやすかった資料
code:python
import operator
import math
class SegTree:
# コンストラクタ.デフォルト min
def __init__(self, init_val, type='min', ide_ele=None):
settings = {'min': (min, float('inf')), # 最小値
'max': (max, -float('inf')), # 最大値
'add': (operator.add, 0), # 和
'mul': (operator.mul, 1), # 積
'gcd': (math.gcd, 0), # 最大公約数
'lcm': (lambda x, y: x * y // math.gcd(x, y), 1), # 最小公倍数
'or': (operator.or_, 0), # 論理和
'and': (operator.and_, 1), # 論理積
'xor': (operator.xor, 0) # 排他的論理和
}
n = len(init_val)
try:
if ide_ele is None:
self.segfunc, self.ide_ele = settingstype else:
self.segfunc, self.ide_ele = settingstype0, ide_ele except KeyError:
exit(f'type として未定義の{type}が指定されています.')
self.num = 1 << (n - 1).bit_length() # 葉の数.完全2分木の葉に元のデータを入れるので,2の累乗個
self.tree = self.ide_ele * 2 * self.num # セルの作成.葉の数×2-1個存在.1-index for i in range(n):
self.treeself.num + i = init_vali # 配列の値を葉にセット.0-index -> 1-index for i in range(self.num - 1, 0, -1):
# 葉以外のセルの構築.1-index.根のindexが1
# k番目(0-index)の値を xに更新
def update(self, k, x):
k += self.num # 葉の位置.0-index -> 1-index
while k > 1:
k >>= 1
# k番目(0-index)の値に xを加算
def add(self, k, x):
self.update(k, self.val(k) + x) # xを加えた値で更新
# 半開区間[l,r)に対するクエリ.l,r: 0-index
def query(self, l, r):
res = self.ide_ele
l += self.num # 葉の位置.0-index -> 1-index
r += self.num # 葉の位置.0-index -> 1-index
while l < r:
if l & 1: # lが奇数(右側のセル)のとき
res = self.segfunc(res, self.treel) l += 1 # 一番左のセルは計算済みなので,1つ右側の葉に移動
if r & 1: # rが奇数(右側のセル)のとき.右側は閉区間なので含まず,その左のセルまでが対象
res = self.segfunc(res, self.treer - 1) # 左隣の葉 l >>= 1 # 親の葉へ.lは範囲内
r >>= 1 # 親の葉へ.rは範囲外
return res
# 葉の値を返す
def val(self, k):
# 使用例
seg = SegTree(a, 'min') # インスタンスの作成.デフォルトの最小値'min'は省略可
print(seg.val(1)) # a1を表示 --> 5 print(seg.query(1, 7)) # a1~a6の最小値を表示 --> 5 print(seg.query(0, 10)) # a0~a9の最小値を表示 --> 2. a10は存在しない seg.update(5, 1) # a5を1に変更 print(seg.query(1, 7)) # a1~a6の最小値を表示 --> 1 seg.add(5, 2) # a5に2を追加して 3になる print(seg.query(1, 7)) # a1~a6の最小値を表示 --> 3 2. セグメント木の利用例
(1) 区間内に存在するデータ数,データの合計
データが順に追加される
ある時点において,ある区間内に存在するデータの件数をO(logL)で求める.Lは全体の区間幅
同様に,ある区間内に存在するデータの合計もO(logL)で求まる.
データの値が負,または10^6以上の場合は,座標圧縮が必要.
code: python
# セグメント木の作成
seg_cnt = SegTree(0 * (11), type='add') # 0, 10におけるデータの個数 seg_sum = SegTree(0 * (11), type='add') # 0, 10におけるデータの値の総和 seg_cnt.add(5, 1) # 個数用に 5を1つ追加
seg_sum.add(5, 5) # 合計用にも 5を1つ追加
print(seg_cnt.query(0, 5)) # 区間[0, 5) = 0, 4にはデータなし --> 0 print(seg_cnt.query(5, 6)) # 区間[5, 6) = 5, 5にデータは1件 --> 1 seg_cnt.add(5, 1) # 個数用に 5を1つ追加
seg_sum.add(5, 5) # 合計用にも 5を1つ追加
seg_cnt.add(2, 1) # 個数用に 2を1つ追加
seg_sum.add(2, 2) # 合計用にも 2を1つ追加
print(seg_cnt.query(0, 5)) # 区間[0, 5) = 0, 4にはデータは1件 --> 1 print(seg_cnt.query(5, 6)) # 区間[5, 6) = 5, 5にデータは2件 --> 2 print(seg_sum.query(0, 5)) # 区間0, 4内のデータは 2のみ.合計2 --> 2 print(seg_sum.query(2, 8)) # 区間2, 7内のデータは {2, 5, 5}.合計12 --> 12 3. 2Dセグメント木
セグメント木の2次元バージョン
code: python
import operator
import math
class SegTree2D:
# コンストラクタ.デフォルト min
def __init__(self, init_val, type='min', ide_ele=None):
settings = {'min': (min, float('inf')), # 最小値
'max': (max, -float('inf')), # 最大値
'add': (operator.add, 0), # 和
'mul': (operator.mul, 1), # 積
'gcd': (math.gcd, 0), # 最大公約数
'lcm': (lambda x, y: x * y // math.gcd(x, y), 1), # 最小公倍数
'or': (operator.or_, 0), # 論理和
'and': (operator.and_, 1), # 論理積
'xor': (operator.xor, 0) # 排他的論理和
}
h = len(init_val0) # 行数.2次元リストの高さ w = len(init_val) # 列数.2次元リストの横幅
try:
if ide_ele is None:
self.segfunc, self.ide_ele = settingstype # デフォルト:min else:
self.segfunc, self.ide_ele = settingstype0, ide_ele except KeyError:
raise TypeError(f'引数 type として未定義の {type} が指定されています.')
self.num_h = 1 << (h - 1).bit_length() # 縦方向の葉の数.完全2分木の葉に元のデータを入れるので,2の累乗個
self.num_w = 1 << (w - 1).bit_length() # 横方向の葉の数.完全2分木の葉に元のデータを入れるので,2の累乗個
# セルの作成.縦横ともに,葉の数×2-1個存在.1-index
self.tree = [self.ide_ele * 2 * self.num_w for _ in range(2 * self.num_h)] for i in range(h):
for j in range(w):
# 配列の値を葉にセット.0-index -> 1-index
for i in range(self.num_h, self.num_h + h):
for j in range(self.num_w - 1, 0, -1):
# 葉の左側にあるセルの構築.1-index.根のindexが1
for j in range(self.num_w * 2):
for i in range(self.num_h - 1, 0, -1):
# 葉がある行の上にあるセルの構築.1-index.根のindexが1
# i行j列の要素(0-index)の値を xに更新
def update(self, i, j, x):
i += self.num_h # 葉の行の位置.0-index -> 1-index
j += self.num_w # 葉の列の位置.0-index -> 1-index
cur_j = j
while cur_j > 1:
# 葉の行の左側を更新
cur_j >>= 1
while i > 1:
# 葉の縦方向での親を更新
i >>= 1
cur_j = j
while cur_j > 1:
cur_j >>= 1
# i行j列の要素(0-index)の値に xを加算
def add(self, i, j, x):
self.update(i, j, self.val(i, j) + x) # xを加えた値で更新
# 左上(r1, c1)を含み,右下(r2, c2)を含まない範囲に対するクエリ.r1, c1, r2, c2: 0-index
def query(self, i1, j1, i2, j2):
res = self.ide_ele
i1 += self.num_h # 左上の葉の行の位置.0-index -> 1-index
j1 += self.num_w # 左上の葉の列の位置.0-index -> 1-index
i2 += self.num_h # 左上の葉の行の位置.0-index -> 1-index
j2 += self.num_w # 左上の葉の列の位置.0-index -> 1-index
while i1 < i2:
if i1 & 1: # i1 が奇数(右側のセル)のとき
cur_j1, cur_j2 = j1, j2
while cur_j1 < cur_j2:
if cur_j1 & 1: # cur_j1が奇数(右側のセル)のとき
res = self.segfunc(res, self.treei1cur_j1) cur_j1 += 1 # 一番左のセルは計算済みなので,1つ右側の葉に移動
if cur_j2 & 1: # cur_j2が奇数(右側のセル)のとき.右側は閉区間なので含まず,その左のセルまでが対象
cur_j1 >>= 1 # 親の葉へ.cur_j1は範囲内
cur_j2 >>= 1 # 親の葉へ.cur_j2は範囲外
i1 += 1 # 一番左のセルは計算済みなので,1つ右側の葉に移動
if i2 & 1: # i2が奇数(右側のセル)のとき.右側は閉区間なので含まず,その左のセルまでが対象
cur_j1, cur_j2 = j1, j2
while cur_j1 < cur_j2:
if cur_j1 & 1: # cur_j1が奇数(右側のセル)のとき
cur_j1 += 1 # 一番左のセルは計算済みなので,1つ右側の葉に移動
if cur_j2 & 1: # cur_j2が奇数(右側のセル)のとき.右側は閉区間なので含まず,その左のセルまでが対象
cur_j1 >>= 1 # 親の葉へ.cur_j1は範囲内
cur_j2 >>= 1 # 親の葉へ.cur_j2は範囲外
i1 >>= 1 # 親の葉へ.i1は範囲内
i2 >>= 1 # 親の葉へ.i2は範囲外
return res
# 葉の値を返す
def val(self, i, j):
# 使用例
a = 1, 2, 3, 4], 8, 7, 6, 5, 9, 10, 11, 12, [16, 15, 14, 13 # 初期2次元リスト.2行3列 seg = SegTree2D(a, 'min') # インスタンスの作成.デフォルトの最小値'min'は省略可
seg.update(1, 2, 0)
print(seg.query(0, 0, 3, 3)) # a00~a22の範囲の最小値を表示 --> 0 (= a12) print(seg.query(2, 2, 4, 4)) # a22~a33の範囲の最小値を表示 --> 11 (= a22) seg.add(2, 2, 10) # a22 = 11 + 10 = 21 print(seg.query(2, 2, 4, 4)) # a22~a33の範囲の最小値を表示 --> 12 (= a23) Verified at: (ただし,以下2つの問題は updateがないので,update非検証.要注意)