P8349 题解

· · 题解

upd on 2022-05-17:补充了一些证明。

SDOI 2022 D1T1。在考场上没有想出来小块对大块的做法,写了个 O(q(len_x+len_y)) 的暴力喜提 \text{30pts}

Part 1 : 考虑暴力该怎么做。

对于每次询问,首先我们将 a_i=xa_i=y 的所有 i 按照顺序放到一起。令这些 i 组成的序列为 p,令 len_x 表示整个 a 序列中 x 的数量,容易得到 p 的长度为 len_x+len_y

那么题意可以转化为,找出一个区间 [l,r],使得 \sum\limits_{i=l}^{r}[a_{p_i}=x]=\sum\limits_{i=l}^{r}[a_{p_i}=y],并令 \sum\limits_{i=l}^{r}b_{p_i} 最大。

设当前考虑到了第 i 个数,在 p 序列的前 i 个数中有 c_1 个是 xc_2 个是 y,当前的 b 数列的前缀和为 sum_i。令 q_i=c_1-c_2+len_y,可以发现,对两个相等的 q_l,q_r,区间 [l,r] 是一个满足题意的区间。

考虑维护当前每个 q_i 所对应的 sum_i 的最小值 mn_{q_i}。那么在计算到 i 时,只需将答案与 sum_i-mn_{q_i} 比较取最大值即可。另外若 c_1=c_2,即 q_i=len_y 时,答案也可以直接更新为当前的前缀和。

时间复杂度为 O(q(len_x+len_y))

Part 2 : 考虑根号平衡。

思考为什么上面的暴力无法通过此题。假设 a 序列是这样构成的:其中有 \dfrac{n}{2}1,剩余 \dfrac{n}{2} 个全部不同。那么只要每次询问 1,i,时间复杂度就将达到 O(n^2)

我们根据 len_i 大小是否达到 B,将数分为 “大块” 和 “小块” 两类。

下面分三类讨论:

1. 小块对小块。显然直接套用暴力即可。复杂度为 O(qB)

2. 大块对大块。因为大块的长度要超过 B,所以大块最多有 \dfrac{n}{B} 个。本质不同的询问数只有 \dfrac{n^2}{B^2} 种(对相同的询问,我们将答案用 unordered_map 存下来,如果遇到相同的询问直接调用答案即可),直接套用暴力,这部分的复杂度最大为 O\left(\dfrac{n^2}{B}\right)

3. 小块对大块。现在我们不能将两部分直接拼起来了。

考虑利用小块的大小不超过 B 的性质。根据暴力的做法,我们将 p 序列中的数分为有效点无效点两类。

显然,小块中每个点都是需要考虑在内的有效点。

然而大块中,许多点对答案是不可能产生贡献的。这是因为,选中的区间 [l,r] 中,xy 的数量必须相等,而大块因为点数多,而往往导致其远远压过小块,所以有很多点是无效点。

具体的,我们仍然考虑 p 序列,令 len_x\leq B<len_y

p 中有一段极长的 x 连续子段,其长度为 l,这个连续子段的左右端点分别为 L,R。那么,p 序列中只有在 L 左侧和 R 右侧的 l+1y 才可能成为有效点。这是因为,更多的 y 会压过 x 的个数。我们把这些点加入真正计算的 p 序列中。

另外注意,若 L 左侧 l+1y 中间还包含了另外一些 x 的极长连续子段,则这些 x 子段向两侧拓展时要跳过这些已经成为有效点的 y。这个过程可以使用 set 去快速维护。

得到有效点后将其离散化,再套用暴力即可。有效点的个数始终只与大小 \leq B 的小块相关,所以是 O(B) 级别的。

考虑到大块的数量不多,且复杂度取决于实际的小块长度,这部分的复杂度应当为 O\left(\dfrac{n^2\log n}{B}\right)。虽然有效点个数带大常数,但从下面的复杂度分析来看,实际上还要带一个非常小的常数,整体来看常数较小。

