题解:CF2069C Beautiful Sequence
Magallan_forever
·
·
题解
简要说明题意:
若一个序列长度至少为 3,且对于除第一个元素外的所有元素,左侧均有一个更小值,对于除最后一个元素外的所有元素,右侧均有一个更大值,那么这个序列就是一个“漂亮的”序列。
现在给出含 n 个元素的序列 a,满足 1 \leq a_i \leq 3。求 a 有多少个(可不连续的)子序列是“漂亮的”序列,答案对 998244353 取模。
题目分析:
一个“漂亮的”序列,第一个元素一定是最小值,最后一个元素一定是最大值(如果最小/最大值位于中部,左侧/右侧一定没有更小/更大值了,矛盾。)。
并且题目又保证 1 \leq a_i \leq 3。那就说明,“漂亮的”序列一定是 1,2,2,\dots,3 的形式。
那就好解决了,对于每一个 1,我们找后面的每一个 3,假如有 k 个 3,每个 3 与 1 之间有 b_i 个 2,这个 1 对答案的贡献就是 \displaystyle{\sum_{i=1}^k{(2^{b_i}-1)}=\sum_{i=1}^k{2^{b_i}}-k}(也就是 2 的非空子集),我们可以考虑从左到右扫描,用线段树直接在每个 3 的位置维护 2^{b_i}。
单次修改为 $O(\log_2n)$,时间复杂度 $O(\sum{n\log_2n})$ 是建树和扫描的复杂度。
还有一些实现细节,具体看代码注释:
```cpp
#include<cstdio>
#include<algorithm>
using namespace std;
const int mod=998244353,rev2=499122177;
int fp(int a,int b,int mod){
int ans=1;
for(;b;a=(long long)a*a%mod,b>>=1) if(b&1) ans=(long long)ans*a%mod;
return ans;
}
int a[200001],tot[200001]={1};
//维护的修改操作:区间乘
//维护的查询操作:区间加
struct node{
int l,r,cnt,tag;
node(int l_=0,int r_=0) :l(l_),r(r_),cnt(0),tag(1) {}
}tree[800001];
void push_down(int p){
if(!tree[p].tag) return;
int tag=tree[p].tag;
tree[p].tag=1;
tree[p<<1].cnt=(long long)tree[p<<1].cnt*tag%mod,tree[(p<<1)|1].cnt=(long long)tree[(p<<1)|1].cnt*tag%mod;
tree[p<<1].tag=(long long)tree[p<<1].tag*tag%mod,tree[(p<<1)|1].tag=(long long)tree[(p<<1)|1].tag*tag%mod;
}
void push_up(int p){
tree[p].cnt=(tree[p<<1].cnt+tree[(p<<1)|1].cnt)%mod;
}
void build(int p,int l,int r){
tree[p]=node(l,r);
if(l==r) if(a[l]==3) tree[p].cnt=tot[l];//建树只记录a_i=3的情况
if(l^r) build(p<<1,l,(l+r)>>1),build((p<<1)|1,((l+r)>>1)+1,r),push_up(p);
}
void modify(int p,int l,int r){
if(l<=tree[p].l&&tree[p].r<=r){
tree[p].cnt=(long long)tree[p].cnt*rev2%mod;
tree[p].tag=(long long)tree[p].tag*rev2%mod;
return;
}
int mid=(tree[p].l+tree[p].r)>>1;
push_down(p);
if(l<=mid) modify(p<<1,l,r);
if(mid<r) modify((p<<1)|1,l,r);
push_up(p);
}
int query(int p,int l,int r){
if(l<=tree[p].l&&tree[p].r<=r) return tree[p].cnt;
int mid=(tree[p].l+tree[p].r)>>1,ans=0;
push_down(p);
if(l<=mid) ans+=query(p<<1,l,r),ans%=mod;
if(mid<r) ans+=query((p<<1)|1,l,r),ans%=mod;
return push_up(p),ans;
}
int main(){
int T,n,cnt3,ans;
scanf("%d",&T);
while(T--){
scanf("%d",&n),cnt3=0,ans=0;
//读入,统计2^b_i和3的个数
for(int i=1;i<=n;++i) scanf("%d",a+i),tot[i]=(tot[i-1]<<(a[i]==2))%mod,cnt3+=(a[i]==3);
build(1,1,n);
for(int i=1;i<=n;++i){
//这里query(1,i,n)在取模后有可能小于cnt3导致答案为负数,需要加一次mod
if(a[i]==1) ans+=query(1,i,n)-cnt3,ans<0?ans=(ans+mod)%mod:ans%=mod;
if(a[i]==2) modify(1,i,n);
if(a[i]==3) --cnt3;
// printf("i=%d ans=%d\n",i,ans);
}
printf("%d\n",ans);
}
return 0;
}
```