LENTMO - EDITORIAL

PROBLEM LINK:

Practice

Contest: Division 1

Contest: Division 2

Setter: Abhishek Vanjani

Tester: Radoslav Dimitrov

Editorialist: Teja Vardhan Reddy

DIFFICULTY:

Easy

PREREQUISITES:

Properties of XOR, greedy

PROBLEM:

Given an array A of n numbers and integers k and x. We can perform the following operation any number of times (including zero times). Take exactly k numbers from the array and replace each of them after doing xor with x. For example, we took ith element among k elements, we will replace A_i with A_i \oplus x. We wish to maximise the sum of elements in the array.

EXPLANATION

Case 1: k=n. Then we only have two cases , either the whole array is xor with x or not. We compare answer in both the cases.

Case 2: k \lt n . This is interesting. We will prove a few things before we get to solution.

Claim 1: We can always do operation such that only two elements can get xor with x while rest remains same.

Proof: We will give a construction for this. Lets say the two elements are 1st and 2nd (without loss of generality).

  1. Take subset as \{A_1,A_3,A_4,A_5,...,A_{k+1}\}.Apply the operation on this set.
  2. Take subset as \{A_2,A_3,A_4,...,A_{k+1}\}. Apply the operation.

Now, all the elements from \{A_3,A_4,A_5...,A_{k+1}\} are xor with x two times and A_1,A_2 only once. Hence, in A only A_1,A_2 are xor with x.

Now, using above idea we can xor any even sized subset with x.

Case a: k is even .

Now, we will prove that at every stage even number of elements are xored with x.
Proof: We will prove by induction.
Base case: Initially none of the elements are xored. Hence, zero (which is even) are xored with x.

Lets assume after i operations. There are y (which is even) which are elements xored with x at the end of i operations. Now, in next operation we assume we take z elements from the y which were xored with x and rest which are not xored i.e from left n-y elements. So finally number of elements xored with x will be y-z +k-z (since z have been xored twice now, we need to remove them from y and new elements that are xored for first time are k-z) which is y+k-2*z which is even.

Hence, when k is even and less than n. we can only get any even sized subset xored with x.

Now, we wish to xor only those elements which see a positive rise in its value when got xored with x. So we sort the elements by the amount of rise in value each element gets when xored. Then we start to pair the adjacent elements from highest rise to lowest rise and take only those pairs that give an overall positive rise. And finally take the sum of the elements in the array thus obtained.

Case b: k is odd.
Now, we will prove that we can get any particular element we want xor with x and rest unchanged.

Proof: Lets assume we want to get A_1 xored with x.

Now, using above knowledge we can get A_1,A_2,....A_{k+1} xor with x since k+1 is even.

Now, we just apply operation for the set A_2,A_3,...A_{k+1}. And, now only A_1 is xor with x whereas others are unchanged.

So, now we only apply xor to elements that give positive rise when xored with x and get the new array.
And finally take the sum of the element thus obtained.

TIME COMPLEXITY

Complexity: O(nlog(n)) since we use sort in the case k is even. rest all are O(n) operations.
Hence, total complexity is O(nlog(n)).

We can reduce the complexity to O(n) by avoiding the sort which I leave for you to figure out. Keep the comments flowing with the O(n) idea :slight_smile:

SOLUTIONS:

Setter's Solution
#include <bits/stdc++.h>
using namespace std;
 
int main() {
    
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    srand(time(NULL));
    int T;
    cin>>T;
    while(T--)
    {
        int N;
        cin>>N;
        long long arr[N];
        long long ans=0;
        for(int i=0;i<N;i++)
        {
            cin>>arr[i];
            ans+=arr[i];
        }
        int K;
        long long X;
        cin>>K>>X;
 
        ///We have two choices, either perform XOR on entire array, or leave the array untouched.
        if(K==N)
        {
            long long ans1=0;
            for(int i=0;i<N;i++)
                ans1+=(arr[i]^X);
            cout<<max(ans,ans1)<<endl;
            continue;
        }
        long long diff[N];
        int gain=0;
        ///gain is number of elements that increase when we xor them with X.
        for(int i=0;i<N;i++)
        {
            long long xorvalue=(arr[i]^X);
            diff[i]=xorvalue-arr[i];
            if(diff[i]>0)
            {
                ans+=diff[i];
                gain++;
            }
        }
        ///It can be proven that we always have a way to make elements reach there maxima if gain%2==0 or (gain%2!=0 && K%2!=0). In the last case left, we can reduce the array
        ///to N-1 elements set to there maxima, and one element left. We find the element which will make the ans decrease the least. 
        if(gain%2!=0 && K%2==0)
        {
            long long x=1000000000000;
            for(int i=0;i<N;i++)
                x=min(x,abs(diff[i]));
            ans-=x;
        }
        cout<<ans<<endl;
 
 
    }
} 
Tester's Solution
import sys
 
