# 求助

@LJB00131  2020-04-04 15:01 回复

#include<bits/stdc++.h>

using namespace std;

#define N 100005
const int mod = 51061;
#define add(x, y) (x + y >= mod ? x + y - mod : x + y)
#define dec(x, y) (x < y ? x - y + mod : x - y)
#define mul(x, y) (x = x * y % mod)
#define ls ch[x][0]
#define rs ch[x][1]

unsigned int n, q, ch[N][2], fa[N], s[N], af[N], mf[N], v[N], st[N], size[N], top = 0, r[N];

bool notroot(int x) {return (ch[fa[x]][0] == x || ch[fa[x]][1] == x);}

inline void update(int x) {s[x] = add(s[ls], add(s[rs], v[x])); size[x] = size[ls] + size[rs] + 1;}

inline void pushm(int x, int c) {mul(s[x], c), mul(v[x], c), mul(mf[x], c), mul(af[x], c);}

inline void pusha(int x, int c) {s[x] = add(s[x], c * size[x]), v[x] = add(v[x], c), af[x] = add(af[x], c);}

void pushdown(int x)
{
if(mf[x] != 1) pushm(ls, mf[x]), pushm(rs, mf[x]), mf[x] = 1;
if(af[x]) pusha(ls, af[x]), pusha(rs, af[x]), af[x] = 0;
if(r[x]) swap(ls, rs), r[ls] ^= 1, r[rs] ^= 1, r[x] = 0;
}

bool which(int x) {return (x == ch[fa[x]][1]);}

void rotate(int x)
{
int y = fa[x], z = fa[y], k = which(x), w = ch[x][k ^ 1];
if(notroot(y)) ch[z][which(y)] = x; ch[x][k ^ 1] = y, ch[y][k] = w;
if(w) fa[w] = y; fa[x] = z, fa[y] = x;
update(y), update(x);
}

void splay(int x)
{
top = 0; int y = x;
st[++top] = y;
while(notroot(y)) {y = fa[y], st[++top] = y;}
while(top) pushdown(st[top--]);
while(notroot(x))
{
int y = fa[x], z = fa[y];
if(notroot(y))
{
if(which(x) == which(y)) rotate(y);
else rotate(x);
}
rotate(x);
}
update(x);
}

void access(int x)
{
for(int y = 0; x; y = x, x = fa[x])
splay(x), ch[x][1] = y, update(x);
}

void make_root(int x)
{
access(x), splay(x);
r[x] ^= 1;
}

void split(int x, int y)
{
make_root(x);
access(y), splay(y);
}

void link(int x, int y)
{
make_root(x);
fa[x] = y;
}

void cut(int x, int y)
{
make_root(x);
fa[y] = ch[x][1] = 0;
update(x);
}

int main()
{
scanf("%d%d", &n, &q);
for(int i = 1; i <= n; i++) v[i] = size[i] = mf[i] = 1;
for(int i = 1; i <= n - 1; i++)
{
int u, v;
scanf("%d%d", &u, &v);
}
while(q--)
{
char opt = getchar(); scanf("%c", &opt);
int u, v, c, u1, u2, v1, v2;
if(opt == '+') {scanf("%d%d%d", &u, &v, &c); split(u, v); pusha(v, c);}
else if(opt == '-') {scanf("%d%d%d%d", &u1, &v1, &u2, &v2); cut(u1, v1); link(u2, v2);}
else if(opt == '*') {scanf("%d%d%d", &u, &v, &c); split(u, v); pushm(v, c);}
else {scanf("%d%d", &u, &v); split(u, v); printf("%d\n", s[v]);}
}
return 0;
}