题解:P13241 「2.48sOI R1」格律树
喜报:离 AC 差一个 +mod。
好我们回到题目。首先这题是给出了一棵树和多组询问,每组询问给出了一些关键点和关键点的颜色,颜色只有两种 0 或 1。其他的点可以任意染色,求使得每个关键点(注意不是所有节点)到根路径上的颜色序列不出现 101 这种情况的方案数。
我们先考虑没有询问怎么做。设计状态 0、1、10 结尾的方案总数。其中 10 结尾,就需要把它统计进
考虑怎么转移。先看 0 填到序列最后,是肯定不会产生非法情况的。所以它可以从前面的任意状态继承过来。于是
接着看 1 填到序列最后,就有可能会产生非法情况了。所以要把这种情况减去,相当于就是减去儿子节点以 10 结尾的情况,所以
最后来看 10 的情况,那么
现在对于每个节点 0 就转移 1 就转移
最后整颗树的答案就是
好,现在我们已经知道了没有询问怎么做,现在考虑有多组询问该怎么做。
首先可以看见题目中询问的总点数是
现在假设我们已经把虚树给建出来了,那么我们的转移可以和之前一样的做法,但是这时候会出现一个问题:我们从虚树上的一个点跳到他的父节点时会经过很多其他的点,要转移多次。但是这些点没有其他的儿子,因此转移是相同的,这引导我们往矩阵乘法上去想。
那么因为我们刚才提及了 3 种状态转移,是不是就要设计 3 种矩阵呢?其实并不需要,因为我们经过的这些节点肯定都是没有颜色的点,因此只要设计一种矩阵就行。
假设现在我们知道
然后我们对于
那么每次处理出
最后的答案为
给出代码:
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define N 1000010
const int mod=1e9+7;
int n,q,k,dep[N],dfn[N],idx,f[N][21],rt,col[N],m,sum[N][21],ans[N],bas2[N],tpp;
vector<int>to[N],v;
vector<int>too[N];
struct matrix{
int a[3][3]={};
friend matrix operator+(matrix a,matrix b){
matrix c;
for(int i=0;i<3;i++){
for(int j=0;j<3;j++){
c.a[i][j]=(a.a[i][j]+b.a[i][j])%mod;
}
}
return c;
}
friend matrix operator*(matrix a,matrix b){
matrix c;
for(int i=0;i<3;i++){
for(int j=0;j<3;j++){
for(int k=0;k<3;k++){
c.a[i][j]=(c.a[i][j]+a.a[i][k]*b.a[k][j]%mod+mod)%mod;
}
}
}
return c;
}
void print(){
for(int i=0;i<3;i++){
for(int j=0;j<3;j++){
cout<<a[i][j]<<" ";
}
cout<<'\n';
}
}
}dp[N],bas[N],base[3];
void ini(){
bas[1].a[0][0]=1;
bas[1].a[0][1]=1;
bas[1].a[1][0]=1;
bas[1].a[1][1]=1;
bas[1].a[1][2]=1;
bas[1].a[2][1]=-1;
// bas[1].a[2][0]=1;
base[2]=bas[1];
base[0].a[0][0]=1;
base[0].a[1][0]=1;
base[0].a[1][2]=1;
base[1].a[0][1]=1;
base[1].a[1][1]=1;
base[1].a[2][1]=-1;
}
bool comp(int x,int y){
return dfn[x]<dfn[y];
}
void clr(){
for(auto x:v){
too[x].clear();
col[x]=2;
}
}
void init(int x,int fa){
dfn[x]=++idx;
dep[x]=dep[fa]+1;
f[x][0]=fa;
for(int i=1;i<=20;i++){
f[x][i]=f[f[x][i-1]][i-1];
}
for(auto y:to[x]){
if(y==fa){
continue;
}
init(y,x);
}
}
int lca(int x,int y){
if(x==y){
return x;
}
if(dep[x]<dep[y]){
swap(x,y);
}
for(int i=19;i>=0;i--){
if(dep[f[x][i]]>=dep[y]){
x=f[x][i];
}
}
if(x==y){
return x;
}
for(int i=19;i>=0;i--){
if(f[x][i]!=f[y][i]){
x=f[x][i];
y=f[y][i];
}
}
return f[x][0];
}
int build(){
sort(v.begin(),v.end(),comp);
vector<int>vv;
vv=v;
for(int i=0;i<v.size()-1;i++){
vv.push_back(lca(v[i],v[i+1]));
}
sort(vv.begin(),vv.end(),comp);
v.clear();
v.push_back(vv[0]);
for(int i=1;i<vv.size();i++){
if(vv[i]!=vv[i-1]){
v.push_back(vv[i]);
}
}
for(int i=0;i<v.size()-1;i++){
int tmp=lca(v[i],v[i+1]);
if(tmp!=v[i+1]){
too[tmp].push_back(v[i+1]);
too[v[i+1]].push_back(tmp);
}
}
return v[0];
}
void dfs(int x,int fa){
if(x!=rt){
tpp+=dep[x]-dep[fa];
}
for(int i=0;i<3;i++){
for(int j=0;j<3;j++){
dp[x].a[i][j]=0;
}
}
if(too[x].size()+(x==rt)<=1){
dp[x].a[0][col[x]]=1;
// cout<<x<<'\n';
// dp[x].print();
if(dep[fa]!=dep[x]-1){
dp[x]=dp[x]*bas[dep[x]-dep[fa]-1];
}
// dp[x].print();
return;
}
if(col[x]<2){
dp[x].a[0][col[x]]=1;
int tmp=1;
for(auto y:too[x]){
if(y==fa){
continue;
}
dfs(y,x);
matrix f=dp[y]/**base[col[x]]*/;
tmp=tmp*f.a[0][0]%mod;
// cout<<x<<" "<<y<<'\n';
// base[col[x]].print();
// f.print();
dp[x].a[0][col[x]]=dp[x].a[0][col[x]]*((f.a[0][0]+f.a[0][1]-col[x]*f.a[0][2]+mod)%mod)%mod;
}
if(col[x]==0){
dp[x].a[0][2]=(dp[x].a[0][0]-tmp+mod)%mod;
}
}
else{
dp[x].a[0][0]=dp[x].a[0][1]=1;
int tmp=1;
for(auto y:too[x]){
if(y==fa){
continue;
}
dfs(y,x);
matrix f=dp[y]/**base[col[x]]*/;
tmp=tmp*f.a[0][0]%mod;
// cout<<x<<" "<<y<<'\n';
// base[col[x]].print();
// f.print();
dp[x].a[0][0]=dp[x].a[0][0]*(f.a[0][0]+f.a[0][1])%mod;
dp[x].a[0][1]=dp[x].a[0][1]*((f.a[0][0]+f.a[0][1]-f.a[0][2]+mod)%mod)%mod;
// dp[x].print();
}
dp[x].a[0][2]=(dp[x].a[0][0]-tmp+mod)%mod;
}
// cout<<x<<'\n';
// dp[x].print();
if(dep[fa]!=dep[x]-1){
dp[x]=dp[x]*bas[dep[x]-dep[fa]-1];
}
// dp[x].print();
}
signed main(){
ini();
cin>>n;
for(int i=1;i<n;i++){
int u,v;
cin>>u>>v;
to[u].push_back(v);
to[v].push_back(u);
}
bas2[0]=1;
cin>>m>>q;
init(1,0);
int maxn=0;
for(int i=1;i<=n;i++){
maxn=max(maxn,dep[i]);
col[i]=2;
bas2[i]=bas2[i-1]*2%mod;
}
for(int i=2;i<=maxn;i++){
bas[i]=bas[i-1]*bas[1];
}
for(int TEST=1;TEST<=m;TEST++){
tpp=0;
clr();
v.clear();
bool fl=0;
cin>>k;
for(int i=1;i<=k;i++){
int x;
cin>>x;
cin>>col[x];
v.push_back(x);
if(x==1){
fl=1;
}
}
rt=build();
dfs(rt,f[rt][0]);
if(dep[rt]!=dep[1]){
dp[rt]=dp[rt]*bas[dep[rt]-dep[1]];
}
tpp+=dep[rt];
ans[TEST]=(dp[rt].a[0][0]+dp[rt].a[0][1])%mod*bas2[n-tpp]%mod;
}
int ANS=0;
for(int i=1;i<=m;i++){
ANS^=ans[i];
if(i%q==0){
cout<<ANS<<'\n';
ANS=0;
}
}
}