题解 P3591 【[POI2015]ODW】

· · 题解

一道套路题——神cy

好像确实挺套路的

考虑我们让所有k>\sqrt n的暴力跳,预处理所有k\leq\sqrt n的情况即sum[i][j]表示i往上每j级祖先选一个点,一直到根节点选的点的情况和,查询就类似树上差分就好了

那么这样暴力跳的复杂度是O(n\sqrt n \log n),预处理是O(n\sqrt n)

对时间进行优化,可以使用长链剖分,每次O(1)查询k级祖先,这样我们的总复杂度就可以做到O(n\log n+n\sqrt n)

长链剖分模板题

然后这道题就愉快的做完了

没错,思想就这么简单,但是代码啊

细节不少,一定要想清楚了再写,不然写题半小时,debug五小时

#include <bits/stdc++.h>
using namespace std;
#define QAQ puts("QAQ")
# define Rep(i,a,b) for(int i=a;i<=b;i++)
# define _Rep(i,a,b) for(int i=a;i>=b;i--)
# define RepG(i,u) for(int i=head[u];~i;i=e[i].next)

typedef long long ll;

const int N=5e4+5;
const int W=250;
const int w=225;

template<typename T> void read(T &x){
   x=0;int f=1;
   char c=getchar();
   for(;!isdigit(c);c=getchar())if(c=='-')f=-1;
   for(;isdigit(c);c=getchar())x=(x<<1)+(x<<3)+c-'0';
    x*=f;
}

int n;
int a[N];
int head[N],cnt;
int b[N],c[N];
int dep[N],mdep[N],son[N],top[N];
int highbit[N],f[N][25];
int sum[N][W];

vector<int> up[N],down[N];

struct Edge{
    int to,next;    
}e[N<<1];

void add(int x,int y){
    e[++cnt]=(Edge){y,head[x]},head[x]=cnt; 
}

void dfs1(int u,int fa){
    mdep[u]=dep[u]=dep[fa]+1;
//  printf("%d %d\n",u,dep[u]);
    f[u][0]=fa;
    Rep(i,1,18)f[u][i]=f[f[u][i-1]][i-1];
    RepG(i,u){
        int v=e[i].to;
        if(v==fa)continue;
        dfs1(v,u);
        if(mdep[v]>mdep[u])mdep[u]=mdep[v],son[u]=v;
    }
}

void dfs2(int u,int fa,int _top){
    top[u]=_top;
    for(int i=1,p=f[u][0];i<=w;i++,p=f[p][0])sum[u][i]=sum[p][i]+a[u];
    if(top[u]==u){
        for(int i=0,ff=u;i<=mdep[u]-dep[u];i++,ff=f[ff][0])
            up[u].push_back(ff);
        for(int i=0,ss=u;i<=mdep[u]-dep[u];i++,ss=son[ss])
            down[u].push_back(ss);
    }
    if(son[u])dfs2(son[u],u,_top);
    RepG(i,u){
        int v=e[i].to;
        if(v==son[u]||v==fa)continue;
        dfs2(v,u,v);    
    }
}

int kthan(int x,int k){
    if(dep[x]<=k)return 0;
    if(!k)return x;
    x=f[x][highbit[k]],k-=(1<<highbit[k]);
    k-=dep[x]-dep[top[x]],x=top[x];
    return k>=0?up[x][k]:down[x][-k];   
}

int lca(int x,int y){
    if(dep[x]<=dep[y])swap(x,y);
    _Rep(i,18,0)if(dep[f[x][i]]>=dep[y])x=f[x][i];
    if(x==y)return x;
    _Rep(i,18,0)if(f[x][i]!=f[y][i])x=f[x][i],y=f[y][i];
    return f[x][0]; 
}

signed main()
{
//  freopen("3591.in","r",stdin);
//  freopen("3591.out","w",stdout);
    memset(head,-1,sizeof(head));
    read(n);
    Rep(i,1,n)read(a[i]);
    Rep(i,1,n-1){
        int x,y;
        read(x),read(y);
        add(x,y),add(y,x);  
    }
    Rep(i,1,n)read(b[i]);
    Rep(i,1,n-1)read(c[i]);
    dfs1(1,0);
    dfs2(1,0,1);
    highbit[1]=0;
    Rep(i,2,n)highbit[i]=highbit[i>>1]+1;
    Rep(i,2,n){
        int x=b[i-1],y=b[i],k=c[i-1];
        int l=lca(x,y);
        int res=0;
        if(k<=w){
            res+=sum[x][k]-sum[kthan(x,k*((dep[x]-dep[l])/k+1))][k];
            y=kthan(y,(dep[x]+dep[y]-2*dep[l])%k);
            res+=sum[y][k]-sum[kthan(y,k*((dep[y]-dep[l])/k+1))][k];
            if((dep[x]-dep[l])%k==0&&(dep[y]-dep[l])%k==0)res-=a[l];
            printf("%d\n",res);
        }
        else{
            for(int i=x;dep[i]>=dep[l];i=kthan(i,k))res+=a[i];
            y=kthan(y,(dep[x]+dep[y]-2*dep[l])%k);
            for(int i=y;dep[i]>=dep[l];i=kthan(i,k))res+=a[i];
            if((dep[x]-dep[l])%k==0&&(dep[y]-dep[l])%k==0)res-=a[l];
            printf("%d\n",res); 
        }
    }
    return 0;
}