PROBLEM LINK:
Contest Division 1
Contest Division 2
Contest Division 3
Contest Division 4
Setter: Utkarsh Gupta
Tester: Abhinav Sharma, Aryan
Editorialist: Lavish Gupta
DIFFICULTY:
Easy-Medium
PREREQUISITES:
PROBLEM:
Chef has a (0-indexed) binary string S of length N such that N is a power of 2.
Chef wants to find the number of pairs (i, j) such that:
- 0 \leq i,j \lt N
- S_{i|j} = S_{i\&j}
(Here | denotes the bitwise OR operation and \& denotes the bitwise AND operation)
Can you help Chef to do so?
QUICK EXPLANATION:
What if we fix i?
Let S_1 be the set of position bits which are set in i and S_2 be S_1’s complement. Now if we fix i|j, we can fix all the bits of j that are present in S_2, whereas all the bits of j which are in S_1 can take both 0 and 1. So, i\&j can take all the possible values of submasks of i, for a fix i and i|j.
How to sum up the answer for a fixed i?
Let’s define supermasks of i as the collection of all masks for which i is a submask.
For a fix i|j, which is a supermask of i, i\&j can be any of the submasks of i, we’ll take values such that character of S at index i|j and i\&j are the same. So, if c_1 and c_0 denote the count of supermasks of i which have values 1 and 0 respectively, and d_1 and d_0 count of submaks of i which have values 1 and 0 respectively, our answer for i is c_1 \cdot d_1 + c_0 \cdot d_0.
How to calculate these values optimally?
We can use the SOS Dynamic Programming technique to calculate the values for all i in O(N \cdot \log{N}) time.
EXPLANATION:
If we think naively for once, we can just iterate through all possible i and j and check that S_{i|j} = S_{i\&j}, but this will take O(N^2) time, which will exceed the Time Limit.
To optimize the approach, we can try to fix i for once and then see what happens. Now, we’ll try to analyze the bits of j and how it affects the answer. Let S_1 be the set of position bits which are set in i, and S_2 be the set of positions of bits which are not set in i. If we further fix i|j, the bits of j that are present in S_2 get fixed, whereas all the bits of j from S_1 can take both values 0 and 1. Since in i\&j is a submask of i, i\&j only contains set bits from S_1, and hence i\&j can be any of the submasks of i. Let’s define supermasks of i as the collection of all masks for which i is a submask. For a fix i|j, which is a supermask of i, i\&j can be any of the submasks of i, we’ll take values such that S at i|j and i\&j are the same. If d_1 and d_0 denote count of submaks of i which have values 1 and 0 respectively, if S_{i|j} = 1, we have d_1 values of i\&j, otherwise d_0 values of i\&j. Let c_1 and c_0 denote the count of supermasks of i which have values 1 and 0 respectively. So, for a fixed i the total answer is c_1 \cdot d_1 + c_0 \cdot d_0.
We want to calculate the values of c_1, c_0, d_1, d_0 for all values of i. If we iterate through all the submasks and supermasks for every i, the time taken will be O(3^{\log_2{n}}) = O(3^{20}). To further optimize this, we have to use the SOS Dynamic programming approach, which can calculate the sum of S at all submasks for every i in much less time.
TIME COMPLEXITY:
In the SOS Dynamic Programming Approach populating the DP values for all i will take O(N \cdot \log{N}) time and then adding answer at each i will take O(N) time. So our total time complexity will be O(N \cdot \log{N})
SOLUTION:
Setter's Solution
//Utkarsh.25dec
#include <bits/stdc++.h>
#define ll long long int
#define pb push_back
#define mp make_pair
#define mod 1000000007
#define vl vector <ll>
#define all(c) (c).begin(),(c).end()
using namespace std;
ll power(ll a,ll b) {ll res=1;a%=mod; assert(b>=0); for(;b;b>>=1){if(b&1)res=res*a%mod;a=a*a%mod;}return res;}
ll modInverse(ll a){return power(a,mod-2);}
const int N=2000023;
bool vis[N];
vector <int> adj[N];
long long readInt(long long l,long long r,char endd){
long long x=0;
int cnt=0;
int fi=-1;
bool is_neg=false;
while(true){
char g=getchar();
if(g=='-'){
assert(fi==-1);
is_neg=true;
continue;
}
if('0'<=g && g<='9'){
x*=10;
x+=g-'0';
if(cnt==0){
fi=g-'0';
}
cnt++;
assert(fi!=0 || cnt==1);
assert(fi!=0 || is_neg==false);
assert(!(cnt>19 || ( cnt==19 && fi>1) ));
} else if(g==endd){
if(is_neg){
x= -x;
}
if(!(l <= x && x <= r))
{
cerr << l << ' ' << r << ' ' << x << '\n';
assert(1 == 0);
}
return x;
} else {
assert(false);
}
}
}
string readString(int l,int r,char endd){
string ret="";
int cnt=0;
while(true){
char g=getchar();
assert(g!=-1);
if(g==endd){
break;
}
cnt++;
ret+=g;
}
assert(l<=cnt && cnt<=r);
return ret;
}
long long readIntSp(long long l,long long r){
return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
return readString(l,r,'\n');
}
string readStringSp(int l,int r){
return readString(l,r,' ');
}
ll sumN=0;
int good[N]={0};
ll A[N], F[N];
void solve()
{
int N=readInt(2,(1<<20),'\n');
sumN+=N;
assert(sumN<=(1<<20));
assert(good[N]==1);
string s=readString(N,N,'\n');
int n=0;
int temp=1;
while(temp!=N)
{
temp*=2;
n++;
}
ll ans=0;
{
for(int i=0;i<(1<<n);i++)
{
if(s[i]=='0')
A[i]=0;
else
A[i]=(1<<(n-(__builtin_popcount(i))));
}
for(int i=0;i<(1<<n);i++)
F[i]=A[i];
for(int i = 0;i < n; ++i)
for(int mask = 0; mask < (1<<n); ++mask)
{
if(mask & (1<<i))
{
F[mask] += F[mask^(1<<i)];
}
}
for(int i=0;i<(1<<n);i++)
{
if(s[i]=='0')
continue;
ans+=(F[i]/((1<<(n-(__builtin_popcount(i))))));
}
}
{
for(int i=0;i<(1<<n);i++)
{
if(s[i]=='1')
A[i]=0;
else
A[i]=(1<<(n-(__builtin_popcount(i))));
}
for(int i=0;i<(1<<n);i++)
F[i]=A[i];
for(int i = 0;i < n; ++i)
for(int mask = 0; mask < (1<<n); ++mask)
{
if(mask & (1<<i))
{
F[mask] += F[mask^(1<<i)];
}
}
for(int i=0;i<(1<<n);i++)
{
if(s[i]=='1')
continue;
ans+=(F[i]/((1<<(n-(__builtin_popcount(i))))));
}
}
cout<<ans<<'\n';
}
int main()
{
#ifndef ONLINE_JUDGE
freopen("input.txt", "r", stdin);
freopen("output.txt", "w", stdout);
#endif
ios_base::sync_with_stdio(false);
cin.tie(NULL),cout.tie(NULL);
int T=readInt(1,1000,'\n');
for(int i=1;i<=20;i++)
good[(1<<i)]=1;
while(T--)
solve();
assert(getchar()==-1);
cerr << "Time : " << 1000 * ((double)clock()) / (double)CLOCKS_PER_SEC << "ms\n";
}
Tester's Solution
#include <bits/stdc++.h>
using namespace std;
/*
------------------------Input Checker----------------------------------
*/
long long readInt(long long l,long long r,char endd){
long long x=0;
int cnt=0;
int fi=-1;
bool is_neg=false;
while(true){
char g=getchar();
if(g=='-'){
assert(fi==-1);
is_neg=true;
continue;
}
if('0'<=g && g<='9'){
x*=10;
x+=g-'0';
if(cnt==0){
fi=g-'0';
}
cnt++;
assert(fi!=0 || cnt==1);
assert(fi!=0 || is_neg==false);
assert(!(cnt>19 || ( cnt==19 && fi>1) ));
} else if(g==endd){
if(is_neg){
x= -x;
}
if(!(l <= x && x <= r))
{
cerr << l << ' ' << r << ' ' << x << '\n';
assert(1 == 0);
}
return x;
} else {
assert(false);
}
}
}
string readString(int l,int r,char endd){
string ret="";
int cnt=0;
while(true){
char g=getchar();
assert(g!=-1);
if(g==endd){
break;
}
cnt++;
ret+=g;
}
assert(l<=cnt && cnt<=r);
return ret;
}
long long readIntSp(long long l,long long r){
return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
return readString(l,r,'\n');
}
string readStringSp(int l,int r){
return readString(l,r,' ');
}
/*
------------------------Main code starts here----------------------------------
*/
const int MAX_T = 1e5;
const int MAX_N = 1e5;
const int MAX_SUM_LEN = 1e5;
#define fast ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define ff first
#define ss second
#define mp make_pair
#define ll long long
#define rep(i,n) for(int i=0;i<n;i++)
#define rev(i,n) for(int i=n;i>=0;i--)
#define rep_a(i,a,n) for(int i=a;i<n;i++)
int sum_len = 0;
int max_n = 0;
int yess = 0;
int nos = 0;
int total_ops = 0;
const ll mod = 998244353;
ll po(ll x, ll n ){
ll ans=1;
while(n>0){
if(n&1) ans=(ans*x)%mod;
x=(x*x)%mod;
n/=2;
}
return ans;
}
ll fun(vector<ll> &v, vector<int>&is_on, int n){
rev(i, 19){
rev(j,n-1){
if((j>>i)&1) v[j^(1<<i)] += v[j];
}
}
ll ret = 0;
rep(i,n){
if(!is_on[i]) continue;
int tmp = __builtin_popcount(i);
ll div = (1<<tmp);
ret += v[i]/div;
}
return ret;
}
void solve()
{
int n = readIntLn(2, 1<<20);
sum_len += n;
max_n = max(max_n, n);
string s = readStringLn(n,n);
assert(__builtin_popcount(n)==1);
vector<ll> v(n);
vector<int> z(n,0);
rep(i,n){
if(s[i]=='0'){
int tmp = __builtin_popcount(i);
v[i] = (1<<tmp);
z[i] = 1;
}
else v[i] = 0;
}
ll ans = fun(v, z, n);
z.assign(n,0);
rep(i,n){
if(s[i]=='1'){
int tmp = __builtin_popcount(i);
v[i] = (1<<tmp);
z[i] = 1;
}
else v[i] = 0;
}
ans += fun(v,z,n);
cout<<ans<<'\n';
}
signed main()
{
#ifndef ONLINE_JUDGE
freopen("input.txt", "r" , stdin);
freopen("output.txt", "w" , stdout);
#endif
fast;
int t = 1;
t = readIntLn(1,1000);
for(int i=1;i<=t;i++)
{
solve();
}
assert(getchar() == -1);
assert(sum_len<=(1<<20));
cerr<<"SUCCESS\n";
cerr<<"Tests : " << t << '\n';
cerr<<"Sum of lengths : " << sum_len << '\n';
cerr<<"Maximum length : " << max_n << '\n';
// cerr<<"Total operations : " << total_ops << '\n';
//cerr<<"Answered yes : " << yess << '\n';
//cerr<<"Answered no : " << nos << '\n';
}
Editorialist's Solution
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define dd double
#define endl "\n"
#define pb push_back
#define all(v) v.begin(),v.end()
#define mp make_pair
#define fi first
#define se second
#define vll vector<ll>
#define pll pair<ll,ll>
#define fo(i,n) for(int i=0;i<n;i++)
#define fo1(i,n) for(int i=1;i<=n;i++)
ll mod=1000000007;
ll n,k,t,m,q,flag=0;
ll power(ll a,ll b) {ll res=1;a%=mod; assert(b>=0); for(;b;b>>=1){if(b&1)res=res*a%mod;a=a*a%mod;}return res;}
// #include <ext/pb_ds/assoc_container.hpp>
// #include <ext/pb_ds/tree_policy.hpp>
// using namespace __gnu_pbds;
// #define ordered_set tree<int, null_type,less<int>, rb_tree_tag,tree_order_statistics_node_update>
// ordered_set s ; s.order_of_key(a) -- no. of elements strictly less than a
// s.find_by_order(i) -- itertor to ith element (0 indexed)
ll min(ll a,ll b){if(a>b)return b;else return a;}
ll max(ll a,ll b){if(a>b)return a;else return b;}
ll gcd(ll a , ll b){ if(b > a) return gcd(b , a) ; if(b == 0) return a ; return gcd(b , a%b) ;}
int main() {
ios_base::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
#ifdef NOOBxCODER
freopen("input.txt", "r", stdin);
freopen("output.txt", "w", stdout);
#else
#define NOOBxCODER 0
#endif
cin>>t;
//t=1;
while(t--){
cin>>n;
string s;
int a[n],b[n];
cin>>s;
fo(i,n){a[i] =s[i]-'0'; b[n - 1 -i]= s[i]-'0'; }
//ll dp1[n][21];dp2[n][21];
//fo(i,n)cout<<a[i]; cout<<endl; fo(i,n)cout<<b[i];cout<<endl;
ll f1[n],f2[n];
ll m= log2(n);
for(int i = 0; i<n ; ++i)
f1[i] = a[i];// f2[i]
for(int i = 0;i < m; ++i) for(int mask = 0; mask < n; ++mask){
if(mask & (1<<i))
f1[mask] += f1[mask^(1<<i)];
}
for(int i = 0; i<n ; ++i)
f2[i] = b[i];
for(int i = 0;i < m; ++i) for(int mask = 0; mask < n; ++mask){
if(mask & (1<<i))
f2[mask] += f2[mask^(1<<i)];
}
ll ans=0;
for(int i=0;i<n;i++){
int c = __builtin_popcount(i);
//cout<<f1[i]<<" "<<f2[i]<<endl;
ans+= (f1[i]*f2[n-1-i]) + ((ll)(1<<c ) - f1[i] ) *((ll)(1<<(m-c)) -f2[n-1-i]);
}
cout<<ans<<endl;
}
cerr << "Time : " << 1000 * ((double)clock()) / (double)CLOCKS_PER_SEC << "ms\n";
return 0;
}