题解 P6300 【悔改】

z7z_Eta

2020-05-18 11:45:23

Solution

只能靠一直写各位神仙都不屑于做的谔谔题~~来荒废度日~~ 数学公式错了可以[在luogu博客看](https://www.luogu.com.cn/blog/z7z-Eta/solution-p6300) --- 任意砍成两段可以认为是一种类型的生成函数的**2次方** $s_k = \lfloor\{\sum_{i+j=k}min(a_i,a_j)\ \}/2\rfloor$ $a$是$cnt$数组, 除$2$取整已经包含了相同大小的两个拼起来的情况 答案就是$s$最大值, 暴力求复杂度是$O(m^2)$的 --- 正解是一类**取min类型的卷积**到普通乘法卷积的转化 $\sum_{i+j=k}min(a_i,a_j) = \sum_{d=1}\sum_{i+j=k}[a_i\ge d][a_j\ge d]$ 这样**枚举$d$**, 每次卷积就可以求了 也就是说$F_d(x)=\sum[a_i\ge d]x^i$ $s_k=\frac{1}{2}\sum_{d=1}\ [x^k]F_d^2(x)$ --- 我们发现, 显然本质不同的$F_d$是$O(\sqrt{n})$级别的, 因为最坏情况是$1+2+3+4+\dots=O(D^2)$ 所以**离散化$a$数组**直接卷就好了, 复杂度$O(m\sqrt{n}\log m)$, 常数不大 虽然比官方题解看起来~~屑了不少~~, 但是也许更好理解一些^QwQ^. ```cpp // Etavioxy #define ll long long #define il inline #define rep(i,s,t) for(register int i=(s);i<=(t);i++) #define rev_rep(i,s,t) for(register int i=(s);i>=(t);i--) #define pt(x) putchar(x) using namespace std; il int ci(){ register char ch; int f=1; while(!isdigit(ch=getchar())) f=ch=='-'?-1:1; register int x=ch^'0'; while(isdigit(ch=getchar())) x=(x*10)+(ch^'0'); return f*x; } template<typename T> il bool chk_max(T&x,T y){ return (x<y?(x=y,1):0); } const int mod = 998244353; enum{N=270023}; il ll qpow(ll a,ll b){ ll ans= 1; for(; b; b>>=1,a=a*a%mod) if( b&1 ){ ans= ans*a%mod; } return ans; } il ll inv(ll x){ return qpow(x,mod-2); } const ll g0 = 3; const ll invg0 = inv(g0); int rev[N],init_rev_n; il void init_rev(int n){ if( n==init_rev_n ) return; int t = log2(n); init_rev_n = n; rep(i,1,n-1) rev[i] = (rev[i>>1]>>1)|((i&1)<<(t-1)); } int NTT(ll*A,int n,int IDFT){ init_rev(n); rep(i,1,n-1) if( i<rev[i] ) swap(A[i],A[rev[i]]); for(int t=1;t<n;t<<=1){ ll w0 = qpow(IDFT==-1?invg0:g0,(mod-1)/(t*2)); for(int i=0;i<n;i+=t*2){ for(register int j=0,w=1;j<t;j++,w=w*w0%mod){ ll x=A[i+j], y=A[i+t+j]; A[i+j] = (x+w*y)%mod; A[i+t+j] = (x-w*y)%mod; } } } if( IDFT==-1 ){ ll invn = inv(n); rep(i,0,n-1) A[i] = A[i]*invn%mod; } return n; } il int set2(int n){ return 1<<((int)log2(n-1)+1); } // 原代码的模板都进行了压行处理 ll t[N],s[N]; int a[N],b[N]; int main(){ int n = ci(), m = ci(); rep(i,1,n) a[ci()]++; rep(i,1,m) b[i] = a[i]; sort(b,b+m+1); int T = unique(b,b+m+1)-b-1; rep(d,1,T){ rep(i,1,m*2) t[i] = a[i]>=b[d]; int L = set2(m*2); NTT(t,L,1); rep(i,0,L-1) t[i] = t[i]*t[i]%mod; NTT(t,L,-1); rep(i,1,m*2) s[i] += 1ll*(b[d]-b[d-1])*((t[i]+mod)%mod); } ll ans = 0, ans2 = 0; rep(i,1,m*2) if( chk_max(ans,s[i]/2) ) ans2 = i; printf("%lld %lld\n",ans,ans2); return 0; } ``` 思路懂了之后, 写起来的确是道`easy-implement`题(。﹏。)