题解:AT_arc182_c [ARC182C] Sum of Number of Divisors of Product

· · 题解

[ARC182C] Sum of Number of Divisors of Product 题解

题意

给你 N,M 让你找对于所有长度 \le N,值域在 [1,M] 之间的序列的乘积求因子个数再求和。

分析

看到 m\le 16 先找出所有质数 2,3,5,7,11,13,一共有 6 个,考虑状压也只有 64

进一步发现 n 巨大,但是状压状态很少,可以想到矩阵快速幂。

考虑怎么构造原 DP 序列和矩阵进行转移:

发现如果设“某一状态为某一质数有没有”无法转移,所以我们先令第 i(0\le i<6) 个质数的编号为 i,个数为 a_ia_i=0 表示没有这个质数),则显然一个序列的答案为 (a_0+1)(a_1+1)(a_2+1)(a_3+1)(a_4+1)(a_5+1) 而如果我们在序列末尾加入一个 6 则答案就变成 (a_0+2)(a_1+2)(a_2+1)(a_3+1)(a_4+1)(a_5+1)

根据 c(a+x)(b+y)=c(ab+ay+bx+xy),设 c=(a_2+1)(a_3+1)(a_4+1)(a_5+1) 原式子即为:

(a_0+2)(a_1+2)c=(a_0+1)(a_1+1)c+(a_0+1)c+(a_1+1)c+c

所以我们设某一状态 f_S 表示 (a_i+1)(a_j+1)(a_k+1)\cdots(a_l+1),(i,j,k,\dots,l\in S) 的和。

那么转移就很简单了,我们枚举每一位和 [1,m] 间的每一位数字,来看看对于原序列的系数是多少,然后给转移矩阵增加即可,因为一个数字最多有两个不同的质数,找两个变量记录一下即可。

式子形如:f_{S}=f_S+f_{S-\{i\}}+f_{S-\{j\}}+f_{S-\{i\}-\{j\}}

注意 12=2\times 2\times 3,所以所带系数不一定为 1,读者可试着自己推一下便于理解。

初始状态 f_S=1,0\le S <64,长度为 len 时答案为 f_{\{0,1,2,3,4,5\}}

代码

#include<bits/stdc++.h>
#define int long long
using namespace std;
inline int read(){
    int x=0,f=1;char c=getchar();
    while(c<'0'||'9'<c){if(c=='-')f=-1;c=getchar();}
    while('0'<=c&&c<='9'){x=(x<<1)+(x<<3)+c-'0';c=getchar();}
    return x*f;
}
const int N=66,mod=998244353;
struct Matrix{
    int n,m,num[N][N];
    Matrix(){
        n=m=0;
        memset(num,0,sizeof(num));
    }
}a,b;
Matrix operator*(const Matrix &x,const Matrix &y){
    Matrix c;c.n=x.n;c.m=y.m;
    for(int i=0;i<x.n;i++)
        for(int j=0;j<y.m;j++)
            for(int k=0;k<x.m;k++) 
                (c.num[i][j]+=x.num[i][k]*y.num[k][j]%mod)%=mod;
    return c;
}
Matrix ksm(Matrix a,int b){
    Matrix t;
    t.n=t.m=a.n;
    for(int i=0;i<a.n;i++)t.num[i][i]=1;
    for(;b;b>>=1,a=a*a)if(b&1)t=t*a;
    return t;
}
int n,m;
int id[17];
signed main(){
    // freopen(".in","r",stdin);
    // freopen(".out","w",stdout);
    id[2]=1;id[3]=2;id[5]=3;id[7]=4;id[11]=5;id[13]=6;
    n=read();m=read();
    a.n=1,a.m=65;
    for(int i=0;i<64;i++)a.num[0][i]=1;
    b.n=b.m=65;
    b.num[64][64]=b.num[63][64]=1;
    for(int i=0;i<64;i++)
        for(int k=1;k<=m;k++){
            int t=k,t1=-1,t2=-1,s1=0,s2=0;
            for(int l=2;l<=t;l++)
                if(t%l==0){
                    if(~t1){t2=id[l]-1;while(t%l==0)t/=l,s2++;}
                    else{t1=id[l]-1;while(t%l==0)t/=l,s1++;}
                }
            b.num[i][i]++;
            if(t1>=0&&(i>>t1&1))b.num[i^(1<<t1)][i]+=s1;
            if(t2>=0&&(i>>t2&1))b.num[i^(1<<t2)][i]+=s2;
            if(t1>=0&&(i>>t1&1)&&t2>=0&&(i>>t2&1))b.num[i^(1<<t1)^(1<<t2)][i]+=s1*s2;
        }
    a=a*ksm(b,n+1);
    printf("%lld\n",a.num[0][64]-1);
    return 0;
}