题解 P5526 【[Ynoi2012]D1T3】

· · 题解

ynoi个个都是毒瘤题……又难想又难写……

可能写的有点乱,所以推荐边看边画图理解。

前置芝士

QTREE6

众所周知白题比黑题难

题解

题意简单粗暴,就是枚举所有路径求不同颜色数然后加起来,带单点修改。只要起点和终点其中一个不同就算不同路径,所有所有路径包括只有一个点的路径,而且有两个以上点的路径要来回算两次 其实认真看样例都看得出来

教练我会O(n^3)的dfs 明显过不了,但可以拿来对拍

然后我们观察一下这个询问操作,我们会发现每种颜色对答案的贡献是独立的,简单说就是dfs时分开每种颜色出现次数,最后全部加起来也是答案。

于是就可以把询问全部拆开,按颜色排序,这样一来实际要处理的就只有两种颜色。或者说问题就转换为了给一棵树,每个点可能是黑色或白色,每次问路径上存在至少一个点是黑色的路径的数量是多少,带修改。(可以随便找组小数据理解一下)

考虑到问存在黑点的路径数有点难搞,我们反过来求所有点都是白点的路径数。这其实就是每个白连通块的大小的平方。

如果改动一个点的颜色,那路径数的变化量就是在这个点为黑色时与这个点相邻的白点所在连通块大小之和加上这些连通块大小两两相乘的和+1。(也推荐画图理解一下)

这个式子简单变化一下就是连通块大小之和的平方减去每个连通块大小的平方再加上连通块大小之和。

于是做法就来了,像qtree6一样用lct维护点所在连通块的信息,然后在轻儿子有变动时把这些平方和啥的处理一下。于是就能得到答案的差分序列,弄个前缀和然后输出就行了。

由于每个询问会拆成把一个点黑改白和把一个点白改黑,再算上初始化的n次白改黑和黑改白,总的操作次数是2n+2m,于是时间复杂度O((n+m)logn),空间O(n+m)。大概就是lct1e6的时间,写的太丑可能会被卡常……

代码

为了防止无脑抄题解,以下代码有一个错误,直接提交会全部WA真的有认真看的人不看代码也能自己写出来吧

#include<iostream>
#include<cstdlib>
#include<cstdio>
#include<algorithm>
using namespace std;
typedef long long ll;
const int N=4e5+5;
inline int read()
{
    int sum=0;
    char c=getchar();
    while(c<'0'||c>'9')c=getchar();
    while(c>='0'&&c<='9')
    {
        sum=sum*10+c-'0';
        c=getchar();
    }
    return sum;
}
inline void write(ll x)
{
    if(x>9)write(x/10);
    putchar(x%10+'0');
}
struct miaow
{
    int ys;
    int siz,lh,rh;
    ll ssum;
    ll xsum,xsum2;
    int fz;
    int son[2],sfa,xfa;
    miaow(int ls=0,int rs=0,int sfa=0,int xfa=0,int ys=0,int siz=1,int lh=1,int rh=1,ll ssum=1,ll xsum=0,ll xsum2=0,int fz=0):
        sfa(sfa),xfa(xfa),ys(ys),siz(siz),lh(lh),rh(rh),ssum(ssum),xsum(xsum),xsum2(xsum2),fz(fz){son[0]=ls,son[1]=rs;}
}t[N];
#define ys(x) t[x].ys
#define siz(x) t[x].siz
#define lh(x) t[x].lh
#define rh(x) t[x].rh
#define ssum(x) t[x].ssum
#define xsum(x) t[x].xsum
#define xsum2(x) t[x].xsum2
#define fz(x) t[x].fz
#define ls(x) t[x].son[0]
#define rs(x) t[x].son[1]
#define ks(x,k) t[x].son[k]
#define sfa(x) t[x].sfa
#define xfa(x) t[x].xfa
#define zsum(x) (ssum(x)+xsum(x))
inline int fx(int x)
{
    if(!sfa(x))return -1;
    return rs(sfa(x))==x;
}
inline void upd(int x)
{
    if(fz(x))return;
    siz(x)=siz(ls(x))+siz(rs(x))+1+xsum(x);
    lh(x)=lh(ls(x))+((lh(ls(x))!=siz(ls(x))||ys(x))?0:(1+xsum(x)+lh(rs(x))));
    rh(x)=rh(rs(x))+((rh(rs(x))!=siz(rs(x))||ys(x))?0:(1+xsum(x)+rh(ls(x))));
    ssum(x)=lh(x)+xsum(x);
}
inline void pdfz(int x)
{
    if(!fz(x))return;
    fz(x)=0;
    swap(ls(x),rs(x));
    if(ls(x))
    {
        swap(lh(ls(x)),rh(ls(x)));
        ssum(ls(x))=lh(ls(x));
        fz(ls(x))^=1;
    }
    if(rs(x))
    {
        swap(lh(rs(x)),rh(rs(x)));
        ssum(rs(x))=lh(rs(x));
        fz(rs(x))^=1;
    }
    upd(x);
}
inline void rotate(int x)
{
    if(!sfa(x))return;
    int y=sfa(x),k=fx(x),z=sfa(y);
    swap(xfa(x),xfa(y));
    ks(y,k)=ks(x,k^1);
    if(ks(x,k^1))sfa(ks(x,k^1))=y;
    if(sfa(y))
    {
        int y2=sfa(y),k2=fx(y);
        ks(y2,k2)=x;
    }
    ks(x,k^1)=y;
    sfa(x)=sfa(y);
    sfa(y)=x;
    upd(y);
    upd(x);
}
inline void pd2(int x)
{
    if(sfa(x))pd2(sfa(x));
    pdfz(x);
}
inline void splay(int x)
{
    pd2(x);
    if(!sfa(x))return;
    int y=sfa(x),k=fx(x);
    while(y)
    {
        int z=sfa(y);
        if(!z)
        {
            rotate(x);
            return;
        }
        if(fx(y)==k)rotate(y);
        else rotate(x);
        rotate(x);
        y=sfa(x),k=fx(x);
    }
}
inline void access(int x)
{
    splay(x);
    if(rs(x))
    {
        sfa(rs(x))=0;
        xfa(rs(x))=x;
        xsum(x)+=ssum(rs(x));
        xsum2(x)+=ssum(rs(x))*ssum(rs(x));
        rs(x)=0;
        upd(x);
    }
    int x2=x;
    while(xfa(x))
    {
        int y=xfa(x);
        splay(y);
        if(rs(y))
        {
            sfa(rs(y))=0;
            xfa(rs(y))=y;
            xsum(y)+=ssum(rs(y));
            xsum2(y)+=ssum(rs(y))*ssum(rs(y));
        }
        xsum(y)-=ssum(x);
        xsum2(y)-=ssum(x)*ssum(x);
        rs(y)=x;
        sfa(x)=y;
        upd(y);
        xfa(x)=0;
        x=y;
    }
}
inline void hg(int x)
{
    access(x);
    splay(x);
    swap(lh(x),rh(x));
    ssum(x)=lh(x);
    fz(x)^=1;
}
inline ll cx(int x,int y)
{
    hg(x);
    pdfz(x);
    if(rs(x))
    {
        upd(rs(x));
        sfa(rs(x))=0;
        xfa(rs(x))=x;
        xsum(x)+=ssum(rs(x));
        xsum2(x)+=ssum(rs(x))*ssum(rs(x));
        rs(x)=0;
        upd(x);
    }
    ys(x)=y;
    upd(x);
    ll sum=(xsum(x)*xsum(x)-xsum2(x)),sum2=xsum(x)*2+1;
    return sum+sum2;
}
inline ll cz(int x,int y)
{
    ll sum=cx(x,y);
    return sum;
}

