题解:AT_abc428_f [ABC428F] Pyramid Alignment

· · 题解

AT_abc428_f [ABC428F] Pyramid Alignment

Problem

N 个滑块叠成一摞,从上往下第 i 个滑块的长为 w_i,且上面滑块的长度总是严格小于下面滑块的长度。初始时,每个滑块的左端点都为 0

现在有 Q 组询问,对询问 12,需将所有位于 v 上的滑块的左端点/右端点设为 v 的左端点/右端点;对询问 3,询问覆盖了点 x+0.5 的滑块数量。

Solution

考虑一个性质:若点 x 被滑块 v 覆盖到了,则滑块 v 及其以下的滑块一定都可以覆盖 x,显然这是具有单调性的,那么我们可以二分最上面的可以覆盖 x 的滑块 i,最后 n-i+1 即为答案。我们考虑如何维护每个滑块的位置信息。

对于一个滑块,无论我们明确了它的左端点还是右端点,都能的到它的位置信息。那么对于一次修改操作,我们先求出滑块 v 的左端点/右端点,然后对它上面的滑块做区间覆盖(另一个端点赋为 -1 表示尚不明确)。这样无论何时我们总是能得知滑块 i 其中一个端点的位置。

Time complexity

二分+线段树,时间复杂度 O(n\log ^2n)

Code

#include<bits/stdc++.h>
using namespace std;
const int N = 200005;
int n,q,w[N];
struct node
{
    int l,r;
}f[N<<2],g[N<<2];
void build(int root,int l,int r)
{
    if (l >= r)
    {
        f[root] = {0,-1};
        return;
    }
    int mid = (l+r)/2;
    build(root*2,l,mid);
    build(root*2+1,mid+1,r);
}
inline void pushdown(int root)
{
    f[root*2] = f[root*2+1] = g[root*2] = g[root*2+1] = g[root];
    g[root] = {0,0};
}
void update(bool isl,int root,int l,int r,int ql,int qr,int x) //isl表示覆盖的是左端点还是右端点
{
    if (ql <= l && qr >= r)
    {
        if (isl) f[root] = g[root] = {x,-1};
        else f[root] = g[root] = {-1,x};
        return;
    }
    if (g[root].l != 0 || g[root].r != 0) pushdown(root);
    int mid = (l+r)/2;
    if (qr <= mid) update(isl,root*2,l,mid,ql,qr,x);
    else if (ql > mid) update(isl,root*2+1,mid+1,r,ql,qr,x);
    else update(isl,root*2,l,mid,ql,mid,x),update(isl,root*2+1,mid+1,r,mid+1,qr,x);
}
node query(int root,int l,int r,int x)
{
    if (l >= r) return f[root];
    if (g[root].l != 0 || g[root].r != 0) pushdown(root);
    int mid = (l+r)/2;
    if (x <= mid) return query(root*2,l,mid,x);
    return query(root*2+1,mid+1,r,x);
}
inline int get_x_l(int x)
{
    node k = query(1,1,n,x);
    if (k.l == -1) return k.r-w[x];
    return k.l;
}
inline int get_x_r(int x)
{
    node k = query(1,1,n,x);
    if (k.r == -1) return k.l+w[x];
    return k.r;
}
signed main()
{
    ios_base::sync_with_stdio(0),cin.tie(0),cout.tie(0);
    cin >> n;
    for (int i = 1; i <= n; i++)
        cin >> w[i];
    build(1,1,n);
    cin >> q;
    while (q--)
    {
        int op,x;
        cin >> op >> x;
        if (op == 1)
        {
            update(1,1,1,n,1,x,get_x_l(x));
        }
        else if (op == 2)
        {
            update(0,1,1,n,1,x,get_x_r(x));
        }
        else
        {
            int l = 1,r = n+1;
            while (l < r)
            {
                int mid = (l+r)/2;
                if (x+0.5 >= get_x_l(mid) && x+0.5 <= get_x_r(mid)) r = mid;
                else l = mid+1;
            }
            cout << n-l+1 << '\n';
        }
    }
    return 0;
}