题解 P5002 【专心OI - 找祖先】
虽然标签有LCA,但根本没有用到求LCA(谁写的标签)
先是存树,我习惯用的是领接表(本人太弱,不会用vector)
int n,m,r,tot;
int to[N],ne[N],head[N],son[N],ans[50002],de[N];
void add(int x,int y){
to[++tot]=y,ne[tot]=head[x],head[x]=tot;
}
n,m,r如题意所示,son数组记录每个子树的大小,de数组记录每个节点深度,ans数组记录每个点对应的答案
本题的核心就是求每棵子树的大小,然后排列组合一下
void dfs(int x,int fa){
son[x]=1;//子树包括自身
de[x]=de[fa]+1;//暂时没什么用的深度
for(int i=head[x];i;i=ne[i]){//邻接表遍历
if(to[i]!=fa){//没了它你的dfs就会在根节点和一个子树上反复横跳
dfs(to[i],x);
son[x]+=son[to[i]]%mo;//递归,求出其儿子的子树的大小后再加上去,递归边界是叶子节点,son[]为1
}
}
}
这样我们就快乐地求出了所有子树的大小
接下来就是分析题目了,题问有多少对节点(x,y)的LCA为p,很容易发现以下性质
- x,y一定在p的子树中
- 如果x,y都!=p,那么x,y绝对不是p的同一个儿子的 子树
总结一下就会发现要么一个节点为p,另一个节点为子树p中的任一节点(一共son [ p ]对);要么一个节点在p的儿子的子树中,另一个节点在p的其它儿子的子树中(包括p节点)。那么我们就可以遍历p的儿子,每一个儿子的子树大小乘上p子树中除该子树的大小,最后加上son[p]
int getans(int p){
int x,y=0,z=0;//x为p的子树大小,y为本次遍历到的子树的大小
z=x=son[p];//z为计算结果
for(int i=head[p];i;i=ne[i]){
if(de[to[i]]<de[p]) continue;//通过深度来保证搜索的其儿子
y=son[to[i]];
z+=(long long)((x-y)*y)%mo;
}
return z;
}
用了这个,就可以针对每一个询问的p求出相应的结果
for(int i=1;i<=m;i++){
int p;
cin>>p;
ans[p]=getans(p);
}
for(int i=1;i<=m;i++) cout<<ans[i]<<endl;
然后。。。然后就TLE了一个点
怎么回事呢?再回去看看题目,M<=50000而N<=10000,诶,询问次数比点要多,说明很多点被重复询问了多次(好坑啊)
for(int i=1;i<=n;i++) ans[i]=getans(i);
for(int i=1;i<=m;i++){
int p;
read(p);
cout<<ans[p]<<endl;
}
所以我们就把每个点都处理一遍,最后直接输出就好了
完整代码
#include<bits/stdc++.h>
#define INF 0x3f3f3f3f
#define N 20002
#define mo 1000000007
using namespace std;
int n,m,r,tot;
int to[N],ne[N],head[N],son[N],ans[50002],de[N];
void add(int x,int y){
to[++tot]=y,ne[tot]=head[x],head[x]=tot;
}
void read(int &x) {
int f = 1; x = 0;
char ch = getchar();
while (ch < '0' || ch > '9') {if (ch == '-') f = -1; ch = getchar();}
while (ch >= '0' && ch <= '9') {x = x * 10 + ch - '0'; ch = getchar();}
x *= f;
}
void dfs(int x,int fa){
son[x]=1;//子树包括自身
de[x]=de[fa]+1;//暂时没什么用的深度
for(int i=head[x];i;i=ne[i]){//邻接表遍历
if(to[i]!=fa){//没了它你的dfs就会在根节点和一个子树上反复横跳
dfs(to[i],x);
son[x]+=son[to[i]]%mo;//递归,求出其儿子的子树的大小后再加上去,
}
}
}
int getans(int p){
int x,y=0,z=0;
z=x=son[p];
for(int i=head[p];i;i=ne[i]){
if(de[to[i]]<de[p]) continue;
y=son[to[i]];
z+=(long long)((x-y)*y)%mo;
}
return z;
}
int main()
{
read(n),read(r),read(m);
for(int i=1;i<n;i++){
int x,y;
read(x),read(y);
add(x,y),add(y,x);
}
dfs(r,r);
for(int i=1;i<=n;i++) ans[i]=getans(i);
for(int i=1;i<=m;i++){
int p;
read(p);
cout<<ans[p]<<endl;
}
return 0;
}
写题解不易,点个赞呗