题解 P5838 【[USACO19DEC]Milk Visits G】

· · 题解

看了一下题解,发现都是(我不会的)平衡树,主席树和神仙离线线性算法等,还没有分块题解,就交了这一发。

题目的意思是给出一棵树和树上每个点的点权,每次给出树上的两个节点a,b,求两节点之间的简单路径上有没有节点的权值为c。我的思路是用树剖把树上操作转移到序列上来,再用分块维护。

先说分块,分块是把应该序列平均分成\lceil\frac{n}{\lfloor\sqrt{n}\rfloor}\rceil个长度为\lfloor\sqrt{n}\rfloor的块来维护,不过最后一个块会短一点。在查询的时候,边角块直接暴力,一个整块直接二分,单次查询的复杂度为O(\sqrt{n})

再说树链剖分,这是一个通过把树分成几条互不相交的链,从而让分块,线段树这样用来维护序列的数据结构可以用在树上。这里的树剖是轻重链剖分,优先遍历节点中子树最大的儿子,给每个节点标记一下dfs序,维护时根据dfs序当做序列维护就可以了。

代码不长,不到90行,2.91s,时间复杂度O(n\sqrt{n}\log^2n)

#include<bits/stdc++.h>
using namespace std;
int n,m,a[100005],u,v,p[100005],sum[100005],top[100005];
int son[100005],dep[100005],size[100005],f[100005];
int begin[100005],end[100005],s[100005],block,cnt;
vector<int>ve[100005];
void dfs1(int now,int fa,int deep) {//预处理 
    dep[now]=deep;
    f[now]=fa;
    size[now]=1;
    int maxsize=0;
    for(int i=0; i<ve[now].size(); i++) {
        int y=ve[now][i];
        if(y==fa)continue;
        dfs1(y,now,deep+1);
        if(size[y]>maxsize)maxsize=size[y],son[now]=y;
        size[now]+=size[y];
    }
    return;
}
void dfs2(int now,int _top) {
    top[now]=_top;
    p[now]=++cnt;
    sum[cnt]=a[now];
    if(!son[now])return;
    dfs2(son[now],_top);//先访问重儿子 
    for(int i=0; i<ve[now].size(); i++) {
        int y=ve[now][i];
        if(y==f[now]||y==son[now])continue;
        dfs2(y,y);
    }
    return;
}
void build() {
    block=sqrt(n);
    int total=0;
    for(int i=1; i<=n; i++)s[i]=sum[i];
    for(int i=1; i<=n; i+=block) {
        total++;
        begin[total]=i;
        end[total]=min(n,i+block-1);
        sort(s+i,s+min(n,i+block-1)+1);//排序,便于二分 
    }
    return;
}
bool find1(int l,int r,int k) {
    int numl=l/block+(l%block!=0),numr=r/block+(r%block!=0);
    if(numl==numr) {
        for(int i=l; i<=r; i++)if(sum[i]==k)return true;
        return false;
    }
    for(int i=l; i<=end[numl]; i++)if(sum[i]==k)return true;
    for(int i=begin[numr]; i<=r; i++)if(sum[i]==k)return true;
    for(int i=numl+1; i<numr; i++) {
        int z=s[lower_bound(s+begin[i],s+end[i]+1,k)-s];//块内二分 
        if(z==k)return true;
    }
    return false;
}
bool find2(int x,int y,int k) {
    while(top[y]!=top[x]) {
        if(dep[top[x]]<dep[top[y]])swap(x,y);
        if(find1(p[top[x]],p[x],k))return true;
        x=f[top[x]];
    }
    if(dep[x]>dep[y])swap(x,y);
    return find1(p[x],p[y],k);
}
int main() {
    scanf("%d%d",&n,&m);
    for(int i=1; i<=n; i++)scanf("%d",&a[i]);
    for(int i=1; i<n; i++) {
        scanf("%d%d",&u,&v);
        ve[u].push_back(v);
        ve[v].push_back(u);
    }
    dfs1(1,0,1);
    dfs2(1,1);
    build();
    int x,y,z;
    while(m--) {
        scanf("%d%d%d",&x,&y,&z);
        if(find2(x,y,z))printf("1");
        else printf("0");
    }
}