题解 P6300 【悔改】
z7z_Eta
2020-05-18 11:45:23
只能靠一直写各位神仙都不屑于做的谔谔题~~来荒废度日~~
数学公式错了可以[在luogu博客看](https://www.luogu.com.cn/blog/z7z-Eta/solution-p6300)
---
任意砍成两段可以认为是一种类型的生成函数的**2次方**
$s_k = \lfloor\{\sum_{i+j=k}min(a_i,a_j)\ \}/2\rfloor$
$a$是$cnt$数组, 除$2$取整已经包含了相同大小的两个拼起来的情况
答案就是$s$最大值, 暴力求复杂度是$O(m^2)$的
---
正解是一类**取min类型的卷积**到普通乘法卷积的转化
$\sum_{i+j=k}min(a_i,a_j) = \sum_{d=1}\sum_{i+j=k}[a_i\ge d][a_j\ge d]$
这样**枚举$d$**, 每次卷积就可以求了
也就是说$F_d(x)=\sum[a_i\ge d]x^i$
$s_k=\frac{1}{2}\sum_{d=1}\ [x^k]F_d^2(x)$
---
我们发现, 显然本质不同的$F_d$是$O(\sqrt{n})$级别的, 因为最坏情况是$1+2+3+4+\dots=O(D^2)$
所以**离散化$a$数组**直接卷就好了, 复杂度$O(m\sqrt{n}\log m)$, 常数不大
虽然比官方题解看起来~~屑了不少~~, 但是也许更好理解一些^QwQ^.
```cpp
// Etavioxy
#define ll long long
#define il inline
#define rep(i,s,t) for(register int i=(s);i<=(t);i++)
#define rev_rep(i,s,t) for(register int i=(s);i>=(t);i--)
#define pt(x) putchar(x)
using namespace std;
il int ci(){ register char ch; int f=1; while(!isdigit(ch=getchar())) f=ch=='-'?-1:1; register int x=ch^'0'; while(isdigit(ch=getchar())) x=(x*10)+(ch^'0'); return f*x; }
template<typename T> il bool chk_max(T&x,T y){ return (x<y?(x=y,1):0); }
const int mod = 998244353;
enum{N=270023};
il ll qpow(ll a,ll b){
ll ans= 1; for(; b; b>>=1,a=a*a%mod) if( b&1 ){ ans= ans*a%mod; } return ans;
}
il ll inv(ll x){ return qpow(x,mod-2); }
const ll g0 = 3;
const ll invg0 = inv(g0);
int rev[N],init_rev_n;
il void init_rev(int n){
if( n==init_rev_n ) return;
int t = log2(n);
init_rev_n = n;
rep(i,1,n-1) rev[i] = (rev[i>>1]>>1)|((i&1)<<(t-1));
}
int NTT(ll*A,int n,int IDFT){
init_rev(n);
rep(i,1,n-1) if( i<rev[i] ) swap(A[i],A[rev[i]]);
for(int t=1;t<n;t<<=1){
ll w0 = qpow(IDFT==-1?invg0:g0,(mod-1)/(t*2));
for(int i=0;i<n;i+=t*2){
for(register int j=0,w=1;j<t;j++,w=w*w0%mod){ ll x=A[i+j], y=A[i+t+j]; A[i+j] = (x+w*y)%mod; A[i+t+j] = (x-w*y)%mod; }
}
}
if( IDFT==-1 ){ ll invn = inv(n); rep(i,0,n-1) A[i] = A[i]*invn%mod; } return n;
}
il int set2(int n){ return 1<<((int)log2(n-1)+1); } // 原代码的模板都进行了压行处理
ll t[N],s[N];
int a[N],b[N];
int main(){
int n = ci(), m = ci();
rep(i,1,n) a[ci()]++;
rep(i,1,m) b[i] = a[i];
sort(b,b+m+1);
int T = unique(b,b+m+1)-b-1;
rep(d,1,T){
rep(i,1,m*2) t[i] = a[i]>=b[d];
int L = set2(m*2);
NTT(t,L,1);
rep(i,0,L-1) t[i] = t[i]*t[i]%mod;
NTT(t,L,-1);
rep(i,1,m*2) s[i] += 1ll*(b[d]-b[d-1])*((t[i]+mod)%mod);
}
ll ans = 0, ans2 = 0;
rep(i,1,m*2) if( chk_max(ans,s[i]/2) ) ans2 = i;
printf("%lld %lld\n",ans,ans2);
return 0;
}
```
思路懂了之后, 写起来的确是道`easy-implement`题(。﹏。)