NDANDANDOR editorial

PROBLEM LINK:

Contest Division 1
Contest Division 2
Contest Division 3
Contest Division 4
Practice

Setter: Valerio Stancanelli
Testers: Takuki Kurokawa and Nishank Suresh
Editorialist: Utkarsh Gupta

DIFFICULTY

3148

PREREQUISITES

Combinatorics

PROBLEM

You are given three positive integers N, M and K.

An array of integers A_1, A_2, \dots, A_N is good if the following statements hold:

  • 0 \leq A_i \leq M for each (1 \leq i \leq N)
  • (A_i \& K) \leq (A_{i+1} \& K) for each (1 \leq i \leq N-1)
  • (A_i \; | \; K) \leq (A_{i+1} \; | \; K) for each (1 \leq i \leq N-1)

Here, \& denotes the bitwise AND operation and | denotes the bitwise OR operation.

Count the number of good arrays A_1, A_2, \dots, A_N. As the result can be very large, you should print it modulo 998\,244\,353.

EXPLANATION

First increament M by 1 so we have to consider only 0 \leq A_i \lt M.

Small Observation

Let us consider only first 21 bits.
Let P contains the set bits in K, while Q contains the unset bits of K.
Then the array should be non-decreasing with respect to the bits present in P as well as Q.

General Approach while solving this problem

As we want A_i \lt M so we will choose some prefix of M and mismatch the prefix only at the last index so that the later bits can be dealt independently.
For example if the binary representation of M is 110110, then we will consider the prefixes 1, 11, 1101, 11011 and now changing the last bit we get 0, 10, 1100, 11010.
Note that we have left the prefixes 110, 110110 as we cannot make them smaller by changing the last bit.
Now after fixing the prefix S of length x we will ensure that the prefix of length x of every A_i should be \leq S as well as prefix of A_N (which is the largest element of the array) should be exactly S (so that we don’t do any overcounting later).
Now after fixing some prefix S and maintaining the above inequalities we have established that all the A_i will be \lt M so we can independently fill the remaining bits of the suffix.

Now after fixing some prefix S we need to count the number of good Arrays A such that

  • prefix of A_i \leq S
  • prefix of A_N = S
  • A is non-decreasing with respect to the bits of sets P and Q defined earlier.

After finding these counts for every prefix we can directly add them to obtain the answer. There will be no overcounting as we have fixed the prefix of A_N everytime.

Now let us consider M = \textcolor{red}{1}\textcolor{green}{0}\textcolor{red}{1}\textcolor{green}{1}\textcolor{green}{1}\textcolor{red}{0}\textcolor{red}{0}\textcolor{green}{1}\textcolor{red}{1}\textcolor{green}{1} in binary. Here Red color denotes the bits involved in set P, while Green color denotes the bits involved in set Q.

Let us consider the prefix of length 5 which is 10111. Now changing the last bit of that prefix we get S = 10110. Let us compress the bits of both colors in the prefix now. So for Red bits in the prefix we need to ensure that Red bits of every A_i should be \leq \textcolor{red}{1}\textcolor{red}{1} (i.e. 3 in decimal). For Green bits in the prefix we need to ensure that Green bits of every A_i should be \leq \textcolor{green}{0}\textcolor{green}{1}\textcolor{green}{0} (i.e. 2 in decimal). As there are no inequality restrictions with respect to M in the suffix so we just want that the Red bits value in suffix of A should be \leq 2^3-1 (Since there are 3 red bits in the suffix) and Green bits value in suffix of A should be \leq 2^2-1 (Since there are 2 green bits in the suffix).

