不包含相邻元素的子序列的最大和

题目

不包含相邻元素的子序列的最大和


给你一个整数数组 nums 和一个二维数组 queries,其中 queries[i] = [posi, xi]

对于每个查询 i,首先将 nums[posi] 设置为 xi,然后计算查询 i 的答案,该答案为 nums不包含相邻元素
的子序列的 最大 和。

返回所有查询的答案之和。

由于最终答案可能非常大,返回其对 109 + 7 取余 的结果。

子序列 是指从另一个数组中删除一些或不删除元素而不改变剩余元素顺序得到的数组。

示例 1:

输入: nums = [3,5,9], queries = [[1,-2],[0,-3]]

输出: 21

解释:
执行第 1 个查询后,nums = [3,-2,9],不包含相邻元素的子序列的最大和为 3 + 9 = 12
执行第 2 个查询后,nums = [-3,-2,9],不包含相邻元素的子序列的最大和为 9 。

示例 2:

输入: nums = [0,-1], queries = [[0,-5]]

输出: 0

解释:
执行第 1 个查询后,nums = [-5,-1],不包含相邻元素的子序列的最大和为 0(选择空子序列)。

提示:

  • 1 <= nums.length <= 5 * 104
  • -105 <= nums[i] <= 105
  • 1 <= queries.length <= 5 * 104
  • queries[i] == [posi, xi]
  • 0 <= posi <= nums.length - 1
  • -105 <= xi <= 105

题解

方法一:

思路

使用线段树,维护区间的不包含相邻元素的子序列的最大和,记为val。

当两个子区间的val能够高效地合并成为当前区间的val,那么查询一次的任意区间也会很高效。
因为任意区间最多只覆盖了线段树$log(n)$个节点,只涉及线段树$log(n)$次区间合并。

由于区间维护的值是不包含相邻元素的子序列的最大和,在两个子区间的值合并时需要考虑是否存在相邻元素,而这个出现相邻元素的位置只有在左区间的右端点和右区间的左端点。

我们可以每个区间维护 不选左端点和不选右端点的val 不选左端点和选取右端点的val 选取左端点和不选右端点的val 选取左端点和选取右端点的val
区间合并时,有16种情况,去除左区间的右端点和右区间的左端点都选取的情况。

对于最小区间[l,r] l=r,区间的左右端点要么都存在,要么都不存在。所以有两种情况时不可能的,我们要维护最大值,所以可以设置为负无穷。

单点修改时,只涉及线段树$log(n)$个节点,有$log(n)$次合并更新。
而查询只需要查询整个区间的最值,所以直接查询线段树的根节点即可。

合并花费时间视为常数,总时间复杂度$O(nlogn)$。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#define ls (id<<1)
#define rs (id<<1|1)
using ll = long long;
const ll NINF = 0xf3f3f3f3f3f3f3f3;
const ll mod = 1e9+7;
template<class t,class u> ostream& operator<<(ostream& os,const pair<t,u>& p) {
return os<<'['<<p.first<<", "<<p.second<<']';
}
template<class t> ostream& operator<<(ostream& os,const vector<t>& v) {
os<<'['; int s = 1;
for(auto e:v) { if (s) s = 0; else os << ", "; os << e; }
return os<<']';
}
template<class t,class u> ostream& operator<<(ostream& os,const map<t,u>& mp){
os<<'{'; int s = 1;
for(auto [x,y]:mp) { if (s) s = 0; else os << ", "; os<<x<<": "<<y; }
return os<<'}';
}

class Solution {
public:
int maximumSumSubsequence(vector<int>& nums, vector<vector<int>>& queries) {
int n = nums.size();
int sz = 4*n+5;
vector val(sz, vector<ll>(4, NINF));
auto push_up = [&](int id) {
val[id] = {NINF, NINF, NINF, NINF};
for (int i=0; i<4; i++) {
for (int j=0; j<4; j++) {
if (i%2 && j/2) continue;
int c = i/2*2 + j%2;
val[id][c] = max(val[id][c], val[ls][i] + val[rs][j]);
}
}
// cout << id << " " << val[id] << endl;
};
function<void(int,int,int)> build = [&](int id, int l, int r) {
if (l == r) {
val[id] = {0, NINF, NINF, nums[l]};
return ;
}
int m = l+r>>1;
build(ls, l, m);
build(rs, m+1, r);
push_up(id);
};
build(1, 0, n-1);
function<void(int,int,int,int,int)> add = [&](int id, int l, int r, int q, int v) {
if (l == r) {
val[id] = {0, NINF, NINF, v};
return ;
}
int m = l+r>>1;
if (q<=m)
add(ls, l, m, q, v);
else
add(rs, m+1, r, q, v);
push_up(id);
};
ll ans = 0;
for (auto& i:queries) {
add(1, 0, n-1, i[0], i[1]);
ll mx = *max_element(val[1].begin(), val[1].end());
ans += mx;
ans %= mod;
}
return ans;
}
};