题解:P10403 「XSOI-R1」跳跃游戏

· · 题解

upd:本文所有的“单调不降”改为“单调不增”。

区间最大公约数单调不增性

\{a_n\} 为任意一个序列,固定左端点 i ,考虑右端点 j 且满足 j \geq i

记区间 [i,j] 内所有元素的最大公约数为:

\gcd_{k=i}^j a_k = \gcd(a_i, a_{i+1}, \dots, a_j)

不难发现,若对任意满足 i \leq j_1 < j_2 的正整数 j_1, j_2 ,均有:

\gcd_{k=i}^{j_1} a_k \geq \gcd_{k=i}^{j_2} a_k

对于任意一个序列,若固定左端点,使右端点递增,其区间最大公约数是单调不增的。

数据类型与算法选择

因为区间最大公约数具有单调不增性,此时可以用二分快速查找位置。

由于需要多次查询区间最大公约数,可以用 ST 表维护这个序列中的最大公约数,并快速查询在 [x,y] 区间内的 \gcd(a_x , a_{x+1} , \dots , a_y)

思路

每次查询出最大公约数为 23 区间,需要统计区间内可以加贡献的端点,只需对其左右端点的奇偶性进行讨论再除以二即可。

发现这是一个等差数列,对其进行求和即可。

优化

从以下几点优化:

代码

#include<cstdio>
using namespace std;
const int N=600100,LOG=35;
int n,a[N],f[N][LOG],log2[N];
long long ans;
void read(int&x)                             //快读
{
  x=0;
  char ch=getchar();
  int f=1;
  while(ch<'0'||'9'<ch)
  {
    if(ch=='-')f*=-1;
    ch=getchar();
  }
  while('0'<=ch&&ch<='9')
  {
    x=x*10+ch-48;
    ch=getchar();
  }
  x*=f;
  return;
}
void write(long long x)                     //快写
{
  if(x<0)
  {
    putchar('-');
    x=-x;
  }
  if(x>9)write(x/10);
  putchar(x%10+'0');
  return;
}
int gcd(int x,int y)                        //辗转相除法求 gcd(x,y)
{
  if(x<y)
  {
    int temp=x;
    x=y;
    y=temp;
  }
  if(y==0)return x;
  return gcd(y,x%y);
}
void ST()                                  //预处理 ST 表
{
  for(int i=2; i<=n; ++i)log2[i]=log2[i>>1]+1;
  for(int i=1; i<=n; ++i)f[i][0]=a[i];
  for(int j=1; j<=log2[n]; ++j)for(int i=1; i<=n-(1<<j)+1; ++i)f[i][j]=gcd(f[i][j-1],f[i+(1<<j-1)][j-1]);
}
int query(int x,int y)                      //区间查询
{
  int lg=log2[y-x+1];
  return gcd(f[x][lg],f[y-(1<<lg)+1][lg]);
}
int binary_search_left(int x,int y,int k)  //在 [x,y] 区间内二分找出 k 的左端点
{
  int l=x-1,r=y+1,mid;
  while(l+1<r)
  {
    mid=l+(r-l>>1);
    if(query(x,mid)>k)l=mid;
    else r=mid;
  }
  return r;
}
int binary_search_right(int x,int y,int k) //在 [x,y] 区间内二分找出 k 的右端点
{
  int l=x-1,r=y+1,mid;
  while(l+1<r)
  {
    mid=l+(r-l>>1);
    if(query(x,mid)>=k)l=mid;
    else r=mid;
  }
  return l;
}
signed main()
{
  read(n);
  for(int i=1; i<=n; ++i)read(a[i]);
  ST();
  for(int i=1,len,cnt,l2,r2,l3,r3; i<=n; ++i)
  {
    l2=binary_search_left(i,n,2);
    r2=binary_search_right(i,n,2);
    l3=binary_search_left(i,n,3);
    r3=binary_search_right(i,n,3);
    if(l2<=r2)                             //存在 2 的区间
    {
      if((l2-i+1)&1)++l2;
      if((r2-i+1)&1)--r2;
      len=r2-l2+1;
      if(len>0)                            
      {
        cnt=(len>>1)+1;
        ans+=(long long)(l2-i+1+r2-i+1)*cnt>>1;
      }
    }
    if(l3<=r3)                              //存在 3 的区间
    {
      if(!(l3-i+1&1))++l3;
      if(!(r3-i+1&1))--r3;
      len=r3-l3+1;
      if(len>0)
      {
        cnt=(len>>1)+1;
        ans+=(long long)(l3-i+1+r3-i+1)*cnt>>1;
      }
    }
  }
  write(ans);
  putchar('\n');
  return 0;
}