LOG_EQN - Editorial

PROBLEM LINK:

Practice

Contest

Setter: Sahil Chimnani

Tester: Teja Vardhan Reddy

Editorialist: Taranpreet Singh

DIFFICULTY:

Easy

PREREQUISITES:

Prefix and suffix sums.

PROBLEM:

Given a sequence A of N integers. A triplet (i, j, k) is defined to be good, if 1 \leq i < j < k \leq N and P = [A_i << (\lfloor log_2(A_j) \rfloor + \lfloor log_2(A_k)\rfloor+2)] + [A_j<<(\lfloor log_2(A_k) \rfloor +1)] +A_k contain odd number of bits in its binary representation, where << represent left shift operation and x << y is equivalent to x*2^y

Find the number of good triplets modulo 10^9+7

QUICK EXPLANATION

  • See that for triplet (i, j, k), the binary representation of P is nothing, but appending the binary representations of A_i, A_j and A_k. So we need to compute the number of triplets where the sum of number of set bits over all three integers is odd.
  • We can try all elements as the middle element of triplet and count the number of triplets with fixed element as middle element, by using the prefix and suffix sums, keeping count of number of numbers with even and an odd number of bits.

EXPLANATION

First of all, let us observe P. For a fact, a number x has exactly \lfloor log_2(x) \rfloor +1 bits in its binary representation.

Let’s assume Q = y << (\lfloor log_2(x) \rfloor +1)+x. Now, since y is left shifted exactly \lfloor log_2(x) \rfloor +1 bits, binary representation of y<<(\lfloor log_2(x) \rfloor +1) have (\lfloor log_2(x) \rfloor +1) all zeroes. So when x is added, there is no carry forward, and the expression is equivalent to appending the binary representations of x and y. Hence number of set bits in Q is sum of number of set bits in x and y

For example, consider x = 9 and y = 12. We have \lfloor log_2(9) \rfloor+1 = 3+1 = 4, hence Q = (12 << 4) + 9 = 12*16+9 = 201
Binary representation of x is 1001 and binary representation of y is 1100 and Binary representation of Q is 11001001

We can similarly prove that P = [A_i << (\lfloor log_2(A_j) \rfloor + \lfloor log_2(A_k)\rfloor+2)] + [A_j<<(\lfloor log_2(A_k) \rfloor +1)] +A_k is nothing, but appending binary representations of A_i, A_j and A_k

Since we only care about parity, let’s define an array B of length N where B_i = 1 if the binary representation of A_i contains odd number of bits, otherwise B_i = 0.

Now the problem is reduced to choosing triplets (i, j, k) such that B_i+B_j+B_k is odd. Let us try fixing all positions as middle elements and counting the number of triplets with fixed middle elements.

Let us suppose for some position j, B_j = 1, then we need the number of pairs (i, k) such that 1 \leq i < j < k \leq N and B_i+B_k is even. Either we can have both B_i and B_k odd, or both even. Since j is fixed now, choosing i and k is independent, so we can choose any i with given parity of B_i and pair it with any k with given parity of B_k

Similarly, if B_j = 0, we need B_i+B_k to be odd, which require exactly one of B_i and B_k to be odd. Hence number of triplets with B_j as mid is

(number of i such that B_i is odd) \times (number of k such that B_k is even) + (number of i such that B_i is even) \times (number of k such that B_k is odd)

We can compute these using prefix and suffix sums and take the sum of triplets over all fixed middle elements.

TIME COMPLEXITY

The time complexity is O(N*b) per test case where b = 31 (for computing number of set bits).

SOLUTIONS:

Setter's Solution
#include<bits/stdc++.h>
#define ll long long
#define mod 1000000007
using namespace std;
 
ll no_of_1(ll num){
	ll counts = 0;
 
	while(num!=0){
		if( num&1 )
			counts++;
		num /= 2;
	}

	return counts;
}
 
