P9481 [NOI2023] 贸易 题解

· · 题解

近五年来最简单的 D2T1。

不难发现,所有点对的贡献都可以被拆分成向上到 lca 再由 lca 到目标点的两段,即 \text{dist}(x,y)=\text{dist}(x,\text{lca}(x,y))+\text{dist}(\text{lca}(x,y),y)

所以只要求每个点到自己子树内所有点的最短路就可以快速算贡献。这里我用的是 dijkstra 求单源最短路。

对最短路有贡献的点只有该点的祖先和后代,所以把这些点提出来单独跑最短路即可。 每个点的复杂度是 \mathcal{O}(siz \log m) 的,而 \sum siz=2^nn,所以总复杂度为

## AC Code ```cpp #include<bits/stdc++.h> using namespace std; typedef long long ll; const ll mod=998244353; const int N=5e5+5,M=1e6+5; int n,m; int tot,ver[M],nxt[M],head[N],num[N]; ll edge[M],d[N],ans,sum[N]; bool v[N],vis[N]; priority_queue< pair<ll,int> > q; vector<int> vec[N]; vector< pair<int,ll> > pos[N]; void add(int x,int y,ll z){ver[++tot]=y,edge[tot]=z,nxt[tot]=head[x],head[x]=tot;} void dfs(int x) { vec[x].push_back(x); if(num[x]==n) return; int xa=x<<1,xb=xa+1; dfs(xa); dfs(xb); vis[x]=1; for(int i=0,val;i<vec[xa].size();i++) val=vec[xa][i],vec[x].push_back(val),vis[val]=1,v[val]=0,d[val]=1e18; for(int i=0,val;i<vec[xb].size();i++) val=vec[xb][i],vec[x].push_back(val),vis[val]=1,v[val]=0,d[val]=1e18; q.push(make_pair(0,x)); ll wa=0,wb=0,wa_=0,wb_=0,ca=0,cb=0,ca_=0,cb_=0; for(int i=1;i<vec[x].size();i++) { int y=vec[x][i]; for(int j=0;j<pos[y].size();j++) if(pos[y][j].first<x){d[y]=min(d[y],sum[x]-sum[pos[y][j].first]+pos[y][j].second);q.push(make_pair(-d[y],y));} } while(!q.empty()) { int now=q.top().second; q.pop(); if(v[now]) continue; v[now]=1; for(int i=head[now];i;i=nxt[i]) { if(!vis[ver[i]]) continue; int y=ver[i]; ll z=edge[i]; if(d[now]+z<d[y]){d[y]=d[now]+z; q.push(make_pair(-d[y],y));} } } v[x]=0; for(int i=0,val;i<vec[xa].size();i++) val=vec[xa][i],wa=(wa+(d[val]==1e18?0:d[val]))%mod,ca=(ca+((d[val]==1e18)?0:1))%mod,v[val]=0,d[val]=1e18; for(int i=0,val;i<vec[xb].size();i++) val=vec[xb][i],wb=(wb+(d[val]==1e18?0:d[val]))%mod,cb=(cb+((d[val]==1e18)?0:1))%mod,v[val]=0,d[val]=1e18; for(int i=0,val;i<vec[xa].size();i++) val=vec[xa][i],wa_=(wa_+sum[val]-sum[x])%mod,ca_++,vis[val]=0; for(int i=0,val;i<vec[xb].size();i++) val=vec[xb][i],wb_=(wb_+sum[val]-sum[x])%mod,cb_++,vis[val]=0; ans=(ans+wa+wb+wa_+wb_+wa*cb_%mod+wb*ca_%mod+wa_*cb%mod+wb_*ca%mod)%mod; } int main() { scanf("%d%d",&n,&m); for(int i=1;i<=(1<<n)-2;i++) { ll z; scanf("%lld",&z); add(i+1,(i+1)>>1,z); num[i]=num[i>>1]+1,sum[i+1]=sum[(i+1)>>1]+z; } for(int i=1,x,y;i<=m;i++){ll z; scanf("%d%d%lld",&x,&y,&z); add(x,y,z); pos[y].push_back(make_pair(x,z));} num[(1<<n)-1]=n; dfs(1); printf("%lld\n",ans); return 0; }