CF1709F Multiset of Strings 题解

· · 题解

一道比较简单的多项式计数题(这也能算计数?)但是场上 sb 了没看出来应该放在 01-trie 上做。

题意略。

把这个题目放在一个插入了所有 [0,2^n-1] 中的二进制数的 01 trie 上考虑,发现可以把每个 c_x 的限制放在数字 x 往父节点连的那条边上。结合图片看一下:

蓝色的字母表示边权 c_x

问题就转化为了,从根节点(第 0 层)有无限多流量的水往下流,求流到第 n 层的最大总流量是 f 的方案数,且边权的范围是 [0,k]

对于每个节点 x,用树形 dp 来计数。

首先你可以发现同一层的点的 dp 值肯定是全部相同的,也就是我们只需要对 n 个点 dp 就可以了。

具体来说用 dp_{x,i} 表示从 x 的父亲往 x 的方向最多可以流 i 的流量的方案数。

初始化第 n 层的节点 x 的 dp 值:dp_{x,i}=1,i\in [0,k]

从第 j 层的节点 s1,s2j\in[2,n])往第 j-1 层的节点 x 转移的时候,考虑这个流量 i 被什么限制:有可能是 dp_{s1},dp_{s2},也可能是被 c_x 限制。

所以有转移方程:dp_{x,i}=\sum_{j+k>i}(dp_{s1,j}\cdot dp_{s2,k})+(k-i+1)\sum_{j+k=i}(dp_{s1,j}\cdot dp_{s2,k}) 。注意 s1s2 是完全等同的,所以把 s1 对应的生成函数平方一下就可以了。

最后一点细节是第一层的节点 x,y 往第 0 层的答案转移,有 Ans=\sum_{j+k=f}dp_{x,j}\cdot dp_{y,k}

用 NTT 进行多项式操作,时间复杂度是 O(nk\log k)

代码如下:

#include <bits/stdc++.h>
const int N = 524288, mod = 998244353, g = 3;
inline int mul(int x, int y){
    return (int)(1ll * x * y % (1ll * mod));
}
inline int add(int x, int y){
    return x + y >= mod ? x + y - mod : x + y;
}
inline int minus(int x, int y){
    return x < y ? x - y + mod : x - y;
}
inline int Qpow(int x, int y){
    int r = 1;
    while(y){
        if(y & 1) r = mul(r, x);
        x = mul(x, x);
        y >>= 1;
    }
    return r;
}
int rev[N];
void NTT(int *A, int limit, int on){
    for(int i = 1; i < limit; ++i) 
        rev[i] = (rev[i >> 1] >> 1) + (i & 1) * (limit >> 1);
    for(int i = 0; i < limit; ++i) 
        if(i < rev[i]) std::swap(A[i], A[rev[i]]);
    for(int i = 2; i <= limit; i <<= 1){
        int t = Qpow(g, (mod - 1) / i);
        if(on == -1) t = Qpow(t, mod - 2);
        for(int j = 0; j < limit; j += i){
            int w = 1;
            for(int k = j; k < j + i / 2; ++k, w = mul(w, t)){
                int u = A[k], v = mul(A[k + i / 2], w);
                A[k] = add(u, v);
                A[k + i / 2] = minus(u, v);
            }
        }
    }
    if(on == -1){
        int u = Qpow(limit, mod - 2);
        for(int i = 0; i < limit; ++i) A[i] = mul(A[i], u); 
    }
    return ;
}
void work1(int *A, int limit){
    NTT(A, limit, 1);
    for(int i = 0; i < limit; ++i) A[i] = mul(A[i], A[i]);
    NTT(A, limit, -1);
    return ;
}
int n, k, f, dp[N];
signed main(){
    scanf("%d%d%d", &n, &k, &f);
    int limit = 1;
    while(limit < 2 * k + 1) limit <<= 1;
    for(int i = 0; i <= k; ++i) dp[i] = 1;
    for(int i = 1; i <= n; ++i){
        work1(dp, limit);
        int s = 0;
        for(int j = 2 * k + 1; j < limit; ++j) dp[j] = 0;
        if(i == n) break;
        for(int j = 2 * k; j >= 0; --j){
            int tmp = dp[j];
            if(j <= k)
            dp[j] = add(s, mul(dp[j], std::max(0, k - j + 1)));
            else dp[j] = 0;
            s = add(s, tmp);
        }
    }
    printf("%d\n", dp[f]);
    return 0;
}