Skip to content

算法竞赛常用平衡树有 Treap、Fhq-Treap、Splay、替罪羊树。

实际上还有 WBLT,但是感觉 WBLT 能干的上面总的找到能干的,算了,体量不要太大,后续再说。

Treap

Treap 是二叉查找树和二叉堆的结合,即:树上一个节点维护两个值,第一关键字维护 BST 的性质,第二关键字维护 heap 的性质。在本文中,称第一关键字为键值(Key),第二关键字为权值(Val),且本文中 Treap 均默认维护大根堆。

Treap 通过初始化固定 Val,使得 BST 能根据 heap 的性质在 BST 形态因增删改查操作发生改变时,动态调整形态。而初始化时,将 Val 设为一个随机值,即可保证 BST 的树高时刻在 O(logn)

因为 Treap 只通过 Val 维护树高,所以增删改查操作不受影响,与朴素 BST 操作完全一致。只是需要在增删等会影响 BST 形态的操作时,同时维护 Val 的 heap 性质即可。

插入:

假设插入元素 v 在 Treap 上新建了一个节点 p,新建节点 p 的 Val 是随机产生的,可能并不满足 Val(p)Val(fa(p)),那么就需要修正 Treap 的形态。

根据 pfa(p) 的左儿子还是右儿子,可将修正操作分为左旋和右旋。

上图中不满足 heap 性质需要旋转的点分别为 3(左旋) 和 2(右旋) 只有上图中涉及的点会在旋转中被影响到,若其中部分点不存在,当作空节点同样处理即可。

具体而言,当前节点为 pp 左儿子为 lsp 右儿子为 rs。对于右旋,令 qls 的右儿子,ls 成为根,p 成为 ls 的右儿子,因此 q 成为 p 的左儿子;对于左旋,令 qls 的左儿子,rs 成为根,p 成为 rs 的左儿子,因此 q 成为 p 的右儿子。

p 调整至父节点后,可能会继续和新的父节点不满足 heap 的性质,所以递归下去继续旋转即可。因为树高是 O(logn) 的,最多转到根节点,所以旋转次数是 O(logn) 的。

删除:

假设删除元素 v 使得 Treap 上节点 p 被删除,若 p 不存在左右儿子,则 p 可以直接删除。若 p 只存在左儿子或右儿子,则可以将 p 替换成左儿子或右儿子即可。若 p 既存在左儿子也存在右儿子,通常有两种解决办法:第一种,惰性删除,即只是将节点上的 cnt 标记为 0,在后续的操作中特殊处理做类似考虑。第二种,将 p 通过旋转操作旋转至叶子后不存在左右儿子,即可直接删去。因为树高是 O(logn) 的所以旋转次数是 O(logn) 的。但是向左儿子旋转还是向右儿子旋转由 heap 的性质决定,要保证旋转后节点仍保持 heap 的性质,即:向 Val 小的儿子转(把 Val 大的儿子转上去)。

