DISJOINTXOR - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: Utkarsh Gupta
Testers: Nishank Suresh, Takuki Kurokawa
Editorialist: Nishank Suresh

DIFFICULTY:

2700

PREREQUISITES:

None

PROBLEM:

You have a binary string S of length N. Find the maximum possible xor of two of its substrings of equal length. The substrings cannot overlap.

EXPLANATION:

Once again, we take care of some obvious corner cases first: if S contains only 0's or only 1's, the answer is 0.

Now, to maximize the answer, as always we try to first maximize its length.

Finding the length of the answer can be done easily in \mathcal{O}(N^2), as follows:

  • Iterate across all pairs (i, j) such that 1 \leq i \lt j \leq N and S_i \neq S_j.
  • Assume that the first substring starts at i and the second at j.
  • The maximum length we can get from these two is then \min(j - i, N - j + 1).
  • The length of the answer is the maximum of this value across all valid pairs (i, j).

This tells us the length of the answer, say K. Now we need to look at all pairs of disjoint substrings of length K. However, there can be \mathcal{O}(N^2) such pairs, and finding the xor of each one would make us take \mathcal{O}(N^3) time in total, which is of course too slow.

Instead, we can make some observations. Consider the left endpoints L_1, L_2 \ (L_1 \lt L_2) of an optimal answer. Then,

  • Either L_2 = R_1 + 1, i.e, L_2 = L_1 + K; or
  • R_2 = N, i.e, L_2 = N-K+1
Proof

Suppose neither of the above cases holds, i.e, R_2 \lt N and L_2 \gt R_1 + 1.
Then, we can simply increase both R_1 and R_2 by 1 to get longer disjoint strings with a strictly higher xor (since the xor would at least multiply by 2), which is a contradiction.

So, one of the above two conditions must hold.

This brings down the number of candidate pairs to \mathcal{O}(N) — once a left endpoint L_1 is fixed, there are only two potential candidates for L_2, so simply take both of them. Do this for each 1 \leq L_1 \leq N.

Then, compute the xors of all \leq 2N candidate pairs, and take the maximum across them all. This gives us a solution in \mathcal{O}(N^2) (since xor computation and string comparison are both \mathcal{O}(N)).

Once again, note that the output is the decimal value of the answer, modulo 10^9 + 7. This can be computed in \mathcal{O}(N), and the method to do so is detailed in the editorial for JOINTXOR.

TIME COMPLEXITY

\mathcal{O}(N^2) per test case.

CODE:

Setter's code (C++)
//Utkarsh.25dec
#include <bits/stdc++.h>
#define ll long long int
#define pb push_back
#define mp make_pair
#define mod 1000000007
#define vl vector <ll>
#define all(c) (c).begin(),(c).end()
using namespace std;
ll power(ll a,ll b) {ll res=1;a%=mod; assert(b>=0); for(;b;b>>=1){if(b&1)res=res*a%mod;a=a*a%mod;}return res;}
ll modInverse(ll a){return power(a,mod-2);}
const int N=500023;
bool vis[N];
vector <int> adj[N];
long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);

            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }

            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }

            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}
