P4827 非 FFT 点分治做法

· · 题解

点分治这么可爱怎么就没人写呢。

我都不用 FFT 了为什么不写我!

题意简述

给定一棵 n 个点的树,对每个点 i,要求计算:

S(i)=\sum_{j=1}^{n}\operatorname{dist}(i,j)^k

答案对 10007 取模。

数据范围:

1\le n\le 5\times 10^4,\quad 1\le k\le 150

思路概述

本题可以使用点分治解决。

点分治每次选取当前连通块的重心 c,统计所有经过 c 的路径贡献,然后递归处理每个子树。

对于当前重心 c,如果点 x 和点 y 的路径经过 c,则有:

\operatorname{dist}(x,y)=\operatorname{dist}(x,c)+\operatorname{dist}(y,c)

于是:

\operatorname{dist}(x,y)^k=(\operatorname{dist}(x,c)+\operatorname{dist}(y,c))^k

根据二项式定理:

(a+b)^k=\sum_{t=0}^{k}\binom{k}{t}a^{k-t}b^t

令:

a=\operatorname{dist}(x,c),\quad b=\operatorname{dist}(y,c)

则:

\operatorname{dist}(x,y)^k= \sum_{t=0}^{k} \binom{k}{t} \operatorname{dist}(x,c)^{k-t} \operatorname{dist}(y,c)^t

所以,只要能统计当前连通块中所有点到重心 c 的距离的各次幂之和,就可以快速计算当前重心对每个点的贡献。

当前重心的贡献

设当前点分治处理的连通块为 T,重心为 c

定义:

M[t]=\sum_{y\in T}\operatorname{dist}(c,y)^t

对于一个固定点 x,如果暂时不考虑 xy 是否在同一个分支内,那么所有点 y\in Tx 的贡献为:

\sum_{y\in T} (\operatorname{dist}(x,c)+\operatorname{dist}(y,c))^k

展开后得到:

\sum_{t=0}^{k} \binom{k}{t} \operatorname{dist}(x,c)^{k-t} M[t]

但这会多算一部分贡献。

如果 xy 在重心 c 的同一个子树分支中,那么路径 x\to y 并不经过 c,不能在当前重心处计算。

因此需要减去同分支内部的贡献。

减去同分支贡献

设点 x 位于重心 c 的某个分支 B 中。

定义:

M_B[t]=\sum_{y\in B}\operatorname{dist}(c,y)^t

那么点 x 在当前重心 c 处真正应该增加的贡献为:

\sum_{t=0}^{k} \binom{k}{t} \operatorname{dist}(x,c)^{k-t} (M[t]-M_B[t])

也就是:

ans[x] \gets \sum_{t=0}^{k} \binom{k}{t} \operatorname{dist}(x,c)^{k-t} (M[t]-M_B[t])

对于重心 c 自己,不属于任何一个分支,因此它在当前连通块内的贡献可以直接加入:

ans[c] \gets M[k]

因为:

M[k]=\sum_{y\in T}\operatorname{dist}(c,y)^k

为什么每对点只会被计算一次

考虑任意一对有序点 (x,y)

在点分治过程中,xy 会在某一层第一次满足以下条件之一:

  1. 当前重心就是 xy

设这一层的重心为 c

此时,路径 x\to y 一定经过 c,因此可以在这一层正确计算:

\operatorname{dist}(x,y)^k

在更高层中,如果 xy 仍然位于同一个分支内,那么它们会被同分支减法抵消。

在更低层中,xy 已经被分到不同的递归连通块中,不会再被同时处理。

因此每一对有序点 (x,y) 的贡献会被计算且仅被计算一次。

预处理优化

如果每次都重新计算距离幂,会导致常数较大。

由于最终答案对 10007 取模,因此距离 d 只需要关心 d\bmod 10007

预处理:

pw[d][t]=d^t\bmod 10007

其中:

0\le d<10007,\quad 0\le t\le k

再预处理:

coef[d][t]=\binom{k}{t}d^{k-t}\bmod 10007

这样计算某个点的贡献时,可以直接使用:

\sum_{t=0}^{k}coef[d][t]\cdot (M[t]-M_B[t])

其中:

d=\operatorname{dist}(x,c)

复杂度分析

点分治的递归深度为:

O(\log n)

每一层中,每个点会被处理一次,每次需要枚举 0k 的幂次。

所以总时间复杂度为:

O(nk\log n)

代码中使用了大量的 vector 以展示容错率,大概是因为点分治跑不满的原因。

:::success[ACcode]

