题解:P7377 [COCI2018-2019#5] Parametriziran

· · 题解

Parametriziran

思路

答案求合法的。那我们转换一下思路,求不合法的,再用全部的配对数 - 不合法的数量,不就是合法的数量了嘛?

那么我们考虑怎么求不合法的数量。

在本题中,两个字符串配对不合法,只能是在都不是 ? 的位上,并且至少有一位不一样。下文称一个位置不一样为:两个字符串在该位置上都不为 ? ,且不一样。

我们需要用到一点点容斥。考虑原本两个字符串有两个位置不一样,那么很明显,长度为 1 时我们算了两次,长度为 2 时我们算了一次,但很明显,我们应该只算 1 次。所以此时我们需要用到容斥来避免重复的减去不合法的对数。

首先,最多只会有 6 位,考虑状压,来确定现在为 1 的位置上不为 ? 的字符串有哪些?

在这时,我们称两个字符串不合法,当且仅当所有是 1 的位置上的字符都不一样。算出此状态的下不合法的对数。

算这个也可以用到容斥,具体怎么做呢?当前状态下的字符串指的是只看那些是 1 的位置。

考虑和当前字符串 i 配对不合法的对数,假设现在要考虑顺序,最后 \div 2 即可。那么用当前状态下的字符串总数 - 当前状态下和他一模一样的字符串数量。由于我们要的是每一位都不同的。这样把有一位,两位,三位.....和他相同的字符串也给算进去了。这个时候就可以又用容斥把多加的以及重复的影响消去。

具体实现步骤看代码。

时间复杂度为 O(2^{m}\times 2^{m}\times n)。常数优秀的话可以过,所以我用了 unprdered_map 来优化常数。

代码

#include<bits/stdc++.h>
#define endl '\n'

using namespace std;

const int N=5e4+10;

typedef long long ll;
typedef pair<int,int> pii;
typedef unsigned long long ull;

const int p=131;

int n,m;
int a[N][10];
int num[1<<7];

ll ans=0;

vector<int> ones[1<<7];
vector<int> zj[1<<7];

vector<int> yp[1<<7];
ull ha[1<<7][N];

void pre(){
    for(int stat=1;stat<1<<m;stat++) num[stat]=num[stat>>1]+(stat&1);
    for(int stat=1;stat<1<<m;stat++){
        for(int i=0;i<m;i++){
            if((stat>>i)&1) ones[stat].push_back(i);
        }
        for(int j=1;j<stat;j++){
            if((j&stat)==j) zj[stat].push_back(j);//j为stat的子集
        }
    }
    for(int s=1;s<1<<m;s++){
        for(int i=1;i<=n;i++){
            bool flag=true;
            ull hs=0;
            for(int j:ones[s]){
                if(!a[i][j]) flag=false;
                hs=hs*p+a[i][j];
            }
            if(!flag) continue; 
            ha[s][i]=hs;
            yp[s].push_back(i);
        }
    }
}//end

int cnt=0;
unordered_map<ull,int> mp[1<<7]; 

ll sum[100*N]; 
ll zdx[1<<7];

int main(){
    freopen("parametriziran.in","r",stdin);
    freopen("parametriziran.out","w",stdout);

    ios::sync_with_stdio(0);

    cin>>n>>m;
    for(int i=1;i<=n;i++){
        for(int j=0;j<m;j++){
            char x; cin>>x;
            if(x=='?') a[i][j]==0;
            else a[i][j]=x-'a'+1;
        }
    }

    pre();

    ans=1ll*(n-1)*n/2;//总的配数 

    for(int s=1;s<1<<m;++s){
        for(int i=1;i<=cnt;++i) sum[i]=0;
        for(int i=1;i<1<<m;++i){
            mp[i].clear();
            zdx[i]=0;
        }

        int f=1;
        if(!(num[s]&1)) f=-1;
        cnt=0;

        for(int i:yp[s]){
            ull hs=ha[s][i];
            if(!mp[s][hs]) mp[s][hs]=++cnt;

            sum[mp[s][hs]]++;
            zdx[s]++;

            for(int zs:zj[s]){
                hs=ha[zs][i];
                if(!mp[zs][hs]) mp[zs][hs]=++cnt;
                sum[mp[zs][hs]]++;
                zdx[zs]++; 
            }
        }

        ll res=0;
        for(int i:yp[s]){
            ull hs=ha[s][i];

            int id=mp[s][hs];
            int sid=mp[s][hs];
            res+=(zdx[s]-sum[id]);

            for(int zs:zj[s]){
                int fl=-1;
                if(!(num[zs]&1)) fl=1;
                hs=ha[zs][i];
                id=mp[zs][hs];
                res+=fl*(sum[id]-sum[sid]);
            }
        }

        ans=ans-f*res/2; 
    }

    cout<<ans<<endl;    

    return 0;
}//end