PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: raysh07
Tester: sushil2006
Editorialist: iceknight1093
DIFFICULTY:
Medium
PREREQUISITES:
Familiarity with bitwise OR
PROBLEM:
For an array A, we define f(A) as the maximum number of subarrays A can be partitioned into, such that the sum of the bitwise ORs of the subarrays is minimized.
You’re given an array A. Compute the sum of f(S) across all subarrays of A.
EXPLANATION:
Let’s understand how to compute f(A) for a single array A first.
The minimum possible sum of bitwise ORs of subarrays is easy to find: it just equals the bitwise OR of the entirety of A, and is achieved by taking A itself as a single subarray.
Next, we need to find the largest possible partition into subarrays that achieves this minimum.
Observe that since the minimum has each bit appearing somewhere in A exactly once, any valid partition must also be such that each bit appearing in A, appears in exactly one subarray of the partition - otherwise, this bit will contribute more than once to the sum, breaking minimality.
Essentially, we need all the subarrays in the partition to be bitwise disjoint, while maximizing their count.
Let’s now focus on the first chosen subarray, which will be some prefix of A.
For convenience, let’s define P_i = A_1 \mid A_2 \mid \ldots \mid A_i and S_i = A_i \mid \ldots\mid A_N to be the prefix and suffix ORs of A, respectively.
Now, observe that we can only choose the prefix ending at index i, if P_i \cap S_{i+1} = \emptyset, i.e. elements till i and elements after i don’t share any bits.
There can be multiple such indices, so for now let’s just choose any one of them, say i_1.
Next, consider the second subarray chosen - suppose it’s A[i_1+1, i_2].
Observe that i_2 must also satisfy P_{i_2} \cap S_{i_2 + 1} = \emptyset.
More generally, if we choose to split the subarrays at indices i_1, i_2, i_3, \ldots then each of these indices must satisfy P_{i_j} \cap S_{i_j + 1} = \emptyset.
We thus have a set of candidates for the endpoints of the subarrays: all indices for which P_i \cap S_{i+1} = \emptyset.
It’s not hard to see that we can split at all such indices and be fine - which immediately tells us what the maximum number of subarrays obtainable is.
As a result of the above discussion, we quite simply obtain f(A) to be equal to the number of indices of A such that P_i \cap S_{i+1} = \emptyset.
Let’s call such indices good indices.
We now use this characterization to solve for summing across subarrays.
Let’s fix a left endpoint L, and consider some index i \ge L.
We try to count the number of R \ge i for which i is a good index.
This is simple: if X = A_L \mid A_{L+1} \mid\ldots \mid A_i denotes the bitwise OR from L to i, then i will remain a good index till the first time an element is not disjoint from X.
That is, let R_0 \gt i be the smallest index such that A_{R_0} \cap X \neq \emptyset, then i will be a good index for all R \in [i, R_0-1].
Now, it’s not too hard to find R_0 given both L and i are fixed; but trying all choices of them is too expensive so we need something faster.
To that end, observe that what really matters here is X - the prefix OR starting at L.
So, suppose we fix the value of X.
The set of indices i for which X is the prefix OR will form a contiguous segment - let this be [l, r].
Observe that for any index i \in [l, r]:
- If there exists j \in [i+1, r] such that A_j \gt 0, then A_j surely shares bits with X (if it didn’t, it would change the prefix XOR).
So, in this case, i would be good only till the next non-zero element is reached, and then stops being good. - If there’s no j \in [i+1, r] such that A_j \gt 0, then we’ll need to find the next element after index r that shares bits with X.
In particular, if k is the largest index in [l, r] such that A_k \gt 0, then observe that the counts for all the indices l, l+1, \ldots, k-1 are easily found (for each of them, it’s the distance to the next non-zero element), and for indices k, k+1, \ldots, r, all of them will have the same endpoint of being good (namely, just before the next element that shares bits with X).
So, suppose we define d_i to be the distance from i to the next non-zero element after i.
Then, to solve for the interval [l, r], with prefix OR X:
- Let k be the largest index in [l, r] containing a non-zero element.
- Let R_0 \gt r be the next index that shares bits with X.
- The overall contribution to the answer is then:
- d_l + d_{l+1} + \ldots + d_{k-1} for indices \lt k.
- (R_0 - k) + (R_0 - k - 1) + (R_0 - k - 2) + \ldots + (R_0 - r) for indices \ge k.
The first term is a range sum of d, and can be found in constant time using prefix sums.
The second term is the sum of consecutive integers and so can be computed in constant time, as long as the values R_0, k, r are known.
Since k and r are fixed, only R_0 needs to be found - which is pretty easy, if you precompute \text{next}[i][j] to be the next element after index i that contains bit j.
So, for a fixed left endpoint L and prefix OR X with right endpoints spanning [l, r], we’re able to compute the contribution to the answer in \mathcal{O}(\log\max(A)) time (with the only slow part being the computation of R_0).
Now observe that there are at most 31 possible values of X (given that L is fixed), since each new value must add at least one new bit and we only have 30 bits.
Finding the values of X and the ranges [l, r] that the right endpoints span is not too hard, and can be done by example by looking at the values of \text{next}[L][j] - each distinct element of this form gives a new prefix OR.
So, we process \mathcal{O}(N\log\max(A)) intervals, each in \mathcal{O}(\log\max(A)) time, which is more than fast enough.
TIME COMPLEXITY:
\mathcal{O}(N\log^2 (\max(A))) per testcase.
CODE:
Editorialist's code (C++)
// #include <bits/allocator.h>
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());
int main()
{
ios::sync_with_stdio(false); cin.tie(0);
int t; cin >> t;
while (t--) {
int n; cin >> n;
vector a(n, 0);
for (int &x : a) cin >> x;
vector pnz(n, -1);
for (int i = 0; i < n; ++i) {
if (i) pnz[i] = pnz[i-1];
if (a[i] > 0) pnz[i] = i;
}
vector nxtd(n, 1);
for (int i = n-2; i >= 0; --i) {
if (a[i+1] > 0) nxtd[i] = 1;
else nxtd[i] = nxtd[i+1] + 1;
}
vector pref(n, 0ll);
for (int i = 0; i < n; ++i) {
pref[i] = nxtd[i];
if (i) pref[i] += pref[i-1];
}
vector<array<int, 30>> jump(n);
for (int i = n-1; i >= 0; --i) {
for (int j = 0; j < 30; ++j) {
jump[i][j] = n;
if (i+1 < n) jump[i][j] = jump[i+1][j];
if (i+1 < n and (a[i+1] & (1 << j))) jump[i][j] = i+1;
}
}
vector<array<int, 2>> ors;
ll ans = 0;
for (int i = n-1; i >= 0; --i) {
vector<array<int, 2>> cur = {{a[i], i}};
for (auto [x, y] : ors) {
int nx = x | a[i];
if (nx == cur.back()[0]) cur.back()[1] = max(cur.back()[1], y);
else cur.push_back({nx, y});
}
swap(ors, cur);
int prv = i-1;
for (auto [x, y] : ors) {
int l = max(i, pnz[y]);
int ct = y - l + 1;
ans += pref[y];
if (prv >= 0) ans -= pref[prv];
ans -= (nxtd[y]-1)*ct;
int bad = n;
for (int b = 0; b < 30; ++b) {
if ((x >> b) & 1) bad = min(bad, jump[y][b]);
}
ans += 1ll*ct*(bad - y - 1);
prv = y;
}
}
cout << ans << '\n';
}
}
Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18
mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());
void Solve()
{
int n; cin >> n;
vector <int> a(n + 1);
for (int i = 1; i <= n; i++){
cin >> a[i];
}
vector<vector<int>> p(n + 1, vector<int>(30, 0));
vector<vector<int>> s(n + 2, vector<int>(30, n + 1));
for (int i = 1; i <= n; i++){
for (int j = 0; j < 30; j++){
if (a[i] >> j & 1){
p[i][j] = i;
} else {
p[i][j] = p[i - 1][j];
}
}
}
for (int i = n; i >= 1; i--){
for (int j = 0; j < 30; j++){
if (a[i] >> j & 1){
s[i][j] = i;
} else {
s[i][j] = s[i + 1][j];
}
}
}
int ans = 0;
for (int i = 1; i < n; i++){
vector <pair<int, int>> vec;
for (int j = 0; j < 30; j++){
vec.push_back({p[i][j], s[i + 1][j]});
}
sort(vec.begin(), vec.end(), greater<pair<int, int>>());
int mn = n + 1;
ans += (mn - i - 1) * (i - vec[0].first);
for (int j = 0; j < 30; j++){
int r = vec[j].first;
if (j + 1 != 30) r -= vec[j + 1].first;
mn = min(mn, vec[j].second);
ans += (mn - i - 1) * r;
}
}
ans += n * (n + 1) / 2;
cout << ans << "\n";
}
int32_t main()
{
auto begin = std::chrono::high_resolution_clock::now();
ios_base::sync_with_stdio(0);
cin.tie(0);
int t = 1;
// freopen("in", "r", stdin);
// freopen("out", "w", stdout);
cin >> t;
for(int i = 1; i <= t; i++)
{
//cout << "Case #" << i << ": ";
Solve();
}
auto end = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n";
return 0;
}