洛谷 P5669 [SDOI2018] 原题识别-改 题解--zhengjun

· · 题解

题面

鉴于这题目前还没题解,提供一种时间 \Theta(n\sqrt{m}),空间 \Theta(n+m) 的做法。

询问 1

可以直接上树分块或者树上莫队,见 P6177 Count on a tree II/【模板】树分块。

但是因为本题询问 2 的做法,所以我采用了树上莫队的做法。

询问 2

方便起见:

这里直接考虑 u,v 不为祖先关系的情况(u,v 为祖先关系的情况显然严格弱于这个,特判一下即可)。

所以答案即为:

\sum\limits_{i\in \operatorname{path}(1,u)}\sum\limits_{j\in \operatorname{path}(1,v)}f(i,j)

因为我们发现,答案的形式非常像对于一个区间的所有子区间求和,那么我们引入新的函数:

F(u,v)=\sum\limits_{i,j\in \operatorname{path}(u,v)}f(i,j)

首先观察这个 F,将 \operatorname{path}(u,v) 理解为一个区间 [1,m]

它的实际意义就是 [1,m] 的所有子区间的本质不同的颜色数之和。

但是这样并不好计算,我们考虑另外一种实际意义:对于每种颜色,计算 [1,m] 的所有子区间中包含该颜色的个数和。

如果把颜色 c 删去,序列剩下来长度为 l_1,l_2,\cdots ,l_{k_c}k_c 段连续区间,那么该颜色的贡献就是 \binom{n+1}{2}-\sum \binom{l_i+1}{2}

