splay学习笔记

splay大法好。

颓了很长时间,最近什么都没做。算是迷失了方向,又找到了方向吧。曾对这个世界感到失望,但还是要有希望,有理想和目标。继续加油吧。

遇见了,不后悔遇见就好。

颓废有用的话,也就不会这么难受了。

一定要有目标。不要管目标能不能达成,只要知道自己努力了就好。

到头来回忆起的,是那段艰苦的岁月,而不是最终的瞬时的胜利。

那些努力追求的东西,都是为了某一个初心。不能放弃初心。

人要成长的啊。想想自己要是荒废了时间,那多亏啊。

累的睡着,睡到天亮。太阳照在身上的时候,感觉很幸福。

要坚定这种幸福一直都在。发现世界还是这么美好的时候,自己是没有理由崩溃的。

——壮壮

那么,让我们来看splay吧。

splay是一种平衡树。在权值上满足二叉排序树的性质,即左儿子小于自己,右儿子大于自己。通过旋转保持平衡。不像treap的一个权值还满足堆的性质,这里只有权值的大小。

splay是一种自适应的数据结构。就是说,树的结构会随着操作的不同而变得更易于操作。比如说,如果插入的次数很多,插入就会变得更快。如果修改的速度很多,修改的速度就会变快。听起来非常神奇,可能需要进一步理解。

splay的速度很慢,通过势能分析可以说明,大部分操作是均摊$O(\log n)$的。但事实上,跑起来比大多数的平衡树都慢。不过splay的优势是非常灵活,可以支持很多操作,但写起来很长。其实除了LCT,其他的操作都是可以替代的。不过它依然是非常重要的数据结构,而且出题人也不会去卡splay,写LCT也要用到,当然要好好学啦。

这里给出指针的版本。有时候一些数据结构也可以用数组写,还是都掌握比较好。

核心操作

Node的定义

Node中存了当前节点的权值、子树大小、这个数字出现的次数和父亲指针、儿子指针。

由于很多操作都是直接针对节点的,因此在Node中有很多函数。不过有简单的几个方法。

maintain:更新子树大小。
relation:和父亲的关系。0表示左儿子或不存在父亲,1表示右儿子。

1
2
3
4
5
6
7
8
9
10
11
12
13
struct Node *root;
struct Node {
int v, cnt, siz;
Node ch[2], *fa;
Node() {}
Node(int x, Node *f) : v(x), cnt(1), siz(1), fa(f) {}
void maintain() {
siz = cnt + (ch[0] ? ch[0]->siz : 0) + (ch[1] ? ch[1]->siz : 0);
}
void rel() {
return fa ? fa->ch[1] == this : 0;
}
}

rotate

旋转是splay最重要、最基本的操作了。也是保证splay复杂度正确,维护树形态的必要操作。

目的是改变父子关系,调换父子的位置,把当前节点上移一位。

改变树的形态的基础上,保证以下两点:

  • 树的中序遍历。即父子的大小关系和整个序列的值。
  • 受影响的节点的size都依然有效。

操作步骤:

  • 将自身连接到祖父上,替代父亲;
  • 将和父亲关系相反的孩子与父亲连接,替代自身;
  • 将父亲连接到上一步的孩子的位置;
  • 如果此时自身为根,则更新root

kyr1no学长写了神奇的link函数,使代码变得更简单。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
void link(Node *o, Node *f, int r) {
if (o) o->fa = f;
if (f) f->ch[r] = o;
}

void rotate {
int r = rel();
Node f = *fa;
link(this, f->fa, f->rel());
link(ch[r ^ 1], f, r);
link(f, this, r ^ 1);
f->maintain();
maintain();
if (!fa)
root = this;
}

splay

顾名思义,这也是splay的核心函数,也被称为伸展操作。

目的是把当前节点旋转到根,维护整棵树的形态。

但是直接死循环判断是不可取的。一条链旋转以后还是一条链,无法保证复杂度。因此采取一些方法让其更好的旋转。

发现了一张图片可以说明:

操作步骤:

  • 如果父亲是目标节点,直接旋转;
  • 如果父亲与祖父的关系和自己与父亲的关系相同,就先旋转父亲,再旋转自身;
  • 否则就将自身旋转两次。
