题解 P4528 【[CTSC2008]图腾】

· · 题解

f[1324]-f[1243]-f[1432]=(f[1x2x]-f[1423])+(f[12xx]-f[1234])+(f[14xx]-f[1423]) =f[1x2x]+f[1234]-f[12xx]-f[14xx] =f[1x2x]+f[1234]+f[13xx]-f[1xxx]

l_i,r_i分别表示i位置左边/右边比它小的数的个数,这些东西都可以用l_i,r_i表示。

#include <bits/stdc++.h>
using namespace std;

typedef long long lint;
const int maxn = 200005, mod = 16777216;

int n, ans, a[maxn], l[maxn], r[maxn];
int sum[maxn];

inline void inc(int &a, int b)
{
    a += b;
    if (a >= mod) a -= mod;
}

inline void dec(int &a, int b)
{
    a -= b;
    if (a < 0) a += mod;
}

inline int gi()
{
    char c = getchar();
    while (c < '0' || c > '9') c = getchar();
    int sum = 0;
    while ('0' <= c && c <= '9') sum = sum * 10 + c - 48, c = getchar();
    return sum;
}

#define lowbit(x) (x & (-x))

inline void insert(int x, int a)
{
    while (x <= n) {
        inc(sum[x], a);
        x += lowbit(x);
    }
}

inline int query(int x)
{
    int res = 0;
    while (x) {
        inc(res, sum[x]);
        x -= lowbit(x);
    }
    return res;
}

int main()
{
    freopen("picture.in", "r", stdin);
    freopen("picture.out", "w", stdout);

    n = gi();
    for (int i = 1; i <= n; ++i) a[i] = gi();

    for (int i = 1; i <= n; ++i) {
        l[i] = query(a[i]); r[i] = a[i] - l[i] - 1;
        insert(a[i], 1);
    }

    //1x2x
    memset(sum + 1, 0, sizeof(int) * n);
    for (int i = 1; i <= n; ++i) {
        inc(ans, (lint)(n - i - r[i]) * (i - 1) % mod * l[i] % mod);
        dec(ans, (lint)query(a[i]) * (n - i - r[i]) % mod);
        dec(ans, (lint)(l[i] - 1) * l[i] / 2 % mod * (n - i - r[i]) % mod);
        insert(a[i], i);
    }

    //1234
    memset(sum + 1, 0, sizeof(int) * n);
    for (int i = n; i >= 1; --i) {
        inc(ans, (lint)l[i] * (query(n) - query(a[i])) % mod);
        insert(a[i], n - i - r[i]);
    }

    //1xxx
    for (int t, i = 1; i <= n; ++i) {
        t = n - i - r[i];
        if (t >= 3) dec(ans, (lint)t * (t - 1) / 2 * (t - 2) / 3 % mod);
    }

    //13xx
    memset(sum + 1, 0, sizeof(int) * n);
    for (int i = 1; i <= n; ++i) {
        inc(ans, (lint)l[i] * r[i] % mod * (n - i - r[i]) % mod);
        inc(ans, (lint)l[i] * (l[i] - 1) / 2 % mod * (n - i - r[i]) % mod);
        dec(ans, (lint)query(a[i]) * (n - i - r[i]) % mod);
        insert(a[i], l[i] + r[i]);
    }

    printf("%d\n", ans);

    return 0;
}