0%

主席树学习

主席树学习心得

前置技能

  • 前缀和

  • 线段树

  • 离散化

​ 对于一颗线段树,我们可以维护很多东西, 区间和、区间最大值、区间最小值等等。假设现在的线段数维护的是区间[1,N] (N<=1e6) 之间每个数字出现的次数。

​ 假设序列[2,5,1,4,2,5,5,3],对于每前i(0<=i<=n)个数,维护区间[1,N]每个数出现次数。

对于i=0, 每个数字出现0次,即第0颗线段树结点sum全为0

对于i=1, 数字2出现次数+1,其它数字出现次数相较于i=0不变,那么$Seg_{i}$与$Seg_{i-1}$的区别只有一条链,即2所在的叶子结点到根节点,在该链上,$Seg_{i}$与$Seg_{i-1}$的区别只有值不同,并且值只相差一,可用画图体验。

。 。 。

对于i=5, 数字2出现次数+1,其它数字出现次数相较与i=4不变,与i=1同理。

对于i=6,…,n,都与上同理。

结论:每颗$Seg_{i}$与$Seg_{i-1}$只有$log_{N-1}+1$个结点不同,对于这n+1颗线段树,可用$nlog_{n}+nlog_{n}$即$2nlog_{n}$个结点存储,不需要将n颗线段树完全存下来,完全存下来的空间为$n^{2}log_{n}$,当n为10000时,空间会炸,故采用$2nlog_{n}$的空间存下所有的线段树,就是主席树。

区间查询第k小:明白每颗线段树存储内容后,很容易得出区间[L,R]的线段树,每个结点的值就是第R颗线段树结点值减去第L-1颗线段树结点值,于是查询就和一颗线段树一样的操作。

实现

  • 数据结构

    int maxn = 1e5; 序列长度。

    int maxm = maxn * 40; 所需的空间 $2nlog_{n}$。

    int T[maxn]; T[i]表示第i颗线段树的根节点编号。

    int R[maxm]; R[i]表示节点i的右儿子编号。

    int L[maxm]; L[i]表示节点i的左儿子编号。

    int sum[maxm]; sum[i]表示节点i的值,叶子结点的sum[j]对于该叶子结点表示数字的出现次数。

    int tot; 结点总数。

  • 成员函数

    void build(int &rt, int l, int r); 用于构造第0颗线段树。

    void update(int &rt,int pre,int v,int l, int r); 用于构造主席树,每次在前一颗线段树的基础上增加一条链。

    int query(int r1, int r2, int k, int l, int r); 用于序列区间[r1, r2]的第k大值, 该值可能被离散化。

  • 具体代码

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    /*
    模板来自于quincy.ink
    Template of ACM中的主席树。
    已验证。
    */

    #include <cstdio>
    #include <algorithm>
    #define ll long long
    #define inf 0x3f3f3f3f
    #define eps 1e-6
    #define pi acos(-1)
    using namespace std;
    const int maxn=1e5+10;
    const int maxm=maxn*40;
    int T[maxn],L[maxm],R[maxm],sum[maxm];
    int sz[maxn],h[maxn]; //sz为原序列,h为离散化之后的序列
    int n,q,ql,qr,k,tot;
    void build(int& rt, int l, int r){
    rt = ++tot;
    sum[rt] = 0;
    if(l==r) return;
    int mid = l + r >> 1;
    build(L[rt], l, mid);
    build(R[rt], mid+1, r);
    }
    void update(int &rt, int pre, int l, int r, int v){
    rt = ++tot;
    L[rt] = L[pre];
    R[rt] = R[pre];
    sum[rt] = sum[pre] + 1;
    if(l == r) return ;
    int mid = l + r >> 1;
    if(v <= mid) update(L[rt], L[pre], l, mid, v);
    else update(R[rt], R[pre], mid+1, r, v);
    }
    int query(int r1, int r2, int k, int l, int r){
    if(l == r) return l;
    int mid = l + r >> 1;
    int kk = sum[L[r2]] - sum[L[r1]];
    if(k <= kk) return query(L[r1], L[r2], k, l, mid);
    return query(R[r1], R[r2], k-kk, mid+1, r);
    }
    int main(){
    tot = 0;
    scanf("%d %d", &n, &q);
    for(int i=1;i<=n;++i){
    scanf("%d", &sz[i]);
    h[i] = sz[i];
    }
    sort(h+1, h+1+n);
    int num = unique(h+1, h+1+n) - (h+1);
    build(T[0], 1, num);
    for(int i=1;i<=n;++i) update(T[i],T[i-1],1,num,lower_bound(h+1, h+1+num,sz[i])-h);

    while(q--){
    int ql, qr, k;
    scanf("%d %d %d", &ql, &qr, &k);
    printf("%d\n", h[query(T[ql-1], T[qr], k, 1, num)]);
    }

    }