题解:P10834 [COTS 2023] 题 Zadatak

· · 题解

线段树合并忘记怎么写了,所以场上写的是启发式合并。

首先我们要手玩小数据,找到图形拼一起有什么性质。容易发现:

那我们需要在启发式合并的同时,维护一个数据结构,可以快速维护一个集合 S 的如下操作:

可以用动态开点权值线段树维护,每个节点维护三个关键量:

时间复杂度 \Theta(n\log^2n)

注意这个做法有点卡空间,如果直接写只能拿 60 分。在将 y 合并到 x 的时候,要释放 y 的内存。以下给出对 unordered_mapvector 释放内存的代码:

mp.clear();
mp.rehash(0);
vec.clear();
vec.shrink_to_fit();

完整代码如下:

#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define fr first
#define sc second
#define pii pair<int,int>
#define yes cout<<"Yes\n"
#define no cout<<"No\n"
#define fo(i,l,r) for(int i=l;i<=r;i++)
#define ro(i,r,l) for(int i=r;i>=l;i--)
const int N=1e5+5,V=1e6;
int n,q,id[N<<1];
ll a[N];
namespace dsu{
    #define mid ((l+r)>>1)
    struct sgm{
        struct node{
            int lc,rc,s;
            ll s0,s1;
        };
        int c=1;
        vector<node>d;
        sgm(){
            d.push_back({0,0,0,0,0});
            d.push_back({0,0,0,0,0});
        }
        void modify(int x,int l,int r,int t,int k){
            if (l==r){
                d[x].s1+=k*1ll*l*l;
                d[x].s+=k;
                return;
            }
            if (t<=mid){
                if (!d[x].lc){
                    d[x].lc=++c;
                    d.push_back({0,0,0,0,0});
                }
                modify(d[x].lc,l,mid,t,k);
            }
            else{
                if (!d[x].rc){
                    d[x].rc=++c;
                    d.push_back({0,0,0,0,0});
                }
                modify(d[x].rc,mid+1,r,t,k);
            }
            d[x].s=d[d[x].lc].s+d[d[x].rc].s;
            d[x].s0=d[d[x].lc].s0+((d[d[x].lc].s&1)?d[d[x].rc].s1:d[d[x].rc].s0);
            d[x].s1=d[d[x].lc].s1+((d[d[x].lc].s&1)?d[d[x].rc].s0:d[d[x].rc].s1);
        }
    }sg[N];
    unordered_map<int,int>val[N];
    void insert(int x,int k){
        if (!val[x].count(k)){
            sg[x].modify(1,1,V,k,1);
            val[x][k]=1;
        }
        else{
            sg[x].modify(1,1,V,k,-1);
            val[x].erase(k);
        }
    }
    ll merge(int x,int y){
        if (val[x].size()<val[y].size())
            swap(x,y);
        id[++n]=x;
        for (auto i:val[y])
            insert(x,i.fr);
        val[y].clear();
        val[y].rehash(0);
        sg[y].d.clear();
        sg[y].d.shrink_to_fit();
        return abs(sg[x].d[1].s0-sg[x].d[1].s1);
    }
}
using namespace dsu;
signed main(){
    ios::sync_with_stdio(0);
    cin.tie(0),cout.tie(0);
    cin>>n,q=n-1;
    fo(i,1,n){
        cin>>a[i];
        id[i]=i;
        insert(i,a[i]);
    }
    while (q--){
        int x,y;
        cin>>x>>y;
        x=id[x],y=id[y];
        ll rs=0;
        rs+=abs(sg[x].d[1].s0-sg[x].d[1].s1);
        rs+=abs(sg[y].d[1].s0-sg[y].d[1].s1);
        ll rt=merge(x,y);
        rs-=rt,rs/=2ll;
        cout<<rs<<'\n';
    }
    return 0;
}