题解:P11803 【MX-X9-T7】『GROI-R3』此花绽放之时

· · 题解

P11803 【MX-X9-T7】『GROI-R3』此花绽放之时

连通块整体加的操作看上去很不可做,不妨对颜色连通块设置一个代表元,钦定这个连通块中深度最小的点是代表元,那么连通块整体加可以沿用类似线段树懒标记的思想,在代表元处打标记 (c,w),表示如果节点颜色为 c,则点权 +v,可以对每种颜色维护一个动态开点线段树来打标记。

对于连通块整体加,需要找到 x 所在连通块的代表元并打标记。对于单点询问,需要将 1\rightarrow x 的标记正确下传,我们记这个下传操作为 \mathrm{spread}(x),如何高效维护这些操作?套路地考虑树剖,树剖的核心思想是正确维护重儿子。在代表元打标记时,将代表元所在重链的向下部分打上标记。\mathrm{spread}(x) 时,考虑 x 向上的重链,不难发现可以借助重链传递标记,并且每个点只需要接受、下传和其自身颜色相同的标记,具体的,在链头处下传和链头颜色相同的标记,这个可以在重链上找到极长的颜色连续段,对每个重链维护 \texttt{ODT} 即可快速找到并打上标记,然后在链尾处通过轻边向下一个重链头传递标记即可。这样做的正确性在于打标记时是对极长连续段进行标记,因此重链间的标记可以通过首尾传递正确维护。但是由于我们需要保留标记以作用于其他轻儿子,所以标记可能被重复下传,考虑对每个位置的每个颜色开一个桶 undo_{u,c} 表示点 u 从父亲处接受的颜色为 c 的标记权值和,\mathrm{spread} 操作时从父亲继承的标记应该减去这部分算重的权值,同时正确维护 undo

对于 u\leftrightarrow v 路径推平,如果直接覆盖,那么路径上方的标记就无法正确下传,因此需要将 1\leftrightarrow u,1\leftrightarrow v 上的标记下传,这就是 \mathrm{spread}(u),\mathrm{spread}(v) 操作,然后可以在 \texttt{ODT} 上推平,不过注意 \mathrm{spread} 操作已经将原先颜色 c 的重链传递完毕,覆盖 u\leftrightarrow vc 之后这些标记就无效了,因此需要用 undo 清空掉。

这样就做完了,需要动态开点线段树、\texttt{ODT} 和树状数组,时间复杂度是 O(n\log^2 n)

#include <bits/stdc++.h>

using namespace std;

typedef long long ll;
typedef pair<int,int> PII;
#define fir first
#define sec second
const int N=200010;
int n,Q;
int Col[N];
int fa[N];
int h[N],e[N],ne[N],idx;
void add(int u,int v) { e[idx]=v,ne[idx]=h[u],h[u]=idx++; }

int sz[N],son[N],dep[N];
void dfs1(int u)
{
    sz[u]=1; dep[u]=dep[fa[u]]+1;
    for (int i=h[u];i!=-1;i=ne[i]) 
    {
        int v=e[i]; dfs1(v);
        sz[u]+=sz[v];
        if (sz[v]>sz[son[u]]) son[u]=v;
    }
}
int dfn[N],ti,id[N],top[N];
void dfs2(int u,int tp)
{
    id[dfn[u]=++ti]=u; top[u]=tp;
    if (!son[u]) return;
    dfs2(son[u],tp);
    for (int i=h[u];i!=-1;i=ne[i])
    {
        int v=e[i];
        if (v!=son[u]) dfs2(v,v);
    }
}

namespace ODT
{
    struct Node
    {
        int l,r; 
        mutable int v;
        Node(int _l,int _r,int _v) { l=_l,r=_r,v=_v; }
        bool operator < (const Node &o) const { return l<o.l; }
    };
    struct odt
    {
        set<Node> st;
        typedef set<Node>::iterator itr;
        itr split(int x)
        {
            if (x>n) return st.end();
            itr it=--st.upper_bound({x,0,0});
            int l=it->l,r=it->r,v=it->v;
            if (l==x) return it;
            st.erase(it); st.insert({l,x-1,v});
            return st.insert({x,r,v}).first;
        }
        void assign(int l,int r,int v)
        {
            itr rit=split(r+1),lit=split(l); st.erase(lit,rit);
            itr p=st.insert({l,r,v}).first; lit=rit=p;
            while (lit->v==p->v) lit--; lit++;
            while (rit->v==p->v) rit++; rit--;
            l=lit->l,r=rit->r; st.erase(lit,++rit); st.insert({l,r,v});
        }
        PII bel(int x) 
        {
            itr it=--st.upper_bound({x,0,0});
            return {it->l,it->r};
        }
        int ask(int x) 
        {
            itr it=--st.upper_bound({x,0,0});
            return it->v;
        } 
    } col[N];

} using namespace ODT;

