题解 P4156 【[WC2016]论战捆竹竿】

· · 题解

本题需要我们了解一个关于 border 的性质。

定理:

一个字符串的所有 border 排序后,至少存在一种将序列划分为若干子序列的方案,使得所有子序列均为等差数列,且子序列的个数为 O(\log|s|)

证明:

一个显然的结论:s[0..k - 1]s 的 border \rightleftharpoons |s| - ks 的周期。

另一个显然的结论:若 p , qs 的周期,则 \gcd(p, q) 也为 s 的周期。

以上两个结论可以套定义得到。

考虑两个 |s| 的 border : AB,其中 B 为最长的 border ( s 本身除外),|A| \geq \frac{|s|}{2}

p=|s| - |A|q=|s| - |B|

若这样的两个 border 存在,则 p , q , \gcd(p,q) 均为 s 的周期,所以 s[0..|s| - \gcd(p, q)]s 的 border,又因为 B 为最长的 border,所以 \gcd(p,q) = q

因此 q|p 。又因为 k \times q 也是 |s| 的周期,所以 s[0..|s| - k \times q - 1] 都是 s 的 border。因此所有长度大于等于字符串一半的 border 构成一个等差数列

考虑两个 |s| 的 border : AB (|A| < |B|)。

A 也为 B 的 border。

所以我们可以把 border 分成两个集合:第一个集合的 border 长度大于等于 \frac{|s|}{2},第二部分为其他。第一部分构成了一个等差数列,第二部分所有字符串为其中最长的字符串的 border,可以递归处理,至多 \log |s| 层。定理得证。

得到了这个定理,我们回到题目。

我们可以将题意转化为:\sum_{i=1}^{cnt}a_ix_i[0,w-n] 中能取到的值的个数,其中 cnt|s| 的周期个数,a 为周期长度。

由此我们想到了同余最短路

注意,\min\{a_i\}*cntO(n^2) 的,构造方法为 aaa\ldots aabaaa\ldots aa 。其中 b 处在字符串较中间的位置。

但是由于 border 拥有特殊的性质,我们可以进行优化。

我们把不同的等差数列分开处理。

首先,对于一个等差数列 x,x + d,x + 2d, \ldots x + l\times d,我们在\mod x 下跑同余最短路。其中对于 0 \leq y < xy (y + d) \mod x 连边,会形成 \gcd(x, d) 个环。每个环可以分开处理。

但我们此时发现,把等差数列分开处理的坏处就是此时没有源点,如果用环上任何一点去优化另外一点,需要 \Omega(n^2) ,这样复杂度很劣。

但是我们又可以发现,环中 dis 最小的一点是绝不会被更新的,且相邻两点之间权值相同,都为 d。这两条性质,可以让我们从 dis 最小点出发,使用单调队列计算。

我们在单调队列里放入离现在处理点距离小于等于 l 的点,以 dis_i - pos_i * d 作为比较方式(因为环上 a 点到 b 点的代价是 -pos_a * d + pos_b * d)。这样我们就能处理环上的转移。

还有一个问题,是 dis 数组在不同模数之间的转换。若原先的模数为 las,现在的模数为 now,很显然 dis_i 可以更新 dis'_{dis_i \mod now}。但 dis_i 的含义是 dis_i + k * las (k \geq 0) 可以被访问。所以我们在\mod now 意义下再用 x 更新(x + las)\mod now 跑一边同余最短路即可。

注意到这个过程类似于单个等差数列的处理过程。同样可以把环分开处理。不过由于没有 l 的限制,不需要用到单调队列。

时间复杂度:O(Tn\log n)

代码:

#include <bits/stdc++.h>
using namespace std;

const int Maxn = 500005;
int now, ct, fail[Maxn], Q[2 * Maxn], seq[Maxn], border[Maxn], pos[Maxn];
long long f[Maxn], res[Maxn], sta[Maxn], ans, w;
string str;
void get_fail(void)
{
    int siz = str.size();
    fail[0] = fail[1] = 0;
    for (int i = 1; i < siz; i++)
    {
        int tmp = fail[i];
        while (tmp && str[tmp] != str[i]) tmp = fail[tmp];
        fail[i + 1] = str[tmp] == str[i] ? tmp + 1 : 0;
    }
    int now = fail[siz];
    while (now)
    {
        border[++ct] = siz - now;
        now = fail[now];
    }
    border[++ct] = siz;
}
void change_mod(int mod)
{
    int cnt = __gcd(mod, now);
    for (int i = 0; i < now; i++)
        res[i] = f[i];
    for (int i = 0; i < mod; i++)
        f[i] = 0x3f3f3f3f3f3f3f3fLL;
    for (int i = 0, tmp; i < now; i++)
        tmp = res[i] % mod, f[tmp] = min(f[tmp], res[i]);
    for (int i = 0; i < cnt; i++)
    {
        int top = 0;
        Q[++top] = i;
        int tmp = (i + now) % mod;
        while (tmp != Q[1])
            Q[++top] = tmp, tmp = (tmp + now) % mod;
        for (int j = top + 1; j <= 2 * top; j++)
            Q[j] = Q[j - top];
        top <<= 1;
        for (int j = 2; j <= top; j++)
            f[Q[j]] = min(f[Q[j]], f[Q[j - 1]] + now);
    }
    now = mod;
}
void work(int first, int diff, int siz)
{
    int cnt = __gcd(diff, first);
    change_mod(first);
    if (diff < 0) return ;
    for (int i = 0; i < cnt; i++)
    {
        int top = 0;
        Q[++top] = i;
        int tmp = (i + diff) % first;
        while (tmp != Q[1])
            Q[++top] = tmp, tmp = (tmp + diff) % first;
        int mini_pos = 1;
        for (int j = 1; j <= top; j++)
            if (f[Q[j]] < f[Q[mini_pos]]) mini_pos = j;
        int tmp_cnt = 0;
        for (int j = mini_pos; j <= top; j++)
            seq[++tmp_cnt] = Q[j];
        for (int j = 1; j < mini_pos; j++)
            seq[++tmp_cnt] = Q[j];
        int head = 1, tail = 1;
        pos[1] = 1, sta[1] = f[seq[1]] - diff;
        for (int j = 2; j <= top; j++)
        {
            while (head <= tail && pos[head] + siz < j) head++;
            if (head <= tail) f[seq[j]] = min(f[seq[j]], sta[head] + j * (long long) diff + first);
            while (head <= tail && sta[tail] >= f[seq[j]] - j * (long long) diff) tail--;
            sta[++tail] = f[seq[j]] - j * (long long) diff, pos[tail] = j;
        }
    }
}
int T, n;
int main()
{
    scanf("%d", &T);
    while (T--)
    {
        ans = ct = 0;
        scanf("%d%lld", &n, &w), w -= n;
        memset(f, 0x3f, sizeof(long long[n]));
        f[0] = 0;
        now = n;
        cin >> str;
        get_fail();
        for (int i = 1, j = 1; i <= ct; i = j)
        {
            while (border[j + 1] - border[j] == border[i + 1] - border[i]) j++;
            work(border[i], border[i + 1] - border[i], j - i - 1);
        }
        for (int i = 0; i < now; i++)
            if (f[i] <= w) ans += (w - f[i]) / now + 1;
        printf("%lld\n", ans);
    }
    return 0;
}