题解 P3167 【[CQOI2014]通配符匹配】
其实……不用动态规划。
很容易想到的是,按照"*"将模板串分段,那么每一段其实都是相互独立的——只要在询问串中能相继找到这些段即可。(我语文不好...以下你可能会看到不太合语文语法的"段",意指按照'*'分隔出来的字符串)
比如:
*abc*bcd*efg* -> abc, bcd, efg
分成三段。只要这三段都能在询问里先后找到即合法。而且显然的是,找到的越前面就越优(找到越前意味着留给后面的选择越多,毕竟‘*’可以任意跳,实在不行把多出来的交给‘*’跳过就好了)。当然,收尾如果不是通配符就要直接比较一下,因为首尾固定了不能动。
那么想想怎么找。。。
首先最好想到的当然是KMP,把'?'跟任何一个字符都认为相等就好了嘛。你看这个复杂度,它多么优秀啊!走,KMP!
(这么想的同学转到楼下大佬的题解。。我模拟赛时打的也是KMP。。。)
没了KMP,我们能快速匹配字符串的方法好像只有hash了。
(原理,懂的同学可以跳过)令数组hash[i]表示str[i] + str[i + i] * base + str[i + 2] * base ^ 2 + ... + str[n] * base ^ (n - i),其中str[i]表示第i位字符,base为任意指定的正整数,hash[n]本质上相当于以一个base进制数表示1到n这段的字符串
那么有hash[i] = hash[i + 1] * base + str[i],我们可以O(n)递推出一个字符串的hash数组。
而且对于其中一个子串[x, x + len],它的hash值我们可以通过对hash数组作减法得到:
son_hash(begin, length) = hash[begin] - hash[begin + length] * base ^ length
只要我们再预处理出所有base^i,上面的式子就是O(1)可得的了。
也就是说,O(n)的预处理后,对于一个给定的字符串中的任意子串,我们可以用O(1)的时间求出其hash值。
一般来说,我们都不会希望一个数字"爆int","爆long long",所以会给它加上一个取模。但对于hash值来说,我们可以不必去模它,而是使用unsigned long long让它自然溢出,这样做比取模要快很多。虽然这样是可能被卡掉的(因为这样就是%2^63次),但一般不会有恶毒的出题人特意来卡自然溢出。
来说说具体的。
首先明确一下问题,原问题已经转化成了给定一个带‘?’的字符串,求在给定字符串中出现的最早位置。(不停的做这个问题,就能处理完所有段了)
显然的是,如果把'?'的值算在hash值里面,我们之前关于hash值比较的一大段话就全当放屁了。(把‘?’的算入hash值,能相等才怪——相等了反倒就是hash冲突了。。)
所以'?'要被单独约谈处理。对于已经按照'*'分隔出来的每一段,我们再把它按照'?'单独分开来,一段一段的找。注意这里和'*'的不同之处:'*'可以不连续,但'?'只能抵掉一个字符。
那么我们要比较按照'*'分隔的段,比较的大概就是下面这种东西:
hash_val1, ?, ?, hash_val2, ?, hsah_val3, ……
比较hash值的时候比去就好了,如果碰到'?'就把询问串的起始指针(下标)往后移动一位,因为'?'匹配且只匹配一个字符。
当然,如果匹配到后面发现匹配不上,就要把询问串的指针往前回溯,回溯到之前的位置+1。听起来很复杂,但实际上写成多个函数就好多了。
我的代码实现未必和上述吻合,但思路是一致的,实现就见仁见智吧。
代码里面会加一小点注释,读者可以细细品味。
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <vector>
namespace my {
typedef unsigned long long ull;
const int maxn(112345);
const ull base(23);//选择的进制数
ull power[maxn];//power[i] = base ^ i,其中'^'表示幂
inline void init_power() {
power[0] = 1;
for (int i(1); i != maxn; ++i) {
power[i] = power[i - 1] * base;
}
}
class string {//自己写的string,封装了一下hash相关的函数
public:
string() : end(0) {}
char& operator[](int p) {
return str[p];
}
void clear() { end = 0; }
void read() {//读入
end = 0;
char c(getchar());
while (c < 'a' || c > 'z') c = getchar();
str[end++] = 'a';//在字符串收尾统一加入相同字符不会影响答案,且避免了结尾、开头是通配符的情况。
do {
str[end++] = c;
c = getchar();
} while (c >= 'a' && c <= 'z');
str[end++] = 'a';
init_hash();
}
bool empty() const { return end == 0; }
int size() const { return end; }
void push_back(char c) {//向该string中加入字符
if (c == '?') pos.push_back(end);//pos记录了该string所有'?'的位置(原因是我在分段时只调用这个函数向string中加入字符)
str[end++] = c;
}
void init_hash() {//初始化hash,原理如上所述
hash[end] = 0;
for (int i(end - 1); i >= 0; --i) {
hash[i] = hash[i + 1] * base + str[i];
}
}
ull gethash(int beg, int len) const {//获得其中某一段的hash值
return hash[beg] - hash[beg + len] * power[len];
}
ull gethash() const {//获得整段的hash值,程序中似乎没有用到
return hash[0];
}
std::vector<int> pos;//pos记录所有'?'的位置
protected:
char str[maxn];
ull hash[maxn];
int end;
}seg[20], head, tail, dest;//四者分别是:模板串分出来的每一段,模板串的首,模板串的尾,目标串(询问串)
char str[maxn];
int endseg, lenstr;//endseg即分出来的段数
void init_seg(int len) {//把模板串分段
int left(0), right(len - 1);
while (left <= right) {//先处理首
if (str[left] == '*') break;
head.push_back(str[left++]);
}
head.init_hash();
int pos(right);
while (left <= pos) {//由于尾串大小不定,先找到起始点
if (str[pos] == '*') break;
--pos;
}
for (int i(pos + 1); i <= right; ++i) {
tail.push_back(str[i]);
}
tail.init_hash();
++left; right = pos;
while (left <= right) {//处理出所有的段
if (str[left] == '*') {//每段间的分隔符
if (!seg[endseg].empty())//为避免ab******c的情况,判断一下是否为空
seg[endseg].init_hash(), ++endseg;//非空,初始化hash值
++left;
continue;
} else {
seg[endseg].push_back(str[left++]);//加入段中
}
}
if (!seg[endseg].empty()) ++endseg;//如果endseg中有元素,++endseg使其指向超出末端下一位
}
inline bool match(int s, int beg) {//常识将段seg[s]与目标串beg处开始的字符串进行匹配
int front(0);
for (int i(0); i != seg[s].pos.size(); ++i) {//枚举'?'间的字符串
int len(seg[s].pos[i] - front);
if (seg[s].gethash(front, len) != dest.gethash(beg, len)) {//hash值不相等,一定没对上
return false;
}
beg += len + 1;
front += len + 1;//匹配上了,那么继续匹配下一个。+1是表示这里跳了一个字符(即'?')
}
int len(seg[s].size() - front);
if (seg[s].gethash(front, len) != dest.gethash(beg, len)) return false;//最后没有'?'之后可能还有一段需要比较
return true;
}
inline bool com_seg(int s, int& left, int right) {
while (left <= right) {//尝试将段seg[s]与dest在[left, right]间的子串进行匹配
if (match(s, left)) {
left += seg[s].size();//如果成功匹配,left就往后跳到下一个需要匹配的起始位置
return true;
}
++left;//匹配失败,尝试++left继续匹配
}
return false;
}
inline bool cmpstr(string& a, int x, string& b, int y, int len) {//比较a从x处的长度为len的字符串是否与b从y处开始的长度为len的字符串相等
for (int i(0); i != len; ++i) {
if (a[x + i] != b[y + i] && a[x + i] != '?' && b[y + i] != '?') return false;
}
return true;
}
void compare() {//整体的匹配
int left(0), right(dest.size() - 1);
if (head.size() > dest.size() || tail.size() > dest.size()) {//首、尾串大小对不上,一定错了。
printf("NO\n");
return;
}
if (!cmpstr(head, 0, dest, 0, head.size())) {//由于插入了'a',模板串的首串一定不是通配符,一定要与目标串的开头匹配。
printf("NO\n");
return;
} else {
left = head.size();
}
if (!cmpstr(tail, 0, dest, right - tail.size() + 1, tail.size())) {//与上同理,匹配尾串
printf("NO\n");
return;
}
else {
right = right - tail.size();
}
for (int i(0); i != endseg; ++i) {
if (!com_seg(i, left, right)) {尝试将段seg[i]与left到right的字符串匹配(可以跳跃,因为段的收尾都是被抹去的'\*')
printf("NO\n");
return;
}
}
printf("YES\n");
}
int main() {
init_power();//一些初始化和函数调用
int T;
scanf("%s%d", str + 1, &T);
str[0] = 'a';
str[lenstr = ::strlen(str)] = 'a';
init_seg(++lenstr);
while (T--) {
dest.read();
compare();
}
return 0;
}
}
int main() {
return my::main();
}