大块对大块的复杂度较小直接忽略,总复杂度为 O\left(qB+\dfrac{n^2\log n}{B}\right)。由均值不等式,在 B=\sqrt{\dfrac{n^2\log n}{q}} 时,复杂度为 O\left(n\sqrt{q\log n}\right) 最优。这式子看起来好像能达到 10^9,不过因为跑不满 + 常数小,而且时限 7s,所以是可以通过的。

按理论计算出来 B=1200 左右最优,但由于数据非常离谱,实际上取 100,200,300,550 时间都差不多,但是取 1000 的时候反而会 TLE。这样就通过了此题。

代码中为了简便,拓展 l+1 个有效点时直接是拓展了 2l 个。

upd 1:补充大块对大块、小块对大块的复杂度的证明。

大块对大块:假设有 k 个大块,第 i 个的长度为 len_i。极限情况下令 \sum\limits_{i=1}^{k}len_i=n

考虑令这 k 个大块两两配对进行询问,共进行 \dfrac{k(k-1)}{2} 次询问。总计算次数为 \sum\limits_{1\leq i<j\leq k}len_i+len_j=(k-1)\sum\limits_{i=1}^{k}len_i=n(k-1)

由于 len_i\geq B,得 k\leq\dfrac{n}{B},故极限复杂度为 O\left(\dfrac{n^2}{B}\right)

小块对大块:设有 k_1 个大块,第 i 个的长度为 l_i;有 k_2 个小块,第 i 个的长度为 L_i。则有 \sum\limits_{i=1}^{k_1}l_i+\sum\limits_{i=1}^{k_2}L_i=n

考虑到小块对大块的复杂度只取决于小块长度,我们用最长的 \dfrac{q}{k_1} 个小块对着大块进行询问。极限情况下,有 k_2=\dfrac{q}{k_1},也即每个小块都对着每个大块查了一次。总计算次数为 \sum\limits_{i=1}^{k_2}\sum\limits_{j=1}^{k_1}L_i\log l_j\leq k_1\log n\sum\limits_{i=1}^{k_2}L_i

最强的情况下,可构造 \dfrac{n}{2B} 个长度略大于 B 的大块,和总长度大约为 \dfrac{n}{2} 的一些小块,此时 k_1\approx\dfrac{n}{B},\sum\limits_{i=1}^{k_2}L_i\approx n,故极限复杂度为 O\left(\dfrac{n^2\log n}{B}\right)。此时要有 k_2\leq\dfrac{qB}{n}

code:

