P7074 [CSP-J2020] 方格取数

· · 题解

思路:

考虑动态规划算法,定义 dp_{i, j} 表示恰好走到第 i 列的第 j 个位置(即之前没有在第 i 列走过,是由 i - 1 列直接走过来的)的最大方案数。

然后转移,枚举是由 i - 1 列的第 k 个走过来的:

dp_{i, j} = \max\Big( \max_{k = 1}^j dp_{i -1, k} + s_{i - 1, j} - s_{i - 1, k - 1}, \max_{k = j}^n dp_{i - 1, k} + s_{i - 1, k} - s_{i - 1, j - 1} \Big)

其中 s_{i, j} 表示第 i 列前 j 个数的和。

朴素转移是 O(N^3) 的(但是好像可以直接过)。

考虑优化,注意到我们要查询一个前缀 dp_{i - 1, k} - s_{i - 1, k - 1} 与一个后缀 dp_{i - 1, k} + s_{i - 1, k} 的最大值。

直接开一个 f_{i ,j}, g_{i, j} 维护即可。

时间复杂度优化至 O(N^2)

完整代码:

#include<bits/stdc++.h>
#include<ext/pb_ds/assoc_container.hpp>
#include<ext/pb_ds/hash_policy.hpp>
#define Add(x, y) (x + y >= mod) ? (x + y - mod) : (x + y)
#define lowbit(x) x & (-x)
#define pi pair<ll, ll>
#define pii pair<ll, pair<ll, ll>>
#define iip pair<pair<ll, ll>, ll>
#define ppii pair<pair<ll, ll>, pair<ll, ll>>
#define ls(k) k << 1
#define rs(k) k << 1 | 1
#define fi first
#define se second
#define full(l, r, x) for(auto it = l; it != r; ++it) (*it) = x
#define Full(a) memset(a, 0, sizeof(a))
#define open(s1, s2) freopen(s1, "r", stdin), freopen(s2, "w", stdout);
#define For(i, l, r) for(register int i = l; i <= r; ++i)
#define _For(i, l, r) for(register int i = r; i >= l; --i)
using namespace std;
using namespace __gnu_pbds;
typedef __int128 __;
typedef long double lb;
typedef double db;
typedef unsigned long long ull;
typedef long long ll;
bool Begin;
const int N = 1010; 
inline ll read(){
    ll x = 0, f = 1;
    char c = getchar();
    while(c < '0' || c > '9'){
        if(c == '-')
          f = -1;
        c = getchar();
    }
    while(c >= '0' && c <= '9'){
        x = (x << 1) + (x << 3) + (c ^ 48);
        c = getchar();
    }
    return x * f;
}
inline void write(ll x){
    if(x < 0){
        putchar('-');
        x = -x;
    }
    if(x > 9)
      write(x / 10);
    putchar(x % 10 + '0');
}
ll ans = -1e18;
int n, m;
int a[N][N];
ll s[N][N], dp[N][N], f[N][N], g[N][N];
bool End;
int main(){
    memset(dp, -0x7f, sizeof(dp));
    memset(f, -0x7f, sizeof(f));
    memset(g, -0x7f, sizeof(g));
    n = read(), m = read();
    for(int i = 1; i <= n; ++i){
        for(int j = 1; j <= m; ++j){
            a[i][j] = read();
            s[j][i] = s[j][i - 1] + a[i][j];
        }
    }
    dp[1][1] = 0, g[1][1] = a[1][1];
    for(int i = 1; i <= n; ++i)
      f[1][i] = 0;
    for(int i = 2; i <= m; ++i){
        for(int j = 1; j <= n; ++j)
          dp[i][j] = max(f[i - 1][j] + s[i - 1][j], g[i - 1][j + 1] - s[i - 1][j - 1]);
        for(int j = 1; j <= n; ++j)
          f[i][j] = max(f[i][j - 1], dp[i][j] - s[i][j - 1]);
        for(int j = n; j >= 1; --j)
          g[i][j] = max(g[i][j + 1], dp[i][j] + s[i][j]);
    }
    for(int i = 1; i <= n; ++i)
      ans = max(ans, dp[m][i] + s[m][n] - s[m][i - 1]);
    write(ans);
    cerr << '\n' << abs(&Begin - &End) / 1048576 << "MB";
    return 0;
}