Nemlit 的博客

Nemlit 的博客

By a konjac

题解 CF809E 【Surprise me!】

posted on 2019-09-30 23:46:38 | under 题解 |

我们要求的柿子是张这样子的:

$$\frac{1}{n * (n - 1)} * \sum_{i = 1}^n\sum_{j = 1}^{n}\phi(a_i*a_j)*dis(i, j)$$

其中 $a_i$ 为一个排列, $dis(i, j)$ 表示在树上的距离

这种题的套路一般是先拆柿子,但是这道题的式子……

我们要从一个性质下手: $$\phi(a * b) = \frac{\phi(a) * \phi(b) * gcd(a, b)}{\phi(gcd(a, b))}$$

代入原式得:

$$\frac{1}{n * (n - 1)} * \sum_{i = 1}^n\sum_{j = 1}^{n}\frac{\phi(a_i) * \phi(a_j) * gcd(a_i, a_j)}{\phi(gcd(a_i, a_j))}*dis(i, j)$$

先忽略前面的数,只看后面的 $\sum$ ,枚举 $gcd(a_i, a_j)$ ,得到

$$\sum_{k = 1}^n\frac{k}{\phi(k)}\sum_{i = 1}^n\sum_{j = 1}^{n}\phi(a_i) * \phi(a_j)*dis(i, j)*[gcd(a_i, a_j) == k]$$

然后反演一波,得到:

$$\sum_{k = 1}^n\frac{k}{\phi(k)}\sum_{i = 1}^n\sum_{j = 1}^{n}\phi(a_i) * \phi(a_j)*dis(i, j)*\sum_{(x * k|a[i]) \& (x * k | a[j])}\mu(x)$$

枚举 $k * x$

$$\sum_{T = 1}^n\sum_{k|T}\frac{k}{\phi(k)}\sum_{i = 1}^n\sum_{j = 1}^{n}\phi(a_i) * \phi(a_j)*dis(i, j)*\sum_{(T|a[i]) \& (T | a[j])}\mu(\frac{T}{k})$$

交换顺序得: $$\sum_{T = 1}^n\sum_{k|T}\frac{k}{\phi(k)} * \mu(\frac{T}{k})\sum_{a[i]\ |\ T}\sum_{a[j]\ |\ T}\phi(a_i) * \phi(a_j)*dis(i, j)$$

我们考虑枚举T,对于后面的柿子,我们可以单独拎出来,对所有 $a[i] | T$ 用树形DP求出后面柿子的答案,前面的柿子可以提前与处理出来

由于虚树的总点数是 $(nlogn)$ 个(并不会证明),所以复杂度正确,但由于虚树上的DP和普通DP有一定差异,所以我们还需要对后面的柿子继续化简

$$\sum_{a[i]\ |\ T}\sum_{a[j]\ |\ T}\phi(a_i) * \phi(a_j)*dis(i, j)$$

拆开 $dis(i, j)$ 得:

$$\sum_{a[i]\ |\ T}\sum_{a[j]\ |\ T}\phi(a_i) * \phi(a_j)*(dep[i] + dep[j] - 2 * dep[lca(i, j)])$$

令 $val[i] = \phi(a_i)$ ,把所有 $a[i] | T$ 拎出来,假设有x个

$$\sum_{i= 1}^{x}\sum_{j = 1}^xval[i] * val[j]*(dep[i] + dep[j] - 2 * dep[lca(i, j)])$$

$$\sum_{i= 1}^{x}\sum_{j = 1}^xval[i] * val[j]*dep[i] + \sum_{i= 1}^{x}\sum_{j = 1}^xval[i] * val[j] * dep[j] -2 * \sum_{i= 1}^{x}\sum_{j = 1}^xval[i] * val[j] * dep[lca(i, j)])$$ $$2 * \sum_{i= 1}^{x}val[i] *dep[i] \sum_{j = 1}^xval[j] -2 * \sum_{i= 1}^{x}\sum_{j = 1}^xval[i] * val[j] * dep[lca(i, j)])$$

前面的柿子可以与处理出来,后面的柿子只需要我们在虚树上枚举lca,求出 $\sum_{i= 1}^{x}\sum_{j = 1}^xval[i] * val[j]*[lca(i, j) == lca]$

这个值其实不难求,记录 $f(x)= \sum_{i = 1}^xval[i]$ 即可

$Code:$

#include<bits/stdc++.h>
using namespace std;
#define il inline
#define re register
#define mod 1000000007
il int read() {
    re int x = 0, f = 1; re char c = getchar();
    while(c < '0' || c > '9') { if(c == '-') f = -1; c = getchar();}
    while(c >= '0' && c <= '9') x = x * 10 + c - 48, c = getchar();
    return x * f;
}
#define rep(i, s, t) for(re int i = s; i <= t; ++ i)
#define Next(i, u) for(re int i = head[u]; i; i = e[i].next)
#define mem(k, p) memset(k, p, sizeof(k))
#define maxn 400005
int n, m, Go[maxn], head[maxn], cnt, rev[maxn];
struct edge { int v, next; }e[maxn << 1];
il void add(int u, int v) {
    e[++ cnt] = (edge){v, head[u]}, head[u] = cnt;
    e[++ cnt] = (edge){u, head[v]}, head[v] = cnt;
}
il int mul(int a, int b) { return 1ll * a * b % mod; }
il int qpow(int a, int b) {
    int r = 1;
    while(b) {
        if(b & 1) r = mul(a, r);
        a = mul(a, a), b >>= 1;
    }
    return r;
}

