[NOI2023] 深搜

· · 题解

很 CNOI 的一道题。

首先考虑暴力怎么做。

题目要求存在某个关键点 s 的方案数,容易想到容斥,转成钦定一个集合内的关键点,使得以这个集合内的任意一个点为根都能满足条件。

可以发现,TG 的 dfs 树的充要条件是额外边中不存在横叉边。

也就是说要求出来选边的方案数,使得任意一条边都不是集合内任何点为根时的横叉边。

显然,边之间的贡献是独立的。

如果有 k 条边满足其不是集合内任何点为根时的横叉边,则方案数为 2^k

于是得到了 O(2^kn) 的做法。

把这个过程放到树形 dp 上。

不难发现,对于一个子树是否合法,我们只关注这个子树内所有关键点的最近公共祖先。

于是设置状态:f_{x,k} 表示 x 的子树内,所有被选了的关键点 lca 为 k,方案数和容斥系数的乘积之和。

转移的时候分讨一下就能解决 B 性质。

套个线段树合并即可获得 72pts。

现在还剩横叉边无法处理。

在横叉边的 lca 处统计贡献。

可以发现,只要选了横叉边,则不在横叉边的端点的子树内的点中一定没有被选的关键点。

于是枚举横叉边两个端点的子树内关键点的 lca,即可做到 O(n^3)

可以发现要统计的东西类似一个扫描线的形式,套个扫描线即可获得 100pts,复杂度 O(n\log n)

