cp-library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub rniya/cp-library

:heavy_check_mark: Binary Trie
(src/datastructure/BinaryTrie.hpp)

概要

整数の $2$ 進数表現を利用して Trie 木のように非負整数を管理する std::multiset をより汎用的にしたデータ構造.

以下ではこのデータ構造にその時点で格納されている非負整数の集合を $S$ としている. また整数の最大 bit 長を $w$ とする. __各クエリにおいて返される整数は $xor\_val$ を適用する前のものであることに注意.

メンバ関数 効果 時間計算量
insert(x) $S$ に $x$ を追加する. $O(w)$
erase(x) $S$ から $x$ を $1$ つ削除する. $O(w)$
find(x) $S$ に $x$ が含めれていない場合は $-1$ を,含まれている場合はそれに対応するノード番号(非負)を返す. クエリ $O(w)$
count(x) $S$ に含まれる $x$ の個数を求める. $O(w)$
min_element(xor_val) $\argmin_{x \in S} x \oplus xor\_val$ を求める. $O(w)$
max_element(xor_val) $\argmax_{x \in S} x \oplus xor\_val$ を求める. $O(w)$
kth_element(k, xor_val) $T = {s \oplus xor\_val \mid s \in S}$ として,$T$ の小さい方から $k$ 番目 (0-indexed) の元に対応する $S$ の元を返す. $O(w)$
count_less(x, xor_val) $T = {s \oplus xor\_val \mid s \in S}$ において $x$ 未満の元の個数を返す. $O(w)$

Verified with

Code

#pragma once
#include <array>
#include <cassert>
#include <vector>

template <typename T, int MAX_LOG> struct BinaryTrie {
    struct Node {
        std::array<int, 2> nxt;
        int count;
        Node() : nxt{-1, -1}, count(0) {}
    };

    std::vector<Node> nodes;

    inline int& next(int i, int j) { return nodes[i].nxt[j]; }

    BinaryTrie() { nodes.emplace_back(); }

    void insert(const T& x) {
        int cur = 0;
        for (int i = MAX_LOG - 1; i >= 0; i--) {
            int f = x >> i & 1;
            int nxt = next(cur, f);
            if (nxt == -1) {
                nxt = nodes.size();
                next(cur, f) = nxt;
                nodes.emplace_back();
            }
            nodes[cur].count++;
            cur = nxt;
        }
        nodes[cur].count++;
    }

    void erase(const T& x) {
        assert(count(x));
        int cur = 0;
        for (int i = MAX_LOG - 1; i >= 0; i--) {
            int f = x >> i & 1;
            nodes[cur].count--;
            cur = next(cur, f);
        }
        nodes[cur].count--;
    }

    int find(const T& x) {
        int cur = 0;
        for (int i = MAX_LOG - 1; i >= 0; i--) {
            int f = x >> i & 1;
            cur = next(cur, f);
            if (cur == -1) return -1;
        }
        return cur;
    }

    int count(const T& x) {
        int idx = find(x);
        return idx == -1 ? 0 : nodes[idx].count;
    }

    T min_element(const T& xor_val = 0) {
        int cur = 0;
        T res = 0;
        for (int i = MAX_LOG - 1; i >= 0; i--) {
            int f = xor_val >> i & 1;
            int l = next(cur, f), r = next(cur, f ^ 1);
            if (l == -1 or nodes[l].count == 0) {
                cur = r;
                res |= T(f ^ 1) << i;
            } else {
                cur = l;
                res |= T(f) << i;
            }
        }
        return res;
    }

    T max_element(const T& xor_val = 0) { return min_element(~xor_val); }

    T kth_element(int k, const T& xor_val = 0) const {
        int cur = 0;
        T res = 0;
        for (int i = MAX_LOG - 1; i >= 0; i--) {
            int f = xor_val >> i & 1;
            int l = next(cur, f), r = next(cur, f ^ 1);
            if (l == -1 or nodes[l].count <= k) {
                cur = r;
                k -= (l == -1 ? 0 : nodes[l].count);
                res |= T(f ^ 1) << i;
            } else {
                cur = l;
                res |= T(f) << i;
            }
        }
        return res;
    }

    int count_less(const T& x, const T& xor_val = 0) {
        int cur = 0;
        int res = 0;
        for (int i = MAX_LOG - 1; i >= 0; i--) {
            int f = xor_val >> i & 1, g = x >> i & 1;
            int l = next(cur, f), r = next(cur, f ^ 1);
            if (f != g and l != -1) res += nodes[l].count;
            cur = next(cur, g);
            if (cur == -1) break;
        }
        return res;
    }
};
#line 2 "src/datastructure/BinaryTrie.hpp"
#include <array>
#include <cassert>
#include <vector>

