题解 AT4168 [ARC100C] Or Plus Max

· · 题解

参考博客

考虑对每个 k 分别进行计算。即令 A_k=\max_{i|j=k,i\not=j}a_i+a_j,则 ans_k=\max_{i|j\le k,i\not=j}a_i+a_j=\max_{i=1}^{k}A_i

由于答案是对每种可能的 a_i+a_j 取 max,因此不妨扩展一下 A_i 的范围。把 i,j,k 看做二进制数,我们说 i\subseteq j,当且仅当 i 的每一位都小于等于 j。令 A_k=\max_{i,j\subseteq k,i\not=j}a_i+a_j,则 ans_k 仍等于 \max_{i=1}^{k}A_i

上述式子成立是由于或运算的性质。或运算是不减的,并且 i,j\subseteq ki|j=k 的必要条件、i|j\le k 的充分条件。这使得所有满足 i,j\subseteq k,i\not=j(i,j) 组成的集合保留了原来那些满足 i|j=k,i\not=j(i,j) 的集合,又没有把 i|j>k(i,j) 放进来,因此答案的正确性不变。

现在我们需要求对于每个 k\max_{i,j\subseteq k,i\not=j}a_i+a_j,这是类子集和问题,得借助高维前缀和完成。

高维前缀和

一维前缀和不用多说,二维前缀和一般是用容斥来做。但容斥的复杂度为 O(2^nV),其中 n 为维数,V 为点数。随着维度的增加复杂度将变得无法接受。其实我们可以以维为单位,逐维拓展,复杂度就会降至 O(nV)。具体思路可以参考如下代码:

求二维前缀和:

for(int i = 1; i <= n; i++)
    for(int j = 1; j <= n; j++)
        a[i][j] += a[i - 1][j];
for(int i = 1; i <= n; i++)
    for(int j = 1; j <= n; j++)
        a[i][j] += a[i][j - 1];

求三维前缀和:

for(int i = 1; i <= n; i++)
    for(int j = 1; j <= n; j++)
        for(int k = 1; k <= n; k++) 
            a[i][j][k] += a[i - 1][j][k];
for(int i = 1; i <= n; i++)
    for(int j = 1; j <= n; j++)
        for(int k = 1; k <= n; k++)
            a[i][j][k] += a[i][j - 1][k];
for(int i = 1; i <= n; i++)
    for(int j = 1; j <= n; j++)
        for(int k = 1; k <= n; k++)
            a[i][j][k] += a[i][j][k - 1];

手模一下就非常清晰明了了,更高维度的以此类推。

回到本题上来,A_k=\max_{i,j\subseteq k,i\not=j}a_i+a_j 实际上就是求高维前缀最大值和次大值,特殊之处在于它的对象是一个边长为 2n 维超立方体,点数为 2^n,总时间复杂度为 O(n2^n)

代码如下(码字不易,还望给个赞QAQ):

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<vector>
#include<queue>
#include<map>
#define pb push_back
#define sml(x,y) (x=min(x,y))
#define big(x,y) (x=max(x,y))
#define ll long long
#define db double
#define fo(i,x,y) for(register int i=x;i<=y;++i)
#define go(i,x,y) for(register int i=x;i>=y;--i)
using namespace std;
inline int read(){int x=0,fh=1; char ch=getchar(); while(!isdigit(ch)){if(ch=='-') fh=-1; ch=getchar();} while(isdigit(ch)){x=(x<<1)+(x<<3)+ch-'0'; ch=getchar();} return x*fh;}

const int N=19,inf=1e9;
int n,m;
struct Node{
    int m1,m2;
    Node operator+(const Node &x){
        Node y;
        if(m1>x.m1){
            y.m1=m1;
            y.m2=max(m2,x.m1);
        }else{
            y.m1=x.m1;
            y.m2=max(m1,x.m2);
        }return y;
    }
}a[1<<N];

int main(){
    cin>>n;m=1<<n;
    fo(i,0,m-1) a[i].m1=read(),a[i].m2=-inf;
    fo(i,0,n-1)
        fo(j,0,m-1) if((j>>i)&1) a[j]=a[j]+a[j^(1<<i)];
    //cout<<a[0].m1+a[0].m2<<endl;
    int ans=0;
    fo(j,1,m-1){
        big(ans,a[j].m1+a[j].m2);
        printf("%d\n",ans);
    }
    return 0;
}