题解 P5220 【特工的信息流】

· · 题解

本文同步发表于我的博客:https://www.alpha1022.me/articles/lg-5220.htm

这题虽然原 idea 是我的,但是 noname 改了之后我就一直咕咕咕没做了。
今天闲得发慌来练手速……

写完这题的第一感觉是可以回去把「SDOI2011」颜色 的坑给填了(虽然我还是没打算写

其实做法比较显然,先考虑在序列上的做法,线段树维护区间后缀积之和 和区间积,那么合并左右子树的时候,根据乘法分配律可得:

\sum_{i=l}^r\prod_{j=i}^r a_j = \prod_{i=m+1}^r a_i\left(\sum_{i=l}^m\prod_{j=i}^m a_j\right) + \sum\limits_{i=m+1}^r\prod\limits_{j=i}^r a_j (m \in [l,r))

于是就显然。

考虑把这个做法放到树上,但是这个时候我们发现如果把路径从 LCA 划分成两段的话,有一段需要用与答案相反前缀积之和来统计,于是改一改线段树就好了。

关于此题树剖做法的码量瓶颈,我认为应该在于查询的过程……
必须保证思路清晰才能不写错。

然后是取模的问题,虽然模数很小但是也印证了那句话:

不开 long long 见祖宗,十年 OI 一场空。

代码:

#include <cstdio>
#include <algorithm>
#include <vector>
#include <utility>
#define ls (p << 1)
#define rs (ls | 1)
using namespace std;
const int N = 1e5;
const long long mod = 20924;
int n,m;
long long a[N + 5];
int to[(N << 1) + 5],pre[(N << 1) + 5],first[N + 5];
inline void add(int u,int v)
{
    static int tot = 0;
    to[++tot] = v;
    pre[tot] = first[u];
    first[u] = tot;
}
int fa[N + 5],dep[N + 5],sz[N + 5],son[N + 5],top[N + 5],id[N + 5],rk[N + 5];
void dfs1(int p)
{
    sz[p] = 1;
    for(register int i = first[p];i;i = pre[i])
        if(to[i] ^ fa[p])
        {
            fa[to[i]] = p,dep[to[i]] = dep[p] + 1;
            dfs1(to[i]),sz[p] += sz[to[i]];
            if(!son[p] || sz[to[i]] > sz[son[p]])
                son[p] = to[i];
        }
}
void dfs2(int p)
{
    static int tot = 0;
    rk[id[p] = ++tot] = p;
    if(!son[p])
        return ;
    top[son[p]] = top[p],dfs2(son[p]);
    for(register int i = first[p];i;i = pre[i])
        if(!id[to[i]])
            top[to[i]] = to[i],dfs2(to[i]);
}
struct segnode
{
    long long prod,sufsum,presum;
} seg[(N << 2) + 10];
void build(int p,int tl,int tr)
{
    if(tl == tr)
    {
        seg[p].prod = seg[p].sufsum = seg[p].presum = a[rk[tl]];
        return ;
    }
    int mid = tl + tr >> 1;
    build(ls,tl,mid);
    build(rs,mid + 1,tr);
    seg[p].prod = seg[ls].prod * seg[rs].prod % mod;
    seg[p].sufsum = (seg[ls].sufsum * seg[rs].prod % mod + seg[rs].sufsum) % mod;
    seg[p].presum = (seg[rs].presum * seg[ls].prod % mod + seg[ls].presum) % mod;
}
void modify(int x,int k,int p,int tl,int tr)
{
    if(tl == tr)
    {
        seg[p].prod += k,seg[p].sufsum += k,seg[p].presum += k;
        return ;
    }
    int mid = tl + tr >> 1;
    if(x <= mid)
        modify(x,k,ls,tl,mid);
    else
        modify(x,k,rs,mid + 1,tr);
    seg[p].prod = seg[ls].prod * seg[rs].prod % mod;
    seg[p].sufsum = (seg[ls].sufsum * seg[rs].prod % mod + seg[rs].sufsum) % mod;
    seg[p].presum = (seg[rs].presum * seg[ls].prod % mod + seg[ls].presum) % mod;
}
long long query_prod(int l,int r,int p,int tl,int tr)
{
    if(l <= tl && tr <= r)
        return seg[p].prod;
    int mid = tl + tr >> 1;
    long long ret = 1;
    if(l <= mid)
        ret = ret * query_prod(l,r,ls,tl,mid) % mod;
    if(r > mid)
        ret = ret * query_prod(l,r,rs,mid + 1,tr) % mod;
    return ret;
}
long long query_sufsum(int l,int r,int p,int tl,int tr)
{
    if(l <= tl && tr <= r)
        return seg[p].sufsum;
    int mid = tl + tr >> 1;
    if(l <= mid && r > mid)
        return (query_sufsum(l,r,ls,tl,mid) * query_prod(l,r,rs,mid + 1,tr) % mod + query_sufsum(l,r,rs,mid + 1,tr)) % mod;
    if(l <= mid)
        return query_sufsum(l,r,ls,tl,mid);
    else
        return query_sufsum(l,r,rs,mid + 1,tr);
}
long long query_presum(int l,int r,int p,int tl,int tr)
{
    if(l <= tl && tr <= r)
        return seg[p].presum;
    int mid = tl + tr >> 1;
    if(l <= mid && r > mid)
        return (query_presum(l,r,rs,mid + 1,tr) * query_prod(l,r,ls,tl,mid) % mod + query_presum(l,r,ls,tl,mid)) % mod;
    if(l <= mid)
        return query_presum(l,r,ls,tl,mid);
    if(r > mid)
        return query_presum(l,r,rs,mid + 1,tr);
}
pair<int,int> getlca(int x,int y)
{
    while(top[x] ^ top[y])
        dep[top[x]] > dep[top[y]] ? x = fa[top[x]] : y = fa[top[y]];
    return dep[x] < dep[y] ? make_pair(x,0) : make_pair(y,1);
}
long long query(int x,int y)
{
    pair<int,int> t = getlca(x,y);
    int lca = t.first,w = t.second;
    vector< pair<int,int> > range;
    while(top[x] ^ top[lca])
        range.push_back(make_pair(id[top[x]],id[x])),x = fa[top[x]];
    if(w)
        range.push_back(make_pair(id[lca],id[x]));
    long long temp = 0,prod = 1;
    for(register int i = range.size() - 1;~i;--i)
        temp = (temp + query_presum(range[i].first,range[i].second,1,1,n) * prod % mod) % mod,prod = prod * query_prod(range[i].first,range[i].second,1,1,n) % mod;
    range.clear();
    while(top[y] ^ top[lca])
        range.push_back(make_pair(id[top[y]],id[y])),y = fa[top[y]];
    if(!w)
        range.push_back(make_pair(id[lca],id[y]));
    long long ret = 0;
    prod = 1;
    for(register int i = 0;i < range.size();++i)
        ret = (ret + query_sufsum(range[i].first,range[i].second,1,1,n) * prod % mod) % mod,prod = prod * query_prod(range[i].first,range[i].second,1,1,n) % mod;
    return (ret + temp * prod % mod) % mod;
}
int main()
{
    scanf("%d%d",&n,&m);
    for(register int i = 1;i <= n;++i)
        scanf("%lld",a + i);
    int u,v;
    for(register int i = 1;i < n;++i)
        scanf("%d%d",&u,&v),add(u,v),add(v,u);
    dep[1] = 1,dfs1(1),top[1] = 1,dfs2(1);
    build(1,1,n);
    char op;
    int x,y;
    while(m--)
    {
        scanf(" %c%d%d",&op,&x,&y);
        if(op == 'Q')
            printf("%lld\n",(query(x,y) + mod) % mod);
        else
            modify(id[x],y,1,1,n);
    }
}