数据结构-ZKW线段树 详解
😊 | Powered By HeartFireY | ZKW Segment Tree
|
📕 | 需要的前导知识:基础线段树(Segment)、位运算
|
一、ZKW线段树简介
ZKW线段树是由清华大学张昆玮所创立的一种线段树储存结构,由于其基于非递归的实现方式以及精简的代码和较高的效率而闻名。甚至,ZKW线段树能够可持久化。
我们从算法的角度对基础线段树进行分析:其实线段树算法本身的本质仍是统计。因此我们可以从统计的角度入手对线段树进行分析:线段树是将一个个数 轴划分为区间进行处理的,因此我们面对的往往是一系列的离散量,这导致了我们在使用时的线段树单纯的退化为一棵"点树"(即最底层的线段树只包含一个点)。基于这一点可以入手对线段树进行优化
二、ZKW线段树的构造原理
首先,我们忽略线段树中的数据,从线段树的框架结构入手进行分析:如图所示是一颗采用堆式储存的基本线段树:

我们将节点编号转换为二进制:

观察转为二进制后的结点规律:在基础线段树的学习中,我们知道对于任意结点
,其左子节点为
,右子节点为
。这个规律是我们从根结点出发向叶节点寻找的规律。那么现在我们换个思路:从叶结点出发向根结点寻找规律:
- 当前结点的父节点一定是当前的结点右移一位(舍弃低位)得到的;
- 当前结点的左子节点为
,右子节点为
; - 每一层结点按照顺序排列,第
层有
个节点、 - 最后一层的结点个数=值域
因为最后一层的结点个数=值域,假设给定数组
,含有元素
。
我们约定,无论元素的个数是否达到
,最后一层的空间都开到
,无数据的叶节点空置即可。
三、ZKW线段树基本操作
1.建树操作
根据基本结构中的分析,我们可以很容易的得出建树操作过程:
inline void build(int n){
for(m = 1; m <= n; m <<= 1);
for (int i = m + 1; i <= m + n; ++i) op_array[i] = read();
for (int i = m - 1; i; --i) operation(),
}
- 如果维护区间和,那么

sum[i] = sum[i << 1] + sum[i << 1 | 1];
- 如果维护区间最小值,那么

minn[i] = min(minn[i << 1], minn[i << 1 | 1]); //不支持修改操作
minn[i] = min(minn[i << 1], minn[i << 1 | 1]),
minn[i << 1] -= minn[i], minn[i << 1 | 1] -= minn[i];
- 如果维护区间最大值,那么

maxx[i] = max(maxx[i << 1], maxx[i << 1 | 1]); //不支持修改操作
maxx[i] = max(maxx[i << 1], maxx[i << 1 | 1]),
maxx[i << 1] -= maxx[i], maxx[i << 1 | 1] -= maxx[i];
2.单点查询
这个操作是相对容易理解的,就是一个从叶子结点开始,不断向父节点走,同时累加沿路的权值的过程。
inline int query_node(int x, int ans = 0){
for (x += m; x; x >>= 1) ans += minn[s];
return ans;
}
3.单点修改
单点修改的思路非常简单,只需要修改当前结点并更新父节点即可。
void update(int x,int v){
op_array[x = m + x] += v;
while(x) operation();
}
- 如果维护区间和,那么

sum[i] = a[i << 1] + a[i << 1 | 1];
//如果单纯维护区间和,那么可以压行:
void update(int p, int k){ for (p += m; p; p >>= 1) sum[p] += k; }
- 如果维护区间最小值,那么

minn[i] = min(minn[i << 1], minn[i << 1 | 1]),
minn[i << 1] -= minn[i], minn[i << 1 | 1] -= minn[i];
- 如果维护区间最大值,那么

maxx[i] = max(maxx[i << 1], maxx[i << 1 | 1]),
maxx[i << 1] -= maxx[i], maxx[i << 1 | 1] -= maxx[i];
4.区间查询
如何进行区间查询?我们继续二进制表示入手,寻找查询的规律。
在实际的查询中,我们采取扩增左右区间端点的方式进行查询,即:将闭区间转换为开区间查询。
我们以下图为例:假设要查询的区间为
,那么首先转换为开区间
,我们可以发现变为开区间之后,
的兄弟结点必在区间之内,
的兄弟结点必在区间内;根据这个规律我们可以总结:
对于待查区间
:
- 如果
是左儿子,则其兄弟结点必位于区间之内; - 如果
是右儿子,则其兄弟结点必位于区间之内; - 查询的终止条件:两个结点同为兄弟;
- 以上结论,对于任意层的结点均成立。
我们通过例子来模拟这个过程:

