P3781 [SDOI2017] 切树游戏
做一下 FWT 就变成长度为
代码非常好写。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int P = 1e4 + 7, I2 = (P + 1) / 2;
int n, m, q;
void FWT(ll *a, int tp) {
ll t = tp == 1 ? 1 : I2;
for (int w = 1; w < m; w <<= 1) {
for (int i = 0; i < m; i += (w << 1)) {
for (int j = 0; j < w; j++) {
ll x = a[i + j], y = a[i + j + w];
a[i + j] = (x + y) * t % P;
a[i + j + w] = (x - y + P) * t % P;
}
}
}
}
ll fwt[135][135], a[30005];
int sn[30005], siz[30005];
vector<int> e[30005];
struct Dat {
ll w[135];
Dat() {}
Dat(ll x) {
memcpy(w, fwt[x], sizeof(w));
}
Dat operator+(const Dat &b) const {
Dat res;
for (int i = 0; i < m; i++) res.w[i] = (w[i] + b.w[i]) % P;
return res;
}
Dat operator*(const Dat &b) const {
Dat res;
for (int i = 0; i < m; i++) res.w[i] = w[i] * b.w[i] % P;
return res;
}
} I, O;
enum Type {
NIL, COMPRESS, RAKE
};
struct Cluster {
int x, y;
Type tp;
Dat w[2][2];
} f[60005];
int ls[60005], rs[60005], prt[60005];
void Pushup(int p) {
int ls = ::ls[p], rs = ::rs[p];
for (int i = 0; i < 2; i++) {
for (int j = 0; j < 2; j++) {
f[p].w[i][j] = O;
}
}
if (f[p].tp == COMPRESS) {
f[p].x = f[ls].x, f[p].y = f[rs].y;
for (int i = 0; i < 2; i++) {
f[p].w[i][0] = f[p].w[i][0] + f[ls].w[i][0];
f[p].w[0][i] = f[p].w[0][i] + f[rs].w[0][i];
}
for (int i = 0; i < 2; i++) {
for (int j = 0; j < 2; j++) {
f[p].w[i][j] = f[p].w[i][j] + f[ls].w[i][1] * f[rs].w[1][j];
}
}
}
else {
f[p].x = f[ls].x, f[p].y = f[ls].y;
for (int i = 0; i < 2; i++) {
f[p].w[0][i] = f[p].w[0][i] + f[ls].w[0][i];
f[p].w[0][0] = f[p].w[0][0] + f[rs].w[0][i];
}
for (int j = 0; j < 2; j++) {
f[p].w[1][j] = f[ls].w[1][j] * (f[rs].w[1][0] + f[rs].w[1][1]);
}
}
}
void DFS1(int u, int fa) {
f[u].x = fa, f[u].y = u;
f[u].w[0][1] = f[u].w[1][1] = Dat(a[u]);
f[u].w[1][0] = I;
siz[u] = 1;
for (int v : e[u]) {
if (v == fa) continue;
DFS1(v, u);
siz[u] += siz[v];
if (siz[v] > siz[sn[u]]) sn[u] = v;
}
}
typedef vector<pair<int, int>> V;
int cnt;
int Div(V::iterator L, V::iterator R, Type tp) {
if (L + 1 == R) return L->second;
auto M = lower_bound(L, R, make_pair((L->first + prev(R)->first + 1) / 2, 0));
if (M == L) M++;
int x = Div(L, M, tp), y = Div(M, R, tp), p = ++cnt;
f[p].tp = tp;
ls[p] = x, rs[p] = y;
prt[x] = prt[y] = p;
Pushup(p);
return p;
}
int rt, las;
int DFS2(int u, int fa) {
V li;
li.push_back({ 1, u });
for (int v = u, w = fa; sn[v]; w = v, v = sn[v]) {
V t; t.push_back({ 1, sn[v] });
for (int x : e[v]) {
if (x == w || x == sn[v]) continue;
t.push_back({ t.back().first + siz[x], DFS2(x, v) });
}
li.push_back({ li.back().first + siz[v] - siz[sn[v]], Div(t.begin(), t.end(), RAKE) });
}
return Div(li.begin(), li.end(), COMPRESS);
}
int main() {
scanf("%d%d", &n, &m), cnt = n;
for (int i = 0; i < m; i++) {
fwt[i][i] = 1;
FWT(fwt[i], 1);
}
I = Dat(0);
for (int i = 1; i <= n; i++) scanf("%lld", a + i);
for (int i = 1, u, v; i < n; i++) {
scanf("%d%d", &u, &v);
e[u].push_back(v), e[v].push_back(u);
}
DFS1(1, 0);
rt = DFS2(1, 0);
scanf("%d", &q);
while (q--) {
char s[10]; scanf("%s", s);
if (!strcmp(s, "Query")) {
int k; scanf("%d", &k);
Dat res = O;
for (int j = 0; j < 2; j++) res = res + f[rt].w[0][j];
FWT(res.w, -1);
printf("%lld\n", res.w[k]);
}
else {
int x; ll y; scanf("%d%lld", &x, &y);
f[x].w[0][1] = f[x].w[1][1] = Dat(a[x] = y);
for (int u = prt[x]; u; u = prt[u]) Pushup(u);
}
}
return 0;
}