题解 P4065 【[JXOI2017]颜色】

· · 题解

非常传统的一道线段树题?

下面以这道题为例介绍一下线段树的经典应用——求点对贡献

本题题解

我们发现删完颜色之后剩余元素必定是一个连续的区间(这不是题目要求嘛)

显然两个不同的连续区间对应着两个不同的删除颜色方案

所以我们统计方案数可以转化为统计有多少个不同的合法区间

而区间是可以被描述为点对的。

所以下面就是经典的求点对的总贡献的统计问题了。

对于这种问题我们的解决方案永远是一个——先枚举一个端点,再用数据结构解决另外一个端点(通常是线段树)

换句话讲,我们对于每个右端点,求出有多少个合法左端点,加在一起就是答案

那么对于这道题我们也是相同的思路统计,我们现在枚举右端点

假设我们枚举的右端点是i

我们考虑一个左端点l在什么时候合法

显然i之后的所有颜色要被删去,所以我们发现整个序列会被切割成几个小段

而这个l只能在和i在同一个区间里

我们设颜色为i的点的位置最大值为max_{i},位置最小值为min_{i}

显然所有max_{k}大于i的颜色都会被删去,因此合法左端点至少保证在它和i之间不存在一定会被删去的颜色

也就是说我们要找到max_{col_{j}}大于i,且离i最近的点j

然后我们选择的左端点就必须在j~i之间,且不能选j

(显然j~i里边的点至少不会被删去(因为最右的颜色还是在i以内),但是越过j之后因为j的颜色一定会被删掉(因为max_{col_{j}}大于i)所以不可以越过j)

但是还没有完,我们发现如果一个颜色j的max小于i,那么我们的左端点不可以落在(min_{j},max_{j}]之间,因为如果落在了这个区间里,显然我们会发现这个颜色j会被删掉(因为min_{j}在左端点之前),但是max_{j}又在区间里,此时我们的区间就不是连续的了

所以我们可以设计出这样一个算法,从左到右枚举右端点i

如果这个点i=max_{col_{i}}的话, 我们就在线段树上区间(min_{col_{i}},max_{col_{i}}]赋值为1,表示这些点被禁用

现在我们唯一的问题变成如何确定max_{col_{j}}大于i,且离i最近的点j, 我们可以贪心,每次将点ipush进一个栈里,然后我们开始贪心,如果栈顶的max已经小于i了我们就pop,重复这个过程直到我们找到第一个max大于i的点作为左端点l

然后我们发现此时i到l里未被禁用的点就是所有的合法左端点,(i-l)再减去线段树的区间和就是这个右端点i的所有合法左端点数量了,枚举所有的右端点i然后把合法左端点数量加在一起就是答案了

代码的话很短,只有一个线段树和一个栈的代码量

上代码~

#include<cstdio>
#include<algorithm>
#include<stack>
#include<vector>
using namespace std;const int N=3*1e5+10;typedef long long ll;
int n;int mi[N];int ma[N];int cnt;int col[N];ll res;int T;
struct linetree//资瓷区间求和,区间赋值 
{
    int val[4*N];int sev[4*N];
    inline void pushdown(int p,int len)
    {if(sev[p]){val[p<<1]=len/2;val[p<<1|1]=len-len/2;sev[p<<1]=1;sev[p<<1|1]=1;}}
    void setval(int p,int l,int r,int dl,int dr)
    {
        if(dl==l&&dr==r){val[p]=r-l;sev[p]=1;return;}
        int mid=(l+r)/2;pushdown(p,r-l);
        if(dl<mid){setval(p<<1,l,mid,dl,min(mid,dr));}
        if(mid<dr){setval(p<<1|1,mid,r,max(dl,mid),dr);}
        val[p]=val[p<<1]+val[p<<1|1];
    }
    int sum(int p,int l,int r,int dl,int dr)
    {
        if(dl==l&&dr==r){return val[p];}
        int mid=(l+r)/2;int res=0;pushdown(p,r-l);
        if(dl<mid){res+=sum(p<<1,l,mid,dl,min(dr,mid));}
        if(mid<dr){res+=sum(p<<1|1,mid,r,max(dl,mid),dr);}
        return res;
    }
}lt;
struct data{int col;int pos;};stack <data> s;//开了一个栈 
inline void clear(stack <data>& st){stack <data> emp;swap(emp,st);}
inline void solve()
{
    scanf("%d",&n);
    for(int i=1;i<=n;i++){scanf("%d",&col[i]);}
    for(int i=1;i<=n;i++){mi[i]=0x3f3f3f3f;ma[i]=0;}
    for(int i=1;i<=4*n;i++){lt.val[i]=0;lt.sev[i]=0;}
    for(int i=1;i<=n;i++){mi[col[i]]=min(mi[col[i]],i);}//处理min 
    for(int i=1;i<=n;i++){ma[col[i]]=max(ma[col[i]],i);}//处理max 
    for(int i=1;i<=n;i++)//开始枚举右端点 
    {
        if(i==ma[col[i]]&&ma[col[i]]!=mi[col[i]])//如果是右端点的话就区间赋值表示禁用 
        {lt.setval(1,0,n,mi[col[i]],ma[col[i]]);}
        else {s.push((data){col[i],i});}
        for(;!s.empty()&&ma[s.top().col]<=i;s.pop());//找到左端点下限l 
        int l=(s.empty())?0:s.top().pos;//如果栈是空的话意味着所有左端点都可能合法 
        if(i!=l){res+=i-l-lt.sum(1,0,n,l,i);}//然后减去禁用的点数就好了 
    }printf("%lld\n",res);res=0;clear(s);//记得清空 
}
int main(){scanf("%d",&T);for(int z=1;z<=T;z++){solve();}return 0;}//拜拜程序~