P7422 题解

· · 题解

P7422 题解

首先观察题目中的限制,不难发现 k 只有 20 ,于是可以对于每个节点枚举 k 计算,考虑如何计算每个节点的答案。

观察两个要求,对于第一个,供应材料不同的限制比较好满足,城市 A 在城市 B_i 到首都的必经之路上,这个可以转化为, A 是割点并且 A 将首都和 B_i 分在了两个联通块中。

第二个要求, B 之间的材料相同比较好满足,对于 B 之间互不影响,不难发现是在去掉 A 后在不同的联通块中,两条限制都与割点有关系,于是建出圆方树。

显然只需要对于圆点统计答案,设点 u 的材料种类为 c_u ,那么在当前点 u ,需要枚举所有与 c_u 不同的颜色计算贡献,枚举完颜色后,可以用 map 启发式合并统计答案,除了枚举颜色之外的所有操作时间复杂度问题都不是很大,瓶颈在于枚举颜色,考虑优化这一点,因为要合并子树信息,使用线段树合并,不难发现只有在合并到叶子节点的时候才会要用 dp 修正该种材料的贡献,考场上想到这里开了一下数组,由于并没有仔细计算所以开的很大,然后就开了两个G的数组,事实上 std 最开始公开时的版本也 MLE 了。

当然把数组开小一点就可以解决了,不过貌似需要精打细算一下,而且赛时由于我开炸的太多了所以没有想优化,因此我使用了一个空间更加安全的做法。

由于对于每个点进行 dp 的时间复杂度正确,所以这一点保持不变,因为每种材料的数量加起来是 n ,考虑对于每种材料单独计算贡献,建出虚树,在每个与该材料种类不同的圆点处直接 dp ,对于虚树上一条边实际上对应原树上的一条链,这一条链上每个圆点的贡献都是该虚树的子树中的这种材料的个数,直接乘上去就行。

#include<bits/stdc++.h>
using namespace std;
const int N=1e6+10;
const int p=998244353;
#define rint register int
#define ll long long
#define rll register long long
template<typename T> void read(T &x){
    x=0;rint f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch<='9'&&ch>='0'){x=x*10+ch-'0';ch=getchar();}
    x*=f;
}
vector<int> g[N];
struct Edge{
    int to,nxt;
}e[N<<1];
int h[N],idx;
void Ins(rint a,rint b){
    e[++idx].to=b;e[idx].nxt=h[a];h[a]=idx;
}
int n,m,k,fang,c[N];
map<int,int> mp;
int low[N],dfn[N],stk[N],tp,Time;
void tarjan(rint u,rint fa){
    dfn[u]=low[u]=++Time;
    stk[++tp]=u;
    for(rint v:g[u]){
        if(v==fa)continue;
        if(!dfn[v]){
            tarjan(v,u);
            low[u]=min(low[u],low[v]);
            if(low[v]>=dfn[u]){
                ++fang;
                while(1){
                    rint x=stk[tp--];
                    Ins(fang,x);Ins(x,fang);
                    if(x==v)break;
                }
                Ins(fang,u);Ins(u,fang);
            }
        }else low[u]=min(low[u],dfn[v]);
    }
}
int dp[30];
ll res;
int top[N],siz[N],fa[N],son[N],d[N],dis[N];
void dfs1(rint u){
    dfn[u]=++Time;siz[u]=1;dis[u]+=u<=n;
    for(rint i=h[u];i;i=e[i].nxt){
        rint v=e[i].to;
        if(v==fa[u])continue;
        d[v]=d[u]+1;
        fa[v]=u;dis[v]=dis[u];
        dfs1(v);
        siz[u]+=siz[v];
        if(siz[v]>siz[son[u]])son[u]=v;
    }
}
void dfs2(rint u,rint t){
    top[u]=t;
    if(son[u])dfs2(son[u],t);
    for(rint i=h[u];i;i=e[i].nxt){
        rint v=e[i].to;
        if(v==fa[u]||v==son[u])continue;
        dfs2(v,v);
    }
    h[u]=0;
}
int lca(rint x,rint y){
    while(top[x]^top[y]){
        if(d[top[x]]<d[top[y]])y=fa[top[y]];
        else x=fa[top[x]];
    }
    return d[x]>d[y]?y:x;
}
int colcnt,col;
bool cmp(rint a,rint b){
    return dfn[a]<dfn[b];
}
void insert(rint x){
    if(!tp){
        stk[++tp]=x;
        return ;
    }
    rint lc=lca(stk[tp],x);
    while(tp>1&&d[stk[tp-1]]>=d[lc])
        Ins(stk[tp-1],stk[tp]),tp--;
    if(stk[tp]!=lc)
        Ins(lc,stk[tp]),stk[tp]=lc;
    stk[++tp]=x;
}
void dfs3(rint u){
    siz[u]=0;
    for(rint i=h[u];i;i=e[i].nxt){
        rint v=e[i].to;
        dfs3(v);
        siz[u]+=siz[v];
        res+=1ll*siz[v]*(dis[v]-dis[u]-(v<=n))%p;
    }
    if(c[u]==col)siz[u]++;
    else if(u<=n){
        dp[0]=1;
        for(rint i=1;i<=k;i++)
            dp[i]=0;
        for(rint i=h[u];i;i=e[i].nxt){
            rint v=e[i].to;
            for(rint j=k;j;j--)
                dp[j]=(dp[j]+1ll*dp[j-1]*siz[v])%p;
        }
        for(rint i=1;i<=k;i++)
            res+=dp[i];
    }
    h[u]=0;
}
void solve(){
    tp=idx=0;
    sort(g[col].begin(),g[col].end(),cmp);
    if(g[col][0]!=1)stk[++tp]=1;
    for(rint x:g[col])
        insert(x);
    while(tp>1)
        Ins(stk[tp-1],stk[tp]),tp--;
    dfs3(1);
}
int main(){
    read(n);read(m);read(k);
    for(rint i=1;i<=n;i++){
        read(c[i]);
        if(!mp[c[i]])mp[c[i]]=++colcnt;
        c[i]=mp[c[i]];
    }
    while(m--){
        rint a,b;
        read(a);read(b);
        g[a].push_back(b);g[b].push_back(a);
    }
    fang=n;
    tarjan(1,0);
    Time=0;
    for(rint i=1;i<=n;i++)g[i].clear();
    for(rint i=1;i<=n;i++)g[c[i]].push_back(i);
    dfs1(1);
    dfs2(1,1);
    for(col=1;col<=colcnt;col++)
        solve();
    printf("%lld\n",res%p);
    return 0;
}