ADJXOR2 - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: Jeevan Jyot Singh
Testers: Abhinav Sharma, Venkata Nikhil Medam
Editorialist: Nishank Suresh

DIFFICULTY:

1827

PREREQUISITES:

Dynamic programming

PROBLEM:

JJ has an array A and an integer X. At most once, he can choose a subsequence of A and add X to all its elements.

What is the maximum possible value of \displaystyle \sum_{i=2}^N (A_i \oplus A_{i-1}) that he can obtain?

EXPLANATION:

This task can be solved fairly easily with the help of dynamic programming.

Suppose we define a new array dp of length N, where dp_i is the maximum possible answer for the first i elements. Our aim is to compute dp_N.

Each element has two choices for it: it either remains A_i, or becomes A_i + X. This information can be encapsulated into the state for dynamic programming by simply holding two values for each index: dp_{i, 1} denotes the maximum answer for the first i elements if A_i isn’t changed, and dp_{i, 2} denotes the maximum answer for the first i elements if A_i is changed to A_i + X.

All that remains are the base cases and transitions.

  • The base case is i = 1, where we have dp_{1, 1} = dp_{1, 2} = 0
  • To compute dp_{i, 1}, we have 2 choices for the previous element: either it changed, or it did not.
    • If it did not change, we get the value dp_{i-1, 1} + (A_i \oplus A_{i-1})
    • If it did change, we get the value dp_{i-1, 2} + (A_i \oplus (A_{i-1} + X))
    • dp_{i, 1} is the maximum of these two values.
  • Similarly, dp_{i, 2} is the maximum of dp_{i-1, 1} + ((A_i + X) \oplus A_{i-1}) and dp_{i-1, 2} + ((A_i + X) \oplus (A_{i-1} + X)).

These values can be computed iteratively, simply iterating i from 2 to N. The final answer is the maximum of dp_{N, 1} and dp_{N, 2}.

TIME COMPLEXITY:

\mathcal{O}(N) per test case.

CODE:

Setter (C++)
#ifdef WTSH
    #include <wtsh.h>
#else
    #include <bits/stdc++.h>
    using namespace std;
    #define dbg(...)
#endif

#define int long long
#define endl "\n"
#define sz(w) (int)(w.size())
using pii = pair<int, int>;

const int mod = 998244353; 

// -------------------- Input Checker Start --------------------

long long readInt(long long l, long long r, char endd)
{
    long long x = 0;
    int cnt = 0, fi = -1;
    bool is_neg = false;
    while(true)
    {
        char g = getchar();
        if(g == '-')
        {
            assert(fi == -1);
            is_neg = true;
            continue;
        }
        if('0' <= g && g <= '9')
        {
            x *= 10;
            x += g - '0';
            if(cnt == 0)
                fi = g - '0';
            cnt++;
            assert(fi != 0 || cnt == 1);
            assert(fi != 0 || is_neg == false);
            assert(!(cnt > 19 || (cnt == 19 && fi > 1)));
        }
        else if(g == endd)
        {
            if(is_neg)
                x = -x;
            if(!(l <= x && x <= r))
            {
                cerr << "L: " << l << ", R: " << r << ", Value Found: " << x << '\n';
                assert(false);
            }
            return x;
        }
        else
        {
            assert(false);
        }
    }
}

string readString(int l, int r, char endd)
{
    string ret = "";
    int cnt = 0;
    while(true)
    {
        char g = getchar();
        assert(g != -1);
        if(g == endd)
            break;
        cnt++;
        ret += g;
    }
    assert(l <= cnt && cnt <= r);
    return ret;
}

long long readIntSp(long long l, long long r) { return readInt(l, r, ' '); }
long long readIntLn(long long l, long long r) { return readInt(l, r, '\n'); }
string readStringSp(int l, int r) { return readString(l, r, ' '); }
string readStringLn(int l, int r) { return readString(l, r, '\n'); }
void readEOF() { assert(getchar() == EOF); }

