CNTSUB - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: poetic_soul
Testers: IceKnight1093, tejas10p
Editorialist: IceKnight1093

DIFFICULTY:

2636

PREREQUISITES:

Iterating through submasks, dynamic programming

PROBLEM:

An array A is said to be good if A_i \mid A_j = (A_i \oplus A_j) + A_j for every i \lt j.

Given a permutation P, find the number of its good subsequences.

EXPLANATION:

The condition for an array to be good is a bit weird, so let’s analyze it a bit.

It can be seen that the given condition reduces to “A_j is a submask of A_i when both are written in binary”.

How?

Let’s look at the relation between A_i \mid A_j and A_i \oplus A_j, bit by bit.

  • If A_i and A_j both have a bit set, A_i \mid A_j has it set but A_i \oplus A_j doesn’t.
  • If one of A_i and A_j has a bit set and the other doesn’t, both A_i \mid A_j and A_i \oplus A_j have it set
  • If both A_i and A_j have a bit unset, A_i\mid A_j and A_i\oplus A_j both have it unset.

Together, these should tell you that A_i \mid A_j = (A_i\oplus A_j) + (A_i \& A_j) where \& denotes bitwise AND.

Now, if A_i \mid A_j = (A_i \oplus A_j) + A_j, then along with the above equation we know that A_j = A_i \& A_j, i.e, A_j is a submask of A_i.

Also notice that the submask relation is transitive, that is, if A_j is a submask of A_i and A_k is a submask of A_j, then A_k is a submask of A_i.
So, we only need to ensure that each pair of adjacent elements in the subsequence satisfies the submask relation: this automatically satisfies it for every pair.

With that in mind, let’s move to a full solution for our permutation P.
I’ll use the notation P_i \subseteq P_j to denote P_i being a submask of P_j.

Let dp_i denote the number of good subsequences ending at position i.
As noted above, we only need to care about adjacent elements, so suppose we fix the previous element of the subsequence as position j. Then we can take any subsequence ending at P_j and append P_i to it to obtain a subsequence ending at position i.

This gives us the relation:

dp_i = 1 + \sum_{\substack{1 \leq j \lt i \\ P_j \subseteq P_i}} dp_j

(we add 1 to account for the 1-length subsequence [P_i])

This is a solution in \mathcal{O}(N^2), which is too slow for the given constraints.

To speed it up, notice that we’re doing some wasteful iteration: we don’t really care about all j \lt i, we only care about those j for which P_i \subseteq P_j; i.e supermasks of P_i.

So, let’s instead iterate across all supermasks of P_i and add their dp values to dp_i.

The complexity of doing this is \mathcal{O}(3^{\log N}) when summed up across all indices.

How?

For convenience, let’s assume N = 2^B - 1.

Notice that our process iterates across every pair (x, y) such that 1 \leq x \leq y \leq N and x \subseteq y exactly once.
The number of such pairs is 3^B: a proof can be found here

Notice that the link in the prerequisites details how to iterate across all submasks, while we care about supermasks.
Modifying the approach given there to iterate supermasks instead of submasks is not hard and left as an exercise :slight_smile:

TIME COMPLEXITY:

\mathcal{O}(3^{\log N}) per testcase.

CODE:

Tester's code (C++)
#include <bits/stdc++.h>
#define mod 1000000007
#define ll long long int
using namespace std;

int main() {
	int t;
	cin >> t;
	while(t--) {
	    int n;
	    cin >> n;
	    int p[n];
	    for(int i = 0; i < n; i++) cin >> p[i];
	    reverse(p, p + n);
	    ll ans[n + 1], res = 0;
	    memset(ans, 0, sizeof(ans));
	    for(int i = 0; i < n; i++) {
	        int now = p[i];
	        ans[now] = 1;
	        int masks = (now - 1)&now;
	        while(masks) {
	            ans[now] += ans[masks];
	            ans[now] %= mod;
	            masks -= 1;
	            masks &= now;
	        }
	        res += ans[now];
	        res %= mod;
	    }
	    cout << res << "\n";
	}
	return 0;
}
Editorialist's code (Python)
mod = 10**9 + 7
for _ in range(int(input())):
	n = int(input())
	lim = 1
	while lim < n: lim *= 2
	p = list(map(int, input().split()))

	dp = [0]*(lim + 1)
	for x in p:
		mask = x
		while mask <= lim:
			dp[x] += dp[mask]
			if dp[x] >= mod: dp[x] -= mod
			mask = (mask + 1) | x
		dp[x] += 1
		if dp[x] >= mod: dp[x] -= mod
	print(sum(dp) % mod)