题解:P10356 [PA2024] Splatanie ciągów

· · 题解

先考虑给定 2 个单调递增的数组 A,B,怎么求 f(A,B)

不妨设 A 长度为 |A|=nB 长度为 |B|=mn\ge m

那么我们考虑往 A 中插入 B,假如插入 xa_i,a_{i+1} 之间,可以发现 a_i\lt x x\lt a_{i+1} 至少满足 1 个,假如 2 个都满足,那么我们可以交换 xa_i,这样一定不劣。

所以我们可以看做用 m 个数,把 A 分成了 m+1 段,但因为插入每个数都会满足一个递增条件,那么就是把 n+m 个数分成了 m+1 段,也就是说我们构造出的答案为 \lceil \frac {n+m}{m+1}\rceil,同时可以发现这就是答案的下界

因此,有 f(A,B)=\lceil \frac {n+m}{m+1}\rceil

也就是说,对于长度为 n 的单调序列,要使答案不超过 x,那么 m 要满足 \lceil \frac {n+m}{m+1}\rceil \le x,解得 m\ge \lceil \frac {n-x}{x-1}\rceil

现在令 a_i 表示 A 中的极长单调区间的长度,例如 A=[1,2,3,4,8,7,6,5],有极长单调区间 [1,2,3,4,8][8,7,6,5],所以 a=[5,4]

