题解 P3781 【[SDOI2017]切树游戏】
shadowice1984
2018-07-25 17:15:12
从九省联考前咕到现在的一道题……
精神污染,做好码5~6个k的心理准备……
____________________
## 前置知识:快速沃尔什变换(FWT)
如果不会fwt的话可以去这个模板区学习一下
[P4717](https://www.luogu.org/problemnew/show/P4717)
## 本题题解
首先呢题意是简单易懂的,求树上异或和为k的联通块个数
到此为止它还是$trival$的,最多是个紫牌题
但是如果我们需要支持动态修改点权呢?
然后这道题就变成了黑牌题了
当然猫老师也用这道题作为了他[博客](http://immortalco.blog.uoj.ac/blog/2625)上的一道例题……
那么让我们讲解一下这道题的原理吧……
___________________
## 朴素的dp
我们可以设$Dp_{i,j}$表示最高点为i的异或和为j的联通块个数
那么,如果询问为k的话我们最后就是求
## $\sum_{i=1}^{n}Dp_{i,k}$
让我先来愉快的推一下转移方程
还是考虑我们有了$Dp_{u,k}$现在要加入一个子树$Dp_{v}$
那么我们可以得到这个转移方程
## $Dp_{u,k}=\sum_{i \oplus j==k}Dp_{u,i}Dp_{v,j}+Dp_{u,k}$
我们发现那个和式其实是所谓的异或卷积的形式
当然可以使用fwt进行优化
所以假如我们设$\hat{Dp}_{i}$为数组$Dp_{i}$经过fwt之后的数组的话
我们会得到这样的转移式子
## $\hat{Dp}_{u,k}=\hat{Dp}_{u,k}\hat{Dp}_{v,k}+\hat{Dp}_{u,k}$
即
## $\hat{Dp}_{u,k}*=(\hat{Dp}_{v,k}+1)$
然后我们假如设
## $g_{u,k}=\sum_{i \in u.subtree}Dp_{u,k}$
的话(u.subtree)就是u的子树的意思
等式两边取沃尔什变换
## $\hat{g}_{u,k}=\sum_{i \in u.subtree}\hat{Dp}_{u,k}$
所以说我们只需要每次计算出$\hat{g}_{1}$这个数组就可以了,当然如果我们每次暴力重新计算的话算法复杂度应该是
## $O(nqmlogm)$
的,显然无法通过本题,所以我们可能需要一些优化
_____________________
## 动态dp
看到这里的话你已经注意到dp的转移方程十分简单这个事实了
这一般预示着一种叫做动态dp的技术,允许我们每次仅仅做$O(logn)$次矩阵乘法就可以快速完成dp的单点修改
那么具体来讲是什么操作呢……?
为了加深理解可以先去做一做这两道题
[P4751](https://www.luogu.org/problemnew/show/P4751)
[P4719](https://www.luogu.org/problemnew/show/P4719)
然后如果你还是不非常理解这种写法的话可以看这里的介绍
______________________
## 链分治
说白了就是轻重链剖分……
我们对这个树进行一波轻重链剖分,处理出每个点的$siz$和重儿子
然后我们采取一种不寻常的dp方式进行dp
我们一条重链一条重链的dp
这是什么意思呢……
假如我们正在计算一条重链上节点的dp值,那么我们首先for一遍这个重链,递归的求出和这个重链相连,且**最高点深度大于这个重链最高点的重链**的dp值,现在呢对于这个重链上的每一个节点,它的轻儿子的dp值已经被算出来了,还差重儿子的dp值没有算
此时我们可以做一个序列上的dp,从底向上依次递推出重链上每一个的dp值
好了到这里你可能会说,这不是没什么卵用吗……
你的算法复杂度仍然是$O(n)$的啊……
但是,请你再好好想一想,我们通过轻重链剖分这个技术将辣手的树上问题转化为了(略微好做一点的)序列问题
那么不如让我们来看一下用这种方式做的dp的转移方程会产生什么效果
那么首先我们需要知道轻儿子的dp值以及轻儿子子树当中dp的和
这样的话我们在重链上dp的时候就可以套用树形dp时的表达式(因为都是加一条边)
所以我们设
## $ldp_{u,k}=dp_{u,k}\prod_{v \in u.lightson}(dp_{v,k}+1)$
假设这里的数组都是fwt之后的形式,所以这里的乘法是普通的按位相乘而不是异或卷积
我们认为$Dp_{u}$的初值是一堆0,只有自己的值那一位是一个1的数组,上面公式里的$dp_{u}$指的是这个初值数组fwt之后的形式
然后接着设
## $lg_{u,k}=\sum_{v \in u.lightson}lg_{v,k}$
当然这里的数组依然也fwt之后的形式给出了
那么我们在重链上的dp式子就是这样
## $dp_{i}=ldp_{i+1}dp_{i}+ldp_{i+1}$
## $g_{i}=dp_{i}+lg_{i+1}+g_{i}$
然后将$g_{i}$表达式当中的$dp_{i}$代入一下就是
## $g_{i}=ldp_{i+1}dp_{i}+ldp_{i+1}+lg_{i+1}+g_{i}$
然后是最为关键的一步了……
在immortalco神仙将dp的转移写成矩阵之前这道题有一个老式的做法……
就是使用线段树维护每一条重链。
然后线段树上每一个节点维护4个值,分别是取了一段连续的重链区间,且左端点取到/右端点取到/左,右端点都没取/左,右端点都取了这个四种状态的方案数
然后你发现这个东西是可以大力转移的,于是大力转移就是了
当然你会发现你的代码十分的恶心……而且难写(这道题额外赠送一个更加难受的轻边处理问题)
当然自从猫老师把dp的转移写成矩阵的形式就不会这么麻烦了……
来,让我们尝试着把dp的转移写成矩阵乘法的形式
我们发现似乎这样做是可行的
手玩一下会发现
### $$\begin{bmatrix} dp_{i} \\ g_{i} \\ 1 \end{bmatrix} × \begin{bmatrix} ldp_{i+1} & ldp_{i+1} & 0 \\ 0 & 1 & 0\\ ldp_{i+1} & ldp_{i+1}+lg_{i+1} & 1\end{bmatrix} = \begin{bmatrix} dp_{i+1} \\ g_{i+1} \\ 1\end{bmatrix}$$
当然随之而来的事情就是,你发现每做一次矩乘似乎要做27次乘法和加法运算,并且,这些东西好像还都是对于一个长度为128的数组来讲的我们的常数可能有点爆炸
没关系我们接着手玩一下矩阵乘法,发现一个有趣的事实
### $$\begin{bmatrix} a_{0,0} & a_{0,1} & 0 \\ 0 & 1 & 0 \\ a_{1,0} & a_{1,1} & 1 \end{bmatrix} × \begin{bmatrix} b_{0,0} & b_{0,1} & 0 \\ 0 & 1 & 0 \\ b_{1,0} & b_{1,1} & 1 \end{bmatrix}=\begin{bmatrix} c_{0,0} & c_{0,1} & 0 \\ 0 & 1 & 0 \\ c_{1,0} & c_{1,1} & 1 \end{bmatrix}$$
而且
### $c_{0,0}=a_{0,0}b_{0,0}$
### $c_{0,1}=a_{0,0}b_{0,1}+a_{0,1}$
### $c_{1,0}=a_{1,0}b_{0,0}+b_{1,0}$
### $c_{1,1}=a_{1,0}b_{0,1}+a_{1,1}+b_{1,1}$
所以事实上我们只需要保存 $a_{0,0},a_{0,1},a_{1,0},a_{1,1}$就行了
而且做一次矩乘只需要4次乘法和加法了,常数降低到了一个可以令人接受的范围
对了,矩阵当中的元素可是多项式,因此这里的0和1指的是系数全部是0或者1的多项式
那么我们想要知道一个轻儿子的dp值,其实只需要提取出整条重链的转移矩阵连乘积,然后乘上一个初始向量(0,0,1)就行了,其实实际写代码的时候甚至没有必要写初值向量乘上转移矩阵这一步,因为事实上你会发现$dp_{u}=a_{1,0},g_{u}=a_{1,1}$
所以在经历了好一顿折腾之后我们终于把dp的转移写成了矩阵乘法的形式,并且把得出dp值转化为了询问一条重链的矩阵连乘积的形式
好了那这么折腾有啥用呢?
**矩阵的乘法具有结合律,因此可以方便的使用线段树维护矩阵连乘积**
想像一下修改时的情景吧,首先我们修改了一个点的点权,这对应着线段树上的单点修改,于是我们使用线段树一路修改上去,于是这个重链顶端节点的dp值改变了,于是顶端节点父亲的$ldp,lg$值都会改变,这对应着另一个重链的单点修改,于是我们再次使用线段树进行修改,重复这个过程,当我们修改到1号节点的时候,我们的算法就结束了
说的好像很轻松,但是刚才的算法流程是有硬伤的
我们修改了一个重链顶的dp值之后,如何快速计算它的父亲的$ldp,lg$值呢?
你可能会说,这有何nan,把老的值除掉/减去,然后乘上/加上新的轻儿子的dp值和g值不就好了……
但是啊,我们很有可能除一个0……此时情况瞬间变得辣手起来,其实,这里我们维护的信息似乎并不可减……此时情况开始变得辣手
当然,你可以向猫老师一样每个节点单开一个线段树来强行维护一波ldp值,但是由于在每个节点处单开一个线段树会妨碍我们使用动态dp的改进算法使得本题的复杂度降至
$O((n+qm)logn+(n+q)mlogm)$
所以说这里我们必须忍痛舍弃掉线段树,转而采取另一种方式使得我们可以除0
具体来讲,我们对于$ldp$数组,额外开一个数组$cntzero$来强行维护它,由于$ldp$是一堆多项式按位相乘得到的,所以$ldp$的第i位的值是所有非零项的乘积,而$cntzero$第i位是这一位0项的个数,当然我们提取这个数组的第i位的时候需要判断一下$cntzero$是否为0
这样我们就有办法乘和除0了,每次我们乘或者除一个0,不对$ldp$数组进行操作而是对$cntzero$数组进行操作,此时操作就变得完全可逆了
好了到此为止恶心的轻边问题就解决了,考虑到这道题的n比较小,我推荐使用常数比较小的树剖(自带一个0.25的常数)进行维护,但是……树剖难写
所以我写的是“全局平衡二叉树”
具体来讲我们使用bst来维护每一个重链,在建bst的时候不是采取平凡的取中点左右两边递归建树的方式,而是给每一个点加上一个附加权值,表示这个点所有轻儿子的siz和+1,然后找这条链的带权重心递归两边左右建树,这样我们可以保证无论是经过一条bst边还是虚边,所在子树的size至少翻一倍,这样的算上虚边的树高就是logn的,尽管某一个bst的形态可能就是一条链
~~但是有一个2的常数,而且在链上的常数会翻一倍……这道题n小体现不出来,常数被树剖吊打~~
~~可以去写一下P4751~~
但是好处就是比起树剖来说好写
然后关于这道题的实现部分,就是封装封装再封装,做好写上5~6个k的心理准备
最难受的是端点调试基本没有,都得人肉调100多行的代码……
上代码
```C
// luogu-judger-enable-o2
#include<cstdio>
#include<algorithm>
using namespace std;const int N=3*1e4+10;const int M=130;typedef unsigned int uit;
const uit mod=10007;int n;int m;int q;uit iv[mod+10];uit ans[M];char opt[20];int rot;
inline void fwt(uit* a,const uit& o)//fwt
{
for(int k=1;k<m;k<<=1)
for(int s=0;s<m;s+=(k<<1))
for(int i=s;i<s+k;i++)
{uit a0=a[i];uit a1=a[i+k];a[i]=(a0+a1)*o%mod;a[i+k]=(a0+mod-a1)*o%mod;}
}
struct poly//多项式类
{
uit f[M];
inline uit& operator [](const int& x){return f[x];}
friend poly operator *(poly a,poly b){poly c;for(int i=0;i<m;i++)c[i]=a[i]*b[i]%mod;return c;}
friend poly operator /(poly a,poly b){poly c;for(int i=0;i<m;i++)c[i]=a[i]*iv[b[i]]%mod;return c;}
friend poly operator +(poly a,poly b){poly c;for(int i=0;i<m;i++)c[i]=(a[i]+b[i])%mod;return c;}
friend poly operator -(poly a,poly b){poly c;for(int i=0;i<m;i++)c[i]=(a[i]+mod-b[i])%mod;return c;}
void operator +=(const poly& b){for(int i=0;i<m;i++)f[i]=(f[i]+b.f[i])%mod;}
void operator -=(const poly& b){for(int i=0;i<m;i++)f[i]=(f[i]+mod-b.f[i])%mod;}
inline void setone(){for(int i=0;i<m;i++)f[i]=1;}
}we[N],one;
struct mar//矩阵类
{
poly mp[2][2];
inline poly* operator [](const int& x){return mp[x];}
friend mar operator* (mar a,mar b)
{
mar c;c[0][0]=a[0][0]*b[0][0];
c[0][1]=a[0][0]*b[0][1]+a[0][1];
c[1][0]=a[1][0]*b[0][0]+b[1][0];
c[1][1]=a[1][0]*b[0][1]+a[1][1]+b[1][1];return c;
}
};
int v[2*N];int x[2*N];int ct;int al[N];int siz[N];int h[N];
inline void add(int u,int V){v[++ct]=V;x[ct]=al[u];al[u]=ct;}
inline int dfs(int u)//轻重链剖分
{
siz[u]=1;
for(int i=al[u];i;i=x[i])if(siz[v[i]]==0){siz[u]+=dfs(v[i]);if(siz[h[u]]<siz[v[i]])h[u]=v[i];}
return siz[u];
}
struct data//用来维护轻边的结构体
{
int fz[M];uit fv[M];
inline void wt(const int& i,const int& x){if(x==0)fz[i]=fv[i]=1;else fz[i]=0,fv[i]=x;}
void operator *=(const poly& b)
{for(int i=0;i<m;i++)fz[i]+=(b.f[i]==0);for(int i=0;i<m;i++)fv[i]=fv[i]*(b.f[i]?b.f[i]:1)%mod;}
void operator /=(const poly& b)
{for(int i=0;i<m;i++)fz[i]-=(b.f[i]==0);for(int i=0;i<m;i++)fv[i]=fv[i]*iv[b.f[i]]%mod;}
};
inline void cpy(poly& a,const data& b){for(int i=0;i<m;i++)a[i]=b.fz[i]?0:b.fv[i];}
struct global_balanced_tree//全局平衡二叉树
{
int s[N][2];int fa[N];int lsiz[N];int st[N];int tp;mar w[N];mar mul[N];data lt[N];poly lh[N];
inline void udw(const int& p)
{cpy(w[p][0][0],lt[p]);w[p][1][1]=w[p][1][0]=w[p][0][1]=w[p][0][0];w[p][1][1]+=lh[p];}
inline void ud(const int& p){mul[p]=mul[s[p][0]]*w[p]*mul[s[p][1]];}
inline void ins(const int& u,const int& v){lt[u]*=(mul[v][1][0]+one);lh[u]+=mul[v][1][1];}
inline void del(const int& u,const int& v){lt[u]/=(mul[v][1][0]+one);lh[u]-=mul[v][1][1];}
inline bool isr(const int& u){return (s[fa[u]][0]!=u)&&(s[fa[u]][1]!=u);}
inline int subbuild(int l,int r)//一条重链递归建树
{
if(l==r){ud(st[l]);return st[l];}if(l>r){return 0;}
int tot=0;for(int i=l;i<=r;i++)tot+=lsiz[st[i]];
for(int i=l,nsiz=lsiz[st[i]];i<=r;i++,nsiz+=lsiz[st[i]])
if(2*nsiz>=tot)
{
s[st[i]][0]=subbuild(l,i-1);s[st[i]][1]=subbuild(i+1,r);
fa[s[st[i]][0]]=st[i];fa[s[st[i]][1]]=st[i];ud(st[i]);return st[i];
}
}
inline int build(int p)//链分治
{
for(int i=p;i;i=h[i])lsiz[i]=siz[i]-siz[h[i]];
for(int i=p;i;i=h[i])
for(int j=al[i];j;j=x[j])
if(lsiz[v[j]]==0){int ls=build(v[j]);fa[ls]=i;ins(i,ls);}
for(int i=p;i;i=h[i])udw(i);
tp=0;for(int i=p;i;i=h[i])st[++tp]=i;reverse(st+1,st+tp+1);
return subbuild(1,tp);
}
inline void modify(int u,const int& w1)//修改,每次跳一条轻边
{
lt[u]/=we[u];
for(int i=0;i<m;i++)we[u][i]=0;we[u][w1]=1;fwt(we[u].f,1);lt[u]*=we[u];udw(u);
for(int p=u;p;p=fa[p])
if(isr(p)&&fa[p]){del(fa[p],p);ud(p);ins(fa[p],p);udw(fa[p]);}else ud(p);
for(int i=0;i<m;i++)ans[i]=mul[rot][1][1][i];fwt(ans,5004);
}
inline void ih()//初始化
{
for(int i=1;i<=n;i++)for(int j=0;j<m;j++)lt[i].wt(j,we[i][j]);
w[0][0][0].setone();mul[0][0][0].setone();
}
}gbt;
int main()
{
iv[0]=1;iv[1]=1;for(int i=2;i<mod;i++)iv[i]=(mod-mod/i)*iv[mod%i]%mod;
scanf("%d%d",&n,&m);one.setone();
for(int i=1,w;i<=n;i++){scanf("%d",&w),we[i][w]=1,fwt(we[i].f,1);}
for(int i=1,u,v;i<n;i++)scanf("%d%d",&u,&v),add(u,v),add(v,u);
dfs(1);gbt.ih();rot=gbt.build(1);
for(int i=0;i<m;i++)ans[i]=gbt.mul[rot][1][1][i];fwt(ans,5004);scanf("%d",&q);
for(int i=1;i<=q;i++)
{
scanf("%s",opt);
if(opt[0]=='Q'){int t;scanf("%d",&t);printf("%d\n",ans[t]);}
else {int w;int u;scanf("%d%d",&u,&w);gbt.modify(u,w);}
}return 0;//拜拜程序~
}
```