P2282题解

· · 题解

写在前面

唉,这题真的是,不愧是黑题啊!整整耗费我了两天连想带调终于做完了,也算是本蒟蒻的第一道黑题(之前的都掉紫了)。

写着题的时候,Day 1 做了这道题的弱化版,然后熬着把这道题的思路捋清了。Day 2 没起来,上课迟到,回家后开始写代码,写出来之后就是调,调了整整一下午,不知道是哪有问题,又跟题解对了对,还是没发现,当时真的是心态炸了,最后发现就是数组没初始化,真的糊了!

所以我就想写这篇题解,思路是一样的,然后增加一些我的理解。

前置知识

题目描述(戳这里查看原题)

给定一个只有数字的字符串,要求通过添加任意多个逗号(可以为 0),将原字符串拆分成一个严格单调递增的数列。多解情况下使数列最后一项最小,第一项的字典序最大(相同情况下满足第二项字典序最大,以此类推)。

测试数据有多组。

正文

根据弱化版 P1415 拆分数列,我们获得了一个 O(n^3) 的算法。这里简单概述一下。

我们定义:

对于两个数组有如下转移方程:
(定义 Num(i, j) 表示字符串中区间 [i, j] 表示的数字)

这里注意。在正向 DP 完 f 数组后,要将 f[n] 数组的前导零加上,对于 [f[n],n] 内的 g 数组要赋值为 n。之后再进行 g 数组的反向 DP。

确定优化方向

考虑优化。正向瞪眼法如果没有用,我们尝试“面向数据编程”,也就是通过我们的经验,根据题目中的数据推出我们应该拥有的复杂度。

首先,多组测试数据 O(T) 不可避免。其次,每次 DP 时一定要把 [1,n] 全部遍历一遍,根据经验这块地方也不容易优化,乘上 O(n)。现在的 O(Tn) 复杂度已经到达 2\times10^6 级别了。根据合理复杂度在 10^8 以下,最多只能带 \log 或者 \log^2。因此,我们最终算法的复杂度大约是 O(Tn\log n)

因此,能给我们优化的无非两个地方,字符串比较和 DP 时枚举 j 的转移

优化字符串比较

在我们 naive 的做法里,字符串比较大小是 O(n) 的,显然可以优化,我们来思考我们的比较步骤。

首先,在比较大小的过程中两个字符串中的前导零都是要除去的。单这一步就能把我们卡到 O(n)。字符串一直是不变的,何不把每个位置去掉前导零后的位置记录下来呢?

因此我们定义 l[i]r[i] 为字符串中 i 位置向左/右第一个不是 0 的位置。可以通过 O(n) 预处理得出,查询是 O(1) 的。

对于位置 i,如果它对应的值是 '0',那么就有 l[i] = l[i-1]r[i] = r[i+1];否则都是 i

回归字符串。我觉得所有人最初的想法应该都是和前缀和有关。比如下面这个数字串:

"1145141919810"

我们希望能够在 O(1) 知道一段区间表示的数字。因而由类似前缀和思想定义 Sum[i] 表示从头到 i 表示的数字,比如 Sum[6] = 114514

因而 Sum[i] = Sum[i-1] \times 10 + str[i]。(str[i] 表示 i 位置字符串表示的数字)

