Python で二次元セグ木を実装した話

はじめに

ライブラリ整備シリーズ第一弾
最近の ABC に出たとある問題でお世話になったので実装しました。
下の例では、minimum query の場合に絞って話をしています。

実装概要

乗せる配列を A (4 × 4) とします。
つまり、セグ木(名前を tree とします)は 8 × 8 配列となります。

初期化

① tree[4][4], A[0][0] を基準として、それぞれ対応するところに値を入れます

f:id:Anko_nasubi:20211127161614p:plain

② 各列・各行について、一次元セグ木と同じ要領で最小値を記録していきます(簡単のため、Python では動かない記法を使っています)

f:id:Anko_nasubi:20211127162304p:plain

③ ②で新しく作った各行 (列でも OK) の値を用いて各列の最小値を記録していきます

f:id:Anko_nasubi:20211127163153p:plain

更新

例として、A[1][2] の値を更新します。

① 一次元セグ木と同じ要領で、ビットシフトなどを用いて値がある行、列の対応する場所を更新します(下図の橙色の部分)

f:id:Anko_nasubi:20211127163747p:plain

② ①の更新で更新した行・列を組み合わせて更新します (この場合、tree[5][1], tree[5][3], tree[1][6], tree[2][6] を更新したので、H = 1 or 2 かつ W = 1 or 3 である場所 (下図の青色の部分) を更新します)

f:id:Anko_nasubi:20211127164241p:plain

クエリ処理

一次元セグ木のときの処理を二重 While を用いて組み合わせます。
例えば、1 <= H < 4, 1 <= W < 4 であるような範囲 (下図の橙色) の最小値は、下図の青色で示した部分の最小値を求めることで求めることができます。

f:id:Anko_nasubi:20211127165338p:plain

ソースコード

計算量

セグ木に乗せる配列の大きさを H × W とすると、

  • 初期化 : O(HW)
  • 更新 : O(log H log W)
  • クエリ : O(log H log W)

となります。

使い方

__init__(self, val, segf = min, ide = 10**18)

val ... セグ木に乗せる配列です。二次元配列でないと動きません。
segf ide ... デフォルトでは minimum query を扱うときのものになっています。必要に応じて指定してください。

update(self, h, w, x)

h 列目 w 行目の値を x に更新します。

query(h1, h2, w1, w2)

h ∈ [h1, h2), w ∈ [w1, w2) であるような h, w に関して、それらの最小値 (segf = min の場合) を返します。h1 >= h2 または w1 >= w2 の場合 ide を返します。

class seg2d:
    def __init__(self, val, segf = min, ide = 10**18):
        h = len(val)
        w = len(val[0])
        self.segf = segf
        self.ide = ide
        self.h = 1 << (h - 1).bit_length()
        self.w = 1 << (w - 1).bit_length()
        self.tree = [[ide] * (2 * self.w) for _ in range(2 * self.h)]
        for i in range(h):
            for j in range(w):
                self.tree[self.h + i][self.w + j] = val[i][j]
        for i in range(h):
            for j in range(self.w - 1, 0, -1):
                self.tree[self.h + i][j]=self.segf(self.tree[self.h + i][j * 2], self.tree[self.h + i][j * 2 + 1])
        for i in range(w):
            for j in range(self.h - 1, 0, -1):
                self.tree[j][self.w + i] = self.segf(self.tree[j * 2][self.w + i], self.tree[j * 2 + 1][self.w + i])
        for i in range(self.h - 1, 0, -1):
            for j in range(self.w - 1, 0, -1):
                self.tree[i][j] = self.segf(self.tree[i][j * 2], self.tree[i][j * 2 + 1])
    def update(self, h, w, x):
        h += self.h
        w += self.w
        self.tree[h][w] = x
        h2 = h
        while h2 > 1:
            self.tree[h2 >> 1][w] = self.segf(self.tree[h2][w], self.tree[h2 ^ 1][w])
            h2 >>= 1
        w2 = w
        while w2 > 1:
            self.tree[h][w2 >> 1] = self.segf(self.tree[h][w2], self.tree[h][w2 ^ 1])
            w2 >>= 1
        h2 = h
        while h2 > 1:
            w2 = w
            while w2 > 1:
                self.tree[h >> 1][w2 >> 1] = self.segf(self.tree[h >> 1][w2], self.tree[h >> 1][w2 ^ 1])
                w2 >>= 1
            h2 >>= 1
    def query(self, h1, h2, w1, w2):
        ret = self.ide
        h1 += self.h
        h2 += self.h
        w1 += self.w
        w2 += self.w
        while h1 < h2:
            w3 = w1
            w4 = w2
            if h1 & 1:
                while w3 < w4:
                    if w3 & 1:
                        ret = self.segf(ret, self.tree[h1][w3])
                        w3 += 1
                    if w4 & 1:
                        ret = self.segf(ret, self.tree[h1][w4 - 1])
                    w3 >>= 1
                    w4 >>= 1
                h1 += 1
            w3 = w1
            w4 = w2
            if h2 & 1:
                while w3 < w4:
                    if w3 & 1:
                        ret = self.segf(ret, self.tree[h2 - 1][w3])
                        w3 += 1
                    if w4 & 1:
                        ret = self.segf(ret, self.tree[h2 - 1][w4 - 1])
                    w3 >>= 1
                    w4 >>= 1
            h1 >>= 1
            h2 >>= 1
        return ret