namespace sgt
{
    const int T=3e7;
    struct Node
    {
        int lc,rc;
        ll dat;
    } tr[T];
    int rt[N],idx;
    void upd(int &u,int lq,int rq,ll v,int l=1,int r=n)
    {
        if (!u) u=++idx;
        if (lq<=l && r<=rq) { tr[u].dat+=v; return; }
        int mid=(l+r)>>1;
        if (lq<=mid) upd(tr[u].lc,lq,rq,v,l,mid);
        if (rq>mid) upd(tr[u].rc,lq,rq,v,mid+1,r);
    }
    ll ask(int u,int x,int l=1,int r=n)
    {
        if (!u) return 0;
        if (l==r) return tr[u].dat;
        int mid=(l+r)>>1; 
        if (x<=mid) return ask(tr[u].lc,x,l,mid)+tr[u].dat;
        else return ask(tr[u].rc,x,mid+1,r)+tr[u].dat;
    }
} using namespace sgt;

namespace BIT
{
    ll tr[N];
    inline int lowbit(int x) { return x & (-x); }
    inline void upd(int x,ll v) { for (;x<=n;x+=lowbit(x)) tr[x]+=v; }
    inline void modify(int l,int r,ll v) { upd(l,v), upd(r+1,-v); }
    inline ll ask(int x) { ll res=0; for (;x;x-=lowbit(x)) res+=tr[x]; return res; }
}

map<int,ll> undo[N];

void psd(int u)
{
    int c=col[top[u]].ask(dfn[u]);
    auto [l,r]=col[top[u]].bel(dfn[u]);
    ll v=ask(rt[c],dfn[fa[u]])-undo[u][c]; undo[u][c]+=v;
    upd(rt[c],l,r,v); BIT::modify(l,r,v);
}

int buc[N],len;
void spread(int u)
{
    len=0;
    while (top[u]) { buc[++len]=top[u]; u=fa[top[u]]; }
    for (int i=len;i;i--) if (buc[i]!=1) psd(buc[i]);
}

int main()
{
    ios::sync_with_stdio(false);
    cin.tie(0),cout.tie(0);

    cin >> n >> Q;
    for (int i=1;i<=n;i++) cin >> Col[i];
    memset(h,-1,sizeof(h));
    for (int i=2;i<=n;i++) 
    {
        cin >> fa[i];
        add(fa[i],i);
    }
    dfs1(1); dfs2(1,1);
    for (int i=1;i<=n;i++) col[i].st.insert({0,n+1,0});
    for (int i=1;i<=n;i++) col[top[i]].assign(dfn[i],dfn[i],Col[i]);

    while (Q--)
    {
        int op; cin >> op;
        if (op==1)
        {
            int u,v,c;
            cin >> u >> v >> c;
            spread(u), spread(v);
            while (top[u]!=top[v])
            {
                if (dep[top[u]]<dep[top[v]]) swap(u,v);
                col[top[u]].assign(dfn[top[u]],dfn[u],c);
                undo[top[u]][c]=ask(rt[c],dfn[fa[top[u]]]); //清空
                u=fa[top[u]];
            }
            if (dep[u]<dep[v]) swap(u,v);
            col[top[u]].assign(dfn[v],dfn[u],c);
            if (v==top[u]) undo[v][c]=ask(rt[c],dfn[fa[v]]);
        }
        else if (op==2)
        {
            int u,w; 
            cin >> u >> w;
            int tc=col[top[u]].ask(dfn[u]);
            int up=u;
            while (u)
            {
                int c=col[top[u]].ask(dfn[u]); if (c!=tc) break;
                auto [l,r]=col[top[u]].bel(dfn[u]); 
                if (l!=dfn[top[u]]) 
                {
                    up=id[l];
                    break;
                }
                else
                {
                    up=top[u];
                    u=fa[top[u]];
                }
            }
            auto [l,r]=col[top[up]].bel(dfn[up]);
            upd(rt[tc],l,r,w), BIT::modify(l,r,w);
        }
        else
        {
            int u; cin >> u;
            spread(u);
            cout << BIT::ask(dfn[u]) << "\n";
        }
    }

    return 0;
}