题解:CF2200G Operation Permutation

· · 题解

建议评绿到蓝。

题意

给定一个 x 和 长度为 n 的操作序列(可任意排列),要求随机排序后按顺序依次对 x 操作,求对于模 10^9+7 后的最终结果的期望值。

解法

首先,我们考虑简化四种运算:

  1. 减法等价于加一个负数
  2. 除法等价于乘上逆元

    这样,我们就先把四种运算拆成两种。然后,我们考虑拆贡献。假设有 m 个乘法操作,对于每个加操作,它的贡献就是它后面的乘法运算个数(假设有 k 个),原因如下:

    (x+a)\times \prod_{i=1}^k b_i= x\times \prod_{i=1}^k b_i+a\times \prod_{i=1}^k b_i

那么怎么高效求加法运算的贡献呢?考虑设动规数组 f_{i,j} 表示从前 i 个乘法操作里面选 j 个的所有子集之和,那么转移不难得到(类似组合数的转移),设第 i 个乘的数为 b_i,那么有:

f_{i,j}=f_{i-1,j}+f_{i-1,j-1}\times b_i

然后,我们要求的是期望,也就是总贡献除以总方案数,对于 m 个乘法运算,任选 k 个,贡献是 f_{m,k},共有 C_m^k 个方案数(连乘时顺序改变,结果是不变的,所以是组合数),那么期望就是 \frac{f_{m,k}}{C_m^k}

可以发现,若总共有 m 个乘法,加法运算可以放在任意一次乘法之前或之后,也就是共有 m+1 个可选位置,那么任选一个的概率就是 \frac{1}{m+1}。既然是期望,我们肯定要考虑所有情况,因此需要枚举加法放在哪一个位置,此时可以得出这个加法的期望就是 \frac{\sum_{k=0}^m \frac{f_{m,k}}{C_m^k}}{m+1}

然后,对于乘法的贡献,比较显然就是 \prod_{k=1}^m b_k

整体思路已经出来,但是如果直接按照上述过程实现,会发现一个问题:有若干次加法,如果对于每个加法都分别求一遍期望,在极限情况下(加法有 \frac{n}{2} 个,乘法有 \frac{n}{2} 个,每次都要重新枚举,时间复杂度会来到 O(n^3) 级别!)是不行的!

如何降低时间复杂度?不难发现,将所有加法汇总在一起,变成一个加法,这一个总加法会遍历所有位置,因此所有位置的贡献都会被计算到,每个位置也都求了到当前位置的概率,总而言之最后期望是不变的!这个一次加法的期望等价于若干个分别加法的期望!因此,最后我们只需要关心乘法个数即可,时间复杂度 O(m^2)

代码如下:

#include<bits/stdc++.h>
using namespace std;
#define int long long
#define F(i,a,b) for(int i=(a);i<=(b);i++)
#define DF(i,a,b) for(int i=(a);i>=(b);i--)
#define pb push_back
const int P=1e9+7,N=5e3+10;
int fac[N]={1};
int fp(int x,int k){
    int ans=1;
    for(x%=P;k;k>>=1,(x*=x)%=P)if(k&1)(ans*=x)%=P;
    return ans;
}
void solve(){
    int n,x,add=0;cin>>n>>x;
    vector<int>mul;
    F(i,1,n){
        char c;int v;
        cin>>c>>v,fac[i]=(fac[i-1]*i)%P;
        if(c=='-')(add+=P-v)%=P;
        else if(c=='+')(add+=v)%=P;
        else if(c=='/')mul.pb(fp(v,P-2));
        else mul.pb(v);
    }
    int m=mul.size(),S=0;
    vector<int>f(m+1);f[0]=1;
    for(int&b:mul){
        (x*=b)%=P;
        DF(i,m,1)(f[i]+=f[i-1]*b)%=P;
    }
    F(i,0,m)(S+=f[i]*fac[i]%P*fac[m-i]%P*fp(fac[m],P-2))%=P;
    (S*=fp(m+1,P-2))%=P;
    cout<<((x+add*S)%P)<<"\n";
}
signed main(){
    ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);
    int T;cin>>T;
    while(T--)solve();
    return 0;
}