P8563 Magenta Potion 题解

· · 题解

前言

我打算在这篇题解中将线段树做法和线性贪心做法都写出来,以拓宽大家的思路,让大家有所收获。

ps:线段树篇幅略长,想直接看简单的贪心可以直接跳过。

解题思路

先抛开修改操作,思考一下如何计算区间的最大乘积。

因为 \left | k \right |,\left | a_i \right | \ge 2,那么在保证乘积为正数时,乘的数字越多,乘积越大。

由于负数的特殊性,所以要根据负数的个数进行分类讨论。

\text{Solution1}

根据以上思路,我们需要维护的区间内的值有这些:负数出现的位置,负数的个数,区间乘积。

这启发我们采用线段树进行操作。

在查询区间最左边的负数位置时,可以用二分找出最小的前缀负数个数为 1 的位置,查询区间最右边的负数位置时,也是如此。时间复杂度为 O(\log^2n)

由于是单点修改,那么就不需要使用懒标记维护区间乘积了,可以直接维护。查询和修改的时间复杂度均为 O(\log n)

查询最终答案时,根据上面思路直接写,时间复杂度 O(\log n)

综上,时间复杂度为 O(n \log^2 n)

值得一提的是,根据数据范围发现,如果直接维护区间乘积,这个值是非常非常巨大的,并且无法存放。

那我们可以这样操作:由于负数在最终乘完后也变为正数,那么在计算的时候判断运算中的值的绝对值是否大于等于 2^{30} + 1,如果符合,那就将其修改为 2^{30}+1。这样做减小了数据规模,并且最终计算得到的结果也会大于 2^{30} ,不影响正确性。这样就可以存储下来了。

给出较简洁且部分注释的代码:

#include <bits/stdc++.h>

using namespace std;
typedef long long ll;

const int N = 2e5 + 7;
const ll P = (1ll << 30) + 1;

int n, q;
ll a[N], sum[N << 2]; //sum维护区间乘积
int num[N << 2], pos[N << 2];//num维护区间负数个数,pos维护区间负数位置

#define _max(a, b) (a > b ? a : b)
#define _min(a, b) (a < b ? a : b)

ll change (ll x) {
    return abs (x) >= P ? P : x;
}//将一个绝对值大于等于 2^30+1 的数字改为 2^30+1

struct SegMentTree {
    void pushup (int p) {
        sum[p] = change (sum[p << 1] * sum[p << 1 | 1]);//维护区间乘积
        num[p] = num[p << 1] + num[p << 1 | 1];//维护区间负数个数
    }

    void build (int p, int l, int r) {
        if (l == r) {
            sum[p] = change (a[l]);//减小数据规模
            if (a[l] < 0) {
                num[p] = 1;
                pos[p] = l;
            }
            return ;
        }
        int mid = (l + r) >> 1;
        build (p << 1, l, mid);
        build (p << 1 | 1, mid + 1, r);

        pushup (p);
    }//建树

    void update (int p, int l, int r, int qp, int k) {
        if (l == r) {
            sum[p] = change (k);//减小数据规模
            if (k < 0) {
                num[p] = 1, pos[p] = l;
            }
            else {
                num[p] = 0, pos[p] = 0;
            }
            return ;
        }
        int mid = (l + r) >> 1;
        if (qp <= mid) {
            update (p << 1, l, mid, qp, k);
        }
        if (qp > mid) {
            update (p << 1 | 1, mid + 1, r, qp, k);
        }

        pushup (p);
    }//单调修改直接暴力 logn 即可

    int query_num (int p, int l, int r, int ql, int qr) {
        if (ql <= l && qr >= r) {
            return num[p];
        }
        int mid = (l + r) >> 1, ans = 0;
        if (ql <= mid) {
            ans += query_num (p << 1, l, mid, ql, qr);
        }
        if (qr > mid) {
            ans += query_num (p << 1 | 1, mid + 1, r, ql, qr);
        }
        return ans;
    }//查询区间负数个数

    int find_l (int l, int r) {
        int ans = 0;
        while (l <= r) {//二分
            int mid = (l + r) >> 1;
            if (query_num (1, 1, n, l, mid)) {
                r = mid - 1;
                ans = mid;
                //找到最小的mid满足前缀只有一个负数,那么这就是最左边的负数位置
            }
            else {
                l = mid + 1;
            }
        }
        return ans;
    }

    int find_r (int l, int r) {
        int ans = 0;
        while (l <= r) {//二分
            int mid = (l + r) >> 1;
            if (query_num (1, 1, n, mid, r)) {
                l = mid + 1;
                ans = mid;
                //找到最大的mid满足后缀只有一个负数,那么这就是最右边的负数位置
            }
            else {
                r = mid - 1;
            }
        }
        return ans;
    }

    ll query_sum (int p, int l, int r, int ql, int qr) {
        if (ql > qr) {
            return 1;
        }
        if (ql <= l && qr >= r) {
            return change (sum[p]);//减小数据规模
        }
        int mid = (l + r) >> 1;
        ll ans = 1;
        if (ql <= mid) {
            ans = ans * query_sum (p << 1, l, mid, ql, qr);
        }
        if (qr > mid) {
            ans = ans * query_sum (p << 1 | 1, mid + 1, r, ql, qr);
        }
        return change (ans);//减小数据规模
    }//查询区间乘积

