题解:P9049 [PA 2021] Mopadulo

· · 题解

题目传送门

考虑暴力:

暴力代码如下:

dp[0] = 1;
for(int i = 1;i <= n;i++){
    for(int j = 0;j < i;j++){
        if((s[i] - s[j] + inf) % inf % 2 == 0){
            dp[i] += dp[j];
            dp[i] %= inf;
        }
    }
}

优化

但这样会 O(n^2) 超时,所以我们需要记录下从 1 \to i - 1 的所有可行的方案数,有以下结论:

s_i - s_j = \begin{cases} (s_i - s_j)\mod10^9+7 & s_j \le s_i 且 s_i&s_j 奇偶性相同\\ (s_i - s_j + 10^9 + 7)\mod 10^9 + 7 & s_j > s_i 且 s_i&s_j 奇偶性不同\end{cases}

我们只需对 s 数组建两棵棵线段树,一棵存奇,一棵存偶,每次求完 dp_i 后单点修改,再进行区间查询记录答案即可。

代码附上——

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int maxn = 300010;
const int inf = 1e9 + 7;
//unsigned long long
//cout << fixed << setprecision(3)
//cout << setw(5) <<
//continue
int a[maxn], dp[maxn], s[maxn], s2[maxn];
struct S {
    int l, r, sum;
} f[maxn * 4], f1[maxn * 4];
//偶线段树     奇线段树
void up(int x) {
    f[x].sum = f[x * 2].sum + f[x * 2 + 1].sum;
}
void built(int u, int l, int r) {
    f[u].l = l, f[u].r = r;
    if(l == r) return ;
    int mid = (l + r) / 2;
    built(u * 2, l, mid);
    built(u * 2 + 1, mid + 1, r);
}
void add(int u, int x, int p) {
    if(f[u].l == f[u].r) f[u].sum += p;
    else {
        int mid = (f[u].l + f[u].r) / 2;
        if(x <= mid) add(u * 2, x, p);
        else add(u * 2 + 1, x, p);
        up(u);
    }
}
int q(int u, int l, int r) {
    if(f[u].l >= l && f[u].r <= r) return f[u].sum;
    int mid = (f[u].l + f[u].r) / 2, ans = 0;
    if(l <= mid) ans += q(u * 2, l, r);
    if(r > mid) ans += q(u * 2 + 1, l, r);
    up(u);
    return ans;
}
void up1(int x) {
    f1[x].sum = f1[x * 2].sum + f1[x * 2 + 1].sum;
}
void built1(int u, int l, int r) {
    f1[u].l = l, f1[u].r = r;
    if(l == r) return ;
    int mid = (l + r) / 2;
    built1(u * 2, l, mid);
    built1(u * 2 + 1, mid + 1, r);
}
void add1(int u, int x, int p) {
    if(f1[u].l == f1[u].r) f1[u].sum += p;
    else {
        int mid = (f1[u].l + f1[u].r) / 2;
        if(x <= mid) add1(u * 2, x, p);
        else add1(u * 2 + 1, x, p);
        up1(u);
    }
}
int q1(int u, int l, int r) {
    if(f1[u].l >= l && f1[u].r <= r) return f1[u].sum;
    int mid = (f1[u].l + f1[u].r) / 2, ans = 0;
    if(l <= mid) ans += q1(u * 2, l, r);
    if(r > mid) ans += q1(u * 2 + 1, l, r);
    up1(u);
    return ans;
}
signed main() {
    //freopen("a.in", "r", stdin);
    //freopen("a.out", "w", stdout);
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    int n;
    cin >> n;
    for(int i = 1; i <= n; i++) {
        cin >> a[i];
        s[i] = s[i - 1] + a[i];
        s[i] %= inf;
        s2[i] = s[i];
    }
    sort(s2 + 1, s2 + 1 + n);
    int len = unique(s2 + 1, s2 + 1 + n) - s2 - 1;
        //离散化
    built(1, 1, len + 1);
    built1(1, 1, len + 1);
        //建树
    for(int i = 1; i <= n; i++) {
        int p = lower_bound(s2 + 1, s2 + 1 + len, s[i]) - s2;
                //找第一个小于等于s[i]的数的位置
        int p1 = upper_bound(s2 + 1, s2 + 1 + len, s[i]) - s2;
                //找第一个大于s[i]的数的位置
        if(s[i] % 2 == 0) {
            dp[i] = 1;
            dp[i] += q(1, 1, p) + q1(1, p1, len);
        } else {
            dp[i] += q1(1, 1, p) + q(1, p1, len);
        }
        dp[i] %= inf;
        if(s[i] % 2 == 0) {
            add(1, p, dp[i]);
        } else {
            add1(1, p, dp[i]);
        }
    }
    cout << dp[n] % inf;
    return 0;
}

感谢阅读!