P7444 题解

· · 题解

实感 2D 严格小于 2C。可能风格不太合口味。

显然根据 c_1 我们可以确定 0 在整个序列的最多两个合法的位置。

直接算肯定是不太行的,考虑将数一个一个加进去。定义 dp_{i,l,r} 表示,填了 0 \sim i-1,在排列里构成的极大区间为 [l,r] 的合法排列个数。分类讨论。

假设 c_i=0,那么新加入的这个数一定存在于之前的区间 [l,r]。可以根据 i 得到 [l,r] 的空位剩下多少,随便加一个就行了。转移显然。

否则,新加入的数可能在 [l,r] 左边或者右边。以左边为例。考虑 c_i 的构成,假设我们插入的位置为 L,那么就会有 (l-L+1) \times (n-r+1)\operatorname{mex} 等于 i 的子数组。那么显然 c_i | n-r+1。这个时候判断随便算算扩大区间就行了。可以看 80pts 代码实现。

实际上到这里可以猜测因为限制十分紧可以猜测实际可以得到的状态很少所以用个啥 map 存一存转移就可以得到 80pts 了。这题这点不行。

#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const LL MOD=998244353;
char buf[1<<21],*p1=buf,*p2=buf;
#define getchar() (p1==p2 && (p2=(p1=buf)+fread(buf,1,1<<18,stdin),p1==p2)?EOF:*p1++)
LL read()
{
    LL x=0;
    char c=getchar();
    while(c<'0' || c>'9')   c=getchar();
    while(c>='0' && c<='9') x=(x<<1)+(x<<3)+(c^'0'),c=getchar();
    return x;
}
void write(LL x)
{
    if(x>9) write(x/10);
    putchar(x%10+'0');
}
struct State{
    LL i,l,r;
    State(){}
    State(LL I,LL L,LL R){i=I,l=L,r=R;}
    bool operator < (State ano) const
    {
        if(i!=ano.i)    return i<ano.i;
        if(l!=ano.l)    return l<ano.l;
        if(r!=ano.r)    return r<ano.r;
        return false;
    }
};
map<State,bool> al;
map<State,LL> dp;
queue<State> Q;
LL n,c[500005];
LL Calc(LL x){return x*(x+1)/2;}
int main(){
    n=read();
    for(LL i=1;i<=n;++i)    c[i]=read();
    for(LL i=1;i<=n;++i)
    {
        LL l=i-1,r=i+1;
        if(Calc(l)+Calc(n-r+1)==c[1])
        {
            State tmp=State(1,i,i);
            dp[tmp]=1;
            Q.push(tmp);
        }
    }
    while(!Q.empty())
    {
        State now=Q.front();
        Q.pop();
        if(al[now]) continue;
        al[now]=true;
        LL nxt=now.i+1,l=now.l,r=now.r;
        if(nxt>n)   continue;
        if(c[nxt]==0)
        {
            LL rest=(r-l+1)-now.i;
            if(rest>0)
            {
                State stt=State(nxt,l,r);
                (dp[stt]+=dp[now]*rest)%=MOD;
                Q.push(stt);
            }
        }
        else
        {
            LL L=0,R=0;
            LL lef=l,rig=n-r+1;
            if(c[nxt]%lef==0)
            {
                R=r+c[nxt]/lef;
                if(R<=n)
                {
                    State stt=State(nxt,l,R);
                    (dp[stt]+=dp[now])%=MOD;
                    Q.push(stt);
                }
            }
            if(c[nxt]%rig==0)
            {
                L=l-c[nxt]/rig;
                if(L>=1)
                {
                    State stt=State(nxt,L,r);
                    (dp[stt]+=dp[now])%=MOD;
                    Q.push(stt);
                }
            }
        }
    }
    State Ans=State(n,1,n);
    write(dp[Ans]);
    return 0;
}

将上面的代码写成数组 dp 的形式的空间为 O(n^3)。滚动数组可以优化一维。考虑到 l,r 的消息不好消除,将 dp 内容输出可以发现 dp_{i,l} 中最多有一个 r 使得 dp_{i,l,r} \neq 0

证明可以考虑转移的过程中的两种不同情况左端点一定不同。存下区间左端点也没官方题解说的那么麻烦,开个数组就行了。这样空间就被压到 O(n) 了。

最后需要考虑的是证明状态数足够少。根据我们上面的发现显然状态数上界 O(n^2)。仔细考虑 a_i 之后发现所有的状态数实际会和 \sum \sqrt{a_i} 扯上关系。通过柯西不等式证明其小于等于 n \sqrt n,可以把时间复杂度看为 O(n \sqrt n)

#include<bits/stdc++.h>
#define Sunset namespace
#define Toybox std
using Sunset Toybox;
typedef long long LL;
const LL MOD=998244353;
char buf[1<<21],*p1=buf,*p2=buf;
#define getchar() (p1==p2 && (p2=(p1=buf)+fread(buf,1,1<<18,stdin),p1==p2)?EOF:*p1++)
LL read()
{
    LL x=0;
    char c=getchar();
    while(c<'0' || c>'9')   c=getchar();
    while(c>='0' && c<='9') x=(x<<1)+(x<<3)+(c^'0'),c=getchar();
    return x;
}
void write(LL x)
{
    if(x>9) write(x/10);
    putchar(x%10+'0');
}
LL Calc(LL x){return x*(x+1)/2;}
LL n,c[500005],dp[2][500005],st[2][500005];
bool vis[500005];
vector<LL> sp,pre,nxt;
int main(){
    n=read();
    for(LL i=0;i<n;++i) c[i]=read();
    for(LL i=1;i<=n;++i)
    {
        LL l=i-1,r=i+1;
        if(Calc(l)+Calc(n-r+1)==c[0])   sp.push_back(i);
    }
    if(LL(sp.size())==0)    return puts("0")&0;
    dp[0][sp[0]]=1;
    st[0][sp[0]]=sp[0];
    pre.push_back(sp[0]);
    for(LL i=1;i<n-1;++i)
    {
        nxt.clear();
        LL lst=(i-1)&1;
        for(auto j:pre)
        {
            LL l=j,r=st[lst][l];//dp{i,l} -> r
            if(vis[l])  continue;
            vis[l]=true;
            if(!c[i])
            {
                if(r-l-i>=0)
                {
                    (dp[i&1][l]+=dp[lst][l]*(r-l+1-i)%MOD)%=MOD;
                    st[i&1][l]=r;
                    nxt.push_back(l);
                }
            }
            else
            {
                LL lef=l,rig=n-r+1;
                if(c[i]%rig==0)
                {
                    LL L=l-c[i]/rig;
                    (dp[i&1][L]+=dp[lst][l])%=MOD;
                    st[i&1][L]=r;
                    nxt.push_back(L);
                }
                if(c[i]%lef==0)
                {
                    LL R=r+c[i]/lef;
                    (dp[i&1][l]+=dp[lst][l])%=MOD;
                    st[i&1][l]=R;
                    nxt.push_back(l);
                }
            }
        }
        for(auto j:pre) vis[j]=false,dp[lst][j]=0;
        nxt.swap(pre);
    }
    LL ans=0;
    for(LL i=1;i<=n;++i)    (ans+=dp[(n-2)&1][i])%=MOD;
    write(ans*LL(sp.size())%MOD);
    return 0;
}