AT_abc310_g 题解

· · 题解

鉴定为,不用动脑子题。

首先期望是假的,直接求出 1\sim K 时刻每个人得到的球数总和再除以 K 即可。

那么考虑对于每个人拥有的球,算他对所有人的贡献即可。

考虑连边 i\to A_i,发现构成一个内向基环树森林。

考虑一个内向基环树上的某个点 i,如果他在环上,发现 B_i 会对环上的所有点都贡献至少 \lfloor\frac{K}{\text{len}}\rfloor 次,其中 \text{len} 是环长,这部分直接整体打 tag。

剩下的 K\bmod \text{len} 次操作形如前后缀加,随便给环上的点编号然后对于每个环开一个树状数组即可。

对于不在环上的点,他在整个过程中只会经过他同样不在环上的且距离他小于等于 K 的祖先一次,这个部分随便树剖下就做完了。他对环的贡献与上一段内容完全相同。

好像可以直接差分,但是无所谓,复杂度 O(n\log^2n),常数比较小。

#include "bits/stdc++.h"
#define int long long 
#define f(i ,m ,n ,x) for (int i = (m) ,i##END = (n) ; i <= i##END ; i += (x))
#define f_(i ,m ,n ,x) for (int i = (m) ,i##END = (n) ; i >= i##END ; i -= (x))

const int N = 2e5 + 25 ,mod = 998244353 ;
int n ,k ,a[N] ,b[N] ,dfn[N] ,hson[N] ,top[N] ,siz[N] ,dep[N] ,tag[N] ;
int id[N] ,rt[N] ,tot ,incir[N] ,cnt ,len[N] ,end[N] ,beg[N] ; bool vis[N] ;
std :: vector < int > to[N] ;

class Bit {
private :
    int ans[N] ;

    inline int Lowbit (int x) { return x & (-x) ;}

    inline void Update (int x ,int v) 
    { f (i ,x ,n ,Lowbit (i)) (ans[i] += v) %= mod ;}

public :
    inline void Modify (int l ,int r ,int v) 
    { return Update (l ,v) ,Update (r + 1 ,(mod - v) % mod) ;}

    inline int Query (int x) {
        int res = 0 ; 
        f_ (i ,x ,1 ,Lowbit (i)) (res += ans[i]) %= mod ;
        return res ;
    } 
} bit1 ,bit2 ;

signed main (void) {
    std :: ios :: sync_with_stdio (false) ,
    std :: cin.tie (nullptr) ,std :: cout.tie (nullptr) ;

    std :: cin >> n >> k ; 
    f (i ,1 ,n ,1) std :: cin >> a[i] ; f (i ,1 ,n ,1) std :: cin >> b[i] ;
    f (i ,1 ,n ,1) to[a[i]].emplace_back (i) ;

    auto Dfs1 = [] (auto &&self ,int cur ,int dad ,int Root) -> void {
        vis[cur] = true ,dep[cur] = dep[dad] + 1 ,rt[cur] = Root ,siz[cur] = 1 ; 
        for (auto nex : to[cur]) {
            self (self ,nex ,cur ,Root) ;
            (siz[nex] > siz[hson[cur]]) && (hson[cur] = nex) ,
            siz[cur] += siz[nex] ;
        }
    } ;

    auto Dfs2 = [] (auto &&self ,int cur ,int tp) -> void {
        dfn[cur] = ++tot ,top[cur] = tp ;
        if (hson[cur]) self (self ,hson[cur] ,tp) ;
        else return ;
        for (auto nex : to[cur]) if (nex != hson[cur])
            self (self ,nex ,nex) ;
    } ;

    int Id = 0 ; f (i ,1 ,n ,1) {
        if (!vis[i]) {
            cnt++ ;
            static std :: vector < int > vec ,cir ; vec.clear () ,cir.clear () ;
            int cur = i ; while (!vis[cur])
                vec.emplace_back (cur) ,vis[cur] = true ,cur = a[cur] ;

            int x = 0 ; while ((x = vec.back ()) != cur)
                incir[x] = cnt ,cir.emplace_back (x) ,vec.pop_back () ;
            incir[cur] = cnt ,cir.emplace_back (cur) ;

            len[cnt] = cir.size () ;
            int ptr = len[cnt] ; for (auto pt : cir) id[pt] = Id + ptr-- ;
            beg[cnt] = Id + 1 ,Id += len[cnt] ,end[cnt] = Id ;
            for (auto pt : cir) for (auto nex : to[pt]) if (!incir[nex])
                Dfs1 (Dfs1 ,nex ,pt ,nex) ,Dfs2 (Dfs2 ,nex ,nex) ;
        } 

        if (!incir[i]) {
            int cur = a[i] ;
            if (!incir[cur]) {
                int sum = std :: min (k ,dep[cur]) ;
                while (sum) {
                    int tp = top[cur] ;
                    if (sum >= dep[cur] - dep[tp] + 1)
                        bit1.Modify (dfn[tp] ,dfn[cur] ,b[i]) ,sum -= dep[cur] - dep[tp] + 1 ,cur = a[tp] ;
                    else bit1.Modify (dfn[cur] - sum + 1 ,dfn[cur] ,b[i]) ,sum = 0 ;
                }
            }

            int rst = k - (dep[i] - 1) ;
            if (rst > 0) {
                int bel = incir[a[rt[i]]] ;
                (tag[bel] += ((rst / len[bel]) % mod) * b[i] % mod) %= mod ;
                int md = rst % len[bel] ;
                if (md) {
                    bit2.Modify (id[a[rt[i]]] ,std :: min (id[a[rt[i]]] + md - 1 ,end[bel]) ,b[i]) ;
                    md -= end[bel] - id[a[rt[i]]] + 1 ; if (md > 0)
                        bit2.Modify (beg[bel] ,beg[bel] + md - 1 ,b[i]) ;
                }
            }
        } else {
            int rst = k ;
            int bel = incir[i] ;
            (tag[bel] += ((rst / len[bel]) % mod) * b[i] % mod) %= mod ;
            int md = rst % len[bel] ;
            if (md) {
                bit2.Modify (id[a[i]] ,std :: min (id[a[i]] + md - 1 ,end[bel]) ,b[i]) ;
                md -= end[bel] - id[a[i]] + 1 ; if (md > 0)
                    bit2.Modify (beg[bel] ,beg[bel] + md - 1 ,b[i]) ;
            }
        }
    }

    int inv = [] (int bs ,int p = mod - 2) -> int {
        int ans = 1 ; bs %= mod ;
        for (; p ; p >>= 1 ,bs = bs * bs % mod)
            (p & 1) && (ans = ans * bs % mod) ; return ans ;
    } (k) ; f (i ,1 ,n ,1) {
        if (incir[i]) {
            int ans = (tag[incir[i]] + bit2.Query (id[i])) % mod ;
            std :: cout << ans * inv % mod << " \n"[i == n] ;
        } else std :: cout << bit1.Query (dfn[i]) * inv % mod << " \n"[i == n] ;
    }
    return 0 ;
}