题解:CF1616G Just Add an Edge

· · 题解

或许更好的阅读体验

题意

题解

考虑怎么样这张图才能有哈密顿路。

首先一个 DAG 什么时候有哈密顿路,那由于题设的性质,它必然存在一条 1\to 2\to 3\to \dots \to N 的链。此时加边方案数就是 \binom{n}{2}

特判这种情况后,考虑一般情况。容易发现除了加的这条边 x\to y,哈密顿回路的其它地方必定是两条结点编号单调递增点集不相交的链。考虑 1\rightsquigarrow x 叫“第一条链”,y\rightsquigarrow n 叫“第二条链”。容易发现 [1,y-1] 必定由第一条链填满,[x+1,n] 必定由第二条链填满。

那么最终的哈密顿回路必然形如这样:

(1\to 2\to \dots \to y-1)\rightsquigarrow x\to y \rightsquigarrow (x+1\to x+2\to \dots\to n)

其中两个括起来的部分表示每次 +1 的链,y-1\rightsquigarrow xy\rightsquigarrow x+1 是两条交为 \varnothing,并为 [x,y] 的链。

是否存在 (1\to 2\to \dots \to y-1)(x+1\to x+2\to \dots\to n) 的链是好判断的,所以只考虑满足条件的 y-1\rightsquigarrow xy\rightsquigarrow x+1 是否存在即可。

容易发现这个条件形如存在一对 (y-1,y)\rightsquigarrow (x,x+1) 的链,且两条链并起来恰好是 [y-1,x+1]。考虑枚举 y'=y-1 然后计算有多少个 x 满足 (y',y'+1)\rightsquigarrow (x,x+1) 的条件。

由于两条链都只能从小到大延伸,想要覆盖一段连续的区间,那么两条链必定形如这样:

那么考虑枚举对颜色交界点 DP,设 f_{i,0/1} 表示从 (y',y'+1) 开始,能否走到 (i,i+1)(即第二维为 0 表示左红右蓝)或 (i+1,i)(第二维为 1 表示左蓝右红)。

具体转移考虑若 f_{i,t}=1,则枚举 i 的出边 i\to j,若存在 (i+1\to i+2\to\dots \to j-1),则 f_{j-1,\lnot t}\gets 1。这个 (i+1\to \dots \to j-1) 的条件可以考虑存储 r_i 表示 i 能每次加一走到的最大点,那么那个条件就转化成了 j-1\in [i+1,r_{i+1}]

这样就能 O(nm) 的做了。但是这个不好优化啊(而且就算能除以 \omega 也过不了),考虑有没有什么更强力的方法。

容易发现,对于一个 p\nrightarrow p+1i<pf_i 是无法转移到 j-1>pf_{j-1} 的。因为对于 i<p,必有 r_{i+1}\le p,又因为 j-1>p,这显然与 r_{i+1}> j-1 矛盾。

由于需要存在 (1\to \dots \to y-1) 的路径,所以 1\le y\le r_1+1,那么对于 y'=y-1,必有 0\le y' \le r_1。同理由于需要 (x+1\to \dots \to n),所以 lst-1\le x\le n,其中 lst 是最小的 i 满足 r_i=n

此处暂且断一下,注意到 f_{0} 没有定义,但这里能取到 0 是因为哈密顿路的起点不一定是 1,可能根本不存在 (1\to \dots \to y-1) 的路径。不妨建一个虚点,向所有点连边,这样起点就能固定是 0 了。同理建一个虚点 n+1,让所有点向它连边,这样终点也能固定了。

回到刚才的思路,由于 y'\le p,而 x>p(除非 r_1=n,但这个在第一步就特判了),所以对于任何 y,DP 的转移必定要经过 f_{p} 这个点。这就非常强了。

考虑对于 i<p,令 f_{i,0/1} 表示能否从 (i,i+1) 走到 (p,p+1)/(p+1,p),对于 i>p,令 f_{i,0/1} 表示能否从 (p,p+1) 走到 (i,i+1)/(i+1,i)(容易发现这两种定义对于 i=p 是等价的)。