所以对于区间 [i,j] 表示的数字有 Num(i,j) = Sum[j] - Sum[i-1]\times 10^{j-i+1}。(例如 Num(3,6) = Sum[6] - Sum[2] \times 10^4 = 114514 - 11\times 10000 = 4514

似乎柳暗花明了?甚至连前导零都没必要考虑了,太好了吧?其实不然。注意到 unsigned long long 也只能表示大概 18 位数字,而我们很有可能拆分数列的位数达到几十或上百,因此超出的位数很容易被卡掉,直接比较大小是不行的

观察上面的式子,似乎很像哈希吧?可以自己试一试,虽然不能直接比较大小,但是可以判断两数是否相等。(此时注意将 base 值改成质数,10 会被卡)

对于两个字符串(已经去除前导零),如果它们长度不相等,那么一定可以通过比较长度判断大小关系。如果它们长度相等,考虑 naive 做法,我们找到第一个同一位但是数字不相等的就可以比较大小了(也就是之前的数字都相等),此时是 O(n)

那么我们现在就是要找到两个字符串的最长公共前缀,然后比较下一位的大小即可。我们刚才演变出的“哈希”给我们提供了 O(1) 判断字符串相等的功能。进而通过观察发现最长公共前缀的长度是满足单调性的,即如果长度 l 不相等,则最长公共前缀长度一定比 l 小,那么我们就可以进行二分。

进而,我们发现去除前导零是 O(1),查找最长公共前缀是 O(\log n),比较后一位大小是 O(1)。因此我们比较字符串函数的复杂度成功从 O(n) 优化到了 O(\log n)

深入研究

还记得我们最初分析时定义的 l[i]r[i] 吗?我们可以再深入探究一下它们的性质,因为后面有用(当然可以先看后面再来这里对照)。

先举个例子:
"114000000514"
容易知道 r[4] = 10l[9] = 3 等等。

我们探究如下性质:

  1. 对于位置 p,我们想让 p 能包括它之前的所有前导零。那么此时 p 的位置变为 q。如何通过两个数组表示 q
    对此我们分类讨论。

    • p 前有前导零,则 p-1 的位置一定是 0,要找到最前的位置,则可以通过找到这一串 0 前第一个不是 0 的位置,向右 +1 即可。
      我们推得有 q = l[p-1]+1
    • p 前没有前导零,则 q = p。但为了让它的性质普遍,我们验证上一种情况的公式是否适用。有我们设定的情况,l[p-1] = p-1
      因而仍满足 q = l[p-1]+1 = p
  2. 对于位置 p,它所表示的不确定是否 0。我们已经求得 r[p] = R_0,现在要逆推出可能的最靠前 p 的位置 q。(即使 q 能包含 p 的所有前导零)
    我们通过性质 1 举一反三。

    • p 表示 0。则可以确定 R_0-1 一定表示 0,且 [p,R_0-1] 是一个零串。我们想让 p 包含它前面所有前导零,
      根据性质 1 我们得出 q = l[R_0-1]+1
    • p 不表示 0。则 r[p] = R_0 = p。同样通过性质 1,有 q = l[R_0-1]+1

有了这些性质,我们就可以继续思考了。

优化 DP 转移

考虑之前的转移方法,填表法查之前的状态再一个一个比对是否合法更新答案。DP 的转移方法有时候会成为优化的关键。

查之前的状态因为要一次次比对,贡献可能不连续,所以不能通过诸如单调队列等方法优化。我们不妨改变一下转移方法,用前面更新后面(就是所谓的刷表法),我们尝试一下看看前面对后面有贡献的区间是否满足连续性,连续的话就可以通过一些方法优化了。

f 为例。联想之前的转移方程。jf[i] 有贡献当且仅当 Num(f[j],j)<Num(j+1,i)。所以对于 j,它能有贡献的 i 有哪些?

回想比较函数,只要 Num(j+1,i) 长度更大,则 j 一定能更新 i。更小一定不可以,相等则需要跑一遍比较函数。因而满足 i-(j+1)+1≥j-f[j]+1。注意考虑前导零的情况,所以严谨写出不等式为 i-r[j+1]+1≥j-r[f[j]]+1
整理得:i≥j-r[f[j]]+r[j+1]。(等号需要特判)

因此,j 所能有贡献的范围是一段连续的范围(最大到 n)。最后每个位置的答案是之前有贡献的最大值,那这就变成了一个可以区间赋值的最大值问题。可以通过线段树优化

进而,我们转移的复杂度就从 O(n^2) 变成了 O(n\log n)

转移 g 时略显繁琐,需要用到我们探究的性质。根据刚才的思路 j 能对 g[i] 有贡献当且仅当 (j-1)-r[i]+1 ≤ g[j] - r[j]+1,整理得 r[i] ≥ j-1-g[j]+r[j]

不太妙。我们确定了 r[i] 的范围,却无法确定 i。很显然,i 前面有多少前导零对 r[i] 无影响,所以我们应该让 i 尽可能的小。这不就是我们的性质 2 吗?因此我们确定了最左端 i ≥ l[j-1-g[j]+r[j]-1]+1。等号需要特判。如果不满足,则原来的 r[i] = r[i]+1。因而求得的 i = r[i]+1。注意此时贡献的范围是 [i,j-1]

(我写题解的时候脑抽了一个想法,为什么更新 f 的时候不用 l[i-1]+1?注意到右端点是不配拥有前导零的。

时间复杂度

这里主要是区分字符串比较和线段树的两个 \log n 是乘法还是加法。注意不是在线段树操作时进行比较函数,所以不满足乘法原理,两者是并行的,所以最终总时间复杂度是 O(Tn\log n)-O(n)

代码

注意事项

因为有多组测试数据,所以一定要多初始化,嫌麻烦的或者想不出来无脑 memset 就完事。否则的话牢记,比较函数二分的答案要初始化为 0,线段树初始化要将懒标记赋值为 0
剩下的看看代码细节就好了,主要是不太好调。

/* DP + Hash + SegTree */
#include <iostream>
#include <algorithm>
#include <cstring>

#define ls (id << 1)
#define rs (id << 1 | 1)
#define mid ((l + r) >> 1)

using namespace std;
typedef unsigned long long ull;
const int maxn = 2006;
const int base = 131;

ull Pow[maxn];
ull Hash[maxn];
char str[maxn];
int n;
int l[maxn], r[maxn];
int f[maxn], g[maxn];
struct SegmentTree{
    int mx, lazy;
}tr[maxn << 2];

void init(){//初始化
    Pow[0] = 1;
    for (int i = 1; i < maxn; i ++){
        Pow[i] = Pow[i-1] * base;
    }
}

void preTreat(){//每次预处理
    for (int i = 1; i <= n; i ++){
        Hash[i] = Hash[i-1] * base + (str[i] - '0');
    }

    r[n+1] = n+1;
    for (int i = 1, j = n; i <= n; i ++, j --){
        l[i] = str[i] == '0' ? l[i-1] : i, 
        r[j] = str[j] == '0' ? r[j+1] : j;
    }
}

ull getHash(int l, int r){
    return Hash[r] - Hash[l-1] * Pow[r-l+1];
}

bool is_greater(int l1, int r1, int l2, int r2){//[l1,r1] > [l2,r2] ?
    l1 = r[l1], l2 = r[l2]; //去除前导零
    int len1 = r1 - l1 + 1, len2 = r2 - l2 + 1;
    if (len1 > len2) return true;
    if (len1 < len2) return false;

    int L = 0, R = len1, Mid = 0, ans = 0;
    while (L <= R){
        Mid = (L + R) >> 1;
        if (getHash(l1, l1+Mid-1) == getHash(l2, l2+Mid-1)){
            ans = Mid;
            L = Mid + 1;
        }
        else{
            R = Mid - 1;
        }
    }
    if (ans == len1) return false; //完全一样
    int p1 = l1 + ans, p2 = l2 + ans;
    return str[p1] > str[p2];
}

void pushup(int id){
    tr[id].mx = max(tr[ls].mx, tr[rs].mx);
}

void build(int id, int l, int r, int v){
    /* 一定记得清空懒标记! */
    tr[id].lazy = 0;
    if (l == r){
        tr[id].mx = v;
        return;
    }
    build(ls, l, mid, v);
    build(rs, mid + 1, r, v);
    pushup(id);
}

void pushdown(int id){
    if (tr[id].lazy){
        int temp = tr[id].lazy;
        tr[ls].mx = max(tr[ls].mx, temp);
        tr[rs].mx = max(tr[rs].mx, temp);
        tr[ls].lazy = max(tr[ls].lazy, temp);
        tr[rs].lazy = max(tr[rs].lazy, temp);
        tr[id].lazy = 0;
    }
}

void update(int id, int l, int r, int a, int b, int v){
    if (a <= l && r <= b){
        tr[id].mx = max(tr[id].mx, v);
        tr[id].lazy = max(tr[id].lazy, v);
        return;
    }
    pushdown(id);
    if (a <= mid) update(ls, l, mid, a, b, v);
    if (b > mid) update(rs, mid+1, r, a, b, v);
    pushup(id);
}

int query(int id, int l, int r, int p){
    if (l == r){
        return tr[id].mx;
    }
    pushdown(id);
    if (p <= mid) return query(ls, l, mid, p);
    else return query(rs, mid+1, r, p);
}

int main(){
    ios::sync_with_stdio(false);
    cin.tie(0), cout.tie(0);
    init();

    while (cin >> (str + 1)){
        n = strlen(str + 1);
        preTreat();
        /* 正向DP */
        build(1, 1, n, 1);
        for (int i = 1; i <= n; i ++){
            f[i] = query(1, 1, n, i);
            int p = i - r[f[i]] + r[i+1];
            if (!is_greater(r[i+1], p, r[f[i]], i)){
                p ++;
            }
            if (p <= n) update(1, 1, n, p, n, i+1);
        }

        /* 反向DP */
        build(1, 1, n, 0);
        update(1, 1, n, l[f[n]-1]+1, n, n);
        for (int i = f[n]; i > 0; i --){
            g[i] = query(1, 1, n, i);
            int p = l[max(i-1-g[i]+r[i]-1, 0)] + 1;
            if (!is_greater(i, g[i], p, i-1)){
                p = r[p] + 1;
            }
            if (p <= i-1) update(1, 1, n, p, i-1, i-1);
        }

        /* 输出 */
        for (int i = 1; i <= n; i = g[i] + 1){
            for (int j = i; j <= g[i]; j ++){
                cout << str[j];
            }
            if (g[i] != n) cout << ',';
        }
        cout << endl;
    }

    return 0;
}

总结

反正这道题确实有水平,本蒟蒻写这道题还有它的题解也收获了很多,可能话多比较繁琐,但我觉得应该能讲得很清楚。 感谢观看!