「Daily OI Round 1」block 题解
teylnol_evteyl · · 题解
设
其中
统计答案时,除了考虑选择的点的最近公共祖先被选择的情况(即
其中
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 1e5 + 5, P = 1e9 + 7;
int n, c[N];
int la[N], ne[N * 2], en[N * 2], idx;
LL f[N], g[N], res;
void add(int a, int b)
{
ne[ ++ idx] = la[a];
la[a] = idx;
en[idx] = b;
}
void dp(int u, int fa)
{
f[u] = 1;
for(int i = la[u]; i; i = ne[i])
{
int v = en[i];
if(v == fa) continue ;
dp(v, u);
LL t = 1;
for(int j = la[v]; j; j = ne[j])
{
int w = en[j];
if(w == u) continue ;
if(c[w] == c[u]) t = t * (f[w] + 1) % P;
}
if(c[v] == c[u]) t = (t + f[v]) % P;
f[u] = f[u] * t % P;
}
for(int i = la[u]; i; i = ne[i])
{
int v = en[i];
if(v == fa) continue ;
g[c[v]] = g[c[v]] * (f[v] + 1) % P;
}
for(int i = la[u]; i; i = ne[i])
{
int v = en[i];
if(v == fa) continue ;
res = (res + g[c[v]] - 1) % P;
g[c[v]] = 1;
}
}
int main()
{
scanf("%d", &n);
for(int i = 1; i <= n; i ++ ) scanf("%d", &c[i]);
for(int i = 1; i < n; i ++ )
{
int a, b;
scanf("%d%d", &a, &b);
add(a, b), add(b, a);
}
add(0, 1);
for(int i = 1; i <= n; i ++ ) g[i] = 1;
dp(0, 0);
printf("%lld\n", res);
return 0;
}