cpp
struct node {
    int l, r;
    int key, val;
    int cnt, sz;
};
struct Treap {
    node tr[N];
    int root, idx;
    void push_up(int p) { tr[p].sz = tr[ls(p)].sz + tr[rs(p)].sz + tr[p].cnt; }
    int make_node(int key) {
        tr[++idx].key = key;
        tr[idx].val = rand();
        tr[idx].cnt = tr[idx].sz = 1;
        return idx;
    }
    void zig(int &p) {
        int q = ls(p);
        ls(p) = rs(q);
        rs(q) = p;
        p = q;
        push_up(rs(p)), push_up(p);
    }
    void zag(int &p) {
        int q = rs(p);
        rs(p) = ls(q);
        ls(q) = p;
        p = q;
        push_up(ls(p)), push_up(p);
    }
    void build() {
        make_node(-inf);
        make_node(inf);
        root = 1;
        rs(1) = 2;
        push_up(root);
        if (tr[1].val < tr[2].val)
            zag(root);
    }
    void insert(int &p, int key) {
        if (!p)
            p = make_node(key);
        else if (tr[p].key == key)
            tr[p].cnt++;
        else if (key < tr[p].key) {
            insert(ls(p), key);
            if (tr[ls(p)].val > tr[p].val)
                zig(p);
        } else if (key > tr[p].key) {
            insert(rs(p), key);
            if (tr[rs(p)].val > tr[p].val)
                zag(p);
        }
        push_up(p);
    }
    void remove(int &p, int key) {
        if (!p)
            return;
        if (tr[p].key == key) {
            if (tr[p].cnt > 1)
                tr[p].cnt--;
            else if (ls(p) || rs(p)) {
                if (!rs(p) || tr[ls(p)].val > tr[rs(p)].val) {
                    zig(p);
                    remove(rs(p), key);
                } else {
                    zag(p);
                    remove(ls(p), key);
                }
            } else
                p = 0;
        } else if (key < tr[p].key)
            remove(ls(p), key);
        else if (key > tr[p].key)
            remove(rs(p), key);
        push_up(p);
    }
    int get_rank_by_key(int p, int key) {
        if (!p)
            return 0;
        if (key == tr[p].key)
            return tr[ls(p)].sz;
        if (key < tr[p].key)
            return get_rank_by_key(ls(p), key);
        return tr[ls(p)].sz + tr[p].cnt + get_rank_by_key(rs(p), key);
    }
    int get_key_by_rank(int p, int rank) {
        if (!p)
            return inf;
        if (tr[ls(p)].sz >= rank)
            return get_key_by_rank(ls(p), rank);
        if (tr[ls(p)].sz + tr[p].cnt >= rank)
            return tr[p].key;
        return get_key_by_rank(rs(p), rank - tr[ls(p)].sz - tr[p].cnt);
    }
    int get_prev(int p, int key) {
        if (!p)
            return -inf;
        if (key <= tr[p].key)
            return get_prev(ls(p), key);
        return max(tr[p].key, get_prev(rs(p), key));
    }
    int get_next(int p, int key) {
        if (!p)
            return inf;
        if (key >= tr[p].key)
            return get_next(rs(p), key);
        return min(tr[p].key, get_next(ls(p), key));
    }
} treap;

Fhq-Treep

通常情况下,Treap 指的是有旋 Treap,即通过旋转操作维护 heap 性质。但是实际上,还存在另一种维护 heap 性质的方式:分裂、合并。一般称之为 Fhq-Treap 或无旋 Treap。

分裂:

分裂操作有两种:按键值分裂,按排名分裂。

键值分裂:

对于一个参数 v,把 Treap 分裂成两部分:tree1 满足所有节点键值 vtree2 满足所有节点键值 >v。返回 tree1 的根和 tree2 的根。

具体实现,从 Treap 根节点开始,若当前节点键值 Key>v,则当前节点和右子树都属于 tree2,同时左子树中可能仍然存在键值 >v 的节点,所以要递归下去分裂并把递归返回的 tree2 作为当前节点的左子树;若当前节点键值 Keyv,则当前节点和左子树都属于 tree1,同时右子树中可能仍然存在键值 v 的节点,所以要递归下去分裂并把递归返回的 tree1 作为当前节点的右子树。

排名分裂:

对于一个参数 rank,把 Treap 分裂成三部分:tree1 满足所有节点排名 <ranktree3 满足所有节点排名 >rank,剩下的部分放在 tree2

具体实现,从 Treap 根节点开始,若当前节点左子树 szrank,则当前节点和右子树都属于 tree3,同时左子树中可能仍然存在属于 tree3 的节点,所以要递归下去分裂并把递归返回的 tree3 作为当前节点的左子树。若左子树 sz+cnt[p]rank,则当前节点左子树就是 tree1,当前节点就是 tree2,当前节点右子树就是 tree3

实际分裂中,当然可能出现不存在满足条件的 tree,对应部分返回空节点即可。注意分裂时要修改分裂后,节点的父子关系。

合并:

合并接受两个参数:root1,root2,要求 root1 为根的 tree1 中所有节点的键值 root2 为根的 tree2 中的所有节点的键值。合并时,要维护 heap 的性质,所以不能随意地把 root1,root2 中的一个作为根。若 Val1>Val2root1 作为根,把 root1 的右子树和 tree2 递归合并,root1 的左子树保持不变。若 Val1<Val2root2 作为根,把 root2 的左子树和 tree1 递归合并,root2 的右子树保持不变。因为 Treap 节点的 Val 是随机生成的,所以合并的过程是随机的,此方法能保证时间复杂度是 O(logn) 的。

