题解:P9338 [JOISC 2023] Chorus (Day3)

· · 题解

容易发现,一个字符串可能合法当且仅当第 iA 在第 iB 前面。并且,在最优操作下,A 或者 B 之间的相对顺序一定是不变的,那么一定是第 [l,r]A 去匹配第 [l,r]B

w(l,r) 表示将第 [l+1,r]AB 去匹配的代价,设 b_k 表示出现在第 kA 之前的 B 的数量。则 w(l,r)=\sum_{k=l+1}^r\max(0,b_k-l),相当于去统计会和每个 A 交换的 B 的数量。

会不会有 B 跨越了第 \le lA 呢?那么这会在之前就进行计算。

设一个 f_{i,l} 表示当前划分出来了 l 个子序列,下一个子序列开头为第 i+1A 所需的最小代价,则:f_{j,l+1}\gets f_{i,l}+w(i,j)

显然 w(i,j) 满足四边形不等式,那么 f_{n,i=1,2,\cdots} 是满足凸性的,且是一个下凸包。(具体证明,把他的 s,r,t 看做 x-1,x,x+1 即可)

使用 wqs 二分进行优化,dp 式子变为:f_j\gets f_i+w(i,j)-mid

而考虑设 p_i 表示第一个满足 b_k\ge ik,且 p_i=\max(p_i,i+1),则 w(l,r)=\sum_{i=p_{l}}^r b_i-l,设 b_i 的前缀和为 s_i,则 w(l,r)=s_r-s_{p_l-1}-l\times(r-p_{l}+1)

可以使用斜率优化,时间为 \mathcal{O}(n\log V)

#include<bits/stdc++.h>
#define ull unsigned long long
#define int long long
#define p_b push_back
#define m_p make_pair
#define pii pair<int,int>
#define fi first
#define se second
#define ls k<<1
#define rs k<<1|1
#define mid ((l+r)>>1)
#define gcd __gcd
#define lowbit(x) (x&(-x))
using namespace std;
int rd(){
    int x=0,f=1; char ch=getchar();
    for(;ch<'0'||ch>'9';ch=getchar())if (ch=='-') f=-1;
    for(;ch>='0'&&ch<='9';ch=getchar())x=(x<<1)+(x<<3)+(ch^48);
    return x*f;
}
void write(int x){
    if(x>9) write(x/10);
    putchar('0'+x%10);
}
const int N=1e6+5,INF=1e12;

int f[N],g[N],q[N];
int b[N],bt,n,K,s[N],p[N];
int X(int i){return i;}
int Y(int i){return f[i]-s[p[i]-1]+p[i]*i-i;}
//维护下凸包
void check(int k){
    int h=0,t=0;
    for(int i=1;i<=n;i++){
        while(h<t&&(Y(q[h+1])-Y(q[h]))<(X(q[h+1])-X(q[h]))*i) h++;
        int j=q[h];
        f[i]=f[j]+s[i]-s[p[j]-1]-j*(i-p[j]+1)+k,g[i]=g[j]+1;
        while(h<t&&(Y(q[t])-Y(q[t-1]))*((X(i)-X(q[t])))>=((Y(i)-Y(q[t])))*(X(q[t])-X(q[t-1])))t--;
        q[++t]=i;    
    }
}

signed main(){
    n=rd(),K=rd();
    for(int i=1,sum=0;i<=2*n;i++){
        char ch=getchar();while(ch!='A'&&ch!='B') ch=getchar();
        if(ch=='A') b[++bt]=sum;
        else sum++;
    }
    for(int i=1;i<=n;i++) s[i]=s[i-1]+b[i];
    for(int i=0;i<=n;i++) p[i]=lower_bound(b+1,b+1+n,i)-b,p[i]=max(p[i],i+1);
    int l=0,r=INF,ans=0;
    while(l<=r){
        check(mid);
        if(g[n]<=K) ans=mid,r=mid-1;
        else l=mid+1;
    }
    check(ans);printf("%lld\n",f[n]-ans*K);
    return 0;
}