P11835 [省选联考 2025] 封印

· · 题解

这道题疑似会 \mathcal O(2^n n) 即会 \mathcal O(n^2)

但是笔者太菜场上写了两个小时 2^n,最后喜提没有时间优化了(但是优化其实是显然的)。

以下的思路完全顺着场上的思路做下来的,有很大可能有很多地方都可以优化(场上没想清楚就一直在改)。

首先先把 m \le 2 判掉,这一部分直接跑爆搜时间复杂度是对的(后面会说原因)。

我们认为第一轮操作表示操作若干个数直到下一次操作的是之前操作过的数,同理定义第二轮……

那么我们把计数过程分成两个部分:

因为每一轮都会让我们的 \max-1,所以属于不同轮之间的状态一定不会重(不是很严谨)。

接下来都是先考虑 \mathcal O(2^n \operatorname{poly} n) 做法,再进行简单优化得到 \mathcal O(n^3) \sim \mathcal O(n^2) 的做法。

考虑 第二轮即之后到达的状态,发现显然我们可以把 a_i=1 的去掉。

接下来考虑 2^n 枚举子集 S,表示我们计数保留 S 中的元素的方案数。

由于需要保留下 S,那么这个集合中每个元素都是从前面一个选中元素到这里的 严格最大值,并且第一个数还需要比最后一个数往后的数 +1 还要大。

那么在没有删任何元素之前(也就是没有操作到 1),整轮的代价是 (\min-2) \times |S|

现在考虑操作到当前这一轮,有一个元素会变成 1,假设它位于第 i 个,则对答案的贡献就是 i(操作 0 \sim i-1 次)。

反之在操作了第 i 个之后,S 这个集合就变化了,那么会被其他的集合计算到。

真的会计算到吗?

发现有 corner,也就是如果选出来的集合形如

2 1 1

就是不会被计算的。具体来说它们都形如 \max,\max,\cdots,\max-1,\max-1,因为当有一个 1 被拿走之后,剩下的集合初始是不合法的。

所以这种东西还有一个额外的贡献 |S|

综上,直接模拟出来就是(这里每一个 S 的贡献都 -1,因为第二轮的初始状态是在第一轮的状态中计数了的)

场上代码,希望能看,有些地方因为场上太急了非常冗余。

  int len=n;n=0;
  for(int i=1;i<=len;i++) if(a[i]>1) a[++n]=a[i];

  for(int S=1;S<(1<<n);S++){
    int mx=0;bool fl=true;
    for(int i=1;i<=n;i++){
      if(S>>(i-1)&1){
        if(mx>=a[i]){fl=false;break;}
        mx=0;   
      }else mx=max(mx,a[i]);
    }
    if(mx) ++mx;
    for(int i=1;i<=n;i++){
      if(S>>(i-1)&1){
        if(mx>=a[i]){fl=false;break;}
        mx=0;   
      }else mx=max(mx,a[i]);
    }
    if(!fl) continue;

    int mn=m+1,id=0,ct=0;
    vector<int> V;
    for(int i=1;i<=n;i++) if((S>>(i-1))&1){
      V.pb(a[i]);
      if(a[i]<mn) mn=a[i],id=ct;
      ++ct;
    }
    add(res,mul(mn-2,pc(S)));

    for(int i=id;i<(int)V.size();i=(i+1)%((int)V.size())){
      if(V[i]!=mn){ct=i;break;}
      if(i==(int)V.size()-1) ++mn;
    }
    add(res,(ct==id?(pc(S)-1):0)+id);
  }

考虑优化。

观察到上述贡献跟 \min 的位置有关,和集合大小有关。

所以考虑枚举 \min 的位置,并且将 A 倍长使得 a[i+n] = a[i]-1(这样方便处理跨过 n 的贡献)。

设当前认为第一个 \min 的位置在 id,我们考虑将序列 a[id \sim id+n] 拿出来 dp。

g_i 表示选了 i 这个位置的方案数,f_i 表示选了 i 这个位置的所有方案的目前 (\min-2) \times |S| 之和 + 若 i \gt n 有额外 1 的贡献(也就是暴力中的 id)。

dp 时直接枚举上一个元素选什么即可。

而对于 \max,\max,\cdots,\max-1,\max-1 的情况,发现我们一定会把所有的 \max 选了并且把之后的 \max-1 都选了,最后计算上这里的贡献就可以了。