    ll query_ans (int p, int l, int r) {
        int tot = query_num (1, 1, n, l, r);
        if (! (tot % 2)) {
            return _min (query_sum (1, 1, n, l, r), P);
            //偶数个负数时,答案就是整个区间乘积
        }
        else {
            int nl = find_l (l, r);
            int nr = find_r (l, r);
            if (nl == nr) {
                if (l == r) {//1个负数且区间长度为1
                    return 1;
                }
                else {//1个负数且区间长度大于1
                    ll ans1 = abs (query_sum (1, 1, n, l, nl - 1));//负数左侧的乘积(不包括负数)
                    ll ans2 = abs (query_sum (1, 1, n, nl + 1, r));//负数右侧的乘积(不包括负数)
                    return _min (_max (ans1, ans2), P);
                }
            }
            else {
                ll ansl = abs (query_sum (1, 1, n, l, nl));//最左边负数左侧的乘积(包括负数)
                ll ansmid = abs (query_sum (1, 1, n, nl + 1, nr - 1));//最左边负数和最右边负数之间的乘积(不包括负数)
                ll ansr = abs (query_sum (1, 1, n, nr, r));//最右边负数右侧的乘积(包括负数)
                return _min (_max (ansl, ansr) * ansmid, P);//ansl和ansr的较大值乘以ansmid就是区间最大乘积,可以理解一下
            }//与2^30+1取min也是为了减小数据规模
        }
    }//查询区间最大乘积
} DS;

int main () {
    scanf ("%d%d", &n, &q);
    for (int i = 1; i <= n; i ++) {
        scanf ("%lld", &a[i]);
    }
    DS.build (1, 1, n);
    while (q --) {
        int op, x, y;
        scanf ("%d%d%d", &op, &x, &y);
        if (op == 1) {
            DS.update (1, 1, n, x, y);
        }
        else if (op == 2) {
            if (DS.query_ans (1, x, y) >= P) {
                printf ("Too large\n");
            }
            else {
                printf ("%lld\n", DS.query_ans (1, x, y));
            }
        }
    }

    return 0;
}

\text{Solution2}

观察到题目要求答案大于 2^{30} 时,就输出 Too large

而且 \left | k \right |,\left | a_i \right | \ge 2,仔细想想这有什么关联?

没错,当一个区间的长度过长时,不用计算这个区间的最大乘积,我们也知道这个最大乘积大于 2^{30}

因为每个值最小为 2,如果超过 30 个值,那么这个区间最大乘积肯定超过了 2^{30}

为了保险起见,当查询的区间长度大于 70 时,就可以直接输出 Too large

那么这样,对于区间长度小于等于 60 时,就可以根据上文的方法算出最大乘积并判断,且这个数字也可以存储下来。

这样的总复杂度为 O(n),代码短并且效率高。

给出代码:

#include <bits/stdc++.h>

using namespace std;
typedef long long ll;

const int N = 2e5 + 7;
const ll P = (1ll << 30);

int n, q;
ll a[N]; 

int main () {
    scanf ("%d%d", &n, &q);
    for (int i = 1; i <= n; i ++) {
        scanf ("%lld", &a[i]);
    }
    while (q --) {
        int op, l, r;
        scanf ("%d%d%d", &op, &l, &r);
        if (op == 1) {
            a[l] = r;//O1直接修改
        }
        else if (op == 2) {
            if (r - l + 1 > 70) {
                printf ("Too large\n");//长度过长 直接输出
            } 
            else {
                int tot = 0;
                for (int i = l; i <= r; i ++) {
                    if (a[i] < 0) {
                        tot ++;//负数个数
                    }
                }
                bool output = false;
                ll ans = 1, p, t = 1ll;
                if (!(tot % 2)) {//偶数个负数
                    for (int i = l; i <= r; i ++) {
                        ans = ans * abs (a[i]);
                        if (ans > P) {//在运算过程中 时刻防止溢出
                            printf ("Too large\n");
                            output = true;
                            break;
                        }
                    }
                    if (!output) {
                        printf ("%lld\n", ans);
                    }
                }
                else {//奇数个负数
                    for (int i = l; i <= r; i ++) {
                        if (a[i] < 0) {
                            p = i;
                            break;
                        }
                    }
                    for (int i = r; i > p; i --) {
                        t = t * a[i];
                        if (t > P) {//在运算过程中 时刻防止溢出
                            printf ("Too large\n");
                            output = true;
                            break;
                        }
                    }
                    if (output) {
                        continue;
                    }
                    ans = max (ans, t);
                    t = 1;
                    for (int i = r; i >= l; i --) {
                        if (a[i] < 0) {
                            p = i;
                            break;
                        }
                    }
                    for (int i = l; i < p; i ++) {
                        t = t * a[i];
                        if (t > P) {//在运算过程中 时刻防止溢出
                            printf ("Too large\n");
                            output = true;
                            break;
                        }
                    }
                    ans = max (ans, t);
                    if (!output) {
                        if (ans > P) {//在运算过程中 时刻防止溢出
                            printf ("Too large\n");
                        }
                        else {
                            printf ("%lld\n", ans);
                        }
                    }
                }
            }
        }
    }
    return 0;
}

后言

本人赛时线段树写挂成 10 分,没有看到这个 Too large 的作用,只是将其当做一个限制,这告诉我们要好好读题

如果读清楚题,想到这个作用,这题就迎刃而解了,而且代码编写简单。(哭)

希望大家不要学我/bx