题解 P4738 【[CERC2017] Cumulative Code】

· · 题解

这题在 NKOJ 上有一个更响亮的名字——弗斯。

堆式二叉树的性质是很好的,考虑如何递归表出 prufer 序:

但暴力预处理 P(0, k, 1) 显然是不好的。

由于每次询问的 d 不同,我们单独考虑每个询问,则一个自然的想法是经典的分层优化

回答询问时差分一下转化为两个前缀询问,处理出所需前缀和信息即可。

时间复杂度为 O(q(2^S + 2^{k - S})),当 S = \frac{k}{2} 时取最优时间复杂度为 O(q 2^{\frac{k}{2}})

代码:

#include <stdio.h>

typedef long long ll;

int mid, mid_size, len0 = 0, len1 = 0, len2 = 0;
int pid0[32767], cid0[32767], pid1[16387], cid1[16387], pid2[16387], cid2[16387], sum01[32767], sum02[32767], sum11[16387], sum12[16387], sum21[16387], sum22[16387];

inline int max(int a, int b){
    return a > b ? a : b;
}

void get_prufer(int op, int n, int pid, int cid, int &len, int pre_id[], int cur_id[]){
    if (op == 0){
        if (n == 2){
            len++;
            pre_id[len] = pid;
            cur_id[len] = cid;
            return;
        }
        get_prufer(1, n - 1, pid * 2, cid * 2, len, pre_id, cur_id);
        len++;
        pre_id[len] = pid * 2;
        cur_id[len] = cid * 2 + 1;
        get_prufer(0, n - 1, pid * 2, cid * 2 + 1, len, pre_id, cur_id);
    } else {
        if (n > 1){
            get_prufer(1, n - 1, pid * 2, cid * 2, len, pre_id, cur_id);
            get_prufer(1, n - 1, pid * 2, cid * 2 + 1, len, pre_id, cur_id);
        }
        len++;
        pre_id[len] = pid / 2;
        cur_id[len] = cid / 2;
    }
}

inline int min(int a, int b){
    return a < b ? a : b;

}

ll f(int op, int depth, int n, int m, int k, int id){
    if (m <= 0) return 0;
    int end = n + k * (m - 1);
    ll ans;
    if (op == 0){
        if (depth == mid){
            ans = (ll)sum01[end] * id + sum02[end];
        } else {
            int cur_size = (1 << (depth - 1)) - 1;
            if (end <= cur_size){
                ans = f(1, depth - 1, n, m, k, id * 2);
            } else {
                ans = 0;
                if (n <= cur_size) ans += f(1, depth - 1, n, (cur_size - n) / k + 1, k, id * 2);
                end -= cur_size + 1;
                if (end % k == 0) ans += id * 2 + 1;
                if (end >= 1) ans += f(0, depth - 1, (end - 1) % k + 1, (end - 1) / k + 1, k, id * 2 + 1);
            }
        }
    } else {
        if (depth == mid){
            if (end <= mid_size){
                ans = (ll)sum11[end] * id + sum12[end];
            } else {
                ans = 0;
                if (n <= mid_size){
                    int end_ = (mid_size - n) / k * k + n;
                    ans += (ll)sum11[end_] * id + sum12[end_];
                }
                end -= mid_size;
                if (end <= mid_size){
                    ans += (ll)sum21[end] * id + sum22[end];
                } else {
                    ans += id / 2;
                    if (k - 1 <= mid_size){
                        end -= k;
                        ans += (ll)sum21[end] * id + sum22[end];
                    }
                }
            }
        } else {
            int cur_size = (1 << (depth - 1)) - 1;
            if (end <= cur_size){
                ans = f(1, depth - 1, n, m, k, id * 2);
            } else {
                ans = 0;
                if (n <= cur_size) ans += f(1, depth - 1, n, (cur_size - n) / k + 1, k, id * 2);
                end -= cur_size;
                if (end <= cur_size){
                    ans += f(1, depth - 1, (end - 1) % k + 1, (end - 1) / k + 1, k, id * 2 + 1);
                } else {
                    ans += id / 2;
                    if (k - 1 <= cur_size){
                        end -= k;
                        ans += f(1, depth - 1, (end - 1) % k + 1, (end - 1) / k + 1, k, id * 2 + 1);
                    }
                }
            }
        }
    }
    return ans;
}

int main(){
    int k, q;
    scanf("%d %d", &k, &q);
    mid = max(k / 2, 2);
    mid_size = (1 << (mid - 1)) - 1;
    get_prufer(0, mid, 1, 0, len0, pid0, cid0);
    get_prufer(1, mid - 1, 2, 0, len1, pid1, cid1);
    get_prufer(1, mid - 1, 2, 1, len2, pid2, cid2);
    for (int i = 1; i <= q; i++){
        int a, d, m, r;
        scanf("%d %d %d", &a, &d, &m);
        r = (a - 1) % d + 1;
        for (int j = 1; j <= len0; j++){
            sum01[j] = pid0[j];
            sum02[j] = cid0[j];
            if (j >= d){
                sum01[j] += sum01[j - d];
                sum02[j] += sum02[j - d];
            }
        }
        for (int j = 1; j <= len1; j++){
            sum11[j] = pid1[j];
            sum12[j] = cid1[j];
            if (j >= d){
                sum11[j] += sum11[j - d];
                sum12[j] += sum12[j - d];
            }
        }
        for (int j = 1; j <= len2; j++){
            sum21[j] = pid2[j];
            sum22[j] = cid2[j];
            if (j >= d){
                sum21[j] += sum21[j - d];
                sum22[j] += sum22[j - d];
            }
        }
        printf("%lld\n", f(0, k, r, m + (a - 1) / d, d, 1) - f(0, k, r, (a - 1) / d, d, 1));
    }
    return 0;
}