题解:P8334 [ZJOI2022] 深搜
怎么大家都会拆贡献?来写一篇与拆贡献无关的做法。
若对每个点
我们不妨对每个
考虑上述过程的优化:对于一个贡献点,每次加入新的贡献点导致其权值改变时,其变化量不与
放一份
#include <bits/stdc++.h>
using namespace std;
const int N=4e5+5,mod=998244353;
inline int read(){
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9')x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
return x*f;
}
inline int ksm(int a,int b){
int ans=1;
while(b){
if(b&1)ans=1ll*ans*a%mod;
a=1ll*a*a%mod;
b>>=1;
}return ans;
}
int T,n,rt,a[N],mn[N],ans,cnt,inv[N],vvv[N],d[N];
vector<int>e[N];
inline void dfs(int x,int fa){
mn[x]=a[x];
if(fa)mn[x]=min(mn[x],a[fa]);
for(int y : e[x]){
if(y^fa)d[y]=d[x]+1,dfs(y,x),mn[x]=min(mn[x],mn[y]);
}
}
struct node{
int val,w,w2,op,d,id;
}tmp[N];
struct DATA{
int val,d,id;
}p[N<<2];
inline bool cmp(int a,int b){
return mn[a]<mn[b];
}
inline void insert(int val,int w,int w2,int d,int r,int id){
for(int j = 1;j<=cnt;j++)if(tmp[j].val>val||tmp[j].val==val&&(tmp[j].d<d||tmp[j].d==d&&tmp[j].id>id)){
if(w2==0)tmp[j].op++;
else tmp[j].w=tmp[j].w*1ll*w2%mod;
}
tmp[++cnt]={val,0,w2,0,d,id};
int ggg=d;
for(int j = 1;j<cnt;j++)if((tmp[j].val<val||tmp[j].d==d&&tmp[j].id<id)&&!tmp[j].op)ggg=(ggg-tmp[j].w+mod)%mod;
tmp[cnt].w=ggg*1ll*r%mod;
}
inline void del(int id){
swap(tmp[id],tmp[cnt]);
int w=ksm(tmp[cnt].w2,mod-2);
for(int i = 1;i<cnt;i++)if(tmp[i].val>tmp[cnt].val||tmp[i].val==tmp[cnt].val&&(tmp[i].d<tmp[cnt].d||tmp[i].d==tmp[cnt].d&&tmp[i].id>tmp[cnt].id)){
if(tmp[cnt].w2==0)tmp[i].op--;
else tmp[i].w=1ll*tmp[i].w*1ll*w%mod;
}
cnt--;
}
inline void dfs2(int x,int fa){
int gg=d[x];
for(int i = 1;i<=cnt;i++){
if(!tmp[i].op&&tmp[i].val<a[x])ans=(ans + tmp[i].val*1ll*tmp[i].w)%mod,gg=(gg+mod-tmp[i].w)%mod;
}
ans=(ans+1ll*a[x]*gg)%mod;
vector<int>v;
for(auto y : e[x])if(y!=fa)v.push_back(y);
sort(v.begin(),v.end(),cmp);
if(v.size())insert(a[x],inv[v.size()],0,d[x],1,v.size());
for(int i = v.size()-1,y;i>=1;i--){
y=v[i];
insert(mn[y],inv[i]*1ll*inv[i+1]%mod,i*1ll*inv[i+1]%mod,d[x],inv[i+1],i);
}
int now = cnt;
for(int i = 0,y;i<v.size();i++){
y=v[i];
dfs2(y,x);
del(now--);
insert(mn[y],inv[i+1]*1ll*inv[i+2]%mod,(i+1ll)*inv[i+2]%mod,d[x],inv[i+2],i);
}
for(int i = 0;i<v.size();i++)del(cnt);
}
int main(){
T=read();
while(T--){
n=read(),rt=read();ans=0;
inv[0]=inv[1]=1;
cnt=0;
for(int i = 2;i<=n+1;i++)inv[i]=inv[mod%i]*1ll*(mod-mod/i)%mod;
for(int i = 1;i<=n;i++)e[i].clear(),a[i]=read();
for(int i = 1,x,y;i<n;i++)x=read(),y=read(),e[x].push_back(y),e[y].push_back(x);
d[rt]=1;dfs(rt,0);
dfs2(rt,0);
printf("%d\n",ans);
}
return 0;
}
不难发现,贡献点之间影响操作为区间乘,单点改,询问操作为区间和。不过需要维护乘
时空复杂度:
贴一份代码:
#include <bits/stdc++.h>
using namespace std;
const int N=4e5+5,mod=998244353;
inline int read(){
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9')x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
return x*f;
}
inline int ksm(int a,int b){
int ans=1;
while(b){
if(b&1)ans=1ll*ans*a%mod;
a=1ll*a*a%mod;
b>>=1;
}return ans;
}
inline int md(int x){
return x>=mod?x-mod:x;
}
int T,n,rt,a[N],mn[N],ans,cnt,inv[N],vvv[N],d[N],num,q[N],C;
vector<int>e[N];
inline void dfs(int x,int fa){
mn[x]=a[x];
if(fa)mn[x]=min(mn[x],a[fa]);
for(int y : e[x]){
if(y^fa)d[y]=d[x]+1,dfs(y,x),mn[x]=min(mn[x],mn[y]);
}
}
inline bool cmp(int a,int b){
return mn[a]<mn[b];
}
struct node{
int val,w,w2,op,d,id;
}tmp[N];
struct DATA{
int val,d,id;
inline bool operator<(DATA b){
if(val^b.val)return val<b.val;
if(d^b.d)return d>b.d;
return id<b.id;
}
}p[N<<1];
struct tree{
int l,r,tag1,tag2,s1,s3;
}t[N<<3];
#define mid (t[p].l+t[p].r>>1)
inline void up(int p){
t[p].s1=md(t[p<<1].s1+t[p<<1|1].s1),t[p].s3=md(t[p<<1].s3+t[p<<1|1].s3);
}
inline void build(int p,int l,int r){
t[p]={l,r,0,1,0,0};
if(l==r)return;
build(p<<1,l,mid),build(p<<1|1,mid+1,r);
}
vector<pair<int,tree>>g[N];
inline void cg2(int p,int v);
inline void cg1(int p,int v){
g[v].push_back({p,t[p]});
t[p].tag1=v;t[p].tag2=1;t[p].s1=t[p].s3=0;
}
inline void cg2(int p,int v){
t[p].tag2=1ll*v*t[p].tag2%mod,t[p].s1=1ll*v*t[p].s1%mod;t[p].s3=1ll*v*t[p].s3%mod;
}
inline void spread(int p,int op=1){
if(t[p].tag1&&op){
cg1(p<<1,t[p].tag1),cg1(p<<1|1,t[p].tag1);
t[p].tag1=0;
}
if(t[p].tag2!=1){
cg2(p<<1,t[p].tag2),cg2(p<<1|1,t[p].tag2);
t[p].tag2=1;
}
}
inline void change(int p,int l,int r,int v,int op){
if(l<=t[p].l&&t[p].r<=r){
if(!v)cg1(p,op);
else cg2(p,v);
return;
}
spread(p);
if(!v)g[op].push_back({p,t[p]});
if(l<=mid)change(p<<1,l,r,v,op);
if(r>mid)change(p<<1|1,l,r,v,op);
up(p);
}
inline void insert(int p,int pos,int val,int w){
if(t[p].l==t[p].r){
t[p].s1=w,t[p].s3=val*1ll*w%mod;
return;
}
spread(p);
if(pos<=mid)insert(p<<1,pos,val,w);
else insert(p<<1|1,pos,val,w);
up(p);
}
inline int query(int p,int l,int r,int op){
if(l<=t[p].l&&t[p].r<=r){
if(op==1)return t[p].s1;
return t[p].s3;
}
spread(p);int ans=0;
if(l<=mid)ans=query(p<<1,l,r,op);
if(r>mid)ans=md(ans+query(p<<1|1,l,r,op));
return ans;
}
inline void insert(int val,int w,int w2,int d,int r,int id){
int pos = lower_bound(p+1,p+num+1,(DATA){val,d,id})-p;
if(pos^num)change(1,pos+1,num,w2,d);
tmp[++cnt]={val,0,w2,0,d,id};
int ggg = d;
if(pos>1)ggg=(ggg+mod-query(1,1,pos-1,1))%mod;
insert(1,pos,val,ggg*1ll*r%mod);
}
inline void del(int id){
swap(tmp[id],tmp[cnt]);
int w = ksm(tmp[cnt].w2,mod-2);
int pos = lower_bound(p+1,p+num+1,(DATA){tmp[cnt].val,tmp[cnt].d,tmp[cnt].id})-p;
insert(1,pos,0,0);
if(pos^num){
if(w)change(1,pos+1,num,w,-1);
else{
for(auto[i,k]:g[tmp[cnt].d]){
spread(i,0);
t[i]=k;
}
g[tmp[cnt].d].clear();
}
}
cnt--;
}
inline void dfs2(int x,int fa){
int pos = lower_bound(p+1,p+num+1,(DATA){a[x],n+1,0})-p-1;
if(pos>=1)ans=(ans+query(1,1,pos,2)+1ll*a[x]*(d[x]+mod-query(1,1,pos,1)))%mod;
else ans=(ans+1ll*a[x]*d[x])%mod;
vector<int>v;
for(auto y : e[x])if(y!=fa)v.push_back(y);
sort(v.begin(),v.end(),cmp);
if(v.size())insert(a[x],inv[v.size()],0,d[x],1,v.size());
for(int i = v.size()-1,y;i>=1;i--){
y=v[i];
insert(mn[y],inv[i]*1ll*inv[i+1]%mod,i*1ll*inv[i+1]%mod,d[x],inv[i+1],i);
}
int now = cnt;
for(int i = 0,y;i<v.size();i++){
y=v[i];
dfs2(y,x);
if(i!=v.size()-1){
del(now--);
insert(mn[y],inv[i+1]*1ll*inv[i+2]%mod,(i+1ll)*inv[i+2]%mod,d[x],inv[i+2],i);
}
}
for(int i = 0;i<v.size();i++)del(cnt);
}
inline void predfs(int x,int fa){
vector<int>v;
for(auto y : e[x])if(y!=fa)v.push_back(y);
sort(v.begin(),v.end(),cmp);
if(v.size()){
p[++num]=(DATA){a[x],d[x],(int)v.size()};
for(int i = 0,y;i<v.size();i++){
y=v[i];
predfs(y,x);
p[++num]=(DATA){mn[y],d[x],i};
}
}
}
int main(){
T=read();
while(T--){
n=read(),rt=read();ans=0;
inv[0]=inv[1]=1;
cnt=0;num=0;
for(int i = 2;i<=n+1;i++)inv[i]=inv[mod%i]*1ll*(mod-mod/i)%mod;
for(int i = 1;i<=n;i++)e[i].clear(),a[i]=read();
for(int i = 1,x,y;i<n;i++)x=read(),y=read(),e[x].push_back(y),e[y].push_back(x);
d[rt]=1;dfs(rt,0);predfs(rt,0);
sort(p+1,p+num+1);
if(num)build(1,1,num);
dfs2(rt,0);
printf("%d\n",ans);
}
return 0;
}