xtq的神笔-赛后题解

· · 题解

20pts:

枚举从哪里开始以及每一笔的结束位置,统计答案。

40pts:

约定:

or[i,j] = a_i$ $or$ $a_{i+1}$ $or$ $...$ $or$ $a_j and[i,j] = a_i$ $and$ $a_{i+1}$ $and$ $...$ $and$ $a_j gcd[i,j] = gcd(a_i, a_{i+1}, a_{i+2}, ... , a_j)

f[i]为画到i为止可以获得的最大分数, 则显然有转移:

f[i] = max\{f[j-1]+or[i,j]+and[i,j]+gcd[i,j]\}

其中j\le i。 遍历ij进行dp可以获得这一部分分。

70pts:

观察到一个重要性质:

若区间[p,q]包含于[r,s],则有:

or[p,q]\le or[r,s] and[r,s]\le and[p,q] gcd[r,s]\le gcd[p,q]

如果我们固定一个右端点q并且从大到小遍历左端点p,并且设每个格子的分数范围在[0,k]中,那么or[p,q],and[p,q],gcd[p,q]都会持续不严格地递增\递减, 并且在递增\递减O(logk)次之后就会达到一个最大\最小值不再改变

这是不难证明的:

对于orand,由于每次变化都会改变其一个二进制位,所以最多将每个二进制位都改变一次,也就是总共改变O(logk)次。

对于gcd,每次变化之后的值都是之前值的一个因数,必然是不超过之前值的一半的,所以最多变化O(logk)次就会达到1

这也就意味着,对于每一个状态f[i],我们可以把i之前的决策点分成至多O(logk)段,使得对于每一段,任意属于这一段的j_1,j_2,有or[j_1,i]=or[j_2,i],and[j_1,i]=and[j_2,i],gcd[j_1,i]=gcd[j_2,i]

那么对于每一个i,二分查找出前面每一段的左右端点,然后对于每一段内使用数据结构查询f的最大值进行转移。

复杂度分析: 一共枚举O(n)个待转移的状态,每个状态需要找出O(logk)段,每一段需要O(logn)时间找出。由于求区间gcd在最坏情况下需要O(logk)的复杂度,所以总共的复杂度上界是O(nlognlog^2k)

事实上,由于这个上界十分宽松,这样的算法应该可以获得接近满分或者满分。

100pts:

我们每次都要重新二分找出段落,这是十分浪费时间的。

事实上,每次计算出一个f[i],我们需要加入可用决策点的只有i一个位置。 而且我们不难想到,对于f[i]的决策点分出的每一段,在f[i+1]情况下的分段中必然不会被拆成多个小段,而只可能和其他相邻的段落合并成一个大的段落。

所以,我们从始至终只需要维护当前决策点的分段即可。每次计算出f[i]后,将[i,i]作为新的一段加入,然后遍历每一段,判断在f[i+1]的情况下相邻两段的or,and,gcd值是否都相同,如果相同就进行合并。

每一段除了维护or,and,gcd以外,还需要维护当前段的最大f值,这样我们在转移时对于每一段只需要O(1)时间来查询当前段的最优决策点进行转移。

此外,std中使用了st表来维护区间的or,and,gcd值,这样的话区间合并就可以做到O(logk)

复杂度分析:

预处理st表:O(nlognlogk)

维护决策点段落:由于每一时刻至多有O(logk)段,所以总复杂度是O(nlog^2k)

转移:总的转移复杂度显然是O(nlogk)的。

总复杂度的上界为O(nlog^2k)。这个上界也比较宽松,可以通过此题。

此题完结。

代码:

#include <iostream>
#include <cstdlib>
#include <cstdio>
#include <cstring>
#include <ctime>