template <typename T, int MAX_LOG> struct BinaryTrie {
    struct Node {
        std::array<int, 2> nxt;
        int count;
        Node() : nxt{-1, -1}, count(0) {}
    };

    std::vector<Node> nodes;

    inline int& next(int i, int j) { return nodes[i].nxt[j]; }

    BinaryTrie() { nodes.emplace_back(); }

    void insert(const T& x) {
        int cur = 0;
        for (int i = MAX_LOG - 1; i >= 0; i--) {
            int f = x >> i & 1;
            int nxt = next(cur, f);
            if (nxt == -1) {
                nxt = nodes.size();
                next(cur, f) = nxt;
                nodes.emplace_back();
            }
            nodes[cur].count++;
            cur = nxt;
        }
        nodes[cur].count++;
    }

    void erase(const T& x) {
        assert(count(x));
        int cur = 0;
        for (int i = MAX_LOG - 1; i >= 0; i--) {
            int f = x >> i & 1;
            nodes[cur].count--;
            cur = next(cur, f);
        }
        nodes[cur].count--;
    }

    int find(const T& x) {
        int cur = 0;
        for (int i = MAX_LOG - 1; i >= 0; i--) {
            int f = x >> i & 1;
            cur = next(cur, f);
            if (cur == -1) return -1;
        }
        return cur;
    }

    int count(const T& x) {
        int idx = find(x);
        return idx == -1 ? 0 : nodes[idx].count;
    }

    T min_element(const T& xor_val = 0) {
        int cur = 0;
        T res = 0;
        for (int i = MAX_LOG - 1; i >= 0; i--) {
            int f = xor_val >> i & 1;
            int l = next(cur, f), r = next(cur, f ^ 1);
            if (l == -1 or nodes[l].count == 0) {
                cur = r;
                res |= T(f ^ 1) << i;
            } else {
                cur = l;
                res |= T(f) << i;
            }
        }
        return res;
    }

    T max_element(const T& xor_val = 0) { return min_element(~xor_val); }

    T kth_element(int k, const T& xor_val = 0) const {
        int cur = 0;
        T res = 0;
        for (int i = MAX_LOG - 1; i >= 0; i--) {
            int f = xor_val >> i & 1;
            int l = next(cur, f), r = next(cur, f ^ 1);
            if (l == -1 or nodes[l].count <= k) {
                cur = r;
                k -= (l == -1 ? 0 : nodes[l].count);
                res |= T(f ^ 1) << i;
            } else {
                cur = l;
                res |= T(f) << i;
            }
        }
        return res;
    }

    int count_less(const T& x, const T& xor_val = 0) {
        int cur = 0;
        int res = 0;
        for (int i = MAX_LOG - 1; i >= 0; i--) {
            int f = xor_val >> i & 1, g = x >> i & 1;
            int l = next(cur, f), r = next(cur, f ^ 1);
            if (f != g and l != -1) res += nodes[l].count;
            cur = next(cur, g);
            if (cur == -1) break;
        }
        return res;
    }
};
Back to top page