[BZOJ2588][Spoj 10628]Count on a tree

试题描述

给定一棵N个节点的树,每个点有一个权值,对于M个询问(u,v,k),你需要回答u xor lastans和v这两个节点间第K小的点权。其中lastans是上一个询问的答案,初始为0,即第一个询问的u是明文。 

输入

第一行两个整数N,M。
第二行有N个整数,其中第i个整数表示点i的权值。
后面N-1行每行两个整数(x,y),表示点x到点y有一条边。
最后M行每行两个整数(u,v,k),表示一组询问。

输出

M行,表示每个询问的答案。最后一个询问不输出换行符

输入示例

8 5
105 2 9 3 8 5 7 7
1 2
1 3
1 4
3 5
3 6
3 7
4 8
2 5 1
0 5 2
10 5 3
11 5 4
110 8 2

输出示例

2
8
9
105
7

数据规模及约定

N,M<=100000

题解

我们可以把主席树按照树形结构来建,即每一个节点上的版本从它父亲节点的版本修改而来,那么一个节点上的主席树记录的就是该节点到根节点的权值信息了,于是利用 d(a, b) = dep(a) + dep(b) - dep(lca(a, b)) - dep(fa[lca(a, b)]) 这个公式(其中 d(a, b) 表示路径 a 到 b 的权值和,dep(u) = d(root, u),root 为根节点,lca(a, b) 为 a 与 b 的最近公共祖先,fa[u] 为 u 的父亲)二分。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <stack>
#include <vector>
#include <queue>
#include <cstring>
#include <string>
#include <map>
#include <set>
using namespace std;

const int BufferSize = 1 << 16;
char buffer[BufferSize], *Head, *Tail;
inline char Getchar() {
	if(Head == Tail) {
		int l = fread(buffer, 1, BufferSize, stdin);
		Tail = (Head = buffer) + l;
	}
	return *Head++;
}
int read() {
	int x = 0, f = 1; char c = Getchar();
	while(!isdigit(c)){ if(c == '-') f = -1; c = Getchar(); }
	while(isdigit(c)){ x = x * 10 + c - '0'; c = Getchar(); }
	return x * f;
}

#define maxn 100010
#define maxm 200010
#define maxlog 17
#define maxnode 2000010
int n, rt[maxn], val[maxn], num[maxn];

int ToT, sumv[maxnode], lc[maxnode], rc[maxnode];
void update(int& y, int x, int l, int r, int p) {
	sumv[y = ++ToT] = sumv[x] + 1;
	if(l == r) return ;
	int mid = l + r >> 1; lc[y] = lc[x]; rc[y] = rc[x];
	if(p <= mid) update(lc[y], lc[x], l, mid, p);
	else update(rc[y], rc[x], mid + 1, r, p);
	return ;
}

int m, head[maxn], next[maxm], to[maxm], fa[maxlog][maxn], dep[maxn];
void AddEdge(int a, int b) {
	to[++m] = b; next[m] = head[a]; head[a] = m;
	swap(a, b);
	to[++m] = b; next[m] = head[a]; head[a] = m;
	return ;
}
void build(int u) {
	update(rt[u], rt[fa[0][u]], 1, n, val[u]);
	for(int i = 1; i < maxlog; i++) fa[i][u] = fa[i-1][fa[i-1][u]];
	for(int e = head[u]; e; e = next[e]) if(to[e] != fa[0][u]) {
		fa[0][to[e]] = u;
		dep[to[e]] = dep[u] + 1;
		build(to[e]);
	}
	return ;
}
int lca(int a, int b) {
	if(dep[a] < dep[b]) swap(a, b);
	for(int i = maxlog - 1; i >= 0; i--) if(dep[a] - dep[b] >= (1 << i)) a = fa[i][a];
	for(int i = maxlog - 1; i >= 0; i--) if(fa[i][a] != fa[i][b]) a = fa[i][a], b = fa[i][b];
	return a == b ? a : fa[0][b];
}

int solve(int a, int b, int k) {
	int lrt[2] = {rt[a], rt[b]}, c = lca(a, b), rrt[2] = {rt[c], rt[fa[0][c]]};
	int l = 1, r = n;
	while(l < r) {
		int mid = l + r >> 1, sum = 0;
		for(int i = 0; i < 2; i++) if(lrt[i] && lc[lrt[i]]) sum += sumv[lc[lrt[i]]];
		for(int i = 0; i < 2; i++) if(rrt[i] && lc[rrt[i]]) sum -= sumv[lc[rrt[i]]];
		if(sum < k) {
			k -= sum; l = mid + 1;
			for(int i = 0; i < 2; i++) if(lrt[i]) lrt[i] = rc[lrt[i]];
			for(int i = 0; i < 2; i++) if(rrt[i]) rrt[i] = rc[rrt[i]];
		}
		else {
			r = mid;
			for(int i = 0; i < 2; i++) if(lrt[i]) lrt[i] = lc[lrt[i]];
			for(int i = 0; i < 2; i++) if(rrt[i]) rrt[i] = lc[rrt[i]];
		}
	}
	return num[l];
}

int main() {
	n = read(); int q = read();
	for(int i = 1; i <= n; i++) val[i] = num[i] = read();
	sort(num + 1, num + n + 1);
	for(int i = 1; i <= n; i++) val[i] = lower_bound(num + 1, num + n + 1, val[i]) - num;
	for(int i = 1; i < n; i++) {
		int a = read(), b = read();
		AddEdge(a, b);
	}
	build(1);
	int lst = 0;
	while(q--) {
		int a = read() ^ lst, b = read(), k = read();
		lst = solve(a, b, k);
		if(q) printf("%d\n", lst);
		else printf("%d", lst);
	}
	
	return 0;
}