题解 ARC158E All Pair Shortest Paths
一个好想,但巨大麻烦的做法。
先弱化一下问题,如果求一个点到其他所有点的最短路长度之和,这个问题用最短路树可以方便地解决,答案为树上每个节点的权值乘上其子树大小。
比如(树根为第二行第一列的
而现在要对所有点对都求,于是考虑强行递推最短路树的形态。
根从第
或者
我们只需要对每个点快速找到它后面的
为方便下面的分析,给出几个定义
那么第一种情况需要满足:
可以拿单调栈找到满足条件的第一个
第二种情况需要满足:
两边形式不同,只能拿 ST 表查区间
然后两种情况的
求出
具体来说:
-
如果子根取第一行第
j 列,那么为:\begin{aligned} res[1][i]=&\ \ \ \ \ 2\cdot (n-i+1)\cdot a[1][i]\ (根的贡献)\\ &+res[1][j]\ (子根的贡献)\\ &+\sum_{k=i+1}^{j-1}a[1][k]\cdot (2\cdot (n-j+1)+j-k)\ (第一行的贡献)\\ &+\sum_{k=i}^{j-1}a[2][k]\cdot(j-k)\ (第二行的贡献) \end{aligned} -
如果子根取第二行第
j 列,那么为:\begin{aligned} res[1][i]=&\ \ \ \ \ 2\cdot (n-i+1)\cdot a[1][i]\ (根的贡献)\\ &+res[2][j]\ (子根的贡献)\\ &+\sum_{k=i+1}^{j-1}a[1][k]\cdot (j-k)\ (第一行的贡献)\\ &+\sum_{k=i}^{j-1}a[2][k]\cdot(2\cdot (n-j+1)+j-k)\ (第二行的贡献) \end{aligned}
使用预处理出的
然后根在第二行的情况是完全一样的,一/二行取反就可以了。
最终的贡献由于要考虑前后与算重,最后答案化出来是:
时间是
代码细节较多
typedef long long ll;
#define F first
#define S second
const int N=2e5+5,mod=998244353;
ll pre[3][N],mx[3][N][20];
int p[3][N],a[3][N],nex[3][N][3],res[3][N],Log[N];
int n,ans,top;
pair<ll,int> stk[N];
int Sum(int d, int l, int r, int k){
return (1ll*(r+k)*((pre[d][r]-pre[d][l-1])%mod)+p[d][r]+mod-p[d][l-1])%mod;
}
ll query(int d, int l, int r){
int k=Log[r-l+1];
return max(mx[d][l][k],mx[d][r-(1<<k)+1][k]);
}
int main(){
scanf("%d",&n);
for (int k=1; k<=2; k++) for (int i=1; i<=n; i++){
scanf("%d",&a[k][i]);
pre[k][i]=pre[k][i-1]+a[k][i];
p[k][i]=(p[k][i-1]+1ll*(mod-i+1)*a[k][i])%mod;
}
stk[top=1]={-1e18,n+1};
for (int i=n; i>=1; i--){
ll val=pre[1][i]-pre[2][i-1];
while (top && stk[top].F>val) top--;
nex[1][i][2]=stk[top].S; stk[++top]={val,i};
}
stk[top=1]={-1e18,n+1};
for (int i=n; i>=1; i--){
ll val=pre[2][i]-pre[1][i-1];
while (top && stk[top].F>val) top--;
nex[2][i][1]=stk[top].S; stk[++top]={val,i};
}
Log[0]=-1;
for (int i=1; i<=n; i++){
Log[i]=Log[i>>1]+1;
mx[1][i][0]=pre[1][i-1]-pre[2][i];
mx[2][i][0]=pre[2][i-1]-pre[1][i];
}
for (int i=n; i>=1; i--)
for (int j=1; i+(1<<j)-1<=n; j++){
mx[1][i][j]=max(mx[1][i][j-1],mx[1][i+(1<<(j-1))][j-1]);
mx[2][i][j]=max(mx[2][i][j-1],mx[2][i+(1<<(j-1))][j-1]);
}
for (int i=1; i<=n; i++){
int l=i+1,r=n; nex[1][i][1]=n+1;
while (l<=r){
int mid=(l+r)>>1;
if (query(1,i+1,mid)>pre[1][i]-pre[2][i-1]) nex[1][i][1]=mid,r=mid-1;
else l=mid+1;
}
l=i+1,r=n; nex[2][i][2]=n+1;
while (l<=r){
int mid=(l+r)>>1;
if (query(2,i+1,mid)>pre[2][i]-pre[1][i-1]) nex[2][i][2]=mid,r=mid-1;
else l=mid+1;
}
}
for (int i=n; i>=1; i--){
int d=min(nex[1][i][1],nex[1][i][2]);
if (d==nex[1][i][1])
res[1][i]=(1ll*res[2][d]+2ll*a[1][i]*(n-i+1)+Sum(1,i+1,d-1,0)+Sum(2,i,d-1,2*(n-d+1)))%mod; else
res[1][i]=(1ll*res[1][d]+2ll*a[1][i]*(n-i+1)+Sum(1,i+1,d-1,2*(n-d+1))+Sum(2,i,d-1,0))%mod;
d=min(nex[2][i][1],nex[2][i][2]);
if (d==nex[2][i][1])
res[2][i]=(1ll*res[2][d]+2ll*a[2][i]*(n-i+1)+Sum(1,i,d-1,0)+Sum(2,i+1,d-1,2*(n-d+1)))%mod; else
res[2][i]=(1ll*res[1][d]+2ll*a[2][i]*(n-i+1)+Sum(1,i,d-1,2*(n-d+1))+Sum(2,i+1,d-1,0))%mod;
ans=(1ll*ans+2ll*res[1][i]+2ll*res[2][i])%mod;
ans=(1ll*ans+mod-3ll*(a[1][i]+a[2][i])%mod)%mod;
}
printf("%d\n",ans);
}