Paths on the Tree

Paths on the Tree

Created by LXC on Wed Dec 13 20:15:12 2023

https://codeforces.com/problemset/problem/1746/D

ranting: 1900

tag: dfs and similar, dp, greedy, sortings, trees

problem

给出一颗n个节点的树,并且每个节点都有权值,现在需要选择k条从根出发的路径,使得每条所选路径上节点权值之和最大。要求是每个节点选择的次数与其他所有兄弟节点被选次数相差不超过1。

solution

树上动态规划。只是状态的索引并不是连续的,我们可以将它存在哈希表中。

定义$dp_{i,j}$为以i号节点为根出发j条路径的最大权值和。状态转移取决于子节点。设i的子节点数量为$sz_i$,由于兄弟节点选择次数相差不超过1,所以我们需要尽可能地将j平分给$sz_i$个节点,每个节点至少分配$\lfloor \frac{j}{sz_i} \rfloor$条路径。对于剩余的j%$sz_i$条改分给谁,应该选择各个子节点 分配$\lfloor \frac{j}{sz_i}\rfloor + 1$条路径的贡献 - 分配$\lfloor \frac{j}{sz_i}\rfloor$条路径的贡献 最大的$j \bmod sz_i$条。

code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77

#include <bits/stdc++.h>
// #define SINGLE_INPUT
#define ll long long
#define ull unsigned long long
#define N 500005
#define MOD 998244353
using namespace std;


void sol() {
int n, k;
cin >> n >> k;
vector<vector<int>> g(n+1);
vector<ll> c(n+1);
vector<vector<pair<int,ll>>> h(n+1);
for (int i=2; i<=n; i++) {
int x; cin >> x;
g[x].push_back(i);
}
for (int i=1; i<=n; i++) {
cin >> c[i];
}
auto dfs = [&](auto self, int x, int k) {
for (auto [i,j]:h[x]) {
if (i == k) return j;
}
int sz = g[x].size();
ll res = k*c[x];
if (sz == 0) {
h[x].emplace_back(k, res);
return res;
}
int a = k/sz, b = k%sz;
if (b == 0) {
for (auto y:g[x]) {
res += self(self, y, a);
}
} else {
vector<ll> t1, t2, t3;
for (auto y:g[x]) {
t1.push_back(self(self, y, a));
}
for (auto y:g[x]) {
t2.push_back(self(self, y, a+1));
}
for (int i=0; i<sz; i++) {
t3.push_back(t2[i]-t1[i]);
}
sort(t3.rbegin(), t3.rend());
for (int i=0; i<sz; i++) {
res += t1[i];
if (i<b) res += t3[i];
}
}
h[x].emplace_back(k, res);
return res;
};
cout << dfs(dfs, 1, k) << endl;
}

int main() {
cout << setprecision(15) << fixed;
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
#ifndef SINGLE_INPUT
int t;
cin >> t;
while (t--) {
sol();
}
#else
sol();
#endif
return 0;
}