题解:P10894 虚树

· · 题解

好题啊(赞赏)

然后自信提交,获得了 30 分的好成绩。

dp_{i,0}i 方案中不包含 ii 的子树的方案总数;dp_{i,1} 为包含 i 的方案,i 的子树的方案总数。

显然,如果要跨子树选择,必须选 i,因为 \text{LCA}(u,v) 肯定有 i 存在,再加上 i 自身,一共是 \prod dp_{v,0}+dp_{v,1}+1.

若不选 i,则继承了任意一颗子树里的方案。总方案数为 \sum dp_{v,0}+dp_{v,1}

转移都是 O(n) 的。

然后自信提交,获得了 30 分的好成绩。

这里有一个小优化,显然一次改变 i 只影响 i1 上的所有路径,删去即可,注意除法要有理数取模。

然后自信提交,获得了 30 分的好成绩。

复杂度仍是 O(nm),由于数据造的太好,没有骗到任何分。

我们考虑树上前缀积算子树贡献系数 g_i,显然这玩意可以预处理。

于是答案为 dp_{1,0}+dp_{1,1}-(dp_{pos,0}+dp_{pos,1})\times g_{pos}

复杂度为 O(n\log mod)

#include<bits/stdc++.h>
#define int long long
#define endl "\n"
using namespace std;
int siz[500001] = {0},n,m,dp[500001][2] = {0},db = 0;
int fdp[500001][2] = {0};
vector<int>e[500001];
int vis[500001] = {0};
const int mod = 998244353;
int fa[500001] = {0};
int g[500001] = {0};
int qpow(int x,int y){
    int ret = 1; 
    while(y){
        if(y&1){
            ret*=x;
            ret%=mod;
        }
        x*=x;
        x%=mod;
        y>>=1;
    }
    return ret;
}
void dfs(int pos){
    vis[pos] = 1;
    int dyc = 1;
    for(int i = 0;i<e[pos].size();i++){
        if(!vis[e[pos][i]]){
            fa[e[pos][i]] = pos;
            //cout<<pos<<" "<<e[pos][i]<<endl;
            dfs(e[pos][i]);
            dp[pos][0]+=dp[e[pos][i]][0]+dp[e[pos][i]][1];
            dp[pos][0]%=mod;
            //cout<<dp[e[pos][i]][0]+dp[e[pos][i]][1]+1<<endl;
            dyc*=(dp[e[pos][i]][0]+dp[e[pos][i]][1]+1)%mod;
            dyc%=mod;
        }
    }
    dp[pos][1]+=dyc;
    dp[pos][1]%=mod;
    for(int i = 0;i<e[pos].size();i++){
        if(e[pos][i]!=fa[pos]){
            g[e[pos][i]] = dp[pos][1]*qpow(dp[e[pos][i]][1]+dp[e[pos][i]][0]+1,mod-2)%mod+1;
        }
    }
    return;
}
void dfs2(int pos){
    if(fa[pos]!=1){
        g[pos]*=g[fa[pos]];
        g[pos]%=mod;
    }
    for(int i = 0;i<e[pos].size();i++){
        if(e[pos][i]!=fa[pos]){
            dfs2(e[pos][i]);
        }
    }
    return;
}
signed main(){
    cin>>n;
    for(int i = 1;i<n;i++){
        int u,v;
        cin>>u>>v;
        e[u].push_back(v);
        e[v].push_back(u);
    }
    for(int i = 1;i<=n;i++){
        dp[i][0] = 0;
        dp[i][1] = 0;
    }
    dfs(1);
    dfs2(1);
//  for(int i = 1;i<=n;i++){
//      cout<<g[i]<<" ";
//  }
//  cout<<endl;
    cin>>m;
    for(int i = 1;i<=m;i++){
        int pos;
        cin>>pos;
        cout<<(dp[1][0]+dp[1][1]+mod-g[pos]*(dp[pos][0]+dp[pos][1])%mod)%mod<<endl;;
    }
    return 0;
}