题解:P14479 生成序列

· · 题解

首先 \mathcal{O}(n^3) 很明显有个 DP,我们设 dp_i 为考虑序列前 i 个元素的最大收益,实现如下:

for (int i=1;i<=n;i++)
{
    dp[i]=0;
    for(int j=0;j<i;j++)
    {
        int s=0;
        for (int k=1;k<=j;k++)
        {
            if(a[k]==a[j+k])
            {
                s=1;
                break;
            }
        }
        dp[i]=max(dp[i],dp[j]+s);
    }
}

考虑优化。

我们记一个 s 数组,s_i 表示是否存在一个 j\in[1,i] 使得 a_i=a_{i+j}

这样我们就可以 \mathcal{O}(n^2) 的先预处理出 s 然后再 \mathcal{O}(n^2) 的 dp,可以获得 29 分,代码如下:

:::success[29 分代码]

#include <bits/stdc++.h>
using namespace std;
#ifdef __linux__
#define gc getchar_unlocked
#define pc putchar_unlocked
#else
#define gc _getchar_nolock
#define pc _putchar_nolock
#endif
#define _ read<int>()
#define int long long
#define rint register int
inline bool blank(const char &x)
{
    return !(x ^ 32) || !(x ^ 10) || !(x ^ 13) || !(x ^ 9);
}
template<class T> inline T read()
{
    T r=0,f=1;char c=gc();
    while(!isdigit(c))
    {
        if(c=='-') f=-1;
        c=gc();
    }
    while(isdigit(c)) r=(r<<1)+(r<<3)+(c^48),c=gc();
    return f * r;
}
inline void out(rint x)
{
    if(x<0) pc('-'),x=-x;
    if(x<10) pc(x+'0');
    else out(x/10),pc(x%10+'0');
}
const int N=1e5+10;
int dp[N],a[N];
bitset<N>s;
inline int bf(const vector<int> &b)
{
    int n=b.size()-1;
    s=0;
    for(rint i=1;i<=(n>>1);i++)
    {
        for(rint j=1;j<=i;j++)
        {
            if(i+j>n) break;
            if(b[j]==b[j+i])
            {
                s[i]=1;
                break;
            }
        }
    }
    dp[0]=0;
    for(rint i=1;i<=n;i++)
    {
        dp[i]=dp[i-1];
        for(rint j=1;j<<1<=i;j++)
        {
            if(s[j])
            {
                dp[i]=max(dp[i],dp[j]+1);
            }
        }
    }
    return dp[n];
}

signed main()
{
    rint n=_,q=_;
    for(rint i=1;i<=n;i++)
    {
        a[i]=_;
    }
    while (q--)
    {
        rint l=_,r=_;
        rint len=r-l+1;
        vector<int> b(len+1);
        for(rint i=1;i<=len;i++)
        {
            b[i]=a[l+i-1];
        }
        out(bf(b));
        pc('\n');
    }
    return 0;
}

:::

我们发现 dp 的时候其实没必要再去从前枚举一下最大的 dp_j 所以直接记录一个前缀最大值就行,然后就会有 29 分,具体的我们可以记一个 pre 数组,pre_i 表示 dp_{1\dots i}+s_{1\dots i} 的最大值,转移:dp_i=pre_{i\div2}

容易发现这个优化完之后,我们只需解决怎么 \mathcal{O}(n\log n) 的时间内处理出 s,不过你发现对于这个 29 分的代码除了换一种枚举思路(也就是存一下每个值在哪里出现过可以拿到 52 分)以外,并没有好的解决方法,那么我们不妨换一个思路。

思考一下发现出题人卡我们的数据应该是什么呢?就是对于一个值,有很多次出现的地方,那我们考虑一下另一个思路,就是最开始的那份代码,我们第二个循环枚举的是什么?是步长,显然对于一个步长只需要记录一次那我们没必要全都枚举,如果这个步长已经被标记了,我们直接不枚举它。

具体来说:我们可以维护一个链表,链表里面存有 1\dots n\div 2,也就是步长,如果这个步长被标记了,就直接在链表里面删掉它,那么如果用这个思路对于 s 的预处理,代码如下:

:::success[链表优化] 初始化链表:

inline void initl(rint n) 
{
    head=1;
    for(rint i=1;i<=(n>>1);i++) nxt[i]=i+1;
    nxt[n>>1]=0;
}

