题解 P7324 【[WC2021] 表达式求值】

· · 题解

考场上人傻了……

首先考虑建出 E 的表达式树,它的特点如下:

显然可以把 A 矩阵的每一列分开算,那么相当于我们要求解 n 次下面的问题:

观察到由于运算都是取最大值、最小值,最终的结果一定在 A_0,...,A_{m-1} 中产生。不妨假定它们互不相同(对于相同的情况,可以钦定编号小的更小一些)。我们考虑对每个 A_i(0\leq i < m),计算 A_i 成为最终答案的方案数。

观察到对于每个 iA_i 成为最终答案的方案数仅仅和 A_0,...,A_{m-1} 中哪些数比 A_i 小有关。我们考虑设计如下状态进行动态规划:

转移显而易见:

f(S,u,k)=\sum_{\max\{i,j\}=k}f(S,l,i)\times f(S,r,j)

若为 <,那么

f(S,u,k)=\sum_{\min\{i,j\}=k}f(S,l,i)\times f(S,r,j)

若为 ? 则把上两式相加即可。

设原树根节点为 r。对于原矩阵 A 的每一列,我们将其排序然后查出对应的 f(S,r,1) 计算即可。这就获得了一个 O(m2^m|E|+n\mathrm{poly}(m)) 的做法。瞟一眼数据范围……居然只有暴力分!怎么回事?

观察数据范围,发现这个做法已经十分接近正解。考虑把第一步枚举 i 的步骤去掉。我们进一步观察到:只需要知道答案大于等于某个 a_i 的方案,将其求和也能得到答案!这样就省去了具体枚举 i 的时间。换句话说,我们重新定义状态:

同上可列出转移方程。观察到计算过程中和 i 是无关的,因此省去了枚举。计算答案时,不妨设排序后 A_{p_0}\leq A_{p_1}\leq ...\leq A_{p_{m-1}},答案应当是

\sum_{i=0}^{m-1}(A_{p_i}-A_{p_{i-1}})\times g(\{p_0,...,p_{i-1}\},r,1)

复杂度降低至 O(2^m|E|+n\mathrm{poly}(m)),可以通过此题。

Q & A

Q:你是怎么做到只拿 70 分的?这题不是随便切吗?

A:考场上我就想到了上面那个 O(m2^m|E|+n\mathrm{poly}(m)) 做法……然后脑子就短路了……在下菜了 T_T

#include<cstdio>
#include<vector>
#include<cstring>
#include<algorithm>

using namespace std;

const long long MOD=1000000007;
long long add(long long x,long long y){return x+y>=MOD?x+y-MOD:x+y;}

long long A[20][100000];char E[100000];

int bl[100000];vector<int> st;

int l[100000],r[100000],op[100000],tot;

int new_node(int x,int y,int o)
{
    tot++;l[tot]=x,r[tot]=y,op[tot]=o;return tot;
}

int build(int l,int r)
{
    //if(r>0)printf("%d %d\n",l,r);
    if(bl[r])
    {
        if(bl[r]==l)return build(l+1,r-1);
        else
        {
            int x=build(l,bl[r]-2),y=build(bl[r]+1,r-1);
            return new_node(x,y,E[bl[r]-1]);
        }
    }
    int y=new_node(0,0,E[r]-'0');if(l==r)return y;
    int x=build(l,r-2);return new_node(x,y,E[r-1]);
}
void output(int u)
{
    if(!u)return;
    printf("O: %d %d %d %d\n",u,l[u],r[u],op[u]);
    output(l[u]);output(r[u]);
}

long long dp[100000][2];
void dfs(int u,int sta)
{
    for(int i=0;i<2;i++)dp[u][i]=0;
    if(!l[u])
    {
        if(sta&(1<<op[u]))dp[u][0]=1;else dp[u][1]=1;
        return;
    }
    dfs(l[u],sta),dfs(r[u],sta);
    if(op[u]!='>')
    {
        for(int i=0;i<2;i++)for(int j=0;j<2;j++)
            dp[u][min(i,j)]=add(dp[u][min(i,j)],dp[l[u]][i]*dp[r[u]][j]%MOD);
    }
    if(op[u]!='<')
    {
        for(int i=0;i<2;i++)for(int j=0;j<2;j++)
            dp[u][max(i,j)]=add(dp[u][max(i,j)],dp[l[u]][i]*dp[r[u]][j]%MOD);
    }
}

long long coe[1<<10];

int x;
int bitcmp(int i,int j){return A[i][x]<A[j][x];}

int p[10];

int main()
{
    //freopen("expr.in","r",stdin);
    //freopen("expr.out","w",stdout);

    int n=0,m=0;scanf("%d%d",&n,&m);
    for(int i=0;i<m;i++)for(int j=1;j<=n;j++)scanf("%lld",&A[i][j]);
    scanf("%s",E+1);int l=strlen(E+1);
    //printf("%d\n",l);
    for(int i=1;i<=l;i++)
    {
        if(E[i]=='(')
        {
            st.push_back(i);
        }
        else if(E[i]==')')
        {
            int u=st.back();st.pop_back();bl[u]=i,bl[i]=u;
        }
    }
    int rt=build(1,l);//output(rt);
    for(int i=0;i<(1<<m);i++)
    {
        dfs(rt,i);coe[i]=dp[rt][1];
    }
    //printf("%d\n",coe[0]);
    long long ans=0;
    for(int j=1;j<=n;j++)
    {
        for(int i=0;i<m;i++)p[i]=i;
        x=j;
        sort(p,p+m,bitcmp);
        //for(int i=0;i<m;i++)printf("%d ",p[i]);puts("");
        ans=add(ans,coe[0]*A[p[0]][j]%MOD);
        for(int i=1,sta=0;i<m;i++)
        {
            sta|=(1<<(p[i-1]));
            ans=add(ans,coe[sta]*(A[p[i]][j]-A[p[i-1]][j])%MOD);
        }
    }
    printf("%lld",ans);
    return 0;
}