主席树学习心得
前置技能
对于一颗线段树,我们可以维护很多东西, 区间和、区间最大值、区间最小值等等。假设现在的线段数维护的是区间[1,N] (N<=1e6) 之间每个数字出现的次数。
假设序列[2,5,1,4,2,5,5,3],对于每前i(0<=i<=n)个数,维护区间[1,N]每个数出现次数。
对于i=0, 每个数字出现0次,即第0颗线段树结点sum全为0
对于i=1, 数字2出现次数+1,其它数字出现次数相较于i=0不变,那么$Seg_{i}$与$Seg_{i-1}$的区别只有一条链,即2所在的叶子结点到根节点,在该链上,$Seg_{i}$与$Seg_{i-1}$的区别只有值不同,并且值只相差一,可用画图体验。
。 。 。
对于i=5, 数字2出现次数+1,其它数字出现次数相较与i=4不变,与i=1同理。
对于i=6,…,n,都与上同理。
结论:每颗$Seg_{i}$与$Seg_{i-1}$只有$log_{N-1}+1$个结点不同,对于这n+1颗线段树,可用$nlog_{n}+nlog_{n}$即$2nlog_{n}$个结点存储,不需要将n颗线段树完全存下来,完全存下来的空间为$n^{2}log_{n}$,当n为10000时,空间会炸,故采用$2nlog_{n}$的空间存下所有的线段树,就是主席树。
区间查询第k小:明白每颗线段树存储内容后,很容易得出区间[L,R]的线段树,每个结点的值就是第R颗线段树结点值减去第L-1颗线段树结点值,于是查询就和一颗线段树一样的操作。
实现
数据结构
int maxn = 1e5; 序列长度。
int maxm = maxn * 40; 所需的空间 $2nlog_{n}$。
int T[maxn]; T[i]表示第i颗线段树的根节点编号。
int R[maxm]; R[i]表示节点i的右儿子编号。
int L[maxm]; L[i]表示节点i的左儿子编号。
int sum[maxm]; sum[i]表示节点i的值,叶子结点的sum[j]对于该叶子结点表示数字的出现次数。
int tot; 结点总数。
成员函数
void build(int &rt, int l, int r); 用于构造第0颗线段树。
void update(int &rt,int pre,int v,int l, int r); 用于构造主席树,每次在前一颗线段树的基础上增加一条链。
int query(int r1, int r2, int k, int l, int r); 用于序列区间[r1, r2]的第k大值, 该值可能被离散化。
具体代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62/*
模板来自于quincy.ink
Template of ACM中的主席树。
已验证。
*/
using namespace std;
const int maxn=1e5+10;
const int maxm=maxn*40;
int T[maxn],L[maxm],R[maxm],sum[maxm];
int sz[maxn],h[maxn]; //sz为原序列,h为离散化之后的序列
int n,q,ql,qr,k,tot;
void build(int& rt, int l, int r){
rt = ++tot;
sum[rt] = 0;
if(l==r) return;
int mid = l + r >> 1;
build(L[rt], l, mid);
build(R[rt], mid+1, r);
}
void update(int &rt, int pre, int l, int r, int v){
rt = ++tot;
L[rt] = L[pre];
R[rt] = R[pre];
sum[rt] = sum[pre] + 1;
if(l == r) return ;
int mid = l + r >> 1;
if(v <= mid) update(L[rt], L[pre], l, mid, v);
else update(R[rt], R[pre], mid+1, r, v);
}
int query(int r1, int r2, int k, int l, int r){
if(l == r) return l;
int mid = l + r >> 1;
int kk = sum[L[r2]] - sum[L[r1]];
if(k <= kk) return query(L[r1], L[r2], k, l, mid);
return query(R[r1], R[r2], k-kk, mid+1, r);
}
int main(){
tot = 0;
scanf("%d %d", &n, &q);
for(int i=1;i<=n;++i){
scanf("%d", &sz[i]);
h[i] = sz[i];
}
sort(h+1, h+1+n);
int num = unique(h+1, h+1+n) - (h+1);
build(T[0], 1, num);
for(int i=1;i<=n;++i) update(T[i],T[i-1],1,num,lower_bound(h+1, h+1+num,sz[i])-h);
while(q--){
int ql, qr, k;
scanf("%d %d %d", &ql, &qr, &k);
printf("%d\n", h[query(T[ql-1], T[qr], k, 1, num)]);
}
}