CF1279F 题解 - wqs二分详解
aaaaaaaawsl · · 题解
困了我一下午加晚上的 wqs 二分终于差不多弄懂了,找的很多博客和题解都或多或少没有详细的讲 wqs 二分,那么下面我详细的讲讲我关于wqs 二分的各种疑惑和理解。
简化题意和一些小转化看其他题解就行,不是这篇题解的重点。
wqs 二分在哪里可以应用
wqs 二分常常应用于带有某些特征动态规划问题的优化方面。该类可以使用 wqs 二分优化的问题常常形如 给定若干个物品,要求恰好进行
特征为:
1.一般来说,随着操作次数的增加,价值是单调变化的。比如此题,随着操作的增加,答案是单调下降的。
2.如果不限制选的个数,那么很容易求出最优方案。
此外,全面的来说,如果把操作次数(
wqs 二分的思想及实现
偷几张图。
假设操作次数与价值形成了这样的函数图像:
这里要明确现在:
-
知道每个点的横坐标(
1 ,2,3,4 …… )。 -
知道这个函数图像的斜率是单调变化的(这个要之前推,推出来才能用 wqs 二分)。
-
现在不知道:
- 每个点的实际的
y 值。
现在观察到:
- 当确定了一条直线的斜率,这条直线一定与图像相切在某个点上。
- 由于图像的斜率是单调的,如果可以求出某一个给定的斜率切这个图像与哪个点上,就可以二分斜率,使其切在目标点上。此时的
y 值就是答案。
while(l <= r){
int mid = (l + r) >> 1;
if(check(mid) <= m) l = mid + 1, p = mid;
else r = mid - 1;
}//mid为二分的斜率。
所以目标变成了:
-
对于一个斜率,得到它与图像相切点的
x 值。 -
对于
x 值,得到y 值。
接下来,可以做到:
-
可以预处理不操作(
x=0 )的情况(y )。 -
得到一个规律,对于一个给定斜率的直线,在它过图上每一个点的情况下,过切点的情况的截距最大,比如下图。
接下来考虑怎么求出该直线的切点的横坐标。
现在有
于是可以通过一个
注意这里不能用上图来协助思考,因为上图的图形是 考虑完
那么在 DP 中,对
当前处理到了第
int check(int mid){
for(int i = 1; i <= n; ++ i){
pair<int, int> tmp = f[i - 1];
tmp.first += a[i], f[i] = tmp;
tmp = f[max(i - len, 0)];
tmp.first -= mid; tmp.second ++;
f[i] = min(f[i], tmp);
}
return f[n].second;
}
最后返回的值即为切点的横坐标。
以返回的值与给定的操作数的大小关系维护二分,即可得出切图像与给定点的切线的截距,斜率。结合横坐标可以求出答案。
Code
// 由于是借助第一篇题解和其他博客进行的学习,代码确实很相似,但是懒得改了。
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
inline int read(){
register int x = 0, f = 1; register char ch = getchar();
for(; ch > '9' || ch < '0'; ch = getchar()) if(ch == '-') f = -1;
for(; ch >= '0' && ch <= '9'; ch = getchar()) x = (x << 1) + (x << 3) + (ch ^ '0');
return x * f;
}
const int N = 1e6 + 10;
int a[N];
char s[N];
int n, m, len, ans = 2147483647;
pair<int,int> f[N];
int check(int mid){
for(int i = 1; i <= n; ++ i){
pair<int, int> tmp = f[i - 1];
tmp.first += a[i], f[i] = tmp;
tmp = f[max(i - len, 0)];
tmp.first -= mid; tmp.second ++;
f[i] = min(f[i], tmp);
}
return f[n].second;
}
void solve(){
int l = -n, r = 0, p;
while(l <= r){
int mid = (l + r) >> 1;
if(check(mid) <= m) l = mid + 1, p = mid;
else r = mid - 1;
}
check(p);
ans = min(ans, f[n].first + p * m);
}
int main(){
n = read(); m = read(); len = read(); cin >> (s + 1);
for(int i = 1; i <= n; ++ i) a[i] = (s[i] >= 'a' && s[i] <= 'z'); solve();
for(int i = 1; i <= n; ++ i) a[i] ^= 1; solve();
printf("%d\n", ans);
return 0;
}