那么,如果在序列的后面加入一个元素,那么答案的增量就是 \sum\limits_{c}n'-l'_{c,k'_c}

所以,我们如果我们维护出了 suf_c=l_{c,k_c} 以及它的和,那么我们就可以 \Theta(1) 向右边扩展了。

注意,我们同时可以 \Theta(1) 删除最后一个位置。

使用链表维护相同颜色的位置,并实时记录每个颜色的起始位置,维护出 \sum suf_c\sum pre_c,这样左右端点都能够 \Theta(1) 左右移动了。

现在,我们就可以使用树上莫队来计算 F(u,v) 了。

接下来考虑怎么计算答案。

设询问的两个节点分别为 u,v

tu,v 的最近公共祖先 (t\ne u , t\ne v)

p,qt 的两个不同的儿子且 p\in \operatorname{path}(u,t),q\in \operatorname{path}(v,t)

考虑对答案进行转化,这里直接给出结果:

\begin{aligned} ans & = \sum\limits_{i\in \operatorname{path}(1,u)}\sum\limits_{j\in \operatorname{path}(1,v)}f(i,j)\\ & = F(1,u)+F(1,v)-|\operatorname{path}(1,t)|-F(u,p)-F(v,q)+F(u,v)-F(u,t)-F(v,t)+1 \end{aligned}

其中后面一大坨的尾巴是 \operatorname{path}(u,p) \operatorname{path}(v,q) 之间的贡献,即:

\sum\limits_{i\in \operatorname{path}(u,p)}\sum\limits_{j\in \operatorname{path}(v,q)}f(i,j)=F(u,v)-F(u,t)-F(v,t)+1

最后加一是因为 f(t,t) 被两边都减了一遍,类似于容斥。

而剩下的贡献就是 F(1,u)+F(1,v)-|\operatorname{path}(1,t)|-F(u,p)-F(v,q)

做到这里似乎已经做完了……

细节处理:

本人直接写完后不卡常最大点用时 4.05s,经过调整块长、对莫队的排序进行奇偶优化过后,最大点用时 1.96s,效率还行,毕竟询问 2 有个 5 倍常数。

代码

#include<bits/stdc++.h>
using namespace std;
using ll=long long;
#ifdef DEBUG
template<class T>
ostream& operator << (ostream &out,vector<T> a){
    out<<'[';
    for(T x:a)out<<x<<',';
    return out<<']';
}
template<class T>
vector<T> ary(T *a,int l,int r){
    return vector<T>{a+l,a+1+r};
}
template<class T>
void debug(T x){
    cerr<<x<<endl;
}
template<class T,class...S>
void debug(T x,S...y){
    cerr<<x<<' ',debug(y...);
}
#else
#define debug(...) void()
#endif
const int N=1e5+10,V=N*2,M=2e5+10;
int n,q,a[N];
int dft,B,id[V],pos[V],dfn[N];
vector<int>to[N];
struct ques{
    int l,r,id,w;
    bool operator < (const ques &a)const{
        return ::id[l]^::id[a.l]?::id[l]<::id[a.l]:(::id[l]&1?r<a.r:r>a.r);
    }
}o1[M],o2[M*5];
int m1,m2;
void make(int u,int fa=0){
    pos[dfn[u]=++dft]=u;
    for(int v:to[u])if(v^fa){
        make(v,u);
        pos[++dft]=u;
    }
}
namespace Path{
    int top[N],fa[N],dep[N],siz[N],son[N];
    void dfs1(int u){
        siz[u]=1,dep[u]=dep[fa[u]]+1;
        for(int v:to[u])if(v^fa[u]){
            fa[v]=u,dfs1(v);
            siz[u]+=siz[v];
            if(siz[v]>siz[son[u]])son[u]=v;
        }
    }
    int dft,dfn[N],pos[N];
    void dfs2(int u,int t){
        top[u]=t,pos[dfn[u]=++dft]=u;
        if(son[u])dfs2(son[u],t);
        for(int v:to[u])if(v^fa[u]&&v^son[u])dfs2(v,v);
    }
    void init(){
        dfs1(1),dfs2(1,1);
    }
    int LCA(int u,int v){
        for(;top[u]^top[v];u=fa[top[u]]){
            if(dep[top[u]]<dep[top[v]])swap(u,v);
        }
        return dep[u]<dep[v]?u:v;
    }
    int jump(int u,int k){
        for(;k>dep[u]-dep[top[u]];u=fa[top[u]])k-=dep[u]-dep[top[u]]+1;
        return pos[dfn[u]-k];
    }
}
using Path::dep;
namespace DS1{
    int now,cnt[N];
    void insert(int x){
        now+=!cnt[x]++;
    }
    void erase(int x){
        now-=!--cnt[x];
    }
    int query(){
        return now;
    }
}
namespace DS2{
    struct Queue{
        int a[N];
        const int& operator [] (const int &x)const{
            return a[(x%N+N)%N];
        }
        int& operator [] (const int &x){
            return a[(x%N+N)%N];
        }
    }col,pre,nex;
    int s,t;
    int now,cnt[N],bg[N],ed[N];
    ll s1,s2,ans;
    void init(){
        s=1e9,t=s-1,s1=s2=now=ans=0;
        memset(bg,0,sizeof bg);
        memset(ed,0,sizeof ed);
        memset(cnt,0,sizeof cnt);
    }
    void push_back(int x){
        // debug("push_back",x);
        col[++t]=x,now+=!cnt[x]++;
        s1+=n-now;
        s2+=n-(ed[x]?t-ed[x]:t-s+1);
        ans+=(t-s+1ll)*n-s2;
        pre[t]=ed[x],nex[t]=0;
        if(ed[x])nex[ed[x]]=t;
        ed[x]=t;
        if(!bg[x])bg[x]=t;
    }
    void pop_back(){
        // debug("pop_back");
        int x=col[t];
        ed[x]=pre[t];
        if(!ed[x])bg[x]=0;
        else nex[ed[x]]=0;
        ans-=(t-s+1ll)*n-s2;
        s2-=n-(ed[x]?t-ed[x]:t-s+1);
        s1-=n-now;
        now-=!--cnt[col[t--]];
    }
    void push_front(int x){
        // debug("push_front",x);
        col[--s]=x,now+=!cnt[x]++;
        s1+=n-(bg[x]?bg[x]-s:t-s+1);
        s2+=n-now;
        ans+=(t-s+1ll)*n-s1;
        nex[s]=bg[x],pre[s]=0;
        if(bg[x])pre[bg[x]]=s;
        bg[x]=s;
        if(!ed[x])ed[x]=s;
    }
    void pop_front(){
        // debug("pop_front");
        int x=col[s];
        bg[x]=nex[s];
        if(!bg[x])ed[x]=0;
        else pre[bg[x]]=0;
        ans-=(t-s+1ll)*n-s1;
        s2-=n-now;
        s1-=n-(bg[x]?bg[x]-s:t-s+1);
        now-=!--cnt[col[s++]];
    }
    ll query(){
        return ans;
    }
}
ll f[N],ans[M];
void dfs(int u,int fa=0){
    DS2::push_back(a[u]);
    f[u]=DS2::query();
    for(int v:to[u])if(v^fa){
        dfs(v,u);
    }
    DS2::pop_back();
}
int vis[N];
void solve1(){
    for(int i=1;i<=m1;i++){
        if(o1[i].l>o1[i].r)swap(o1[i].l,o1[i].r);
    }
    B=max(1.0,dft/max(1.0,sqrt(m1))*3);
    for(int i=1;i<=dft;i++)id[i]=(i-1)/B+1;
    sort(o1+1,o1+1+m1);
    auto go=[&](int u,int v){
        if(!vis[v])DS1::insert(a[v]),vis[v]=1;
        else DS1::erase(a[u]),vis[u]=0;
    };
    int l=1,r=0;
    for(int i=1;i<=m1;i++){
        for(;r<o1[i].r;r++)go(pos[r],pos[r+1]);
        for(;r>o1[i].r;r--)go(pos[r],pos[r-1]);
        for(;l<o1[i].l;l++)go(pos[l],pos[l+1]);
        for(;l>o1[i].l;l--)go(pos[l],pos[l-1]);
        ans[o1[i].id]+=DS1::query()*o1[i].w;
    }
}
void solve2(){
    memset(vis,0,sizeof vis);
    for(int i=1;i<=m2;i++){
        if(o2[i].l>o2[i].r)swap(o2[i].l,o2[i].r);
    }
    B=max(1.0,dft/max(1.0,sqrt(m2))*3);
    for(int i=1;i<=dft;i++)id[i]=(i-1)/B+1;
    sort(o2+1,o2+1+m2);
    auto go_t=[&](int u,int v){
        if(!vis[v])DS2::push_back(a[v]),vis[v]=1;
        else DS2::pop_back(),vis[u]=0;
    };
    auto go_s=[&](int u,int v){
        if(!vis[v])DS2::push_front(a[v]),vis[v]=1;
        else DS2::pop_front(),vis[u]=0;
    };
    int l=1,r=0;
    DS2::init();
    for(int i=1;i<=m2;i++){
        for(;r<o2[i].r;r++)go_t(pos[r],pos[r+1]);
        for(;r>o2[i].r;r--)go_t(pos[r],pos[r-1]);
        for(;l<o2[i].l;l++)go_s(pos[l],pos[l+1]);
        for(;l>o2[i].l;l--)go_s(pos[l],pos[l-1]);
        ans[o2[i].id]+=DS2::query()*o2[i].w;
    }
}
int main(){
    freopen(".in","r",stdin);
    // freopen(".out","w",stdout);
    scanf("%d%d",&n,&q);
    for(int i=1;i<=n;i++)scanf("%d",&a[i]);
    for(int i=1,u,v;i<n;i++){
        scanf("%d%d",&u,&v);
        to[u].push_back(v),to[v].push_back(u);
    }
    make(1),Path::init(),DS2::init(),dfs(1);
    debug("f",ary(f,1,n));
    for(int i=1,op,u,v;i<=q;i++){
        scanf("%d%d%d",&op,&u,&v);
        if(op==1){
            o1[++m1]={dfn[u],dfn[v],i,1};
        }else{
            int t=Path::LCA(u,v);
            ans[i]=f[u]+f[v]-dep[t];
            if(u^t)o2[++m2]={dfn[u],dfn[Path::jump(u,dep[u]-dep[t]-1)],i,-1};
            if(v^t)o2[++m2]={dfn[v],dfn[Path::jump(v,dep[v]-dep[t]-1)],i,-1};
            if(u^t&&v^t){
                ans[i]++;
                o2[++m2]={dfn[u],dfn[v],i,1};
                o2[++m2]={dfn[u],dfn[t],i,-1};
                o2[++m2]={dfn[v],dfn[t],i,-1};
            }
        }
    }
    solve1(),solve2();
    for(int i=1;i<=q;i++)printf("%lld\n",ans[i]);
    return 0;
}