P9208 虚人「无」
线段树
乘法逆元
用 dfs 序把树映射到序列上,然后分别维护
使用逆元求解的条件是所有参与计算的值都存在逆元,需要模数与任意
线段树
现在需要一种不依赖逆元的方法快速求出某个区间内数字的乘积,不难想到线段树。
前置知识是线段树维护区间乘积,这里是模板题。
用线段树维护区间内
记得对答案取模,题目是 ACM 赛制的,同样的代码多次 WA 反馈的数据点可能不同,对查错不是很友好。如果被卡常,使用链式前向星代替 vector 可以获得一定程度提速。
时间复杂度
#include<bits/stdc++.h>
using namespace std;
struct Segment{int l,r,c,v;}s[1200001];
constexpr Segment emp=Segment{0,0,1,1};
int n,mod,pos,c[300001],v[300001],w[300001],ex[300001],dfn[300001];
vector<int> e[300001];//ex[x]记录dfs序中位置x对应点的编号,建树用
long long ans;
void dfs(int x,int las){
w[x]=1,dfn[x]=++pos,ex[pos]=x;
for(int i:e[x]) if(i!=las) dfs(i,x),w[x]+=w[i];
}
void pushup(Segment &u,const Segment &ls,const Segment &rs){
u.c=1ll*ls.c*rs.c%mod,u.v=1ll*ls.v*rs.v%mod;
}
void build(int u,int l,int r){
s[u].l=l,s[u].r=r;
if(l==r) return s[u].c=c[ex[l]],s[u].v=v[ex[l]],void();
build(u*2,l,(l+r)/2);
build(u*2+1,(l+r)/2+1,r);
pushup(s[u],s[u*2],s[u*2+1]);
}
Segment query(int u,int l,int r){
if(s[u].l>r||s[u].r<l) return emp;
if(s[u].l>=l&&s[u].r<=r) return s[u];
auto ret=emp,ls=query(u*2,l,r),rs=query(u*2+1,l,r);
pushup(ret,ls,rs);
return ret;
}
int main(){
scanf("%d%d",&n,&mod);
for(int i=1;i<=n-1;i++){
static int x,y;
scanf("%d%d",&x,&y);
e[x].push_back(y);
e[y].push_back(x);
}
for(int i=1;i<=n;i++) scanf("%d",&c[i]);
for(int i=1;i<=n;i++) scanf("%d",&v[i]);
dfs(1,0),build(1,1,n),ans=s[1].c;//根节点答案独立统计
for(int i=2;i<=n;i++){
static Segment lot,son,rot;//dfs序左边 子树 dfs序右边
lot=query(1,1,dfn[i]-1);
son=query(1,dfn[i],dfn[i]+w[i]-1);
rot=query(1,dfn[i]+w[i],n);
ans+=1ll*lot.v*rot.v%mod*son.c%mod;
}
ans%=mod;
printf("%lld\n",ans);
return 0;
}
前缀积 & 后缀积
Updated on 2023/4/15
可以发现线段树只需要维护
#include<bits/stdc++.h>
using namespace std;
struct Segment{int l,r,w;}s[1200001];
int n,mod,pos,c[300001],v[300001],w[300001],ex[300001],dfn[300001],mul[300001],pre[300005],suf[300005];
vector<int> e[300001];
long long ans;
void dfs(int x,int las){
w[x]=1,dfn[x]=++pos,ex[pos]=x;
for(int i:e[x]) if(i!=las) dfs(i,x),w[x]+=w[i];
}
void build(int u,int l,int r){
s[u].l=l,s[u].r=r;
if(l==r) return s[u].w=c[ex[l]],mul[l]=v[ex[l]],void();
build(u*2,l,(l+r)/2),build(u*2+1,(l+r)/2+1,r);
s[u].w=1ll*s[u*2].w*s[u*2+1].w%mod;
}
int query(int u,int l,int r){
if(s[u].l>r||s[u].r<l) return 1;//注意返回值
if(s[u].l>=l&&s[u].r<=r) return s[u].w;
return 1ll*query(u*2,l,r)*query(u*2+1,l,r)%mod;
}
int main(){
scanf("%d%d",&n,&mod);
for(int i=1;i<=n-1;i++){
static int x,y;
scanf("%d%d",&x,&y);
e[x].push_back(y);
e[y].push_back(x);
}
for(int i=1;i<=n;i++) scanf("%d",&c[i]);
for(int i=1;i<=n;i++) scanf("%d",&v[i]);
dfs(1,0),build(1,1,n),ans=s[1].w;
pre[0]=1,suf[n+1]=1;//在pre和suf的越界处设置极值,无需特判
for(int i=1;i<=n;i++) pre[i]=1ll*mul[i]*pre[i-1]%mod;//前缀积
for(int i=n;i>=1;i--) suf[i]=1ll*mul[i]*suf[i+1]%mod;//后缀积
for(int i=2;i<=n;i++){
static int lot,son,rot;
lot=pre[dfn[i]-1];
son=query(1,dfn[i],dfn[i]+w[i]-1);
rot=suf[dfn[i]+w[i]];
ans+=1ll*lot*rot%mod*son%mod;
}
ans%=mod;
printf("%lld\n",ans);
return 0;
}