CAC_XORSUBS - Editorial

PROBLEM LINK:

Practice
Save the humanity

Authors: chef_hamster
Testers: chef_hamster
Editorialist: chef_hamster

Difficulty

Hard

PROBLEM:

Given a binary string \bold S and a target binary string \bold Q, a subsequence, \bold K_s, of \bold S is called good if:

  • The xor of all possible non-empty subsequences of \bold K_s is equal to given \bold Q.

Find how many such good subsequences exist for \bold S.

Since the answer can be very large, take mod 10^9+7.

Prerequisites:

  • Number theory
  • Combinatorics
  • Basic knowledge of XOR
  • Patience (very important :laughing: )

Hint:

1st Hint

Only MSB contributes to XOR of all possible subsequences. (All other bits appear even number of times at each position)

2nd Hint

For any binary string of length n+1, if MSB = 1, suppose X_b is the XOR of all possible subsequences, then any b[i] = n \choose i % 2.

Proofs

Solution in C++
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
#define endl "\n"
#define fio ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL);
#define input(arr,n) for(int i=0;i<n;i++) cin>>arr[i];
#define fr(i,n) for(int i=0;i<n;i++)
#define rf(i,n) for(int i=n-1;i>=0;i--)
#define mod2 (ll)998244353
#define mod (ll)1000000007
#define yy "YES\n"
#define nn "NO\n"
ll binexp(ll a, ll b) {ll res = 1;while(b > 0){if(b & 1)res = res * a;a = a * a;b >>= 1;}return res;}
ll power(ll x,ll y, ll p){ll res = 1;x = x % p;if (x == 0) return 0;while (y > 0){if (y & 1)res = (res*x) % p;y = y>>1;x = (x*x) % p;}return res;}
void to_bin(ll n, ll arr[]){ll i=0;while(n){arr[31-i]+=n%2;n/=2;i++;}return;}
int to_deci(ll arr[]){ll ans=0;for(ll i=0;i<32;i++){if(arr[i]%2)ans+=ll(1<<(31-i));}return ans;}
ll min(ll x,ll y){return (x>y?y:x);}
ll max(ll x,ll y){return (x<y?y:x);}
ll gcd(ll a,ll b){return b == 0 ? a : gcd(b, a % b);}
ll modInverse(ll A, ll M){ll m0 = M;ll y = 0, x = 1;if (M == 1)return 0;while (A > 1) {ll q = A / M;ll t = M;M = A % M, A = t;t = y;y = x - q * y;x = t;}if (x < 0)x += m0;return x;}
/*------------------------------------------------------------*/


#define MAXN 1000000


/*--Precalculate highest power of two which divides 1 to n!--*/
vector<ll> powerOfTwo(MAXN,0);

void P2(){
    for(int i=2;i<MAXN;i++){
        powerOfTwo[i] = powerOfTwo[i-1] + log2(i&(~(i-1)));
    }
}
/*--------------------------------------------------------*/


bool isSet(ll i, ll n){
    return powerOfTwo[n] - powerOfTwo[i] - powerOfTwo[n-i]==0;
}

/*--------------------------------------------------------*/

void fillNcr(map<ll,ll> &ncr,ll n, ll r){
    ncr[r] = 1;
    for(ll i=r+1;i<=n;i++){
        ncr[i] = (ncr[i-1]%mod * (i%mod)%mod * modInverse(i-r,mod)%mod)%mod;
    }
    // for(int i=r;i<=n;i++)cout<<ncr[i]<<" ";
}

/*--------------------------------------------------------*/
int main() {
    #ifndef ONLINE_JUDGE
    freopen("inputE0.txt","r",stdin);
    freopen("outputE0.txt","w",stdout);
    #endif
    fio;
    P2();
    int t=1;
    cin>>t;
    while(t--){
        ll n,m;
        cin>>n>>m;
        string s,q;
        cin>>s>>q;
        string ans = "1";
        for(int i=1;i<m;i++){
            ans += isSet(i,m-1)==true?"1":"0";
        }
        if(ans!=q){
            cout<<0<<endl;
        }
        else{
            ll ans = 0;
            map<ll,ll> ncr;
            fillNcr(ncr,n,m-1);
            for(int i=0;i<=n-m;i++){
                if(s[i]=='1'){
                    ans = (ans%mod + ncr[n-i-1]%mod)%mod;
                }
            }
            cout<<ans<<endl;
        }
    }
    return 0;
}
Solution in Python
MAX_N = 1000000
mod = int(1e9+7)
import math

dp = [0 for i in range(MAX_N)]

powerOfTwo = [0 for i in range(MAX_N)]


def modInverse(A, M):
 
    g = gcd(A, M)
 
    if (g != 1):
        print("Inverse doesn't exist")
 
    else:
 
        return power(A, M - 2, M)
 
 
 
