题解:P11822 [湖北省选模拟 2025] 团队分组 / divide

· · 题解

首先有一个 n^2\log n 的贪心,我们对于每个 k 从后往前分段,每段的长度尽可能短。这个做法常数很小,但是一个递减序列就把我们卡满了。

首先可以证明我们一共最多有 \sqrt n 种段如果对于两个不同的 ka 为第一个序列的 a 数组,b 为第二个序列的 a 数组 ,有 a_i=b_i,a_{i+1}=b_{i+1} 那么 \{a_p=b_p|p\le i\}。于是我们可以记忆化区间,这样总复杂度就是 n\sqrt n\log n

考虑到 hash 表常数很大,所以我们在 \displaystyle\sum_{i=1}^n k=i\texttt{时的段数} 很小时直接跑暴力就可以了。

注意要使用 hash 表存状态,而不是记录上次的指针位置并二分后面几个指针的新位置,似乎存在一种构造方式可以使这种做法所有指针在两个集合中反复横跳导致指针总移动次数很大。

代码。

#include<bits/stdc++.h>
#include<bits/extc++.h>
using namespace std;
#define int long long
const int N = 1e5+100;
int n,a[N],s[N];
int p[N],cnt,e[N],res,ls;
bool q[N];
__gnu_pbds::gp_hash_table<int,int> mp1[N],mp2[N];
inline int solve1(int x)
{
    int res=0;
    int now=x-1;
    cnt=0;
    p[++cnt]=x+1,e[cnt]=-1,p[++cnt]=x,e[cnt]=a[x],ls=a[x];
    while(s[now]>ls)
    {
        int l=1,r=now,mid,ans=0;
        while(l<=r)
        {
            mid=(l+r)>>1;
            if(s[now]-s[mid-1]>ls)ans=mid,l=mid+1;
            else r=mid-1;
        }
        ls=s[now]-s[ans-1];
        if(mp1[ans].find(ls)!=mp1[ans].end())
        {
            int tmp=mp2[ans][ls];
            res=mp1[ans][ls];
            for(int i=cnt,j=1;i;j++,i--)res+=p[i]*(j+tmp),mp1[p[i]][e[i]]=res,mp2[p[i]][e[i]]=tmp+j;
            return res;
        }
        now=ans-1;
        p[++cnt]=ans;
        e[cnt]=ls;
    }
    int tmp=0;
    for(int i=cnt,j=1;i;j++,i--)res+=p[i]*(j+tmp),mp1[p[i]][e[i]]=res,mp2[p[i]][e[i]]=tmp+j;
    return res;
}
inline int solve(int x)
{
    int su=0,res=0;
    int now=x-1;
    cnt=0;
    int ls=a[x];
    res=x*3+2,su=x*2+1;
    while(s[now]>ls)
    {
        int l=1,r=now,mid,ans=0;
        while(l<=r)
        {
            mid=(l+r)>>1;
            if(s[now]-s[mid-1]>ls)ans=mid,l=mid+1;
            else r=mid-1;
        }
        ls=s[now]-s[ans-1];
        now=ans-1;
        su+=ans;
        res+=su;
    }
    return res;
}
signed main()
{
    int st=clock();
    ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
    __gnu_pbds::gp_hash_table<int,int> cnt;
    cin>>n;
    for(int i=1;i<=n;i++)cin>>a[i],cnt[a[i]]++,s[i]=s[i-1]+a[i];
    bool tmp=0;
    for(auto i:cnt)if(i.second>n/4)tmp=1;
    if(!tmp)
    {
        for(int i=1;i<=n;i++)
            cout<<solve1(i)<<' ';
    }
    else
    {
        for(int i=1;i<=n;i++)
            cout<<solve(i)<<' ';
    }
    cerr<<(double)(clock()-st)/CLOCKS_PER_SEC;
}