以下 \mathcal O(n^3) 实现,可以过 96pts(拼上 AB 性质能过)。

  int len=n;n=0;
  for(int i=1;i<=len;i++) if(a[i]>1) a[++n]=a[i];
  for(int i=n+1;i<=2*n;i++) a[i]=a[i-n]-1;

  for(int id=1,cur=0;id<=n;id++){
    for(int i=1;i<=2*n;i++) f[i]=g[i]=0;
    g[id]=1,cur=a[id]-2;
    for(int i=id+1;i<=id+n;i++) if(i==id+n||a[i]>=a[id]){
      for(int j=i-1,mx=0;j>=id;j--){
        if(a[i]<=mx) break;
        add(f[i],f[j]),add(g[i],g[j]);
        mx=max(mx,a[j]);    
      }
      add(f[i],mul(cur,g[i]));
      if(i!=id+n&&i>n) add(f[i],g[i]);    
    }
    add(res,f[id+n]);
  }

  for(int i=n,fl=0;i>=1;i--){
    if(a[i]==m) fl=1,add(res,1);
    else if(!fl&&a[i]==m-1) add(res,1);
  }
  del(res,1);

\mathcal O(n^3) \to \mathcal O(n^2) 的优化显然,容易发现这里能转移的 j 是一个后缀,于是直接用单调栈维护 + 前缀和优化即可。

于是这一部分时间复杂度 \mathcal O(n^2)

  int len=n;n=0;
  for(int i=1;i<=len;i++) if(a[i]>1) a[++n]=a[i];
  for(int i=n+1;i<=2*n;i++) a[i]=a[i-n]-1;

  for(int id=1,cur=0;id<=n;id++){
    for(int i=1;i<=2*n;i++) f[i]=g[i]=sf[i]=sg[i]=0;
    g[id]=sg[id]=1,cur=a[id]-2,st[tp=1]=id;

    for(int i=id+1;i<=id+n;i++){
      while(tp&&a[i]>a[st[tp]]) --tp;
      if(i==id+n||a[i]>=a[id]){
        f[i]=dec(sf[i-1],sf[max(id,st[tp])-1]);
        g[i]=dec(sg[i-1],sg[max(id,st[tp])-1]);

        add(f[i],mul(cur,g[i]));
        if(i!=id+n&&i>n) add(f[i],g[i]);
      }
      st[++tp]=i;
      sf[i]=adc(sf[i-1],f[i]),sg[i]=adc(sg[i-1],g[i]);
    }
    add(res,f[id+n]);
  }

  for(int i=n,fl=0;i>=1;i--){
    if(a[i]==m) fl=1,add(res,1);
    else if(!fl&&a[i]==m-1) add(res,1);
  }
  del(res,1);

现在再来考虑 第一轮能到达的状态

同样我们枚举操作集合 S,表示我们操作了这些数。

同样的,这个集合的每个元素都要是和前一个元素之间的严格最大值。

并且为了不算重,我们还有一些额外的限制:

(或许这里的这些 corner 就是我们需要特判 m \le 2 的原因)

于是直接模拟上述过程得到如下代码

  int res=0,ret=0;
  for(int i=n;i>=1;i--){
    if(a[i]!=1) break;
    ++ret;
  }

  for(int S=0;S<(1<<n);S++){
    int mx=0;bool fl=true,vis=false;
    for(int i=1;i<=n;i++){
      if(S>>(i-1)&1){
        if(mx>=a[i]){fl=false;break;}
        if(a[i]>1&&vis){fl=false;break;}
        vis|=a[i]==1;
        mx=0;
      }else mx=max(mx,a[i]);
    }
    if(!fl) continue;

    bool chk=0;
    for(int i=n;i>n-ret;i--) chk|=(S>>(i-1)&1);
    for(int i=1;i<=n;i++) if(S>>(i-1)&1){
      chk&=a[i]==2;
      break;
    }
    if(!chk) add(res,fl);
  }

这里的优化就变得容易了很多了,我们对前 n-ret 个元素进行 dp。

h_{i,0/1} 表示选了第 i 个元素,第一个元素是否时 2 的方案数。

直接做就是 \mathcal O(n^2) 的,最后是否选后面 ret1,也就是带来 h_{n-ret,0} \times ret 的贡献(容易发现 n-ret 这个位置是一定要选的)。

  int res=1,ret=0;
  for(int i=n;i>=1;i--) if(a[i]!=1){ret=n-i;break;}

  for(int i=0;i<=n;i++) h[i][0]=h[i][1]=0;
  h[0][0]=1;

  for(int i=1;i<=n-ret;i++){
    for(int j=i-1,mx=0;j>=0;j--){
      if(a[i]<=mx) break;
      if(j==0){
        if(a[i]==2) add(h[i][1],1);
        else add(h[i][0],1);    
      }else if(!(a[i]>1&&a[j]==1))
        for(int k:{0,1}) add(h[i][k],h[j][k]);
      mx=max(mx,a[j]);
    }
    add(res,adc(h[i][0],h[i][1]));
  }
  add(res,mul(h[n-ret][0],ret));

