题解 P5220 【特工的信息流】
本文同步发表于我的博客:https://www.alpha1022.me/articles/lg-5220.htm
这题虽然原 idea 是我的,但是 noname 改了之后我就一直咕咕咕没做了。
今天闲得发慌来练手速……
写完这题的第一感觉是可以回去把「SDOI2011」颜色 的坑给填了(虽然我还是没打算写
其实做法比较显然,先考虑在序列上的做法,线段树维护区间后缀积之和 和区间积,那么合并左右子树的时候,根据乘法分配律可得:
于是就显然。
考虑把这个做法放到树上,但是这个时候我们发现如果把路径从 LCA 划分成两段的话,有一段需要用与答案相反前缀积之和来统计,于是改一改线段树就好了。
关于此题树剖做法的码量瓶颈,我认为应该在于查询的过程……
必须保证思路清晰才能不写错。
然后是取模的问题,虽然模数很小但是也印证了那句话:
不开 long long 见祖宗,十年 OI 一场空。
代码:
#include <cstdio>
#include <algorithm>
#include <vector>
#include <utility>
#define ls (p << 1)
#define rs (ls | 1)
using namespace std;
const int N = 1e5;
const long long mod = 20924;
int n,m;
long long a[N + 5];
int to[(N << 1) + 5],pre[(N << 1) + 5],first[N + 5];
inline void add(int u,int v)
{
static int tot = 0;
to[++tot] = v;
pre[tot] = first[u];
first[u] = tot;
}
int fa[N + 5],dep[N + 5],sz[N + 5],son[N + 5],top[N + 5],id[N + 5],rk[N + 5];
void dfs1(int p)
{
sz[p] = 1;
for(register int i = first[p];i;i = pre[i])
if(to[i] ^ fa[p])
{
fa[to[i]] = p,dep[to[i]] = dep[p] + 1;
dfs1(to[i]),sz[p] += sz[to[i]];
if(!son[p] || sz[to[i]] > sz[son[p]])
son[p] = to[i];
}
}
void dfs2(int p)
{
static int tot = 0;
rk[id[p] = ++tot] = p;
if(!son[p])
return ;
top[son[p]] = top[p],dfs2(son[p]);
for(register int i = first[p];i;i = pre[i])
if(!id[to[i]])
top[to[i]] = to[i],dfs2(to[i]);
}
struct segnode
{
long long prod,sufsum,presum;
} seg[(N << 2) + 10];
void build(int p,int tl,int tr)
{
if(tl == tr)
{
seg[p].prod = seg[p].sufsum = seg[p].presum = a[rk[tl]];
return ;
}
int mid = tl + tr >> 1;
build(ls,tl,mid);
build(rs,mid + 1,tr);
seg[p].prod = seg[ls].prod * seg[rs].prod % mod;
seg[p].sufsum = (seg[ls].sufsum * seg[rs].prod % mod + seg[rs].sufsum) % mod;
seg[p].presum = (seg[rs].presum * seg[ls].prod % mod + seg[ls].presum) % mod;
}
void modify(int x,int k,int p,int tl,int tr)
{
if(tl == tr)
{
seg[p].prod += k,seg[p].sufsum += k,seg[p].presum += k;
return ;
}
int mid = tl + tr >> 1;
if(x <= mid)
modify(x,k,ls,tl,mid);
else
modify(x,k,rs,mid + 1,tr);
seg[p].prod = seg[ls].prod * seg[rs].prod % mod;
seg[p].sufsum = (seg[ls].sufsum * seg[rs].prod % mod + seg[rs].sufsum) % mod;
seg[p].presum = (seg[rs].presum * seg[ls].prod % mod + seg[ls].presum) % mod;
}
long long query_prod(int l,int r,int p,int tl,int tr)
{
if(l <= tl && tr <= r)
return seg[p].prod;
int mid = tl + tr >> 1;
long long ret = 1;
if(l <= mid)
ret = ret * query_prod(l,r,ls,tl,mid) % mod;
if(r > mid)
ret = ret * query_prod(l,r,rs,mid + 1,tr) % mod;
return ret;
}
long long query_sufsum(int l,int r,int p,int tl,int tr)
{
if(l <= tl && tr <= r)
return seg[p].sufsum;
int mid = tl + tr >> 1;
if(l <= mid && r > mid)
return (query_sufsum(l,r,ls,tl,mid) * query_prod(l,r,rs,mid + 1,tr) % mod + query_sufsum(l,r,rs,mid + 1,tr)) % mod;
if(l <= mid)
return query_sufsum(l,r,ls,tl,mid);
else
return query_sufsum(l,r,rs,mid + 1,tr);
}
long long query_presum(int l,int r,int p,int tl,int tr)
{
if(l <= tl && tr <= r)
return seg[p].presum;
int mid = tl + tr >> 1;
if(l <= mid && r > mid)
return (query_presum(l,r,rs,mid + 1,tr) * query_prod(l,r,ls,tl,mid) % mod + query_presum(l,r,ls,tl,mid)) % mod;
if(l <= mid)
return query_presum(l,r,ls,tl,mid);
if(r > mid)
return query_presum(l,r,rs,mid + 1,tr);
}
pair<int,int> getlca(int x,int y)
{
while(top[x] ^ top[y])
dep[top[x]] > dep[top[y]] ? x = fa[top[x]] : y = fa[top[y]];
return dep[x] < dep[y] ? make_pair(x,0) : make_pair(y,1);
}
long long query(int x,int y)
{
pair<int,int> t = getlca(x,y);
int lca = t.first,w = t.second;
vector< pair<int,int> > range;
while(top[x] ^ top[lca])
range.push_back(make_pair(id[top[x]],id[x])),x = fa[top[x]];
if(w)
range.push_back(make_pair(id[lca],id[x]));
long long temp = 0,prod = 1;
for(register int i = range.size() - 1;~i;--i)
temp = (temp + query_presum(range[i].first,range[i].second,1,1,n) * prod % mod) % mod,prod = prod * query_prod(range[i].first,range[i].second,1,1,n) % mod;
range.clear();
while(top[y] ^ top[lca])
range.push_back(make_pair(id[top[y]],id[y])),y = fa[top[y]];
if(!w)
range.push_back(make_pair(id[lca],id[y]));
long long ret = 0;
prod = 1;
for(register int i = 0;i < range.size();++i)
ret = (ret + query_sufsum(range[i].first,range[i].second,1,1,n) * prod % mod) % mod,prod = prod * query_prod(range[i].first,range[i].second,1,1,n) % mod;
return (ret + temp * prod % mod) % mod;
}
int main()
{
scanf("%d%d",&n,&m);
for(register int i = 1;i <= n;++i)
scanf("%lld",a + i);
int u,v;
for(register int i = 1;i < n;++i)
scanf("%d%d",&u,&v),add(u,v),add(v,u);
dep[1] = 1,dfs1(1),top[1] = 1,dfs2(1);
build(1,1,n);
char op;
int x,y;
while(m--)
{
scanf(" %c%d%d",&op,&x,&y);
if(op == 'Q')
printf("%lld\n",(query(x,y) + mod) % mod);
else
modify(id[x],y,1,1,n);
}
}