MAX1S - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: Arun Sharma
Testers: Takuki Kurokawa, Utkarsh Gupta
Editorialist: Nishank Suresh

DIFFICULTY:

2255

PREREQUISITES:

Basic combinatorics, Maximum subarray sum

PROBLEM:

You have a binary string S. At most once, you can pick a substring of S and flip all its characters.

If the substring is chosen optimally, what is the maximum value of the sum of the count of 1's of each substring, across all substrings?

EXPLANATION:

Let’s first forget about the flipping operation and try to quickly compute the value for a given string. Obviously iterating over all substrings would take \mathcal{O}(N^2) time, which is too much.

The quantity we want to calculate is the count of 1's in each substring, summed across all substrings.
Let’s look at it from a slightly different perspective: how much does each 1 present in the string ‘contribute’ to the final answer?

Answer

Suppose S_i = 1. S_i then contributes +1 to the final answer exactly once for each subarray it is in.

A simple combinatorial argument tells us that the number of subarrays it is present in equals i\cdot (N-i+1): the left endpoint of the subarray has i choices (1, 2, 3, \ldots, i) and the right has N-i+1 choices (i, i+1, i+2, \ldots, N).

So, let’s define B_i = i\cdot(N-i+1). Then, the answer for S is simply the sum of B_i across all those positions i such that S_i = 1, which can easily be computed in \mathcal{O}(N).

Now let’s look at how flipping can change things:

  • If S_i = 1, then flipping this position will decrease the answer by B_i (since it used to contribute to the sum, and won’t after the flip).
  • If S_i = 0, then flipping this position will increase the answer by B_i.

Flipping a substring is then equivalent to adding/subtracting the relevant values of B_i, which is essentially just a subarray sum!
In fact, suppose we define another array C as follows:

  • C_i = +B_i if S_i = 0
  • C_i = -B_i if S_i = 1

Then, it’s easy to see that flipping the range [L, R] in S changes the answer by exactly the subarray sum of C from L to R.

Of course, we want this change to be as large as possible, since our aim is to maximize the answer. This means we want to find the maximum subarray sum of C, which can be done in \mathcal{O}(N) in a variety of ways.

Thus, the final solution is:

  • Compute the B and C arrays as mentioned above.
  • Using B, compute the answer for S without flips.
  • Then, find the maximum subarray sum of C and add it to the answer.

TIME COMPLEXITY

\mathcal{O}(N) per test case.

