统计树中的合法路径数目

题目

2867. 统计树中的合法路径数目


给你一棵 n 个节点的无向树,节点编号为 1 到 n 。给你一个整数 n 和一个长度为 n - 1 的二维整数数组 edges ,其中 edges[i] = [ui, vi] 表示节点 ui 和 vi 在树中有一条边。

请你返回树中的 合法路径数目 。

如果在节点 a 到节点 b 之间 恰好有一个 节点的编号是质数,那么我们称路径 (a, b) 是 合法的 。

注意:

  • 路径 (a, b) 指的是一条从节点 a 开始到节点 b 结束的一个节点序列,序列中的节点 互不相同 ,且相邻节点之间在树上有一条边。
  • 路径 (a, b) 和路径 (b, a) 视为 同一条 路径,且只计入答案 一次 。

示例 1:

1
2
3
4
5
6
7
8
输入:n = 5, edges = [[1,2],[1,3],[2,4],[2,5]]
输出:4
解释:恰好有一个质数编号的节点路径有:
- (1, 2) 因为路径 1 到 2 只包含一个质数 2 。
- (1, 3) 因为路径 1 到 3 只包含一个质数 3 。
- (1, 4) 因为路径 1 到 4 只包含一个质数 2 。
- (2, 4) 因为路径 2 到 4 只包含一个质数 2 。
只有 4 条合法路径。

示例 2:

1
2
3
4
5
6
7
8
9
10
输入:n = 6, edges = [[1,2],[1,3],[2,4],[3,5],[3,6]]
输出:6
解释:恰好有一个质数编号的节点路径有:
- (1, 2) 因为路径 1 到 2 只包含一个质数 2 。
- (1, 3) 因为路径 1 到 3 只包含一个质数 3 。
- (1, 4) 因为路径 1 到 4 只包含一个质数 2 。
- (1, 6) 因为路径 1 到 6 只包含一个质数 3 。
- (2, 4) 因为路径 2 到 4 只包含一个质数 2 。
- (3, 6) 因为路径 3 到 6 只包含一个质数 3 。
只有 6 条合法路径。

提示:

  • 1 <= n <= 10^5
  • edges.length == n - 1
  • edges[i].length == 2
  • 1 <= ui, vi <= n
  • 输入保证 edges 形成一棵合法的树。

题解

方法一:

思路

这是一道比较综合的题,求n以内的质数,深搜/广搜,贡献计数。

核心的思路,找到每个质数点有多少条路径经过它。

具体的做法,先以质数的节点作为分割点求出各个连通块的大小。然后对于每个质数节点,找到其所有连接的连通块,如果有k个质数,设第i个质数节点所连接的连通块有$n_i$个,且每个连通块大小为$s_{i,1}, s_{i,2}, \cdots, s_{i, n_i}$

由于每个连通块内的点可以经过第i个质数到达其他连通块。比如第一个连通块就有$s_{i,1}$个点,到其他联通块的点有$s_{i,2}+ \cdots + s_{i, n_i}$个点,但是别忘了还可以只经过第i个质数。所以第1个连通块的贡献是$s_{i,1}\times \sum \limits_{j=2}^{n_i}s_{i,j}$,同理第t个连通块的贡献是$s_{i,t}\times (-s_{i,t}+\sum \limits_{j=1}^{n_i}s_{i,j})$

第i个质数节点的总贡献是$\sum \limits {t=1}^{n_i}s{i,t}\times (-s_{i,t}+\sum \limits_{j=1}^{n_i}s_{i,j})$

所有质数的贡献是$\sum \limits {u=1}^k \sum \limits {t=1}^{n_i}s{i,t}\times (-s{i,t}+\sum \limits_{j=1}^{n_i}s_{i,j})$

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#define MAXN 100005
int nprime[MAXN];
vector <int> prime;


int get_prime = []() {
memset(nprime, 0, sizeof(nprime));
nprime[1] = 1;
for (int i=2; i*i<MAXN; i++) {
if (!nprime[i]) {
for (int j=i*i; j<MAXN; j+=i) {
nprime[j] = 1;
}
}
}
for (int i=2; i<MAXN; i++) {
if (!nprime[i]) {
prime.push_back(i);
// cout << i << endl;
}
}
return 0;
}();
class Solution {
public:
using ll = long long;
long long countPaths(int n, vector<vector<int>>& edges) {
// cout << prime.size();
vector<vector<int>> g(n+1);
for (auto& i:edges) {
g[i[0]].push_back(i[1]);
g[i[1]].push_back(i[0]);
}
vector<int> vis(n+1, -1);
vector<int> cnt;
function<int(int,int)> dfs = [&](int x, int f) {
int rt = 1;
vis[x] = f;
for (int y:g[x]) {
if (vis[y] != -1 || nprime[y] == 0) continue;
// cout << y << " " << f;
rt += dfs(y, f);
}
return rt;
};
ll ans = 0;
for (int i=1; i<=n; i++) {
if (nprime[i] == 0 || vis[i] != -1) continue;
int sz = cnt.size();
cnt.push_back(dfs(i, sz));
}
// for (int i=1; i<=n; i++) {
// cout << i << " " << vis[i] << endl;
// if (nprime[i]) cout << cnt[vis[i]] << " " << endl;
// }
for (int i=1; i<=n; i++) {
if (nprime[i]) continue;
vector<ll> son;
ll sum = 0;
for (int x:g[i]) {
if (nprime[x]) {
sum += cnt[vis[x]];
son.push_back(cnt[vis[x]]);
}
}
// cout << sum << endl;
if (sum) {
for (ll x:son) {
ans += x*(sum - x + 1);
sum -= x;
}
}
}
return ans;
}
};