[P8338] [AHOI2022] 排列

· · 题解

本题涉及了关于线性筛、质因数分解、置换、分析数据范围特性等多种技巧,是一道难得的好题,为出题人点赞!

Hint 1

可以把每个 p_i 看成从 i 连向 p_i 的一条有向边,这样整个图会由若干个互不相交的简单环构成(所有点的入度出度均为 1 )。

可以通过分析得出, P^k 的意义就相当于每个人一开始都在初始的 i 号结点, k 每次 +1 就变成所有人在图上走一步。

于是我们得出了 f(i,j)=0 的充要条件: i,j 在同一个环上。

Hint 2

考虑对一个排列 P 如何计算 v(P) ,设 i 号结点所在的环长为 r_i ,则可以得出 v(P)=\operatorname{lcm}\{r_1,r_2,\cdots,r_n\} ,这是因为所有点都同时回到原点需要保证每个人的步数都为 r_i 的倍数。

接着考虑 v(A_{ij}) 的值,草稿纸上画一画就可以发现,把两个不同环上的点的出边进行交换,就会把两个环合并,因此 v(A_{ij}) 的值就为 r_i+r_j 与其他环长取 \operatorname{lcm} 的值。

到这一步实际上我们发现如果枚举 i,j 并暴力计算 \operatorname{lcm} ,那么可以得到一个看上去是 O(n^3\log n) 的做法。

Hint 3

对上一步的暴力做法进行些许优化,对每个相同的 r_i 其实是不用重复计算的,所以设环长的种类数为 m ,能做到 O(m^3\log n) 的复杂度。

这里需要进行对数据范围分析,注意到这里是 m 个互不相同的数相加不超过 n ,所以一定有 \sum_{i=1}^m\le n ,计算得出 mO(\sqrt n) 级别的,所以时间复杂度貌似可以优化到 O(n\sqrt n\log n)

Hint 4

但是我们会发现 \operatorname{lcm} 的值可能会很大并不能直接计算,所以需要对每个 r_i 进行质因数分解,最后在每个质因子上计算指数的最大值得出结果,这样单次计算 \operatorname{lcm} 的时间复杂度为 O(m\sqrt n+cnt_{prime}\cdot\log\log n)

这里可以利用线性筛的性质来加速质因数分解的过程,注意到每次线性筛筛到一个数时,一定是被他的最小质因子筛到,于是可以预处理记录下每个数字对应的最小质因子,这样单次质因数分解的时间复杂度可以优化到 O(\log n)

除此之外,我们还可以在每次分解质因数时直接质因子判断对应指数的值,并判断是否乘上去,如果采用 mapset 来记录,这样单次计算 \operatorname{lcm} 的复杂度变成了 O(m\log^2 m) ,总时间复杂度为 O(n\sqrt n\log^2 n)

如果在判断每个质因子对应的最大指数时做到 O(1) 的维护,那么就是 O(n\sqrt n\log n) 的时间复杂度,期望得分 80 ,如果能做到常数优秀且熟练掌握卡常技巧是完全有机会通过此题的。

Hint 5

继续优化求 \operatorname{lcm} 的方法,我们可以预先求出所有数字的 \operatorname{lcm} ,然后记录每个质因子对应的三个最大指数(因为后面要删掉两个数),这样每次删数的时候就可以现场质因数分解 O(\log^2 n) 维护 \operatorname{lcm} ,添加一个数字的时候也能够当场维护。

另外,也可以直接用 multiset 记录每个质因子对应的指数集合,同样也是 O(\log^2 n) 的维护复杂度。

总时间复杂度 O(n\log^2n) ,已经足够通过此题。

#include<bits/stdc++.h>
using namespace std;
#define N 500050
#define LL long long
#define MOD 1000000007
int T,n,a[N],r[N],vis[N];
int cnt,cntR,p[N],v[N],LCM,m,ans;
vector<pair<int,int>>d;
multiset<int>s[N];
LL qow(LL x,LL y){return y?(y&1?x*qow(x,y-1)%MOD:qow(x*x%MOD,y/2)):1;}
void add(int x)
{
    while(x>1){
        int p=v[x],c=0,k=(*s[p].rbegin());
        while(v[x]==p)x/=v[x],c++;
        if(c>k)LCM=1ll*LCM*qow(p,c-k)%MOD;
        s[p].insert(c);
    }
}
void del(int x)
{
    while(x>1){
        int p=v[x],c=0,k;
        while(v[x]==p)x/=v[x],c++;
        s[p].erase(s[p].find(c));
        k=(*s[p].rbegin());
        if(k<c)LCM=1ll*LCM*qow(p,MOD-1+k-c)%MOD;
    }
}
void init()
{
    LCM=1;
    cntR=0;
    d.clear();
    for(int i=1;i<=cnt;i++)s[p[i]].clear(),s[p[i]].insert(0);
    scanf("%d",&n);
    for(int i=1;i<=n;i++){
        scanf("%d",&a[i]);
        r[i]=vis[i]=0;
    }
    for(int i=1;i<=n;i++)if(!vis[i]){
        int x=i,c=1;
        while(a[x]!=i)vis[x]=1,c++,x=a[x];
        add(c);
        vis[x]=1;
        r[++cntR]=c;
    }
    sort(r+1,r+cntR+1);
    for(int i=1;i<=cntR;i++){
        if(d.empty() || r[i]!=d.back().first)
            d.push_back(make_pair(r[i],1));
        else d[d.size()-1].second++;
    }
    m=d.size(),ans=0;
    for(int i=0;i<m;i++){
        del(d[i].first);
            for(int j=i+1;j<m;j++){
            int res=2ll*(1ll*d[i].first*d[i].second%MOD)*(1ll*d[j].first*d[j].second%MOD)%MOD;
            int x=d[i].first+d[j].first;
            del(d[j].first);
            add(x);
            ans=(1ll*res*LCM+ans)%MOD;
            del(x);
            add(d[j].first);
        }
        if(d[i].second>1){
            int res=1ll*(1ll*d[i].first*d[i].second%MOD)*(1ll*d[i].first*(d[i].second-1))%MOD;
            int x=d[i].first+d[i].first;
            del(d[i].first);
            add(x);
            ans=(1ll*res*LCM+ans)%MOD;
            del(x);
            add(d[i].first);
        }
        add(d[i].first);
    }
    printf("%d\n",ans);
}
int main()
{
    v[1]=1;
    for(int i=2;i<N;i++){
        if(!v[i])v[i]=i,p[++cnt]=i;
        for(int j=1;j<=cnt && i*p[j]<N;j++){
            v[i*p[j]]=p[j];
            if(i%p[j]==0)break;
        }
    }
    scanf("%d",&T);
    while(T--)init();
}