题解:P11824 [湖北省选模拟 2025] 团队协作 / team

· · 题解

题意

对每个点求出包含这个点的所有独立集中点权最大值的和,对 998244353 取模。

思路

我们称节点 i 的点权比节点 j 大当且仅当 v_i>v_jv_i=v_j,i>j

考虑转置原理。设 a_{i,j} 表示包含点 i 且点权最大的点为 j 的独立集个数,n 阶方阵 A=(a_{i,j}),列向量 \vec{a}=[v_1,v_2,\cdots,v_n]^{\top},则答案向量为 \vec b=A\vec a

它的转置问题即为:对于每个 i,求满足 i 是独立集中点权最大的点的所有独立集的点权和。

按点权从小到大加入每个点,加入点 i 时所有独立集点权和的变化量就是 i 的答案。在这里,一个点被加入指的是可以在独立集中出现,未被加入指的是不能出现在独立集中。

这是一个经典的 ddp 问题,我们使用静态 top tree 解决。对每个簇设 f(0/1,0/1),g(0/1,0/1) 分别表示不选/选上界点、不选/选下界点的独立集方案数和权值和,转移是显然的。值得注意的是,此处 g(1,*) 我们不计入上界点的贡献。可以新建一个虚点 0 作为 1 的父亲方便统计答案。

如果把 f 视为常量(矩阵里的量),则转移关于 g 是线性的。

需要注意,你可以把 ddp 的每次修改看成持久化的,即我们修改的时候变量所占用的内存都是不同的,需要先赋值一次。

然后转置就行了。时间复杂度 O(n\log n)。发现我们实际上并不需要持久化,每次直接覆盖就行,所以空间 O(n)。如果还是不明白可以看代码,注释里标了转置前后的内容。

事实上,用转置原理得到的做法和其他几篇直接用 top tree 得到的做法是一样的。

代码

