题解:P10795 『SpOI - R1』Lamborghini (Demo)

· · 题解

树上笛卡尔树。

我们显然需要一个结构来表示“树上路径上点权最小的点”。在链上,可以选用笛卡尔树;在树上同理。具体的建树方式是:按点权从大往小枚举每个点 i,在它的所有相邻的点中,对于点权大于 i 的点 j,我们连边 (i,j)。这样连边后最终会得到一棵树,且若以点权最小的点为根,则新树上任意两点的 lca 为原树上两点间路径上权值最小的点

则原问题转化为:对于每个点 x,求有多少个点对 (i,j) 满足 lca_{i,j}=xv_i\leq v_x\leq v_j,对 v 离散化后用权值线段树维护各个值的点的数量,线段树合并即可。复杂度 O(n\log n)

#include<bits/stdc++.h>
#define rep(i,j,k) for(int i=j;i<=k;i++)
#define repp(i,j,k) for(int i=j;i>=k;i--)
#define ls(x) lson[x]
#define rs(x) rson[x]
#define mp make_pair
#define fir first
#define sec second
#define pii pair<int,int>
#define lowbit(x) x&-x
#define int long long
#define qingbai 666
using namespace std;
typedef long long ll;
const int N=1e5+5,M=2e5+5,S=(1<<17)+2,mo=998244353,inf=1e18+7;
const double eps=1e-8;
void read(int &p){
    int x=0,w=1;
    char ch=0;
    while(!isdigit(ch)){
        if(ch=='-')w=-1;
        ch=getchar();
    }
    while(isdigit(ch)){
        x=(x<<1)+(x<<3)+ch-'0';
        ch=getchar();
    }
    p=x*w;
}
int n,lsh[N],cntl;
struct point{
    int v,a,id;
    friend bool operator<(point x,point y){
        return x.a>y.a;
    }
}p[N];
int ans;
struct tree{
    struct edge{
        int to,nxt;
    }e[N];
    int fir[N],np;
    void add(int x,int y){
        e[++np]=(edge){y,fir[x]};
        fir[x]=np;
    }
    struct seg{
        int t[40*N],lson[40*N],rson[40*N],rt[N],cnt;
        void pushup(int x){
            t[x]=t[ls(x)]+t[rs(x)];
        }
        void add(int &x,int le,int ri,int p){
            if(!x)x=++cnt;
            if(le==ri){
                t[x]++;
                return;
            }
            int mid=(le+ri)>>1;
            if(p<=mid)add(ls(x),le,mid,p);
            else add(rs(x),mid+1,ri,p);
            pushup(x); 
        }
        int query(int x,int le,int ri,int ql,int qr){
            if(ql>qr)return 0;
            if(!x)return 0;
            if(ql<=le&&qr>=ri)return t[x];
            int mid=(le+ri)>>1,ret=0;
            if(ql<=mid)ret+=query(ls(x),le,mid,ql,qr);
            if(qr>mid)ret+=query(rs(x),mid+1,ri,ql,qr);
            return ret;
        }
        int merge(int p,int q,int le,int ri){
            if(!p)return q;
            if(!q)return p;
            if(le==ri){
                t[p]+=t[q];
                return p;
            }
            int mid=(le+ri)>>1;
            ls(p)=merge(ls(p),ls(q),le,mid),rs(p)=merge(rs(p),rs(q),mid+1,ri); 
            pushup(p);
            return p;
        }
    }T;
    void dfs(int x){
        int res=0;
        T.add(T.rt[x],1,cntl,p[x].v);
        for(int i=fir[x];i;i=e[i].nxt){
            int j=e[i].to;
            dfs(j);
            int nw1=T.query(T.rt[j],1,cntl,1,p[x].v-1);
            int nw2=T.query(T.rt[j],1,cntl,p[x].v,p[x].v);
            int nw3=T.query(T.rt[j],1,cntl,p[x].v+1,cntl);
            res+=nw1*T.query(T.rt[x],1,cntl,p[x].v,cntl);
            res+=nw2*(T.query(T.rt[x],1,cntl,1,cntl)+T.query(T.rt[x],1,cntl,p[x].v,p[x].v));
            res+=nw3*T.query(T.rt[x],1,cntl,1,p[x].v);
            T.rt[x]=T.merge(T.rt[x],T.rt[j],1,cntl);
        }
        res++;
        ans+=res*p[x].id;
    }
}T;
int fa[N];
int find(int x){
    if(fa[x]==x)return x;
    return fa[x]=find(fa[x]);
}
struct edge{
    int to,nxt;
}e[2*N];
int fir[N],np,nid[N];
void add(int x,int y){
    e[++np]=(edge){y,fir[x]};
    fir[x]=np;
}
void merge(int x,int y){
    int fx=find(x),fy=find(y);
    if(fx==fy)return;
    fa[fy]=fx,T.add(fx,fy);
}
map <int,int> ok;
signed main(){
    read(n);
    rep(i,1,n)
        read(p[i].a),ok[p[i].a]=1,fa[i]=i,p[i].id=i;
    assert(ok.size()==n);
    rep(i,1,n)
        read(p[i].v),lsh[++cntl]=p[i].v;
    sort(lsh+1,lsh+cntl+1),cntl=unique(lsh+1,lsh+cntl+1)-lsh-1;
    rep(i,1,n)
        p[i].v=lower_bound(lsh+1,lsh+cntl+1,p[i].v)-lsh;
    sort(p+1,p+n+1);
    rep(i,1,n)
        nid[p[i].id]=i;
    rep(i,1,n-1){
        int x,y;
        read(x),read(y),add(nid[x],nid[y]),add(nid[y],nid[x]);
    }
    rep(i,1,n){
        for(int p=fir[i];p;p=e[p].nxt){
            int j=e[p].to;
            if(i>j)merge(i,j);//i is the father of j.
        }
    }
    T.dfs(n);
    printf("%lld\n",ans);
    return 0; 
}