[St-OI Round 1] T3:树上询问 官方题解
[St-OI Round 1] T3:树上询问 官方题解
原定时限是1.2s,本人std最慢的一个点是600ms,所以超1.2s的同学都要订正啊qwq。(不要喷我duliu!
废话不多说。
关于不随机数据:就是菊花树,是为卡掉枚举z出边的。
思路
这题乍一眼看上去好像并没有什么特别显而易见的做法(雾
但先画个图!
就以样例二为例:
当我们找(2,5,z)的答案时,发现只有{2,3,1,5}可以作为LCA,因为他们在(2,5)的唯一路径上,所以发现z不在路径上时(比如4),就可以直接输0。
如果在路径上呢?
比如(2,5,3),3肯定是答案,然后我们把3给“拎”起来:
发现以4为根时,也可以。这时我们可以思考,是不是所有3的子节点都可以呢?
当然不是,2,5所在的子节点(即2,1)就不行,但其他子节点及其子节点下所有的节点都可以!(比如下图中的选中的所有节点)
思路明显了吧:当z在(x,y)路径上时,z和非路径上z的所有子节点(包括father)及其子孙节点就是答案!(如果纯看语言没看懂就看代码)
实现:
前置知识:
树上倍增(+求LCA),dfs序
首先一遍以1为根dfs求出dfs序和子树大小,再进行倍增预处理。
然后分4种情况讨论:
在路径上时:
不在路径上时:
- print(0);
CODE:
#include <cstdio>
#define For(x) for(int i=hd[x];i;i=e[i].nxt)
#define v (e[i].to)
#define fsize(x) (n-size[x])//father子树大小
#define swp(x,y) (x^=y^=x^=y)
const int N=5e5+5;
int n,m,x,y,z,hd[N],num;
struct cz {
int nxt,to;
}e[N<<1];
int size[N],tim[N],dep[N],cnt;//size是子树大小,tim是时间戳(dfs序),dep是节点深度
int lg[N],f[N][30];//倍增用
inline int read() {
int x=0,flag=0;char ch=getchar();
while(ch<'0'||ch>'9'){flag|=(ch=='-');ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
return flag?-x:x;
}
inline void add(int x,int y) {
e[++num]=(cz) {hd[x],y};
hd[x]=num;
}
void dfs(int x,int father) {
size[x]=1;tim[x]=++cnt;dep[x]=dep[father]+1;
f[x][0]=father;
For(x) {
if(v==father) continue;
dfs(v,x);
size[x]+=size[v];
}
}
inline int Lca(int x,int y) {
if(dep[x]<dep[y]) swp(x,y);
for(int i=lg[n];i>=0;--i)
if(dep[f[x][i]]>=dep[y])
x=f[x][i];
if(x==y) return x;
for(int i=lg[n];i>=0;--i)
if(f[x][i]!=f[y][i])
x=f[x][i],y=f[y][i];
return f[x][0];
}
inline bool be_in(int x,int z) {return (tim[z]<=tim[x]&&tim[x]<=tim[z]+size[z]-1);}
//判断x在不在z的子树上
inline int tot(int x,int fa) {//x一直jump(倍增)到fa子节点的位置
if(x==fa) return 0;
for(int i=lg[n];i>=0;--i)
if(tim[f[x][i]]>tim[fa])
x=f[x][i];
return size[x];
}
int main() {
n=read();m=read();
for(int i=2;i<=n;++i) lg[i]=lg[i-1]+(1<<(lg[i-1]+1)==i?1:0);
for(int i=1;i<n;++i) {
x=read();y=read();
add(x,y);add(y,x);
}
dfs(1,0);
for(int j=1;j<=lg[n];++j)
for(int i=1;i<=n;++i)
f[i][j]=f[f[i][j-1]][j-1];
for(int i=1;i<=m;++i) {
x=read();y=read();z=read();//下面就是4种情况
if(Lca(x,y)==z) printf("%d\n",n-tot(x,z)-tot(y,z));
else if(be_in(x,z)&&!be_in(y,z)) printf("%d\n",n-tot(x,z)-fsize(z));
else if(be_in(y,z)&&!be_in(x,z)) printf("%d\n",n-tot(y,z)-fsize(z));
else printf("0\n");
}
return 0;
}
-完-