题解
对于一个点x,合法的y一定在它到最远点的路径上。
而一个点x的最远点一定是直径的一个端点。
那么以直径的端点做根,每一个点x在距离较大的那个根上统计答案。
那么就是求每个点x到根路径上的合法的y的不同颜色数。
先对树进行长链剖分,设son[x]表示x的重儿子,md[x]表示x到子树里最远点的距离。
假设现在在x上,有一个栈维护到x到根路径上剩下的合法的y。
如果往重儿子走,那么栈中到x距离<=(max(md[轻儿子])+1)的会被弹掉,
如果往轻儿子走,那么栈中到x距离<=(md[重儿子]+1)的会被弹掉。
那就先往重儿子走,再往轻儿子走,这样保证弹掉元素有有序的。
*在进入每一个儿子前,都要把自己加入栈。
由于元素加入次数是\(n-1\),所以复杂度是\(O(n)\)。
在栈操作的同时维护一个桶即可统计栈内不同颜色数。
Code:
#include<bits/stdc++.h>
#define fo(i, x, y) for(int i = x, _b = y; i <= _b; i ++)
#define ff(i, x, y) for(int i = x, _b = y; i < _b; i ++)
#define fd(i, x, y) for(int i = x, _b = y; i >= _b; i --)
#define ll long long
#define pp printf
#define hh pp("\n")
using namespace std;
const int N = 3e5 + 5;
int n, m, x, y;
#define V vector<int>
#define si size()
#define pb push_back
V e[N];
int c[N];
void Init() {
scanf("%d %d", &n, &m);
fo(i, 1, n - 1) {
scanf("%d %d", &x, &y);
e[x].pb(y); e[y].pb(x);
}
fo(i, 1, n) scanf("%d", &c[i]);
}
int d1[N], d2[N], d[N], d0, bz[N];
void bfs(int st, int *dis) {
fo(i, 1, n) dis[i] = bz[i] = 0;
d[d0 = 1] = st; bz[st] = 1;
for(int i = 1; i <= d0; i ++) {
int x = d[i];
ff(_y, 0, e[x].si) {
int y = e[x][_y];
if(!bz[y]) dis[y] = dis[x] + 1, d[++ d0] = y, bz[y] = 1;
}
}
}
int s, t;
int ans[N];
int md[N], dep[N], son[N], fa[N];
void dg(int x) {
dep[x] = dep[fa[x]] + 1;
md[x] = 0;
son[x] = 0; md[0] = -1e9;
ff(_y, 0, e[x].si) {
int y = e[x][_y];
if(y == fa[x]) continue;
fa[y] = x;
dg(y);
md[x] = max(md[x], md[y] + 1);
if(md[y] > md[son[x]]) son[x] = y;
}
}
int cnt[N], c0, z[N], z0;
void add(int x) {
c0 += !cnt[c[x]];
cnt[c[x]] ++;
z[++ z0] = x;
}
void pop() {
int x = z[z0 --];
cnt[c[x]] --;
c0 -= !cnt[c[x]];
}
int ky[N];
void dfs(int x) {
int mx = -1e9;
ff(_y, 0, e[x].si) {
int y = e[x][_y];
if(y == fa[x]) continue;
if(y != son[x]) mx = max(mx, md[y] + 1);
}
while(z0 > 0 && dep[x] - dep[z[z0]] <= mx) pop();
add(x);
if(son[x]) dfs(son[x]);
while(z0 > 0 && dep[x] - dep[z[z0]] <= md[son[x]] + 1) pop();
ff(_y, 0, e[x].si) {
int y = e[x][_y];
if(y == fa[x]) continue;
if(y != son[x]) {
if(z[z0] != x) add(x);
dfs(y);
}
}
while(z0 > 0 && dep[z[z0]] >= dep[x]) pop();
if(ky[x]) {
ans[x] += c0;
}
}
void dp() {
fo(i, 1, n) fa[i] = 0;
dg(s);
fo(i, 1, m) cnt[i] = 0;
c0 = z0 = 0;
dfs(s);
}
int main() {
Init();
bfs(1, d1);
s = d[d0];
bfs(s, d1);
t = d[d0];
bfs(t, d2);
fo(i, 1, n) ky[i] = d1[i] >= d2[i];
dp();
swap(s, t);
swap(d1, d2);
fo(i, 1, n) ky[i] = !ky[i];
dp();
fo(i, 1, n) pp("%d\n", ans[i]);
}