P8367 [LNOI2022] 盒

· · 题解

\text{Foreword}

一道比较神的数学题...
考场写的 40 分做法,感觉自己实际数学水平也就这 40 分了(悲
出题人给了两种做法,我的评价是:组合意义天地灭,代数推导天地灭
感觉组合意义还是相对阳间一点的,所以讲解一下这种做法。

\text{Solution}

问题很典,定义 s_i=\sum_{j=1}^ia_j,单独考虑每条边的贡献,答案就是:

ans=\sum_{i=1}^{n-1}w_i\sum_{j=0}^S|s_i-j|f(i,j)f(n-i,S-j)

其中 f(i,j) 表示 i 个非负元素加和为 j 的方案数,隔板法不难得出 f(i,j)=\binom {i+j-1}{i-1}
那么答案就是:

ans=\sum_{i=1}^{n-1}w_i\sum_{j=0}^S|s_i-j|\binom{i+j-1}{i-1}\binom {n-i+S-j-1}{n-i-1}

然后就可以拿 40 跑路了。
接下来我们重点关心对于每一个 i,如何快速计算内层枚举 j 的结果 res_i
把绝对值拆掉,再把求和用前缀和作差改成从 0 开始,就有:

res_i=2\sum_{j=0}^{s_i}(s_i-j)\binom{i+j-1}{i-1}\binom {n-i+S-j-1}{n-i-1}+\sum_{j=0}^S(j-s_i)\binom{i+j-1}{i-1}\binom {n-i+S-j-1}{n-i-1}

两者本质上只有枚举上界不同,显然后面上界为 S 的式子更好做,我们重点关心一下这部分。
来自组合数问题的 trick:j\binom{i+j-1}{j}=i\binom {i+j-1}{j-1},那么后面的式子就是:

i\sum_{j=0}^S\binom {i+j-1}{j-1}\binom{n-i+S-j-1}{S-j}-s_i\sum_{j=0}^S\binom{i+j-1}{j}\binom {n-i+S-j-1}{S-j}

前面组合数里的 j-1 会出现 -1 令人感到生理不适...回到原式发现这个地方对应的项是 0,所以不妨从 1 开始枚举,再令 j\gets j-1,就是:

i\sum_{j=0}^{S-1}\binom {i+j}{j}\binom{n-i+S-j-2}{S-j-1}-s_i\sum_{j=0}^S\binom{i+j-1}{j}\binom {n-i+S-j-1}{S-j}

好,到这里我们就不会了需要进一步的赋予这两个和式新的组合意义,首先,考虑一个简单的组合问题:从 (0,0)->(n,m),只能往右上走的方案数是多少?
显然是 \binom {n+m}{n}。换个角度,考虑其从第 i 列移动到第 i+1 列时的纵坐标的位置 j,就可以从另一个角度来计算这个方案数,因此有:

\binom {n+m}{n}=\sum_{j=0}^m\binom{i+j}{j}\binom{n+m-i-j-1}{m-j}

我们把原来的柿子试图往这个形式上套:

i\sum_{j=0}^{S-1}\binom {i+j}{j}\binom{n-i+S-j-2}{S-j-1}-s_i\sum_{j=0}^S\binom{i+j-1}{j}\binom {n-i+S-j-1}{S-j} =i\sum_{j=0}^{S-1}\binom {i+j}{j}\binom{n+(S-1)-i-j-1}{(S-1)-j}-s_i\sum_{j=0}^S\binom{(i-1)+j}{j}\binom {(n-1)-(i-1)+S-j-1}{S-j} =i\binom {n+S-1}{n}-s_i\binom {n+S-1}{S}

这样这部分我们就解决了。

再考虑另一个和式,发现其实唯一的不同就是枚举上界发生了变化,那么对应的组合意义就变成了形如:从 (0,0)->(n,m) 在第 p 列的纵坐标不能超过 q 的方案数是多少?
那么这个限制其实也就等价于:从第 q 行走到第 q+1 行时,横坐标至少为 p+1,两种角度对应的方案数相等,就有:

\sum_{i=0}^q\binom{p+i}{i}\binom{n+m-p-i-1}{m-i}=\sum_{i=p+1}^n\binom{i+q}{q}\binom{n+m-i-q-1}{n-i}

有了这两个柿子,我们就可以 O(1) 的从 (n,m,p,q) 的答案转移到 (n,m,p,q+1)(n,m,p+1,q)
回到我们实际的问题,发现我们需要的 p/q 确实都是单调不降的,所以我们采用这样的方法暴力移动 p,q 两个指针均摊复杂度就可以做到线性了。

\text{Code}

#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define ull unsigned ll
#define debug(...) fprintf(stderr,__VA_ARGS__)
#define ok debug("OK\n")
inline ll read() {
  ll x(0),f(1);char c=getchar();
  while(!isdigit(c)) {if(c=='-') f=-1;c=getchar();}
  while(isdigit(c)) {x=(x<<1)+(x<<3)+c-'0';c=getchar();}
  return x*f;
}

const int N=3e6+100;
const double inf=1e18;
const int mod=998244353;

inline ll ksm(ll x,ll k){
  ll res=1;
  while(k){
    if(k&1) res=res*x%mod;
    x=x*x%mod;
    k>>=1;
  }
  return res;
}

int n;

ll jc[N],ni[N];
void init(int n){
  jc[0]=1;
  for(int i=1;i<=n;i++) jc[i]=jc[i-1]*i%mod;
  ni[n]=ksm(jc[n],mod-2);
  for(int i=n-1;i>=0;i--) ni[i]=ni[i+1]*(i+1)%mod;
  return;
}
inline ll C(int n,int m){
  //if(n<m){
    //ok;
  //}
  return n<m?0:jc[n]*ni[m]%mod*ni[n-m]%mod;
}

struct Calc{
  int n,m,p,q;
  ll res;
  void init(int N,int M){
    n=N;m=M;
    p=0;q=0;
    assert(n+m-1>=m);
    res=C(n+m-1,m);
    return;
  }
  inline ll calc(int P,int Q){
    //if(p<0||q<0) return 0;
    while(q<Q){
      ++q;
      //assert(n+m-q-p-1>=m-q);
      (res+=C(q+p,q)*C(n+m-q-p-1,m-q))%=mod;
    }
    while(p<P){
      ++p;
      //assert(n+m-q-p-1>=n-p);
      res=(res-C(p+q,p)*C(n+m-p-q-1,n-p)%mod+mod)%mod;
    }
    return res;
  }
}f1,f2;

ll s[N],w[N],S;
void work(){
  n=read();
  for(int i=1;i<=n;i++) s[i]=s[i-1]+read();
  for(int i=1;i<n;i++) w[i]=read();  
  S=s[n];
  f1.init(n-1,S);
  f2.init(n,S-1);
  ll ans(0);
  for(int i=1;i<n;i++){
    ll add(0);
    assert(n+S-1>=n);
    assert(n+S-1>=S);
    //printf("i=%d s=%lld\n",i,s[i]);
    add=(add+i*C(n+S-1,n))%mod;    
    //printf("  add=%lld\n",add);
    add=(add+mod-s[i]*C(n+S-1,S)%mod)%mod;
    //printf("  add=%lld\n",add);
    if(s[i]) add=(add+2*s[i]*f1.calc(i-1,s[i]))%mod;
    //printf("  add=%lld\n",add);
    if(s[i]) add=(add+mod-2*i*f2.calc(i,s[i]-1)%mod)%mod;
    //printf("  add=%lld\n",add);
    ans=(ans+add*w[i])%mod;
    //printf("%lld %lld\n",ans,add*w[i]%mod);
  }
  printf("%lld\n",ans);
}

signed main(){
  #ifndef ONLINE_JUDGE
  freopen("a.in","r",stdin);
  freopen("a.out","w",stdout);
  #endif
  init(2e6+5e5+500);
  int T=read();
  while(T--) work();
  return 0;
}
/*
*/