PROBLEM LINK:
Authors: chef_hamster
Testers: chef_hamster
Editorialist: chef_hamster
Difficulty
Hard
PROBLEM:
Given a binary string \bold S and a target binary string \bold Q, a subsequence, \bold K_s, of \bold S is called good if:
- The xor of all possible non-empty subsequences of \bold K_s is equal to given \bold Q.
Find how many such good subsequences exist for \bold S.
Since the answer can be very large, take mod 10^9+7.
Prerequisites:
- Number theory
- Combinatorics
- Basic knowledge of XOR
- Patience (very important
)
Hint:
1st Hint
Only MSB contributes to XOR of all possible subsequences. (All other bits appear even number of times at each position)
2nd Hint
For any binary string of length n+1, if MSB = 1, suppose X_b is the XOR of all possible subsequences, then any b[i] = n \choose i % 2.
Solution in C++
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
#define endl "\n"
#define fio ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL);
#define input(arr,n) for(int i=0;i<n;i++) cin>>arr[i];
#define fr(i,n) for(int i=0;i<n;i++)
#define rf(i,n) for(int i=n-1;i>=0;i--)
#define mod2 (ll)998244353
#define mod (ll)1000000007
#define yy "YES\n"
#define nn "NO\n"
ll binexp(ll a, ll b) {ll res = 1;while(b > 0){if(b & 1)res = res * a;a = a * a;b >>= 1;}return res;}
ll power(ll x,ll y, ll p){ll res = 1;x = x % p;if (x == 0) return 0;while (y > 0){if (y & 1)res = (res*x) % p;y = y>>1;x = (x*x) % p;}return res;}
void to_bin(ll n, ll arr[]){ll i=0;while(n){arr[31-i]+=n%2;n/=2;i++;}return;}
int to_deci(ll arr[]){ll ans=0;for(ll i=0;i<32;i++){if(arr[i]%2)ans+=ll(1<<(31-i));}return ans;}
ll min(ll x,ll y){return (x>y?y:x);}
ll max(ll x,ll y){return (x<y?y:x);}
ll gcd(ll a,ll b){return b == 0 ? a : gcd(b, a % b);}
ll modInverse(ll A, ll M){ll m0 = M;ll y = 0, x = 1;if (M == 1)return 0;while (A > 1) {ll q = A / M;ll t = M;M = A % M, A = t;t = y;y = x - q * y;x = t;}if (x < 0)x += m0;return x;}
/*------------------------------------------------------------*/
#define MAXN 1000000
/*--Precalculate highest power of two which divides 1 to n!--*/
vector<ll> powerOfTwo(MAXN,0);
void P2(){
for(int i=2;i<MAXN;i++){
powerOfTwo[i] = powerOfTwo[i-1] + log2(i&(~(i-1)));
}
}
/*--------------------------------------------------------*/
bool isSet(ll i, ll n){
return powerOfTwo[n] - powerOfTwo[i] - powerOfTwo[n-i]==0;
}
/*--------------------------------------------------------*/
void fillNcr(map<ll,ll> &ncr,ll n, ll r){
ncr[r] = 1;
for(ll i=r+1;i<=n;i++){
ncr[i] = (ncr[i-1]%mod * (i%mod)%mod * modInverse(i-r,mod)%mod)%mod;
}
// for(int i=r;i<=n;i++)cout<<ncr[i]<<" ";
}
/*--------------------------------------------------------*/
int main() {
#ifndef ONLINE_JUDGE
freopen("inputE0.txt","r",stdin);
freopen("outputE0.txt","w",stdout);
#endif
fio;
P2();
int t=1;
cin>>t;
while(t--){
ll n,m;
cin>>n>>m;
string s,q;
cin>>s>>q;
string ans = "1";
for(int i=1;i<m;i++){
ans += isSet(i,m-1)==true?"1":"0";
}
if(ans!=q){
cout<<0<<endl;
}
else{
ll ans = 0;
map<ll,ll> ncr;
fillNcr(ncr,n,m-1);
for(int i=0;i<=n-m;i++){
if(s[i]=='1'){
ans = (ans%mod + ncr[n-i-1]%mod)%mod;
}
}
cout<<ans<<endl;
}
}
return 0;
}
Solution in Python
MAX_N = 1000000
mod = int(1e9+7)
import math
dp = [0 for i in range(MAX_N)]
powerOfTwo = [0 for i in range(MAX_N)]
def modInverse(A, M):
g = gcd(A, M)
if (g != 1):
print("Inverse doesn't exist")
else:
return power(A, M - 2, M)
def power(x, y, M):
if (y == 0):
return 1
p = power(x, y // 2, M) % M
p = (p * p) % M
if(y % 2 == 0):
return p
else:
return ((x * p) % M)
def gcd(a, b):
if (a == 0):
return b
return gcd(b % a, a)
def powerOfTwoF():
for i in range(2, MAX_N):
powerOfTwo[i] = int(powerOfTwo[i-1] + math.log2(i&(~(i-1))))
def isSet(n, i):
return powerOfTwo[n] - powerOfTwo[n-i] - powerOfTwo[i] == 0;
def solve(n,r):
# nci from nc(i-1)
dp[r] = 1
for i in range(r+1, n+1):
dp[i] = (dp[i-1]%mod * (i%mod * modInverse(i-r,mod)%mod)%mod)%mod
def main():
powerOfTwoF()
t = int(input())
for _ in range(t):
[n, m] = [int(x) for x in input().split()]
s = input()
q = input()
possible = "1"
for i in range(1,m):
if isSet(m-1, i):
possible += "1"
else: possible+="0"
if possible!=q:
print(0)
else:
solve(n,m-1)
ans = 0
for i in range(n-m+1):
if(s[i]=="1"):
ans = (ans%mod + dp[n-i-1]%mod)%mod
print(ans)
main()
Solution in Java
import com.sun.jdi.IntegerValue;
import java.util.*;
import java.lang.*;
import java.io.*;
import java.util.concurrent.atomic.AtomicReferenceFieldUpdater;
public class Main
{
static PrintWriter out = new PrintWriter(new BufferedOutputStream(System.out));
static FastReader sc = new FastReader();
static long mod = (int)1e9+7;
static long mod2 = 998244353;
static class Pair implements Comparable<Pair>{
int a, b;
Pair(int a, int b){
this.a=a;
this.b=b;
}
public int compareTo(Pair o){
return this.b-o.b;
}
}
static void fillNcr(HashMap<Long, Long> ncr, long n, long r) {
ncr.put(r, 1L);
for (long i = r+1L; i <= n; i++) {
long ncr_i = ((ncr.get(i-1) % mod) * (i % mod)) % mod ;
ncr_i = (ncr_i * (modInverse(i-r, mod) % mod))%mod;
ncr.put(i, ncr_i);
}
}
static long modInverse(long A, long M)
{
long g = gcd(A, M);
if (g != 1)
return -1;
else {
return power(A, M-2, M);
}
}
// To compute x^y under modulo m
static long power(long x, long y, long M)
{
if (y == 0)
return 1L;
long p = power(x, y / 2, M) % M;
p = (p * p) % M;
return (y % 2 == 0) ? p : (x * p) % M;
}
static long gcd(long a, long b)
{
if (a == 0)
return b;
return gcd(b % a, a);
}
static int[] po2;
static boolean isSet(int i, int n){
return ((long)po2[n] - (long)po2[i] - (long)po2[n-i])==0L;
}
public static void main (String[] args) throws java.lang.Exception {
po2 = new int[1000000];
for(int i = 2;i<1000000;++i){
po2[i] = po2[i-1] + Integer.numberOfTrailingZeros(i);
}
int t = sc.nextInt();
while (t-- > 0) {
solve();
}
}
public static void solve() {
int n = i(), m = i();
String s = s(), q = s();
StringBuilder sb = new StringBuilder("1");
for(int i = 1;i<m;++i){
sb.append(isSet(i, m-1) ? "1" : "0");
}
String ans = sb.toString();
if(ans.equals(q)){
long res = 0L;
HashMap<Long, Long> hm = new HashMap<>();
fillNcr(hm, n, m-1);
for(int i = 0;i<=n-m;++i){
if(s.charAt(i) == '1'){
res = (res%mod + (hm.get((long)n-i-1)%mod))%mod;
}
}
out.println(res);
}else{
out.println(0);
}
out.flush();
}
static int i() {
return sc.nextInt();
}
static String s() {
return sc.next();
}
static long l() {
return sc.nextLong();
}
static int[] ia(int n){
int[] arr= new int[n];
for(int i = 0;i<n;++i){
arr[i] = i();
}
return arr;
}
static class FastReader {
BufferedReader br;
StringTokenizer st;
public FastReader()
{
br = new BufferedReader(
new InputStreamReader(System.in));
}
String next()
{
while (st == null || !st.hasMoreElements()) {
try {
st = new StringTokenizer(br.readLine());
}
catch (IOException e) {
e.printStackTrace();
}
}
return st.nextToken();
}
int nextInt() { return Integer.parseInt(next()); }
long nextLong() { return Long.parseLong(next()); }
double nextDouble()
{
return Double.parseDouble(next());
}
String nextLine()
{
String str = "";
try {
str = br.readLine();
}
catch (IOException e) {
e.printStackTrace();
}
return str;
}
}
}