不错的树形 dp 练习题——树链剖分(非模板)
Sweetlemon · · 题解
既然能得到比 std 更优的答案(连样例都有更优的解),要不要加强数据?(雾)
空说无凭,给出一组更优的样例输出(其实人力求解也可以得到这个更优解)。
2
6
7
9
0
11
0
0
13
0
0
0
14
0
其实只是把
(可能是为了降低难度才把 std 的解改劣?如果是这样,那就不要加强了吧~)
(吹水完毕,题解正文开始)
树链剖分(非模板)
这道题应该还是有不少部分分的方法的,比如随便选一个儿子输出(雾),或不管三七二十一直接重链剖分,都可以拿到一定的分数。
下面介绍一种思路比较自然、实现不困难、可以得到比 std 更优的答案的树形 dp 方法。个人认为应该已经达到了理论最优解,欢迎 hack。
条件转化
题目中的“所有查询经过的轻重链的总数”不是特别容易统计,因此我们考虑把它进行转化。
初始的一步是把
那么如何描述一条树链经过的链的数量呢?
假设我们是暴力统计这个量,那么我们的思路可以是这样的:从
总结上面的过程,“一条树链经过的链的数量”可以描述为“树链的长度减去树链中相邻重边对的数量”,也就是假设这条树链的边依次是
由于树链中边的数量是定值,所以想要让“所有查询经过的轻重链的总数”最小,就应该让“所有查询的树链所经过的相邻重边对”最多。
这样我们就有了比较方便的统计量。方便在哪里呢?原来的“经过轻重链”所涉及的状态比较复杂,相比之下“相邻重边对”就简单多了——这个状态只与“这一条边”和“上一条边”有关。
状态和状态转移的设计
这道题是要决策“选哪个重儿子”,又有了比较容易表示的统计量,我们觉得可以试试树形 dp。
按照树形 dp 的套路,我们只考虑某一棵子树。
为方便叙述,先定义两个概念。相邻重边对需要有一个“统计位置”。这里我们把树链上相邻两条边上深度最大的点称为这个相邻边对的“基点”,并把基点作为相邻重边对的统计位置。换句话说,相邻边对就是基点往上跳两层。可以设
我们设
那么我们考虑这个状态如何转移呢?
假设我们已经算出了
怎么办呢?加上去不就好了?设
现在似乎就比较容易转移了。
如果
如果
这个树形 dp 的主框架已经完成了。我们只需要在计算
现在的问题是,如何计算
预处理工作
查询一条路径
怎么实现这个加 树链剖分+线段树 树链剖分+树状数组 树上差分(因为是全部加完再查询,相当于离线操作,因此差分就够了)。设 delt[x]++,delt[y]++,delt[u]--,delt[v]--;最后令
还有一个小问题,如何求出
另外,如果 u=x 或 v=y,一加一减就抵消了。写代码的时候注意处理。
思路总结
综上,这题的主要解决步骤是:
- 读入树,做好 LCA 的预处理,
O(n \log n) - 读入查询,用倍增 LCA 找
u,v ,做树上差分,O(q \log n) - dfs 一遍,由差分得到
w ,O(n) - 进行树上 dp,求出
f 的决策点,O(n) - 根据决策情况确定重儿子,输出答案,
O(n)
总复杂度是
这道题的主要启示就是,拿到题要勤于转化题目的条件,把不好处理的条件转化为熟悉的、容易处理的条件,这样就比较容易解题了。
代码
代码比较长,但是基本都是套路,因此也不难写。
吐槽一下数据太弱,把我一份连样例都没过的错误代码(那份代码的 LCA 写炸了)放过去了。建议把样例加入数据中(笑)。
评测记录在这里。可以看到答案基本上都比 std 优。
#include <cstdio>
#include <cctype>
#include <algorithm>
#define MAXIOLG 25
#define FILENAME(x)\
freopen(x".in","r",stdin);\
freopen(x".out","w",stdout);
#define LOWBIT(x) ((x)&(-(x)))
#define MAXN 100005
#define MAXM 200005
#define MAXLG 17
using namespace std;
typedef long double ld;
typedef long long ll;
typedef ll io_t;
//快读快写
io_t shin[MAXIOLG];
io_t seto(void);
void ayano(io_t x,char spliter='\n');
int n;
//存图
int g[MAXM];
int fst[MAXN],nxt[MAXM],edges=0;
//倍增父亲数组, 儿子数量, 深度
int par[MAXLG][MAXN],chs[MAXN],depth[MAXN];
//log 2 的值的预处理
int lg[MAXN];
//差分数组,后来直接用来表示 w
int delt[MAXN];
//f 数组和决策点数组
int f[MAXN][2],h[MAXN][2];
//isp=is preferred,是不是重儿子;pr[x] 表示 x 的重儿子
int isp[MAXN],pr[MAXN];
void dfs1(int x,int pa); //预处理 LCA
int jmp(int x,int k); //x 往上跳 k 层
void solve(int x,int y); //处理查询
void dfs2(int x,int pa); //树上差分的求和, delt 变成 w
void dfs3(int x,int pa); //树上 dp
void dfs4(int x,int pa); //由决策点得出答案
void addedge(int u,int v); //加边
int main(void){
n=seto();
int q=seto();
{
//预处理 log 2
lg[0]=lg[1]=0;
int tlg=0;
for (int i=2;i<=n;i++)
(i==LOWBIT(i))?(tlg++):(0),lg[i]=tlg;
}
for (int i=1;i<n;i++){
int u,v;
u=seto(),v=seto();
addedge(u,v),addedge(v,u);
}
//预处理 LCA
dfs1(1,0);
for (int tlg=1,lgn=lg[n];tlg<=lgn;tlg++){
int ulg=tlg-1;
for (int i=1;i<=n;i++)
par[tlg][i]=par[ulg][par[ulg][i]];
}
//处理查询
while (q--){
int u,v;
u=seto(),v=seto();
solve(u,v);
}
//树上差分的求和, delt 变成 w
dfs2(1,0);
//树上 dp
dfs3(1,0);
//由决策点得出答案
isp[1]=0; //1 不是重儿子
dfs4(1,0);
//输出答案
for (int i=1;i<=n;i++)
ayano(pr[i]);
return 0;
}
void dfs1(int x,int pa){
//预处理 LCA
par[0][x]=pa,depth[x]=depth[pa]+1;
for (int ei=fst[x];ei;ei=nxt[ei]){
int v=g[ei];
if (v==pa)
continue;
dfs1(v,x),chs[x]++;
}
}
void dfs2(int x,int pa){
//树上差分的求和, delt 变成 w
for (int ei=fst[x];ei;ei=nxt[ei]){
int v=g[ei];
if (v==pa)
continue;
dfs2(v,x);
delt[x]+=delt[v];
}
}
void dfs3(int x,int pa){
//树上 dp
if (!chs[x]){
//叶子节点的 f 和决策点都是 0
f[x][0]=f[x][1]=h[x][0]=h[x][1]=0;
return;
}
int s0=0; //所有儿子 u 的 f[u][0] 之和
for (int ei=fst[x];ei;ei=nxt[ei]){
int v=g[ei];
if (v==pa)
continue;
dfs3(v,x);
s0+=f[v][0];
}
int f0=-1,f1=-1; //当前最优 f[x][0], f[x][1]
int h0=0,h1=0; //当前决策点
for (int ei=fst[x];ei;ei=nxt[ei]){
int v=g[ei];
if (v==pa)
continue;
int tf0=s0-f[v][0]+f[v][1];
int tf1=tf0+delt[v];
(tf0>f0)?(f0=tf0,h0=v):(0);
(tf1>f1)?(f1=tf1,h1=v):(0);
}
f[x][0]=f0,f[x][1]=f1;
h[x][0]=h0,h[x][1]=h1;
}
void dfs4(int x,int pa){
//由决策点得出答案
pr[x]=h[x][isp[x]]; //根据 x 是不是重儿子,选择相应决策点
(pr[x])?(isp[pr[x]]=1):(0);
for (int ei=fst[x];ei;ei=nxt[ei]){
int v=g[ei];
if (v==pa)
continue;
dfs4(v,x);
}
}
inline int jmp(int x,int k){
//x 往上跳 k 层
if (!k)
return x;
for (int i=0,lgk=lg[k];i<=lgk;i++){
int tpw=1<<i;
(k&tpw)?(x=par[i][x]):(0);
}
return x;
}
void solve(int x,int y){
//处理 x y 查询
if (x==y || par[0][x]==y || par[0][y]==x)
//特殊情况判断
return;
delt[x]++,delt[y]++;
int xd=depth[x],yd=depth[y],dd;
if (xd<yd)
swap(x,y),dd=yd-xd;
else
dd=xd-yd;
int txans=x,tyans=y; //txans, tyans 表示 u, v
int xjmp=jmp(x,dd);
if (xjmp==y){
//LCA 就是 y
txans=jmp(x,dd-1);
tyans=y;
}
else {
x=xjmp;
for (int tlg=lg[depth[x]];~tlg;tlg--)
(par[tlg][x]!=par[tlg][y])?(x=par[tlg][x],y=par[tlg][y]):(0);
txans=x,tyans=y;
}
delt[txans]--,delt[tyans]--;
}
void addedge(int u,int v){
//加边
edges++;
g[edges]=v;
nxt[edges]=fst[u],fst[u]=edges;
}
io_t seto(void){
//快读
io_t ans=0;
int symbol=0;
char ch=getchar();
while (!isdigit(ch))
(ch=='-')?(symbol=1):(0),ch=getchar();
while (isdigit(ch))
(ans=ans*10+(ch-'0')),ch=getchar();
return (symbol)?(-ans):(ans);
}
void ayano(io_t x,char spliter){
//快写
if (!x){
putchar('0'),putchar(spliter);
return;
}
if (x<0)
putchar('-'),x=-x;
int len=0;
while (x){
io_t d=x/10;
shin[len++]=x-d*10;
x=d;
}
while (len){
len--;
putchar(shin[len]+'0');
}
putchar(spliter);
}