CODE:

Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
            now++;
        }
        return now;
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        // cerr << res << endl;
        return res;
    }

    string readString(int minl, int maxl, const string &pattern = "") {
        assert(minl <= maxl);
        string res = readOne();
        assert(minl <= (int) res.size());
        assert((int) res.size() <= maxl);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res = stoi(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res = stoll(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

int main() {
    input_checker in;
    int tt = in.readInt(1, 1e5);
    in.readEoln();
    int sn = 0;
    while (tt--) {
        string s = in.readString(1, 3e5, "01");
        in.readEoln();
        int n = (int) s.size();
        sn += n;
        vector<long long> a(n);
        long long t = 0;
        for (int i = 0; i < n; i++) {
            long long c1 = (n - i) * 1LL * (i + 1);
            if (s[i] == '0') {
                a[i] = c1;
            } else {
                t += c1;
                a[i] = -c1;
            }
        }
        vector<long long> pref(n + 1);
        for (int i = 0; i < n; i++) {
            pref[i + 1] = pref[i] + a[i];
        }
        long long mn = 0;
        long long ans = t;
        for (int i = 0; i < n + 1; i++) {
            mn = min(mn, pref[i]);
            ans = max(ans, t + pref[i] - mn);
        }
        cout << ans << '\n';
    }
    assert(sn <= 3e5);
    in.readEof();
    return 0;
}
Editorialist's code (Python)
for _ in range(int(input())):
	s = input()
	ans = mx = cur = 0
	for i in range(len(s)):
		subarrays = (i+1) * (len(s)-i)
		if s[i] == '1':
			ans += subarrays
			subarrays *= -1
		cur += subarrays
		cur = max(cur, 0)
		mx = max(mx, cur)
	print(ans + mx)
1 Like

In my solution : CodeChef: Practical coding for everyone
In line no 181.
when i use the function1(subarray) to find longest sum it fails last 4 test case , but when i use function2 (maxSubArraySum) it passes all test cases why it is so .
Anyone please ?

Solved this problem using Dynamic Programming.
Here’s the link: Link
My solution is using a similar idea used here: Link
@kartik8800 orz

https://www.codechef.com/viewsolution/78994495

Can someone tell me What’s wrong with my approach ?

  1. Each element contribution was (i+1)*(n-i) (0-based indexing)
  2. So I took DP[i] which means max sum we can in [0…i] array by flipping some array which starts from any index [0…i] and ends at i
  3. Then my answer would be max(dp[i]+suffix[i+1]) for i from 0 to n
  4. Here Suffix[i] indicates sum we get from [i…n-1] without flipping

long long was culprit. I took int in place of long long in some places.
Above approach works.
:disappointed:

I tried a little different approach we can replaced 0 with 1 and 1 with -1 and then find max subarray (longest if many present) and then simply flip within this subarray and find the value. which somehow looks similar to editorial. So can anyone suggest me if this approach is correct or not because I am getting wrong answer.
Thanks :slight_smile:

can we solve it using dp?

@aadarshsinha
You were missing the equality condition in your function. You considered only > and < cases, not = case. This is where it was wrong

        if (nums[i] > nums[i] + currMax) {
            currMax = nums[i];
            startIndex = i; 
        }
        else if (nums[i] < nums[i] + currMax) {
            currMax = nums[i] + currMax;
        }

Here is the accepted code with changes : https://www.codechef.com/viewsolution/79013058

All given testcases works (Time Complexity : O(n)),but on submission,it shows error.Can someone help me with this ,Please!

public static void main (String[] args) throws java.lang.Exception
{
// your code goes here
Scanner s=new Scanner(System.in);
int t=s.nextInt();

	while(t>0)
	{
	    String str=s.next();
	    int n=str.length();
	    
	    char[] arr=new char[n];
	    arr=str.toCharArray();
	    int k=0;
	    int[] ar=new int[n];
	    
	    for(char temp :arr)
	    {
	        ar[k++]=Integer.parseInt(String.valueOf(temp));
	    }
	    
	    for(int i=0;i<n;i++)
	    {
	        if(ar[i]==0)
	        {
	            ar[i]=1;
	        }
	        else{
	            ar[i]=-1;
	        }
	    }
	    int add=0;
	    int max=Integer.MIN_VALUE;
	    int indr=0;
	    int indl=0;
	     
	    for(int i=0;i<n;i++)
	    {
	       add+=ar[i];
	       if(add>max)
	       {
	           max=add;
	           indr=i;
	       }
	       if(add<0)
	       {
	           add=0;
	           indl=i+1;
	       }
	        
	    }
	    if(indl<n)
	    {
	        for(int i=indl;i<=indr;i++)
	        {
		        if(ar[i]==1)
		        {
		            ar[i]=-1;
		        }
		        else if(ar[i]==-1)
		        {
		            ar[i]=1;
		        }
	        }
	        
	        for(int i=0;i<n;i++)
	        {
		        
		        if(ar[i]==-1)
		        {
		            ar[i]=1;
		        }
		        else{
		            ar[i]=0;
		        }
	        }
	        
	        str="";
	        
	        for(int i=0;i<n;i++)
	        {
	            str=str+ar[i];
	        }
	        
	    }
	    
	    
	    int res=0;
	    
	    for(int i=0;i<n;i++)
	    {
	        if(str.charAt(i)=='1')
	        {
	            res+=(i+1)*(n-i);
	        }
	        
	    }
	    System.out.println(res);
	    t--;
	}

}

@nitinkumar1238
Thank You so much for figuring it out.