def read_line():
    return sys.stdin.readline()[:-1]
 
def read_int():
    return int(sys.stdin.readline())
 
def read_int_line():
    return [int(v) for v in sys.stdin.readline().split()]
 
############
# Solution #
 
T = read_int()
for _test in range(T):
    N = read_int()
    A = read_int_line()
    K = read_int()
    X = read_int()
    
    # Corner case 
    if X == 0 or K == N:
        ans1 = 0
        ans2 = 0
        for v in A:
            ans1 += v
            ans2 += v ^ X
 
        print(max(ans1, ans2))
        continue
 
    # We can prove that the answer only depends on the parity of K
    K %= 2
 
    ans = 0
    cnt = 0
    for v in A:
        val = max(v, v ^ X)
        ans += val
        if val == (v ^ X):
            cnt += 1
    
    # We must either remove, or add one number to the group that was XOR-ed 
    if K == 0 and cnt % 2 == 1:
        rem = 10**18
        for x in A:
            v = ((x ^ X) - x)       
            if v >= 0:
                rem = min(rem, v)
 
        add = -10**18
        for x in A:
            v = ((x ^ X) - x)       
            if v <= 0:
                add = max(add, v)
        
        ans = max(ans - rem, ans + add)
 
    print(ans)
Editorialist's Solution
//teja349
#include <bits/stdc++.h>
#include <vector>
#include <set>
#include <map>
#include <string>
#include <cstdio>
#include <cstdlib>
#include <climits>
#include <utility>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <iomanip>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp> 
//setbase - cout << setbase (16); cout << 100 << endl; Prints 64
//setfill -   cout << setfill ('x') << setw (5); cout << 77 << endl; prints xxx77
//setprecision - cout << setprecision (14) << f << endl; Prints x.xxxx
//cout.precision(x)  cout<<fixed<<val;  // prints x digits after decimal in val
 
using namespace std;
using namespace __gnu_pbds;
 
#define f(i,a,b) for(i=a;i<b;i++)
#define rep(i,n) f(i,0,n)
#define fd(i,a,b) for(i=a;i>=b;i--)
#define pb push_back
#define mp make_pair
#define vi vector< int >
#define vl vector< ll >
#define ss second
#define ff first
#define ll long long
#define pii pair< int,int >
#define pll pair< ll,ll >
#define sz(a) a.size()
#define inf (1000*1000*1000+5)
#define all(a) a.begin(),a.end()
#define tri pair<int,pii>
#define vii vector<pii>
#define vll vector<pll>
#define viii vector<tri>
#define mod (1000*1000*1000+7)
#define pqueue priority_queue< int >
#define pdqueue priority_queue< int,vi ,greater< int > >
#define flush fflush(stdout) 
#define primeDEN 727999983
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
 
// find_by_order()  // order_of_key
typedef tree<
int,
null_type,
less<int>,
rb_tree_tag,
tree_order_statistics_node_update>
ordered_set;
 
#define int ll
int a[123456];
main(){
    std::ios::sync_with_stdio(false); cin.tie(NULL);
    int t;
    cin>>t;
    while(t--){
        int n;
        cin>>n;
        int i;
        ll sumi=0,sum1=0;
        rep(i,n){
            cin>>a[i];
        }
        int k,x;
        cin>>k>>x;
        vi vec;
        rep(i,n){
            sumi+=a[i];
            sum1+=a[i]^x;
            vec.pb(a[i]-(a[i]^x));
        }
        if(k==n){
            cout<<max(sumi,sum1)<<endl;
            continue;
        }
        int val;
        sort(all(vec));
        if(k%2){
            rep(i,vec.size()){
                val=vec[i];
                val*=-1;
                if(val>0)
                    sumi+=val;
            }
        }
        else{
            for(i=0;2*i+1<vec.size();i++){
                val=vec[2*i]+vec[2*i+1];
                val*=-1;
                if(val>0)
                    sumi+=val;
            }
 
 
        }
        cout<<sumi<<endl;
 
    }
    return 0;   
}

Feel free to Share your approach, If it differs. Suggestions are always welcomed. :slight_smile:

8 Likes

I used a little diffrent approach. For each integer in the given array, I stored v[i]^x-v[i]. The idea was that the sum of all integers in the given array can always be obtained, but if we replace any element with its xor, we simply have to add v[i]^x -v[i] to the sum. Now I used greedy to choose which elements should be chosen to be xored depending on the parity of (v[i]^x - v[i]).
Here is link to my accepted solution,

2 Likes

I did something like this. Instead of using greedy I sorted the difference of XOR value and original value (in an array D[]). Then taking k most greater elements I checked whether their sum was positive in a loop. If it was positive I added it to the original sum of elements. Or if it was negative I ended the loop and printed the answer. I still got WA , can you tell me what I did wrong ?
https://www.codechef.com/viewsolution/24843882

