题解:P7758 [COCI 2012/2013 #3] HERKABE

· · 题解

题解:P7758 [COCI 2012/2013 #3] HERKABE

题目大意

对于给定的 n 个字符串,求满足以下条件的排列数: 对于任意两个具有相同前缀的字符串,它们之间的所有字符串也必须具有这个前缀。

思路

要找出前缀相同的字符串(即有最长相同前缀的字符串放在一起)可以想到对 string 数组进行一次 sort,然后递归进行分组。

由于对于大小为 n 的组(并且没有两个字符串更长的相同前缀),全排列一共有 n! 种方式,所以考虑预处理模 10^9+7 意义下的阶乘。

核心代码有注释,不多做解释。

代码如下:

#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
#define endl putchar('\n')
#define isd(c) (c>='0'&&c<='9')
#define isa(c) ((c>='a'&&c<='z')||(c>='A'&&c<='Z'))
#define blank(c) (c=='\n'||c=='\r'||c=='\t'||c==' ')
#define For(i,a,b) for(auto i=(a);i<=(b);i++)
using namespace std;
namespace temp{
char gc(){//直接从缓冲区读入
    static char buf[1<<20],*p1=buf,*p2=buf;
    return p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<20,stdin),p1==p2) ? EOF : *p1++;
}
void read(string &s){//字符串快读
    char c=gc();
    while(blank(c))c=gc();
    while(!blank(c))s.push_back(c),c=gc();
}
template<class T>void read(T &n){//数字快读
    n=0;char c=gc();int k=1;
    while(!isd(c)){if(c=='-')k*=-1;c=gc();}
    while(isd(c))n=n*10+c-'0',c=gc();
    n*=k;
}
template<class T>void write(T n){//数字快写
    if(n<0)n=-n,putchar('-');
    if(n>9)write(n/10);
    putchar(n%10+'0');
}
}

const int MAXN = 3005;
const LL MOD = 1e9 + 7;
string s[MAXN];
LL fact[MAXN];
int n;

namespace Main{
using namespace temp;
LL dfs(int d, int l, int r) {//核心代码
    if (l >= r) return 1;// 递归边界:当区间为空时,只有1种排列方式
    int endcnt = 0;      // 统计在当前深度结束的字符串数量
    LL res = 1;             // 当前区间的方案数
    int childcnt = 0;    // 统计需要继续分组的子区间数量

    int p = l;  // 当前处理位置
    while (p < r) {
        if (d == (int)s[p].size())// 情况1:当前字符串在当前深度结束
            endcnt++,p++;
        else {// 情况2:当前字符串需要继续分组
            char c = s[p][d];
            int seg_l = p;// 子区间起始位置
            while (p < r && d < (int)s[p].size() && s[p][d] == c)p++;
            // 找到相同字符的连续区间
            int seg_r = p;
            res = res * dfs(d + 1, seg_l, seg_r) % MOD;// 递归处理子区间:深度+1,区间[seg_l, seg_r)
            childrencnt++;
        }
    }
    int tot = endcnt + childrencnt;
    res = res * fact[tot] % MOD;// 当前区间的方案数 = 分支排列数 × 子区间方案数乘积
    return res;
}

void Main() {
    read(n);
    For(i,0,n-1) read(s[i]);
    sort(s,s+n);
    fact[0]=1;
    For(i,1,n) fact[i] = fact[i - 1] * i % MOD;//预处理阶乘( mod 1e9+7)
    write(dfs(0, 0, n));
}
}

int main() {
    ios::sync_with_stdio(0);
    cin.tie(0), cout.tie(0);
    Main::Main();
    return 0;
}

时间复杂度:O(nl_{max})l_{max} 表示最长字符串的长度,可以通过此题,主要花销为 dfs 函数。