vector<int> readVectorInt(int n, long long l, long long r)
{
    vector<int> a(n);
    for(int i = 0; i < n - 1; i++)
        a[i] = readIntSp(l, r);
    a[n - 1] = readIntLn(l, r);
    return a;
}

// -------------------- Input Checker End --------------------

int sumN = 0;

void solve()
{
    int n = readIntSp(1, 1e5);
    int x = readIntLn(1, 1e9);
    vector<int> a = readVectorInt(n, 1, 1e9);
    sumN += n;
    vector<array<int, 2>> dp(n);
    for(int i = 1; i < n; i++)
    {
        dp[i][0] = max(dp[i - 1][0] + (a[i - 1] ^ a[i]), dp[i - 1][1] + ((a[i - 1] + x) ^ a[i]));
        dp[i][1] = max(dp[i - 1][0] + (a[i - 1] ^ (a[i] + x)), dp[i - 1][1] + ((a[i - 1] + x) ^ (a[i] + x)));
    }
    cout << max(dp[n - 1][0], dp[n - 1][1]) << endl;
}

int32_t main()
{
    ios::sync_with_stdio(0); 
    cin.tie(0);
    int T = readIntLn(1, 1e5);
    for(int tc = 1; tc <= T; tc++)
    {
        // cout << "Case #" << tc << ": ";
        solve();
    }
    readEOF();
    assert(sumN <= 2e5);
    return 0;
}
Tester (nikhil_medam, C++)
// Tester: Nikhil_Medam
#include <bits/stdc++.h>
#pragma GCC optimize ("-O3")
using namespace std;
#define IOS ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
#define endl "\n"
#define int long long
#define double long double
const int N = 1e5 + 5;

int t, n, x, a[N], dp[N][2];
int xorValue(int x, int y) {
    return x ^ y;
}
int32_t main() {
    cin >> t;
    while(t--) {
        cin >> n >> x;
        for(int i = 1; i <= n; i++) {
            cin >> a[i];
        }
        for(int i = 2; i <= n; i++) {
            dp[i][0] = max(xorValue(a[i], a[i - 1]) + dp[i - 1][0], xorValue(a[i], a[i - 1] + x) + dp[i - 1][1]);
            dp[i][1] = max(xorValue(a[i] + x, a[i - 1]) + dp[i - 1][0], xorValue(a[i] + x, a[i - 1] + x) + dp[i - 1][1]);
        }
        cout << max(dp[n][0], dp[n][1]) << endl;
    }
	return 0;
}
Editorialist (Python)
for _ in range(int(input())):
	n, x = map(int, input().split())
	a = list(map(int, input().split()))
	p, q = 0, 0
	for i in range(1, n):
		same = max(p + (a[i] ^ a[i-1]), q + (a[i] ^ (a[i-1] + x)))
		change = max(p + ((a[i]+x) ^ a[i-1]), q + ((a[i]+x) ^ (a[i-1] + x)))
		p, q = same, change
	print(max(p, q))
3 Likes

I was trying to solve this qus using recursion for just practice why I’m getting 45 ans for 3rd example test case

static long helper(long arr[],int idx,long x){
if(idx==arr.length) return 0L;

    long first=Math.max(xorValue(arr[idx],arr[idx-1])+helper(arr,idx+1,x) , xorValue(arr[idx],(arr[idx-1]+x))+helper(arr,idx+1,x));
    long second=Math.max(xorValue((arr[idx]+x),arr[idx-1])+helper(arr,idx+1,x) , xorValue((arr[idx]+x),(arr[idx-1]+x))+helper(arr,idx+1,x));
    return Math.max(first,second);
}

what I have written wrong in this recursion ?

1 Like

I think my solution explains the dp state and transitions more clearly.
https://www.codechef.com/viewsolution/71111314

1 Like

can anyone upload their recursion + memoised approach?
Here’s what I was trying to do, but failed.

