P11516 Sol

· · 题解

给我调秃了啊淦。细节真的多。本题解会尽量把我感到比较疑惑的地方,尤其是 DP 的优化细节写得详细一点。

令题面里原有的 A,B,R 变为 a,b,rt,然后先手叫 A,后手叫 B。下文的“起点”“终点”全部是对于一个人的单次操作而言的。

先把 rt 拎起来作根。A 不能走走过的路,所以 A 每次都只能往更深的地方走。然后 B 每次可以移动到大小为 b 的邻域(最后一次必须移到子树内)。

第一档分是 a\le b,考虑它有什么性质。发现这时主动权完全在 B 手上,A 每次跳完,B 就一定可以跳到目前位置和点 1 的 LCA 上,跳着跳着就把 A 卡死了,所以答案就是 1。那么现在只用考虑 b<a 的情况。

然后这里我们选择先二分一个答案 mid 使求解的东西变得简单一点。(看完了下面的内容你可以想想如果这里不二分的话会发生什么。实际上 O(n^2) 的 DP 还是可以做的,但是后续再往下砍复杂度似乎会变得非常困难。反正我自己做的时候就是推完 O(n^2) 没想起来二分然后寄了。)

显然每次 A 行动的起点只会越来越深,这支持我们设计某种 DP。设 f_i 表示 A 从 i 点开始走,最后终点会不会 \le midg_i 表示 B 从 i 点开始走最后终点会不会 \le mid。理论上按深度从大到小去更新就没有什么问题。进一步考虑这个 DP 的细节。

f_i:预处理出每个点的所有 a 级儿子。

g_i:按理来说就是看 i 大小为 B 的邻域里的 f_i 有没有是 1 的就好了。但是这里会出现本题最恶心的一个细节:

如果 B 的一次操作终点是起点的祖先,那么 A 下一次操作就不能顺着这个起点方向的儿子再走下去,因为已经被走过了。

所以我们未必能直接拿邻域里的 f_j 去更新 g_i。更具体地,当 ji 祖先时,我们要去掉 f_j 中走回了 i 所在子树的情况。

这里精细实现一下已经可以获得一个 O(n^2\log n) 的 DP,不二分就没有 \log。然后考虑优化。

实际上唯一能超时限的只有 g_i 的更新。因为已经二分过答案了,这里 g_i 的更新可以当成:

这样就跟邻域没有什么关系了。考虑这样要怎么做。时刻记住我们的 DP 是在按深度从大到小更新 f_i,设现在的深度是 d,而现在是在确定所有深度为 d+ag_i 用于转移 f

考虑某种做法:我们不倾向于对每个 i 去寻找可能的 j 进行转移,而是在每个 j 处去尝试更新所有 i 的答案。

你注意到这里所有和 i 距离不超过 b 的点 j 的深度都不小于 d+a-b。而 b<a,说明这些点 jf_j 已经被我们更新过了,我们可以放心地用。所以我们可以在每次到一个新的 d 的时候,把所有深度为 d+a-b 的点 j 计入考虑范围。

当一个点 j 被计入考虑范围时,我们需要做这些事:

那么求 g_i 的时候只要把 t_i 和线段树里面的东西 +dep_i 拿出来取个 \min 就大功告成了。按上面说的更新 dep_i=d 上的所有 f_i 啥的就好。

至此做完了。O(n\log^2 n)。我这代码目前是 qoj 的第三短解,没卡常大概 2.5s,性价比挺高。

我自己想的进度是写对了无二分的 O(n^2) DP 和链。感觉最厉害的部分还是优化 g_i 转移那里想到去反向更新,然后进一步想到用逐层更新+线段树去维护这个过程。可能邻域相关套路我还是没见过多少。剩下的我认为思路其实比较自然,只要认真顺着这个做法去分析可能的情况就可以了。而分析错了就会感受到调到头秃的酸爽。

