arc158e题解

· · 题解

考虑扫描线。左端点从右到左枚举,对于所有 (i,1/2) 满足 i\ge l 的格子,记录下这个格子分别到 (l,1)(l,2) 的最短路。

设这两个值在 l+1 分别是 p_1,p_2,扫描到 l 时为 q_1,q_2,能得出转移式子为 q_1=\min(p_1,p_2+b_l)+a_l,q_2=\min(p_1+a_l,p_2)+b_l

能发现 q_1p_2+b_l 转移的条件为 p_2+b_l<p_1p_1-p_2>b_l,同样的 q_2p_1+a_l 转移的条件为 p_1-p_2<-a_l。扫描线时把在 l 右侧所有的格子按照 p_1-p_2 排序,先对于所有 p_1a_l,对于所有 p_2b_l,再暴力修改上述讨论的两种情况的格子,能发现是一个前缀和一个后缀,暴力查暴力修改,单次修改会让 p_1-p_2 不同的格子的数量减少 1,所以修改总复杂度为 \Theta(n\log n)

全体加用两个变量存标记即可。注意取模,不取模会爆 long long。

#include <bits/stdc++.h>
#define endl '\n'
#define ll long long
#define mod 998244353
using namespace std;
const int N = 2e5 + 5;
int n, a[N], b[N];
ll taga, tagb, sa, sb, ans;
struct node{
    ll d, suma, sumb;//a - b
    int sz;
    node() : d(0) {}
    node(ll d) : d(d) {}
    node(ll d, ll suma, ll sumb, int sz) : d(d), suma(suma), sumb(sumb), sz(sz) {}
    friend bool operator < (const node &a, const node &b){
        return a.d < b.d;
    }
};
set<node> st;
int main(){
    ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
//  freopen("in.txt", "r", stdin);
//  freopen("out.txt", "w", stdout);
    cin >> n;
    for (int i = 1; i <= n; i++)
        cin >> a[i];
    for (int i = 1; i <= n; i++)
        cin >> b[i];
    for (int i = n; i >= 1; i--){
        taga += a[i];
        tagb += b[i];
        auto insert = [&](ll aa, ll bb){
            aa -= taga;
            bb -= tagb;
            set<node>::iterator it = st.lower_bound(node(aa - bb));
            if (it != st.end() && it -> d == aa - bb){
                node nd = *it;
                nd.sz++;
                nd.suma = (nd.suma + aa) % mod;
                nd.sumb = (nd.sumb + bb) % mod;
                st.erase(it);
                st.insert(nd);
            }else
                st.insert(node(aa - bb, aa, bb, 1));
            sa = (sa + aa) % mod;
            sb = (sb + bb) % mod;
        };
        insert(a[i], a[i] + b[i]);
        insert(a[i] + b[i], b[i]);
        {
            node nd;
            nd.d = a[i] - (taga - tagb);
            nd.suma = nd.sumb = 0;
            nd.sz = 0;
            while (!st.empty() && prev(st.end()) -> d >= nd.d){
                set<node>::iterator it = prev(st.end());
                nd.sumb = (nd.sumb + it -> sumb) % mod;
                nd.sz += it -> sz;
                sa = (sa - it -> suma) % mod;
                sb = (sb - it -> sumb) % mod;
                st.erase(it);
            }
            nd.suma = (nd.sumb + nd.d % mod * nd.sz) % mod;
            sa = (sa + nd.suma) % mod;
            sb = (sb + nd.sumb) % mod;
            if (nd.sz)
                st.insert(nd);
        }
        {
            node nd;
            nd.d = -b[i] - (taga - tagb);
            nd.suma = nd.sumb = 0;
            nd.sz = 0;
            while (!st.empty() && st.begin() -> d <= nd.d){
                set<node>::iterator it = st.begin();
                nd.suma = (nd.suma + it -> suma) % mod;
                nd.sz += it -> sz;
                sa = (sa - it -> suma) % mod;
                sb = (sb - it -> sumb) % mod;
                st.erase(it);
            }
            nd.sumb = (nd.suma - nd.d % mod * nd.sz) % mod;
            sa = (sa + nd.suma) % mod;
            sb = (sb + nd.sumb) % mod;
            if (nd.sz)
                st.insert(nd);
        }
        ans = (ans + sa + taga % mod * 2 * (n - i + 1) + sb + tagb % mod * 2 * (n - i + 1)) % mod;
    }
    for (int i = 1; i <= n; i++)
        ans = (ans - (a[i] + b[i])) % mod;
    ans = 1ll * ans * 2 % mod;
    for (int i = 1; i <= n; i++)
        ans = (ans - a[i] - b[i]) % mod;
    cout << (ans + mod) % mod;
    return 0; 
}