题解 P3167 【[CQOI2014]通配符匹配】

· · 题解

其实……不用动态规划。

很容易想到的是,按照"*"将模板串分段,那么每一段其实都是相互独立的——只要在询问串中能相继找到这些段即可。(我语文不好...以下你可能会看到不太合语文语法的"段",意指按照'*'分隔出来的字符串)

比如:

*abc*bcd*efg* -> abc, bcd, efg

分成三段。只要这三段都能在询问里先后找到即合法。而且显然的是,找到的越前面就越优(找到越前意味着留给后面的选择越多,毕竟‘*’可以任意跳,实在不行把多出来的交给‘*’跳过就好了)。当然,收尾如果不是通配符就要直接比较一下,因为首尾固定了不能动。

那么想想怎么找。。。

首先最好想到的当然是KMP,把'?'跟任何一个字符都认为相等就好了嘛。你看这个复杂度,它多么优秀啊!走,KMP!

(这么想的同学转到楼下大佬的题解。。我模拟赛时打的也是KMP。。。)

没了KMP,我们能快速匹配字符串的方法好像只有hash了。

(原理,懂的同学可以跳过)令数组hash[i]表示str[i] + str[i + i] * base + str[i + 2] * base ^ 2 + ... + str[n] * base ^ (n - i),其中str[i]表示第i位字符,base为任意指定的正整数,hash[n]本质上相当于以一个base进制数表示1到n这段的字符串

那么有hash[i] = hash[i + 1] * base + str[i],我们可以O(n)递推出一个字符串的hash数组。

而且对于其中一个子串[x, x + len],它的hash值我们可以通过对hash数组作减法得到:

son_hash(begin, length) = hash[begin] - hash[begin + length] * base ^ length

只要我们再预处理出所有base^i,上面的式子就是O(1)可得的了。

也就是说,O(n)的预处理后,对于一个给定的字符串中的任意子串,我们可以用O(1)的时间求出其hash值。

一般来说,我们都不会希望一个数字"爆int","爆long long",所以会给它加上一个取模。但对于hash值来说,我们可以不必去模它,而是使用unsigned long long让它自然溢出,这样做比取模要快很多。虽然这样是可能被卡掉的(因为这样就是%2^63次),但一般不会有恶毒的出题人特意来卡自然溢出。

来说说具体的。

首先明确一下问题,原问题已经转化成了给定一个带‘?’的字符串,求在给定字符串中出现的最早位置。(不停的做这个问题,就能处理完所有段了)

显然的是,如果把'?'的值算在hash值里面,我们之前关于hash值比较的一大段话就全当放屁了。(把‘?’的算入hash值,能相等才怪——相等了反倒就是hash冲突了。。)

所以'?'要被单独约谈处理。对于已经按照'*'分隔出来的每一段,我们再把它按照'?'单独分开来,一段一段的找。注意这里和'*'的不同之处:'*'可以不连续,但'?'只能抵掉一个字符

那么我们要比较按照'*'分隔的段,比较的大概就是下面这种东西:

hash_val1, ?, ?, hash_val2, ?, hsah_val3, ……

比较hash值的时候比去就好了,如果碰到'?'就把询问串的起始指针(下标)往后移动一位,因为'?'匹配且只匹配一个字符。

当然,如果匹配到后面发现匹配不上,就要把询问串的指针往前回溯,回溯到之前的位置+1。听起来很复杂,但实际上写成多个函数就好多了。

我的代码实现未必和上述吻合,但思路是一致的,实现就见仁见智吧。

代码里面会加一小点注释,读者可以细细品味。

#include <cstdio>
#include <cstdlib>
#include <cstring>

#include <vector>