using namespace std;
const int MAXN = 300005;
typedef long long ll;
int n,k,a[MAXN],bin[MAXN<<1];
inline int gcd(int a, int b)
{
    int c;
    while(b)
    {
        c = a;
        a = b;
        b = c%b;
    }
    return a;
}
int sto[MAXN][19],sta[MAXN][19],stg[MAXN][19];
inline int queyro(int l, int r)
{
    int kk = bin[r-l+1];
    return sto[l][kk]|sto[r-(1<<kk)+1][kk];
}
inline int querya(int l, int r)
{
    int kk = bin[r-l+1];
    return sta[l][kk]&sta[r-(1<<kk)+1][kk];
}
inline int queryg(int l, int r)
{
    int kk = bin[r-l+1];
    return gcd(stg[l][kk],stg[r-(1<<kk)+1][kk]);
}
ll f[MAXN];
struct Node
{
    ll val1,val2,val3,maxf;
}pool[MAXN];
int tot,front = 0, rear = 0;
int nxt[MAXN],lst[MAXN];
inline void insert(Node t)
{
    pool[++tot] = t;
    nxt[rear] = tot, lst[tot] = rear;
    rear = tot;
    nxt[rear] = -1;
}
inline void remove(int pos)
{
    nxt[lst[pos]] = nxt[pos];
    if(pos!=rear) lst[nxt[pos]] = lst[pos];
    if(pos==rear) rear = lst[pos];
}
inline int read()
{
    char c = 0;
    int res = 0, flag = 1;
    while(c<'0'||c>'9')
    {
        c = getchar();
        if(c=='-') flag = -1;
    }
    while(c>='0'&&c<='9')
    {
        res = res*10+c-'0';
        c = getchar();
    }
    return res*flag;
}
inline void init()
{
    tot = 0;
    front = rear = 0;
    memset(nxt,0,sizeof(nxt));
    memset(lst,0,sizeof(lst));
    for(int i = 0; i<MAXN; i++)
        f[i] = -(1ll<<62);
    nxt[0] = -1;
}

int main()
{
    int T;
    cin >> T;
    while(T--)
    {
        init();
        cin >> n >> k;
        for(int i = 1; i<=n; i++)
            a[i] = read();
        for(int i = 1; i<=k; i++)
            f[i-1] = read();
        for(int cnt = 0, i = 1; i<=n; i <<= 1, cnt++)
            for(int j = i; j<(i<<1); j++)
                bin[j] = cnt;
        for(int i = 1; i<=n; i++)
            sto[i][0] = sta[i][0] = stg[i][0] = a[i];
        for(int j = 1; (1<<j)<=n; j++)
            for(int i = 1; i+(1<<j)-1<=n; i++)
            {
                sto[i][j] = sto[i][j-1]|sto[i+(1<<(j-1))][j-1];
                sta[i][j] = sta[i][j-1]&sta[i+(1<<(j-1))][j-1];
                stg[i][j] = gcd(stg[i][j-1],stg[i+(1<<(j-1))][j-1]); 
            }
        Node tmpnode;
        for(int i = k; i<=n; i++)
        {
            int tmppos = nxt[front];
            while(tmppos>=0)
            {
                pool[tmppos].val1 |= a[i];
                pool[tmppos].val2 &= a[i];
                pool[tmppos].val3 = gcd(pool[tmppos].val3,a[i]);
                tmppos = nxt[tmppos];
            }
            tmpnode.val1 = queyro(i-k+1,i);
            tmpnode.val2 = querya(i-k+1,i);
            tmpnode.val3 = queryg(i-k+1,i);
            tmpnode.maxf = f[i-k];
            insert(tmpnode);
            tmppos = nxt[front];
            while(nxt[tmppos]>=0&&tmppos>=0)
            {
                if(pool[nxt[tmppos]].val1==pool[tmppos].val1&&
                   pool[nxt[tmppos]].val2==pool[tmppos].val2&&
                   pool[nxt[tmppos]].val3==pool[tmppos].val3)
                {
                    pool[tmppos].maxf = max(pool[tmppos].maxf,pool[nxt[tmppos]].maxf);
                    remove(nxt[tmppos]);
                }
                tmppos = nxt[tmppos];
            }
            tmppos = nxt[front];
            while(tmppos>=0)
            {
                f[i] = max(f[i],pool[tmppos].maxf+pool[tmppos].val1+
                                pool[tmppos].val2+pool[tmppos].val3);
                tmppos = nxt[tmppos];
            }
        }
        cout << f[n] << endl;
    }
    return 0;
}