题解 P6624 【[2020 年联考 A 卷] 作业题】

· · 题解

--- 此次省选我唯一切掉的题目(并且在不 FST 的情况下)(捂脸 --- 首先这题非常的 “二合一”。前面的化式子和后面矩阵树定理的使用基本上割裂开了,那我也就分开来讲。 ## Part 1:化简式子 后面一坨 $\gcd$ 需要反演掉,比较懒的话直接套这个众所周知的式子就行: $$\varphi*1=\mathrm{id}$$ 然后把 $\varphi$ 提到前面来,方便之后使用矩阵树定理。 具体推式子过程如下: $$\begin{aligned} \text{answer}&=\sum\limits_{T}\left(\sum\limits_{i=1}^{n-1}w_{e_i}\right)\times\gcd(w_{e_1},\cdots,w_{e_{n-1}}) \\ &=\sum\limits_{T}\left(\sum\limits_{i=1}^{n-1}w_{e_i}\right)\times\sum\limits_{d\mid w_{e_1},\cdots,d\mid w_{e_{n-1}}}\varphi(d) \\ &=\sum\limits_{d=1}^{mx}\varphi(d)\sum\limits_{T\atop d\mid w_{e_1},\cdots,d\mid w_{e_{n-1}}}\left(\sum\limits_{i=1}^{n-1}w_{e_i}\right) \end{aligned}$$ (其中 $mx$ 为 $w_i$ 的最大值。) --- ## Part 2:矩阵树定理的使用 外层枚举 $d$,里层那一坨就需要矩阵树定理来处理了。 求所有生成树的边权和是被出烂的老经典题了。。。 首先考虑 $w_i$ 都相同的情况:直接套模板即可,最后答案乘上 $w_i$ 完事。 有不同的情况就先考虑一种最裸的做法:枚举一条边然后求包含这条边的生成树个数。 有木有方法把所有边的答案的和经过求一次行列式就全部算出来? 还真有,矩阵每个位置放个一次函数(一次多项式)即可。 一条边的贡献珂以写成 $w_ix+1$,最后答案就是行列式的一次项系数。 撕烤一下为什么? 答案实际上是求 钦定一条边之后的生成树个数$\times$这条边的边权 之和,那么答案里被乘上一次项系数的边就是被钦定的边。 下面就是一下多项式操作了:$a+bx$ 和 $c+dx$ 的四则运算规则 加减就直接对应项项加/减,不说了。 乘法:$(a+bx)(c+dx)=ac+(ad+bc)x$。 除法:$\dfrac{a+bx}{c+dx}=\dfrac{a}{c}+\dfrac{bc-ad}{c^2}x

对于枚举的每个 d 都求一遍行列式的话复杂度是 \mathcal{O}(n^3mx),不知道能不能过。

不过珂以优化,只对可选边数大于等于 n-1d 求行列式即可,这样复杂度就优化到了\mathcal{O}(n^2\sum\sigma_0(w_i)) ,上界是 144\times n^4,实际完全跑不满。

注:这个 n^2\sum\sigma_0(w_i) 其实就是 n^3\frac{\sum\sigma_0(w_i)}{n-1}

code:

#include<iostream>
#include<cstring>
#include<cstdio>
using namespace std;
const int mod=998244353;
#define N 33
#define M 200020
inline int read(){
    int x=0,f=1;
    char c=getchar();
    while(c<'0'||c>'9'){
        if(c=='-')f=-1;
        c=getchar();
    }
    while(c>='0'&&c<='9'){
        x=(x<<1)+(x<<3)+c-'0';
        c=getchar();
    }
    return x*f;
}
typedef pair<int,int> pii;
int n,m,ans,h[M];
int qpow(int a,int b){
    int ans=1;
    while(b){
        if(b&1)ans=1LL*ans*a%mod;
        a=1LL*a*a%mod;
        b>>=1;
    }
    return ans;
}
pii g[N][N]; //(常数项系数,一次项系数)
pii operator +(const pii a,const pii b){
    return make_pair((a.first+b.first)%mod,(a.second+b.second)%mod);
}
pii operator -(const pii a,const pii b){
    return make_pair((a.first-b.first+mod)%mod,(a.second-b.second+mod)%mod);
}
pii operator *(const pii a,const pii b){
    return make_pair(1LL*a.first*b.first%mod,(1LL*a.first*b.second+1LL*a.second*b.first)%mod);
}
pii operator /(const pii a,const pii b){
    int inv=qpow(b.first,mod-2);
    return make_pair(1LL*a.first*inv%mod,(1LL*a.second*b.first%mod-1LL*a.first*b.second%mod+mod)%mod*inv%mod*inv%mod);
}
int x[N*N],y[N*N],w[N*N],mx;
int p[M],pn,phi[M];
bool ntp[M];
void init(int n){
    phi[1]=ntp[1]=1;
    for(int i=2;i<=n;++i){
        if(!ntp[i])p[++pn]=i,phi[i]=i-1;
        for(int j=1;j<=pn&&p[j]*i<=n;++j){
            ntp[p[j]*i]=true;
            if(i%p[j]==0){
                phi[p[j]*i]=phi[i]*p[j];
                break;
            }
            phi[p[j]*i]=phi[i]*(p[j]-1);
        }
    }
}
pii Guass(int n){
    pii ans=make_pair(1,0);
    bool rev=false;
    for(int i=1;i<=n;++i){
        if(!g[i][i].first){
            for(int j=i+1;j<=n;++j){
                if(g[j][i].first){
                    rev^=1;
                    swap(g[i],g[j]);
                    break;
                }
            }
        }
        pii inv=make_pair(1,0)/g[i][i];
        for(int j=i+1;j<=n;++j){
            pii div=g[j][i]*inv;
            for(int k=i;k<=n;++k){
                g[j][k]=g[j][k]-div*g[i][k];
            }
        }
        ans=ans*g[i][i];
    }
    return rev?make_pair(0,0)-ans:ans;
}
int Solve(int t){
    memset(g,0,sizeof(g));
    for(int i=1;i<=m;++i){
        if(w[i]%t)continue;
        g[x[i]][x[i]]=g[x[i]][x[i]]+make_pair(1,w[i]);
        g[y[i]][y[i]]=g[y[i]][y[i]]+make_pair(1,w[i]);
        g[x[i]][y[i]]=g[x[i]][y[i]]-make_pair(1,w[i]);
        g[y[i]][x[i]]=g[y[i]][x[i]]-make_pair(1,w[i]);
    }
    return Guass(n-1).second;
}
void check(int x){
    for(int i=1;i*i<=x;++i){
        if(x%i==0){
            ++h[i];
            if(i*i!=x)++h[x/i];
        }
    }
}
int main(){
    freopen("count.in","r",stdin);
    freopen("count.out","w",stdout);
    n=read(),m=read();
    for(int i=1;i<=m;++i){
        x[i]=read(),y[i]=read(),w[i]=read();
        check(w[i]);
        mx=max(mx,w[i]);
    }
    init(mx);
    for(int i=1;i<=mx;++i){
        if(h[i]<n-1)continue;
        ans=(ans+1LL*phi[i]*Solve(i))%mod;
    }
    cout<<ans<<endl;
    return 0;
}

Froggy's blog

呱!!