P6138 [IOI2012] 骑马比武竞赛 题解

· · 题解

提供一个无脑做法。

考虑操作的本质是让区间最大值替代整个区间,于是我们要维护序列删掉一个区间,还要维护单点插入。这玩意可以用平衡树做,但是难以拓展到每个点求答案上。

不妨这样:对于每次操作 [l,r],我们都找到 [l,r] 对应的原序列的区间 [l',r'],这个可以用平衡树模拟操作维护。此时,一个人 i 获胜需要满足 i \in [l',r']i 为区间最大值。

我们不妨枚举新的人在哪,二分两边第一个大于这个数的位置,那么问题就转成了一个二维数点问题,由于我比较无脑,所以直接上了 K-D Tree,复杂度 O(n \sqrt{n})

#include <bits/stdc++.h>
#include<ext/rope>
using namespace __gnu_cxx;
using namespace std;
//#define int long long

const int N = 1e5 + 5, MOD = 1e9 + 7, HSMOD = 1610612741, HSMOD2 = 998244353; // Remember to change

long long qpow(long long a, long long b)
{
    long long res = 1ll, base = a;
    while (b)
    {
        if (b & 1ll) res = res * base % MOD;
        base = base * base % MOD;
        b >>= 1ll;
    }
    return res;
}

namespace FastIo
{
    #define QUICKCIN ios::sync_with_stdio(0), cin.tie(0), cout.tie(0)
    int read()
    {
        char ch = getchar();
        int x = 0, f = 1;
        while ((ch < '0' || ch > '9') && ch != '-') ch = getchar();
        while (ch == '-')
        {
            f = -f;
            ch = getchar();
        }
        while (ch >= '0' && ch <= '9')
        {
            x = (x << 1) + (x << 3) + (ch ^ 48);
            ch = getchar();
        }
        return x * f;
    }
    template<class T>
    void write(T x)
    {
        if (x < 0)
        {
            putchar('-');
            x = -x;
        }
        if (x > 9) write(x / 10);
        putchar(x % 10 + '0');
    }
    template<class T>
    void writeln(T x)
    {
        write(x);
        putchar('\n');
    }
}

template<typename T>
class Bit
{
public:
    T lowbit(T x)
    {
        return x & -x;
    }
    T tr[N];
    void add(T x, T y)
    {
        while (x < N)
        {
            tr[x] += y;
            x += lowbit(x);
        }
    }
    T query(T x)
    {
        T sum = 0;
        while (x)
        {
            sum += tr[x];
            x -= lowbit(x);
        }
        return sum;
    }
};

int n, c, r, a[N], s[N], e[N], ns[N], ne[N];

struct Node
{
    int l, r;
    Node() = default;
    Node(int l, int r): l(l), r(r){}
}; 

rope<Node> f;
int nl[N], nr[N];
int ra[N];

int LG2[N];

struct Nd
{
    int x[2];
}ee[N];

class ST
{
public:
    int f[N][21];
    void Init()
    {
        for (int i = 1; i < n; i++)
        {
            f[i][0] = a[i - 1];
        }
        for (int j = 1; j <= LG2[n - 1]; j++)
        {
            for (int i = 1; i + (1 << j) - 1 <= n - 1; i++)
            {
                f[i][j] = max(f[i][j - 1], f[i + (1 << (j - 1))][j - 1]);
            }
        }
    }
    int query(int l, int r)
    {
        int p = LG2[r - l + 1];
        return max(f[l][p], f[r - (1 << p) + 1][p]);
    }   
}st;