考虑一般情况的 $f(A,B)$,同样考虑将短的序列往长的序列里面插入,那么可以得到 $f(A,B)$ 为最小的整数 $x$,满足 $\sum \lceil \frac {a_i-x}{x-1}\rceil\le |B|,\sum \lceil \frac {b_i-x}{x-1}\rceil\le |A| $,两个不等式分别对应 $B$ 短和 $A$ 短的情况。 题目要求的是恰好,可以转变为对于每个 $i$ 求有多少 $f(A',B')\le i$。 可以发现对于任意 $x$,两个式子至少满足一个,因此可以分开求答案,最后减去总方案数,就能得到两个式子都满足的方案数。 假如求出了 $\sum \lceil \frac {a_i-x}{x-1}\rceil=y$,那就是求 $B$ 有多少个长度大等于 $y$ 的区间,这是一个关于 $|B|=m$ 的 $2$ 次多项式。 **注意仅 $1\le y\le m$ 时满足,$y=0$ 可能要特殊处理,$y\gt m$ 要避免计算**。 先整理一下式子,**化为下取整,并用 $x$ 取代 $x-1$** 得到 $y=\sum \lfloor \frac {a_i-2}{x}\rfloor$。 要求所有子区间的答案,枚举每个 $x$,考虑到 $a_i-2\lt x$ 的 $a_i$ 是没有贡献的,也就是说对于 $x$,只有 $O(\frac n x)$ 个有用的 $a_i$。 只要能在 $\widetilde O(\frac n x)$ 的复杂度内求出答案就行了。 方法不少,但似乎都比较繁琐,这里说一下我的做法。 $a_i-2$ 相当于去掉左右端点,可以先去掉右端点,然后把极长区间长度看成 $r_i-l_i$。 然后可以先处理出以 $1$ 为左端点时,答案发生变化的右端点,然后再枚举答案发生变化的左端点,**总结就是每经过一个关键端点答案 $+1$,把除法解决掉**,用双指针维护 $\sum \lfloor \frac {r_i-l_i}{x}\rfloor\le m$ 的最右位置,然后将答案暴力展开成多项式,维护每一项的系数和。 注意到无用的 $a_i$ 相当于是一个空区间,也就是说我们每次求的是一个空区间和实区间交替的区间序列的答案,可能要分几种情况处理,但大体都是类似的。 这样每次是 $O(\frac n x)$ 的,时间复杂度 $O(n\ln n)$。 给一下参考代码吧,但细节比较多所以代码比较恶心,而且都是展开后的式子估计也看不懂,所以建议自己先写暴力然后一点点优化。 参考代码: ```cpp #include<bits/stdc++.h> using namespace std; const int N=6e5+5,P=1e9+7,i2=P+1>>1; typedef long long ll; typedef pair<int,int> pii; int n,m,a[N],b[N],ans[N],s0[N],s1[N],s2[N]; ll va[N]; vector<pii>A,B; void solve(vector<pii>&A,int x,int n,int m){ if(A.empty()){ ans[x]=(ans[x]+1ll*n*(n+1)/2%P*(1ll*m*(m-1)%P))%P; return; } vector<int>tr; for(auto [l,r]:A){ int u=l+x; while(u<=r)tr.push_back(u-1),u+=x; } tr.push_back(n); for(int i=1;i<tr.size();++i){ int w=tr[i]-tr[i-1]; s0[i]=(s0[i-1]+w)%P; s1[i]=(s1[i-1]+1ll*(P-i)*w)%P; s2[i]=(s2[i-1]+1ll*i*i%P*w)%P; } for(int i=0,j=0;i<A.size();++i){ int ls=1,l=A[i].first,r=A[i].second; if(r==n)break; if(i>0)ls=A[i-1].second+1; int u=r-x; vector<int>tl; while(u>=l)tl.push_back(x),u-=x; tl.push_back(u+x-ls+1); ls=r; while(tr[j]<=r)++j; for(int k=0;k<tl.size();++k){ int ct=max(1,k),sum=0,o=m-k+j; if(ct<m)sum=(sum+1ll*(m-ct)*(m-ct+1)%P*(tr[j]-r))%P; int z=min(m-k+j-1,(int)tr.size()-1); if(z>j)sum=(sum+(1ll*o*o+o)%P*(s0[z]-s0[j]+P)+(2ll*o+1)*(s1[z]-s1[j]+P)+s2[z]-s2[j]+P)%P; ans[x]=(ans[x]+1ll*tl[k]*sum)%P; if(k==m-1)break; } } ll cnt=1ll*(A[0].first-1)*A[0].first/2; for(int i=0;i<A.size();++i){ int l=A[i].second+1,r=n; if(i+1<A.size())r=A[i+1].first-1; if(l<=r)cnt+=1ll*(r-l+2)*(r-l+1)/2; } ans[x]=(ans[x]+cnt%P*(1ll*m*(m-1)%P))%P; for(auto [l,r]:A){ va[0]=1ll*(r-l+1)*(r-l+2)/2; int i=1; while(i*x<=r-l){ int d=r-l-i*x; va[i]=1ll*(d+1)*(d+2)/2; ++i; } va[i]=0; for(int j=0;j<i;++j){ cnt=va[j]-va[j+1]; int ct=max(1,j); if(ct<m)ans[x]=(ans[x]+1ll*(m-ct)*(m-ct+1)%P*(cnt%P))%P; } } for(int i=0;i<A.size();++i){ int ls=1,l=A[i].first,r=A[i].second; if(i)ls=A[i-1].second+1; if(ls>=l)continue; vector<int>tr; int u=l+x; while(u<=r)tr.push_back(x),u+=x; tr.push_back(r-u+x+1); for(int j=0;j<tr.size();++j){ int ct=max(1,j); if(ct<m)ans[x]=(ans[x]+1ll*(m-ct)*(m-ct+1)%P*(l-ls)%P*tr[j])%P; } } } void solve(vector<pii>A,int n,int m){ for(int i=1;i<n+m;++i){ solve(A,i,n,m); ans[i]=(ans[i]+1ll*m*(m-1)%P*(n+1))%P; vector<pii>tmp; for(auto [x,y]:A)if(y-x>i)tmp.emplace_back(x,y); A=tmp; } } int main(){ scanf("%d%d",&n,&m); for(int i=1;i<=n;++i)scanf("%d",&a[i]); for(int i=1;i<=m;++i)scanf("%d",&b[i]); for(int i=1,j;i<n;i=j){ j=i+1; if(a[j]>a[i])while(j<n&&a[j+1]>a[j])++j; else while(j<n&&a[j+1]<a[j])++j; if(j-1>i)A.emplace_back(i,j-1); } for(int i=1,j;i<m;i=j){ j=i+1; if(b[j]>b[i])while(j<m&&b[j+1]>b[j])++j; else while(j<m&&b[j+1]<b[j])++j; if(j-1>i)B.emplace_back(i,j-1); } solve(A,n-1,m+1),solve(B,m-1,n+1); printf("0"); ans[0]=1ll*n*(n+1)/2%P*(1ll*m*(m+1)%P)%P; for(int i=1;i<n+m;++i)printf(" %d",1ll*(ans[i]+P-ans[i-1])*i2%P); return 0; } ```