题解 P6773 【[NOI2020]命运】

· · 题解

题意

给定一棵 n 个点的树和 m 条限制,你可以给树上的每一条边赋一个 01 的权值。对于所有限制 (u,v)(保证 vu 的祖先) 你需要保证 uv 上至少有一条边的权值为 1,求赋值方案数。

\texttt{Data Range:}1\leq n,m\leq 5\times 10^5

题解

神仙题,但是大多数题解直接摆结论所以对于我这种还带有 whk 后遗症的菜鸡来说理解起来有点困难,所以写一篇自认为比较清晰的。

首先第一个问题是怎么 DP。

注意到有一个 key observation 是这样的:对于某一个节点 u,如果 (u,v_1) 被满足且 v_2 深度比 v_1 小且也为 u 的祖先,那么限制 (u,v_2) 也能被满足。(其实就是因为 uv_2 的路径包含了 uv_1 的路径)用人话表述一遍的话,就是对于下端点在 u 的所有限制来说,上端点深的能满足的话那么上端点浅的一定也能满足

同时,对于 u 来说,只需要考虑限制的一端在 u 的子树内部的情况。(因为不在子树 u 内部的限制与 u 子树内边的赋值情况无关),再加上刚刚的 key observation 我们就得出了状态表示:f_{u,i} 表示以 u 为根的子树内,下端点在子树内并且没有被满足的限制中上端点的最深深度为 i,对 u 的子树内赋值的方案数。(对于那些没有考虑的边,你可以认为权值默认为 0)其中,f_{u,0} 表示所有限制都被满足的方案数,所以答案为 f_{1,0}

考虑转移,如果将 v 的子树往 u 的子树合并,讨论一下 (u,v) 这条边的权值,如果为 1 那么下端点在 v 子树内上端点在 v 子树外的限制全部都会满足,第二维的 i 必须全部由 u 来贡献,为 0 则意味着两边贡献出来的 f 第二维最大值必须为 i,写成表达式就会有:

f_{u,i}\gets \sum\limits_{j=0}^{dep_u}f_{u,i}f_{v,j}+\sum\limits_{j=0}^{i}f_{u,i}f_{v,j}+\sum\limits_{j=0}^{i-1}f_{u,j}f_{v,i}

第一个和式表示 (u,v) 权值为 1,后面两个表示权值为 0,第三个和式减去上界的原因是 f_{u,i}f_{v,i} 已经算过。

这个时候设 g_{u,i} 表示 \sum\limits_{j=0}^{i}f_{u,i},则转移可以写成这样:

f_{u,i}\gets f_{u,i}(g_{v,dep_u}+g_{v,i})+g_{u,i-1}f_{v,i}

然后我们就有 O(n^2) 的解法了,见此处,所以考虑优化。

这个时候考虑线段树合并。合并的过程中借用一点 cdq 分治的思路,也就是说先合并左子树顺便更新 g,然后再合并右子树。然后因为每次转移是形如一个乘某个值再加某个值的情况,所以需要写一个乘法标记,时间复杂度 O(n\log n)

代码

#include<bits/stdc++.h>
using namespace std;
typedef int ll;
typedef long long int li;
const ll MAXN=5e5+51,MOD=998244353;
ll n,m,u,v,totn;
ll depth[MAXN],ls[MAXN<<5],rs[MAXN<<5],sm[MAXN<<5],tag[MAXN<<5],rt[MAXN];
vector<ll>vg[MAXN],link[MAXN];
inline ll read()
{
    register ll num=0,neg=1;
    register char ch=getchar();
    while(!isdigit(ch)&&ch!='-')
    {
        ch=getchar();
    }
    if(ch=='-')
    {
        neg=-1;
        ch=getchar();
    }
    while(isdigit(ch))
    {
        num=(num<<3)+(num<<1)+(ch-'0');
        ch=getchar();
    }
    return num*neg;
}
inline void update(ll node)
{
    sm[node]=(sm[ls[node]]+sm[rs[node]])%MOD;
}
inline void spread(ll node)
{
    if(tag[node]!=1)
    {
        sm[ls[node]]=(li)tag[node]*sm[ls[node]]%MOD;
        sm[rs[node]]=(li)tag[node]*sm[rs[node]]%MOD;
        tag[ls[node]]=(li)tag[node]*tag[ls[node]]%MOD;
        tag[rs[node]]=(li)tag[node]*tag[rs[node]]%MOD;
        tag[node]=1;
    }
}
inline void change(ll l,ll r,ll pos,ll val,ll &node)
{
    !node?tag[node=++totn]=1:1;
    if(l==r)
    {
        return (void)(tag[node]=1,sm[node]=val);
    }
    ll mid=(l+r)>>1;
    spread(node);
    pos<=mid?change(l,mid,pos,val,ls[node]):change(mid+1,r,pos,val,rs[node]);
    update(node);
}
inline ll query(ll l,ll r,ll ql,ll qr,ll node)
{
    if(ql<=l&&qr>=r)
    {
        return sm[node];
    }
    ll mid=(l+r)>>1,res=0;
    spread(node);
    res=(res+(ql<=mid?query(l,mid,ql,qr,ls[node]):0))%MOD;
    res=(res+(qr>mid?query(mid+1,r,ql,qr,rs[node]):0))%MOD;
    return res;
}
inline ll merge(ll x,ll y,ll l,ll r,ll &su,ll &sv)
{
    if(!x&&!y)
    {
        return 0;
    }
    if(!x)
    {
        sv=(sv+sm[y])%MOD,tag[y]=(li)tag[y]*su%MOD;
        return sm[y]=(li)sm[y]*su%MOD,y;
    }
    if(!y)
    {
        su=(su+sm[x])%MOD,tag[x]=(li)tag[x]*sv%MOD;
        return sm[x]=(li)sm[x]*sv%MOD,x;
    }
    if(l==r)
    {
        ll cu=sm[x],cv=sm[y];
        sv=(sv+cv)%MOD,sm[x]=((li)sm[x]*sv+(li)sm[y]*su)%MOD;
        return su=(su+cu)%MOD,x;
    }
    ll mid=(l+r)>>1;
    spread(x),spread(y);
    ls[x]=merge(ls[x],ls[y],l,mid,su,sv);
    rs[x]=merge(rs[x],rs[y],mid+1,r,su,sv);
    update(x);
    return x;
}
inline void dfs(ll node,ll fa)
{
    depth[node]=depth[fa]+1;
    ll mxd=0,su,sv;
    for(register int i:link[node])
    {
        mxd=max(mxd,depth[i]);
    }
    change(0,n,mxd,1,rt[node]);
    for(register int i:vg[node])
    {
        if(i!=fa)
        {
            dfs(i,node),su=0,sv=query(0,n,0,depth[node],rt[i]);
            rt[node]=merge(rt[node],rt[i],0,n,su,sv);
        }
    }
}
int main()
{
    n=read();
    for(register int i=0;i<n-1;i++)
    {
        u=read(),v=read(),vg[u].emplace_back(v),vg[v].emplace_back(u);
    }
    m=read();
    for(register int i=1;i<=m;i++)
    {
        u=read(),v=read(),link[v].push_back(u);
    }
    dfs(1,0),printf("%d\n",query(0,n,0,0,rt[1]));
}