题解:CF917E Upside Down

· · 题解

Upside Down

考虑随机选择 \sqrt n 个点作为关键点,则对于路径可以分成左端点到第一个关键点和第一个关键点到右端点两部分,前者期望长度 \sqrt n,直接暴力哈希,后者可以对于每个关键点跑一遍 dfs,求字符串在当前节点到关键点的出现次数可以用 AC 自动机外加值域分块来解决,时间复杂度 O(n\sqrt n)

树上路径哈希可以预处理到根正着的和反着的哈希值来计算。

你以为你过了,实则不然,你会被卡常数不知多少次。

首先当询问路径长度小于某个数值时,这个根据自己代码而定,可以直接拎出来哈希或者 KMP 计算,可以减小常数。

然后对于询问串很短的可以将这部分串直接拎出来,与前者相同。

块长大约是 \frac{\sqrt n}{1.35},可以根据自己代码而定。

不要用 long long。

以及在大量细节上卡常。

经过这一顿折腾,如果你像我一样用的自然溢出会 WA 在第 117 个点上,你可以像我一样偷懒特判,也可以发现对于这些数据跑的并不慢,可以换成双哈希并使用巴雷特模乘,总之再折腾一顿就卡过去了。

史上最长代码:

#include <bits/stdc++.h>

using namespace std;

const int N = 1e5 + 5, base = 998244853;

int n, m, q, dpt[N], st[N][21], Fa[N], id[N], is[N], ans[N], B, t[N];

int he[N], Nxt[2 * N], To1[2 * N], ct;

char To2[2 * N], fu[N], fw[N]; 

inline void add (const int u, const int v, const char c) {
    ++ ct;
    To1[ct] = v, To2[ct] = c;
    Nxt[ct] = he[u];
    he[u] = ct;
    return;
} 

string s[N];

char huan[26] = {107, 3, 5, 7, 11, 13, 17, 23, 29, 31, 37, 41, 43, 47, 53, 61, 67, 71, 73, 79, 83, 89, 91, 97, 101, 103}; 

int ID[N];

struct AC {
    int to[N][26], fa[N], dfn[N], siz[N], idx = 1, num = 0;

    int h[N], nxt[N], To[N], ct = 0;

    int q[N], l, r;

    inline void add (const int u, const int v) {
        To[ ++ ct] = v;
        nxt[ct] = h[u];
        h[u] = ct;
        return;
    }

    struct kuai {
        int B = 316;
        int sum1[N], sum2[N / 316 + 5], id[N], L[N], R[N];
        kuai() {
            for (int l = 1, j = 1; l <= 100001; l += B, ++ j ) {
                int r = min(100001, l + B - 1);
                for (int i = l; i <= r; ++ i ) {
                    L[i] = l, R[i] = r;
                    id[i] = j;
                } 
            }
        }   

        inline void update (const int x, const int k) {
            sum1[x] += k;
            sum2[id[x]] += k;
            return;
        }

        inline int query (const int l, const int r) {
            int res = 0;
            if (id[l] == id[r]) {
                for (int i = l; i <= r; ++ i ) res += sum1[i];
                return res;
            }
            if (R[l] - l + 1 < l - L[l]) {
                for (int i = l; i <= R[l]; ++ i ) res += sum1[i];
            } else {
                res += sum2[id[l]];
                for (int i = L[l]; i < l; ++ i ) res -= sum1[i];    
            }
            for (int i = id[l] + 1; i < id[r]; ++ i ) res += sum2[i];
            if (r - L[r] + 1 < R[r] - r) {
                for (int i = L[r]; i <= r; ++ i )  res += sum1[i];
            } else {
                res += sum2[id[r]];
                for (int i = r + 1; i <= R[r]; ++ i ) res -= sum1[i];
            }
            return res;
        }

    } T;

    void insert (string &s, const int i) {
        int at = 1;
        for (char c : s) {
            const int w = c - 'a';
            if (!to[at][w]) to[at][w] = ++ idx;
            at = to[at][w];
        }
        ID[i] = at;
        return;
    }

    void dfs (const int u) {
        dfn[u] = ++ num;
        siz[u] = 1;
        for (int i = h[u]; i; i = nxt[i] ) {
            dfs(To[i]);
            siz[u] += siz[To[i]];
        }
        return;
    }

