8354(写完了)

· · 题解

以下,为了方便表述,称原多边形的边为边,边上等分出的线段叫做线,最终的三角剖分中,在同一个三角形中的若干同一条边上的线组成的东西叫段。

首先,普通的多边形三角剖分的方案数是卡特兰数。

本题类似于对 s=\sum a_i 条线进行三角剖分,但是:

  1. 不能连接处于同一条边上的两个点。
  2. 同一条边上的若干条线可以处于同一个三角形中。

首先想办法处理掉第二个条件。可以对于每条边枚举上面的线最终会被分成几段,方案数是一个组合数的形式。后续的过程中,不同的段就不能分到同一个三角形中了。

接下来,考虑处理第一个条件。考虑容斥,钦定一些同一条边上的点对被连接了,来考虑一下容斥系数:

  1. 对于跨过一条线的点对,系数为 1
  2. 对于跨过两条线的点对,系数为 -1,因为钦定了一条非法的线段。
  3. 对于跨过三条及以上线的点对,系数为 0。因为若其被连接,那么在其内部一定还有若干点对被连接。(这一点不理解的可以看后面的补充。)

每条边留下几段,这在边边之间是相互独立的。因此我们可以得到一个 s^2 的算法:对于每条边,分别求出其钦定后留下 i 段的容斥系数之和。然后就可以算出将整个多边形剖为 i 段的系数之和,每一项乘上 i 边形的三角剖分方案数即是答案。

void solve(){
    init();
    cin>>n;
    int sum=0;
    dp[0][0]=1;
    for(int i=0;i<N;i++)
        for(int j=0;j<N-2;j++){
            add(dp[i+1][j+1],dp[i][j]);
            add(dp[i+1][j+2],-dp[i][j]);
        }
    f[0]=1;
    for(int i=1;i<=n;i++){
        cin>>a[i];
        memcpy(g,f,sizeof(g));
        memset(f,0,sizeof(f));
        memset(h,0,sizeof(h));
        for(int j=1;j<=a[i];j++)// 枚举此边分为了 j 段(处理第二个条件)
            for(int k=1;k<=j;k++)// 枚举此段钦定后还有 k 段(处理第一个条件)
                add(h[k],C(a[i]-1,j-1)*dp[k][j]);
        for(int k=1;k<=a[i];k++)
            for(int l=0;l<=sum;l++)
                add(f[l+k],h[k]*g[l]%mod);
        sum+=a[i];
    }
    int res=0;
    for(int i=1;i<=sum;i++)
        add(res,f[i]*Ca(i-2));
    cout<<(res%mod+mod)%mod<<endl;
}

现在可以考虑优化这个算法了。对于后面那个把 h 卷起来那个过程,可以拿个堆按大小排序从小到大乘起来就好。现在主要问题是怎么算这个 h 数组。

首先把模型转化一下,把两个条件综合在一起。n 条线有 n-1 个分界点,把这些点分成三类:

  1. 连接点,代表前后的线最终属于同一段。
  2. 断点,代表前后的线最终既不属于同一段,也没有被某个钦定非法的线段跨过。
  3. 非法点,代表前后的线最终不属于同一段,但是其前后的段被钦定非法的线段跨过了。

最终钦定后的段数就是断点的个数,容斥系数就是 -1 的非法点个数次方。

据此,可以将段根据是否存在非法点来分成两类。两个都存在非法点的段是不能用连接点连起来的。

那么我们设 dp_{i,0/1,0/1,j} 代表左边的段是否有非法点,右边的段是否有非法点,当前有 i 条线,最终划分为 j 段,容斥系数之和是多少。然后考虑用倍增的方法来算。每次可以 O(n\log n) 地从 dp_{i} 推出 dp_{2i};也可以 O(n) 地从 dp_{i} 推出 dp_{i+1}。具体的方式是枚举当前点是哪一类,然后转移。注意会有一种无法统计进去的特殊情况,即所有点都是连接点,这样左右是同一个段,此时需要特殊处理。

#include<bits/stdc++.h>
#define int long long
const int N=1000005,mod=998244353;
const int g=3;
using namespace std;
using vi = vector<int>;

void add(int&x,int y){
    x+=y;if(x>=mod||x<-mod)x%=mod;
}

int gsc(int a,int b){
    int res=1;
    for(int i=1;i<=b;i<<=1,a=a*a%mod)
        if(b&i)res=res*a%mod;
    return res;
}int inv(int k){return gsc(k,mod-2);}

int jc[N],ij[N],iv[N];
int C(int n,int m){
    if(n<m||m<0)return 0;
    return jc[n]*ij[m]%mod*ij[n-m]%mod;
}
void init(){
    jc[0]=iv[0]=1;
    for(int i=1;i<N;i++)jc[i]=jc[i-1]*i%mod;
    ij[N-1]=inv(jc[N-1]);
    for(int i=N-2;i>=0;i--)ij[i]=ij[i+1]*(i+1)%mod;
    for(int i=1;i<N;i++)iv[i]=ij[i]*jc[i-1]%mod;
}

