CF1856E2 题解

· · 题解

从题解区看来好像大家都觉得官方题解的思路是 \mathcal O(\dfrac{n\sqrt{n}\log n}{w}) 的,但是它其实没有 \log ,下文会证明复杂度。

思路

假设读者已经知道了 E1 的思路。

题目可以转化成对于每个点将它的子树内所有点黑白染色,使得颜色不同且不在同一个当前节点的儿子的子树内的对数最大,求最大的对数的总和。

首先有一个结论:存在一个最优解使得当前点的每个儿子子树内的节点颜色相同。

证明:一个儿子子树内的节点对答案的贡献只与子树外的点有关,因此这个子树内的点在外面状态不变的情况下染黑或白的贡献是固定的,所以节点颜色一定相同。

此时又可以发现,因为每个儿子子树内的节点颜色都相同,所以对对数的限制,“不在同一个当前节点的儿子的子树内”就没有用了,因此现在可以将问题转化成了:将每个儿子的子树大小划分成两个集合,使得集合元素之和的积最大。

现在可以考虑求出来所有第一个集合的可能元素之和。这是一个经典的背包问题。将这个问题形式化表述:

当前有 n 个物品,每个物品有重量 a_i。设 m=\sum a_i,请对于每个 d \in [0,m] 求出是否可以选出一个集合 S 使得 \sum_{i|S} a_i = d

下面主要参考了 CF 上的这篇博客。

首先可以发现,不同的 a_i 只有 \mathcal O(\sqrt{m}) 种。因为如果 a_i > \sqrt{m},这样的 i 一定 \le \sqrt{m} 个。

因此现在将物品变成了有 \sqrt{m} 个物品,每个物品有重量 a_i 和出现次数 c_i

接下来可以将 c_i 进行二进制优化。将 c_i 拆成 2^0+2^1+\cdots + 2^{k_i} + x_i,其中 x_i < 2^{k_i+1},然后将这 c_i 个物品变成 k_i+2 个物品,重量为原先的重量乘上拆分出来对应的数。

拆分后,用 bitset 优化 dp 即可。使用 vis 表示前 i 个物品能凑出来哪些数,转移到 i+1 时直接让 vis 或上 vis 左移 a_{i+1} 即可。

这样做的复杂度看起来是 \mathcal O(n+\dfrac{m\sqrt{m}\log m }{w}) 的,但是我们可以证明,物品只有 \mathcal O(m\sqrt{m}) 个。

首先 x_i 对应的物品只有 \mathcal O(m\sqrt{m}) 个。接下来考虑对每个 2^j 对应的物品个数。设 c_i 能拆分出 2^ji 集合为 S_j。因为 2^j 是被拆分出来的,所以 \sum_{i \in S_j} 2^j \cdot a_i \le m,\sum_{i \in S_j} a_i \le \dfrac{m}{2^j}。又因为 a_i 互不相同,所以 |S_j| \le \sqrt{\dfrac{2C}{2^j}}

那么 \sum_{j} |S_j| \le \sqrt{2C}\sum_{j}\dfrac{1}{\sqrt{2^j}}。考虑后面的求和,就是 \dfrac{1}{\sqrt 1} + \dfrac{1}{\sqrt 2} + \dfrac{1}{\sqrt 4}+\dfrac{1}{\sqrt 8} + \cdots。注意到 \left(\dfrac{1}{\sqrt 1} + \dfrac{1}{\sqrt 4} + \cdots \right)=\left( 1+\dfrac{1}{2} +\cdots \right)< 2,而 \left(\dfrac{1}{\sqrt 2} + \dfrac{1}{\sqrt 8} + \cdots \right)< \left(\dfrac{1}{\sqrt 1} + \dfrac{1}{\sqrt 4} + \cdots\right),所以 \sum_{j}\dfrac{1}{\sqrt{2^j}}<4\sum_{j} |S_j| \le 4\sqrt {2C},是 \mathcal O(\sqrt{C}) 级别。

因此物品个数是 \mathcal O(\sqrt{C}) 级别,上述算法的复杂度为 \mathcal O(n+\dfrac{m\sqrt m}{w})

回到原问题,我们现在仅仅以 \mathcal O(\sum_v \dfrac{siz_u\sqrt{siz_u}}{w}) 的时间解决了一个节点的问题,如果要求所有节点,这个复杂度还不行。