#include <bits/stdc++.h>
#define File(a) freopen(#a".in","r",stdin);freopen(#a".out","w",stdout)
#define ll long long
#define F(i,a,b) for(int i(a),i##i##end(b);i<=i##i##end;++i)
#define R(i,a,b) for(int i(a),i##i##end(b);i>=i##i##end;--i)
#define fi first
#define se second
using namespace std;
const int MAXN=3e5+2,MOD=998244353;
int n;
pair<int,int>val[MAXN];
int cnt;
struct Node{
    short type;//-1unit,0rake,1compress
    int up,dn,mid;
    int siz,lc,rc,fa;
    ll f[2][2],g[2][2];
    Node(const int&t=-1,const int&d=0,const int&e=0,const int&a=1,const int&b=0,const int&c=0){
        type=t,siz=a,lc=b,rc=c,up=d,dn=e,fa=0;
        memset(f,0,sizeof(f)),memset(g,0,sizeof(g));
        return;
    }
}node[MAXN<<1];
bool isin[MAXN]; 
inline void upd(int now){
    Node&qwq(node[now]),&l(node[qwq.lc]),&r(node[qwq.rc]);
    memset(qwq.f,0,sizeof(qwq.f));
    if(node[now].type) F(i,0,1) F(j,0,1) F(k,0,isin[l.dn]) qwq.f[i][j]=(qwq.f[i][j]+l.f[i][k]*r.f[k][j])%MOD;
    else F(i,0,1) F(j,0,1) F(k,0,isin[r.dn]) qwq.f[i][j]=(qwq.f[i][j]+l.f[i][j]*r.f[i][k])%MOD;
    return;
}
inline void psd(int now){
    Node&qwq(node[now]),&l(node[qwq.lc]),&r(node[qwq.rc]);
    if(node[now].type){
        /*
        转置前 
        F(i,0,1) F(j,0,1) F(k,0,isin[l.dn]){
            qwq.g[i][j]=(qwq.g[i][j]+l.f[i][k]*r.g[k][j])%MOD;
            qwq.g[i][j]=(qwq.g[i][j]+l.g[i][k]*r.f[k][j])%MOD;
        }
        */ 
        F(i,0,1) F(j,0,1) F(k,0,isin[l.dn]){
            r.g[k][j]=(r.g[k][j]+l.f[i][k]*qwq.g[i][j])%MOD;
            l.g[i][k]=(l.g[i][k]+r.f[k][j]*qwq.g[i][j])%MOD;
        }//倒不倒序不影响 
    }else{
        /*
        转置前 
        F(i,0,1) F(j,0,1) F(k,0,isin[r.dn]){
            qwq.g[i][j]=(qwq.g[i][j]+l.f[i][j]*r.g[i][k])%MOD;
            qwq.g[i][j]=(qwq.g[i][j]+l.g[i][j]*r.f[i][k])%MOD;
        }
        */ 
        F(i,0,1) F(j,0,1) F(k,0,isin[r.dn]){
            r.g[i][k]=(r.g[i][k]+l.f[i][j]*qwq.g[i][j])%MOD;
            l.g[i][j]=(l.g[i][j]+r.f[i][k]*qwq.g[i][j])%MOD;
        }
    }
    memset(qwq.g,0,sizeof(qwq.g));
    return;
}
inline int merge(int x,int y,int type){
    node[x].fa=node[y].fa=++cnt;
    node[cnt]=Node(type,node[x].up,node[type?y:x].dn,node[x].siz+node[y].siz,x,y);
    return upd(cnt),cnt;
}
#define Poi vector<int>::iterator
int build(Poi l,Poi r,int type){
    if(l==r) return 0;
    if(l+1==r) return *l;
    int sum=0,all=0;
    for(auto it=l;it!=r;++it) all+=node[*it].siz;
    Poi mid=l+1;
    for(auto it=l;it!=r;++it){
        sum+=node[*it].siz;
        if(sum*2<=all) mid=it+1;
        else break;
    }
    return merge(build(l,mid,type),build(mid,r,type),type);
}
int siz[MAXN],fa[MAXN],son[MAXN],rt[MAXN];//根为 cnt 
vector<int>g[MAXN];
void dfs1(int now){
    siz[now]=1;
    node[now]=Node(-1,fa[now],now);
    node[now].f[0][0]=node[now].f[1][0]=node[now].f[0][1]=1;
    isin[now]=1;
    for(int i:g[now]){
        dfs1(i);
        siz[now]+=siz[i];
        if(siz[son[now]]<siz[i]) son[now]=i;
    }
    return;
}
void dfs2(int now,bool heavy){
    if(son[now]) dfs2(son[now],1);
    for(int i:g[now]) if(i!=son[now]) dfs2(i,0);
    if(!heavy){
        vector<int>chain;
        chain.push_back(now);
        for(int i(son[now]);i;i=son[i]){
            vector<int>sub({i});
            for(int j:g[fa[i]]) if(j!=i) sub.push_back(rt[j]);
            chain.push_back(build(sub.begin(),sub.end(),0));
        }
        rt[now]=build(chain.begin(),chain.end(),1);
    }
    return;
}
int ans[MAXN];
signed main(){
    ios::sync_with_stdio(0);
    cin.tie(0),cout.tie(0);
    cin>>n;
    F(i,2,n) cin>>fa[i],g[fa[i]].push_back(i);
    F(i,1,n) cin>>val[i].fi,val[i].se=i;
    sort(val+1,val+n+1);
    dfs1(1),cnt=n,dfs2(1,0);
    R(i,n,1){//倒序执行 
        /*
        转置前 (ans[i]-ans[i-1])+=node[cnt].g[0][0/1] 
        转置后倒序变成 node[cnt].g[0][0/1]+=val[i]-val[i+1] 
        */
        node[cnt].g[0][0]=(node[cnt].g[0][0]+val[i].fi-val[i+1].fi+MOD)%MOD;
        if(isin[node[cnt].dn]) node[cnt].g[0][1]=(node[cnt].g[0][1]+val[i].fi-val[i+1].fi+MOD)%MOD;
        int now=val[i].se,tp=0;
        static int path[MAXN];
        while(node[now].fa) path[++tp]=node[now].fa,now=path[tp];
        R(j,tp,1) psd(path[j]);//从上到下倒序执行 
        now=val[i].se;
        ans[now]=node[now].g[0][1],node[now].g[0][1]=0;//node[now].g[0][1]=val[now]的转置 
        isin[now]=0,node[now].f[0][1]=0; 
        while(node[now].fa) upd(node[now].fa),now=node[now].fa;//按原顺序更新常量(矩阵里的量) 
    }
    F(i,1,n) cout<<ans[i]<<" "; 
    return 0;
}