CF1327F

· · 题解

发现题解区没有人写容斥。

CF1327F

按位考虑。猜测时间复杂度为 O((n+m)k)。如果按位考虑,每一位都是独立的,最后把所有情况乘起来即可。那么问题转化为下面的情况:

有一个长度为 n01 序列。有 m 个限制:

上面的存在比较难搞,因为有多个区间存在交集的情况,考虑容斥是最先想到的方法。即对于限制 1,若强制某个子集 l\sim r 全取 1,那么方案数就是 2^{n-|[l_1,r_1]\cup[l_2,r_2]...|}

考虑挖掘题目性质:

我们考虑一个 dp。在不管条件 2 的情况下,我们把区间排序,此时 l,r 都是递增的,这样子我们就可以优化容斥了,我们并不关心选了几个区间,只关心选的区间个数的奇偶性,定义 f_{i,0/1} 表示强制选区间 i,此时有奇数/偶数个区间的方案,那么有转移:

f_{i,p}=\sum_{j=1}^{i-1} f_{j,\lnot p}\times 2^{count(r_j+1\sim l_i-1)}

其中 count(l,r) 表示 (l,r) 中不强制取 1 的位置个数。

最后算贡献时处理出这个区间后面有几个可以任意取数的位置个数 p,乘上 2^p 即可。

那么使用一些一般的手段对这个 dp 进行优化即可。时间复杂度 O((n+m)k)。带个 \log n 会被卡掉。具体双指针即可。

#include<bits/stdc++.h>
#define ll long long
#define pb emplace_back
#define N 500005
using namespace std;
const int mod=998244353;
int n,k,m;
struct node{
    int l,r,w;
}a[N];
ll f[N][2],res=1;
int d[N],cnt[N];
int R[N],lim,tag[N];
ll s1,s2,r1,r2;
ll pw[N];
ll solve(int bit)
{
    lim=n+1;s1=s2=r1=r2=0;
    memset(R,127/3,sizeof R);
    memset(f,0,sizeof f);
    memset(d,0,sizeof d);
    for(int i=1;i<=m;i++)
        if(a[i].w&(1<<bit)) d[a[i].l]++,d[a[i].r+1]--;
        else R[a[i].l]=min(R[a[i].l],a[i].r);
    for(int i=1;i<=n;i++) d[i]+=d[i-1];
    for(int i=n;i>=1;i--) cnt[i]=cnt[i+1]+(!d[i]);
    for(int i=n;i>=1;i--) if(R[i]<lim) lim=R[i],tag[i]=1; else tag[i]=0;
    ll sum=pw[cnt[1]];
    int l=1;
    s1=1; 
    for(int i=1;i<=n;i++)
    {
        while(l<=i&&(!tag[l]||R[l]<i))
        {
            if(tag[l])
            {
                s1=(s1+f[l][0])%mod;
                s2=(s2+f[l][1])%mod;
                r1=(r1+mod-f[l][0])%mod;
                r2=(r2+mod-f[l][1])%mod;
            }
            ++l;
        }
        if(tag[i])
        {
            f[i][0]=(s2+r2)%mod;
            f[i][1]=(s1+r1)%mod;
            r1=(r1+f[i][0])%mod;
            r2=(r2+f[i][1])%mod;
            sum=(sum+(f[i][0]-f[i][1]+mod)*pw[cnt[R[i]+1]])%mod;
        }
        if(!d[i]) s1=s1*2%mod,s2=s2*2%mod;
    }
    return sum;
}
int main()
{
    scanf("%d%d%d",&n,&k,&m);
    pw[0]=1;
    for(int i=1;i<=n;i++) pw[i]=pw[i-1]*2%mod;
    for(int i=1;i<=m;i++) scanf("%d%d%d",&a[i].l,&a[i].r,&a[i].w);
    for(int i=0;i<k;i++) res=res*solve(i)%mod;
    printf("%lld",res);
    return 0;
}