【题解】[BalkanOI 2012] handsome

· · 题解

给定 n 和长度为 n 的排列 p,和 m 个长度为 2 的串,和一个序列 s,求有多少长度为 n 的序列 a,使得 a_i \in \{1,2,3\}a 中不存在 m 个子串,并且求出序列 b 使得 b_i = a_{p_i},满足 b 的字典序小于 s。对 10 ^ 9 + 7 取模。

对于子串的限制,我们从 1\to n 依次填数,简单 DP 即可解决问题。

对于字典序限制,我们按 p 依次填数,简单 DP 即可解决问题。

现在同时考虑两个限制。我们考虑转化一下字典序限制。我们枚举第一个 <s 的位 k,那么 <k 的位置都必须等于 s,等于 k 的位置必须小于 s,其余位置可以随便填。

那么我们依次枚举 k,每次只会有两个位置的限制发生变化。那么我们可以简单线段树维护矩阵优化 DP 即可。时间复杂度 \mathcal{O}(27N\log N)

#define N 400005
int p[N], n, m; char s[N];
struct mat{
    int a[3][3];
    mat(){memset(a, 0, sizeof(a));}
    void init(){
        memset(a, 0, sizeof(a));
        a[0][0] = a[1][1] = a[2][2] = 1;
    }
    void out(){
        el;
        rep(i, 0, 2){rep(j, 0, 2)printf("%d ", a[i][j]); el;}
    }
}b;
mat operator*(mat a, mat b){
    mat c;
    rep(i, 0, 2)rep(k, 0, 2)rep(j, 0, 2)c.a[i][j] = (c.a[i][j] + a.a[i][k] * (LL)b.a[k][j]) % P;
    return c;
};
mat a[N << 2], u[N];
#define ls (x << 1)
#define rs (ls | 1)
void build(int x,int l,int r){
    if(l == r)a[x] = u[l];
    else{
        int mid = (l + r) >> 1;
        build(ls, l, mid), build(rs, mid + 1, r);
        a[x] = a[ls] * a[rs];
    }
}
void ins(int x,int l,int r,int pos,mat w){
    if(l == r)a[x] = w;
    else{
        int mid = (l + r) >> 1;
        if(mid >= pos)ins(ls, l, mid, pos, w);
        else ins(rs, mid + 1, r, pos, w);
        a[x] = a[ls] * a[rs];
    }
}
int ans;
void calc(int i){
    if(i > 1){
        int x = p[i - 1], op = s[x] - '1'; mat c = u[x];
        rep(l, 0, 2)rep(r, 0, 2)if(r != op)c.a[l][r] = 0;
        ins(1, 1, n, x, c);
    }
    int x = p[i], op = s[x] - '1'; mat c = u[x];
    rep(l, 0, 2)rep(r, 0, 2)if(r >= op)c.a[l][r] = 0;
    ins(1, 1, n, x, c);
    rep(r, 0, 2)ad(ans, a[1].a[0][r]);
}
int main() {
    read(n);
    rp(i, n)read(p[i]);
    read(m);
    rep(i, 0, 2)rep(j, 0, 2)b.a[i][j] = 1;
    u[1] = b;
    rp(i, m){
        char op[5];
        scanf("%s", op);
        b.a[op[0] - '1'][op[1] - '1'] = 0;
    }
    rep(i, 2, n)u[i] = b;
    scanf("%s", s + 1);
    build(1, 1, n);
    rp(i, n)calc(i);
    ad(ans, 1);
    printf("%d\n", ans);
    return 0;
}