int sumN=0;
void solve()
{
    // DISJOINTXOR
    int n=readInt(2,5000,'\n');
    sumN+=(n*n);
    assert(sumN<=25000000);
    string s=readString(n,n,'\n');
    vector <int> ones,zeros;
    for(int i=0;i<n;i++)
    {
        assert(s[i]=='0' || s[i]=='1');
        if(s[i]=='1')
            ones.pb(i);
        else
            zeros.pb(i);
    }
    if(ones.size()==0 || zeros.size()==0)
    {
        cout<<0<<'\n';
        return;
    }
    int l=1,r=n/2;
    while(l<=r)
    {
        int mid=(l+r)/2;
        int flag=0;
        {
            int a=ones[0];
            auto it=lower_bound(all(zeros),a+mid);
            if(it!=zeros.end())
            {
                if(((*it)+mid-1)<n)
                    flag=1;
            }
        }
        {
            int a=zeros[0];
            auto it=lower_bound(all(ones),a+mid);
            if(it!=ones.end())
            {
                if(((*it)+mid-1)<n)
                    flag=1;
            }
        }
        if(flag)
            l=mid+1;
        else
            r=mid-1;
    }
    int len=r;
    vector <string> v;
    // Continuous segment of length 2*len
    for(int i=0;i<n;i++)
    {
        int l1=i,r1=i+len-1;
        int l2=r1+1,r2=l2+len-1;
        if(r2>=n)
            break;
        string maxi="";
        for(int i=0;i<=r1-l1;i++)
        {
            if(s[l1+i]!=s[l2+i])
                maxi+='1';
            else
                maxi+='0';
        }
        v.pb(maxi);
    }
    // Suffix of length len
    int r2=n-1,l2=r2-len+1;
    for(int i=0;i<n;i++)
    {
        int l1=i,r1=l1+len-1;
        if(r1>=l2)
            break;
        string maxi="";
        for(int i=0;i<=r1-l1;i++)
        {
            if(s[l1+i]!=s[l2+i])
                maxi+='1';
            else
                maxi+='0';
        }
        v.pb(maxi);
    }
    string ans=v[0];
    for(int i=1;i<v.size();i++)
        ans=max(ans,v[i]);
    ll out=0;
    for(int i=0;i<ans.length();i++)
    {
        out*=2;
        if(ans[i]=='1')
            out++;
        out%=mod;
    }
    cout<<out<<'\n';
}
int main()
{
    #ifndef ONLINE_JUDGE
    freopen("input.txt", "r", stdin);
    freopen("output.txt", "w", stdout);
    #endif
    ios_base::sync_with_stdio(false);
    cin.tie(NULL),cout.tie(NULL);
    int T=readInt(1,100000,'\n');
    while(T--)
        solve();
    assert(getchar()==-1);
    cerr << "Time : " << 1000 * ((double)clock()) / (double)CLOCKS_PER_SEC << "ms\n";
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;

int main() {
    int tt;
    cin >> tt;
    while (tt--) {
        int n;
        cin >> n;
        string s;
        cin >> s;
        string ans;
        for (int i = 0; i < n; i++) {
            for (int j = i + 1; j < n; j++) {
                if (s[i] != s[j]) {
                    if (min(j - i, n - j) < (int) ans.size()) {
                        continue;
                    }
                    string t;
                    for (int k = 0; k < min(j - i, n - j); k++) {
                        t += (char) ('0' + (s[i + k] ^ s[j + k]));
                    }
                    if (t.size() > ans.size() || t > ans) {
                        ans = t;
                    }
                }
            }
        }
        long long output = 0;
        for (char c : ans) {
            output *= 2;
            output += c - '0';
            output %= 1000000007;
        }
        cout << output << endl;
    }
    return 0;
}
Editorialist's code (C++)
#include "bits/stdc++.h"
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

const int mod = 1e9 + 7;

int main()
{
	ios::sync_with_stdio(false); cin.tie(0);

	auto calc = [&] (string s) {
		int pw = 1, ret = 0;
		reverse(begin(s), end(s));
		for (auto c : s) {
			c -= '0';
			ret += c * pw; ret %= mod;
			pw *= 2; pw %= mod;
		}
		return ret;
	};

	int t; cin >> t;
	while (t--) {
		int n; cin >> n;
		string s; cin >> s;
		int len = 0;
		for (int i = 0; i < n; ++i) {
			for (int j = i+1; j < n; ++j) {
				if (s[i] == s[j]) continue;
				len = max(len, min(j - i, n - j));
			}
		}
		if (len == 0) {
			cout << 0 << '\n';
			continue;
		}
		vector<array<int, 2>> active;
		string ans = "1";
		for (int i = 0; i < n; ++i) {
		    if (i + 2*len <= n) {
		        if (s[i] != s[i+len]) active.push_back({i, i+len});
		        if (s[i] != s[n-len]) active.push_back({i, n-len});
		    }
		}
		for (int pos = 1; pos < len; ++pos) {
			vector<array<int, 2>> take;
			for (auto [i, j] : active) {
				if (s[i+pos] != s[j+pos]) take.push_back({i, j});
			}
			if (take.empty()) ans += '0';
			else {
				ans += '1';
				active = take;
			}
		}
		cout << calc(ans) << '\n';	
	}
}
2 Likes