P4365 [九省联考2018]秘密袭击coat(插值+线段树合并优化dp)

· · 题解

P4365 [九省联考2018]秘密袭击coat

先把问题转化为,对于每个权值 i,求出权值大于等于 i 的点数 = k 的连通块个数。求和就是答案。 设 f_{u,i,j} 表示 u 号点为连通块根,权值 \ge i 的有 j 个点的方案数,再设 g_{u,i,j}=f_{u,i,j}+\sum g_{v,i,j}。则有

f_{u,i,j}=\begin{cases}\prod_{\sum k=j}f_{v,i,k}&d_u<i\\\prod_{\sum k=j-1}f_{v,i,k}&d_u\ge i\end{cases}

这是一个卷积的形式。把 F,G 看作生成函数,插 n+1 个值代入,转移式即为

F(u,i)=x^{[d_u\ge i]}\prod (F(v,i)+1)\\ G(u,i)=F(u,i)+\sum G(v,i)\\

对应位置相乘、相加,这是一个线段树合并的形式。每次进行的操作为:全局赋值为 1 ,[1,d_u]x ,全局+1,F 对应位置相乘,G 对应位置相加,将 F 数组加到 G 数组上。

考虑一个变换,初始时 b=f,d=g

(a,b,c,d)*(f,g)=(af+b,cf+g+d)

手推一下发现这个变换有结合律。

用线段树维护这个变换即可。最后再用拉格朗日插值求出 k\sim n 次项系数即可,方法见这里。

#include<bits/stdc++.h>
#define ls t[ro].l
#define rs t[ro].r
#define mid ((l+r)>>1)
using namespace std;
typedef long long ll;
namespace FGF
{
    int n,K,w;
    const int N=2005,mo=64123;
    struct Node{
        ll a,b,c,d;
        Node(){}
        Node(int _a,int _b,int _c,int _d):a(_a),b(_b),c(_c),d(_d){};
    };
    Node operator *(Node x,Node y){return Node(x.a*y.a%mo,(y.a*x.b+y.b)%mo,(x.a*y.c+x.c)%mo,(y.c*x.b+x.d+y.d)%mo);}
    struct tree{
        int l,r;Node val;
        void init(){l=r=0;val=Node(1,0,0,0);}
    }t[N*40];
    int a[N],rt[N],num,ans[N],inv[N];
    vector<int> g[N];
    void pushdown(int ro)
    {
        if(!ls)ls=++num,t[ls].init();
        if(!rs)rs=++num,t[rs].init();
        t[ls].val=t[ls].val*t[ro].val,t[rs].val=t[rs].val*t[ro].val;
        t[ro].val=Node(1,0,0,0);
    }
    ll query(int ro,int l,int r)
    {
        if(l==r)return t[ro].val.d;
        pushdown(ro);
        return (query(ls,l,mid)+query(rs,mid+1,r))%mo;
    }
    void updat(int ro,int l,int r,int L,int R,const Node &x)
    {
        if(L<=l&&r<=R){t[ro].val=t[ro].val*x;return;}
        pushdown(ro);
        if(L<=mid)updat(ls,l,mid,L,R,x);
        if(R>mid)updat(rs,mid+1,r,L,R,x);
    }
    int merge(int &x,int &y)
    {
        if(!x||!y)return x+y;
        if(!t[x].l&&!t[x].r)swap(x,y);
        if(!t[y].l&&!t[y].r)
        {
            t[x].val=t[x].val*Node(t[y].val.b,0,0,t[y].val.d);
            return x;
        }
        pushdown(x),pushdown(y);
        t[x].l=merge(t[x].l,t[y].l),t[x].r=merge(t[x].r,t[y].r);
        return x;
    }
    void dfs(int u,int f,int k)
    {
        rt[u]=++num,t[rt[u]].init();
        updat(rt[u],1,w,1,w,Node(0,1,0,0));
        updat(rt[u],1,w,1,a[u],Node(k,0,0,0));
        for(auto v:g[u])
            if(v!=f)dfs(v,u,k),updat(rt[v],1,w,1,w,Node(1,1,0,0)),rt[u]=merge(rt[u],rt[v]);
        updat(rt[u],1,w,1,w,Node(1,0,1,0)); 
    }
    int Lagrange()
    {
        static int f[N],inv[N],g[N],s;
        inv[1]=f[0]=1;
        for(int i=2;i<=n+1;i++)
            inv[i]=1ll*inv[mo%i]*(mo-mo/i)%mo;
        for(int i=1;i<=n+1;i++)
            for(int j=n+1;j>=0;j--)
                f[j]=((j?f[j-1]:0)+1ll*(mo-i)*f[j]%mo)%mo;
        for(int i=1;i<=n+1;i++)
        {
            g[0]=1ll*f[0]*(mo-inv[i])%mo;
            for(int j=1;j<=n+1;j++)
                g[j]=1ll*(f[j]-g[j-1]+mo)%mo*(mo-inv[i])%mo;
            int tmp=ans[i];
            for(int j=1;j<=n+1;j++)
            {
                if(i>j)tmp=1ll*tmp*inv[i-j]%mo;
                if(i<j)tmp=1ll*tmp*(mo-inv[j-i])%mo;
            }
            for(int j=K;j<=n+1;j++)
                s=(s+1ll*tmp*g[j]%mo)%mo;
        }
        return s;
    }
    void work()
    {
        scanf("%d%d%d",&n,&K,&w);
        for(int i=1;i<=n;i++)
            scanf("%d",&a[i]);
        for(int i=1,u,v;i<n;i++)
            scanf("%d%d",&u,&v),g[u].push_back(v),g[v].push_back(u);
        for(int i=1;i<=n+1;i++)
        {
            num=0;
            dfs(1,0,i);
            ans[i]=query(rt[1],1,w);
        }
        printf("%d\n",Lagrange());
    }
}
int main()
{
    FGF::work();
    return 0;
}