题解:P14521 【MX-S11-T2】加减乘除

· · 题解

细想一下就不难了。

题目概述

给你一颗树,你一开始权值为 x,你经过一个节点需要跟他进行运算,之后你可以选择在这个点结束或者说继续走到他的其中一个儿子,但是需要满足 x\in[l,r],其中 l,r 是这条边的限制范围。

数据范围 1\leq n\leq 5\times 10^5,1\leq q\leq 10^6

分析

看到数据范围:n,q 不同阶,大概是预处理以及 \mathcal{O}(1) 或者 \mathcal{O}(\log n) 查询的做法。

因此我们考虑,在某个点结束需要一开始的 x 满足什么条件。

假设我的 x 经过了一个点走了一条有限制的边且限制为 x'\in[l,r],这里的 x'x 跟上一个点进行运算的值。

显然你让这个 x' 拆开,将非 x 的部分移项过去算出来就行了。

而且我们能在这个点结束当且仅当在这个点与根节点的路径上的每一个点的范围条件取交集。

显然这个取交集并不会溢出,但是我还是判了。

取出来之后一个点能否贡献当且仅当这个一开始的 x 是否在这个范围内。

很明显询问排个序,区间排个序就可以 \mathcal{O}(n+q) 解决了。

代码

时间复杂度 \mathcal{O}(n+n\log n+q\log q+q)

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <stdlib.h>
#include <vector>
#include <climits>
#define int long long
#define N 500005
#define PII pair<int,int>
#define isdigit(ch) ('0' <= ch && ch <= '9')
using namespace std;
template<typename T>
void read(T&x) {
    x = 0;
    int f = 1;
    char ch = getchar();
    for (;!isdigit(ch);ch = getchar()) f = (ch == '-' ? -1 : f);
    for (;isdigit(ch);ch = getchar()) x = (x << 1) + (x << 3) + (ch ^ 48);
    x *= f;
}
template<typename T>
void write(T x) {
    if (x > 9) write(x / 10);
    putchar(x % 10 + '0');
}
const int INF = 1e18 + 1000;
int n,q;
vector<int> g[N];
struct edge{
    int to;
    int l,r;
};
vector<edge> edges[N];
char op[N];
int a[N],b[N],L[N],R[N],ans[N << 1];
struct que{
    int x,id;
}qu[N << 1];
signed main(){
    cin >> n >> q;
    for (int i = 2;i <= n;i ++) {
        int p,l,r;
        read(p),read(l),read(r);
        edges[p].push_back({i, l, r});
    }
    for (int i = 1;i <= n;i ++) {
        char s;
        cin >> s;
        read(a[i]);
        op[i] = s;
    }
    vector<int> sta;
    sta.push_back(1);
    b[1] = (op[1] == '+') ? a[1] : -a[1];
    L[1] = -INF,R[1] = INF;
    while (!sta.empty()) {
        int u = sta.back();
        sta.pop_back();
        for (auto e : edges[u]) {
            int v = e.to,l = e.l, r = e.r;
            int curL = (l > INF + b[u]) ? INF + 1 : (l < -INF + b[u]) ? -INF - 1 : l - b[u];
            int curR = (r > INF + b[u]) ? INF : (r < -INF + b[u]) ? -INF : r - b[u];
            if (curL < -INF) curL = -INF;
            else if (curL > INF) curL = INF + 1;
            if (curR < -INF) curR = -INF - 1;
            else if (curR > INF) curR = INF;
            L[v] = max(L[u],curL);
            R[v] = min(R[u],curR);
            b[v] = b[u];
            if (op[v] == '+') {
                if (a[v] > 0 && b[v] > LLONG_MAX - a[v]) b[v] = INF;
                else if (a[v] < 0 && b[v] < LLONG_MIN - a[v]) b[v] = -INF;
                else b[v] += a[v];
            }
            else {
                if (a[v] > 0 && b[v] < LLONG_MIN + a[v]) b[v] = -INF;
                else if (a[v] < 0 && b[v] > LLONG_MAX + a[v]) b[v] = INF;
                else b[v] -= a[v];
            }
            if (b[v] < -2 * INF) b[v] = -2 * INF;
            if (b[v] > 2 * INF) b[v] = 2 * INF;
            sta.push_back(v);
        }
    }
    vector<PII> ls2;
    for (int u = 1;u <= n;u ++) {
        if (L[u] > R[u]) continue;
        ls2.push_back({L[u],1ll});
        if (R[u] < INF) ls2.push_back({R[u] + 1,-1});
        else ls2.push_back({INF + 1,-1});
    }
    stable_sort(ls2.begin(),ls2.end());
    for (int i = 1;i <= q;i ++) read(qu[i].x),qu[i].id = i;
    stable_sort(qu + 1,qu + 1 + q,[](que x,que y) {
        return x.x < y.x;
    });
    int j = 0,cnt = 0;
    for (int i = 1;i <= q;i ++) {
        int x = qu[i].x,id = qu[i].id;
        while (j < ls2.size() && ls2[j].first <= x) cnt += ls2[j].second,j ++;
        ans[id] = cnt;
    }
    for (int i = 1;i <= q;i ++) write(ans[i]),putchar('\n');
    return 0;
}