仙人掌染色 题解

· · 题解

首先考虑树怎么做。

你肯定是要在节点上算贡献,但是按照题目给的方法算,巨大复杂,肯定没法 dp,我们尝试化简一下贡献的计算方式。

通过一些猜测,你想到了也许边的顺序其实不影响答案,通过手搓几组数据你发现很对,然后你可以简单证明交换相邻的边不会对答案产生影响,所以顺序确实无关。

因此你实际上只在乎一个点的所有出边有多少条被染了色。

那么对于树就有一个简单的 dp,不妨设 dp_{u,i} 表示考虑 u 子树内的点,u 到其父亲的边是否染色的最大权值,转移时可以枚举点 u 的出边中一共有多少条被染色,然后你选出 dp_{v,1} - dp_{v,0} 的前这么多大染色即可。

然后考虑做仙人掌,你发现比较阴间,我们在圆方树上考虑这个问题。

首先状态要改下,对于一个圆点而言,设计 dp_{u,0/1/2} 表示其在圆方树上的子树内边全部染好色,其在父亲点双中有 0/1/2 条边染色,子树内的边带来的代价(不算父亲点双中的边)以及子树内所有点带来的贡献(算自己)。

对于一个方点而言,设计状态 dp_{u,0/1/2} 表示其父亲圆点在这个点双内有 0/1/2 条边染了色,其圆方树上子树内所有边(算父亲圆点在自己点双内的边)染色代价以及子树内所有点带来的贡献(不算父亲圆点)。

方点的转移是简单的,关键是圆点的转移需要处理一个形如:有 n 个物品,每个物品可以以 a_i 的代价带来一贡献或者以 b_i 的代价带来两贡献,求一共有 i 贡献的最大花费,需要对 i \in [1,deg_u] 求解,你发现这就是一个反悔贪心经典例题 Cardboard Box,我们用 5 个堆来维护取物品以及反悔的过程即可。

时间复杂度 O((n+m) \log (n+m))

