P9482 [NOI2023] 字符串

· · 题解

P9482 [NOI2023] 字符串

S_i 表示 s 的从位置 i 开始的后缀,P_i 表示 s 的从位置 i 开始的前缀。

先考虑朴素的 \mathcal{O}(qn ^ 2) 暴力:枚举所有 1\leq l\leq r,检查是否有 s[i, i + l - 1] < s ^ R[i + l, i + 2l - 1],记作条件 X

我们发现,没有很好的后缀数据结构精确刻画所有 s ^ R[i + l, i + 2l - 1]。此时一般有两种思路:将限制放宽,再减去不合法的;将限制缩紧,再加上没统计到的。

相比于 “子串”,我们肯定更希望看到 “前缀” 或 “后缀”,因为前者有 \mathcal{O}(n ^ 2) 个,而后者只有 \mathcal{O}(n) 个。因此,尝试将限制改写成 S_i < P_{i + 2l - 1},记作条件 Y,再对求出的答案进行修正。

首先思考这样做对答案造成的影响:如果满足条件 X,条件 Y 显然一定被满足。因此,需要减去满足条件 Y 但不满足条件 X(记作条件 Z)的 l 的数量。

尝试推出条件 Z 对应的限制:如果 s[i, i + l - 1] > s ^ R[i + l, i + 2l - 1],则 S_i > P_{i + 2l - 1},不满足条件 Y。又因为不能满足条件 X,所以 s[i, i + l - 1] = s ^ R[i + l, i + 2l - 1],即 s[i, i + 2l - 1] 回文。再根据条件 Y,推出 S_{i + 2l} < P_{i - 1}。容易证明这是充要条件。

综上,我们将问题分成两部分:计算 S_i < P_{i + 2l - 1}l 的数量,减去 s[i, i + 2l - 1] 回文且 S_{i + 2l} < P_{i - 1}l 的数量。

第一部分相当容易:对 s + c + s ^ R + d 建出后缀数组,其中 c, d 是任意不属于字符集的分隔符,且我们认为 c, d 小于字符集的任意字符且 c > d,这是为了保证在 S_iP_j 的前缀或 P_jS_i 的前缀时正确比较两个前后缀,以及当 S_i = P_j 时得到 [S_i < P_j] = 0。那么,问题转化为有多少个结束位置形如 i + 2l - 1 的前缀,排名大于 S_i 对应的后缀。二维偏序,离线询问后按排名从大到小扫描线,遇到前缀则标记对应位置,遇到后缀则处理其对应的所有询问,形如查询区间内被标记的奇数或偶数位置数量。用两棵 BIT 维护即可。

第二部分相对困难一些。称 长度为偶数 的回文串 s[l, r] 合法,当且仅当 S_{r + 1} < P_{l - 1}。则第二部分要求合法的 s[i, i + 2l - 1] 的数量。

尝试刻画所有回文串,有两种方式:Manacher 和 PAM。因为我不会 PAM,所以考虑 Manacher,它求出了以所有位置或间隔为回文中心的最长回文半径。由于要求长度为偶数,所以本题只用到了以间隔为回文中心的最长回文半径。

接下来是本题最核心的观察:回文串具有一定对称性。考虑 i\sim i + 1 的间隔对应的所有回文串 s[i - l, i + 1 + l],设最大的 lR_i。我们发现,在比较 S_{i + l + 2}P_{i - l - 1} 时,它们具有长度为 R_i - l 的公共前缀:因为 s[i - R_i, i + 1 + R_i] 回文,所以 s[i + l + 2, i + 1 + R_i] = s ^ R[i - R_i, i - l - 1]。又因为 s_{i + R_i + 2} \neq s_{i - R_i - 1}(否则 R_i 可以更大),所以只需检查是否有 s_{i + R_i + 2} < s_{i - R_i - 1}。如果满足限制,则所有 s[i - l, i + 1 + l] 均合法,反之则均不合法。

