[AT TTPC2015 O] 数列色ぬり 题解

· · 题解

:::info[鲜花] MX NOIP 模拟赛 T3,补题时只找到了一篇 十年前的日文题解,并且目前洛谷里这道题唯一一个讨论甚至还是提供翻译,并且也是七年前的了。这题真冷门。 :::

下文中,最长上升子序列简写为 LIS,最长下降子序列简写为 LDS。

注意到红色的元素构成了一个单调递增的子序列 p,蓝色的元素构成了一个单调递减的子序列 q。在最优情况下,若 p 恰好是 LIS,且 q 恰好是 LDS,那么答案为 |LIS| + |LDS|;否则,一定是 LIS 与 LDS 共享了一个元素,此时答案为 |LIS| + |LDS| - 1

所以答案一定是 |LIS| + |LDS| - 1|LIS| + |LDS| 中的一个。于是我们只需要找到:是否存在一组 LIS 与 LDS,使得它们的交集为空。若存在,则答案为 |LIS| + |LDS|,否则答案为 |LIS| + |LDS| - 1

直接判断是否存在有些困难,正难则反,考虑计算出一共有多少组 LIS 与 LDS,其中又有多少组 LIS 与 LDS,使得它们的交集不为空。如果每组 LIS 与 LDS 的交集都不为空,则说明不存在一组 LIS 与 LDS 交集为空,否则存在。

计算一共有多少组 LIS 与 LDS 是容易的,将 LIS 的数量与 LDS 的数量相乘即可。要计算交集不为空的 LIS 与 LDS 数量,枚举交集即可。具体地:

\sum_{i = 1}^n (包含\ a_i\ 的\ LIS\ 的数量) \times (包含\ a_i\ 的\ LDS\ 的数量)

如果上式等于 LIS 的数量与 LDS 的数量的乘积,则说明不存在一组 LIS 与 LDS 使得它们的交集为空,答案为 |LIS| + |LDS| - 1,否则存在,答案为 |LIS| + |LDS|。具体的计算过程可以通过权值线段树实现,这里不再展开。

代码实现上,计算方案数时需要取模,也许是数据不强,单模也可以过。

:::success[Code]{open}

#include <bits/stdc++.h>
#define rep(i, s, e) for (int i = s; i <= e; ++i)
#define _rep(i, s, e) for (int i = s; i >= e; --i)
#define int long long
#define pii pair<int, int>
using namespace std;

constexpr int mod = 998244353;

struct Segtree {
    pii tr[400005];

    inline pii upd(pii a, pii b) {
        if (a.first < b.first) swap(a, b);
        if (a.first == b.first) (a.second += b.second) %= mod;
        return a;
    }

    inline void build(int l, int r, int p) {
        tr[p] = {0, 0};
        if (l == r) return;
        int mid = (l + r) >> 1;
        build(l, mid, p << 1);
        build(mid + 1, r, p << 1 | 1);
    }

    inline void update(int l, int r, int p, int s, pii d) {
        if (l == r) {
            tr[p] = d;
            return;
        }
        int mid = (l + r) >> 1;
        if (s <= mid) update(l, mid, p << 1, s, d);
        else update(mid + 1, r, p << 1 | 1, s, d);
        tr[p] = upd(tr[p << 1], tr[p << 1 | 1]);
    }

    inline pii query(int l, int r, int p, int s, int t) {
        if (s > t) return {0, 0};
        if (s <= l && r <= t) return tr[p];
        int mid = (l + r) >> 1;
        pii ans = {0, 0};
        if (s <= mid) ans = upd(ans, query(l, mid, p << 1, s, t));
        if (t > mid) ans = upd(ans, query(mid + 1, r, p << 1 | 1, s, t));
        return ans;
    }
} tr;

int n, a[100005];
int li[100005], ld[100005], ri[100005], rd[100005], len[100005];
int lis, lds, tot, sum;
pii tmp;

/*
li[i] : 1 -> i 的 LIS 长度
ld[i] : 1 -> i 的 LDS 长度
ri[i] : n -> i 的 LIS 长度
rd[i] : n -> i 的 LDS 长度 
len[i] : 经过 a[i] 的 LIS / LDS 长度 
lis : LIS 总长度
lds : LDS 总长度 
tot : (LIS, LDS) 总数
sum : 交集不为空的 (LIS, LDS) 数量
*/

signed main() {
    scanf("%lld", &n);
    rep(i, 1, n) scanf("%lld", &a[i]);

    // 计算 li[i]
    tr.build(1, n, 1);
    rep(i, 1, n) {
        tmp = tr.query(1, n, 1, 1, a[i] - 1);
        (tmp.second += !tmp.first) %= mod;
        ++tmp.first;
        tr.update(1, n, 1, a[i], tmp);
        len[i] = tmp.first;
        li[i] = tmp.second;
    }
    tmp = tr.query(1, n, 1, 1, n);
    lis = tmp.first;
    tot = tmp.second;

    // 计算 rd[i]
    tr.build(1, n, 1);
    _rep(i, n, 1) {
        tmp = tr.query(1, n, 1, a[i] + 1, n);
        (tmp.second += !tmp.first) %= mod;
        ++tmp.first;
        tr.update(1, n, 1, a[i], tmp);
        len[i] += tmp.first - 1;
        rd[i] = tmp.second;
    }

    rep(i, 1, n) {
        if (len[i] != lis) {
            li[i] = rd[i] = 0;
        }
    }

    // 计算 ld[i] 
    tr.build(1, n, 1);
    rep(i, 1, n) {
        tmp = tr.query(1, n, 1, a[i] + 1, n);
        (tmp.second += !tmp.first) %= mod;
        ++tmp.first;
        tr.update(1, n, 1, a[i], tmp);
        len[i] = tmp.first;
        ld[i] = tmp.second;
    }
    tmp = tr.query(1, n, 1, 1, n);
    lds = tmp.first;
    (tot *= tmp.second) %= mod;

    // 计算 ri[i]
    tr.build(1, n, 1);
    _rep(i, n, 1) {
        tmp = tr.query(1, n, 1, 1, a[i] - 1);
        (tmp.second += !tmp.first) %= mod;
        ++tmp.first;
        tr.update(1, n, 1, a[i], tmp);
        len[i] += tmp.first - 1;
        ri[i] = tmp.second;
    }

    rep(i, 1, n) {
        if (len[i] != lds) {
            ld[i] = ri[i] = 0;
        }
    }

    // 计算 sum 
    rep(i, 1, n) {
        (sum += li[i] * ld[i] % mod * ri[i] % mod * rd[i] % mod) %= mod;
    }

    printf("%lld\n", lis + lds - (tot == sum));
    return 0;
}

:::