P13694

· · 题解

对于每个分割排列,钦定在其分断点的必须构成逆序(不满足该条件的分割排列直接对应唯一的排列 p,暴力 check 即可),其余地方顺序。枚举第一个排列的分段点,变成两条链,不妨先假设对于其余所有分割排列 p^{-1}_{q_{i-1}}<p^{-1}_{q_i} 成立,加入这样的边,则每个分割排列要求恰好违反一条边,可以直接记录状态 f_{i,j,S} 表示当前已经填完了第一条链上前 i 个点的 p_*^{-1} 和第二条链上前 j 个点的 p_*^{-1}S 内集合的分割排列已经违反恰好一条边;这样 dp 总复杂度上界是 \mathcal{O}(n^32^m) 的(应该可以分析出更好的界?),但实际上根本不满!!!加一些剪枝即可通过:

#include<bits/stdc++.h>
// #define int long long
#define ll long long
#define ld long double
#define bint __uint128_t
#define PII pair<int,int>
#define rep(k,l,r) for(int k=l;k<=r;++k)
#define per(k,r,l) for(int k=r;k>=l;--k)
#define chkmax(a,b) a=max(a,b)
#define chkmin(a,b) a=min(a,b)
#define cl(f,x) memset(f,x,sizeof(f))
using namespace std;
void file_IO() {
    freopen(".in","r",stdin);
    freopen(".out","w",stdout);
}
bool M1;
const int INF=0x3f3f3f3f;
const ll INFLL=0x3f3f3f3f3f3f3f3f;
const int MOD=998244353;
void add(int &a,int b) {
    a+=b;
    if(a>=MOD)
        a-=MOD;
}
const int N=3e2+1;
struct bits {
    bint x,y,z;
    bits() {
        x=y=z=0;
    }
    bool operator < (const bits &tmp) const {
        if(x!=tmp.x)
            return x<tmp.x;
        if(y!=tmp.y)
            return y<tmp.y;
        return z<tmp.z;
    }
    inline bool get(const short &c) {
        if(c<128)
            return (x>>c)&1;
        else if(c<256)
            return (y>>(c-128))&1;
        else
            return (z>>(c-256))&1;
    }
    inline void upd(const short &c) {
        if(c<128)
            x^=((bint)1)<<c;
        else if(c<256)
            y^=((bint)1)<<(c-128);
        else
            z^=((bint)1)<<(c-256);
    }
    bool empty() {
        return x==0&&y==0&&z==0;
    }
};
map<bits,int> f[N][N];
vector<PII> mpL[N],mpR[N];
int dfs(int n,int m,bits S) {
    if(n==-1&&m==-1)
        return S.empty();
    if(f[n+1][m+1].count(S))
        return f[n+1][m+1][S];
    int res=0;
    if(n!=-1) {
        // make (n,L) = n+m+1
        bool flag=true;
        bits T=S;
        for(auto x:mpL[n]) {
            if(x.second>m) {
                if(x.first==-1||!T.get(x.first)) {
                    flag=false;
                    break;
                } else
                    T.upd(x.first);
            }
        }
        if(flag)
            add(res,dfs(n-1,m,T));
    }
    if(m!=-1) {
        // make (m,R) = n+m+1
        bool flag=true;
        bits T=S;
        for(auto x:mpR[m]) {
            if(x.second>n) {
                if(x.first==-1||!T.get(x.first)) {
                    flag=false;
                    break;
                } else
                    T.upd(x.first);
            }
        }
        if(flag)
            add(res,dfs(n,m-1,T));
    }
    return f[n+1][m+1][S]=res;
}
int q[N];
mt19937 rd(time(0));
int get(int l,int r) {
    return rd()%(r-l+1)+l;
}
int solve(int n,int m,vector<vector<int>> &splits) {
    if(n==1)
        return 1;
    int res=0,t=get(0,m-1);
    rep(i,0,n-1)
        q[splits[t][i]]=i;
    rep(i,0,n-2) {
        int _n=i,_m=n-2-i;
        rep(i,0,_n+1) {
            rep(j,0,_m+1)
                f[i][j].clear();
        }
        rep(i,0,_n)
            mpL[i].clear();
        rep(i,0,_m)
            mpR[i].clear();
        mpL[_n].push_back(make_pair(-1,0));
        bits S;
        int tot=0;
        bool flag=true;
        rep(k,0,m-1) {
            if(k==t)
                continue;
            int cnt=0;
            rep(j,0,n-2) {
                if((q[splits[k][j]]<=_n)==(q[splits[k][j+1]]<=_n))
                    cnt+=q[splits[k][j]]>q[splits[k][j+1]];

            }
            if(cnt>1) {
                flag=false;
                break;
            }
            if(cnt==1) {
                rep(j,0,n-2) {
                    if(q[splits[k][j]]<=_n&&q[splits[k][j+1]]>_n)
                        mpR[q[splits[k][j+1]]-(_n+1)].push_back(make_pair(-1,q[splits[k][j]]));
                    if(q[splits[k][j]]>_n&&q[splits[k][j+1]]<=_n)
                        mpL[q[splits[k][j+1]]].push_back(make_pair(-1,q[splits[k][j]]-(_n+1)));
                }
            } else {
                rep(j,0,n-2) {
                    if(q[splits[k][j]]<=_n&&q[splits[k][j+1]]>_n)
                        mpR[q[splits[k][j+1]]-(_n+1)].push_back(make_pair(tot,q[splits[k][j]]));
                    if(q[splits[k][j]]>_n&&q[splits[k][j+1]]<=_n)
                        mpL[q[splits[k][j+1]]].push_back(make_pair(tot,q[splits[k][j]]-(_n+1)));
                }
                S.upd(tot);
                ++tot;
            }
        }
        if(!flag)
            continue;
        add(res,dfs(_n,_m,S));
    }
    rep(i,0,m-1) {
        rep(j,0,n-1)
            q[splits[i][j]]=j;
        bool flag=true;
        rep(k,0,m-1) {
            int cnt=0;
            rep(j,0,n-2)
                cnt+=q[splits[k][j]]>q[splits[k][j+1]];
            if(cnt>1) {
                flag=false;
                break;
            }
        }
        res+=flag;
    }
    return res%MOD;
}
void solve() {
    int n,m;
    scanf("%d%d",&n,&m);
    vector<vector<int>> vec(m);
    rep(i,0,m-1) {
        vec[i].resize(n);
        rep(j,0,n-1)
            scanf("%d",&vec[i][j]);
    }
    printf("%d\n",solve(n,m,vec));
}
bool M2;
// g++ P13694.cpp -Wall -std=c++14 -O2 -o P13694
signed main() {
    int testcase=1;
    // scanf("%d",&testcase);
    while(testcase--)
        solve();
    fprintf(stderr,"used time = %dms\n",(int)(1000*clock()/CLOCKS_PER_SEC));
    fprintf(stderr,"used memory = %dMB\n",(int)((&M2-&M1)/1024/1024));
    return 0;
}