LG P10061. [SNOI 2024] 矩阵 对修改序列分块

· · 题解

该算法的时间复杂度可能有问题。

注意到 8 s 时限,容易想到根号;注意到只用求最终结果,感觉难以直接维护,需要离线。

对修改序列分块,设块长为 B。处理出每个位置经过一块修改后的变换(到哪、加多少):

尽管旋转能使横线和竖线(切口)参次不齐,但预计每块会被切成略少于 4B^2 个矩形(官方数据是这样的)。感觉可以更多,但好像多不到哪去,不妨认为是大常数 O(B^2) 的。

时间复杂度为 O\left(n^2\dfrac{q}{B}+qB^2\right),取 B=n^\frac{2}{3}O(n^\frac{4}{3}q)

因为后半部分常数大,B 应取小几倍。取 B\in[50,60] 时,官方数据跑了 3 s 左右。

#include <bits/stdc++.h>

const int N = 3005, Mo = 1e9 + 7;
int n, q, op[N], x[N], y[N], xx[N], yy[N], d[N], a[2][N][N];
struct Trans {
  short xkx = 1, xky = 0, xb = 0, ykx = 0, yky = 1, yb = 0; int add = 0;
  void rotate(int id) {
    int _xkx = xkx, _xky = xky, _xb = xb;
    xkx = -ykx, xky = -yky, xb = xx[id] + y[id] - yb;
    ykx = _xkx, yky = _xky, yb = y[id] - x[id] + _xb;
    // (X, Y) -> (xx[id] + y[id] - Y, y[id] - x[id] + X)
  }
  void addnum(int id) { add = (add + d[id]) % Mo; }
  void trans(int &x, int &y) {
    int _x = x; x = xkx * x + xky * y + xb, y = ykx * _x + yky * y + yb;
  }
  void itrans(short &x, short &y) {
    if (!xky && !ykx) x = (x - xb) / xkx, y = (y - yb) / yky;
    else if (!xkx && !yky) { int _x = x; x = (y - yb) / ykx, y = (_x - xb) / xky; }
    else assert(false); // Never happen.
  }
};
struct Block {
  short xl, xr, yl, yr;
  Block(int _xl, int _xr, int _yl, int _yr): xl(_xl), xr(_xr), yl(_yl), yr(_yr) {}
};
std::vector<std::pair<Block, Trans>> vec[2];

int main() {
  scanf("%d%d", &n, &q);
  for (int i = 1; i <= n; ++i) for (int j = 1, res = 1; j <= n; ++j)
    a[0][i][j] = res = 1ll * res * (i + 1) % 998244353;
  for (int i = 1; i <= q; ++i) {
    scanf("%d%d%d%d%d", &op[i], &x[i], &y[i], &xx[i], &yy[i]);
    if (op[i] == 2) scanf("%d", &d[i]);
  }
  for (int B = __, L = 1, R; L <= q; L += B) {
    R = std::min(L + B - 1, q);
    vec[0].emplace_back(Block(1, n, 1, n), Trans());
    for (int i = L; i <= R; ++i) {
      for (auto it: vec[0]) {
        Block blk = it.first; Trans trs = it.second;
        std::vector<std::pair<short, short>> vx, vy;
        if (xx[i] < blk.xl || blk.xr < x[i] || (x[i] <= blk.xl && blk.xr <= xx[i]))
          vx.emplace_back(blk.xl, blk.xr);
        else if (x[i] <= blk.xl && xx[i] < blk.xr)
          vx.emplace_back(blk.xl, xx[i]), vx.emplace_back(xx[i] + 1, blk.xr);
        else if (blk.xl < x[i] && blk.xr <= xx[i])
          vx.emplace_back(blk.xl, x[i] - 1), vx.emplace_back(x[i], blk.xr);
        else vx.emplace_back(x[i], xx[i]),
             vx.emplace_back(blk.xl, x[i] - 1), vx.emplace_back(xx[i] + 1, blk.xr);
        if (yy[i] < blk.yl || blk.yr < y[i] || (y[i] <= blk.yl && blk.yr <= yy[i]))
          vy.emplace_back(blk.yl, blk.yr);
        else if (y[i] <= blk.yl && yy[i] < blk.yr)
          vy.emplace_back(blk.yl, yy[i]), vy.emplace_back(yy[i] + 1, blk.yr);
        else if (blk.yl < y[i] && blk.yr <= yy[i])
          vy.emplace_back(blk.yl, y[i] - 1), vy.emplace_back(y[i], blk.yr);
        else vy.emplace_back(y[i], yy[i]),
             vy.emplace_back(blk.yl, y[i] - 1), vy.emplace_back(yy[i] + 1, blk.yr);
        for (auto xi: vx) for (auto yi: vy) {
          Block _blk = Block(xi.first, xi.second, yi.first, yi.second);
          Trans _trs = trs;
          if (x[i] <= _blk.xl && _blk.xr <= xx[i] && y[i] <= _blk.yl && _blk.yr <= yy[i]) {
            if (op[i] == 1) {
              _trs.rotate(i);
              _blk.xl = xx[i] + y[i] - _blk.yr, _blk.xr = xx[i] + y[i] - _blk.yl;
              _blk.yl = y[i] - x[i] + xi.first, _blk.yr = y[i] - x[i] + xi.second;
            }
            else _trs.addnum(i);
          }
          vec[1].emplace_back(_blk, _trs);
        }
      }
      vec[0].clear(), std::swap(vec[0], vec[1]);
    }
    for (auto it: vec[0]) {
      Block blk = it.first; Trans trs = it.second;
      trs.itrans(blk.xl, blk.yl), trs.itrans(blk.xr, blk.yr);
      if (blk.xl > blk.xr) std::swap(blk.xl, blk.xr);
      if (blk.yl > blk.yr) std::swap(blk.yl, blk.yr);
      for (int i = blk.xl; i <= blk.xr; ++i) for (int j = blk.yl; j <= blk.yr; ++j) {
        int x = i, y = j; trs.trans(x, y); a[1][x][y] = (a[0][i][j] + trs.add) % Mo;
      }
    }
    vec[0].clear(), std::swap(a[0], a[1]);
  }
  long long ans = 0;
  for (int i = 1, res = 1; i <= n; ++i) for (int j = 1; j <= n; ++j)
    res = 12345ll * res % Mo, ans = (ans + 1ll * a[0][i][j] * res % Mo);
  printf("%lld", ans % Mo);
  return 0;
}