题解:P1527 [国家集训队] 矩阵乘法

· · 题解

真小波矩阵(换行)就该套小波矩阵.jpg。

如果你还不知道什么是小波矩阵,欢迎阅读我的专栏 浅谈 Wavelet Matrix。

在 Moeebius 老师的 题解 中,使用了权值线段树套小波树,实际上的实现介于小波树和小波矩阵之间。由于小波矩阵的底层进行了压位,导致在线段树的叶子结点空间表现非常差。既然小波树把每一层合起来压成了小波矩阵,那么我们尝试使用小波矩阵套小波矩阵,以避免掉上文的问题。

下文以及代码实现使用 0-index。

考虑把大小 n \times n 的矩阵 a'_{i,j} 拆成长 n^2 的二元组序列 a_{in+j}=(a'_{i,j},j),再把查询也改为 x_1=x_1'n+y_1,x_2=x_2'n+y_2。这样就变成了查询集合 \{a_{i,0} \mid i \in [x_1,x_2] \wedge a_{i,1} \in [y_1,y_2]\} 的第 k 小,这是一个类三维数点问题。

我们按 a_{i,0} 构建外层小波矩阵,每层内再按 a_{i,1} 构建内层小波矩阵。查询时相当于在外层二分答案,内层做二维数点。推广到高维是容易的。

实现上有几个小问题:内层的 stable_partition 不应改变外层数组,以及内层的小波矩阵应在外层进行 stable_partition 之后再构建。

于是做完了,时空复杂度 O((n^2+q)\log^2{n})-O(n^2\log{n})。把小波矩阵层数开满外层 30 内层 18 的话可能比较卡常,离散化开到外层 18 内层 14 就行。完全不用担心压位数据结构 n \le \omega 后时空爆炸的问题。

提交记录 以及核心代码:

struct bits
{
    int n;
    vector <pair <ULL, int>> bit;
    bits() {}
    bits(vector <ULL> arr)
    {
        n = (arr.size() >> 6) + 1, bit.resize(n);
        for (int i : viota(0, ssize(arr))) bit[i >> 6].first |= arr[i] << (i & 63);
        for (int i : viota(1, n)) bit[i].second = bit[i - 1].second + __builtin_popcountll(bit[i - 1].first);
    }
    int ask(int k) {return bit[k >> 6].second + __builtin_popcountll(bit[k >> 6].first & ((1ull << (k & 63)) - 1));}
};

struct wav1
{
    int n;
    array <bits, 14> bit;
    array <int, 14> pos;
    wav1() {}
    wav1(vector <int> arr)
    {
        n = arr.size();
        for (int u : viota(0, 14) | vreve)
        {
            bit[u] = bits (arr | vtran([&] (auto x) {return x >> u & 1ull;}) | ranges::to <vector> ());
            pos[u] = n - ranges::stable_partition(arr, [&] (int x) {return ~x >> u & 1;}).size();
        }
    }
    int ask(int l, int r, int x, int y)
    {
        int res = 0;
        int lx = l, rx = r, ly = l, ry = r;
        for (int u : viota(0, 14) | vreve)
        {
            if (int ll = bit[u].ask(lx), rr = bit[u].ask(rx); ~x >> u & 1) lx -= ll, rx -= rr;
            else res -= rx - rr - lx + ll, lx = pos[u] + ll, rx = pos[u] + rr;
            if (int ll = bit[u].ask(ly), rr = bit[u].ask(ry); ~y >> u & 1) ly -= ll, ry -= rr;
            else res += ry - rr - ly + ll, ly = pos[u] + ll, ry = pos[u] + rr;
        }
        return res;
    }
};

struct wav2
{
    int n;
    array <bits, 18> bit;
    array <wav1, 18> wav;
    array <int, 18> pos;
    wav2(vector <array <int, 2>> arr)
    {
        n = arr.size();
        for (int u : viota(0, 18) | vreve)
        {
            bit[u] = bits (arr | vtran([&] (auto x) {return x[1] >> u & 1ull;}) | ranges::to <vector> ());
            pos[u] = n - ranges::stable_partition(arr, [&] (auto x) {return ~x[1] >> u & 1;}).size();
            wav[u] = wav1 (arr | vtran([] (auto x) {return x[0];}) | ranges::to <vector> ());
        }
    }
    int ask(int l, int r, int x, int y, int k)
    {
        int res = 0;
        for (int u : viota(0, 18) | vreve)
        {
            int ll = bit[u].ask(l), rr = bit[u].ask(r);
            if (int c = wav[u].ask(l - ll, r - rr, x, y); c >= k) l -= ll, r -= rr;
            else res |= 1 << u, k -= c, l = pos[u] + ll, r = pos[u] + rr;
        }
        return res;
    }
};

void staring::mainSolve()
{
    int n, q;
    read(n, q);
    vector arr(n * n, array {0, 0});
    vector brr(n * n, 0);
    for (int x; int i : viota(0, n * n))
        read(x), arr[i] = {i % n, x}, brr[i] = x;

    rsort(brr);
    brr.erase(ranges::unique(brr).begin(), brr.end());
    for (auto &[i, x] : arr) x = rlowb(brr, x) - brr.begin();
    wav2 wav(arr);

    while (q --)
    {
        int x1, y1, x2, y2, k;
        read(x1, y1, x2, y2, k);
        x1 = (x1 - 1) * n + --y1;
        x2 = (x2 - 1) * n + y2;
        write(brr[wav.ask(x1, x2, y1, y2, k)]);
    }
}