「简单 DP」 CF960E Alternating Tree

皎月半洒花

2020-06-09 08:11:55

Solution

秒了。~~指想出来但写了一年~~ 然后一眼就是点分…不过好像比较麻烦的样子。 或者考虑直接树形 $dp$ 。就只需要统计出每个点在多少条路径上是奇数位置,多少条路径上是偶数位置。这个可以用比较 trivial 的 up_and_down 技巧来做。具体的,考虑对于奇偶位置分别维护三个信息,$f,g,h$ 分别表示「起点在子树外、终点在子树内」、「起点在子树内,终点在子树外」、「起点在子树内,终点在子树内」三种信息,其中第三种依赖于第二种,第一种需要第二种进行 up_and_down 。这样就可以结合子树 $size$ 直接算出贡献来了。 有小细节需要注意。就是 $f_x,g_x$ 在处理奇数的时候需要判一下 $x$ 的边界。 ```cpp #include <cstdio> #include <vector> typedef long long ll ; typedef std :: vector <int> vint ; #define il inline #define p_b push_back const int N = 200010 ; const int P = 1000000007 ; using namespace std ; ll ans ; int n ; ll f[N][3] ; ll g[N][3] ; int size[N] ; int base[N] ; vint E[N] ; template <typename T> il void add(T &x, ll y, int mod = P){ x += y ; x = x >= mod ? x - mod : x ; } template <typename T> il void dec(T &x, ll y, int mod = P){ x -= y ; x = x < 0 ? x + mod : x ; } void dfs(int x, int fa){ f[x][0] = size[x] = 1 ; for (auto y : E[x]){ if (y == fa) continue ; dfs(y, x) ; size[x] += size[y] ; f[x][0] += g[y][0] ; g[x][0] += f[y][0] ; } for (auto y : E[x]){ if (y == fa) continue ; add(f[x][2], (size[x] - size[y] - 1) * g[y][0] % P) ; add(g[x][2], (size[x] - size[y] - 1) * f[y][0] % P) ; } } void dfs2(int x, int fa){ for (auto y : E[x]){ if (y == fa) continue ; f[y][1] = g[x][1] ; g[y][1] = f[x][1] ; f[y][1] += (g[x][0] - f[y][0]) ; g[y][1] += (f[x][0] - g[y][0]) ; dfs2(y, x) ; } } void solve(){ for (int x = 1 ; x <= n ; ++ x){ add(ans, f[x][2] * base[x] % P) ; dec(ans, g[x][2] * base[x] % P) ; dec(ans, base[x]) ; // printf("%lld\n", ans) ; add(ans, (f[x][1] + 1) * size[x] % P * base[x] % P) ; // printf("%lld\n", ans) ; add(ans, f[x][0] * (n - size[x] + 1) % P * base[x] % P) ; // printf("%lld\n", ans) ; dec(ans, g[x][1] * size[x] % P * base[x] % P) ; // printf("%lld\n", ans) ; dec(ans, g[x][0] * (n - size[x] + 1) % P * base[x] % P) ; // printf("%lld\n", ans) ; puts("-----------") ; } } int main(){ scanf("%d", &n) ; int x, y ; for (int i = 1 ; i <= n ; ++ i) scanf("%d", &base[i]) ; for (int i = 1 ; i < n ; ++ i) scanf("%d%d", &x, &y), E[x].p_b(y), E[y].p_b(x) ; dfs(1, 0) ; dfs2(1, 0) ; solve() ; printf("%lld\n", ans % P) ; } ```