//Author:kevinZ99
#include <bits/stdc++.h>
#define up(a,b,c) for(int (a)=(b);(a)<=(c);(a)=-~(a))
#define dn(a,b,c) for(int (a)=(b);(a)>=(c);(a)=~-(a))
#define fst first
#define sed second
#define pref static inline
#define gc() p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<12,stdin),p1==p2)?EOF:*p1++
using namespace std; using hint = __int128;using pii = pair< int , int > ;
using us = unsigned short ;using ldb  = long double ;using ll = long long;
using ull= unsigned long long;using ui=unsigned int;using pll=pair<ll,ll>;
using pil= pair<int,ll> ;using vpil   = vector<pil>;using vl = vector<ll>;
using pli= pair<ll,int>;using vpli    = vector<pli>;using vi =vector<int>;
using vpi= vector< pii > ;using vpl   = vector<pll> ; using db =  double ;
namespace mystl{
    char buf[1<<20], *p1=buf, *p2=buf, sr[1<<23], z[23], nc;int C=-1 ,Z=0;
    template <typename T>pref void read ( T & x )    { bool flag = false ;
        while( nc = gc() ,(nc<48||nc>57) && nc!=-1)flag|=(nc==45);x=nc-48;
        while(nc=gc(),47<nc&&nc<58)x=(x<<3)+(x<<1)+(nc^48); if(flag)x=-x;}
    template <  typename T , typename ... Args_Arrays_Typename_KevinZ99  >
    void read(T&x,Args_Arrays_Typename_KevinZ99&...a){read(x);read(a...);}
    pref void ot(  )    {  fwrite( sr , 1 , C + 1, stdout ) ;  C = - 1 ; }
    pref void flush( )   { if ( C > 1<<22 ) ot() ; } template <typename T>
    pref void write(T x,char t)   { int y = 0 ; if ( x < 0 ) y = 1, x = -x;
        while( z [ ++ Z ] = x % 10 + 48 , x /= 10) ; if( y ) z[ ++Z ]='-';
        while( sr[ ++ C ] = z[ Z ] , -- Z ) ; sr [ ++C ] = t ; flush() ; }
    pref void write(char x)   {sr[C=-~C]=x;}pref void write(string s){for(
    char t:s)write(t);}pref ll qpow(ll a , ll b,ll p)   {if(a==0)return 0;
    ll c=1ll; while(b) { if(b & 1) c=a*c%p; a=a*a%p; b>>=1; } return c ; }
    pref ll lcm ( ll x , ll y )   {return x / std :: __gcd( x , y ) * y ;}
};
using namespace mystl;
namespace my{
    constexpr int P=static_cast<int>(10007);
    pref void madd(int & x , int y)    { x = ( x + y >= P )?(x+y-P):(x+y);}
    pref int fmadd(int x , int y)    { return ( x + y >=P )?(x+y-P):(x+y);}
    pref void msub(int & x , int y)    { x = ( x < y ) ? (x-y+P) : (x-y); }
    pref int fmsub(int x , int y)    { return ( x < y ) ? (x-y+P) : (x-y);}
    pref void mmul ( int & x , int y )   { x = (int)( 1ll * x * y % P ) ; }
    pref int fmmul ( int x,int y )    { return (int)( 1ll * x * y % P ) ; }
    template<typename T>pref T Min(T x,T y)   {return (x<y)?(x):(y);}
    template<typename T>pref T Max(T x,T y)   {return (x>y)?(x):(y);}
    template<typename T>pref T Abs(T x)      {return (x<0)?(-x):(x);}
    constexpr int N=static_cast<int>(50005),K=155,inf=static_cast<int>(1e9);
    int n,k;
    vi g[N];
    int siz[N],ans[N],C[K][K];
    bitset<N>vis;
    void dfs(int x,int fa){
        siz[x]=1;
        for(int v:g[x])if(v^fa&&!vis[v]){
            dfs(v,x);
            siz[x]+=siz[v];
        }
    }
    int findc(int x,int fa,int tot){
        for(int v:g[x])if(v^fa&&!vis[v]&&(siz[v]<<1)>tot)return findc(v,x,tot);
        return x;
    }
    struct Part{
        vpi node;
        vi m;
        Part(int k=0){
            m.assign(k+1,0);
        }
    };
    void initial(int x,int fa,int d,Part&part){
        part.node.push_back({x,d});
        int pw=1,base=d%P;
        up(t,0,k)
            madd(part.m[t],pw),mmul(pw,base);
        for(int v:g[x])if(v^fa&&!vis[v])initial(v,x,d+1,part);
    }
    int pwdis[P][K],coef[P][K];
    int calc(int d,const vi&m){
        int base=d%P,ans=0;
        up(t,0,k)madd(ans,fmmul(coef[base][t],m[t]));
        return ans;
    }
    void partition(int x){
        dfs(x,0);
        int c=findc(x,0,siz[x]);
        vis[c]=true;
        vector<Part>parts;
        vi tot(k+1,0);tot[0]=1;
        for(int v:g[c]){
            if(vis[v])continue;
            Part part(k);
            initial(v,c,1,part);
            up(t,0,k)madd(tot[t],part.m[t]);
            parts.push_back(part);
        }
        madd(ans[c],calc(0,tot));
        for(auto&part:parts){
            for(auto pair:part.node){
                int u=pair.fst,d=pair.sed;
                int add=calc(d,tot),sub=calc(d,part.m);
                madd(ans[u],fmsub(add,sub));
            }
        }for(int v:g[c])if(!vis[v])partition(v);
    }
    pref void SOLVE(){
        read(n,k);
        up(i,2,n){
            int x,y;read(x,y);
            g[x].push_back(y);
            g[y].push_back(x);  
        }
        up(i,0,k){
            C[i][0]=C[i][i]=1;
            up(j,1,i-1)C[i][j]=fmadd(C[i-1][j-1],C[i-1][j]);
        }
        up(d,0,P-1){
            pwdis[d][0]=1;
            up(t,1,k)pwdis[d][t]=fmmul(pwdis[d][t-1],d);
        }
        up(d,0,P-1)up(t,0,k)coef[d][t]=fmmul(C[k][t],pwdis[d][k-t]);
        partition(1);up(i,1,n)write(ans[i],'\n');ot();
    }
}
int main(){
        my::SOLVE();
    return 0;
}

:::