A.solution

· · 题解

题解

一句话总结:阈值分块,整除分块。

观察数据范围发现 \sum\limits_{i=1}^{n}a_i 只有 10^9,我们设 A=\sum\limits_{i=1}^{n}a_i,在 B 把数分为两块,B_1\le B,B_2>B

Part 1:B_1\times B_1

开个桶存一下每个数有多少个,B^2 枚举任意两个数算方案数即可。时间复杂度 \Theta(B^2)

Part 2:B_2\times B_2

发现 B_2 最多只有 \frac{A}{B} 个数,又发现答案范围为:\frac{m}{B^2}

那么我们枚举答案,再枚举 B_2 中的数,根据整除分块的知识即可求出另一个数的范围。

另一个数最大是 \frac{m}{B},前缀和存一下就行。

时间复杂度 \Theta(\frac{A\times m}{B^3})

Part 3:B_1\times B_2

这里我们把 B_1 中的数按 Q 分为 Q_1\le Q,Q_2>Q

对于 Q_1,B_2,暴力枚举 Q_1 中数的大小,再枚举 B_2 中的每一个数,通过桶计算方案数即可,不过不要忘记对贡献 \times 2,因为每对数正反都会贡献一遍。时间复杂度 \Theta(\frac{A\times Q}{B})

对于 Q_2,B_2,会发现答案范围为:\frac{m}{Q\times B} 依然枚举答案,枚举 B_2 中的数,根据整除分块算出范围,前缀和算方案数即可,这里和上一种一样,贡献也需要 \times 2。时间复杂度 \Theta(\frac{A\times m}{Q\times B^2})

在通过整除分块算边界时有些细节需要稍微注意。

整体下来复杂度是 \Theta(B^2+\frac{A\times m}{B^3}+\frac{A\times Q}{B}+\frac{A\times m}{Q\times B^2})。稍微想一下发现 B=10^4,Q=10^3 的时候算出来大概是 3\times 10^8,可以通过。

验题人对于计算 $B_1,B_2$ 有一个新的想法:不难发现 $B_2$ 中其实只有 $45000$ 左右个不同的数,那么暴力枚举大概在 $4\times 10^8$,一个是数据不是很极端,再一个是有 $4$s,卡卡常差不多能够通过。 ## Code 1 出题人的做法。 ```cpp #include<bits/stdc++.h> using namespace std; #define int long long #define pii pair<int,int> #define mk(x,y) make_pair(x,y) const int M=1e6+10,N=1e4+10,mod=998244353; int n,m,gm; int a[M]; int t[M]; int s[M]; vector<int>b; int ans=0; pii ct[M]; int read() { int x = 0, w = 1; char ch = 0; while (ch < '0' || ch > '9') { if (ch == '-') w = -1; ch = getchar(); } while (ch >= '0' && ch <= '9') { x = x * 10 + (ch - '0'); ch = getchar(); } return x * w; } signed main(){ freopen("temp.in","r",stdin); freopen("temp.out","w",stdout); cin>>n>>m; gm=ceil(sqrt(m)); for(int i=1;i<=n;i++){ a[i]=read(); if(a[i]<=1e6)t[a[i]]++,s[a[i]]++; if(a[i]>1e4)b.push_back(a[i]); } sort(a+1,a+n+1); for(int i=1;i<M;i++)s[i]+=s[i-1]; for(int i=1;i<=1e4;i++){ for(int j=1;j<=1e4;j++){ ans=(ans+m/(i*j)*t[i]%mod*t[j]%mod)%mod; } } int l=1,r; while(l<=m){ r=m/(m/l); if(m/l<=gm)ct[m/l]=mk(l,r); l=r+1; } for(int i=1;i<=m/1e8;i++){ for(int j:b){ int l=ct[i].first,r=ct[i].second; l=ceil(1.0*l/j),r=r/j; if(l>r||r<=1e4)continue; l=max(l,(int)1e4+1); int nex=s[r]-s[l-1]; ans+=nex*i; if(ans>mod)ans%=mod; } } for(int i=1;i<=1e3;i++){ for(int j:b){ ans+=m/(i*j)*t[i]*2; if(ans>mod)ans%=mod; } } for(int i=1;i<=m/1e7;i++){ for(int j:b){ int l=ct[i].first,r=ct[i].second; l=ceil(1.0*l/j),r=r/j; if(l>r||r<=1e3||l>1e4)continue; if(r>1e4)r=1e4; if(l<=1e3)l=1e3+1; int nex=s[r]-s[l-1]; ans+=nex*i*2; if(ans>mod)ans%=mod; } } cout<<ans<<endl; return 0; } ``` ## Code 2 验题人的做法。 ```cpp #include <bits/stdc++.h> typedef long long LL; const int M = 1e6 + 10, B = 1500; const int mod = 998244353; const LL Max_m = 1e10; LL m, ans; int n, a[M], pre[Max_m / B + 10]; std::vector<std::pair<int, int> > S, S0, S1; inline void Init() { std::map<int, int> mp; for (int i = 1; i <= n; i++) if (a[i] <= B) ++mp[a[i]]; for (auto x : mp) S0.emplace_back(x); std::sort(S0.begin(), S0.end()); for (int i = 1; i <= n; i++) if (a[i] > B) ++mp[a[i]]; for (auto x : mp) S.emplace_back(x); std::sort(S.begin(), S.end()); mp.clear(); for (int i = 1; i <= n; i++) if (a[i] > B) ++mp[a[i]]; for (auto x : mp) S1.emplace_back(x); std::sort(S1.begin(), S1.end()); } int main() { scanf("%d%lld", &n, &m); for (int i = 1; i <= n; i++) scanf("%d", &a[i]); Init(); for (auto [x, a] : S) { LL v = m / x; for (auto [y, b] : S0) { if (x < y) break; if (1ll * x * y > m) break; if (x == y) ans += 1ll * a * b % mod * (v / y % mod); else ans += 2ll * a * b % mod * (v / y % mod); ans %= mod; } } for (auto [x, a] : S1) if (x <= m / B) pre[x] += a; for (int i = B; i <= m / B; i++) pre[i] += pre[i - 1]; for (LL l = 1, r; l <= m; l = r + 1) { r = m / (m / l); if (m / l <= m / B / B) for (auto [x, a] : S1) { LL p = m / l; LL _l = l / x + (bool)(l % x), _r = r / x; // m / (x * y) = p // _l <= y <= _r if (_r >= _l) { ans += 1ll * pre[_r] * p % mod * a % mod; if (_l > 0) ans += mod - 1ll * pre[_l - 1] * p % mod * a % mod; ans %= mod; } } } printf("%lld\n", ans); return 0; } ```