P11477 [COCI 2024/2025 #3] 林卡树 / Stablo 题解

· · 题解

传送门

虽然重交了几次了,但真不是故意的,错误好难找,辛苦管理员了。

#

考虑将 f(x) 转化为更好求得的形式。

dep_pp 到根节点的距离,v_pp 的点权。

对于节点 x,令其子树内所有点为 a_1,a_2, \dots ,a_n

于是有:

f(x) = v_{a_1} \times (dep_{a_1} - dep_x) + v_{a_2} \times (dep_{a_2} - dep_x) + \ldots + v_{a_n} \times (dep_{a_n} - dep_x)

将括号拆开,得:

f(x) = v_{a_1} \times dep_{a_1} + v_{a_2} \times dep_{a_2} + \ldots + v_{a_n} \times dep_{a_n} - \sum_{i=1}^n v_{a_i} \times dep_x

在这个式子中,有非常多的不变量,我们只用预处理三个量。

  1. 深度 dep_p
  2. 减号前这一串式子,记为 tot_p

即可求 f(x)

现在回到查询,我们简要概括这个操作:

给出 (x,y),将 x 沿父亲不断上移成 y 的儿子,x 的所有儿子由 x 的原父亲继承,求 f(y)

我们尝试找到一些结论:

有了这两个结论,结合 x 上移对 x 本身贡献的影响,求 f'(y) 就很简单了。

f'(y) = f(y) + (sum_{y_{son}} + v_{y_{son}}) -( sum_x + v_x) + v_x - v_x \times (dep_x - dep_y)

整理得:

f'(y) = f(y) + sum_{y_{son}} + v_{y_{son}} - sum_x - v_x \times (dep_x - dep_y)

由数据范围想到要用 O(n \log n) 时间复杂度的算法。

题目的优化关键在求 y_{son},我们使用倍增法。

预处理 O(n),查询 O(q),每次查询 O(\log n) 找子树。

总时间复杂度为 O(n + q \log n)

以下为代码:

# include <stdio.h>
# include <stdlib.h>

struct Node
{
    int p;
    struct Node* next;
}; 

struct Head
{
    struct Node* next;
};

struct Head p[500005];
long long val[500005]; 
long long tot[500005];
long long dep[500005];
long long sum[500005];
long long f[500005];
int father[500005][22];
int cur;

struct Node* ini()
{
    struct Node* tmp = (struct Node*) malloc (sizeof(struct Node)); 
    tmp->next = NULL;
    return tmp;
}

void add(int u,int fa)
{
    struct Node* tmp = ini();
    tmp->p = u;
    tmp->next = p[fa].next;
    p[fa].next = tmp;
    return ;
}

void dfs(int u,int fa)
{
    dep[u] = dep[fa]+1;
    cur++;

    father[u][0] = fa;

    for (int i=1;i<21;i++)
    {
        father[u][i] = father[father[u][i-1]][i-1];
    }

    for (struct Node* tmp=p[u].next;tmp!=NULL;tmp=tmp->next)
    {
        int v = tmp->p;
        if (v == fa)
        {
            continue;
        }
        dfs(v,u);
        tot[u]+=tot[v]+val[v]*dep[v];
        sum[u]+=sum[v]+val[v]; 
    }
    f[u] = tot[u]-sum[u]*dep[u];

    return ;
}

int get_root(int x, int y)
{
    int k = 20;
    while (k >= 0)
    {
        if (dep[father[x][k]] >= dep[y]+1) 
            {
            x = father[x][k];
        }
        k--;
    }
    return x;
}

int main (void)
{
    int n,q;
    scanf ("%d %d",&n,&q);

    for (int i=1;i<=n;i++)  
    {
        scanf ("%d",&val[i]);
        p[i].next = NULL;
    }

    for (int i=2;i<=n;i++)
    {
        int fa;
        scanf ("%d",&fa);
        add(i,fa);
    }

    dfs(1,1); 

    for (int i=1;i<=q;i++)
    {
        int x,y;
        scanf ("%d %d",&x,&y);
        int son = get_root(x,y);
        printf ("%lld\n",f[y]+sum[son]+val[son]-sum[x]-val[x]*(dep[x]-dep[y])); //
    }

    return 0;
}