PROBLEM LINK:
Setter: Abhishek Vanjani
Tester: Radoslav Dimitrov
Editorialist: Teja Vardhan Reddy
DIFFICULTY:
Easy
PREREQUISITES:
Properties of XOR, greedy
PROBLEM:
Given an array A of n numbers and integers k and x. We can perform the following operation any number of times (including zero times). Take exactly k numbers from the array and replace each of them after doing xor with x. For example, we took ith element among k elements, we will replace A_i with A_i \oplus x. We wish to maximise the sum of elements in the array.
EXPLANATION
Case 1: k=n. Then we only have two cases , either the whole array is xor with x or not. We compare answer in both the cases.
Case 2: k \lt n . This is interesting. We will prove a few things before we get to solution.
Claim 1: We can always do operation such that only two elements can get xor with x while rest remains same.
Proof: We will give a construction for this. Lets say the two elements are 1st and 2nd (without loss of generality).
- Take subset as \{A_1,A_3,A_4,A_5,...,A_{k+1}\}.Apply the operation on this set.
- Take subset as \{A_2,A_3,A_4,...,A_{k+1}\}. Apply the operation.
Now, all the elements from \{A_3,A_4,A_5...,A_{k+1}\} are xor with x two times and A_1,A_2 only once. Hence, in A only A_1,A_2 are xor with x.
Now, using above idea we can xor any even sized subset with x.
Case a: k is even .
Now, we will prove that at every stage even number of elements are xored with x.
Proof: We will prove by induction.
Base case: Initially none of the elements are xored. Hence, zero (which is even) are xored with x.
Lets assume after i operations. There are y (which is even) which are elements xored with x at the end of i operations. Now, in next operation we assume we take z elements from the y which were xored with x and rest which are not xored i.e from left n-y elements. So finally number of elements xored with x will be y-z +k-z (since z have been xored twice now, we need to remove them from y and new elements that are xored for first time are k-z) which is y+k-2*z which is even.
Hence, when k is even and less than n. we can only get any even sized subset xored with x.
Now, we wish to xor only those elements which see a positive rise in its value when got xored with x. So we sort the elements by the amount of rise in value each element gets when xored. Then we start to pair the adjacent elements from highest rise to lowest rise and take only those pairs that give an overall positive rise. And finally take the sum of the elements in the array thus obtained.
Case b: k is odd.
Now, we will prove that we can get any particular element we want xor with x and rest unchanged.
Proof: Lets assume we want to get A_1 xored with x.
Now, using above knowledge we can get A_1,A_2,....A_{k+1} xor with x since k+1 is even.
Now, we just apply operation for the set A_2,A_3,...A_{k+1}. And, now only A_1 is xor with x whereas others are unchanged.
So, now we only apply xor to elements that give positive rise when xored with x and get the new array.
And finally take the sum of the element thus obtained.
TIME COMPLEXITY
Complexity: O(nlog(n)) since we use sort in the case k is even. rest all are O(n) operations.
Hence, total complexity is O(nlog(n)).
We can reduce the complexity to O(n) by avoiding the sort which I leave for you to figure out. Keep the comments flowing with the O(n) idea
SOLUTIONS:
Setter's Solution
#include <bits/stdc++.h>
using namespace std;
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
srand(time(NULL));
int T;
cin>>T;
while(T--)
{
int N;
cin>>N;
long long arr[N];
long long ans=0;
for(int i=0;i<N;i++)
{
cin>>arr[i];
ans+=arr[i];
}
int K;
long long X;
cin>>K>>X;
///We have two choices, either perform XOR on entire array, or leave the array untouched.
if(K==N)
{
long long ans1=0;
for(int i=0;i<N;i++)
ans1+=(arr[i]^X);
cout<<max(ans,ans1)<<endl;
continue;
}
long long diff[N];
int gain=0;
///gain is number of elements that increase when we xor them with X.
for(int i=0;i<N;i++)
{
long long xorvalue=(arr[i]^X);
diff[i]=xorvalue-arr[i];
if(diff[i]>0)
{
ans+=diff[i];
gain++;
}
}
///It can be proven that we always have a way to make elements reach there maxima if gain%2==0 or (gain%2!=0 && K%2!=0). In the last case left, we can reduce the array
///to N-1 elements set to there maxima, and one element left. We find the element which will make the ans decrease the least.
if(gain%2!=0 && K%2==0)
{
long long x=1000000000000;
for(int i=0;i<N;i++)
x=min(x,abs(diff[i]));
ans-=x;
}
cout<<ans<<endl;
}
}
Tester's Solution
import sys
def read_line():
return sys.stdin.readline()[:-1]
def read_int():
return int(sys.stdin.readline())
def read_int_line():
return [int(v) for v in sys.stdin.readline().split()]
############
# Solution #
T = read_int()
for _test in range(T):
N = read_int()
A = read_int_line()
K = read_int()
X = read_int()
# Corner case
if X == 0 or K == N:
ans1 = 0
ans2 = 0
for v in A:
ans1 += v
ans2 += v ^ X
print(max(ans1, ans2))
continue
# We can prove that the answer only depends on the parity of K
K %= 2
ans = 0
cnt = 0
for v in A:
val = max(v, v ^ X)
ans += val
if val == (v ^ X):
cnt += 1
# We must either remove, or add one number to the group that was XOR-ed
if K == 0 and cnt % 2 == 1:
rem = 10**18
for x in A:
v = ((x ^ X) - x)
if v >= 0:
rem = min(rem, v)
add = -10**18
for x in A:
v = ((x ^ X) - x)
if v <= 0:
add = max(add, v)
ans = max(ans - rem, ans + add)
print(ans)
Editorialist's Solution
//teja349
#include <bits/stdc++.h>
#include <vector>
#include <set>
#include <map>
#include <string>
#include <cstdio>
#include <cstdlib>
#include <climits>
#include <utility>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <iomanip>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
//setbase - cout << setbase (16); cout << 100 << endl; Prints 64
//setfill - cout << setfill ('x') << setw (5); cout << 77 << endl; prints xxx77
//setprecision - cout << setprecision (14) << f << endl; Prints x.xxxx
//cout.precision(x) cout<<fixed<<val; // prints x digits after decimal in val
using namespace std;
using namespace __gnu_pbds;
#define f(i,a,b) for(i=a;i<b;i++)
#define rep(i,n) f(i,0,n)
#define fd(i,a,b) for(i=a;i>=b;i--)
#define pb push_back
#define mp make_pair
#define vi vector< int >
#define vl vector< ll >
#define ss second
#define ff first
#define ll long long
#define pii pair< int,int >
#define pll pair< ll,ll >
#define sz(a) a.size()
#define inf (1000*1000*1000+5)
#define all(a) a.begin(),a.end()
#define tri pair<int,pii>
#define vii vector<pii>
#define vll vector<pll>
#define viii vector<tri>
#define mod (1000*1000*1000+7)
#define pqueue priority_queue< int >
#define pdqueue priority_queue< int,vi ,greater< int > >
#define flush fflush(stdout)
#define primeDEN 727999983
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
// find_by_order() // order_of_key
typedef tree<
int,
null_type,
less<int>,
rb_tree_tag,
tree_order_statistics_node_update>
ordered_set;
#define int ll
int a[123456];
main(){
std::ios::sync_with_stdio(false); cin.tie(NULL);
int t;
cin>>t;
while(t--){
int n;
cin>>n;
int i;
ll sumi=0,sum1=0;
rep(i,n){
cin>>a[i];
}
int k,x;
cin>>k>>x;
vi vec;
rep(i,n){
sumi+=a[i];
sum1+=a[i]^x;
vec.pb(a[i]-(a[i]^x));
}
if(k==n){
cout<<max(sumi,sum1)<<endl;
continue;
}
int val;
sort(all(vec));
if(k%2){
rep(i,vec.size()){
val=vec[i];
val*=-1;
if(val>0)
sumi+=val;
}
}
else{
for(i=0;2*i+1<vec.size();i++){
val=vec[2*i]+vec[2*i+1];
val*=-1;
if(val>0)
sumi+=val;
}
}
cout<<sumi<<endl;
}
return 0;
}
Feel free to Share your approach, If it differs. Suggestions are always welcomed.