题解:P5278 算术天才⑨与等差数列

· · 题解

思路

首先,我们考虑等差数列,要确定首项和末项,不难想到要维护区间最大最小值。

然后考虑一个等差数列的性质:

公差为 k 的等差数列中任意选出两个元素,他们做差一定是 k 的倍数。

把一个等差数列重排一下,然后做一个差分,这个差分数组的最大公约数等于题目给定值时,才可成立。

考虑还需要什么?还需要区间中不能出现重复的数。

这个我们对每个数维护前驱,然后变成一个数点的问题了。

所以,总结一下要满足的条件:

  1. 所有区间内的值的差的最大公约数为 k
  2. 区间内所有值各不相同。

条件一二

线段树板子,不再赘述。

条件三

记录每个数的前驱,要求区间所有数的前驱小于 l,或 k=0 时,要求 maxn=minn

预处理

map<int,set<int> > ma;
for (int i = 1; i <= n; i++)
{
    a[i]=read();
    c[i]=abs(a[i]-a[i-1]);//差分
    if (ma[a[i]].size()==0) pre[i]=-1;
    else pre[i]=*ma[a[i]].rbegin();
    ma[a[i]].insert(i);
}

添加时代码

注意:原位的后继也要修改。

auto lzx = ma[a[x]].find(x);//先找到原来的位置
lzx++;
if (lzx!=ma[a[x]].end())
{
    auto cpn = lzx;
    cpn--;
    if (cpn!=ma[a[x]].begin()) cpn--,pre[*lzx]=*cpn;
    else pre[*lzx]=-1;
    mod1(1,*lzx);
}//修改原来后继的前驱
lzx--;
ma[a[x]].erase(lzx);//删除原位置
ma[y].insert(x);//加入新位置
auto ljk = ma[y].lower_bound(x);
if (ljk==ma[y].begin()) pre[x]=-1;
else ljk--,pre[x]=*ljk;//找到新位置的前驱

其他代码可自行理解,不再赘述。

代码

#include <bits/stdc++.h>
#define int long long
using namespace std;
int read()
{
    int res = 0,f = 1;
    char ch = getchar();
    while (ch<'0'||ch>'9') f = (ch=='-'?-1:1),ch = getchar();
    while (ch>='0'&&ch<='9') res = (res<<3)+(res<<1)+(ch^48),ch = getchar();
    return res*f;
}
void write(int x)
{
    if (x<0) putchar('-'),x=-x;
    if (x>9) write(x/10);
    putchar(x%10+'0');
}
void writech(int x,char ch){write(x),putchar(ch);}
const int N = 3e5+5;
int n,m;
int a[N],c[N],pre[N];
struct tree
{
    int l,r;
    int maxn,minn,pre;
    int gcd;
}tr[4*N];
void pushup(int x)
{
    tr[x].maxn=max(tr[2*x].maxn,tr[2*x+1].maxn);
    tr[x].minn=min(tr[2*x].minn,tr[2*x+1].minn);
    tr[x].pre=max(tr[2*x].pre,tr[2*x+1].pre);
    tr[x].gcd=__gcd(tr[2*x].gcd,tr[2*x+1].gcd);
}
void bt(int x,int l,int r)
{
    tr[x].l=l,tr[x].r=r;
    if (l==r)
    {
        tr[x].maxn=tr[x].minn=a[l];
        tr[x].gcd=c[l];
        tr[x].pre=pre[l];
        return ;
    }
    int mid = (l+r)/2;
    bt(2*x,l,mid);
    bt(2*x+1,mid+1,r);
    pushup(x);
}
void mod1(int x,int q)
{
    int l = tr[x].l,r = tr[x].r;
    if (l==q&&r==q)
    {
        tr[x].maxn=tr[x].minn=a[l];
        tr[x].pre=pre[l];
        return ;
    }
    int mid = (l+r)/2;
    if (q<=mid) mod1(2*x,q);
    else mod1(2*x+1,q);
    pushup(x);
}
void mod2(int x,int q)
{
    int l = tr[x].l,r = tr[x].r;
    if (l==q&&r==q)
    {
        tr[x].gcd=c[l];
        return ;
    }
    int mid = (l+r)/2;
    if (q<=mid) mod2(2*x,q);
    else mod2(2*x+1,q);
    pushup(x);
}
tree query(int x,int ql,int qr)
{
    int l = tr[x].l,r = tr[x].r;
    if (ql<=l&&r<=qr) return tr[x];
    int mid = (l+r)/2;
    tree ll,rr;
    bool hl=false,hr=false;
    if (ql<=mid) ll = query(2*x,ql,qr),hl=true;
    if (qr>mid) rr = query(2*x+1,ql,qr),hr=true;
    if (hl&&hr)
    {
        tree res;
        res.l=ll.l,res.r=rr.r;
        res.maxn=max(ll.maxn,rr.maxn);
        res.minn=min(ll.minn,rr.minn);
        res.pre=max(ll.pre,rr.pre);
        return res;
    }
    if (hl) return ll;
    else return rr;
}
int query2(int x,int ql,int qr)
{
    int l = tr[x].l,r = tr[x].r;
    if (ql<=l&&r<=qr) return tr[x].gcd;
    int mid = (l+r)/2;
    int res = -1;
    if (ql<=mid) res = query2(2*x,ql,qr);
    if (qr>mid) res = (res==-1?query2(2*x+1,ql,qr):__gcd(res,query2(2*x+1,ql,qr)));
    return res;
}
map<int,set<int> > ma;
signed main()
{
    n=read(),m=read();
    for (int i = 1; i <= n; i++)
    {
        a[i]=read();
        c[i]=abs(a[i]-a[i-1]);
        if (ma[a[i]].size()==0) pre[i]=-1;
        else pre[i]=*ma[a[i]].rbegin();
        ma[a[i]].insert(i);
    }
    bt(1,1,n);
    int cnt = 0;
    while (m--)
    {
        int op=read();
        if (op==1)
        {
            int x=read()^cnt,y=read()^cnt;
            auto lzx = ma[a[x]].find(x);
            lzx++;
            if (lzx!=ma[a[x]].end())
            {
                auto cpn = lzx;
                cpn--;
                if (cpn!=ma[a[x]].begin()) cpn--,pre[*lzx]=*cpn;
                else pre[*lzx]=-1;
                mod1(1,*lzx);
            }
            lzx--;
            ma[a[x]].erase(lzx);
            ma[y].insert(x);
            auto ljk = ma[y].lower_bound(x);
            if (ljk==ma[y].begin()) pre[x]=-1;
            else ljk--,pre[x]=*ljk;
            a[x]=y;
            c[x]=abs(a[x]-a[x-1]);
            c[x+1]=abs(a[x+1]-a[x]);
            mod1(1,x);
            mod2(1,x);
            if (x+1<=n) mod2(1,x+1);
        }
        else
        {
            int l=read()^cnt,r=read()^cnt,k=read()^cnt;
            if (l==r){puts("Yes");cnt++;continue;}
            tree dlf = query(1,l,r);
            int maxn = dlf.maxn,minn = dlf.minn,mpre = dlf.pre;
            int d = query2(1,l+1,r);
            if (maxn-minn!=k*(r-l)){puts("No");continue;}
            if (k&&mpre>=l){puts("No");continue;}
            if (d!=k){puts("No");continue;}
            puts("Yes");cnt++;
        }
    }
    return 0;
}

代码稍长,谨慎使用。