class KD_Tree
{
public:
    struct Node
    {
        int lson, rson, x[2];
        int minn[2], maxn[2], sum;
        Node()
        {
            minn[0] = minn[1] = (int)1e9;
            maxn[0] = maxn[1] = -1;
        }
    }tr[N];
    int idx;
    void pushup(int u)
    {
        tr[u].sum = tr[tr[u].lson].sum + tr[tr[u].rson].sum + 1;
        tr[u].minn[0] = min({tr[tr[u].lson].minn[0], tr[tr[u].rson].minn[0], tr[u].x[0]});
        tr[u].minn[1] = min({tr[tr[u].lson].minn[1], tr[tr[u].rson].minn[1], tr[u].x[1]});
        tr[u].maxn[0] = max({tr[tr[u].lson].maxn[0], tr[tr[u].rson].maxn[0], tr[u].x[0]});
        tr[u].maxn[1] = max({tr[tr[u].lson].maxn[1], tr[tr[u].rson].maxn[1], tr[u].x[1]});
    }
    int build(int l, int r, int g)
    {
        int u = ++idx;
        int mid = l + r >> 1;
        nth_element(ee + l, ee + mid, ee + r + 1, [&](const auto& x, const auto& y){return x.x[g] < y.x[g];});
        tr[u].x[0] = ee[mid].x[0], tr[u].x[1] = ee[mid].x[1];
        if (l == r)
        {
            pushup(u);
            return u;
        }
        if (l < mid) tr[u].lson = build(l, mid - 1, g ^ 1);
        if (r > mid) tr[u].rson = build(mid + 1, r, g ^ 1);
        pushup(u);
        return u;
    }
    int query(int u, int X1, int X2, int Y1, int Y2)
    {
        if (!u) return 0;
        if (tr[u].minn[0] >= X1 && tr[u].minn[0] <= X2 && tr[u].maxn[0] >= X1 && tr[u].maxn[0] <= X2 && tr[u].minn[1] >= Y1 && tr[u].minn[1] <= Y2 && tr[u].maxn[1] >= Y1 && tr[u].maxn[1] <= Y2) return tr[u].sum;
        if (tr[u].minn[0] > X2 || tr[u].maxn[0] < X1 || tr[u].minn[1] > Y2 || tr[u].maxn[1] < Y1) return 0;
        int res = (tr[u].x[0] >= X1 && tr[u].x[0] <= X2 && tr[u].x[1] >= Y1 && tr[u].x[1] <= Y2);
        res += query(tr[u].lson, X1, X2, Y1, Y2);
        res += query(tr[u].rson, X1, X2, Y1, Y2);
        return res;
    }
}kdt;

int main()
{
    ios::sync_with_stdio(0), cin.tie(nullptr), cout.tie(nullptr);
    cin >> n >> c >> r;
    for (int i = 2; i < N; i++) LG2[i] = LG2[i >> 1] + 1;
    for (int i = 0; i < n; i++)
    {
        nl[i] = nr[i] = i;
        f.push_back(Node(i, i));
    }
    for (int i = 0; i < n - 1; i++) 
    {
        cin >> a[i];
    }
    st.Init();
    for (int i = 1; i <= c; i++)
    {
        cin >> s[i] >> e[i];
        ns[i] = f[s[i]].l, ne[i] = f[e[i]].r;
        f.erase(s[i], e[i] - s[i] + 1);
        f.insert(s[i], Node(ns[i], ne[i]));
        ee[i].x[0] = ns[i], ee[i].x[1] = ne[i];
    //  cout << "!!!: " << ns[i] << " " << ne[i] << "\n";
    }
    kdt.build(1, c, 0);
//  cout << "????: " << kdt.query(1, 0, 1, 1, 2) << "\n";
    int ans = -1, pos = 0;
    for (int i = -1; i < n - 1; i++)
    {
        int nl = -1, nr = -1;
        // find nl
        /*
        for (int j = i; j >= 0; j--)
        {
            if (a[j] > r)
            {
                nl = j + 1;
                break;
            }
        }
        for (int j = i + 1; j < n - 1; j++)
        {
            if (a[j] > r) 
            {
                nr = j;
                break;
            }
        }*/
        int L = 0, R = i;
        while (L <= R)
        {
            int mid = (L + R) >> 1;
            if (st.query(mid + 1, i + 1) < r)
            {
                R = mid - 1;
                nl = mid;
            }
            else L = mid + 1;
        }
        if (nl == -1) nl = i + 1;
        // find nr
        L = i + 1, R = n - 2;
        while (L <= R)
        {
            int mid = (L + R) >> 1;
            if (st.query(i + 2, mid + 1) < r)
            {
                nr = mid + 1;
                L = mid + 1;
            }
            else R = mid - 1;
        }
        if (nr == -1) nr = i + 1;
        int X1 = nl, X2 = min(nr, i + 1), Y1 = max(nl, i + 1), Y2 = nr;
        int cnt = kdt.query(1, X1, X2, Y1, Y2);
        //cout << "!!!!!: " << X1 <<" " << X2 << " " << Y1 << " " << Y2 << " " << cnt << "\n";
        if (cnt > ans)
        {
            ans = cnt;
            pos = i;
        }
    }
    cout << pos + 1 << "\n";
    return 0;
}