题解:AT_abc392_g [ABC392G] Fine Triplets

· · 题解

题意

给你一个大小为 N 的集合 S,求满足 A<B<CB-A=C-BA,B,C\in S 的三元组 (A,B,C) 个数。

思路

设集合中某个数 x 出现的次数为 cnt_x,考虑枚举 BB-A,则答案为:

\sum_{i=1}^n\sum_{j=1}^{s_i}cnt_{s_i-j}cnt_{s_i+j}

其中 i=Aj=B-A

f_x=\sum_{i=1}^xcnt_{x-i}cnt_{x+i},则原式可以转化为

\sum_{i=1}^nf_{s_i}

我们规定 cnt_0=0,然后考虑化简 f_x,有

\begin{aligned} f_x&=\sum_{i=1}^{x-1}cnt_icnt_{2x-i}\\ &=\left(\sum_{i=0}^xcnt_icnt_{2x-i}\right)-cnt_x^2\\ &=\frac{\sum_{i=0}^{2x}cnt_icnt_{2x-i}}{2}-cnt_x^2 \end{aligned}

不难发现 \sum_{i=0}^{2x}cnt_icnt_{2x-i} 是一个卷积的形式,带入原式得

\left(\frac{1}{2}\sum_{i=1}^n\sum_{j=0}^{2a_i}cnt_jcnt_{2a_i-j}\right)-n

因为任意一个在集合中出现的元素 x 都有 cnt_x=1,故 cnt_x^2=1,所以最终算出来再统一减去 n 即可。

f(x)=\sum_{i=0}^{10^6}x^icnt_xg(x)=\sum_{i=0}^{10^6}x^icnt_x,计算出 f*g 的系数再统计答案即可。这里使用NTT来达到 O(n\log n) 的时间复杂度。

代码

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=1e6+5,mod=998244353,g=3;//998244353=(2^23)*7*17
int fpow(int n,int k=mod-2)//快速幂
{
    int res=1,base=n;
    for(;k>0;k>>=1)
    {
        if((k&1)==1)res=1ll*res*base%mod;
        base=1ll*base*base%mod;
    }
    return res;
}
const int invg=fpow(g);
int rev[N<<2];
void NTT(int*a,int len,bool inv) //NTT
{
    for(int i=0;i<len;i++)
        if(i<rev[i])swap(a[i],a[rev[i]]);
    for(int mid=1;mid<len;mid<<=1)
    {
        int wn=fpow(inv?invg:g,(mod-1)/(mid<<1));
        for(int i=0;i<len;i+=mid<<1)
        {
            int wk=1;
            for(int j=i;j<i+mid;j++,wk=1ll*wk*wn%mod)
            {
                int x=a[j],y=1ll*wk*a[j+mid]%mod;
                a[j]=(x+y)%mod,a[j+mid]=(x-y+mod)%mod;
            }
        }
    }
}
void times(int n,int m,int*a,int*b) //计算多项式相乘的系数
{
    int len=1,cnt=0;
    while(len<=n+m)len<<=1,cnt++;
    for(int i=1;i<len;i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<cnt-1);
    NTT(a,len,false),NTT(b,len,false);
    for(int i=0;i<len;i++)a[i]=1ll*a[i]*b[i]%mod;
    NTT(a,len,true);
    for(int i=0;i<len;i++)a[i]=1ll*a[i]*fpow(len)%mod;
}
int a[N<<2],b[N<<2],x[N];
signed main()
{
    ios::sync_with_stdio(false);
    cin.tie(0),cout.tie(0);
    int n;
    cin>>n;
    for(int i=1;i<=n;i++)cin>>x[i],a[x[i]]++,b[x[i]]++;
    times(1e6,1e6,a,b);
    ll ans=0;
    for(int i=1;i<=n;i++)ans+=a[2*x[i]]; //统计答案
    cout<<(ans-n)/2;
    return 0;
}