P6992 [NEERC2014] Hidden Maze 题解

· · 题解

经典题。

我们称边上的原权为边的值,后面赋的权为边的权。考虑弱化一下,如何求中位数恰为 k 的路径数,显然可以把 \le k 的边设为 -1\ge k 的设为 1,如果一条路径边权和为 1 且恰经过某条值为 k 的边,那显然恰好中位数为 k

路径和这种东西可以点分,但不好扩展和处理恰经过的条件,考虑 dp,设 f_{u,i}u 子树内所有点到 u 边权和为 i 的点数,转移 O(\sum dep_i) 是简单的。对于计算答案,只需要枚举所有值为 k 的边 (u,v,k),dep_u<dep_v,枚举 u 祖先 UUu 方向的儿子 V,撤销 V 子树对 dp_U 的贡献,然后枚举 U 子树内 V 子树外的边权和计算一下即可。实现只需要开一个栈维护到根链然后倒着扫。就是经典撤销子树内贡献的方法。

怎么扩展呢,如果我们按边的值顺次加入边,那么每次只会影响祖先的 dp 值,先倒着撤销,再暴跳修改即可,本质上是动态 dp。

时间复杂度考虑一个点被重构的次数和单次复杂度,为 O(\sum sz_i(maxdep_i-dep_i)),另一篇题解说这个在树按照本题的方式随机生成的情况下复杂度为 O(n\sqrt n)

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
#define P 250
#define ll long long
using namespace std;

int n,m;
int stk[P+5];
int ww[30005];
int fa[30005];
int mde[30005];
int dep[30005];
int f[30005][P*2+5];
vector <int> g[30005];
struct node{int u,v,w;}e[30005];
ll all;

inline void in(int &n){
    n=0;
    char c=getchar();
    while(c<'0' || c>'9') c=getchar();
    while(c>='0'&&c<='9') n=n*10+c-'0',c=getchar();
    return ;
}

inline pair <int,int> dfs(int u,int fath){
    fa[u]=fath;
    dep[u]=dep[fa[u]]+1;
    mde[u]=0;
    f[u][P]=1;
    int s1=0,s2=0;
    if(dep[u]&1) s1++;
    else s2++;
    for(int v:g[u]){
        if(v==fath) continue;
        ww[v]=-1;
        auto tmp=dfs(v,u);
        mde[u]=max(mde[u],mde[v]+1);
        all+=s1*tmp.second+s2*tmp.first;
        s1+=tmp.first,s2+=tmp.second;
        for(int i=-mde[u];i<=mde[u];i++) f[u][i+P]+=f[v][i+1+P];
    }
    return {s1,s2};
}

signed main(){
    in(n);
    for(int i=1;i<n;i++){
        in(e[i].u),in(e[i].v),in(e[i].w);
        g[e[i].u].emplace_back(e[i].v);
        g[e[i].v].emplace_back(e[i].u);
    }
    dfs(1,0);
    sort(e+1,e+n,[](node p,node q){return p.w<q.w;});
    ll ans=0,S=0;
    for(int i=1;i<n;i++){
        int u=e[i].u,v=e[i].v;
        if(dep[u]<dep[v]) swap(u,v);
        ans=0;
        int top=0,uu=fa[u],s=1;
        stk[0]=u;
        while(uu) stk[++top]=uu,s+=ww[uu],uu=fa[uu];
        for(int d=top;d>=1;d--){
            int U=stk[d],V=stk[d-1];
            for(int j=-mde[U];j<=mde[U];j++) f[U][j+P]-=f[V][j-ww[V]+P];
            for(int j=-mde[U];j<=mde[U];j++){
                int x=1-j-s;
                if(x<-P||x>P) continue;
                ans+=f[U][j+P]*f[u][x+P];
            }
            s-=ww[V];
        }
        S+=e[i].w*ans;
        ww[u]=1;
        int U=fa[u],V=u;
        while(U){
            for(int j=-mde[U];j<=mde[U];j++) f[U][j+P]+=f[V][j-ww[V]+P];
            U=fa[U],V=fa[V];
        }
    }
    printf("%.9lf\n",1.0*S/all);

    return 0;
}