かんプリンの学習記録

勉強したことについてメモしています. 主に競技プログラミングの問題の解説やってます.

ABC234 D - Prefix K-th Max

問題はこちら

問題概要

{(1,2,\cdots,N)}の順列{P=(P_1,P_2,\cdots,P_N)}および正整数{K}が与えられる.{i=K,K+1,\cdots,N}について以下の値を求めよ.

  • {P}の先頭{i}項のうち,{K}番目に大きい値

{1\leq K,N\leq 5×10^5}

解説

公式解説にもある通り優先度付きキューで解けますが,BinaryTrieでも解くことができます.BinaryTrieとは整数の(多重)集合を管理することができるデータ構造です.

BinaryTrieでできること
以下の操作を{O(\log{A_{max}})}でできます.

  • 値の追加
  • 値の削除
  • 値の個数の取得
  • {k}番目の値の取得
  • {x}以上の最小値の取得

など

ほかにもXORの操作もできます.
今回の問題では{P_i}を順にBinaryTrieに追加して降順{K}番目の値を出力していけばいいです.計算量は{O(N\log{N})}となります.

BinaryTrieを使える問題
kanpurin.hatenablog.com

提出プログラム

class BinaryTrie:
    def __init__(self, max_query=5*10**5, bitlen=30):
        n = max_query * bitlen
        self.nodes = [-1] * (2 * n)
        self.cnt = [0] * n
        self.id = 0
        self.bitlen = bitlen
 
    def size(self):
        return self.cnt[0]
 
    # xの個数
    def count(self,x):
        pt = 0
        for i in range(self.bitlen-1,-1,-1):
            y = x>>i&1
            if self.nodes[2*pt+y] == -1:
                return 0
            pt = self.nodes[2*pt+y]
        return self.cnt[pt]
 
    # xの挿入
    def insert(self,x):
        pt = 0
        for i in range(self.bitlen-1,-1,-1):
            y = x>>i&1
            if self.nodes[2*pt+y] == -1:
                self.id += 1
                self.nodes[2*pt+y] = self.id
            self.cnt[pt] += 1
            pt = self.nodes[2*pt+y]
        self.cnt[pt] += 1
 
    # xの削除
    # xが存在しないときは何もしない
    def erase(self,x):
        if self.count(x) == 0:
            return
        pt = 0
        for i in range(self.bitlen-1,-1,-1):
            y = x>>i&1
            self.cnt[pt] -= 1
            pt = self.nodes[2*pt+y]
        self.cnt[pt] -= 1
 
    # 昇順x番目の値(1-indexed)
    def kth_elm(self,x):
        assert 1 <= x <= self.size()
        pt, ans = 0, 0
        for i in range(self.bitlen-1,-1,-1):
            ans <<= 1
            if self.nodes[2*pt] != -1 and self.cnt[self.nodes[2*pt]] > 0:
                if self.cnt[self.nodes[2*pt]] >= x:
                    pt = self.nodes[2*pt]
                else:
                    x -= self.cnt[self.nodes[2*pt]]
                    pt = self.nodes[2*pt+1]
                    ans += 1
            else:
                pt = self.nodes[2*pt+1]
                ans += 1
        return ans
 
    # x以上の最小要素が昇順何番目か(1-indexed)
    # x以上の要素がない時はsize+1を返す
    def lower_bound(self,x):
        pt, ans = 0, 1
        for i in range(self.bitlen-1,-1,-1):
            if pt == -1: break
            if x>>i&1 and self.nodes[2*pt] != -1:
                ans += self.cnt[self.nodes[2*pt]]
            pt = self.nodes[2*pt+(x>>i&1)]
        return ans

bt = BinaryTrie()
N,K = map(int,input().split())
P = list(map(int,input().split()))
 
for i in range(K-1):
    bt.insert(P[i])
 
for i in range(K-1,N):
    bt.insert(P[i])
    print(bt.kth_elm(i+2-K))

https://atcoder.jp/contests/abc234/submissions/28417715

感想

データ構造で殴る👊