现在有 n 堆石子,每次可以从不同的两堆里面各拿走一个,求最少剩下多少。
设 S = \sum \limits _{i = 1} ^n a_i,m = \max \limits _{i = 1} ^n a_i,则我们去比较 S - m 和 m 的大小。
若 S - m < m,即每次操作都拿走最大值中的一个,取走其他所有石子之后最多的那一堆仍然有剩余。
此时答案显然为 S - 2m。
若 S - m \ge m,则答案为 S \bmod 2。
证明的话考虑直接构造方案。每次选择最大的两堆石子分别拿走一个。
若 m 是最大值且个数不超过两个,则操作之后 (S - 2) - (m - 1) \ge (m - 1) 仍然成立。
若 m 是最大值且出现超过两次,则有 S \ge 3m,操作之后只要 m > 1 就有 (S - 2) - m \ge 2m - 2 \ge m。
---
回到原问题,现在我们需要考虑操作 2 的影响。
正常情况下 $x$ 位置需要 $x$ 次操作变为 $0$,但若 $x$ 前面存在奇数位置就可以选择直接跳过去从而减少操作次数。
这里我们记 $l_i$ 和 $r_i$ 表示位于 $i$ 位置最多/最少分别需要多少次操作 1。显然有 $r_i = i$,同时 $l_i$ 也可以方便的求解。
仍然设 $S = \sum \limits _{i = 1} ^n b_i$,$p$ 表示 $r_i$ 取到最大值的 $i$,则我们现在去考虑 $S - r_p$ 和 $r_p$ 的大小关系。
若 $S - r_p \ge r_p$,则所有位置均取 $r$ 值即可,根据上面的结论此时答案为 $S \bmod 2$。
若 $S - r_p < r_p$,但是 $S - r_p \ge l_p$,由于我们知道一个奇数位置最多只能节约两步,所以 $l_p$ 与 $r_p$ 之间每相邻两个值之间就有一个可以取到。
考虑如果存在一个可以取到的值 $x$ 满足 $S - r_p \ge x$ 且 $x$ 为最大值,那么答案就是 $S \bmod 2$。
如果取到的 $x$ 不是最大值那么只可能是 $S = x + (x + 1)$ 且 $x + 1$ 位置我们取不到,此时 $S$ 为奇数且刚好答案为 $1$,不影响结论。
最后是如果 $S - r_p < l_p$,那么 $p$ 位置只能被操作一削减 $S - r_p$ 次,此时贪心向下减到最小即可。
最后考虑 Alice 还可以操作一次的问题。
如果存在大于一个奇数位置,那么最后的答案可以 $-2$。
如果我们的答案包括 $S \bmod 2 = 1$ 的部分,则可以直接消掉。
如果是 $S - r_p < l_p$ 的情况,我们也可以让答案再 $-2$。
除此之外的其他情况无法再操作。
由于需要离散化/排序所以直接用了 `map`,复杂度 $O(n \log n)$。
如果你相信基数排序的话那就是 $O(n)$ 的。
细节有点麻烦,贴个代码吧。
```cpp
#include <map>
#include <stdio.h>
#include <algorithm>
using namespace std;
typedef long long ll;
const int sz = 200005;
struct node
{
int l, r, tim;
};
int n, top, mx;
long long ans, sum;
map<int, int> mp;
node num[sz];
int read();
int solve(int, int);
void cld();
int main()
{
n = read(); read();
for (int i = 1; i <= n; ++i)
++mp[read()];
int last = 0, pir = 0;
for (auto i : mp)
{
num[++top].tim = i.second;
num[top].r = i.first;
num[top].l = num[last].l + (num[top].r - num[last].r) - pir * 2;
if ((num[top].tim & 1) && (num[top - 1].tim & 1))
{
ans += num[top].r;
num[top].tim ^= 1;
}
if (num[top].tim & 1)
++pir;
else
{
last = top;
pir = 0;
}
}
for (int i = 1; i <= top; ++i)
{
if (num[i].tim & 1)
ans += num[i].r;
num[i].tim >>= 1;
sum += (ll)num[i].r * num[i].tim;
}
while (top && num[top].tim == 0)
--top;
if (sum - num[top].r < num[top].r)
{
int tmp = solve(sum - num[top].r, num[top].r);
if (tmp)
tmp -= 1;
else
cld();
ans += tmp * 2;
}else
cld();
printf ("%lld\n", ans);
return 0;
}
int read()
{
int x = 0;
char c = getchar();
while (c < '0') c = getchar();
do {
x = x * 10 + (c & 15);
c = getchar();
}while (c >= '0');
return x;
}
int solve(int a, int pos)
{
mp[0] = 2;
auto i = mp.end();
for (--i; pos; --i)
{
if (i->first < pos)
{
if (i->second & 1)
{
if (a < pos - i->first - 1)
{
pos -= a;
break;
}
a -= pos - i->first - 1;
pos = i->first - 1;
}else
{
if (a < pos - i->first)
{
pos -= a;
break;
}
a -= pos - i->first;
pos = i->first;
}
}
}
return pos;
}
void cld()
{
sum = 0;
for (int i = 1; i <= top; ++i)
sum ^= num[i].tim & num[i].r;
if (sum & 1)
return;
int tim = 0;
for (auto i : mp)
tim += i.second & 1;
if (tim > 1)
ans -= 2;
}
```