CF925E May Holidays 题解

· · 题解

这里给出一个时间 O(n\sqrt {n \log n}) ,空间 O(n) ,的做法,时间上介于正常树剖分块和神奇树剖分块之间,空间上没有问题。

转换一下题意,就是每次对于一条根节点到某个点的路径 +1/-1,并对该节点反色,并每次询问权值大于 0 的白色点个数。

看到这道题的第一眼,就有一个 O(nq) 的做法,然而过不去。

但注意到如果我们不做 O(nq),而是找一个数 B,来做 O(B^2),做 \frac{q}{B} 次,那么总时间复杂度就是 O(qB) 的。

这就是询问分块的思想。

考虑阈值 B,把每 B 个询问分作一块。

如果对于每个询问继续暴力处理,时间复杂度依旧是 O(B \times n \times \frac{n}{B}) = O(n^2)

但注意到 B 个询问只涉及最多 B 个点。

如果我们只考虑这 B 个点,向上跳,我们至多只会修改 O(B) 个点。

但这 B 个点不一定成一棵树,所以我们要 建虚树

建完虚树后,我们在虚树上跳,就至多会跳到 O(B) 个点。

但询问分块的思路依旧要求我们对于每个点,O(1) 修改信息。

而此时,每条虚边已经覆盖了原树上的一条链了,我们无法简单地按照暴力那样修改信息。

一个很重要的性质在于:每条虚边在每个时刻,被加的值都是相同的

所以我们对于每条虚边 (u,v) (其中 uv 的祖先,都可以把这条边上的点权挂在 v 上(不包括 u,否则会算重,并打一个 add 标记。

考虑修改操作时怎么办。

先把挂在每个虚点上的点权从大到小排序。

注意到每次 add 变化量为 1。

对每个虚点记录一个指针 poi,表示 poi 及其前面的都 >0

如果点权互不相同,poi 至多移动一位。

而点权相同时可以去重,在点权外额外记录出现次数。

这样单次修改复杂度就是 O(1) 的。

另一个棘手的问题是颜色怎么办。

第一,我们记录“出现次数,只能记录白色点的出现次数。

第二,我们修改颜色时,修改的肯定是个虚点,所以我们只要在它身上的信息中二分查找到该点权所在的位置,修改出现次数即可。

注意到如果位置在 poi 及其之前,还要修改答案。

好了,对于每一块我们就做完了。

还有一点,每一块答案的正确,依赖的是前一块信息维护的正确,但我们只建了这一块的虚树。

所以我们在每一块结束后,要把 add 标记下放到点权上,清空标记和挂载在上面的信息,这就是询问分块的 定期重构 步骤。

算法讲解完毕,我们来考虑 B 的取值以及复杂度。

总共有 \frac{q}{B} 块,每块内是 O(n\log n + B^2) 的(因为有排序的复杂度)。

所以总复杂度为 O(\frac{nq\log n}{B} + qB),假设 n,q 同阶,取 B = \sqrt {n \log n} 就可以的得到 O(n \sqrt {n \log n}) 的时间复杂度。

如果使用基数排序,可以做到 O(n \sqrt n)

Code:

#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 5;
int fir[N],nxt[N<<1],to[N<<1],ect = 0,n,q;
int fa[N];
inline void addedge(int u1,int v1)
{
    nxt[++ect] = fir[u1];fir[u1] = ect;to[ect] = v1;
}
namespace LCA{
    int ST[N*2][20],fst[N],dep[N],dfn[N];
    int dfstime,lg[N*2];
    void dfs(int x,int fa)
    {
        dep[x] = dep[fa] + 1;
        ST[fst[x] = ++dfstime][0] = x;
        for(int i = fir[x],y;y=to[i],i;i=nxt[i])
        {
            if(y == fa) continue;
            dfs(y,x);
            ST[++dfstime][0] = x;
        }
    }
    void dfs_dfs(int x,int fa)
    {
        dfn[x] = ++dfstime;
        for(int i = fir[x],y;y=to[i],i;i=nxt[i])
            if(y != fa) dfs_dfs(y,x);

    }
    inline int cmp(int x,int y) { return dep[x] < dep[y] ? x : y;}
    void init()
    {
        dfs_dfs(1,0);
        dfstime = 0;
        dfs(1,0);lg[0] = -1;
        for(int i = 1;i <= dfstime;i++) lg[i] = lg[i >> 1] + 1;
        for(int j = 1;j <= 19;j++)
            for(int i = 1;i + (1 << j) - 1 <= dfstime;i++)
                ST[i][j] = cmp(ST[i][j-1],ST[i+(1<<j-1)][j-1]);
    }
    inline int lca(int x,int y)
    {
        x = fst[x];y = fst[y];
        if(x > y) swap(x,y);
        int k = lg[y-x+1];
        return cmp(ST[x][k],ST[y-(1<<k)+1][k]);
    }
}
int Pos[N];
vector<int> G[N];
inline bool cmp(int x,int y) { return LCA::dfn[x] < LCA::dfn[y];}
int stk[N],top;
inline void addvir(int x,int y)
{
    G[x].push_back(y);
    G[y].push_back(x);
}
inline void build_tr(int &num)
{
    G[1].clear();
    stk[top = 1] = 1;
    sort(Pos + 1,Pos + num + 1,cmp);
    num = std::unique(Pos + 1,Pos + num + 1) - Pos - 1;

    for(int i = 1;i <= num;i++)
    {
        if(Pos[i] == 1) continue;
        int l = LCA::lca(Pos[i],stk[top]);
        if(l != stk[top])
        {
            while(LCA::dfn[l] < LCA::dfn[stk[top-1]])
                addvir(stk[top-1],stk[top]),--top;
            if(LCA::dfn[l] != LCA::dfn[stk[top-1]])
                G[l].clear(),addvir(l,stk[top]),stk[top] = l;
            else addvir(l,stk[top--]);
        }
        G[Pos[i]].clear();stk[++top] = Pos[i];
    }
    for(int i = 1;i < top;i++) addvir(stk[i],stk[i+1]);
}
int a[N],Q[N];
vector<pair<int,int> > W[N];
typedef vector<pair<int,int> > :: iterator Iter;
int virfa[N];
int add[N],Poi[N];
int col[N];
int ans = 0;
void dfs1(int x,int f)
{
    int now = x;
    virfa[x] = f;
    vector<pair<int,int> > tmp;
    while(now != f)
    {
        tmp.push_back(make_pair(a[now],now));
        now = fa[now];
    }
    sort(tmp.begin(),tmp.end(),greater<pair<int,int> >());
    W[x].clear();
    for(int i = 0;i < (int)tmp.size();i++)
    {
        if(W[x].empty()) {W[x].push_back(make_pair(tmp[i].first,!col[tmp[i].second]));continue;}
        if(W[x].back().first != tmp[i].first)
            W[x].push_back(make_pair(tmp[i].first,!col[tmp[i].second]));
        else W[x][W[x].size()-1].second += !col[tmp[i].second];
    }

    add[x] = 0;
    Poi[x] = -1;
    while(Poi[x] < (int)W[x].size() - 1 && W[x][Poi[x] + 1].first > 0) ++Poi[x];

    for(int i = 0;i < (int)G[x].size();i++)
    {
        int y = G[x][i];
        if(y == f) continue;
        dfs1(y,x);
    }
}
inline bool tmpcmp(const pair<int,int> &x,const pair<int,int> &y) { return x > y;}
inline void change(int x)
{
    col[x] ^= 1;
    if(col[x] == 0) //Add
    {
        Iter it = lower_bound(W[x].begin(),W[x].end(),make_pair(a[x],1e9),tmpcmp);
        assert(it->first == a[x]);
        it->second += 1;
        if(it - W[x].begin() <= Poi[x]) ++ans;
    }
    else
    {
        Iter it = lower_bound(W[x].begin(),W[x].end(),make_pair(a[x],1e9),tmpcmp);
        assert(it -> first == a[x]);
        it->second -= 1;
        if(it - W[x].begin() <= Poi[x]) --ans;
    }
}
void modify(int x,int v)
{
    int now = x;
    while(now)
    {
        if(v == 1)
        {
            add[now]++;
            if(Poi[now] < (int)W[now].size() - 1 && W[now][Poi[now] + 1].first + add[now] >= 1)
                ans += W[now][++Poi[now]].second;
        }
        else
        {
            add[now]--;
            if(Poi[now] >= 0 && W[now][Poi[now]].first + add[now] < 1)
                ans -= W[now][Poi[now]--].second;
        }
        now = virfa[now];
    }
}
int readd[N];
void dfs2(int x,int fa)
{
    for(int i = fir[x],y;y=to[i],i;i=nxt[i])
    if(y != fa) dfs2(y,x),readd[x] += readd[y]; 
    a[x] += readd[x];
}
inline void rebuild(int l,int r)
{
    for(int i = 1;i <= n;i++) readd[i] = 0;
    for(int i = l;i <= r;i++)
    {
        if(Q[i] > 0) readd[Q[i]] += 1;
        else readd[-Q[i]] -= 1;
    }
    dfs2(1,0);
    // ans = 0;
}
int main()
{
    scanf("%d%d",&n,&q);
    for(int i = 2;i <= n;i++)
        scanf("%d",&fa[i]),addedge(fa[i],i),addedge(i,fa[i]);
    LCA::init();
    for(int i = 1,x;i <= n;i++)
        scanf("%d",&x),a[i] = -x;
    for(int i = 1;i <= q;i++)
        scanf("%d",&Q[i]);
    int B = sqrt(1.0*q*log2(n));

    for(int l = 1;l <= q;l += B)
    {
        int r = min(l + B - 1,q);
        int num = 0;
        for(int i = l;i <= r;i++)
            Pos[++num] = abs(Q[i]);
        build_tr(num);
        // dfs0(1,0);
        dfs1(1,0);
        for(int i = l;i <= r;i++)
        {
            int x = abs(Q[i]),y = Q[i] > 0 ? 1 : -1;
            change(x);
            modify(x,y);
            printf("%d ",ans);
        }
        rebuild(l,r);
    }
    return 0;
}