题解 CF1787H Codeforces Scoreboard

· · 题解

容易将题意转化为,给定 n 个二元组 (k_1,c_1),\cdots,(k_n,c_n),要求将它们重排得到排名数组 rk_1,\cdots,rk_n,最小化 \min(k_i\cdot rk_i,c_i) 之和。

钦定一个集合 S,并钦定 i\in S\min(k_i\cdot rk_i,c_i)=k_i\cdot rk_i,那么显然将 S 中的 ik_i 从大到小排序放在前面,后面再放 S 之外的 i' 是最优的。我们要找的是任意钦定 S 时的最小值。

那么先将所有 (k_i,c_i)k_i 从大往小排序,然后设 f_{i,j} 表示考虑了 (k_1,c_1),\cdots,(k_i,c_i),其中 \{1,\cdots,i\} 中有 j 个被钦定在 S 中时的最小值。容易得到一个 O(n^2) 的 DP:

f_{i,j}=\min(f_{i-1,j}+c_i,f_{i-1,j-1}+k_i\cdot j)

接下来就需要证明更多的性质。

其实感觉引理 1 已经是一个很好的性质了,因为我们可以在 S_{j-1} 的基础上,直接贪心地再选择一个最优的没被选择过的点得到 S_j,但发现竟然不好用数据结构维护,也有可能是 DS 水平不够(

现在重新考虑 f_{i-1}f_i 的转移:

f_{i,j}=\min(f_{i-1,j}+c_i,f_{i-1,j-1}+k_i\cdot j)

相减得到 (f_{i-1,j}+c_i)-(f_{i-1,j-1}+k_i\cdot j)=(f_{i-1,j}-f_{i-1,j-1})-k_i\cdot j+c_i

根据引理 2,可知上式关于 j 的差分始终非负,所以一定存在一个界点 M,使得对于 j\leq M 都是 f_{i-1,j}+c_i 更小;对于 j>M 都是 f_{i-1,j-1}+k_i\cdot j 更小。

维护 h(j)=f_{i,j}-f_{i,j-1},找 M 直接二分 h(j)y=k_i\cdot j-c_i 的交点,而 h(j) 的变化只有单点插入和后缀加。可以用平衡树维护。

#include<bits/stdc++.h>

#define N 200010
#define ll long long
#define lc(u) t[u].ch[0]
#define rc(u) t[u].ch[1]

using namespace std;

inline int read()
{
    int x=0,f=1;
    char ch=getchar();
    while(ch<'0'||ch>'9')
    {
        if(ch=='-') f=-1;
        ch=getchar();
    }
    while(ch>='0'&&ch<='9')
    {
        x=(x<<1)+(x<<3)+(ch^'0');
        ch=getchar();
    }
    return x*f;
}

mt19937 rnd(time(NULL));

struct data
{
    int k,c;
}p[N];

struct Treap
{
    int ch[2],sz,rd;
    ll val,lazy;
}t[N];

int n,node,rt;

int newnode(ll v)
{
    int u=++node;
    lc(u)=rc(u)=0,t[u].sz=1,t[u].rd=rnd();
    t[u].val=v,t[u].lazy=0;
    return u;
}

void up(int u)
{
    t[u].sz=t[lc(u)].sz+t[rc(u)].sz+1;
}

void downn(int u,ll v)
{
    t[u].val+=v,t[u].lazy+=v;
}

void down(int u)
{
    if(t[u].lazy)
    {
        downn(lc(u),t[u].lazy);
        downn(rc(u),t[u].lazy);
        t[u].lazy=0;
    }
}

void split(int u,int &a,int &b,int s,int k,int c)
{
    if(!u) return void(a=b=0);
    down(u);
    if(t[u].val<=1ll*k*(s+t[lc(u)].sz+1)-c)
        a=u,split(rc(u),rc(a),b,s+t[lc(u)].sz+1,k,c),up(a);
    else b=u,split(lc(u),a,lc(b),s,k,c),up(b);
}

int merge(int a,int b)
{
    if(!a||!b) return a+b;
    if(t[a].rd<t[b].rd)
        return down(a),rc(a)=merge(rc(a),b),up(a),a;
    return down(b),lc(b)=merge(a,lc(b)),up(b),b;
}

ll dfs(int u)
{
    if(!u) return 0;
    return down(u),(t[u].val<0?t[u].val:0)+dfs(lc(u))+dfs(rc(u));
}

void Main()
{
    node=rt=0;
    n=read();
    ll sumb=0;
    for(int i=1;i<=n;i++)
    {
        int k=read(),b=read(),a=read();
        p[i].k=k,sumb+=b,p[i].c=b-a;
    }
    sort(p+1,p+n+1,[](data a,data b){return a.k>b.k;});
    ll f0=0;
    for(int i=1;i<=n;i++)
    {
        int a,b;
        split(rt,a,b,0,p[i].k,p[i].c);
        int c=newnode(1ll*p[i].k*(t[a].sz+1)-p[i].c);
        downn(b,p[i].k);
        rt=merge(merge(a,c),b);
        f0+=p[i].c;
    }
    printf("%lld\n",sumb-(f0+dfs(rt)));
}

int main()
{
    int T=read();
    while(T--) Main();
    return 0;
}