CF1254D 根号分治+序列分块

· · 题解

\tt Link

考虑分析一次修改的本质是什么:

不妨让每个点加的值从 \dfrac{n-s}n 变成 n-s,最后乘上一个 n^{-1}\bmod998244353 即可。

我们记一个点的度数为 d_u。然后考虑两种算法:

哦,但是还要注意如果发生过修改 \{u,x\},那么 u 的答案最后要加上 x

这两种算法的复杂度分别是 \mathcal O(kd_u)\mathcal O(t),其中 t 是你选择枚举的点数,k 是你选择的数据结构修改的复杂度。

这并不优。于是常规手段,根号摊平。(以下的查询指计算这个点对其它点的贡献)

数据结构我们希望做到修改 \mathcal O(1),这样才能保证最后根号的复杂度。

根号题用根号 \tt DS 是极好的,我们选择采用序列分块,这个子树加单点和是常规的 \tt dfs 时间戳。

在补充说说。暴力 \tt2 里有一步是对于 u,v 要找到一个 w,满足 wv 的儿子,也是 u 的祖先。这一步其实挺像 \tt LCA,考虑树剖解决:

我们魔改树求 \tt LCA,在最后一步,两个点都在一条重链的时候,深度小的那个点在重链上的儿子就是答案。

还有一种情况,就是它们中深度小的那个在重链底,这时答案是它的某个轻儿子,我们无法得知。这种情况需要在跳重链的时候,如果某一步出现 x 的重链顶的父亲是 y,就不跳了,返回 x 的重链顶就是答案。

我觉得写的可读性已经很高了。

如果 \texttt{WA on test 6} 且某个结果你输出了负数,那么不妨检查哪一步没有取模,并且最后判断:如果自己的答案是负数,应该加上模数保证输出的结果不是负数。

const int N = 1.5e5 + 5,SQ = 405;
const ll mod = 998244353;

int n,m; ll inv_n,add[N];
vector<int> G[N],big;

int B,C,st[SQ],ed[SQ],bl[N];
ll sum[SQ],cnt[N];

int dep[N],siz[N],son[N],fa[N];
int dfn[N],top[N],tot = 0;

#define Big(i) ((signed)G[i].size() > B)
#define rdfn(u) dfn[u] + siz[u] - 1
#define subt(u) dfn[u],rdfn(u)

namespace HLC{
    void dfs1(int u = 1,int ft = 0){
        dep[u] = dep[fa[u] = ft] + (siz[u] = 1);
        for(int v : G[u]) if(v != ft){
            dfs1(v,u); siz[u] += siz[v];
            if(siz[v] > siz[son[u]]) son[u] = v;
        } if(Big(u)) big.push_back(u);
    }

    void dfs2(int u = 1,int tp = 1){
        dfn[u] = ++tot; top[u] = tp;
        if(son[u]) dfs2(son[u],tp);
        for(int v : G[u]) if(!top[v]) dfs2(v,v);
    }

    int son_lca(int x,int y){
        while(top[x] != top[y]){
            if(dep[top[x]] < dep[top[y]]) swap(x,y);
            if(fa[top[x]] == y) return top[x];
            x = fa[top[x]];
        } return son[dep[x] < dep[y] ? x : y];
    }
} using namespace HLC;

namespace DS{
    void init(){
        B = sqrt(n); C = (n - 1) / B + 1;
        rep(i,1,C){
            st[i] = ed[i - 1] + 1;
            ed[i] = i == C ? n : i * B;
            rep(j,st[i],ed[i]) bl[j] = i;
        }
    }

    void upd(int l,int r,ll x){
        x %= mod; ++r;
        (cnt[l] += x) %= mod;
        (cnt[r] += mod - x) %= mod;
        (sum[bl[l]] += x) %= mod;
        (sum[bl[r]] += mod - x) %= mod;
    }

    ll qry(int p){
        ll res = 0;
        rep(i,1,bl[p] - 1) (res += sum[i]) %= mod;
        rep(i,st[bl[p]],p) (res += cnt[i]) %= mod;
        return res;
    }
} using namespace DS;

void Upd(int u,ll x){
    (add[u] += x) %= mod; if(Big(u)) return;
    for(int v : G[u]) if(v != fa[u])
        DS::upd(subt(v),x * (n - siz[v]));
    if(u == 1) return;
    ll y = x * siz[u]; upd(subt(1),y); upd(subt(u),-y);
}

ll Qry(int u){
    ll res = qry(dfn[u]);
    for(int v : big) if(u != v){
        if(!(dfn[v] <= dfn[u] && rdfn(u) <= rdfn(v)))
            (res += add[v] * siz[v] % mod) %= mod;
        else (res += add[v] * (n - siz[son_lca(u,v)]) % mod) %= mod;
    } res = ((inv_n * res % mod) + add[u]) % mod;
    return (res % mod + mod) % mod;
}

ll fpow(ll v,int b){
    ll res = 1;
    while(b){
        if(b & 1) (res *= v) %= mod;
        (v *= v) %= mod; b >>= 1;
    }
    return res;
}

signed main(){
    read(n,m);
    inv_n = fpow(n,mod - 2);
    rep(i,1,n - 1){
        int u,v; read(u,v);
        G[u].push_back(v);
        G[v].push_back(u);
    }
    init(); dfs1(); dfs2();
    while(m--){
        if(read() == 1){ int x = read(); Upd(x,read()); }
        else print(Qry(read()),'\n');
    }
    return 0;
}