Skip to content

手写 bitset

点击展开代码
cpp
#include <bits/stdc++.h>
using namespace std;

struct FastBitset
{
    using ull = unsigned long long;
    static constexpr int W = 64;

    int n; // bit 数量
    int m; // ull 块数量
    vector<ull> a;

    FastBitset(int _n = 0, bool val = false)
    {
        init(_n, val);
    }

    void init(int _n, bool val = false)
    {
        n = _n;
        m = (n + W - 1) >> 6;
        a.assign(m, val ? ~0ULL : 0ULL);
        trim();
    }

    ull last_mask() const
    {
        int r = n & 63;
        if (r == 0)
            return ~0ULL;
        return (1ULL << r) - 1;
    }

    void trim()
    {
        if (m && (n & 63))
        {
            a.back() &= last_mask();
        }
    }

    // 单点操作
    void set(int p)
    {
        a[p >> 6] |= 1ULL << (p & 63);
    }

    void reset(int p)
    {
        a[p >> 6] &= ~(1ULL << (p & 63));
    }

    void flip(int p)
    {
        a[p >> 6] ^= 1ULL << (p & 63);
    }

    bool test(int p) const
    {
        return (a[p >> 6] >> (p & 63)) & 1ULL;
    }

    bool operator[](int p) const
    {
        return test(p);
    }

    // 整体操作
    void set_all()
    {
        fill(a.begin(), a.end(), ~0ULL);
        trim();
    }

    void reset_all()
    {
        fill(a.begin(), a.end(), 0ULL);
    }

    void flip_all()
    {
        for (auto &x : a)
            x = ~x;
        trim();
    }

    int count() const
    {
        int res = 0;
        for (ull x : a)
            res += __builtin_popcountll(x);
        return res;
    }

    bool any() const
    {
        for (ull x : a)
        {
            if (x)
                return true;
        }
        return false;
    }

    bool none() const
    {
        return !any();
    }

    // 找从 pos 开始第一个 1,找不到返回 -1
    int find_first1(int pos = 0) const
    {
        if (pos >= n)
            return -1;

        int id = pos >> 6;
        int off = pos & 63;

        ull x = a[id] & (~0ULL << off);
        while (true)
        {
            if (x)
            {
                int p = (id << 6) + __builtin_ctzll(x);
                return p < n ? p : -1;
            }
            ++id;
            if (id >= m)
                break;
            x = a[id];
        }
        return -1;
    }

    // 找从 pos 开始第一个 0,找不到返回 -1
    int find_first0(int pos = 0) const
    {
        if (pos >= n)
            return -1;

        int id = pos >> 6;
        int off = pos & 63;

        ull x = (~a[id]) & (~0ULL << off);
        if (id == m - 1)
            x &= last_mask();

        while (true)
        {
            if (x)
            {
                int p = (id << 6) + __builtin_ctzll(x);
                return p < n ? p : -1;
            }
            ++id;
            if (id >= m)
                break;
            x = ~a[id];
            if (id == m - 1)
                x &= last_mask();
        }
        return -1;
    }

    // 位运算,要求长度相同
    FastBitset &operator&=(const FastBitset &b)
    {
        assert(n == b.n);
        for (int i = 0; i < m; i++)
            a[i] &= b.a[i];
        return *this;
    }

    FastBitset &operator|=(const FastBitset &b)
    {
        assert(n == b.n);
        for (int i = 0; i < m; i++)
            a[i] |= b.a[i];
        return *this;
    }

    FastBitset &operator^=(const FastBitset &b)
    {
        assert(n == b.n);
        for (int i = 0; i < m; i++)
            a[i] ^= b.a[i];
        return *this;
    }

    friend FastBitset operator&(FastBitset x, const FastBitset &y)
    {
        x &= y;
        return x;
    }

    friend FastBitset operator|(FastBitset x, const FastBitset &y)
    {
        x |= y;
        return x;
    }

    friend FastBitset operator^(FastBitset x, const FastBitset &y)
    {
        x ^= y;
        return x;
    }

    FastBitset operator~() const
    {
        FastBitset res = *this;
        res.flip_all();
        return res;
    }