在如图所示的ZKW线段树中,假设我们要查询区间
,那么步骤如下:
- 闭区间改开区间,
改为查询
,扩增至
; - 判断:左端点
是左儿子,那么其兄弟
必位于区间内,累加
;
判断:右端点
是右儿子,那么其兄弟
必位于区间内,累加
; - 缩小区间(向根结点缩):

- 判断:左端点
是左儿子,那么其兄弟
必位于区间内,累加
;
判断:右端点
是左儿子,不做操作; - 缩小区间(向根结点缩):

- 此时
和
同为兄弟,因此终止查询。
我们可以总结出区间查询的步骤:
- 闭区间改开区间

- 判断当前区间左端点是否是左儿子,如果是,则向累加器中累加兄弟结点;
判断当前区间右端点是否为右儿子,如果是,则向累加器中累加兄弟结点; - 端点变量处理操作:
; - 循环执行
的步骤,直到
和
同为兄弟结点(此时不终止会导致重复计算)
如何判断是否为左子节点?我们很容易观察到左右子节点共同的特征:左子节点最低位为
,右子节点最低位为
,那么我们可以通过以下操作的真值判断左右子节点:

对于取兄弟结点的值则可以通过与
异或求得:

建立在上述操作的基础上,我们可以实现区间查询:
- 维护区间和
inline int get_sum(int l, int r, int ans = 0){
for (l = l + m - 1, r = r + m + 1; l ^ r ^ 1; l >>= 1, r >>= 1){
if (~l & 1) ans += sum[l ^ 1];
if (r & 1) ans += sum[r ^ 1];
} return ans;
}
- 维护区间最小值
int get_min(int l, int r, int LL = 0, int RR = 0){
for (l = l + m - 1, r = r + m + 1; l ^ r ^ 1; l >>= 1, r >>= 1){
LL += minn[l], RR += minn[r];
if (~l & 1) LL = min(LL, minn[l ^ 1]);
if (r & 1) RR = min(RR, minn[r ^ 1]);
}
int res = min(LL, RR);
while (l) res += maxx[l >>= 1]; return res;
}
- 维护区间最大值
int get_max(int l, int r, int LL = 0, int RR = 0){
for (l = l + m - 1, r = r + m + 1; l ^ r ^ 1; l >>= 1, r >>= 1){
LL += maxx[l], RR += maxx[r];
if (~l & 1) LL = max(LL, maxx[l ^ 1]);
if (r & 1) RR = max(RR, maxx[r ^ 1]);
}
int res = max(LL, RR);
while (l) res += maxx[l >>= 1]; return res;
}
!注意:
求最大值最小值不要忘记最后的统计步骤(差分还原)
5.区间修改
如何进行区间修改/更新?这个过程跟查询的思路是十分相似的,我们首先给出区间修改的思路:
- 闭区间改开区间:需要让左端点
,右端点
; - 判断:当前区间左端点是否为左儿子,如果是则兄弟结点更新;
判断:当前区间右端点是否为右儿子,如果是则兄弟结点更新; - 端点变量处理操作:
; - 循环执行
的步骤,直到
和
同为兄弟结点(此时不终止会导致重复计算)
根据上述过程可以得出代码(与查询是比较相似的):
- 维护区间和
//这里有点问题,太晚了改天再改
inline void update_part(int l, int r, ll v){
for (l += m - 1, r += m + 1; l ^ r ^ 1; l >>= 1, r >>= 1, len <<= 1){
if (l & 1 ^ 1) sum[l ^ 1] += v;
if (r & 1) sum[r ^ 1] += v;
}
while(l) sum[l >>= 1] += v;
}
- 维护区间最小值
inline void update_part(int l, int r, int v, int A = 0){
for(l = l + M - 1, r = r + M + 1; l ^ r ^ 1; l >>= 1, r >>= 1){
if (~l & 1) minn[l ^ 1] += v;
if (r & 1) minn[r ^ 1] += v;
A = min(minn[l], minn[l ^ 1]); minn[l] -= A, minn[l ^ 1] -= A, minn[l >> 1] += A;
A = min(minn[r], minn[r ^ 1]); minn[r] -= A, minn[r ^ 1] -= A, minn[r >> 1] += A;
}
while(l) A = min(minn[l], minn[l ^ 1]), minn[l] -= A, minn[l ^ 1] -= A, minn[l >>= 1] += A;
}
- 维护区间最大值
inline void update_part(int l, int r, int v, int A = 0){
for(l = l + M - 1, r = r + M + 1; l ^ r ^ 1; l >>= 1, r >>= 1){
if (~l & 1) maxx[l ^ 1] += v;
if (r & 1) maxx[r ^ 1] += v;
A = min(maxx[l], maxx[l ^ 1]); maxx[l] -= A, maxx[l ^ 1] -= A, maxx[l >> 1] += A;
A = min(maxx[r], maxx[r ^ 1]); maxx[r] -= A, maxx[r ^ 1] -= A, maxx[r >> 1] += A;
}
while(l) A = min(maxx[l], maxx[l ^ 1]), maxx[l] -= A, maxx[l ^ 1] -= A, maxx[l >>= 1] += A;
}
四、Lazy标记
1.
类
ZKW线段树同样支持Lazy标记,也支持标记上下传。注意,一般不用这个方法,可以直接跳到标记永久化。
这里暂时没详细研究,先参考dalao博客放一个思路:
那么ZKW线段树中的Lazy标记是如何实现的呢?首先我们回到区间修改这个操作–大致的框架如下:
void update_part(int l, int r){
for (l = l + m - 1, r = r + m + 1; l ^ r ^ 1; l >>= 1, r >>= 1, updata(l), updata(r)){
if (~l & 1) ...
if (r & 1)...
}
l >>= 1;
while (l) updata(l), l >>= 1;
}
如果要实现
操作,只需额外添加一个
函数,其实就是用栈模拟基础线段树的标记下传操作。
void push(int x){
int top = 0;
while (x) sta[++top] = x, x >>= 1;
while (top > 1) pushdown(sta[top--]);
}
将区间修改改为:
void update_part(int l, int r){
for (l = l + m - 1, r = r + m + 1, push(l), push(r); l ^ r ^ 1; l >>= 1, r >>= 1, update(l), update(r)){
if (~l & 1) ...
if (r & 1)...
}
l >>= 1;
while (l) updata(l), l >>= 1;
}
2.标记永久化
标记永久化相关的概念不再赘述,与基础线段树一样的用法。
标记永久化之后的区间查询操作:
void update_part(int l, int r, ll k) {
int lnum = 0, rnum = 0, now = 1;
//lnum 表示当前左端点走到的子树有多少个元素在修改区间内 (rnum与lnum对称)
//now 表示当前端点走到的这一层有多少个叶子节点
for (l = l + m - 1, r = r + m + 1; l ^ r ^ 1; l >>= 1, r >>= 1; now <<= 1) {
sum[l] += k * lnum, sum[r] += k * rnum;
if (~l & 1) sum[l ^ 1] += k * now, add[l ^ 1] += k, lnum += now;
if (r & 1) sum[r ^ 1] += k * now, add[r ^ 1] += k, rnum += now;
}
for (; l; l >>= 1, r >>= 1) sum[l] += k * lnum, sum[r] += k * rnum;
}
标记永久化之后的区间修改
ll query(int l, int r) {
int lnum = 0, rnum = 0, now = 1;
long long ret = 0;
for (l = l + M - 1, r = r + M + 1; l ^ r ^ 1; l >>= 1, r >>= 1, now <<= 1) {
if (add[l]) ret += add[l] * lnum;
if (add[r]) ret += add[r] * rnum;
if (~l & 1) ret += sum[l ^ 1], lnum += now;
if (r & 1) ret += sum[r ^ 1], rnum += now;
}
for (; l; l >>= 1, r >>= 1) ret += add[l] * lnum, ret += add[r] * rnum;
return ret;
}