MultiSet
1. 概要
https://qiita.com/toast-uz/items/a63f2d57ec7321186f12 で提案されたPython用のデータ構造
集合とリストを合わせ持ったようなデータ構造.C++のstd::multisetに類する
データの追加/削除が行われ,その時点でのデータ集合に対して二分探索を実行したいときに便利.
同じ要素を複数保存できる点で,SortedSetと異なる.SortedSetの上位互換(?)
table: 比較表
set list SoertedSet MultiSet
in / not in O(1) O(N) O(√N) O(1)
データの追加 O(1) O(1) O(√N) O(logN)
データの削除 O(1) 位置による O(√N) O(logN)
二分探索 不可 bisectでO(logN) O(√N) O(logN)
2. コード
BITのコードも必要.https://qiita.com/toast-uz/items/bf6f142bace86c525532#13-bit を利用
code: python
# https://qiita.com/toast-uz/items/bf6f142bace86c525532#13-bit
class BIT:
def __init__(self, n):
self.n = len(n) if isinstance(n, list) else n
self.size = 1 << (self.n - 1).bit_length()
if isinstance(n, list): # nは1-indexedなリスト
a = 0
for p in n: a.append(p + a-1)
a += [a-1] * (self.size - self.n)
self.d = [ap - ap - (p & -p) for p in range(self.size + 1)]
else: # nは大きさ
self.d = 0 * (self.size + 1)
def __repr__(self):
p = self.size
res = []
while p > 0:
res2 = []
for r in range(p, self.size + 1, p * 2):
l = r - (r & -r) + 1
res2.append(f'{l}, {r}:{self.dr}')
res.append(' '.join(res2))
p >>= 1
res.append(f'{self.sum(p + 1) - self.sum(p) for p in range(self.size)}')
return '\n'.join(res)
def add(self, p, x): # O(log(n)), 点pにxを加算
assert p > 0
while p <= self.size:
self.dp += x
p += p & -p
def get(self, p, default=None): # O(log(n))
assert p > 0
return self.sum(p) - self.sum(p - 1) if 1 <= p <= self.n or default is None else default
def sum(self, p): # O(log(n)), 閉区間1, pの累積和
assert p >= 0
res = 0
while p > 0:
res += self.dp
p -= p & -p
return res
def lower_bound(self, x): # O(log(n)), x <= 閉区間1, pの累積和 となる最小のp
if x <= 0: return 0
p, r = 0, self.size
while r > 0:
if p + r <= self.n and self.dp + r < x:
x -= self.dp + r
p += r
r >>= 1
return p + 1
# https://qiita.com/toast-uz/items/a63f2d57ec7321186f12
import bisect
class MultiSet:
# n: サイズ、compress: 座圧対象list-likeを指定(nは無効)
# multi: マルチセットか通常のOrderedSetか
def __init__(self, n=0, *, compress=[], multi=True):
self.multi = multi
self.inv_compress = sorted(set(compress)) if len(compress) > 0 else i for i in range(n)
self.compress = {k: v for v, k in enumerate(self.inv_compress)}
self.counter_all = 0
self.counter = 0 * len(self.inv_compress)
self.bit = BIT(len(self.inv_compress))
def add(self, x, n=1): # O(log n)
if not self.multi and n != 1: raise KeyError(n)
x = self.compressx
count = self.counterx
if count == 0 or self.multi: # multiなら複数カウントできる
self.bit.add(x + 1, n)
self.counter_all += n
self.counterx += n
def remove(self, x, n=1): # O(log n)
if not self.multi and n != 1: raise KeyError(n)
x = self.compressx
count = self.bit.get(x + 1)
if count < n: raise KeyError(x)
self.bit.add(x + 1, -n)
self.counter_all -= n
self.counterx -= n
def __repr__(self):
return f'MultiSet {{{(", ".join(map(str, list(self))))}}}'
def __len__(self): # oprator len: O(1)
return self.counter_all
def count(self, x): # O(1)
return self.counter[self.compressx] if x in self.compress else 0
def __getitem__(self, i): # operator []: O(log n)
if i < 0: i += len(self)
x = self.bit.lower_bound(i + 1)
if x > self.bit.n: raise IndexError('list index out of range')
return self.inv_compressx - 1
def __contains__(self, x): # operator in: O(1)
return self.count(x) > 0
def bisect_left(self, x): # O(log n)
return self.bit.sum(bisect.bisect_left(self.inv_compress, x))
def bisect_right(self, x): # O(log n)
return self.bit.sum(bisect.bisect_right(self.inv_compress, x))
# 使用例
A = 0, 1, 2, 3, 4, 100, 10000 # MultiSetの要素を集めたリスト.同じ要素があってもよい
mset = MultiSet(compress=A) # MultiSetの作成.まだ空集合
mset.add(2)
mset.add(4)
mset.add(100, 2) # 100を2個追加
mset.add(10000)
print(mset) # 表示 O(NlogN) -> MultiSet {2, 4, 100, 100, 10000}
print(5 in mset) # in判定 O(1) -> False
print(100 in mset) # in判定 O(1) -> True
print(len(mset)) # 要素数 O(1) -> 5
print(mset.count(100)) # 指定要素の個数 O(1) -> 2
print(mset.bisect_left(6)) # 昇順リストとして見て,6以上の最小のindex -> 2
print(mset.bisect_left(100)) # 昇順リストとして見て,100以上の大きい最小のindex -> 2
print(mset.bisect_right(100)) # 昇順リストとして見て,100より大きい最小のindex -> 4
3. 使用例
ABC308 F - Vouchers https://atcoder.jp/contests/abc308/tasks/abc308_f
code: python
# MultiSetをここに貼り付け
N, M = list(map(int, input().split()))
P = list(map(int, input().split()))
L = list(map(int, input().split()))
D = list(map(int, input().split()))
DL = sorted([(Di, Li) for i in range(M)], reverse=True)
mset = MultiSet(compress=P)
for p in P:
mset.add(p)
discount = 0
for d, l in DL: # dの大きい順にチェック
idx = mset.bisect_left(l) # Pidx = l以上で最も小さなP
if idx < len(mset): # l以上のPが存在するなら
mset.remove(msetidx)
discount += d
print(sum(P) - discount)
#データ構造 #MultiSet