def power(x, y, M):
 
    if (y == 0):
        return 1
 
    p = power(x, y // 2, M) % M
    p = (p * p) % M
 
    if(y % 2 == 0):
        return p
    else:
        return ((x * p) % M)
 
 
 
def gcd(a, b):
    if (a == 0):
        return b
 
    return gcd(b % a, a)
 
 


def powerOfTwoF():
    for i in range(2, MAX_N):
        powerOfTwo[i] = int(powerOfTwo[i-1] + math.log2(i&(~(i-1))))


def isSet(n, i):
    return powerOfTwo[n] - powerOfTwo[n-i] - powerOfTwo[i] == 0;
def solve(n,r):
    # nci from nc(i-1)
    dp[r] = 1
    for i in range(r+1, n+1):
        dp[i] = (dp[i-1]%mod * (i%mod * modInverse(i-r,mod)%mod)%mod)%mod

def main():
    powerOfTwoF()
    t = int(input())
    for _ in range(t):
        [n, m] = [int(x) for x in input().split()]
        s = input()
        q = input()
        possible = "1"
        for i in range(1,m):
            if isSet(m-1, i):
                possible += "1"
            else: possible+="0"
        if possible!=q:
            print(0)
        else:
            solve(n,m-1)
            ans = 0
            for i in range(n-m+1):
                if(s[i]=="1"):
                    ans = (ans%mod + dp[n-i-1]%mod)%mod
            print(ans)

main()
Solution in Java
import com.sun.jdi.IntegerValue;

import java.util.*;
import java.lang.*;
import java.io.*;
import java.util.concurrent.atomic.AtomicReferenceFieldUpdater;

public class Main
{
    static PrintWriter out = new PrintWriter(new BufferedOutputStream(System.out));
    static FastReader sc = new FastReader();

    static long mod = (int)1e9+7;
    static long mod2 = 998244353;
    static class Pair implements Comparable<Pair>{
        int a, b;
        Pair(int a, int b){
            this.a=a;
            this.b=b;
        }
        public int compareTo(Pair o){
            return this.b-o.b;
        }
    }


    static void fillNcr(HashMap<Long, Long> ncr, long n, long r) {
        ncr.put(r, 1L);
        for (long i = r+1L; i <= n; i++) {
            long ncr_i = ((ncr.get(i-1) % mod) * (i % mod)) % mod ;
            ncr_i = (ncr_i * (modInverse(i-r, mod) % mod))%mod;
            ncr.put(i, ncr_i);
        }
    }
    static long modInverse(long A, long M)
    {
        long g = gcd(A, M);
        if (g != 1)
            return -1;
        else {
            return power(A, M-2, M);
        }
    }

    // To compute x^y under modulo m
    static long power(long x, long y, long M)
    {
        if (y == 0)
            return 1L;

        long p = power(x, y / 2, M) % M;
        p = (p * p) % M;

        return (y % 2 == 0) ? p : (x * p) % M;
    }

    static long gcd(long a, long b)
    {
        if (a == 0)
            return b;
        return gcd(b % a, a);
    }

    static int[] po2;
    static boolean isSet(int i, int n){
        return ((long)po2[n] - (long)po2[i] - (long)po2[n-i])==0L;
    }
    public static void main (String[] args) throws java.lang.Exception {
        po2 = new int[1000000];
        for(int i = 2;i<1000000;++i){
            po2[i] = po2[i-1] + Integer.numberOfTrailingZeros(i);
        }
        
        int t = sc.nextInt();
        while (t-- > 0) {
            solve();
        }
    }
    public static void solve() {
        int n = i(), m = i();
        String s = s(), q = s();
        StringBuilder sb = new StringBuilder("1");
        for(int i = 1;i<m;++i){
            sb.append(isSet(i, m-1) ? "1" : "0");
        }

        String ans = sb.toString();

        if(ans.equals(q)){
            long res = 0L;
            HashMap<Long, Long> hm = new HashMap<>();
            fillNcr(hm, n, m-1);
            for(int i = 0;i<=n-m;++i){
                if(s.charAt(i) == '1'){
                    res = (res%mod + (hm.get((long)n-i-1)%mod))%mod;
                }
            }
            out.println(res);
        }else{
            out.println(0);
        }
        out.flush();
    }

    static int i() {
        return sc.nextInt();
    }
    static String s() {
        return sc.next();
    }
    static long l() {
        return sc.nextLong();
    }
    static int[] ia(int n){
        int[] arr= new int[n];
        for(int i = 0;i<n;++i){
            arr[i] = i();
        }
        return arr;
    }

    static class FastReader {
        BufferedReader br;
        StringTokenizer st;

        public FastReader()
        {
            br = new BufferedReader(
                    new InputStreamReader(System.in));
        }

        String next()
        {
            while (st == null || !st.hasMoreElements()) {
                try {
                    st = new StringTokenizer(br.readLine());
                }
                catch (IOException e) {
                    e.printStackTrace();
                }
            }
            return st.nextToken();
        }

        int nextInt() { return Integer.parseInt(next()); }

        long nextLong() { return Long.parseLong(next()); }

        double nextDouble()
        {
            return Double.parseDouble(next());
        }

        String nextLine()
        {
            String str = "";
            try {
                str = br.readLine();
            }
            catch (IOException e) {
                e.printStackTrace();
            }
            return str;
        }
    }
}