题解:P1054 [NOIP 2005 提高组] 等价表达式

· · 题解

首先感谢 @shihaocheng110909

与大部分的写法不同,这里采用分治的写法(比栈要好写)。每次分治查找区间内优先级最低的符号,然后分别递归计算这个符号两边的算式即可。由于数据较小,可以直接枚举 a 的值,然后依次对于每个算式进行判断结果是否相同。

需要注意的是,由于该题是远古题目,所以给定的字符串里可能有不可见字符(比如 \r),还有括号可能不匹配,我们需要对这种情况对于特殊处理。最后,答案的存储可以用 unsigned long long 的自然溢出,用 hash 的原理,比较最终结果即可。

由于 @shihaocheng110909 的代码马蜂优良,所以使用他的代码。他线下同意了我使用他的代码。

#include<bits/stdc++.h>
using namespace std;
typedef unsigned long long ull;
string s[30];
bool vis[30];
ull ll,rr;
int n;
ull yx(char ch){
    if(ch=='+'||ch=='-')return 1;
    if(ch=='*')return 2;
    if(ch=='^')return 3;
}
ull num(int l,int r,string ss){
    ull ans=0;
    for(int i=l;i<=r;i++)ans=ans*10+ss[i]-'0';
    return ans;
}
ull solve(int l,int r,string ss){
    ull ps=0,mid=0;
    for(int i=l;i<=r;i++){
        if(ss[i]=='('){ps++;continue;}
        if(ss[i]==')'){ps--;continue;}
        if(!ps&&(ss[i]=='+'||ss[i]=='*'||ss[i]=='-'||ss[i]=='^')){
            if(yx(ss[mid])>=yx(ss[i]))mid=i;
        }
    }
    ull L=l,R=r;
    while(ss[L]!='+'&&ss[L]!='*'&&ss[L]!='-'&&ss[L]!='^'&&ss[L]!='('&&ss[L]!=')'&&(ss[L]<'0'||ss[L]>'9'))L++;
    while(ss[R]!='+'&&ss[R]!='*'&&ss[R]!='-'&&ss[R]!='^'&&ss[R]!='('&&ss[R]!=')'&&(ss[R]<'0'||ss[R]>'9'))R--;
    if(mid==0){
        if(ss[L]=='('&&ss[R]==')')return solve(L+1,R-1,ss);
        return num(L,R,ss);
    }else{
        if(ss[mid]=='+')return solve(L,mid-1,ss)+solve(mid+1,R,ss);
        if(ss[mid]=='-')return solve(L,mid-1,ss)-solve(mid+1,R,ss);
        if(ss[mid]=='*')return solve(L,mid-1,ss)*solve(mid+1,R,ss);
        if(ss[mid]=='^'){
            ull s1=solve(L,mid-1,ss);
            ull s2=solve(mid+1,R,ss);
            ull ans=1;
            while(s2--)ans*=s1;
            return ans;
        }
    }
}
int main(){
    getline(cin,s[0]);
    s[0]='('+s[0]+')';
    for(int i=1;i<s[0].size()-1;i++){
        if(s[0][i]=='(')ll++;
        if(s[0][i]==')')rr++;
    }
    if(ll<rr){
        while(rr>ll){
            s[0]='('+s[0];rr--;
        }
    }
    if(rr<ll){
        while(rr<ll){
            s[0]=s[0]+')';ll--;
        }
    }
    cin>>n;
    for(int j=1;j<=n;j++){
        char ch;cin>>ch;
        getline(cin,s[j]);
        if(ch!='\n')s[j]=ch+s[j];
        s[j]='('+s[j]+')';
        ll=rr=0;
        for(int i=1;i<s[j].size()-1;i++){
            if(s[j][i]=='(')ll++;
            if(s[j][i]==')')rr++;
        }
        if(ll<rr){
            while(rr>ll)s[j]='('+s[j],rr--;
        }
        if(rr<ll){
            while(rr<ll)s[j]=s[j]+')',ll--;
        }
    }
    for(char xx='0';xx<='9';xx++){
        ull fl=123456789987654321;
        for(int k=0;k<=n;k++){
            string t=s[k];
            for(int i=0;i<t.size();i++){
                if(t[i]=='a')t[i]=xx;
            }
            ull ls=solve(0,s[k].size()-1,t);
            if(fl==123456789987654321)fl=ls;
            if(ls!=fl)vis[k]=1;
        }
    }
    for(int i=1;i<=n;i++){
        if(!vis[i])cout<<char(i-1+'A');
    }
    return 0;
}