ANDSORT - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: sushil2006
Tester: sushil2006
Editorialist: iceknight1093

DIFFICULTY:

Easy - Medium

PREREQUISITES:

Bitmask DP

PROBLEM:

You’re given an array A of length N with elements from 0 to N.
Count the number of integers X between 0 and N such that the array B where
B_i = A_i \land X
is sorted in non-descending order.

\land denotes the bitwise AND operator.

EXPLANATION:

For the final array to be sorted in non-descending order, it’s enough to ensure that
(A_i\land X) \leq (A_{i+1} \land X) for each 1 \leq i \lt N.

Let’s see when (A_i\land X) \leq (A_{i+1} \land X) is possible.
For this to be the case, at the highest bit where A_i\land X and A_{i+1}\land X differ, the former must have it unset while the latter must have it set (or they don’t differ at all, in the case of equality).
In particular, note that bits that are equal in A_i and A_{i+1} (either both set or both unset) don’t matter at all.

Let’s call a bit a “10-type” if it’s set in A_i but not A_{i+1}, and a “01-type” if it’s set in A_{i+1} but not A_i.
Taking the bitwise AND with X will either set both bits to 0 if X has the bit unset (in which case what they originally were doesn’t matter), or preserve the relation between them if X has the bit set.

In particular, if X is set at at 10-type bit, this cannot be the highest bit where A_i\land X and A_{i+1}\land X differ.
So, if X is set at a 10-type bit, there must exist some higher 01-type bit such that X is also set there (and if no higher 01-type bit exists, this bit isn’t allowed to be set in X at all).


Notice that we now have several constraints of the form (b, mask), meaning “if X is valid and bit b is set in X, at least one of the bits in mask must also be set in X”.

Equivalently, if for some (b, mask) none of the bits in mask are set in X but b is set, this X will be invalid.

Given that we know that mask being unset in X is bad for bit b, any supermask of mask being unset is also bad for b (since it’ll mean that mask is itself unset).
This allows us to find all masks that can be unset with respect to b using dynamic programming.

Let f(b, mask) be a boolean function, returning true if mask can be unset with respect to b and false otherwise.
Then,

  • If we have a constraint that’s (b, mask), then f(b, mask) = 0 for sure.
  • Otherwise, f(b, mask) is true if and only if f(b, mask \oplus 2^k) is true for every bit k that’s set in mask.
    This is because if any submask of mask is bad, mask is also bad - and to avoid iterating over all submasks, we remove one bit at a time and let DP do the rest.

With dynamic programming, all the f(b, mask) values can be computed in \mathcal{O}(N\log N) or \mathcal{O}(N\log^2 N) time, depending on implementation.


Now that these values are known, we can move to checking which X are valid.

If we fix a value of X, it’s valid if and only if, for each set bit b in X, f(b, mask) = 1 - where mask is the mask of unset bits larger than b in X.
This is easily checked in \mathcal{O}(\log N) time for each X by just iterating over bits (and f(b, mask) can be looked up in constant time), so this part takes \mathcal{O}(N\log N) time overall.

The initial part of generating all constraints also takes \mathcal{O}(N\log N) time, so the runtime is dominated by the middle part of computing all f(b, mask) values.

TIME COMPLEXITY:

\mathcal{O}(N\log N) or \mathcal{O}(N\log^2 N) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>

using namespace std;
using namespace __gnu_pbds;

template<typename T> using Tree = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;
typedef long long int ll;
typedef long double ld;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;

#define fastio ios_base::sync_with_stdio(false); cin.tie(NULL)
#define pb push_back
#define endl '\n'
#define sz(a) (int)a.size()
#define setbits(x) __builtin_popcountll(x)
#define ff first
#define ss second
#define conts continue
#define ceil2(x,y) ((x+y-1)/(y))
#define all(a) a.begin(), a.end()
#define rall(a) a.rbegin(), a.rend()
#define yes cout << "YES" << endl
#define no cout << "NO" << endl

#define rep(i,n) for(int i = 0; i < n; ++i)
#define rep1(i,n) for(int i = 1; i <= n; ++i)
#define rev(i,s,e) for(int i = s; i >= e; --i)
#define trav(i,a) for(auto &i : a)

