题解:P5607 [Ynoi2013] 无力回天 NOI2017

· · 题解

更好的阅读体验

并集不好做,转化为求交。

n 比较小的时候,我们会一个 O(\frac{n^2}{\omega})bitset 做法。

当一个数字被插入的次数很小的时候,我们在加入这个数字的时候,可以直接枚举这个数字出现的位置,将这两个集合的交集 +1

这启发我们根号分治。设置阈值 B,对于出现次数 < B 的数字,我们枚举这个数字出现的位置,然后更新交集,这部分复杂度 O(B)。对于出现次数 \ge B 的数字,我们使用 bitset 可以做到 O(\frac{n}{B\omega})

B = \sqrt \frac{n}{\omega} 可以做到 O(n\sqrt \frac{n}{\omega}),实现中因为 bitset 比较快,可以稍微把阈值调小一点,我使用了 80

然后会被卡常和卡空间。我们可以对哈希表和 bitset 动态开点,然后把很慢很慢的 vector 换成链表,这样就很快。最慢点 572 毫秒。

然后这道题就做完了,复杂度 O(n\sqrt \frac{n}{\omega})

#include<bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#define endl '\n'
#define N 1000006
#define K 12506
using namespace std;
constexpr int M_X = 1 << 20;

char buf[M_X], *p1, *p2;

inline char gc() {
  if (p1 == p2) {
    p1 = buf;
    p2 = p1 + fread(buf, 1, M_X, stdin);
  }
  return p1 == p2 ? EOF : *p1++;
}
inline int read()
{
    int ret=0,f=0; char ch=gc();
    while((ch<'0'||ch>'9')&&ch!='-')ch=gc();
    if(ch=='-')f=1,ch=gc();
    while(ch>='0'&&ch<='9')ret=(ret<<3)+(ret<<1)+(ch^48),ch=gc();
    return f?-ret:ret;
}
inline void write(int k)
{
    if(k<0)putchar('-'),k=-k;
    static int nnum[20],ttp=0;
    while(k)nnum[++ttp]=k%10,k/=10;
    if(!ttp)nnum[++ttp]=0;
    while(ttp)putchar(nnum[ttp--]^48);
}
using ull=unsigned long long;
constexpr int B=80;
constexpr ull P=1e9+1011;
int q,occ[N],sz[N],b[N],vis[N],bn;
struct Ask {int opt,x,y;} ask[N];
bitset<K> *bt[N];
inline ull has(int x,int y){return P*x+y;}
__gnu_pbds::gp_hash_table<int,int> *mp[N];
int tot,head[N],val[N],nxt[N];
main()
{
    q=read();
    for(int i=1;i<=q;i++)
    {
        ask[i]={read(),read(),read()};
        if(ask[i].opt==1)occ[ask[i].y]++;
        else if(ask[i].x>ask[i].y)swap(ask[i].x,ask[i].y);
        if(ask[i].opt==2)
        {
            if(!mp[ask[i].x])mp[ask[i].x]=new __gnu_pbds::gp_hash_table<int,int> ();
            (*mp[ask[i].x])[ask[i].y]=0;
        }
    }
    for(int i=1;i<=q;i++)if(occ[i]>B)b[++bn]=i;
    for(int i=1;i<=q;i++)if(ask[i].opt==1)
    {
        int x=ask[i].x,y=ask[i].y;
        if(occ[ask[i].y]>B)
        {
            int t=lower_bound(b+1,b+1+bn,y)-b;
            if(!bt[x])bt[x]=new bitset<K> ();
            bt[ask[i].x]->set(t);
        } else {
            for(int i=head[y];i;i=nxt[i])
            {
                int l=min(val[i],x),r=max(val[i],x);
                if(mp[l]&&mp[l]->find(r)!=mp[l]->end())(*mp[l])[r]++;
            }
            val[++tot]=x,nxt[tot]=head[y],head[y]=tot;
        }
        sz[x]++;
    } else {
        int x=ask[i].x,y=ask[i].y; if(x==y){write(sz[x]),putchar(10); continue;}
        int ans=sz[x]+sz[y]-(*mp[x])[y];
        if(bt[x]&&bt[y])ans-=(*bt[x]&*bt[y]).count();
        write(ans),putchar(10);
    }
    return 0;
}