@LJB00131 2020-04-04 15:01 回复 只有5分/kel #include<bits/stdc++.h> using namespace std; #define N 100005 const int mod = 51061; #define add(x, y) (x + y >= mod ? x + y - mod : x + y) #define dec(x, y) (x < y ? x - y + mod : x - y) #define mul(x, y) (x = x * y % mod) #define ls ch[x][0] #define rs ch[x][1] unsigned int n, q, ch[N][2], fa[N], s[N], af[N], mf[N], v[N], st[N], size[N], top = 0, r[N]; bool notroot(int x) {return (ch[fa[x]][0] == x || ch[fa[x]][1] == x);} inline void update(int x) {s[x] = add(s[ls], add(s[rs], v[x])); size[x] = size[ls] + size[rs] + 1;} inline void pushm(int x, int c) {mul(s[x], c), mul(v[x], c), mul(mf[x], c), mul(af[x], c);} inline void pusha(int x, int c) {s[x] = add(s[x], c * size[x]), v[x] = add(v[x], c), af[x] = add(af[x], c);} void pushdown(int x) { if(mf[x] != 1) pushm(ls, mf[x]), pushm(rs, mf[x]), mf[x] = 1; if(af[x]) pusha(ls, af[x]), pusha(rs, af[x]), af[x] = 0; if(r[x]) swap(ls, rs), r[ls] ^= 1, r[rs] ^= 1, r[x] = 0; } bool which(int x) {return (x == ch[fa[x]][1]);} void rotate(int x) { int y = fa[x], z = fa[y], k = which(x), w = ch[x][k ^ 1]; if(notroot(y)) ch[z][which(y)] = x; ch[x][k ^ 1] = y, ch[y][k] = w; if(w) fa[w] = y; fa[x] = z, fa[y] = x; update(y), update(x); } void splay(int x) { top = 0; int y = x; st[++top] = y; while(notroot(y)) {y = fa[y], st[++top] = y;} while(top) pushdown(st[top--]); while(notroot(x)) { int y = fa[x], z = fa[y]; if(notroot(y)) { if(which(x) == which(y)) rotate(y); else rotate(x); } rotate(x); } update(x); } void access(int x) { for(int y = 0; x; y = x, x = fa[x]) splay(x), ch[x][1] = y, update(x); } void make_root(int x) { access(x), splay(x); r[x] ^= 1; } void split(int x, int y) { make_root(x); access(y), splay(y); } void link(int x, int y) { make_root(x); fa[x] = y; } void cut(int x, int y) { make_root(x); fa[y] = ch[x][1] = 0; update(x); } int main() { scanf("%d%d", &n, &q); for(int i = 1; i <= n; i++) v[i] = size[i] = mf[i] = 1; for(int i = 1; i <= n - 1; i++) { int u, v; scanf("%d%d", &u, &v); link(u, v); } while(q--) { char opt = getchar(); scanf("%c", &opt); int u, v, c, u1, u2, v1, v2; if(opt == '+') {scanf("%d%d%d", &u, &v, &c); split(u, v); pusha(v, c);} else if(opt == '-') {scanf("%d%d%d%d", &u1, &v1, &u2, &v2); cut(u1, v1); link(u2, v2);} else if(opt == '*') {scanf("%d%d%d", &u, &v, &c); split(u, v); pushm(v, c);} else {scanf("%d%d", &u, &v); split(u, v); printf("%d\n", s[v]);} } return 0; }
只有5分/kel