然后路径就转化成了 (y',y'+1)\rightsquigarrow (p,p+1)/(p+1,p)\rightsquigarrow (x,x+1)。容易发现这可以枚举 y'\in[0,p]x\in[lst-1,n] 来做,若 f_{y',0}\land f_{x,0}f_{y',1}\land f_{x,1},则说明存在一条那样的路径。

而且这个非常强力,考虑一个简单的容斥,计算满足 f_{y',0}\land f_{x,0}(y',x) 二元组数量加上满足 f_{y',1}\land f_{x,1}(y',x) 的数量再减去两个都满足的即可。具体计算考虑乘法原理。

注意到一个 bug,根据枚举的方式,假如 p=lst-1 就会多算一次 y'=x=p 的情况,特判 r_{r_{1}+1}=n 的情况减掉即可。

复杂度 O(n)。注意实际实现加了虚点后的边界情况会发生改变。

code

#include<bits/stdc++.h>
#define mem(a,b) memset(a,b,sizeof(a))
#define forup(i,s,e) for(int i=(s);i<=(e);i++)
#define fordown(i,s,e) for(int i=(s);i>=(e);i--)
#ifdef DEBUG
#define msg(args...) fprintf(stderr,args)
#else
#define msg(...) void()
#endif
using namespace std;
using i64=long long;
#define gc getchar()
inline int read(){
    int x=0,f=1;char c;
    while(!isdigit(c=gc)) if(c=='-') f=-1;
    while(isdigit(c)){x=(x<<3)+(x<<1)+(c^48);c=gc;}
    return x*f;
}
#undef gc
const int N=150005,inf=0x3f3f3f3f;
int n,m,a[N],r[N];
int dp[N][2];
vector<int> e[N];
void solve(){
    n=read();m=read();
    forup(i,0,n+1){//注意清空
        e[i].clear();
        r[i]=i;a[i]=0;
        dp[i][0]=dp[i][1]=0;
        if(i>1) e[0].push_back(i);
        if(i<n) e[i].push_back(n+1);
    }
    a[n]=a[0]=1;
    forup(i,1,m){
        int u=read(),v=read();
        if(v==u+1) a[u]=1;
        else e[u].push_back(v);
      //如果 v=u+1 就不加边,因为这种边在转移时 j-1<i+1,这样可以避免转移时特判
    }
    fordown(i,n+1,0){
        if(a[i]) r[i]=r[i+1];
    }
    if(r[0]==n+1){//特判 r[1]=n,注意边界的改变
        printf("%lld\n",1ll*n*(n-1)/2);
        return;
    }
    dp[r[0]][0]=1;
    forup(i,r[0],n+1){
        for(auto j:e[i]){
            if(j-1<=r[i+1]){
                dp[j-1][0]|=dp[i][1];
                dp[j-1][1]|=dp[i][0];
            }
        }
    }
    fordown(i,r[0]-1,0){
        for(auto j:e[i]){
            if(j-1<=r[i+1]){
                dp[i][0]|=dp[j-1][1];
                dp[i][1]|=dp[j-1][0];
            }
        }
    }
    int cx0=0,cx1=0,cy0=0,cy1=0,cx01=0,cy01=0;
    forup(i,0,r[0]){
        cy0+=dp[i][0];
        cy1+=dp[i][1];
        cy01+=dp[i][0]&&dp[i][1];
    }
    int lst=n+1;
    while(r[lst]==n+1) --lst;
    forup(i,lst,n){
        cx0+=dp[i][0];
        cx1+=dp[i][1];
        cx01+=dp[i][0]&&dp[i][1];
    }
    i64 ans=1ll*cx0*cy0+1ll*cx1*cy1-1ll*cx01*cy01;
    if(r[r[0]+1]==n+1) --ans;
    printf("%lld\n",ans);
}
signed main(){
    int t=read();
    while(t--){
        solve();
    }
}