题解 P6799 【「StOI-2」独立集】
littleKtian
·
·
题解
upd on 2021.3.16:修改了原做法和代码中的错误,现在能通过 hack 数据了。
显然这是道树形 dp,为了方便处理信息,我们考虑将这 m 条链放到 lca 上维护。
设 f_{i,0/1} 表示以 i 为根的子树中没有选取/有选取以 i 为 lca 的链时,选取链的方案数,g_i 为两者之和。
易得转移方程 f_{i,0}=\prod\limits_{j\text{是}i\text{的孩子}}g_j,f_{i,1}=\sum\limits_{\text{链}L\text{的lca是}i}\prod\limits_{j\text{不在}L\text{上,且}\atop j\text{的父亲在}L\text{上}}g_j。
发现连乘的形式和 $f_{i,0}$ 很像,于是考虑往这个方向变形。
得 $f_{i,1}=\sum\limits_{\text{链}L\text{的lca是}i}\dfrac{\prod\limits_{j\text{的父亲在}L\text{上}}g_j}{\prod\limits_{k\text{在}L\text{上且}k\neq i}g_k}=\sum f_{i,0}\prod\limits_{j\text{在}L\text{上且}j\neq i}\dfrac{f_{j,0}}{g_j}$。
显然这是一个树上的区间求乘积,最简单的方法是 树剖+线段树,但并不能通过此题。
我们可以对树剖的每条重链维护后缀积,因为整个 dp 的过程是从下往上的,而且我们只需要用到下面的信息点,所以对于每个点我们都可以做到 $O(1)$ 维护后缀积+每条链 $O(\log n\log p)$ 查询,预处理逆元后可以每个点 $O(\log p)$ 维护+每条链 $O(\log n)$ 查询,可以通过。
~~以上做法已被 hack~~
------------
上面做法的最大问题是 $f_{i,0}$ 和 $g_i$ 在对 $998244353$ 取模意义下可能为 $0$,导致逆元无法使用。但我们也可以从中获得一些启示。
我们把乘积式中的元素分为 $j$ 是 $i$ 的儿子以及不是 $i$ 的儿子两部分,发现对于 $j$ 是 $i$ 的儿子的部分,最多只会有两个元素缺失,所以我们可以给 $i$ 的每个儿子一个标号,用线段树处理出区间乘积,对于每条链最多只会有 $3$ 次询问,所以对于每条链复杂度是 $O(\log n)$。
对于不是 $i$ 的儿子的部分,可以发现对于所有链 $L$ 上的点 $j$,都最多只会缺失一个元素,并且缺失的元素就是链 $L$ 上的点。所以我们可以另设 $h_i$ 表示 $fa_i$ 的儿子中除了 $i$ 以外所有 $g$ 的乘积,那么我们相当于需要查询一条链上所有元素的乘积,我们可以转化为每次对以某个点为根的子树内所有元素乘上一个数,然后进行单点查询。因此复杂度还是 $O(\log n)$ 的。
总复杂度应该是 $O\left((n+m)\log n\right)$,细节相对较多。
```cpp
#include<bits/stdc++.h>
using namespace std;
#define p 998244353
int lsw[500005],bi[1000005][2],bs;
int fa[500005],so[500005],si[500005],de[500005],to[500005],xh[500005],dfn;
int lw[500005],la[500005][3];
int n,m,ff[2][500005],f[500005];
int dr()
{
int xx=0;char ch=getchar();
while(ch<'0'||ch>'9')ch=getchar();
while(ch>='0'&&ch<='9')xx=xx*10+ch-'0',ch=getchar();
return xx;
}
int P(int x,int y=p-2)
{
int z=1;
for(;y;x=1ll*x*x%p,y>>=1)if(y&1)z=1ll*z*x%p;
return z;
}
void tj(int u,int v){++bs,bi[bs][0]=lsw[u],bi[bs][1]=v,lsw[u]=bs;}
void dfs_1(int w,int f)
{
fa[w]=f,si[w]=1,de[w]=de[f]+1;
for(int o_o=lsw[w];o_o;o_o=bi[o_o][0])
{
int v=bi[o_o][1];
if(v!=f)
{
dfs_1(v,w),si[w]+=si[v];
if(si[v]>si[so[w]])so[w]=v;
}
}
}
void dfs_2(int w,int t)
{
to[w]=t,xh[w]=++dfn;
if(so[w])dfs_2(so[w],t);
for(int o_o=lsw[w];o_o;o_o=bi[o_o][0])
{
int v=bi[o_o][1];
if(v!=fa[w]&&v!=so[w])dfs_2(v,v);
}
}
int fi(int x,int y)
{
int fx=to[x],fy=to[y];
while(fx!=fy)
{
if(de[fx]>=de[fy])x=fa[fx],fx=to[x];
else y=fa[fy],fy=to[y];
}
return de[x]<de[y]? x:y;
}
int ju(int x,int y)
{
int fx=to[x],fy=to[y];
while(fx!=fy)
{
if(fa[fx]==y)return fx;
x=fa[fx],fx=to[x];
}
return so[y];
}
int tree[2000005],ltree[2000005],a[500005],lxh[500005];
#define ls(w) (w<<1)
#define rs(w) (ls(w)^1)
void d(int w){if(tree[w]!=1)tree[ls(w)]=1ll*tree[ls(w)]*tree[w]%p,tree[rs(w)]=1ll*tree[rs(w)]*tree[w]%p,tree[w]=1;}
void csh1(int w,int l,int r)
{
tree[w]=1;
if(l==r)return;
int mid=(l+r)>>1;
csh1(ls(w),l,mid),csh1(rs(w),mid+1,r);
}
void xg1(int w,int l,int r,int L,int R,int x)
{
if(L<=l&&r<=R){tree[w]=1ll*tree[w]*x%p;return;}
d(w);
int mid=(l+r)>>1;
if(L<=mid)xg1(ls(w),l,mid,L,R,x);
if(mid<R)xg1(rs(w),mid+1,r,L,R,x);
}
int cx1(int w,int l,int r,int xh)
{
if(l==r)return tree[w];
d(w);
int mid=(l+r)>>1;
if(xh<=mid)return cx1(ls(w),l,mid,xh);
else return cx1(rs(w),mid+1,r,xh);
}
void u(int w){ltree[w]=1ll*ltree[ls(w)]*ltree[rs(w)]%p;}
void csh2(int w,int l,int r)
{
if(l==r){ltree[w]=a[l];return;}
int mid=(l+r)>>1;
csh2(ls(w),l,mid),csh2(rs(w),mid+1,r);
u(w);
}
int cx2(int w,int l,int r,int L,int R)
{
if(L>R)return 1;
if(L<=l&&r<=R)return ltree[w];
int mid=(l+r)>>1,x=1;
if(L<=mid)x=1ll*x*cx2(ls(w),l,mid,L,R)%p;
if(mid<R)x=1ll*x*cx2(rs(w),mid+1,r,L,R)%p;
return x;
}
void dp(int w)
{
ff[0][w]=1;int tot=0;
for(int o_o=lsw[w];o_o;o_o=bi[o_o][0])
{
int v=bi[o_o][1];
if(v!=fa[w])dp(v),lxh[v]=++tot,ff[0][w]=1ll*ff[0][w]*f[v]%p;
}
for(int o_o=lsw[w];o_o;o_o=bi[o_o][0])
{
int v=bi[o_o][1];
if(v!=fa[w])a[lxh[v]]=f[v];
}
if(tot)csh2(1,1,tot);
for(int o_o=lw[w];o_o;o_o=la[o_o][0])
{
int u=la[o_o][1],v=la[o_o][2];
if(u!=w&&v!=w)
{
int x=ju(u,w),y=ju(v,w);
if(lxh[x]>lxh[y])swap(x,y);
ff[1][w]=(ff[1][w]+(1ll*cx1(1,1,n,xh[u])*ff[0][u]%p)*(1ll*cx1(1,1,n,xh[v])*ff[0][v]%p)%p*(1ll*cx2(1,1,tot,1,lxh[x]-1)*cx2(1,1,tot,lxh[x]+1,lxh[y]-1)%p*cx2(1,1,tot,lxh[y]+1,tot)%p))%p;
}
else if(u==w&&v==w)
{
ff[1][w]=(ff[1][w]+ff[0][w])%p;
}
else
{
if(v==w)swap(u,v);
int x=ju(v,w);
ff[1][w]=(ff[1][w]+1ll*cx1(1,1,n,xh[v])*ff[0][v]%p*cx2(1,1,tot,1,lxh[x]-1)%p*cx2(1,1,tot,lxh[x]+1,tot))%p;
}
}
for(int o_o=lsw[w];o_o;o_o=bi[o_o][0])
{
int v=bi[o_o][1];
if(v!=fa[w])
{
int x=1ll*cx2(1,1,tot,1,lxh[v]-1)*cx2(1,1,tot,lxh[v]+1,tot)%p;
xg1(1,1,n,xh[v],xh[v]+si[v]-1,x);
}
}
f[w]=(ff[0][w]+ff[1][w])%p;
}
int main()
{
n=dr(),m=dr();
for(int i=1;i<n;i++)
{
int u=dr(),v=dr();
tj(u,v),tj(v,u);
}
dfs_1(1,0),dfs_2(1,1);
for(int i=1;i<=m;i++)
{
int u=dr(),v=dr(),g=fi(u,v);
la[i][0]=lw[g],la[i][1]=u,la[i][2]=v,lw[g]=i;
}
csh1(1,1,n),dp(1);
printf("%d",f[1]);
}
```