定义

template < typename Key >struct Splay ;

Splay 是关联容器,含有 key 类型对象的已排序集。用比较函数less<Key>进行排序。搜索、移除和插入拥有对数复杂度。 Splay 以Splay树实现。

成员类型

成员类型定义
value_typeKey

成员

分类或标识符描述
常规
构造函数构造 Splay (公开成员函数)
析构函数析构 Splay (公开成员函数)
容量
empty检查是否为空 (公开成员函数)
size返回容纳的元素数 (公开成员函数)
修改器
clear清除内容(公开成员函数)
ins插入结点(公开成员函数)
del擦除元素(公开成员函数)
maintain在改变节点位置后,将节点的子树更新(私有成员函数)
rotate将节点上移一个位置(公开成员函数)
splay将节点旋转至根节点(公开成员函数)
查找
get判断节点是父亲的左儿子或右儿子。
rk查询节点的排名 (公开成员函数)
kth查询排名的值 (公开成员函数)
pre查询前驱 (公开成员函数)
nxt查询后继 (公开成员函数)

可能的实现

template < typename Key >
struct Splay {
    int rt, tot, fa[MAXN], ch[MAXN][2], cnt[MAXN], sz[MAXN];
    Key val[MAXN];
    void maintain( int x ) {
        sz[x] = sz[ch[x][0]] + sz[ch[x][1]] + cnt[x];
    }
    bool get( int x ) {
        return x == ch[fa[x]][1];
    }
    void clear( int x ) {
        ch[x][0] = ch[x][1] = fa[x] = val[x] = sz[x] = cnt[x] = 0;
    }
    void rotate( int x ) {
        int y = fa[x], z = fa[y], chk = get( x );
        ch[y][chk] = ch[x][chk ^ 1];
        fa[ch[x][chk ^ 1]] = y;
        ch[x][chk ^ 1] = y;
        fa[y] = x;
        fa[x] = z;
        if ( z )
            ch[z][y == ch[z][1]] = x;
        maintain( x );
        maintain( y );
    }
    void splay( int x ) {
        for ( int f = fa[x]; f = fa[x], f; rotate( x ) )
            if ( fa[f] )
                rotate( get( x ) == get( f ) ? f : x );
        rt = x;
    }
    void ins( int k ) {
        if ( !rt ) {
            val[++tot] = k;
            cnt[tot]++;
            rt = tot;
            maintain( rt );
            return;
        }
        int cnr = rt, f = 0;
        while ( true ) {
            if ( val[cnr] == k ) {
                cnt[cnr]++;
                maintain( cnr );
                maintain( f );
                splay( cnr );
                break;
            }
            f = cnr;
            cnr = ch[cnr][val[cnr] < k];
            if ( !cnr ) {
                val[++tot] = k;
                cnt[tot]++;
                fa[tot] = f;
                ch[f][val[f] < k] = tot;
                maintain( tot );
                maintain( f );
                splay( tot );
                break;
            }
        }
    }
    int rk( int k ) {
        int res = 0, cnr = rt;
        while ( 1 ) {
            if ( k < val[cnr] ) {
                cnr = ch[cnr][0];
            }
            else {
                res += sz[ch[cnr][0]];
                if ( k == val[cnr] ) {
                    splay( cnr );
                    return res + 1;
                }
                res += cnt[cnr];
                cnr = ch[cnr][1];
            }
        }
    }
    int kth( int k ) {
        int cnr = rt;
        while ( 1 ) {
            if ( ch[cnr][0] && k <= sz[ch[cnr][0]] ) {
                cnr = ch[cnr][0];
            }
            else {
                k -= cnt[cnr] + sz[ch[cnr][0]];
                if ( k <= 0 )
                    return val[cnr];
                cnr = ch[cnr][1];
            }
        }
    }
    int pre() {
        int cnr = ch[rt][0];
        while ( ch[cnr][1] )
            cnr = ch[cnr][1];
        return cnr;
    }
    int nxt() {
        int cnr = ch[rt][1];
        while ( ch[cnr][0] )
            cnr = ch[cnr][0];
        return cnr;
    }
    void del( int k ) {
        rk( k );
        if ( cnt[rt] > 1 ) {
            cnt[rt]--;
            maintain( rt );
            return;
        }
        if ( !ch[rt][0] && !ch[rt][1] ) {
            clear( rt );
            rt = 0;
            return;
        }
        if ( !ch[rt][0] ) {
            int cnr = rt;
            rt = ch[rt][1];
            fa[rt] = 0;
            clear( cnr );
            return;
        }
        if ( !ch[rt][1] ) {
            int cnr = rt;
            rt = ch[rt][0];
            fa[rt] = 0;
            clear( cnr );
            return;
        }
        int x = pre(), cnr = rt;
        splay( x );
        fa[ch[cnr][1]] = x;
        ch[x][1] = ch[cnr][1];
        clear( cnr );
        maintain( rt );
    }
    int pre( int x ) {
        int ans;
        ins( x );
        ans = val[pre()];
        del( x );
        return ans;
    }
    int nxt( int x ) {
        int ans;
        ins( x );
        ans = val[nxt()];
        del( x );
        return ans;
    }
    bool empty() {
        return tot == 0;
    }
    int size() {
        return tot;
    }
};
Last modification:December 14, 2019
如果您觉得我的文章有用,给颗糖糖吧~