P6177 Count on a tree II/【模板】树分块 题解

· · 题解

这题调了六个来小时调自闭了…准确来说还是因为我树分块板子不够完善。

写完之后发现题解区没有我的做法,甚至 O(\sqrt{n}) 的做法都少得可怜,于是有了这篇题解。

思路是 Top_Cluster + 树上前缀和,离散化后预处理出每个界点到其他每个点的答案以及前缀值计数,然后散块内暴力。

其实 critnos 大佬的题解中有和我类似的思路,但是他用的是不够优美的树上撒点分块,而且目测是口胡,这题实现的时候细节特别多。

不妨设 dep_u > dep_v,我们分三种情况讨论:

然后用 Top_Cluster 求 LCA 可能很多人没有写过,下面放一下我的代码,可以自己理解一下(这里就体现出 Top_Cluster 的优美之处了:两个簇之间只会有一个交点,所以不用另外再写 LCA)。

int get_lca(int u, int v) {
        while (u ^ v) {
            if (dep[u] < dep[v]) swap(u, v); 
            if (id[u] && dwbn[v] != u)  u = CLfa[u]; 
            else u = fa[u]; 
        } return u; 
    } 

最后是总代码,写了将近 300 行,没有卡常效率还不错,不过 cache miss 比较严重。调整下循环顺序和代码逻辑有希望跑进最优解(不过我调疯了不想卡了…)。

#include<cstdio>
#include<cctype>
#include<cmath>
#include<ctime>
#include<vector>
#include<algorithm>
#include<initializer_list>
using namespace std;

