PALAND - Editorial

PROBLEM LINK:

Contest Division 1
Contest Division 2
Contest Division 3
Contest Division 4

Setter: Omkar Tripathi
Tester: Harris Leung
Editorialist: Trung Dang

DIFFICULTY:

Medium

PREREQUISITES:

Dynamic Programming

PROBLEM:

You are given an array A consisting of N integers. Calculate the number of ways to divide this array into subsegments, such that the sequence formed by taking bitwise AND in each segment of the partition is a palindrome.

More formally, consider a partition of the array into segments [L_1, R_1], [L_2, R_2], [L_3, R_3], \ldots, [L_k, R_k] such that L_1 = 1, L_2 = R_1 + 1, L_3 = R_2 + 1, \ldots, L_k = R_{k - 1} + 1, R_k = N. Let’s define B_i as AND on i-th segment: B_i=A_{L_i}\land\ldots\land A_{R_i}, where x_1\land x_2\land \ldots \land x_t, denotes bitwise AND of numbers x_1, \ldots, x_t.

A partition is palindromic if B_i=B_{k + 1 - i} for every 1 \le i \le k. Your task is to calculate the number of palindromic partitions. Since this number can be large, calculate it modulo 998244353.

EXPLANATION:

We have a pretty straightforward dynamic programming solution: Let dp_{l, r} be the number of palindromic partitions of the subarray A_{[l, r]}, then:

dp_{l, r} = 1 + \sum_{i = l}^{r - 1} \sum_{j = i + 1}^{r} dp_{i + 1, j - 1} \cdot (f(l, i) == f(j, r))

where f(l, r) = A_l \land A_{l + 1} \land \dots \land A_r.

However, this solution obviously will not pass, since the complexity is O(N^4). We have a few observations to speed up this algorithm:

  • Notice that if we fix l, there are at most \log(\max(A)) values of f(l, i). To prove this, we can imagine f(l, i) as the “prefix AND” starting from l, where f(l, i) = f(l, i - 1) \land A_i. Notice that f(l, i) is always a submask of f(l, i - 1). Therefore, if we look at the number of bits of f(l, i), it is always non-decreasing, which means there are only \log(\max(A)) values of the number of bits, which proves our observation.
  • Also notice that f(l, i) is also non-increasing, which can be proven directly from the fact that f(l, i) is a submask of f(l, i - 1).

Therefore, we can group f(l, i) into at most \log(\max(A)) ranges, where each range contains the indices j where f(l, j) have the same value. Similarly, when we fix r, we can group f(j, r) into at most \log(\max(A)) ranges. Finally, using the monotonic property, we can use two-pointers to loop over these ranges and pick out the two ranges where f(l, i) have the same value as f(j, r).

This alone doesn’t solve the problem, since we still need to iterate i and j over these ranges; more specifically, we still need to calculate \sum_{i = a}^{b - 1} \sum_{j = c + 1}^{d} dp_{i + 1}{j - 1}, where [a, b) and (c, d] are the mentioned ranges. Luckily, we can easily optimize this operation: let s_{i, j} be sum of dp_{l, r} where i \le l \le r \le j (basically, s_{i, j} sums over all dp within the range [i, j]). We can easily calculate s_{i, j} after calculating any dp_{i, j} (s_{i, j} = dp_{i, j} + s_{i, j - 1} + s_{i + 1, j} - s_{i + 1, j - 1}), and the previous calculation is just s_{a + 1, d - 1} - s_{a + 1, c - 1} - s_{b + 1, d - 1} + s_{b + 1, c - 1}.

TIME COMPLEXITY:

Time complexity is O(N^2\log(\max(A))) per test case.

SOLUTION:

