题解 P5311 [Ynoi2011] 成都七中

· · 题解

算是题解里面的新解法阿。下面有一些废话,可以跳过。

最近几天自闭三连。雪乃输给蕾姆一点点,大老师对折棒也输了。为什么折棒这么猛呢!!虽然确实有点可爱。

然后今天下午去看电影,特效很牛逼我被撅晕了,在晕死的过程中发现这题淀粉质的做法也很牛逼,就是不太对头。怎么用奇怪解法呢!!为什么题解全是淀粉质呢!!!

然后发现可以用超超简单的非常常规解法。

如果设每条边的边权为两端点的最小值,那么一条路径上的边权最小值等价于点权最小值。最大值同理。

那么就把两种边权建重构树,发现就是等价于给你两棵树,每次给你两个子树求其交的颜色数这不就 dfn 拍下来然后回滚莫队跑二维分块吗!!!

直接对一棵树跑 dsu on tree,然后就是支持 O(n\log n) 次插入删除点,求子树中的颜色数。这个很典啊,考虑每种颜色的贡献就是这种颜色的点到根的链并,直接对每种颜色维护 set 按 dfn 排序,插入每个点然后把相邻点的 lca 的贡献抵消掉。其实就是 O(n\log n) 次链加单点求值。直接树状数组 O(n\log^2 n) 搞定。

我很久没写代码了然后写了一小会一发过了,说明代码很好写啊!!然而重构树的部分直接把隔壁 werewolf 的代码贺过来的。空间写的倍增 O(n\log n),其实线性很好搞。

这个应该做不了一般图,因为树路径唯一。

#include<bits/stdc++.h>
#define Yukinoshita namespace
#define Yukino std
using Yukinoshita Yukino;
int read()
{
    int s=0;
    char ch=getchar();
    while(ch<'0'||ch>'9') ch=getchar();
    while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();
    return s;
}
const int mxn=2e5+5;
struct edge
{
    int x,y,v;
    friend bool operator <(const edge &x,const edge &y)
    {
        return x.v<y.v;
    }
}a[mxn];
int c[mxn],f[mxn];
int v[mxn],v2[mxn];
int t[mxn][2],t2[mxn][2];
int fa[mxn][19],fa2[mxn][19];
int dfnl[mxn],dfnr[mxn],size[mxn],dep[mxn];
int find(int x)
{
    return f[x]==x?x:f[x]=find(f[x]);
}
int n,dcnt;
vector<pair<int,int> > q[mxn];
int ans[mxn];
struct bit
{
    #define lowbit(x) (x&-x)
    int t[mxn];
    void add(int x,int v)
    {
        if(!x) return;
        for(;x<n*2;x+=lowbit(x))
            t[x]+=v;
    }
    int ask(int x)
    {
        int s=0;
        for(;x;x-=lowbit(x))
            s+=t[x];
        return s;
    }
}T;
void dfs(int d,int f,int t[][2],int fa[][19])
{
    if(!d) return;
    fa[d][0]=f;
    for(int i=1;i<=18;i++)
        fa[d][i]=fa[fa[d][i-1]][i-1];
    dfs(t[d][0],d,t,fa),dfs(t[d][1],d,t,fa);
}
int asklca(int x,int y)
{
    if(!x||!y) return 0;
    if(x==y) return x;
    if(dep[x]>dep[y]) swap(x,y);
    int i;
    for(i=16;i>=0;i--)
        if(dep[fa2[y][i]]>=dep[x])
            y=fa2[y][i];  
    if(x==y) return x;
    for(i=16;i>=0;i--)
        if(fa2[x][i]!=fa2[y][i])
            x=fa2[x][i],y=fa2[y][i];
    return fa2[x][0];
}
void build(int t[][2],int fa[][19],int v[],bool fl)
{
    sort(a+1,a+n);
    if(fl) reverse(a+1,a+n);
    int i,fx,fy,cnt=n;
    dcnt=0;
    for(i=1;i<n*2;i++)
        f[i]=i;
    for(i=1;i<n;i++)
        if((fx=find(a[i].x))!=(fy=find(a[i].y)))
        {
            f[fx]=f[fy]=++cnt;
            v[cnt]=a[i].v;
            t[cnt][0]=fx,t[cnt][1]=fy;
        }
    dfs(cnt,0,t,fa);
}
void dfs1(int d)
{
    if(!d) return;
    size[d]=d<=n;
    dfs1(t[d][0]),dfs1(t[d][1]);
    size[d]+=size[t[d][0]]+size[t[d][1]];
}
void dfs2(int d,int deep)
{
    if(!d) return;
    dfnl[d]=++dcnt,dep[d]=deep;
    dfs2(t2[d][0],deep+1),dfs2(t2[d][1],deep+1);
    dfnr[d]=dcnt;
}
void ask(int id,int x,int L,int R)
{
    int i,y=x;
    for(i=18;i>=0;i--)
        if(v[fa[y][i]]<=R)
            y=fa[y][i]; 
    for(i=18;i>=0;i--)
        if(v2[fa2[x][i]]>=L)
            x=fa2[x][i];
    q[y].push_back({x,id});
}
set<pair<int,int> > st[mxn]; 
void insert(int d)
{
    int pre=(--st[c[d]].lower_bound({dfnl[d],d}))->second,nxt=st[c[d]].lower_bound({dfnl[d],d})->second;
    T.add(dfnl[d],1);
    T.add(dfnl[asklca(d,pre)],-1);
    T.add(dfnl[asklca(d,nxt)],-1);
    T.add(dfnl[asklca(pre,nxt)],1);
    st[c[d]].insert({dfnl[d],d});
}
void erase(int d)
{
    st[c[d]].erase({dfnl[d],d});
    int pre=(--st[c[d]].lower_bound({dfnl[d],d}))->second,nxt=st[c[d]].lower_bound({dfnl[d],d})->second;
    T.add(dfnl[d],-1);
    T.add(dfnl[asklca(d,pre)],1);
    T.add(dfnl[asklca(d,nxt)],1);
    T.add(dfnl[asklca(pre,nxt)],-1);
}
void add(int d,bool fl)
{
    if(!d) return;
    if(d<=n)
        fl?insert(d):erase(d);
    add(t[d][0],fl),add(t[d][1],fl);
}
void solve(int d)
{
    if(!d) return;
    if(size[t[d][0]]>size[t[d][1]]) swap(t[d][0],t[d][1]);
    solve(t[d][0]),add(t[d][0],0);
    solve(t[d][1]),add(t[d][0],1);
    if(d<=n)
        insert(d);
    for(int i=0;i<q[d].size();i++)
        ans[q[d][i].second]=T.ask(dfnr[q[d][i].first])-T.ask(dfnl[q[d][i].first]-1);
}
int main()
{
    int q,i,l,r;
    n=read(),q=read();
    for(i=1;i<=n;i++)
        c[i]=read();
    for(i=1;i<n;i++)
        a[i].x=read(),a[i].y=read(),a[i].v=max(a[i].x,a[i].y);
    v[0]=2e9,v2[0]=-2e9;
    build(t,fa,v,0);
    for(i=1;i<n;i++)
        a[i].v=min(a[i].x,a[i].y);
    build(t2,fa2,v2,1);
    dfs1(n*2-1),dfs2(n*2-1,0);
    for(i=1;i<=n;i++)
        st[i].insert({0,0}),st[i].insert({INT_MAX,0});
    for(i=1;i<=q;i++)
        l=read(),r=read(),ask(i,read(),l,r);
    solve(n*2-1);
    for(i=1;i<=q;i++)
        printf("%d\n",ans[i]);
}