「DROI」Round 1 距离 题解

· · 题解

不妨思考极远点对的本质是什么。发现若 (u,v) 为极远点对,可以看作从任意点 u 找到距离其最远点 v,再从点 v 找到距离其最远点且此点恰好为 u,不难察觉这符合dfs两遍求树的直径的过程,故 u,v 为同一直径两端点。

所以题目意转化为对于每个点,求出其在多少条直径上。

将树以 1 为根定型,然后求解答案。

不妨先树形 DP 求出总直径数 sum,以 x 为端点直径数的子树和 l_x,挂在 x 上的直径数 p_xp_x 的子树和 s_x。其中除了 l_x 的求解略显繁琐外,其他的求解都是朴素的(注意细节,一定要想清楚再写,数组的含义千万不能弄混)。

考虑经过 x 的直径数,不难得到:v_x =p_x + (l_x - s_x \times 2)。前者是挂在 x 上的直径,后者是简单容斥得到的横穿 x 子树内外的直径。同样,这样的计算是直观的。

至此整道题解决完毕,时间复杂度线性。

下面阐述一下部分分缘由:

下附代码:

#include<bits/stdc++.h>
using namespace std;
#define LL long long
const int N=5e6+10,mod=998244353;
int n,k,len,ans,s[N],l[N];
int h[N],e[N<<1],ne[N<<1],idx;
struct node
{
    int val,cnt;
}p[N],fi[N],se[N];
inline int read()
{
    int x=0;char ch=getchar();
    while(ch<'0'||ch>'9') ch=getchar();
    while(ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
    return x;
}
inline void add(int a,int b)
{       
    e[idx]=b,ne[idx]=h[a],h[a]=idx++;
}
inline void dfs(int u,int fa)
{
    fi[u].val=fi[u].cnt=1;
    for(int i=h[u];~i;i=ne[i])
    {
        if(e[i]==fa)    continue;
        dfs(e[i],u);
        if(fi[u].val+fi[e[i]].val>p[u].val) p[u]={fi[u].val+fi[e[i]].val,(int)((LL)fi[u].cnt*fi[e[i]].cnt%mod)};    
        else if(fi[u].val+fi[e[i]].val==p[u].val) p[u].cnt=(p[u].cnt+(LL)fi[u].cnt*fi[e[i]].cnt%mod)%mod;
        if(fi[e[i]].val+1>fi[u].val) se[u]=fi[u],fi[u]={fi[e[i]].val+1,fi[e[i]].cnt};
        else if(fi[e[i]].val+1==fi[u].val) fi[u].cnt+=fi[e[i]].cnt;
        else if(fi[e[i]].val+1>se[u].val) se[u]={fi[e[i]].val+1,fi[e[i]].cnt};
        else if(fi[e[i]].val+1==se[u].val) se[u].cnt+=fi[e[i]].cnt;
    } 
    len=max(len,p[u].val);
}
inline void DFS(int u,int fa,int up,int num)
{
    s[u]=p[u].cnt=(p[u].val==len?p[u].cnt:0);
    for(int i=h[u];~i;i=ne[i])
    {
        if(e[i]==fa)    continue;
        int cur=max(up+1,(fi[e[i]].val+1==fi[u].val&&fi[e[i]].cnt==fi[u].cnt)?se[u].val:fi[u].val);
        if(fi[e[i]].val+1==fi[u].val&&fi[e[i]].cnt==fi[u].cnt) DFS(e[i],u,cur,num*((up+1)==cur)+se[u].cnt*((se[u].val)==cur));
        else if(fi[e[i]].val+1==fi[u].val) DFS(e[i],u,cur,num*((up+1)==cur)+(fi[u].cnt-fi[e[i]].cnt)*((fi[u].val)==cur));
        else DFS(e[i],u,cur,num*((up+1)==cur)+fi[u].cnt*((fi[u].val)==cur));
        l[u]=(l[u]+l[e[i]])%mod;s[u]=(s[u]+s[e[i]])%mod;
    }
    if(up+1==len) l[u]=(l[u]+num)%mod;
    if(fi[u].val==len) l[u]=(l[u]+fi[u].cnt)%mod;
}
inline int pwr(int x,int y){return y==1?x:(LL)x*x%mod;}
int main()
{
    memset(h,-1,sizeof h);
    n=read(),k=read();
    for(int i=1;i<n;i++)    
    {
        int u,v;
        u=read(),v=read();
        add(u,v),add(v,u);
    }
    dfs(1,-1);DFS(1,-1,0,0);
    for(int i=1;i<=n;i++) ans=(ans+pwr(p[i].cnt+(l[i]-s[i]*2),k))%mod;
    printf("%d",ans);
    return 0;
}

似乎验题人有更简单的做法。。。