Fhq-Treap 要保证操作结束后还是一整棵树,也就是每一次通过分裂操作实现别的操作后,都要通过合并操作把树合并回去。

围绕 Fhq-Treap 的分裂、合并操作,增删改查操作和朴素 BST 有所不同,存在更加简便的做法:本质是将要操作元素单独分裂出来,然后操作即可。

但是对于查询操作也存在一个弊端,通过分裂和合并来实现查询,每一次操作调用分裂和合并,常数会较大,但是树的形态没有变化,而 Fhq-Treap 保证了树高始终 O(logn),所以可以采用和朴素 BST 一致的方式进行查询,常数较小。

cpp
 struct node {
    int ls, rs, key, val, cnt, sz;
};
struct fhq_treap {
    node tr[N];
    int idx, root;
    int make_node(int key) {
        tr[++idx] = {0, 0, key, -rand(), 1, 1};
        return idx;
    }
    void push_up(int p) { tr[p].sz = tr[ls(p)].sz + tr[rs(p)].sz + tr[p].cnt; }
    pair<int, int> split_by_key(int p, int key) {
        if (!p)
            return {0, 0};
        if (tr[p].key <= key) {
            auto temp = split_by_key(rs(p), key);
            rs(p) = temp.first;
            push_up(p);
            return {p, temp.second};
        } else {
            auto temp = split_by_key(ls(p), key);
            ls(p) = temp.second;
            push_up(p);
            return {temp.first, p};
        }
    }
    tuple<int, int, int> split_by_rank(int p, int rank) {
        if (!p)
            return {0, 0, 0};
        if (rank <= tr[ls(p)].sz) {
            int l, mid, r;
            tie(l, mid, r) = split_by_rank(ls(p), rank);
            ls(p) = r;
            push_up(p);
            return {l, mid, p};
        } else if (rank <= tr[ls(p)].sz + tr[p].cnt) {
            int l = ls(p), r = rs(p);
            ls(p) = rs(p) = 0;
            return {l, p, r};
        } else {
            int l, mid, r;
            tie(l, mid, r) =
                split_by_rank(rs(p), rank - tr[ls(p)].sz - tr[p].cnt);
            rs(p) = l;
            push_up(p);
            return {p, mid, r};
        }
    }
    int merge(int u, int v) {
        if (!u && !v)
            return 0;
        if (u && !v)
            return u;
        if (!u && v)
            return v;
        if (tr[u].val < tr[v].val) {
            rs(u) = merge(rs(u), v);
            push_up(u);
            return u;
        } else {
            ls(v) = merge(u, ls(v));
            push_up(v);
            return v;
        }
    }
    void insert(int key) {
        auto temp = split_by_key(root, key);
        auto l = split_by_key(temp.first, key - 1);
        int now;
        if (!l.second) {
            now = make_node(key);
        } else {
            tr[l.second].cnt++;
            push_up(l.second);
        }
        int L = merge(l.first, !l.second ? now : l.second);
        root = merge(L, temp.second);
    }
    void remove(int key) {
        auto temp = split_by_key(root, key);
        auto l = split_by_key(temp.first, key - 1);
        if (tr[l.second].cnt > 1) {
            tr[l.second].cnt--;
            push_up(l.second);
            l.first = merge(l.first, l.second);
        }
        root = merge(l.first, temp.second);
    }
    int get_rank_by_key(int p, int key) {
        auto temp = split_by_key(p, key - 1);
        int res = tr[temp.first].sz + 1;
        root = merge(temp.first, temp.second);
        return res;
    }
    int get_key_by_rank(int p, int rank) {
        int l, mid, r;
        tie(l, mid, r) = split_by_rank(p, rank);
        int res = tr[mid].key;
        root = merge(merge(l, mid), r);
        return res;
    }
    int get_pre(int key) {
        auto temp = split_by_key(root, key - 1);
        int res = get_key_by_rank(temp.first, tr[temp.first].sz);
        root = merge(temp.first, temp.second);
        return res;
    }
    int get_suf(int key) {
        auto temp = split_by_key(root, key);
        int res = get_key_by_rank(temp.second, 1);
        root = merge(temp.first, temp.second);
        return res;
    }
} tree;

Splay

伸展:

Splay 通过伸展操作,不断将某个节点旋转到根节点,即任意操作后得到的节点,都要转到根。能够在均摊 O(logn) 的时间内完成增删改查。

