P5213 [SHOI2014]超能粒子炮 题解

· · 题解

注意到 \min\{a,a^{-1}\}\leq1000,必有蹊跷。

考虑将连续的一段不受取模影响的 f(i) 合并,这样每段的长度都应该是 O(\frac na),那么这样的段有 O(a) 个。

这就可以做 a\leq 1000 的部分了,直接算出每一段的首项和长度,然后枚举一对连续段计算逆序对。

但是这里有个问题,这真的是个函数吗?换句话说, $i$ 和 $f(i)$ 之间是否有一一对应关系,否则反函数将不存在,这个算法也就不可行了。 幸运的是在 $m$ 是质数是这的确是个函数,证明就考虑不妨设 $n=m$,然后这时 $i=1,2,\cdots,n$ 是模 $m$ 意义下的一个完系,那么 $ai=a,2a,\cdots,n$ 也是一个完系,加上常量相当于整体旋转,所以 $f(i)$ 是模 $m$ 意义下的一个完系,即模 $m$ 意义下互不相同。因此反函数存在。 然后考虑 $n<m$ 相当于在互不相同的 $f(i)$ 中取一个子集出来,显然互不相同。 然后的工作比较简单,和上面一样的方式计算逆序对即可。值得注意的是并非所有位置的值都是合法的,需要把 $i=0$ 和 $i>n$ 的位置去掉。 求逆序对需要比较多的分类讨论,不是很好写。 ```cpp #include<iostream> #include<cstdio> using namespace std; #define int long long int n,mod,a,b,ans; pair<int,int> v[10001]; inline int read() { int x=0; char c=getchar(); while(c<'0'||c>'9') c=getchar(); while(c>='0'&&c<='9') { x=(x<<1)+(x<<3)+(c^48); c=getchar(); } return x; } inline int pw(int a,int b) { int res=1; while(b) { if(b&1) res=res*a%mod; b>>=1; a=a*a%mod; } return res; } signed main() { n=read(),mod=read(),a=read()%mod,b=read()%mod; if(!a) { puts("0"); return 0; } if(a<=1000) { for(int i=1,j=1;j<=n;++i) { v[i]={(a*j%mod+b)%mod+1,(mod-(a*j%mod+b)%mod-1)/a+1}; if(j+v[i].second>n) { v[i].second=n-j+1; j=n+1; } else j+=v[i].second; for(int k=1;k<i;++k) if(v[k].first<v[i].first) { int cnt=v[k].second-((v[i].first-v[k].first)/a+1); ans+=(cnt+cnt-v[i].second+1)*v[i].second/2; } else if(v[k].first>v[i].first) { int cnt=(v[k].first-v[i].first+a-1)/a; if(cnt>=v[i].second) { ans+=v[i].second*v[k].second; continue; } ans+=cnt*v[k].second+(v[k].second-1+v[k].second-v[i].second+cnt)*(v[i].second-cnt)/2; } else ans+=(v[k].second-min(v[k].second,v[i].second)+v[k].second-1)*min(v[i].second,v[k].second)/2; } cout<<ans<<'\n'; return 0; } a=pw(a,mod-2); b=(mod-(b+1)%mod*a%mod)%mod; for(int i=1,j=1;j<=mod;++i) { v[i]={(a*j%mod+b-1)%mod+1,(mod-(a*j%mod+b-1)%mod-1)/a+1}; if(j+v[i].second>mod) { v[i].second=mod-j+1; j=mod+1; } else j+=v[i].second; if(v[i].first>n||v[i].second<=0) { --i; continue; } if(v[i].first+(v[i].second-1)*a>n) v[i].second=(n-v[i].first)/a+1; if(!v[i].first) { --v[i].second; v[i].first+=a; } for(int k=1;k<i;++k) if(v[k].first<v[i].first) { int cnt=v[k].second-((v[i].first-v[k].first)/a+1); ans+=(cnt+cnt-v[i].second+1)*v[i].second/2; } else if(v[k].first>v[i].first) { int cnt=(v[k].first-v[i].first+a-1)/a; if(cnt>=v[i].second) { ans+=v[i].second*v[k].second; continue; } ans+=cnt*v[k].second+(v[k].second-1+v[k].second-v[i].second+cnt)*(v[i].second-cnt)/2; } else ans+=(v[k].second-min(v[k].second,v[i].second)+v[k].second-1)*min(v[i].second,v[k].second)/2; } cout<<ans<<'\n'; return 0; } ```