文章目录
- 前言
- 二叉搜索树
- 代码
- treap
- 代码
- splay
- 开点
- 旋转
- splay
- 插入
- 查找第k大元素
- 查找给定元素的排名
- 前驱&后继
- 删除
- 完整代码
- 练习总结
前言
终于开始学这个东西了
看了好几篇博客才找到一篇可读的qwq
我曾经还以为线段树码量大…我真傻,真的
所谓平衡树,就是把二叉搜索树加了一个随机权值
并通过旋转使这个权值始终符合堆的性质
(treap=tree+heap)
我觉得平衡树主要的功能就是维护排名相关的东西
(update:更正观点!平衡树最好用的地方还是区间问题,排名问题在序列上可以主席树,动态的可以树状数组,为啥要写splay这种东西…)
前驱后继这些其实都可以直接拿set偷懒
(当然本刚学treap1h的蒟蒻的理解完全没有参考价值)
一开始WA成了60分qwq
千万注意一定不要落掉无处不在的pushup!
二叉搜索树
不学BST,何以treap? ——鲁迅
二叉搜索树是一种二叉树的树形数据结构,其定义如下:
-
空树是二叉搜索树。
-
若二叉搜索树的左子树不为空,则其左子树上所有点的附加权值均小于其根节点的值。
-
若二叉搜索树的右子树不为空,则其右子树上所有点的附加权值均大于其根节点的值。
-
二叉搜索树的左右子树均为二叉搜索树。
二叉搜索树上的基本操作所花费的时间与这棵树的高度成正比。对于一个有 n个结点的二叉搜索树中,这些操作的最优时间复杂度为 Ologn,最坏为On。随机构造这样一棵二叉搜索树的期望高度为logn。
代码
#include<bits/stdc++.h>
using namespace std;
const int N=1e5+100;
#define ll long long
int n,m,k;
int x,y;
int cnt[N],ls[N],rs[N],val[N],siz[N],tot=1;
void insert(int &o,int v){//插入元素if(!o){o=++tot;val[o]=v;ls[o]=rs[o]=0;siz[o]=cnt[o]=1;}siz[o]++;if(val[o]==v) {cnt[o]++;return;}if(v<val[o]) insert(ls[o],v);else insert(rs[o],v);
}
int delmin(int &o){if(!ls[o]){int u=o;o=rs[o];return u;}else{int u=delmin(ls[o]);siz[o]-=cnt[u];return u;}
}
void del(int &o,int v){//删除元素siz[o]--;if(val[o]==v){if(cnt[o]>1) cnt[o]--;else if(ls[o]&&rs[o]) o=delmin(rs[o]);else o=ls[o]+rs[o];return;}if(v<val[o]) del(ls[o],v);else del(rs[o],v);
}
int askrank(int o,int v){//查询x的排名if(val[o]==v) return siz[ls[o]]+1;else if(val[o]>v) return askrank(ls[o],v);else return siz[ls[o]]+cnt[o]+askrank(rs[o],v);
}
int asknth(int o,int k){//查询第k大的元素if(siz[ls[o]]>=k) return asknth(ls[o],k);else if(siz[ls[o]]+cnt[o]>=k) return val[o];else return asknth(rs[o],k-(siz[ls[o]]+cnt[o]));
}
int main(){scanf("%d",&n);int flag;val[1]=-2e9;int r=1;for(int i=1;i<=n;i++){scanf("%d%d",&flag,&x);if(flag==1) insert(r,x);else if(flag==2) del(r,x);else if(flag==3) printf("%d\n",askrank(1,x));else if(flag==4) printf("%d\n",asknth(1,x));}return 0;
}
/**/
treap
旋转是平衡树的灵魂
一个很重要的技巧是利用0/1存储左右儿子
这样在旋转的时候写起来会容易很多
代码
#include<bits/stdc++.h>
using namespace std;
const int N=1e6+2e5+100;
#define ll long long
int n,m,k;
int x,y;
int cnt[N],ch[N][2],val[N],siz[N],tot,r,dat[N];
int New(int v){val[++tot]=v;dat[tot]=rand();ch[tot][0]=ch[tot][1]=0;siz[tot]=cnt[tot]=1;return tot;
}
void pushup(int o){siz[o]=siz[ch[o][0]]+siz[ch[o][1]]+cnt[o];
}
void build(){r=New(-2e9);ch[1][1]=New(2e9);pushup(r);
}
void rotate(int &o,int d){int temp=ch[o][!d];ch[o][!d]=ch[temp][d];ch[temp][d]=o;o=temp;pushup(o);pushup(ch[o][d]);
}
void insert(int &o,int v){if(!o){o=New(v);return;}if(v==val[o]){cnt[o]++;pushup(o);return;}int d= v>val[o];insert(ch[o][d],v);if(dat[ch[o][d]]>dat[o]) rotate(o,!d);pushup(o);
}
void del(int &o,int v){if(!o) return;if(v==val[o]){if(cnt[o]>1){cnt[o]--;pushup(o);return;}if(ch[o][0]||ch[o][1]){int d=!ch[o][1]||dat[ch[o][1]]<dat[ch[o][0]];rotate(o,d);del(ch[o][d],v);pushup(o);}else o=0;return;}if(v<val[o]) del(ch[o][0],v);else del(ch[o][1],v);pushup(o);
}
int getrank(int o,int v){if(!o) return 1;if(val[o]==v) return siz[ch[o][0]]+1;else if(v<val[o]) return getrank(ch[o][0],v);else return getrank(ch[o][1],v)+siz[ch[o][0]]+cnt[o];
}
int getnth(int o,int k){if(!o) return 2e9;if(siz[ch[o][0]]>=k) return getnth(ch[o][0],k);else if(siz[ch[o][0]]+cnt[o]>=k) return val[o];else return getnth(ch[o][1],k-(siz[ch[o][0]]+cnt[o]));
}
int getpre(int v){int res=-2e9,p=r;while(p){if(val[p]<v){res=val[p];p=ch[p][1];}else p=ch[p][0];}return res;
}
int getnxt(int v){int res=2e9,p=r;while(p){if(val[p]>v){res=val[p];p=ch[p][0];}else p=ch[p][1];}return res;
}
int main(){scanf("%d%d",&n,&m);build();int flag;for(int i=1;i<=n;i++){scanf("%d",&x);insert(r,x);}int ans=0,lst=0;for(int i=1;i<=m;i++){scanf("%d%d",&flag,&x);x^=lst;if(flag==1) insert(r,x);else if(flag==2) del(r,x);else if(flag==3){int res=getrank(r,x)-1;lst=res;ans^=res;}else if(flag==4){int res=getnth(r,x+1);lst=res;ans^=res;}else if(flag==5){int res=getpre(x);lst=res;ans^=res;}else{int res=getnxt(x);lst=res;ans^=res;}}printf("%d\n",ans);return 0;
}
/**/
splay
看好几篇博客说splay在区间问题的功能更强大,所以也学习了splay
最后实在de不出来bug还是动用了减法原理
累死窝了qwq
这个东西真的好难debug
但是决定以后就用它了awa
当然要用强的啦
这个东西好好讲讲
开点
所有的点都是开出来的
开点还是比较正常
int New(int v,int fa){val[++tot]=v;f[tot]=fa;ch[tot][0]=ch[tot][1]=0;siz[tot]=cnt[tot]=1;return tot;
}
旋转
它也是旋转完成的,不能没有它
和treap一样啦
防止写错,总体的改变可以分三对
- x与gfa的父子关系
- fa与x的父子关系
- x的异向儿子与fa的父子关系
void rotate(int x){int fa=f[x],gfa=f[fa];int k=getwhich(x);int temp=ch[x][k^1];f[temp]=fa;ch[fa][k]=temp;f[x]=gfa;if(gfa) ch[gfa][ch[gfa][1]==fa]=x;f[fa]=x;ch[x][k^1]=fa;pushup(x);pushup(fa);
}
splay
splay怎么能不splay呢
所以我们现在讲讲splay的splay部分(停止扯淡)
splay总的来说就是把一个结点不停转转转
一直转到根的地方
沿途长链死光光
从而保证复杂度的正确性
这也是splay的精髓所在
而且跳到根也便利了我们其他的操作
有一个很关键的细节
就是当父亲和自己相对于各自父节点的方向同向时
必须要先转父亲
不然就无法达到消链的目的
这个可以自己画画图理解
(我看别人题解画的天花乱坠,最后还是自己画图才明白的)
代码极为简洁
void splay(int x){for(int fa;fa=f[x];rotate(x)){if(f[fa]) rotate((getwhich(fa)==getwhich(x))?fa:x);}r=x;
}
插入
开始干正事了
找到应该加点的位置开点
然后splay一下
注意pushup!
void insert(int v){if(!r){r=New(v,0);return;}int now=r,fa=0;while(1){if(val[now]==v){cnt[now]++;pushup(now);pushup(fa);splay(now);break;}fa=now;now=ch[now][v>val[now]];if(!now){ch[fa][v>val[fa]]=New(v,fa);pushup(fa);
// printf("\ninsert:(pre)\n");
// print();splay(tot);
// printf("\ninsert:(after)\n");
// print();break;}}
}
查找第k大元素
这个很好写
理解起来应该也不难
int findnth(int k){int now=r;while(1){if(siz[ch[now][0]]>=k) now=ch[now][0];else if(siz[ch[now][0]]+cnt[now]>=k) return val[now];else{k-=siz[ch[now][0]]+cnt[now];now=ch[now][1];}}
}
查找给定元素的排名
这个也没有太大的难度
(尽管我de了一年多之后发现就是这里写挂的)
为了后面删除元素的遍历我们找到这个元素后splay一下
int findrank(int x){int now=r,ans=0;while(1){if(!now) return ans+1;if(val[now]>x) now=ch[now][0];else if(val[now]==x){ans+=siz[ch[now][0]];//记录一下车祸现场splay(now);return ans+1;}else{ans+=siz[ch[now][0]]+cnt[now];now=ch[now][1];}}
}
前驱&后继
这里是找的根的前驱(后继)
找给定值的前驱(后继)的话就先insert进去,它自动splay到根,然后再求就行了
int findpre(){int now=ch[r][0];while(ch[now][1]) now=ch[now][1];return now;
}
int findnxt(){int now=ch[r][1];while(ch[now][0]) now=ch[now][0];return now;
}
删除
这个是重点
首先把删除的元素利用前面现成的findrank提到根上
有副本就直接删
否则看它的儿子情况
啥都没有就直接变空树了
只有一个就把那个儿子当成根
如果两个儿子都有就考虑把根的前驱提上来
因为是前驱,所以它在到x之前一定没有右儿子
也就是这样:
再转一下:
注意到待删元素一定没有左儿子
因此我们可以把B直接接到pre上达到删除的目的
也就是:
这样就ok啦
void del(int v){findrank(v);if(cnt[r]>1) {cnt[r]--;return;}else if(!ch[r][0]&&!ch[r][1]){r=0;return;}else if(!ch[r][0]){int temp=r;r=ch[r][1];f[r]=0;return;}else if(!ch[r][1]){int temp=r;r=ch[r][0];f[r]=0;return;}int temp=ch[r][1],pre=findpre(),oldr=r;splay(pre);ch[r][1]=temp;f[temp]=r;pushup(r);
}
完整代码
#include<bits/stdc++.h>
using namespace std;
const int N=1e5+100;
#define ll long long
int n,m,k;
int x,y;
int cnt[N],ch[N][2],val[N],siz[N],tot,r;
int f[N];
int New(int v,int fa){val[++tot]=v;f[tot]=fa;ch[tot][0]=ch[tot][1]=0;siz[tot]=cnt[tot]=1;return tot;
}
void pushup(int o){if(o) siz[o]=siz[ch[o][0]]+siz[ch[o][1]]+cnt[o];
}
void build(){r=New(-2e9,0);ch[1][1]=New(2e9,r);pushup(r);
}
int getwhich(int x){return ch[f[x]][1]==x;
}
void rotate(int x){int fa=f[x],gfa=f[fa];int k=getwhich(x);int temp=ch[x][k^1];f[temp]=fa;ch[fa][k]=temp;f[x]=gfa;if(gfa) ch[gfa][ch[gfa][1]==fa]=x;f[fa]=x;ch[x][k^1]=fa;pushup(x);pushup(fa);
}
void splay(int x){for(int fa;fa=f[x];rotate(x)){if(f[fa]) rotate((getwhich(fa)==getwhich(x))?fa:x);}r=x;
}
void insert(int v){if(!r){r=New(v,0);return;}int now=r,fa=0;while(1){if(val[now]==v){cnt[now]++;pushup(now);pushup(fa);splay(now);break;}fa=now;now=ch[now][v>val[now]];if(!now){ch[fa][v>val[fa]]=New(v,fa);pushup(fa);
// printf("\ninsert:(pre)\n");
// print();splay(tot);
// printf("\ninsert:(after)\n");
// print();break;}}
}
int findnth(int k){int now=r;while(1){if(siz[ch[now][0]]>=k) now=ch[now][0];else if(siz[ch[now][0]]+cnt[now]>=k) return val[now];else{k-=siz[ch[now][0]]+cnt[now];now=ch[now][1];}}
}
int findrank(int x){int now=r,ans=0;while(1){if(!now) return ans+1;if(val[now]>x) now=ch[now][0];else if(val[now]==x){ans+=siz[ch[now][0]];splay(now);return ans+1;}else{ans+=siz[ch[now][0]]+cnt[now];now=ch[now][1];}}
}
int findpre(){int now=ch[r][0];while(ch[now][1]) now=ch[now][1];return now;
}
int findnxt(){int now=ch[r][1];while(ch[now][0]) now=ch[now][0];return now;
}
void del(int v){findrank(v);if(cnt[r]>1) {cnt[r]--;return;}else if(!ch[r][0]&&!ch[r][1]){r=0;return;}else if(!ch[r][0]){int temp=r;r=ch[r][1];f[r]=0;return;}else if(!ch[r][1]){int temp=r;r=ch[r][0];f[r]=0;return;}int temp=ch[r][1],pre=findpre(),oldr=r;splay(pre);ch[r][1]=temp;f[temp]=r;pushup(r);
}
int main(){scanf("%d",&n);int flag;for(int i=1;i<=n;i++){scanf("%d%d",&flag,&x);if(flag==1) insert(x);else if(flag==2) del(x);else if(flag==3) printf("%d\n",findrank(x));else if(flag==4) printf("%d\n",findnth(x));else if(flag==5){insert(x);printf("%d\n",val[findpre()]);del(x);}else{insert(x);printf("%d\n",val[findnxt()]);del(x);}
// print();}return 0;
}
/**/
练习总结
传送门