AT ARC184E Accumulating Many Times

· · 题解

cnblogs。

因为 a_i 并不具有什么性质,所以尝试从操作入手。

尝试形式化的刻画这个操作: 记 F(x) = \sum\limits_{i = 1}^m A_{*, i}x^i,进行操作就是乘上 (1 + x + x^2 + \cdots)
不过这样形式太难看,考虑转化一下,转成由 A_j 变为 A_i,前缀和对应就被转成差分,就对应 F(x)(1 - x) \equiv F(x)(1 + x)\pmod 2

继续考虑叠加操作,当进行 k 次操作后,可以表示为 F(x)(1 + x)^k\bmod 2

此时这个 \bmod 2 有很好的性质,具体来说考虑 Lucas 定理:对于质数 p,有 (1 + x)^k\bmod p \equiv (1 + x^p)^{\lfloor\frac{k}{p}\rfloor}(1 + x)^{k\bmod p}\pmod p
那么取 p = 2,发现可以表示为二进制拆分的形式,即令 k = 2^{e_1} + 2^{e_2} + \cdots 2^{e_l}(e_1 < e_2 < \cdots < e_l),那么 F(x)(1 + k)^x \equiv F(x)\prod\limits_{i = 1}^l (1 + x^{2^{e_i}})\pmod 2

经过这些推导,能够发现这个转移一定是一个环,具体来说考虑若取 k = 2^e \ge m,那么 F(x)(1 + x^k) 在只保留前 m 位的情况下还会是 F(x)

那么此时 f(i, j) 就有了计算的办法:找到 A_i, A_j 对应的环,再根据在环上的具体位置计算操作次数。

此时要如何找到其对应的环呢?
这就相当于是把一个环内的元素看做一个等价类,对于这个问题,我们有很好的办法——找到代表元。
不妨认为一个环的代表元是环中字典序最小的数组。

那么如何找到这个代表元?
前文已经提供了一定的思路,考虑对操作到代表元的操作次数 k 进行二进制拆分,转而对于每个 (1 + x^{2^e}) 考虑是否要选上。

考虑到前缀的 00\cdots001 肯定不管如何操作都不会改变,记最靠前的 1p 下标,考虑如何最小化字典序。
那么接下来就需要最小化 p + 1 下标的值,所以若 a_{i, p + 1} = 1,乘上 (1 + x),就可以让 a_{i, p + 1} 变为 0;同样的,也可以操作使 a_{i, p + 2} 变为 0
但是对于 a_{i, p + 3},因为 (1 + x), (1 + x^2) 都已经被考虑过了,如果为了 a_{i, p + 3} 变为 0 就会导致 a_{i, p + 1} 或是 a_{i, p + 2} 变成 1,这明显是不优的。
于是可以知道这个过程就是依次考虑 a_{i, p + 2^0}, a_{i, p + 2^1}, a_{i, p + 2^2}\cdots,若 a_{i, p + 2^e}1,就乘上 (1 + x^{2^e}) 使其变为 0

p + 2^e > m 时明显是没有用的,所以整个过程只会经过 \log m 次。

并且这也同时求出了 A_i 变到环内代表元的操作次数,可以利用这个信息求解答案。

不过求解答案时还有个问题,可能从 A_iA_j 要跨过代表元,此时这个计算式就与环长有关,环长一定是最小的 2^e \ge m 吗?
考虑 0001,其周期其实就为 1 而不为 4,那显然是错的。
回到刚刚求解代表元的过程,当 p + 2^e > m 时没有用就是因为此时操作 2^e 次一定是相同的结果。
而对于满足该条件最小的 e,因为 p + 2^{e'}\le m(0\le e' < e),所以操作 2^{e'} 一定对应着不同的结果,那么通过二进制组合起来,操作 1\le k < 2^e 次得到的结果都不同。
所以可以知道周期就为满足 p + 2^e > m 的最小的 2^e,特别的,全 0 序列的周期为 1

那么对于周期长度为 M,分别操作 d_i, d_j 次能够操作成代表元的两个序列 A_i, A_j,可以推出 i\to j 的操作次数即为 d_i - d_j + [d_i < d_j]M
于是使用树状数组统计 \sum [d_i < d_j] 的值即可。

时间复杂度 \mathcal{O}(nm(\log m + \log n))

\log n 的原因是我是直接根据代表元排序进行的划分等价类。

代码中序列的下标是 0\sim m - 1

#include <bits/stdc++.h>

using ll = long long;

constexpr ll mod = 998244353; 
constexpr int maxn = 1e6 + 10;

int n, m;
std::vector<int> a[maxn];
int dis[maxn];

int od[maxn];

struct fenwick {
    ll sum[maxn * 2];
    inline void add(int x, ll y) {
        for (; x < maxn * 2; x += x & -x) {
            sum[x] = (sum[x] + y) % mod;
        }
    }
    inline ll query(int x) {
        ll y = 0;
        for (; x >= 1; x -= x & -x) {
            y = (y + sum[x]) % mod;
        }
        return y;
    }
} f, g;

int main() {
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; i++) {
        a[i].resize(m);
        for (int &x : a[i]) {
            scanf("%d", &x);
        }
        int p = 0;
        while (p < m && ! a[i][p]) {
            p++;
        }

        for (int l = 1; p + l < m; l *= 2) {
            if (a[i][p + l]) {
                for (int j = m - 1; j >= p + l; j--) {
                    a[i][j] ^= a[i][j - l];
                }
                dis[i] += l;
            }
        }
    }

    std::iota(od + 1, od + n + 1, 1);
    std::sort(od + 1, od + n + 1, [&](int x, int y) {
        return a[x] == a[y] ? x < y : a[x] < a[y];
    });

    ll ans = 0;
    for (int l = 1, r = 1; l <= n; l = ++r) {
        while (r + 1 <= n && a[od[l]] == a[od[r + 1]]) {
            r++;
        }

        int p = 0;
        while (p < m && ! a[od[l]][p]) {
            p++;
        }
        int M = 1;
        while (p + M < m) {
            M *= 2;
        }

        ll cnt = 0, sum = 0;
        for (int i = l; i <= r; i++) {
            const int u = od[i];
            const ll c0 = f.query(dis[u] + 1);
            ans = (ans + cnt * dis[u] - sum + (cnt - c0) * M + mod) % mod;
            f.add(dis[u] + 1, 1);
            cnt++, sum += dis[u];
        }
        for (int i = l; i <= r; i++) {
            const int u = od[i];
            f.add(dis[u] + 1, mod - 1);
        }
    }
    printf("%lld\n", ans);

    return 0;
}