JOINTXOR - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: Utkarsh Gupta
Testers: Nishank Suresh, Takuki Kurokawa
Editorialist: Nishank Suresh

DIFFICULTY:

2398

PREREQUISITES:

None

PROBLEM:

You have a binary string S of length N. Find the maximum possible xor of two of its substrings of equal length. The substrings may overlap.

EXPLANATION:

First, get a couple of simple edge-cases out of the way: if S consists of only zeros or only ones, the answer is obviously 0, since any two substrings of the same length will be equal.

Now, say S has both 0's and 1's. Let A represent an answer substring, and suppose the substrings chosen to obtain A were [l_1, r_1] and [l_2, r_2].
We can make a couple of observations about A:

  • A must begin with a 1. If it doesn’t, we can simply take [l_1+1, r_1] and [l_2+1, r_2] to obtain the same decimal value with a shorter answer string.
  • Once we know that A begins with a 1, it’s better for A to be as long as possible: any length K binary string starting with 1 represents a strictly larger integer than any length K-1 binary string.

This leads us to a natural ‘solution’:

  • Let L_1 be the position of the first occurrence of a 0, and L_2 be the position of first occurrence of a 1. Without loss of generality, let L_1 = 1 and L_2 \gt 1.
  • Fix R_1 and R_2 to be as large as possible while maintaining equality of length. In practice, this means that R_2 = N and R_1 = 1 + N - L_2.
  • Take the xor of the two substrings obtained.

However, this solution is not entirely correct: for example, consider S = 0001101. Here, it’s optimal to choose L_1 = 2 and L2 = 4, to obtain the strings 0011 and 1101 for a xor of 1110.

This should give you an idea of what went wrong with the initial solution: choosing L_2 and R_2 was correct, the issue is that L_1 can be anything from 1 to (L_2 - 1): we need to find which one of them is optimal.

This can be done greedily, since we also need to maximize our answer greedily.

Let L_3 be the first position \gt L_2 that contains a 0 (recall that for us, S_1 = 0 and S_{L_2} = 1. If this is not the case for S, find a 1 instead).
Ideally, when we choose the position of L_1, we’d like it to be such that positions L_2, L_2+1, \ldots, L_3-1 get matched with a 0, and L_3 gets matched with a 1.
The only way to do this is to set L_1 = L_2 - (L_3 - L_2).

So, we have a few cases to deal with:

  • First, it’s possible that L_3 might not exist at all, i.e, there is no 0 after L_2. In this case, note that it’s optimal to choose L_1 = 1.
  • Now, suppose L_3 exists. Set L_1 = L_2 - (L_3 - L_2). We have two cases here:
    • If L_1 \geq 1, then nothing more needs to be done: we have our (L_1, R_1) and (L_2, R_2) so we can find the xor of the corresponding substrings.
    • However, it’s possible that L_1 \leq 0. In this case, note that it’s once again optimal to choose L_1 = 1.

So, based on these cases, find the values of L_1, R_1, L_2, R_2 and take the xor of the corresponding substrings to obtain A.

Note that the output is the decimal value represented by the answer string A. This can be done easily as follows:

  • Initialize the answer ans to 0
  • Iterate across the characters of A from left to right
  • At each iteration, multiply ans by 2. Then, if the current character is 1, add 1 to ans.
  • Remember to always keep ans modulo 10^9 + 7.

TIME COMPLEXITY

\mathcal{O}(N) per test case.

CODE:

Setter's code (C++)
//Utkarsh.25dec
#include <bits/stdc++.h>
#define ll long long int
#define pb push_back
#define mp make_pair
#define mod 1000000007
#define vl vector <ll>
#define all(c) (c).begin(),(c).end()
using namespace std;
ll power(ll a,ll b) {ll res=1;a%=mod; assert(b>=0); for(;b;b>>=1){if(b&1)res=res*a%mod;a=a*a%mod;}return res;}
ll modInverse(ll a){return power(a,mod-2);}
const int N=500023;
bool vis[N];
vector <int> adj[N];
long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);

            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }

            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }

            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}
