CF1279F 题解 - wqs二分详解

· · 题解

困了我一下午加晚上的 wqs 二分终于差不多弄懂了,找的很多博客和题解都或多或少没有详细的讲 wqs 二分,那么下面我详细的讲讲我关于wqs 二分的各种疑惑和理解。

简化题意和一些小转化看其他题解就行,不是这篇题解的重点。

wqs 二分在哪里可以应用

wqs 二分常常应用于带有某些特征动态规划问题的优化方面。该类可以使用 wqs 二分优化的问题常常形如 给定若干个物品,要求恰好进行 m 次操作,最大化或最小化操作后的价值(根据题目计算)。 在此的操作可以为选取,或变换等等,但是我们总归可以把他转化为背包问题。

特征为:

1.一般来说,随着操作次数的增加,价值是单调变化的。比如此题,随着操作的增加,答案是单调下降的。

2.如果不限制选的个数,那么很容易求出最优方案。

此外,全面的来说,如果把操作次数(x)和在此操作次数下的最优价值(g(x))以函数形式表示,如果该函数是个凸函数,那么可以使用。

wqs 二分的思想及实现

偷几张图。

假设操作次数与价值形成了这样的函数图像:

这里要明确现在:

现在不知道:

现在观察到:

while(l <= r){
    int mid = (l + r) >> 1;
    if(check(mid) <= m) l = mid + 1, p = mid;
    else r = mid - 1;
}//mid为二分的斜率。

所以目标变成了:

接下来,可以做到:

接下来考虑怎么求出该直线的切点的横坐标。

现在有 x=0 时的 y 坐标,看第三个图,由于横坐标之间是彼此相差 1,现在二分出的斜率为 mid,所以当直线交在某个点上时,它与 y 轴的截距 b = y - mid x 。比 x-1 时的截距小 mid

于是可以通过一个 O(n) 的 DP 来维护 考虑到第 i 个决策位置的时候,最优操作方案的操作次数和直线的截距

注意这里不能用上图来协助思考,因为上图的图形是 考虑完 n 个决策位置的,对于不同操作数的最优解组成的图形。与 DP 中的前后决策位置和操作数转移完全没关系。

那么在 DP 中,对 f_i 维护两个值 f,s 表示:

当前处理到了第 i 位,得到一个点 (f_i.s, j),f_i.s 为当前最优方案的操作数,j 为最优方案的值(不记录,使用 f_i.f进行转移可以保证正确性,最后可以通过斜率,横坐标和截距计算出来),斜率为 mid 的直线过这个点的截距为 f_i.f

int check(int mid){
    for(int i = 1; i <= n; ++ i){
        pair<int, int> tmp = f[i - 1];
        tmp.first += a[i], f[i] = tmp;
        tmp = f[max(i - len, 0)];
        tmp.first -= mid; tmp.second ++;
        f[i] = min(f[i], tmp);
    }
    return f[n].second;
}

最后返回的值即为切点的横坐标。

以返回的值与给定的操作数的大小关系维护二分,即可得出切图像与给定点的切线的截距,斜率。结合横坐标可以求出答案。

Code

// 由于是借助第一篇题解和其他博客进行的学习,代码确实很相似,但是懒得改了。
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>

using namespace std;

inline int read(){
    register int x = 0, f = 1; register char ch = getchar();
    for(; ch > '9' || ch < '0'; ch = getchar()) if(ch == '-') f = -1;
    for(; ch >= '0' && ch <= '9'; ch = getchar()) x = (x << 1) + (x << 3) + (ch ^ '0');
    return x * f; 
}

const int N = 1e6 + 10;

int a[N];
char s[N];
int n, m, len, ans = 2147483647;
pair<int,int> f[N];

int check(int mid){
    for(int i = 1; i <= n; ++ i){
        pair<int, int> tmp = f[i - 1];
        tmp.first += a[i], f[i] = tmp;
        tmp = f[max(i - len, 0)];
        tmp.first -= mid; tmp.second ++;
        f[i] = min(f[i], tmp);
    }
    return f[n].second;
}

void solve(){
    int l = -n, r = 0, p;
    while(l <= r){
        int mid = (l + r) >> 1;
        if(check(mid) <= m) l = mid + 1, p = mid;
        else r = mid - 1;
    }
    check(p);
    ans = min(ans, f[n].first + p * m);
}

int main(){
    n = read(); m = read(); len = read(); cin >> (s + 1);
    for(int i = 1; i <= n; ++ i) a[i] = (s[i] >= 'a' && s[i] <= 'z'); solve();
    for(int i = 1; i <= n; ++ i) a[i] ^= 1; solve();
    printf("%d\n", ans);
    return 0;
}