题解:P12480 [集训队互测 2024] Classical Counting Problem
比较考验基本功的题,需要一步步慢慢转化。
首先考虑刻画合法连通块具有哪些性质。注意到合法性只跟块内的最小值
显然,每次加入的点必须不在
例如上图
回到原问题。由于我们刚才刻画的条件唯一的限制与
这时候容易实现
考虑优化。直接枚举
但是答案乘上
现在的条件要好刻画多了。发现
回到点分治过程。我们现在确定了一个分治中心
所有限制现在都跟每个点到
现在问题简化为:给定若干
开一颗线段树,维护
每次加入一个点,有三种情况:
总结一下,我们要用线段树维护三个数组
这样已经基本解决了。注意当
每层分治总共需要对
代码虽然稍长但并不难写,细节与边界情况也不是很多。
以下是本人实现,代码较丑,人傻常数大,仅供参考。
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define ull unsigned long long
#define pii pair<int,int>
#define fir first
#define sec second
#define chmin(a,b) a=min(a,b)
#define chmax(a,b) a=max(a,b)
#define pb push_back
const int inf=0x3f3f3f3f3f3f3f3f;
constexpr int mod=(1LL<<32)-1;
int ans;
int n,sz[100010],vis[100010],mx[100010],l[100010],r[100010],msz;
vector<int>g[100010];
vector<int>tmp;
void dfs1(int u,int f){sz[u]=1;for(auto v:g[u])if(v!=f&&!vis[v])dfs1(v,u),sz[u]+=sz[v];}
void dfs2(int u,int f,int &rt){mx[u]=msz-sz[u];for(auto v:g[u]){if(v!=f&&!vis[v])dfs2(v,u,rt),chmax(mx[u],sz[v]);}if(mx[u]<=mx[rt])rt=u;}
void dfs3(int u,int f)
{
l[u]=min(l[f],u),r[u]=max(r[f],u),tmp.pb(u);
for(auto v:g[u])if(!vis[v]&&v!=f)dfs3(v,u);
}
/*
---------------------
*/
struct tnode
{
int l,r,s,v,c,tag;
tnode(){}
tnode(int _l,int _r,int _s,int _v,int _c,int _tag)
{l=_l,r=_r,s=_s,v=_v,c=_c,tag=_tag;}
//s: 区间内所有 mn 对答案的贡献
//v: 区间内所有合法的 mn 的和
//c: 区间内合法 v 的数量
}t[400010];
#define ls rt<<1
#define rs rt<<1|1
void pushup(int rt)
{
t[rt].s=(t[ls].s+t[rs].s)&mod;
t[rt].v=(t[ls].v+t[rs].v)&mod;
t[rt].c=(t[ls].c+t[rs].c)&mod;
}
void change(int rt,int tag){(t[rt].s+=t[rt].v*tag)&=mod,(t[rt].tag+=tag)&=mod;}
void pushdown(int rt){if(t[rt].tag)change(ls,t[rt].tag),change(rs,t[rt].tag),t[rt].tag=0;}
void build(int rt,int l,int r)
{
t[rt]={l,r,0,0,0,0};
if(l==r)return;
int mid=l+r>>1;
build(ls,l,mid);
build(rs,mid+1,r);
}
void update(int rt,int x,int v,int op)
{
if(t[rt].l==t[rt].r)
{
if(op==0)(t[rt].s+=v)&=mod;
else if(op==1)(t[rt].v+=v)&=mod;
else (t[rt].c+=v)&=mod;
return;
}
pushdown(rt);
if(x<=t[ls].r)update(ls,x,v,op);
else update(rs,x,v,op);
pushup(rt);
}
void add(int rt,int l,int r,int v)
{
if(l<=t[rt].l&&r>=t[rt].r){change(rt,v);return;}
pushdown(rt);
if(l<=t[ls].r)add(ls,l,r,v);
if(r>=t[rs].l)add(rs,l,r,v);
pushup(rt);
}
int query(int rt,int l,int r,int op)
{
if(l<=t[rt].l&&r>=t[rt].r)return op?(op==1?t[rt].v:t[rt].c):t[rt].s;
pushdown(rt);
int ans=0;
if(l<=t[ls].r)ans+=query(ls,l,r,op);
if(r>=t[rs].l)ans+=query(rs,l,r,op);
return ans&mod;
}
/*
---------------------
*/
int m,b[300010];
vector<int>q[300010];
int calc()
{
int ans=0;
m=0;
for(auto u:tmp)b[++m]=u,b[++m]=l[u],b[++m]=r[u];
sort(b+1,b+m+1);
m=unique(b+1,b+m+1)-b-1;
build(1,1,m);
for(int i=1;i<=m;i++)q[i].clear();
for(auto u:tmp)
{
int p=lower_bound(b+1,b+m+1,u)-b;
q[p].pb(u);
int rp=lower_bound(b+1,b+m+1,r[u])-b;
q[rp].pb(-u);
}
for(int i=1;i<=m;i++)
{
sort(q[i].begin(),q[i].end());
for(auto u:q[i])
{
if(u<0)
{
u=-u;
int p=lower_bound(b+1,b+m+1,u)-b;
int lp=lower_bound(b+1,b+m+1,l[u])-b;
//u 为 v
add(1,1,lp,1);
update(1,lp,1,2);
//u 为 mn
if(u==l[u])
{
int q=query(1,p,p,1);
update(1,p,u,1);
int t=query(1,p,m,2);
update(1,p,(u*t)&mod,0);
}
}
else if(u==r[u])
{
//u 为 mx
int p=lower_bound(b+1,b+m+1,u)-b;
int lp=lower_bound(b+1,b+m+1,l[u])-b;
(ans+=query(1,1,lp,0)*u)&=mod;
}
}
}
return ans;
}
void dfs(int u)
{
dfs1(u,0);
msz=sz[u],mx[0]=inf;
int rt=0;
dfs2(u,0,rt);
l[rt]=r[rt]=rt;
vector<int>tr;
tr.pb(rt);
int pans=0;
for(auto v:g[rt])if(!vis[v])
{
tmp.clear();
if(!vis[v])dfs3(v,rt);
int q=calc();
(pans+=mod+1-q)&=mod;
for(auto x:tmp)tr.pb(x);
}
tmp=tr;
int q=calc();
pans+=q;
(ans+=pans)&=mod;
vis[rt]=1;
for(auto v:g[rt])if(!vis[v])dfs(v);
}
void solve()
{
cin>>n;
for(int i=1;i<=n;i++)g[i].clear(),vis[i]=0;
for(int i=1;i<n;i++)
{
int u,v;
cin>>u>>v;
g[u].pb(v);
g[v].pb(u);
}
ans=0;
dfs(1);
cout<<ans<<endl;
return;
}
signed main()
{
int t;
cin>>t;
while(t--)solve();
return 0;
}