题解:CF1824C LuoTianyi and XOR-Tree

· · 题解

一个只用到线段树合并的(小常数?)单 \log 做法

先做一遍根到叶节点的异或和,接着一个简单的想法:设 dp[u][i] 表示对 u 子树进行操作,使得 u 子树内叶节点点权都为 i 的最小操作次数。

那么可以发现一个重要的性质:对一个相同的 udp[u][i] 要么为 \min\{dp[u][i]\} 要么为 \min\{dp[u][i]\}+1,因为可以对 u 进行操作。

于是可以对每个 u 只记录 dp 最小值,记为 w_u,和所有取到最小值的 i,组成的集合记为 s_u,注意到所有能取到最小值的 i 一定为某个叶节点的点权,因此 i 的个数是可以保证的。

然后考虑合并 u 的所有儿子 v,可以发现

dp[u][i]=\sum_v (w_v+1)-\sum_{v}[i\in s_v]

需要计算 w_u,就需要求出 dp[u][i] 的最小值,前一个 \sum 对所有 i 都是相同的,而后一个 \sum 需要最大化,因此将所有 s_v 并在一起组成一个可重集,取出现次数最多的那些值,就可以组成 s_u,而 w_u 也可以计算出来。

现在考虑优化这个过程,使用线段树合并,每个节点 u 维护一个包含所有 s_u 中元素的权值线段树,每个节点记录出现次数,在线段树合并时进行对位累加,这样再维护一个出现次数 \max 就可以得到出现次数的最大值。

接下来需要删除出现次数小于最大值的部分,注意到一个元素只会被加入一次删除一次,所以只要我们能够精准地找出所有要删掉的元素,暴力遍历再删除的复杂度是没有问题的,于是再维护一个出现次数 \min 来判断区间内是否存在需要删除的元素,在 u 的所有儿子 v 都合并完之后,再做一次 dfs,删掉所有出现次数小于最大值的部分。最后打一个“出现次数重置为 1”的 tag 即可。

时空复杂度 O(n\log W)W 为值域大小。

void pushup(int now){
    Min[now]=min(Min[ls[now]],Min[rs[now]]);
    Max[now]=max(Max[ls[now]],Max[rs[now]]);
}
void pushdown(int now){
    if (tag[now]){
        if (ls[now]) tag[ls[now]]=Min[ls[now]]=Max[ls[now]]=1;
        if (rs[now]) tag[rs[now]]=Min[rs[now]]=Max[rs[now]]=1;
        tag[now]=0;
    }
}
int Merge(int now, int las, int l, int r){
    if (!now || !las) return now|las;
    if (l==r) return Max[now]+=Max[las],Min[now]+=Min[las],now;
    int mid=(l+r)>>1; pushdown(now); pushdown(las);
    ls[now]=Merge(ls[now],ls[las],l,mid);
    rs[now]=Merge(rs[now],rs[las],mid+1,r);
    pushup(now);
    return now;
}
void update(int &now, int l, int r, int x){
    if (!now) now=++tsiz;
    if (l==r) return Min[now]=Max[now]=1,void();
    int mid=(l+r)>>1;
    mid>=x?update(ls[now],l,mid,x):update(rs[now],mid+1,r,x);
    pushup(now);
}
void del(int &now, int l, int r){
    if (!now) return;
    if (Min[now]==mx) return;
    if (l==r){
        if (Min[now]<mx) Min[now]=INF,Max[now]=ls[now]=rs[now]=0,now=0;
        return;
    }
    int mid=(l+r)>>1; pushdown(now);
    del(ls[now],l,mid); del(rs[now],mid+1,r);
    if (!ls[now] && !rs[now]) now=0;
    else pushup(now);
}
void dfs(int u, int f){
    a[u]^=a[f]; int c=0;
    for (int v:G[u])
        if (v!=f){
            dfs(v,u); c++;
            val[u]+=val[v]+1;
            rt[u]=Merge(rt[u],rt[v],0,L);
        }
    if (!c) return update(rt[u],0,L,a[u]),void();
    mx=Max[rt[u]]; val[u]-=mx; del(rt[u],0,L);
    tag[rt[u]]=Min[rt[u]]=Max[rt[u]]=1;
}
int Find(int now, int l, int r){
    if (!now) return 0;
    if (l==r) return 1;
    int mid=(l+r)>>1;
    return Find(ls[now],l,mid);
}
int main(){
    scanf("%d",&n);
    memset(Min,0x3f,sizeof(Min)); INF=Min[0];
    for (int i=1; i<=n; i++) scanf("%d",&a[i]);
    for (int i=1; i<n; i++){
        scanf("%d%d",&u,&v);
        G[u].push_back(v),G[v].push_back(u);
    }
    dfs(1,0);
    printf("%d\n",val[1]+(Find(rt[1],0,L)?0:1));
}