P9399 解题报告

· · 题解

前言

双倍经验:Misha and LCP on Tree,但是要和我同个做法才能过。

这篇题解使用的是树剖做法,时间复杂度为 O(n+m\log n)

该题解思路鸣谢 lzyqwq,感谢他点醒了我成功拿下最优解。

同时这题是我树剖 part 2 中的一道练习题,详情文章可见:『从入门到入土』树链剖分学习笔记。

思路分析

题意:

显然的是可以直接一眼秒:

考虑先第一步转化,查询时加点不方便,因为新加的点不会对前面的查询有影响,所以离线下询问先给点加上就行了。

接着考虑处理查询,如果被限死在序列中的话复杂度就会难免带上一个序列长度 len

所以考虑一种能把序列转化为一个值的东西:Hash!

Hash 之后直接处理也是难绷的复杂度,所以我们考虑二分答案。

这样的话我们只需要考虑如何快速求出 u 向上/下 k 步的 Hash 值就行了。

考虑使用倍增,处理出 2^j 步向上/下的 Hash 值。

接着查询的时候二分答案使用倍增算出 Hash 值比一下是否相同 check 即可。

这样我们就完成了倍增做法,欸那这题和树剖有什么关系啊?

想完了倍增做法后,我果断打开了题解发现没有一个重剖做法,所以我考虑重剖如何做。

最初,受到了出题人题解中说可以用长剖树上 k 级祖先的方法来处理 Hash 值的启发,考虑用重剖实现树上 k 级祖先,来代替掉倍增求 Hash 的过程。

仔细一分析,这个复杂度好像不是很优,是 O(n\log^3n)

接下去神犇 lzyqwq 给了我点启发,他看了一眼后说,不就是 CF226E 吗?

受此启发,我们考虑利用重剖的特性。

比起倍增把路径分为了 2^j 长度的链,重剖的特性就是把路径划分为了一条条重链。

那我们考虑直接把 u 开始往上跳的那条重链拿出来,再把 x 开始往上跳的重链也拿出来。

那我们就能把这两段重链直接相消。

如果不匹配呢?

那直接上大力二分即可。

具体写的时候推荐把每条链的转折点 topf 记录进一个 vector 中。

原因是可能两段重链长度不同,所以只能消掉其中的一段,另一段要从已经被消掉的头开始继续往上匹配

也就是我们用 vector 的话,假设原本没被消光的这个点为 x,上段重链的链顶为 topf,那我们这次消完了后就是去掉这个 topf,再压入这个 x 就做完了。

因为害怕 Hash 被卡,所以在此推荐两个大模数:1004535809,167772161。(来源:mrsrz 的代码。)

这两个模数的强度很高,所以我用单 Hash 轻松通过了此题并且拿下最优解。

时间复杂度:O(n+m\log n) 常数最劣情况下约为 3,均摊一下的话常数还是很小的。

代码

代码有些冗长,见谅。