ll modadd(ll a,ll b){
	a = a % mod;
	b = b % mod;
	ll ans = (a + b)% mod;
	return ans;
}
 
 
int main(){
	ios_base::sync_with_stdio(false);
	cin.tie(NULL);
	srand(time(NULL));
	int t;
	cin >> t;
 
	while(t--){
		int n;
		cin >> n;
		vector<ll> v(n);
		for(int i = 0; i < n ; i++)
			cin >> v[i];
		for(int i = 0; i < n ; i++)
		{
			v[i] = no_of_1(v[i]);
		}
 
		vector<ll> odd_prefix(n);
		vector<ll> even_prefix(n);
		odd_prefix[0] = v[0] % 2;
		even_prefix[0] = (v[0] + 1) % 2;
		for(int i = 1; i < n ; i++)
		{
			odd_prefix[i] = odd_prefix[i-1] + (v[i] % 2);	
			even_prefix[i] = even_prefix[i-1] + ((v[i] + 1) % 2);	
		}
		ll ans = 0;
		for(int i = 1; i < n-1 ; i++)
		{
			if( v[i] % 2 == 1)	//odd
			{
				ll temp = odd_prefix[i-1] * (odd_prefix[n-1] - odd_prefix[i]);	// odd,_,odd
				temp = temp % mod;
				ans = modadd(ans,temp);
 
				temp = even_prefix[i-1] * (even_prefix[n-1] - even_prefix[i]);	// even,_,even
				temp = temp % mod;
				ans = modadd(ans,temp);
			}
			else				//even
			{
				ll temp = odd_prefix[i-1] * (even_prefix[n-1] - even_prefix[i]);	// odd,_,even
				temp = temp % mod;
				ans = modadd(ans,temp);
 
				temp = even_prefix[i-1] * (odd_prefix[n-1] - odd_prefix[i]);	// even,_,odd
				temp = temp % mod;
				ans = modadd(ans,temp);
			}
		}
		cout << ans << "\n";
	}
 
	return 0;
}
Tester's Solution
//teja349
#include <bits/stdc++.h>
#include <vector>
#include <set>
#include <map>
#include <string>
#include <cstdio>
#include <cstdlib>
#include <climits>
#include <utility>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <iomanip>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp> 
//setbase - cout << setbase (16); cout << 100 << endl; Prints 64
//setfill -   cout << setfill ('x') << setw (5); cout << 77 << endl; prints xxx77
//setprecision - cout << setprecision (14) << f << endl; Prints x.xxxx
//cout.precision(x)  cout<<fixed<<val;  // prints x digits after decimal in val

using namespace std;
using namespace __gnu_pbds;

#define f(i,a,b) for(i=a;i<b;i++)
#define rep(i,n) f(i,0,n)
#define fd(i,a,b) for(i=a;i>=b;i--)
#define pb push_back
#define mp make_pair
#define vi vector< int >
#define vl vector< ll >
#define ss second
#define ff first
#define ll long long
#define pii pair< int,int >
#define pll pair< ll,ll >
#define sz(a) a.size()
#define inf (1000*1000*1000+5)
#define all(a) a.begin(),a.end()
#define tri pair<int,pii>
#define vii vector<pii>
#define vll vector<pll>
#define viii vector<tri>
#define mod (1000*1000*1000+7)
#define pqueue priority_queue< int >
#define pdqueue priority_queue< int,vi ,greater< int > >
#define flush fflush(stdout) 
#define primeDEN 727999983
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());

// find_by_order()  // order_of_key
typedef tree<
int,
null_type,
less<int>,
rb_tree_tag,
tree_order_statistics_node_update>
ordered_set;

