题解:P14509 树上求值 tree

· · 题解

写题解,祝 NOIP rp++。

T2 赛时 3h 不会,T3 赛后 40min 会正解了?

首先最后这个答案形式肯定是要一个一个求 s_i 的。

我们考虑拆贡献,考虑 x 自己作为 LCA 和 x 的祖先作为 LCA 的两部分。

接下来考虑 $x$ 的祖先作为 LCA 的贡献。令 $fa_x(k)$ 表示 $x$ 的 $k$ 级祖先,$V_x$ 表示以 $x$ 为根的子树的点集,那么对 $x$ 的贡献就是 $u \in V_{fa_x(k)} - V_{fa_x(k-1)}$ 时 $\sum f(u+dep_{fa_x(k)} )$,这个贡献传递给子孙是不会有任何影响的。 所以我们仍然可以利用前面的 01-trie,每一次向上都会全局减一,我们记录每一个点 $x$ 的只考虑子树的贡献做一次全局减一记录。再利用记录的答案向下传递给各个儿子即可。 时间复杂度 $O(T n \log n)$。 ```cpp #include <bits/stdc++.h> using namespace std; const int N=2e5+5; vector<int> e[N]; int T,n,root,A[2][23]; #define ll long long ll mod; int t[N*25][2],s[N*25],rt[N],tot=0; int newnode(){ ++tot; s[tot]=t[tot][0]=t[tot][1]=0; return tot; } int F(int x){ int ans=1; for(int i=0;i<=20;++i) ans=1ll*ans*A[(x>>i)&1][i]%mod; return ans; } void pushup(int u,int w){ s[u]=(1ll*s[t[u][0]]*A[0][w]+1ll*s[t[u][1]]*A[1][w])%mod; } void insert(int &u,int x,int w){ if(!u) u=newnode(); if(w>20){ ++s[u]; return ; } insert(t[u][(x>>w)&1],x,w+1); pushup(u,w); } void del(int u,int w){ swap(t[u][0],t[u][1]); if(t[u][1]) del(t[u][1],w+1); pushup(u,w); } int merge(int x,int y,int w){ if(!x||!y) return x|y; if(w>20){ s[x]+=s[y]; return x; } t[x][0]=merge(t[x][0],t[y][0],w+1); t[x][1]=merge(t[x][1],t[y][1],w+1); pushup(x,w); return x; } int tmp[N],ans[N],Ans[N]; int dep[N],fa[N],siz[N]; void dfs1(int u,int f){ fa[u]=f; dep[u]=dep[f]+1; siz[u]=1; for(int v:e[u]) if(v^f){ dfs1(v,u); siz[u]+=siz[v]; } } void dfs(int u){ for(int v:e[u]) if(v^fa[u]){ dfs(v); del(rt[v],0); tmp[v]=s[rt[v]]; rt[u]=merge(rt[u],rt[v],0); } insert(rt[u],u+dep[u],0); ans[u]=(s[rt[u]])%mod; } void dfs_(int u){ for(int v:e[u]) if(v^fa[u]){ Ans[v]=((Ans[u]+ans[u]-tmp[v])%mod+mod)%mod; dfs_(v); } } int main(){ scanf("%d",&n); for(int i=1;i<n;++i){ int u,v; scanf("%d%d",&u,&v); e[u].push_back(v); e[v].push_back(u); } scanf("%d",&T); while(T--){ scanf("%d%lld",&root,&mod); for(int i=0;i<=20;++i) scanf("%d",&A[0][i]); for(int i=0;i<=20;++i) scanf("%d",&A[1][i]); for(int i=1;i<=n;++i) rt[i]=0; tot=0; Ans[root]=0; dfs1(root,0); dfs(root); dfs_(root); ll res=0; for(int i=1;i<=n;++i) res^=(1ll*(ans[i]+Ans[i])%mod*i); printf("%lld\n",res); } } ```