Luogu P14031 【MX-X20-T5】「FAOI-R7」连接时光 II

· · 题解

cnblogs。

首先因为 f_S(p) 的限制都是对于前缀的图的限制,所以先来考察前缀的图的结构和变化情况。

经过手玩能够知道,对于前 i 个数的图,根据值域划分,连通块就为一些相邻的区间 。
然后在最后加入了一个数 a_{i + 1} = x(此时考虑的是相对大小)后,就相当于是加入了 [x, x + 1) 这个区间,并且把区间右端点大于 x 的区间都合并到一起。

此时能够发现,合并的一定都是这些区间里的后缀。

再结合这个限制,相当于是要求只有一个区间。
贪心的考虑,因为每次合并都是合并一段后缀,那么前 i 个数形成的连通块如果不是一个区间(不满足限制),最靠前的区间的左端点一定不为 i

于是会发现在过程中只关心最靠前的区间的右端点,那就可以考虑设计 dp 了。

f_{i, j} 表示前 i 个数,最靠前的区间的右端点是 j(前 i 个的相对顺序)的答案。

考虑如果加入了 a_{i + 1} = x(前 i + 1 个的相对顺序),此时 j 的变化是什么。
经过分讨容易知道有 j' = \begin{cases}i + 1 & 1\le x\le j\\j & j < x \le i + 1\end{cases}
那么转移就很好写出了:f_{i + 1, i + 1}\gets f_{i, j}\times \sum\limits_{k = 1}^j a_{i + 1}^{i + 1 - k}, f_{i + 1, j}\gets f_{i, j}\times \sum\limits_{k = j + 1}^{i + 1} a_{i + 1}^{i + 1 - k}
因为这都对应的是 a_{i + 1}^k 的前缀和或后缀和,所以容易优化到 \mathcal{O}(1)
当然也可以把 f_{i + 1, i + 1} 视作“不合法”的方案数,类似总和减掉合法方案数,求出 f_{i + 1, j}(j\le i)f_{i + 1, i + 1} = \sum\limits_{j = 1}^i f_{i, j}\sum\limits_{j = 0}^i a_{i + 1}^j - \sum\limits_{j = 1}^i f_{i + 1, j},这样就只需要关心 a_{i + 1}^k 的前缀和,会方便写一些。

对于 s_i = 1 的位置,要求前缀必须为一个连通块,把不合法的 f_{i, j}(j < i) 都置为 0 即可。

最后套上外层的 \sum\limits_{T\subseteq S},对于 s_i = 1 的位置,可以是为 1,就只保留 f_{i, i};也可以为 0,所有数都不变。
于是扩展到 \sum\limits_{T\subseteq S} 只需要加上一个 f_{i, i}\gets f_{i, i}\times (1 + s_i)

时间复杂度 \mathcal{O}(n^2)

#include <bits/stdc++.h>

using ll = long long;

constexpr ll mod = 998244353;
constexpr int maxn = 5000 + 10;

int n, a[maxn];
char s[maxn];
ll f[maxn];

inline void solve() {
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) {
        scanf("%d", &a[i]);
    }
    scanf("%s", s + 1);

    f[1] = 1 + (s[1] - '0');
    for (int i = 2; i <= n; i++) {
        ll pw = 1, sum = 1, sumf = 0;
        for (int j = i - 1; j >= 1; j--) {
            sumf = (sumf + f[j]) % mod;
            f[j] = f[j] * sum % mod;
            pw = pw * a[i] % mod;
            sum = (sum + pw) % mod;
        }
        f[i] = sumf * sum % mod;
        for (int j = 1; j < i; j++) {
            f[i] = (f[i] - f[j] + mod) % mod;
        }
        f[i] = f[i] * (1 + s[i] - '0') % mod;
    }

    ll ans = 0;
    for (int i = 1; i <= n; i++) {
        ans = (ans + f[i]) % mod;
    }
    printf("%lld\n", ans);
}

int main() {
    int t;
    scanf("%d", &t);
    while (t--) {
        solve();
    }
    return 0;
}