#include<bits/stdc++.h>
#define int long long
#define mid ((l+r)>>1)
#define pb push_back
#define mp make_pair
#define fi first
#define se second
using namespace std;
const int N=200010,mod=167772161,INF=0x3f3f3f3f3f3f3f3f;
const int p=13331;
struct edge
{int v,nxt;}e[N<<1];
struct node
{int u,v,x,y;}q[N];
int n,m,tot,cnt;
int s[N],op[N],h1[N],h2[N],h3[N];
int top[N],id[N],dfn[N],head[N];
int fa[N],dep[N],si[N],son[N];
static char buf[1000000],*paa=buf,*pd=buf;
#define getchar() paa==pd&&(pd=(paa=buf)+fread(buf,1,1000000,stdin),paa==pd)?EOF:*paa++
int read()
{
    char c=getchar();int x=0,f=1;
    while(!isdigit(c)){if(c=='-') f=-1;c=getchar();}
    while(isdigit(c)) x=(x<<1)+(x<<3)+(c^48),c=getchar();
    return x*f;
}
void print(int x)
{
    if(x<0) putchar('-'),x=-x;
    if(x>9) print(x/10);
    putchar(x%10+'0');
}
void add(int u,int v){e[++tot].v=v,e[tot].nxt=head[u],head[u]=tot;}
void dfs1(int u,int ff)
{
    fa[u]=ff;dep[u]=dep[ff]+1,si[u]=1;
    for(int i=head[u];i;i=e[i].nxt)
    {
        int v=e[i].v;
        if(v==ff) continue;
        dfs1(v,u),si[u]+=si[v];
        if(si[v]>si[son[u]]) son[u]=v;
    }
}
void dfs2(int u,int topf)
{
    top[u]=topf,dfn[u]=++cnt,id[cnt]=u;
    if(son[u]) dfs2(son[u],topf);
    for(int i=head[u];i;i=e[i].nxt)
    {
        int v=e[i].v;
        if(v==fa[u]||v==son[u]) continue;
        dfs2(v,v);
    }
}
vector<pair<int,int> > get(int x,int y)
{
    vector<pair<int,int> >l,r;
    while(top[x]!=top[y])
        if(dep[top[x]]<dep[top[y]]) r.pb(mp(top[y],y)),y=fa[top[y]];
        else l.pb(mp(x,top[x])),x=fa[top[x]];
    int lca=dep[x]<dep[y]?x:y;l.pb(mp(x,lca));
    if(y!=lca) r.pb(mp(son[lca],y));
    while(r.size()) l.pb(r.back()),r.pop_back();
    return l;
}
int h(bool opt,int x,int k)
{
    if(opt) return (h2[x-k+1]-h2[x+1]*op[k]%mod+mod)%mod;
    return (h1[x+k-1]-h1[x-1]*op[k]%mod+mod)%mod;
}
int gt(int l,int r){return (h3[r]-h3[l-1]*op[r-l+1]%mod+mod)%mod;}
signed main()
{
    n=read();m=read();read();for(int i=1;i<=n;i++) s[i]=read();op[0]=1;
    for(int i=1,u,v;i<n;i++) u=read(),v=read(),add(u,v),add(v,u);
    for(int i=1,opt,x,y;i<=m;i++)
    {
        opt=read();
        if(opt==2) x=read(),add(x,++n),add(n,x),s[n]=read();
        else q[i].u=read(),q[i].v=read(),q[i].x=read(),q[i].y=read();
    }
    dfs1(1,0),dfs2(1,1);
    for(int i=1;i<=n;i++) 
        h1[i]=(h1[i-1]*p+s[id[i]])%mod,h2[n+1-i]=(h2[n+2-i]*p+s[id[n+1-i]])%mod,
        op[i]=op[i-1]*p%mod,h3[i]=(h3[i-1]*p+i)%mod;
    for(int i=1,a,b,c,d,ans,s,t;i<=m;i++)
    {
        if(!q[i].u) continue;
        a=q[i].u,b=q[i].v,c=q[i].x,d=q[i].y;s=t=ans=0;
        vector<pair<int,int> > f=get(a,b),g=get(c,d);
        while(s<(int)f.size()&&t<(int)g.size())
        {
            int d11=dfn[f[s].fi],d12=dfn[f[s].se],d21=dfn[g[t].fi],d22=dfn[g[t].se];
            bool opf=d11>d12,opt=d21>d22;
            int lf=(opf?d11-d12:d12-d11)+1,lt=(opt?d21-d22:d22-d21)+1,len=min(lf,lt);
            int hf=h(opf,d11,len),hg=h(opt,d21,len),hsh=gt(ans+1,ans+len);
            if((hf+hsh)%mod==hg)
            {
                if(len==lf) s++;
                else f[s].fi=id[d11+(opf?-1:1)*len];
                if (len==lt) t++;
                else g[t].fi=id[d21+(opt?-1:1)*len];
                ans+=len;
            }
            else
            {
                int l=1,r=len;
                while(l<r)
                {
                    hf=h(opf,d11,mid),hg=h(opt,d21,mid),hsh=gt(ans+1,ans+mid);
                    if((hf+hsh)%mod==hg) l=mid+1;
                    else r=mid;
                }
                ans+=l-1;break;
            }
        }
        print(ans);puts("");
    }
    return 0;
}