磁控法则 题解

· · 题解

显然我们并不需要具体是哪个磁极,只需要只要磁极是否相同。

y 为树的根,设 f_{i,0/1} 表示从 i 走到 y,磁极相同/不同时的期望步数。

\begin{cases}f_{i,0}=pf_{fa_i,1}+(1-p)(f_{i,0}+1)+1,f_{i,1}=(1-p)f_{fa_i,1}+p(f_{i,0}+1)+1&i\text{是叶子节点}\\f_{i,0}=pf_{fa_i,1}+(1-p)\sum\dfrac{f_{j,0}}{so_i}+1,f_{i,1}=(1-p)f_{fa_i,1}+p\sum\dfrac{f_{j,0}}{so_i}+1&i\text{不是叶子节点}\end{cases},其中 ji 的儿子,fa_ii 的父亲,so_ii 的儿子节点个数。特别的,f_{y,0}=f_{y,1}=0

Subtask 1

直接写式子然后高斯消元。

Subtask 2&3

(我感觉这两部分做法是一样的)

考虑人脑消元。

对叶子节点的式子变形,有 f_{i,0}=f_{fa_i,1}+\dfrac{2-p}{p},f_{i,1}=f_{fa_i,1}+3

发现递推式中只涉及到 f_{fa_i,1} 以及常数项,考虑归纳证明每个节点的递推式只与其父亲有关。

对于非叶子节点 i,假设其后代节点均满足此规律。设 \sum\dfrac{f_{j,0}}{so_i}=k_1f_{i,1}+k_2

f_{i,0}=pf_{fa_i,1}+(1-p)(k_1f_{i,1}+k_2)+1,f_{i,1}=(1-p)f_{fa_i,1}+p(k_1f_{i,1}+k_2)+1

变形,得 f_{i,0}=\dfrac{p+k_1-2pk_1}{1-pk_1}f_{fa_i,1}+\dfrac{k_1+k_2-2pk_1-pk_2+1}{1-pk_1},f_{i,1}=\dfrac{1-p}{1-pk_1}f_{fa_i,1}+\dfrac{1+pk_2}{1-pk_1},得证。

因此我们只需要先进行一次树形dp求出递推式的系数和常数,然后再递推得出结果。

Subtask 2 直接以每个点为根的情况各进行一次dp,Subtask 3 直接以 1 为根进行dp即可。

Subtask 4

容易发现每个点的递推式实际上只和根的方向有关。

所以可以用换根dp求出所有的递推式再利用一些奇怪的操作来回答每个询问。

这样做还是很麻烦,所以我们考虑再简化递推式。

实际上把每个点的递推式的系数打表打出来会发现都为 1

同样考虑归纳证明容易证得,于是得到 k_1=1

于是有递推式 f_{i,0}=f_{fa_i,1}+k_2+2,f_{i,1}=f_{fa_i,1}+\dfrac{1+pk_2}{1-p}

于是这道题就成了一道换根dp+树上链求和的板题,预处理逆元可以做到 O(n+q\log n)

当然具体实现还有很多细节(特别是要分清 f_{i,0}f_{i,1},以及注意链跳的方向)。

