题解 P4891 【序列】
cosmicAC
2018-09-24 10:11:12
(刚才ubuntu自带输入法又双叒叕出问题了,不爽)
给一个复杂度(也许)正确的做法:$O((n+q) log^2{n})$ 的。比赛时想到这个做法基本没动脑子。但是写的欲哭无泪。成功的炸零了。比赛结束后又调了一个小时才过。
~~话说我如此优秀的复杂度为什么最优解倒数?不服。~~
**[我的博客](https://www.luogu.org/blog/474D/)**
注意到每次修改A都是把A中的一个数变大。所以可以考虑这样一件事:假设一个数比它之前的所有数都要大,那么在所有时刻中"这样的数“集合肯定变动不超过$O(n+q)$次。可以感性理解一下。所以容易想到用一个set来维护这个集合。
现在考虑每次这个集合中一个数a和a的下一个数b之间(左闭右开)的那一段对答案的贡献。可以发现,这一段的C值都是一样的。所以只要知道区间中有多少个b值小于区间的C值以及这些b值的乘积。用一个数据结构维护即可(我用的是树状数组套主席树)。然后就统计一下区间乘积,乘上C的某个次幂即可。
对于1操作,直接在树状数组套主席树里修改就行了。2操作细节非常多。首先要二分出当前的值在序列的哪一段是最大值(这里也要一个树状数组)。如果是空区间直接跳过(我这里因为直接continue,忘了输出答案,调了好久)。假设区间是[l,r],然后在set里查询一下,要把set中包含l和包含r的区间(**可能是同一个**;这里的区间,指的是一个数a和a的下一个数b之间(左闭右开)的那一段,如果没有下一个数就是到序列结尾)**可能需要**分裂一下,然后把l和r中间的区间全都删掉,同时统计这些区间对答案的贡献,最后在set中插入l,表示这个新的区间。
还有空间需要卡一卡,主席树的节点中区间乘积必须开int,否则要么MLE要么RE。我的空间用了近500MB。
下面是代码:(当然是我最爱的C++17)
```cpp
// luogu-judger-enable-o2
#include<bits/stdc++.h>
#define maxn 100010
#define int64 long long
#define L(x) (a[x].ch[0])
#define R(x) (a[x].ch[1])
using namespace std;
multiset<pair<int,int64>> s;
//说好的set怎么变成multiset了?因为我是先插入新区间再删除旧区间的,可能两个区间完全相同。
const int64 mod=1e9+7;
int n,q,a[maxn],b[maxn],c1[maxn],rt[maxn];
int64 ans=1;
int64 power(int64 a,int64 b){
if(!b)return 1;
int64 ret=power(a,b>>1);
ret=(ret*ret)%mod;
if(b&1)ret=(ret*a)%mod;
return ret;
}
struct segtree{
struct node{int cnt,ch[2],mul=1;}a[30000010];
//必须就是3000W,多一毫MLE,少一毫RE
int tot,_cnt;int64 _mul;
void ins(int x,int v,int modi,int &p,int tl=1,int tr=1e9){
if(!p)p=++tot;
if(tl==tr){a[p].mul=(1ll*a[p].mul*v)%mod;a[p].cnt+=modi;return;}
int mid=tl+tr>>1;
if(x<=mid)ins(x,v,modi,L(p),tl,mid);
else ins(x,v,modi,R(p),mid+1,tr);
a[p].mul=(1ll*a[L(p)].mul*a[R(p)].mul)%mod;
a[p].cnt=a[L(p)].cnt+a[R(p)].cnt;
}
void qry(int l,int r,int p,int tl=1,int tr=1e9){
if(!p)return;
if(l<=tl && tr<=r){(_mul*=a[p].mul)%=mod,_cnt+=a[p].cnt;return;}
int mid=tl+tr>>1;
if(l<=mid)qry(l,r,L(p),tl,mid);
if(r>mid)qry(l,r,R(p),mid+1,tr);
}
}tr;
void chkmax(int &a,int b){a=max(a,b);}
void ins(int p,int v){for(;p<=n;p+=p&-p)chkmax(c1[p],v);}
int qry(int p){int r=0;for(;p;p&=p-1)chkmax(r,c1[p]);return r;}
void Ins(int p,int v,int tp=1){
int t=p;
for(;p<=n;p+=p&-p){
if(tp)tr.ins(b[t],power(b[t],mod-2),-1,rt[p]);
tr.ins(v,v,1,rt[p]);
}
b[t]=v;
}
int64 Qry(int l,int r,int u){
int d=r-l+1;int64 mul=1,cnt=0;
for(l--;l;l&=l-1){
tr._mul=1,tr._cnt=0;
tr.qry(0,u,rt[l]);
(mul*=power(tr._mul,mod-2))%=mod,cnt-=tr._cnt;
}
for(;r;r&=r-1){
tr._mul=1,tr._cnt=0;
tr.qry(0,u,rt[r]);
(mul*=tr._mul)%=mod,cnt+=tr._cnt;
}
return power(u,d-cnt)*mul%mod;
}
int main(){
scanf("%d%d",&n,&q);
for(int i=1;i<=n;i++)scanf("%d",a+i),ins(i,a[i]);
for(int i=1;i<=n;i++)scanf("%d",b+i),Ins(i,b[i],0);
for(int i=1,l=0;i<=n+1;i++)if(i>n || qry(i)==a[i]){
if(l){
int64 t=Qry(l,i-1,a[l]);
s.insert(pair(l,t));
(ans*=t)%=mod;
}
l=i;
}
while(q--){
int op,x,y;scanf("%d%d%d",&op,&x,&y);
if(op){
Ins(x,y);
auto it=s.upper_bound(pair(x,(int64)1e18)),nxt=it;
int64 t;--it;
(ans*=power(it->second,mod-2))%=mod;
s.insert(pair(it->first,
t=Qry(it->first,nxt==s.end()?n:nxt->first-1,qry(it->first))));
s.erase(it);
(ans*=t)%=mod;
}else{
int l=x,r=n+1;ins(x,y);
if(qry(x)!=y){printf("%lld\n",ans%mod);continue;}
while(l<r){
int mid=l+r>>1;
if(qry(mid)>y)r=mid;else l=mid+1;
}
l--;
auto it=s.upper_bound(pair(x,(int64)1e18)),nxt=it;
--it;
int stp=it->first;int64 t=1,t1;
vector<decltype(s.begin())> v;
//原谅我C++17学艺不精,不会更简单的写法
for(;nxt!=s.end() && nxt->first<=l;++nxt,++it)
(ans*=power(it->second,mod-2))%=mod,v.push_back(it);
if(stp!=x)s.insert(pair(stp,t=Qry(stp,x-1,qry(stp)))); //分裂开头
for(auto it:v)s.erase(it); //删除中间区间
int edp=nxt==s.end()?n:nxt->first-1;
if(l!=edp)s.insert(pair(l+1,t1=Qry(l+1,edp,qry(l+1)))),t=(t*t1)%mod; //分裂结尾
(ans*=power(it->second,mod-2))%=mod;
(ans*=t)%=mod;
s.erase(it);
t=Qry(x,l,y);(ans*=t)%=mod;
s.insert(pair(x,t));
}
printf("%lld\n",ans%mod);
}
return 0;
}
```