[ICPC2023 Jinan R] Graph Partitioning 2 题解
题目链接:gym qoj
题目大意:给定一棵树和常数
upd 2025.4.21:已更新线性做法。
考虑 DP。设
然后你发现如果
对于剩下的点,记
然后所以一个
现在我们可以认为树上只剩下
总的复杂度为
以下做法来自 LHF。
其实我们还可以更进一步:不使用 NTT,我们直接将两个数组里有值的位置暴力乘起来就就可以做到
证明:注意到
然后其实这个东西可以被视为
所以子树大小分别为
这个形式很像树上背包,考虑用类似的分析方式。
于是我们成功把复杂度分析到了线性。
目前为止我是高贵的最优解。
#include <bits/stdc++.h>
using namespace std;
constexpr int Spp{1<<20},S2{1<<20};
char buf[Spp],*p1,*p2,buf2[S2],*l2=buf2,_st[22];
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,Spp,stdin),p1==p2)?EOF:*p1++)
#define putchar(c) (l2==buf2+S2&&(fwrite(buf2,1,S2,stdout),l2=buf2),*l2++=(c))
template <typename T>
void read(T &x) {
char c;int f{1};
do x=(c=getchar())^48;
while (!isdigit(c)&&c!='-');
if (x==29) f=-1,x=0;
while (isdigit(c=getchar()))
x=(x*10)+(c^48);
x*=f;
}
template <typename T,typename ...Args>
void read(T& x,Args&... args) {read(x);read(args...);}
template<typename T>
void write(T x,char c='\n') {
if (x<0) putchar('-'),x*=-1;
int tp=0;
do _st[++tp]=x%10; while (x/=10);
while (tp) putchar(_st[tp--]+'0');
putchar(c);
}
struct OI{~OI(){fwrite(buf2,1,l2-buf2,stdout);}}oi;
constexpr int N(1e5),b6e0{998244353};
int F[N*3+5],h[N+5],w[N+5];
int sz[N+5];
vector<int> e[N+5];
int n,k;
void mul(int a[],int b[],int sz) {
vector<pair<int,int>> l,r;
fill_n(w,k+2,0);
for (int i{0};i<=k+1;++i) {
if (a[i]) l.emplace_back(i,a[i]);
if (i<=sz&&b[i]) r.emplace_back(i,b[i]);
}
for (auto [x,px]:l)
for (auto [y,py]:r)
if (x+y<=k+1)
w[x+y]=(w[x+y]+1LL*px*py)%b6e0;
}
void init(int u) {
sz[u]=1;
for (auto v:e[u]) {
e[v].erase(find(e[v].begin(),e[v].end(),u));
init(v);
sz[u]+=sz[v];
}
int tw{0};
for (auto v:e[u])
if (sz[v]>sz[tw]) tw=v;
for (auto &v:e[u])
if (v==tw) {
swap(v,e[u].back());
break;
}
}
int *f;
bool solve(int u) {
int sc{0},s{1};
for (auto v:e[u])
if (sz[v]>=k) ++sc;
else s+=sz[v];
int *g{f};
if (s>k+1)
return false;
f+=s;
for (auto v:e[u])
if (sz[v]>=k&&!solve(v))
return false;
if (sc==0) {
fill_n(g,k+2,0);
g[sz[u]]=1;
} else if (sc==1) {
fill_n(g,s,0);
} else {
fill_n(h,k+2,0);
h[s]=1;
f=g+s;
for (auto v:e[u]) {
if (sz[v]>=k) {
mul(h,f,sz[v]);
copy_n(w,k+2,h);
f+=k+2;
}
}
copy_n(h,k+2,g);
}
g[0]=(g[k]+g[k+1])%b6e0;
f=g+k+2;
return true;
}
signed main() {
int T;read(T);
while (T--) {
read(n,k);
for (int i{1};i<n;++i) {
int u,v;read(u,v);
e[u].push_back(v);
e[v].push_back(u);
}
init(1);
f=F;
if (solve(1)) write(F[0]);
else write(0);
for (int i{1};i<=n;++i)
e[i].clear();
}
return 0;
}