int prim[maxn], tot, Vis[maxn], phi[maxn], mu[maxn], F[maxn], ans, G[maxn];
il void init(int n) {
    mu[1] = phi[1] = 1;
    rep(i, 2, n) {
        if(!Vis[i]) prim[++ cnt] = i, mu[i] = -1, phi[i] = i - 1;
        rep(j, 1, cnt) {
            if(i * prim[j] > n) break;
            Vis[i * prim[j]] = 1;
            if(i % prim[j] == 0) {
                phi[i * prim[j]] = phi[i] * prim[j];
                break;
            }
            mu[i * prim[j]] = -mu[i], phi[i * prim[j]] = phi[i] * phi[prim[j]];
        }
    }
    rep(i, 1, n) 
        for(re int j = i; j <= n; j += i) 
            F[j] = (F[j] + mul(mul(i, qpow(phi[i], mod - 2)), mu[j / i])) % mod, F[j] = (F[j] + mod) % mod;
}

int fa[maxn], dep[maxn], Top[maxn], dfn[maxn], col, son[maxn], size[maxn];
il void dfs1(int u, int fr) {
    size[u] = 1, fa[u] = fr, dep[u] = dep[fr] + 1;
    Next(i, u) {
        int v = e[i].v;
        if(v == fr) continue;
        dfs1(v, u), size[u] += size[v];
        if(size[v] > size[son[u]]) son[u] = v;
    }
}
il void dfs2(int u, int fr) {
    dfn[u] = ++ col, Top[u] = fr;
    if(son[u]) dfs2(son[u], fr);
    Next(i, u) if(e[i].v != fa[u] && e[i].v != son[u]) dfs2(e[i].v, e[i].v);
}
il int LCA(int u, int v) {
    while(Top[u] != Top[v]) dep[Top[u]] > dep[Top[v]] ? u = fa[Top[u]] : v = fa[Top[v]];
    return dep[u] > dep[v] ? v : u;
}

int st[maxn], top, a[maxn], tmp, pax, vis[maxn], f[maxn], val[maxn], g[maxn];
il bool cmp(int a, int b) { return dfn[a] < dfn[b]; }
il void insert(int x) {
    if(top == 1 && x != 1) return (void)(st[++ top] = x);
    int lca = LCA(st[top], x);
    if(x == lca) return;
    while(top > 1 && dep[st[top - 1]] > dep[lca]) {
        add(st[top], st[top - 1]), -- top; 
    }
    if(dep[st[top]] > dep[lca]) add(lca, st[top]), -- top;
    if(dep[st[top]] < dep[lca]) st[++ top] = lca;
    st[++ top] = x;
}
il void build(int n) {
    sort(a + 1, a + n + 1, cmp), st[top = 1] = 1;
    rep(i, 1, n) insert(a[i]);
    while(top > 1) add(st[top - 1], st[top]), -- top;
}
il void get_dis(int u, int fr) {
    if(vis[u]) f[u] = mul(phi[Go[u]], dep[u]), val[u] = phi[Go[u]];
    int sum = val[u];
    Next(i, u) {
        int v = e[i].v;
        if(v == fr) continue;
        get_dis(v, u);
        g[u] = (g[u] + mul(val[v], sum)) % mod;
        sum = (sum + val[v]) % mod;
        f[u] = (f[u] + f[v]) % mod, val[u] = (val[u] + val[v]) % mod;
    }
    g[u] = mul(g[u], dep[u]);
}
il void dfs_mem(int u, int fr) {
    Next(i, u) if(e[i].v != fr) dfs_mem(e[i].v, u);
    tmp = (tmp + g[u]) % mod, head[u] = vis[u] = f[u] = val[u] = g[u] = 0;
}
il void solve() {
    rep(T, 1, n / 2) {
        pax = tmp = cnt = 0;
        for(re int i = T; i <= n; i += T) a[++ pax] = rev[i], vis[rev[i]] = 1;
        build(pax), get_dis(1, 0);
        G[T] = 2ll * mul(f[1], val[1]) % mod;
        dfs_mem(1, 0), tmp = mul(2, tmp);
        rep(i, 1, pax) tmp = (tmp + mul(dep[a[i]], mul(phi[Go[a[i]]], phi[Go[a[i]]]))) % mod;
        G[T] = (G[T] - 2ll * tmp % mod + mod) % mod;
    }
}
int main() {
    n = read(), init(n);
    rep(i, 1, n) Go[i] = read(), rev[Go[i]] = i;
    rep(i, 1, n - 1) add(read(), read());
    dfs1(1, 0), dfs2(1, 1), mem(head, 0), solve();
    rep(i, 1, n) ans = (ans + mul(G[i], F[i])) % mod;
    printf("%d", mul((ans + mod) % mod, qpow(mul(n, n - 1), mod - 2)));
    return 0;
}