CF2174C2 Beautiful Patterns (Hard Version) 题解

· · 题解

题目大意不讲了。

如果你计数做多了,你会发现:次数的平方其实并不好算。于是我们考虑转换。

容易想到的一个双射是:有多少个 l_1,r_1,l_2,r_2,满足区间 [l_1,r_1] 和区间 [l_2,r_2] 都是回文的,这样的一个期望。根据期望的线性性,我们套路地将它们拆开,题目变成:区间 [l_1,r_1] 和区间 [l_2,r_2] 都是回文的期望的总和。因为这个条件的取值只有 01,因此我们可以变成算概率的总和。

先考虑一般情况,两个区间可能有交。这样其实是不好直接计数的,所以稍微转化一下。不难想到如果限制要求 s_{i-k}=s_{i+k},那么我们将 i-ki+k 连一条无向边。这样会有若干连通块,分别染色即可。预处理一下,然后直接暴力应该可以做到 O(n^4)O(n^5)

我们发现这个连通块其实很烦。于是不妨开始分类讨论。

第一种:[l_1,r_1][l_2,r_2] 的回文中心相同。这个情况最简单,令两个区间长度的最大值为 L,则满足条件的概率为 \frac{m^{\lceil\frac{L}{2}\rceil}}{m^L}

第二种:[l_1,r_1][l_2,r_2] 的回文中心不同。考虑其交集大小为 len,第一个区间的长度为 L_1,第二个区间的长度为 L_2,则概率为 \frac{m^{\lceil\frac{L_1+1}{2}\rceil+\lceil\frac{L_2+1}{2}\rceil-len}}{m^{L-len}}=\frac{m^{\lceil\frac{L_1+1}{2}\rceil+\lceil\frac{L_2+1}{2}\rceil}}{m^{L}}

这是一个好消息,因为这意味着我们只需要枚举两个区间的长度就可以了。稍微处理一些细节,然后注意提前预处理,可以做到 O(n^2),即通过 Easy Version。

代码将就看看,因为很丑:


#include<bits/stdc++.h>
#define int long long
using namespace std;
int mod;
const int N=1000010;
const int INF=0xc0c0c0c0c0c0c0c0;
int ksm(int a,int b){
    int z=1;
    while(b){
        if(b&1)z=z*a%mod;
        a=a*a%mod;
        b>>=1;
    }
    return z;
}
int inv(int x){return ksm(x,mod-2);}
int n,m;
int pw[N];
int ipw[N];
void solve(){
    cin>>n>>m>>mod;
    pw[0]=1;
    for(int i=1;i<=2*n;i++)pw[i]=pw[i-1]*m%mod;
    int I=inv(m);
    ipw[0]=1;
    for(int i=1;i<=2*n;i++)ipw[i]=ipw[i-1]*I%mod;
    int ans=0;
    for(int i=1;i<=n;i++)ans=(ans+2*(n-i+1)*((i-1)/2)%mod*pw[(i+1)/2]%mod*ipw[i])%mod;
    for(int i=1;i<=n;i++)ans=(ans+(n-i+1)*pw[(i+1)/2]%mod*ipw[i])%mod;
    //上面都是在处理第一种情况,这里默认L1<=L2,然后注意同区间的特殊情况
    for(int i=1;i<=n;i++){
        for(int j=1;j<=n;j++){
            int xx=(n-i+1)*(n-j+1);
            if(i%2==j%2)xx-=min(n-i+1,n-j+1);
            xx%=mod;
            ans=(ans+xx*pw[(i+1)/2+(j+1)/2]%mod*ipw[i+j])%mod;
        }
    }
    cout<<ans<<"\n";
    return;
}
signed main(){
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    int Tc=1;
    cin>>Tc;
    while(Tc--)solve();
    return 0;
}
/*

*/

其实接下来没什么好说的,因为你发现这个东西关于 L_1 的和关于 L_2 的式子都是独立的。所以你可以直接前缀和什么的乱搞搞就过了,时间复杂度 O(n)

#include<bits/stdc++.h>
#define int long long
using namespace std;
// const int mod=1e9+7;
int mod;
const int N=1000010;
const int INF=0xc0c0c0c0c0c0c0c0;
int ksm(int a,int b){
    int z=1;
    while(b){
        if(b&1)z=z*a%mod;
        a=a*a%mod;
        b>>=1;
    }
    return z;
}
int inv(int x){return ksm(x,mod-2);}
int n,m;
int pw[N],ipw[N];
int sum[2];
void solve(){
    cin>>n>>m>>mod;
    pw[0]=1;for(int i=1;i<=2*n;i++)pw[i]=pw[i-1]*m%mod;
    int I=inv(m);
    ipw[0]=1;for(int i=1;i<=2*n;i++)ipw[i]=ipw[i-1]*I%mod;
    int ans=0;
    for(int i=1;i<=n;i++)ans=(ans+2*(n-i+1)*((i-1)/2)%mod*pw[(i+1)/2]%mod*ipw[i])%mod;
    for(int i=1;i<=n;i++)ans=(ans+(n-i+1)*pw[(i+1)/2]%mod*ipw[i])%mod;
    //上面处理第一种情况
    int tot=0;
    for(int i=1;i<=n;i++)tot=(tot+(n-i+1)*pw[(i+1)/2]%mod*ipw[i])%mod;
    ans=(ans+tot*tot)%mod;//这里处理第二种情况,没有减去第一种情况下的方案
    sum[0]=sum[1]=0;
    for(int i=1;i<=n;i++){
        sum[i&1]=(sum[i&1]+pw[(i+1)/2]*ipw[i])%mod;
        ans=(ans-(n-i+1)*pw[(i+1)/2]%mod*ipw[i]%mod*sum[i&1]%mod+mod)%mod;
    }
    sum[0]=sum[1]=0;
    for(int i=1;i<=n;i++){
        ans=(ans-(n-i+1)*pw[(i+1)/2]%mod*ipw[i]%mod*sum[i&1]%mod+mod)%mod;
        sum[i&1]=(sum[i&1]+pw[(i+1)/2]*ipw[i])%mod;
    }
    //上面处理第二种多算的情况,暂定L1<=L2,注意两个区间可能相同,所以我采用减两次
    cout<<ans<<"\n";
    return;
}
signed main(){
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    int Tc=1;
    cin>>Tc;
    while(Tc--)solve();
    return 0;
}
/*

*/