SegmentTreebeats!
セグメント木では一点更新・区間取得が、遅延セグメント木では区間更新・区間取得ができます。ではSegment Tree beats!では何ができるのでしょうか?
Segment Tree beats!では、通常の遅延セグメント木では扱うことのできない更新クエリに対応することができます。(例:区間chmin/chmaxクエリ)
注:この記事はhitonanodeさんの
を参考に、よりコンテスト中の実用を意識した記事です。
実装方法
先に結論から書きます。区間加算&chmax&chmin/区間和のSegment Tree beats!は以下のようになります。これは、hitonanodeさんの記事の一番下に載っているものとACLのlazy_segtreeを合わせて展開したものです。このままコピペして使えます。
code:cpp
namespace internal {
#if __cplusplus >= 202002L using std::bit_ceil;
// @return same with std::bit::bit_ceil
unsigned int bit_ceil(unsigned int n) {
unsigned int x = 1;
while (x < (unsigned int)(n)) x *= 2;
return x;
}
// @param n 1 <= n
// @return same with std::bit::countr_zero
int countr_zero(unsigned int n) {
unsigned long index;
_BitScanForward(&index, n);
return index;
return __builtin_ctz(n);
}
// @param n 1 <= n
// @return same with std::bit::countr_zero
constexpr int countr_zero_constexpr(unsigned int n) {
int x = 0;
while (!(n & (1 << x))) x++;
return x;
}
} // namespace internal
#if __cplusplus >= 201703L template <class S,
auto op,
auto e,
class F,
auto mapping,
auto composition,
auto id>
struct Segtreebeats {
static_assert(std::is_convertible_v<decltype(op), std::function<S(S, S)>>,
"op must work as S(S, S)");
static_assert(std::is_convertible_v<decltype(e), std::function<S()>>,
"e must work as S()");
static_assert(
std::is_convertible_v<decltype(mapping), std::function<S(F, S)>>,
"mapping must work as F(F, S)");
static_assert(
std::is_convertible_v<decltype(composition), std::function<F(F, F)>>,
"compostiion must work as F(F, F)");
static_assert(std::is_convertible_v<decltype(id), std::function<F()>>,
"id must work as F()");
template <class S,
S (*op)(S, S),
S (*e)(),
class F,
S (*mapping)(F, S),
F (*composition)(F, F),
F (*id)()>
struct Segtreebeats {
public:
Segtreebeats() : Segtreebeats(0) {}
explicit Segtreebeats(int n) : Segtreebeats(std::vector<S>(n, e())) {}
explicit Segtreebeats(const std::vector<S>& v) : _n(int(v.size())) {
size = (int)internal::bit_ceil((unsigned int)(_n));
log = internal::countr_zero((unsigned int)size);
d = std::vector<S>(2 * size, e());
lz = std::vector<F>(size, id());
for (int i = size - 1; i >= 1; i--) {
update(i);
}
}
void set(int p, S x) {
assert(0 <= p && p < _n);
p += size;
for (int i = log; i >= 1; i--) push(p >> i);
for (int i = 1; i <= log; i++) update(p >> i);
}
S get(int p) {
assert(0 <= p && p < _n);
p += size;
for (int i = log; i >= 1; i--) push(p >> i);
}
S prod(int l, int r) {
assert(0 <= l && l <= r && r <= _n);
if (l == r) return e();
l += size;
r += size;
for (int i = log; i >= 1; i--) {
if (((l >> i) << i) != l) push(l >> i);
if (((r >> i) << i) != r) push((r - 1) >> i);
}
S sml = e(), smr = e();
while (l < r) {
if (l & 1) sml = op(sml, dl++); if (r & 1) smr = op(d--r, smr); l >>= 1;
r >>= 1;
}
return op(sml, smr);
}
S all_prod() { return d1; } void apply(int p, F f) {
assert(0 <= p && p < _n);
p += size;
for (int i = log; i >= 1; i--) push(p >> i);
for (int i = 1; i <= log; i++) update(p >> i);
}
void apply(int l, int r, F f) {
assert(0 <= l && l <= r && r <= _n);
if (l == r) return;
l += size;
r += size;
for (int i = log; i >= 1; i--) {
if (((l >> i) << i) != l) push(l >> i);
if (((r >> i) << i) != r) push((r - 1) >> i);
}
{
int l2 = l, r2 = r;
while (l < r) {
if (l & 1) all_apply(l++, f);
if (r & 1) all_apply(--r, f);
l >>= 1;
r >>= 1;
}
l = l2;
r = r2;
}
for (int i = 1; i <= log; i++) {
if (((l >> i) << i) != l) update(l >> i);
if (((r >> i) << i) != r) update((r - 1) >> i);
}
}
template <bool (*g)(S)> int max_right(int l) {
return max_right(l, [](S x) { return g(x); });
}
template <class G> int max_right(int l, G g) {
assert(0 <= l && l <= _n);
assert(g(e()));
if (l == _n) return _n;
l += size;
for (int i = log; i >= 1; i--) push(l >> i);
S sm = e();
do {
while (l % 2 == 0) l >>= 1;
while (l < size) {
push(l);
l = (2 * l);
l++;
}
}
return l - size;
}
l++;
} while ((l & -l) != l);
return _n;
}
template <bool (*g)(S)> int min_left(int r) {
return min_left(r, [](S x) { return g(x); });
}
template <class G> int min_left(int r, G g) {
assert(0 <= r && r <= _n);
assert(g(e()));
if (r == 0) return 0;
r += size;
for (int i = log; i >= 1; i--) push((r - 1) >> i);
S sm = e();
do {
r--;
while (r > 1 && (r % 2)) r >>= 1;
while (r < size) {
push(r);
r = (2 * r + 1);
r--;
}
}
return r + 1 - size;
}
} while ((r & -r) != r);
return 0;
}
private:
int _n, size, log;
std::vector<S> d;
std::vector<F> lz;
void all_apply(int k, F f) {
if (k < size) {
lzk = composition(f, lzk); if (dk.fail) push(k), update(k); }
}
void push(int k) {
all_apply(2 * k + 1, lzk); }
};
template <typename Num> inline Num second_lowest(Num a, Num a2, Num c, Num c2) noexcept {
// a < a2, c < c2 のとき全引数を昇順ソートして二番目の値
return a == c ? std::min(a2, c2) : a2 <= c ? a2 : c2 <= a ? c2 : std::max(a, c);
}
template <typename Num> inline Num second_highest(Num a, Num a2, Num b, Num b2) noexcept {
// a > a2, b > b2 のとき全引数を降順ソートして二番目の値
return a == b ? std::max(a2, b2) : a2 >= b ? a2 : b2 >= a ? b2 : std::min(a, b);
}
using BNum = long long;
constexpr BNum BINF = 1LL << 61;
struct S {
BNum lo, hi, lo2, hi2, sum; // 区間最小・最大値,区間最小・最大から二番目の値,区間総和
unsigned sz, nlo, nhi; // 区間要素数,区間最小・最大値をとる要素の個数
bool fail;
S() : lo(BINF), hi(-BINF), lo2(BINF), hi2(-BINF), sum(0), sz(0), nlo(0), nhi(0), fail(0) {}
S(BNum x, unsigned sz_ = 1)
: lo(x), hi(x), lo2(BINF), hi2(-BINF), sum(x * sz_), sz(sz_), nlo(sz_), nhi(sz_), fail(0) {}
};
S e() { return S(); }
S op(S l, S r) {
S ret;
ret.lo = std::min(l.lo, r.lo), ret.hi = std::max(l.hi, r.hi);
ret.lo2 = second_lowest(l.lo, l.lo2, r.lo, r.lo2);
ret.hi2 = second_highest(l.hi, l.hi2, r.hi, r.hi2);
ret.sum = l.sum + r.sum, ret.sz = l.sz + r.sz;
ret.nlo = l.nlo * (l.lo <= r.lo) + r.nlo * (r.lo <= l.lo);
ret.nhi = l.nhi * (l.hi >= r.hi) + r.nhi * (r.hi >= l.hi);
return ret;
}
struct F {
BNum lb, ub, bias;
F(BNum chmax_ = -BINF, BNum chmin_ = BINF, BNum add = 0) : lb(chmax_), ub(chmin_), bias(add) {}
static F chmin(BNum x) noexcept { return F(-BINF, x, BNum(0)); }
static F chmax(BNum x) noexcept { return F(x, BINF, BNum(0)); }
static F add(BNum x) noexcept { return F(-BINF, BINF, x); };
};
F composition(F fnew, F fold) {
F ret;
ret.lb = std::max(std::min(fold.lb + fold.bias, fnew.ub), fnew.lb) - fold.bias;
ret.ub = std::min(std::max(fold.ub + fold.bias, fnew.lb), fnew.ub) - fold.bias;
ret.bias = fold.bias + fnew.bias;
return ret;
}
F id() { return F(); }
S mapping(F f, S x) {
if (x.sz == 0) return e();
else if (x.lo == x.hi or f.lb == f.ub or f.lb >= x.hi or f.ub <= x.lo) {
return S(std::min(std::max(x.lo, f.lb), f.ub) + f.bias, x.sz);
} else if (x.lo2 == x.hi) {
x.lo = x.hi2 = std::max(x.lo, f.lb) + f.bias;
x.hi = x.lo2 = std::min(x.hi, f.ub) + f.bias;
x.sum = x.lo * x.nlo + x.hi * x.nhi;
return x;
} else if (f.lb < x.lo2 and f.ub > x.hi2) {
BNum nxt_lo = std::max(x.lo, f.lb), nxt_hi = std::min(x.hi, f.ub);
x.sum += (nxt_lo - x.lo) * x.nlo - (x.hi - nxt_hi) * x.nhi + f.bias * x.sz;
x.lo = nxt_lo + f.bias, x.hi = nxt_hi + f.bias, x.lo2 += f.bias, x.hi2 += f.bias;
return x;
}
x.fail = 1;
return x;
}
using Beats = Segtreebeats<S, op, e, F, mapping, composition, id>;
使用例としては以下のようになります。
code:cpp
int main(){
Segtreebeats<S, op, e, F, mapping, composition, id> seg(10);
rep(i,10)seg.set(i,i);
//区間加算
seg.apply(0,5,F(-BINF,BINF,10));
rep(i,10){
cout<<seg.get(i).sum<<" ";
}
cout<<endl;
//区間chmin
seg.apply(0,5,F(-BINF,12,0));
rep(i,10){
cout<<seg.get(i).sum<<" ";
}
cout<<endl;
//区間chmax
seg.apply(0,5,F(100,BINF,0));
rep(i,10){
cout<<seg.get(i).sum<<" ";
}
cout<<endl;
//区間和
cout << seg.prod(0,10).sum<<endl;
}
実際にAtCoderで使ってみるとこんな感じです。通常のセグ木も併用しているので少しわかりづらいのと、chmaxクエリしか使っていなくてaddとchminが腐っていますが...
(必要ない操作は消した方が高速になります、当たり前ですが)
code:cpp
using S2 = ll;
S2 op2(S2 a, S2 b){
return max(a,b);
}
S2 e2(){
return 0;
}
int main(){
ll n;cin >> n;
vector<ll>h(n);
for(auto&i:h)cin >> i;
atcoder::segtree<S2,op2,e2>segm(n);
set<ll>s{h.begin(),h.end()};
map<ll,ll>mp;
for(auto i:s)mpi=mp.size(); Segtreebeats<S, op, e, F, mapping, composition, id> seg(n);
rep(i,n)seg.set(i,0);
ll now{1};
rep(i,n){
cout<<seg.get(i).sum<<" ";
}
cout<<endl;
};
rep(i,n){
// db();
ll pos = segm.prod(mp[hi],n); ll x = hi * (i+1-pos) - seg.prod(pos,i+1).sum; now+=x;
cout<<(now)<<" ";
seg.apply(pos,i+1,F(hi,BINF,0)); }
}