PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: raysh07
Tester: sushil2006
Editorialist: iceknight1093
DIFFICULTY:
Medium
PREREQUISITES:
SOS DP
PROBLEM:
Given an array A, construct a complete weighted graph on it, with the weight of edge (i, j) being A_i\oplus A_j.
The length of a path is defined to be the sum of weights of edges on it.
Find the minimum length of a path from 1 to N, and the number of paths that achieve this minimum.
EXPLANATION:
First, we need to find the minimum path length.
An obvious candidate for this is to just jump directly from 1 to N, which has a cost of A_1\oplus A_N.
This is indeed optimal: the proof follows from the fact that for any three non-negative integers x, y, z, we have (x\oplus y) + (y\oplus z) \geq (x\oplus z).
(To prove this inequality, simply consider each bit independently and work out what happens.)
With the minimum known, our goal is now to compute the number of paths that attain the minimum.
To this end, let’s analyze what such a path must look like.
Consider a fixed bit b. Then,
- If b is not set in both A_1 and A_N (so it’s unset in A_1 \oplus A_N), then we cannot visit any vertex that has b set.
This is because visiting such a vertex would add at least 2\cdot 2^b to the cost; since we’ll have to shift from it being unset to set, and then set to unset, at least once each.
So, all values that have b set can be ignored. - Similarly, if b is set in both A_1 and A_N, we cannot visit any vertex that has this bit unset.
- Next, suppose b is unset in A_1 and set in A_N.
Then, some prefix of our path must have b unset, and the remaining suffix must have b set.
Once again, not following this will result in an additional 2\cdot 2^b (at least) to the cost which is not what we want. - Similarly, if b is set in A_1 but unset in A_N, the opposite is true: some prefix of the path must have b set, and the remaining suffix must have it unset.
Now, the first couple of points tell us that if A_i differs from A_1 at any bit that’s unset in (A_1 \oplus A_N), no shortest path from 1 to N can include A_i, so we can ignore this element altogether.
Let’s discard all such elements, leaving us with only path candidates.
We can further simplify things.
For any bit b that’s set in A_1 and unset in A_N, let’s flip the state of b in every remaining A_i.
Now, all remaining bits are the same: they must be off on some prefix of the path, and on in the remaining suffix.
In particular, this means if the path goes via the values A_1 = x_1 \to x_2 \to x_3 \to\cdots\to x_k = A_N, then each x_i must be a submask of x_{i+1}.
It’s not hard to see that any such path is valid, so our aim is to count the number of such paths.
This immediately lends itself to a dynamic programming solution, where we build up paths by their prefixes.
That is, define dp[\text{mask}] to be the number of ways to choose a path such that the last value in it is \text{mask}.
We have the following transitions:
- To end at \text{mask}, we must choose any non-empty subset of elements that equal \text{mask}, and then choose an order for these elements.
If there are k occurrences of \text{mask}, there are then \sum_{x=1}^k \binom{k}{x} x! ways of doing this.
Note that computing this in \mathcal{O}(k) is fine since the sum of frequencies is N. - Next, we look at the previous elements on the path.
For this, for each submask \text{sub} of \text{mask}, we can take any path ending with \text{sub} and extend it with the occurrences of \text{mask} we’ve chosen.
So, the number of ways here is the sum of dp[\text{sub}] across all proper submasks of \text{mask}.
Note that \text{mask} = A_1 and \text{mask} = A_N must be treated separately since we must start with A_1 and end with A_N.
This is not particularly hard, and only changes numbers a bit: the actual idea behind the combinatorics remains the same.
Our DP essentially iterates over all pairs of (\text{mask}, \text{submask}), which is well-known to have a complexity of \mathcal{O}(3^b) when run over b-bit integers. (Simple proof: each bit has three options for which of \text{mask} and \text{submask} it can belong to).
So, the complexity of just implementing this directly, using submask enumeration, is \mathcal{O}(3^{\log_2 N}).
Unfortunately, for the given constraints \log_2 N can be as large as 20, and 3^{20} is too slow so we need to do better.
Observe that when computing dp[\text{mask}], we had one multiplicative term that’s purely dependent on \text{mask}, and then one term that’s just the sum of dp[\text{sub}] across all submasks of \text{mask}. The latter is what is slow to compute.
One obvious optimization comes to mind: this is quite literally a sum over subsets, so perhaps we can use SOS DP.
However, we immediately run into an issue: SOS DP doesn’t work “online”, and needs to know the values at all masks to start with; which is not the case for us.
However, with a bit of work we can still make it work for us.
Let’s start with dp[0] initialized to its correct value, and everything else initialized to 0.
Suppose we run SOS DP on this; let the results be stored in the \text{result} array.
Observe that \text{result}[0] = dp[0] will hold; but for every other mask we’ll be undercounting.
However, there’s one set of exceptions: all those masks with only one bit set.
For all of these masks, \text{result}[\text{mask}] will be exactly the sum of dp[\text{sub}] across all the proper submasks (which is trivially just the 0 mask in this case).
Now, let’s go one step further.
Compute dp[\text{mask}] from \text{result}[\text{mask}] for all those masks that have only a single bit set; and reset all the incorrect dp[\text{mask}] values to 0.
Now, once again run SOS DP on this new dp array, with the result being stored in \text{result}.
Observe that on this second run, \text{result}[\text{mask}] will be correctly computed (as in, be the sum of all submask dp values) for all those masks with exactly two bits set in them.
So, we can use these values to update the corresponding dp[\text{mask}] entries, and then reset everything else that was incorrectly computed to 0.
Running SOS for a third time will then allow us to correctly obtain the values of all masks with three bits set, and so on and so forth.
Repeating this process \log_2 N times will hence correctly compute all the values of dp, at which point we’re done.
TIME COMPLEXITY:
\mathcal{O}(N\log^2 N) per testcase.
CODE:
Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18
mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());
const int mod = 998244353;
struct mint{
int x;
mint (){ x = 0;}
mint (int32_t xx){ x = xx % mod; if (x < 0) x += mod;}
mint (long long xx){ x = xx % mod; if (x < 0) x += mod;}
int val(){
return x;
}
mint &operator++(){
x++;
if (x == mod) x = 0;
return *this;
}
mint &operator--(){
if (x == 0) x = mod;
x--;
return *this;
}
mint operator++(int32_t){
mint result = *this;
++*this;
return result;
}
mint operator--(int32_t){
mint result = *this;
--*this;
return result;
}
mint& operator+=(const mint &b){
x += b.x;
if (x >= mod) x -= mod;
return *this;
}
mint& operator-=(const mint &b){
x -= b.x;
if (x < 0) x += mod;
return *this;
}
mint& operator*=(const mint &b){
long long z = x;
z *= b.x;
z %= mod;
x = (int)z;
return *this;
}
mint operator+() const {
return *this;
}
mint operator-() const {
return mint() - *this;
}
mint operator/=(const mint &b){
return *this = *this * b.inv();
}
mint power(long long n) const {
mint ok = *this, r = 1;
while (n){
if (n & 1){
r *= ok;
}
ok *= ok;
n >>= 1;
}
return r;
}
mint inv() const {
return power(mod - 2);
}
friend mint operator+(const mint& a, const mint& b){ return mint(a) += b;}
friend mint operator-(const mint& a, const mint& b){ return mint(a) -= b;}
friend mint operator*(const mint& a, const mint& b){ return mint(a) *= b;}
friend mint operator/(const mint& a, const mint& b){ return mint(a) /= b;}
friend bool operator==(const mint& a, const mint& b){ return a.x == b.x;}
friend bool operator!=(const mint& a, const mint& b){ return a.x != b.x;}
mint power(mint a, long long n){
return a.power(n);
}
friend ostream &operator<<(ostream &os, const mint &m) {
os << m.x;
return os;
}
explicit operator bool() const {
return x != 0;
}
};
// Remember to check MOD
struct factorials{
int n;
vector <mint> ff, iff;
factorials(int nn){
n = nn;
ff.resize(n + 1);
iff.resize(n + 1);
ff[0] = 1;
for (int i = 1; i <= n; i++){
ff[i] = ff[i - 1] * i;
}
iff[n] = ff[n].inv();
for (int i = n - 1; i >= 0; i--){
iff[i] = iff[i + 1] * (i + 1);
}
}
mint C(int n, int r){
if (n == r) return mint(1);
if (n < 0 || r < 0 || r > n) return mint(0);
return ff[n] * iff[r] * iff[n - r];
}
mint P(int n, int r){
if (n < 0 || r < 0 || r > n) return mint(0);
return ff[n] * iff[n - r];
}
mint solutions(int n, int r){
// Solutions to x1 + x2 + ... + xn = r, xi >= 0
return C(n + r - 1, n - 1);
}
mint catalan(int n){
return ff[2 * n] * iff[n] * iff[n + 1];
}
};
const int PRECOMP = 3e6 + 69;
factorials F(PRECOMP);
// REMEMBER To check MOD and PRECOMP
void Solve()
{
int n; cin >> n;
vector <int> a(n + 1);
for (int i = 1; i <= n; i++){
cin >> a[i];
}
int bits = 0;
for (int i = 0; i < 30; i++){
if ((1 << i) <= n){
bits++;
}
}
for (int i = 2; i <= n; i++){
a[i] ^= a[1];
}
a[1] = 0;
int m = 1 << bits;
vector <int> f(m, 0);
for (int i = 2; i < n; i++){
f[a[i]]++;
}
vector <mint> g(m, 0);
for (int i = 0; i < m; i++){
for (int j = 1; j <= f[i]; j++){
g[i] += F.C(f[i], j) * F.ff[j];
}
}
g[0]++;
if (a[n]) g[a[n]]++;
vector <mint> dp(m, 0);
dp[0] = g[0];
vector<vector<int>> at(bits + 1);
for (int i = 0; i < m; i++){
at[__builtin_popcount(i)].push_back(i);
}
for (int i = 1; i <= bits; i++){
vector <mint> sos = dp;
for (int k = 0; k < bits; k++){
for (int j = 0; j < m; j++) if (j >> k & 1){
sos[j] += sos[j ^ (1 << k)];
}
}
for (int x : at[i]){
dp[x] = sos[x] * g[x];
}
}
cout << a[n] << " " << dp[a[n]] << "\n";
}
int32_t main()
{
auto begin = std::chrono::high_resolution_clock::now();
ios_base::sync_with_stdio(0);
cin.tie(0);
int t = 1;
// freopen("in", "r", stdin);
// freopen("out", "w", stdout);
cin >> t;
for(int i = 1; i <= t; i++)
{
//cout << "Case #" << i << ": ";
Solve();
}
auto end = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n";
return 0;
}