#include <bits/stdc++.h>
#define pb push_back
#define fi first
#define se second
using namespace std; bool MEM;
using ll=long long; using ld=long double;
using pii=pair<int,int>; using pll=pair<ll,ll>;
const int I=1e9,N=3e5+7;
const ll J=1e18;
int n,rt,a,b;
int dep[N],fa[N],dfn[N],cnn,dfm[N],co[N],st[N],tp;
vector<int> e[N],e1[N],to[N],de[N];
void dfs(int p,int f) {
    dfn[p]=++cnn,co[cnn]=p,dep[p]=dep[f]+1,de[dep[p]].pb(p);
    fa[p]=f,st[++tp]=p;
    if (tp>a) to[st[tp-a]].pb(p);
    for (int i:e[p]) if (i!=f) e1[p].pb(i),dfs(i,p);
    tp--,dfm[p]=cnn;
}
struct sgt {
    int t[N<<2];
    void ini() { memset(t,60,sizeof(t)); }
    void add(int x,int y,int z,int p=1,int l=1,int r=n) {
        if (x>y) return;
        if (x<=l&&r<=y) return t[p]=min(t[p],z),void();
        int mid=(l+r)>>1;
        if (x<=mid) add(x,y,z,p<<1,l,mid);
        if (y>mid) add(x,y,z,p<<1|1,mid+1,r);
    }
    int que(int x,int p=1,int l=1,int r=n) {
        if (l==r) return t[p];
        int mid=(l+r)>>1;
        if (x<=mid) return min(t[p],que(x,p<<1,l,mid));
        else return min(t[p],que(x,p<<1|1,mid+1,r));
    }
} T;
int mnn[N],f[N],g[N],t[N],tmp[N],is[N];
bool chk(int mid) {
    T.ini();
    for (int i=1;i<=n;i++) t[i]=1e8,mnn[i]=i<=mid?dep[i]:I;
    for (int d=n;d;d--) {
        if (d+a-b<=n) {
            for (int p:de[d+a-b]) {
                if (f[p]) t[p]=dep[p];
                else { t[p]=I; for (int i:e1[p]) t[p]=min(t[p],t[i]); }
                T.add(dfn[fa[p]]+1,dfn[p]-1,t[p]-dep[fa[p]]*2);
                T.add(dfm[p]+1,dfm[fa[p]],t[p]-dep[fa[p]]*2); // 以上是第一种转移
                int len=e1[p].size(),sum=0,mus=0,le=0;
                for (int i=0,j=0;i<len;i++) {
                    tmp[i]=1,is[i]=1;
                    while (j<to[p].size()&&dfn[to[p][j]]<=dfm[e1[p][i]])
                        tmp[i]&=g[to[p][j++]],is[i]=0;
                    if (is[i]) tmp[i]=mnn[e1[p][i]]-dep[p]<=b,mus+=tmp[i];
                    else sum+=tmp[i],le++;
                }
                for (int i=0;i<len;i++) {
                    int w=0;
                    if (!is[i]) {
                        if (sum-tmp[i]!=le-1);
                        else if (le-1) w=1;
                        else w|=mus||p<=mid;
                    }
                    else {
                        if (sum!=le);
                        else if (le) w=1;
                        else w|=mus-tmp[i]||p<=mid;
                    }
                    if (w) T.add(dfn[e1[p][i]],dfm[e1[p][i]],-I);
                } // 以上是第二种转移
            }
        }
        for (int p:de[d]) {
            for (int i:e1[p]) mnn[p]=min(mnn[p],mnn[i]);
            if (to[p].size()) {
                f[p]=1;
                for (int u:to[p]) {
                    int w=min(t[u]-dep[u],dep[u]+T.que(dfn[u]));
                    if (w<=b) g[u]=1;
                    else f[p]=0,g[u]=0;
                }
            }
            else f[p]=mnn[p]-dep[p]<=b;
        }
    }
    return f[rt];
}
void mian() {
    scanf("%d%d%d%d",&n,&rt,&a,&b),cnn=0;
    for (int i=1;i<=n;i++) 
        e[i].clear(),e1[i].clear(),to[i].clear(),de[i].clear();
    for (int i=1,x,y;i<n;i++) scanf("%d%d",&x,&y),e[x].pb(y),e[y].pb(x);
    if (b>=a) return cout<<"1\n",void();
    dfs(rt,0);
    int l=1,r=n,mid,res=0;
    while (l<=r) {
        mid=(l+r)>>1;
        if (chk(mid)) res=mid,r=mid-1;
        else l=mid+1;
    }
    cout<<res<<"\n";
}
bool ORY; int main() {
    // while (1)
    // int t; for (scanf("%d",&t);t--;)
    mian();
    cerr<<"\n"<<abs(&MEM-&ORY)/1048576<<"MB "<<clock()<<"ms\n";
    return 0;
}