现在问题转化为:给平面上一条斜率为 -1 的线段上所有点的权值 +1,点的坐标形如 (i - l, i + 1 + l),其中 0\leq l\leq R_i。支持给定某点 (i, i + 2r - 1),查询它下方(包括它本身)的所有点的权值之和,相当于固定横坐标 i,求有多少 i + 2l - 1 满足 s[i, i + 2l - 1] 合法。坐标变换后也是经典二维数点,直接做即可。我的做法是:纵坐标的定义为长度一半,则线段的点形如 (i - l, l + 1)0\leq l\leq R_i),查询的点形如 (i, r)。令横坐标加上纵坐标,则线段的点形如 (i + 1, l + 1)0\leq l\leq R_i),查询的点形如 (i + r, r)。这样,一条线段对查询的点的影响形如矩形 +1,且矩形没有上边界。若干次矩形加之后若干次单点查询,扫描线即可。

总时间复杂度 \mathcal{O}((n + q)\log n)

同学用了一些牛鬼蛇神方法,复杂度形如 \mathcal{O}(\frac {n ^ 2} w)\mathcal{O}(n\log ^ 2 n)n, q 同级),基本上都过了。还有 n ^ 2 过的,发怒了。

#include <bits/stdc++.h>
using namespace std;

using ll = long long;
using pii = pair<int, int>;
using pll = pair<ll, ll>;
using pdi = pair<double, int>;
using pdd = pair<double, double>;
using ull = unsigned long long;

#define ppc(x) __builtin_popcount(x)
#define clz(x) __builtin_clz(x)

bool Mbe;
// mt19937 rnd(chrono::steady_clock::now().time_since_epoch().count());
mt19937_64 rnd(1064);
int rd(int l, int r) {
  return rnd() % (r - l + 1) + l;
}

constexpr int mod = 1e9 + 7;
void addt(int &x, int y) {
  x += y, x >= mod && (x -= mod);
}
int add(int x, int y) {
  return x += y, x >= mod && (x -= mod), x;
}
int ksm(int a, int b) {
  int s = 1;
  while(b) {
    if(b & 1) s = 1ll * s * a % mod;
    a = 1ll * a * a % mod, b >>= 1;
  }
  return s;
}

constexpr int Z = 1e6 + 5;
int fc[Z], ifc[Z];
int bin(int n, int m) {
  if(n < m) return 0;
  return 1ll * fc[n] * ifc[m] % mod * ifc[n - m] % mod;
}
void init_fac(int Z) {
  for(int i = fc[0] = 1; i < Z; i++) fc[i] = 1ll * fc[i - 1] * i % mod;
  ifc[Z - 1] = ksm(fc[Z - 1], mod - 2);
  for(int i = Z - 2; ~i; i--) ifc[i] = 1ll * ifc[i + 1] * (i + 1) % mod;
}

// ---------- templates above ----------

constexpr int N = 2e5 + 5;

int n, m, q;
char s[N], t[N];
ll ans[N];

namespace SA {
  int sa[N], rk[N];
  int buc[N], ork[N], id[N];
  bool cmp(int a, int b, int w) {
    return ork[a] == ork[b] && ork[a + w] == ork[b + w];
  }
  void build(int n) {
    memset(buc, 0, N << 2);
    int m = 1 << 7, p = 0;
    for(int i = 1; i <= n; i++) buc[rk[i] = t[i]]++;
    for(int i = 1; i <= m; i++) buc[i] += buc[i - 1];
    for(int i = n; i; i--) sa[buc[rk[i]]--] = i;
    for(int w = 1; ; w <<= 1, m = p, p = 0) {
      for(int i = n - w + 1; i <= n; i++) id[++p] = i;
      for(int i = 1; i <= n; i++) if(sa[i] > w) id[++p] = sa[i] - w;
      memset(buc, 0, N << 2);
      memcpy(ork, rk, N << 2), p = 0;
      for(int i = 1; i <= n; i++) buc[rk[i]]++;
      for(int i = 1; i <= m; i++) buc[i] += buc[i - 1];
      for(int i = n; i; i--) sa[buc[rk[id[i]]]--] = id[i];
      for(int i = 1; i <= n; i++) rk[sa[i]] = cmp(sa[i - 1], sa[i], w) ? p : ++p;
      if(p == n) break;
    }
  }
}