因为 Splay 的伸展操作,需要考虑的情况过于繁多(主要是多,单独并不难考虑),所以为了简化问题,本文将略过这个具体过程,将其视作一个伸展操作的封装即可。

具体而言,splay(x,y) 即表示把 x 旋转成 y 的儿子。要求 yx 的祖先,否则不会执行。因为一般情况下,根节点没有父节点,而按照 splay(x,y) 的定义,如果想把 x 转到根,根不能没有父亲,所以 Splay 一般特殊定义根的父亲为 0

Splay 上的增删改查都基于 splay(x,y) 实现。具体而言,先目标元素的前驱转到根,再把目标元素的后继转到前驱的右儿子,此时,目标的位置就是后继的左儿子。这里和 Fhq-Treap 相同,都是将目标元素表示成一棵子树。

cpp
struct node {
    int s[2];
    int key, fa, cnt, sz;
};
struct Splay {
    node tr[N];
    int root, idx;
    void push_up(int p) { tr[p].sz = tr[ls(p)].sz + tr[rs(p)].sz + tr[p].cnt; }
    void rotate(int x) {
        int y = tr[x].fa, z = tr[y].fa;
        // push_down(y); 如果需要 push_down 的话
        // push_down(x);
        bool w = (rs(y) == x);
        tr[z].s[rs(z) == y] = x, tr[x].fa = z;
        tr[y].s[w] = tr[x].s[w ^ 1], tr[tr[x].s[w ^ 1]].fa = y;
        tr[x].s[w ^ 1] = y, tr[y].fa = x;
        push_up(y), push_up(x);
    }
    void splay(int x, int k) {
        while (tr[x].fa != k) {
            int y = tr[x].fa, z = tr[y].fa;
            if (z != k) {
                if ((rs(y) == x) ^ (rs(z) == y))
                    rotate(x);
                else
                    rotate(y);
            }
            rotate(x);
        }
        if (!k)
            root = x;
    }
    int make_node(int key, int fa) {
        tr[++idx] = {0, 0, key, fa, 1, 1};
        return idx;
    }
    void init() {
        for (int i = 1; i <= idx; i++)
            tr[i] = {0, 0, 0, 0, 0, 0};
        idx = root = 0;
        tr[++idx] = {0, 2, -inf, 0, 1, 2};
        tr[++idx] = {0, 0, inf, 1, 1, 1};
        root = 1;
    }
    int get_pre(int key, int y = 0) {
        int x = root, res = 0;
        while (x) {
            if (tr[x].key < key) {
                if (!res || tr[res].key < tr[x].key)
                    res = x;
                x = rs(x);
            } else {
                x = ls(x);
            }
        }
        // 因为初始化插入了 -inf 所以前驱一定存在
        splay(res, y);
        return res;
    }
    int get_suf(int key, int y = 0) {
        int x = root, res = 0;
        while (x) {
            if (tr[x].key > key) {
                if (!res || tr[res].key > tr[x].key)
                    res = x;
                x = ls(x);
            } else {
                x = rs(x);
            }
        }
        // 因为初始化插入了 inf 所以后继一定存在
        splay(res, y);
        return res;
    }
    void insert(int key) {
        auto pre = get_pre(key);
        auto suf = get_suf(key, pre);
        auto &now = ls(suf);
        if (now) {
            tr[now].cnt++;
        } else {
            now = make_node(key, suf);
        }
        splay(now, 0);
    }
    void remove(int key) {
        auto pre = get_pre(key);
        auto suf = get_suf(key, pre);
        auto &now = ls(suf);
        tr[now].cnt--;
        if (!tr[now].cnt) {
            now = 0;
        } else {
            splay(now, 0);
        }
    }
    int get_rank_by_key(int key) {
        // 统计比 key 小的数量,注意 -inf
        get_pre(key);
        return tr[root].cnt + tr[ls(root)].sz;
    }
    int get_key_by_rank(int rank) {
        int x = root;
        rank++; // 需要加上 -inf 的贡献
        while (x) {
            if (tr[ls(x)].sz >= rank) {
                x = ls(x);
            } else if (tr[ls(x)].sz + tr[x].cnt >= rank) {
                splay(x, 0);
                return tr[root].key;
            } else {
                rank -= tr[ls(x)].sz + tr[x].cnt;
                x = rs(x);
            }
        }
        // 没有做存在性检查,上述代码在保证有解的前提下
    }
} tree;

