P6651 「SWTR-5」Chain

· · 题解

前言

其实感觉题目描述不太对,应该改为禁用一个点(因为新的链不算)。

正言

不难发现一个最简单的 dp。即 f_i 表示从入度为 0 的点走到 i 的方案数,有转移 f_v=\displaystyle\sum_{u\rightarrow v} f_u,答案即为 \displaystyle\sum_{du_i=0} f_i

考虑我们如何删点,其实就是删除经过 i 点的所有链。不妨记 g_i 表示从 i 到出度为 0 的点的方案。则对于每个点我们应该减去 f_i\times g_i

可是这样计算显然有重复,一般去重我们考虑容斥,但是显然不能正常容斥。所以考虑如何在树上快速的进行容斥。

考虑对于每个点去重,而我们先用拓扑序钦定一个顺序,使得后面的点不能走前面走过的路径。所以对于在 j 前面的点 i,我们应该在 f_i 中做减法而非 g_i。而不难想到再记录一个 h_{i,j} 表示链经过 ij 的方案数。则显然 f_j 应该减少 f_i\times h_{i,j}

注意我们并不是用原来的 f_i 去更新 f_j,而是用被更新过的 f_i 去更新。道理也很简单:因为要枚举所有 i<j,如果用原来的 f_i,那显然会重复减去。

#include <bits/stdc++.h>
using namespace std;
#define int long long
#define pb push_back
#define eb emplace_back
#define all(x) (x).begin(),(x).end()
#define fi first
#define se second
#define SZ(x) ((int)(x).size())
typedef vector<int> VI;
typedef long long ll;
typedef pair<int,int> PII;
const ll MOD=1e9+7;
// head
const int N=2e3+5;
vector<vector<int>> G(N),G1(N);
int du[N],du1[N];
int f[N],g[N],h[N][N];
int id[N],cnt;
int c[N],d[N];
bool cmp(int p,int q) {return id[p]<id[q];}
signed main() 
{
    cin.tie(nullptr);cout.tie(nullptr);
    ios::sync_with_stdio(false);

    int n,m;cin>>n>>m;
    for(int i=1;i<=m;i++){
        int u,v;cin>>u>>v;
        G[u].pb(v);du[v]++;
        G1[v].pb(u);du1[u]++;
    }
    queue<int> Q;
    for(int i=1;i<=n;i++) h[i][i]=1;
    for(int i=1;i<=n;i++) if(!du[i]) {Q.push(i);f[i]=1;}
    while(!Q.empty())
    {
        int u=Q.front();
        id[u]=++cnt;
        Q.pop();
        for(auto v:G[u]){
            du[v]--;
            (f[v]+=f[u])%=MOD;
            for(int i=1;i<=n;i++) (h[i][v]+=h[i][u])%=MOD;
            if(!du[v]) Q.push(v);
        }
    }
    int sum=0;
    for(int i=1;i<=n;i++) if(!du1[i]) {Q.push(i);g[i]=1,(sum+=f[i])%=MOD;}
    while(!Q.empty())
    {
        int u=Q.front();
        Q.pop();
        for(auto v:G1[u]){
            du1[v]--;
            (g[v]+=g[u])%=MOD;
            if(!du1[v]) Q.push(v);
        }
    }
    int q;cin>>q;
    while(q--)
    {
        int k;cin>>k;
        for(int i=1;i<=k;i++) cin>>c[i];
        sort(c+1,c+k+1,cmp);
        int ans=sum;
        for(int i=1;i<=k;i++) d[i]=f[c[i]];
        for(int i=1;i<=k;i++){
            for(int j=1;j<i;j++){
                (d[i]-=d[j]*h[c[j]][c[i]]%MOD-MOD)%=MOD;
            }
            (ans-=d[i]*g[c[i]]%MOD-MOD)%=MOD;
        }
        cout<<ans<<endl;
    }
}