Preparer's Solution
#ifdef DEBUG
#define _GLIBCXX_DEBUG
#endif
//#pragma GCC optimize("O3")
#include <bits/stdc++.h>
using namespace std;
typedef long double ld;
typedef long long ll;
const int mod = 998244353;
int sum(int a, int b) {
    int s = a + b;
    if (s >= mod) s -= mod;
    return s;
}
int sub(int a, int b) {
    int s = a - b;
    if (s < 0) s += mod;
    return s;
}
int mult(int a, int b) {
    return (1LL * a * b) % mod;
}
int n;
const int maxN = 3005;
int A[maxN];
int dp[maxN][maxN];
int S[maxN][maxN];
int get(int l1, int l2, int r1, int r2) {
    if (l1 > l2 || r1 > r2) {
        return 0;
    }
    assert(l1 <= l2 && r1 <= r2);
    int ans = S[l1][r2];
    ans = sub(ans, S[l2 + 1][r2]);
    ans = sub(ans, S[l1][r1 - 1]);
    ans = sum(ans, S[l2 + 1][r1 - 1]);
    return ans;
}
const int SZ = 35;
int pref[maxN][SZ];
int pref_val[maxN][SZ];
int size_pref[maxN];
int suf[maxN][SZ];
int size_suf[maxN];
int suf_val[maxN][SZ];
int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
//    freopen("input.txt", "r", stdin);
    int tst;
    cin >> tst;
    while (tst--) {
        cin >> n;
        for (int i = 1; i <= n; i++) {
            cin >> A[i];
        }
        for (int i = 1; i <= n; i++) {
            size_pref[i] = 0;
            int cur_and = A[i];
            for (int j = i; j <= n; j++) {
                if (j == i || (A[j] & cur_and) != cur_and) {
                    pref[i][size_pref[i]] = j;
                    cur_and &= A[j];
                    pref_val[i][size_pref[i]++] = cur_and;
                }
            }
        }
        for (int i = n; i >= 1; i--) {
            size_suf[i] = 0;
            int cur_and = A[i];
            for (int j = i; j >= 1; j--) {
                if (j == i || (A[j] & cur_and) != cur_and) {
                    suf[i][size_suf[i]] = j;
                    cur_and &= A[j];
                    suf_val[i][size_suf[i]++] = cur_and;
                }
            }
        }
        for (int i = 1; i <= n; i++) {
            for (int j = i - 1; j <= n; j++) {
                dp[i][j] = S[i][j] = 0;
            }
        }
        for (int l = n; l >= 1; l--) {
            dp[l][l - 1] = 1;
            S[l][l - 1] = 1;
            for (int r = l; r <= n; r++) {
                S[l][r] = sum(S[l + 1][r], S[l][r - 1]);
                S[l][r] = sub(S[l][r], S[l + 1][r - 1]);
                int ptr1 = 0;
                int ptr2 = 0;
                while (ptr1 < size_pref[l] && ptr2 < size_suf[r]) {
                    if (pref_val[l][ptr1] == suf_val[r][ptr2]) {
                        dp[l][r] = sum(dp[l][r],
                                       get(pref[l][ptr1] + 1, (ptr1 == size_pref[l] - 1) ? n : pref[l][ptr1 + 1],
                                           (ptr2 == size_suf[r] - 1) ? 1 : suf[r][ptr2 + 1], suf[r][ptr2] - 1));
                        ptr1++;
                        ptr2++;
                    } else if (pref_val[l][ptr1] > suf_val[r][ptr2]) {
                        ptr1++;
                    } else {
                        ptr2++;
                    }
                }
                dp[l][r] = sum(dp[l][r], 1);
                S[l][r] = sum(S[l][r], dp[l][r]);
            }
        }
        cout << dp[1][n] << '\n';
    }
    return 0;
}
Tester's Solution
#ifdef DEBUG
#define _GLIBCXX_DEBUG
#endif
//#pragma GCC optimize("O3")
#include <bits/stdc++.h>
using namespace std;
typedef long double ld;
typedef long long ll;
const int mod = 998244353;
int sum(int a, int b) {
    int s = a + b;
    if (s >= mod) s -= mod;
    return s;
}
int sub(int a, int b) {
    int s = a - b;
    if (s < 0) s += mod;
    return s;
}
int mult(int a, int b) {
    return (1LL * a * b) % mod;
}
int n;
const int maxN = 3005;
int A[maxN];
int dp[maxN][maxN];
int S[maxN][maxN];
int get(int l1, int l2, int r1, int r2) {
    if (l1 > l2 || r1 > r2) {
        return 0;
    }
    assert(l1 <= l2 && r1 <= r2);
    int ans = S[l1][r2];
    ans = sub(ans, S[l2 + 1][r2]);
    ans = sub(ans, S[l1][r1 - 1]);
    ans = sum(ans, S[l2 + 1][r1 - 1]);
    return ans;
}
const int SZ = 35;
int pref[maxN][SZ];
int pref_val[maxN][SZ];
int size_pref[maxN];
int suf[maxN][SZ];
int size_suf[maxN];
int suf_val[maxN][SZ];
int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
//    freopen("input.txt", "r", stdin);
    int tst;
    cin >> tst;
    while (tst--) {
        cin >> n;
        for (int i = 1; i <= n; i++) {
            cin >> A[i];
        }
        for (int i = 1; i <= n; i++) {
            size_pref[i] = 0;
            int cur_and = A[i];
            for (int j = i; j <= n; j++) {
                if (j == i || (A[j] & cur_and) != cur_and) {
                    pref[i][size_pref[i]] = j;
                    cur_and &= A[j];
                    pref_val[i][size_pref[i]++] = cur_and;
                }
            }
        }
        for (int i = n; i >= 1; i--) {
            size_suf[i] = 0;
            int cur_and = A[i];
            for (int j = i; j >= 1; j--) {
                if (j == i || (A[j] & cur_and) != cur_and) {
                    suf[i][size_suf[i]] = j;
                    cur_and &= A[j];
                    suf_val[i][size_suf[i]++] = cur_and;
                }
            }
        }
        for (int i = 1; i <= n; i++) {
            for (int j = i - 1; j <= n; j++) {
                dp[i][j] = S[i][j] = 0;
            }
        }
        for (int l = n; l >= 1; l--) {
            dp[l][l - 1] = 1;
            S[l][l - 1] = 1;
            for (int r = l; r <= n; r++) {
                S[l][r] = sum(S[l + 1][r], S[l][r - 1]);
                S[l][r] = sub(S[l][r], S[l + 1][r - 1]);
                int ptr1 = 0;
                int ptr2 = 0;
                while (ptr1 < size_pref[l] && ptr2 < size_suf[r]) {
                    if (pref_val[l][ptr1] == suf_val[r][ptr2]) {
                        dp[l][r] = sum(dp[l][r],
                                       get(pref[l][ptr1] + 1, (ptr1 == size_pref[l] - 1) ? n : pref[l][ptr1 + 1],
                                           (ptr2 == size_suf[r] - 1) ? 1 : suf[r][ptr2 + 1], suf[r][ptr2] - 1));
                        ptr1++;
                        ptr2++;
                    } else if (pref_val[l][ptr1] > suf_val[r][ptr2]) {
                        ptr1++;
                    } else {
                        ptr2++;
                    }
                }
                dp[l][r] = sum(dp[l][r], 1);
                S[l][r] = sum(S[l][r], dp[l][r]);
            }
        }
        cout << dp[1][n] << '\n';
    }
    return 0;
}
Editorialist's Solution
#include <bits/stdc++.h>
using namespace std;

