题解:P14509 树上求值 tree
StarsIntoSea
·
·
题解
写题解,祝 NOIP rp++。
T2 赛时 3h 不会,T3 赛后 40min 会正解了?
首先最后这个答案形式肯定是要一个一个求 s_i 的。
我们考虑拆贡献,考虑 x 自己作为 LCA 和 x 的祖先作为 LCA 的两部分。
接下来考虑 $x$ 的祖先作为 LCA 的贡献。令 $fa_x(k)$ 表示 $x$ 的 $k$ 级祖先,$V_x$ 表示以 $x$ 为根的子树的点集,那么对 $x$ 的贡献就是 $u \in V_{fa_x(k)} - V_{fa_x(k-1)}$ 时 $\sum f(u+dep_{fa_x(k)} )$,这个贡献传递给子孙是不会有任何影响的。
所以我们仍然可以利用前面的 01-trie,每一次向上都会全局减一,我们记录每一个点 $x$ 的只考虑子树的贡献做一次全局减一记录。再利用记录的答案向下传递给各个儿子即可。
时间复杂度 $O(T n \log n)$。
```cpp
#include <bits/stdc++.h>
using namespace std;
const int N=2e5+5;
vector<int> e[N];
int T,n,root,A[2][23];
#define ll long long
ll mod;
int t[N*25][2],s[N*25],rt[N],tot=0;
int newnode(){
++tot;
s[tot]=t[tot][0]=t[tot][1]=0;
return tot;
}
int F(int x){
int ans=1;
for(int i=0;i<=20;++i) ans=1ll*ans*A[(x>>i)&1][i]%mod;
return ans;
}
void pushup(int u,int w){
s[u]=(1ll*s[t[u][0]]*A[0][w]+1ll*s[t[u][1]]*A[1][w])%mod;
}
void insert(int &u,int x,int w){
if(!u) u=newnode();
if(w>20){
++s[u]; return ;
}
insert(t[u][(x>>w)&1],x,w+1);
pushup(u,w);
}
void del(int u,int w){
swap(t[u][0],t[u][1]);
if(t[u][1]) del(t[u][1],w+1);
pushup(u,w);
}
int merge(int x,int y,int w){
if(!x||!y) return x|y;
if(w>20){
s[x]+=s[y]; return x;
}
t[x][0]=merge(t[x][0],t[y][0],w+1);
t[x][1]=merge(t[x][1],t[y][1],w+1);
pushup(x,w);
return x;
}
int tmp[N],ans[N],Ans[N];
int dep[N],fa[N],siz[N];
void dfs1(int u,int f){
fa[u]=f;
dep[u]=dep[f]+1;
siz[u]=1;
for(int v:e[u]) if(v^f){
dfs1(v,u);
siz[u]+=siz[v];
}
}
void dfs(int u){
for(int v:e[u]) if(v^fa[u]){
dfs(v);
del(rt[v],0);
tmp[v]=s[rt[v]];
rt[u]=merge(rt[u],rt[v],0);
}
insert(rt[u],u+dep[u],0);
ans[u]=(s[rt[u]])%mod;
}
void dfs_(int u){
for(int v:e[u]) if(v^fa[u]){
Ans[v]=((Ans[u]+ans[u]-tmp[v])%mod+mod)%mod;
dfs_(v);
}
}
int main(){
scanf("%d",&n);
for(int i=1;i<n;++i){
int u,v; scanf("%d%d",&u,&v);
e[u].push_back(v);
e[v].push_back(u);
}
scanf("%d",&T);
while(T--){
scanf("%d%lld",&root,&mod);
for(int i=0;i<=20;++i) scanf("%d",&A[0][i]);
for(int i=0;i<=20;++i) scanf("%d",&A[1][i]);
for(int i=1;i<=n;++i) rt[i]=0;
tot=0;
Ans[root]=0;
dfs1(root,0); dfs(root);
dfs_(root);
ll res=0;
for(int i=1;i<=n;++i) res^=(1ll*(ans[i]+Ans[i])%mod*i);
printf("%lld\n",res);
}
}
```