Let’s say that Red bits value and green bits value of the prefix of A_N be x and y respectively (In the taken example x=3 and y=2.
Let’s say that Red bits value and green bits value of the suffix of A_N be p and q respectively (In the taken example 0 \leq p \leq 2^2-1 and 0 \leq q \leq 2^3-1).
We will first update x to 2^3 \cdot x (Since there are 3 red bits in the suffix) and update y to 2^2 \cdot y (Since there are 2 green bits in the suffix).
Now total Red value = x+p and total green value = y + q.
We want Red value as well as Green value of A to be non-decreasing.
Let us fix the fix the Red Value of A_N to x+p and Green Value of A_N to y+q, then the number of non-decreasing arrays satisfying the above conditions is:

( n+x+p \choose n - n+x+p-1 \choose n ) \cdot ( n+y+q \choose n - n+y+q-1 \choose n )
This formula can be derived using stars and bars. In the above formula the first term denotes the contribution of Red bits while second term denotes the contribution of Green bits. Also the term n+x+p \choose n denotes the number of non-decreasing arrays of length n such that all the elements are \leq x+p and \geq 0. So subtracting n+x+p-1 \choose n we will get the count of non-decreasing arrays whose last term is x+p.

So for each prefix we will add the above formula to obtain the answer. See editorialists code for the above implementation.

Some common Mistakes

Keep your MaxN = 4 \cdot 10^6 for precompution of factorials.

TIME COMPLEXITY

O(log^2m) for every test case

SOLUTIONS

Setter's Solution
#include <bits/stdc++.h>
using namespace std;

#define nl "\n"
#define nf endl
#define ll long long
#define pb push_back
#define _ << ' ' <<

#define INF (ll)1e18
#define mod 998244353
#define maxn 2097162

ll i, i1, j, k, k1, t, n, m, res, flag[10], a, b;
ll fc[maxn], nv[maxn], c[2], d[2], bb;

ll fxp(ll b, ll e) {
    ll r = 1, k = b;
    while (e != 0) {
        if (e % 2) r = (r * k) % mod;
        k = (k * k) % mod; e /= 2;
    }
    return r;
}

ll inv(ll x) {
    return fxp(x, mod - 2);
}

ll bnm(ll a, ll b) {
    if (a < b || b < 0) return 0;
    ll r = (fc[a] * nv[b]) % mod;
    r = (r * nv[a - b]) % mod;
    return r;
}


int main() {
    ios::sync_with_stdio(0);
    cin.tie(0);

    fc[0] = 1; nv[0] = 1;
    for (i = 1; i < maxn; i++) {
        fc[i] = (i * fc[i - 1]) % mod; nv[i] = inv(fc[i]);
    }

    cin >> t;
    while (t--) {
        cin >> n >> m >> k;
        c[0] = 0; c[1] = 0; d[0] = 1; d[1] = 1; res = 0;
        for (i = 0; i <= 19; i++) d[(k >> i) & 1] *= 2;
        for (i = 19; i >= -1; i--) {
            if (i != -1) {
                c[(k >> i) & 1] *= 2; d[(k >> i) & 1] /= 2;
                if (((m >> i) & 1) == 0) continue;
            }
            res = (res + (bnm((c[0] + 1) * d[0] + n - 1, n) - bnm(c[0] * d[0] + n - 1, n) + mod) *
                    (bnm((c[1] + 1) * d[1] + n - 1, n) - bnm(c[1] * d[1] + n - 1, n) + mod)) % mod;
            // cout << "i, c[0], c[1], d[0], d[1], res =" _ i _ c[0] _ c[1] _ d[0] _ d[1] _ res << nl;
            if (i != -1) c[(k >> i) & 1]++;
        }

        cout << res << nl;
    }

    return 0;
}
Tester's Solution 1
#include <bits/stdc++.h>
using namespace std;

int main() {
    const long long mod = 998244353;
    const int N = (int) 3e6;
    vector<long long> fact(N), inv(N), inv_fact(N);
    fact[0] = inv[0] = inv_fact[0] = 1;
    fact[1] = inv[1] = inv_fact[1] = 1;
    for (int i = 2; i < N; i++) {
        fact[i] = fact[i - 1] * i % mod;
        inv[i] = (mod - inv[mod % i] * (mod / i) % mod) % mod;
        inv_fact[i] = inv_fact[i - 1] * inv[i] % mod;
    }
    auto C = [&](int i, int j) {
        if (j < 0 || i < j) {
            return 0LL;
        }
        return fact[i] * inv_fact[j] % mod * inv_fact[i - j] % mod;
    };
    int tt;
    cin >> tt;
    while (tt--) {
        int n, m, k;
        cin >> n >> m >> k;
        if (m == 0) {
            cout << 1 << endl;
            continue;
        }
        int t = 0;
        for (int i = 0; i < 20; i++) {
            if (m & (1 << i)) {
                t = i;
            }
        }
        auto Get = [&](int low, int high) {
            return (C(n + high, n) - C(n + low - 1, n) + mod) % mod;
        };
        int a = 0, b = 0;
        int x = 0, y = 0;
        for (int i = 0; i <= t; i++) {
            if (k & (1 << i)) {
                x++;
            } else {
                y++;
            }
        }
        long long ans = 0;
        for (int i = t; i >= 0; i--) {
            if (m & (1 << i)) {
                if (k & (1 << i)) {
                    ans += Get(a << x, (a << x) + (1 << (x - 1)) - 1) * Get(b << y, (b << y) + (1 << y) - 1) % mod;
                } else {
                    ans += Get(a << x, (a << x) + (1 << x) - 1) * Get(b << y, (b << y) + (1 << (y - 1)) - 1) % mod;
                }
                ans %= mod;
            }
            if (k & (1 << i)) {
                x--;
                a <<= 1;
                if (m & (1 << i)) {
                    a |= 1;
                }
            } else {
                y--;
                b <<= 1;
                if (m & (1 << i)) {
                    b |= 1;
                }
            }
        }
        ans += Get(a, a) * Get(b, b) % mod;
        ans %= mod;
        cout << ans << endl;
    }
    return 0;
}
Tester's Solution 2
#include "bits/stdc++.h"
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

/**
 * Integers modulo p, where p is a prime
 * Source: Aeren (modified from tourist?)
 *         Modmul for 64-bit mod from kactl:ModMulLL
 * Works with p < 7.2e18 with x87 80-bit long double, and p < 2^52 ~ 4.5e12 with 64-bit
 */
template<typename T>
struct Z_p{
	using Type = typename decay<decltype(T::value)>::type;
	static vector<Type> MOD_INV;
	constexpr Z_p(): value(){ }
	template<typename U> Z_p(const U &x){ value = normalize(x); }
	template<typename U> static Type normalize(const U &x){
		Type v;
		if(-mod() <= x && x < mod()) v = static_cast<Type>(x);
		else v = static_cast<Type>(x % mod());
		if(v < 0) v += mod();
		return v;
	}
	const Type& operator()() const{ return value; }
	template<typename U> explicit operator U() const{ return static_cast<U>(value); }
	constexpr static Type mod(){ return T::value; }
	Z_p &operator+=(const Z_p &otr){ if((value += otr.value) >= mod()) value -= mod(); return *this; }
	Z_p &operator-=(const Z_p &otr){ if((value -= otr.value) < 0) value += mod(); return *this; }
	template<typename U> Z_p &operator+=(const U &otr){ return *this += Z_p(otr); }
	template<typename U> Z_p &operator-=(const U &otr){ return *this -= Z_p(otr); }
	Z_p &operator++(){ return *this += 1; }
	Z_p &operator--(){ return *this -= 1; }
	Z_p operator++(int){ Z_p result(*this); *this += 1; return result; }
	Z_p operator--(int){ Z_p result(*this); *this -= 1; return result; }
	Z_p operator-() const{ return Z_p(-value); }
	template<typename U = T>
	typename enable_if<is_same<typename Z_p<U>::Type, int>::value, Z_p>::type &operator*=(const Z_p& rhs){
		#ifdef _WIN32
		uint64_t x = static_cast<int64_t>(value) * static_cast<int64_t>(rhs.value);
		uint32_t xh = static_cast<uint32_t>(x >> 32), xl = static_cast<uint32_t>(x), d, m;
		asm(
			"divl %4; \n\t"
			: "=a" (d), "=d" (m)
			: "d" (xh), "a" (xl), "r" (mod())
		);
		value = m;
		#else
		value = normalize(static_cast<int64_t>(value) * static_cast<int64_t>(rhs.value));
		#endif
		return *this;
	}
	template<typename U = T>
	typename enable_if<is_same<typename Z_p<U>::Type, int64_t>::value, Z_p>::type &operator*=(const Z_p &rhs){
		uint64_t ret = static_cast<uint64_t>(value) * static_cast<uint64_t>(rhs.value) - static_cast<uint64_t>(mod()) * static_cast<uint64_t>(1.L / static_cast<uint64_t>(mod()) * static_cast<uint64_t>(value) * static_cast<uint64_t>(rhs.value));
		value = normalize(static_cast<int64_t>(ret + static_cast<uint64_t>(mod()) * (ret < 0) - static_cast<uint64_t>(mod()) * (ret >= static_cast<uint64_t>(mod()))));
		return *this;
	}
	template<typename U = T>
	typename enable_if<!is_integral<typename Z_p<U>::Type>::value, Z_p>::type &operator*=(const Z_p &rhs){
		value = normalize(value * rhs.value);
		return *this;
	}
	template<typename U>
	Z_p &operator^=(U e){
		if(e < 0) *this = 1 / *this, e = -e;
		Z_p res = 1;
		for(; e; *this *= *this, e >>= 1) if(e & 1) res *= *this;
		return *this = res;
	}
	template<typename U>
	Z_p operator^(U e) const{
		return Z_p(*this) ^= e;
	}
	Z_p &operator/=(const Z_p &otr){
		Type a = otr.value, m = mod(), u = 0, v = 1;
		if(a < (int)MOD_INV.size()) return *this *= MOD_INV[a];
		while(a){
			Type t = m / a;
			m -= t * a; swap(a, m);
			u -= t * v; swap(u, v);
		}
		assert(m == 1);
		return *this *= u;
	}
	template<typename U> friend const Z_p<U> &abs(const Z_p<U> &v){ return v; }
	Type value;
};
template<typename T> bool operator==(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value == rhs.value; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator==(const Z_p<T>& lhs, U rhs){ return lhs == Z_p<T>(rhs); }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator==(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) == rhs; }
template<typename T> bool operator!=(const Z_p<T> &lhs, const Z_p<T> &rhs){ return !(lhs == rhs); }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator!=(const Z_p<T> &lhs, U rhs){ return !(lhs == rhs); }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator!=(U lhs, const Z_p<T> &rhs){ return !(lhs == rhs); }
template<typename T> bool operator<(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value < rhs.value; }
template<typename T> bool operator>(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value > rhs.value; }
template<typename T> bool operator<=(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value <= rhs.value; }
template<typename T> bool operator>=(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value >= rhs.value; }
template<typename T> Z_p<T> operator+(const Z_p<T> &lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) += rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator+(const Z_p<T> &lhs, U rhs){ return Z_p<T>(lhs) += rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator+(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) += rhs; }
template<typename T> Z_p<T> operator-(const Z_p<T> &lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) -= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator-(const Z_p<T>& lhs, U rhs){ return Z_p<T>(lhs) -= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator-(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) -= rhs; }
template<typename T> Z_p<T> operator*(const Z_p<T> &lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) *= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator*(const Z_p<T>& lhs, U rhs){ return Z_p<T>(lhs) *= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator*(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) *= rhs; }
template<typename T> Z_p<T> operator/(const Z_p<T> &lhs, const Z_p<T> &rhs) { return Z_p<T>(lhs) /= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator/(const Z_p<T>& lhs, U rhs) { return Z_p<T>(lhs) /= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator/(U lhs, const Z_p<T> &rhs) { return Z_p<T>(lhs) /= rhs; }
template<typename T> istream &operator>>(istream &in, Z_p<T> &number){
	typename common_type<typename Z_p<T>::Type, int64_t>::type x;
	in >> x;
	number.value = Z_p<T>::normalize(x);
	return in;
}
template<typename T> ostream &operator<<(ostream &out, const Z_p<T> &number){ return out << number(); }

/*
using ModType = int;
struct VarMod{ static ModType value; };
ModType VarMod::value;
ModType &mod = VarMod::value;
using Zp = Z_p<VarMod>;
*/

// constexpr int mod = 1e9 + 7; // 1000000007
constexpr int mod = (119 << 23) + 1; // 998244353
// constexpr int mod = 1e9 + 9; // 1000000009
using Zp = Z_p<integral_constant<decay<decltype(mod)>::type, mod>>;

template<typename T> vector<typename Z_p<T>::Type> Z_p<T>::MOD_INV;
template<typename T = integral_constant<decay<decltype(mod)>::type, mod>>
void precalc_inverse(int SZ){
	auto &inv = Z_p<T>::MOD_INV;
	if(inv.empty()) inv.assign(2, 1);
	for(; inv.size() <= SZ; ) inv.push_back((mod - 1LL * mod / (int)inv.size() * inv[mod % (int)inv.size()]) % mod);
}

template<typename T>
vector<T> precalc_power(T base, int SZ){
	vector<T> res(SZ + 1, 1);
	for(auto i = 1; i <= SZ; ++ i) res[i] = res[i - 1] * base;
	return res;
}

template<typename T>
vector<T> precalc_factorial(int SZ){
	vector<T> res(SZ + 1, 1); res[0] = 1;
	for(auto i = 1; i <= SZ; ++ i) res[i] = res[i - 1] * i;
	return res;
}

int main()
{
	ios::sync_with_stdio(false); cin.tie(0);

	const int MX = 3e6 + 100;
	auto fac = precalc_factorial<Zp>(MX), invf = fac;
	for (auto &x : invf) x = Zp(1)/x;
	auto C = [&] (int n, int r) {
		if (n < 0 or r < 0 or n < r) return Zp(0);
		return fac[n] * invf[r] * invf[n-r];
	};
	int t; cin >> t;
	while (t--) {
		int n, m, k; cin >> n >> m >> k;
		Zp ans = 0;
		int p1 = 1 << __builtin_popcount(k), p2 = 1 << (20 - __builtin_popcount(k));
		int s1 = 0, s2 = 0;
		for (int bit = 19; bit >= 0; --bit) {
			int fl1 = (k >> bit) & 1, fl2 = (m >> bit) & 1;
			if (fl2) {
				// cerr << bit << ": " << p1 << ' ' << p2 << ' ';
				Zp x1 = C(n - 1 + p1*s1 + (1 + !fl1)*p1/2, n) - C(n - 1 + p1*s1, n);
				Zp x2 = C(n - 1 + p2*s2 + (1 + fl1)*p2/2, n) - C(n - 1 + p2*s2, n);
				ans += x1*x2;
				// ans += (C(n - 1 + s1*p1, n) - C(n - 1 + p1*s1 + p1/2*(1 + !fl1), n)) * (C(n - 1 + s2*p2, n) - C(n - 1 + p2*s2 + p2/2*(1 + fl1), n));
			}
			p1 >>= fl1;
			p2 >>= !fl1;
			s1 = (s1 << fl1) + (fl1 && fl2);
			s2 = (s2 << !fl1) + (!fl1 && fl2);
			// cerr << bit << ": " << p1 << ' ' << p2 << ' ' << s1 << ' ' << s2 << '\n';
		}
		ans += C(n+s1-1, n-1)*C(n+s2-1, n-1);
		cout << ans << '\n';
	}
}
Editorialist's Solution
//Utkarsh.25dec
#include <bits/stdc++.h>
#define ll long long int
#define pb push_back
#define mp make_pair
#define mod 998244353 
#define vl vector <ll>
#define all(c) (c).begin(),(c).end()
using namespace std;
ll power(ll a,ll b) {ll res=1;a%=mod; assert(b>=0); for(;b;b>>=1){if(b&1)res=res*a%mod;a=a*a%mod;}return res;}
ll modInverse(ll a){return power(a,mod-2);}
const int N=4000023;
bool vis[N];
vector <int> adj[N];
ll fact[N];
ll invfact[N];
ll inv[N];
void factorialsComputation()
{
    inv[0]=inv[1]=1;
    fact[0]=fact[1]=1;
    invfact[0]=invfact[1]=1;
    for(int i=2;i<N;i++)
    {
        inv[i]=(inv[mod%i]*(mod-mod/i))%mod;
        fact[i]=(fact[i-1]*i)%mod;
        invfact[i]=(invfact[i-1]*inv[i])%mod;
    }
}
ll ncr(ll n,ll r)
{
    ll ans=fact[n]*invfact[r];
    ans%=mod;
    ans*=invfact[n-r];
    ans%=mod;
    return ans;
}
void solve()
{
    ll n,m,k;
    cin>>n>>m>>k;
    m++;
    int high=0;
    for(int i=22;i>=0;i--)
    {
    	if((m&(1<<i))!=0)
    	{
    		high=i;
    		break;
    	}
    }
    ll ans=0;
    for(int i=high;i>=0;i--)
    {
    	if((m&(1<<i))==0)
    		continue;
    	ll x=0,y=0;
    	for(int j=high;j>=i;j--)
    	{
    		if(j==i)
    		{
    			if((k&(1<<j))==0)
    				x*=2;
    			else
    				y*=2;
    			continue;
    		}
    		if((k&(1<<j))==0)
    		{
    			x*=2;
    			if((m&(1<<j))!=0)
    				x++;
    		}
    		else
    		{
    			y*=2;
    			if((m&(1<<j))!=0)
    				y++;
    		}
    	}
    	ll p=0,q=0;
    	for(int j=i-1;j>=0;j--)
    	{
    		if((k&(1<<j))==0)
    		{
    			p*=2;
    			x*=2;
    			p++;
    		}
    		else
    		{
    			q*=2;
    			y*=2;
    			q++;
    		}
    	}
    	ll tmp=(ncr(n+x+p,n)+mod-ncr(n+x-1,n))*(ncr(n+y+q,n)+mod-ncr(n+y-1,n));
    	tmp%=mod;
    	ans+=tmp;
    	ans%=mod;
    }
    cout<<ans<<'\n';
}
int main()
{
    #ifndef ONLINE_JUDGE
    freopen("input.txt", "r", stdin);
    freopen("output.txt", "w", stdout);
    #endif
    ios_base::sync_with_stdio(false);
    cin.tie(NULL),cout.tie(NULL);
    factorialsComputation();
    int T=1;
    cin>>T;
    while(T--)
        solve();
    cerr << "Time : " << 1000 * ((double)clock()) / (double)CLOCKS_PER_SEC << "ms\n";
}

Feel free to share your approach. Suggestions are welcomed as always. :slight_smile: