# NCOPIES - Editorial

Author: Utkarsh Gupta, Jeevan Jyot Singh
Testers: Abhinav Sharma, Venkata Nikhil Medam
Editorialist: Nishank Suresh

1745

Prefix sums

# PROBLEM:

Chef has a binary string A of length N. He creates binary string B by concatenating M copies of A. Find the number of positions in B such that pref_i = suf_{i+1}.

# EXPLANATION:

Let S denote the sum of A, i.e, S = A_1 + A_2 + \ldots + A_N.

Now, suppose we know that pref_i = suf_{i+1} for some index i of B. What can we say about pref_i?

pref_i must be equal to \frac{M\cdot S}{2}.

This is because pref_i + suf_{i+1} always equals the total sum of B, which is M\cdot S (since B is formed from M copies of A).

Note that the above division is not floor division. In particular, when M\cdot S is odd, no good index can exist.

Now the problem reduces to finding the number of indices of B whose prefix sum is a given value. This can be done in several ways, though they all depend on the fact that the prefix sums are non-decreasing. For example:

• Since M is small, it is possible to simply iterate over the number of copies while the current prefix sum is smaller than the target value, each time adding S to the current prefix sum. When the prefix sum exceeds the target, iterate across that copy of A in \mathcal{O}(N) and count the number of good indices. This takes \mathcal{O}(N + M) time.
• Some care needs to be taken when implementing this. For example, it might be that the next copy of A (if it exists) also contributes some indices to the answer, for example if A starts with a 0. Also, depending on implementation, a string with all zeros might be an edge case for the solution, causing either TLE or WA since all N\cdot M indices are good.
• Another option with less thinking involved is to binary search for the first and last positions with the target prefix sum. The prefix sum for a given position can be calculated in \mathcal{O}(1) if we know the prefix sums of A, and so this solution runs in \mathcal{O}(\log(N\cdot M)). It still requires \mathcal{O}(N) time to read the input string, however.

# TIME COMPLEXITY:

\mathcal{O}(N + M) or \mathcal{O}(N + \log(N\cdot M)) per test case, depending on implementation.

# CODE:

Setter (C++)
#ifdef WTSH
#include <wtsh.h>
#else
#include <bits/stdc++.h>
using namespace std;
#define dbg(...)
#endif

#define int long long
#define endl "\n"
#define sz(w) (int)(w.size())
using pii = pair<int, int>;

const int mod = 998244353;

// -------------------- Input Checker Start --------------------

long long readInt(long long l, long long r, char endd)
{
long long x = 0;
int cnt = 0, fi = -1;
bool is_neg = false;
while(true)
{
char g = getchar();
if(g == '-')
{
assert(fi == -1);
is_neg = true;
continue;
}
if('0' <= g && g <= '9')
{
x *= 10;
x += g - '0';
if(cnt == 0)
fi = g - '0';
cnt++;
assert(fi != 0 || cnt == 1);
assert(fi != 0 || is_neg == false);
assert(!(cnt > 19 || (cnt == 19 && fi > 1)));
}
else if(g == endd)
{
if(is_neg)
x = -x;
if(!(l <= x && x <= r))
{
cerr << "L: " << l << ", R: " << r << ", Value Found: " << x << '\n';
assert(false);
}
return x;
}
else
{
assert(false);
}
}
}

string readString(int l, int r, char endd)
{
string ret = "";
int cnt = 0;
while(true)
{
char g = getchar();
assert(g != -1);
if(g == endd)
break;
cnt++;
ret += g;
}
assert(l <= cnt && cnt <= r);
return ret;
}

long long readIntSp(long long l, long long r) { return readInt(l, r, ' '); }
long long readIntLn(long long l, long long r) { return readInt(l, r, '\n'); }
string readStringSp(int l, int r) { return readString(l, r, ' '); }
void readEOF() { assert(getchar() == EOF); }

vector<int> readVectorInt(int n, long long l, long long r)
{
vector<int> a(n);
for(int i = 0; i < n - 1; i++)
a[n - 1] = readIntLn(l, r);
return a;
}

// -------------------- Input Checker End --------------------

int sumN = 0;

