线段树学习笔记

近日新学习了线段树,总结一下它的用法和注意事项。

  1. querymodify都不需要取$mid$,只有build需要。

  2. 有取模时应该每步取模。

  3. 区间的端点现算出来而不记录会更快,可以记录区间的长度。

  4. 如果有多种操作且不可叠加时pushdown放在前面,且特判叶结点;如果操作可叠加或者只有一种就放在后面。

模板1

已知一个数列,进行两种操作:

  1. 将某区间每一个数加上一个数;
  2. 求出某区间每一个数的和;

区间修改,区间查询,最简单的模板。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#include <cstdio>
#include <iostream>
typedef long long ll;
const int N = 100005;

int n, m;
ll a[N], ret;

struct Node {
int len;
Node *lc, *rc;
ll sum, tag;

Node() {}
Node(ll val) : len(1), lc(NULL), rc(NULL), sum(val), tag(0) {}
Node(Node *lc, Node *rc, int l) : len(l), lc(lc), rc(rc), tag(0) {
sum = lc->sum + rc->sum;
}

void add(ll d) {
sum += len * d;
tag += d;
}

void pushdown() {
if (tag) {
lc->add(tag);
rc->add(tag);
tag = 0;
}
}

void modify(int l, int r, int nl, int nr, ll d) {
if (nr < l || r < nl)
return;
if (l <= nl && nr <= r) {
add(d);
return;
}
pushdown();
int mid = (nl + nr) >> 1;
lc->modify(l, r, nl, mid, d);
rc->modify(l, r, mid + 1, nr, d);
sum = lc->sum + rc->sum;
}

void query(int l, int r, int nl, int nr) {
if (nr < l || r < nl)
return;
if (l <= nl && nr <= r) {
ret += sum;
return;
}
pushdown();
int mid = (nl + nr) >> 1;
lc->query(l, r, nl, mid);
rc->query(l, r, mid + 1, nr);
}
} *segt, tpool[N << 1], *tcur = tpool;

Node *build (int l, int r) {
if (l == r)
return new(tcur++) Node(a[l]);
int mid = (l + r) >> 1;
return new(tcur++) Node(build(l, mid), build(mid + 1, r), r - l + 1);
}

int main() {
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; ++i)
scanf("%lld", &a[i]);
segt = build(1, n);
int s, x, y;
ll k;
for (int i = 1; i <= m; ++i) {
scanf("%d%d%d", &s, &x, &y);
if (s == 1) {
scanf("%lld", &k);
segt->modify(x, y, 1, n, k);
} else {
ret = 0;
segt->query(x, y, 1, n);
printf("%lld\n", ret);
}
}
return 0;
}

模板2

已知一个数列,进行两种操作:

  1. 将某区间每一个数乘上一个数;
  2. 将某区间每一个数加上一个数;
  3. 求出某区间每一个数的和

区间修改,区间查询,修改有加有乘。

在结点维护$kx+b$,$k$和$b$分开算,在乘$k$时把$k$和$b$都乘$k$,$k$初始化为$1$。

注意把pushdown放在前面,每次修改前先下传标记。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#include <iostream>
#include <cstdio>
typedef long long ll;
const int N = 100005;
int n, m;
ll MOD, a[N];

struct Node {
int l, r;
Node *lc, *rc;
ll sum, k, b;
Node() {}
Node(int pos, ll val) : l(pos), r(pos), lc(NULL), rc(NULL), sum(val), k(1), b(0) {}
Node(Node *lc, Node *rc) : l(lc->l), r(rc->r), lc(lc), rc(rc), k(1), b(0) {
sum = (lc->sum + rc->sum) % MOD;
}

void addk(ll d) {
sum = sum * d % MOD;
k = k * d % MOD;
b = b * d % MOD;
}

void addb(ll d) {
sum = (sum + (r - l + 1) * d) % MOD;
b = (b + d) % MOD;
}

void pushdown() {
if (k != 1) {
lc->addk(k);
rc->addk(k);
k = 1;
}
if (b) {
lc->addb(b);
rc->addb(b);
b = 0;
}
}

void modify(int s, int l, int r, ll d) {
if (r < this->l || this->r < l)
return;
if (this->l != this->r)
pushdown();
if (l <= this->l && this->r <= r) {
if (s == 1)
addk(d);
else
addb(d);
return;
}
lc->modify(s, l, r, d);
rc->modify(s, l, r, d);
sum = (lc->sum + rc->sum) % MOD;
}

ll query(int l, int r) {
if (r < this->l || this->r < l)
return 0;
if (this->l != this->r)
pushdown();
if (l <= this->l && this->r <= r)
return sum;
return (lc->query(l, r) + rc->query(l, r)) % MOD;
}
} *root, tpool[N << 1], *tcur = tpool;

Node *build(int l, int r) {
if (l == r)
return new (tcur++) Node(l, a[l]);
int mid = (l + r) >> 1;
return new (tcur++) Node(build(l, mid), build(mid + 1, r));
}

int main() {
scanf("%d%d%lld", &n, &m, &MOD);
for (int i = 1; i <= n; ++i)
scanf("%lld", &a[i]);
root = build(1, n);
int p, x, y;
ll d;
for (int i = 1; i <= m; ++i) {
scanf("%d%d%d", &p, &x, &y);
if (p == 3)
printf("%lld\n", root->query(x, y));
else {
scanf("%lld", &d);
root->modify(p, x, y, d);
}
}
return 0;
}

模板3

已知一个数列,进行5种操作:

  1. 把区间内的所有数都增加一个数;
  2. 把区间内的所有数都设为一个数;
  3. 查询区间的区间和;
  4. 查询区间的最大值;
  5. 查询区间的最小值。

要维护的量比较多,注意setadd的顺序。

