XORPROD - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: tejas10p
Testers: IceKnight1093, tabr
Editorialist: IceKnight1093

DIFFICULTY:

1845

PREREQUISITES:

Sorting

PROBLEM:

Given an array A, in one move you can choose x, y \in A, delete them, and add x\oplus y to A.
Maximize the product of A.

EXPLANATION:

A little analysis of how the operation affects the product of all the elements should tell you that it’s almost never optimal to replace two elements.

In fact, it’s optimal to replace x and y with x\oplus y if and only if x = 1 and y is even.

Why?

This can be seen somewhat intuitively by looking at just two elements.
Let x \leq y, and we want to decide whether to operate on x and y or not.

First, note that we have x\oplus y \lt 2y, which should both be obvious to see by looking at the bits independently.
Now,

  • If we don’t operate on x and y, we contribute xy to the product.
  • If we do operate on them, we contribute x\oplus y \lt 2y to the product.

In particular, if xy \geq 2y, i.e, x \geq 2, then it’s always better to not perform the operation.
This forces x = 1.

Now we have to decide which y give us 1\oplus y \gt 1\cdot y = y. Note that:

  • If y is even, 1\oplus y = y+1
  • If y is odd, 1\oplus y = y-1

This tells us that x = 1 and y being even is the only optimal case to perform an operation on two elements.

It’s somewhat reasonable to expect this to hold when we need to perform more than one move, but a lot less obvious why: after all, the order of moves matters, and maybe we want to perform one ‘bad’ move to be able to get to a ‘good’ one later.
It so happens that this is never the case. A slightly more detailed proof is attached below if you’re interested.

More detailed proof

Let B = [B_1, B_2, \ldots, B_k] be the final array, after we have performed some operations.
Note that B_i = A_{i_1} \oplus A_{i_2} \oplus \ldots A_{i_r} for some indices i_1, \ldots, i_r.
Let’s call each A_{i_j} a component of B_i.

Suppose there exists a B_i such that it has at least two components that are \geq 2. W.l.o.g let A_{i_1} \geq 2.
Then, we can instead perform operations so that we end up with B_i \oplus A_{i_1} and A_{i_1} instead of just B_i; and this gives us a strictly higher product.
So, an optimal solution will never have such a B_i.

Now suppose some B_i has \geq 2 components that are 1.
Then, we can remove two ones from this component (which doesn’t change its xor) and keep those ones as two more separate components: this doesn’t affect the product.
So, there exists an optimal solution in which each B_i has at most one 1, and at most one value \geq 2.

Now consider B_i = 1 \oplus y where y \geq 2.
As we noted above, if y is odd it’s better to have y and 1 separately.

So, an optimal solution can only have B_i that are either single elements, or 1 \oplus y for even y.
This completes the proof.

With this information in hand, let’s now get to actually solving the problem.

We can simply simulate the process: as long as we have at least one 1 and one even number remaining, perform an operation on them.
All that remains is to decide which even number to operate on at each step. This is simple: choose the smallest remaining even number.

Why?

Note that we’re choosing x =1 and even y, which means we’re removing 1\cdot y from the product and multiplying it by y+1 instead. So, our product is multiplied by \frac{y+1}{y}.

\frac{y+1}{y} is larger the smaller y is, so it’s optimal to choose the smallest y we can (while ensuring it’s even).

Implementing this is fairly straightforward: count the number of 1's in the sequence, then sort the even numbers and keep choosing the smallest one of them to operate on while there are remaining 1's.
Note that operating on an even number turns it into an odd number so the list doesn’t need to be updated.

TIME COMPLEXITY:

\mathcal{O}(N\log N) per testcase.

CODE:

Setter's code (C++)
#include <bits/stdc++.h>
#define mod 998244353
using namespace std;

int main() {
    //freopen("inp4.in", "r", stdin);
    //freopen("out4.out", "w", stdout);
    int t;
    cin >> t;
    assert(t > 0 && t < 50000);
    while(t--) {
        int n;
        cin >> n;
        assert(n > 0 && n <= 100000);
        long long int a[n];
        int ones = 0;
        priority_queue<long long int> pq;
        long long int ans = 1;
        for(int i = 0; i < n; i++) {
            cin >> a[i];
            assert(a[i] > 0 && a[i] <= 1000000000);
            if(a[i]&1) {
                if(a[i] == 1) ones++;
                ans *= a[i];
                ans %= mod;
            } else pq.push(-a[i]);
        }
        while(ones && !pq.empty()) {
            int top = -pq.top();
            pq.pop();
            ones--;
            ans *= (top + 1);
            ans %= mod;
        }
        while(!pq.empty()) {
            ans *= (-pq.top());
            pq.pop();
            ans %= mod;
        }
        cout << ans << "\n";
    }
}
Tester's code (C++)
#include <bits/stdc++.h>