#include <bits/stdc++.h>
using namespace std;
#define int long long
const int N=5e5+5,mods=1e9+7;
namespace tr{
    struct node{
        signed lson,rson,sm,laz;
    }p[N*50];
    int idx;
    void mul(int x,int sm){
        p[x].sm=1ll*p[x].sm*sm%mods;
        p[x].laz=1ll*p[x].laz*sm%mods;
    }
    void dnset(int x){
        if(p[x].laz!=1){
            mul(p[x].lson,p[x].laz);
            mul(p[x].rson,p[x].laz);
            p[x].laz=1;
        }
    }
    void init(int x){
        if(!p[x].lson)p[x].lson=++idx,p[idx].laz=1;
        if(!p[x].rson)p[x].rson=++idx,p[idx].laz=1;
    }
    void upset(int x){
        p[x].sm=(p[p[x].lson].sm+p[p[x].rson].sm)%mods;
    }
    void mul(int x,int l,int r,int sm,int nl,int nr){
        if(l<=nl&&r>=nr){
            mul(x,sm);
            return;
        }
        int mid=nl+nr>>1;
        init(x);
        dnset(x);
        if(l<=mid)mul(p[x].lson,l,r,sm,nl,mid);
        if(r>mid)mul(p[x].rson,l,r,sm,mid+1,nr);
        upset(x);
    }
    void add(int x,int d,int sm,int nl,int nr){
        if(nl==nr){
            p[x].sm=(p[x].sm+sm)%mods;
            return;
        }
        init(x);
        dnset(x);
        int mid=nl+nr>>1;
        if(d<=mid)add(p[x].lson,d,sm,nl,mid);
        else add(p[x].rson,d,sm,mid+1,nr);
        upset(x);
    }
    int gets(int x,int l,int r,int nl,int nr){
        if(!x)return 0;
        if(l<=nl&&r>=nr)return p[x].sm;
        int mid=nl+nr>>1;
        dnset(x);
        if(r<=mid)return gets(p[x].lson,l,r,nl,mid);
        if(l>mid)return gets(p[x].rson,l,r,mid+1,nr);
        return (gets(p[x].lson,l,r,nl,mid)+gets(p[x].rson,l,r,mid+1,nr))%mods;
    }
    int hb(int a,int b,int nl,int nr){
        if(!a||!b)return a|b;
        if(nl==nr){
            p[a].sm=(p[a].sm+p[b].sm)%mods;
            return a;
        }
        int mid=nl+nr>>1;
        dnset(a);dnset(b);
        p[a].lson=hb(p[a].lson,p[b].lson,nl,mid);
        p[a].rson=hb(p[a].rson,p[b].rson,mid+1,nr);
        upset(a);
        return a;
    }
}
int pows(int a,int b){
    if(b==0)return 1;
    int res=pows(a,b>>1);
    res=res*res%mods;
    if(b&1)res=res*a%mods;
    return res;
}
int inv2=mods+1>>1,op,n,m,k,rt[N],dfn[N],dy[N],js[N],f1[N],f2[N],sl[N],ff[N],sz[N],res,mk[N],cf[N],pw2[N],inv[N],dp[N],eds[N],dep[N],fa[N][20],idx,bk[N];
vector<int>p[N],g[N],gs[N];
vector<pair<int,int> >jl[N];
map<pair<int,int>,vector<pair<int,int> > >q[N];
void dfs(int x){
    for(int i=1;i<=19;i++)fa[x][i]=fa[fa[x][i-1]][i-1];
    mk[x]=1;
    dfn[x]=++idx;
    dy[idx]=x;
    for(auto c:p[x]){
        if(mk[c])continue;
        dep[c]=dep[x]+1;
        fa[c][0]=x;
        dfs(c);
    }
    mk[x]=0;
    eds[x]=idx;
}
int up(int x,int k){
    while(k){
        x=fa[x][__lg(k&-k)];
        k^=k&-k;
    }
    return x;
}
int lca(int a,int b){
    if(dep[a]>dep[b])swap(a,b);
    b=up(b,dep[b]-dep[a]);
    if(a==b)return a;
    for(int i=19;i>=0;i--){
        if(fa[a][i]!=fa[b][i])a=fa[a][i],b=fa[b][i];
    }
    return fa[a][0];
}
bool in(int a,int b){
    return a>=dfn[b]&&a<=eds[b];
}
struct msg{
    int x,op,l,r;
};
void solve(int x){
    rt[x]=++tr::idx;
    mk[x]=1;
    dp[x]=1;
    if(bk[x])tr::add(rt[x],dfn[x],-1,1,n);
    sz[x]=0;
    for(auto c:p[x]){
        if(mk[c])continue;
        solve(c);
        dp[x]=dp[x]*dp[c]%mods;
        sz[x]+=sz[c]+g[c].size();
    }
    for(auto [t,c]:q[x]){
        int a=t.first,b=t.second,ans=-tr::gets(rt[a],1,n,1,n)*tr::gets(rt[b],1,n,1,n)%mods;
        vector<msg>jl;
        for(auto [s1,s2]:c){
            jl.push_back({dfn[s1],0,dfn[s2],eds[s2]});
            jl.push_back({eds[s1]+1,1,dfn[s2],eds[s2]});
        }
        jl.push_back({n+1,2});
        sort(jl.begin(),jl.end(),[&](msg a,msg b){
            return a.x<b.x;
        });
        int lst=0;
        for(auto [x,op,l,r]:jl){
            if(lst<x)ans+=tr::gets(rt[a],lst,x-1,1,n)*tr::gets(rt[b],1,n,1,n)%mods;
            lst=x;
            if(op==0)tr::mul(rt[b],l,r,2,1,n);
            if(op==1)tr::mul(rt[b],l,r,inv2,1,n);
        }
        ans%=mods;
        res+=ans*pw2[cf[dfn[x]]-sz[x]]%mods*dp[x]%mods*inv[a]%mods*inv[b]%mods;
    }
    dp[x]=1;
    for(auto c:p[x]){
        if(mk[c])continue;
        int he=tr::gets(rt[x],1,n,1,n);
        int tmp=he*tr::gets(rt[c],1,n,1,n)%mods;
        tr::mul(rt[c],dp[x]);
        tr::mul(rt[x],dp[c]);
        rt[x]=tr::hb(rt[x],rt[c],1,n);
        tr::add(rt[x],dfn[x],tmp,1,n);
        dp[x]=dp[x]*dp[c]%mods;
    }
    res+=tr::gets(rt[x],dfn[x],dfn[x],1,n)*pw2[cf[dfn[x]]-sz[x]]%mods;
    res%=mods;
    for(auto c:g[x]){
        tr::mul(rt[x],dfn[c],eds[c],2,1,n);
        dp[x]=dp[x]*2%mods;
        if(c==x)js[x]++;
    }
    inv[x]=pows(dp[x],mods-2);
    mk[x]=0;
}
signed main(){
    ios::sync_with_stdio(0);
    cin.tie(0);cout.tie(0);
    cin>>op>>n>>m>>k;
    pw2[0]=1;
    for(int i=1;i<=m;i++)pw2[i]=pw2[i-1]*2%mods;
    for(int i=1;i<n;i++){
        int x,y;
        cin>>x>>y;
        p[x].push_back(y);
        p[y].push_back(x);
    }
    dfs(1);
    for(int i=1;i<=m;i++){
        int a,b;
        cin>>a>>b;
        if(dfn[a]>dfn[b])swap(a,b);
        if(dfn[b]>=dfn[a]&&dfn[b]<=eds[a]){
            g[up(b,dep[b]-dep[a]-1)].push_back(b);
            cf[dfn[b]]++;
            cf[eds[b]+1]--;
            cf[1]++;f1[i]=a;f2[i]=b;
            int tmp=up(b,dep[b]-dep[a]-1);
            cf[dfn[tmp]]--;
            cf[eds[tmp]+1]++;
            ff[dfn[b]]++;ff[eds[b]+1]--;
        }else{
            int c=lca(a,b);
            jl[c].push_back({a,b});
            q[c][{up(a,dep[a]-dep[c]-1),up(b,dep[b]-dep[c]-1)}].push_back({a,b});
            gs[up(a,dep[a]-dep[c]-1)].push_back(a);
            gs[up(b,dep[b]-dep[c]-1)].push_back(b);
            cf[dfn[a]]++;
            cf[eds[a]+1]--;
            cf[dfn[b]]++;
            cf[eds[b]+1]--;
        }
    }
    for(int i=1;i<=n;i++)cf[i]+=cf[i-1],ff[i]+=ff[i-1];
    for(int i=1;i<=k;i++){
        int x;
        cin>>x;
        bk[x]=1;
    }
    solve(1);
    cout<<(-res%mods+mods)%mods<<"\n";
}