template<typename T>
void amin(T &a, T b) {
    a = min(a,b);
}

template<typename T>
void amax(T &a, T b) {
    a = max(a,b);
}

#ifdef LOCAL
#include "debug.h"
#else
#define debug(...) 42
#endif

/*



*/

const int MOD = 1e9 + 7;
const int N = 1e5 + 5;
const int inf1 = int(1e9) + 5;
const ll inf2 = ll(1e18) + 5;

void solve(int test_case){
    int n; cin >> n;
    vector<int> a(n+5);
    rep1(i,n) cin >> a[i];

    int B = 0;
    while((1<<B) <= n) B++;

    int dp[B][1<<B];
    memset(dp,0,sizeof dp);

    rep1(i,n-1){
        int x = a[i], y = a[i+1];
        int curr_mask = 0;

        rev(bit,B-1,0){
            int f = 1<<bit;
            int b1 = 0, b2 = 0;
            if(x&f) b1 = 1;
            if(y&f) b2 = 1;
            if(b1 == b2){
                // can take anything
                curr_mask |= f;
            }
            else{
                if(!b1 and b2){
                    // 0 = continue
                    // 1 = becomes good
                }   
                else{
                    // 0 = continue
                    // 1 = becomes bad
                    // we are counting the bad masks, so push to dp
                    int mask = curr_mask|f;
                    mask |= f-1;
                    dp[bit][mask]++;
                }
            }
        }
    }

    rep(bit1,B){
        rep(bit,B){
            rep(mask,1<<B){
                if(!(mask&(1<<bit))){
                    dp[bit1][mask] += dp[bit1][mask^(1<<bit)];
                }
            }
        }
    }

    int ans = 0;

    rep(mask,n+1){
        bool ok = true;
        rep(bit,B){
            if(mask&(1<<bit)){
                if(dp[bit][mask]){
                    ok = false;
                    break;
                }
            }
        }

        if(ok){
            ans++;
        }
    }

    cout << ans << endl;
}

int main()
{
    fastio;

    int t = 1;
    cin >> t;

    rep1(i, t) {
        solve(i);
    }

    return 0;
}
Editorialist's code (C++)
// #include <bits/allocator.h>
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());

const int N = 1 << 19;
const int LG = 19;
int mark[LG + 1][N];

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    int t; cin >> t;
    while (t--) {
        int n; cin >> n;
        vector a(n, 0);
        for (int &x : a) cin >> x;

        int mxb = 0, mxv = 1;
        while ((1 << mxb) <= n) ++mxb;
        mxv = 1 << mxb;
        --mxb, --mxv;

        int ban = 0;
        for (int i = 0; i+1 < n; ++i) {
            int mask = 0;
            for (int b = mxb; b >= 0; --b) {
                int x = (a[i] >> b) & 1;
                int y = (a[i+1] >> b) & 1;
                if (x == y) continue;
                
                if (x == 0) mask |= 1 << b;
                else {
                    if (mask) mark[b][mask] = 1;
                    else ban |= 1 << b;
                }
            }
        }

        for (int b = 0; b < mxb; ++b) {
            for (int mask = (1 << b); mask < mxv; ++mask) {
                for (int i = b+1; i < LG; ++i) {
                    if (mask & (1 << i)) mark[b][mask] |= mark[b][mask ^ (1 << i)];
                }
            }
        }

        int ans = 0;
        for (int x = 0; x <= n; ++x) {
            int pre = 0, good = (ban & x) == 0;
            for (int b = mxb; b >= 0; --b) {
                if (x & (1 << b)) {
                    if (mark[b][pre]) good = 0;
                }
                else pre |= 1 << b;
            }
            ans += good;
        }
        cout << ans << '\n';

        for (int i = 0; i <= mxb; ++i)
            for (int j = 0; j < mxv; ++j)
                mark[i][j] = 0;
    }
}

This topic was automatically closed after 2 hours. New replies are no longer allowed.

This topic was automatically opened after 1 minute.