题解:P12866 [JOI Open 2025] 抽奖 / Lottery

· · 题解

upd:优化了主席树写法,增添了小波矩阵的代码

显然我们只关心能获奖的取球方案,不妨先令询问 L=0,R=n-1

二分答案求是否能取 K 轮,若存在 x_i+y_i<K 显然无解。

假设最终方案中袋 i 取了 a_i 个红球,那么有 mn_i \le a_i \le mx_i,其中 mn_i=\max(K-y_i,0),mx_i=\min(x_i,K)

问题转为给定 mn_i \le mx_i 和初始为 0 的 a_i,需要进行 K 次操作,每次选择 \frac{n}{2}a_i +1,询问是否存在操作方案,使得最终 \forall i,mn_i \le a_i \le mx_i

可以证明,存在操作方案等价于 \sum mn_i \le \frac{nK}{2} \le \sum mx_i,必要性易证,对于充分性,考虑若最终存在合法的 a_i,进行 K 轮操作,每轮操作找 \frac{n}{2} 次最大的 a_i 并令其 -1,归纳可得 K 轮操作后一定有 a_i=0,将 -1 的合法操作方案逆序即可得到 +1 的合法操作方案,又因为 \sum mn_i \le \frac{nK}{2} \le \sum mx_i,最终合法的 a_i 是容易构造的。

那么可取 K 轮等价于 \sum \max(K-y_i,0) \le \frac{nK}{2} \le \sum \min(x_i,K),也即 \frac{nK}{2} \le \min(\sum \min(x_i,K),\sum \min(y_i,K))

x_i,y_i 建立权值线段树,即可 O(\log^2{V}) 地解决每次全局询问,注意到二分答案和权值树查询可以合并,线段树二分即可做到 O(\log{V})

然而发现条件可以拆成 K \le x_i+y_i,K \le \sum \min(x_i,K),K \le \sum \min(y_i,K),且三个条件各自均可二分,于是可以减小编码复杂度

离散化后求出原序列上的最优答案,此时 \le K>K 的元素都已确定,解不等式即可,优化 O(\log{V})O(\log{n})

对于区间询问,建立主席树即可,总复杂度 O((n+q)\log{n})

#include <bits/stdc++.h>
using namespace std;
namespace staring
{
    using LL = long long;
    using ULL = unsigned long long;
    #define fir first
    #define sec second

    #define FOR(i,a,b) for(int i = (a), i##E = (b); i <= i##E; i ++)
    #define ROF(i,a,b) for(int i = (a), i##E = (b); i >= i##E; i --)

    template <typename TYPE>
    int gmax(TYPE &x, const TYPE& y) {return x < y ? x = y, 1 : 0;}
    template <typename TYPE>
    int gmin(TYPE &x, const TYPE& y) {return y < x ? x = y, 1 : 0;}

    static constexpr int SIZE = 1 << 20;
    static char buffin[SIZE]{}, *pin1{}, *pin2{};
    static char buffout[SIZE]{}, *pout{buffout};
    #define GETC() (pin1 == pin2 && (pin2 = (pin1 = buffin) + fread(buffin, 1, SIZE, stdin), pin1 == pin2)? EOF : *pin1++)
    #define PUTC(c) (pout - buffout == SIZE && (fwrite(buffout, 1, SIZE, stdout), pout = buffout), (*pout++ = c))
    template <typename TYPE>
    void read(TYPE &x)
    {
        static int signf{0}, chin{0};
        x = signf = 0, chin = GETC();
        while(chin < '0' || chin > '9') signf |= chin == '-', chin = GETC();
        while(chin >= '0' && chin <= '9') x = (x << 3) + (x << 1) + (chin ^ 48), chin = GETC();
        if(signf) x = -x;
    }
    template <typename TYPE>
    void write(TYPE x, char ch = ' ')
    {
        static int stack[64]{}, top{0};
        !x && PUTC('0'), x < 0 && (x = -x, PUTC('-'));
        while(x) stack[top++] = x % 10, x /= 10;
        while(top) PUTC(stack[--top] | 48);
        if(ch) PUTC(ch);
    }

}using namespace staring;

