题解:AT_arc199_c [ARC199C] Circular Tree Embedding

· · 题解

dalao 的题解都太意识流了,增加一些可能不怎么严谨的证明。

引理 1

若一个排列 p 合法,那么经过 p_i\to (p_i+k)\%n+1 的变换后,仍旧合法。

证明比较显然,可以看作那个环旋转了一下。

引理 2

invP^1 的逆排列,替换 P_j^iP_{inv_j}^i 后,不会改变答案。

证明:

这个替换过程,实际上就是一个重标号过程,我们把原先合法的树重标号后变为合法,且此时不会新增合法的树。

引理 3

树是合法的的充要条件为它的所有子树对应排列的标号为连续的(由于是环,可能是一段前缀和一段后缀)

先证明充分性,假设我们现在连接了 u 子树内所有的边,且目前是合法,且满足 u 子树内对应的标号为连续的,为 [l,r],现在考虑边 (fa_u,u),设 p_u=x,p_{fa_u}=y

首先,若 y=l-1/r+1,显然满足条件。

反之,边 (x,y) 会把环上分为两段,不妨视为,两段的点只能在内部相互连边。

若这两段内都有元素不在 fa_u 的子树内,那么这个图一定连不满 n-1 条边(最优的情况,内部连完边后还需要连接两条边,而顶多有一条 (fa_{fa_u},fa_u) 的边需要连接)

也就是说,至少有一段的点都在 fa_u 子树内,综上得证。

后证明必要性,即这样子的生成的树一定合法。

从上往下在环上连边即可,随便想一下就知道一定合法的。

根据引理 1,3,我们可以把问题简化一下,不妨钦定 P_{i,1}=1,以及 1 为根,此时问题不变,且实现了破环为链,不用考虑连续段跨过首尾的情况。

由引理 3,我们想要知道有所有下标集合 S,满足 \forall i\in[1,m],T=\{p_{i,j}|j\in S\}T 中的数为连续的。

由引理 2,经过处理,可使得 P^{1}=(1,2,3,\cdots,n),只样子上面的点集就一定为若干个区间了。

可以用 \mathcal{O}(n^2m) 的时间处理完每个区间 [l,r] 是否合法。

然后可开始区间 dp,设 f_{l,r} 表示 [l,r] 中的数构成的树的个数,可以把这个区间拆为多个儿子区间和一个根节点,增设 g_{l,r} 表示只考虑儿子区间的个数,那么 f_{l,r}=\sum g_{l,rt-1}g_{rt+1,r}g_{l,r}=\sum f_{i,k}g_{k+1,r} 得到,注意边界情况。时间复杂度为 \mathcal{O}(n^3)

最终答案其实就是 g_{2,n},因为我们钦定了 1 为根。

总时间就是 \mathcal{O}(n^2(n+m))

#include<bits/stdc++.h>
#define ull unsigned long long
#define ll long long
#define p_b push_back
#define m_p make_pair
#define pii pair<int,int>
#define fi first
#define se second
#define ls k<<1
#define rs k<<1|1
#define mid ((l+r)>>1)
#define gcd __gcd
#define lowbit(x) (x&(-x))
using namespace std;
int rd(){
    int x=0,f=1; char ch=getchar();
    for(;ch<'0'||ch>'9';ch=getchar())if (ch=='-') f=-1;
    for(;ch>='0'&&ch<='9';ch=getchar())x=(x<<1)+(x<<3)+(ch^48);
    return x*f;
}
void write(int x){
    if(x>9) write(x/10);
    putchar('0'+x%10);
}
const int N=500+5,INF=0x3f3f3f3f,mod=998244353;
int n,m,p[N][N],op[N][N],inv[N],tmp[N],f[N][N],g[N][N];
void add(int &x,int y){
    x+=y;
    if(x>=mod)x-=mod;
}
int main(){
    //freopen(".in","r",stdin);
    //freopen(".out","w",stdout);
    n=rd(),m=rd();
    for(int i=1;i<=n;i++)for(int j=i;j<=n;j++) op[i][j]=1;
    for(int i=1;i<=m;i++){
        for(int j=1;j<=n;j++)p[i][j]=rd();  
        for(int j=n;j;j--)p[i][j]=(p[i][j]-p[i][1]+n)%n+1;
    }
    for(int i=1;i<=n;i++)inv[p[1][i]]=i;
    for(int i=1;i<=m;i++){
        for(int j=1;j<=n;j++)tmp[j]=p[i][inv[j]];   
        for(int j=1;j<=n;j++)p[i][j]=tmp[j];
    }
    for(int t=1;t<=m;t++){
        for(int i=1;i<=n;i++){
            int mn=INF,mx=-INF;
            for(int j=i;j<=n;j++){
                mn=min(mn,p[t][j]),mx=max(mx,p[t][j]);
                if(j-i!=mx-mn)op[i][j]=0;
            }
        }
    }
    for(int i=1;i<=n+1;i++) g[i][i-1]=1;
    for(int len=1;len<=n;len++){
        for(int l=1;l<=n-len+1;l++){
            int r=l+len-1;
            if(op[l][r]) for(int k=l;k<=r;k++) add(f[l][r],1ll*g[l][k-1]*g[k+1][r]%mod);
            g[l][r]=f[l][r];for(int k=l;k<r;k++) add(g[l][r],1ll*f[l][k]*g[k+1][r]%mod);
        }
    }
    printf("%d\n",g[2][n]);
    return 0;
}