注意需要记上初始状态,所以从 res=1 开始。

这样就做完了,总时间复杂度 \mathcal O(n^2)。代码。

想清楚其实很简单,但是我场上为什么写了两个小时暴力/ll/ll/ll

#include <bits/stdc++.h>
using namespace std;
#define pb push_back

const int N=5005,H=998244353;
int n,m,a[N],f[N],g[N],h[N][2],sf[N],sg[N],tp=0,st[N];

int adc(int a,int b){return a+b>=H?a+b-H:a+b;}
int dec(int a,int b){return a<b?a-b+H:a-b;}
int mul(int a,int b){return 1ll*a*b%H;}
void add(int &a,int b){a=adc(a,b);}
void del(int &a,int b){a=dec(a,b);}

namespace BF{
  map<vector<int>,bool > mp;

  void dfs(vector<int> V){
    if(mp.count(V)) return;
    mp[V]=1;

    vector<int> nw;
    int mx=0;
    for(int i=0;i<(int)V.size();i++){
      if(V[i]>mx){
        nw.clear();
        for(int j=i+1;j<(int)V.size();j++) nw.pb(V[j]);
        if(V[i]>1) nw.pb(V[i]-1);
        dfs(nw);
      }
      mx=max(mx,V[i]);
    }
  }

  void SOLVE(){
    vector<int> V;
    for(int i=1;i<=n;i++) V.pb(a[i]);
    mp.clear(),dfs(V);
    cout<<((int)mp.size()-1)%H<<'\n';
  }
}

int pc(int x){return __builtin_popcount(x);}

void SOLVE(){
  cin>>n>>m;m=0;
  for(int i=1;i<=n;i++) cin>>a[i],m=max(m,a[i]);

  if(m<=2) return BF::SOLVE();

  int res=1,ret=0;
  for(int i=n;i>=1;i--) if(a[i]!=1){ret=n-i;break;}

  for(int i=0;i<=n;i++) h[i][0]=h[i][1]=0;
  h[0][0]=1;

  for(int i=1;i<=n-ret;i++){
    for(int j=i-1,mx=0;j>=0;j--){
      if(a[i]<=mx) break;
      if(j==0){
        if(a[i]==2) add(h[i][1],1);
        else add(h[i][0],1);    
      }else if(!(a[i]>1&&a[j]==1))
        for(int k:{0,1}) add(h[i][k],h[j][k]);
      mx=max(mx,a[j]);
    }
    add(res,adc(h[i][0],h[i][1]));
  }
  add(res,mul(h[n-ret][0],ret));

  int len=n;n=0;
  for(int i=1;i<=len;i++) if(a[i]>1) a[++n]=a[i];
  for(int i=n+1;i<=2*n;i++) a[i]=a[i-n]-1;

  for(int id=1,cur=0;id<=n;id++){
    for(int i=1;i<=2*n;i++) f[i]=g[i]=sf[i]=sg[i]=0;
    g[id]=sg[id]=1,cur=a[id]-2,st[tp=1]=id;

    for(int i=id+1;i<=id+n;i++){
      while(tp&&a[i]>a[st[tp]]) --tp;
      if(i==id+n||a[i]>=a[id]){
        f[i]=dec(sf[i-1],sf[max(id,st[tp])-1]);
        g[i]=dec(sg[i-1],sg[max(id,st[tp])-1]);

        add(f[i],mul(cur,g[i]));
        if(i!=id+n&&i>n) add(f[i],g[i]);
      }
      st[++tp]=i;
      sf[i]=adc(sf[i-1],f[i]),sg[i]=adc(sg[i-1],g[i]);
    }
    add(res,f[id+n]);
  }

  for(int i=n,fl=0;i>=1;i--){
    if(a[i]==m) fl=1,add(res,1);
    else if(!fl&&a[i]==m-1) add(res,1);
  }
  del(res,1);

  cout<<res<<'\n';
}

int main(){
//  freopen("seal.in","r",stdin);
//  freopen("seal.out","w",stdout);
  ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
  int cas,_;cin>>cas>>_;
  while(_--) SOLVE();
  return 0;
}

拿到了目前最优解(