笨蛋花的小窝qwq

笨蛋花的小窝qwq

“我要再和生活死磕几年。要么我就毁灭,要么我就注定铸就辉煌。如果有一天,你发现我在平庸面前低了头,请向我开炮。”

「简单 DP」 CF960E Alternating Tree

posted on 2020-06-09 08:11:55 | under 题解 |

秒了。指想出来但写了一年

然后一眼就是点分…不过好像比较麻烦的样子。

或者考虑直接树形 $dp$ 。就只需要统计出每个点在多少条路径上是奇数位置,多少条路径上是偶数位置。这个可以用比较 trivial 的 up_and_down 技巧来做。具体的,考虑对于奇偶位置分别维护三个信息,$f,g,h$ 分别表示「起点在子树外、终点在子树内」、「起点在子树内,终点在子树外」、「起点在子树内,终点在子树内」三种信息,其中第三种依赖于第二种,第一种需要第二种进行 up_and_down 。这样就可以结合子树 $size$ 直接算出贡献来了。

有小细节需要注意。就是 $f_x,g_x$ 在处理奇数的时候需要判一下 $x$ 的边界。

#include <cstdio>

#include <vector>

typedef long long ll ;

typedef std :: vector <int> vint ;

#define il inline

#define p_b push_back

const int N = 200010 ;

const int P = 1000000007 ;

using namespace std ;

ll ans ;

int n ;
ll f[N][3] ;
ll g[N][3] ;
int size[N] ;
int base[N] ;

vint E[N] ;

template <typename T>
il void add(T &x, ll y, int mod = P){
    x += y ; x = x >= mod ? x - mod : x ;
}
template <typename T>
il void dec(T &x, ll y, int mod = P){
    x -= y ; x = x < 0 ? x + mod : x ;
}

void dfs(int x, int fa){
    f[x][0] = size[x] = 1 ;
    for (auto y : E[x]){
        if (y == fa) continue ;
        dfs(y, x) ; size[x] += size[y] ;
        f[x][0] += g[y][0] ; g[x][0] += f[y][0] ;
    }
    for (auto y : E[x]){
        if (y == fa) continue ;
        add(f[x][2], (size[x] - size[y] - 1) * g[y][0] % P) ;
        add(g[x][2], (size[x] - size[y] - 1) * f[y][0] % P) ;
    }
}
void dfs2(int x, int fa){
    for (auto y : E[x]){
        if (y == fa) continue ;
        f[y][1] = g[x][1] ;
        g[y][1] = f[x][1] ;
        f[y][1] += (g[x][0] - f[y][0]) ;
        g[y][1] += (f[x][0] - g[y][0]) ;
        dfs2(y, x) ;
    }
}
void solve(){
    for (int x = 1 ; x <= n ; ++ x){
        add(ans, f[x][2] * base[x] % P) ;
        dec(ans, g[x][2] * base[x] % P) ;
        dec(ans, base[x]) ;
//        printf("%lld\n", ans) ;
        add(ans, (f[x][1] + 1) * size[x] % P * base[x] % P) ;
//        printf("%lld\n", ans) ;
        add(ans, f[x][0] * (n - size[x] + 1) % P * base[x] % P) ;
//        printf("%lld\n", ans) ;
        dec(ans, g[x][1] * size[x] % P * base[x] % P) ;
//        printf("%lld\n", ans) ;
        dec(ans, g[x][0] * (n - size[x] + 1) % P * base[x] % P) ;
//        printf("%lld\n", ans) ; puts("-----------") ;
    }
}
int main(){
    scanf("%d", &n) ; int x, y ;
    for (int i = 1 ; i <= n ; ++ i) scanf("%d", &base[i]) ;
    for (int i = 1 ; i < n ; ++ i)
        scanf("%d%d", &x, &y), E[x].p_b(y), E[y].p_b(x) ;
    dfs(1, 0) ; dfs2(1, 0) ; solve() ; printf("%lld\n", ans % P) ;
}