题解:AT_arc158_f [ARC158F] Random Radix Sort

· · 题解

还差一步啊,还是差了一些!

面对这个问题分析充要条件,我们将给定的序列倒序那么相当于一个分别以 K_1, K_2, \dots, K_M 作为最高到最低关键字的稳定排序。很显然,重复出现的 K 只有第一个会起作用,只要知道这个去重后的结果,而之后出现的一定可以用组合数计算。

先考虑计算去重后的结果,这个时候则必须要涉及到有关 A,B 次序相关的手段。因为是稳定排序,所以可以唯一确定两个排列中数的对应关系。这里我自己思考的时候犯了一个错,我以为对于任意 i < j 都要确定先后关系才能唯一确定,实际上不是的,只要确定对于 i, (i+1) 的先后关系即可。可以把确定出来的关系看作拓扑序的约束,这样就很显然只要知道一条链的约束就可以唯一确定了。

具体而言,令 o_i 为在原序列中的次序。

产生了 O(n) 个约束。这样的计数看上去很难。但是注意到只有“强制规定 S_1 中至少出现过一个”和上面几个约束本质不同。直接状压对于 f[S]S 内满足其它几个约束的排列数。对着最后的 S 判断“至少在 S_1 中出现过一次”的约束即可。现在约束都形如 (S_i, T_i)S_i 最早出现的一定在 T_i 之前。考虑 f[x] 转移到 f[S + \{x\}],如果存在一个 x\in T_jS_i \cap S = \empty,那么不合法。S_i \cap S = \empty \iff S_i\subseteq (U - S),反过来我们选择判断此时是否存在 x\in T_i。这是一个高维前缀或的形式,可以预处理。而“强制规定 S_1 中至少出现过一个”的约束也可以用类似的高维前缀和解决。

现在问题变成了有 t 个不同的数,约束好了第一次出现的顺序,问构造成长度为 m 的序列方案数。设 g[i, j] 为填了前 i 个数目前出现了 j 个数,那么转移很显然是 g[i, j] = jg[i - 1, j] + g[i - 1, j - 1]。我们惊喜的发现这个东西的递推式和初值和第二类斯特林数一模一样!所以这个东西就是 S(m, t) = \sum\limits_{i = 0}^t \dfrac{(-1)^{t - i}i^m}{(t - i)!i!},就可以快速计算了。

#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int N = 2e5, K = 18;
const int Mod = 998244353;
void upd(int &x, int y) {
    x = ((x + y >= Mod) ? (x + y - Mod) : (x + y));
}
int qpow(int n, int m) {
    int res = 1;
    while(m) {
        if(m & 1) res = 1ll * res * n % Mod;
        n = 1ll * n * n % Mod;
        m >>= 1;
    }
    return res;
}
int n, m, qk;
struct node {
    int type, S1, S2;
} Q[N + 10];
int w[(1 << K) + 10], wc[K + 3][(1 << K) + 10], wf[(1 << K) + 10], tmp[(1 << K) + 10];

ll cp[N + 10], len = 0; vector <int> pos[N + 10];
ll a[N + 10], b[N + 10]; int o[N + 10];

void fwt(int *arr, int n) {
    for(int k = 1; k < (1 << n); k <<= 1) {
        for(int i = 0; i < (1 << n); i += (k << 1)) {
            for(int j = 0; j < k; j++) {
                ll rest = arr[i + j];
                arr[i + j] = rest % Mod;
                arr[i + j + k] = ((rest + arr[i + j + k]) % Mod + Mod) % Mod;
            }
        }
    }
}

int f[(1 << K) + 10];
int sum[K + 10], fac[K + 10], invf[K + 10];
int main() {
    ios::sync_with_stdio(0);
    cin.tie(0), cout.tie(0);
    cin >> n >> m >> qk;
    for(int i = 1; i <= n; i++) cin >> a[i], cp[i] = a[i];
    sort(cp + 1, cp + n + 1); len = unique(cp + 1, cp + n + 1) - cp - 1;
    for(int i = 1; i <= n; i++) {
        cin >> b[i];
        a[i] = lower_bound(cp + 1, cp + len + 1, a[i]) - cp;
        pos[a[i]].push_back(i);
    }
    for(int i = n; i >= 1; i--) {
        int d = lower_bound(cp + 1, cp + len + 1, b[i]) - cp;
        o[i] = pos[d].back();
        pos[d].pop_back();
    }

    int U = (1 << qk) - 1;
    for(int i = 1; i < n; i++) {
        int t1[20], t2[20], len1 = 0, len2 = 0;
        for(int j = 1; j <= qk; j++) t1[j] = t2[j] = 0;

        ll x; x = b[i]; while(x) t1[++len1] = x % 10, x /= 10;
        x = b[i + 1]; while(x) t2[++len2] = x % 10, x /= 10;

        if(o[i] < o[i + 1]) Q[i].type = 1;
        else Q[i].type = 2;
        int S1 = 0, S2 = 0;
        for(int j = 1; j <= qk; j++)
            if(t1[j] < t2[j]) S1 |= (1 << (j - 1));
            else if(t1[j] > t2[j]) S2 |= (1 << (j - 1));

        Q[i].S1 = S1, Q[i].S2 = S2;
        w[S1] |= S2;
        if(Q[i].type == 2) wf[S1]++;
    }

    for(int i = 0; i < (1 << qk); i++) {
        for(int j = 0; j < qk; j++)
            if((w[i] >> j) & 1) wc[j][i] = 1;
    }
    for(int j = 0; j < qk; j++) {
        for(int i = 0; i < (1 << qk); i++) tmp[i] = wc[j][i];
        fwt(tmp, qk);
        for(int i = 0; i < (1 << qk); i++) wc[j][i] = tmp[i];
    }
    fwt(wf, qk); reverse(wf, wf + (1 << qk));

    f[0] = 1;
    for(int S = 0; S < (1 << qk); S++) {
        for(int x = 0; x < qk; x++) {
            if(!((S >> x) & 1)) {
                if(!wc[x][U - S])
                    upd(f[S + (1 << x)], f[S]);
            }
        }
    }

    fac[0] = 1; for(int i = 1; i <= qk; i++) fac[i] = 1ll * fac[i - 1] * i % Mod;
    invf[qk] = qpow(fac[qk], Mod - 2); for(int i = qk - 1; i >= 0; i--) invf[i] = 1ll * invf[i + 1] * (i + 1) % Mod;
    for(int t = 0; t <= qk; t++) {
        for(int i = 0; i <= t; i++) {
            int t1 = (((t - i) % 2 == 0) ? (1) : (Mod - 1));
            int t2 = qpow(i, m);
            int t3 = 1ll * invf[t - i] * invf[i] % Mod;
            upd(sum[t], 1ll * t1 * t2 % Mod * t3 % Mod);
        }
    }

    int ans = 0;
    for(int S = 0; S < (1 << qk); S++)
        if(!wf[S]) upd(ans, 1ll * f[S] * sum[__builtin_popcount(S)] % Mod);
    cout << ans << '\n';
}