    // 令当前 bitset = src << k
    void assign_shift_left(const FastBitset &src, int k)
    {
        assert(n == src.n);

        if (k <= 0)
        {
            if (this != &src)
                a = src.a;
            return;
        }

        if (k >= n)
        {
            reset_all();
            return;
        }

        int ws = k >> 6;
        int bs = k & 63;

        // 同源情况:this == &src,需要按原地左移处理,从高块往低块更新
        if (this == &src)
        {
            if (ws)
            {
                for (int i = m - 1; i >= 0; i--)
                {
                    a[i] = (i >= ws ? a[i - ws] : 0ULL);
                }
            }

            if (bs)
            {
                for (int i = m - 1; i >= 1; i--)
                {
                    a[i] = (a[i] << bs) | (a[i - 1] >> (64 - bs));
                }
                a[0] <<= bs;
            }

            trim();
            return;
        }

        // 非同源情况:直接从 src 计算到当前对象
        for (int i = m - 1; i >= 0; i--)
        {
            ull val = 0;

            int from = i - ws;

            if (from >= 0)
            {
                val |= src.a[from] << bs;

                if (bs && from - 1 >= 0)
                {
                    val |= src.a[from - 1] >> (64 - bs);
                }
            }

            a[i] = val;
        }

        trim();
    }

    // 令当前 bitset = src >> k
    void assign_shift_right(const FastBitset &src, int k)
    {
        assert(n == src.n);

        if (k <= 0)
        {
            if (this != &src)
                a = src.a;
            return;
        }

        if (k >= n)
        {
            reset_all();
            return;
        }

        int ws = k >> 6;
        int bs = k & 63;

        // 同源情况:this == &src,需要按原地右移处理,从低块往高块更新
        if (this == &src)
        {
            if (ws)
            {
                for (int i = 0; i < m; i++)
                {
                    a[i] = (i + ws < m ? a[i + ws] : 0ULL);
                }
            }

            if (bs)
            {
                for (int i = 0; i + 1 < m; i++)
                {
                    a[i] = (a[i] >> bs) | (a[i + 1] << (64 - bs));
                }
                a[m - 1] >>= bs;
            }

            trim();
            return;
        }

        // 非同源情况:直接从 src 计算到当前对象
        for (int i = 0; i < m; i++)
        {
            ull val = 0;

            int from = i + ws;

            if (from < m)
            {
                val |= src.a[from] >> bs;

                if (bs && from + 1 < m)
                {
                    val |= src.a[from + 1] << (64 - bs);
                }
            }

            a[i] = val;
        }

        trim();
    }

    FastBitset &operator<<=(int k)
    {
        assign_shift_left(*this, k);
        return *this;
    }

    FastBitset &operator>>=(int k)
    {
        assign_shift_right(*this, k);
        return *this;
    }

    friend FastBitset operator<<(FastBitset x, int k)
    {
        x <<= k;
        return x;
    }

    friend FastBitset operator>>(FastBitset x, int k)
    {
        x >>= k;
        return x;
    }

    bool operator==(const FastBitset &b) const
    {
        return n == b.n && a == b.a;
    }

    bool operator!=(const FastBitset &b) const
    {
        return !(*this == b);
    }

    void assign_and(const FastBitset &x, const FastBitset &y)
    {
        assert(n == x.n && n == y.n);
        for (int i = 0; i < m; i++)
            a[i] = x.a[i] & y.a[i];
    }

    void assign_or(const FastBitset &x, const FastBitset &y)
    {
        assert(n == x.n && n == y.n);
        for (int i = 0; i < m; i++)
            a[i] = x.a[i] | y.a[i];
    }

    void assign_xor(const FastBitset &x, const FastBitset &y)
    {
        assert(n == x.n && n == y.n);
        for (int i = 0; i < m; i++)
            a[i] = x.a[i] ^ y.a[i];
    }
    int count(const FastBitset &b) const
    {
        int res = 0;
        for (int i = 0; i < m; i++)
            res += __builtin_popcountll(a[i] & b.a[i]);
        return res;
    }
    FastBitset slice(int l, int r) const {
        int len = r - l + 1;
        FastBitset res(len);

        int base = l >> 6;     // 原 bitset 中的起始块
        int off = l & 63;      // 起点在块内的偏移

        for (int i = 0; i < res.m; i++) {
            int j = base + i;

            if (off == 0) {
                res.a[i] = a[j];
            } else {
                res.a[i] = a[j] >> off;

                if (j + 1 < m) {
                    res.a[i] |= a[j + 1] << (64 - off);
                }
            }
        }

        res.trim();
        return res;
    }
};