#include<iostream>
#include<cstdio>
#include<set>
#include<algorithm>
#include<vector>
#include<unordered_map>
using namespace std;
typedef long long ll;
const int blo=300;
int n,q,x,y,a[300005],len[300005],st[300005],tim[300005],pre[300005],cl[300005];
ll b[300005],mn[300005];unordered_map<int,ll> mp[300005];set<int> pos[300005];
struct node{int id;ll v;} xl[300005];
vector<node> tj[300005];vector<ll> qz[300005];
template<typename T> void Read(T &x)
{
    x=0;bool f=0;
    char ch=getchar();
    while(ch<'0'||ch>'9') f=f||(ch=='-'),ch=getchar();
    while(ch>='0'&&ch<='9') x=x*10+(ch^48),ch=getchar();
    x=f==0?x:-x;
}
int main()
{
    Read(n);Read(q);
    for(int i=1;i<=n;i++) Read(a[i]),tim[i]=tim[pre[a[i]]]+1,pre[a[i]]=i;
    for(int i=1;i<=n;i++) Read(b[i]);
    for(int i=1;i<=n;i++) tj[a[i]].push_back({i,b[i]}),pos[a[i]].insert(i);
    for(int i=1;i<=n;i++)
        if(tj[i].size())
        {
            len[i]=tj[i].size();qz[i].resize(len[i]+1);
            for(int j=0;j<len[i];j++) qz[i][j+1]=qz[i][j]+tj[i][j].v;
        }
    for(int i=0;i<=n;i++) mn[i]=1ll<<60;
    while(q--)
    {
        Read(x);Read(y);
        if(mp[x].find(y)!=mp[x].end()){printf("%lld\n",mp[x][y]);continue;}
        if((len[x]>blo&&len[y]>blo)||(len[x]+len[y]<=blo))
        {
            int tl=0,tr=0;ll ans=-(1ll<<60);
            while(tl<len[x]&&tr<len[y])
            {
                if(tj[x][tl].id<tj[y][tr].id) xl[tl+tr]=tj[x][tl],tl++;
                else xl[tl+tr]=tj[y][tr],tr++;
            }
            while(tl<len[x]) xl[tl+tr]=tj[x][tl],tl++;
            while(tr<len[y]) xl[tl+tr]=tj[y][tr],tr++;
            int cur1=0,cur2=0;
            for(int i=0;i<len[x]+len[y];i++)
            {
                if(a[xl[i].id]==x) cur1++;
                if(a[xl[i].id]==y) cur2++;
                int sl=cur1-cur2+len[y];
                if(sl==len[y]) ans=max(ans,qz[x][cur1]+qz[y][cur2]);
                ans=max(ans,qz[x][cur1]+qz[y][cur2]-mn[sl]);
                mn[sl]=min(mn[sl],qz[x][cur1]+qz[y][cur2]);cl[i]=sl;
            }
            mp[x][y]=mp[y][x]=ans;printf("%lld\n",ans);
            for(int i=0;i<len[x]+len[y];i++) mn[cl[i]]=1ll<<60;
        }
        else
        {
            if(len[x]>len[y]) swap(x,y);
            int top=0;ll ans=-(1ll<<60);
            for(int i=0;i<len[x];i++)
            {
                if(!pos[y].size()) continue;
                auto tmp=pos[y].lower_bound(tj[x][i].id);
                if(tmp==pos[y].begin()) continue;--tmp;st[++top]=*tmp;
                auto del=tmp;if(tmp==pos[y].begin()) continue;--tmp;
                st[++top]=*tmp;pos[y].erase(del);pos[y].erase(tmp);
            }
            for(int i=0;i<len[x];i++)
            {
                if(!pos[y].size()) continue;
                auto tmp=pos[y].upper_bound(tj[x][i].id);
                if(tmp==pos[y].end()) continue;st[++top]=*tmp;
                auto del=tmp;++tmp;if(tmp==pos[y].end()) continue;
                st[++top]=*tmp;pos[y].erase(del);pos[y].erase(tmp);
            }
            for(int i=1;i<=top;i++) pos[y].insert(st[i]);
            for(int i=0;i<len[x];i++) st[++top]=tj[x][i].id;
            sort(st+1,st+1+top);int cnt=unique(st+1,st+1+top)-st-1;
            for(int i=1;i<=cnt;i++) xl[i]={st[i],b[st[i]]};
            int cur1=0,cur2=0,las1=0,las2=0;
            for(int i=1;i<=cnt;i++)
            {
                if(a[xl[i].id]==x) cur1+=tim[xl[i].id]-las1,las1=tim[xl[i].id];
                if(a[xl[i].id]==y) cur2+=tim[xl[i].id]-las2,las2=tim[xl[i].id];
                int sl=cur2-cur1+len[x];
                if(sl==len[x]) ans=max(ans,qz[x][cur1]+qz[y][cur2]);
                ans=max(ans,qz[x][cur1]+qz[y][cur2]-mn[sl]);
                mn[sl]=min(mn[sl],qz[x][cur1]+qz[y][cur2]);cl[i]=sl;
            }
            mp[x][y]=mp[y][x]=ans;printf("%lld\n",ans);
            for(int i=1;i<=cnt;i++) mn[cl[i]]=1ll<<60;
        }
    }
    return 0;
}