using VEC = vector <int>;
constexpr int N = 2e5 + 5, A = 20;
constexpr int M = N * A;

int st[A][N];

struct waveletMatrix
{
    int tot, rt[N];
    VEC lsh;
    struct
    {
        int lc, rc;
        LL cnt, sum;
    }
    tr[M];

    #define mid (l + r >> 1)

    void insert(int k, int& p, int q, int l, int r)
    {
        p = ++tot, tr[p] = tr[q], ++tr[p].cnt, tr[p].sum += lsh[k - 1];
        if(l < r) k <= mid ? insert(k, tr[p].lc, tr[q].lc, l, mid) : insert(k, tr[p].rc, tr[q].rc, mid + 1, r);
    }

    void build(int n, VEC vec)
    {
        lsh = vec;
        sort(begin(lsh), end(lsh));
        lsh.erase(unique(begin(lsh), end(lsh)), end(lsh));
        FOR(i, 1, n)
        {
            int v = lower_bound(begin(lsh), end(lsh), vec[i - 1]) - begin(lsh) + 1;
            insert(v, rt[i], rt[i - 1], 1, size(lsh));
        }
    }

    int query(LL s, LL c, int p, int q, int l, int r)
    {
        if(l == r) return - (s + tr[p].sum - tr[q].sum) / c;
        LL ss = s + tr[tr[p].lc].sum - tr[tr[q].lc].sum;
        LL cc = c + tr[tr[p].rc].cnt - tr[tr[q].rc].cnt;
        if(ss + cc * lsh[mid] >= 0)
            return query(ss, c, tr[p].rc, tr[q].rc, mid + 1, r);
        return query(s, cc, tr[p].lc, tr[q].lc, l, mid);
    }

    int ask(int L, int R)
    {
        return query(0, -(R - L + 1 >> 1), rt[R], rt[L - 1], 1, size(lsh));
    }
}
xtre, ytre;

void init(int n, int q, VEC x, VEC y)
{
    FOR(i, 1, n) st[0][i] = x[i - 1] + y[i - 1];
    FOR(u, 1, A - 1)
        FOR(i, 1, n - (1 << u) + 1)
            st[u][i] = min(st[u - 1][i], st[u - 1][i + (1 << u - 1)]);
    xtre.build(n, x);
    ytre.build(n, y);
}

int max_prize(int L, int R)
{
    ++L, ++R;
    int k = __lg(R - L + 1);
    return min(min(st[k][L], st[k][R - (1 << k) + 1]), min(xtre.ask(L, R), ytre.ask(L, R)));
}

官解告诉我们,上面问题的形式也可以用小波矩阵 wavelet matrix 解决

如果你不知道什么是小波矩阵,这里自荐一下我的 小波矩阵博客,里面具体记述了这里小波矩阵的实现

在小波矩阵上二分即可,由于要记录元素和,时空复杂度与主席树相同:O((n+q)\log{n})-O(n\log{n})

#include <bits/stdc++.h>
using namespace std;
namespace staring
{
    using LL = long long;
    using ULL = unsigned long long;
    #define fir first
    #define sec second

    #define FOR(i,a,b) for(int i = (a), i##E = (b); i <= i##E; i ++)
    #define ROF(i,a,b) for(int i = (a), i##E = (b); i >= i##E; i --)

    template <typename TYPE>
    int gmax(TYPE &x, const TYPE& y) {return x < y ? x = y, 1 : 0;}
    template <typename TYPE>
    int gmin(TYPE &x, const TYPE& y) {return y < x ? x = y, 1 : 0;}

