arc158e题解
考虑扫描线。左端点从右到左枚举,对于所有
设这两个值在
能发现
全体加用两个变量存标记即可。注意取模,不取模会爆 long long。
#include <bits/stdc++.h>
#define endl '\n'
#define ll long long
#define mod 998244353
using namespace std;
const int N = 2e5 + 5;
int n, a[N], b[N];
ll taga, tagb, sa, sb, ans;
struct node{
ll d, suma, sumb;//a - b
int sz;
node() : d(0) {}
node(ll d) : d(d) {}
node(ll d, ll suma, ll sumb, int sz) : d(d), suma(suma), sumb(sumb), sz(sz) {}
friend bool operator < (const node &a, const node &b){
return a.d < b.d;
}
};
set<node> st;
int main(){
ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
// freopen("in.txt", "r", stdin);
// freopen("out.txt", "w", stdout);
cin >> n;
for (int i = 1; i <= n; i++)
cin >> a[i];
for (int i = 1; i <= n; i++)
cin >> b[i];
for (int i = n; i >= 1; i--){
taga += a[i];
tagb += b[i];
auto insert = [&](ll aa, ll bb){
aa -= taga;
bb -= tagb;
set<node>::iterator it = st.lower_bound(node(aa - bb));
if (it != st.end() && it -> d == aa - bb){
node nd = *it;
nd.sz++;
nd.suma = (nd.suma + aa) % mod;
nd.sumb = (nd.sumb + bb) % mod;
st.erase(it);
st.insert(nd);
}else
st.insert(node(aa - bb, aa, bb, 1));
sa = (sa + aa) % mod;
sb = (sb + bb) % mod;
};
insert(a[i], a[i] + b[i]);
insert(a[i] + b[i], b[i]);
{
node nd;
nd.d = a[i] - (taga - tagb);
nd.suma = nd.sumb = 0;
nd.sz = 0;
while (!st.empty() && prev(st.end()) -> d >= nd.d){
set<node>::iterator it = prev(st.end());
nd.sumb = (nd.sumb + it -> sumb) % mod;
nd.sz += it -> sz;
sa = (sa - it -> suma) % mod;
sb = (sb - it -> sumb) % mod;
st.erase(it);
}
nd.suma = (nd.sumb + nd.d % mod * nd.sz) % mod;
sa = (sa + nd.suma) % mod;
sb = (sb + nd.sumb) % mod;
if (nd.sz)
st.insert(nd);
}
{
node nd;
nd.d = -b[i] - (taga - tagb);
nd.suma = nd.sumb = 0;
nd.sz = 0;
while (!st.empty() && st.begin() -> d <= nd.d){
set<node>::iterator it = st.begin();
nd.suma = (nd.suma + it -> suma) % mod;
nd.sz += it -> sz;
sa = (sa - it -> suma) % mod;
sb = (sb - it -> sumb) % mod;
st.erase(it);
}
nd.sumb = (nd.suma - nd.d % mod * nd.sz) % mod;
sa = (sa + nd.suma) % mod;
sb = (sb + nd.sumb) % mod;
if (nd.sz)
st.insert(nd);
}
ans = (ans + sa + taga % mod * 2 * (n - i + 1) + sb + tagb % mod * 2 * (n - i + 1)) % mod;
}
for (int i = 1; i <= n; i++)
ans = (ans - (a[i] + b[i])) % mod;
ans = 1ll * ans * 2 % mod;
for (int i = 1; i <= n; i++)
ans = (ans - a[i] - b[i]) % mod;
cout << (ans + mod) % mod;
return 0;
}