struct BIT {
  int c[N];
  void clear() {
    memset(c, 0, N << 2);
  }
  void add(int x, int v) {
    while(x < N) c[x] += v, x += x & -x; 
  }
  int query(int x) {
    int s = 0;
    while(x) s += c[x], x -= x & -x;
    return s;
  }
  int query(int l, int r) {
    return query(r) - query(l - 1);
  }
};

namespace Part1 {
  struct dat {
    int id, l, r;
  };
  vector<dat> qu[N];
  void add(int i, int r, int id) {
    qu[SA::rk[i]].push_back({id, i + 1, i + r + r - 1});
  }

  BIT odd, eve;
  void solve() {
    odd.clear(), eve.clear();
    for(int i = m; i; i--) {
      int pos = SA::sa[i];
      if(pos <= n) {
        for(dat it : qu[i]) {
          if(it.l & 1) ans[it.id] += odd.query(it.l, it.r);
          else ans[it.id] += eve.query(it.l, it.r);
        }
      }
      else if(pos > n + 1 && pos < m) {
        pos = m - pos;
        if(pos & 1) odd.add(pos, 1);
        else eve.add(pos, 1);
      }
    }
    for(int i = 1; i < N; i++) qu[i].clear();
  }
}

namespace Part2 {
  vector<pii> qu[N], ad[N];
  void add(int i, int r, int id) {
    qu[i].push_back({i + r, id});
  }

  BIT tr;
  int R[N];
  char u[N];
  void solve() {
    tr.clear();
    int cnt = 0;
    u[0] = ',', u[cnt = 1] = '?';
    for(int i = 1; i <= n; i++) {
      u[++cnt] = s[i];
      u[++cnt] = '?';
    }
    u[++cnt] = '!';
    for(int i = 1, c = 0, r = 0; i < cnt; i++) {
      R[i] = i > r ? 1 : min(r - i + 1, R[c + c - i]);
      while(u[i - R[i]] == u[i + R[i]]) R[i]++;
      if(i + R[i] - 1 > r) c = i, r = i + R[i] - 1;
    }
    for(int i = 2; i <= n; i++) {
      int r = R[i * 2 - 1] >> 1;
      if(!r || s[i + r] >= s[i - r - 1]) continue;
      ad[i - r].push_back({i, 1});
      ad[i].push_back({i, -1});
    }
    for(int i = 1; i <= n; i++) {
      for(pii it : ad[i]) tr.add(it.first, it.second);
      for(pii it : qu[i]) ans[it.second] -= tr.query(it.first);
    }
    for(int i = 1; i < N; i++) {
      qu[i].clear();
      ad[i].clear();
    }
  }
}

void mian() {
  cin >> n >> q >> s + 1, m = 0;
  for(int i = 1; i <= n; i++) t[++m] = s[i];
  t[++m] = 57;
  for(int i = n; i; i--) t[++m] = s[i];
  t[++m] = 40;
  SA::build(m);
  for(int _ = 1; _ <= q; _++) {
    int i, r;
    cin >> i >> r;
    Part1::add(i, r, _);
    Part2::add(i, r, _);
  }
  Part1::solve();
  Part2::solve();
  for(int i = 1; i <= q; i++) {
    cout << ans[i] << "\n";
    ans[i] = 0;
  }
}

bool Med;
int main() {
  fprintf(stderr, "%.3lf MB\n", (&Mbe - &Med) / 1048576.0);
  ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
  int c, T = 1;
  cin >> c >> T;
  while(T--) mian();
  cerr << 1e3 * clock() / CLOCKS_PER_SEC << " ms\n";
  return 0;
}