CF1864G Magic Square

· · 题解

*CF1864G Magic Square

太酷了。

首先,因为一个数至多被操作两次,且每行每列至多被操作一次,所以一个数所在的行被操作之后,这个数一定在正确的列上。否则这个数还需要两次操作才能归位。同理,如果一个数所在的列被操作,那么操作完这个数一定在正确的行上。

题目还给出了被操作两次的两个数的偏移向量不能相同,这是容易检查的。一般而言,容易检查的限制需要当成题目给出的条件,帮助我们推出更多性质。考虑这个条件怎么用。

如果一个数要被操作两次,那么至少一个行操作和列操作会被执行。在此基础上,如果两个数的偏移向量相同,那么考虑这两个数移动的过程,有至少三次操作对它们当中任意一个产生影响(否则这两个数相同)。这说明存在两个偏移量相同的行,或者两个偏移量相同的列。反之,如果存在两个偏移量相同的行(列),那么可以证明存在两个数的偏移向量相同:考虑第 x_1x_2 行的偏移量相同,均为 \Delta y,那么无论第 x_1 行和第 y 列哪个先被操作,总存在偏移向量为 (\Delta x, \Delta y) 的数。对于第 x_2 行和第 y 列同理。如果这两个数相同,那么这个数被操作了三次。因此这两个数不同,也就是存在两个偏移向量相同的数。无论哪种情况均不合法。

这说明:如果至少一行和至少一列被操作,那么所有被操作的行的偏移量不同,被操作的列的偏移量不同

r_i 表示第 i 行的偏移量,c_j 表示第 j 列的偏移量,那么所有 r_i > 0r_i 互不相同,且 c_j > 0c_j 互不相同。

接下来的目标是求出所有 r_ic_j。以 r_i 为例:若第 i 行存在一个数没有改变它的行,那么显然 r_i 只能等于这个数的列偏移量。否则第 i 行所有数均改变了它所在的行。也就是 c_{1\sim n} 均不为 0。而 1\leq c_j < n,所以存在相同且不为 0c_j,这说明所有行均不能被操作,否则就和刚才推出的性质矛盾了,因此 r_i = 0。这说明每一行和每一列的偏移量是定值,不随操作而变化。

性质推完了,信息也处理完了,回归原问题。设一行是合法的,当且仅当它没有被操作过,且 r_i > 0,且操作后所有数在正确的列上。同理可设一列合法。如果任意时刻合法的仅有行或列中的一种,那么这些行(列)之间就是独立的,贡献一个阶乘的方案数。

如果第 i 行和第 j 列同时合法,那么考虑 a_{i, j},如果先操作第 i 行,且第 j + r_i 列的偏移量可以使得 a_{i, j} 归位,那么因为先操作第 j 列也可以让 a_{i, j} 到达正确的行,所以第 j 列和第 j + r_i 列的偏移量相同,矛盾。同理,先操作第 j 列也可以导出矛盾。这说明,不存在行和列同时合法。考虑当前所有合法行(列),显然一次行操作不会产生新的合法行,所以一定是全部操作完后产生若干合法列,行列交替不断操作直到没有合法行(列),此时检查一下所有数是否归位。若否则无解,若是则根据乘法原理,答案为若干阶乘的积。

由上述分析可知计算 r, c 并不必要,但可以帮助我们更好地理解问题。

时间复杂度 \mathcal{O}(n ^ 3)\mathcal{O}(n ^ 2)

#include <bits/stdc++.h>
using namespace std;
using ll = long long;
using ui = unsigned int;
using pii = pair<int, int>;
using pll = pair<ll, ll>;
using pdi = pair<double, int>;
using ull = unsigned long long;
using vint = vector<int>;
using vll = vector<ll>;

#define ppc(x) __builtin_popcount(x)
#define clz(x) __builtin_clz(x)

bool Mbe;
// mt19937 rnd(chrono::steady_clock::now().time_since_epoch().count());
mt19937 rnd(20060130);
int rd(int l, int r) {return rnd() % (r - l + 1) + l;}

constexpr int mod = 998244353;
void addt(int &x, int y) {x += y, x >= mod && (x -= mod);}
int add(int x, int y) {return x += y, x >= mod && (x -= mod), x;}
int ksm(int a, int b) {
  int s = 1;
  while(b) {
    if(b & 1) s = 1ll * s * a % mod;
    a = 1ll * a * a % mod, b >>= 1;
  }
  return s;
}

char buf[1 << 20], *p1 = buf, *p2 = buf;
#define getc() (p1 == p2 && (p2 = (p1 = buf) + \
  fread(buf, 1, 1 << 20, stdin), p1 == p2) ? EOF : *p1++)
inline int read() {
  int x = 0;
  char s = getc();
  while(!isdigit(s)) s = getc();
  while(isdigit(s)) x = x * 10 + s - '0', s = getc();
  return x;
}

