题解:P12629 [NAC 2025] Popping Balloons

· · 题解

第一次排序时间是不好办的,考虑将每一个序列变成排序的过程列出来,给过程中每一个序列贡献 1。也即,求所有未排序的子序列出现概率的和就是答案,而子序列的出现概率只与长度有关。正难则反,我们希望对每个长度计算出排好序的子序列的数量。

注意到值域只有 [0,2],可以直接设 f_{i,j,len} 表示前 i 个元素以 j \in [0,2] 结尾形成长度为 len 的有序子序列的概率。直接 DP 复杂度是平方的,套路的考虑分治:对于区间 [l,r],设 f_{x,y,len} 表示左右分别是 x,y \in [0,2] 且长度为 len 的概率,合并左右区间时枚举端点,之后的合并是卷积。

总复杂度 \mathcal O(v^2n\log^2 n),还是可以通过的。

/**
 *    author: sunkuangzheng
 *    created: 21.06.2025 20:27:37
**/
#include<bits/stdc++.h>
#ifdef DEBUG_LOCAL
#include <mydebug/debug.h>
#endif

#include <algorithm>
#include <array>

#ifdef _MSC_VER
#include <intrin.h>
#endif

// namespace atcoder

#define V vector
#define A array
using ll = long long;
const int N = 5e5+5;
const int W = 3,mod = 998244353;
using namespace std;
using Z = atcoder::modint998244353;
#define T A<A<V<Z>,3>,3>
int n,a[N]; string s;
T solve(int l,int r){
    T dp;
    for(int i = 0;i < W;i ++) for(int j = i;j < W;j ++)
        dp[i][j].resize(r-l+2);
    if(l == r) return dp[a[l]][a[l]][1] = 1,dp;
    int mid = (l + r) / 2;
    T L = solve(l,mid),R = solve(mid+1,r);
    for(int i = 0;i < W;i ++) for(int j = i;j < W;j ++){
        for(int l = 0;l < L[i][j].size();l ++) dp[i][j][l] += L[i][j][l];
        for(int l = 0;l < R[i][j].size();l ++) dp[i][j][l] += R[i][j][l];
    }for(int i = 0;i < W;i ++) for(int p = i;p < W;p ++){
        for(int j = 0;j <= i;j ++) for(int k = p;k < W;k ++){
            V<Z> a = L[j][i],b = R[p][k];
            auto c = atcoder::convolution(a,b);
            for(int l = 0;l < c.size();l ++) dp[j][k][l] += c[l];
        }
    }return dp;
}int main(){
    ios::sync_with_stdio(0),cin.tie(0);
    cin >> s,n = s.size(),s = " " + s;
    for(int i = 1;i <= n;i ++) a[i] = (s[i] == 'R' ? 2 : (s[i] == 'Y'));
    auto dp = solve(1,n);
    V<Z> as(n + 1);
    for(int i = 0;i < W;i ++) for(int j = i;j < W;j ++)
        for(int l = 0;l <= n;l ++) as[l] += dp[i][j][l];
    for(int i = 1;i <= n;i ++) cerr << as[i].val() << " \n"[i == n];
    Z C = 1,ans = 0;
    for(int i = 1;i <= n;i ++)
        ans += (C - as[n-i+1]) / C,C = C * (n - i + 1) / i;
    cout << ans.val() << "\n";
}