This documentation is automatically generated by online-judge-tools/verification-helper
#include "src/datastructure/BinaryTrie.hpp"整数の $2$ 進数表現を利用して Trie 木のように非負整数を管理する std::multiset をより汎用的にしたデータ構造.
以下ではこのデータ構造にその時点で格納されている非負整数の集合を $S$ としている. また整数の最大 bit 長を $w$ とする. __各クエリにおいて返される整数は $xor\_val$ を適用する前のものであることに注意.
| メンバ関数 | 効果 | 時間計算量 |
|---|---|---|
insert(x) |
$S$ に $x$ を追加する. | $O(w)$ |
erase(x) |
$S$ から $x$ を $1$ つ削除する. | $O(w)$ |
find(x) |
$S$ に $x$ が含めれていない場合は $-1$ を,含まれている場合はそれに対応するノード番号(非負)を返す. | クエリ $O(w)$ |
count(x) |
$S$ に含まれる $x$ の個数を求める. | $O(w)$ |
min_element(xor_val) |
$\argmin_{x \in S} x \oplus xor\_val$ を求める. | $O(w)$ |
max_element(xor_val) |
$\argmax_{x \in S} x \oplus xor\_val$ を求める. | $O(w)$ |
kth_element(k, xor_val) |
$T = {s \oplus xor\_val \mid s \in S}$ として,$T$ の小さい方から $k$ 番目 (0-indexed) の元に対応する $S$ の元を返す. | $O(w)$ |
count_less(x, xor_val) |
$T = {s \oplus xor\_val \mid s \in S}$ において $x$ 未満の元の個数を返す. | $O(w)$ |
#pragma once
#include <array>
#include <cassert>
#include <vector>
template <typename T, int MAX_LOG> struct BinaryTrie {
struct Node {
std::array<int, 2> nxt;
int count;
Node() : nxt{-1, -1}, count(0) {}
};
std::vector<Node> nodes;
inline int& next(int i, int j) { return nodes[i].nxt[j]; }
BinaryTrie() { nodes.emplace_back(); }
void insert(const T& x) {
int cur = 0;
for (int i = MAX_LOG - 1; i >= 0; i--) {
int f = x >> i & 1;
int nxt = next(cur, f);
if (nxt == -1) {
nxt = nodes.size();
next(cur, f) = nxt;
nodes.emplace_back();
}
nodes[cur].count++;
cur = nxt;
}
nodes[cur].count++;
}
void erase(const T& x) {
assert(count(x));
int cur = 0;
for (int i = MAX_LOG - 1; i >= 0; i--) {
int f = x >> i & 1;
nodes[cur].count--;
cur = next(cur, f);
}
nodes[cur].count--;
}
int find(const T& x) {
int cur = 0;
for (int i = MAX_LOG - 1; i >= 0; i--) {
int f = x >> i & 1;
cur = next(cur, f);
if (cur == -1) return -1;
}
return cur;
}
int count(const T& x) {
int idx = find(x);
return idx == -1 ? 0 : nodes[idx].count;
}
T min_element(const T& xor_val = 0) {
int cur = 0;
T res = 0;
for (int i = MAX_LOG - 1; i >= 0; i--) {
int f = xor_val >> i & 1;
int l = next(cur, f), r = next(cur, f ^ 1);
if (l == -1 or nodes[l].count == 0) {
cur = r;
res |= T(f ^ 1) << i;
} else {
cur = l;
res |= T(f) << i;
}
}
return res;
}
T max_element(const T& xor_val = 0) { return min_element(~xor_val); }
T kth_element(int k, const T& xor_val = 0) const {
int cur = 0;
T res = 0;
for (int i = MAX_LOG - 1; i >= 0; i--) {
int f = xor_val >> i & 1;
int l = next(cur, f), r = next(cur, f ^ 1);
if (l == -1 or nodes[l].count <= k) {
cur = r;
k -= (l == -1 ? 0 : nodes[l].count);
res |= T(f ^ 1) << i;
} else {
cur = l;
res |= T(f) << i;
}
}
return res;
}
int count_less(const T& x, const T& xor_val = 0) {
int cur = 0;
int res = 0;
for (int i = MAX_LOG - 1; i >= 0; i--) {
int f = xor_val >> i & 1, g = x >> i & 1;
int l = next(cur, f), r = next(cur, f ^ 1);
if (f != g and l != -1) res += nodes[l].count;
cur = next(cur, g);
if (cur == -1) break;
}
return res;
}
};#line 2 "src/datastructure/BinaryTrie.hpp"
#include <array>
#include <cassert>
#include <vector>
template <typename T, int MAX_LOG> struct BinaryTrie {
struct Node {
std::array<int, 2> nxt;
int count;
Node() : nxt{-1, -1}, count(0) {}
};
std::vector<Node> nodes;
inline int& next(int i, int j) { return nodes[i].nxt[j]; }
BinaryTrie() { nodes.emplace_back(); }
void insert(const T& x) {
int cur = 0;
for (int i = MAX_LOG - 1; i >= 0; i--) {
int f = x >> i & 1;
int nxt = next(cur, f);
if (nxt == -1) {
nxt = nodes.size();
next(cur, f) = nxt;
nodes.emplace_back();
}
nodes[cur].count++;
cur = nxt;
}
nodes[cur].count++;
}
void erase(const T& x) {
assert(count(x));
int cur = 0;
for (int i = MAX_LOG - 1; i >= 0; i--) {
int f = x >> i & 1;
nodes[cur].count--;
cur = next(cur, f);
}
nodes[cur].count--;
}
int find(const T& x) {
int cur = 0;
for (int i = MAX_LOG - 1; i >= 0; i--) {
int f = x >> i & 1;
cur = next(cur, f);
if (cur == -1) return -1;
}
return cur;
}
int count(const T& x) {
int idx = find(x);
return idx == -1 ? 0 : nodes[idx].count;
}
T min_element(const T& xor_val = 0) {
int cur = 0;
T res = 0;
for (int i = MAX_LOG - 1; i >= 0; i--) {
int f = xor_val >> i & 1;
int l = next(cur, f), r = next(cur, f ^ 1);
if (l == -1 or nodes[l].count == 0) {
cur = r;
res |= T(f ^ 1) << i;
} else {
cur = l;
res |= T(f) << i;
}
}
return res;
}
T max_element(const T& xor_val = 0) { return min_element(~xor_val); }
T kth_element(int k, const T& xor_val = 0) const {
int cur = 0;
T res = 0;
for (int i = MAX_LOG - 1; i >= 0; i--) {
int f = xor_val >> i & 1;
int l = next(cur, f), r = next(cur, f ^ 1);
if (l == -1 or nodes[l].count <= k) {
cur = r;
k -= (l == -1 ? 0 : nodes[l].count);
res |= T(f ^ 1) << i;
} else {
cur = l;
res |= T(f) << i;
}
}
return res;
}
int count_less(const T& x, const T& xor_val = 0) {
int cur = 0;
int res = 0;
for (int i = MAX_LOG - 1; i >= 0; i--) {
int f = xor_val >> i & 1, g = x >> i & 1;
int l = next(cur, f), r = next(cur, f ^ 1);
if (f != g and l != -1) res += nodes[l].count;
cur = next(cur, g);
if (cur == -1) break;
}
return res;
}
};