题解:P13763 [CERC 2021] Airline
前言
非常好的树上问题,使我的大脑旋转。\ 不难,思维难度也不高,但是如果没有想到真的很难说。
广告
同步发布于博客园,可以尝试更好的阅读体验。
题意
给出一颗树,不带边权点权,每次询问给出
思考
首先我们令我们给出的
做法
每一次将
代码
#include<algorithm>
#include<iostream>
#include<cstring>
#include<climits>
#include<vector>
#include<cmath>
#define ll long long
using namespace std;
const int N=1e6+9;
ll n,q,fa[N][26],dep[N],node[N*10],cnt,siz[N],nodesiz[N*10];
ll ans,hzsum[N*10];
vector<int>e[N];
inline void dfs(int x,int f){
siz[x]=1;
dep[x]=dep[f]+1;
fa[x][0]=f;
for(int i=1;i<=25;i++)
fa[x][i]=fa[fa[x][i-1]][i-1];
for(int to:e[x])
if(to!=f)
dfs(to,x),siz[x]+=siz[to];
}
inline int LCA(int x,int y){
if(dep[x]<dep[y]) swap(x,y);
for(int i=25;i>=0;i--)
if(dep[fa[x][i]]>=dep[y])
x=fa[x][i];
if(x==y)return x;
for(int i=25;i>=0;i--)
if(fa[x][i]!=fa[y][i])
x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
namespace IN {
const int MAXX_INPUT = 1000000;
#define getc() (p1 == p2 && (p2 = (p1 = buf) + inbuf -> sgetn(buf, MAXX_INPUT), p1 == p2) ? EOF : *p1++)
char buf[MAXX_INPUT], *p1, *p2;
template <typename T> inline bool redi(T &x) {
static streambuf *inbuf = cin.rdbuf();
x = 0;
register int f = 0, flag = false;
register char ch = getc();
while (!isdigit(ch)) {
ch = getc();
}
if (isdigit(ch)) x = x * 10 + ch - '0', ch = getc(),flag = true;
while (isdigit(ch)) {
x = x * 10 + ch - 48;
ch = getc();
}
return flag;
}
template <typename T,typename ...Args> inline bool redi(T& a,Args& ...args) {
return redi(a) && redi(args...);
}
#undef getc
}
void write(ll x){
if(x<0)putchar('-'),x=-x;
if(x>9)write(x/10);
putchar(x%10+'0');
return;
}
using IN::redi;
int main(){
ios::sync_with_stdio(0);
cin.tie(0);cout.tie(0);
redi(n,q);
for(int i=1;i<n;i++){
int u,v;
redi(u,v);
e[u].push_back(v);
e[v].push_back(u);
}
dfs(1,1);
while(q--){
ans=0;cnt=0;
int s,t;
redi(s,t);
int lca=LCA(s,t);
if(dep[s]<dep[t]) swap(s,t);
int ns=s,nt=t;
while(ns!=lca) node[++cnt]=ns,ns=fa[ns][0];
node[++cnt]=lca;int tmpcnt=cnt;
while(nt!=lca) node[++cnt]=nt,nt=fa[nt][0];
reverse(node+tmpcnt+1,node+cnt+1);
nodesiz[1]=siz[node[1]];
for(int i=2;i<tmpcnt;i++)
nodesiz[i]=siz[node[i]]-siz[node[i-1]];
nodesiz[tmpcnt]=n-siz[node[tmpcnt-1]]-siz[node[tmpcnt+1]];
for(int i=tmpcnt+1;i<cnt;i++)
nodesiz[i]=siz[node[i]]-siz[node[i+1]];
if(node[cnt]!=lca)nodesiz[cnt]=siz[node[cnt]];
hzsum[cnt+1]=0;
for(int i=cnt;i>=1;i--)hzsum[i]=hzsum[i+1]+nodesiz[i];
int pos=0;
for(int i=cnt;i>=0;i--){
if(i-1<=cnt-i+1){
pos=i;
break;
}
}
for(int i=1;i<=pos;i++){
if(pos>=cnt) break;
ans+=nodesiz[i]*hzsum[pos+1];
pos++;
}
write(ans);puts("");
for(int i=1;i<=cnt;i++)
nodesiz[i]=0,hzsum[i]=0,node[i]=0;
}
return 0;
}