题解 P7727 【风暴之眼(Eye of the Storm)】

· · 题解

树上计数问题一般考虑树形 dp。显然可以按照权值划分出一个森林。

我们考虑一些特殊情况:

有了这些结论我们就可以 dp 了,记 dp_{i,0/1/2/3} 分别表示当前点选四种类型的节点的合法方案,那么联通块内直接根据结论 1,3 转移,联通块之间根据结论 1,2 转移。

这样可以通过此题……了吗?

考虑一种特殊情况:0 的联通块内全部是 and1,周围相邻节点的都是 or1。这种情况显然不符合题意,因为不会有 0 传递进入这个联通块。对于 or0 是同理的。

但我们刚才并没有去掉这种情况。我们考虑额外进行一个 dp。记 s_i 表示当前点所在联通块出现上述不合法方案的数量,那么在联通块内乘法合并,联通块之间转移 dp 数组时减掉儿子的统计的不合法方案,将父亲的不合法方案乘上儿子选 and1/or0 的方案即可。

时间复杂度 O(n)

#include<iostream>
#include<cstdio>
using namespace std;
#define int long long
struct edge
{
    int nxt,to;
}e[200001<<1];
const int mod=998244353;
int n,tot,h[200001<<1],a[200001],dp[200001][4],s[200001];
//0 0and 1 1and 2 0or 3 1or
inline void add(int x,int y)
{
    e[++tot].nxt=h[x];
    h[x]=tot;
    e[tot].to=y;
}
void dfs(int k,int f)
{
    dp[k][0]=dp[k][1]=dp[k][2]=dp[k][3]=1;
    if(a[k])
        dp[k][0]=0;
    else
        dp[k][3]=0;
    s[k]=1;
    for(register int i=h[k];i;i=e[i].nxt)
    {
        if(e[i].to==f)
            continue;
        dfs(e[i].to,k);
        if(!a[k]&&!a[e[i].to])
        {
            s[k]=s[k]*s[e[i].to]%mod;
            dp[k][0]=(dp[k][0]*(dp[e[i].to][0]+dp[e[i].to][1]+dp[e[i].to][2])%mod)%mod;
            dp[k][1]=(dp[k][1]*(dp[e[i].to][0]+dp[e[i].to][1])%mod)%mod;
            dp[k][2]=(dp[k][2]*(dp[e[i].to][0]+dp[e[i].to][2])%mod)%mod;
        }
        if(a[k]&&a[e[i].to])
        {
            s[k]=s[k]*s[e[i].to]%mod;
            dp[k][1]=(dp[k][1]*(dp[e[i].to][1]+dp[e[i].to][3])%mod)%mod;
            dp[k][2]=(dp[k][2]*(dp[e[i].to][2]+dp[e[i].to][3])%mod)%mod;
            dp[k][3]=(dp[k][3]*(dp[e[i].to][1]+dp[e[i].to][2]+dp[e[i].to][3])%mod)%mod;
        }
        if(a[k]&&!a[e[i].to])
        {
            s[k]=s[k]*dp[e[i].to][0]%mod;
            dp[k][0]=dp[k][1]=0;
            dp[k][2]=(dp[k][2]*(dp[e[i].to][0]+dp[e[i].to][1])%mod)%mod;
            dp[k][3]=(dp[k][3]*(dp[e[i].to][0]+dp[e[i].to][1]+mod-s[e[i].to])%mod)%mod;
        }
        if(!a[k]&&a[e[i].to])
        {
            s[k]=s[k]*dp[e[i].to][3]%mod;
            dp[k][2]=dp[k][3]=0;
            dp[k][0]=(dp[k][0]*(dp[e[i].to][2]+dp[e[i].to][3]+mod-s[e[i].to])%mod)%mod;
            dp[k][1]=(dp[k][1]*(dp[e[i].to][2]+dp[e[i].to][3])%mod)%mod;
        }
    }
}
signed main()
{
    scanf("%lld",&n);
    for(register int i=1;i<=n;++i)
    {
        scanf("%lld",&a[i]);
    }
    for(register int i=1;i<n;++i)
    {
        int x,y;
        scanf("%lld%lld",&x,&y);
        add(x,y);
        add(y,x);
    }
    dfs(1,0);
    printf("%lld\n",(dp[1][0]+dp[1][1]+dp[1][2]+dp[1][3]-s[1]+mod*114514)%mod);
    return 0;
}