CF840C

· · 题解

总结了一下三种做法。个人觉得这是一道很好的题目,能够明白三种不同的解法对理解这种 DP 和容斥应该也有帮助(吧)。【new:添加了理论 O(n\log^2n) 的做法,但是因为 sb 模数没有 implement】。

首先考虑相乘等于平方相当于两个数除去平方因子后相同。所以我们把所有数都除去平方因子,然后问的相当于是序列有多少排排列满足相邻两项不等。为了更方便地 DP,我们对修改后的序列排序。接下来就是三种解法(代码用命名空间合成了一份,所以你会发现代码出奇的长,只是因为里面实际上把三种解法全写在里面了)。

O(n^3) 做法

考虑用插入的DP。f(i,j,k) 表示目前已经插入了前 i 个数,且目前有 j 个相邻位置的数相同,有 k 个位置满足与当前转移的数字相同的数相邻。转移先设 s_i 表示 i 之前有 s_i 个数与第 i 个数相等。考虑三种情况。

  1. 插入后自己与相邻的数相同(j 变大)。

    此时对 f(i,j,k) 的贡献为 (2_s{i}-k+1)\times f(i-1,j-1,k-1)

  2. 插入后没有任何变化(j 不变)。

    此时对 f(i,j,k) 的贡献为 (i-2s_{i+1}+2k-j)\times f(i-1,j,k)

  3. 插入后两个相同的数不相邻了(j 变小)。

    此时对 f(i,j,k) 的贡献为 (j-k+1)\times f(i-1,j+1,k)

上面转移问题其实很大,因为如果 a_i\neq a_{i+1},那么在计算 i-1i 的贡献的时候,需要有把所有 k>0f(i-1,j,k) 清空并将贡献记到 f(i-1,j,0) 中,这样才能保证不会有问题。至于空间我们滚动一下。

另一种 O(n^3) 做法

我们发现这个序列是可以被分成块的,满足每一块内数字相同。于是我们考虑每次插入好多个数而不止一个数。设第 i 块有 s_i 个相同的数。设 f(i,j) 表示目前插入了前 i 块,并且有 j 个位置相邻的数相同。

对于如何转移,我们考虑先枚举把第 i 块拆分成 k 个小块(这种方案数设其为 g_{s_i,k},比较好弄,之后再说)。首先其内部会产生 s_i-k 对相邻的相同的数。然后我们考虑枚举有多少个小块插入到了满足左右两边数值一样(即在插入前位置相邻且数值相同),设为 p,则对位置相邻且数值相同的对数的贡献为 -p

q 为第 i 块前插入了多少个数(即块的大小总和),于是乎一共有 q+1 个空位可以插。

f(i,j+s_i-k-p) \leftarrow \sum_{k} \binom{j}{p}\binom{q+1-j}{k-p}g(s_i,k)f(i-1,j)

我们看看 g 的转移。注意到块内的数是有顺序而言的。

g(i,j)\leftarrow (i-1+j)\times g(i-1,j)+j\times g(i-1,j-1)

关于初值,有 g(0,0)=f(0,0)=1

时间复杂度亦是 O(\sum s_in^2)=O(n^3)

比较仙的 O(n^2)

考虑容斥。首先我们转化为无标号的计数(最后乘上 \sum s_i! 就能转化成有标号的)。对于一个不合法方案,我们观察每一个数值,设其能被分为 b_i 个连续段。我们容斥 b,设 \sum b=r

\sum (-1)^{r} \frac{(n-r)!\prod\binom{s_i-1}{b_i}}{\prod (s_i-b_i)!}

其中 (n-r)! 为所有可能性,然后要除掉那些已经在一个连续段内的数的排列情况。然后分子的那个 \prod 是对连续段做插板法。考虑用 DP 求出这个值(类似背包计数)。f(i,j) 表示前 i 个数值,目前 r=j\sum (-1)^r \frac{\prod ...}{\prod ...}。显然我们不需要把 (n-r)! 这一项包括在 DP 的计算中,因为它可以在计算答案的时候算。

于是我们有转移:

f(i,j)=\sum_{k} f(i-1,j-k)\times \frac{\binom{s_i-1}{k}}{(s_i-k)!}

转移看上去非常卷积。但是我多项式很拉。所以就不管他了。

初值 f(0,0)=1。注意到我们还要转化成有标号计数,所以实际答案 ans=(\prod s_i!) \times \sum f_{cnt,i}(n-i)!。时间复杂度为 O(\sum s_in)=O(n^2)

#include<bits/stdc++.h>
#define int long long
#define rep(i,a,b) for(register int i=(a);i<=(b);i++)
#define per(i,a,b) for(register int i=(a);i>=(b);i--)
using namespace std;
const int N=309,mod=1e9+7;

inline long long read() {
    register long long x=0, f=1; register char c=getchar();
    while(c<'0'||c>'9') {if(c=='-') f=-1; c=getchar();}
    while(c>='0'&&c<='9') {x=(x<<3)+(x<<1)+c-48,c=getchar();}
    return x*f;
}

