题解 P3521 【[POI2011]ROT-Tree Rotations】

· · 题解

这是一道线段树合并的棵题

每个节点开一个权值线段树,父亲的线段树由儿子的合并而来,在合并的过程中即可计算答案

这样我们递归处理,每一层只考虑跨越两颗子树的答案

考虑合并lx,rx两颗树,在交换之前,答案是

sum[rs[lx]] * sum[ls[rx]]

在交换之后,答案是

sum[ls[lx]] * sum[rs[rx]]

可以手推,十分易得

还有一个问题,这题卡内存,那么可以写个内存池

#include<cstdio>
#include<iostream>
#include<cstring>
#include<queue>
#include<vector>
#include<map>
#include<climits>
#include<algorithm>
using namespace std;
typedef int mainint;
typedef double db;
typedef long long ll;
#define il inline 
#define pii pair<ll,int>
#define mp make_pair
#define B cout << "breakpoint" << endl;
#define O(x) cout << #x << "  " << x << endl;
inline int read()
{
    int ans = 0,op = 1;
    char ch = getchar();
    while(ch < '0' || ch > '9')
    {
        if(ch == '-') op = -1;
        ch = getchar();
    }
    while(ch >= '0' && ch <= '9')
    {
        (ans *= 10) += ch - '0';
        ch = getchar();
    }
    return ans * op;
}
const int maxn = 5e5 + 5;
int ls[maxn],rs[maxn],tot;
int sum[maxn];
int n;
ll ans1,ans2,ans;
int st[maxn],top;
il void trash(int i)
{
    ls[i] = rs[i] = sum[i] = 0; st[++top] = i;
}
il void up(int i)
{
    sum[i] = sum[ls[i]] + sum[rs[i]];
}
void update(int &i,int l,int r,int x,int d)
{
    if(!i) i = !top ? ++tot : st[top--];
    if(l == r && l == x)
    {
        sum[i] += d;
        return;
    }
    int mid = l + r >> 1;
    if(x <= mid) update(ls[i],l,mid,x,d);
    else update(rs[i],mid + 1,r,x,d);
    up(i);
}
void merge(int &lx,int rx)
{
    if(!lx || !rx)  
    { 
        lx = lx + rx; 
        //if(rx) trash(rx);
        return; 
    }
    sum[lx] = sum[lx] + sum[rx];
    ans1 += 1ll * sum[rs[lx]] * sum[ls[rx]];
    ans2 += 1ll * sum[ls[lx]] * sum[rs[rx]];
    merge(ls[lx],ls[rx]);
    merge(rs[lx],rs[rx]);
    trash(rx);
}
void solve(int &x)
{
    int t = read();
    int leftson,rightson;
    x = 0;
    if(t)
        update(x,1,n,t,1);
    else
    {
        solve(leftson); solve(rightson);
        ans1 = ans2 = 0;
        x = leftson; 
        merge(x,rightson);
        ans += min(ans1,ans2);
    }
}
int main()
{
    n = read();
    int x = 0;
    solve(x);
    cout << ans;
}