题解 P6772 【[NOI2020]美食家】

· · 题解

蒟蒻语

wtcl, 只会最简单的题目

这道题目与 P6569 [NOI Online #3 提高组]魔法值(民间数据) 类似, 都是倍增优化矩阵乘法。

蒟蒻解

首先观察数据, 看到 w_i \le 5, 可以想到储存下前 5 天的状态, 从而推出现在的状态。

显然着这样是 O(nwT) 的, 不可通过次题。

但是考虑到 n 很小, 可以使用矩阵乘法优化。然后按照图建立矩阵 a。大概是这样子的:

原来的数组(now表示转移结束后的状态):

now-5 now-4 now-3 now-2 now-1

要变成:

now-4 now-3 now-2 now-1 now

转移矩阵 a :

0 0 0 0 w = 5 的转移
S 0 0 0 w = 4 的转移
0 S 0 0 w = 3 的转移
0 0 S 0 w = 2 的转移
0 0 0 S w = 1 的转移
形如这样子的: | $1$ | $0$ | $0$ | | -----------: | -----------: | -----------: | | $0$ | $1$ | $0$ | | $0$ | $0$ | $1$ | 矩阵快速幂即可。时间复杂度为 $O((nw)^3\log T)

但是这只是没有美食节的情况, k = 0。对于每次美食节, 可以先把答案算到美食节的时间 -1。然后对原来建立的矩阵进行矩阵快速幂, 时间复杂度 O(k(nw)^3logT), 不可通过本题。

考虑继续优化。矩阵快速幂多次计算了一个矩阵的 2^t 次方, 考虑优化调这个时间。倍增预处理出原矩阵的 2^t 次方。这部分的代码:

void build() {
    btd[0] = a; // btd 是倍增矩阵。
    for(int i = 1; i <= 30; i++) btd[i] = btd[i - 1] * btd[i - 1]; 
    // 利用2^{t-1}矩阵的信息处理出2^t的矩阵。
}

求值时拿原来的序列乘以需要乘的矩阵。那么这部分的代码是:

void get(ll *f, int b) {
    for(int i = 0; i <= 30; i++) if(b & (1 << i)) cf(f, btd[i]);
    // cf : 乘法, 代表数组(1*(nw)的矩阵)乘以矩阵
}

预处理时间复杂度是 (nw^3)\log T, 而单次求值是 (nw^2)logT, 求 m 次就是 m(nw)^2 \log T, 总时间复杂度是 (nw^3)\log T + m(nw)^2 \log T, 可以通过本题。

蒟蒻码

int mn;
struct M {
    ll a[255][255];
} a;
M clear() {
    M res;
    for(int i = 1; i <= mn; i++) for(int j = 1; j <= mn; j++) res.a[i][j] = -inf;
    return res;
}
M operator * (M aa, M bb) {
    M res = clear();
    for(int i = 1; i <= mn; i++) 
        for(int j = 1; j <= mn; j++) 
            for(int k = 1; k <= mn; k++)
                res.a[i][j] = max(res.a[i][j], aa.a[i][k] + bb.a[k][j]);
    return res;
}
M btd[33];
void cf(ll *f, M b) {
    ll fz[255];
    for(int i = 1; i <= mn; i++) fz[i] = f[i], f[i] = -inf;
    for(int i = 1; i <= mn; i++) 
        for(int j = 1; j <= mn; j++) 
            f[i] = max(f[i], fz[j] + b.a[j][i]);
}
void get(ll *f, int b) {
    for(int i = 0; i <= 30; i++) if(b & (1 << i)) cf(f, btd[i]);
}
void build() {
    btd[0] = a;
    for(int i = 1; i <= 30; i++) btd[i] = btd[i - 1] * btd[i - 1];
}
struct node {
    int t, x, y;
} msj[2333];
bool cmp(node aa, node bb) {
    return aa.t < bb.t;
}
int T, n, m, k, c[55];
ll ans[255];
int main() {
    scanf("%d%d%d%d", &n, &m, &T, &k), mn = n * 5;
    for(int i = 1; i <= n; i++) scanf("%d", &c[i]);
    a = clear();
    for(int i = 1; i <= 4 * n; i++) a.a[i + n][i] = 0;
    for(int i = 1; i <= m; i++) {
        int u, v, w;
        scanf("%d%d%d", &u, &v, &w);
        a.a[u + n * (5 - w)][v + n * 4] = c[v];
    }
    build();
    for(int i = 1; i <= k; i++) 
        scanf("%d%d%d", &msj[i].t, &msj[i].x, &msj[i].y);
    sort(msj + 1, msj + k + 1, cmp);
    for(int i = 1; i <= mn; i++) ans[i] = -inf;
    ans[n * 4 + 1] = c[1];
    for(int i = 1; i <= k; i++) {
        get(ans, msj[i].t - msj[i - 1].t - 1);
        M master = a;
        for(int j = 1; j <= mn; j++) master.a[j][4 * n + msj[i].x] += msj[i].y;
        cf(ans, master);
    }
    get(ans, T - msj[k].t);
    if(ans[n * 4 + 1] < 0ll) puts("-1");
    else printf("%lld\n", ans[n * 4 + 1]);
    return 0;
}