HIRING - Editorial

PROBLEM LINK:

Practice

Contest: Division 1

Contest: Division 2

Setter: Rami

Tester: Roman Bilyi

Editorialist: Taranpreet Singh

DIFFICULTY:

Easy-Medium

PREREQUISITES:

Dynamic Programming, Meet-in-the-middle.

PROBLEM:

You are given N strings of length M each consisting of 0 and 1 only. Consider all those subsequences of these strings such that for any two consecutive strings, there exists no such position p such that p th character of both strings is 1. Find the number of such subsequences modulo 10^9+7

EXPLANATION

The first thing to notice is, that condition “for any two consecutive strings, there exists no such position p such that p th character of both strings is 1” means that bitwise AND of both strings should be zero if both strings represent a number each in binary notation.

So, ditching strings, the problem becomes, Given N integers, each integer in the range [0, 2^M-1], Count the number of valid subsequences where a valid subsequence is the one in which bitwise AND of all pairs of two consecutive integers is 0.

Let’s consider a slow solution first.

Let us create a DP table where DP_x denotes the number of subsequences whose last integers is x, after considering first p elements.

Now considering next element y, the number of subsequences ending with y is \sum DP_z where bitwise AND of y and z is zero. Computing this summation takes 2^M time. Then, we can increase DP_y by the number of ways.

Here, Calculating the number of ways take O(2^M) time and updating DP table takes O(1) time, resulting in O(N*2^M) time complexity which is not feasible.

Let’s try another approach. Now DP_x stores the number of subsequences such that the bitwise AND of x and last element in subsequence is zero. Considering the next element y, the number of valid subsequences ending with this element is DP_y. To update DP table, we have to add DP_y to all DP_z for all z such that bitwise AND of y and z is zero.

Here, Calculating the number of ways take O(1) time and updating DP table takes O(2^M) time, resulting again in O(N*2^M) time.

Can we try merging these two approaches? That’s where meet-in-the-middle comes in.

Let us make a Two-dimensional DP table and assuming M = 16, DP_{x, y} denoting the number of subsequences such that upper 8 bits of the last number is same as x and lower 8 bits of the last number do not share any bit with y. Note that 0 \leq x, y < 2^8.

Let lo(x) return the lower 8 bits of x and hi(x) return the upper 8 bits of x.

So, assuming we have this table calculated for all previous values, how do we calculate the number of subsequences ending with the current element, say w. We can see, that if bitwise AND of x and hi(w) is zero, then DP_{x, lo(w)} contribute to the number of ways. We can try all values of x here.

Now, Assuming we have counted the number of ways, we need to update this modified DP table.

By our definition, only DP_{hi(w), y} shall be increased by the number of subsequences ending at this position, where bitwise AND of y and lo(w) is zero.

In this solution, the time complexity of both calculation and update is O(N*2^{M/2}) which is sufficient to pass the time limit.

TIME COMPLEXITY

The time complexity for this problem is O(N*2^{M/2}) per test case.

SOLUTIONS:

Setter's Solution
#include "bits/stdc++.h"
#pragma GCC optimize("Ofast")
#pragma GCC target("sse,sse2,sse3,ssse3,sse4,avx,avx2")
using namespace std;

#define FOR(i,a,b) for (int i = (a); i < (b); i++)
#define RFOR(i,b,a) for (int i = (b) - 1; i >= (a); i--)
#define ITER(it,a) for (__typeof(a.begin()) it = a.begin(); it != a.end(); it++)
#define FILL(a,value) memset(a, value, sizeof(a))

#define SZ(a) (int)a.size()
#define ALL(a) a.begin(), a.end()
#define PB push_back
#define MP make_pair

typedef long long Int;
typedef vector<int> VI;
typedef pair<int, int> PII;

const double PI = acos(-1.0);
const int INF = 1000 * 1000 * 1000;
const Int LINF = INF * (Int) INF;
const int MAX = 100007;