    void build() {
        for (int i = 0; i < 26; ++ i ) to[0][i] = 1;
        l = r = 1;
        q[1] = 1;
        while (l <= r) {
            const int u = q[l ++ ];
            const int f = fa[u];
            for (int i = 0; i < 26; ++ i ) {
                const int v = to[u][i];
                if (!v) to[u][i] = to[f][i];
                else {
                    fa[v] = to[f][i];
                    add(fa[v], v);
                    q[ ++ r] = v;
                }
            }
        }
        dfs(1);
        return;
    }

    inline void modify (const int x, const int k) {
        T.update(dfn[x], k);
        return; 
    }

    inline int query (const int x) {
        return T.query(dfn[x], dfn[x] + siz[x] - 1);
    }

} T;

struct que {
    int u, v, k;
};

que b[N];

unsigned long long h1[N], h2[N], c[N], inv[N], h[N];

int cx[N], cnt;

vector <pair <int, pair <int, char> > > M; 

void dfs (const int u, const int fa) {
    Fa[u] = fa;
    st[u][0] = fa;
    for (int j = 1; j <= 16; ++ j ) st[u][j] = st[st[u][j - 1]][j - 1];
    for (int i = he[u]; i; i = Nxt[i] ) {
        const int v = To1[i];
        if (v == fa) continue;
        const char C = To2[i];
        dpt[v] = dpt[u] + 1;
        fu[v] = C;
        fw[v] = huan[C - 'a'];
        h1[v] = (h1[u] * base + huan[C - 'a']);
        h2[v] = (h2[u] + huan[C - 'a'] * c[dpt[u]]);
        M.push_back({u, {v, C}});
        dfs(v, u);
    }
    cx[ ++ cnt] = u; 
    return;
}

inline int lca (int u, int v) {
    if (dpt[u] < dpt[v]) swap(u, v);
    for (int i = 16; i >= 0; -- i )
        if (dpt[st[u][i]] >= dpt[v])
            u = st[u][i];

    if (u == v) return u;
    for (int i = 16; i >= 0; -- i )
        if (st[u][i] != st[v][i])
            u = st[u][i],
            v = st[v][i];

    return st[u][0];
}

inline int dis (const int u, const int v) {
    return dpt[u] + dpt[v] - 2 * dpt[lca(u, v)];
}

inline int dis (const int u, const int v, const int k) {
    return dpt[u] + dpt[v] - 2 * dpt[k];
}

inline unsigned long long H1 (const int u, const int v) {
    return (h1[v] - h1[u] * c[dpt[v] - dpt[u]]); 
}

inline unsigned long long H2 (const int u, const int v) {
    return (h2[u] - h2[v]) * inv[dpt[v]];
}

inline unsigned long long H (const int u, const int v, const int k) {
    return (H2(u, k) * c[dpt[v] - dpt[k]] + H1(k, v));
}

inline int jump (int u, int k) {
    for (int i = 0; i <= 16; ++ i ) 
        if (k >> i & 1)
            u = st[u][i];

    return u;
} 

inline int to (const int u, const int v, int k) {
    const int K = lca(u, v);
    if (dpt[u] - k >= dpt[K]) return jump(u, k);
    k -= dpt[u] - dpt[K];
    return jump(v, (dpt[v] - dpt[K]) - k);
}

int U[N], V[N], tmp[N], head[355][N], nxt[N], To[N], idx;

bool ok[N];

inline void add (const int u, const int id, const int i) {
    To[ ++ idx] = i;
    nxt[idx] = head[id][u];
    head[id][u] = idx;
    return;
}

inline void Get (int u, int v, int *U, int &idx) {
    const int K = lca(u, v);
    idx = 0;
    int cnt = dpt[u] - dpt[K] + 1;
    while (cnt -- ) {
        U[ ++ idx] = u;
        ok[idx] = false;
        u = Fa[u];
    }
    int ct = idx + dpt[v] - dpt[K];
    idx += dpt[v] - dpt[K];
    cnt = dpt[v] - dpt[K];
    while (cnt -- ) {
        U[ct -- ] = v;
        ok[idx] = true;
        v = Fa[v];
    }
    return;
}

inline void Get2 (int u, int v, int *U, int &idx) {
    const int K = lca(u, v);
    idx = 0;
    int cnt = dpt[u] - dpt[K] + 1;
    while (cnt -- ) {
        U[ ++ idx] = u;
        u = Fa[u];
    }
    int ct = idx + dpt[v] - dpt[K];
    idx += dpt[v] - dpt[K];
    cnt = dpt[v] - dpt[K];
    while (cnt -- ) {
        U[ct -- ] = v;
        v = Fa[v];
    }
    return;
}

