题解 P2260 【[清华集训2012]模积和】

· · 题解

题意:给定 n,m,需要在 O(\sqrt n) 的时间复杂度内求得下面式子的值:

\sum_{i=1}^n \sum_{j=1}^m (n \bmod i)(m \bmod j)[i ≠j]

首先简单容斥,假设 n<m,可以将式子化为下面的样子:

\sum_{i=1}^n \sum_{j=1}^m (n \bmod i)(m \bmod j)-\sum_{i=1}^{n}(n \bmod i)(m \bmod i)

两个式子不相关,考虑分别计算。左边一式长得不好看,首先将 (n \bmod i) 一项移出来。原式化为:

\sum_{i=1}^n (n \bmod i) \sum_{j=1}^m (m \bmod j)

根据取模的定义得到:

\sum_{i=1}^n (n-i\times \lfloor \dfrac{n}{i} \rfloor) \times \sum_{i=1}^m (m-i\times \lfloor \dfrac{m}{i} \rfloor)

又将 nm 提出来:

(n^2 - \sum_{i=1}^n i\times \lfloor \dfrac{n}{i} \rfloor) \times (m^2 -\sum_{i=1}^mi\times \lfloor \dfrac{m}{i} \rfloor)

有整除分块的样子了。

考虑整理右边的式子。根据定义式将其化为:

\sum_{i=1}^n(nm-m\times i\times \lfloor \dfrac{n}{i} \rfloor - n \times i \times \lfloor \dfrac{m}{i} \rfloor + \lfloor \dfrac{n}{i} \rfloor \times \lfloor \dfrac{m}{i} \rfloor \times i^2)

也是整除分块的样子。接下来考虑的是 \sum_{i=1}^n i^2 怎么算。

从三次方入手开始降次。写出 (n+1)^3-n^3=3n^2+3n+1。同理可得 n^3-(n-1)^3=3(n-1)^2+3(n-1)+1。继续写下去,会得到 2^3-1^3=3\times 1^3+3\times 1+1。将这 n 个式子加起来,得到 (n+1)^3-1=3 \times (\sum_{i=1}^n i^2)+3\times (\sum_{i=1}^n i)

我们已经知道了 \sum_{i=1}^n i=\dfrac{i \times (i+1)}{2},移项得到 \sum_{i=1}^n i^2=\dfrac{n \times(n+1 \times (2n+1))}{6}

最后处理出 2,6 的逆元。这道题就解决了。

#include<bits/stdc++.h>
using namespace std;
typedef __int128 LL;
char buf[1<<21],*p1=buf,*p2=buf;
#define getchar() (p1==p2 && (p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
LL read()
{
    LL x=0,f=1;
    char c=getchar();
    while(c<'0' || c>'9')
    {
        if(c=='-')  f=-1;
        c=getchar();
    }
    while(c>='0' && c<='9') x=(x<<1)+(x<<3)+(c^'0'),c=getchar();
    return x*f;
}
void write(LL x)
{
    if(x<0) putchar('-'),x=-x;
    if(x>9) write(x/10);
    putchar(x%10+'0');
}
const LL MOD=19940417,inv2=9970209,inv6=3323403;
LL line(LL p){return p*(p+1)/2;}
LL square(LL p){return p*(p+1)*(2*p+1)/6;}
LL n,m;
int main(){
    n=read(),m=read();
    if(n>m) n^=m^=n^=m;
    LL a=0;
    for(LL l=1,r;l<=n;l=r+1)
    {
        r=n/(n/l);
        a+=(n/l)*(line(r)-line(l-1));
    }
    a=n*n-a;
    LL b=0;
    for(LL l=1,r;l<=m;l=r+1)
    {
        r=m/(m/l);
        b+=(m/l)*(line(r)-line(l-1));
    }
    b=m*m-b;
    LL c=n*m*n;
    LL k=0;
    for(LL l=1,r;l<=n;l=r+1)
    {
        r=min(n/(n/l),m/(m/l));
        r=min(r,n);
        k+=(n/l)*(line(r)-line(l-1));
    }
    c-=k*m;
    k=0;
    for(LL l=1,r;l<=n;l=r+1)
    {
        r=min(n/(n/l),m/(m/l));
        r=min(r,n);
        k+=(m/l)*(line(r)-line(l-1));
    }
    c-=k*n;
    k=0;
    for(LL l=1,r;l<=n;l=r+1)
    {
        r=min(n/(n/l),m/(m/l));
        r=min(r,n);
        k+=(n/l)*(m/l)*(square(r)-square(l-1));
    }
    c+=k;
    LL res=a*b-c;
    res%=MOD;
    res+=MOD;
    res%=MOD;
    write(res);
    return 0;
}