求助

回复帖子

@LJB00131  2020-04-04 15:01 回复

只有5分/kel

#include<bits/stdc++.h>

using namespace std;

#define N 100005
const int mod = 51061;
#define add(x, y) (x + y >= mod ? x + y - mod : x + y)
#define dec(x, y) (x < y ? x - y + mod : x - y)
#define mul(x, y) (x = x * y % mod)
#define ls ch[x][0]
#define rs ch[x][1]

unsigned int n, q, ch[N][2], fa[N], s[N], af[N], mf[N], v[N], st[N], size[N], top = 0, r[N];

bool notroot(int x) {return (ch[fa[x]][0] == x || ch[fa[x]][1] == x);}

inline void update(int x) {s[x] = add(s[ls], add(s[rs], v[x])); size[x] = size[ls] + size[rs] + 1;}

inline void pushm(int x, int c) {mul(s[x], c), mul(v[x], c), mul(mf[x], c), mul(af[x], c);}

inline void pusha(int x, int c) {s[x] = add(s[x], c * size[x]), v[x] = add(v[x], c), af[x] = add(af[x], c);}

void pushdown(int x)
{
    if(mf[x] != 1) pushm(ls, mf[x]), pushm(rs, mf[x]), mf[x] = 1;
    if(af[x]) pusha(ls, af[x]), pusha(rs, af[x]), af[x] = 0;
    if(r[x]) swap(ls, rs), r[ls] ^= 1, r[rs] ^= 1, r[x] = 0;
}

bool which(int x) {return (x == ch[fa[x]][1]);}

void rotate(int x)
{
    int y = fa[x], z = fa[y], k = which(x), w = ch[x][k ^ 1];
    if(notroot(y)) ch[z][which(y)] = x; ch[x][k ^ 1] = y, ch[y][k] = w;
    if(w) fa[w] = y; fa[x] = z, fa[y] = x;
    update(y), update(x);
}

void splay(int x)
{
    top = 0; int y = x;
    st[++top] = y;
    while(notroot(y)) {y = fa[y], st[++top] = y;}
    while(top) pushdown(st[top--]);
    while(notroot(x))
    {
        int y = fa[x], z = fa[y];
        if(notroot(y))
        {
            if(which(x) == which(y)) rotate(y);
            else rotate(x);
        }
        rotate(x);
    }
    update(x);
}

void access(int x)
{
    for(int y = 0; x; y = x, x = fa[x])
        splay(x), ch[x][1] = y, update(x);
}

void make_root(int x)
{
    access(x), splay(x);
    r[x] ^= 1;
}

void split(int x, int y)
{
    make_root(x);
    access(y), splay(y);
}

void link(int x, int y)
{
    make_root(x);
    fa[x] = y;
}

void cut(int x, int y)
{
    make_root(x);
    fa[y] = ch[x][1] = 0;
    update(x);
}

int main()
{
    scanf("%d%d", &n, &q);
    for(int i = 1; i <= n; i++) v[i] = size[i] = mf[i] = 1;
    for(int i = 1; i <= n - 1; i++)
    {
        int u, v;
        scanf("%d%d", &u, &v);
        link(u, v);
    }
    while(q--)
    {
        char opt = getchar(); scanf("%c", &opt);
        int u, v, c, u1, u2, v1, v2;
        if(opt == '+') {scanf("%d%d%d", &u, &v, &c); split(u, v); pusha(v, c);}
        else if(opt == '-') {scanf("%d%d%d%d", &u1, &v1, &u2, &v2); cut(u1, v1); link(u2, v2);}
        else if(opt == '*') {scanf("%d%d%d", &u, &v, &c); split(u, v); pushm(v, c);}
        else {scanf("%d%d", &u, &v); split(u, v); printf("%d\n", s[v]);}
    }
    return 0;
}
反馈
如果你认为某个帖子有问题,欢迎向洛谷反馈,以帮助更多的同学。



请具体说明理由,以增加反馈的可信度。