1
2
3
4
5
6
7
8
9
void splay(Node *tar = NULL) { // 目标节点的父亲节点
while (fa != tar) {
if (fa->fa == tar)
rotate();
else if (rel() == fa->rel())
fa->rotate(), rotate();
else rotate(), rotate();
}
}

基本操作

splay支持的操作很多,这里主要是平衡树的最基本操作。

insert

在树中插入一个数字。

找到这个数的位置,如果这个数字出现过,就加上$1$,否则就新建这个节点。

有双指针的迭代版本和递归版,递归版的稍微快一点。可以省略沿途的maintain,因为最后splay的时候会按照原路maintain,直到root。

双指针迭代版:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
Node *ins(int x) {
Node **o = &root, *fa = NULL;
while (*o && (*o)->v != x) {
fa = *o;
//++fa->siz;
o = &fa->ch[x > fa->v]; // ***
}
if (*o)
++(*o)->cnt;
//++(*o)->siz;
else
*o = new (tcur++) Node(fa, x);
(*o)->splay();
return root;
}

递归版:

1
2
3
4
5
6
7
8
Node ins(Node *&o, Node *f, int x) {
if (!o)
o = new (tcur++) Node(f, x);
//++o->siz;
if (o->v == x)
return ++o->cnt, o;
return ins(o->ch[x > o->v], o, x);
}

delete(erase)

支持区间删除和单点删除。

把左端点的前驱splay到根,再把右端点的后继splay到根的右子树,再删除右端点后继的左儿子就好啦。实际上就包括了左端点到右端点的全部区间。和treap还是很像的。

如果是单点删除,要注意是否多次出现,如果多次出现就把出现次数减1。

1
2
3
4
5
6
7
8
9
10
11
12
13
void del(Node *o) {
Node *p = o->pre(), *s = o->suc();
p->splay(), s->splay(p); // ***
if (o->cnt > 1)
--o->cnt, --o->siz;
else s->ch[0] = NULL;
s->upd(), p->upd();
}

void del(int x) {
Node *o = find(x);
if (o) del(o);
}

build

和线段树的build差不多,等做到的时候再写吧。

find

在树中找一个数字。如果找到记得splay。

1
2
3
4
5
6
7
Node *find(int x) {
Node *o = root;
while (o && o->v != x)
o = o->ch[x > o->v];
if (o) o->splay();
return o;
}

rank

由于哨兵节点的存在,只需要splay后返回左子树的值就好啦。要是改成查询某一个数字的rank的话,和predecessor、successor都一样啦。

节点内部:

1
2
3
int rnk() {
return ch[0] ? ch[0]->siz : 0;
}

节点外部:

1
2
3
4
5
6
7
8
int rnk(int x) {
Node *o = find(x);
if (o) return o->rnk();
o = ins(x);
int ans = o->rnk();
del(o);
return ans;
}

kth(select)

查询第$k$大。很多地方写的是select。就是在树上二分一下。

注意由于哨兵节点的存在,只需要左子树的大小等于$k$时,这个数即为第$k$大。

1
2
3
4
5
6
7
8
9
10
11
12
int kth(int k) {
Node *o = root;
while (true) {
if (o->rnk() > k)
o = o->ch[0];
else if (o->rnk() + o->cnt <= k)
k -= o->rnk() + o->cnt, o = o->ch[1]; // ***
else break;
}
o->splay(); // ***
return o->v;
}

predecessor

有两种实现,分别是找到前驱节点和找到前一个数字。
前驱节点当然是比它小的最大值。先splay到根,再取左子树的最大值就好啦。一般调用的时候都会先find,而在find的时候就已经splay过了,这里可以不重复splay。

节点版本:

1
2
3
4
5
6
7
Node *pre(Node *o) {
//o->splay();
o = o->ch[0];
while (o->ch[1])
o = o->ch[1];
return o;
}

数字版本:

1
2
3
4
5
6
7
8
int pre(int x) {
Node *o = find(x);
if (o) return o->pre()->v;
o = ins(x);
int ans = o->pre()->v;
del(o);
return ans;
}

