题解--risrqnis

· · 题解

因为是和出题人截然不同的做法,且常数更为优秀,所以就以博客的形式挂了出来。欢迎大家提出更为优秀的做法。

首先考虑 m=1。我们发现,每个数最多只会被加入 S 一次,所以我们可以均摊分析。考虑用一个并查集维护在值域上所属 1 区间最右端点,然后每次暴力把一个数赋值成 1 加入树状数组就可以了,查询直接区间求和。时间复杂度 O(n\log n)
但是 m 较大的时候,对于每一个 S_k,我们都要将势能归零,这会非常慢。因为 S_k 独立,考虑对 S_k 的操作次数进行根号分治:

对于操作次数大于 \sqrt{q}k,我们考虑用上面的方法。但是发现修改是 O(\sqrt{n}\log n) 查询是 O(\log n) 的很不平衡,所以考虑根号平衡,即把树状数组换成 O(1) 区间修改,O(\sqrt{n}) 单点查询的分块,这一部分复杂度 O(n\sqrt{q}+q\sqrt{n})。(当然这里是把并查集的复杂度忽略了的,据 dqstz 大佬分析,用 ODT 而不是并查集可以达到严格复杂度)

对于操作次数小于 \sqrt{q}k,我们发现一个好处是加入的区间所形成的值域连续段数量不会超过 \sqrt{q}。所以我们在查询的时候可以对每个值域连续段查询询问的区间在这个连续段内有多少个,这相当于一个二维数点。但是离线下来后我们不能用树状数组维护,因为这会使修改是 O(\log n) 查询是 O(\sqrt{n}\log n) 的很不平衡。所以我们一样的用一个 O(\sqrt{n}) 单点修改,O(1) 区间查询的分块,这一部分复杂度 O(n\sqrt{n}+q\sqrt{q})。但是发现空间复杂度达到了 O(n\sqrt{n}) 级别,于是我们可以离线询问的区间,然后给那需要用到它的询问加上贡献,这样就可以达到线性空间。

综上,令 n,q 同阶,复杂度 O(n\sqrt{n})。注意 m=1 特判。代码仅供参考。

