[AT TTPC2015 O] 数列色ぬり 题解
Getaway_Car · · 题解
:::info[鲜花] MX NOIP 模拟赛 T3,补题时只找到了一篇 十年前的日文题解,并且目前洛谷里这道题唯一一个讨论甚至还是提供翻译,并且也是七年前的了。这题真冷门。 :::
下文中,最长上升子序列简写为 LIS,最长下降子序列简写为 LDS。
注意到红色的元素构成了一个单调递增的子序列
所以答案一定是
直接判断是否存在有些困难,正难则反,考虑计算出一共有多少组 LIS 与 LDS,其中又有多少组 LIS 与 LDS,使得它们的交集不为空。如果每组 LIS 与 LDS 的交集都不为空,则说明不存在一组 LIS 与 LDS 交集为空,否则存在。
计算一共有多少组 LIS 与 LDS 是容易的,将 LIS 的数量与 LDS 的数量相乘即可。要计算交集不为空的 LIS 与 LDS 数量,枚举交集即可。具体地:
如果上式等于 LIS 的数量与 LDS 的数量的乘积,则说明不存在一组 LIS 与 LDS 使得它们的交集为空,答案为
代码实现上,计算方案数时需要取模,也许是数据不强,单模也可以过。
:::success[Code]{open}
#include <bits/stdc++.h>
#define rep(i, s, e) for (int i = s; i <= e; ++i)
#define _rep(i, s, e) for (int i = s; i >= e; --i)
#define int long long
#define pii pair<int, int>
using namespace std;
constexpr int mod = 998244353;
struct Segtree {
pii tr[400005];
inline pii upd(pii a, pii b) {
if (a.first < b.first) swap(a, b);
if (a.first == b.first) (a.second += b.second) %= mod;
return a;
}
inline void build(int l, int r, int p) {
tr[p] = {0, 0};
if (l == r) return;
int mid = (l + r) >> 1;
build(l, mid, p << 1);
build(mid + 1, r, p << 1 | 1);
}
inline void update(int l, int r, int p, int s, pii d) {
if (l == r) {
tr[p] = d;
return;
}
int mid = (l + r) >> 1;
if (s <= mid) update(l, mid, p << 1, s, d);
else update(mid + 1, r, p << 1 | 1, s, d);
tr[p] = upd(tr[p << 1], tr[p << 1 | 1]);
}
inline pii query(int l, int r, int p, int s, int t) {
if (s > t) return {0, 0};
if (s <= l && r <= t) return tr[p];
int mid = (l + r) >> 1;
pii ans = {0, 0};
if (s <= mid) ans = upd(ans, query(l, mid, p << 1, s, t));
if (t > mid) ans = upd(ans, query(mid + 1, r, p << 1 | 1, s, t));
return ans;
}
} tr;
int n, a[100005];
int li[100005], ld[100005], ri[100005], rd[100005], len[100005];
int lis, lds, tot, sum;
pii tmp;
/*
li[i] : 1 -> i 的 LIS 长度
ld[i] : 1 -> i 的 LDS 长度
ri[i] : n -> i 的 LIS 长度
rd[i] : n -> i 的 LDS 长度
len[i] : 经过 a[i] 的 LIS / LDS 长度
lis : LIS 总长度
lds : LDS 总长度
tot : (LIS, LDS) 总数
sum : 交集不为空的 (LIS, LDS) 数量
*/
signed main() {
scanf("%lld", &n);
rep(i, 1, n) scanf("%lld", &a[i]);
// 计算 li[i]
tr.build(1, n, 1);
rep(i, 1, n) {
tmp = tr.query(1, n, 1, 1, a[i] - 1);
(tmp.second += !tmp.first) %= mod;
++tmp.first;
tr.update(1, n, 1, a[i], tmp);
len[i] = tmp.first;
li[i] = tmp.second;
}
tmp = tr.query(1, n, 1, 1, n);
lis = tmp.first;
tot = tmp.second;
// 计算 rd[i]
tr.build(1, n, 1);
_rep(i, n, 1) {
tmp = tr.query(1, n, 1, a[i] + 1, n);
(tmp.second += !tmp.first) %= mod;
++tmp.first;
tr.update(1, n, 1, a[i], tmp);
len[i] += tmp.first - 1;
rd[i] = tmp.second;
}
rep(i, 1, n) {
if (len[i] != lis) {
li[i] = rd[i] = 0;
}
}
// 计算 ld[i]
tr.build(1, n, 1);
rep(i, 1, n) {
tmp = tr.query(1, n, 1, a[i] + 1, n);
(tmp.second += !tmp.first) %= mod;
++tmp.first;
tr.update(1, n, 1, a[i], tmp);
len[i] = tmp.first;
ld[i] = tmp.second;
}
tmp = tr.query(1, n, 1, 1, n);
lds = tmp.first;
(tot *= tmp.second) %= mod;
// 计算 ri[i]
tr.build(1, n, 1);
_rep(i, n, 1) {
tmp = tr.query(1, n, 1, 1, a[i] - 1);
(tmp.second += !tmp.first) %= mod;
++tmp.first;
tr.update(1, n, 1, a[i], tmp);
len[i] += tmp.first - 1;
ri[i] = tmp.second;
}
rep(i, 1, n) {
if (len[i] != lds) {
ld[i] = ri[i] = 0;
}
}
// 计算 sum
rep(i, 1, n) {
(sum += li[i] * ld[i] % mod * ri[i] % mod * rd[i] % mod) %= mod;
}
printf("%lld\n", lis + lds - (tot == sum));
return 0;
}
:::