题解 P4930 【「FJ2014集训」采药人的路径】

Shikita

2019-03-19 16:40:07

Solution

# 树上路径统计 点分治 点分治是用来解决树上路径问题的一种方法。 在解决树上路径问题时,我们可以选取一点为根,将树转化为有根树,然后考虑经过根的所有路径(有时将两条从根出发的路径连接为一条)。统计完这些路径的答案后,将根节点标记为删除,对剩下的若干棵树进行同样的操作。 每次确认一个根节点后,共有$n$条需要考虑的路径($n$为当前子树大小)。若将当前根节点删除后,剩下某棵的子树较大,和原树大小相当,继续处理这棵子树时仍然需要与前一过程相当的时间。 最严重的情况,当整棵树是一条链时,每次需要考虑的路径数量是$O(n)$级别的,如果每条路径需要常数时间进行统计,则总时间复杂度为$O(n^2)$ 而对于形态随机的树,则远远小于这个级别。 所以我们应该合理选取树的根节点 而树的重心能够很好满足我们的需求 我们定义一棵树的重心为以该点为根时最大子树最小的点。 那么我们每次删除掉当前根节点后,剩余的子树能够保证尽量平衡,这就可以大大减小复杂度 ``` //mx[x] 记录的是x的子树中最大的子树大小 inline void getroot(int x,int fa) { size[x]=1,mx[x]=0; for(int i=head[x];i;i=e[i].next) if(v!=fa&&!vis[v]) { getroot(v,x); size[x]+=size[v]; mx[x]=max(mx[x],size[v]); } mx[x]=max(mx[x],sum-size[x]); if(mx[x]<mx[rt]) rt=x; } ``` ~~我觉得你们自己看代码应该好懂,不解释了~~ 那么对于本题,我们将阴阳点分别置为1和-1,求路径上有一个0点且终点为0的路径条数 我们令$g[dis][0/1]$表示之前dfs过的子树中,到根的距离为$dis$的路径数目 $f$数组表示当前正在dfs的子树,到根的距离为$dis$的路径数目 0/1表示其中是否出现过前缀和为0的点。 ``` #include<bits/stdc++.h> #define ll long long using namespace std; inline int read() { int x=0; char c=getchar(); bool flag=0; while(c<'0'||c>'9') {if(c=='-')flag=1;c=getchar();} while(c>='0'&&c<='9'){x=(x+(x<<2)<<1)+ c-'0';c=getchar();} return flag?-x:x; } const int N=1e5+5; int n,sum,rt,mxdep; int vis[N],t[N<<1],mx[N],size[N],dep[N],dis[N]; ll ans,g[N<<1][2],f[N<<1][2]; int head[N],cnt; struct node{int to,next,w;}e[N<<1]; inline void add(int u,int v,int w) { e[++cnt].to=v,e[cnt].next=head[u],e[cnt].w=w,head[u]=cnt; e[++cnt].to=u,e[cnt].next=head[v],e[cnt].w=w,head[v]=cnt; } //mx[x] x的最大子树 #define v e[i].to inline void getroot(int x,int fa) { size[x]=1,mx[x]=0; for(int i=head[x];i;i=e[i].next) if(v!=fa&&!vis[v]) { getroot(v,x); size[x]+=size[v]; mx[x]=max(mx[x],size[v]); } mx[x]=max(mx[x],sum-size[x]); if(mx[x]<mx[rt]) rt=x; } //找重心 inline void dfs(int x,int fa) { mxdep=max(mxdep,dep[x]); (t[dis[x]])?++f[dis[x]][1]:++f[dis[x]][0]; ++t[dis[x]]; for(int i=head[x];i;i=e[i].next) if(v!=fa&&!vis[v]) { dep[v]=dep[x]+1; dis[v]=dis[x]+e[i].w; dfs(v,x); } --t[dis[x]]; } inline void solve(int x) { int mx=0;vis[x]=g[n][0]=1; for(int i=head[x];i;i=e[i].next) if(!vis[v]) { dis[v]=n+e[i].w;dep[v]=1;mxdep=1; dfs(v,0);mx=max(mx,mxdep); ans+=(g[n][0]-1)*f[n][0]; for(int j=-mxdep;j<=mxdep;++j) ans+=g[n-j][1]*f[n+j][1]+g[n-j][0]*f[n+j][1]+g[n-j][1]*f[n+j][0]; for(int j=n-mxdep;j<=n+mxdep;++j) { g[j][0]+=f[j][0]; g[j][1]+=f[j][1]; f[j][0]=f[j][1]=0; } } for(int i=n-mx;i<=n+mx;++i) g[i][0]=g[i][1]=0; for(int i=head[x];i;i=e[i].next) if(!vis[v]) { sum=size[v];rt=0;getroot(v,0);solve(rt); } } #undef v int main() { n=read(); for(int i=1;i<n;++i) { int u=read(),v=read(),w=read(); if(!w) w=-1; add(u,v,w); } sum=mx[0]=n; getroot(1,0); solve(rt); printf("%d\n",ans); } ```