PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: am_aadvik
Tester: raysh07
Editorialist: iceknight1093
DIFFICULTY:
Medium
PREREQUISITES:
Dynamic programming, combinatorics
PROBLEM:
For an array A and a value M, define f(A, M) to be the maximum possible bitwise OR of A after XOR-ing at most one subarray of A with the value 2^M - 1.
You’re given N and M, compute the sum of f(A, M) across all arrays of length N with elements in [0, 2^M).
EXPLANATION:
Let’s understand how to compute f(A, M) for a fixed array A.
Observe that since A_i \lt 2^M, XOR-ing a segment by 2^M - 1 essentially just flips all the bits of all elements in that segment.
Since we’re aiming to maximize the bitwise OR, it’s optimal to try and keep the higher bits set (even at the expense of lower bits).
So, let’s look at bit (M-1).
We want bit (M-1) to be set in the overall OR, which means at least one element must have it set.
So,
- If initially no element has it set, we need to perform the operation on some non-empty subarray.
It doesn’t really matter which non-empty subarray is chosen though. - If instead some elements do have it set, we need to ensure that whichever operation we perform doesn’t turn them all off entirely (or if we do turn them all off; then we must flip some other index to have it set).
In particular for the second case: if the set of elements that have (M-1) set form a contiguous segment of A, then we cannot operate on exactly that segment; but any other segment will be fine.
(Note that this can also be thought of as including the first case, since there the banned segment is the empty one.)
On the other hand, if the set of elements that have (M-1) set don’t form a contiguous segment, we actually have no restriction: any operation (or even no operation) will allow (M-1) to be set.
Observe that the above analysis really applies to any single bit we look at: if all of its occurrences form a contiguous range, then operating on that segment exactly will turn this bit off in the final OR whereas any other segment will be fine; and if the occurrences aren’t contiguous then our operations aren’t restricted at all and any operation at all will do.
So, rather than bits, we’re somewhat bound by segments.
Thus, let’s compute for each contiguous segment [L, R] the set of bits that will be turned off if this segment is operated on.
(We also include the empty segment here.)
Let this value be S_{L, R} for the segment [L, R]. (We can treat S_{1, 0} as belonging to the empty segment.)
Then, observe that if we operate on segment [L, R], the final OR value will be exactly
(2^M - 1) \oplus S_{L, R}
That is, take all bits from 0 to M-1, and then turn off exactly the ones in S_{L, R}.
Note that this is equivalent to 2^M - 1 - S_{L, R}, so clearly the optimal choice is just whichever S_{L, R} is minimum.
In particular,
- If S_{L, R} = 0 for some segment (meaning that there’s a segment which is not banned by any bit), then the answer is just 2^M - 1.
- Otherwise, if S_{L, R} \gt 0 for all segments, the minimum S_{L, R} is the one whose highest set bit is minimized (since different S_{L, R} don’t share any bits).
Let’s now use the above observation to solve the problem.
There are (2^{M})^{N} = 2^{MN} arrays in total.
For each of them, let’s assume the answer is initially 2^M - 1, i.e. has all bits set.
This gives us a starting value of 2^{MN}\cdot (2^M - 1).
From here, we can try to deal with the subtraction part via counting contributions.
Specifically, let’s fix a bit b, and try to count the number of arrays in which 2^b is subtracted from the answer.
For this to happen,
- Every segment [L, R] along with the empty segment must have some bit banning it.
- b itself must ban some segment.
- The segment that b bans must have the lowest possible highest bit, to ensure that this is the S_{L, R} that’s being subtracted.
Let’s try to count the number of arrays where this happens.
There are \frac{N\cdot (N+1)}{2} + 1 possible subarrays, including the empty one.
Let K = \frac{N\cdot (N+1)}{2} + 1 for convenience.
For each bit \gt b, there are a few options:
- It doesn’t ban any range.
This means it must not appear as a contiguous segment in A.
There are 2^N - K ways this can happen: 2^N ways to just distribute the bit across all elements, from which we remove the number of ways the correspond to segments.
Importantly, this value is independent of the bit entirely. - It does ban some range.
On paper there are K choices for distributing it across A, but the situation is not so simple here since things aren’t independent across bits.
This is because we need to ensure that every range gets banned at least once by the time we get to b, to ensure that 2^b is subtracted.
Let’s forget the first type of bit for now, and just assume that every bit \gt b bans some range.
Then, observe that what we need to do is separate these higher bits into subsets; where each subset corresponds to the same range being banned.
Further,
- If there are K non-empty subsets among higher bits, the subset for b is chosen uniquely (it must go into the ‘last’ one).
- If there are K-1 non-empty subsets among higher bits, again the subset for b is chosen uniquely (it must form its own, to be the last one.)
- If there are \lt K-1 non-empty subsets among higher bits, there’s no situation in which we subtract 2^b.
So, we only really need to count the number of ways to partition the higher bits into either K or K-1 subsets.
Note that after deciding a partition into subsets, and also including b, there are exactly K! ways to assign segments to the subsets, since it doesn’t really matter which subset gets which range - just that every range is taken care of.
Thus, if we define f(x, y) to be the number of ways to partition \{1, 2, \ldots, x\} into y subsets (where the subsets are unordered), the values we’re looking for are f(M-b-1, K) and f(M-b-1, K-1).
Let’s now consider computing f(x, y), the number of ways of partitioning \{1, 2, \ldots, x\} into y non-empty subsets.
This has a simple recurrence:
f(x, y) = f(x-1, y)\cdot y + f(x-1, y-1).
This is because we can either partition the first x-1 elements into y subsets and then choose which one x goes to; or we can partition the first x-1 elements into y-1 subsets and then make x its own subset.
(You may notice these are exactly the Stirling numbers of the second kind).
This allows for all the f(x, y) values to be precomputed in \mathcal{O}(M^2) time, which is fast enough for the given constraints.
Once this is done, we can simply look them up as needed.
We return to actually solving the problem.
Our initial assumption was that every higher bit would correspond to some range, resulting in the values f(M-1-b, K) and f(M-1-b, K-1) respectively (multiplied by K!).
However, this need not be the case: some higher bits could correspond to no range.
Luckily, since we observed that the combinatorics for such bits doesn’t really depend on which ones they are, they can be handled easily.
Let’s fix r to be the number of higher bits that do not correspond to any range.
There are \binom{M-1-b}{r} ways to choose which bits these are, and then each of these r bits has 2^N - K options, for (2^N - K)^r options overall.
There are M-1-b-r higher bits remaining, which must then be distributed to the ranges: this can be done in f(M-1-b-r, K) + f(M-1-b-r, K-1) ways, as done above.
Finally, we also need to deal with the lower bits.
However, once we’ve ensured that 2^b is subtracted, lower bits can’t influence this at all; so they’re basically completely free.
This gives 2^N options for each lower bit, for 2^{Nb} options overall.
So, the number of arrays in which b is subtracted equals
This can be computed in \mathcal{O}(M) for a single bit, giving a solution in \mathcal{O}(M^2) overall.
TIME COMPLEXITY:
\mathcal{O}(M^2) per testcase.
CODE:
Editorialist's code (C++)
// #include <bits/allocator.h>
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());
/**
* Integers modulo p, where p is a prime
* Source: Aeren (modified from tourist?)
* Modmul for 64-bit mod from kactl:ModMulLL
* Works with p < 7.2e18 with x87 80-bit long double, and p < 2^52 ~ 4.5e12 with 64-bit
*/
template<typename T>
struct Z_p{
using Type = typename decay<decltype(T::value)>::type;
static vector<Type> MOD_INV;
constexpr Z_p(): value(){ }
template<typename U> Z_p(const U &x){ value = normalize(x); }
template<typename U> static Type normalize(const U &x){
Type v;
if(-mod() <= x && x < mod()) v = static_cast<Type>(x);
else v = static_cast<Type>(x % mod());
if(v < 0) v += mod();
return v;
}
const Type& operator()() const{ return value; }
template<typename U> explicit operator U() const{ return static_cast<U>(value); }
constexpr static Type mod(){ return T::value; }
Z_p &operator+=(const Z_p &otr){ if((value += otr.value) >= mod()) value -= mod(); return *this; }
Z_p &operator-=(const Z_p &otr){ if((value -= otr.value) < 0) value += mod(); return *this; }
template<typename U> Z_p &operator+=(const U &otr){ return *this += Z_p(otr); }
template<typename U> Z_p &operator-=(const U &otr){ return *this -= Z_p(otr); }
Z_p &operator++(){ return *this += 1; }
Z_p &operator--(){ return *this -= 1; }
Z_p operator++(int){ Z_p result(*this); *this += 1; return result; }
Z_p operator--(int){ Z_p result(*this); *this -= 1; return result; }
Z_p operator-() const{ return Z_p(-value); }
template<typename U = T>
typename enable_if<is_same<typename Z_p<U>::Type, int>::value, Z_p>::type &operator*=(const Z_p& rhs){
#ifdef _WIN32
uint64_t x = static_cast<int64_t>(value) * static_cast<int64_t>(rhs.value);
uint32_t xh = static_cast<uint32_t>(x >> 32), xl = static_cast<uint32_t>(x), d, m;
asm(
"divl %4; \n\t"
: "=a" (d), "=d" (m)
: "d" (xh), "a" (xl), "r" (mod())
);
value = m;
#else
value = normalize(static_cast<int64_t>(value) * static_cast<int64_t>(rhs.value));
#endif
return *this;
}
template<typename U = T>
typename enable_if<is_same<typename Z_p<U>::Type, int64_t>::value, Z_p>::type &operator*=(const Z_p &rhs){
uint64_t ret = static_cast<uint64_t>(value) * static_cast<uint64_t>(rhs.value) - static_cast<uint64_t>(mod()) * static_cast<uint64_t>(1.L / static_cast<uint64_t>(mod()) * static_cast<uint64_t>(value) * static_cast<uint64_t>(rhs.value));
value = normalize(static_cast<int64_t>(ret + static_cast<uint64_t>(mod()) * (ret < 0) - static_cast<uint64_t>(mod()) * (ret >= static_cast<uint64_t>(mod()))));
return *this;
}
template<typename U = T>
typename enable_if<!is_integral<typename Z_p<U>::Type>::value, Z_p>::type &operator*=(const Z_p &rhs){
value = normalize(value * rhs.value);
return *this;
}
template<typename U>
Z_p &operator^=(U e){
if(e < 0) *this = 1 / *this, e = -e;
Z_p res = 1;
for(; e; *this *= *this, e >>= 1) if(e & 1) res *= *this;
return *this = res;
}
template<typename U>
Z_p operator^(U e) const{
return Z_p(*this) ^= e;
}
Z_p &operator/=(const Z_p &otr){
Type a = otr.value, m = mod(), u = 0, v = 1;
if(a < (int)MOD_INV.size()) return *this *= MOD_INV[a];
while(a){
Type t = m / a;
m -= t * a; swap(a, m);
u -= t * v; swap(u, v);
}
assert(m == 1);
return *this *= u;
}
template<typename U> friend const Z_p<U> &abs(const Z_p<U> &v){ return v; }
Type value;
};
template<typename T> bool operator==(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value == rhs.value; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator==(const Z_p<T>& lhs, U rhs){ return lhs == Z_p<T>(rhs); }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator==(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) == rhs; }
template<typename T> bool operator!=(const Z_p<T> &lhs, const Z_p<T> &rhs){ return !(lhs == rhs); }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator!=(const Z_p<T> &lhs, U rhs){ return !(lhs == rhs); }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator!=(U lhs, const Z_p<T> &rhs){ return !(lhs == rhs); }
template<typename T> bool operator<(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value < rhs.value; }
template<typename T> bool operator>(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value > rhs.value; }
template<typename T> bool operator<=(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value <= rhs.value; }
template<typename T> bool operator>=(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value >= rhs.value; }
template<typename T> Z_p<T> operator+(const Z_p<T> &lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) += rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator+(const Z_p<T> &lhs, U rhs){ return Z_p<T>(lhs) += rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator+(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) += rhs; }
template<typename T> Z_p<T> operator-(const Z_p<T> &lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) -= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator-(const Z_p<T>& lhs, U rhs){ return Z_p<T>(lhs) -= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator-(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) -= rhs; }
template<typename T> Z_p<T> operator*(const Z_p<T> &lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) *= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator*(const Z_p<T>& lhs, U rhs){ return Z_p<T>(lhs) *= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator*(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) *= rhs; }
template<typename T> Z_p<T> operator/(const Z_p<T> &lhs, const Z_p<T> &rhs) { return Z_p<T>(lhs) /= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator/(const Z_p<T>& lhs, U rhs) { return Z_p<T>(lhs) /= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator/(U lhs, const Z_p<T> &rhs) { return Z_p<T>(lhs) /= rhs; }
template<typename T> istream &operator>>(istream &in, Z_p<T> &number){
typename common_type<typename Z_p<T>::Type, int64_t>::type x;
in >> x;
number.value = Z_p<T>::normalize(x);
return in;
}
template<typename T> ostream &operator<<(ostream &out, const Z_p<T> &number){ return out << number(); }
/*
using ModType = int;
struct VarMod{ static ModType value; };
ModType VarMod::value;
ModType &mod = VarMod::value;
using Zp = Z_p<VarMod>;
*/
// constexpr int mod = 1e9 + 7; // 1000000007
constexpr int mod = (119 << 23) + 1; // 998244353
// constexpr int mod = 1e9 + 9; // 1000000009
using Zp = Z_p<integral_constant<decay<decltype(mod)>::type, mod>>;
template<typename T> vector<typename Z_p<T>::Type> Z_p<T>::MOD_INV;
template<typename T = integral_constant<decay<decltype(mod)>::type, mod>>
void precalc_inverse(int SZ){
auto &inv = Z_p<T>::MOD_INV;
if(inv.empty()) inv.assign(2, 1);
for(; inv.size() <= SZ; ) inv.push_back((mod - 1LL * mod / (int)inv.size() * inv[mod % (int)inv.size()]) % mod);
}
template<typename T>
vector<T> precalc_power(T base, int SZ){
vector<T> res(SZ + 1, 1);
for(auto i = 1; i <= SZ; ++ i) res[i] = res[i - 1] * base;
return res;
}
template<typename T>
vector<T> precalc_factorial(int SZ){
vector<T> res(SZ + 1, 1); res[0] = 1;
for(auto i = 1; i <= SZ; ++ i) res[i] = res[i - 1] * i;
return res;
}
const int N = 3005;
Zp dp[N][N]; // dp[i][j] = ways to partition [1...i] into j distinct non-empty subsets
Zp C[N][N]; // dp[i][j] = ways to partition [1...i] into j distinct non-empty subsets
int main()
{
ios::sync_with_stdio(false); cin.tie(0);
dp[0][0] = C[0][0] = 1;
for (int n = 1; n < N; ++n) {
C[n][0] = 1;
for (int r = 1; r <= n; ++r) {
dp[n][r] = dp[n-1][r] * r + dp[n-1][r-1];
C[n][r] = C[n-1][r] + C[n-1][r-1];
}
}
int t; cin >> t;
while (t--) {
int n, m; cin >> n >> m;
Zp ans = 0;
int tot = n*(n+1)/2 + 1;
Zp other = (Zp(2)^n) - tot;
Zp fac = 1;
for (int i = 1; i <= tot; ++i) fac *= i;
for (int i = m-1; i >= 0; --i) {
int before = m-1-i, after = i;
Zp val = (Zp(2) ^ i) * fac * (Zp(2) ^ (n * after));
Zp skipways = 1;
for (int skip = 0; skip <= before; ++skip) {
// skip things are not segments
// (before-skip) things are segments
for (int k : {tot-1, tot}) {
Zp ways = 0;
if (before-skip >= k) ways = dp[before-skip][k];
ways *= skipways;
ways *= C[before][skip];
ans -= ways * val;
}
skipways *= other;
}
ans += (Zp(2) ^ (n*m + i));
}
cout << ans << '\n';
}
}