P4751题解

· · 题解

前置芝士:P4719

众所周知,出题人为了卡离线算法和 O(n\log^2n) 的树剖出了这题。

看到现在的题解里面没有一个是树剖的 O(2^3\times n\log^2n) 做法,那么我就来写一篇题解教大家如何卡常吧。

首先我把我的 P4719 代码贴上来。

#include<cstdio>
namespace something_useful
{
    inline int max(int a,int b){ return a>b?a:b;}
    inline int read()
    {
        int s=0,w=1;char ch;
        while((ch=getchar())>'9'||ch<'0') if(ch=='-') w=-1;
        while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();
        return s*w;
    }
    struct line
    {
        int nx;
        int to;
    };
    struct graph
    {
        int head[100001];
        line a[200001];
        int tot;
        void ad(int u,int v)
        {
            tot++;
            a[tot].to=v;
            a[tot].nx=head[u];
            head[u]=tot;
        }
        void add(int u,int v){ ad(u,v),ad(v,u);}
        int st(int num){ return head[num];}
        int to(int num){ return a[num].to;}
        int nx(int num){ return a[num].nx;}
    };
    struct mat
    {
        int a[3][3];
        int m,n;
        mat (){ a[1][1]=a[1][2]=a[2][1]=a[2][2]=0;}
        mat (int mm,int nn){ m=mm,n=nn,mat();}
        mat operator * (mat b)
        {
            mat ans;
            for(int i=1;i<=2;i++)
                for(int k=1;k<=2;k++)
                {
                    int s=a[i][k];
                    for(int j=1;j<=2;j++)
                        ans.a[i][j]=max(ans.a[i][j],s+b.a[k][j]);
                }
            return ans;
        }
        void operator = (mat b)
        {
            for(int i=1;i<=2;i++)
                for(int j=1;j<=2;j++)
                    a[i][j]=b.a[i][j];
        }
    };
}
using namespace something_useful;
int f[100001][2];
int hson[100001];
int siz[100001];
int dfn[100001];
int top[100001];
int end[100001];
int rk[100001];
int fa[100001];
mat t[400001];
mat g[100001];
int a[100001];
graph in;
int tot;
int n,m;
void read_all()
{
    n=read(),m=read();
    for(int i=1;i<=n;i++) a[i]=read();
    for(int i=1;i<n;i++) in.add(read(),read());
}
void dfs(int num)
{
    siz[num]=1;f[num][1]=a[num];
    for(int i=in.st(num);i;i=in.nx(i))
    {
        int p=in.to(i);
        if(p==fa[num]) continue;
        fa[p]=num;dfs(p);
        siz[num]+=siz[p];
        if(siz[p]>siz[hson[num]]) hson[num]=p;
        f[num][1]+=f[p][0],f[num][0]+=max(f[p][1],f[p][0]);
    }
}
void dfs(int num,int t)
{
    top[num]=t;
    dfn[num]=++tot;
    rk[tot]=num;
    end[t]=tot;
    if(hson[num]) dfs(hson[num],t);
    g[num].a[2][2]=-0x3f3f3f3f;g[num].a[2][1]=a[num];
    for(int i=in.st(num);i;i=in.nx(i))
    {
        int p=in.to(i);
        if(p==fa[num]||p==hson[num]) continue;
        dfs(p,p);
        g[num].a[1][1]+=max(f[p][0],f[p][1]);
        g[num].a[2][1]+=f[p][0];
    }
    g[num].a[1][2]=g[num].a[1][1];
}
void push_up(int p){ t[p]=t[p*2]*t[p*2|1];}
void build(int p,int pl,int pr)
{
    if(pl==pr)
    {
        t[p]=g[rk[pl]];
        return ;
    }
    int mid=(pl+pr)>>1;
    build(p*2,pl,mid);
    build(p*2|1,mid+1,pr);
    push_up(p);
}
void init()
{
    dfs(1);dfs(1,1);
    build(1,1,n);
}
mat get(int p,int pl,int pr,int l,int r)
{
    if(l<=pl&&pr<=r) return t[p];
    int mid=(pl+pr)>>1;
    if(mid<l) return get(p*2|1,mid+1,pr,l,r);
    if(r<=mid) return get(p*2,pl,mid,l,r);
    return get(p*2,pl,mid,l,r)*get(p*2|1,mid+1,pr,l,r);
}
void change(int p,int pl,int pr,int x)
{
    if(pl==pr){ t[p]=g[rk[x]];return ;}
    int mid=(pl+pr)>>1;
    if(x<=mid) change(p*2,pl,mid,x);
    else change(p*2|1,mid+1,pr,x);
    push_up(p);
}
void update(int pos,int val)
{
    g[pos].a[2][1]+=val-a[pos];
    a[pos]=val;
    while(pos)
    {
        mat last=get(1,1,n,dfn[top[pos]],end[top[pos]]);
        change(1,1,n,dfn[pos]);
        mat now=get(1,1,n,dfn[top[pos]],end[top[pos]]);
        pos=fa[top[pos]];
        g[pos].a[1][1]+=max(now.a[1][1],now.a[2][1])-max(last.a[1][1],last.a[2][1]);
        g[pos].a[1][2]=g[pos].a[1][1];
        g[pos].a[2][1]+=now.a[1][1]-last.a[1][1];
    }
}
void run()
{
    int pos=read(),val=read();
    update(pos,val);
    mat ans=get(1,1,n,1,end[1]);
    printf("%d\n",max(ans.a[1][1],ans.a[2][1]));
}
int main()
{
    read_all();
    init();
    for(int i=1;i<=m;i++) run();
    return 0;
}

