SMAWK 学习笔记

· · 题解

APIO2021 讲课 \color{black}{I}\red{tst} 讲的一个神仙算法。

此题大致题意是,给定一个 n\times m 矩阵 A 满足 1\leq x<y \leq n,min_x(A)\leq min_y(A)min_k(A)A 中第 k 行最小值的位置。要通过不超过 4\times(n+m) 次查询单个元素的值求出矩阵中的最小值。

这里直接引入 SMAWK 算法。

首先介绍 SMAWK 算法的一个重要过程:reduce 过程。

首先我们定义,当一个位置不可能是当前行的最小值位置,那么称其为冗余位置。

那么我们发现每一列上面冗余位置一定包括了一个列上的前缀。接下来我们考虑如何删除冗余的列来得到 n\times \min(n,m) 的矩阵。

我们维护一个指针 k 初始为 1 ,然后进行如下操作:

1:当 n ≥ m 时结束,否则比较 a_{k,k}a_{k,k+1}

2:若 a_{k,k} ≥ a_{k,k+1},删除第 k 列,k = \max(k-1,1) ,回到步骤 1;

3:若 a_{k,k} < a_{k,k+1}k=n ,删除第 n+1 列,回到步骤 1;

4:若 a_{k,k} < a_{k,k+1}k < nk=k+1,回到步骤 1。

接下来我们考虑这么做为什么是对的:

我们目前维护的矩阵 k\times k 显然在对角线以上(注意不包括对角线)的必定是冗余位置,否则中间一定有冗余列。

a_{k,k} ≥ a_{k,k+1} 那么显然第 k 整列都是冗余列,因为 a_{k,1……k-1} 部分已经确定为冗余位置,而 a_{k,k} 也被确定为冗余位置也直接导致整列被确定为冗余列,于是删除此列。

否则我们考虑,若 a_{k,k} < a_{k,k+1}k=n 那么显然下一列必定是冗余列,删除即可,否则我们还不能确定下一列是否为冗余列,将 k++ 即可。

于是 reduce 过程结束了,返回一个 n\times \min(n,m) 的矩阵 b

考虑一次 reduce 会查询几个值,每次我们会发现 k 要么加 1 或减 1 且删除一列,查询的值差不多是 O(n+m) 个。

接下来引入 SMAWK 的主体 SMAWK(a)

接下来我们考虑将 b 分成两部分,奇数行与偶数行分开,这样子我们就把这个矩阵分成了两个集合 a,c ,我们先运行 SMAWK(c),求出偶数行每行的最小值位置。

接下来我们根据这个矩阵的性质,可以得出第一行最小值所在的位置 p_1 一定在 1 ~ p_2 之间 p_3 一定在 p_2 ~ p_4 之间……以此类推。

然后我们这部分暴力查询就得到了 a 中每一行的最小值的位置,一次查询了 O(n+m) 个值。

接下来我们分析一下总体的复杂度:

f(n,m)n\times m 的矩阵进行 SMAWK 的复杂度,那么可以得到当 n\leq m 时有 f(n,m)=f(\frac{n}{2},n))+n+mn>m 时为 f(n,m)=f(\frac{n}{2},m)+n+m 并且 n>m 的情况递归 \log\frac{n}{m} 层后变为了 n\leq m 的情况,所以总体复杂度是 O(m(1 + \log\frac{n}{m}) + n)

具体实现我用的是链表进行删除列的操作,子矩阵用维护列与行的集合操作。


//QwQcOrZ yyds!!!
#include<bits/stdc++.h>
#define ll long long
#define F(i,a,b) for (int i=(a);i<=(b);i++)
#define R(i,a,b) for (int i=(a);i<(b);i++)
#define D(i,a,b) for (int i=(a);i>=(b);i--)
#define go(i,x) for (int i=head[x];i;i=e[i].nx)
#define mp make_pair
#define pb push_back
#define pa pair < int,int >
#define fi first
#define se second
#define re register
#define be begin()
#define en end()
#define sqr(x) ((x)*(x))
//#define ret return puts("-1"),0;
#define put putchar('\n')
#define inf 1e18
#define mod 998244353
//#define int ll
#define N 1000005
using namespace std;
inline char gc(){static char buf[100000],*p1=buf,*p2=buf;return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++;}
#define gc getchar
inline ll read(){char c=gc();ll su=0,f=1;for (;c<'0'||c>'9';c=gc()) if (c=='-') f=-1;for (;c>='0'&&c<='9';c=gc()) su=su*10+c-'0';return su*f;}
inline void write(ll x){if (x<0){putchar('-');write(-x);return;}if (x>=10) write(x/10);putchar(x%10+'0');}
inline void writesp(ll x){write(x),putchar(' ');}
inline void writeln(ll x){write(x);putchar('\n');}
int pre[N],suf[N],M[N],n,m,ans=1e18,P=0;
vector<int>L,H;
map<int,int>Mp[N];
int query(int x,int y,int z)
{
    if (Mp[x][y]) return Mp[x][y];
    cout<<"? "<<x<<" "<<y<<endl;
    int p;
    cin>>p;
    Mp[x][y]=p;
    return p;
}
void del(int x)
{
    if (pre[x]!=-1)
        suf[pre[x]]=suf[x]; else P=suf[x];
    pre[suf[x]]=pre[x];
}
vector<int> reduce(vector<int>X,vector<int>Y)
{
    for (int i=0;i<Y.size();i++) pre[i]=i-1,suf[i]=i+1;
    int x=0,y=0;
    P=0;
    for (int nmsl=Y.size()-X.size();nmsl>0;nmsl--)
    {
        if (query(X[x],Y[y],0)>query(X[x],Y[suf[y]],0))
        {
            y=suf[y];
            del(pre[y]);
            if (x) y=pre[y],x--;
        } else
        {
            if (x==X.size()-1) del(suf[y]);
            else
            {
                y=suf[y];
                x++;
            }
        }
    }
    vector<int>ret;
    for (int i=P;i!=Y.size();i=suf[i])  ret.push_back(Y[i]);
    return ret;
}
void Solve(vector<int>X,vector<int>Y)
{
    Y=reduce(X,Y);
    if (X.size()==1)
    {
        M[X[0]]=Y[0];
        return;
    }
    vector<int>Z;
    for (int i=0;i<X.size();i++)
        if (!(i%2)) Z.push_back(X[i]);
    Solve(Z,Y);
    for (int i=0;i<X.size();i++)
    {
        if (!(i%2)) continue;
        int l=lower_bound(Y.begin(),Y.end(),M[X[i-1]])-Y.begin();
        int r=0;
        if (i==X.size()-1) r=Y.size()-1;
        else 
        {
            r=lower_bound(Y.begin(),Y.end(),M[X[i+1]])-Y.begin();
        }
        M[X[i]]=Y[l];
        while (l<r)
        {
            l++;
            if (query(X[i],Y[l],1)<query(X[i],M[X[i]],1)) M[X[i]]=Y[l];
        }
    }
}
signed main()
{
    cin>>n>>m;
    for (int i=1;i<=n;i++)
    {
        H.push_back(i);
    }
    for (int i=1;i<=m;i++) L.push_back(i);
    Solve(H,L);
    for (int i=1;i<=n;i++)
        ans=min(ans,query(i,M[i],1));
    cout<<"! "<<ans<<endl;

}
/*

*/