ANDEQOR - Editorial

PROBLEM LINK:

Contest Division 1
Contest Division 2
Contest Division 3
Contest Division 4

Setter: Utkarsh Gupta
Tester: Abhinav Sharma, Aryan
Editorialist: Lavish Gupta

DIFFICULTY:

Easy-Medium

PREREQUISITES:

SOS-DP

PROBLEM:

Chef has a (0-indexed) binary string S of length N such that N is a power of 2.

Chef wants to find the number of pairs (i, j) such that:

  • 0 \leq i,j \lt N
  • S_{i|j} = S_{i\&j}

(Here | denotes the bitwise OR operation and \& denotes the bitwise AND operation)

Can you help Chef to do so?

QUICK EXPLANATION:

What if we fix i?
Let S_1 be the set of position bits which are set in i and S_2 be S_1’s complement. Now if we fix i|j, we can fix all the bits of j that are present in S_2, whereas all the bits of j which are in S_1 can take both 0 and 1. So, i\&j can take all the possible values of submasks of i, for a fix i and i|j.

How to sum up the answer for a fixed i?
Let’s define supermasks of i as the collection of all masks for which i is a submask.
For a fix i|j, which is a supermask of i, i\&j can be any of the submasks of i, we’ll take values such that character of S at index i|j and i\&j are the same. So, if c_1 and c_0 denote the count of supermasks of i which have values 1 and 0 respectively, and d_1 and d_0 count of submaks of i which have values 1 and 0 respectively, our answer for i is c_1 \cdot d_1 + c_0 \cdot d_0.

How to calculate these values optimally?
We can use the SOS Dynamic Programming technique to calculate the values for all i in O(N \cdot \log{N}) time.

EXPLANATION:

If we think naively for once, we can just iterate through all possible i and j and check that S_{i|j} = S_{i\&j}, but this will take O(N^2) time, which will exceed the Time Limit.

To optimize the approach, we can try to fix i for once and then see what happens. Now, we’ll try to analyze the bits of j and how it affects the answer. Let S_1 be the set of position bits which are set in i, and S_2 be the set of positions of bits which are not set in i. If we further fix i|j, the bits of j that are present in S_2 get fixed, whereas all the bits of j from S_1 can take both values 0 and 1. Since in i\&j is a submask of i, i\&j only contains set bits from S_1, and hence i\&j can be any of the submasks of i. Let’s define supermasks of i as the collection of all masks for which i is a submask. For a fix i|j, which is a supermask of i, i\&j can be any of the submasks of i, we’ll take values such that S at i|j and i\&j are the same. If d_1 and d_0 denote count of submaks of i which have values 1 and 0 respectively, if S_{i|j} = 1, we have d_1 values of i\&j, otherwise d_0 values of i\&j. Let c_1 and c_0 denote the count of supermasks of i which have values 1 and 0 respectively. So, for a fixed i the total answer is c_1 \cdot d_1 + c_0 \cdot d_0.

We want to calculate the values of c_1, c_0, d_1, d_0 for all values of i. If we iterate through all the submasks and supermasks for every i, the time taken will be O(3^{\log_2{n}}) = O(3^{20}). To further optimize this, we have to use the SOS Dynamic programming approach, which can calculate the sum of S at all submasks for every i in much less time.

TIME COMPLEXITY:

In the SOS Dynamic Programming Approach populating the DP values for all i will take O(N \cdot \log{N}) time and then adding answer at each i will take O(N) time. So our total time complexity will be O(N \cdot \log{N})

SOLUTION:

Setter's Solution
//Utkarsh.25dec
#include <bits/stdc++.h>
#define ll long long int
#define pb push_back
#define mp make_pair
#define mod 1000000007
#define vl vector <ll>
#define all(c) (c).begin(),(c).end()
using namespace std;
ll power(ll a,ll b) {ll res=1;a%=mod; assert(b>=0); for(;b;b>>=1){if(b&1)res=res*a%mod;a=a*a%mod;}return res;}
ll modInverse(ll a){return power(a,mod-2);}
const int N=2000023;
bool vis[N];
vector <int> adj[N];
long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);

            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }

            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }

            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}