2 Likes

how the answer for this test case is 6:

1
6
0 0 0 1 1 1
5
1

0 0 0 1 1 1
//doing first time
1 1 1 0 0 1
//doing second time
0 0 0 1 0 0
// doing third time
1 1 1 1 1 1

1 Like

As k is odd you can individually xor any element you wish. So u can xor all zeros and leave ones. Hence the ans becomes 6.

I get my mistake now. Taking K at a time I was missing positive rise elements (if the sum was negative) which could have been xored using the techniques of the editorialist.

well I have the O(n) solution look at this…

https://www.codechef.com/viewsolution/24628476

my approach https://www.youtube.com/watch?v=yMuaIjKlTfg&lc=UgyKLEMrUKk3l3nVuo14AaABAg in O(N)

3 Likes

According to case b we proved that a single xor can be obtained(is it possible to xor with single bag and leave rest as it is).But how to prove that we can xor with any number of bags if k is odd ? In other words - suppose n = 20 , k =3 , then how to prove that we can xor with 1 or 2 or 3 … or 20 bags and leaving rest as it is.If i missed something in editorial please tell .

2 Likes

can you please explain the second case when k is odd

1 Like

I used the same logic but got WA cuz i missed out a corner case :expressionless:

2 Likes

Test case 2
(others cases as you mentioned)
a[i] is 0 or 1
x = 1
k<n
so our question will be a sequence of 0s and 1s
z = no. of zeroes
if(k is odd or zeroes is even)
ans = n
else
ans = n-1

For other test cases
let a[i] be some element
make a new array p
p[i][0] is a[i] and p[i][1] is a[i]^x if a[i]<a[i]^x
else
p[i][1] is a[i] and p[i][0] is a[i]^x
(denoting lower element with 0 and higher with 1 :))
gs += p[i][1] for all i
for the ans n case gs is the ans
for the n-1 case loop for every i and find max of
gs - p[i][1] + p[i][0]
O(n) solution :slight_smile:

Nice editorial
:slight_smile:

case b: for k is odd how is he doing xor for k+1 and then xoring k in next step

2 Likes
    #include<iostream>
    #include<stdio.h>
    #include<vector>
    #include <stack>
    #include <bits/stdc++.h>

    #define rep(i,n) for(int i=0;i<n;i++)
    #define repA(i,a,n) for(int i=a;i<=n;i++)
    #define repD(i,a,n) for(int i=a;i>=n;i--)
    #define ll long long int
    #define fi first
    #define se second

    using namespace std;

    int main()
    {
    	ios_base::sync_with_stdio(false);
    	cin.tie(NULL);
    	int t;cin>>t;
    	while(t--)
    	{
    		int n;cin>>n;
    		ll a[n],ch=0;
    		rep(i,n) cin>>a[i];
    		rep(i,n) ch+=a[i];// sum of all elements
    		int k,x;cin>>k>>x;

    		int ct=0;
    		ll b[n];
    		rep(i,n)
    		{
    			b[i]=((a[i])^x)-a[i];
    			if(b[i]>0) ct++;// counts the number which are increasing after taking XOR
    		}

    		sort(b,b+n);
    		int end=n-1,j=0;
    		while(j<ct)
    		{
    			ch+=b[end-j];
    			if(j!=ct-1) b[end-j]=b[end-j]*(-1);
    			j++;
    		}
    		sort(b,b+n);

    		if((ct%2==0) || (ct%2 && k%2))  /* if ct is EVEN and if both ct & k are ODD, we can take all increasing XOR*/
    			printf("%lld\n",ch);
    		else
    		{
    			if(b[end]+b[end-1]>0) ch=ch+b[end-1];
    			else ch=ch-b[end];
    			printf("%lld\n",ch);
    		}
    	}
    	return 0;
    }

Please help, why my code is not passing all the test cases?
I am storing the change in number after taking its XOR. If no of positive change is even then we can take all the numbers, also if both no of positive changes & ‘k’ is odd we can take all the numbers.
In rest of the cases which will take (total no of positive changes) - 1 or we can also take the one remaining positive change + smallest negative change such that there sum is positive.
Thanks in advance.

I did something similar, how did you overcome it? Can you explain in a little more detail?

I used same logic but there is something wrong in my code can you tell me where I am wrong. Here is my code: https://www.codechef.com/viewsolution/24877754

Hi, I guess you missed few cases. suppose n=k=3 and first two elements of vec are positive while the third is negative. As per your code, you will pick first two and not the last, but actually you will have to pick all the three as you have no choice.
Correct me if I am wrong.

But in editorial @teja349 He prove that we can take two elements instead of k elements.
SO that I take two elements each time.