统计子树中城市之间最大距离

题目

1617. 统计子树中城市之间最大距离


给你 n 个城市,编号为从 1 到 n 。同时给你一个大小为 n-1 的数组 edges ,其中 edges[i] = [ui, vi] 表示城市 ui 和 vi 之间有一条双向边。题目保证任意城市之间只有唯一的一条路径。换句话说,所有城市形成了一棵  。

一棵 子树 是城市的一个子集,且子集中任意城市之间可以通过子集中的其他城市和边到达。两个子树被认为不一样的条件是至少有一个城市在其中一棵子树中存在,但在另一棵子树中不存在。

对于 d 从 1 到 n-1 ,请你找到城市间 最大距离 恰好为 d 的所有子树数目。

请你返回一个大小为 n-1 的数组,其中第 d 个元素(下标从 1 开始)是城市间 最大距离 恰好等于 d 的子树数目。

请注意,两个城市间距离定义为它们之间需要经过的边的数目。

示例 1:

1
2
3
4
5
6
输入:n = 4, edges = [[1,2],[2,3],[2,4]]
输出:[3,4,0]
解释:
子树 {1,2}, {2,3} 和 {2,4} 最大距离都是 1 。
子树 {1,2,3}, {1,2,4}, {2,3,4} 和 {1,2,3,4} 最大距离都为 2 。
不存在城市间最大距离为 3 的子树。

示例 2:

1
2
输入:n = 2, edges = [[1,2]]
输出:[1]

示例 3:

1
2
输入:n = 3, edges = [[1,2],[2,3]]
输出:[2,1]

提示:

  • 2 <= n <= 15
  • edges.length == n-1
  • edges[i].length == 2
  • 1 <= ui, vi <= n
  • 题目保证 (ui, vi) 所表示的边互不相同。

题解

方法一:

思路

枚举15个节点的子集。判断是否在一个连通块。

若在同一个连通块,则求树的直径d。对应城市间最大距离恰好为d的所有子树数目+1.

求树的直径,先从任意点v出发开始bfs,然后找到与v最远的节点u。然后从u开始bfs,离u最远的节点直接的距离就是直径。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
class Solution {
public:
const int INF = 0x3f3f3f3f;
vector<int> g[20];
vector<int> countSubgraphsForEachDiameter(int n, vector<vector<int>>& edges) {
for (auto& i:edges) {
i[0]--; i[1]--;
g[i[0]].push_back(i[1]);
g[i[1]].push_back(i[0]);
}
auto get_diameter = [&](int mask) {
vector<int> v1(n, INF), v2(n, INF);
int s = 0;
while ((mask>>s&1) == 0) s++;
v1[s] = 0;
queue<int> q;
q.push(s);
while (q.size()) {
int u = q.front(); q.pop();
for (int v:g[u]) {
if ((mask>>v&1) && v1[v] > v1[u]+1) {
v1[v] = v1[u]+1;
q.push(v);
}
}
}
int mx = s;
for (int i=0; i<n; i++) {
if (mask>>i&1) {
if (v1[i] == INF) return -1; // 不连通
if (v1[mx] < v1[i]) mx = i;
}
}
while (q.size()) q.pop();
v2[mx] = 0;
q.push(mx);
while (q.size()) {
int u = q.front(); q.pop();
for (int v:g[u]) {
if ((mask>>v&1) && v2[v] > v2[u]+1) {
v2[v] = v2[u]+1;
q.push(v);
}
}
}
int dia = 0;
for (int i=0; i<n; i++) {
if (mask>>i&1) {
dia = max(dia, v2[i]);
}
}
return dia;
};
int sz = 1<<n;
vector<int> ans(n-1);
for (int i=1; i<sz; i++) {
int dia = get_diameter(i);
if (dia > 0) {
// cout << dia << " " << i << endl;
ans[dia-1]++;
}
}
return ans;
}
};