ll sumN=0;
int good[N]={0};
ll A[N], F[N];
void solve()
{
    int N=readInt(2,(1<<20),'\n');
    sumN+=N;
    assert(sumN<=(1<<20));
    assert(good[N]==1);
    string s=readString(N,N,'\n');
    int n=0;
    int temp=1;
    while(temp!=N)
    {
        temp*=2;
        n++;
    }
    ll ans=0;
    {
        for(int i=0;i<(1<<n);i++)
        {
            if(s[i]=='0')
                A[i]=0;
            else
                A[i]=(1<<(n-(__builtin_popcount(i))));
        }
        for(int i=0;i<(1<<n);i++)
            F[i]=A[i];
        for(int i = 0;i < n; ++i) 
            for(int mask = 0; mask < (1<<n); ++mask)
            {
                if(mask & (1<<i))
                {
                    F[mask] += F[mask^(1<<i)];
                }
            }
        for(int i=0;i<(1<<n);i++)
        {
            if(s[i]=='0')
                continue;
            ans+=(F[i]/((1<<(n-(__builtin_popcount(i))))));
        }
    }
    {
        for(int i=0;i<(1<<n);i++)
        {
            if(s[i]=='1')
                A[i]=0;
            else
                A[i]=(1<<(n-(__builtin_popcount(i))));
        }
        for(int i=0;i<(1<<n);i++)
            F[i]=A[i];
        for(int i = 0;i < n; ++i) 
            for(int mask = 0; mask < (1<<n); ++mask)
            {
                if(mask & (1<<i))
                {
                    F[mask] += F[mask^(1<<i)];
                }
            }
        for(int i=0;i<(1<<n);i++)
        {
            if(s[i]=='1')
                continue;
            ans+=(F[i]/((1<<(n-(__builtin_popcount(i))))));
        }
    }
    cout<<ans<<'\n';
}
int main()
{
    #ifndef ONLINE_JUDGE
    freopen("input.txt", "r", stdin);
    freopen("output.txt", "w", stdout);
    #endif
    ios_base::sync_with_stdio(false);
    cin.tie(NULL),cout.tie(NULL);
    int T=readInt(1,1000,'\n');
    for(int i=1;i<=20;i++)
        good[(1<<i)]=1;
    while(T--)
        solve();
    assert(getchar()==-1);
    cerr << "Time : " << 1000 * ((double)clock()) / (double)CLOCKS_PER_SEC << "ms\n";
}
Tester's Solution
#include <bits/stdc++.h>
using namespace std;
 
 
/*
------------------------Input Checker----------------------------------
*/
 
long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);
 
            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }
 
            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }
 
            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}
 
 
/*
------------------------Main code starts here----------------------------------
*/
 
const int MAX_T = 1e5;
const int MAX_N = 1e5;
const int MAX_SUM_LEN = 1e5;
 
#define fast ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define ff first
#define ss second
#define mp make_pair
#define ll long long
#define rep(i,n) for(int i=0;i<n;i++)
#define rev(i,n) for(int i=n;i>=0;i--)
#define rep_a(i,a,n) for(int i=a;i<n;i++)
 
int sum_len = 0;
int max_n = 0;
int yess = 0;
int nos = 0;
int total_ops = 0;

const ll mod = 998244353;

ll po(ll x, ll n ){ 
    ll ans=1;
     while(n>0){
        if(n&1) ans=(ans*x)%mod;
        x=(x*x)%mod;
        n/=2;
     }
     return ans;
}

ll fun(vector<ll> &v, vector<int>&is_on, int n){

    rev(i, 19){
        rev(j,n-1){
            if((j>>i)&1) v[j^(1<<i)] += v[j];
        }
    }

    ll ret = 0;

    rep(i,n){
        if(!is_on[i]) continue;
        int tmp = __builtin_popcount(i);
        ll div = (1<<tmp);

        ret += v[i]/div;
    }

    return ret;
}

void solve()
{   

    int n = readIntLn(2, 1<<20);
    sum_len += n;
    max_n = max(max_n, n);
    string s = readStringLn(n,n);

    assert(__builtin_popcount(n)==1);

    vector<ll> v(n);
    vector<int> z(n,0);

    rep(i,n){
        if(s[i]=='0'){
            int tmp = __builtin_popcount(i);
            v[i] = (1<<tmp);
            z[i] = 1;
        }
        else v[i] = 0;
    }

    ll ans = fun(v, z, n);

    z.assign(n,0);

    rep(i,n){
        if(s[i]=='1'){
            int tmp = __builtin_popcount(i);
            v[i] = (1<<tmp);
            z[i] = 1;
        }
        else v[i] = 0;
    }

    ans += fun(v,z,n);

    cout<<ans<<'\n';
}
 
signed main()
{

    #ifndef ONLINE_JUDGE
    freopen("input.txt", "r" , stdin);
    freopen("output.txt", "w" , stdout);
    #endif
    fast;
    
    int t = 1;
    
    t = readIntLn(1,1000);
        
    for(int i=1;i<=t;i++)
    {    
       solve();
    }
   
    assert(getchar() == -1);
    assert(sum_len<=(1<<20));
 
    cerr<<"SUCCESS\n";
    cerr<<"Tests : " << t << '\n';
    cerr<<"Sum of lengths : " << sum_len << '\n';
    cerr<<"Maximum length : " << max_n << '\n';
    // cerr<<"Total operations : " << total_ops << '\n';
    //cerr<<"Answered yes : " << yess << '\n';
    //cerr<<"Answered no : " << nos << '\n';
}