#include<bits/stdc++.h>
#define fi first
#define sc second
using namespace std;
inline int read() {
    int x=0,f=1;
    char c=getchar();
    while(c<'0'||c>'9') {
        if(c=='-')f=-1;
        c=getchar();
    }
    while(c>='0'&&c<='9')x=x*10+c-'0',c=getchar();
    return x*f;
}
const int maxn=3e5+5,maxm=666,mxn=1e6+5;
int n,q,m,lp[maxm],rp[maxm],bel[maxn],sqrtn,sn,ans[maxn],a[mxn],b[mxn],op[maxn],qc,nowk;
typedef pair<int,int> pi;
pi id[mxn];
namespace sub0 {
    int c[mxn],f[mxn];
    int getf(int x) {
        return f[x]==x?x:f[x]=getf(f[x]);
    }
    void add(int x) {
        while(x<=n)c[x]++,x+=x&(-x);
    }
    int ask(int x) {
        int sum=0;
        while(x)sum+=c[x],x-=x&(-x);
        return sum;
    }
    void solve() {
        int op,l,r;
        for(int i=1; i<=n+1; i++)f[i]=i;
        while(q--) {
            op=read(),l=read(),r=read(),read();
            if(op==1) {
                l=lower_bound(b+1,b+n+1,l)-b,r=upper_bound(b+1,b+n+1,r)-b-1;
                int x=getf(l);
                while(x<=r)add(id[x].sc),f[x]=f[x+1],x=getf(x+1);
            } else printf("%d\n",ask(r)-ask(l-1));
        }
    }
}
struct quest {
    int op,l,r,k,id;
    friend bool operator<(const quest a,const quest b) {
        return a.k==b.k?a.id<b.id:a.k<b.k;
    }
} p[maxn];
namespace sub1 {
    struct node {
        int l,r,st;
        node() {}
        node(int l,int r,int st) {
            this->l=l,this->r=r,this->st=st;
        }
        friend bool operator<(const node a,const node b) {
            return a.l<b.l;
        }
    } st[maxn];
    struct edge {
        int next,op,l,r;
    } e[maxn*4];
    int h[maxn],cnt,s1[maxn],s2[maxm],top;
    set<node> s;
    typedef set<node>::iterator iter;
    void add(int x) {
        int b=bel[x];
        for(int i=x; i<=rp[b]; i++)s1[i]++;
        for(int i=b+1; i<=sn; i++)s2[i]++;
    }
    int ask(int x) {
        if(!x)return 0;
        return s1[x]+s2[bel[x]];
    }
    void addedge(int x,int op,int l,int r) {
        e[++cnt]=(edge) {
            h[x],op,l,r
        },h[x]=cnt;
    }
    void ins(int l,int r,int t) {
        st[++top]=node(l,r,t);
    }
    void ers(int l,int r,int t) {
        for(int i=1; i<=top; i++)
            if(st[i].l==l&&st[i].r==r) {
                addedge(r,1,st[i].st,t),addedge(l-1,-1,st[i].st,t);
                for(int j=i; j<top; j++)st[j]=st[j+1];
                top--;
                break;
            }
    }
    void solve1(int L,int R) {
        s.clear(),top=0;
        iter il,ir;
        for(int w=L; w<=R; w++) {
            int op=p[w].op,l=p[w].l,r=p[w].r;
            if(op==1) {
                l=lower_bound(b+1,b+n+1,l)-b,r=upper_bound(b+1,b+n+1,r)-b-1;
                if(l>r)continue;
                if(s.empty())goto shik;
                il=s.lower_bound(node(l,0,0)),ir=s.upper_bound(node(r,0,0));
                if(il!=s.begin()) {
                    iter tmp=il;
                    tmp--;
                    if(tmp->r>=l)il--;
                }
                if(ir==s.begin()||(--ir)->r<l)goto shik;
                l=min(l,il->l),r=max(r,ir->r);
                while(il!=ir)ers(il->l,il->r,w),il=s.erase(il);
                ers(ir->l,ir->r,w),s.erase(ir);
shik:
                ins(l,r,w),s.insert(node(l,r,w));
            }
        }
        for(int i=1; i<=top; i++)addedge(st[i].r,1,st[i].st,R),addedge(st[i].l-1,-1,st[i].st,R),st[i].l=st[i].r=st[i].st=0;
    }
    void solve() {
        for(int u=1; u<=n; u++) {
            add(id[u].sc);
            for(int i=h[u]; i; i=e[i].next) {
                int l=e[i].l,r=e[i].r,o=e[i].op;
                for(int j=l; j<=r; j++)
                    if(p[j].op==2)ans[p[j].id]+=o*(ask(p[j].r)-ask(p[j].l-1));
            }
        }
    }
}
namespace sub2 {
    int f[maxn],s1[maxn],s2[maxm];
    int getf(int x) {
        return f[x]==x?x:f[x]=getf(f[x]);
    }
    void add(int x) {
        s1[x]++,s2[bel[x]]++;
    }
    int ask(int l,int r) {
        int x=bel[l],y=bel[r],sum=0;
        if(x==y) {
            for(int i=l; i<=r; i++)sum+=s1[i];
            return sum;
        }
        for(int i=l; i<=rp[x]; i++)sum+=s1[i];
        for(int i=lp[y]; i<=r; i++)sum+=s1[i];
        for(int i=x+1; i<y; i++)sum+=s2[i];
        return sum;
    }
    void solve(int L,int R) {
        for(int i=1; i<=n+1; i++)f[i]=i;
        memset(s1,0,sizeof(s1)),memset(s2,0,sizeof(s2));
        for(int w=L; w<=R; w++) {
            int op=p[w].op,l=p[w].l,r=p[w].r;
            if(op==1) {
                l=lower_bound(b+1,b+n+1,l)-b,r=upper_bound(b+1,b+n+1,r)-b-1;
                int x=getf(l);
                while(x<=r)add(id[x].sc),f[x]=f[x+1],x=getf(x+1);
            } else ans[p[w].id]=ask(l,r);
        }
    }
}
int main() {
    n=read(),q=read(),m=read(),sqrtn=sqrt(n),sn=(n+sqrtn-1)/sqrtn;
    for(int i=1; i<=n; i++)a[i]=read(),id[i]=pi(a[i],i),b[i]=a[i];
    sort(id+1,id+n+1),sort(b+1,b+n+1);
    if(m==1)sub0::solve();
    else {
        for(int i=1; i<=sn; i++)lp[i]=(i-1)*sqrtn+1,rp[i]=min(i*sqrtn,n);
        for(int i=1; i<=n; i++)bel[i]=(i-1)/sqrtn+1;
        for(int i=1; i<=q; i++)op[i]=p[i].op=read(),p[i].l=read(),p[i].r=read(),p[i].k=read(),p[i].id=i;
        sort(p+1,p+q+1);
        for(int l=1,r; l<=q; l=r+1) {
            r=l;
            while(p[r+1].k==p[l].k)r++;
            nowk=p[l].k;
            if(r-l+1>=sqrt(q))sub2::solve(l,r);
            else sub1::solve1(l,r);
        }
        sub1::solve();
        for(int i=1; i<=q; i++)if(op[i]==2)printf("%d\n",ans[i]);
    }
    return 0;
}