P9481 [NOI2023] 贸易 题解
Gaode_Sean
·
·
题解
近五年来最简单的 D2T1。
不难发现,所有点对的贡献都可以被拆分成向上到 lca 再由 lca 到目标点的两段,即 \text{dist}(x,y)=\text{dist}(x,\text{lca}(x,y))+\text{dist}(\text{lca}(x,y),y)。
所以只要求每个点到自己子树内所有点的最短路就可以快速算贡献。这里我用的是 dijkstra 求单源最短路。
对最短路有贡献的点只有该点的祖先和后代,所以把这些点提出来单独跑最短路即可。
每个点的复杂度是 \mathcal{O}(siz \log m) 的,而 \sum siz=2^nn,所以总复杂度为
## AC Code
```cpp
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod=998244353;
const int N=5e5+5,M=1e6+5;
int n,m;
int tot,ver[M],nxt[M],head[N],num[N];
ll edge[M],d[N],ans,sum[N];
bool v[N],vis[N];
priority_queue< pair<ll,int> > q;
vector<int> vec[N];
vector< pair<int,ll> > pos[N];
void add(int x,int y,ll z){ver[++tot]=y,edge[tot]=z,nxt[tot]=head[x],head[x]=tot;}
void dfs(int x)
{
vec[x].push_back(x);
if(num[x]==n) return;
int xa=x<<1,xb=xa+1;
dfs(xa); dfs(xb);
vis[x]=1;
for(int i=0,val;i<vec[xa].size();i++) val=vec[xa][i],vec[x].push_back(val),vis[val]=1,v[val]=0,d[val]=1e18;
for(int i=0,val;i<vec[xb].size();i++) val=vec[xb][i],vec[x].push_back(val),vis[val]=1,v[val]=0,d[val]=1e18;
q.push(make_pair(0,x));
ll wa=0,wb=0,wa_=0,wb_=0,ca=0,cb=0,ca_=0,cb_=0;
for(int i=1;i<vec[x].size();i++)
{
int y=vec[x][i];
for(int j=0;j<pos[y].size();j++) if(pos[y][j].first<x){d[y]=min(d[y],sum[x]-sum[pos[y][j].first]+pos[y][j].second);q.push(make_pair(-d[y],y));}
}
while(!q.empty())
{
int now=q.top().second; q.pop();
if(v[now]) continue;
v[now]=1;
for(int i=head[now];i;i=nxt[i])
{
if(!vis[ver[i]]) continue;
int y=ver[i]; ll z=edge[i];
if(d[now]+z<d[y]){d[y]=d[now]+z; q.push(make_pair(-d[y],y));}
}
}
v[x]=0;
for(int i=0,val;i<vec[xa].size();i++) val=vec[xa][i],wa=(wa+(d[val]==1e18?0:d[val]))%mod,ca=(ca+((d[val]==1e18)?0:1))%mod,v[val]=0,d[val]=1e18;
for(int i=0,val;i<vec[xb].size();i++) val=vec[xb][i],wb=(wb+(d[val]==1e18?0:d[val]))%mod,cb=(cb+((d[val]==1e18)?0:1))%mod,v[val]=0,d[val]=1e18;
for(int i=0,val;i<vec[xa].size();i++) val=vec[xa][i],wa_=(wa_+sum[val]-sum[x])%mod,ca_++,vis[val]=0;
for(int i=0,val;i<vec[xb].size();i++) val=vec[xb][i],wb_=(wb_+sum[val]-sum[x])%mod,cb_++,vis[val]=0;
ans=(ans+wa+wb+wa_+wb_+wa*cb_%mod+wb*ca_%mod+wa_*cb%mod+wb_*ca%mod)%mod;
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=(1<<n)-2;i++)
{
ll z; scanf("%lld",&z);
add(i+1,(i+1)>>1,z);
num[i]=num[i>>1]+1,sum[i+1]=sum[(i+1)>>1]+z;
}
for(int i=1,x,y;i<=m;i++){ll z; scanf("%d%d%lld",&x,&y,&z); add(x,y,z); pos[y].push_back(make_pair(x,z));}
num[(1<<n)-1]=n;
dfs(1);
printf("%lld\n",ans);
return 0;
}