题解:AT_arc120_f [ARC120F] Wine Thief

· · 题解

ARC120F

简要题意: 求所有大小为 k 相邻边独立集的价值和。

解法:

带着贡献数数不好做,拆贡献。考虑 a_i 对价值和的贡献,令 F_i 表示钦定包含 a_i 的合法集合方案数,则答案就是 \displaystyle\sum a_i\times F_i

那么自然令 f(i,j) 表示 i 个数,选出 j 个数且互不相邻的方案数。如何组合计数?将一个位置 xx+1 绑定作为一组,那么相当于把这些组塞到 i 个位置中且不重叠,不过由于放在最后也是合法的,但是此时 x+1 会捅出去,所以实际上是放到 (i+1) 个空中(类似思路)。(x+1) 会占一个位置,所以可以放 x 的位置有 (i-j+1) 个,选出来 j 个放 x,于是 f(i,j)=\dbinom{i-j+1}{j}

如何算出 F_i?由于 i 选了,所以 (i-1),(i+1) 也被选了,一个错误的思路是,枚举左边选了 j 个数,右边就选 (k-j-1) 个数,得到 F_i=\displaystyle\sum_{j}f(i-2,j)\times f(n-i-1,k-j-1),不过注意到这无法用基本方法快速计算。

考虑一种类似容斥的做法,既然 (i-1,i,i+1) 已经占了三个空,那么方案数就是 f(n-3,k-1)?错了,在这种意义下,i-2,i+2 会被认为是相邻的,因此算少了 i-2,i+2 都被选中的方案,继续加上这种方案,于是 [i-3,i+3] 都被占,方案数是 f(n-7,k-3),一直推下去,得到总方案数是:

F_i=\sum_{j} f(n-4j+1,k-2j+1)

我们发现这个和式的项中和 i 无关,与 i 有关的就是 j 的范围,因此可以直接预处理出上述和式的前缀和,对于每个 i 求出 j 上界后可以直接计算出 F_i,不过注意细节,如果跑到了边缘,那么超出的那个 \texttt{ban} 掉的范围是要忽略的。

总时间复杂度 O(n)

#include <iostream>
#include <cstring>
#include <algorithm>

using namespace std;

typedef long long ll;
const int mod=998244353;
inline int Mod(int x) { return x<0 ? x+mod : (x>=mod ? x-mod : x); }
inline void adm(int &x,int y) { x=Mod(x+y); }
inline int qmi(ll a,int b)
{
    ll res=1;
    for (;b;b>>=1,a=a*a%mod) if (b&1) res=res*a%mod;
    return res;
}
const int N=600010;
int n,k;
int a[N];
int fac[N],infac[N];
void init(int M=N-1)
{
    fac[0]=1;
    for (int i=1;i<=M;i++) fac[i]=(ll)fac[i-1]*i%mod;
    infac[M]=qmi(fac[M],mod-2);
    for (int i=M-1;~i;i--) infac[i]=(ll)infac[i+1]*(i+1)%mod;
}
int binom(int a,int b)
{
    if (a<0 || b<0 || b>a) return 0;
    return (ll)fac[a]*infac[b]%mod*infac[a-b]%mod;
}
int f(int a,int b) { return binom(a-b+1,b); }
int sum[N];

int main()
{
    int buf;
    cin >> n >> k >> buf;
    for (int i=1;i<=n;i++) cin >> a[i];
    init();

    for (int i=1;i<=n;i++)
    {
        int cof=f(n-4*i+1,k-2*i+1);
        sum[i]=Mod(sum[i-1]+cof);
    }

    int ans=0;
    for (int i=1;i<=n;i++)
    {
        int mxj=min((i+1)/2,(n-i+2)/2);
        int s=(mxj ? sum[mxj-1] : 0);
        auto calc = [&](int j) 
        {
            int cut=n-4*j+1+((i-2*j+1==0)+(i+2*j-1==n+1));
            int cof=f(cut,k-2*j+1);
            return cof;
        };
        adm(s,calc(mxj));
        //cout << i << " " << s << "\n";
        adm(ans,(ll)a[i]*s%mod);
    }
    cout << ans << "\n";

    return 0; 
}