P3781 [SDOI2017] 切树游戏

· · 题解

做一下 FWT 就变成长度为 m 数组对位相乘,也就是我们要维护 m 个数据结构,每个支持单点修改,求所有连通块点权乘积的和。上静态 Top Tree 即可。维护上下界点是否被选的答案,注意为了合并方便,这里的点权乘积不算上界点的点权。

代码非常好写。

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int P = 1e4 + 7, I2 = (P + 1) / 2;
int n, m, q;
void FWT(ll *a, int tp) {
    ll t = tp == 1 ? 1 : I2;
    for (int w = 1; w < m; w <<= 1) {
        for (int i = 0; i < m; i += (w << 1)) {
            for (int j = 0; j < w; j++) {
                ll x = a[i + j], y = a[i + j + w];
                a[i + j] = (x + y) * t % P;
                a[i + j + w] = (x - y + P) * t % P;
            }
        }
    }
}
ll fwt[135][135], a[30005];
int sn[30005], siz[30005];
vector<int> e[30005];
struct Dat {
    ll w[135];
    Dat() {}
    Dat(ll x) {
        memcpy(w, fwt[x], sizeof(w));
    }
    Dat operator+(const Dat &b) const {
        Dat res;
        for (int i = 0; i < m; i++) res.w[i] = (w[i] + b.w[i]) % P;
        return res;
    }
    Dat operator*(const Dat &b) const {
        Dat res;
        for (int i = 0; i < m; i++) res.w[i] = w[i] * b.w[i] % P;
        return res;
    }
} I, O;
enum Type {
    NIL, COMPRESS, RAKE
};
struct Cluster {
    int x, y;
    Type tp;
    Dat w[2][2];
} f[60005];
int ls[60005], rs[60005], prt[60005];
void Pushup(int p) {
    int ls = ::ls[p], rs = ::rs[p];
    for (int i = 0; i < 2; i++) {
        for (int j = 0; j < 2; j++) {
            f[p].w[i][j] = O;
        }
    }
    if (f[p].tp == COMPRESS) {
        f[p].x = f[ls].x, f[p].y = f[rs].y;
        for (int i = 0; i < 2; i++) {
            f[p].w[i][0] = f[p].w[i][0] + f[ls].w[i][0];
            f[p].w[0][i] = f[p].w[0][i] + f[rs].w[0][i];
        }
        for (int i = 0; i < 2; i++) {
            for (int j = 0; j < 2; j++) {
                f[p].w[i][j] = f[p].w[i][j] + f[ls].w[i][1] * f[rs].w[1][j];
            }
        }
    }
    else {
        f[p].x = f[ls].x, f[p].y = f[ls].y;
        for (int i = 0; i < 2; i++) {
            f[p].w[0][i] = f[p].w[0][i] + f[ls].w[0][i];
            f[p].w[0][0] = f[p].w[0][0] + f[rs].w[0][i];
        }
        for (int j = 0; j < 2; j++) {
            f[p].w[1][j] = f[ls].w[1][j] * (f[rs].w[1][0] + f[rs].w[1][1]);
        }
    }
}
void DFS1(int u, int fa) {
    f[u].x = fa, f[u].y = u;
    f[u].w[0][1] = f[u].w[1][1] = Dat(a[u]);
    f[u].w[1][0] = I;
    siz[u] = 1;
    for (int v : e[u]) {
        if (v == fa) continue;
        DFS1(v, u);
        siz[u] += siz[v];
        if (siz[v] > siz[sn[u]]) sn[u] = v;
    }
}
typedef vector<pair<int, int>> V;
int cnt;
int Div(V::iterator L, V::iterator R, Type tp) {
    if (L + 1 == R) return L->second;
    auto M = lower_bound(L, R, make_pair((L->first + prev(R)->first + 1) / 2, 0));
    if (M == L) M++;
    int x = Div(L, M, tp), y = Div(M, R, tp), p = ++cnt;
    f[p].tp = tp;
    ls[p] = x, rs[p] = y;
    prt[x] = prt[y] = p;
    Pushup(p);
    return p;
}
int rt, las;
int DFS2(int u, int fa) {
    V li;
    li.push_back({ 1, u });
    for (int v = u, w = fa; sn[v]; w = v, v = sn[v]) {
        V t; t.push_back({ 1, sn[v] });
        for (int x : e[v]) {
            if (x == w || x == sn[v]) continue;
            t.push_back({ t.back().first + siz[x], DFS2(x, v) });
        }
        li.push_back({ li.back().first + siz[v] - siz[sn[v]], Div(t.begin(), t.end(), RAKE) });
    }
    return Div(li.begin(), li.end(), COMPRESS);
}
int main() {
    scanf("%d%d", &n, &m), cnt = n;
    for (int i = 0; i < m; i++) {
        fwt[i][i] = 1;
        FWT(fwt[i], 1);
    }
    I = Dat(0);
    for (int i = 1; i <= n; i++) scanf("%lld", a + i);
    for (int i = 1, u, v; i < n; i++) {
        scanf("%d%d", &u, &v);
        e[u].push_back(v), e[v].push_back(u);
    }
    DFS1(1, 0);
    rt = DFS2(1, 0);
    scanf("%d", &q);
    while (q--) {
        char s[10]; scanf("%s", s);
        if (!strcmp(s, "Query")) {
            int k; scanf("%d", &k);
            Dat res = O;
            for (int j = 0; j < 2; j++) res = res + f[rt].w[0][j];
            FWT(res.w, -1);
            printf("%lld\n", res.w[k]);
        }
        else {
            int x; ll y; scanf("%d%lld", &x, &y);
            f[x].w[0][1] = f[x].w[1][1] = Dat(a[x] = y);
            for (int u = prt[x]; u; u = prt[u]) Pushup(u);
        }
    }
    return 0;
}