题解 P4891 【序列】

cosmicAC

2018-09-24 10:11:12

Solution

(刚才ubuntu自带输入法又双叒叕出问题了,不爽) 给一个复杂度(也许)正确的做法:$O((n+q) log^2{n})$ 的。比赛时想到这个做法基本没动脑子。但是写的欲哭无泪。成功的炸零了。比赛结束后又调了一个小时才过。 ~~话说我如此优秀的复杂度为什么最优解倒数?不服。~~ **[我的博客](https://www.luogu.org/blog/474D/)** 注意到每次修改A都是把A中的一个数变大。所以可以考虑这样一件事:假设一个数比它之前的所有数都要大,那么在所有时刻中"这样的数“集合肯定变动不超过$O(n+q)$次。可以感性理解一下。所以容易想到用一个set来维护这个集合。 现在考虑每次这个集合中一个数a和a的下一个数b之间(左闭右开)的那一段对答案的贡献。可以发现,这一段的C值都是一样的。所以只要知道区间中有多少个b值小于区间的C值以及这些b值的乘积。用一个数据结构维护即可(我用的是树状数组套主席树)。然后就统计一下区间乘积,乘上C的某个次幂即可。 对于1操作,直接在树状数组套主席树里修改就行了。2操作细节非常多。首先要二分出当前的值在序列的哪一段是最大值(这里也要一个树状数组)。如果是空区间直接跳过(我这里因为直接continue,忘了输出答案,调了好久)。假设区间是[l,r],然后在set里查询一下,要把set中包含l和包含r的区间(**可能是同一个**;这里的区间,指的是一个数a和a的下一个数b之间(左闭右开)的那一段,如果没有下一个数就是到序列结尾)**可能需要**分裂一下,然后把l和r中间的区间全都删掉,同时统计这些区间对答案的贡献,最后在set中插入l,表示这个新的区间。 还有空间需要卡一卡,主席树的节点中区间乘积必须开int,否则要么MLE要么RE。我的空间用了近500MB。 下面是代码:(当然是我最爱的C++17) ```cpp // luogu-judger-enable-o2 #include<bits/stdc++.h> #define maxn 100010 #define int64 long long #define L(x) (a[x].ch[0]) #define R(x) (a[x].ch[1]) using namespace std; multiset<pair<int,int64>> s; //说好的set怎么变成multiset了?因为我是先插入新区间再删除旧区间的,可能两个区间完全相同。 const int64 mod=1e9+7; int n,q,a[maxn],b[maxn],c1[maxn],rt[maxn]; int64 ans=1; int64 power(int64 a,int64 b){ if(!b)return 1; int64 ret=power(a,b>>1); ret=(ret*ret)%mod; if(b&1)ret=(ret*a)%mod; return ret; } struct segtree{ struct node{int cnt,ch[2],mul=1;}a[30000010]; //必须就是3000W,多一毫MLE,少一毫RE int tot,_cnt;int64 _mul; void ins(int x,int v,int modi,int &p,int tl=1,int tr=1e9){ if(!p)p=++tot; if(tl==tr){a[p].mul=(1ll*a[p].mul*v)%mod;a[p].cnt+=modi;return;} int mid=tl+tr>>1; if(x<=mid)ins(x,v,modi,L(p),tl,mid); else ins(x,v,modi,R(p),mid+1,tr); a[p].mul=(1ll*a[L(p)].mul*a[R(p)].mul)%mod; a[p].cnt=a[L(p)].cnt+a[R(p)].cnt; } void qry(int l,int r,int p,int tl=1,int tr=1e9){ if(!p)return; if(l<=tl && tr<=r){(_mul*=a[p].mul)%=mod,_cnt+=a[p].cnt;return;} int mid=tl+tr>>1; if(l<=mid)qry(l,r,L(p),tl,mid); if(r>mid)qry(l,r,R(p),mid+1,tr); } }tr; void chkmax(int &a,int b){a=max(a,b);} void ins(int p,int v){for(;p<=n;p+=p&-p)chkmax(c1[p],v);} int qry(int p){int r=0;for(;p;p&=p-1)chkmax(r,c1[p]);return r;} void Ins(int p,int v,int tp=1){ int t=p; for(;p<=n;p+=p&-p){ if(tp)tr.ins(b[t],power(b[t],mod-2),-1,rt[p]); tr.ins(v,v,1,rt[p]); } b[t]=v; } int64 Qry(int l,int r,int u){ int d=r-l+1;int64 mul=1,cnt=0; for(l--;l;l&=l-1){ tr._mul=1,tr._cnt=0; tr.qry(0,u,rt[l]); (mul*=power(tr._mul,mod-2))%=mod,cnt-=tr._cnt; } for(;r;r&=r-1){ tr._mul=1,tr._cnt=0; tr.qry(0,u,rt[r]); (mul*=tr._mul)%=mod,cnt+=tr._cnt; } return power(u,d-cnt)*mul%mod; } int main(){ scanf("%d%d",&n,&q); for(int i=1;i<=n;i++)scanf("%d",a+i),ins(i,a[i]); for(int i=1;i<=n;i++)scanf("%d",b+i),Ins(i,b[i],0); for(int i=1,l=0;i<=n+1;i++)if(i>n || qry(i)==a[i]){ if(l){ int64 t=Qry(l,i-1,a[l]); s.insert(pair(l,t)); (ans*=t)%=mod; } l=i; } while(q--){ int op,x,y;scanf("%d%d%d",&op,&x,&y); if(op){ Ins(x,y); auto it=s.upper_bound(pair(x,(int64)1e18)),nxt=it; int64 t;--it; (ans*=power(it->second,mod-2))%=mod; s.insert(pair(it->first, t=Qry(it->first,nxt==s.end()?n:nxt->first-1,qry(it->first)))); s.erase(it); (ans*=t)%=mod; }else{ int l=x,r=n+1;ins(x,y); if(qry(x)!=y){printf("%lld\n",ans%mod);continue;} while(l<r){ int mid=l+r>>1; if(qry(mid)>y)r=mid;else l=mid+1; } l--; auto it=s.upper_bound(pair(x,(int64)1e18)),nxt=it; --it; int stp=it->first;int64 t=1,t1; vector<decltype(s.begin())> v; //原谅我C++17学艺不精,不会更简单的写法 for(;nxt!=s.end() && nxt->first<=l;++nxt,++it) (ans*=power(it->second,mod-2))%=mod,v.push_back(it); if(stp!=x)s.insert(pair(stp,t=Qry(stp,x-1,qry(stp)))); //分裂开头 for(auto it:v)s.erase(it); //删除中间区间 int edp=nxt==s.end()?n:nxt->first-1; if(l!=edp)s.insert(pair(l+1,t1=Qry(l+1,edp,qry(l+1)))),t=(t*t1)%mod; //分裂结尾 (ans*=power(it->second,mod-2))%=mod; (ans*=t)%=mod; s.erase(it); t=Qry(x,l,y);(ans*=t)%=mod; s.insert(pair(x,t)); } printf("%lld\n",ans%mod); } return 0; } ```