vector<int> A[N];
int lim,r[N];
void initntt(int n){
    lim=1;while(lim<n)lim<<=1;
    for(int i=0;i<lim;i++)
        r[i]=(r[i>>1]>>1)|((i&1)*(lim>>1));
}
void ntt(vi&a,int ty=1){
    a.resize(lim);
    for(int i=0;i<lim;i++)if(r[i]<=i)swap(a[i],a[r[i]]);
    int x,y;
    for(int i=1;i<lim;i<<=1){
        int o=gsc(ty==1?g:inv(g),mod/(i*2)),x,y;
        for(int j=0;j<lim;j+=i<<1)
            for(int k=0,w=1;k<i;k++,w=w*o%mod)
                x=a[j+k],y=a[j+k+i]*w,a[j+k]=(x+y)%mod,a[j+k+i]=(x-y)%mod;
    }
    if(ty==-1){
        int iv=inv(lim);
        for(int i=0;i<lim;i++)a[i]=(a[i]*iv%mod+mod)%mod;
    }
}

vi operator * (vi a,vi b){
    initntt(a.size()+b.size()-1);
    ntt(a),ntt(b);
    for(int i=0;i<lim;i++)a[i]=a[i]*b[i]%mod;
    ntt(a,-1);
    while(a.size()&&a.back()==0)a.pop_back();
    return a;
}
vi operator + (vi a,vi b){
    if(a.size()<b.size())swap(a,b);
    for(int i=0;i<b.size();i++)add(a[i],b[i]);
    return a;
}
void operator +=(vi&a,vi b){a=a+b;}
vi operator - (vi a){for(auto&x:a)x=-x;return a;}
vi yy(vi a){if(a.size())a.erase(a.begin());return a;}
vi zy(vi a){a.insert(a.begin(),0);return a;}
struct cmp{bool operator ()(vi a,vi b){return a.size()>b.size();}};
vi Mul(vector<vi> A){
    priority_queue< vi,vector<vi>,cmp > Q;
    for(auto x:A)Q.push(x);
    while(Q.size()>1){
        auto x=Q.top();Q.pop();
        auto y=Q.top();Q.pop();
        Q.push(x*y);
    }
    return Q.top();
}

struct Res{vi a[2][2];auto operator[](int b){return a[b];}}tem[N];
int vis[N];
Res geth(int k){
    if(k==1){Res res;return res;}
    if(vis[k])return tem[k];
    int fl=0;
    vis[k]=1;
    if(k&1)k--,fl=1;
    vis[k]=1;
    Res L=geth(k/2);
    Res res;
    auto D=-L[0][1]*L[0][1];
    for(int i=0;i<2;i++)
        for(int j=i;j<2;j++){
            auto A=(L[i][0]+L[i][1])*(L[0][j]+L[1][j]);
            auto B=(i==0&&j==0)?D:-L[i][1]*L[1][j];
            auto C=(i==1&&j==1)?D:-L[i][0]*L[0][j];
            res[i][j]+=A;
            res[i][j]+=yy(A);
            res[i][j]+=yy(C);
            res[i][j]+=yy(B);
        }
    res[1][0]=res[0][1];
    for(int i=0;i<2;i++){
        res[i][0]+=L[i][0]+zy(L[i][0])+zy(L[i][1]);
        res[i][1]+=-L[i][0]+L[i][1];
        res[0][i]+=L[0][i]+zy(L[0][i])+zy(L[1][i]);
        res[1][i]+=-L[0][i]+L[1][i];
    }
    res[0][0]+=(vi){0,0,1};
    res[1][1]+=(vi){0,-1};
    tem[k]=res;
    if(fl){
        Res rr;
        for(int i=0;i<2;i++){
            rr[i][0]+=res[i][0]+zy(res[i][0])+zy(res[i][1]);
            rr[i][1]+=-res[i][0]+res[i][1];
        }
        rr[0][0]+=(vi){0,0,1};
        rr[1][1]+=(vi){0,-1};
        res=rr;
    }
    return tem[k+fl]=res;
}
vector<int> getH(int x){
    Res r=geth(x);
    vector<int> a(x+1);a[1]=1;
    for(int i=0;i<2;i++)for(int j=0;j<2;j++)
        for(int k=0;k<r[i][j].size();k++)
            add(a[k],r[i][j][k]);
    return a;
}

int n;
int a[N];

int Ca(int k){
    if(k<0)return 0;
    return C(2*k,k)*iv[k+1]%mod;
}

int h[N],f[N];
void solve(){
    init();
    int n,s=0;cin>>n;
    vector<vi> A;
    for(int i=1;i<=n;i++){
        int x;cin>>x;s+=x;
        A.push_back(getH(x));
    }
    auto res=Mul(A);
    int result=0;
    for(int i=3;i<res.size();i++)
        add(result,res[i]*Ca(i-2));
    cout<<(result%mod+mod)%mod<<endl;
}

main(){
    ios::sync_with_stdio(0);
    solve();
}

补充一下为什么跨过三条线的边容斥系数是 0。考虑其内的任意一种三角剖分,都包含有至少 1 条跨过两条线的边,假设有 k 条。考虑这 k 条边是否钦定,会有 2^k 种,其中有一半系数为 1,其余为 -1,总和是 0