考虑如果有一个节点的一个儿子的子树大小占到了当前节点的子树大小一半,即 2siz_v \ge siz_u-1,选择方案一定为只选择当前儿子放到第一个集合,剩下的放到第二个集合。其他情况,递归下去的子树大小至少减半,递归树层数是 \mathcal O(\log n)。这样看来,好像还需要加一个 \log,但是实际上还是不用。

考虑计算递归树的第 i 层的复杂度(根节点位于 0 层)。设位于第 i 层的节点集合为 S_i,这一层的运算量为 T(i)=\sum_{j \in S_i} \dfrac{siz_j\sqrt{siz_j}}{w}。因为每递归一层 siz 至少除以二,所以 T(i) \le \sum_{j \in S_i} \dfrac{siz_j\sqrt{\large\frac{n}{2^i}}}{w}=\dfrac{\sqrt n}{\sqrt 2^i w} \sum_{j\in S_i}siz_j \le \dfrac{n\sqrt n}{\sqrt2^i w}

T(i) 求和即为整个算法的运算量:\sum_i T(i) \le \sum_i \dfrac{n\sqrt{n}}{\sqrt 2^i w} = \dfrac{n\sqrt n}{w} \sum_i \dfrac{1}{\sqrt 2^i}。考虑后面的求和,就是 \sum_i \dfrac{1}{\sqrt{2^i}}。上文已经分析过,\sum_i \dfrac{1}{\sqrt{2^i}}<4,所以 \sum_i \dfrac{1}{\sqrt 2^i}<4\sum_i T(i) \le \dfrac{4n\sqrt n}{w},算法复杂度为 \mathcal O(\dfrac{n\sqrt n}{w})

代码

这题的一个细节是,对于每一个问题的 bitset 需要开动态的空间,但是 bitset 只能开静态的空间。有两种解决办法,一种是像下面给出的代码一样手写 bitset,另一种是使用神秘的 template,详见官方题解。

#pragma GCC optimize("Ofast,no-stack-protector,unroll-loops,fast-math")
#pragma GCC target("sse,sse2,sse3,ssse3,sse4.1,sse4.2,avx,avx2,popcnt,tune=native")
#include<bits/stdc++.h>
using namespace std;
int n; vector<int> e[1000010];
int siz[1000010];
int cnt[1000010];
unsigned long long vis[(1000000>>6)+10];
void lshift(int n,int x)
{
    if((x&63)==0)
    {
        for(int i=n; i>=(x>>6); --i) vis[i]|=vis[i-(x>>6)];
    }
    else
    {
        for(int i=n; i>=(x>>6); --i)
        {
            int wzl=(i<<6)-x;
            vis[i]|=(wzl<0?0:vis[wzl>>6]>>(wzl&63))|(vis[(wzl>>6)+1]<<(64-(wzl&63)));
        }
    }
}
long long ans=0;
void dfs(int u)
{
    siz[u]=1;
    int ax=0;
    for(int v:e[u]) dfs(v),siz[u]+=siz[v],ax=max(ax,siz[v]);
    if(ax*2>=siz[u]-1) ans+=1ll*ax*(siz[u]-1-ax);
    else
    {
        for(int v:e[u]) ++cnt[siz[v]];
        memset(vis,0,((siz[u]-1>>6)+5)*8);
        vis[0]=1;
        for(int v:e[u])
        {
            if(cnt[siz[v]]==0) continue;
            int now=1;
            while(now<=cnt[siz[v]])
            {
                lshift(siz[u]-1>>6,siz[v]*now);
                cnt[siz[v]]-=now,now*=2;
            }
            if(cnt[siz[v]]!=0) lshift(siz[u]-1>>6,siz[v]*cnt[siz[v]]);
            cnt[siz[v]]=0;
        }
        int mid=(siz[u]-1)/2;
        for(int i=0; ; ++i)
        {
            if(vis[mid+i>>6]>>(mid+i&63)&1) { ans+=1ll*(mid+i)*(siz[u]-1-mid-i); break; }
            if(vis[mid-i>>6]>>(mid-i&63)&1) { ans+=1ll*(mid-i)*(siz[u]-1-mid+i); break; }
        }
    }
}
int main()
{
    ios::sync_with_stdio(false),cin.tie(0);
    cin>>n;
    for(int i=2; i<=n; ++i)
    {
        int f; cin>>f;
        e[f].push_back(i);
    }
    dfs(1);
    cout<<ans;
    return 0;
}