P7845 「dWoi R2」Change

· · 题解

\王队/\王队/\王队/

Change 官方题解

一个微不足道的贪心:选尽可能靠前的边永远是最好的。题解中我们称对于点 u,那么设删掉它能让 s 无法到达 u 的边为支配边。

考虑求出 DAG 上的支配树(空间为 n\log n)。我们发现所有能支配一个点的边必然是支配树上的自己或者祖先中所有入度为 1 的点的入边,于是树形 DP 一下即可。可以过 60 分。

实际上并不需要用 O(m\log ) 建支配树。我们在拓扑排序的时候,考虑维护四个东西:t_e 表示距离 e 最远的支配边(没有即 0),g_u 表示距离 u 最远的支配边(没有即 0)。

考虑对于 e=u\to v,我们做如下转移:

注意一开始的时候要删掉那些 s 无法到达的点,否则无法进行拓扑排序。

#include<bits/stdc++.h>
using namespace std;
#define rep(i,a,b) for(int i=(a);i<=(b);i++)
#define per(i,a,b) for(int i=(a);>=(b);i--)
const int N=6e6+9,mod=998244353;
typedef pair<int,int> pii;

inline long long read() {
    int x=0,f=1; char c=getchar();
    while(!isdigit(c)) {if(c=='-') f=-1; c=getchar();}
    while(c>='0'&&c<='9') {x=x*10+c-48; c=getchar();}
    return x*f;
}

struct edge {int to,nxt;} e[N]; int hd[N],tot=1;
void add(int u,int v) {e[++tot]=(edge){v,hd[u]};hd[u]=tot;}

int n,m,s,p,deg[N],out[N],t[N],g[N],d[N],f[N],res[N];
long long ans;
bool in[N],vst[N];
stack<int>st;

long long ksm(long long x,int y,long long ret=1) {
    while(y) {
        if(y%2) ret=ret*x%mod;
        x=x*x%mod, y>>=1;
    } return ret;
}

void topo() {
    queue<int>q; q.push(s);
    memset(g,-1,sizeof(g)); g[s]=0;
    memset(f,0x3f,sizeof(f)); f[s]=0;
    while(!q.empty()) {
        int u=q.front(); q.pop();
        for(int i=hd[u];i;i=e[i].nxt) {
            int v=e[i].to;
            f[v]=min(f[v],f[u]+1);
            d[i]=f[v];
            if(g[u]==0) t[i]=i;
            else t[i]=g[u];
            if(g[v]==-1) g[v]=t[i];
            else if(g[v]!=t[i]) g[v]=0, res[v]=0, t[i]=0;
            if(t[i]) res[v]=p-d[t[i]];
            if((--deg[v])==0) q.push(v);
        }
    }
}

void cfs(int u) {
    vst[u]=1;
    for(int i=hd[u],v;i;i=e[i].nxt) {
        if(!vst[v=e[i].to]) cfs(v=e[i].to);
    }
}

int main() {
    n=read(), m=read(), s=read(), p=read();
    rep(i,1,m) {
        int u=read(), v=read();
        add(u,v);
    }
    cfs(s);
    rep(u,1,n) if(vst[u]) {
        for(int i=hd[u];i;i=e[i].nxt) deg[e[i].to]++;
    }
    topo();
    int kk=ksm(n,mod-2);
    rep(i,1,n) ans=(ans+1ll*kk*res[i])%mod;
    printf("%lld\n",ans);
    return 0;
}