#define putc(x) putchar(x)
inline void print(ui x) {
  if(x >= 10) print(x / 10);
  putc(x % 10 + '0');
}

// ---------- templates above ----------

constexpr int N = 500 + 5;

int n, a[N][N], b[N][N], fc[N];
int vis[N][N], r[N], c[N];
pii pos[N * N];
void mian() {
  cin >> n;
  for(int i = 1; i <= n; i++) {
    for(int j = 1; j <= n; j++) {
      cin >> a[i][j];
      vis[i][j] = 0; 
    }
  }
  for(int i = 1; i <= n; i++) {
    for(int j = 1; j <= n; j++) {
      cin >> b[i][j];
      pos[b[i][j]] = make_pair(i, j);
    }
  }
  for(int i = 1; i <= n; i++) {
    for(int j = 1; j <= n; j++) {
      pii p = pos[a[i][j]];
      int dx = (p.first + n - i) % n;
      int dy = (p.second + n - j) % n;
      if(dx && dy) {
        if(vis[dx][dy]) {
          cout << "0\n";
          return;
        }
        vis[dx][dy] = 1;
      }
    }
  }

  int cntr = 0, cntc = 0;
  for(int i = 1; i <= n; i++) {
    map<int, int> mp;
    for(int j = 1; j <= n; j++) mp[b[i][j]] = j;
    r[i] = -1;
    for(int j = 1; j <= n; j++) {
      auto it = mp.find(a[i][j]);
      if(it == mp.end()) continue;
      int shif = (it->second + n - j) % n;
      if(r[i] == -1) r[i] = shif;
      else if(r[i] != shif) {
        cout << "0\n";
        return;
      }
    }
    if(r[i] == -1) r[i] = 0;
    cntr += r[i] > 0;
  }
  for(int j = 1; j <= n; j++) {
    map<int, int> mp;
    for(int i = 1; i <= n; i++) mp[b[i][j]] = i;
    c[j] = -1;
    for(int i = 1; i <= n; i++) {
      auto it = mp.find(a[i][j]);
      if(it == mp.end()) continue;
      int shif = (it->second + n - i) % n;
      if(c[j] == -1) c[j] = shif;
      else if(c[j] != shif) {
        cout << "0\n";
        return;
      }
    }
    if(c[j] == -1) c[j] = 0;
    cntc += c[j] > 0;
  }

  int ans = 1, cur = 0, lstc = -1;
  vint vr(n + 1), vc(n + 1);
  for(int i = 1; i <= n; i++) vr[i] = r[i] == 0;
  for(int i = 1; i <= n; i++) vc[i] = c[i] == 0;
  while(1) {
    int cnt = 0;
    if(cur == 0) {
      for(int i = 1; i <= n; i++) {
        if(vr[i]) continue;
        bool ban = 0;
        for(int j = 1; j <= n; j++) {
          if((pos[a[i][j]].second + n - j) % n != r[i]) ban = 1;
        }
        if(!ban) {
          cnt++, vr[i] = 1;
          static int d[N];
          for(int j = 1, p = 1 + r[i]; j <= n; j++) {
            d[p] = a[i][j];
            if(++p > n) p -= n;
          }
          memcpy(a[i], d, sizeof(d));
        }
      }
    }
    else {
      for(int j = 1; j <= n; j++) {
        if(vc[j]) continue;
        bool ban = 0;
        for(int i = 1; i <= n; i++) {
          if((pos[a[i][j]].first + n - i) % n != c[j]) ban = 1;
        }
        if(!ban) {
          cnt++, vc[j] = 1;
          static int d[N];
          for(int i = 1, p = 1 + c[j]; i <= n; i++) {
            d[p] = a[i][j];
            if(++p > n) p -= n;
          }
          for(int i = 1; i <= n; i++) a[i][j] = d[i];
        }
      }
    }
    if(!cnt && !lstc) break;
    cur ^= 1, lstc = cnt;
    ans = 1ll * ans * fc[cnt] % mod;
  }
  bool ok = 1;
  for(int i = 1; i <= n; i++) {
    for(int j = 1; j <= n; j++) {
      ok &= a[i][j] == b[i][j];
    }
  }
  if(!ok) ans = 0;
  cout << ans << "\n";
}

bool Med;
int main() {
  fprintf(stderr, "%.3lf MB\n", (&Mbe - &Med) / 1048576.0);
  for(int i = fc[0] = 1; i < N; i++) fc[i] = 1ll * fc[i - 1] * i % mod;
  ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
  int T = 1;
  cin >> T;
  while(T--) mian();
  fprintf(stderr, "%d ms\n", int(1e3 * clock() / CLOCKS_PER_SEC));
  return 0;
}