[ICPC2023 Jinan R] Graph Partitioning 2 题解

· · 算法·理论

题目链接:gym qoj

题目大意:给定一棵树和常数 k,查询将树划分为一些连通块,且每个连通块大小为 kk+1 的方案数。

upd 2025.4.21:已更新线性做法。

考虑 DP。设 f_{u,i} 表示 u 子树的划分,且 u 所在的连通块大小为 i

然后你发现如果 sz_u<k DP 值是平凡的,所以我们跳过 sz_u<k 的点的 DP。

对于剩下的点,记 s_u=1+\sum_{v\in \text{son}(u)} [sz_v<k]sz_v,显然 u 所在的连通块大小至少为 s_u。显然有 s_u\le k+1,否则整个树的答案肯定是 0。并且 \sum s_u\le n,因为每个点只会贡献一次。

然后所以一个 sz_u\ge k 的点,只需要考虑同样 sz_v\ge k 的儿子的 DP 值,因为剩下的儿子和自己在一个连通块里,所以合并完儿子后把自己的 DP 值偏移 s_u 即可。可以用 deque 存储 DP 数组来方便偏移。

现在我们可以认为树上只剩下 sz_u\ge k 的点要考虑。这是经典问题:儿子个数不为 1 的点的儿子个数之和为 O(n/k),所以我们可以直接用 NTT 将这些点的儿子的 DP 数组进行合并。剩下的点恰好有一个儿子,显然可以转移为 f_{u,i}=f_{v,i}f_{u,0}=f_{v,k}+f_{v,k+1},然后偏移 s_u 的距离。所以 DP 数组只会从儿子更改一个位置,直接把儿子的 DP 继承过来即可。

总的复杂度为 O(n/k\times k\log k)=O(n\log n)

以下做法来自 LHF。

其实我们还可以更进一步:不使用 NTT,我们直接将两个数组里有值的位置暴力乘起来就就可以做到 O(n)

证明:注意到 f_u 只有 O(sz_u/k+1) 个值。因为 u 子树内的连通块个数必然为 O(sz_u/k),每个连通块可以选 kk+1,于是剩余的点集大小的可能性也只有这么多种。

然后其实这个东西可以被视为 O(sz_u/k),因为在上面的 DP 中我们是不考虑那些 sz_u<k 的点的。

所以子树大小分别为 x,y 的两个点背包合并可以把复杂度视为 O(\min(x/k,k)\times\min(y/k,k))

这个形式很像树上背包,考虑用类似的分析方式。

于是我们成功把复杂度分析到了线性。

目前为止我是高贵的最优解。

#include <bits/stdc++.h>
using namespace std;
constexpr int Spp{1<<20},S2{1<<20};
char buf[Spp],*p1,*p2,buf2[S2],*l2=buf2,_st[22];
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,Spp,stdin),p1==p2)?EOF:*p1++)
#define putchar(c) (l2==buf2+S2&&(fwrite(buf2,1,S2,stdout),l2=buf2),*l2++=(c))
template <typename T>
void read(T &x) {
    char c;int f{1};
    do x=(c=getchar())^48;
    while (!isdigit(c)&&c!='-');
    if (x==29) f=-1,x=0;
    while (isdigit(c=getchar()))
        x=(x*10)+(c^48);
    x*=f;
}
template <typename T,typename ...Args>
void read(T& x,Args&... args) {read(x);read(args...);}
template<typename T>
void write(T x,char c='\n') {
    if (x<0) putchar('-'),x*=-1;
    int tp=0;
    do _st[++tp]=x%10; while (x/=10);
    while (tp) putchar(_st[tp--]+'0');
    putchar(c);
}
struct OI{~OI(){fwrite(buf2,1,l2-buf2,stdout);}}oi;
constexpr int N(1e5),b6e0{998244353};
int F[N*3+5],h[N+5],w[N+5];
int sz[N+5];
vector<int> e[N+5];
int n,k;
void mul(int a[],int b[],int sz) {
    vector<pair<int,int>> l,r;
    fill_n(w,k+2,0);
    for (int i{0};i<=k+1;++i) {
        if (a[i]) l.emplace_back(i,a[i]);
        if (i<=sz&&b[i]) r.emplace_back(i,b[i]);
    }
    for (auto [x,px]:l)
        for (auto [y,py]:r)
            if (x+y<=k+1)
                w[x+y]=(w[x+y]+1LL*px*py)%b6e0;
}
void init(int u) {
    sz[u]=1;
    for (auto v:e[u]) {
        e[v].erase(find(e[v].begin(),e[v].end(),u));
        init(v);
        sz[u]+=sz[v];
    }
    int tw{0};
    for (auto v:e[u])
        if (sz[v]>sz[tw]) tw=v;
    for (auto &v:e[u])
        if (v==tw) {
            swap(v,e[u].back());
            break;
        }

}
int *f;
bool solve(int u) {
    int sc{0},s{1};
    for (auto v:e[u])
        if (sz[v]>=k) ++sc;
        else s+=sz[v];
    int *g{f};
    if (s>k+1)
        return false;
    f+=s;
    for (auto v:e[u])
        if (sz[v]>=k&&!solve(v))
            return false;
    if (sc==0) {
        fill_n(g,k+2,0);
        g[sz[u]]=1;
    } else if (sc==1) {
        fill_n(g,s,0);
    } else {
        fill_n(h,k+2,0);
        h[s]=1;
        f=g+s;
        for (auto v:e[u]) {
            if (sz[v]>=k) {
                mul(h,f,sz[v]);
                copy_n(w,k+2,h);
                f+=k+2;
            }
        }
        copy_n(h,k+2,g);
    }
    g[0]=(g[k]+g[k+1])%b6e0;
    f=g+k+2;
    return true;
}
signed main() {
    int T;read(T);
    while (T--) {
        read(n,k);
        for (int i{1};i<n;++i) {
            int u,v;read(u,v);
            e[u].push_back(v);
            e[v].push_back(u);
        }
        init(1);
        f=F;
        if (solve(1)) write(F[0]);
        else write(0);
        for (int i{1};i<=n;++i)
            e[i].clear();
    }
    return 0;
}