P4827 非 FFT 点分治做法
点分治这么可爱怎么就没人写呢。
我都不用 FFT 了为什么不写我!
题意简述
给定一棵
答案对
数据范围:
思路概述
本题可以使用点分治解决。
点分治每次选取当前连通块的重心
对于当前重心
于是:
根据二项式定理:
令:
则:
所以,只要能统计当前连通块中所有点到重心
当前重心的贡献
设当前点分治处理的连通块为
定义:
对于一个固定点
展开后得到:
但这会多算一部分贡献。
如果
因此需要减去同分支内部的贡献。
减去同分支贡献
设点
定义:
那么点
也就是:
对于重心
因为:
为什么每对点只会被计算一次
考虑任意一对有序点
在点分治过程中,
- 当前重心就是
x 或y ; -
设这一层的重心为
此时,路径
在更高层中,如果
在更低层中,
因此每一对有序点
预处理优化
如果每次都重新计算距离幂,会导致常数较大。
由于最终答案对
预处理:
其中:
再预处理:
这样计算某个点的贡献时,可以直接使用:
其中:
复杂度分析
点分治的递归深度为:
每一层中,每个点会被处理一次,每次需要枚举
所以总时间复杂度为:
代码中使用了大量的 vector 以展示容错率,大概是因为点分治跑不满的原因。
:::success[ACcode]
//Author:kevinZ99
#include <bits/stdc++.h>
#define up(a,b,c) for(int (a)=(b);(a)<=(c);(a)=-~(a))
#define dn(a,b,c) for(int (a)=(b);(a)>=(c);(a)=~-(a))
#define fst first
#define sed second
#define pref static inline
#define gc() p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<12,stdin),p1==p2)?EOF:*p1++
using namespace std; using hint = __int128;using pii = pair< int , int > ;
using us = unsigned short ;using ldb = long double ;using ll = long long;
using ull= unsigned long long;using ui=unsigned int;using pll=pair<ll,ll>;
using pil= pair<int,ll> ;using vpil = vector<pil>;using vl = vector<ll>;
using pli= pair<ll,int>;using vpli = vector<pli>;using vi =vector<int>;
using vpi= vector< pii > ;using vpl = vector<pll> ; using db = double ;
namespace mystl{
char buf[1<<20], *p1=buf, *p2=buf, sr[1<<23], z[23], nc;int C=-1 ,Z=0;
template <typename T>pref void read ( T & x ) { bool flag = false ;
while( nc = gc() ,(nc<48||nc>57) && nc!=-1)flag|=(nc==45);x=nc-48;
while(nc=gc(),47<nc&&nc<58)x=(x<<3)+(x<<1)+(nc^48); if(flag)x=-x;}
template < typename T , typename ... Args_Arrays_Typename_KevinZ99 >
void read(T&x,Args_Arrays_Typename_KevinZ99&...a){read(x);read(a...);}
pref void ot( ) { fwrite( sr , 1 , C + 1, stdout ) ; C = - 1 ; }
pref void flush( ) { if ( C > 1<<22 ) ot() ; } template <typename T>
pref void write(T x,char t) { int y = 0 ; if ( x < 0 ) y = 1, x = -x;
while( z [ ++ Z ] = x % 10 + 48 , x /= 10) ; if( y ) z[ ++Z ]='-';
while( sr[ ++ C ] = z[ Z ] , -- Z ) ; sr [ ++C ] = t ; flush() ; }
pref void write(char x) {sr[C=-~C]=x;}pref void write(string s){for(
char t:s)write(t);}pref ll qpow(ll a , ll b,ll p) {if(a==0)return 0;
ll c=1ll; while(b) { if(b & 1) c=a*c%p; a=a*a%p; b>>=1; } return c ; }
pref ll lcm ( ll x , ll y ) {return x / std :: __gcd( x , y ) * y ;}
};
using namespace mystl;
namespace my{
constexpr int P=static_cast<int>(10007);
pref void madd(int & x , int y) { x = ( x + y >= P )?(x+y-P):(x+y);}
pref int fmadd(int x , int y) { return ( x + y >=P )?(x+y-P):(x+y);}
pref void msub(int & x , int y) { x = ( x < y ) ? (x-y+P) : (x-y); }
pref int fmsub(int x , int y) { return ( x < y ) ? (x-y+P) : (x-y);}
pref void mmul ( int & x , int y ) { x = (int)( 1ll * x * y % P ) ; }
pref int fmmul ( int x,int y ) { return (int)( 1ll * x * y % P ) ; }
template<typename T>pref T Min(T x,T y) {return (x<y)?(x):(y);}
template<typename T>pref T Max(T x,T y) {return (x>y)?(x):(y);}
template<typename T>pref T Abs(T x) {return (x<0)?(-x):(x);}
constexpr int N=static_cast<int>(50005),K=155,inf=static_cast<int>(1e9);
int n,k;
vi g[N];
int siz[N],ans[N],C[K][K];
bitset<N>vis;
void dfs(int x,int fa){
siz[x]=1;
for(int v:g[x])if(v^fa&&!vis[v]){
dfs(v,x);
siz[x]+=siz[v];
}
}
int findc(int x,int fa,int tot){
for(int v:g[x])if(v^fa&&!vis[v]&&(siz[v]<<1)>tot)return findc(v,x,tot);
return x;
}
struct Part{
vpi node;
vi m;
Part(int k=0){
m.assign(k+1,0);
}
};
void initial(int x,int fa,int d,Part&part){
part.node.push_back({x,d});
int pw=1,base=d%P;
up(t,0,k)
madd(part.m[t],pw),mmul(pw,base);
for(int v:g[x])if(v^fa&&!vis[v])initial(v,x,d+1,part);
}
int pwdis[P][K],coef[P][K];
int calc(int d,const vi&m){
int base=d%P,ans=0;
up(t,0,k)madd(ans,fmmul(coef[base][t],m[t]));
return ans;
}
void partition(int x){
dfs(x,0);
int c=findc(x,0,siz[x]);
vis[c]=true;
vector<Part>parts;
vi tot(k+1,0);tot[0]=1;
for(int v:g[c]){
if(vis[v])continue;
Part part(k);
initial(v,c,1,part);
up(t,0,k)madd(tot[t],part.m[t]);
parts.push_back(part);
}
madd(ans[c],calc(0,tot));
for(auto&part:parts){
for(auto pair:part.node){
int u=pair.fst,d=pair.sed;
int add=calc(d,tot),sub=calc(d,part.m);
madd(ans[u],fmsub(add,sub));
}
}for(int v:g[c])if(!vis[v])partition(v);
}
pref void SOLVE(){
read(n,k);
up(i,2,n){
int x,y;read(x,y);
g[x].push_back(y);
g[y].push_back(x);
}
up(i,0,k){
C[i][0]=C[i][i]=1;
up(j,1,i-1)C[i][j]=fmadd(C[i-1][j-1],C[i-1][j]);
}
up(d,0,P-1){
pwdis[d][0]=1;
up(t,1,k)pwdis[d][t]=fmmul(pwdis[d][t-1],d);
}
up(d,0,P-1)up(t,0,k)coef[d][t]=fmmul(C[k][t],pwdis[d][k-t]);
partition(1);up(i,1,n)write(ans[i],'\n');ot();
}
}
int main(){
my::SOLVE();
return 0;
}
:::