前面的那个 namespace 是我的链式前向星模板和快读,大家可以不用管它。

下面我们来看哪里需要更改。

第一步:数组大小

原题的数据范围是 n\le 10^5 ,现在被改成了 n\le 10^6 ,对应的,我们的数组大小也要开大十倍。

注意:不要忘了线段树的大小要开到 4\times 10^6

第二步:强制在线

注意到这一题是强制在线的,所以我们在读入节点编号的时候要异或上一次的答案。

注意:

  1. 只有节点编号要异或,修改的值是不用的;

  2. 别忘了给 lastans 赋初值以及每次算完答案要赋值。

至此,我们的修改部分完成了。

提交之后,你有可能会发现 TLE 了(我挂了一个点)或者是 MLE。

当然,到这里大家可以选择第十个点暴力直接 A 题,但是本着学习新技巧而不是 A 题的态度,我们开始卡常。

因为比较简单,所以先讲一下 MLE 的解决方案。

第一步:矩阵数组

本题的矩阵是 2\times2 的,所以我们只需要把数组大小改成类似于 a[2][2] 就可以了,这样子与原来的 a[3][3] 对比,矩阵的空间会小到原来的一半。

第二步:存图

不知道有多少小伙伴用的是 vector,如果还是 MLE,可以考虑改成链式前向星,节省空间。

注意:如果你算过了你的空间确实不会炸,而且基本没有点 AC,那么大概率是写挂了,建议看一眼有没有无限递归。

先来我们来讲 TLE 解决方案。

第一步:读入

使用快读可以减少时间复杂度,大家应该都会,这里不细说。

第二步:矩阵乘法

本题的矩阵是 2\times2 的,所以我们可以不使用循环,直接将矩阵乘法展开,代码放到下面:

mat operator * (mat b)
{
    mat ans;
    ans.a[0][0]=max(a[0][0]+b.a[0][0],a[0][1]+b.a[1][0]);
    ans.a[1][0]=max(a[1][0]+b.a[0][0],a[1][1]+b.a[1][0]);
    ans.a[0][1]=max(a[0][0]+b.a[0][1],a[0][1]+b.a[1][1]);
    ans.a[1][1]=max(a[1][0]+b.a[0][1],a[1][1]+b.a[1][1]);
    return ans;
}