successor

和pred一样啦。找右子树的最小值就好啦。

节点版本:

1
2
3
4
5
6
7
Node *suc(Node *o) {
//o->splay();
o = o->ch[1];
while (o->ch[0])
o = o->ch[0];
return o;
}

数字版本:

1
2
3
4
5
6
7
8
int suc(int x) {
Node *o = find(x);
if (o) return o->suc()->v;
o = ins(x);
int ans = o->suc()->v;
del(o);
return ans;
}

其实可以和pred合并成一个函数。

进阶操作

区间反转什么的,非常多。做到再写吧。

完整代码

P3369普通平衡树。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#include <cstdio>
#include <algorithm>
#include <climits>
using namespace std;

const int N = 1e5 + 5;

struct Node *root;
struct Node {
Node *fa, *ch[2];
int v, cnt, siz;
Node() {}
Node(Node *f, int x) : fa(f), v(x), cnt(1), siz(1) {}
void upd() {
siz = cnt + (ch[0] ? ch[0]->siz : 0) + (ch[1] ? ch[1]->siz : 0);
}
int rel() {
return fa ? fa->ch[1] == this : 0;
}
void link(Node *o, Node *f, int r) {
if (o) o->fa = f;
if (f) f->ch[r] = o;
}
void rot() {
Node *f = fa;
int r = rel();
link(this, f->fa, f->rel());
link(ch[r ^ 1], f, r);
link(f, this, r ^ 1);
f->upd(), upd();
if (!fa) root = this;
}
void splay(Node *tar = NULL) {
while (fa != tar) {
if (fa->fa == tar) rot();
else if (rel() == fa->rel())
fa->rot(), rot();
else rot(), rot();
}
}
int rnk() {
return ch[0] ? ch[0]->siz : 0;
}
Node *pre() {
Node *o = ch[0];
while (o->ch[1])
o = o->ch[1];
return o;
}
Node *suc() {
Node *o = ch[1];
while (o->ch[0])
o = o->ch[0];
return o;
}
} tpool[N], *tcur = tpool;

Node *find(int x) {
Node *o = root;
while (o && o->v != x)
o = o->ch[x > o->v];
if (o) o->splay(); // ***
return o;
}

Node *ins(Node *&o, Node *f, int x) {
if (!o)
return o = new (tcur++) Node(f, x);
if (o->v == x)
return ++o->cnt, o;
return ins(o->ch[x > o->v], o, x);
}

Node *ins(int x) {
Node *o = ins(root, NULL, x);
o->splay();
return o;
}

void del(Node *o) {
Node *p = o->pre(), *s = o->suc();
p->splay(), s->splay(p); // ***
if (o->cnt > 1)
--o->cnt, --o->siz;
else s->ch[0] = NULL;
s->upd(), p->upd();
}

void del(int x) {
Node *o = find(x);
if (o) del(o);
}

int rnk(int x) {
Node *o = find(x);
if (o) return o->rnk();
o = ins(x);
int ans = o->rnk();
del(o);
return ans;
}

int kth(int k) {
Node *o = root;
while (true) {
if (k < o->rnk())
o = o->ch[0];
else if (k >= o->rnk() + o->cnt)
k -= o->rnk() + o->cnt, o = o->ch[1];
else break;
}
o->splay();
return o->v;
}

int pre(int x) {
Node *o = find(x);
if (o) return o->pre()->v;
o = ins(x);
int ans = o->pre()->v;
del(o);
return ans;
}

int suc(int x) {
Node *o = find(x);
if (o) return o->suc()->v;
o = ins(x);
int ans = o->suc()->v;
del(o);
return ans;
}

void init() {
ins(INT_MAX);
ins(INT_MIN);
}

int main() {
int n, op, x;
init();
scanf("%d", &n);
while (n--) {
scanf("%d%d", &op, &x);
switch(op) {
case 1: ins(x);break;
case 2: del(x);break;
case 3: printf("%d\n", rnk(x));break;
case 4: printf("%d\n", kth(x));break;
case 5: printf("%d\n", pre(x));break;
case 6: printf("%d\n", suc(x));break;
}
}
}