题解 P5114 【八月脸】

· · 题解

我们发现需要求一个类似 \max(kx+y) 的东西,那么只需要求出所有 (x,y) 的上凸壳然后二分即可.

官方题解用的边分治,由于我不会边分治所以我写的点分治.

点分治的常用计算答案方法有三种: 先不考虑两个点在同一棵子树的情况计算最后再容斥掉;逐子树加入,即考虑每棵子树对之前加入的子树的贡献;合并果子,即每次拿出两个size最小的子树计算答案,然后把这两个子树的信息合并后放回去.

这里需要第三种算法,如果合并两棵大小为 size_xsize_y 的子树需要的复杂度是 O(size_x+size_y),那么总复杂度仍然是一个log然而我不会证

求出所有子树的末端为黑/白点的凸包,那么合并两棵子树 x,y 的时候,只需要对 x 的黑点凸包和 y 的白点凸包、x 的白点凸包和 y 的黑点凸包分别做一个闵可夫斯基和,得到的凸包上的点扔到答案凸包里去(先不构建,只存储点). 不会闵可夫斯基和的可以去看一下这道题. 然后归并 x 的凸包和 y 的凸包放回去即可.

求出答案凸包之后询问在上面二分即可,当然也可以像上面那道题一样暴力移动指针.

算一下复杂度,发现合并凸包的部分是没问题的,但是构建答案凸包和构建子树凸包的时候可能会有问题,使用sort会使复杂度变成两个log,那么只需要离线基数排序即可. 不过两个log也能过,所以就懒得写基排了(

跑得好像比边分治慢不少.

需要注意的细节是由于值域为 1.5\times 10^9,所以加加减减的时候可能会爆int(

然后排序的时候先按 x 排序,x 一样的时候按 y 排序,不然求的凸包可能会有问题.

#include<iostream>
#include<cstdio>
#include<cstring>
#include<vector>
#include<queue>
#include<algorithm>
#define IT vector<int>::iterator
using namespace std;
const int N=2e5+5000,INF=2e9;
struct Point
{
    long long x,y;
    Point operator +(const Point &a)const{return (Point){x+a.x,y+a.y};}
    Point operator -(const Point &a)const{return (Point){x-a.x,y-a.y};}
    bool operator >=(const Point &a)const{return 1ll*x*a.y<=1ll*y*a.x;}
    long long f(int k){return 1ll*k*x+y;}
}tmpans[N*200],tmp[2][N],w[N];
struct Hull
{
    int len;
    vector<Point>a;
    Hull(){len=0;}
    Point operator [](const int &x)const{return a[x];}
    void clear(){len=0;a.clear();}
    void push(const Point &x)
    {
        while(len>1&&(x-a[len-1])>=(a[len-1]-a[len-2]))--len,a.pop_back();
        ++len,a.push_back(x);
    }
}a[N<<1][2];
vector<int>e[N];
int ansn=0,len[2],col[N],n,m,nn,mxson,root,size[N],vis[N];
long long ans[N];
struct Q{int w,id;}q[N];
void ade(int u,int v){e[u].push_back(v);}
Hull merge(const Hull &a,const Hull &b)
{
    Hull tmp;
    for(int i=0,j=0;i<a.len||j<b.len;)
        if(i==a.len||(j<b.len&&b[j].x<a[i].x))
            tmp.push(b[j]),j++;
        else if(j==b.len||a[i].x<b[j].x)
            tmp.push(a[i]),i++;
        else tmp.push(a[i].y>b[j].y?a[i]:b[j]),i++,j++;
    return tmp;
}
Hull Minkow(const Hull &a,const Hull &b)
{
    Hull tmp;if(!a.len||!b.len)return tmp;
    tmp.push(a[0]+b[0]);int i=0,j=0;
    for(;i<a.len-1||j<b.len-1;)
        if(i==a.len-1)tmp.push(a[i]+b[++j]);
        else if(j==b.len-1)tmp.push(a[++i]+b[j]);
        else tmp.push((a[i+1]-a[i])>=(b[j+1]-b[j])?a[++i]+b[j]:a[i]+b[++j]);
    return tmp;
}
void findroot(int u,int fa)
{
    size[u]=1;int mx=0;
    for(IT i=e[u].begin();i!=e[u].end();i++)
    {
        int v=*i;if(vis[v]||v==fa)continue;
        findroot(v,u),mx=max(mx,size[v]),size[u]+=size[v];
    }
    mx=max(mx,nn-size[u]);
    if(mx<mxson)mxson=mx,root=u;
}
void dfs(int u,int fa,Point s)
{
    tmp[col[u]][++len[col[u]]]=s;
    for(IT i=e[u].begin();i!=e[u].end();i++)
    {
        int v=*i;if(v==fa||vis[v])continue;
        dfs(v,u,s+w[v]);
    }
}
int cmp(const Point &a,const Point &b){return a.x==b.x?a.y<b.y:a.x<b.x;}
void build(int u)
{
    a[u][0].clear();a[u][1].clear();
    sort(tmp[0]+1,tmp[0]+len[0]+1,cmp),sort(tmp[1]+1,tmp[1]+len[1]+1,cmp);

    for(int i=1;i<=len[0];i++)a[u][0].push(tmp[0][i]);
    for(int i=1;i<=len[1];i++)a[u][1].push(tmp[1][i]);
}
void calc(int x,int y,int u)
{
    Hull tmp=Minkow(a[x][0],a[y][1]);
    for(int i=0;i<tmp.len;i++)tmpans[++ansn]=tmp[i]+w[u];
    tmp=Minkow(a[x][1],a[y][0]);
    for(int i=0;i<tmp.len;i++)tmpans[++ansn]=tmp[i]+w[u];
}
struct HeapNode{int u,size;bool operator <(const HeapNode &a)const{return size>a.size;}};
void solve(int u)
{
    vis[u]=1;priority_queue<HeapNode>q;
    for(IT i=e[u].begin();i!=e[u].end();i++)
    {
        int v=*i;if(vis[v])continue;
        len[0]=len[1]=0;dfs(v,u,w[v]);
        build(v);q.push((HeapNode){v,a[v][0].len+a[v][1].len});
    }
    int tn=n;
    while(q.size()>1)
    {
        HeapNode x=q.top(),y;
        q.pop();y=q.top();q.pop();
        calc(x.u,y.u,u);
//      cout<<x.u<<" "<<y.u<<endl;
        ++tn;a[tn][0]=merge(a[x.u][0],a[y.u][0]),a[tn][1]=merge(a[x.u][1],a[y.u][1]);
        q.push((HeapNode){tn,a[tn][0].len+a[tn][1].len});
    }
    if(q.empty())return;
    int t=q.top().u;q.pop();
    if(col[u])
    {
        for(int i=0;i<a[t][0].len;i++)tmpans[++ansn]=a[t][0][i]+w[u];
    }
    else
    {
        for(int i=0;i<a[t][1].len;i++)tmpans[++ansn]=a[t][1][i]+w[u];
    }
    for(IT i=e[u].begin();i!=e[u].end();i++)
    {
        int v=*i;if(vis[v])continue;
        nn=size[v],mxson=INF,findroot(v,u),solve(root);
    }
}
int getin()
{
    int x=0,f=1;char ch=getchar();
    while((ch<'0'||ch>'9')&&ch!='-')ch=getchar();
    if(ch=='-')f=-1,ch=getchar();
    while(ch>='0'&&ch<='9')x=x*10+ch-48,ch=getchar();
    return x*f;
}
int cmp2(const Q &a,const Q &b){return a.w<b.w;}
int main()
{
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++)w[i].x=getin();
    for(int i=1;i<=n;i++)w[i].y=getin();
    for(int i=1;i<=n;i++)col[i]=getin();
    for(int i=1,u,v;i<n;i++)scanf("%d%d",&u,&v),ade(u,v),ade(v,u);
    nn=n,mxson=INF,findroot(1,0),solve(root);
    sort(tmpans+1,tmpans+ansn+1,cmp);
//    for(int i=1;i<=ansn;i++)cout<<tmpans[i].x<<" "<<tmpans[i].y<<endl;
    Hull tmp;
    for(int i=1;i<=ansn;i++)tmp.push(tmpans[i]);
//    for(int i=0;i<tmp.len;i++)cout<<tmp[i].x<<" "<<tmp[i].y<<endl;
    for(int i=1;i<=m;i++)scanf("%d",&q[i].w),q[i].id=i;
    if(!tmp.len){for(int i=1;i<=m;i++)puts("0");return 0;}
    sort(q+1,q+m+1,cmp2);
    for(int i=1,j=0;i<=m;i++)
    {
        while(j+1<tmp.len&&tmp[j+1].f(q[i].w)>=tmp[j].f(q[i].w))++j;
        ans[q[i].id]=tmp[j].f(q[i].w);
    }
    for(int i=1;i<=m;i++)printf("%lld\n",ans[i]);
}