void solve()
{
sumN += n;
assert(*min_element(a.begin(), a.end()) >= '0' and *max_element(a.begin(), a.end()) <= '1');
int S = count(a.begin(), a.end(), '1');
if(S == 0)
cout << n * m << endl;
else if(S * m % 2 == 1)
cout << 0 << endl;
else
{
string b = a;
if(m % 2 == 0)
b += a, S += S;
int cur = 0, ans = 0;
for(int i = 0; i < sz(b); i++)
{
cur += b[i] - '0';
if(2 * cur == S)
ans++;
}
cout << ans << endl;
}
}

int32_t main()
{
ios::sync_with_stdio(0);
cin.tie(0);
for(int tc = 1; tc <= T; tc++)
{
// cout << "Case #" << tc << ": ";
solve();
}
assert(sumN <= 2e5);
return 0;
}

Tester (C++)
// Tester: Nikhil_Medam
#include <bits/stdc++.h>
#pragma GCC optimize ("-O3")
using namespace std;
#define IOS ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
#define endl "\n"
#define int long long
#define double long double
const int N = 1e5 + 5;

int t, n, m;
string s;
int32_t main() {
cin >> t;
while(t--) {
cin >> n >> m >> s;
int sum = 0;
for(int i = 0; i < n; i++) {
sum += (s[i] - '0');
}
if(sum == 0) {
cout << n * m << endl;
}
else if ((sum * m) % 2 == 1) {
cout << 0 << endl;
}
else {
if(m % 2 == 0) {
int cnt_0_start = 0, cnt_0_end = 0;
for(int i = 0; i < n; i++) {
if(s[i] == '1') {
break;
}
cnt_0_start++;
}
for(int i = n - 1; i >= 0; i--) {
if(s[i] == '1') {
break;
}
cnt_0_end++;
}
cout << cnt_0_start + cnt_0_end + 1 << endl;
}
else {
int ans = 0, cur_sum = 0;
for(int i = 0; i < n; i++) {
cur_sum += (s[i] - '0');
ans += (cur_sum == sum / 2);
if(cur_sum > sum / 2) {
break;
}
}
cout << ans << endl;
}
}
}
return 0;
}

Editorialist (Python)
for _ in range(int(input())):
n, m = map(int, input().split())
s = input()

# pref[i] = suf[i+1]
# pref[i] + suf[i+1] = S
# pref[i] = S/2
tot = s.count('1')
target = tot*m
if target%2 == 1:
print(0)
continue
if target == 0:
print(n*m)
continue

target //= 2
cur = 0
while m > 0:
if cur + tot < target:
m -= 1
cur += tot
continue
else:
break
ans = 0
for j in range(min(m, 2)):
for i in range(n):
ans += cur == target
cur += s[i] == '1'
print(ans)

for _ in range(int(input())):
n,m=(int(i) for i in input().split())
a=input()
if(a.count("0")==n):
print(n*m)
continue
if(m%2==0):
a=a+a
s=list(map(int,list(a)))
pre=[0]*n
pre[0]=s[0]
for i in range(1,n):
pre[i]=pre[i-1]+s[i]
sm=sum(s)
cnt=0
for i in range(n):
if(pre[i]==sm-pre[i]):
cnt+=1
print(cnt)


Can you please say one test case , where is my code is not giving correct answer !!!

2 Likes

The following is my code:

import math
from collections import defaultdict
#list(map(int,input().split()))
for _ in range(int(input())):
n,m=list(map(int,input().split()))
a=input()
b=m*a
t0=0
tt1=0
for i in a:
if(i=="0"):
t0+=1
elif(i=="1"):
tt1+=1
cgi=0
c1=0
t1=tt1*m
if(t1==0):
cgi=n*m
elif(t1%2==1):
cgi=0
elif(tt1%2==0 and m%2==1):
ef=0
for i in a:
if(i=="1"):
ef+=1
if(ef==(tt1-ef)):
cgi+=1
elif(m%2==0):
hz=0
tz=0
for i in a:
if( i=="0"):
hz+=1
elif(i=="1"):
break
for j in range(n-1,-1,-1):
if(a[j]=="0"):
tz+=1
elif(a[j]=="1"):
break
cgi=tz+hz+1
print(cgi)



i checked other successful submissions…they are working the same way [my logic is correct and so is my implementation], but still I don’t know why i am getting some unknown runtime error.

same problem, i had written the same code , but it failed on 1 test case.

The line b = m*a in your code is what’s causing the issue, since it creates a string of length upto 10^{10}.

You aren’t using it in the rest of the code, and commenting that line out gets AC.

For me, it failed on test cases 2,3, and 5, except that all are passed

/****For one test case this code fails*****/
#include<bits/stdc++.h>
using namespace std;
#define fo(ini,n) for(int i=ini; i<n; i++)
#define ll long long
void solve();
int main(){
int t;cin>>t;
while(t--){
solve();
}
return 0;
}
void solve(){
int n,m;
cin>>n>>m;
string a;
cin>>a;
int on,one=0;
fo(0,a.length()){
if(a[i]=='1')one++;
}
if(one==0){//.........................1
cout<<m*n<<endl;
return;
}
on=one*m;
int final=0;
if(on%2){
cout<<final<<endl;
return;
}
if(m%2==0){
for(int i=a.length()-1; i>=0; i--){
if(a[i]=='0')final++;
if(a[i]=='1'){
final++;
break;
}
}
fo(0,a.length()){
if(a[i]=='0')final++;
if(a[i]=='1'){
final++;
break;
}
}
}
else
{
int hi=0;
fo(0,a.length()){
if(a[i]=='1'){
hi++;
if(hi>one/2){
break;
}
}
if(hi==one/2){
final++;
}
}
}
cout<<final<<endl;
}
/*For one test case code fails*/


Please tell a test case where this code fails

Input
1
2 2
0 1

Expected Output: 2
Explanation : the resultant string is 0101 so (01,01) and (010,1) matches the conditions prefix=suffix.


Anyone who solved this question with binary search please comment your code.

#include <bits/stdc++.h>
using namespace std;

#define FIO ios_base::sync_with_stdio(false), cin.tie(NULL), cout.tie(0);
typedef long long ll;

int main(){
FIO;
int TC;
cin>>TC;
while(TC-->0){
int n,m;
cin>>n>>m;
string s;
cin>>s;
int sum = 0;
for(auto e:s){
if(e == '1') sum++;
}
if(sum == 0){
cout<<(n*m)<<endl;
continue;
}
int x = (m*sum);
if(x%2 == 1){
cout<<"0\n";
continue;
}
x /= 2;

int k = (x%sum);
if(k == 0){
int i = 0;
int ans = 1;
while(i < n && s[i] == '0'){
ans++;
i++;
}
i = n-1;
while(i >= 0 && s[i] == '0'){
ans++;
i--;
}

cout<<ans<<endl;
continue;
}
int ct = 0;

int i = 0;
for(i=0;i<n;i++){

if(s[i] == '1'){
ct++;
}
if(ct == k) break;

}
i++;
int ans = 1;
while(i<n && s[i] == '0'){
ans++;
i++;
}
if(i == n){
i = 0;
while(i<n && s[i] == '0'){
ans++;
i++;
}
}

cout<<ans<<endl;

}
return 0;
}


On what testcase does this code fail?

@gulshan2052 It does not looks like that you uses Binary Search in your Code.

I got it , there is one simple mistake , I forgot to update the length of the final string

for _ in range(int(input())):
n,m=(int(i) for i in input().split())
a=input()
if(a.count("0")==n):
print(n*m)
continue
if(m%2==0):
a=a+a
s=list(map(int,list(a)))
pre=[0]*len(s)
pre[0]=s[0]
for i in range(1,len(s)):
pre[i]=pre[i-1]+s[i]
sm=sum(s)
cnt=0
for i in range(len(s)):
if(pre[i]==sm-pre[i]):
cnt+=1
print(cnt)


I forgot to update the length of the prefix !

Mysolution Any test case where this fails?

Thank you so much!!!

https://www.codechef.com/viewsolution/71147145

Can anyone please say where my code is going wrong

Ok its solved now i just changed the ceil part of my code with the method of finding ceil without using the inbuilt function …

