P3541 [POI2010] Monotonicity 题解

· · 题解

很容易想到这题的 DP。dp_{i,j} 表示目前到了 a_is 序列匹配到了第 j 位的最大长度。转移就从 dp_{t, (j-1==0)?k:j-1} 转移来,其中 t 要满足 s_j 的条件。然后容易看出这个转移可以用权值线段树来优化。然后复杂度就是 O(nk\log{n}) 的。这道题需要动态开点。

需要注意的是,符号序列与数字序列不是一一对应的,而是两个数字间的相对关系。所以要特殊处理一下序列最开始的位置。我这里是用 dp_{?,0} 来表示的。方案的话记录一下转移路径就行了。

#include <cstdio>
#include <cstring>
#include <utility>
#include <iostream>
#include <algorithm>
using namespace std;
typedef pair<int, int> pii;
#define px first
#define py second

const int K = 105;
const int N = 20005;
const int inf = 1000001;
int n, k, ans, top, bar[N], a[N], s[K], dp[N][K];
pii lst[N][K];

struct sgtree{
    int rt, tot = 0;
    struct segtree{
        int lc, rc, pl, maxn;
        #define lc(x) tree[x].lc
        #define rc(x) tree[x].rc
        #define pl(x) tree[x].pl
        #define maxn(x) tree[x].maxn 
    };
    segtree tree[N * 7];
    void pushup(int x){
        if(! lc(x)) maxn(lc(x)) = - inf;
        if(! rc(x)) maxn(rc(x)) = - inf;
        if(maxn(lc(x)) > maxn(rc(x))) maxn(x) = maxn(lc(x)), pl(x) = pl(lc(x));
        else maxn(x) = maxn(rc(x)), pl(x) = pl(rc(x));
    }
    void modify(int &x, int l, int r, int p, int id, int val){
        if(! x) x = ++ tot;
        if(l == r){maxn(x) = val; pl(x) = id; return;}
        int mid = (l + r) >> 1;
        if(p <= mid) modify(lc(x), l, mid, p, id, val);
        else modify(rc(x), mid + 1, r, p, id, val);
        pushup(x);
    }
    pii query(int x, int l, int r, int ql, int qr){
        if(ql > qr) return {0, - inf}; if(! x) return {0, - inf}; 
        if(ql <= l && r <= qr) return {pl(x), maxn(x)};
        int mid = (l + r) >> 1; pii q1 = {0, - inf}, q2 = {0, - inf};
        if(ql <= mid) q1 = query(lc(x), l, mid, ql, qr);
        if(qr > mid) q2 = query(rc(x), mid + 1, r, ql, qr);
        if(q1.py > q2.py) return {q1.px, q1.py};
        return {q2.px, q2.py};
    }
};
sgtree t[K];

int main(){
    scanf("%d %d", &n, &k);
    for(int i = 1; i <= n; i ++) scanf("%d", &a[i]);
    for(int i = 1; i <= k; i ++){
        char c; cin >> c;
        if(c == '<') s[i] = 1; else if(c == '=') s[i] = 2; else s[i] = 3;
    }
    memset(dp, - 0x3f, sizeof dp);
    for(int i = 1; i <= n; i ++){
        dp[i][0] = 1; pii tmp1, tmp2, tmp; int flag;
        if(s[1] == 1){
            tmp1 = t[k].query(t[k].rt, 1, inf, 1, a[i] - 1);
            tmp2 = t[0].query(t[0].rt, 1, inf, 1, a[i] - 1);
        }
        if(s[1] == 2){
            tmp1 = t[k].query(t[k].rt, 1, inf, a[i], a[i]);
            tmp2 = t[0].query(t[0].rt, 1, inf, a[i], a[i]);
        }
        if(s[1] == 3){
            tmp1 = t[k].query(t[k].rt, 1, inf, a[i] + 1, inf);
            tmp2 = t[0].query(t[0].rt, 1, inf, a[i] + 1, inf);
        }
        if(tmp1.py > tmp2.py) tmp.py = tmp1.py, tmp.px = tmp1.px, flag = k;
        else tmp.py = tmp2.py, tmp.px = tmp2.px, flag = 0;
        for(int j = k; j >= 2; j --){
            pii tmp; int op = j - 1;
            if(s[j] == 1) tmp = t[op].query(t[op].rt, 1, inf, 1, a[i] - 1);
            if(s[j] == 2) tmp = t[op].query(t[op].rt, 1, inf, a[i], a[i]);
            if(s[j] == 3) tmp = t[op].query(t[op].rt, 1, inf, a[i] + 1, inf);
            if(dp[i][j] < tmp.py + 1){
                dp[i][j] = tmp.py + 1;
                lst[i][j] = {tmp.px, op};
            }
            if(dp[i][j] > 0) t[j].modify(t[j].rt, 1, inf, a[i], i, dp[i][j]);
        }
        if(dp[i][1] < tmp.py + 1){
            dp[i][1] = tmp.py + 1;
            lst[i][1] = {tmp.px, flag};
        }
        if(dp[i][1] > 0) t[1].modify(t[1].rt, 1, inf, a[i], i, dp[i][1]);
        t[0].modify(t[0].rt, 1, inf, a[i], i, dp[i][0]);
    }
    for(int i = 1; i <= n; i ++) for(int j = 0; j <= k; j ++) ans = max(ans, dp[i][j]);
    int x, y;
    for(int i = 1; i <= n; i ++) for(int j = 0; j <= k; j ++) if(dp[i][j] == ans){
        x = i; y = j; break;
    }
    while(y){
        bar[++ top] = x; int tx = x, ty = y;
        x = lst[tx][ty].px; y = lst[tx][ty].py;
    }
    bar[++ top] = x;
    printf("%d\n", ans);
    for(int i = top; i >= 1; i --) printf("%d ", a[bar[i]]);
    puts(""); return 0;
}