struct meow
{
    int a,b,c,d;
    inline bool operator < (const meow& qwe) const
    {
        return c==qwe.c?(a==qwe.a?d<qwe.d:a<qwe.a):c<qwe.c;
    }
    meow(int a=0,int b=0,int c=0,int d=0):a(a),b(b),c(c),d(d){}
}q[N*5];
int js=1;

struct nya
{
    int f,t,l;
    nya(int f=0,int t=0,int l=0):f(f),t(t),l(l){}
}imap[N*2+1];
int str[N]={0},cou=1;
inline void jb(int f,int t)
{
    imap[cou]=nya(f,t,str[f]);str[f]=cou++;
    imap[cou]=nya(t,f,str[t]);str[t]=cou++;
}
inline void dfs(int o,int fa)
{
    for(int i=str[o];i;i=imap[i].l)
    {
        int tt=imap[i].t;
        if(tt==fa)continue;
        dfs(tt,o);
        xsum(o)+=ssum(tt);
        xsum2(o)+=ssum(tt)*ssum(tt);
        xfa(tt)=o;
    }
    upd(o);
}
int cs[N]={0},n,m;
ll ans[N]={0};
int main()
{
    n=read(),m=read();
    int zs=m+1;
    for(int i=1;i<=n;++i)
    {
        int x=read();
        cs[i]=x;
        q[js]=meow(0,i,x,1);
        ++js;
    }
    for(int i=1;i<n;++i)
    {
        int x=read(),y=read();
        jb(x,y);
    }
    for(int i=1;i<=m;++i)
    {
        int x=read(),y=read();
        q[js]=meow(i,x,cs[x],0);++js;
        q[js]=meow(i,x,y,1);++js;
        cs[x]=y;
    }
    for(int i=1;i<=n;++i)
    {
        q[js]=meow(zs,i,cs[i],0);++js;
    }
    siz(0)=0,lh(0)=0,rh(0)=0,ssum(0)=0;
    dfs(1,0);
    sort(q+1,q+js);
    int col=0;
    for(int i=1;i<js;++i)
    {
        int a=q[i].a,b=q[i].b,c=q[i].c,d=q[i].d;
        if(c!=col)
        {
            col=c;
        }
        ll sum=cz(b,d);
        if(d==0)ans[a]-=sum;
        else ans[a]+=sum;
    }
    write(ans[0]),putchar('\n');
    for(int i=1;i<=m;++i)
    {
        ans[i]+=ans[i-1];
        write(ans[i]);
        putchar('\n');
    }
}