P9208 虚人「无」

· · 题解

线段树

乘法逆元

用 dfs 序把树映射到序列上,然后分别维护 c,v 的前缀积,由于节点的子树在 dfs 序中为一段连续区间,可以用逆元取出子树中 c 的乘积,同样可以用逆元计算不包括当前子树所有 v 的乘积,时间复杂度 O(n\log n),线性求逆元可将复杂度降至 O(n)

使用逆元求解的条件是所有参与计算的值都存在逆元,需要模数与任意 c,v 互质。题目中没有给出相关性质,即使模数为质数也可以构造出不存在逆元的情况。

线段树

现在需要一种不依赖逆元的方法快速求出某个区间内数字的乘积,不难想到线段树。

前置知识是线段树维护区间乘积,这里是模板题。

用线段树维护区间内 c,v 的乘积,支持区间查询即可,不需要延迟标记。

记得对答案取模,题目是 ACM 赛制的,同样的代码多次 WA 反馈的数据点可能不同,对查错不是很友好。如果被卡常,使用链式前向星代替 vector 可以获得一定程度提速。

时间复杂度 O(n\log n)

#include<bits/stdc++.h>
using namespace std;
struct Segment{int l,r,c,v;}s[1200001];
constexpr Segment emp=Segment{0,0,1,1};
int n,mod,pos,c[300001],v[300001],w[300001],ex[300001],dfn[300001];
vector<int> e[300001];//ex[x]记录dfs序中位置x对应点的编号,建树用
long long ans;
void dfs(int x,int las){
    w[x]=1,dfn[x]=++pos,ex[pos]=x;
    for(int i:e[x]) if(i!=las) dfs(i,x),w[x]+=w[i];
}
void pushup(Segment &u,const Segment &ls,const Segment &rs){
    u.c=1ll*ls.c*rs.c%mod,u.v=1ll*ls.v*rs.v%mod;
}
void build(int u,int l,int r){
    s[u].l=l,s[u].r=r;
    if(l==r) return s[u].c=c[ex[l]],s[u].v=v[ex[l]],void();
    build(u*2,l,(l+r)/2);
    build(u*2+1,(l+r)/2+1,r);
    pushup(s[u],s[u*2],s[u*2+1]);
}
Segment query(int u,int l,int r){
    if(s[u].l>r||s[u].r<l) return emp;
    if(s[u].l>=l&&s[u].r<=r) return s[u];
    auto ret=emp,ls=query(u*2,l,r),rs=query(u*2+1,l,r);
    pushup(ret,ls,rs);
    return ret;
}
int main(){
    scanf("%d%d",&n,&mod);
    for(int i=1;i<=n-1;i++){
        static int x,y;
        scanf("%d%d",&x,&y);
        e[x].push_back(y);
        e[y].push_back(x);
    }
    for(int i=1;i<=n;i++) scanf("%d",&c[i]);
    for(int i=1;i<=n;i++) scanf("%d",&v[i]);
    dfs(1,0),build(1,1,n),ans=s[1].c;//根节点答案独立统计 
    for(int i=2;i<=n;i++){
        static Segment lot,son,rot;//dfs序左边 子树 dfs序右边 
        lot=query(1,1,dfn[i]-1);
        son=query(1,dfn[i],dfn[i]+w[i]-1);
        rot=query(1,dfn[i]+w[i],n);
        ans+=1ll*lot.v*rot.v%mod*son.c%mod;
    }
    ans%=mod;
    printf("%lld\n",ans);
    return 0;
}

前缀积 & 后缀积

Updated on 2023/4/15

可以发现线段树只需要维护 c 的区间乘积,dfs 序左右两端的查询实际上是 v 的一段前缀积和一段后缀积。维护 v 的前缀积和后缀积可以大幅降低常数。

#include<bits/stdc++.h>
using namespace std;
struct Segment{int l,r,w;}s[1200001];
int n,mod,pos,c[300001],v[300001],w[300001],ex[300001],dfn[300001],mul[300001],pre[300005],suf[300005];
vector<int> e[300001];
long long ans;
void dfs(int x,int las){
    w[x]=1,dfn[x]=++pos,ex[pos]=x;
    for(int i:e[x]) if(i!=las) dfs(i,x),w[x]+=w[i];
}
void build(int u,int l,int r){
    s[u].l=l,s[u].r=r;
    if(l==r) return s[u].w=c[ex[l]],mul[l]=v[ex[l]],void();
    build(u*2,l,(l+r)/2),build(u*2+1,(l+r)/2+1,r);
    s[u].w=1ll*s[u*2].w*s[u*2+1].w%mod;
}
int query(int u,int l,int r){
    if(s[u].l>r||s[u].r<l) return 1;//注意返回值 
    if(s[u].l>=l&&s[u].r<=r) return s[u].w;
    return 1ll*query(u*2,l,r)*query(u*2+1,l,r)%mod;
}
int main(){
    scanf("%d%d",&n,&mod);
    for(int i=1;i<=n-1;i++){
        static int x,y;
        scanf("%d%d",&x,&y);
        e[x].push_back(y);
        e[y].push_back(x);
    }
    for(int i=1;i<=n;i++) scanf("%d",&c[i]);
    for(int i=1;i<=n;i++) scanf("%d",&v[i]);
    dfs(1,0),build(1,1,n),ans=s[1].w;
    pre[0]=1,suf[n+1]=1;//在pre和suf的越界处设置极值,无需特判 
    for(int i=1;i<=n;i++) pre[i]=1ll*mul[i]*pre[i-1]%mod;//前缀积 
    for(int i=n;i>=1;i--) suf[i]=1ll*mul[i]*suf[i+1]%mod;//后缀积 
    for(int i=2;i<=n;i++){
        static int lot,son,rot;
        lot=pre[dfn[i]-1];
        son=query(1,dfn[i],dfn[i]+w[i]-1);
        rot=suf[dfn[i]+w[i]];
        ans+=1ll*lot*rot%mod*son%mod;
    }
    ans%=mod;
    printf("%lld\n",ans);
    return 0;
}