题解 P4930 【「FJ2014集训」采药人的路径】
Shikita
2019-03-19 16:40:07
# 树上路径统计 点分治
点分治是用来解决树上路径问题的一种方法。
在解决树上路径问题时,我们可以选取一点为根,将树转化为有根树,然后考虑经过根的所有路径(有时将两条从根出发的路径连接为一条)。统计完这些路径的答案后,将根节点标记为删除,对剩下的若干棵树进行同样的操作。
每次确认一个根节点后,共有$n$条需要考虑的路径($n$为当前子树大小)。若将当前根节点删除后,剩下某棵的子树较大,和原树大小相当,继续处理这棵子树时仍然需要与前一过程相当的时间。
最严重的情况,当整棵树是一条链时,每次需要考虑的路径数量是$O(n)$级别的,如果每条路径需要常数时间进行统计,则总时间复杂度为$O(n^2)$
而对于形态随机的树,则远远小于这个级别。
所以我们应该合理选取树的根节点
而树的重心能够很好满足我们的需求
我们定义一棵树的重心为以该点为根时最大子树最小的点。
那么我们每次删除掉当前根节点后,剩余的子树能够保证尽量平衡,这就可以大大减小复杂度
```
//mx[x] 记录的是x的子树中最大的子树大小
inline void getroot(int x,int fa)
{
size[x]=1,mx[x]=0;
for(int i=head[x];i;i=e[i].next)
if(v!=fa&&!vis[v])
{
getroot(v,x);
size[x]+=size[v];
mx[x]=max(mx[x],size[v]);
}
mx[x]=max(mx[x],sum-size[x]);
if(mx[x]<mx[rt]) rt=x;
}
```
~~我觉得你们自己看代码应该好懂,不解释了~~
那么对于本题,我们将阴阳点分别置为1和-1,求路径上有一个0点且终点为0的路径条数
我们令$g[dis][0/1]$表示之前dfs过的子树中,到根的距离为$dis$的路径数目
$f$数组表示当前正在dfs的子树,到根的距离为$dis$的路径数目
0/1表示其中是否出现过前缀和为0的点。
```
#include<bits/stdc++.h>
#define ll long long
using namespace std;
inline int read()
{
int x=0;
char c=getchar();
bool flag=0;
while(c<'0'||c>'9') {if(c=='-')flag=1;c=getchar();}
while(c>='0'&&c<='9'){x=(x+(x<<2)<<1)+ c-'0';c=getchar();}
return flag?-x:x;
}
const int N=1e5+5;
int n,sum,rt,mxdep;
int vis[N],t[N<<1],mx[N],size[N],dep[N],dis[N];
ll ans,g[N<<1][2],f[N<<1][2];
int head[N],cnt;
struct node{int to,next,w;}e[N<<1];
inline void add(int u,int v,int w)
{
e[++cnt].to=v,e[cnt].next=head[u],e[cnt].w=w,head[u]=cnt;
e[++cnt].to=u,e[cnt].next=head[v],e[cnt].w=w,head[v]=cnt;
}
//mx[x] x的最大子树
#define v e[i].to
inline void getroot(int x,int fa)
{
size[x]=1,mx[x]=0;
for(int i=head[x];i;i=e[i].next)
if(v!=fa&&!vis[v])
{
getroot(v,x);
size[x]+=size[v];
mx[x]=max(mx[x],size[v]);
}
mx[x]=max(mx[x],sum-size[x]);
if(mx[x]<mx[rt]) rt=x;
}
//找重心
inline void dfs(int x,int fa)
{
mxdep=max(mxdep,dep[x]);
(t[dis[x]])?++f[dis[x]][1]:++f[dis[x]][0];
++t[dis[x]];
for(int i=head[x];i;i=e[i].next) if(v!=fa&&!vis[v])
{
dep[v]=dep[x]+1;
dis[v]=dis[x]+e[i].w;
dfs(v,x);
}
--t[dis[x]];
}
inline void solve(int x)
{
int mx=0;vis[x]=g[n][0]=1;
for(int i=head[x];i;i=e[i].next) if(!vis[v])
{
dis[v]=n+e[i].w;dep[v]=1;mxdep=1;
dfs(v,0);mx=max(mx,mxdep);
ans+=(g[n][0]-1)*f[n][0];
for(int j=-mxdep;j<=mxdep;++j)
ans+=g[n-j][1]*f[n+j][1]+g[n-j][0]*f[n+j][1]+g[n-j][1]*f[n+j][0];
for(int j=n-mxdep;j<=n+mxdep;++j)
{
g[j][0]+=f[j][0];
g[j][1]+=f[j][1];
f[j][0]=f[j][1]=0;
}
}
for(int i=n-mx;i<=n+mx;++i) g[i][0]=g[i][1]=0;
for(int i=head[x];i;i=e[i].next)
if(!vis[v])
{
sum=size[v];rt=0;getroot(v,0);solve(rt);
}
}
#undef v
int main()
{
n=read();
for(int i=1;i<n;++i)
{
int u=read(),v=read(),w=read();
if(!w) w=-1;
add(u,v,w);
}
sum=mx[0]=n;
getroot(1,0);
solve(rt);
printf("%d\n",ans);
}
```