我能在使用分块换根的情况下通过此题吗?

· · 题解

观前小贴士:

本文仅会对具体做法进行简述,重点在于介绍优化常数的方法,如对具体做法有疑惑者可搭配本文姊妹篇食用,同时也相当于是对 mrsrz 大佬强大的常数优化方式的一些延伸吧。

思路浅析:

列出题目中求的式子即为 \sum\limits_{i=l}^r \sum\limits_{j=i+1}^r dis_{i,j},拆开来看,即为 \sum\limits_{i=l}^r \sum\limits_{j=i+1}^r dep_i+dep_j-2\times dep_{\operatorname{lca}_{i,j}}

剩下的 $\sum\limits_{i=l}^r \sum\limits_{j=i+1}^r-2\times dep_{\operatorname{lca}_{i,j}}$ 是形如 $\sum\limits_{i=l1}^{r1}\sum\limits_{j=l2}^{r2}val(lca(id_i,id_j))$ 的东西,使用分块+换根 DP 预处理整块对整块,整块对散块,散块对整块的答案,散块对散块部分建虚树后换根 DP 就能在 $O((n+m)\sqrt n)$ 内解决,具体讲解可参考 mrsrz 大佬的题解或[蒟蒻的乐色](https://www.luogu.com.cn/article/ugz2jc33)(~~无耻广告是吧~~)。 不过注意到暴力开 $O(n\sqrt n)$ 的数组存答案空间就飞天了,解决方案也是简单的:离线逐块换根后统计贡献即可,空间复杂度被降到了 $O(n\log n)$(瓶颈在于 ST 表)。 于是现在你就能写出如下代码了: (注意到建虚树的过程中可以得到虚树的 dfs 序以避免 dfs 带来的大常数,下述代码已使用该方法) ```cpp #include <bits/stdc++.h> #define ll int #define un unsigned using namespace std; int n,m; struct ed { ll v,w,next; }edge[400005]; ll head[200005],cnt; void add(ll u,ll v,ll w) { edge[++cnt].v=v;edge[cnt].w=w;edge[cnt].next=head[u];head[u]=cnt; edge[++cnt].v=u;edge[cnt].w=w;edge[cnt].next=head[v];head[v]=cnt; } int fa[1805],uid[1805]; int dep[200005],in[200005],tot,f[400005][20],lg[400005],to[200005]; int dfn[200005],cot,sz[200005],ss,id[200005],from[200005],L[1005],R[1005],fat[200005]; un ll dis[200005],sum[200005],pre_res[200005]; void ad(ll u,ll v) { uid[++tot]=u;fa[uid[tot]]=v;sz[u]=0; } bool cmp(int a,int b) { return dfn[a]<dfn[b]; } void dfs(ll id,ll fa) { dep[id]=dep[fa]+1;f[in[id]=++tot][0]=id; dfn[id]=++cot;to[cot]=id;fat[id]=fa; for(ll i=head[id];i;i=edge[i].next) { ll v=edge[i].v,w=edge[i].w; if(v==fa)continue; dis[v]=dis[id]+w; dfs(v,id); f[++tot][0]=id; } } int lca(int u,int v) { u=in[u];v=in[v]; if(u>v)swap(u,v); ll len=lg[v-u+1]; return (dep[f[u][len]]<dep[f[v-(1<<len)+1][len]]?f[u][len]:f[v-(1<<len)+1][len]); } bool vis[200005]; int st[1805],tp; struct px { int l,r; un ll res; }ask[200005]; int a[2],d[2][455],s[2]; int p[905]; void pre() { for(ll i=0;i<=n;++i)sz[i]=0; for(ll i=n;i;--i) { ll u=to[i]; if(vis[u])++sz[u]; sz[fat[u]]+=sz[u]; } for(ll i=1;i<=n;++i) { ll u=to[i]; pre_res[u]=pre_res[fat[u]]+(sz[fat[u]]-sz[u])*1ull*dis[fat[u]]; sum[u]=sz[u]*dis[u]+pre_res[u]; } } int main() { ios::sync_with_stdio(0);cin.tie(0);cout.tie(0); cin>>n>>m; ss=sqrt(n); ll u,v,w; for(ll i=1;i<n;++i)cin>>u>>v>>w,add(u,v,w); dfs(1,0); for(ll i=1;i<=n;++i)id[i]=i,sum[i]=sum[i-1]+dis[i],from[i]=(i-1)/ss+1; for(ll i=1;i<=from[n];++i)L[i]=(i-1)*ss+1,R[i]=i*ss; R[from[n]]=n; for(ll i=1;i<=from[n];++i)sort(id+L[i],id+1+R[i],cmp); for(ll i=2;i<=tot;++i)lg[i]=lg[i>>1]+1; for(ll i=tot;i;--i) { for(ll j=1;(1<<j)+i-1<=tot;++j) f[i][j]=((dep[f[i][j-1]]<dep[f[i+(1<<(j-1))][j-1]]?f[i][j-1]:f[i+(1<<(j-1))][j-1])); } dfn[0]=n+1; for(ll i=1;i<=m;++i) { cin>>ask[i].l>>ask[i].r; ll l=ask[i].l,r=ask[i].r; ask[i].res=(sum[r]-sum[l-1])*(r-l+1); a[0]=a[1]=cot=0; if(from[l]==from[r]) { for(ll i=l;i<=r;++i)vis[i]=1; for(ll i=L[from[l]];i<=R[from[l]];++i) if(vis[id[i]])d[0][++a[0]]=id[i]; } else { for(ll i=l;i<=R[from[l]];++i)vis[i]=1; for(ll i=L[from[l]];i<=R[from[l]];++i) if(vis[id[i]])d[0][++a[0]]=id[i]; for(ll i=L[from[r]];i<=r;++i)vis[i]=1; for(ll i=L[from[r]];i<=R[from[r]];++i) if(vis[id[i]])d[1][++a[1]]=id[i]; } if(ask[i].l!=1)p[++cot]=1; d[0][a[0]+1]=d[1][a[1]+1]=0; s[0]=s[1]=1; while(s[0]<=a[0]||s[1]<=a[1]) { if(dfn[d[0][s[0]]]<dfn[d[1][s[1]]]) p[++cot]=d[0][s[0]++]; else p[++cot]=d[1][s[1]++]; } st[tp=1]=1;tot=0; for(ll i=2;i<=cot;++i) { ll id=p[i]; ll c=lca(id,st[tp]); if(c!=st[tp]) { while(tp-1&&dfn[c]<dfn[st[tp-1]]) ad(st[tp],st[tp-1]),--tp; ad(st[tp],c); --tp;if(c!=st[tp])st[++tp]=c; } st[++tp]=id; } while(tp-1)ad(st[tp],st[tp-1]),--tp; uid[++tot]=1; for(ll i=1;i<=tot;++i)sz[i]=0; for(ll i=1;i<=tot;++i) { u=uid[i]; if(vis[u])++sz[u]; sz[fa[u]]+=sz[u]; } un ll res=0; for(ll i=tot;i>=1;--i) { u=uid[i]; pre_res[u]=pre_res[fa[u]]+(sz[fa[u]]-sz[u])*1ull*dis[fa[u]]; if(vis[u])res+=pre_res[u]+dis[u]*1ull*sz[u]; } ask[i].res-=res; if(from[l]==from[r])for(ll i=l;i<=r;++i)vis[i]=0; else { for(ll i=l;i<=R[from[l]];++i)vis[i]=0; for(ll i=L[from[r]];i<=r;++i)vis[i]=0; } } for(ll i=2;i<from[n];++i) { for(ll j=L[i];j<=R[i];++j)vis[j]=1; pre(); for(ll i=1;i<=n;++i)sum[i]+=sum[i-1]; for(ll j=1;j<=m;++j) { ll gl=from[ask[j].l],gr=from[ask[j].r]; if(gl<i&&i<gr) { ask[j].res-=(sum[ask[j].r]-sum[ask[j].l-1]); ask[j].res-=(sum[R[gl]]-sum[ask[j].l-1]+sum[ask[j].r]-sum[L[gr]-1]); } } for(ll j=L[i];j<=R[i];++j)vis[j]=0; } for(ll i=1;i<=m;++i)cout<<ask[i].res<<'\n'; } ``` 这就做...完了? 当然不是辣,你会喜获[千里江山满地黑](https://www.luogu.com.cn/record/181459518)的好成绩。 ### 卡常!卡常!卡常! ~~我们终于进入了正文部分。~~ 注意到上面的代码主要分成了两个部分,因此我们的卡常也分为两个部分: #### 散块处理时的建虚树+换根部分 (注意到 $dep_{\operatorname{lca}_{i,j}}$ 前的系数 $-2$ 恰好与钦定 $i<j$ 相抵消,故所有 $\times 2$ 均是可以省略的,于是就可以直接自然溢出了,此优化已在上面的代码中使用,特此说明。) 注意到我们循环的次数实在是太多了,考虑优化。 mrsrz 大佬告诉我们可以在建虚树的同时计算贡献。 具体的,在一个点被弹出单调栈时就意味着其子树内形态已经固定了,此时我们就可以计算出它子树内的特殊点和其父亲子树内其它子树特殊点的贡献了,即为 $sz_u\times sz_{fa_u}\times dep_{fa_u}$,并将特殊点个数累加到父亲节点上。 可以发现,这样一对点恰好只计算了一次,所以记得 $\times 2$。 另一个重要的优化则是**寻址连续**,ST 表调换数组维度,将 dfs 序相邻的两个关键点预处理 LCA,直接将一个点的子树内特殊点个数存在其栈中的下标处,这些优化均能减少寻址时间,在常数上带来巨大飞跃。 现在的核心代码: ```cpp st[tp=1]=yc[1]=1;tot=0; for(ll i=2;i<=cot;++i)yc[i]=lca(p[i-1],p[i]); for(ll i=1;i<=cot;++i) { ll id=p[i]; if(yc[i]!=st[tp]) { ll c=yc[i]; while(tp-1&&dfn[c]<=dfn[st[tp-1]]) { res+=sz[tp-1]*sz[tp]*dis[st[tp-1]]; sz[tp-1]+=sz[tp]; --tp; } st[tp]=c; } st[++tp]=id; sz[tp]=1; } while(tp-1) { res+=sz[tp-1]*sz[tp]*dis[st[tp-1]]; sz[tp-1]+=sz[tp]; --tp; } ``` 可以发现,我们省下了两个循环计算贡献和清空的循环! 但是这部分跑得仍然很慢,因此块长也是需要合理设置的,为了减少虚树点数,我们需要调短块长,最终可以发现块长设置在 $[200,250]$ 之间是合适的,这样还能够让数组卡入 Cache 以提升寻址效率。 #### 整块换根+贡献累计部分 似乎没有哪个循环是可以省略的喵,可是后面部分[根据测试](https://www.luogu.com.cn/record/181495458),就跑了 3s 多,再加上前面的部分是 ~~acceptable~~unacceptable 的。 观察到 `pre()` 函数中有两个循环都有对 `to[u]` 的调用,这样的寻址是很不连续的,遂优化之,考虑直接将存子树中特殊点的个数和答案的数组拍到 dfs 序,最后前缀和的时候再调用 dfs 序上的位置即可减少一个循环的对 `to[u]` 调用。 同时我们发现还存在大量对 `from[l]`,`from[r]`,`from[l]+1`,`from[r]-1` 的调用,显然也是很不连续的,直接在存储询问的结构体中记录 `from[l]` 和 `from[r]` 即可使寻址更加连续。 最后清空记得使用 `memset`。 本部分核心代码: ```cpp void pre() { for(ll i=0;i<=n;++i)sz[i]=0; for(ll i=n;i;--i) { ll u=to[i]; if(vis[u])++sz[u]; sz[fat[u]]+=sz[u]; } for(ll i=1;i<=n;++i) { ll u=to[i]; pre_res[u]=pre_res[fat[u]]+(sz[fat[u]]-sz[u])*1ull*dis[fat[u]]; sum[u]=sz[u]*dis[u]+pre_res[u]; } } for(ll i=2;i<from[n];++i) { for(ll j=L[i];j<=R[i];++j)vis[j]=1; pre(); for(ll i=1;i<=n;++i)sum[i]+=sum[i-1]; for(ll j=1;j<=m;++j) { ll gl=from[ask[j].l],gr=from[ask[j].r]; if(gl<i&&i<gr) { ask[j].res-=(sum[ask[j].r]-sum[ask[j].l-1]); ask[j].res-=(sum[R[gl]]-sum[ask[j].l-1]+sum[ask[j].r]-sum[L[gr]-1]); } } for(ll j=L[i];j<=R[i];++j)vis[j]=0; } ``` 终于,我们相对优雅地~~打倒了毒瘤 lxl 的恶毒数据~~完成了卡常工作,以[还能看的时间](https://www.luogu.com.cn/record/181478044)通过了本题。 总结一下,我们从本题卡常部分中学到最重要的技巧就是让寻址连续以提高 Cache 的命中率,这往往能大幅减小常数。 最后,献上笔者丑陋的代码: ```cpp #include <bits/stdc++.h> #define ll int #define un unsigned using namespace std; const ll ss=250,MAXN=200000; ll n,m; struct ed { ll v,w,next; }edge[400005]; ll head[200005],cnt; void add(ll u,ll v,ll w) { edge[++cnt].v=v;edge[cnt].w=w;edge[cnt].next=head[u];head[u]=cnt; edge[++cnt].v=u;edge[cnt].w=w;edge[cnt].next=head[v];head[v]=cnt; } ll lg[400005]; ll p[ss*2+2],yc[ss*2+2]; ll dep[200005],bl[MAXN/ss+2],in[200005],tot,f[20][400005],to[200005]; ll a[2],d[2][ss+2]; ll dfn[200005],cot,id[200005]; un ll dis[200005],sum[200005],pre_res[200005]; ll from[200005],L[MAXN/ss+2],R[MAXN/ss+2],sz[200005],fat[200005],rsz[200005],ds[200005]; bool cmp(ll a,ll b) { return dfn[a]<dfn[b]; } ll sf[200005]; void dfs(ll id,ll fa) { dep[id]=dep[fa]+1;f[0][in[id]=++tot]=id; dfn[id]=++cot;to[cot]=id;fat[id]=fa;rsz[id]=1; sf[dfn[id]]=dfn[fa];ds[cot]=dis[id]; for(ll i=head[id];i;i=edge[i].next) { ll v=edge[i].v,w=edge[i].w; if(v==fa)continue; dis[v]=dis[id]+w; dfs(v,id); f[0][++tot]=id;rsz[id]+=rsz[v]; } } ll lca(ll u,ll v) { u=in[u];v=in[v]; ll len=lg[v-u+1]; return dfn[f[len][u]]<dfn[f[len][v-(1<<len)+1]]?f[len][u]:f[len][v-(1<<len)+1]; } struct px { ll l,r,bl,br; un ll res; }ask[200005]; ll ns[200005]; void pre(ll l,ll r) { memset(sz,0,sizeof(sz)); for(ll i=l;i<=r;++i)sz[dfn[i]]=1; for(ll i=n;i;--i)sz[sf[i]]+=sz[i]; for(ll i=1;i<=n;++i) { pre_res[i]=pre_res[sf[i]]+(sz[sf[i]]-sz[i])*ds[sf[i]]; ns[i]=sz[i]*ds[i]+pre_res[i]; } } ll st[ss*4+2],tp; int main() { ios::sync_with_stdio(0);cin.tie(0);cout.tie(0); cin>>n>>m; ll u,v,w; for(ll i=1;i<n;++i)cin>>u>>v>>w,add(u,v,w); dfs(1,0); for(ll i=1;i<=n;++i)id[i]=i,sum[i]=sum[i-1]+dis[i],from[i]=(i-1)/ss+1; for(ll i=1;i<=from[n];++i)L[i]=(i-1)*ss+1,R[i]=i*ss; R[from[n]]=n; for(ll i=1;i<=from[n];++i)sort(id+L[i],id+1+R[i],cmp); for(ll j=1;j<19;++j) for(ll i=1;i+(1<<j)-1<=tot;++i) f[j][i]=((dep[f[j-1][i]]<dep[f[j-1][i+(1<<(j-1))]]?f[j-1][i]:f[j-1][i+(1<<(j-1))])); lg[0]=-1; for(ll i=1;i<=tot;++i)lg[i]=lg[i>>1]+1; dfn[0]=n+1; for(ll i=1;i<=m;++i) { cin>>ask[i].l>>ask[i].r; ll l=ask[i].l,r=ask[i].r; ask[i].res=(sum[r]-sum[l-1])*(r-l); a[0]=a[1]=cot=0; un ll res=0; ll fl=from[l],fr=from[r]; ask[i].bl=fl+1;ask[i].br=fr-1; if(fl==fr) { for(ll i=L[fl];i<=R[fl];++i) if(l<=id[i]&&id[i]<=r)d[0][++a[0]]=id[i]; } else { for(ll i=L[fl];i<=R[fl];++i) if(l<=id[i])d[0][++a[0]]=id[i]; for(ll i=L[fr];i<=R[fr];++i) if(id[i]<=r)d[1][++a[1]]=id[i]; if(ask[i].bl<=ask[i].br)ask[i].res+=sum[R[ask[i].br]]-sum[L[ask[i].bl]-1]; } ll s[2]={1,1}; d[0][a[0]+1]=d[1][a[1]+1]=0; while(s[0]<=a[0]||s[1]<=a[1]) { if(dfn[d[0][s[0]]]<dfn[d[1][s[1]]]) p[++cot]=d[0][s[0]++]; else p[++cot]=d[1][s[1]++]; } cot=a[0]+a[1]; st[tp=1]=yc[1]=1;tot=0; for(ll i=2;i<=cot;++i)yc[i]=lca(p[i-1],p[i]); for(ll i=1;i<=cot;++i) { ll id=p[i]; if(yc[i]!=st[tp]) { ll c=yc[i]; while(tp-1&&dfn[c]<=dfn[st[tp-1]]) { res+=sz[tp-1]*sz[tp]*dis[st[tp-1]]; sz[tp-1]+=sz[tp]; --tp; } st[tp]=c; } st[++tp]=id; sz[tp]=1; } while(tp-1) { res+=sz[tp-1]*sz[tp]*dis[st[tp-1]]; sz[tp-1]+=sz[tp]; --tp; } res=res*2; ask[i].res-=res; } for(ll i=2;i<from[n];++i) { pre(L[i],R[i]); for(ll i=1;i<=n;++i)sum[i]=sum[i-1]+ns[dfn[i]]; for(ll i=1;i<=from[n];++i) bl[i]=sum[L[i]-1]; for(ll j=1;j<=m;++j) { if(ask[j].bl<=i&&i<=ask[j].br) ask[j].res-=(2*(sum[ask[j].r]-sum[ask[j].l-1])-bl[ask[j].br+1]+bl[ask[j].bl]); } } for(ll i=1;i<=m;++i)cout<<ask[i].res<<'\n'; } ```