Array Partition

Array Partition

Created by LXC on Thu Jan 4 09:26:25 2024

https://codeforces.com/problemset/problem/1454/F

ranting: 2100

tag: binary search, data structures, greedy, two pointers

problem

给定序列 $a$,要求将其划分为三个非空子串(设三个子串长分别为 $x,y,z$),使得 $\max\limits_{i=1}^x a_i = \min\limits_{i=x+1}^{x+y} a_i = \max\limits_{i=x+y+1}^n a_i$。

若存在方案,输出 $\texttt{YES}$ 和任意一组 $x,y,z$ 的值;若不存在,输出 $\texttt{NO}$。

$3 \le n \le 2 \cdot 10^5$,$1 \le a_i \le 10^9$

solution

要求将数组分割为三个非空的子数组,中间子数组中的最小值与两侧子数组中的最大值都相等。

我们枚举后缀子数组[i...n-1],求得子数组的最大值为smx,寻找最大值等于smx的前缀区间[0...x],我们可以在[0...i-1]上通过二分得到x的取值范围在[l1,r1)
然后在区间[0...i-1]上查找最小值为smx的后缀[y...i-1],可以用线段树或st表二分得到y的取值范围在[l2,r2)
显然l1<=l2,反证法,若l1>l2,[x...i-1], l2<=x<=l1的最小值为smx,说明[l2...l1]区间内值均不小于smx,所以[0...x], l2<=x<=l1最大值均会不小于smx,而事实上二分得到的l1是满足区间[0...x]最大值为smx这个条件中的x的最小值,与l1>l2冲突。

我们考虑l2作为分割的第二个子数组的首元素位置,
那么需要保证l2在[l2,r2)中,l2-1在[l1,r1)中。
否则考虑l2+1作为分割的第二个子数组的首元素位置,
那么需要保证l2+1在[l2,r2)中,l2在[l1,r1)中。

code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113

#include <bits/stdc++.h>
// #define SINGLE_INPUT
#define ll long long
#define ull unsigned long long
#define N 500005
#define MOD 998244353
#define INF 0x3f3f3f3f
#define ls (id << 1)
#define rs (id << 1 | 1)
using namespace std;

#define MAXN 200005
int st[MAXN][30]; // st[i][j] 代表区间[i, i+2^j)最小值

void ST(const vector<int>& a) {
int n = a.size();
for (int i = 0; i < n; i++)
st[i][0] = a[i];
for (int j = 1; (1 << j) <= n; j++) { // 区间大小
for (int i = 0; i + (1 << j) - 1 < n; i++) { // 区间下限
st[i][j] = min(st[i][j - 1], st[i + (1 << (j - 1))][j - 1]);
}
}
}

int ask(int l, int r) {
int k = 0;
while ((1 << (k + 1)) <= r - l + 1)
k++;
return min(st[l][k], st[r - (1 << k) + 1][k]);
}

void sol() {
int n;
cin >> n;
vector<int> a(n);
for (auto& i : a)
cin >> i;
ST(a);
// [0, p] 中找最小x使得 [x, p] 中最小值为val
auto lb = [&](int p, int val) {
int l = 0, r = p + 1;
while (l < r) {
int m = l + r >> 1;
if (ask(m, p) < val)
l = m + 1;
else
r = m;
}
return r;
};
// [0, p] 中找最大x使得 [x-1, p] 中最小值为val
auto ub = [&](int p, int val) {
int l = 0, r = p + 1;
while (l < r) {
int m = l + r >> 1;
if (ask(m, p) <= val)
l = m + 1;
else
r = m;
}
return r;
};
vector<int> pmx(n + 1);
pmx[0] = a[0];
for (int i = 1; i < n; i++) {
pmx[i] = max(pmx[i - 1], a[i]);
}
int smx = 0;
for (int i = n - 1; i >= 2; i--) {
smx = max(smx, a[i]);
int l1 = lower_bound(pmx.begin(), pmx.begin() + i, smx) - pmx.begin();
int r1 = upper_bound(pmx.begin(), pmx.begin() + i, smx) - pmx.begin();
// smx = [l1, r1)
int l2 = lb(i - 1, smx);
int r2 = ub(i - 1, smx);
// cout << i << ": " << endl;
// cout << l1 << " " << r1 << endl;
// cout << l2 << " " << r2 << endl;
// smx = [l2, r2)
int sp = -1;
if (l2 < r2 && l1 <= l2 - 1 && l2 - 1 < r1) {
sp = l2;
}
if (l2 + 1 < r2 && l1 <= l2 && l2 < r1) {
sp = l2 + 1;
}
if (sp != -1) {
cout << "YES\n";
cout << sp << " " << i - sp << " " << n - i << "\n";
return;
}
}
cout << "NO\n";
}

int main() {
cout << setprecision(15) << fixed;
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
#ifndef SINGLE_INPUT
int t;
cin >> t;
while (t--) {
sol();
}
#else
sol();
#endif
return 0;
}