void pls(int &x,int y) {x+=y; x=(x>=mod?x-mod:x);}

int n,a[N],s[N],fac[N],inv[N],ifac[N];

void pre() {
    inv[0]=inv[1]=fac[0]=ifac[0]=1;
    rep(i,1,n) fac[i]=fac[i-1]*i%mod;
    rep(i,2,n) inv[i]=inv[mod%i]*(mod-mod/i)%mod;
    rep(i,1,n) ifac[i]=ifac[i-1]*inv[i]%mod;
}
int C(int x,int y) {
    if(x<0||y<0||x<y) return 0;
    else return fac[x]*ifac[y]%mod*ifac[x-y]%mod;
}

namespace Solution_1 { //第一种O(n^3)做法 
    int s[N],f[2][N][N];
    void solve() {
        int d=0;
        rep(i,1,n) {
            if(a[i]==a[i-1]) s[i]=s[i-1]+1;
            else s[i]=0;
        }
        f[0][0][0]=1;
        rep(i,1,n) {
            if(a[i]!=a[i-1]) {
                rep(j,0,i) rep(k,1,s[i-1])
                    pls(f[d^1][j][0],f[d^1][j][k]), f[d^1][j][k]=0;
            }
            rep(j,0,i) {
                int upk=min(s[i],j);
                rep(k,0,upk) {
                    if(j&&k) pls(f[d][j][k],f[d^1][j-1][k-1]*(2*s[i]-k+1)%mod);
                    pls(f[d][j][k],f[d^1][j][k]*(i-2*s[i]+2*k-j)%mod);
                    pls(f[d][j][k],f[d^1][j+1][k]*(j-k+1)%mod);
                }
            }
            memset(f[d^1],0,sizeof(f[d^1]));
            if(i!=n) d^=1;
        }
        printf("%lld\n",f[d][0][0]);
    }
}

namespace Solution_2 { //第二种o(n^3)做法 
    int cnt,s[N],g[N][N],f[N][N];
    void calc_g() {
        g[0][0]=1;
        rep(i,1,n) rep(j,1,i) g[i][j]=((i+j-1)*g[i-1][j]+j*g[i-1][j-1])%mod;
    }
    void calc_f() {
        int q=0;
        f[0][0]=1;
        rep(i,1,cnt) {
            rep(j,0,q) rep(k,1,s[i]) rep(p,0,j) {
                pls(f[i][j+s[i]-k-p],C(j,p)*C(q+1-j,k-p)%mod*g[s[i]][k]%mod*f[i-1][j]%mod);
            }
            q+=s[i];
        }
    }
    void solve() {
        rep(i,1,n) {
            if(a[i]!=a[i-1]) {
                s[++cnt]=1;
            } else {
                s[cnt]++;
            }
        }
        calc_g();
        calc_f();
        printf("%lld\n",f[cnt][0]);
    }
}

namespace Solution_3 { //第三种O(n^2)做法 
    int cnt,s[N],f[N][N];
    void solve() {
        rep(i,1,n) {
            if(a[i]!=a[i-1]) {
                s[++cnt]=1;
            } else {
                s[cnt]++;
            }
        }
        int mxr=0;
        f[0][0]=1;
        rep(i,1,cnt) {
            mxr+=s[i]-1;
            rep(j,0,mxr) {
                int upk=min(s[i],j);
                rep(k,0,upk) 
                    pls(f[i][j],f[i-1][j-k]*C(s[i]-1,k)%mod*ifac[s[i]-k]%mod);
            }
        }
        int ans=0;
        rep(i,0,mxr) pls(ans,f[cnt][i]*fac[n-i]%mod*(i%2==0?1:mod-1)%mod);
        rep(i,1,cnt) ans=ans*fac[s[i]]%mod;
        printf("%lld\n",ans);
    } 
}

signed main() {
    n=read();
    rep(i,1,n) {
        a[i]=read();
        for(int j=2;j*j<=a[i];j++)
            while(a[i]%(j*j)==0) a[i]/=j*j;
    }
    sort(a+1,a+n+1);
    pre();
    Solution_3::solve(); //all solutions can pass
    return 0;
}

New: O(n\log^2n)(理论方法)

考虑优化 DP 的计算。我们设第 i 行的 DP 用幂级数表示为 F_i,然后设

G_i=\sum_{k=0}^{s_i} \frac{\binom{s_i-1}{k}}{(s_i-k)!}x^k

,那么有

F_i=F_{i-1}\times G_{i}

所以我们求的目标 F_{cnt}=\prod G_i。一个非常重要的东西是,我们发现 \sum s_i=n,也就意味着我们可以通过分治 NTT 的方法做到 O(n\log^2 n)

(你会发现这只不过是 [集训队作业2019] 青春猪头少年不会梦到兔女郎学姐 的弱化一万次的版本)

最后你发现模数是 10^9+7 然后作者不想写三模NTT。