P9149 题解(2023 激励计划评分 7)

· · 题解

我们将 1 \sim w 中在 B 序列出现过的元素称作关键元素

显然,对于一个可能产生贡献的选择方案,关键元素均不能选。于是我们记非关键元素的数量为 c

将 A 序列中的关键元素按下标顺次提出来,得到下标序列 p。我们不妨对 p 的每个长为 m 的子串计算贡献。

当前子串能够产生贡献,当且仅当它与 B 序列完全相同。这一点可以用 kmp 进行判定。接下来我们需要考虑的是贡献系数。

对于子串 p_i,p_{i+1},\dots,p_{i+m-1},显然区间 [p_i,p_{i+m-1}] 内的非关键元素均需要被删除。 记其数量为 o。则该子串的贡献为 {c-o} \choose {d-o}。至于 o 的维护,可以用双指针加桶实现。时间复杂度线性。

一个细节是,如果你 kmp 的写法会访问到 B_{m+1},记得清空这个位置。否则只能通过后两个 Subtask。

代码:

#include<bits/stdc++.h>
using namespace std;

const int N=1e6+10,mod=1e9+7;
int T,n,m,c,w,d,o,L,ans,a[N],b[N],p[N];
int nx[N],cnt[N],fct[N],inv[N],finv[N];
bool vis[N];
inline int read()
{
    int x=0;char ch=getchar();
    while(ch<'0'||ch>'9') ch=getchar();
    while(ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
    return x;
}
inline int C(int x,int y){return y<0||x<y?0:1ll*fct[x]*finv[y]%mod*finv[x-y]%mod;}
inline void clear()
{
    c=o=L=ans=0;a[n+1]=b[m+1]=0;
    for(int i=1;i<=w;i++) cnt[i]=vis[i]=0; 
}
int main()
{
    fct[0]=inv[0]=finv[0]=fct[1]=inv[1]=finv[1]=1;
    for(int i=2;i<N;i++)
    {
        fct[i]=1ll*fct[i-1]*i%mod;
        inv[i]=(mod-1ll*mod/i*inv[mod%i]%mod)%mod;
        finv[i]=1ll*finv[i-1]*inv[i]%mod;
    }
    T=read();
    while(T--)
    {
        n=read(),m=read(),w=read(),d=read();clear();
        for(int i=1;i<=n;i++) a[i]=read();
        for(int i=1;i<=m;i++) b[i]=read(),vis[b[i]]=true;
        for(int i=1;i<=w;i++) if(!vis[i]) ++c;
        for(int i=1;i<=n;i++) if(vis[a[i]]) p[++L]=i;
        for(int i=2,j=0;i<=m;i++)
        {
            while(j&&b[i]!=b[j+1]) j=nx[j];
            if(b[i]==b[j+1]) ++j;nx[i]=j;
        }
        for(int i=1,j=0,l=1,r=0;i<=L;i++)
        {
            while(r<p[i])
            {
                ++r;
                if(vis[a[r]]) continue;
                ++cnt[a[r]];if(cnt[a[r]]==1) ++o;
            }
            if(i>=m)
                while(l<p[i-m+1])
                {
                    if(!vis[a[l]]) 
                    {
                        --cnt[a[l]];
                        if(!cnt[a[l]]) --o;
                    }
                    l++;
                }
            while(j&&a[p[i]]!=b[j+1]) j=nx[j];
            if(a[p[i]]==b[j+1]) ++j;if(j==m) ans=(ans+C(c-o,d-o))%mod;
        }
        printf("%d\n",ans); 
    }
    return 0;
}