const int MOD = 1000000007;

const double Pi = acos(-1.0);

int dp[1 << 8][1 << 8];

int main(int argc, char* argv[])
{
	// freopen("in.txt", "r", stdin);
	//ios::sync_with_stdio(false); cin.tie(0);

	int t;
	cin >> t;
	FOR(tt,0,t) {
	    int n, m;
	    cin >> n >> m;
	    FILL(dp, 0);
	    int res = 0;
	    FOR(i,0,n)
	    {
	        string s;
	        cin >> s;
	        int x = 0;
	        FOR(j,0,m)
	        {
	            if (s[j] == '1')
	                x += (1 << j);
	        }

	        int r = 1;

	        FOR(i,0,1 << 8)
	        {
	            if (((x >> 8) & i) == 0)
	            {
	                r += dp[i][x & 255];
	                r %= MOD;
	            }
	        }
	        // cerr << r << endl;
	        res += r;
	        res %= MOD;
	        FOR(i,0,1 << 8)
	        {
	            if ((i & (x & 255)) == 0)
	            {
	                dp[x >> 8][i] += r;
	                dp[x >> 8][i] %= MOD;
	            }
	        }
	    }
	    cout << res << endl;
	}

	cerr << 1.0 * clock() / CLOCKS_PER_SEC << endl;

	
}
Tester's Solution
#include "bits/stdc++.h"
#pragma GCC optimize("Ofast")
#pragma GCC target("sse,sse2,sse3,ssse3,sse4,avx,avx2")
using namespace std;
 
#define FOR(i,a,b) for (int i = (a); i < (b); i++)
#define RFOR(i,b,a) for (int i = (b) - 1; i >= (a); i--)
#define ITER(it,a) for (__typeof(a.begin()) it = a.begin(); it != a.end(); it++)
#define FILL(a,value) memset(a, value, sizeof(a))
 
#define SZ(a) (int)a.size()
#define ALL(a) a.begin(), a.end()
#define PB push_back
#define MP make_pair
 
typedef long long Int;
typedef vector<int> VI;
typedef pair<int, int> PII;
 
const double PI = acos(-1.0);
const int INF = 1000 * 1000 * 1000;
const Int LINF = INF * (Int) INF;
const int MAX = 100007;
 
const int MOD = 1000000007;
 
const double Pi = acos(-1.0);
 
int dp[1 << 8][1 << 8];
 
