题解:P13828 [Ynoi Easy Round 2026] 寒蝉鸣泣之时·卒

· · 题解

@operator 太强了,拜谢 @operator

我们称选出的点为关键点。

首先设 len=\sum t_i,那么有 len\leqslant 10^5

首先考虑对 t_i 根号分治,设阈值 C,那么最多有 \mathcal{O}({len\over C}) 次查询的 t_i 大于 C,那么考虑直接 \mathcal{O}(n) 暴力去求答案,时间复杂度 \mathcal{O}({n\times len\over C})

现在考虑 t_i 小于 C 怎么做。

考虑将边的集合拆成一些从根开始的链减去另一些从根开始的链。

根据虚树的思想或手模一下即可发现,将关键点按 dfn 序排序之后,可以用所有从根到关键点的链减去从根到排序后相邻点的 lca 的链来表示,包括第一个关键点与最后一个关键点的 lca

于是我们就把集合拆成了 2t_i 条权值为 1/-1 的链,那么集合的答案就可以拆成每两条链相互的贡献之和了。

接下来需要思考求答案了。

考虑先将无序变成有序,这样算出来的答案减去自己与自己的贡献再除以二就是答案了。

考虑一种非主流树分块去求解答案。

就是一个块内最长的祖先-后代链的长度不超过阈值 B ,那么除了根所在的块一个块的大小就至少为 B ,于是块数大约在 \mathcal{O}({n\over B}) 级别的。

考虑如何计算两条链 x,y 的相互的贡献。

假设链 x 可以拆成从根开始的整块链 a 与散块链 cy 可以拆成整块 b 与散块 d,设链 u 对链 v 的贡献为 uv,易知 uv=vu

那么贡献应该是 ab+ba+cb+bc+ad+da+cd+dc

散块对散块先不管,那么整块可以拆成 (ab+2ad)+(ba+2bc)=(2ay-ab)+(2bx-ba),其中 ay,ab 之类的都是容易预处理的。

散块只有 t_iB 个树上的点,所以散块对散块直接暴力求即可。

那么这部分的时间复杂度就是 \mathcal{O}((t_i)^2+t_iB)

由于t_i\leqslant C,所以时间复杂度为 \mathcal{O}(lenB+lenC)

B=C=\sqrt n,复杂度即为 \mathcal{O}(n\sqrt n)。(视为 n,m,len 同阶)

::::info[Code]

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=1e5+10,B=320;
int n,m,w[N],lo[N];
vector<int> g[N];
int dfn[N],st[N][20],lg[N],rnk[N],p1;
int vis[N],cnt,fa[N],up[N],dep[N];
int res[N];
ll f[B][N];
void clear(){
    for(int i=0;i<=n;i++) res[i]=0;
}
void dfs(int x){
    dfn[x]=++p1,rnk[p1]=x;
    st[dfn[x]][0]=dfn[fa[x]];
    for(int v:g[x]){
        dep[v]=dep[x]+1;
        dfs(v);
        lo[x]=max(lo[v]+1,lo[x]);
    }
    if(x==1||lo[x]>=B) vis[x]=++cnt,lo[x]=0;
}
int lca(int x,int y){
    if(x==y) return x;
    x=dfn[x],y=dfn[y];
    if(x>y) swap(x,y);
    int len=lg[y-x];
    return rnk[min(st[x+1][len],st[y-(1<<len)+1][len])];
}
void dfs1(int x,int tp,ll las){
    las+=res[w[x]];
    f[tp][x]=las;
    for(int v:g[x]) dfs1(v,tp,las);
}
int gt[N],sm[N];
int tot;
struct node{
    int x,val;
}ln[N];
int main(){
    ios::sync_with_stdio(false);
    cin.tie(0),cout.tie(0);
    cin>>n;
    for(int i=2;i<=n;i++){
        cin>>fa[i]>>w[i];
        g[fa[i]].push_back(i);
    }
    dfs(1);
    for(int i=2;i<=n;i++) lg[i]=lg[i>>1]+1;
    for(int j=1;j<=19;j++){
        for(int i=1;i+(1<<j)-1<=n;i++){
            st[i][j]=min(st[i][j-1],st[i+(1<<j-1)][j-1]);
        }
    }
    for(int i=1;i<=n;i++){
        if(!vis[i]) up[i]=up[fa[i]];
        else{
            up[i]=i;
            clear();
            for(int x=i;x!=1;x=fa[x]) res[w[x]]++;
            dfs1(1,vis[i],0);
        }
    }
    clear();
    int q;
    cin>>q;
    while(q--){
        int len;
        cin>>len;
        for(int i=1;i<=len;i++) cin>>gt[i];
        if(len>=B){
            ll ans=0;
            for(int i=1;i<=len;i++) sm[gt[i]]++;
            for(int i=n;i>=1;i--){
                if(sm[i]&&sm[i]!=len) ans+=res[w[i]],res[w[i]]++;
                sm[fa[i]]+=sm[i];
            }
            clear();
            for(int i=0;i<=n;i++) sm[i]=0;
            cout<<ans<<"\n";
            continue;
        }
        sort(gt+1,gt+1+len,[&](int x,int y){return dfn[x]<dfn[y];});
        int las=gt[len];tot=0;
        for(int i=1;i<=len;i++){
            ln[++tot]={gt[i],1};
            ln[++tot]={lca(gt[i],las),-1};
            las=gt[i];
        }
        ll ans=0;
        for(int i=1;i<=tot;i++){
            node p=ln[i];
            int tp=vis[up[p.x]];
            for(int j=1;j<=tot;j++){
                node q=ln[j];
                ans+=q.val*p.val*(f[tp][q.x]*2-f[tp][up[q.x]]);
            }
        }
        for(int i=1;i<=tot;i++){
            for(int x=ln[i].x;up[x]!=x;x=fa[x]) res[w[x]]+=ln[i].val;
        }
        for(int i=1;i<=tot;i++){
            for(int x=ln[i].x;up[x]!=x;x=fa[x]) ans+=1ll*res[w[x]]*res[w[x]],res[w[x]]=0;
        }
        for(int i=1;i<=tot;i++) ans-=dep[ln[i].x]*ln[i].val;
        cout<<ans/2<<"\n";
    }
    return 0;
}

::::