题解 P6144 【[USACO20FEB]Help Yourself P】

· · 题解

我丢: https://www.luogu.com.cn/problem/P6144

先考虑 k = 1

考虑DP

首先将所有的线段按照左端点排序,考虑 f[r]表示最右以r结尾的线段集合的连通块的数量的和.

插入一个线段[l,r]

对于右端点在[1,l-1]的每种情况,都可以使它们的连通块加1,再加到f[r]里面

对于右端点在[l,r]的每种情况,可以直接加到f[r]里面 对于右端点在[r+1,n]f[i],要全部*2 (选当前线段或不选)

也就是要一个数据结构,支持区间*2,区间求和,显然线段树

对于k != 1, 可以考虑把它的 [0,k]次的结果分别维护,计算的时候用二项式定理合并就好了

就是对于[1,l-1]的 要把 x^k -> (x+1)^k再加到f[r]里,这里可以用二项式定理把[0,k]次方的结果乘上二项式系数再加起来就好

就是线段树每个点要维护各个次方的结果

看代码好理解一点

code:

#include<bits/stdc++.h>
#define N 400005
#define int long long
#define mod 1000000007
using namespace std;
struct haha {
    int a[13];
};
struct A {
    int l, r;
} a[N];
int cmp(A x, A y) {
    return x.l < y.l;
}
int ha[N << 3][12], tag[N << 3], n, k, C[25][25];
void update(int rt) {
    for(int i = 0; i <= k; i ++)
        ha[rt][i] = (ha[rt << 1][i] + ha[rt << 1 | 1][i]) % mod;
}
void pushdown(int rt) {
    if(tag[rt] == 1) return;
    tag[rt << 1] *= tag[rt], tag[rt << 1 | 1] *= tag[rt], tag[rt << 1 | 1] %= mod, tag[rt << 1] %= mod;
    int t = 1;
    for(int i = 0; i <= k; i ++) 
        ha[rt << 1][i] *= tag[rt], ha[rt << 1 | 1][i] *= tag[rt], ha[rt << 1][i] %= mod, ha[rt << 1 | 1][i] %= mod;
    tag[rt] = 1;
}   
void add(int rt, int l, int r, int x, haha b) { 
    if(l == r) {
        for(int i = 0; i <= k; i ++) ha[rt][i] = (ha[rt][i] + b.a[i]) % mod;
        return;
    }
    pushdown(rt);
    int mid = (l + r) >> 1;
    if(x <= mid) add(rt << 1, l, mid, x, b);
    else add(rt << 1 | 1, mid + 1, r, x, b);
    update(rt);
}
void mul(int rt, int l, int r, int L, int R) {
    if(L > R) return;
    if(L <= l && r <= R) {
        tag[rt] = tag[rt] * 2 % mod;
        int t = 1;
        for(int i = 0; i <= k; i ++)
            ha[rt][i] = ha[rt][i] * 2 % mod;
        return;
    }
    pushdown(rt);
    int mid = (l + r) >> 1;
    if(L <= mid) mul(rt << 1, l, mid, L, R);
    if(R > mid) mul(rt << 1 | 1, mid + 1, r, L, R);
    update(rt);
}
haha query(int rt, int l, int r, int L, int R) { 
    if(L > R) {
        haha b;
        for(int i = 0; i <= k; i ++) b.a[i] = 0;
        return b;
    }
    if(L <= l && r <= R) {
        haha d;
        for(int i = 0; i <= k; i ++) d.a[i] = ha[rt][i];
        return d;
    }
    pushdown(rt);
    int mid = (l + r) >> 1; haha ret;
    ret.a[0] = 1;
    for(int i =  0; i <= k; i ++) ret.a[i] = 0;
    if(L <= mid) {
        haha b = query(rt << 1, l, mid, L, R);
        for(int i = 0; i <= k; i ++) ret.a[i] += b.a[i], ret.a[i] %= mod;
    } 
    if(R > mid) {
        haha b = query(rt << 1 | 1, mid + 1, r, L, R);
        for(int i = 0; i <= k; i ++) ret.a[i] += b.a[i], ret.a[i] %= mod;
    }
    return ret;
}
signed main() {
    scanf("%lld%lld", &n, &k); int lim = 2 * n;
    for(int i = 0; i <= lim * 4; i ++) ha[i][0] = 0 ;
    for(int i = 0; i <= k; i ++) C[i][0] = 1;
    for(int i = 1; i <= k; i ++)
        for(int j = 1; j <= i; j ++)
            C[i][j] = (C[i - 1][j] + C[i - 1][j - 1]) % mod;
    for(int i = 1; i <= n; i ++) scanf("%lld%lld", &a[i].l, &a[i].r);
    sort(a + 1, a + 1 + n, cmp);
    haha d;
    for(int i = 0; i <= k; i ++) d.a[i] = 0;
    d.a[0] = 1;
    add(1, 0, lim, 0, d);
    for(int tt = 1; tt <= n; tt ++) {
        int l = a[tt].l, r = a[tt].r;
        haha b = query(1, 0, lim, 0, l - 1);
        haha c;
        for(int i = 0; i <= k; i ++) c.a[i] = 0;
        for(int i = 0; i <= k; i ++){
            int ret = 0;
            for(int j = 0; j <= i; j ++)
                c.a[i] += b.a[j] * C[i][j] % mod, c.a[i] %= mod;
        }
        b = query(1, 0, lim, l, r - 1);
        for(int i = 0; i <= k; i ++) c.a[i] = (c.a[i] + b.a[i]) % mod;
        add(1, 0, lim, r, c);
        mul(1, 0, lim, r + 1, lim);
    }
    printf("%lld", query(1, 0, lim, 0, lim).a[k]);
    return 0;
}

这一场就这题值得写一下QWQ