P3781 [SDOI2017]切树游戏(FWT+动态dp)

· · 题解

P3781 [SDOI2017]切树游戏

首先考虑不带修时的暴力 dp:设 f_{u,i} 表示以 u 为根、价值为 i 的联通子树的个数,h_{u,i} 表示 u 子树内 f_{v,i} 的和,转移为

f_{u,k}+=\sum_{i \oplus j=k}f_{u,i}\times f_{v,j},h_{u,i}=f_{u,i}+\sum h_{v,i}

复杂度为 O(nm^2) 。发现转移式是一个异或卷积的形式,于是可以先 FWT 一次,转移时直接用点值相乘,复杂度为 O(nm)

带修怎么办?动态 dp!设 lf_{u,i}=\prod_{v\in lightson_{u}}(1+f_{v,i}),lh_{u,i}=\sum_{v\in lightson_{u}}h_{v,i} 。设 F,H,LF,LH 分别表示 f,h,lf,lh 的生成函数,设 vu 的重儿子,则转移为

f_{u,i}=(1+f_{v,i})\times lf_{u,i}\\ h_{u,i}=h_{v,i}+lh_{u,i}+(1+f_{v,i})\times lf_{u,i}\\ \begin{pmatrix} F_v,H_v,1\\\end{pmatrix}\times\begin{pmatrix}\ &LF_u,&LF_u,&0\\&0,&1,&0\\&LF_u,&LH_u+LF_u,&1\end{pmatrix}=\begin{pmatrix} F_u,H_u,1\\\end{pmatrix}

优化常数:如下式,其实我们只需要维护 a,b,c,d 处的四个值即可。

\begin{pmatrix}&a_1,&b_1,&0\\&0,&1,&0\\&c_1,&d_1,&1\end{pmatrix}\times\begin{pmatrix}\ &a_2,&b_2&0\\&0,&1,&0\\&c_2,&d_2,&1\end{pmatrix}=\begin{pmatrix}\ &a_1a_2,&a_1b_2+b_1,&0\\&0,&1,&0\\&a_2c_1+c_2,&c_1b_2+d_1+d_2,&1\end{pmatrix}

还有一个细节:跳重链修改父亲 lf 的值的时候可能会出现除以 0 的情况,所以必须再开一个变量记 lf 中 0 的个数,除以 0 就直接减一,通过这个变量推出 lf 的真实值。

这里放一个loj能过,洛谷被卡到 T 飞的垃圾树剖代码,复杂度 O(qm\log^2n),洛谷上 80 pts。

