题解 CF791D 【Bear and Tree Jumps】

2019-02-14 17:45:18


题目大意就是给你一颗树,然后问所有的链的链长除以k(上取整)的总和是多少。

第一反应,所有的链布拉布拉的,首先就是点分治吧。。

于是我们就用点分治来做,但是注意,CF上是会RE爆栈的,但是其实本地测是能过的。

但是注意,普通用点分治统计路径长度的话是会T的,因为统计的时候那个平方太慢了。于是,我们可以根据余数来统计。具体见代码。

#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
#define int long long
int cas,n,k,cnt=0,rt,sum,res=0;
int head[150010],nxt[300100],to[300100],ans=0;
int f[300100],siz[300100],dep[300100],tmp[10],vis[300100],num[10];
int read()
{
    int u=0;char c=getchar();
    while(c<'0'||c>'9') c=getchar();
    while(c>='0'&&c<='9') u=(u<<1)+(u<<3)+c-'0',c=getchar();
    return u;
}
void addedge(int x,int y)
{
    cnt++;
    nxt[cnt]=head[x];
    head[x]=cnt;
    to[cnt]=y;
}
void getroot(int u,int fa)
{
    siz[u]=1;f[u]=0;
    for(int i=head[u];i!=-1;i=nxt[i])
    {
        int v=to[i];
        if(v==fa || vis[v]) continue;
        getroot(v,u);
        siz[u]+=siz[v];
        f[u]=max(f[u],siz[v]);
    }
    f[u]=max(f[u],sum-siz[u]);
    if(f[u]<f[rt]) rt=u;
}
void getdep(int u,int fa)
{
    num[dep[u]%k]++;
    tmp[dep[u]%k]+=dep[u]/k;//先按余数存链的数量和长度
    for(int i=head[u];i!=-1;i=nxt[i])
    {
        int v=to[i];
        if(v==fa || vis[v]) continue;
        dep[v]=dep[u]+1,getdep(v,u);
    }
}
void Add(int u,int s,int p)
{
    dep[u]=s;cnt=0;
    memset(num,0,sizeof(num));
    memset(tmp,0,sizeof(tmp));
    getdep(u,0);
    for(int i=0;i<k;i++)
    {
        if(i) res+=(i*2>k?2:1)*num[i]*(num[i]-1)/2*p;
        res+=(num[i]-1)*tmp[i]*p;//这里先处理余数相同的链之间的答案,
        for(int j=i+1;j<k;j++)
            res+=((i+j>k?2:1)*num[i]*num[j]+tmp[i]*num[j]+tmp[j]*num[i])*p;//处理余数不同的
    }
}
void solve(int u)
{
    Add(u,0,1);vis[u]=1;
    for(int i=head[u];i!=-1;i=nxt[i])
    {
        int v=to[i];
        if(vis[v]) continue;
        Add(v,1,-1);
        sum=siz[v];rt=0;
        getroot(v,u);
        solve(rt);
    }
}
signed main()
{
//  freopen("tree.in","r",stdin);
//  freopen("tree.out","w",stdout);
    memset(head,-1,sizeof(head));
    n=read(),k=read();
    for(int i=1;i<n;i++)
    {
        int x=read(),y=read();
        addedge(x,y);
        addedge(y,x);
    }
    sum=f[0]=n;getroot(1,0);
    solve(rt);
    printf("%I64d",res);
    return 0;
}