namespace my {
    typedef unsigned long long ull;
    const int maxn(112345);
    const ull base(23);//选择的进制数
    ull power[maxn];//power[i] = base ^ i,其中'^'表示幂
    inline void init_power() {
        power[0] = 1;
        for (int i(1); i != maxn; ++i) {
            power[i] = power[i - 1] * base;
        }
    }
    class string {//自己写的string,封装了一下hash相关的函数
        public:
            string() : end(0) {}
            char& operator[](int p) {
                return str[p];
            }
            void clear() { end = 0; }
            void read() {//读入
                end = 0;
                char c(getchar());
                while (c < 'a' || c > 'z') c = getchar();
                str[end++] = 'a';//在字符串收尾统一加入相同字符不会影响答案,且避免了结尾、开头是通配符的情况。
                do {
                    str[end++] = c;
                    c = getchar();
                } while (c >= 'a' && c <= 'z');
                str[end++] = 'a';
                init_hash();
            }
            bool empty() const { return end == 0; }
            int size() const { return end; }
            void push_back(char c) {//向该string中加入字符
                if (c == '?') pos.push_back(end);//pos记录了该string所有'?'的位置(原因是我在分段时只调用这个函数向string中加入字符)
                str[end++] = c;
            }
            void init_hash() {//初始化hash,原理如上所述
                hash[end] = 0;
                for (int i(end - 1); i >= 0; --i) {
                    hash[i] = hash[i + 1] * base + str[i];
                }
            }
            ull gethash(int beg, int len) const {//获得其中某一段的hash值
                return hash[beg] - hash[beg + len] * power[len];
            }
            ull gethash() const {//获得整段的hash值,程序中似乎没有用到
                return hash[0];
            }
            std::vector<int> pos;//pos记录所有'?'的位置
        protected:
            char str[maxn];
            ull hash[maxn];
            int end;
    }seg[20], head, tail, dest;//四者分别是:模板串分出来的每一段,模板串的首,模板串的尾,目标串(询问串)
    char str[maxn];
    int endseg, lenstr;//endseg即分出来的段数
    void init_seg(int len) {//把模板串分段
        int left(0), right(len - 1);
        while (left <= right) {//先处理首
            if (str[left] == '*') break;
            head.push_back(str[left++]);
        }
        head.init_hash();
        int pos(right);
        while (left <= pos) {//由于尾串大小不定,先找到起始点
            if (str[pos] == '*') break;
            --pos;
        }
        for (int i(pos + 1); i <= right; ++i) {
            tail.push_back(str[i]);
        }
        tail.init_hash();
        ++left; right = pos;
        while (left <= right) {//处理出所有的段
            if (str[left] == '*') {//每段间的分隔符
                if (!seg[endseg].empty())//为避免ab******c的情况,判断一下是否为空
                    seg[endseg].init_hash(), ++endseg;//非空,初始化hash值
                ++left;
                continue;
            } else {
                seg[endseg].push_back(str[left++]);//加入段中
            }
        }
        if (!seg[endseg].empty()) ++endseg;//如果endseg中有元素,++endseg使其指向超出末端下一位
    }

    inline bool match(int s, int beg) {//常识将段seg[s]与目标串beg处开始的字符串进行匹配
        int front(0);
        for (int i(0); i != seg[s].pos.size(); ++i) {//枚举'?'间的字符串
            int len(seg[s].pos[i] - front);
            if (seg[s].gethash(front, len) != dest.gethash(beg, len)) {//hash值不相等,一定没对上
                return false;
            }
            beg += len + 1;
            front += len + 1;//匹配上了,那么继续匹配下一个。+1是表示这里跳了一个字符(即'?')
        }
        int len(seg[s].size() - front);
        if (seg[s].gethash(front, len) != dest.gethash(beg, len)) return false;//最后没有'?'之后可能还有一段需要比较 
        return true;
    }
    inline bool com_seg(int s, int& left, int right) {
        while (left <= right) {//尝试将段seg[s]与dest在[left, right]间的子串进行匹配
            if (match(s, left)) {
                left += seg[s].size();//如果成功匹配,left就往后跳到下一个需要匹配的起始位置
                return true;
            }
            ++left;//匹配失败,尝试++left继续匹配
        }
        return false;
    }
    inline bool cmpstr(string& a, int x, string& b, int y, int len) {//比较a从x处的长度为len的字符串是否与b从y处开始的长度为len的字符串相等
        for (int i(0); i != len; ++i) {
            if (a[x + i] != b[y + i] && a[x + i] != '?' && b[y + i] != '?') return false;
        }
        return true;
    }
    void compare() {//整体的匹配
        int left(0), right(dest.size() - 1);
        if (head.size() > dest.size() || tail.size() > dest.size()) {//首、尾串大小对不上,一定错了。
            printf("NO\n");
            return;
        }
        if (!cmpstr(head, 0, dest, 0, head.size())) {//由于插入了'a',模板串的首串一定不是通配符,一定要与目标串的开头匹配。
            printf("NO\n");
            return;
        } else {
            left = head.size();
        }
        if (!cmpstr(tail, 0, dest, right - tail.size() + 1, tail.size())) {//与上同理,匹配尾串
            printf("NO\n");
            return;
        }
        else {
            right = right - tail.size();
        }
        for (int i(0); i != endseg; ++i) {
            if (!com_seg(i, left, right)) {尝试将段seg[i]与left到right的字符串匹配(可以跳跃,因为段的收尾都是被抹去的'\*')
                printf("NO\n");
                return;
            }
        }
        printf("YES\n");
    }

    int main() {
        init_power();//一些初始化和函数调用
        int T;
        scanf("%s%d", str + 1, &T);
        str[0] = 'a';
        str[lenstr = ::strlen(str)] = 'a';
        init_seg(++lenstr);
        while (T--) {
            dest.read();
            compare();
        }
        return 0;
    }
}

int main() {
    return my::main();
}