#include<bits/stdc++.h>
#define inl inline
#define mid ((l+r)>>1)
using namespace std;
typedef pair<int,int> pii;
namespace FGF
{
    int n,m,Q;
    const int N=30005,M=135,mo=1e4+7;
    int a[N],siz[N],son[N],fa[N],dep[N],top[N],dfn[N],id[N],c[M][M],f[N][M],h[N][M],lh[N][M],tmp1[M],tmp2[M],ed[N],inv[N],num;
    vector<int> g[N];
    struct matrx{
        int a[M],b[M],c[M],d[M];
        void clear()
        {
            memset(a,0,sizeof(int[m])),memset(b,0,sizeof(int[m])),memset(c,0,sizeof(int[m])),memset(d,0,sizeof(int[m]));
        }
    }t[N<<2];
    struct Node{
        int x,y;
        inline void operator =(int a){a?(x=a,y=0):(x=1,y=1);}
        friend Node operator *(Node a,int b){!b?a.y++:a.x=1ll*a.x*b%mo;return a;}
        friend Node operator /(Node a,int b){!b?a.y--:a.x=1ll*a.x*inv[b]%mo;return a;}
        int val(){return y?0:x;}
    }lf[N][M];
    inl matrx operator *(const matrx &x,const matrx &y)
    {
        matrx s;s.clear();
        for(int i=0;i<m;i++)
        {
            s.a[i]=x.a[i]*y.a[i]%mo;
            s.b[i]=(x.a[i]*y.b[i]%mo+x.b[i])%mo;
            s.c[i]=(x.c[i]*y.a[i]%mo+y.c[i])%mo;
            s.d[i]=(x.c[i]*y.b[i]%mo+x.d[i]+y.d[i])%mo;
        }
        return s;
    }
    void dfs1(int u,int f)
    {
        fa[u]=f,dep[u]=dep[f]+1,siz[u]=1;
        for(auto v:g[u])
            if(v!=f)
            {
                dfs1(v,u);
                siz[u]+=siz[v];
                if(siz[v]>siz[son[u]])son[u]=v;
            }
    }
    void dfs2(int u,int tp)
    {
        top[u]=tp,dfn[u]=++num,id[num]=u,ed[tp]=u;
        if(son[u])dfs2(son[u],tp);
        for(auto v:g[u])
            if(v!=son[u]&&v!=fa[u])dfs2(v,v);
    }
    inl void FWTxor(int *y,int op)
    {
        for(int i=1;1<<i<=m;i++)
            for(int j=0;j<m;j+=(1<<i))
                for(int k=0,c,d;k<(1<<(i-1));k++)
                    c=(y[j|k]+y[j|k|(1<<(i-1))])%mo,d=(y[j|k]-y[j|k|(1<<(i-1))]+mo)%mo,
                    y[j|k]=1ll*op*c%mo,y[j|k|(1<<(i-1))]=1ll*op*d%mo;
    }
    void dfs3(int u)
    {
        for(int i=0;i<m;i++)
            f[u][i]=c[a[u]][i];
        for(auto v:g[u])
            if(v!=fa[u])
            {
                dfs3(v);
                for(int i=0;i<m;i++)
                    f[u][i]=(f[u][i]+f[u][i]*f[v][i]%mo)%mo,h[u][i]=(h[u][i]+h[v][i])%mo;
            }  
        for(int i=0;i<m;i++)
            h[u][i]=(h[u][i]+f[u][i])%mo;
    }
    void build(int ro,int l,int r)
    {
        if(l==r)
        {
            int u=id[l];
            for(int i=0;i<m;i++)
                lf[u][i]=c[0][i];
            for(auto v:g[u])
                if(v!=son[u]&&v!=fa[u])
                    for(int i=0;i<m;i++)
                        lf[u][i]=lf[u][i]*((f[v][i]+1)%mo),lh[u][i]=(lh[u][i]+h[v][i])%mo;
            for(int i=0;i<m;i++)
                t[ro].a[i]=t[ro].b[i]=t[ro].c[i]=t[ro].d[i]=lf[u][i].val()*c[a[u]][i]%mo,(t[ro].d[i]+=lh[u][i])%=mo;
            return; 
        }
        build(ro<<1,l,mid),build(ro<<1|1,mid+1,r);
        t[ro]=t[ro<<1|1]*t[ro<<1];
    }
    matrx query(int ro,int l,int r,int L,int R)
    {
        if(L<=l&&r<=R)return t[ro];
        if(R<=mid)return query(ro<<1,l,mid,L,R);
        if(L>mid)return query(ro<<1|1,mid+1,r,L,R);
        return query(ro<<1|1,mid+1,r,L,R)*query(ro<<1,l,mid,L,R);
    }
    void modif(int ro,int l,int r,int x)
    {
        if(l==r)
        {
            int u=id[l];
            for(int i=0;i<m;i++)
                t[ro].a[i]=t[ro].b[i]=t[ro].c[i]=t[ro].d[i]=lf[u][i].val()*c[a[u]][i]%mo,(t[ro].d[i]+=lh[u][i])%=mo;
            return; 
        }
        x<=mid?modif(ro<<1,l,mid,x):modif(ro<<1|1,mid+1,r,x);
        t[ro]=t[ro<<1|1]*t[ro<<1];
    }
    void askk(int x)
    {
        matrx ans=query(1,1,n,dfn[top[x]],dfn[ed[top[x]]]);
        for(int i=0;i<m;i++)
            tmp1[i]=ans.c[i],tmp2[i]=ans.d[i];
    }
    void updat(int x,int y)
    {
        a[x]=y;
        while(x)
        {
            askk(x);
            if(fa[top[x]])
                for(int i=0;i<m;i++)
                    lf[fa[top[x]]][i]=lf[fa[top[x]]][i]/((tmp1[i]+1)%mo),lh[fa[top[x]]][i]=(lh[fa[top[x]]][i]-tmp2[i]+mo)%mo;
            modif(1,1,n,dfn[x]),askk(x);    
            if(fa[top[x]])
                for(int i=0;i<m;i++)
                    lf[fa[top[x]]][i]=lf[fa[top[x]]][i]*((tmp1[i]+1)%mo),lh[fa[top[x]]][i]=(lh[fa[top[x]]][i]+tmp2[i])%mo;
            x=fa[top[x]];
        }
    }
    void work()
    {
        scanf("%d%d",&n,&m);
        for(int i=1;i<=n;i++)   
            scanf("%d",&a[i]);
        for(int i=1,u,v;i<n;i++)
            scanf("%d%d",&u,&v),g[u].push_back(v),g[v].push_back(u);
        for(int i=0;i<m;i++)
            c[i][i]=1,FWTxor(c[i],1);
        inv[1]=1;
        for(int i=2;i<mo;i++)
            inv[i]=inv[mo%i]*(mo-mo/i)%mo;
        dfs1(1,0),dfs2(1,1),dfs3(1);
        build(1,1,n);
        scanf("%d",&Q);
        while(Q--)
        {
            char ch[20];int x,y;
            scanf("%s%d",ch,&x);
            if(ch[0]=='C')scanf("%d",&y),updat(x,y);
            else askk(1),FWTxor(tmp2,(mo+1)/2),printf("%d\n",tmp2[x]);
        }
    }
}
int main()
{
    FGF::work();
    return 0;
}

