题解 ARC158E All Pair Shortest Paths

· · 题解

一个好想,但巨大麻烦的做法。

先弱化一下问题,如果求一个点到其他所有点的最短路长度之和,这个问题用最短路树可以方便地解决,答案为树上每个节点的权值乘上其子树大小。

比如(树根为第二行第一列的 15):

而现在要对所有点对都求,于是考虑强行递推最短路树的形态。

根从第 n 列扫到第 1 列,先讨论在第一行的情况,最短路树有两种可能的形态:

或者

我们只需要对每个点快速找到它后面的 j,就可以进行递推。

为方便下面的分析,给出几个定义

那么第一种情况需要满足:

\begin{aligned} p[1][j]-p[1][i]<p[2][j-1]-p[2][i-1]\\ \Rightarrow p[2][i-1]-p[1][i]<p[2][j-1]-p[1][j] \end{aligned}

可以拿单调栈找到满足条件的第一个 j

第二种情况需要满足:

\begin{aligned} p[2][j]-p[2][i-1]<p[1][j-1]-p[1][i]\\ \Rightarrow p[1][i]-p[2][i-1]<p[1][j-1]-p[2][j] \end{aligned}

两边形式不同,只能拿 ST 表查区间 \max + 二分来找(捂脸)。

然后两种情况的 j 取个 \min 就是真实的 j 了(如果不存在,则令 j=n+1)。

求出 j 之后进行递推,即第 j 列之后的子问题,加上 i\sim j-1 列的贡献。

具体来说:

  1. 如果子根取第一行第 j 列,那么为:

    \begin{aligned} res[1][i]=&\ \ \ \ \ 2\cdot (n-i+1)\cdot a[1][i]\ (根的贡献)\\ &+res[1][j]\ (子根的贡献)\\ &+\sum_{k=i+1}^{j-1}a[1][k]\cdot (2\cdot (n-j+1)+j-k)\ (第一行的贡献)\\ &+\sum_{k=i}^{j-1}a[2][k]\cdot(j-k)\ (第二行的贡献) \end{aligned}
  2. 如果子根取第二行第 j 列,那么为:

    \begin{aligned} res[1][i]=&\ \ \ \ \ 2\cdot (n-i+1)\cdot a[1][i]\ (根的贡献)\\ &+res[2][j]\ (子根的贡献)\\ &+\sum_{k=i+1}^{j-1}a[1][k]\cdot (j-k)\ (第一行的贡献)\\ &+\sum_{k=i}^{j-1}a[2][k]\cdot(2\cdot (n-j+1)+j-k)\ (第二行的贡献) \end{aligned}

使用预处理出的 ps 不难 O(1) 求出。

然后根在第二行的情况是完全一样的,一/二行取反就可以了。

最终的贡献由于要考虑前后与算重,最后答案化出来是:

ans=\sum_{i=1}^n(2\cdot (res[1][i]+res[2][i])-3\cdot (a[1][i]+a[2][i]))

时间是 O(n\log n),瓶颈在于 ST 表 + 二分。

代码细节较多

typedef long long ll;
#define F first
#define S second
const int N=2e5+5,mod=998244353;
ll pre[3][N],mx[3][N][20];
int p[3][N],a[3][N],nex[3][N][3],res[3][N],Log[N];
int n,ans,top;
pair<ll,int> stk[N];
int Sum(int d, int l, int r, int k){
    return (1ll*(r+k)*((pre[d][r]-pre[d][l-1])%mod)+p[d][r]+mod-p[d][l-1])%mod;
}
ll query(int d, int l, int r){
    int k=Log[r-l+1];
    return max(mx[d][l][k],mx[d][r-(1<<k)+1][k]);
}
int main(){
    scanf("%d",&n);
    for (int k=1; k<=2; k++) for (int i=1; i<=n; i++){
        scanf("%d",&a[k][i]);
        pre[k][i]=pre[k][i-1]+a[k][i];
        p[k][i]=(p[k][i-1]+1ll*(mod-i+1)*a[k][i])%mod;
    }
    stk[top=1]={-1e18,n+1};
    for (int i=n; i>=1; i--){
        ll val=pre[1][i]-pre[2][i-1];
        while (top && stk[top].F>val) top--;
        nex[1][i][2]=stk[top].S; stk[++top]={val,i};
    }
    stk[top=1]={-1e18,n+1};
    for (int i=n; i>=1; i--){
        ll val=pre[2][i]-pre[1][i-1];
        while (top && stk[top].F>val) top--;
        nex[2][i][1]=stk[top].S; stk[++top]={val,i};
    }
    Log[0]=-1;
    for (int i=1; i<=n; i++){
        Log[i]=Log[i>>1]+1;
        mx[1][i][0]=pre[1][i-1]-pre[2][i];
        mx[2][i][0]=pre[2][i-1]-pre[1][i];
    }
    for (int i=n; i>=1; i--)
        for (int j=1; i+(1<<j)-1<=n; j++){
            mx[1][i][j]=max(mx[1][i][j-1],mx[1][i+(1<<(j-1))][j-1]);
            mx[2][i][j]=max(mx[2][i][j-1],mx[2][i+(1<<(j-1))][j-1]);
        }
    for (int i=1; i<=n; i++){
        int l=i+1,r=n; nex[1][i][1]=n+1;
        while (l<=r){
            int mid=(l+r)>>1;
            if (query(1,i+1,mid)>pre[1][i]-pre[2][i-1]) nex[1][i][1]=mid,r=mid-1;
            else l=mid+1;
        }
        l=i+1,r=n; nex[2][i][2]=n+1;
        while (l<=r){
            int mid=(l+r)>>1;
            if (query(2,i+1,mid)>pre[2][i]-pre[1][i-1]) nex[2][i][2]=mid,r=mid-1;
            else l=mid+1;
        }
    }
    for (int i=n; i>=1; i--){
        int d=min(nex[1][i][1],nex[1][i][2]);
        if (d==nex[1][i][1])
            res[1][i]=(1ll*res[2][d]+2ll*a[1][i]*(n-i+1)+Sum(1,i+1,d-1,0)+Sum(2,i,d-1,2*(n-d+1)))%mod; else
            res[1][i]=(1ll*res[1][d]+2ll*a[1][i]*(n-i+1)+Sum(1,i+1,d-1,2*(n-d+1))+Sum(2,i,d-1,0))%mod;
        d=min(nex[2][i][1],nex[2][i][2]);
        if (d==nex[2][i][1])
            res[2][i]=(1ll*res[2][d]+2ll*a[2][i]*(n-i+1)+Sum(1,i,d-1,0)+Sum(2,i+1,d-1,2*(n-d+1)))%mod; else
            res[2][i]=(1ll*res[1][d]+2ll*a[2][i]*(n-i+1)+Sum(1,i,d-1,2*(n-d+1))+Sum(2,i+1,d-1,0))%mod;
        ans=(1ll*ans+2ll*res[1][i]+2ll*res[2][i])%mod;
        ans=(1ll*ans+mod-3ll*(a[1][i]+a[2][i])%mod)%mod;
    }
    printf("%d\n",ans);
}