int main(int argc, char* argv[])
{
	// freopen("in.txt", "r", stdin);
	//ios::sync_with_stdio(false); cin.tie(0);
 
	int t;
	cin >> t;
	FOR(tt,0,t) {
	    int n, m;
	    cin >> n >> m;
	    FILL(dp, 0);
	    int res = 0;
	    FOR(i,0,n)
	    {
	        string s;
	        cin >> s;
	        int x = 0;
	        FOR(j,0,m)
	        {
	            if (s[j] == '1')
	                x += (1 << j);
	        }
 
	        int r = 1;
 
	        FOR(i,0,1 << 8)
	        {
	            if (((x >> 8) & i) == 0)
	            {
	                r += dp[i][x & 255];
	                r %= MOD;
	            }
	        }
	        // cerr << r << endl;
	        res += r;
	        res %= MOD;
	        FOR(i,0,1 << 8)
	        {
	            if ((i & (x & 255)) == 0)
	            {
	                dp[x >> 8][i] += r;
	                dp[x >> 8][i] %= MOD;
	            }
	        }
	    }
	    cout << res << endl;
	}
 
	cerr << 1.0 * clock() / CLOCKS_PER_SEC << endl;
 
	
} 
Editorialist's Solution
import java.util.*;
import java.io.*;
import java.text.*;
class HIRING{
	//SOLUTION BEGIN
	void pre() throws Exception{}
	void solve(int TC) throws Exception{
	    int n = ni(), m = ni();
	    long[][] cnt = new long[1<<8][1<<8];//cnt[i][j] -> number of ways to choose subsequences, such that upper 8 bits of last chosen element is i and lower 8 bits does not share any bit with j        
	    long ans = 0;
	    while(n-->0){
	        int x = Integer.parseInt(n(), 2);
	        long ways = 1; //Considering a subsequence starting at current position.
	        
	        for(int i = 0; i< 256; i++){
	            if(((x>>8)&i) == 0){//Comparing upper 8 bits with i
	                ways = (ways+cnt[i][x&255])%mod;
	            }
	        }
	        ans = (ans+ways)%mod;
	        for(int i = 0; i< 256; i++){
	            //If lower 8 bits of x do not match, all ways ending at current element contribute to cnt[i][j] such that upper 8 bits are i and lower 8 bits do not share a bit with j
	            if(((i&x)&255) == 0){
	                cnt[x>>8][i] = (cnt[x>>8][i]+ways)%mod;
	            }
	        }
	    }
	    pn(ans);
	}
	long mod = (long)1e9+7;
	//SOLUTION END
	void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
	DecimalFormat df = new DecimalFormat("0.00000000000");
	static boolean multipleTC = true;
	FastReader in;PrintWriter out;
	void run() throws Exception{
	    in = new FastReader();
	    out = new PrintWriter(System.out);
	    //Solution Credits: Taranpreet Singh
	    int T = (multipleTC)?ni():1;
	    pre();for(int t = 1; t<= T; t++)solve(t);
	    out.flush();
	    out.close();
	}
	public static void main(String[] args) throws Exception{
	    new HIRING().run();
	}
	int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
	void p(Object o){out.print(o);}
	void pn(Object o){out.println(o);}
	void pni(Object o){out.println(o);out.flush();}
	String n()throws Exception{return in.next();}
	String nln()throws Exception{return in.nextLine();}
	int ni()throws Exception{return Integer.parseInt(in.next());}
	long nl()throws Exception{return Long.parseLong(in.next());}
	double nd()throws Exception{return Double.parseDouble(in.next());}

	class FastReader{
	    BufferedReader br;
	    StringTokenizer st;
	    public FastReader(){
	        br = new BufferedReader(new InputStreamReader(System.in));
	    }

	    public FastReader(String s) throws Exception{
	        br = new BufferedReader(new FileReader(s));
	    }

	    String next() throws Exception{
	        while (st == null || !st.hasMoreElements()){
	            try{
	                st = new StringTokenizer(br.readLine());
	            }catch (IOException  e){
	                throw new Exception(e.toString());
	            }
	        }
	        return st.nextToken();
	    }

	    String nextLine() throws Exception{
	        String str = "";
	        try{   
	            str = br.readLine();
	        }catch (IOException e){
	            throw new Exception(e.toString());
	        }  
	        return str;
	    }
	}
}

Feel free to share your approach, if you want to. (even if its same :stuck_out_tongue: ) . Suggestions are welcomed as always had been. :slight_smile:

7 Likes

Are elements x and y certain bits? And also can someone please explain what a DP table is?
Thanks

1 Like

Google search exists okay?

Both x and y are values with 8 at most bits, that is, 0 \leq x, y < 2^8.

If DP_{x, y} denote the number of sequences with their last element being z, then significant 8 bits of z are given by x and non-significant 8 bits of z do not share any bit with y.

You can basically assume it’s a two-dimensional array, ignoring DP table.

Writing a book on it. Would be available to those who say “Java is best” :stuck_out_tongue:

1 Like

I hope you will finish your book before 14 OCT else your efforts might get wasted.

1 Like

Can I get a test-case to fail above solution?
Thanks in advance :slight_smile:

Yes, When CountArr[i] gets overflowed.

And another problem is that accumulate function is using integer variable to accumulate sum. Which is being overflowed.

modified code is here.

And no need to mention that TLE remains. :stuck_out_tongue: