P9844 [ICPC2021 Nanjing R] Paimon Segment Tree

· · 题解

简要题意

你需要维护一个 (m+1)\times n 的矩阵 A,行从 0 开始计数。初始时满足 A_{0i}=A_{1i}=\cdots=A_{mi}=a_i

m 次操作,第 i 次操作给出区间 [l,r] 和常数 v,表示将满足 i\leq x\leq m,l\leq y\leq rA_{xy} 加上 v

q 次询问,每次询问给出给出区间 [x,y][l,r],求满足 x\leq p\leq y,l\leq q\leq rA_{pq} 的平方和。对一个大质数取模。

1\leq n,m,q\leq 5\times10^4

思路

先差分,改为求 [1,y] 的答案减去 [1,x-1] 的答案。所以可以转成求区间历史平方和(应该叫做这个名字)。

考虑区间和 B、区间平方和 C、区间历史平方和 D 与区间长度 A 的关系。假设现在执行了区间加 v,则:

\begin{aligned} &A^{'}=A\\ &B^{'}=vA+B\\ &C^{'}=v^2A+2vB+C\\ &D^{'}=v^2A+2vB+C+D \end{aligned}

其中 C^{'} 就是拆完全平方公式,其他没什么好提的。

不难发现假如我们用向量来表示每一个位置:

\begin{bmatrix} A&B&C&D \end{bmatrix}

则可以构造一个矩阵,用矩阵乘法的方法来修改:

\begin{bmatrix} A&B&C&D \end{bmatrix}\times \begin{bmatrix} 1&v&v^2&v^2\\ 0&1&2v&2v\\ 0&0&1&1\\ 0&0&0&1 \end{bmatrix} = \begin{bmatrix} A^{'}&B^{'}&C^{'}&D^{'} \end{bmatrix}

然后需要注意下没有修改的部分,区间历史平方和需要集体手动更新:

\begin{aligned} &A^{'}=A\\ &B^{'}=B\\ &C^{'}=C\\ &D^{'}=C+D \end{aligned}

则可以构造一个矩阵,用矩阵乘法的方法来修改:

\begin{bmatrix} A&B&C&D \end{bmatrix}\times \begin{bmatrix} 1&0&0&0\\ 0&1&0&0\\ 0&0&1&1\\ 0&0&0&1 \end{bmatrix} = \begin{bmatrix} A^{'}&B^{'}&C^{'}&D^{'} \end{bmatrix}

然后你抓一只无辜的线段树维护一下即可。

至于询问,可以离线,在每一次操作后计算一下需要计算的区间历史平方和。

时间复杂度 O(q\log n) 带一个 64 倍常数。其实可以优化矩阵乘法来减小常数,但是据说没必要,我就没有写了(于是就要丧心病狂的卡常了)。

实现上的小细节

代码

喜提最优解最后一名!此代码目前只能在 Luogu 上通过,Gym 上被卡常了,开心。

Submission

#include <bits/stdc++.h>
#define ls (i << 1)
#define rs (i << 1 | 1)
#define mid ((l + r) >> 1)
using namespace std;

const int mod = 1e9 + 7, N = 5e4 + 5;

int M(long long x){return x%mod;}
int Add(long long x, long long y){return (x + y) > mod ? (x - mod + y) : (x + y);}

struct matrix{
    int n,m,a[5][5];
    void clear(){memset(a, 0, sizeof(a));}
    void init(int N, int M){n = N, m = M;clear();}
    int* operator[](int i){return a[i];}
};

matrix operator*(matrix a, matrix b){
    matrix ans;ans.init(a.n, b.m);
    assert(a.m == b.n);
    for(int k=1;k<=a.m;k++){
        for(int i=1;i<=a.n;i++){
            for(int j=1;j<=b.m;j++) ans[i][j] = Add(ans[i][j], M(1ll * a[i][k] * b[k][j]));
        }
    }
    return ans;
}

matrix operator+(matrix a, matrix b){
    matrix ans;ans.init(a.n, a.m);
    for(int i=1;i<=a.n;i++){
        for(int j=1;j<=a.m;j++) ans[i][j] = Add(a[i][j], b[i][j]);
    }
    return ans;
}

matrix t[N << 2], tag[N << 2];

void pushup(int i){t[i] = t[ls] + t[rs];}

void build(int i, int l, int r){
    tag[i].init(4, 4);
    tag[i][1][1] = tag[i][2][2] = tag[i][3][3] = tag[i][4][4] = 1;
    if(l == r){
        int v;cin>>v;
        t[i].init(1, 4);
        t[i][1][1] = 1;t[i][1][2] = v = Add(mod, v);
        t[i][1][3] = t[i][1][4] = M(1ll * v * v);
        return;
    }
    build(ls, l, mid);build(rs, mid + 1, r);
    pushup(i);
}

void pushdown(int i){
    tag[ls] = tag[ls] * tag[i];
    tag[rs] = tag[rs] * tag[i];
    t[ls] = t[ls] * tag[i];
    t[rs] = t[rs] * tag[i];
    tag[i].init(4, 4);
    tag[i][1][1] = tag[i][2][2] = tag[i][3][3] = tag[i][4][4] = 1;
}

void update(int ql, int qr, matrix v, int i, int l, int r){
    if(ql > qr) return;
    if(ql <= l && r <= qr){
        tag[i] = tag[i] * v;
        t[i] = t[i] * v;
        return;
    }
    pushdown(i);
    if(ql <= mid) update(ql, qr, v, ls, l, mid);
    if(qr > mid) update(ql, qr, v, rs, mid + 1, r);
    pushup(i);
}

matrix query(int ql, int qr, int i, int l, int r){
    if(ql <= l && r <= qr) return t[i];
    pushdown(i);
    if(ql <= mid && qr > mid) return query(ql, qr, ls, l, mid) + query(ql, qr, rs, mid + 1, r);
    if(ql <= mid) return query(ql, qr, ls, l, mid);
    else return query(ql, qr, rs, mid + 1, r);
}

vector<pair<int,int> > qs[N];
vector<int> ans[N];
int n,m,q;

struct Update{
    int l, r, v;
} updates[N];

struct Query{
    int p1, p2, q1, q2;
} qqs[N];

signed main(){
    ios::sync_with_stdio(false);
    cin.tie(0);cout.tie(0);
    cin>>n>>m>>q;
    build(1, 1, n);
    for(int i=1;i<=m;i++) cin>>updates[i].l>>updates[i].r>>updates[i].v;
    for(int i=1;i<=q;i++){
        int l,r,x,y;cin>>l>>r>>x>>y;
        qqs[i].p1 = qs[x].size();
        qqs[i].q1 = x;
        qs[x].push_back(make_pair(l, r));
        qqs[i].p2 = qs[y + 1].size();
        qqs[i].q2 = y + 1;
        qs[y + 1].push_back(make_pair(l, r));
    }
    for(auto i : qs[0]) ans[0].push_back(0);
    for(auto i : qs[1]) ans[1].push_back(query(i.first, i.second, 1, 1, n)[1][4]);
    for(int i=1;i<=m;i++){
        int l = updates[i].l, r = updates[i].r, v = updates[i].v;
        matrix mat;mat.init(4, 4);
        mat[1][1] = mat[2][2] = mat[3][3] = mat[3][4] = mat[4][4] = 1;
        mat[1][2] = v = Add(mod, v);
        mat[1][3] = mat[1][4] = M(1ll * v * v);
        mat[2][3] = mat[2][4] = Add(v, v);
        update(l, r, mat, 1, 1, n);
        mat.clear();
        mat[1][1] = mat[2][2] = mat[3][3] = mat[3][4] = mat[4][4] = 1;
        update(1, l - 1, mat, 1, 1, n);
        update(r + 1, n, mat, 1, 1, n);
        for(auto j : qs[i + 1]) ans[i + 1].push_back(query(j.first, j.second, 1, 1, n)[1][4]);
    }
    for(int i=1;i<=q;i++) cout<<Add(ans[qqs[i].q2][qqs[i].p2], mod - ans[qqs[i].q1][qqs[i].p1])<<'\n';
    return 0;
}