题解 P6854 【Tram】

· · 题解

满足第二个条件的区间不会很多,而判断这样的一个区间是否满足第一个条件只要查询最值,所以我们只需要找出所有满足第二个条件的区间,这可以枚举左端点,用线段树维护合法的右端点集合,复杂度O(Ans\times logn).

我们现在需要估计满足第二个条件的区间到底有多少个,一个显然的上界是O(n\sqrt{n}),但这个界很松.

考虑把序列中所有数按照权值从小到大加入.假设我们现在加入的是权值为x的所有数,一共k个,在加入之前序列中合法子区间的数量是Cnt.现在我们考虑加入的数对答案的影响,把所有区间分成下面几类.

原本的合法区间,没有被加入任何数.

原本的合法区间,中间恰好被加入了xx.

以上这些区间不以x开头也不以x结尾,至多Cnt个.

原本的合法区间,在开头加上了xx

原本的合法区间,在末尾加上了xx

x开头结尾的合法区间.

第三类区间的数量至多是k-x+1.

将连续的xx称为一块,考虑相邻的两个块,如果它们之间的距离是d,那么它们至多会对区间个数产生2\times min(d,x-1)的贡献.所以第一二类区间的数量至多是(\frac{k}{x}+1)\times 2\times (x-1).

容易看出最终合法区间的个数是O(n)级别的.

线段树,O(nlog{n}).

PS:这题实现的时候需要注意一下常数问题,不恰当的实现方式可能会导致2到3倍(甚至30倍)的常数。而另外一些log做法的常数似乎也要这个级别......

bonus:找到线性做法。

以下代码来自EndSaH


#include <cstdio>
#include <iostream>
#include <vector>
#include <algorithm>

using namespace std;
using LL = long long;

const int maxN = 1e6 + 5;

int n, m, ans, rig;
vector<int> pos[maxN];
int a[maxN];

template<typename _Tp>
inline bool Chkmin(_Tp& x, const _Tp& y)
{ return x > y ? x = y, true : false; }

template<typename _Tp>
inline bool Chkmax(_Tp& x, const _Tp& y)
{ return x < y ? x = y, true : false; }

int read()
{
    int x = 0;
    char ch;
    while (!isdigit(ch = getchar()));
    while (x = x * 10 + (ch & 15), isdigit(ch = getchar()));
    return x;
}

namespace RMQ
{
    int lg2[maxN];
    int mini[22][maxN], maxi[22][maxN];

    void Init()
    {
        for (int i = 2; i <= n; ++i)
            lg2[i] = lg2[i >> 1] + 1;

        for (int i = 1; i <= n; ++i)
            mini[0][i] = maxi[0][i] = a[i];
        for (int i = 1; i < 21; ++i)
            for (int j = 1; j + (1 << i) - 1 <= n; ++j)
            {
                mini[i][j] = min(mini[i - 1][j], mini[i - 1][j + (1 << (i - 1))]);
                maxi[i][j] = max(maxi[i - 1][j], maxi[i - 1][j + (1 << (i - 1))]);
            }
    }

    int Query_min(int l, int r)
    {
        int t = lg2[r - l + 1];
        return min(mini[t][l], mini[t][r - (1 << t) + 1]);
    }

    int Query_max(int l, int r)
    {
        int t = lg2[r - l + 1];
        return max(maxi[t][l], maxi[t][r - (1 << t) + 1]);
    }

    bool Check(int l, int r)
    {
        //cout << l << ' ' << r << endl;
        int a = Query_min(l, r), b = Query_max(l, r);
        return LL(a + b) * (b - a + 1) / 2 == r - l + 1;
    }
}

namespace SEG
{
    int maxi[maxN * 4], tag[maxN * 4];

    inline void add(int x, int addv)
    { maxi[x] += addv, tag[x] += addv; }

    inline void Pushdown(int x)
    {
        if (tag[x])
        {
            add(x << 1, tag[x]), add(x << 1 | 1, tag[x]);
            tag[x] = 0;
        }
    }

    inline void Pushup(int x)
    { maxi[x] = max(maxi[x << 1], maxi[x << 1 | 1]); }

    void Add(int ql, int qr, int addv, int l = 1, int r = n, int x = 1)
    {
        if (ql <= l and r <= qr)
        {
            add(x, addv);
            return;
        }

        int mid = (l + r) >> 1;
        Pushdown(x);
        if (ql <= mid)
            Add(ql, qr, addv, l, mid, x << 1);
        if (mid < qr)
            Add(ql, qr, addv, mid + 1, r, x << 1 | 1);
        Pushup(x);
    }

    void Find(int l = 1, int r = n, int x = 1)
    {
        if (maxi[x] != m)
            return;
        if (l == r)
        {
            ans += RMQ::Check(l, rig);
            return;
        }
        int mid = (l + r) >> 1;
        Pushdown(x);
        Find(l, mid, x << 1);
        Find(mid + 1, r, x << 1 | 1);
    }
}

int main()
{
   // freopen("tram.in", "r", stdin);
   // freopen("tram.out", "w", stdout);
    ios::sync_with_stdio(false);
    n = read();
    for (int i = 1; i <= n; ++i)
        a[i] = read(), m = max(m, a[i]);
    RMQ::Init();

    for (int i = 1; i <= m; ++i)
        pos[i].push_back(0);

    for (int i = 1; i <= n; ++i)
    {
        rig = i;
        int t = a[i];
        SEG::Add(i, i, m);
        SEG::Add(pos[t].back() + 1, i, -1);
        pos[t].push_back(i);

        if ((int)pos[t].size() >= t + 1)
        {
            int a = pos[t].size() - t - 1;
            SEG::Add(pos[t][a] + 1, pos[t][a + 1], 1);
            if (a > 0)
                SEG::Add(pos[t][a - 1] + 1, pos[t][a], -1);
        }
        SEG::Find();
    }
    cout << ans << endl;

    return 0;
}