预处理 s

inline void bff(const vector<int> &b)
{
    rint n=b.size()-1,le=(n>>1);
    initl(n);
    for(rint i=1;i<=le&&head;i++)
    {
        rint prev=0;
        for(rint d=head;d;)
        {
            rint nd=nxt[d];
            if(b[i]==b[i+d]&&i<=d&&i+d<=n)
            {
                s[d]=1;
                if(prev) nxt[prev]=nd;
                else head=nd;
            }
            else prev=d;
            d=nd;
        }
    }
}

:::

那么我们就可以在原来的代码的基础上加上这个优化就可以过。

当然我们得先计算上面的时间复杂度(52 分),其实是

\sum{cnt_i^2}

其中 cnt_i 等于值为 i 的元素在序列出现过几次。

所以当这个值超过 3\times 10^8 我们就使用链表优化。(因为链表优化的常数大)

:::success[Ac Code]

#include <bits/stdc++.h>
using namespace std;
#ifdef __linux__
#define gc getchar_unlocked
#define pc putchar_unlocked
#else
#define gc _getchar_nolock
#define pc _putchar_nolock
#endif
#define R register 
#define _ read<int>()
// #define int long long
#define rint register int
inline bool blank(const char &x)
{
    return !(x ^ 32) || !(x ^ 10) || !(x ^ 13) || !(x ^ 9);
}
template<class T> inline T read()
{
    T r=0,f=1;R char c=gc();
    while(!isdigit(c))
    {
        if(c=='-') f=-1;
        c=gc();
    }
    while(isdigit(c)) r=(r<<1)+(r<<3)+(c^48),c=gc();
    return f * r;
}
inline void out(rint x)
{
    if(x<0) pc('-'),x=-x;
    if(x<10) pc(x+'0');
    else out(x/10),pc(x%10+'0');
}
const int N=1e6+10;
int dp[N],a[N],pre[N];
bitset<N>s;
vector<int> pos[N];
int cnt[N];
int nxt[N],head;
inline void initl(rint n) 
{
    head=1;
    for(rint i=1;i<=(n>>1);i++) nxt[i]=i+1;
    nxt[n>>1]=0;
}
inline void bf(const vector<int> &b)
{
    rint n=b.size()-1;
    for(rint i=1;i<=n;i++) 
    {
        rint c=b[i];
        for(rint j=0,_sz=pos[c].size();j<_sz;j++)
        {
            rint lp=pos[c][j];
            if(lp>(i>>1)) break;
            rint d=i-lp;
            if(d<=(n>>1)) s[d]=1;
        }
        pos[c].push_back(i);
    }
    for(rint i=0;i<=N;i++) pos[i].clear();
}

inline void bff(const vector<int> &b)
{
    rint n=b.size()-1,le=(n>>1);
    initl(n);
    for(rint i=1;i<=le&&head;i++)
    {
        rint prev=0;
        for(rint d=head;d;)
        {
            rint nd=nxt[d];
            if(b[i]==b[i+d]&&i<=d&&i+d<=n)
            {
                s[d]=1;
                if(prev) nxt[prev]=nd;
                else head=nd;
            }
            else prev=d;
            d=nd;
        }
    }
}

inline int ans(const vector<int> &b)
{
    rint n=b.size()-1;
    dp[0]=pre[0]=0;
    rint t,p;
    for(rint i=1;i<=n;i++)
    {
        t=i>>1;
        dp[i]=max(dp[i-1],pre[t]);
        p=dp[i];
        if(i<=(n>>1)&&s[i]) p++;
        pre[i]=max(pre[i-1],p);
    }
    s=0;
    return dp[n];
}

signed main()
{
    rint n=_,q=_;
    for(rint i=1;i<=n;i++) a[i]=_;
    while (q--)
    {
        rint l=_,r=_;
        rint len=r-l+1;
        vector<int> b(len+1);
        for(rint i=1;i<=len;i++) b[i]=a[l+i-1];

        bool f=0;
        rint sum=0;
        for(rint i=1;i<=len;i++) 
        {
            ++cnt[b[i]];
            sum+=cnt[b[i]];
            if(sum>1e8) {f=1;break;}
        }
        if(f) bff(b);
        else bf(b);
        out(ans(b));
        pc('\n');
        memset(cnt,0,sizeof(cnt));
    }
    return 0;
}

:::