[题解] P3689 [ZJOI2017] 多项式

· · 题解

此处给出一种在 n,m,k,f(x),t 给定的情况下,可以做到 O(\log m) 回答一组 L,R 的做法。

首先考虑 f(x)^m 的组合意义。将 f(x) 看作一个集合 S \subseteq \{0,1,\cdots,n\} 的生成函数,则 f(x)^m 的第 p 项系数代表从中有序地选出 m 个数,和为 p 的方案数的奇偶性。考虑如果我们取了 a_ii,其中 \sum a_i=m,那么其对 \sum ia_i=p 这一项的贡献为 C_m^{a_0}C_{m-a0}^{a_1}\cdots C_{a_n}^{a_n} \bmod 2。根据卢卡斯定理,当且仅当 \forall i\ne j,a_i \operatorname{and} a_j = 0 时,上式为 1。因此,我们可将 f(x)^m 的第 p 项系数表示为:对于 m 的第 i 位,若其为 1,给其赋一个存在于 S 中的值 x,使其加权和 \sum 2^ix=p 的方案数的奇偶性。

注:上述转化实际上与次数扩倍的思想基本一致。

考虑如何求出 x^p 这一项的系数。我们可以数位 DP,记 f_{i,x} 表示当前为第 i 位,第 i-1 位向该位进位为 x,且最后总和为 p 的方案数奇偶性。转移时若 m 这一位为 0,则有 f_{i+1,x} \to f_{i,2x+c}c=[p \operatorname{and} 2^i]。若 m 该位为 1,则有 \forall j \in S,f_{i+1,x} \to f_{i,2x+c-j}。此时最终答案为 f_{0,0}

根据定义,进位数 x<n,故我们可将 f_{i,x} 压成 n 维向量 F_i。此时转移相当于乘以一个矩阵。注意到根据 mpi 位的取值,一共仅有 4 种不同的矩阵 Mat_{0/1,0/1}。又注意到上述向量 F_i 每一位均为 01,故我们可将其进一步压缩为一个 0 ~ 2^n-1 的状态 st

此时考虑倍增,记 G_{st,k} 表示 高位传下来的向量为 st,低位取遍 0 ~ 2^k-1 时的所有答案拼成的字符串。 显然 G_{st,0}=st \operatorname{and} 1。对于 G_{st,i},考虑向量 ls=st \times Mat_{c,0}rs=st \times Mat_{c,1}c=[m \operatorname{and} 2^i],则 G_{st,i}G_{ls,i-1}G_{rs,i-1} 拼接而成的字符串。对于本题所求,由于 k \le 18,我们可存储 res,fl,fr 分别表示 Gt 的出现次数,以及 G 的前 k 位与后 k 位。合并与询问的倍增是显然的。我们需要预处理每种向量 st \times Mat_{0/1,0/1} 的结果以加速该过程。

另外,在合并时,若我们仅仅是将 frfl 拼接起来后扫一遍,复杂度可能升至 O(2^n k\log m)。此处可以改变 fl 的定义,让其记录 所有满足 t[i,n]G 的前缀的 i 的集合。 fr 同理。此时拼接的贡献即为 \operatorname{popcount}(fl \operatorname{and} fr)

时空复杂度 O(2^n\log m),查询 [L,R] 答案的复杂度 O(\log m)

#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
const int maxn = 263000;

int n, k, str, lim;
char s[20], t[20];
int fl[64][maxn], fr[64][maxn];
ll res[64][maxn], m;
int mat[20][2][2], nx[maxn][2][2], popcnt[maxn];

int mult(int st, bool o1, bool o2)
{
    int ans = 0;
    for (int i = 0; i < n; i++) ans |= ((popcnt[st & mat[i][o1][o2]] & 1) << i);
    return ans;
}

int RE, FL, FR;
void calc(ll nw, int len)
{
    RE = FL = FR = 0;
    for (int j = 1; j < k && j <= len; j++)
    {
        ll x = nw & ((1 << j) - 1);
        ll y = (str & (((1 << j) - 1) << (k - j))) >> (k - j);
        if (x == y) FL |= (1 << j);
    }
    for (int j = 0; j < k - 1 && j < len; j++)
    {
        ll x = nw >> (len - j - 1);
        ll y = str & ((1 << (j + 1)) - 1);
        if (x == y) FR |= (1 << (k - j - 1));
    }
    for (int j = k; j <= len; j++)
    {
        ll x = (nw >> (len - j));
        x &= ((1 << k) - 1);
        if (x == str) RE++;
    }
}