第三步:线段树

这是卡常卡的最多的一部分。

使用树剖的时候,时间复杂度的瓶颈在于对某个点权值修改的部分,现在我们对这部分进行卡常。

可以发现,我们在查询线段树的时候每次只会查找一个重链的矩阵之积,考虑对每一个重链单独开一个线段树,以减少常数。

有人可能会问:你这不就是一个水到极致,基本没用的常数优化吗,给你一条链不就相当于啥都没有。

真的是这样吗?我们来看一眼变化。

原来的代码:

void update(int pos,int val)
{
    g[pos].a[2][1]+=val-a[pos];
    a[pos]=val;
    while(pos)
    {
        mat last=get(1,1,n,dfn[top[pos]],end[top[pos]]);
        change(1,1,n,dfn[pos]);
        mat now=get(1,1,n,dfn[top[pos]],end[top[pos]]);
        pos=fa[top[pos]];
        g[pos].a[1][1]+=max(now.a[1][1],now.a[2][1])-max(last.a[1][1],last.a[2][1]);
        g[pos].a[1][2]=g[pos].a[1][1];
        g[pos].a[2][1]+=now.a[1][1]-last.a[1][1];
    }
}

更改过后的代码:

void update(int pos,int val)
{
    g[pos].a[1][0]+=val-a[pos];
    a[pos]=val;
    while(pos)
    {
        mat last=t[root[top[pos]]];
        change(root[top[pos]],dfn[top[pos]],end[top[pos]],dfn[pos]);
        mat now=t[root[top[pos]]];
        pos=fa[top[pos]];
        g[pos].a[0][0]+=max(now.a[0][0],now.a[1][0])-max(last.a[0][0],last.a[1][0]);
        g[pos].a[0][1]=g[pos].a[0][0];
        g[pos].a[1][0]+=now.a[0][0]-last.a[0][0];
    }
}

注意一下:你可能会发现修改的时候矩阵的位置不一样,这是因为我原来 MLE 过,然后把矩阵从 3\times 3 改成 2\times 2 了。

有什么变化呢?

没错!循环中少了两次函数的调用!

至此,我们的树剖已经可以通过这道题。

最后,给大家放上完整代码:

#include<cstdio>
namespace something_useful
{
    inline int max(int a,int b){ return a>b?a:b;}
    inline int read()
    {
        int s=0,w=1;char ch;
        while((ch=getchar())>'9'||ch<'0') if(ch=='-') w=-1;
        while(ch>='0'&&ch<='9') s=s*10+(ch^48),ch=getchar();
        return s*w;
    }
    struct line
    {
        int nx;
        int to;
    };
    struct graph
    {
        int head[1000001];
        line a[2000000];
        int tot;
        void ad(int u,int v)
        {
            tot++;
            a[tot].to=v;
            a[tot].nx=head[u];
            head[u]=tot;
        }
        void add(int u,int v){ ad(u,v),ad(v,u);}
        int st(int num){ return head[num];}
        int to(int num){ return a[num].to;}
        int nx(int num){ return a[num].nx;}
    };
    struct mat
    {
        int a[2][2];
        mat operator * (mat b)
        {
            mat ans;
            ans.a[0][0]=max(a[0][0]+b.a[0][0],a[0][1]+b.a[1][0]);
            ans.a[1][0]=max(a[1][0]+b.a[0][0],a[1][1]+b.a[1][0]);
            ans.a[0][1]=max(a[0][0]+b.a[0][1],a[0][1]+b.a[1][1]);
            ans.a[1][1]=max(a[1][0]+b.a[0][1],a[1][1]+b.a[1][1]);
            return ans;
        }
    };
}
using namespace something_useful;
int f[1000001][2];
int hson[1000001];
int root[1000001];
int siz[1000001];
int dfn[1000001];
int top[1000001];
int end[1000001];
int rk[1000001];
int fa[1000001];
int ls[4000001];
int rs[4000001];
mat t[4000001];
mat g[1000001];
int a[1000001];
graph in;
int last;
int tot;
int n,m;
int nod;
void read_all()
{
    n=read(),m=read();
    for(int i=1;i<=n;i++) a[i]=read();
    for(int i=1;i<n;i++) in.add(read(),read());
}
void dfs(int num)
{
    siz[num]=1;f[num][1]=a[num];
    for(int i=in.st(num);i;i=in.nx(i))
    {
        int p=in.to(i);
        if(p==fa[num]) continue;
        fa[p]=num;dfs(p);
        siz[num]+=siz[p];
        if(siz[p]>siz[hson[num]]) hson[num]=p;
        f[num][1]+=f[p][0],f[num][0]+=max(f[p][1],f[p][0]);
    }
}
void dfs(int num,int t)
{
    top[num]=t;
    dfn[num]=++tot;
    rk[tot]=num;
    end[t]=tot;
    if(hson[num]) dfs(hson[num],t);
    g[num].a[1][1]=-0x3f3f3f3f;g[num].a[1][0]=a[num];
    for(int i=in.st(num);i;i=in.nx(i))
    {
        int p=in.to(i);
        if(p==fa[num]||p==hson[num]) continue;
        dfs(p,p);
        g[num].a[0][0]+=max(f[p][0],f[p][1]);
        g[num].a[1][0]+=f[p][0];
    }
    g[num].a[0][1]=g[num].a[0][0];
}
void push_up(int p){ t[p]=t[ls[p]]*t[rs[p]];}
void build(int &p,int pl,int pr)
{
    p=++nod;
    if(pl==pr)
    {
        t[p]=g[rk[pl]];
        return ;
    }
    int mid=(pl+pr)>>1;
    build(ls[p],pl,mid);
    build(rs[p],mid+1,pr);
    push_up(p);
}
void init()
{
    dfs(1);dfs(1,1);
    for(int i=1;i<=n;i++) if(top[i]==i) build(root[i],dfn[i],end[i]);
}
mat get(int p,int pl,int pr,int l,int r)
{
    if(l<=pl&&pr<=r) return t[p];
    int mid=(pl+pr)>>1;
    if(mid<l) return get(rs[p],mid+1,pr,l,r);
    if(r<=mid) return get(ls[p],pl,mid,l,r);
    return get(ls[p],pl,mid,l,r)*get(rs[p],mid+1,pr,l,r);
}
void change(int p,int pl,int pr,int x)
{
    if(pl==pr){ t[p]=g[rk[x]];return ;}
    int mid=(pl+pr)>>1;
    if(x<=mid) change(ls[p],pl,mid,x);
    else change(rs[p],mid+1,pr,x);
    push_up(p);
}
void update(int pos,int val)
{
    g[pos].a[1][0]+=val-a[pos];
    a[pos]=val;
    while(pos)
    {
        mat last=t[root[top[pos]]];
        change(root[top[pos]],dfn[top[pos]],end[top[pos]],dfn[pos]);
        mat now=t[root[top[pos]]];
        pos=fa[top[pos]];
        g[pos].a[0][0]+=max(now.a[0][0],now.a[1][0])-max(last.a[0][0],last.a[1][0]);
        g[pos].a[0][1]=g[pos].a[0][0];
        g[pos].a[1][0]+=now.a[0][0]-last.a[0][0];
    }
}
void run()
{
    int pos=read()^last,val=read();
    update(pos,val);
    mat ans=t[root[1]];
    printf("%d\n",last=max(ans.a[0][0],ans.a[1][0]));
}
int main()
{
    read_all();
    init();
    for(int i=1;i<=m;i++) run();
    return 0;
}

感谢观看!