int sumN=0;
void solve()
{
    // JOINTXOR
    int n=readInt(2,2000000,'\n');
    sumN+=n;
    assert(sumN<=2000000);
    string s=readString(n,n,'\n');
    int ones=0,zeros=0;
    for(int i=0;i<n;i++)
    {
        if(s[i]=='0')
            zeros++;
        else
            ones++;
        assert(s[i]=='0' || s[i]=='1');
    }
    if(zeros==0 || ones==0)
    {
        cout<<0<<'\n';
        return;
    }
    int st;
    for(int i=1;i<n;i++)
    {
        if(s[i]!=s[0])
        {
            st=i;
            break;
        }
    }
    int l2=st,r2=n-1;
    int diff=r2-l2;
    int nxt=0;
    for(int i=st+1;i<n;i++)
    {
        if(s[i]!=s[st])
        {
            nxt=i;
            break;
        }
    }
    int l1,r1;
    if(nxt==0)
        l1=0,r1=l1+diff;
    else
    {
        l1=l2-(nxt-l2);
        l1=max(l1,0);
        r1=l1+diff;
    }
    string maxi="";
    for(int i=0;i<=r1-l1;i++)
    {
        if(s[l1+i]!=s[l2+i])
            maxi+='1';
        else
            maxi+='0';
    }
    ll ans=0;
    for(int i=0;i<maxi.length();i++)
    {
        ans*=2;
        if(maxi[i]=='1')
            ans++;
        ans%=mod;
    }
    cout<<ans<<'\n';
}
int main()
{
    #ifndef ONLINE_JUDGE
    freopen("input.txt", "r", stdin);
    freopen("output.txt", "w", stdout);
    #endif
    ios_base::sync_with_stdio(false);
    cin.tie(NULL),cout.tie(NULL);
    int T=readInt(1,100000,'\n');
    while(T--)
        solve();
    assert(getchar()==-1);
    cerr << "Time : " << 1000 * ((double)clock()) / (double)CLOCKS_PER_SEC << "ms\n";
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;

int main() {
    int tt;
    cin >> tt;
    while (tt--) {
        int n;
        cin >> n;
        string s;
        cin >> s;
        vector<int> rle;
        for (int i = 0, j = 0; i < n; i = j) {
            while (j < n && s[i] == s[j]) {
                j++;
            }
            rle.emplace_back(j - i);
        }
        if (rle[0] == n) {
            cout << 0 << '\n';
        } else {
            int x = max(0, rle[0] - rle[1]);
            int y = rle[0];
            long long ans = 0;
            for (int i = 0; y + i < n; i++) {
                ans *= 2;
                if (s[x + i] != s[y + i]) {
                    ans++;
                }
                ans %= 1000000007;
            }
            cout << ans << endl;
        }
    }
    return 0;
}
Editorialist's code (C++)
#include "bits/stdc++.h"
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

const int mod = 1e9 + 7;

int main()
{
	ios::sync_with_stdio(false); cin.tie(0);

	auto calc = [&] (string s) {
		int pw = 1, ret = 0;
		reverse(begin(s), end(s));
		for (auto c : s) {
			c -= '0';
			ret += c * pw; ret %= mod;
			pw *= 2; pw %= mod;
		}
		return ret;
	};

	int t; cin >> t;
	while (t--) {
		int n; cin >> n;
		string s; cin >> s;
		int l1 = -1, l2 = -1;
		for (int i = 0; i < n; ++i) {
			if (s[i] == '0') {
				if (l2 == -1) l2 = i;
			}
			else {
				if (l1 == -1) l1 = i;
			}
		}
		if (l1 == -1 or l2 == -1) { // Same char
			cout << 0 << '\n';
			continue;
		}
		string ans = "";
		if (l1 > l2) swap(l1, l2);
		int st = l2 - 1;
		for (int i = l2 + 1; st > 0 and i < n; ++i) {
			if (s[i] == s[0]) break;
			--st;
		}
		for (int i = l2; i < n; ++i) {
			ans += '0' + (s[i] != s[st+i-l2]);
		}
		cout << calc(ans) << '\n';
	}
}
4 Likes

can anyone tell me my mistake ?
https://www.codechef.com/viewsolution/78390426

At first glance, I can’t find where you calculate the mod.

@iceknight1093 can you please explain this more clearly??

Think about how you’d create a binary string with the largest value:

  • Ideally, the first character is a 1
  • If you can do that, it’d be nice if the second character could be 1 too
  • If you can do that, it’d be nice if the third character could be 1 too

and so on.

Now think about what that means for the xor of two strings: their first characters should be different, if that’s possible their second characters should be different, if that’s possible their third, \ldots

Now look at the definitions of L_1, L_2, L_3.

  • L_2 is the first 0 in the string. (Note that the editorial assumes S_1 = 1, so L_2 \gt 1)
  • L_3 is the first 1 after L_2, which means everything at positions L_2+1, L_2+2, \ldots, L_3-1 is a 0
  • L_1 is something \lt L_2, and in particular we know everything at a position \lt L_2 is a 1.

Put these facts together and you’ll see what I meant.

2 Likes