ll work(ll x)
{
    if (x < k - 1) return 0; x++;
    int st = 1;
    int rs = 0, ln = 0;
    ll nw = 0, ans = 0;
    for (int i = lim; ~i; i--)
    {
        if (!((x >> i) & 1)) st = nx[st][(m >> i) & 1][0];
        else
        {
            int ls = nx[st][(m >> i) & 1][0];
            if (i >= 10 || (1 << i) >= k)
            {
                ans += res[i][ls] + popcnt[rs & fl[i][ls]];
                rs = fr[i][ls];
            }
            else nw |= ((ll)fl[i][ls] << ln), ln += (1 << i);
            st = nx[st][(m >> i) & 1][1];
        }
    }
    calc(nw, ln);
    return ans + RE + popcnt[rs & FL];
}

int main()
{
    ios::sync_with_stdio(false);
    cin.tie(0);

    int T;
    cin >> T;
    for (int i = 1; i < (1 << 18); i++) popcnt[i] = popcnt[i ^ (i & -i)] + 1;
    while (T--)
    {
        memset(mat, 0, sizeof(mat));

        ll l, r; str = 0;
        cin >> n >> m >> k >> l >> r >> s >> t;
        l = m * n + 1 - l, r = m * n + 1 - r;
        swap(l, r);
        for (int i = 0; i < k; i++) str |= ((t[i] == '1') << (k - i - 1));
        int N = 1 << n;
        for (int i = 0; i <= n / 2; i++) swap(s[i], s[n - i]);

        for (int c = 0; c < 2; c++)
            for (int i = 0; i < n; i++)
                for (int j = 0; j <= n; j++)
                {
                    if (s[j] ^ '1') continue;
                    int nw = i - c + j;
                    if (nw < 0 || (nw & 1)) continue; nw >>= 1;
                    mat[i][1][c] |= (1 << nw);
                }
        for (int c = 0; c < 2; c++)
            for (int i = c; i < n; i++)
            {
                int nw = i - c;
                if (nw & 1) continue; nw >>= 1;
                mat[i][0][c] |= (1 << nw);
            }

        for (int st = 0; st < N; st++)
            for (int i = 0; i < 2; i++)
                for (int j = 0; j < 2; j++) nx[st][i][j] = mult(st, i, j);
        lim = 0;
        while ((1ll << lim) <= r || (1ll << lim) <= m) lim++;
        for (int st = 0; st < N; st++)
        {
            fl[0][st] = fr[0][st] = (bool)(st & 1);
            res[0][st] = 0;
            if (k == 1)
            {
                res[0][st] = (fl[0][st] == str);
                fl[0][st] = fr[0][st] = 0;
            }
        }

        for (int i = 1; i <= lim; i++)
        {
            bool ch = (m >> (i - 1)) & 1;
            for (int st = 0; st < N; st++)
            {
                int ls = nx[st][ch][0], rs = nx[st][ch][1];
                if (i <= 4 && (1 << i) < k)
                    fl[i][st] = fr[i][st] = (fr[i - 1][rs] << (1 << (i - 1))) | fl[i - 1][ls], res[i][st] = 0;
                else if (i >= 10 || (1 << i) >= (k << 1))
                {
                    fl[i][st] = fl[i - 1][ls], fr[i][st] = fr[i - 1][rs];
                    res[i][st] = res[i - 1][ls] + res[i - 1][rs] + popcnt[fr[i - 1][ls] & fl[i - 1][rs]];
                }
                else
                {
                    int len = 1 << i;
                    ll nw = ((ll)fr[i - 1][rs] << (1 << (i - 1))) | fl[i - 1][ls];
                    fl[i][st] = fr[i][st] = res[i][st] = 0;
                    calc(nw, len); res[i][st] = RE;
                    fl[i][st] = FL, fr[i][st] = FR;
                }
            }
        }

        if (r - l + 1 < k)
        {
            cout << "0\n";
            continue;
        }
        cout << work(r) - work(l + k - 2) << "\n";
    }
    return 0;
}