世界未末日 加强版 题解

· · 题解

数据范围:1 \le k \le n \le 3 \times 10^6,1 \le S \le 3 \times 10^{17}

首先原来的做法显然用不了了。我们考虑找一下 dSG 函数(即从 SG 值对应回最少石子数的函数)的性质。

我们发现数列长这样(称为数列 0,下标从 0 开始,其余数列下标均从 1 开始):

0,4,11,20,32,46,62,80,100,123,148,\dots

差分一下(得到的数列称为数列 1):

4,7,9,12,14,16,18,20,23,25,\dots

再差分一下(得到的数列称为数列 2):

3,2,3,2,2,2,2,3,2,\dots

发现数列 2 里全是 23。多算一些会发现当我们以每个 3 为结尾带上前面连续一些 2 组成一些段后,这些段的长度分别是 1,2,5,10,21,42,85,170,\dots,即二进制表示下为 01 相间的数。我们设第 i 个这样的数为 x_i

首先我们尝试解数列 1 的通式。由于数列 2 里只有 32,我们考虑把数列 1 里的数表示成 2a_i+b_i 的形式,令 a_1=2,b_1=0。则对于每个 ia_{i+1}=a_i+1,对于数列 2 中第 i 个数为 3ib_{i+1}=b_i+1,否则有 b_{i+1}=b_i

S_ix_i 的前缀和,则当且仅当 i=S_j 有数列 2 中第 i 个数为 3。于是 a_i=i+1b_i=\max\{j:S_j \le i\}-1

那么我们可以分别求 a_ib_i 的前缀和得到数列 0 中的数。设数列 0 中下标为 i 的是 2c_i+d_i,则 ca 的前缀和,db 的前缀和。因此有

c_v=\sum_{i=1}^{v}a_i =\sum_{i=1}^{v}(i+1) =\dfrac{v(v+3)}{2} d_v=\sum_{i=1}^{v}b_i =\sum_{i=1}^{v}(\max\{j:S_j \le i\}-1)

t=\max\{j:S_j \le v\},则上式

=t(v-S_t)+\sum_{j=1}^{t}(j-1)x_j =\sum_{i=1}^{t}(v-S_t)+\sum_{j=1}^{t}\sum_{i=1}^{j-1}x_j =\sum_{i=1}^{t}\left(v-S_t+\sum_{j=i+1}^{t}x_j\right) =\sum_{i=1}^{t}(v-S_i)

T_iS_i 的前缀和,则 d_v=tv-T_t

因此数列 0 的第 v 项为 \operatorname{dSG}(v)=2c_v+d_v=v^2+(3+t)v-T_t。由于 \operatorname{dSG}(v) 不减,并且 S_t \le v < S_{t+1},因此我们只要二分出最大的 S_t^2+(3+t)S_t-T_t \le st,再求出满足 v^2+(3+t)v-T_t \le s 的最大 v,那么 v 就是 s 的 SG 值。前面的二分是 O(\log\log{S}) 的,后面求 k 由于左边是二次函数且开口朝上,可以直接解出方程正根然后下取整,显然就是所求。这一步是 O(n\log\log{S}) 的。

然后由于暴力求每个二进制位上数值和是 O(n\log{S}) 的,会被卡,所以要优化(如果您能卡过去可以私信我)。以下优化方案由永远滴神 @tiger2005 提出。(据说是常见 trick?是我不行)

考虑分治计算贡献。设所有 SG 值最大的二进制位数为 m,我们把每个数的二进制拆成 w 段,每段不超过 \left\lceil\dfrac{m}{w}\right\rceil 位。(尽可能平均分,为简便接下来省略上取整符号)

对于每段,先将每个数这一段中二进制位对应的所有可能取值出现次数用一个下标为 [0,2^{\frac{m}{w}}) 的桶存起来。如 \dfrac{m}{w}=5,某个数的二进制第 5 \sim 9 位为 01011,则将 11 位的计数器加上 1,类似处理。

然后对于这 2^{\frac{m}{w}} 个数暴力枚举每个数二进制表示里面 1 的位置计算贡献。

这样的时间复杂度是 O(w(n+2^{\frac{m}{w}-1})),由于 SG 值不超过 \sqrt{S},取 w=2 时复杂度是 O(n+S^{\frac{1}{4}}\log{S}) 的,可以通过。

总时间复杂度是 O(n\log\log{S}+S^{\frac{1}{4}}\log{S})

Code:

#include<cmath>
#include<cstdio>
#define rg register
#define ll long long
#define lb(x) (x&(-x))
inline char gc()
{
    static char buf[1048576],*pn=buf,*pe=buf;
    return (pn==pe)&&(pe=(pn=buf)+fread(buf,1,1048576,stdin),pn==pe)?EOF:*pn++;
}
inline int read()
{
    int x=0;
    char cc=gc();
    while(cc<'0'||cc>'9')cc=gc();
    while(cc>='0'&&cc<='9')x=x*10+cc-'0',cc=gc();
    return x;
}
inline ll _read()
{
    ll x=0;
    char cc=gc();
    while(cc<'0'||cc>'9')cc=gc();
    while(cc>='0'&&cc<='9')x=x*10+cc-'0',cc=gc();
    return x;
}
ll S,s;
int n,k;
int t,w,m;
int v,flag;
ll ssmv[37];
int cnt[37],pw2[17];
int x[37],sm[37],ssm[37];
int cnt1[32773],cnt2[32773];
int ans1[32773],ans2[32773];
int main()
{
    n=read(),k=read()+1,S=_read();
    while(1ll<<(m<<1)<=S)++m;
    --m,t=(m>>1)+1,w=(1<<t)-1;
    pw2[0]=x[1]=sm[1]=ssm[1]=1;
    for(rg int i=1;i<t;++i)pw2[i]=(pw2[i-1]<<1);
    for(rg int i=2;i<=m;++i)
    {
        x[i]=(x[i-1]<<1)|(i&1),sm[i]=sm[i-1]+x[i];
        ssm[i]=ssm[i-1]+sm[i],ssmv[i]=(sm[i]+3ll+i)*sm[i]-ssm[i];
    }
    for(rg int i=0,l,r,mid;i<n;++i)
    {
        s=_read(),l=0,r=m;
        while(l<r)mid=r-((r-l)>>1),(ssmv[mid]<=s)?(l=mid):(r=mid-1);
        v=int((sqrt((l+3)*(l+3)+((s+ssm[l])<<2))-l-3)/2),++cnt1[v&w],++cnt2[v>>t];
    }
    for(rg int i=0;i<=w;++i)
    {
        for(rg int j=i;j;j^=lb(j))
        {
            ans1[lb(j)]+=cnt1[i];
        }
    }
    for(rg int i=0;i<t;++i)cnt[i]=ans1[pw2[i]];
    for(rg int i=0;i<=w;++i)
    {
        for(rg int j=i;j;j^=lb(j))
        {
            ans2[lb(j)]+=cnt2[i];
        }
    }
    for(rg int i=0;i<t;++i)cnt[i+t]=ans2[pw2[i]];
    for(rg int i=0;i<=m;++i)(cnt[i]%k)&&(flag=1);
    puts((flag)?"YES":"NO");
    return 0;
}

关于解法中关于 2,3 个数的结论我还没有找到一个证明(但是跑完整个数据范围验证确实暂时没有出错),如果有神仙找到了证明方法或举出了反例欢迎私信我。如果证明很简单请顺带把我这个垃圾爆踩一顿