Splay 采用迭代实现,一方面是为了方便在操作后进行 splay 操作,另一方面是抵消常数(实际这部分影响比较小)。

替罪羊树

替罪羊树通过引入一个平衡因子 α,表示当子节点的子树大小超过当前节点的子树大小 ×α 时将子节点的子树重构的方式,保证树高始终在 O(logn)。一般 α 设为 0.70.8

重构:

重构分为两个步骤:按中序遍历展开成序列;二分建树。返回重构后的根节点。

插入:

插入部分和朴素 BST 一致,区别在于递归返回时要判断子节点的子树是否需要重构。

删除:

因为替罪羊树没有随意修改树形态的操作(重构要求子节点子树大小),所以不能采用一般的删除手段。可以采用惰性删除,cnt0 表示这个点已被删除。因为点不会被删除多次,所以 cnt 不为负。

前驱/后继:

因为替罪羊树采用惰性删除,所以查询前驱/后继时,经过的键值不能直接递归。在朴素 BST 中查询 v 的前驱时,若当前的键值 <v,则递归右子树查询。因为朴素 BST 中节点上的键值是一定存在的,所以可以向更大的右节点递归。但是在替罪羊树中,当前节点的键值可能不存在,此时不能向右递归,因为可能左子树中还可能有前驱。因此最坏情况下需要遍历整棵 BST。时间复杂度:O(n)

如果要直接递归求前驱/后继,为每个点再维护一个子树 min,max 即可。

另一个简单的实现是,rank 和 K-th 不受惰性删除影响,所以,可以通过 rank 和 K-th 查询前驱/后继。

实践中,虽然通过 rank 和 K-th 查询前驱/后继需要操作两次,但是因为不需要维护 min,max,所以常数差不多,可能前者还更快一点。

cpp
struct node {
    int ls, rs, key, cnt, sz, s;
    int maxn, minn;
};
struct Tzy_tree {
    node tr[N];
    int root, idx;
    int sec[N];
    double alpha = 0.7;
    bool need_rebuild(int p) {
        return alpha * tr[p].s <= (double)max(tr[ls(p)].s, tr[rs(p)].s);
    }

    void push_up(int p) {
        tr[p].s = tr[ls(p)].s + tr[rs(p)].s + 1;
        tr[p].sz = tr[ls(p)].sz + tr[rs(p)].sz + tr[p].cnt;
        if (!tr[p].cnt) {
            tr[p].maxn = -inf;
            tr[p].minn = inf;
        } else {
            tr[p].maxn = tr[p].minn = tr[p].key;
        }

        tr[p].maxn = max({tr[p].maxn, tr[ls(p)].maxn, tr[rs(p)].maxn});
        tr[p].minn = min({tr[p].minn, tr[ls(p)].minn, tr[rs(p)].minn});
    }
    void Flatten(int &id, int p) {
        if (!p)
            return;
        Flatten(id, ls(p));
        if (tr[p].cnt)
            sec[++id] = p;
        Flatten(id, rs(p));
    }
    int Rebuild(int l, int r) {
        if (l > r)
            return 0;
        int mid = l + r >> 1;
        ls(sec[mid]) = Rebuild(l, mid - 1);
        rs(sec[mid]) = Rebuild(mid + 1, r);
        push_up(sec[mid]);
        return sec[mid];
    }
    void Re(int &p) {
        int id = 0;
        Flatten(id, p);
        p = Rebuild(1, id);
    }
    int make_node(int key) {
        tr[++idx] = {0, 0, key, 1, 1, 1, key, key};
        return idx;
    }
    void init() {
        for (int i = 0; i <= idx; i++)
            tr[i] = {0, 0, 0, 0, 0, 0, -inf, inf};
        idx = root = 0;
        tr[++idx] = {0, 2, -inf, 1, 2, 2, inf, -inf};
        tr[++idx] = {0, 0, inf, 1, 1, 1, inf, inf};
        root = 1;
    }
    void insert(int &p, int key) {
        if (!p) {
            p = make_node(key);
            return;
        }
        if (key < tr[p].key) {
            insert(ls(p), key);
        } else if (key == tr[p].key) {
            tr[p].cnt++;
        } else {
            insert(rs(p), key);
        }
        push_up(p);
        if (need_rebuild(p))
            Re(p);
    }
    void remove(int p, int key) {
        // if (!p) return ;
        if (key < tr[p].key) {
            remove(ls(p), key);
        } else if (key == tr[p].key) {
            tr[p].cnt--;
        } else {
            remove(rs(p), key);
        }
        push_up(p);
    }
    int get_rank_by_key(int p, int key) {
        if (!p)
            return 0;
        if (tr[p].key < key)
            return tr[p].cnt + tr[ls(p)].sz + get_rank_by_key(rs(p), key);
        return get_rank_by_key(ls(p), key);
    }
    int get_key_by_rank(int p, int rank) {
        if (!p)
            return 0;
        if (tr[ls(p)].sz >= rank)
            return get_key_by_rank(ls(p), rank);
        if (tr[ls(p)].sz + tr[p].cnt >= rank)
            return tr[p].key;
        return get_key_by_rank(rs(p), rank - tr[ls(p)].sz - tr[p].cnt);
    }
    int get_pre(int key) {
        int x = root, res = -inf;
        while (x) {
            if (tr[x].key < key) {
                if (tr[x].cnt) {
                    res = max(res, tr[x].key);
                } else
                    res = max(res, tr[ls(x)].maxn);
                x = rs(x);
            } else {
                x = ls(x);
            }
        }
        return res;
    }
    int get_suf(int key) {
        int x = root, res = inf;
        while (x) {
            if (tr[x].key > key) {
                if (tr[x].cnt) {
                    res = min(res, tr[x].key);
                } else
                    res = min(res, tr[rs(x)].minn);
                x = ls(x);
            } else {
                x = rs(x);
            }
        }
        return res;
    }
} tree;