using namespace std;

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
            now++;
        }
        return now;
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        // cerr << res << endl;
        return res;
    }

    string readString(int minl, int maxl, const string &pattern = "") {
        assert(minl <= maxl);
        string res = readOne();
        assert(minl <= (int) res.size());
        assert((int) res.size() <= maxl);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res = stoi(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res = stoll(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

template<long long mod>
struct modular {
    long long value;

    modular(long long x = 0) {
        value = x % mod;
        if (value < 0) value += mod;
    }

    modular &operator+=(const modular &other) {
        if ((value += other.value) >= mod) value -= mod;
        return *this;
    }

    modular &operator-=(const modular &other) {
        if ((value -= other.value) < 0) value += mod;
        return *this;
    }

    modular &operator*=(const modular &other) {
        value = value * other.value % mod;
        return *this;
    }

    modular &operator/=(const modular &other) {
        long long a = 0, b = 1, c = other.value, m = mod;
        while (c != 0) {
            long long t = m / c;
            m -= t * c;
            swap(c, m);
            a -= t * b;
            swap(a, b);
        }
        a %= mod;
        if (a < 0) a += mod;
        value = value * a % mod;
        return *this;
    }

    friend modular operator+(const modular &lhs, const modular &rhs) { return modular(lhs) += rhs; }

    friend modular operator-(const modular &lhs, const modular &rhs) { return modular(lhs) -= rhs; }

    friend modular operator*(const modular &lhs, const modular &rhs) { return modular(lhs) *= rhs; }

    friend modular operator/(const modular &lhs, const modular &rhs) { return modular(lhs) /= rhs; }

    modular &operator++() { return *this += 1; }

    modular &operator--() { return *this -= 1; }

    modular operator++(int) {
        modular res(*this);
        *this += 1;
        return res;
    }

    modular operator--(int) {
        modular res(*this);
        *this -= 1;
        return res;
    }

    modular operator-() const { return modular(-value); }

    bool operator==(const modular &rhs) const { return value == rhs.value; }

    bool operator!=(const modular &rhs) const { return value != rhs.value; }

    bool operator<(const modular &rhs) const { return value < rhs.value; }
};

template<long long mod>
string to_string(const modular<mod> &x) {
    return to_string(x.value);
}

template<long long mod>
ostream &operator<<(ostream &stream, const modular<mod> &x) {
    return stream << x.value;
}

template<long long mod>
istream &operator>>(istream &stream, modular<mod> &x) {
    stream >> x.value;
    x.value %= mod;
    if (x.value < 0) x.value += mod;
    return stream;
}

constexpr long long mod = 998244353;
using mint = modular<mod>;

mint power(mint a, long long n) {
    mint res = 1;
    while (n > 0) {
        if (n & 1) {
            res *= a;
        }
        a *= a;
        n >>= 1;
    }
    return res;
}

vector<mint> fact(1, 1);
vector<mint> finv(1, 1);

mint C(int n, int k) {
    if (n < k || k < 0) {
        return mint(0);
    }
    while ((int) fact.size() < n + 1) {
        fact.emplace_back(fact.back() * (int) fact.size());
        finv.emplace_back(mint(1) / fact.back());
    }
    return fact[n] * finv[k] * finv[n - k];
}

int main() {
    input_checker in;
    int tt = in.readInt(1, 5e4);
    in.readEoln();
    int sn = 0;
    while (tt--) {
        int n = in.readInt(2, 1e5);
        in.readEoln();
        sn += n;
        vector<int> a(n);
        for (int i = 0; i < n; i++) {
            a[i] = in.readInt(1, 1e9);
            (i == n - 1 ? in.readEoln() : in.readSpace());
        }
        vector<vector<int>> b(2);
        int c = 0;
        for (int i = 0; i < n; i++) {
            if (a[i] == 1) {
                c++;
            } else {
                b[a[i] % 2].emplace_back(a[i]);
            }
        }
        sort(b[0].begin(), b[0].end());
        for (int i = 0; i < min(c, (int) b[0].size()); i++) {
            b[0][i]++;
        }
        mint ans = 1;
        for (int i = 0; i < 2; i++) {
            for (int j: b[i]) {
                ans *= j;
            }
        }
        cout << ans << '\n';
    }
    cerr << sn << endl;
    assert(sn <= 3e5);
    in.readEof();
    return 0;
}
Editorialist's code (Python)
mod = 998244353
for _ in range(int(input())):
    n = int(input())
    a = list(map(int, input().split()))
    evens, odds = sorted([x for x in a if x%2 == 0]), sorted([x for x in a if x%2 == 1])
    p = q = 0
    while p < len(odds) and q < len(evens):
        if odds[p] != 1: break
        evens[q] += 1
        p += 1
        q += 1
    ans = 1
    for i in range(p, len(odds)): ans = (ans * odds[i])%mod
    for x in evens: ans = (ans * x)%mod
    print(ans)
1 Like

#include<bits/stdc++.h>
using namespace std;

#define mod 998244353
typedef long long ll;
int main() {

int t;
cin >> t;
while (t--) {
	int n;
	cin >> n;
	int arr[n];

	int count1 = 0;
	ll product = 1;
	for (int i = 0 ; i < n; i++) {
		cin >> arr[i];
		if (arr[i] == 1) {
			count1++;
		}

		product = ((product * arr[i])) % mod;

	}


	sort(arr, arr + n);



	for (int i = 0 ; i < n; i++) {
		if (arr[i] % 2 == 0) {
			product = (product / arr[i]) % mod;
			product = (product * (arr[i] + 1)) % mod;
			count1--;
		}
		if (count1 <= 0) {
			break;
		}
	}

	cout << (product) % mod << endl;

}
return 0;

}

can somebody please tell me why this code is giving the wrong answer on the submission
even if I used the same approach mentioned above

You can’t directly divide when you’re working under modulo, you need to multiply by inverses instead.

Thanks for the editorial @iceknight1093 ! Could you please explain the quoted part in a bit more detail, if possible with an example as to why we will get a strictly higher product. I only have difficulty understanding this part and it seems very crucial.

(Also, how do you quote parts from the comments with LaTeX? I had to manually edit the quoted part for proper LaTeX formatting :sweat_smile: )

It’s mildly annoying, but you just need to write out all the information you have.
Let x = A_{i_1}.
Then, B_i \oplus x \geq B_i - x, since you at most remove the set bits of x from B_i.

Let the product of the other elements be P.
In the first case, the overall product was P\cdot B_i; while in the second, it’s at least P\cdot (B_i - x) \cdot x.

Essentially, we want to compare B_i with (B_i-x)\cdot x.
Also note that B_i \geq 2x, since x was a factor in its product.
Now a bit of algebra:

B_i \geq 2x \geq x\times \frac{x}{x-1} \implies (x-1)B_i \geq x^2 \implies B_ix - x^2 \geq B_i

which is what we wanted.

The inequality technically isn’t strict, but you’ll notice that equality only holds when x = 2 and B_i = 2x = 4 (because you need both B_i = 2x and \frac{x}{x-1} = 2); in every other case the inequality is strict.

In that one case you get B_i \oplus x = 4\oplus 2 = 6 so the product increases anyway.

No clue, I fix formatting whenever I quote something

Sorry to bother you again, but could you please explain the above line. I get that x was used in the XOR sum to calculate B_i but not how x is a factor of B_i .

Oh you’re right, B_i \geq 2x isn’t necessarily going to be true.
I was thinking of something else and messed up the model in my head.

It needs a couple more cases to be fixed, I hope this is final:

  • If B_i \geq 2x I’ve already shown it’s optimal
  • If B_i \lt x it’s obviously better to separate x since you get a product of x with something positive which is better than what we had earlier
  • If B_i = x, separate out x and any other component that’s \geq 2 (recall that we’re considering the case when there are at least two such components) and you’ll get a higher product that’s at least 2x.
  • Now we’re left with x \lt B_i \lt 2x.
    • If some other component falls into one of the above cases, we can operate on that instead. So, we only need to care about the case there every component satisfies this inequality
  • When every component satisfies this, simply separate them all. Since there are \geq 2 components, we get a product of at least 2\times M where M is the largest component; which is of course strictly larger than B_i.
1 Like