ll solve(int idx, bool prev, vll &v)
{
    if(idx==(n-1))
        {
            if(prev)
            {
                return max((v[idx]+x)^(v[idx-1]+x),(v[idx]^(v[idx-1]+x)));
            }
            if(!prev)
            {
                return max((v[idx]+x)^(v[idx-1]),(v[idx]^(v[idx-1])));
            }
        }
    ll ans = 0;
    if(prev)
    {
        ans = (v[idx]^(v[idx-1]+x)) + max(solve(idx+1,0,v),solve(idx+1,1,v));
    }
    else if(!prev)
    {
        ans = (v[idx]^v[idx-1]) + max(solve(idx+1,0,v),solve(idx+1,1,v));
    }
    return ans;
}

This is my (recursive + memoization) code

#include <bits/stdc++.h>
#define ll long long int
using namespace std;

vector<vector<ll>> dp;
ll solve(vector<ll> &arr, ll x, ll i, bool f){
    ll n = arr.size();
    if(i >= n) return 0;
    if(dp[i][f] != -1) return dp[i][f];

    ll ans= 0;

    if(f == true){
        ans= max(solve(arr, x, i+1, false) + ((arr[i-1]+x)^arr[i]), solve(arr, x, i+1, true) + ((arr[i-1]+x)^(arr[i]+x)));
    }else{
        ans= max(solve(arr,x, i+1, true) + (arr[i-1]^(arr[i]+x)), solve(arr,x, i+1, false) + (arr[i-1]^arr[i]));
    }  

    return dp[i][f] = ans;
}

int32_t main() {
    ios_base::sync_with_stdio(false); 
    cin.tie(NULL);
    
    ll T;
    cin>>T;
    while(T--){
        ll n, x;
        cin>>n>>x;
        
        vector<ll> arr(n,0);
        for(int i=0; i<n; ++i) cin>>arr[i];

        dp.clear();
        dp.resize(n+1, vector<ll>(2,-1));
        ll ans = solve(arr,x,1,false);

        dp.clear();
        dp.resize(n+1, vector<ll>(2,-1));
        ans = max(ans, solve(arr,x,1,true));
        cout<<ans<<endl;

    }///end of while
    
    return 0;
}

#include <bits/stdc++.h>
#define ll long long int
using namespace std;

vector<vector<ll>> dp;
ll solve(vector<ll> &arr, ll x, ll i, bool f){
    ll n = arr.size();
    if(i >= n) return 0;
    if(dp[i][f] != -1) return dp[i][f];

    ll ans= 0;

    if(f == true){
        ans= max(solve(arr, x, i+1, false) + ((arr[i-1]+x)^arr[i]), solve(arr, x, i+1, true) + ((arr[i-1]+x)^(arr[i]+x)));
    }else{
        ans= max(solve(arr,x, i+1, true) + (arr[i-1]^(arr[i]+x)), solve(arr,x, i+1, false) + (arr[i-1]^arr[i]));
    }  

    return dp[i][f] = ans;
}

int32_t main() {
    ios_base::sync_with_stdio(false); 
    cin.tie(NULL);
    
    ll T;
    cin>>T;
    while(T--){
        ll n, x;
        cin>>n>>x;
        
        vector<ll> arr(n,0);
        for(int i=0; i<n; ++i) cin>>arr[i];

        dp.clear();
        dp.resize(n+1, vector<ll>(2,-1));
        ll ans = solve(arr,x,1,false);

        dp.clear();
        dp.resize(n+1, vector<ll>(2,-1));
        ans = max(ans, solve(arr,x,1,true));
        cout<<ans<<endl;

    }///end of while
    
    return 0;
}

https://www.codechef.com/viewsolution/71038559

Can anyone help me to get a solution in this take notTake form because this will also make a subsequence of all possible ways it should also come, if a solution is possible plz help.
Thank YOU!!

Your code fails on this:

1 Like

yeah I know that its not right,
thats what i asked if someone help me to make my code right