int now; 

void DFS (const int u, const int at) {
    for (int i = head[now][u]; i; i = nxt[i] ) {
        const int id = To[i];
        const int k = b[id].k;
        ans[id] += T.query(ID[k]);
    }
    for (int i = he[u]; i; i = Nxt[i] ) {
        const int ID = T.to[at][To2[i] - 'a'];
        T.modify(ID, 1);
        DFS(To1[i], ID);
        T.modify(ID, -1);
    }
    return;
}

unsigned long long hs[N];

char S[N];

void Get_S (int u, int v, int &idx) {
    idx = 0;
    const int K = lca(u, v);
    cnt = dpt[u] - dpt[K];
    while (cnt -- ) {
        S[ ++ idx] = fw[u];
        u = Fa[u];
    }
    int ct = idx + dpt[v] - dpt[K];
    idx += dpt[v] - dpt[K]; 
    cnt = dpt[v] - dpt[K]; 
    while (cnt -- ) {
        S[ct -- ] = fw[v];
        v = Fa[v];
    }
    for (int i = 1; i <= idx; ++ i ) hs[i] = hs[i - 1] * base + S[i];
    return;
} 

int fa[N][9]; 

struct KMP {
    char s[N], t[N];
    int len, nxt[N];

    void build (string S) {
        len = 0;
        for (char c : S) t[ ++ len] = c;
        t[len + 1] = 0;
        for (int i = 2, j = 0; i <= len; ++ i ) {
            while (j && t[j + 1] != t[i]) j = nxt[j];
            if (t[j + 1] == t[i]) ++ j;
            nxt[i] = j;
        }
        return;
    }

    vector <int> query() {
        vector <int> v;
        for (int i = 1, j = 0; i < n; ++ i ) {
            while (j && t[j + 1] != s[i]) j = nxt[j];
            if (t[j + 1] == s[i]) ++ j;
            if (j == len) {
                v.push_back(i - len + 1);
                j = nxt[j];
            }
        }
        return v;
    }

} K;

int s1[114][N], s2[114][N];

