CP03 Editorial

Problem Link

Click here

Difficulty

Medium

Solution

Let S be the set containing the bits that are set in x and S^c be the set containing the bits that are not set, then the bitwise OR value of some subarray arr[i...j] will be equal to x if and only if corresponding to each bit in S, we have some element in arr[i...j] having that bit set and corresponding to each bit in S^c, there is no element in arr[i...j] having that bit set.
Now we iterate over the array and maintain another array prev, where prev[k] denotes the rightmost index having k^{th} bit set. Lets say we are currently at index i and high = min(prev[k]) where k\in S and low = max(prev[k]) where k\in S^c, then we can say that
arr[low+1...i], arr[low+2...i]....arr[high...i] all have bitwise OR value equal to x, hence we can add (high-low) to our answer. Note that high > low should hold true.

Code

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
int main()
{
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    cout.tie(NULL);
    int T;
    cin>>T;
    while(T--)
    {
        int n,x;
        cin>>n>>x;
        vector<int>arr(n+1);
        for(int i=1; i<=n; i++)
         cin>>arr[i];
        if(x==0){
            ll cnt=0;
            ll ans=0;
            for(int i=1; i<=n; i++){
                if(arr[i]==0){
                    cnt++;
                }
                else{
                    ans=ans+cnt*(cnt+1)/2;
                    cnt=0;
                }
            }
            ans=ans+cnt*(cnt+1)/2;
            cout<<ans<<"\n";
        }
        else{
            ll ans=0;
            vector<int>av(10);
            int pr=1;
            for(int j=0; j<10; j++){
                if(x&pr)
                 av[j]=1;
                pr*=2;
            }
            vector<int>pos(10);
            for(int i=1; i<=n; i++){
                int pr=1;
                for(int j=0; j<10; j++){
                    if(arr[i]&pr)
                     pos[j]=i;
                    pr*=2;
                }
                int mx1=i,mx2=0;
                for(int j=0; j<10; j++){
                    if(av[j])
                     mx1=min(mx1,pos[j]);
                    else
                     mx2=max(mx2,pos[j]);
                }
                ans=ans+max(0,mx1-mx2);
            }
            cout<<ans<<"\n";
        }
    }
    return 0;
}