题解

· · 题解

验题人题解。

考虑对于一个 x,枚举它的祖先 y 并钦定 y=\operatorname{LCA*}(x,i),设 zy 的儿子且满足 xz 的子树内,那么 yx 的贡献就是 y 子树内的所有点的 f(u+d_y) 减去 z 子树内的所有点的 f(v+d_y),由于 d_z=d_y+1,可以改写成 f(v+d_z-1)

那我们相当于要对每个 y,求出它子树内的 f(u+d_y) 之和、f(u+d_y-1) 之和。这样再做一个根到结点的前缀和就能对于每个 x 做到 O(1) 求得 s_x 了。

容易发现计算 f(u+d_y)f(u+d_y-1) 的做法应当是本质相同的,因此我们下面只讨论计算 f(u+d_y) 之和的做法。

看到这个贡献形式我们首先想到丢到 01-Trie 上面计算贡献。考虑在 dfs 的同时对每个点维护一个 Trie 作为子树内的贡献信息,合并时考虑类似于线段树合并地合并两棵 Trie。

枚举 d_y 显然没什么前途。但是我们注意到 d 的性质和 dfs 结构的相关性,考虑在 u 处插入 u+d_u,而在 y 处需要的贡献是 u+d_y,因此在从 u 向上爬的时候每次将 Trie 的信息做一次全局减 1,就能在 y 处获得正确的贡献。注意为了支持全局减操作,我们使用低位到高位的 01-Trie。

计算 f(u+d_y-1) 的过程也是同理,只要在做全局减 1 的前后都统计一下答案就能分别获得结果。

这样就做完了,时间复杂度是 O(Tn\log n)

#include<bits/stdc++.h>
#define pii pair<int,int>
#define x first
#define y second
#define pb push_back
using namespace std;
const int N = 2e5 + 7;
int tr[N*30][2],sum[N*30];
int num;
int rt[N];
int n,q,mod,r;
vector<int>a[N];
int b[30][2];
int s[N],pl[N],pr[N];
int d[N];
void add(int &x,int y){x += y; if(x>=mod) x -= mod;}
int tmp[30];
void ins(int val,int id){
    int p = id;
    memset(tmp,0,sizeof tmp);
    for(int i=0;i<=20;++i){
        int x = (val >> i) & 1;
        if(!tr[p][x]) tr[p][x] = ++num;
        tmp[i] = b[i][x];
        p = tr[p][x]; 
    }
    for(int i=19;i>=0;--i) tmp[i] = 1ll * tmp[i+1] * tmp[i] % mod;
    p = id;
    for(int i=0;i<=20;++i){
        int x = (val >> i) & 1;
        add(sum[p],tmp[i]);
        p = tr[p][x]; 
    }
}
int merge(int x,int y){
    if(!x) return y;
    if(!y) return x;
    add(sum[x],sum[y]);
    tr[x][0] = merge(tr[x][0],tr[y][0]);
    tr[x][1] = merge(tr[x][1],tr[y][1]);
    return x;
}
void del(int x,int i){
    swap(tr[x][0],tr[x][1]);
    if(tr[x][1]) del(tr[x][1],i+1);
    sum[x] = 1ll * sum[tr[x][0]] * b[i][0] % mod;
    add(sum[x],1ll * sum[tr[x][1]] * b[i][1] % mod);
}
void dfs(int x,int fa){
    d[x] = d[fa] + 1;
    if(!rt[x]) rt[x] = ++num;
    for(auto c:a[x]){
        if(c==fa) continue;
        dfs(c,x);
        rt[x] = merge(rt[x],rt[c]);
    }
    ins(x+d[x],rt[x]);
    pl[x] = sum[rt[x]];
    del(rt[x],0);
    pr[x] = sum[rt[x]];
}
void calc(int x,int fa){
    s[x] = s[fa];
    add(s[x],pl[x]);
    if(x!=r) add(s[x],mod-pr[x]);
    for(auto c:a[x]){
        if(c==fa) continue;
        calc(c,x);
    }
}
void solve(){
    memset(s,0,sizeof s);
    memset(tr,0,sizeof tr);
    memset(rt,0,sizeof rt);
    memset(d,0,sizeof d);
    memset(sum,0,sizeof sum);
    num = 0;
    dfs(r,0);
    calc(r,0);
    long long ans = 0;
    for(int i=1;i<=n;++i) ans ^= 1ll * i * s[i];
    cout<<ans<<'\n';
}
signed main(){
    ios::sync_with_stdio(0);
    cin.tie(nullptr);
    cout.tie(nullptr);
    cin>>n;
    for(int i=1;i<n;++i){
        int u,v;
        cin>>u>>v;
        a[u].push_back(v);
        a[v].push_back(u);
    }
    cin>>q;
    while(q--){
        cin>>r>>mod;
        for(int i=0;i<=20;++i) cin>>b[i][0];
        for(int i=0;i<=20;++i) cin>>b[i][1];
        solve();
    }
    return 0;
}