题解:P6108 [Ynoi2009] rprsvq
yangzichen1203 · · 题解
组合数学好题。
首先拆方差的式子:
发现所有子序列这件事情不好处理,考虑拆成单点贡献。以下设
我们发现
以下设
注意到
则
注意到
::::info[上式的证明]
注意到
:::info[上式的证明] 注意到
则
:::
则
::::
类似地,
然后就做完了。
最后就是线性预处理一大堆东西,然后用线段树处理询问,很好写,不过可能会被卡空间。
时间复杂度:
空间复杂度:
:::success[Code]
#include<bits/stdc++.h>
#define For(i,j,k) for(int i=j;i<=k;i++)
#define dFor(i,j,k) for(int i=j;i>=k;i--)
using namespace std;
#define MAXN 5000005
#define Mod 998244353
int n,m,a[MAXN],f[MAXN],g[MAXN],inv[MAXN],P[MAXN],T1[MAXN],T2[MAXN],T3[MAXN];
inline int md(int x){
return x>=Mod?x-Mod:x;
}
int Pow(int x,int y){
int ans=1;
while(y){
if(y&1){
ans=1ll*ans*x%Mod;
}
x=1ll*x*x%Mod;
y>>=1;
}
return ans;
}
void init(){
f[0]=1;
For(i,1,n+1){
f[i]=1ll*f[i-1]*i%Mod;
}
g[n+1]=Pow(f[n+1],Mod-2);
dFor(i,n,0){
g[i]=1ll*g[i+1]*(i+1)%Mod;
}
For(i,1,n+1){
inv[i]=1ll*g[i]*f[i-1]%Mod;
}
P[0]=1;
For(i,1,n+1){
P[i]=P[i-1]*2%Mod;
}
int pre=0;
For(i,1,n+1){
T1[i]=1ll*(P[i]-1)*inv[i]%Mod;
pre=md(pre+T1[i]);
T2[i]=1ll*pre*inv[i]%Mod;
T3[i]=1ll*(T1[i]-T2[i]+Mod)*inv[i-1]%Mod;
}
}
struct Ans{
int len,sum1,sum2;
Ans(int len=0,int sum1=0,int sum2=0):len(len),sum1(sum1),sum2(sum2){}
};
Ans operator +(const Ans &x,const Ans &y){
return Ans(x.len+y.len,md(x.sum1+y.sum1),md(x.sum2+y.sum2));
}
Ans operator +(const Ans &x,const int &y){
return Ans(x.len,md(x.sum1+1ll*x.len*y%Mod),md(md(x.sum2+2ll*y*x.sum1%Mod)+1ll*x.len*y%Mod*y%Mod));
}
struct Tree{
int add;
Ans ans;
}tr[MAXN*4];
void build(int c,int l,int r){
tr[c].ans=Ans(r-l+1,0,0);
if(l==r) return ;
int mid=(l+r)/2;
build(c*2,l,mid);
build(c*2+1,mid+1,r);
}
void update(int c){
tr[c].ans=tr[c*2].ans+tr[c*2+1].ans;
}
void down(int c){
tr[c*2].ans=tr[c*2].ans+tr[c].add;
tr[c*2].add=md(tr[c*2].add+tr[c].add);
tr[c*2+1].ans=tr[c*2+1].ans+tr[c].add;
tr[c*2+1].add=md(tr[c*2+1].add+tr[c].add);
tr[c].add=0;
}
void modify(int c,int L,int R,int l,int r,int k){
if(L==l&&R==r){
tr[c].ans=tr[c].ans+k;
tr[c].add=md(tr[c].add+k);
return ;
}
down(c);
int mid=(L+R)/2;
if(r<=mid) modify(c*2,L,mid,l,r,k);
else if(l>mid) modify(c*2+1,mid+1,R,l,r,k);
else modify(c*2,L,mid,l,mid,k),modify(c*2+1,mid+1,R,mid+1,r,k);
update(c);
}
Ans query(int c,int L,int R,int l,int r){
if(L==l&&R==r){
return tr[c].ans;
}
down(c);
int mid=(L+R)/2;
if(r<=mid) return query(c*2,L,mid,l,r);
else if(l>mid) return query(c*2+1,mid+1,R,l,r);
else return query(c*2,L,mid,l,mid)+query(c*2+1,mid+1,R,mid+1,r);
}
int main(){
ios::sync_with_stdio(0);
cin.tie(0);cout.tie(0);
cin>>n>>m;
init();
build(1,1,n);
while(m--){
int op,l,r;
cin>>op>>l>>r;
if(op==1){
int x;
cin>>x;
modify(1,1,n,l,r,x);
}else{
Ans t=query(1,1,n,l,r);
int sum1=t.sum1,sum2=t.sum2;
int t1=T1[r-l+1],t2=T2[r-l+1],t3=T3[r-l+1];
int ans=1ll*sum2*t1%Mod-1ll*sum2*t2%Mod-(1ll*sum1*sum1-sum2)%Mod*t3%Mod;
while(ans<0) ans+=Mod;
cout<<ans<<'\n';
}
}
return 0;
}
:::