namespace Main {

#define vi vector<int> 
#define ld long double
#define rep(i, l, r) for(int i(l), END##i(r); i <= END##i; ++ i) 
#define per(i, r, l) for(int i(r), END##i(l); i >= END##i; -- i) 

template<class T> 
inline void cmin(T& x, const T& y) { x = min(x, y); } 
template<class T> 
inline void cmax(T& x, const T& y) { x = max(x, y); }  

namespace Fast_OI {
    char buf[1000000], *p1 = buf, *p2 = buf, obuf[1000000], *p3 = obuf;
    #define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1000000,stdin),p1==p2)?EOF:*p1++)
    #define putchar(x) (p3-obuf<1000000?*p3++=x:(fwrite(obuf,1,p3-obuf,stdout),p3=obuf,*p3++=x))
    int read() {
        int x = 0; bool f = 1; char c = getchar();
        while (!isdigit(c)) { if (c == '-') f = 0; c = getchar(); }
        while (isdigit(c)) x = (x << 3) + (x << 1) + (c ^ 48), c = getchar();
        return f ? x : -x;
    }
    void puts(const char* str, bool nw_line = 1) {
        while (*str != '\0') 
            putchar(*str), ++ str; 
        if (nw_line) putchar('\n'); 
    }
    void write(int x) {
        if (x < 0) putchar('-'), x = -x;
        if (x > 9) write(x / 10);
        putchar(x % 10 + 48);
    } 
    void flush() { fwrite(obuf,1,p3-obuf,stdout); }
} using namespace Fast_OI;

const int N = 4e4 + 10, B = 260; 
const int Bnum = 6 * N / B + 5; 

int n, Q, val[N]; 
struct edge {
    int to, nex; 
}e[N << 1]; int idx, head[N]; 
inline void add(int u, int v) {
    e[++ idx].to = v; 
    e[idx].nex = head[u]; 
    head[u] = idx; 
} 
inline void add_edge(int u, int v) {
    add(u, v); add(v, u); 
}

void Dscrt() {
    vi tmp; 
    rep(i, 1, n) tmp.emplace_back(val[i]); 
    sort(tmp.begin(), tmp.end()); 
    tmp.erase(unique(tmp.begin(), tmp.end()), tmp.end()); 
    rep(i, 1, n) 
        val[i] = lower_bound(tmp.begin(), tmp.end(), val[i]) - tmp.begin() + 1; 
}

int fa[N], dep[N]; 
void dfs_prew(int u) {
    dep[u] = dep[fa[u]] + 1; 
    for (int i = head[u]; i; i = e[i].nex) {
        int v = e[i].to; 
        if (v == fa[u]) continue; 
        fa[v] = u; dfs_prew(v); 
    }
}

class Top_Cluster {
    public:
    int CLfa[N], nrd[N]; 
    int upbn[N], dwbn[N]; 
    int BNct, id[N]; 
    vi BN, clpt[N]; 
    private : 
    vi cur_cl; 
    void New_CL(int up, int dw) {
        if (cur_cl.empty()) return; 
        if (!dw) dw = cur_cl.back(); 
        if (!id[dw]) ++ BNct, BN.emplace_back(dw); 
        if (up ^ dw) CLfa[dw] = up; nrd[up] = up; 
        for (int u = dw; u ^ up; u = fa[u]) 
            nrd[u] = u; 
        for (const int& u : cur_cl) {
            upbn[u] = up, dwbn[u] = dw; 
            clpt[dw].emplace_back(u); 
            int y = u; while (!nrd[y]) y = fa[y]; 
            nrd[u] = nrd[y]; 
        } cur_cl.clear(); 
    } 
    int stc[N], stctop; 
    int rec_top[N]; 
    int udfct[N], lwBN[N]; 
    void dfs(int u) {
        rec_top[u] = stctop; 
        udfct[u] = 1; 
        int BN_cnt = 0; 
        for (int i = head[u]; i; i = e[i].nex) {
            int v = e[i].to; 
            if (v == fa[u]) continue; 
            stc[++ stctop] = v; 
            dfs(v); udfct[u] += udfct[v]; 
            if (lwBN[v]) lwBN[u] = lwBN[v], ++ BN_cnt; 
        } 
        if (udfct[u] > B || BN_cnt > 1 || !fa[u]) {
            udfct[u] = 0; lwBN[u] = u; 
            int p = rec_top[u] + 1, cnt = 0, cur_down = 0; 
            auto reset = [&](const int& v) -> void {
                while (p <= stctop && (!v || p < rec_top[v])) 
                    cur_cl.emplace_back(stc[p ++]); 
                New_CL(u, cur_down); cnt = cur_down = 0; 
            }; for (int i = head[u]; i; i = e[i].nex) {
                int v = e[i].to; 
                if (v == fa[u]) continue; 
                if (cnt + udfct[v] > B || (cur_down && lwBN[v])) 
                    reset(v); 
                cnt += udfct[v]; if (lwBN[v]) cur_down = lwBN[v]; 
            } reset(0); ++ BNct; BN.emplace_back(u); 
            stctop = rec_top[u]; 
        }
    }
    private : 
    int sum[Bnum][N]; 
    int res[Bnum][N]; 
    int tmpct, tmpsum[N]; 
    void dfs_calc(int u, int idx, int fath) {
        if (!tmpsum[val[u]] ++) ++ tmpct; 
        res[idx][u] = tmpct; 
        for (int i = head[u]; i; i = e[i].nex) {
            int v = e[i].to; 
            if (v == fath) continue; 
            dfs_calc(v, idx, u); 
        } if (!-- tmpsum[val[u]]) -- tmpct; 
    }
    void Prew() {
        rep(i, 1, BNct) {
            int u = BN[i]; 
            sum[i][val[u]] = 1; 
            for (u = fa[u]; u && !id[u]; u = fa[u]) 
                ++ sum[i][val[u]]; 
            rep(j, 1, n) 
                sum[i][j] += sum[id[u]][j]; 
            // printf("#------------ prew : %d, %d\n", BN[i], CLfa[i]); 
            // rep(j, 1, n) printf ("%d ", sum[i][j]); 
            // puts("");     
            dfs_calc(BN[i], i, 0); 
        } 
    } 
    int find_nrBN(int u) {
        while (!id[u]) u = fa[u]; 
        return u; 
    }
    int get_lca(int u, int v) {
        while (u ^ v) {
            if (dep[u] < dep[v]) swap(u, v); 
            if (id[u] && dwbn[v] != u)  u = CLfa[u]; 
            else u = fa[u]; 
        } return u; 
    } 
    public : 
    void Build(int rt = 1) {
        dfs(rt); 
        cur_cl.emplace_back(rt); 
        New_CL(1, 1); 
        BN.emplace_back(0); 
        reverse(BN.begin(), BN.end()); 
        rep(i, 1, BNct) id[BN[i]] = i; 
        // puts("----------"); 
        // rep(i, 1, n) printf ("%d : %d %d\n", i, upbn[i], dwbn[i]); 
        // puts("----------"); 
        Prew(); 
    } 
    int Query(int u, int v) {
        // fprintf (stderr, "query %d %d\n", u, v); 
        static bool apl[N]; int ans = 0; 
        int fu = find_nrBN(u), fv = find_nrBN(v); 
        if (dep[fu] < dep[fv]) swap(u, v), swap(fu, fv);  
        if (fu == fv) {
            // fprintf(stderr, "#1\n"); 
            int ans = 0; 
            int recu = u, recv = v; 
            while (u ^ v) {
                if (dep[u] < dep[v]) 
                    swap(u, v); 
                if (!apl[val[u]]) apl[val[u]] = 1, ++ ans; 
                u = fa[u]; 
            } if (!apl[val[u]]) ++ ans; 
            u = recu, v = recv; 
            while (u ^ v) {
                if (dep[u] < dep[v]) 
                    swap(u, v); 
                apl[val[u]] = 0; 
                u = fa[u]; 
            }
            return ans; 
        } 

        int fw = get_lca(fu, fv); 

        if (fv == fw) {
            // fprintf(stderr, "#2\n"); 
            int dwv = dwbn[v];
            int lca = get_lca(u, v); 
            if (lca == upbn[v]) dwv = upbn[v];  
            int ans = res[id[fu]][v]; 
            // fprintf (stderr, "# %d %d & (%d , %d) : %d %d\n", u, v, fu, fv, dwv, lca); 
            int t = v; 
            while (t ^ lca) apl[val[t]] = 1, t = fa[t]; 
            t = dwv; 
            while (t ^ lca) apl[val[t]] = 1, t = fa[t]; 
            apl[val[lca]] = 1; 
            t = u; 
            while (t ^ fu) {
                if (!apl[val[t]] && sum[id[fu]][val[t]] - sum[id[dwv]][val[t]] == 0) 
                    ++ ans, apl[val[t]] = 1; 
                t = fa[t]; 
            } 

            t = v; 
            while (t ^ lca) apl[val[t]] = 0, t = fa[t]; 
            t = dwv; 
            while (t ^ lca) apl[val[t]] = 0, t = fa[t]; 
            apl[val[lca]] = 0; 
            t = u; 
            while (t ^ fu) apl[val[t]] = 0, t = fa[t]; 

            return ans; 
        }

        // fprintf(stderr, "#3\n"); 
        apl[val[fw]] = 1; 
        ans = res[id[fu]][fv]; 

        int recu = u, recv = v; 
        rep(i, 0, 1) {
            // fprintf (stderr, "%d & %d\n", u, fu); 
            while (u != fu) {
                if (!apl[val[u]] && !(sum[id[fu]][val[u]] + sum[id[fv]][val[u]] - 2 * sum[id[fw]][val[u]])) 
                    apl[val[u]] = 1, ++ ans; // fprintf (stderr, "(%d, %d)\n", u, val[u]); 
                u = fa[u]; 
            } swap(u, v); swap(fu, fv); 
        } 

        u = recu, v = recv; 
        rep(i, 0, 1) {
            while (u != fu) apl[val[u]] = 0, u = fa[u]; 
            swap(u, v); swap(fu, fv); 
        } apl[val[fw]] = 0; 

        return ans; 
    }
}tcl; 

void ERoRain() {
    n = read(), Q = read(); 
    rep(i, 1, n) val[i] = read(); 
    Dscrt(); //rep(i, 1, n) printf ("%d ", val[i]); puts("val"); 
    rep(i, 1, n - 1) {
        int u = read(), v = read(); 
        add_edge(u, v); 
    } dfs_prew(1); 
    tcl.Build(); 
    int lstans = 0; 
    while (Q --) {
        int u = read(), v = read(); 
        u ^= lstans; 
        write(lstans = tcl.Query(u, v)), puts("");  
        // fprintf (stderr, "getans : %d\n", lstans); 
    }
}

signed main() {
    ld start_time = clock(); 
    int T = 1; 
    while (T --) ERoRain(); 
    flush(); 
    fprintf(stderr, "Time : %Lfs\n", (clock() - start_time) / CLOCKS_PER_SEC); 
    return 0; 
} 

} signed main() {
    freopen("a.in", "r", stdin); 
    freopen("a.out", "w", stdout); 

    Main::main(); 
    return 0; 
}