题解:P10282 [USACO24OPEN] Smaller Averages G

· · 题解

去年 USACO 的题目,比赛时没做出来,今天上课老师正好讲到了,才补出来。如此补题,怎能不爱?

先看数据,对于 N\le10 很显然,直接暴力就行了。对于做题没什么启发。

然后是 N\le80 的数据,这组数据是做题的出发点。我们考虑使用 DP,先解释 f_{i,j} 的含义。f_{i,j} 代表 a_1\sim a_ib_1\sim b_j 中划分相同份数且合法的方案数量。不难得出,在初始时,f_{0,0}=1,而最后的答案位是 f_{n,n}

此时有一个 \mathcal{O}(N^4) 的做法。考虑枚举两个数组的上一个结束的位置 k,l。不难得出方程,在满足

\dfrac{\sum_{pos=k+1}^{i}a_{pos}}{i-k}\le\dfrac{\sum_{pos=l+1}^{j}b_{pos}}{j-l}

的时候, f_{i,j}=f_{i,j}+f_{k,l}。此时的条件含义十分明显,当 a 中在区间 [k+1,i] 中的元素的平均值不大于 b 中在区间 [l+1,j] 中的元素的平均值时,进行转移。

但是此时做法的复杂度完全无法接受,N\le80 时,\mathcal{O}(n^4) 不会超时,但是对于 100\% 的数据,N\le500,可以接受的最大限度为 \mathcal{O}(N^3),这里提一嘴,有几篇题解说最大限度是 \mathcal{O}(N^3\cdot\log N) 是错误的,因为那个复杂度给到的是 85\% 的数据,即 N\le300 时的数据。

所以现在重点在于考虑如何优化。

我们可以尝试优化 l 啦。但是观察后发现平均值无单调性,所以这个时候有一个领先人类智商巅峰一万年的方法。因为 N 本身范围其实不大,所以我们预处理出所有以 i 结尾的区间的平均值。

如果细品一下其实感觉挺有道理的,用 \mathcal{O}(N\cdot N^2) 的时间换掉了 \mathcal{O}(N\cdot N^3)N^3 的一个 N,将时间优化为 \mathcal{O}(N\times(N^2+N^2))=\mathcal{O}(N^3)

a 把以 i 结尾的区间 [l_a,i]\ (l_a\in[1,i]) 按平均值排序,将得到的数组记作 s_1;同时对 b 把以 j 结尾的区间 [l_b,j]\ (l_b\in[1,j]) 按平均值排序,将得到的数组记作 s_2。这时做法也就比较显然了,用双指针维护即可啦。

#include <bits/stdc++.h>
#define int long long
using namespace std;

const int maxn = 505, mod = 1e9 + 7;

int n, a[maxn], b[maxn];
struct node {
    int id;
    long double avg;
    friend bool operator < (node x, node y) {
        if (x.avg != y.avg) return x.avg < y.avg;
        return x.id < y.id;
    }
} s1[maxn][maxn], s2[maxn][maxn];
int s[maxn][maxn], f[maxn][maxn];

signed main() {
    cin >> n;
    for (int i = 1; i <= n; i++) {
        cin >> a[i];
        a[i] += a[i - 1];
        for (int j = 0; j < i; j++) {
            s1[i][j + 1] = (node){j, double(a[i] - a[j]) / (i - j)};
        }
        sort(s1[i] + 1, s1[i] + i + 1);
    }
    for (int i = 1; i <= n; i++) {
        cin >> b[i];
        b[i] += b[i - 1];
        for (int j = 0; j < i; j++) {
            s2[i][j + 1] = (node){j, double(b[i] - b[j]) / (i - j)};
        }
        sort(s2[i] + 1, s2[i] + i + 1);
    }
    f[0][0] = 1;
    for (int i = 1; i <= n; i++) {
        for (int j = 0; j <= n; j++) {
            for (int k = 1; k <= i; k++) {
                s[j][k] = s[j][k - 1] + f[s1[i][k].id][j];
                s[j][k] %= mod;
            }
        }
        for (int j = 1; j <= n; j++) {
            for (int k = 1, l = 1; l <= j; l++) {
                while (k <= i && s1[i][k].avg <= s2[j][l].avg) k++;
                f[i][j] += s[s2[j][l].id][k - 1];
                f[i][j] %= mod;
            }
        }
    }
    cout << f[n][n] << '\n';
    return 0;
}