题解 P3085 【[USACO13OPEN]阴和阳Yin and Yang】

· · 题解

闲话

偶然翻到一道没有题解的淀粉质,想证明一下自己是真的弱

然而ZSYC(字符串组合)早就切了

然后证明成功了,WA到怀疑人生,只好借着ZSY的代码拍,拍了几万组就出来了。。。

思路

是人都能想到的:路径统计,点分治跑不了了。

然而这个统计有些麻烦。。。

首先别看错题,是中间的一个点到两个端点的两条路径都要满足黑白相等。(因为蒟蒻就看错了)

显然,我们每次要统计经过重心的路径,但是这个中点不一定会在重心。于是,必须要更一般化地统计了。

容易想到的是差分。记d_xx到重心路径上黑白个数的差,可以改一下边权变成1-1,更好实现。

这时候,我们已经可以做到统计黑白个数相等的路径了,一对差分值为d-d的路径就可以产生贡献。

但是,题目要两条子路径都黑白相等啊!再想想,假设有两个点x,y,d_x=-d_y,那么是不是只有在x的祖先中存在zd_z=-d_y,或者在y的祖先中存在zd_x=-d_z,路径x-y才能贡献答案?

这也就是说,我们选择的中点zd值,需要和d_xd_y相等才行。

于是,我们根据每个点是否有祖先的d等于它的d,分成两类。显然,没有的点只能和有的点算贡献,而有的点可以和两种点都算贡献。开两个桶区别统计一下就好了。

具体实现的话,蒟蒻直接依次枚举重心的子树,先统计一个子树的答案,再把子树的每个点丢进桶里。这样就不用容斥去除不合法情况了。但是要特判一个端点在重心上的路径的贡献。

如何快速判断当前点的属于哪一类?开一个数组b统计祖先中每个d值出现的次数就可以了(d可能是负数,要把所有下标加n)。注意回溯的时候要减掉。

注意答案要开long long,注意清空数组。

代码算是很短的了。

#include<cstdio>
#include<cstring>
#define RG register
#define R RG int
#define G c=getchar()
const int N=1e5+1,M=N<<1;
int n,rt,mx,mn,he[N],ne[M],to[M],c[M],s[N],m[N],b[M],f[M],g[M];
bool vis[N];
long long ans;
inline void min(R&x,R y){if(x>y)x=y;}
inline void max(R&x,R y){if(x<y)x=y;}
inline int in(){
    RG char G;
    while(c<'-')G;
    R x=c&15;G;
    while(c>'-')x=x*10+(c&15),G;
    return x;
}
void dfs(R x){//求重心
    vis[x]=1;s[x]=1;m[x]=0;
    for(R y,i=he[x];i;i=ne[i])
        if(!vis[y=to[i]])
            dfs(y),s[x]+=s[y],max(m[x],s[y]);
    max(m[x],n-s[x]);
    if(m[rt]>m[x])rt=x;
    vis[x]=0;
}
void upd(R x,R d){//统计答案,f没有g有
    min(mn,d);max(mx,d);
    ans+=g[M-d];//分类贡献
    if(b[d])ans+=f[M-d];
    if(d==N)ans+=b[d]>1;//特判
    vis[x]=1;++b[d];
    for(R i=he[x];i;i=ne[i])
        if(!vis[to[i]])upd(to[i],d+c[i]);
    vis[x]=0;--b[d];//回溯清零
}
void mdf(R x,R d){//修改桶
    ++(b[d]?g[d]:f[d]);
    vis[x]=1;++b[d];
    for(R i=he[x];i;i=ne[i])
        if(!vis[to[i]])mdf(to[i],d+c[i]);
    vis[x]=0;--b[d];
}
void div(R x){//分治
    rt=0;dfs(x);x=rt;
    vis[x]=1;b[mn=mx=N]=1;//此处注意初始化
    R t=n,y,i;
    for(i=he[x];i;i=ne[i])
        if(!vis[y=to[i]])
            upd(y,N+c[i]),mdf(y,N+c[i]);
    memset(f+mn,0,(mx-mn+1)<<2);//注意清空
    memset(g+mn,0,(mx-mn+1)<<2);
    for(i=he[x];i;i=ne[i])
        if(!vis[y=to[i]])
            n=s[x]>s[y]?s[y]:t-s[x],div(y);
}
int main(){
    m[0]=1e9;n=in();
    for(R a,b,p=0,i=1;i<n;++i){
        a=in();b=in();
        ne[++p]=he[a];to[he[a]=p]=b;
        ne[++p]=he[b];to[he[b]=p]=a;
        c[p]=c[p-1]=in()?1:-1;//处理边权
    }
    div(1);
    printf("%lld\n",ans);
    return 0;
}