题解 P4501 【[ZJOI2018]胖】

· · 题解

先放博客,然后直接跑Bellman-Ford。

换个角度,考虑每一个修路的点能更新到哪一些点,显然这是一个包含自己的连续区间。

那么对于一个修路的点X,分别二分左右两个端点L,R,以L为例,令l=X-L,如果[L-l,L+l]这个区间内不存在点使得从这里出发到L小于从X出发到L,那么这个L就是可行的。

答案就是求\sum\limits_{i=1}^{K}(R_i-L_i+1)

还要考虑边界问题防止重复计算,比如相同距离编号相近的更优先,相同距离编号差的绝对值也相同就左边优先。

rmq的实现方法也很重要,线段树很慢(没卡过去),我写的是用离散化后的st表。虽然两个复杂度都是nlog^2n的。还有二分查找用不要用库函数,时间有手写的两倍。

#include <bits/stdc++.h>
using namespace std;
int read() {
    int x = 0, f = 1; char ch = getchar();
    while (!isdigit(ch)) { if (ch == '-') f = -1; ch = getchar(); }
    while (isdigit(ch)) { x = x * 10 + ch - '0'; ch = getchar(); }
    return x * f;
}
#define ll long long

const ll Max = 200033;
const ll inf = 9999999999999999ll;
ll n, m, now, K;
ll dis[Max], pos[Max], len[Max], Ans;
int bin[20], Log[Max];

struct Data {
    ll mn, id;
    Data() {}
    Data(ll a, ll b) {
        mn = a, id = b;
    }
};

void init() {
    bin[0] = 1;
    for (int i = 1; i <= 18; i++) {
        bin[i] = bin[i - 1] << 1;
        Log[bin[i]] = i;
    }
    for (int i = 1; i < Max; i++)
        if (!Log[i]) Log[i] = Log[i - 1];
}

struct ST {
    ll d[Max][18];
    void init() {
        for (int j = 1; j <= Log[K]; j++)
            for (int i = 1; i <= K; i++)
                d[i][j] = min(d[i][j - 1], d[i + bin[j - 1]][j - 1]);
    }
    ll query(int L, int R) {
        int len = R - L + 1;
        return min(d[L][Log[len]], d[R - bin[Log[len]] + 1][Log[len]]);
    }
} st[2];

int tmp[Max], t[Max];

int lower(int V) {
    int L = 1, R = K, re = K + 1;
    while (R >= L) {
        int mid = (L + R) >> 1;
        if (tmp[mid] < V) L = mid + 1;
        else R = mid - 1, re = mid;
    }
    return re;
}

int upper(int V) {
    int L = 1, R = K, re = K + 1;
    while (R >= L) {
        int mid = (L + R) >> 1;
        if (tmp[mid] <= V) L = mid + 1;
        else R = mid - 1, re = mid;
    }
    return re;
}

void solve() {
    for (int i = 1; i <= K; i++)
        tmp[i] = pos[i];
    sort(tmp + 1, tmp + K + 1);
    for (int i = 1; i <= K; i++)
        t[tmp[i]] = i;
    for (int i = 1; i <= K; i++) {
        st[0].d[t[pos[i]]][0] = len[i] - dis[pos[i]];
        st[1].d[t[pos[i]]][0] = len[i] + dis[pos[i]];
    }
    st[0].init(), st[1].init();
    Ans = 0;
    for (int i = 1; i <= K; i++) {
        int ql = 1, qr = pos[i] - 1, L = pos[i], R = L;
        while (qr >= ql) {
            bool flag = 1;
            int mid = (ql + qr) >> 1;
            ll nowdis = len[i] + dis[pos[i]] - dis[mid], idx = pos[i] - mid;
            ll dl = inf, dr = dl;

            int qql = lower(max(1ll, mid - idx));
            int qqr = upper(mid) - 1;
            if (qql <= qqr) dl = st[0].query(qql, qqr);

            qql = lower(mid);
            qqr = upper(pos[i] - 1) - 1;
            if (qql <= qqr) dr = st[1].query(qql, qqr);

            if (dl + dis[mid] <= nowdis) flag = 0;
            if (dr - dis[mid] <= nowdis) flag = 0;
            if (flag) qr = mid - 1, L = mid;
            else ql = mid + 1;
        }

        ql = pos[i] + 1, qr = n;
        while (qr >= ql) {
            bool flag = 1;
            int mid = (ql + qr) >> 1;
            ll nowdis = len[i] + dis[mid] - dis[pos[i]], idx = mid - pos[i];
            ll dl = inf, dr = dl;

            int qql = lower(mid);
            int qqr = upper(min(mid + idx - 1, n)) - 1;
            if (qql <= qqr) dr = st[1].query(qql, qqr);
            //cout << "fuck: " << mid << ' ' << dl << ' ' << dr << endl;

            qql = lower(pos[i] + 1);
            qqr = upper(mid) - 1;
            if (qql <= qqr) dl = st[0].query(qql, qqr);

            if (dl + dis[mid] <= nowdis) flag = 0;
            if (dr - dis[mid] <= nowdis) flag = 0;
            if (mid + idx <= n) {
                qqr = upper(min(mid + idx, n)) - 1;
                if (tmp[qqr] == mid + idx)
                    if (st[1].d[qqr][0] - dis[mid] < nowdis)
                        flag = 0;
            }
            if (flag) ql = mid + 1, R = mid;
            else qr = mid - 1;
        }
        Ans += (R - L + 1);
    }
    printf("%lld\n", Ans);
}

signed main() {
    init();
    n = read(), m = read();
    for (int i = 2; i <= n; i++)
        dis[i] = dis[i - 1] + read();
    for (int i = 1; i <= m; i++) {
        K = read();
        for (int j = 1; j <= K; j++)
            pos[j] = read(), len[j] = read();
        solve();
    }
}