Editorialist's Solution
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define dd double
#define endl "\n"
#define pb push_back
#define all(v) v.begin(),v.end()
#define mp make_pair
#define fi first
#define se second
#define vll vector<ll>
#define pll pair<ll,ll>
#define fo(i,n) for(int i=0;i<n;i++)
#define fo1(i,n) for(int i=1;i<=n;i++)
ll mod=1000000007;
ll n,k,t,m,q,flag=0;
ll power(ll a,ll b) {ll res=1;a%=mod; assert(b>=0); for(;b;b>>=1){if(b&1)res=res*a%mod;a=a*a%mod;}return res;}
// #include <ext/pb_ds/assoc_container.hpp> 
// #include <ext/pb_ds/tree_policy.hpp> 
// using namespace __gnu_pbds; 
// #define ordered_set tree<int, null_type,less<int>, rb_tree_tag,tree_order_statistics_node_update>
// ordered_set s ; s.order_of_key(a) -- no. of elements strictly less than a
// s.find_by_order(i) -- itertor to ith element (0 indexed)
ll min(ll a,ll b){if(a>b)return b;else return a;}
ll max(ll a,ll b){if(a>b)return a;else return b;}
ll gcd(ll a , ll b){ if(b > a) return gcd(b , a) ; if(b == 0) return a ; return gcd(b , a%b) ;}
int main() {
    ios_base::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);
    #ifdef NOOBxCODER
    freopen("input.txt", "r", stdin);
    freopen("output.txt", "w", stdout);
    #else 
    #define NOOBxCODER 0
    #endif
    cin>>t;
    //t=1;
    while(t--){
        cin>>n;
        string s;
        int a[n],b[n];
        cin>>s;
        fo(i,n){a[i] =s[i]-'0'; b[n - 1 -i]= s[i]-'0';  }
        
        //ll dp1[n][21];dp2[n][21];
        //fo(i,n)cout<<a[i]; cout<<endl; fo(i,n)cout<<b[i];cout<<endl;
        
        ll f1[n],f2[n];
        ll m= log2(n);
        
        for(int i = 0; i<n ; ++i)
        	f1[i] = a[i];// f2[i]
        for(int i = 0;i < m; ++i) for(int mask = 0; mask < n; ++mask){
	        if(mask & (1<<i))
	    	f1[mask] += f1[mask^(1<<i)];
        }
        for(int i = 0; i<n ; ++i)
        	f2[i] = b[i];
        for(int i = 0;i < m; ++i) for(int mask = 0; mask < n; ++mask){
	        if(mask & (1<<i))
	    	f2[mask] += f2[mask^(1<<i)];
        }
        
        ll ans=0;
        
        for(int i=0;i<n;i++){
            int c = __builtin_popcount(i);
            //cout<<f1[i]<<" "<<f2[i]<<endl;
            ans+= (f1[i]*f2[n-1-i]) + ((ll)(1<<c ) - f1[i] ) *((ll)(1<<(m-c))  -f2[n-1-i]);
        }
        cout<<ans<<endl;
    }
    cerr << "Time : " << 1000 * ((double)clock()) / (double)CLOCKS_PER_SEC << "ms\n";
    return 0;
}
2 Likes

Can you elaborate on the implementation of counting of supermasks and submasks for 0 values? What is the use of array b in this code? f1, f2 corresponds to which variables in the editorial?

1 Like

suppose number of 1’s be c1 and zeroes be c2
now we have choose two index i and j such that s[i] | s[j] = s[i] & s[j]
for indexes which are 1
~For two distinct index i and j
their contribution should be c1C2 * 2!
(multiplication of 2! is becuase ordering does not matter here so multiplication)
~for same i == j
so simply there will be c1 indexes of that type
total contribution for 1’s should be (c1C2 * 2!) + c1
similarly contribution of zero should be calculated in similar manner and adding up the answer as both are independent problem
Can anyone help me out like what i am doing wrong

Why isn’t the correct answer the number of pairs of 1’s plus the number of pairs of 0’s? For instance, in the third problem example, there seem to me to be the following 40 pairs - (0,0),(0,1),(0,2),(0,4),(0,5),(0,6),(1,1),(1,0),(1,2),(1,4),(1,5),(1,6),(2,2),(2,0),(2,1),(2,4),(2,5),(2,6),(3,3),(3,7),(4,4),(4,0),(4,1),(4,2),(4,5),(4,6),(5,5),(5,0),(5,1),(5,2),(5,4),(5,6),(6,6),(6,0),(6,1),(6,2),(6,4),(6,5),(7,7),(7,1).

The answer listed for example 3 is 32. Any assistance in helping me understand the problem would be appreciated.

1 Like

I think my solution is a bit more intuitive. Here’s my idea:

