题解 T378016 【Polygon】

· · 个人记录

Polygon

给定一个凸多边形的三角剖分,求生成树的个数。

观察到 n 非常大,传统的 \mathtt{Matrix}\text{-}\mathtt{Tree} 做法应该是做不了的。

计数问题考虑 \mathtt{dp},于是我们在原问题里寻找子结构,注意到三角剖分图有一个十分优秀的性质:

我们可以随便找一条边,不妨令这条边为 (1,n),与之构成三角形的点为 u。接下来把 (1,n) 这条边断开,发现整张图变成了两个凸多边形的三角剖分,并且通过点 u 相连。

相似的结构出现了,我们可以用分治来维护这个 \mathtt{dp},分治到 (l,r) 时表示当前是由 l,l+1,\cdots r 这些点构成凸多边形,求出 f_{0/1} 分别表示点 l 与点 r 是否直接通过边 (l,r) 相连时的生成树个数。

\mathtt{dp} 过程中能求出某两个点关于这两个点之间的边直接相连的方案数,若把这条边断开,我们发现这个方案数还等价于这两个点不连通时的生成树个数,这对于转移非常有用。

接下来考虑合并答案,以 (1,n) 为例,讨论 (1,n),(1,u),(u,n) 这三条边的出现情况以及 1,u,n 三个点分别是通过什么方式进行联通。总共有八种情况,转移十分自然,在此不做过多展开。

还有一个问题就是如何找三角形,我的做法是按照顺/逆时针枚举每一个点的所有边,两条相邻的边一定构成一个三角形,可以使用 unordered_map 进行存储,并且可以通过点的顺序来区分包含该条边的至多两个三角形,具体看代码实现。

时间复杂度为 \mathcal{O(n\log n)},空间复杂度为 \mathcal{O}(n),时间复杂度的瓶颈在于找三角形中的排序(\mathtt{dp} 部分的时间复杂度是 \mathcal{O}(n) 的)。

Code:

#include<bits/stdc++.h>
typedef long long ll;
typedef long double ld;
using namespace std;
const int N=200010,M=2000010,p=998244353;
inline int max(int x,int y){return x>y?x:y;}
inline int min(int x,int y){return x<y?x:y;}
inline void swap(int &x,int &y){x^=y^=x^=y;}
int n,top=1;
ll f[N<<2][2];
vector<int> v[N];
unordered_map<int,int> w[N];
void add(int x,int y){
    if(x>y)swap(x,y);
    v[x].push_back(y);
    v[y].push_back(x+N);
    w[x][y]=w[y][x]=N;
}
void solve(int x,int lst,int l,int r){
    if(l+1==r){
        f[x][0]=0,f[x][1]=1;
        return;
    }
    int u=w[l][r];
    if(u==lst||u==N)u=w[r][l];
    int ls=++top,rs=++top;
    solve(ls,r,l,u);
    solve(rs,l,u,r);
    f[x][1]=(f[x][1]+f[ls][1]*f[rs][1]%p)%p;
    f[x][1]=(f[x][1]+f[ls][1]*f[rs][1]%p)%p;
    f[x][1]=(f[x][1]+f[ls][0]*f[rs][1]%p)%p;
    f[x][1]=(f[x][1]+f[ls][1]*f[rs][0]%p)%p;
    f[x][0]=(f[x][0]+f[ls][1]*f[rs][1]%p)%p;
    f[x][0]=(f[x][0]+f[ls][1]*f[rs][0]%p)%p;
    f[x][0]=(f[x][0]+f[ls][0]*f[rs][1]%p)%p;
    f[x][0]=(f[x][0]+f[ls][0]*f[rs][0]%p)%p; 
}
int main(){
    scanf("%d",&n);
    for(int i=1;i<n;i++)
        add(i,i+1);
    add(1,n);
    for(int i=1;i<=n-3;i++){
        int x,y;
        scanf("%d%d",&x,&y);
        add(x,y);
    }
    for(int i=1;i<=n;i++){
        sort(v[i].begin(),v[i].end());
        for(int j=0;j<(int)v[i].size();j++)
            if(v[i][j]>N)v[i][j]-=N;
    }
    for(int i=1;i<=n;i++){
        for(int j=1;j<(int)v[i].size();j++){
            int x=v[i][j-1],y=v[i][j];
            w[x][y]=i;
        }
    }
    solve(1,0,1,n);
    printf("%lld\n",(f[1][0]+f[1][1])%p);
    return 0;
}