【题目来源】
https://www.luogu.com.cn/problem/P6136
【算法分析】
Splay 树简介及代码模板:
https://blog.csdn.net/hnjzsyjyj/article/details/138504578
【代码一:含 pushdown() 函数版本】
● 本代码为洛谷 P6136 代码。题目来源为:https://www.luogu.com.cn/problem/P6136
● 洛谷 P6136 的代码一可作为 Splay 树的代码模板,它与洛谷 P3391 代码的差异仅仅在函数 get_val_by_pri() 的定义不同,其他自定义函数 pushup()、pushdown()、rotate()、splay()、insert()、find()、get_pre()、get_suc()、remove()、get_pri_by_val() 完全相同。
● 代码一中 inf 的值按题设定义为 (1<<30)+5,其对应的十进制数为 1073741824。切忌不要定义为传统的 0x3f3f3f3f,因为其对应的十进制数为 7717637477,远超 1073741824。否则,会产生 TLE 和 WA 错误。
#include <bits/stdc++.h>
using namespace std;
const int maxn=1.1e6+5;
const int inf=(1<<30)+5;
int n,m;
int root,idx;
struct Node {
int s[2],v,p; //subtree,val,root
int size,cnt;
int lazy;
} tr[maxn];
void pushup(int x) {
tr[x].size=tr[tr[x].s[0]].size+tr[tr[x].s[1]].size+tr[x].cnt;
}
void pushdown(int x) {
if(tr[x].lazy) {
swap(tr[x].s[0],tr[x].s[1]);
tr[tr[x].s[0]].lazy^=1;
tr[tr[x].s[1]].lazy^=1;
tr[x].lazy=0;
}
}
void rotate(int x) {
int y=tr[x].p;
int z=tr[y].p;
int k=(tr[y].s[1]==x);
tr[z].s[tr[z].s[1]==y]=x, tr[x].p=z;
tr[y].s[k]=tr[x].s[k^1], tr[tr[x].s[k^1]].p=y;
tr[x].s[k^1]=y, tr[y].p=x;
pushup(y), pushup(x);
}
void splay(int x,int k) {
while(tr[x].p!=k) {
int y=tr[x].p;
int z=tr[y].p;
if(z!=k) {
if((tr[y].s[0]==x)^(tr[z].s[0]==y)) rotate(x);
else rotate(y);
}
rotate(x);
}
if(!k) root=x;
}
void insert(int x) {
int u=root, p=0;
while(u && tr[u].v!=x) {
p=u;
u=tr[u].s[x>tr[u].v];
}
if(u) tr[u].cnt++;
else {
u=++idx;
if(p) tr[p].s[x>tr[p].v]=u;
tr[u].p=p, tr[u].v=x, tr[u].size=1;
tr[u].cnt=1;
}
splay(u,0);
}
void find(int x) {
int u=root;
while(tr[u].s[x>tr[u].v] && tr[u].v!=x) u=tr[u].s[x>tr[u].v];
splay(u,0);
}
int get_pre(int x) {
find(x);
if(tr[root].v<x) return root;
int u=tr[root].s[0];
while(tr[u].s[1]) u=tr[u].s[1];
splay(u,0);
return u;
}
int get_suc(int x) {
find(x);
if(tr[root].v>x) return root;
int u=tr[root].s[1];
while(tr[u].s[0]) u=tr[u].s[0];
splay(u,0);
return u;
}
void remove(int x) {
int pre=get_pre(x), suc=get_suc(x);
splay(pre,0), splay(suc,pre);
int del=tr[suc].s[0];
if(tr[del].cnt>1) tr[del].cnt--, splay(del,0);
else tr[suc].s[0]=0, splay(suc,0);
}
int get_pri_by_val(int x) {
insert(x);
int ans=tr[tr[root].s[0]].size;
remove(x);
return ans;
}
int get_val_by_pri(int x) { //apply to P6136
int u=root;
while(true) {
if(x<=tr[tr[u].s[0]].size) u=tr[u].s[0];
else if(x<=tr[tr[u].s[0]].size+tr[u].cnt) break;
else x-=tr[tr[u].s[0]].size+tr[u].cnt, u=tr[u].s[1];
}
splay(u,0);
return tr[u].v;
}
/*int get_val_by_pri(int x) { //apply to P3391
int u=root;
while(true) {
pushdown(u);
if(x<=tr[tr[u].s[0]].size) u=tr[u].s[0];
else if(x==tr[tr[u].s[0]].size+1) return u;
else x-=tr[tr[u].s[0]].size+1, u=tr[u].s[1];
}
return -1;
}*/
int main() {
insert(-inf);
insert(inf);
int n,m;
cin>>n>>m;
for(int i=1; i<=n; i++) {
int x;
cin>>x;
insert(x);
}
int ans=0, last=0;
while(m--) {
int op,x;
cin>>op>>x;
x^=last;
if(op==1) insert(x);
else if(op==2) remove(x);
else if(op==3) ans^=(last=get_pri_by_val(x));
else if(op==4) ans^=(last=get_val_by_pri(x+1));
else if(op==5) ans^=(last=tr[get_pre(x)].v);
else ans^=(last=tr[get_suc(x)].v);
}
cout<<ans<<endl;
return 0;
}
/*
in:
6 7
1 1 4 5 1 4
2 1
1 9
4 1
5 8
3 13
6 7
1 4
out:
6
*/
【代码二:不含 pushdown() 函数版本】
#include<bits/stdc++.h>
using namespace std;
const int maxn=1.1e6+5;
const int inf=(1<<30)+5;
int n,m;
int root,idx;
struct Node {
int s[2],v,p; //subtree,val,root
int size,cnt;
int lazy;
} tr[maxn];
void pushup(int u) {
tr[u].size=tr[tr[u].s[0]].size+tr[tr[u].s[1]].size+tr[u].cnt;
}
void rotate(int x) {
int y=tr[x].p;
int z=tr[y].p;
int k=(tr[y].s[1]==x);
tr[y].s[k]=tr[x].s[k^1],tr[tr[x].s[k^1]].p=y;
tr[x].s[k^1]=y,tr[y].p=x;
tr[z].s[tr[z].s[1]==y]=x,tr[x].p=z;
pushup(y);
pushup(x);
}
void splay(int x, int k) {
while(tr[x].p!=k) {
int y=tr[x].p;
int z=tr[y].p;
if(z!=k) {
if((tr[z].s[1]==y) ^ (tr[y].s[1]==x)) rotate(x);
else rotate(y);
}
rotate(x);
}
if(!k) root=x;
}
void insert(int x) {
int u=root, p=0;
while(u && tr[u].v!=x) {
p=u;
u=tr[u].s[x>tr[u].v];
}
if(u) tr[u].cnt++;
else {
u=++idx;
if(p) tr[p].s[x>tr[p].v]=u;
tr[u].p=p,tr[u].v=x,tr[u].size=1;
tr[u].cnt=1;
}
splay(u,0);
}
void find(int x) {
int u=root;
while(tr[u].s[x>tr[u].v] && tr[u].v!=x) u=tr[u].s[x>tr[u].v];
splay(u,0);
}
int get_pre(int x) {
find(x);
if(tr[root].v<x) return root;
int u=tr[root].s[0];
while(tr[u].s[1]) u=tr[u].s[1];
splay(u,0);
return u;
}
int get_suc(int x) {
find(x);
if(tr[root].v>x) return root;
int u=tr[root].s[1];
while(tr[u].s[0]) u=tr[u].s[0];
splay(u,0);
return u;
}
void remove(int x) {
int pre=get_pre(x);
int suc=get_suc(x);
splay(pre,0);
splay(suc,pre);
int del=tr[suc].s[0];
if(tr[del].cnt>1) tr[del].cnt--, splay(del,0);
else tr[suc].s[0]=0, splay(suc,0);
}
int get_pri_by_val(int x) {
insert(x);
int ans=tr[tr[root].s[0]].size;
remove(x);
return ans;
}
int get_val_by_pri(int k) {
int u=root;
while(true) {
if(k<=tr[tr[u].s[0]].size) u=tr[u].s[0];
else if(k<=tr[tr[u].s[0]].size+tr[u].cnt) break;
else k-=tr[tr[u].s[0]].size+tr[u].cnt, u=tr[u].s[1];
}
splay(u,0);
return tr[u].v;
}
int main() {
insert(-inf);
insert(inf);
int n,m;
cin>>n>>m;
for(int i=1; i<=n; i++) {
int x;
cin>>x;
insert(x);
}
int ans=0, last=0;
while(m--) {
int op,x;
cin>>op>>x;
x^=last;
if(op==1) insert(x);
else if(op==2) remove(x);
else if(op==3) ans^=(last=get_pri_by_val(x));
else if(op==4) ans^=(last=get_val_by_pri(x+1));
else if(op==5) ans^=(last=tr[get_pre(x)].v);
else ans^=(last=tr[get_suc(x)].v);
}
cout<<ans<<endl;
return 0;
}
/*
in:
6 7
1 1 4 5 1 4
2 1
1 9
4 1
5 8
3 13
6 7
1 4
out:
6
*/
【参考文献】
https://blog.csdn.net/hnjzsyjyj/article/details/138504578
https://blog.csdn.net/hnjzsyjyj/article/details/138522947