题解 P4065 【[JXOI2017]颜色】

shadowice1984

2018-04-02 16:17:26

Solution

非常传统的一道线段树题? 下面以这道题为例介绍一下线段树的经典应用——求点对贡献 # 本题题解 我们发现删完颜色之后剩余元素必定是一个连续的区间~~(这不是题目要求嘛)~~ 显然两个不同的连续区间对应着两个不同的删除颜色方案 所以我们统计方案数可以转化为统计有多少个不同的合法区间 而区间是可以被描述为点对的。 所以下面就是经典的求点对的总贡献的统计问题了。 对于这种问题我们的解决方案永远是一个——先枚举一个端点,再用数据结构解决另外一个端点(通常是线段树) 换句话讲,我们对于每个右端点,求出有多少个合法左端点,加在一起就是答案 那么对于这道题我们也是相同的思路统计,我们现在枚举右端点 假设我们枚举的右端点是i 我们考虑一个左端点l在什么时候合法 显然i之后的所有颜色要被删去,所以我们发现整个序列会被切割成几个小段 而这个l只能在和i在同一个区间里 我们设**颜色**为i的点的位置最大值为$max_{i}$,位置最小值为$min_{i}$ 显然所有$max_{k}$大于$i$的颜色都会被删去,因此合法左端点至少保证在它和i之间不存在一定会被删去的颜色 **也就是说我们要找到$max_{col_{j}}$大于$i$,且离i最近的点j** 然后我们选择的左端点就必须在j~i之间,且不能选j (显然j~i里边的点至少不会被删去(因为最右的颜色还是在i以内),但是越过j之后因为j的颜色一定会被删掉(因为$max_{col_{j}}$大于$i$)所以不可以越过j) 但是还没有完,我们发现如果一个颜色j的$max$小于$i$,那么我们的左端点不可以落在$(min_{j},max_{j}]$之间,因为如果落在了这个区间里,显然我们会发现这个颜色j会被删掉(因为$min_{j}$在左端点之前),但是$max_{j}$又在区间里,此时我们的区间就不是连续的了 所以我们可以设计出这样一个算法,从左到右枚举右端点i 如果这个点$i=max_{col_{i}}$的话, 我们就在线段树上区间$(min_{col_{i}},max_{col_{i}}]$赋值为1,表示这些点被禁用 现在我们唯一的问题变成如何确定$max_{col_{j}}$大于$i$,且离i最近的点j, 我们可以贪心,每次将点ipush进一个栈里,然后我们开始贪心,如果栈顶的$max$已经小于i了我们就pop,重复这个过程直到我们找到第一个$max$大于i的点作为左端点l 然后我们发现此时i到l里未被禁用的点就是所有的合法左端点,(i-l)再减去线段树的区间和就是这个右端点i的所有合法左端点数量了,枚举所有的右端点i然后把合法左端点数量加在一起就是答案了 代码的话很短,只有一个线段树和一个栈的代码量 上代码~ ```C #include<cstdio> #include<algorithm> #include<stack> #include<vector> using namespace std;const int N=3*1e5+10;typedef long long ll; int n;int mi[N];int ma[N];int cnt;int col[N];ll res;int T; struct linetree//资瓷区间求和,区间赋值 { int val[4*N];int sev[4*N]; inline void pushdown(int p,int len) {if(sev[p]){val[p<<1]=len/2;val[p<<1|1]=len-len/2;sev[p<<1]=1;sev[p<<1|1]=1;}} void setval(int p,int l,int r,int dl,int dr) { if(dl==l&&dr==r){val[p]=r-l;sev[p]=1;return;} int mid=(l+r)/2;pushdown(p,r-l); if(dl<mid){setval(p<<1,l,mid,dl,min(mid,dr));} if(mid<dr){setval(p<<1|1,mid,r,max(dl,mid),dr);} val[p]=val[p<<1]+val[p<<1|1]; } int sum(int p,int l,int r,int dl,int dr) { if(dl==l&&dr==r){return val[p];} int mid=(l+r)/2;int res=0;pushdown(p,r-l); if(dl<mid){res+=sum(p<<1,l,mid,dl,min(dr,mid));} if(mid<dr){res+=sum(p<<1|1,mid,r,max(dl,mid),dr);} return res; } }lt; struct data{int col;int pos;};stack <data> s;//开了一个栈 inline void clear(stack <data>& st){stack <data> emp;swap(emp,st);} inline void solve() { scanf("%d",&n); for(int i=1;i<=n;i++){scanf("%d",&col[i]);} for(int i=1;i<=n;i++){mi[i]=0x3f3f3f3f;ma[i]=0;} for(int i=1;i<=4*n;i++){lt.val[i]=0;lt.sev[i]=0;} for(int i=1;i<=n;i++){mi[col[i]]=min(mi[col[i]],i);}//处理min for(int i=1;i<=n;i++){ma[col[i]]=max(ma[col[i]],i);}//处理max for(int i=1;i<=n;i++)//开始枚举右端点 { if(i==ma[col[i]]&&ma[col[i]]!=mi[col[i]])//如果是右端点的话就区间赋值表示禁用 {lt.setval(1,0,n,mi[col[i]],ma[col[i]]);} else {s.push((data){col[i],i});} for(;!s.empty()&&ma[s.top().col]<=i;s.pop());//找到左端点下限l int l=(s.empty())?0:s.top().pos;//如果栈是空的话意味着所有左端点都可能合法 if(i!=l){res+=i-l-lt.sum(1,0,n,l,i);}//然后减去禁用的点数就好了 }printf("%lld\n",res);res=0;clear(s);//记得清空 } int main(){scanf("%d",&T);for(int z=1;z<=T;z++){solve();}return 0;}//拜拜程序~ ```