ABC262Ex Max Limited Sequence 题解

· · 题解

首先,我们可以用线段树求出每个位置 i 的值域上界 v_i。具体而言,我们有:

v_i=\min_{l_j \le i \le r_j} x_j

接下来,我们分别对于每一个值域上界 w 进行动态规划。

我们拎出所有满足 v_i=w 的位置 i,设这些位置分别为 c_1,c_2,\cdots,c_s

同时,我们也拎出所有满足 x_i=w 的限制 i,设这些限制分别为 d_1,d_2,\cdots,d_t

显然我们只关心我们拎出的每个位置所填的值是否是 w,所以我们可以定义 f_{i,j} 表示当前考虑到第 c_i 位且上一个所填的值为 w 的位置为 c_j 的方案数。

初始化 f_{0,0}=1。为了方便,我们认为 c_0=0c_{s+1}=n+1

先不考虑限制,尝试转移:

\begin{aligned} &f_{i,i}=\sum_{j=0}^{i-1} f_{i-1,j}\\ &f_{i,j}=f_{i-1,j} \times w \end{aligned}

注意到,此时每个限制相当于要求 [l_{d_j},r_{d_j}] 中有至少一个位置的值为 w,所以只需要在 c_i \le r_{d_j} \lt c_{i+1} 时将所有满足 k \lt l_{d_j}f_{i,k} 的值赋为 0 即可。

我们可以用线段树维护 f 数组,只需要支持单点修改、区间乘、区间求和即可。此时的方案数即为 \sum\limits_{i=0}^s f_{s,i}

最终的答案即为每个值域上界的方案数之积,注意每个没有被限制覆盖的位置的贡献均为 m+1

时间复杂度 \mathcal O(n\log n)

#include <bits/stdc++.h>

#define ll long long
#define i128 __int128
#define endl '\n'
#define pb push_back
#define pii pair<int,int>
#define fi first
#define se second
#define vei vector<int>
#define pq priority_queue
#define yes puts("yes")
#define no puts("no")
#define Yes puts("Yes")
#define No puts("No")
#define YES puts("YES")
#define NO puts("NO")
#define In(x) freopen(x".in","r",stdin)
#define Out(x) freopen(x".out","w",stdout)
#define File(x) (In(x),Out(x))
using namespace std;
const int N=2e5+5,inf=1e9,mod=998244353;
int n,m,q,l[N],r[N],x[N],v[N],mi[N<<2],c[N],d[N],s,t,w,val[N<<2],tag[N<<2],ans=1;
pii a[N],b[N];
void work(int g,int l,int r,int x,int y,int k){
    if(x<=l&&r<=y){
        mi[g]=min(mi[g],k);
        return;
    }
    if(r<x||y<l) return;
    int m=(l+r)>>1;
    work(g<<1,l,m,x,y,k);
    work(g<<1|1,m+1,r,x,y,k);
}
void search(int g,int l,int r,int k){
    k=min(mi[g],k);
    if(l==r){
        v[l]=k;
        return;
    }
    int m=(l+r)>>1;
    search(g<<1,l,m,k);
    search(g<<1|1,m+1,r,k);
}
int add(int a,int b){
    return a+b>=mod?a+b-mod:a+b;
}
int mul(int a,int b){
    return 1ll*a*b%mod;
}
void upd(int g){
    val[g]=add(val[g<<1],val[g<<1|1]);
}
void down(int g){
    tag[g<<1]=mul(tag[g<<1],tag[g]);
    tag[g<<1|1]=mul(tag[g<<1|1],tag[g]);
    val[g<<1]=mul(val[g<<1],tag[g]);
    val[g<<1|1]=mul(val[g<<1|1],tag[g]);
    tag[g]=1;
}
void build(int g,int l,int r){
    val[g]=0;
    tag[g]=1;
    if(l==r) return;
    int m=(l+r)>>1;
    build(g<<1,l,m);
    build(g<<1|1,m+1,r);
}
void modify(int g,int l,int r,int x,int k){
    if(l==x&&r==x){
        val[g]=k;
        return;
    }
    if(r<x||x<l) return;
    down(g);
    int m=(l+r)>>1;
    modify(g<<1,l,m,x,k);
    modify(g<<1|1,m+1,r,x,k);
    upd(g);
}
void times(int g,int l,int r,int x,int y,int k){
    if(x<=l&&r<=y){
        tag[g]=mul(tag[g],k);
        val[g]=mul(val[g],k);
        return;
    }
    if(r<x||y<l) return;
    down(g);
    int m=(l+r)>>1;
    times(g<<1,l,m,x,y,k);
    times(g<<1|1,m+1,r,x,y,k);
    upd(g);
}
int ask(int g,int l,int r,int x,int y){
    if(x<=l&&r<=y) return val[g];
    if(r<x||y<l) return 0;
    down(g);
    int m=(l+r)>>1;
    return add(ask(g<<1,l,m,x,y),ask(g<<1|1,m+1,r,x,y));
}
bool cmp(int a,int b){
    return r[a]<r[b];
}
void dp(){
    sort(c+1,c+1+s);
    sort(d+1,d+1+t,cmp);
    c[s+1]=n+1;
    build(1,0,s);
    for(int i=0,j=1;i<=s;i++){
        if(i==0) modify(1,0,s,0,1);
        else{
            int sum=ask(1,0,s,0,i-1);
            modify(1,0,s,i,sum);
            times(1,0,s,0,i-1,w);
        }
        while(j<=t&&c[i]<=r[d[j]]&&r[d[j]]<c[i+1]){
            int pos=lower_bound(c,c+1+s,l[d[j]])-c-1;
            times(1,0,s,0,pos,0);
            j++;
        }
    }
    int sum=ask(1,0,s,0,s);
    ans=mul(ans,sum);
}
void solve(){
    cin>>n>>m>>q;
    for(int i=1;i<=q;i++) cin>>l[i]>>r[i]>>x[i];
    memset(mi,0x3f,sizeof mi);
    for(int i=1;i<=q;i++) work(1,1,n,l[i],r[i],x[i]);
    search(1,1,n,inf);
    for(int i=1;i<=n;i++) a[i].fi=v[i],a[i].se=i;
    for(int i=1;i<=q;i++) b[i].fi=x[i],b[i].se=i;
    sort(a+1,a+1+n);
    sort(b+1,b+1+q);
    for(int i=1,j=1;j<=q;){
        s=0,t=0,w=b[j].fi;
        while(a[i].fi==w) c[++s]=a[i++].se;
        while(b[j].fi==w) d[++t]=b[j++].se;
        dp();
    }
    for(int i=1;i<=n;i++) if(v[i]==inf) ans=mul(ans,m+1);
    cout<<ans<<endl;
}
signed main(){
    ios::sync_with_stdio(0);
    signed T=1;
//  cin>>T;
    while(T--) solve();
    return 0;
}