signed main() {
    ios :: sync_with_stdio(0), cin.tie(0), cout.tie(0);

    srand(1145141);

    for (int i = 0; i < 26; ++ i ) swap(huan[i], huan[rand() % 26]);

    c[0] = 1;

    for (int i = 1; i <= 1e5; ++ i ) c[i] = c[i - 1] * base;

    inv[0] = 1;

    for (int i = 1; i <= 1e5; ++ i ) inv[i] = inv[i - 1] * 7170437717599314525ull;

    cin >> n >> m >> q;

    bool ok = true;

    for (int i = 1, u, v; i < n; ++ i ) {
        char c;
        cin >> u >> v >> c;
        add(u, v, c);
        add(v, u, c);
        K.s[i] = c;
        ok &= v == u + 1;
    }

    if (ok && m < 114) {
        for (int i = 1; i <= m; ++ i ) {
            cin >> s[i];
            K.build(s[i]);
            vector <int> id = K.query();
            for (int j : id) ++ s1[i][j];
            for (int j = 1; j < n; ++ j ) s1[i][j] += s1[i][j - 1];
        }
        for (int i = 1; i <= m; ++ i ) {
            reverse(s[i].begin(), s[i].end());
            K.build(s[i]);
            vector <int> id = K.query();
            for (int j : id) ++ s2[i][j];
            for (int j = 1; j < n; ++ j ) s2[i][j] += s2[i][j - 1];
        }
        while (q -- ) {
            int u, v, k;
            cin >> u >> v >> k;
            if (abs(u - v) < s[k].size()) {
                cout << 0 << '\n';
                continue;
            }
            if (u < v) {
                int l = u, r = v - 1;
                r = r - s[k].size() + 1;
                cout << s1[k][r] - s1[k][l - 1] << '\n';
                continue;
            }
            int l = v, r = u - 1;
            r = r - s[k].size() + 1; 
            cout << s2[k][r] - s2[k][l - 1] << '\n';
        }
        return 0;
    }

    dpt[0] = -1;
    dfs(1, 0);

    memset(he, 0, sizeof(he));
    memset(Nxt, 0, sizeof(Nxt));
    memset(To1, 0, sizeof(To1));
    memset(To2, 0, sizeof(To2));
    ct = 0;

    for (auto i : M) add(i.first, i.second.first, i.second.second);

    for (int i = 1; i <= m; ++ i ) {
        cin >> s[i];
        for (char c : s[i]) h[i] = (h[i] * base + huan[c - 'a']);
        T.insert(s[i], i);
    }

    T.build();

    for (int i = 1; i <= n; ++ i ) id[i] = i;

    B = sqrt(q) / 1.35;

    for (int i = 1; i <= n; ++ i ) swap(id[i], id[rand() % n + 1]);

    for (int i = 1; i <= B; ++ i ) is[id[i]] = i;

    for (int i = n; i >= 1; -- i ) {
        int u = cx[i];
        if (is[u]) t[u] = u;
        else t[u] = t[Fa[u]];
        fa[u][0] = t[u]; 
    }

    for (int j = 1; j <= 7; ++ j ) 
        for (int i = 1; i <= n; ++ i )
            fa[i][j] = fa[Fa[fa[i][j - 1]]][j - 1];

    for (int i = 1, u, v, k; i <= q; ++ i ) {
        cin >> u >> v >> k;
        b[i] = {u, v, k};
        if (dis(u, v) < s[k].size()) continue; 
        const int K = lca(u, v);
        if ((dpt[t[u]] < dpt[K] && dpt[t[v]] < dpt[K]) || dis(u, v, K) <= 1580) {
            int idx = 0;
            Get_S(u, v, idx);
            const int len = s[k].size();
            const char C1 = s[k][0], C2 = huan[s[k][len - 1] - 'a'];
            for (int l = 0, r = len; r <= idx; ++ l, ++ r ) {
                ans[i] += S[r] == C2 && hs[r] - hs[l] * c[len] == h[k];
            }   
        } else {
            int d = 0, idx = 0;
            if (dpt[t[u]] >= dpt[K]) d = t[u];
            else {
                d = t[v];
                for (int j = 7; j >= 0; -- j )
                    if (dpt[fa[Fa[d]][j]] >= dpt[K])
                        d = fa[Fa[d]][j];

            }
            add(v, is[d], i);
            if (d == u) continue;
            d = to(d, u, 1);
            const int at = to(v, u, s[k].size());
            if (dis(u, at) < dis(u, d)) {
                d = at;
            }
            if (s[k].size() <= 800 && dis(u, d) > 80) {
                if (dis(u, at) < dis(u, d)) {
                    d = at;
                }
                const int ct = dis(d, u) + 1;
                Get_S(u, to(d, v, s[k].size()), idx);
                const int len = s[k].size(), C = huan[s[k][0] - 'a'];
                for (int l = 1, r = len; l <= ct; ++ l, ++ r ) {
                    ans[i] += S[l] == C && hs[r] - hs[l - 1] * c[len] == h[k];
                }
                continue;
            }
            Get2(u, d, U, idx); 
            Get2(to(u, v, s[k].size()), to(U[idx], v, s[k].size()), V, idx);
            int o = 0;
            if (lca(U[1], V[1]) == V[1]) o = 0;
            else if (lca(U[1], V[1]) == K) o = 1;
            else o = 2;
            const unsigned long long hk = h[k];
            for (int j = 1; j <= idx; ++ j ) {
                if (V[j] == K) o = 1;
                if (U[j] == K) o = 2;
                if (o == 0) {
                    ans[i] += H2(U[j], V[j]) == hk;
                } else if (o == 1) {
                    ans[i] += H(U[j], V[j], K) == hk;
                } else {
                    ans[i] += H1(U[j], V[j]) == hk;
                }
            }
        } 
    }
    for (int i = 1; i <= B; ++ i ) {
        now = i;
        int lst = 0;
        int u = id[i], at = 1;
        while (true) {
            for (int j = head[now][u]; j; j = nxt[j] ) {
                const int id = To[j];
                const int k = b[id].k;
                ans[id] += T.query(ID[k]);
            }
            for (int j = he[u]; j; j = Nxt[j] ) {
                if (To1[j] == lst) continue;
                int ID = T.to[at][To2[j] - 'a'];
                T.modify(ID, 1);
                DFS(To1[j], ID);
                T.modify(ID, -1);
            }
            if (!Fa[u]) break;
            lst = u;
            at = T.to[at][fu[u] - 'a'];
            T.modify(at, 1); 
            u = Fa[u];
        }
        memset(T.T.sum1, 0, sizeof(T.T.sum1));
        memset(T.T.sum2, 0, sizeof(T.T.sum2)); 
    }

    for (int i = 1; i <= q; ++ i ) cout << ans[i] << '\n'; 

    return 0;
}