csps 星战题解

· · 题解

更好的体验
首先可以观察出一个性质,只要每个点的出度都是 1,这个图就是一个内向奇环树,也就是说只要每个点出度为 1,那么该情况就是合法的。然后考虑怎么动态维护每个点的出度。

这里我们使用一个哈希,用一个值记录当前每个点的出度大小为多少,当这个值和“每个点出度为 1”这个状态的值一样的时候就是合法的值了。设出度为 ou 数组,定义哈希函数 cur=\sum_{i=1}^n p^i \times ou_i ,那么目标情况的哈希值就是 fin=\sum_{i=1}^n p^i
当一条边 e_{i,j} 被毁,我们就给 cur 减去 p^i ,同理,当 e_{i,j} 被重建就给 cur 加上 p^i ,对于全体摧毁和全体重建,我们记录一个数组 pas_j 表示每一条以 j 为终点的边被摧毁时,对 cur 减值的总和,即 pas_j=\sum p^i ,其中边 e_{i,j} 被摧毁。再记录一个数组 in_j 表示一开始的时候每条以 j 为终点的边对 cur 的贡献的总和,即 in_j=\sum p^i ,其中边 e_{i,j} 存在。那么集体摧毁的时候 cur 就应当减去 in_j-pas_j ,集体修复同理。具体可以参考代码,代码中为了求稳使用了双哈希。

另:由于本蒟蒻恶臭习惯,喜欢把出边和入边反过来写,所以代码里面 inou 含义交换。

#include<iostream>
#include<cstdio>
#include<cstring>
typedef long long ll;
using namespace std;
const ll N=6e5,p0=23333,mod1=998244353,mod2=1e9+7;
ll n,m,in[N],qu;

struct Pair{
    ll x,y;
    Pair operator *(const ll&tmp )const{
        Pair c;
        c.x=x*tmp%mod1,c.y=y*tmp%mod2;
        return c;
    }
    Pair operator +(const Pair&tmp )const{
        Pair c;
        c.x=(x+tmp.x)%mod1,c.y=(y+tmp.y)%mod2;
        return c;
    }
    Pair operator -(const Pair&tmp )const{
        Pair c;
        c.x=(x-tmp.x+mod1)%mod1,c.y=(y-tmp.y+mod2)%mod2;
        return c;
    }
    void clear(){x=0,y=0;}
}p[N],sam,cur,pas[N],ou[N];

int main(){
    cin>>n>>m;
    p[0].x=p[0].y=1;
    for(ll i=1;i<=n;i++)p[i]=p[i-1]*p0;
    for(ll i=1;i<=m;i++){
        ll a1,a2;
        scanf("%lld%lld",&a1,&a2);
        in[a1]++,ou[a2]=ou[a2]+p[a1];
    }
    for(ll i=1;i<=n;i++)sam=sam+p[i];
    for(ll i=1;i<=n;i++)
        cur=cur+(Pair){p[i].x*in[i]%mod1,p[i].y*in[i]%mod2};
    cin>>qu;
    for(ll i=1;i<=qu;i++){
        ll a1,a2,a3;
        scanf("%lld",&a1);
        if(a1==1){
            scanf("%lld%lld",&a2,&a3);
            cur=cur-p[a2];
            pas[a3]=pas[a3]+p[a2];
        }
        if(a1==2){
            scanf("%lld",&a2);
            cur=cur-ou[a2];
            cur=cur+pas[a2];
            pas[a2]=ou[a2];
        }
        if(a1==3){
            scanf("%lld%lld",&a2,&a3);
            cur=cur+p[a2];                                                                                                                       
            pas[a3]=pas[a3]-p[a2];
        }
        if(a1==4){
            scanf("%lld",&a2);
            cur=cur+pas[a2];
            pas[a2].clear();
        }
        if(cur.x==sam.x&&cur.y==sam.y)printf("YES\n");
        else printf("NO\n");
    }
}