题解:CF2144E1 Looking at Towers (easy version)

· · 题解

题目传送门 CF

题目大意

n 个塔,第 i 个塔的高度为 h_i

定义 L(h) 为从左侧能看到的塔的集合,R(h) 为从右侧能看到的塔的集合。从左侧能看到当前这个塔当且仅当它左侧的塔都严格小于它,右侧同理。

现在给出一个序列 a_1, a_2, a_3, \cdots, a_n。求它有多少个子序列 a' 满足 L(a) = L(a')R(a) = R(a')。答案对 998244353 取模。

题目分析

考虑 dp。

注意到,L(a)R(a) 能预处理,所以不妨先预处理出来。

因为 n \le 5000,所以应该是 \mathcal{O}(n ^ 2) 的算法。

我们可以设计出一个状态 dp_{i, j} 表示左侧前 i 个数当前已经与 L(a) 匹配了前 j 个数。

接下来考虑如何转移,考虑当前状态 dp_{i, j}

最后统计答案。考虑枚举最高点,那么这个最高点对答案的贡献就是左侧匹配上 L(a) 的方案数乘上右侧匹配上 R(a) 的方案数,但这样会重复统计,因为会重复统计到两侧都用到了这个最高点的情况,所以我们只需要强制一边选最高点,另一边不选最高点即可。

code

#include <bits/stdc++.h>
#define ft first
#define sd second
#define endl '\n'
#define pb push_back
#define md make_pair
#define gc() getchar()
#define pc(ch) putchar(ch)
#define umap unordered_map
#define pque priority_queue
using namespace std;
typedef double db;
typedef long long ll;
typedef unsigned long long ull;
typedef __int128 bint;
typedef pair<int, int> pii;
typedef pair<pii, int> pi1;
typedef pair<pii, pii> pi2;
const ll INF = 0x3f3f3f3f;
const db Pi = acos(-1.0);
inline ll read()
{
    ll res = 0, f = 1; char ch = gc();
    while (ch < '0' || ch > '9') f = (ch == '-' ? -1 : f), ch = gc();
    while (ch >= '0' && ch <= '9') res = (res << 1) + (res << 3) + (ch ^ 48), ch = gc();
    return res * f;
}
inline void write(ll x)
{
    if (x < 0) x = -x, pc('-');
    if (x > 9) write(x / 10);
    pc(x % 10 + '0');
}
inline void writech(ll x, char ch) { write(x), pc(ch); }
const int mod = 998244353;
const int N = 5e3 + 5;
int a[N], L[N], R[N];
int dpl[N][N], dpr[N][N];
void Add(int &a, int b) { a += b, a %= mod; }
// dpl[i][j] : left i numbers, now match left j of L(h)
// dpr[i][j] : right i numbers, now match right j of R(h)
void solve()
{
    int n = read();
    for (int i = 1; i <= n; i++) a[i] = read();
    for (int i = 0; i <= n + 1; i++)
        for (int j = 0; j <= n + 1; j++)
            dpl[i][j] = dpr[i][j] = 0;
    for (int i = 1; i <= n; i++) L[i] = R[i] = 0;
    int cntL = 0, cntR = 0; // init L(h) & R(h)
    for (int i = 1; i <= n; i++) if (L[cntL] < a[i]) L[++cntL] = a[i];
    for (int i = n; i >= 1; i--) if (R[cntR] < a[i]) R[++cntR] = a[i];
    for (int i = 0; i <= n; i++) dpl[i][0] = 1; // nothing
    dpl[1][1] = 1; // a[1] must be in L(h)
    for (int i = 2; i <= n; i++)
    {
        for (int j = 1; j <= min(cntL, i); j++)
        {
            dpl[i][j] = dpl[i - 1][j]; // don't select a[i]
            if (a[i] < L[j]) Add(dpl[i][j], dpl[i - 1][j]); // can select a[i]
            else if (a[i] == L[j]) Add(dpl[i][j], (dpl[i - 1][j - 1] + dpl[i - 1][j]) % mod);
            // if a[i] == L[j], we can select previous a[i] or just now a[i]
        }
    }
    for (int i = 1; i <= n + 1; i++) dpr[i][0] = 1; // nothing
    dpr[n][1] = 1; // a[n] must be in R(h)
    for (int i = n - 1; i >= 1; i--)
    {
        for (int j = 1; j <= min(cntR, n - i + 1); j++)
        {
            dpr[i][j] = dpr[i + 1][j]; // don't select a[i]
            if (a[i] < R[j]) Add(dpr[i][j], dpr[i + 1][j]); // can select a[i]
            else if (a[i] == R[j]) Add(dpr[i][j], (dpr[i + 1][j - 1] + dpr[i + 1][j]) % mod);
            // if a[i] == R[j], we can select previous a[i] or just now a[i]
        }
    }
    int Max = L[cntL];
    int ans = 0;
    for (int i = 1; i <= n; i++)
    {
        if (a[i] == Max) // now is maximum and must be in a'
        {
            int left = (dpl[i - 1][cntL] + dpl[i - 1][cntL - 1]) % mod; // we can select previous a[i] or just now a[i]
            int right = dpr[i + 1][cntR - 1]; // we can just select pervious a[i]
            Add(ans, 1ll * left * right % mod);
        }
    }
    writech(ans, '\n');
}
int main()
{
    int T = read();
    while (T--) solve();
    return 0;
}