    static constexpr int SIZE = 1 << 20;
    static char buffin[SIZE]{}, *pin1{}, *pin2{};
    static char buffout[SIZE]{}, *pout{buffout};
    #define GETC() (pin1 == pin2 && (pin2 = (pin1 = buffin) + fread(buffin, 1, SIZE, stdin), pin1 == pin2)? EOF : *pin1++)
    #define PUTC(c) (pout - buffout == SIZE && (fwrite(buffout, 1, SIZE, stdout), pout = buffout), (*pout++ = c))
    template <typename TYPE>
    void read(TYPE &x)
    {
        static int signf{0}, chin{0};
        x = signf = 0, chin = GETC();
        while(chin < '0' || chin > '9') signf |= chin == '-', chin = GETC();
        while(chin >= '0' && chin <= '9') x = (x << 3) + (x << 1) + (chin ^ 48), chin = GETC();
        if(signf) x = -x;
    }
    template <typename TYPE>
    void write(TYPE x, char ch = ' ')
    {
        static int stack[64]{}, top{0};
        !x && PUTC('0'), x < 0 && (x = -x, PUTC('-'));
        while(x) stack[top++] = x % 10, x /= 10;
        while(top) PUTC(stack[--top] | 48);
        if(ch) PUTC(ch);
    }

}using namespace staring;

using VEC = vector <int>;
constexpr int N = 2e5 + 5, M = 20;

int st[M][N];

struct waveletMatrix
{
    int m;
    LL sum[M][N], cnt[M][N];
    int lsh[N], val[N], tmp[N];
    int pos[M];
    int tot;

    void init(int n, VEC vec)
    {
        FOR(i, 1, n) lsh[i] = vec[i - 1];
        sort(lsh + 1, lsh + n + 1);
        tot = unique(lsh + 1, lsh + n + 1) - lsh - 1;
        FOR(i, 1, n) val[i] = lower_bound(lsh + 1, lsh + tot + 1, vec[i - 1]) - lsh;
        m = __lg(tot) + 1;

        ROF(u, m - 1, 0)
        {
            int x = 0, y = 0;
            FOR(i, 1, n)
                if(val[i] >> u & 1)
                {
                    sum[u][i] = sum[u][i - 1];
                    cnt[u][i] = cnt[u][i - 1] + 1;
                    tmp[++y] = val[i];
                }
                else
                {
                    cnt[u][i] = cnt[u][i - 1];
                    sum[u][i] = sum[u][i - 1] + lsh[val[i]];
                    val[++x] = val[i];
                }
            pos[u] = x;
            memcpy(val + x + 1, tmp + 1, sizeof(int[y]));
        }
    }

    int ask(int l, int r)
    {
        int res = 0;
        LL s = 0, c = -(r - l + 1 >> 1);
        ROF(u, m - 1, 0)
            if(int cur = res | (1 << u); cur <= tot &&
            s + sum[u][r] - sum[u][l - 1] + (c + cnt[u][r] - cnt[u][l - 1]) * lsh[cur] >= 0)
            {
                res |= 1 << u;
                s += sum[u][r] - sum[u][l - 1];
                l = pos[u] + cnt[u][l - 1] + 1;
                r = pos[u] + cnt[u][r];
            }
            else
            {
                c += cnt[u][r] - cnt[u][l - 1];
                l -= cnt[u][l - 1];
                r -= cnt[u][r];
            }
        return - (s + (r - l + 1ll) * lsh[res]) / c;
    }
}
xmat, ymat;

void init(int n, int q, VEC x, VEC y)
{
    FOR(i, 1, n) st[0][i] = x[i - 1] + y[i - 1];
    FOR(u, 1, M - 1)
        FOR(i, 1, n - (1 << u) + 1)
            st[u][i] = min(st[u - 1][i], st[u - 1][i + (1 << u - 1)]);
    xmat.init(n, x);
    ymat.init(n, y);
}

int max_prize(int L, int R)
{
    ++L, ++R;
    int k = __lg(R - L + 1);
    return min(min(st[k][L], st[k][R - (1 << k) + 1]), min(xmat.ask(L, R), ymat.ask(L, R)));
}