题解:P11807 [PA 2017] 抄作业

· · 题解

为了这道题特意学了在收藏夹里吃灰三年的主席树。

题意

m 个长度为 n 的序列 a_i。给定 a_1,并且 \forall2\le i\le m 都满足 a_ia_{i-1} 只在一个位置上不同,且给定不同位置及对应 a_i 的值。求这 m 个序列按字典序排序后的序列编号。

思路

最朴素想法肯定是把每个序列单独抠出来排序。问题在于一大堆数如何比较两个序列 i, j 字典序大小。

都朴素到这个地步了你就别想着暴力比较了。我们使用一种叫做二\tiny{分}\tiny{希}的东西,二分一个位置 p 比对 i,j 两个哈希 [1,p] 部分是否不同,借此可以找到第一个不一样的地方,然后就可以比大小了;若是超出去了就说明两个相等了,比较 i,j 本身数值大小即可。

把每个序列抠出来是 O(mn) 的,排序加上比较函数本身复杂度是 O(m \log m \log n) 的。预处理前缀哈希值的空间也是 O(mn) 的。

这样你就获得了一份 O( 不能过 ) 的优秀代码。

想想是否有什么条件被我们漏掉了:那当然是相邻两个序列之间只有一个不同!我们想到了主席树的优秀性质:每次只改一个点,只动一条链,便不需要把每个序列及其哈希抠出来;而前缀哈希甚至是区间哈希刚好可以使用线段树轻松维护,我们就高兴地过掉了这道题。

时间复杂度 O((n+m)\log n+m\log m\log n),毫无卡常必要。

一些细节:

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod = 998244353;
ll n, m, a[500002], rt[500002], idx = 1, pw[500002], ans[500002];
struct tree{
    ll x, hs, ls, rs;
}tr[20000002];
void build(ll x, ll l, ll r) {
    if (l == r) return tr[x] = {a[l], a[l] * pw[n - l] % mod, 0, 0}, void(0);
    ll mid = l + r >> 1;
    tr[x].ls = ++ idx, build(idx, l, mid);
    tr[x].rs = ++ idx, build(idx,mid+1,r);
    tr[x].hs = (tr[tr[x].ls].hs + tr[tr[x].rs].hs) % mod;
}
ll update(ll x, ll l, ll r, ll p, ll c) {
    if (l == p && r == p) return tr[++ idx] = {c, c * pw[n - l] % mod, 0, 0}, idx;
    ll mid = l + r >> 1, t;
    if (mid >= p) return t = update(tr[x].ls, l, mid, p, c), tr[++ idx] = {0, (tr[t].hs + tr[tr[x].rs].hs) % mod, t, tr[x].rs}, idx;
    if (mid <  p) return t = update(tr[x].rs,mid+1,r, p, c), tr[++ idx] = {0, (tr[tr[x].ls].hs + tr[t].hs) % mod, tr[x].ls, t}, idx;
}
ll query(ll lx, ll rx, ll l, ll r) {
    if (l == r) return tr[lx].hs == tr[rx].hs ? 2 : (tr[lx].x < tr[rx].x);
    ll mid = l + r >> 1;
    if (tr[tr[lx].ls].hs != tr[tr[rx].ls].hs) return query(tr[lx].ls, tr[rx].ls, l, mid);
    return query(tr[lx].rs, tr[rx].rs,mid+1,r);
}

bool cmp(ll az, ll bz) {
    ll t = query(rt[az], rt[bz], 1, n);
    if (t == 2) return az < bz;
    return t;
}
int main() {
    ios::sync_with_stdio(0);
    cin.tie(0), cout.tie(0); 
    pw[0] = 1;
    cin >> n >> m;
    ll bs = 131; 
    rt[1] = 1;
    for (ll i = 1; i <= n; i ++ ) pw[i] = pw[i - 1] * bs % mod, cin >> a[i];
    build(1, 1, n);
    for (ll i = 1; i <= m; i ++ ) ans[i] = i;
    for (ll i = 2, p, x; i <= m; i ++ ) cin >> p >> x, rt[i] = update(rt[i - 1], 1, n, p, x);
    sort(ans + 1, ans + m + 1, cmp);
    for (ll i = 1; i <= m; i ++ ) cout << ans[i] << (i == m ? "\n" : " ");
}