Instead of fixing i or j we fix the value of i|j. Note that i\&j is a submask of i|j . We’ll solve the problem for S_{i \&j}=S_{i|j}=1 and similarly, compute the answer for S_{i \&j}=S_{i|j}=0, the final answer would just be the sum of both results. Now before moving forward we need to solve the following subproblem.

Given x and y s.t. y \subseteq x find the number of ordered pairs (i,j) s.t. x=i|j and y=(i \&j) .

It’s trivial to prove that the number of ordered pairs (i,j) for the given problem is just 2^{B(x \oplus y)}, where B(x \oplus y) is the number of set bits in x\oplus y.
Now we have an \mathcal{O}(3^{\log_2N}) solution where we iterate over all indices from 1 to N and for each index (note that this is the OR value which we fixed), we iterate over its submasks and add its contribution to the final answer. (note that these submasks are the possible AND values that we can have corresponding to the currently fixed OR value). Slow Solution Code

So we can significantly speed this solution using an augmented SOS dp,
dp state would be the following

\text{dp}[i][j] \rightarrow Total Number of pairs (x,y) s.t. x|y=i and x\&y is a submask of i where only the first j bits are allowed to go to zero.
Base Case: \text{dp}[i][0]=[S_i=1]\ \forall \ 1 \le i \le N
The transitions would look like.

\text{dp}[i][j] = \begin{cases} \text{dp}[i][j-1]& \text{ $j^{\text{th}}$ bit of $i$ is unset}\\ \text{dp}[i][j-1]+2 \times \text{dp}[i\oplus2^j][j-1] & \mathrm{otherwise}\\ \end{cases}

Note that in case the j^{th} bit is set then we add 2 \times \text{dp}[i\oplus2^j][j-1] , here the extra 2 is multiplied to compensate for the fact that i and i\oplus 2^j differ at the j^{th} bit.

Final answer would just be \displaystyle\sum_{i=1}^N\text{dp}[i][\log_2N] \times [S_i=1]

CODE FOR REFERENCE

5 Likes

Agree. Answer should be N02 + N12 where N0 is number of zeros and N1 is number of ones. But this algorithm gets a WA verdict.

https://www.codechef.com/viewsolution/59170300

Consider this test case:

1
4 1011

Correct answer is 12 but your algorithm gives 10. @rphenson

Here are the possible 10 pairs I could find for this test case
(0, 0), (0,2), (0,3), (1,1), (2,0), (2,2), (2,3), (3,0), (3,2), (3,3)

which two am I missing?

Here is a brute-force N2 solution that produces a result of 10 for this case.
https://www.codechef.com/viewsolution/59177894

Your brute force solution is incorrect.

Correct Brute Force Solution
#include"bits/stdc++.h"
using namespace std ;

void solve(){
  int n,ans=0;string s  ;
  cin >> n >> s  ;
  for(int i=0;i<n;i++)
    for(int j=0;j<n;j++)
      ans+=(s[(i|j)]==s[(i&j)])  ;
  cout << ans << '\n' ;
}

int main(){
  int T  ;cin >> T  ;
  while(T--)
    solve() ;
}
These are the valid pairs
[i,j] = [0, 0]
[i,j] = [0, 2]
[i,j] = [0, 3]
[i,j] = [1, 1]
[i,j] = [1, 2]
[i,j] = [2, 0]
[i,j] = [2, 1]
[i,j] = [2, 2]
[i,j] = [2, 3]
[i,j] = [3, 0]
[i,j] = [3, 2]
[i,j] = [3, 3]
[Finished in 517ms]

OR, AND operations are to be performed on the indices, not on the values.

(1LL<<(__builtin_popcountll(i)))*(s[i]==s[0])
What does this thing do?

__builtin_popcount(i) returns the number of set bits in i,

(1LL<<_builtin_popcount(i)) is equivalent to 2^{\text{number of set bits in i}}

ans+=(1LL<<(__builtin_popcountll(i)))*(s[i]==s[0]) is equivalent to doing.

if(s[i]==s[0])
       ans+=(1LL<<(__builtin_popcountll(i))) ;
1 Like

uhm… i actually got this thing. I meant to ask what is this doing conceptually? I mean why doing this power of 2 of xor of mask and submask leads to answer?

We’re trying to iterate over all submasks of i and include the contribution of every submask, but while doing submask enumeration we don’t include the submask 0 (which is essentially a submask of every positive integer) hence I’ve handled that serparately outside the inner loop.

Note that contribution from every submask x of a number i is 2^{\text{set bits}({x \oplus i})}, \text{set bits}({x \oplus i}) is just the number of bits where the number i differs with its submask x and since x is 0 in this case hence it is equivalent to doing 2^{\text{set bits}({i})}

oh, thanks:)

Thanks for pointing this out. Makes sense now.