#include<bits/stdc++.h>
using namespace std;
#define mod 998244353
int lw[1000005],bi[2000005][2],bs;
int s[1000005],tf[1000005],f[1000005][2],ff[1000005][2],sf[1000005][2];
int fa[1000005],so[1000005],si[1000005],de[1000005],to[1000005];
int n,q,p,np,nfp,ny[1000005];
int dr()
{
    int xx=0;char ch=getchar();
    while(ch<'0'||ch>'9')ch=getchar();
    while(ch>='0'&&ch<='9')xx=(xx<<1)+(xx<<3)+ch-'0',ch=getchar();
    return xx;
}
int P(int x,int y=mod-2)
{
    int z=1;
    for(;y;x=1ll*x*x%mod,y>>=1)if(y&1)z=1ll*z*x%mod;
    return z;
}
void tj(int u,int v){++bs,bi[bs][0]=lw[u],bi[bs][1]=v,lw[u]=bs;}
void dfs1(int w,int fath,int d)
{
    fa[w]=fath,si[w]=1,de[w]=d;
    int x;
    for(int o_o=lw[w];o_o;o_o=bi[o_o][0])
    {
        int v=bi[o_o][1];
        if(v!=fath)
        {
            ++s[w],dfs1(v,w,d+1),tf[w]=(tf[w]+f[v][0])%mod,si[w]+=si[v];
            if(si[v]>si[so[w]])so[w]=v;
        }
    }
    if(s[w])x=1ll*tf[w]*ny[s[w]]%mod,f[w][0]=(2+x)%mod,f[w][1]=1ll*(1+1ll*p*x)%mod*nfp%mod;
    else f[w][0]=1ll*(2-p+mod)*np%mod,f[w][1]=3;
}
void dfs2(int w)
{
    if(s[w]==0)return;
    int ns=ny[s[w]-(fa[w]?0:1)],tt=(tf[w]+ff[w][0])%mod,x;
    for(int o_o=lw[w];o_o;o_o=bi[o_o][0])
    {
        int v=bi[o_o][1];
        if(v!=fa[w])
        {
            if(fa[w]||s[w]>1)x=1ll*(tt-f[v][0]+mod)*ns%mod,ff[v][0]=(2+x)%mod,ff[v][1]=1ll*(1+1ll*p*x)%mod*nfp%mod;
            else ff[v][0]=1ll*(2-p+mod)*np%mod,ff[v][1]=3;
            dfs2(v);
        }
    }
}
void dfs(int w,int t)
{
    to[w]=t,sf[w][0]=(sf[fa[w]][0]+f[w][1])%mod,sf[w][1]=(sf[fa[w]][1]+ff[w][1])%mod;
    if(so[w])dfs(so[w],t);
    for(int o_o=lw[w];o_o;o_o=bi[o_o][0])
    {
        int v=bi[o_o][1];
        if(v!=fa[w]&&v!=so[w])dfs(v,v);
    }
}
int lca(int x,int y)
{
    int fx=to[x],fy=to[y];
    while(fx!=fy)if(de[fx]>de[fy])x=fa[fx],fx=to[x];else y=fa[fy],fy=to[y];
    return de[x]<de[y]?x:y;
}
int gs(int x,int y)
{
    int fx=to[x],fy=to[y];
    while(fx!=fy)
    {
        if(fa[fx]==y)return fx;
        x=fa[fx],fx=to[x];
    }
    return so[y];
}
int main()
{
    n=dr(),q=dr(),p=dr(),np=P(p),nfp=P(mod+1-p);
    ny[1]=1;
    for(int i=2;i<=n;i++)ny[i]=1ll*(mod-mod/i)*ny[mod%i]%mod;
    for(int i=1;i<n;i++)
    {
        int u=dr(),v=dr();
        tj(u,v),tj(v,u);
    }
    dfs1(1,0,1),dfs2(1),dfs(1,1);
    while(q--)
    {
        int x=dr(),y;char c1=getchar(),c2;
        while(c1!='N'&&c1!='S')c1=getchar();
        y=dr(),c2=getchar();
        while(c2!='N'&&c2!='S')c2=getchar();
        int a=lca(x,y),ans=((sf[x][0]-sf[a][0]+sf[y][1]-sf[a][1])%mod+mod)%mod;
        if(c1==c2)
        {
            if(x==a)
            {
                int b=gs(y,a);
                ans=((ans-ff[b][1]+ff[b][0])%mod+mod)%mod;
            }
            else ans=((ans-f[x][1]+f[x][0])%mod+mod)%mod;
        }
        printf("%d\n",ans);
    }
}

关于为什么 p\geq 2

其实就是直接帮你们排除了 p=0 以及 p=1 的情况,因为上面所有的式子变形都基于 p\neq 0p\neq 1