P8803 [蓝桥杯 2022 国 B] 费用报销

· · 题解

更新于 2025.6.8:经评论区用户指出,这个两年前的题解存在许多问题,现予以修复。且两年前并未发现这个 dfs 优化没有正确性,现已删除。同时更新了部分内容。

其实我们可以发现这个题很像背包问题,那么就可以开一个 f 数组,其中 f_{i,j} 表示前 i 个票据占用 j 的背包的最大价值。

我们先预处理每一个票据离它最近的且合法的票据的位置,这个过程是容易 \mathcal {O}(n^2) 处理的,我们把它存到 lst 数组中去。

对于每一个票据也是有选或不选两种选择,那么我们可以得出以下方程:

f_{i,j}=\max(f_{i-1,j},f_{lst_{i},j-d_{i}.v}+d_{i}.v)

最后我们输出 f_{n,m} 即可,综合时间复杂度为 \mathcal {O}(nm)

对于用户指出的一组数据:

9 647 10
1 10 255
3 17 13
4 19 106
9 7 140
10 3 196
10 6 393
11 2 219
11 13 319
11 30 224

正确输出是 644

dp 代码跑出了错误答案,原因在于,枚举背包体积的一维并没有转移完,对于 j<v 的部分仍需要继承 i-1 的所有答案,因此把背包体积枚举到 0 即可。代码已更新。

#include<bits/stdc++.h>
using namespace std;
int n,m,i,j,ans,k1,sum;
struct ren{
    int m,d,v,t;
}d[1005];
int f[1005][5005],s[1005],lst[1005];
bool cmp(ren a,ren b){
    return a.t<b.t;
} 
int dx[]={0,31,28,31,30,31,30,31,31,30,31,30,31};
int main(){
    scanf("%d%d%d",&n,&m,&k1);
    for(i=2;i<=12;i++) s[i]=s[i-1]+dx[i-1];
    for(i=1;i<=n;i++){
        scanf("%d%d%d",&d[i].m,&d[i].d,&d[i].v);
        d[i].t=s[d[i].m]+d[i].d;
    } 
    sort(d+1,d+1+n,cmp);
    for(i=1;i<=n;i++){
        for(j=0;j<i;j++){
            if(d[i].t-d[j].t>=k1) lst[i]=j;//预处理离第i个票据最近的票据的位置
        }
    }
    for(i=1;i<=n;i++){
        for(j=m;j>=0;j--){
            f[i][j]=f[i-1][j];
            if(j>=d[i].v) f[i][j]=max(f[i][j],f[lst[i]][j-d[i].v]+d[i].v);
        }
    }
    printf("%d",f[n][m]);
    return 0;
}

当然,如果你一时脑抽没有想到预处理最近转移点,然后继承前面所有答案的话,你依然可以数据结构优化,前提是你依旧需要找出最近转移点。

那么你的所有转移点就是 1 \sim lst[i],转移的体积是固定的。因此你对每个体积值开一棵线段树,下标存的是第 i 位上的 dp 值,那么你每次查询就查 j-v 这一棵上 [1,lst[i]] 这个区间的答案的最大值即可。

复杂度是 O(nm \log n) 的,看上去很蠢。

#include<bits/stdc++.h>
using namespace std;
#define N 5005
#define M 1005
int n,m,i,j,ans,k1,sum;
struct ren{
    int m,d,v,t;
}d[M];
int f[M][N],s[M],lst[M];
bool cmp(ren a,ren b){
    return a.t<b.t;
} 
int dx[]={0,31,28,31,30,31,30,31,31,30,31,30,31};
struct seg{
    int d[M<<2];
    void upd(int l,int r,int p,int k,int c){
        if(l==r){
            d[p]=c;
            return;
        }
        int mid=(l+r)>>1;
        if(mid>=k) upd(l,mid,p<<1,k,c);
        else upd(mid+1,r,p<<1|1,k,c);
        d[p]=max(d[p<<1],d[p<<1|1]);
    }
    int qry(int l,int r,int p,int s,int t){
        if(s>t) return 0; 
        if(l>=s && r<=t) return d[p];
        int mid=(l+r)>>1,res=0;
        if(mid>=s) res=max(res,qry(l,mid,p<<1,s,t));
        if(mid<t) res=max(res,qry(mid+1,r,p<<1|1,s,t));
        return res;
    }
}tr[N];
int main(){
    scanf("%d%d%d",&n,&m,&k1);
    for(i=2;i<=12;i++) s[i]=s[i-1]+dx[i-1];
    for(i=1;i<=n;i++){
        scanf("%d%d%d",&d[i].m,&d[i].d,&d[i].v);
        d[i].t=s[d[i].m]+d[i].d;
    } 
    sort(d+1,d+1+n,cmp);
    for(i=1;i<=n;i++){
        for(j=0;j<i;j++){
            if(d[i].t-d[j].t>=k1) lst[i]=j;
        }
    }
    for(i=1;i<=n;i++){
        for(j=m;j>=0;j--){
            f[i][j]=f[i-1][j];
            if(j>=d[i].v) f[i][j]=max(f[i][j],d[i].v); 
            if(j>d[i].v) f[i][j]=max(f[i][j],tr[j-d[i].v].qry(1,n,1,1,lst[i])+d[i].v);
            tr[j].upd(1,n,1,i,f[i][j]);
        }
    }
    printf("%d",f[n][m]);
    return 0;
}