题解:P10764 [BalticOI 2024] Wall

· · 题解

硬拆 2 的次幂的做法很不友好,我来个好想好写的做法。

对于 h 序列固定的情况,答案是 \sum_{i = 1}^{n} \min(pr_i, su_i) - h_i = pr_i + su_i - global - h_i。其中 pr, su 是前缀、后缀最大值,global 是全局最大值。

思考直接求出所有情况的 \sum pr, \sum su, \sum global, \sum h

思考一个 DP 状态 f_{i, x} 表示考虑前 i 个元素,其 \max h = x 的方案数。\sum_{i = 1}^{n}2^{n - i}\sum xf_{i, x} 就是 \sum pr。枚举 i,把 f 挂在线段树上维护;\sum su 同理;\sum global 考虑 xf_{n, x} 的和;\sum h 等于 \sum h2^{n - 1}

因此写一个线段树优化 DP,把系数放进线段树里算就可以了。

/* Good Game, Well Play. */
#include <bits/stdc++.h>
#define lowbit(x) ((x) & (-(x)))
using namespace std;

const int N = 500010, mod = 1e9 + 7;
inline void Plus(long long &now, long long add)
{now += add; while(now >= mod) now -= mod;}
int n, h[2][N]; long long pw[N];

struct Disc {int id, pos; long long val;} disc[N * 2]; int dn;
inline bool cmp_disc(Disc u, Disc v) {return u.val < v.val;}

int rt, idx;
struct SGT
{
    int ls, rs;
    long long cnt, sum, tmp, mul;
    #define ls(x) tree[x].ls
    #define rs(x) tree[x].rs
    #define cnt(x) tree[x].cnt
    #define sum(x) tree[x].sum
    #define tmp(x) tree[x].tmp
    #define mul(x) tree[x].mul
} tree[N * 4];
inline void pushup(int now)
{
    cnt(now) = (cnt(ls(now)) + cnt(rs(now))) % mod;
    sum(now) = (sum(ls(now)) + sum(rs(now))) % mod;
}
inline void push_mul(int now, long long tg)
{
    cnt(now) = cnt(now) * tg % mod;
    sum(now) = sum(now) * tg % mod;
    mul(now) = mul(now) * tg % mod;
}
inline void pushdown(int now)
{
    if(mul(now) != 1)
    {
        push_mul(ls(now), mul(now));
        push_mul(rs(now), mul(now));
        mul(now) = 1;
    }
}
inline void build(int &now, int l, int r)
{
    now = ++idx; mul(now) = 1;
    if(l == r) {tmp(now) = disc[l].val; return ;}
    int mid = (l + r) >> 1;
    build(ls(now), l, mid), build(rs(now), mid + 1, r);
    tmp(now) = (tmp(ls(now)) + tmp(rs(now))) % mod;
}
inline void range_mul(int now, int l, int r, int L, int R, long long num)
{
    if(L > R) return ;
    if(L <= l && r <= R) {push_mul(now, num); return ;}
    pushdown(now); int mid = (l + r) >> 1;
    if(L <= mid) range_mul(ls(now), l, mid, L, R, num);
    if(mid < R) range_mul(rs(now), mid + 1, r, L, R, num);
    pushup(now);
}
inline void single_add(int now, int l, int r, int pos, int num)
{
    if(l == r)
    {
        cnt(now) = (cnt(now) + num) % mod;
        sum(now) = (sum(now) + tmp(now) * num) % mod;
        return ;
    }
    pushdown(now); int mid = (l + r) >> 1;
    if(pos <= mid) single_add(ls(now), l, mid, pos, num);
    else single_add(rs(now), mid + 1, r, pos, num);
    pushup(now);
}
inline int query_cnt(int now, int l, int r, int L, int R)
{
    if(L > R) return 0;
    if(L <= l && r <= R) return cnt(now);
    pushdown(now); int mid = (l + r) >> 1;
    if(R <= mid) return query_cnt(ls(now), l, mid, L, R);
    else if(mid < L) return query_cnt(rs(now), mid + 1, r, L, R);
    else return (query_cnt(ls(now), l, mid, L, R) + query_cnt(rs(now), mid + 1, r, L, R)) % mod;
}

long long LN, RN, GLOBAL, SELF;

int main()
{
//  freopen("text.in", "r", stdin);
//  freopen("prog.out", "w", stdout);
    ios::sync_with_stdio(false);
    cin.tie(0), cout.tie(0);
    cin >> n;
    for(int i = 0; i <= 1; ++i)
        for(int j = 1; j <= n; ++j)
            cin >> h[i][j], disc[++dn] = {i, j, h[i][j]};
    sort(disc + 1, disc + dn + 1, cmp_disc);
    for(int i = 1; i <= dn; ++i) h[disc[i].id][disc[i].pos] = i;
    pw[0] = 1; for(int i = 1; i <= n; ++i) pw[i] = pw[i - 1] * 2 % mod;

    build(rt, 1, 2 * n);
    single_add(rt, 1, 2 * n, h[0][1], 1);
    single_add(rt, 1, 2 * n, h[1][1], 1);
    Plus(LN, sum(rt) * pw[n - 1] % mod);
    for(int i = 2; i <= n; ++i)
    {
        int p_min = min(h[0][i], h[1][i]), p_max = max(h[0][i], h[1][i]);
        single_add(rt, 1, 2 * n, p_max, query_cnt(rt, 1, 2 * n, 1, p_max - 1));
        single_add(rt, 1, 2 * n, p_min, query_cnt(rt, 1, 2 * n, 1, p_min - 1));
        range_mul(rt, 1, 2 * n, 1, p_min - 1, 0);
        range_mul(rt, 1, 2 * n, p_max + 1, 2 * n, 2);
        Plus(LN, sum(rt) * pw[n - i] % mod);
    }
    push_mul(rt, 0);
    single_add(rt, 1, 2 * n, h[0][n], 1);
    single_add(rt, 1, 2 * n, h[1][n], 1);
    Plus(RN, sum(rt) * pw[n - 1] % mod);
    for(int i = n - 1; i >= 1; --i)
    {
        int p_min = min(h[0][i], h[1][i]), p_max = max(h[0][i], h[1][i]);
        single_add(rt, 1, 2 * n, p_max, query_cnt(rt, 1, 2 * n, 1, p_max - 1));
        single_add(rt, 1, 2 * n, p_min, query_cnt(rt, 1, 2 * n, 1, p_min - 1));
        range_mul(rt, 1, 2 * n, 1, p_min - 1, 0);
        range_mul(rt, 1, 2 * n, p_max + 1, 2 * n, 2);
        Plus(RN, sum(rt) * pw[i - 1] % mod);
    }

    for(int i = 0; i <= 1; ++i)
        for(int j = 1; j <= n; ++j)
            Plus(SELF, disc[h[i][j]].val * pw[n - 1] % mod);

    long long left = pw[n]; bool vis[N] = {0};
    for(int i = 2 * n, iv = (mod + 1) / 2; i >= 1; --i)
    {
        if(!vis[disc[i].pos])
        {
            left = left * iv % mod;
            Plus(GLOBAL, left * disc[i].val % mod);
            vis[disc[i].pos] = true;
        }
        else
        {
            Plus(GLOBAL, left * disc[i].val % mod);
            break;
        }
    }

    cout << (LN + RN - GLOBAL * n % mod - SELF + mod * 2) % mod << '\n';
    return 0;
}
/*

*/