题解:CF2144E1 Looking at Towers (easy version)
题目传送门 CF
题目大意
有
定义
现在给出一个序列
题目分析
考虑 dp。
注意到,
因为
我们可以设计出一个状态
接下来考虑如何转移,考虑当前状态
- 首先肯定可以不选
a_i ,所以有转移dp_{i, j} = dp_{i - 1, j} 。 - 当
a_i < L_j 时,我们这时可以选a_i ,所以dp_{i, j} \leftarrow dp_{i, j} + dp_{i - 1, j} ,其中dp_{i - 1, j} 代表选a_i 但是不进行匹配。 - 当
a_i = L_j 时,我们可以选当前的a_i ,也可以选之前的a_i 也就是不选现在的a_i ,所以dp_{i, j} \leftarrow dp_{i, j} + dp_{i - 1, j} + dp_{i - 1, j - 1} ,其中dp_{i - 1, j} 代表选之前的a_i ,所以之前就匹配完j 个了,而dp_{i - 1, j - 1} 代表选现在的a_i ,所以之前之匹配完了j - 1 个。 - 当
a_i > L_j 时,这时候会影响到L(a') ,所以不用转移。
最后统计答案。考虑枚举最高点,那么这个最高点对答案的贡献就是左侧匹配上
code
#include <bits/stdc++.h>
#define ft first
#define sd second
#define endl '\n'
#define pb push_back
#define md make_pair
#define gc() getchar()
#define pc(ch) putchar(ch)
#define umap unordered_map
#define pque priority_queue
using namespace std;
typedef double db;
typedef long long ll;
typedef unsigned long long ull;
typedef __int128 bint;
typedef pair<int, int> pii;
typedef pair<pii, int> pi1;
typedef pair<pii, pii> pi2;
const ll INF = 0x3f3f3f3f;
const db Pi = acos(-1.0);
inline ll read()
{
ll res = 0, f = 1; char ch = gc();
while (ch < '0' || ch > '9') f = (ch == '-' ? -1 : f), ch = gc();
while (ch >= '0' && ch <= '9') res = (res << 1) + (res << 3) + (ch ^ 48), ch = gc();
return res * f;
}
inline void write(ll x)
{
if (x < 0) x = -x, pc('-');
if (x > 9) write(x / 10);
pc(x % 10 + '0');
}
inline void writech(ll x, char ch) { write(x), pc(ch); }
const int mod = 998244353;
const int N = 5e3 + 5;
int a[N], L[N], R[N];
int dpl[N][N], dpr[N][N];
void Add(int &a, int b) { a += b, a %= mod; }
// dpl[i][j] : left i numbers, now match left j of L(h)
// dpr[i][j] : right i numbers, now match right j of R(h)
void solve()
{
int n = read();
for (int i = 1; i <= n; i++) a[i] = read();
for (int i = 0; i <= n + 1; i++)
for (int j = 0; j <= n + 1; j++)
dpl[i][j] = dpr[i][j] = 0;
for (int i = 1; i <= n; i++) L[i] = R[i] = 0;
int cntL = 0, cntR = 0; // init L(h) & R(h)
for (int i = 1; i <= n; i++) if (L[cntL] < a[i]) L[++cntL] = a[i];
for (int i = n; i >= 1; i--) if (R[cntR] < a[i]) R[++cntR] = a[i];
for (int i = 0; i <= n; i++) dpl[i][0] = 1; // nothing
dpl[1][1] = 1; // a[1] must be in L(h)
for (int i = 2; i <= n; i++)
{
for (int j = 1; j <= min(cntL, i); j++)
{
dpl[i][j] = dpl[i - 1][j]; // don't select a[i]
if (a[i] < L[j]) Add(dpl[i][j], dpl[i - 1][j]); // can select a[i]
else if (a[i] == L[j]) Add(dpl[i][j], (dpl[i - 1][j - 1] + dpl[i - 1][j]) % mod);
// if a[i] == L[j], we can select previous a[i] or just now a[i]
}
}
for (int i = 1; i <= n + 1; i++) dpr[i][0] = 1; // nothing
dpr[n][1] = 1; // a[n] must be in R(h)
for (int i = n - 1; i >= 1; i--)
{
for (int j = 1; j <= min(cntR, n - i + 1); j++)
{
dpr[i][j] = dpr[i + 1][j]; // don't select a[i]
if (a[i] < R[j]) Add(dpr[i][j], dpr[i + 1][j]); // can select a[i]
else if (a[i] == R[j]) Add(dpr[i][j], (dpr[i + 1][j - 1] + dpr[i + 1][j]) % mod);
// if a[i] == R[j], we can select previous a[i] or just now a[i]
}
}
int Max = L[cntL];
int ans = 0;
for (int i = 1; i <= n; i++)
{
if (a[i] == Max) // now is maximum and must be in a'
{
int left = (dpl[i - 1][cntL] + dpl[i - 1][cntL - 1]) % mod; // we can select previous a[i] or just now a[i]
int right = dpr[i + 1][cntR - 1]; // we can just select pervious a[i]
Add(ans, 1ll * left * right % mod);
}
}
writech(ans, '\n');
}
int main()
{
int T = read();
while (T--) solve();
return 0;
}