题解 CF490F 【Treeland Tour】

ezoixx130

2018-12-22 20:45:24

Solution

这题的n很小,只有6000,可以$O(n^2logn)$跑过去。 然而其实有更优秀的算法。 首先用动态开点权值线段树维护每个结点的LIS和LDS。 然后在dfs过程中Merge一下就可以了。 Merge的过程用线段树合并实现。 于是我们就31ms跑过了这道题。 代码: ```cpp #include <bits/stdc++.h> using namespace std; #define MAXN 200010 int head[MAXN],to[MAXN<<1],nxt[MAXN<<1]; int n,id[MAXN],ls[4000000],rs[4000000],cnt,val[MAXN],ans,tot; class IO { char buffer[10000001];char *h;int len; inline char gchar(){return *h++;} inline bool validdigit(char c){return c>='0' && c<='9';} public: inline void init(){len=fread(buffer,1,10000000,stdin);h=buffer;} inline int nextint(){register int i=0;register char c;do c=gchar();while(!validdigit(c));do{i=i*10+c-48;c=gchar();}while(validdigit(c));return i;} }io; struct data { int val,id; bool operator<(const data &d)const { return val<d.val; } }tmp[MAXN]; struct SGT { int maxo[4000000]; SGT(){cnt=0;} int query(int o,int l,int r,int ql,int qr) { if(!o)return 0; if(ql<=l && r<=qr)return maxo[o]; int mid=(l+r)>>1,ans=0; if(ql<=mid)ans=max(ans,query(ls[o],l,mid,ql,qr)); if(qr>mid)ans=max(ans,query(rs[o],mid+1,r,ql,qr)); return ans; } void update(int &o,int l,int r,int p,int x) { if(!o)o=++cnt;maxo[o]=max(maxo[o],x); if(l==r) return; int mid=(l+r)>>1; if(p<=mid)update(ls[o],l,mid,p,x); else update(rs[o],mid+1,r,p,x); } }LIS,LDS; void merge(int &a,int b) { if(a==0||b==0) { a+=b; return; } LIS.maxo[a]=max(LIS.maxo[a],LIS.maxo[b]);LDS.maxo[a]=max(LDS.maxo[a],LDS.maxo[b]); ans=max(ans,max(LIS.maxo[ls[a]]+LDS.maxo[rs[b]],LIS.maxo[ls[b]]+LDS.maxo[rs[a]])); merge(ls[a],ls[b]);merge(rs[a],rs[b]); } void dfs(int u,int fa) { int maxlis=0,maxlds=0; for(int i=head[u];i;i=nxt[i]) { int v=to[i]; if(v==fa)continue; dfs(v,u); int vlis=LIS.query(id[v],1,MAXN,1,val[u]-1),vlds=LDS.query(id[v],1,MAXN,val[u]+1,MAXN); ans=max(ans,max(vlis+maxlds+1,vlds+maxlis+1)); merge(id[u],id[v]); maxlis=max(maxlis,vlis); maxlds=max(maxlds,vlds); } LIS.update(id[u],1,MAXN,val[u],maxlis+1); LDS.update(id[u],1,MAXN,val[u],maxlds+1); } int main() { io.init(); n=io.nextint(); for(int i=1;i<=n;++i) { tmp[i].val=io.nextint(); tmp[i].id=i; } sort(tmp+1,tmp+n+1); int now=0; tmp[0].val=-1; for(int i=1;i<=n;++i) { if(tmp[i].val!=tmp[i-1].val)++now; val[tmp[i].id]=now; } for(int i=1;i<n;++i) { int u=io.nextint(),v=io.nextint(); nxt[++tot]=head[u]; to[tot]=v; head[u]=tot; nxt[++tot]=head[v]; to[tot]=u; head[v]=tot; } dfs(1,0); printf("%d\n",ans); } ```