P10070 [CCO2023] Travelling Trader 题解
win114514
·
·
题解
非常好题目,使我代码长度起飞。
思路
发现 K 只有三种取值。
考虑分类讨论。
k=1
容易发现只需要求一个端点是 1 的最长链。
k=3
考虑这个时候我们将有一个遍历整个树的方案。
考虑递归的处理整个问题。
我们从该节点跳到它一个儿子的儿子。
然后递归处理这个儿子的儿子。
然后再跳到该节点的这个儿子的另一个儿子。
然后递归处理。
将所有儿子的儿子处理完以后,在跳回这个儿子。
然后继续处理其他的儿子的儿子。
这样就可以简单找到遍历整个树的方案。
k=2
考虑 k=2 怎么做。
我们可以使用树形 dp。
设:
$f_{x,1}$ 为从 $x$ 出发往下走,对终止节点要求为 $x$ 的某个儿子或 $x$ 的最大贡献。
$f_{x,2}$ 为从 $x$ 的某个儿子出发往下走,对终止节点无要求的最大贡献。
$f_{x,3}$ 为从 $x$ 的某个儿子出发往下走,对终止节点要求为 $x$ 的最大贡献。
考虑转移式。
1. $$f_{x,0}=a_x+f_{y,2}$$
表示先走到 $x$,在直接从 $y$ 往下走。
2. $$f_{x,0}=a_x+f_{y1,3}+\sum_{y\not=y1,y2} a_y+f_{y2,0}$$
表示先走到 $x$,在再把 $y1$ 走一圈后回到 $y1$,然后走它的兄弟,最后在某个兄弟往下走。
3. $$f_{x,1}=a_x+f_{y1,3}+\sum_{y\not=y1}a_y$$
表示先走到 $x$,在再把 $y1$ 走一圈后回到 $y1$,然后走它的兄弟。
4. $$f_{x,2}=f_{x,0}$$
和情况一类似。
5. $$f_{x,2}=\sum_{y\not=y1,y2} a_y+f_{y1,1}+a_x+f_{y2,2}$$
表示先走到 $x$ 的一些儿子,然后走到 $y1$ 这个儿子转一圈,然后回到 $x$,然后从 $y2$ 往下走。
6. $$f_{x,2}=\sum_{y\not=y1,y2,y3} a_y+f_{y1,1}+a_x+f_{y2,3}+f_{y3,0}$$
表示先走到 $x$ 的一些儿子,然后走到 $y1$ 这个儿子转一圈,然后回到 $x$,然后从 $y2$ 往下走一圈,然后从 $y3$ 往下走。
7. $$f_{x,3}=\sum_{y\not=y1} a_y+f_{y1,1}+a_x$$
表示先走到 $x$ 的一些儿子,然后走到 $y1$ 这个儿子转一圈,然后回到 $x$。
注意很重要的一点,在记录方案时,这些顺序时不能随便颠倒的,否则容易方案不合法。
容易发现以上所有 dp 式都可以线性解决。
时间复杂度:$O(n)$。
### Code
```cpp
#include <bits/stdc++.h>
using namespace std;
#define x first
#define y second
#define int long long
#define mp(x, y) make_pair(x, y)
#define eb(...) emplace_back(__VA_ARGS__)
#define fro(i, x, y) for(int i = (x); i <= (y); i++)
#define pre(i, x, y) for(int i = (x); i >= (y); i--)
inline void JYFILE19();
typedef int64_t i64;
typedef pair<int, int> PII;
bool ST;
const int N = 2e5 + 10;
const int mod = 998244353;
int n, m, a[N], dp[N], fa[N], pre[N];
vector<int> to[N];
namespace subtask1 {
inline void dfs(int now, int fa) {
dp[now] = a[now];
for(auto i : to[now]) {
if(i == fa) continue;
dfs(i, now);
if(dp[i] > dp[pre[now]])
pre[now] = i;
}
dp[now] += dp[pre[now]];
}
inline void Solve() {
dfs(1, 0);
vector<int> ans;
int now = 1;
while(now) ans.eb(now), now = pre[now];
cout << dp[1] << "\n";
cout << ans.size() << "\n";
for(auto i : ans) cout << i << " ";
cout << "\n";
}
}
namespace subtask2 {
struct Node {
int x, op;
inline Node(int xx, int opx) {
x = xx, op = opx;
}
};
struct node {
int num, id;
inline bool operator<(const node &tmp) const {
return num < tmp.num;
}
} pr[N], sf[N];
int tp, stk[N], f[N][4], dp[N][2][2][2][4];
vector<Node> g[N][4];
inline void dfs(int now, int fa) {
int sum = 0;
vector<int> son;
for(auto i : to[now])
if(i != fa) son.eb(i);
for(auto i : son)
dfs(i, now), sum += a[i];
tp = 0;
for(auto i : son) stk[++tp] = i;
if(tp == 0) {
fro(i, 0, 3) {
f[now][i] = a[now];
g[now][i].eb(now, 4);
}
return;
}
{
int idl = 0;
for(auto i : son)
if(f[i][2] > f[idl][2])
idl = i;
pr[0] = sf[0] = pr[tp + 1] = sf[tp + 1] = {0, 0};
fro(i, 1, tp) {
pr[i] = {f[stk[i]][3] - a[stk[i]], stk[i]};
sf[i] = {f[stk[i]][3] - a[stk[i]], stk[i]};
}
fro(i, 1, tp) pr[i] = max(pr[i], pr[i - 1]);
pre(i, tp, 1) sf[i] = max(sf[i], sf[i + 1]);
int id = 0;
auto get = [&](int x) {
if(x == 0) return 0ll;
return max(pr[x - 1], sf[x + 1]).num + f[stk[x]][0] - a[stk[x]];
};
fro(i, 1, tp) if(get(id) <= get(i)) id = i;
f[now][0] = a[now] + get(id) + sum;
if(f[now][0] < a[now] + f[idl][2]) {
f[now][0] = a[now] + f[idl][2];
g[now][0].eb(now, 4);
g[now][0].eb(idl, 2);
}
else {
int id1 = max(pr[id - 1], sf[id + 1]).id;
int id2 = stk[id];
g[now][0].eb(now, 4);
if(id1) g[now][0].eb(id1, 3);
for(auto i : son)
if(i != id1 && i != id2)
g[now][0].eb(i, 4);
if(id2) g[now][0].eb(id2, 0);
}
}
{
int id = 0;
for(auto i : son)
if(f[i][3] - a[i] > f[id][3] - a[id])
id = i;
f[now][1] = a[now] + f[id][3] - a[id] + sum;
g[now][1].eb(now, 4);
if(id) g[now][1].eb(id, 3);
for(auto i : son) if(i != id)
g[now][1].eb(i, 4);
}
{
int num1 = f[now][0];
pr[0] = sf[0] = pr[tp + 1] = sf[tp + 1] = {0, 0};
fro(i, 1, tp) {
pr[i] = {f[stk[i]][1] - a[stk[i]], stk[i]};
sf[i] = {f[stk[i]][1] - a[stk[i]], stk[i]};
}
fro(i, 1, tp) pr[i] = max(pr[i], pr[i - 1]);
pre(i, tp, 1) sf[i] = max(sf[i], sf[i + 1]);
int id = 0;
auto get = [&](int x) {
if(x == 0) return 0ll;
return max(pr[x - 1], sf[x + 1]).num + f[stk[x]][2] - a[stk[x]];
};
fro(i, 1, tp) if(get(id) <= get(i)) id = i;
int num2 = a[now] + get(id) + sum;
fro(i, 0, tp) {
fro(op1, 0, 1) {
fro(op2, 0, 1) {
fro(op3, 0, 1) {
dp[i][op1][op2][op3][0] = -1e18;
dp[i][op1][op2][op3][1] = 0;
dp[i][op1][op2][op3][2] = 0;
dp[i][op1][op2][op3][3] = 0;
}
}
}
}
dp[0][0][0][0][0] = 0;
fro(i, 1, tp) {
fro(op1, 0, 1) { fro(op2, 0, 1) { fro(op3, 0, 1) {
fro(k, 0, 3) dp[i][op1][op2][op3][k] = dp[i - 1][op1][op2][op3][k];
}}}
fro(op1, 0, 1) {
fro(op2, 0, 1) {
fro(op3, 0, 1) {
int num = dp[i - 1][op1][op2][op3][0];
int A = dp[i - 1][op1][op2][op3][1];
int B = dp[i - 1][op1][op2][op3][2];
int C = dp[i - 1][op1][op2][op3][3];
if(op1 == 0) {
if(dp[i][1][op2][op3][0] < num - a[stk[i]] + f[stk[i]][0]) {
dp[i][1][op2][op3][0] = num - a[stk[i]] + f[stk[i]][0];
dp[i][1][op2][op3][1] = stk[i];
dp[i][1][op2][op3][2] = B;
dp[i][1][op2][op3][3] = C;
}
}
if(op2 == 0) {
if(dp[i][op1][1][op3][0] < num - a[stk[i]] + f[stk[i]][1]) {
dp[i][op1][1][op3][0] = num - a[stk[i]] + f[stk[i]][1];
dp[i][op1][1][op3][1] = A;
dp[i][op1][1][op3][2] = stk[i];
dp[i][op1][1][op3][3] = C;
}
}
if(op3 == 0) {
if(dp[i][op1][op2][1][0] < num - a[stk[i]] + f[stk[i]][3]) {
dp[i][op1][op2][1][0] = num - a[stk[i]] + f[stk[i]][3];
dp[i][op1][op2][1][1] = A;
dp[i][op1][op2][1][2] = B;
dp[i][op1][op2][1][3] = stk[i];
}
}
}
}
}
}
int num3 = 0, f1 = 0, f2 = 0, f3 = 0;
fro(op1, 0, 1) {
fro(op2, 0, 1) {
fro(op3, 0, 1) {
if(num3 < dp[tp][op1][op2][op3][0]) {
num3 = dp[tp][op1][op2][op3][0];
f1 = op1, f2 = op2, f3 = op3;
}
}
}
}
num3 += sum + a[now];
f[now][2] = max({num1, num2, num3});
if(num1 >= num2 && num1 >= num3) {
g[now][2] = g[now][0];
}
else if(num2 >= num1 && num2 >= num3) {
int id1 = max(pr[id - 1], sf[id + 1]).id;
int id2 = stk[id];
for(auto i : son)
if(i != id1 && i != id2)
g[now][2].eb(i, 4);
if(id1) g[now][2].eb(id1, 1);
g[now][2].eb(now, 4);
if(id2) g[now][2].eb(id2, 2);
}
else {
int id1 = dp[tp][f1][f2][f3][1];
int id2 = dp[tp][f1][f2][f3][2];
int id3 = dp[tp][f1][f2][f3][3];
for(auto i : son)
if(i != id1 && i != id2 && i != id3)
g[now][2].eb(i, 4);
if(id2) g[now][2].eb(id2, 1);
g[now][2].eb(now, 4);
if(id3) g[now][2].eb(id3, 3);
if(id1) g[now][2].eb(id1, 0);
}
}
{
int id = 0;
for(auto i : son)
if(f[i][1] - a[i] > f[id][1] - a[id])
id = i;
f[now][3] = a[now] + f[id][1] - a[id] + sum;
for(auto i : son) if(i != id)
g[now][3].eb(i, 4);
if(id) g[now][3].eb(id, 1);
g[now][3].eb(now, 4);
}
}
vector<int> res;
inline void print(int x, int op) {
for(auto i : g[x][op]) {
if(i.op == 4) res.eb(i.x);
else print(i.x, i.op);
}
}
inline void Solve() {
dfs(1, 0);
int num = max({f[1][0], f[1][1]});
fro(i, 0, 1) if(num == f[1][i]) { print(1, i); break; }
cout << num <<"\n";
cout << res.size() << "\n";
for(auto i : res) cout << i << " ";
cout << "\n";
}
}
namespace subtask3 {
vector<int> ans;
inline void dfs(int now) {
for(auto i : to[now]) {
if(i == fa[now]) continue;
fa[i] = now, dfs(i);
}
}
inline void calc(int now) {
ans.eb(now);
for(auto i : to[now]) {
if(i == fa[now]) continue;
for(auto j : to[i]) {
if(j == fa[i]) continue;
calc(j);
}
ans.eb(i);
}
}
inline void Solve() {
dfs(1), calc(1);
int num = 0;
fro(i, 1, n) num += a[i];
cout << num << "\n";
cout << ans.size() << "\n";
for(auto i : ans) cout << i << " ";
cout << "\n";
}
}
signed main() {
JYFILE19();
cin >> n >> m;
fro(i, 1, n - 1) {
int x, y;
cin >> x >> y;
to[x].eb(y);
to[y].eb(x);
}
fro(i, 1, n) cin >> a[i];
if(m == 1) subtask1::Solve();
if(m == 2) subtask2::Solve();
if(m == 3) subtask3::Solve();
return 0;
}
bool ED;
inline void JYFILE19() {
// freopen("", "r", stdin);
// freopen("", "w", stdout);
ios::sync_with_stdio(0), cin.tie(0);
double MIB = fabs((&ED-&ST)/1048576.), LIM = 1024;
cerr << "MEMORY: " << MIB << endl, assert(MIB<=LIM);
}
```