#include<bits/stdc++.h>
using namespace std;
#define fo(ini,n) for(int i=ini; i<n; i++)
#define ll long long
void solve();
int main(){
int t;cin>>t;
while(t--){
solve();
}
return 0;
}
void solve(){
int n,m;
string a;
cin>>n>>m>>a;
int one=0,final=0;
fo(0,n){
one+=(a[i]-'0');
}
if(one==0){//.........................1
cout<<m*n<<endl;
}
else if((one*m)%2){
cout<<final<<endl;
}
else if(m%2==0){
for(int i=n-1; i>=0; i--){
final+=(a[i]=='0');
if(a[i]=='1'){
final++;
break;
}
}
fo(0,n){
final+=(a[i]=='0');
if(a[i]=='1'){
break;
}

}
cout<<final<<endl;
}
else
{
int zero=0;
fo(0,n){
zero+=a[i]-'0';
final+=(zero==one/2);
if(zero>one/2)break;
}
cout<<final<<endl;
}
}
[quote="iceknight1093, post:1, topic:102574, full:true"]

[Practice](https://www.codechef.com/problems/NCOPIES)
[Contest: Division 1](https://www.codechef.com/START51A/problems/NCOPIES)
[Contest: Division 2](https://www.codechef.com/START51B/problems/NCOPIES)
[Contest: Division 3](https://www.codechef.com/START51C/problems/NCOPIES)
[Contest: Division 4](https://www.codechef.com/START51D/problems/NCOPIES)

***Author:*** [Utkarsh Gupta](https://www.codechef.com/users/utkarsh_25dec), [Jeevan Jyot Singh](https://www.codechef.com/users/jeevanjyot)
***Testers:*** [Abhinav Sharma](https://www.codechef.com/users/inov_360), [Venkata Nikhil Medam](https://www.codechef.com/users/nikhil_medam)
***Editorialist:*** [Nishank Suresh](https://www.codechef.com/users/IceKnight1093)

# DIFFICULTY:
To be updated

# PREREQUISITES:
Prefix sums

# PROBLEM:
Chef has a binary string $A$ of length $N$. He creates binary string $B$ by concatenating $M$ copies of $A$. Find the number of positions in $B$ such that $pref_i = suf_{i+1}$.

# EXPLANATION:
Let $S$ denote the sum of $A$, i.e, $S = A_1 + A_2 + \ldots + A_N$.

Now, suppose we know that $pref_i = suf_{i+1}$ for some index $i$ of $B$. What can we say about $pref_i$?

$pref_i$ must be equal to $\frac{M\cdot S}{2}$.

This is because $pref_i + suf_{i+1}$ always equals the total sum of $B$, which is $M\cdot S$ (since $B$ is formed from $M$ copies of $A$).

Note that the above division is *not* floor division. In particular, when $M\cdot S$ is odd, no good index can exist.
[/details]

Now the problem reduces to finding the number of indices of $B$ whose prefix sum is a given value. This can be done in several ways, though they all depend on the fact that the prefix sums are non-decreasing. For example:
- Since $M$ is small, it is possible to simply iterate over the number of copies while the current prefix sum is smaller than the target value, each time adding $S$ to the current prefix sum. When the prefix sum exceeds the target, iterate across that copy of $A$ in $\mathcal{O}(N)$ and count the number of good indices. This takes $\mathcal{O}(N + M)$ time.
- Some care needs to be taken when implementing this. For example, it might be that the next copy of $A$ (if it exists) also contributes some indices to the answer, for example if $A$ starts with a $0$. Also, depending on implementation, a string with all zeros might be an edge case for the solution, causing either TLE or WA since all $N\cdot M$ indices are good.
- Another option with less thinking involved is to binary search for the first and last positions with the target prefix sum. The prefix sum for a given position can be calculated in $\mathcal{O}(1)$ if we know the prefix sums of $A$, and so this solution runs in $\mathcal{O}(\log(N\cdot M))$. It still requires $\mathcal{O}(N)$ time to read the input string, however.

# TIME COMPLEXITY:
$\mathcal{O}(N + M)$ or $\mathcal{O}(N + \log(N\cdot M))$ per test case, depending on implementation.

# CODE:
[details = Setter (C++)]
cpp
#ifdef WTSH
#include <wtsh.h>
#else
#include <bits/stdc++.h>
using namespace std;
#define dbg(...)
#endif

#define int long long
#define endl "\n"
#define sz(w) (int)(w.size())
using pii = pair<int, int>;

const int mod = 998244353;

// -------------------- Input Checker Start --------------------

long long readInt(long long l, long long r, char endd)
{
long long x = 0;
int cnt = 0, fi = -1;
bool is_neg = false;
while(true)
{
char g = getchar();
if(g == '-')
{
assert(fi == -1);
is_neg = true;
continue;
}
if('0' <= g && g <= '9')
{
x *= 10;
x += g - '0';
if(cnt == 0)
fi = g - '0';
cnt++;
assert(fi != 0 || cnt == 1);
assert(fi != 0 || is_neg == false);
assert(!(cnt > 19 || (cnt == 19 && fi > 1)));
}
else if(g == endd)
{
if(is_neg)
x = -x;
if(!(l <= x && x <= r))
{
cerr << "L: " << l << ", R: " << r << ", Value Found: " << x << '\n';
assert(false);
}
return x;
}
else
{
assert(false);
}
}
}

string readString(int l, int r, char endd)
{
string ret = "";
int cnt = 0;
while(true)
{
char g = getchar();
assert(g != -1);
if(g == endd)
break;
cnt++;
ret += g;
}
assert(l <= cnt && cnt <= r);
return ret;
}

long long readIntSp(long long l, long long r) { return readInt(l, r, ' '); }
long long readIntLn(long long l, long long r) { return readInt(l, r, '\n'); }
string readStringSp(int l, int r) { return readString(l, r, ' '); }
void readEOF() { assert(getchar() == EOF); }

vector<int> readVectorInt(int n, long long l, long long r)
{
vector<int> a(n);
for(int i = 0; i < n - 1; i++)
a[n - 1] = readIntLn(l, r);
return a;
}

// -------------------- Input Checker End --------------------

int sumN = 0;

void solve()
{
sumN += n;
assert(*min_element(a.begin(), a.end()) >= '0' and *max_element(a.begin(), a.end()) <= '1');
int S = count(a.begin(), a.end(), '1');
if(S == 0)
cout << n * m << endl;
else if(S * m % 2 == 1)
cout << 0 << endl;
else
{
string b = a;
if(m % 2 == 0)
b += a, S += S;
int cur = 0, ans = 0;
for(int i = 0; i < sz(b); i++)
{
cur += b[i] - '0';
if(2 * cur == S)
ans++;
}
cout << ans << endl;
}
}

int32_t main()
{
ios::sync_with_stdio(0);
cin.tie(0);
for(int tc = 1; tc <= T; tc++)
{
// cout << "Case #" << tc << ": ";
solve();
}
assert(sumN <= 2e5);
return 0;
}


[/details]

Tester (C++)
// Tester: Nikhil_Medam
#include <bits/stdc++.h>
#pragma GCC optimize ("-O3")
using namespace std;
#define IOS ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
#define endl "\n"
#define int long long
#define double long double
const int N = 1e5 + 5;

int t, n, m;
string s;
int32_t main() {
cin >> t;
while(t--) {
cin >> n >> m >> s;
int sum = 0;
for(int i = 0; i < n; i++) {
sum += (s[i] - '0');
}
if(sum == 0) {
cout << n * m << endl;
}
else if ((sum * m) % 2 == 1) {
cout << 0 << endl;
}
else {
if(m % 2 == 0) {
int cnt_0_start = 0, cnt_0_end = 0;
for(int i = 0; i < n; i++) {
if(s[i] == '1') {
break;
}
cnt_0_start++;
}
for(int i = n - 1; i >= 0; i--) {
if(s[i] == '1') {
break;
}
cnt_0_end++;
}
cout << cnt_0_start + cnt_0_end + 1 << endl;
}
else {
int ans = 0, cur_sum = 0;
for(int i = 0; i < n; i++) {
cur_sum += (s[i] - '0');
ans += (cur_sum == sum / 2);
if(cur_sum > sum / 2) {
break;
}
}
cout << ans << endl;
}
}
}
return 0;
}

Editorialist (Python)
for _ in range(int(input())):
n, m = map(int, input().split())
s = input()

# pref[i] = suf[i+1]
# pref[i] + suf[i+1] = S
# pref[i] = S/2
tot = s.count('1')
target = tot*m
if target%2 == 1:
print(0)
continue
if target == 0:
print(n*m)
continue

target //= 2
cur = 0
while m > 0:
if cur + tot < target:
m -= 1
cur += tot
continue
else:
break
ans = 0
for j in range(min(m, 2)):
for i in range(n):
ans += cur == target
cur += s[i] == '1'
print(ans)
`

[/quote]