#include<bits/stdc++.h>
using namespace std;
#define int long long
const int maxn = 2e5+114;
vector<int> E[maxn];
map<int,int> w[maxn];
int stk[maxn],tp,low[maxn],dfn[maxn],dfncnt;
int tot;
int n,m,p;
vector<int> G[maxn<<1];
void tarjan(int u){
    dfn[u]=low[u]=++dfncnt;
    stk[++tp]=u;
    for(int v:E[u]){
        if(dfn[v]==0){
            tarjan(v);
            low[u]=min(low[u],low[v]);
            if(low[v]>=dfn[u]){
                tot++;
                for(int x=0;x!=v;tp--){
                    x=stk[tp];
                    G[tot].push_back(x);
                    G[x].push_back(tot);
                }
                G[tot].push_back(u);
                G[u].push_back(tot);                
            }
        }else low[u]=min(low[u],dfn[v]);
    }
}
int fa[maxn<<1];
void dfs(int u){
    for(int v:G[u]){
        if(v!=fa[u]){
            fa[v]=u;
            dfs(v);
        }
    }
    if(fa[u]!=0){
        vector<int> vec;
        int id=0;
        for(int i=0;i<G[u].size();i++){
            if(G[u][i]==fa[u]){
                id=i;
                break;
            }
        }
        for(int i=id;i<G[u].size();i++) vec.push_back(G[u][i]);
        for(int i=0;i<id;i++) vec.push_back(G[u][i]);
        G[u]=vec;
    }
}
int dp[maxn<<1][3];//0 1 2 init -inf
//u<=n sub and Count of u,sum of u in fa,no cost of edge
//u>n sub and sum of u in fa,cost of edge,no Count of u
const int inf = 5e18;
vector<int> solve(vector<int> a,vector<int> b,int ed){
    vector<int> res;
    res.resize(ed+1);
    vector<int> vis;
    vis.resize(a.size());
    int ans=0;
    priority_queue< pair<int,int> > q[5];//0 1 2 3 4 | vis[i]=0 max(a) | vis[i]=1 max(b-a) | vis[i] = 1 max(-a) | vis[i] = 0 max(b) | vis[i] = 2 max(a-b)
    for(int i=0;i<a.size();i++){
        q[0].push(make_pair(a[i],i));
        if(b[i]>-inf) q[3].push(make_pair(b[i],i));
    }
    for(int i=1;i<=ed;i++){
        while(q[0].size()>0&&vis[q[0].top().second]!=0) q[0].pop();
        while(q[1].size()>0&&vis[q[1].top().second]!=1) q[1].pop();
        while(q[2].size()>0&&vis[q[2].top().second]!=1) q[2].pop();
        while(q[3].size()>0&&vis[q[3].top().second]!=0) q[3].pop();
        while(q[4].size()>0&&vis[q[4].top().second]!=2) q[4].pop();
        int ch1=-inf,ch2=-inf,ch3=-inf,ch4=-inf;
        if(q[0].size()>0) ch1=q[0].top().first;
        if(q[1].size()>0) ch2=q[1].top().first;
        if(q[2].size()>0&&q[3].size()>0) ch3=q[2].top().first+q[3].top().first;
        if(q[4].size()>0&&q[3].size()>0) ch4=q[3].top().first+q[4].top().first;
        int mx=max(max(ch1,ch2),max(ch3,ch4));
        ans+=mx;
        if(mx==ch1){
            int x=q[0].top().second;
            q[0].pop();
            vis[x]=1;
            if(b[x]>-inf) q[1].push(make_pair(b[x]-a[x],x));
            q[2].push(make_pair(-a[x],x));
        }else if(mx==ch2){
            int x=q[1].top().second;
            q[1].pop();
            vis[x]=2;
            if(b[x]>-inf) q[4].push(make_pair(a[x]-b[x],x));
        }else if(mx==ch3){
            int x=q[2].top().second,y=q[3].top().second;
            q[2].pop(),q[3].pop();
            vis[x]=0;
            vis[y]=2;
            q[0].push(make_pair(a[x],x));
            if(b[x]>-inf) q[3].push(make_pair(b[x],x));
            if(b[y]>-inf) q[4].push(make_pair(a[y]-b[y],y));
        }else{
            int x=q[4].top().second,y=q[3].top().second;
            q[4].pop(),q[3].pop();
            vis[x]=1,vis[y]=2;
            if(b[x]>-inf) q[1].push(make_pair(b[x]-a[x],x));
            q[2].push(make_pair(-a[x],x));
            if(b[y]>-inf) q[4].push(make_pair(a[y]-b[y],y));
        }
        res[i]=ans;
    }
    return res;
}
int val(int k,int d){
    return (k*(d-k)+k*(d-k)*(d-2)/2)*p;
}
void DP(int u){
    for(int v:G[u]){
        if(v!=fa[u]) DP(v);
    }
    if(u<=n){
        int ty=0;
        int res=0;
        if(fa[u]!=0){
            if(G[fa[u]].size()==2) ty=1;
            else ty=2;
        }
        vector<int> a,b;
        for(int v:G[u]){
            if(v!=fa[u]){
                res+=dp[v][0];
                a.push_back(dp[v][1]-dp[v][0]);
                if(G[v].size()>2) b.push_back(dp[v][2]-dp[v][0]);
                else b.push_back(-inf);
            }
        }
        vector<int> f=solve(a,b,E[u].size()-ty);
        for(int i=0;i<=E[u].size()-ty;i++){
            f[i]+=res;
            for(int j=0;j<=ty;j++){
                dp[u][j]=max(dp[u][j],f[i]+val(i+j,E[u].size()));
            }
        }
    }else if(G[u].size()==2){
        for(int v:G[u]){
            if(v!=fa[u]){
                dp[u][0]=dp[v][0];
                dp[u][1]=dp[v][1]-w[G[u][0]][G[u][1]];   
            }
        }
    }else{
        int f[2][2];
        for(int i=0;i<2;i++)
            for(int j=0;j<2;j++) f[i][j]=-inf;
            f[0][0]=0;
            f[1][1]=-w[G[u][0]][G[u][1]];
        for(int i=1;i<G[u].size();i++){
            int g[2][2];
            for(int x=0;x<2;x++){
                for(int y=0;y<2;y++) g[x][y]=-inf;
            }
            for(int x=0;x<2;x++){
                for(int y=0;y<2;y++){
                    if(f[x][y]==-inf) continue;
                    for(int z=0;z<2;z++){
                        //xy -> xz
                        g[x][z]=max(g[x][z],f[x][y]-z*w[G[u][i]][G[u][(i+1)%G[u].size()]]+dp[G[u][i]][y+z]);
                    }
                }
            }
            for(int x=0;x<2;x++){
                for(int y=0;y<2;y++) f[x][y]=g[x][y];
            }
        }
        for(int x=0;x<2;x++){
            for(int y=0;y<2;y++) dp[u][x+y]=max(dp[u][x+y],f[x][y]);
        }
    }
}
signed main(){
    ios::sync_with_stdio(0);
    cin.tie(0),cout.tie(0);
    for(int i=0;i<(maxn<<1);i++){
        for(int j=0;j<3;j++) dp[i][j]=-inf;
    }
    cin>>n>>m>>p;
    tot=n;
    for(int i=1;i<=m;i++){
        int u,v;
        cin>>u>>v;
        E[u].push_back(v);
        E[v].push_back(u);
        cin>>w[u][v];
        w[v][u]=w[u][v];
    }
    tarjan(1);
    dfs(1);
    DP(1);
    cout<<dp[1][0]<<'\n';
    return 0;
}
/*
5 4 10
1 4 1
2 3 2
3 4 2
3 5 2
*/
/*
5 5 10
1 2 1
2 3 1
3 4 1
4 5 1
5 1 1
*/