序列平衡树

朴素情况下,平衡树用于维护值域,平衡树的中序遍历即为集合中的元素升序后的结果。特殊地,将序列下标视作键值,使用平衡树维护,那么平衡树的中序遍历即为原序列。

原则上所有平衡树都能维护序列,但是由于部分操作和部分平衡树的特殊性,导致部分平衡树在这方面更强大。

使用平衡树维护序列,一般是动态维护序列(在序列中插入元素)否则平衡树建树后形态不发生变化,直接在一开始使用二分建树即可。

区间加:

在朴素 BST 上,将 [l,r,+v] 转换成 [1,l1,v][1,r,+v],那么只用标记出 x{l1,r} 的前驱及其路径上 x 的所有位置即可。具体而言,若 Keypx,那么左子树都满足 Keyqx,左子树打上一个区间加的懒标记即可。

时间复杂度同前驱:O(logn)

区间翻转:

和区间加不同,区间翻转不能满足可差分性,不能采用和区间加一样的方式维护。

但是 Fhq-Treap 和 Splay 可以做到把一个区间表示成一棵子树,那么就可以使用懒标记维护。

区间移动:

和区间翻转类似,将一个区间在平衡树上用一个子树表示,然后拼接子树实现区间移动。

区间插入:

根据输入内容,在某两个位置之间插入一段序列,若能预先计算最终的位置,直接按 Key 插入即可。若不能预先计算最终位置,按 rank 找到前驱、后继判断插入哪一个儿子即可。

时间复杂度同前驱:O(klogn),其中 k 是输入序列长度。

根据区间在 Fhq-Treap 和 Splay 上的子树表示,可以先 O(k) 建树,然后直接拼接即可,时间复杂度:O(k+logn)

Fhq-Treap O(n) 建树

  • 二分建树时,构造 Val 以满足 heap 性质。
  • 二分建树时,随机 Val 但是不维护 heap 性质,在后续 merge 时维护即可。
  • Fhq-Treap 是 treap,仿照笛卡尔树的 O(n) 方式建树即可。

Splay O(n) 建树

直接二分建树即可。

比较

种类时间常数区间翻转持久化
Treap×
Fhq-Treap
Splay最大×
替罪羊树最小××

WBLT 时间常数可能还要更小一些,再说吧。

一般情况下大概是这样,但是如果出题人数据造的不够优秀,那么会出现“利好”某些数据结构的情况,没办法,自认倒霉/喷出题人吧。