const int MOD = 998244353;

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    int t; cin >> t;
    while (t--) {
        int n; cin >> n;
        vector<int> a(n);
        for (int i = 0; i < n; i++) {
            cin >> a[i];
        }
        vector<vector<pair<int, int>>> forward(n), backward(n);
        for (int i = 0; i < n; i++) {
            int f = a[i];
            for (int j = i + 1; j < n; j++) {
                if ((f & a[j]) < f) {
                    forward[i].push_back({f, j});
                }
                f &= a[j];
            }
            forward[i].push_back({f, n});
            int b = a[i];
            for (int j = i - 1; j >= 0; j--) {
                if ((b & a[j]) < b) {
                    backward[i].push_back({b, j});
                }
                b &= a[j];
            }
            backward[i].push_back({b, -1});
        }
        vector<vector<long long>> dp(n, vector<long long>(n)), sum(n, vector<long long>(n));
        for (int l = n - 1; l >= 0; l--) {
            for (int r = l; r < n; r++) {
                dp[l][r] = 1;
                if (l == r) {
                    sum[l][r] = 1;
                } else {
                    int cur_left = l, cur_right = r, lp = 0, rp = 0;
                    while (cur_left < cur_right && lp < forward[l].size() && rp < backward[r].size()) {
                        if (forward[l][lp].first < backward[r][rp].first) {
                            cur_right = backward[r][rp++].second;
                        } else if (forward[l][lp].first > backward[r][rp].first) {
                            cur_left = forward[l][lp++].second;
                        } else {
                            auto [_, nxt_left] = forward[l][lp++];
                            auto [__, nxt_right] = backward[r][rp++];
                            if (nxt_left <= nxt_right) {
                                // sum of dp[i][j] where cur_left < i <= nxt_left <= nxt_right <= j < cur_right
                                (dp[l][r] += sum[cur_left + 1][cur_right - 1] - sum[cur_left + 1][nxt_right - 1] - sum[nxt_left + 1][cur_right - 1] + sum[nxt_left + 1][nxt_right - 1]) %= MOD;
                            } else {
                                (dp[l][r] += sum[cur_left + 1][cur_right - 1]) %= MOD;
                                (dp[l][r] += cur_right - cur_left) %= MOD;
                            }
                            cur_left = nxt_left; cur_right = nxt_right;
                        }
                    }
                    sum[l][r] = (sum[l + 1][r] + sum[l][r - 1] - sum[l + 1][r - 1] + dp[l][r]) % MOD;
                }
            }
        }
        cout << (dp[0][n - 1] + MOD) % MOD << '\n';
    }
}