题解:P9333 [JOISC 2023 Day2] Council

· · 题解

提供一个不太一样的做法,没有使用高维前缀和之类的二进制处理的工具。

首先注意到 m 很小,可以大概猜想一下复杂度可能是对 m 的指数级的做法。首先可以思考一下对于第 i 个人作为主席的情况下,他最优的情况下应该选择谁作为自己的副主席。对于每一个提案,假设把当前的主席扔掉之后,如果投票人数已经严格小于需求的数量,则这个提案一定不会通过。类似的,如果一个提案严格大于需求数量,则这个提案一定会通过。关键就是那些恰好等于需求数量的提案。

如果我们将那些恰好等于的提案状压成一个二进制数,则问题被转换成找到一个二进制数,使得两个数 xy 的按位与的 1 的个数尽量小。

然后我没有想到怎么使用任何二进制处理的方法解决这个问题。考虑为什么最基础的暴力很慢,每次添加一个人的时候相当于一次修改操作,我们每次查询的时候都得枚举 2^m 的代价,再加上我们枚举主席的复杂度,复杂度就来到了 O(n 2^m) 。注意到我们的修改操作非常的快,这是没有必要的,我们只会修改 O(n) 次,于是考虑平衡修改与查询的复杂度。考虑将一个状态 x 分成两段,修改的时候枚举查询的时候前半段的可能,并更新对应的后半部分不变的最小答案,查询的时候枚举后半段即可将查询和修改都平衡到 O(2^{m/2}) 的。对于主席不能将自己选为副主席的处理在这个做法里是非常容易的,直接正着扫一遍,只处理 i < j,然后倒着扫一遍,只处理 i > j 。 复杂度为 O(n2^{m/2}) ,代码好写常数很小。

代码

#include<bits/stdc++.h>
using namespace std;
const int N=300005;
const int bk=(1<<10)-1;
int n,m;
int a[N][21];
int vl[(1<<10)+5][(1<<10)+5],p;
int num[N];
int val[N];
int ans[N];
int cnt[(1<<10)+5];
void insert(int num)
{
    int half=num&bk;
    int res=num>>10;
    for(int i=0;i<1024;i++)
    {
        vl[i][res]=min(vl[i][res],cnt[half&i]);
    }
}
int get(int x)
{
    int half=x&bk;
    int res=x>>10;
    int val=1e9+7;
    for(int i=0;i<1024;i++)
    {
        val=min(val,vl[half][i]+cnt[res&i]);
    }
    return val;
}
int main()
{
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    cin>>n>>m;
    for(int i=0;i<1024;i++)
    {
        cnt[i]=__builtin_popcount(i);
    }
    for(int i=1;i<=n;i++)
    {
        for(int j=1;j<=m;j++)
        {
            cin>>a[i][j];
            if(a[i][j])
            {
                num[j-1]++;
                val[i]|=(1<<j-1);
            }
        }
    }
    m=20;

    memset(vl,0x3f,sizeof(vl));
    insert(val[1]);
    int mid=n>>1;
    for(int i=2;i<=n;i++)
    {
        int sum=0;
        for(int j=0;j<m;j++)
        {
            if((val[i]>>j)&1)num[j]--;
        }
        int tmp=0;
        for(int j=0;j<m;j++)
        {
            if(num[j]>=mid)sum++;
            if(num[j]==mid)tmp|=(1<<j);
        }
        ans[i]=max(ans[i],sum-get(tmp));
        for(int j=0;j<m;j++)
        {
            if((val[i]>>j)&1)num[j]++;
        }
        insert(val[i]);
    }

    memset(vl,0x3f,sizeof(vl));
    insert(val[n]);
    for(int i=n-1;i>=1;i--)
    {
        int sum=0;
        for(int j=0;j<m;j++)
        {
            if((val[i]>>j)&1)num[j]--;
        }
        int tmp=0;
        for(int j=0;j<m;j++)
        {
            if(num[j]>=mid)sum++;
            if(num[j]==mid)tmp|=(1<<j);
        }

        ans[i]=max(ans[i],sum-get(tmp));
        for(int j=0;j<m;j++)
        {
            if((val[i]>>j)&1)num[j]++;
        }
        insert(val[i]);
    }
    for(int i=1;i<=n;i++)
    {
        cout<<ans[i]<<'\n';
    }
}