这里放了kyr1no学长的代码$Orz$。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
#include <bits/stdc++.h>
typedef long long ll;
typedef const int cint;
typedef const long long cll;
typedef const char cchar;
#define daze << '\n'

template <cint LI, cint LO>
struct IO {
char a[LI], b[LO], r[LO], *s, *t, *z, c;
std::streambuf *fbi, *fbo;
IO() : z(b) {
std::ios::sync_with_stdio(false);
if (LI) std::cin.tie(NULL), fbi = std::cin.rdbuf();
if (LO) std::cout.tie(NULL), fbo = std::cout.rdbuf();
}
~IO() { if (LO) fbo->sputn(b, z - b); }
char gc() {
if (s == t) t = (s = a) + fbi->sgetn(a, LI);
return s == t ? EOF : *s++;
}
template <class T>
IO &operator >> (T &x) {
for (c = gc(); c != '-' && !isdigit(c); c = gc());
bool f = c == '-';
x = (f ? gc() : c) - '0';
for (c = gc(); isdigit(c); c = gc())
x = x * 10 + (c - '0');
if (f) x = -x;
return *this;
}
char *gs(char *x) {
for (c = gc(); !isgraph(c); c = gc());
for (*x++ = c, c = gc(); isgraph(c); *x++ = c, c = gc());
return *x = 0, x;
}
IO &operator >> (char *x) {
for (c = gc(); !isgraph(c); c = gc());
for (*x++ = c, c = gc(); isgraph(c); *x++ = c, c = gc());
return *x = 0, *this;
}
IO &operator >> (char &x) {
for (x = gc(); !isgraph(x); x = gc());
return *this;
}
template <class T>
operator T () { T x; *this >> x; return x; }
void pc(cchar x) {
if (z == b + LO) fbo->sputn(z = b, LO);
*z++ = x;
}
void fl() {
fbo->sputn(b, z - b);
z = b;
}
template <class T>
IO &operator << (T x) {
if (x == 0) return pc('0'), *this;
if (x < 0) pc('-'), x = -x;
char *j = r;
for (T y; x; x = y) y = x / 10, *j++ = x - y * 10 + '0';
while (j != r) pc(*--j);
return *this;
}
IO &operator << (char *x) {
while (*x) pc(*x++);
return *this;
}
IO &operator << (cchar *x) {
while (*x) pc(*x++);
return *this;
}
IO &operator << (cchar x) { return pc(x), *this; }
};
IO<1000000, 1000000> io;

cint N = 100003;

int n;
ll ret;
inline ll fsum(cll x, cll y) {
return x + y;
}
struct Node {
Node *lc, *rc;
int len;
ll sum, min, max, tgs, tga; // tgs goes first
Node() {}
Node(cll x) : lc(NULL), rc(NULL), len(1), sum(x), min(x), max(x), tgs(LLONG_MIN), tga(0) {}
Node(Node *l, Node *r, cint le) : lc(l), rc(r), len(le), tgs(LLONG_MIN), tga(0) {
maintain();
}
void maintain() {
sum = lc->sum + rc->sum;
min = std::min(lc->min, rc->min);
max = std::max(lc->max, rc->max);
}
void cover_s(cll x) {
sum = x * len;
min = max = tgs = x;
tga = 0;
}
void cover_a(cll x) {
sum += x * len;
min += x;
max += x;
tga += x;
}
void push_down() {
if (tgs != LLONG_MIN) {
lc->cover_s(tgs);
rc->cover_s(tgs);
tgs = LLONG_MIN;
}
if (tga) {
lc->cover_a(tga);
rc->cover_a(tga);
tga = 0;
}
}
#define MODIFY_FUNC(func, coverrer) \
void func(cint ql, cint qr, cll x, cint l = 1, cint r = n) { \
if (qr < l || r < ql) \
return; \
if (ql <= l && r <= qr) \
return coverrer(x); \
push_down(); \
int mid = (l + r) >> 1; \
if (ql <= mid) \
lc->func(ql, qr, x, l, mid); \
if (qr > mid) \
rc->func(ql, qr, x, mid + 1, r); \
maintain(); \
}
#define QUERY_FUNC(func, attr, opt) \
void func(cint ql, cint qr, cint l = 1, cint r = n) { \
if (qr < l || r < ql) \
return; \
if (ql <= l && r <= qr) { \
ret = opt(ret, attr); \
return; \
} \
push_down(); \
int mid = (l + r) >> 1; \
if (ql <= mid) \
lc->func(ql, qr, l, mid); \
if (qr > mid) \
rc->func(ql, qr, mid + 1, r); \
}
MODIFY_FUNC(add, cover_a);
MODIFY_FUNC(set, cover_s);
QUERY_FUNC(qmin, min, std::min);
QUERY_FUNC(qmax, max, std::max);
QUERY_FUNC(qsum, sum, fsum);
} *segt;
Node *build(cint l, cint r) {
static Node pool[N << 1], *curr = pool;
if (l == r)
return new (curr++) Node((ll)io);
int mid = (l + r) >> 1;
Node *lc = build(l, mid);
return new (curr++) Node(lc, build(mid + 1, r), r - l + 1);
}

int main() {
int m;
io >> n >> m;
segt = build(1, n);
char opt[6];
while (m--) {
int l, r;
io >> opt >> l >> r;
if (opt[0] == 's') {
if (opt[1] == 'u')
ret = 0, segt->qsum(l, r), io << ret daze;
else
segt->set(l, r, io);
} else if (opt[0] == 'm') {
if (opt[1] == 'i')
ret = LLONG_MAX, segt->qmin(l, r);
else
ret = LLONG_MIN, segt->qmax(l, r);
io << ret daze;
} else
segt->add(l, r, io);
}
}