全局平衡二叉树:

#include<bits/stdc++.h>
#define inl inline
using namespace std;
typedef pair<int,int> pii;
namespace FGF
{
    int n,m,Q;
    const int N=30005,M=135,mo=1e4+7;
    int a[N],siz[N],son[N],c[M][M],f[N][M],h[N][M],lh[N][M],tmp[M],ff[N],inv[N];
    int fa[N],rt,ls[N],rs[N],st[N],sum[N];
    vector<int> g[N];
    struct matrx{
        int a[M],b[M],c[M],d[M];
        void clear()
        {
            memset(a,0,sizeof(int[m])),memset(b,0,sizeof(int[m])),memset(c,0,sizeof(int[m])),memset(d,0,sizeof(int[m]));
        }
    }hw[N];
    struct Node{
        int x,y;
        inline void operator =(int a){a?(x=a,y=0):(x=1,y=1);}
        friend Node operator *(Node a,int b){!b?a.y++:a.x=1ll*a.x*b%mo;return a;}
        friend Node operator /(Node a,int b){!b?a.y--:a.x=1ll*a.x*inv[b]%mo;return a;}
        int val(){return y?0:x;}
    }lf[N][M];
    inl matrx operator *(const matrx &x,const matrx &y)
    {
        matrx s;s.clear();
        for(int i=0;i<m;i++)
        {
            s.a[i]=x.a[i]*y.a[i]%mo;
            s.b[i]=(x.a[i]*y.b[i]%mo+x.b[i])%mo;
            s.c[i]=(x.c[i]*y.a[i]%mo+y.c[i])%mo;
            s.d[i]=(x.c[i]*y.b[i]%mo+x.d[i]+y.d[i])%mo;
        }
        return s;
    }
    void dfs(int u,int fa)
    {
        ff[u]=fa,siz[u]=1;
        for(int i=0;i<m;i++)
            f[u][i]=c[a[u]][i],lf[u][i]=c[0][i];
        for(auto v:g[u])
            if(v!=fa)
            {
                dfs(v,u);
                siz[u]+=siz[v];
                if(siz[v]>siz[son[u]])son[u]=v;
                for(int i=0;i<m;i++)
                    f[u][i]=(f[u][i]+f[u][i]*f[v][i]%mo)%mo,h[u][i]=(h[u][i]+h[v][i])%mo;
            }
        for(int i=0;i<m;i++)
            h[u][i]=(h[u][i]+f[u][i])%mo;
        for(auto v:g[u])
            if(v!=son[u]&&v!=ff[u])
                for(int i=0;i<m;i++)
                    lf[u][i]=lf[u][i]*((f[v][i]+1)%mo),lh[u][i]=(lh[u][i]+h[v][i])%mo;
    }
    inl void FWTxor(int *y,int op)
    {
        for(int i=1;1<<i<=m;i++)
            for(int j=0;j<m;j+=(1<<i))
                for(int k=0,c,d;k<(1<<(i-1));k++)
                    c=(y[j|k]+y[j|k|(1<<(i-1))])%mo,d=(y[j|k]-y[j|k|(1<<(i-1))]+mo)%mo,
                    y[j|k]=1ll*op*c%mo,y[j|k|(1<<(i-1))]=1ll*op*d%mo;
    }
    void pushup(int x)
    {
        for(int i=0;i<m;i++)
            hw[x].a[i]=hw[x].b[i]=hw[x].c[i]=hw[x].d[i]=lf[x][i].val()*c[a[x]][i]%mo,(hw[x].d[i]+=lh[x][i])%=mo;
        if(ls[x])hw[x]=hw[x]*hw[ls[x]];
        if(rs[x])hw[x]=hw[rs[x]]*hw[x];
    }
    int construc(int L,int R)
    {
        int l=L,r=R,mid,ans,x;
        while(l<=r)
        {
            mid=(l+r)>>1;
            sum[mid-1]-sum[L-1]<=sum[R]-sum[mid-1]?(l=mid+1,ans=mid):r=mid-1;
        }
        x=st[ans];
        if(ans!=L)ls[x]=construc(L,ans-1),fa[ls[x]]=x;
        if(ans!=R)rs[x]=construc(ans+1,R),fa[rs[x]]=x;
        pushup(x);
        return x;
    }
    int build(int x)
    {
        int y=x;
        do{
            for(auto v:g[y])
                if(v!=son[y]&&v!=ff[y])fa[build(v)]=y;
        }while(y=son[y]);
        while(x)st[++y]=x,sum[y]=sum[y-1]+siz[x]-siz[son[x]],x=son[x];
        return construc(1,y);
    }
    void updat(int x,int y)
    {
        a[x]=y;
        while(x)
        {
            if(fa[x]&&ls[fa[x]]!=x&&rs[fa[x]]!=x)
            {
                for(int i=0;i<m;i++)
                    lf[fa[x]][i]=lf[fa[x]][i]/((hw[x].c[i]+1)%mo),lh[fa[x]][i]=(lh[fa[x]][i]-hw[x].d[i]+mo)%mo;
                pushup(x);  
                for(int i=0;i<m;i++)
                    lf[fa[x]][i]=lf[fa[x]][i]*((hw[x].c[i]+1)%mo),lh[fa[x]][i]=(lh[fa[x]][i]+hw[x].d[i])%mo;
            }
            else pushup(x);
            x=fa[x];
        }
    }
    void work()
    {
        scanf("%d%d",&n,&m);
        for(int i=1;i<=n;i++)   
            scanf("%d",&a[i]);
        for(int i=1,u,v;i<n;i++)
            scanf("%d%d",&u,&v),g[u].push_back(v),g[v].push_back(u);
        for(int i=0;i<m;i++)
            c[i][i]=1,FWTxor(c[i],1);
        inv[1]=1;
        for(int i=2;i<mo;i++)
            inv[i]=inv[mo%i]*(mo-mo/i)%mo;
        dfs(1,0),rt=build(1);;
        scanf("%d",&Q);
        while(Q--)
        {
            char ch[20];int x,y;
            scanf("%s%d",ch,&x);
            if(ch[0]=='C')scanf("%d",&y),updat(x,y);
            else 
            {
                for(int i=0;i<m;i++)
                    tmp[i]=hw[rt].d[i];
                FWTxor(tmp,(mo+1)/2),printf("%d\n",tmp[x]);
            }
        }
    }
}
int main()
{
    FGF::work();
    return 0;
}