int getcnt(int val){
	int ans=0;
	while(val){
		if(val%2)
			ans++;
		val/=2;
	}
	return ans;
}
ll dp[123456][4][2];
int a[123456];
int main(){
	std::ios::sync_with_stdio(false); cin.tie(NULL);
	int t;
	cin>>t;
	while(t--){
		int n;
		cin>>n;
		int i,j,k;
		int val;
		f(i,1,n+1){
			cin>>val;
			a[i]=getcnt(val)%2;
		}
		rep(i,4){
			rep(j,2){
				dp[0][i][j]=0;
			}
		}
		dp[0][0][0]=1;
		f(i,1,n+1){
			rep(j,4){
				rep(k,2){
					dp[i][j][k]=dp[i-1][j][k];
				}
			}
			rep(j,3){
				rep(k,2){
					dp[i][j+1][(k+a[i])%2]+=dp[i-1][j][k];
					dp[i][j+1][(k+a[i])%2]%=mod;
				}
			}
		}

		cout<<dp[n][3][1]<<endl;



	}
	return 0;   
}
Editorialist's Solution
import java.util.*;
import java.io.*;
import java.text.*;
public class Main{
	//SOLUTION BEGIN
	//Into the Hardware Mode
	long mod = (long)1e9+7;
	void pre() throws Exception{}
	void solve(int TC)throws Exception{
	    int n = ni();
	    int[] a = new int[1+n];
	    for(int i = 1; i<= n; i++)
	        a[i] = bit(nl())%2;
	    int[][] pre = new int[2][1+n], suf = new int[2][2+n];
	    for(int i = 1; i<= n; i++){
	        pre[0][i] = pre[0][i-1];
	        pre[1][i] = pre[1][i-1];
	        pre[a[i]][i]++;
	    }
	    for(int i = n; i>= 1; i--){
	        suf[0][i] = suf[0][i+1];
	        suf[1][i] = suf[1][i+1];
	        suf[a[i]][i]++;
	    }
	    long ans = 0;
	    for(int i = 1; i<= n; i++){
	        if(a[i] == 0)
	            ans += pre[0][i-1]*(long)suf[1][i+1]+pre[1][i-1]*(long)suf[0][i+1];
	        else
	            ans += pre[0][i-1]*(long)suf[0][i+1]+pre[1][i-1]*(long)suf[1][i+1];
	        ans %= mod;
	    }
	    pn(ans);
	}
	//SOLUTION END
	void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
	void exit(boolean b){if(!b)System.exit(0);}
	long IINF = (long)1e18;
	final int INF = (int)1e9, MX = (int)2e6+5;
	DecimalFormat df = new DecimalFormat("0.00");
	double PI = 3.141592653589793238462643383279502884197169399, eps = 1e-7;
	static boolean multipleTC = true, memory = false, fileIO = false;
	FastReader in;PrintWriter out;
	void run() throws Exception{
	    if(fileIO){
	        in = new FastReader("C:/users/user/desktop/inp.in");
	        out = new PrintWriter("C:/users/user/desktop/out.out");
	    }else {
	        in = new FastReader();
	        out = new PrintWriter(System.out);
	    }
	    //Solution Credits: Taranpreet Singh
	    int T = (multipleTC)?ni():1;
	    pre();
	    for(int t = 1; t<= T; t++)solve(t);
	    out.flush();
	    out.close();
	}
	public static void main(String[] args) throws Exception{
	    if(memory)new Thread(null, new Runnable() {public void run(){try{new Main().run();}catch(Exception e){e.printStackTrace();}}}, "1", 1 << 28).start();
	    else new Main().run();
	}
	int find(int[] set, int u){return set[u] = (set[u] == u?u:find(set, set[u]));}
	int digit(long s){int ans = 0;while(s>0){s/=10;ans++;}return ans;}
	long gcd(long a, long b){return (b==0)?a:gcd(b,a%b);}
	int gcd(int a, int b){return (b==0)?a:gcd(b,a%b);}
	int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
	void p(Object o){out.print(o);}
	void pn(Object o){out.println(o);}
	void pni(Object o){out.println(o);out.flush();}
	String n()throws Exception{return in.next();}
	String nln()throws Exception{return in.nextLine();}
	int ni()throws Exception{return Integer.parseInt(in.next());}
	long nl()throws Exception{return Long.parseLong(in.next());}
	double nd()throws Exception{return Double.parseDouble(in.next());}

	class FastReader{
	    BufferedReader br;
	    StringTokenizer st;
	    public FastReader(){
	        br = new BufferedReader(new InputStreamReader(System.in));
	    }

	    public FastReader(String s) throws Exception{
	        br = new BufferedReader(new FileReader(s));
	    }

	    String next() throws Exception{
	        while (st == null || !st.hasMoreElements()){
	            try{
	                st = new StringTokenizer(br.readLine());
	            }catch (IOException  e){
	                throw new Exception(e.toString());
	            }
	        }
	        return st.nextToken();
	    }

	    String nextLine() throws Exception{
	        String str = "";
	        try{
	            str = br.readLine();
	        }catch (IOException e){
	            throw new Exception(e.toString());
	        }
	        return str;
	    }
	}
}

Feel free to share your approach. Suggestions are welcomed as always. :slight_smile:

I read the editorial and tried to implement it in C++, but it’s giving me WA. Could someone please help with some test cases that my solution fails on? Thanks!

https://www.codechef.com/viewsolution/41106012