I have a doubt:
Suppose an array is A1, A2, A3. This is of size 3
then ans = A1^A2 + A2^A3.
That means addition for A2 with X depends upon it’s XOR with A1 and A3.
but in DP solution, it seems we are taking decision for A2 on the basis of A1 only.
Can someone please explain. It seems I am not clear with logic.

1 Like

same doubt !

Note the definition of the dynamic programming states: dp_i depends only on the first i elements of the array. We are essentially ignoring everything to the right of i.

This is extremely common in dynamic programming problems, and is why you sometimes see the word ‘subproblem’ thrown around when talking about it. The idea behind DP is to create the solution to the larger problem by continuously solving smaller versions of the same problem and somehow combining their answers.

In this case, we find the answer for i elements by first finding the answer for i-1 elements, and then trying to add the i-th element in. You will also notice this pattern in many other common DP tasks:

  • Longest increasing subsequence: you define lis_i to be the longest increasing subsequence ending at i, without care for whether it can be later extended or not
  • Classical 0/1 knapsack: generally, dp_{i, x} is the maximum value you can get with a weight of x, from the first i items. Maybe later on you find that it’s not optimal to create a weight of x from the first i items, but that really doesn’t matter at the current step — it will be taken care of when it matters.

and so on. In fact, pretty much every problem that can be solved with DP is going to be something like this, since the fact that you can use it at all means you are able to restrict your solution to deal with something local instead of something global, that’s how you create subproblems and solve them.

1 Like

Thanks! got it

I am getting just one WA for my code ,and can’t figure it out . please help

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.security.Key;
import java.util.*;

import java.util.HashMap;
import java.util.Map;
import java.util.function.Consumer;
//import javafx.util.Pair;

class Codechef {
static class FastReader {
BufferedReader br;
StringTokenizer st;

    public FastReader() {
        br = new BufferedReader(
                new InputStreamReader(System.in));
    }

    String next() {
        while (st == null || !st.hasMoreElements()) {
            try {
                st = new StringTokenizer(br.readLine());
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
        return st.nextToken();
    }

    int nextInt() {
        return Integer.parseInt(next());
    }

    long nextLong() {
        return Long.parseLong(next());
    }

    double nextDouble() {
        return Double.parseDouble(next());
    }


    String nextLine() {
        String str = "";
        try {
            str = br.readLine();
        } catch (IOException e) {
            e.printStackTrace();
        }
        return str;
    }
}

public static void main(String[] args) {
    FastReader sc = new FastReader();
    int t = sc.nextInt();
    while (t-- > 0) {
        long n = sc.nextLong();
        long x = sc.nextLong();
        long arr[] = new long[(int) (n+1)];
        for (int i = 1; i <= n; i++) {
            arr[i] = sc.nextLong();
        }
        if (n == 1) {
            System.out.println(arr[0] + x);
            continue;
        }
       long dp1[]=new long[(int) (n+1)];
        long dp2[]=new long[(int) (n+1)];
        dp1[1]=dp2[1]=0;
        for (int i =2; i <=n ; i++) {
            dp1[i]=Math.max(((arr[i]+x)^(arr[i-1]))+dp2[i-1],((arr[i]+x)^(arr[i-1]+x))+dp1[i-1]);
            dp2[i]=Math.max(((arr[i])^(arr[i-1]+x))+dp1[i-1],((arr[i])^(arr[i-1]))+dp2[i-1]);
        }
        long ans=Math.max(dp1[arr.length-1],dp2[arr.length-1]);
        System.out.println(ans);
    }
}

private static long gcd(long a, long b) {
    if (a == 0)
        return b;
    return gcd(b % a, a);
}
private static long lcm(long a,long b)
{
    return (a*b/gcd(a,b));
}

static long findpow(long a,long b)
{
    long m=1000000007;
    long ans=1;
    while(b!=0)
    {
        if((b & 1)==1)
            ans=(ans%m * a%m)%m;

        a=(a%m * a%m)%m;
        b>>=1;
    }
    return ans;
}

}