COUNTARR343 - Editorial


Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: satyam_343
Tester: tabr
Editorialist: iceknight1093




Elementary combinatorics


An array B is called good if it can be filled with zeros by repeating the following operation several times:

  • Pick elements x and y of B, replace x with (x\mid y) - x and y with (x\mid y) - y.

Count the number of good arrays of length N whose elements are between 0 and 2^{K} - 1.


First, let’s understand what the given operation really does, and when an array is good.
Consider two non-negative integers x and y.
Let x' =(x\mid y) - x and y' = (x\mid y) - y be their values after the operation.

Note that (x\mid y) is a supermask of both x and y, meaning subtraction is really just the same thing as bitwise XOR in this case.
That is, x' = (x\mid y)\oplus x and y' = (x\mid y)\oplus y.

Since we’re dealing with bitwise only bitwise operations, let’s look at what happens for a fixed bit b.

  • If b is not set in both x and y, it won’t be set in x' or y' either.
  • If b is set in both x and y, it will be unset in both x' and y'.
    This is because b will be set in (x\mid y), then when XOR-ing with both x and y it gets cancelled out.
  • If b is set in x but not in y, it will be set in y' but not in x'.
    Similarly, if b is set in y but not x, it will be set in x' but not y'.
    Essentially, b swaps from x to y (or vice versa).

Our aim is to ensure that every bit is made 0, since only then will every number be 0.
Notice that our operation either swaps the position of a bit around, or destroys two bits - it never creates new set bits.
So, for an array A and bit b:

  • Let s_b denote the number of elements of A that have b set.
  • If s_b is odd, it’s impossible to remove all occurrences of b from the array, and the array can never be good.
  • If s_b is even, it’s always possible to remove all set bits, two at a time.

So, the array A is good if and only if, for each bit b, an even number of elements have b set.
Since we want 0 \leq A_i \lt 2^K, we have K bits to work with.

For each of these K bits, the elements that have it set can be chosen independently.
For a fixed bit b, we need to choose an even number of the N elements to have this bit set.
It’s well-known that for an N-element set, the number of subsets of even size is 2^{N-1}.

So, we have 2^{N-1} options for each bit, independent of the other bits.
With K bits in total, this is (2^{N-1})^K ways of creating a good array in total.

This value can be computed in \mathcal{O}(\log N + \log K) using binary exponentiation.


\mathcal{O}(\log N + \log K) per testcase.


Author's code (C++)
#pragma GCC optimize("O3,unroll-loops")
#include <bits/stdc++.h>   
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;   
using namespace std;
#define ll long long
const ll INF_ADD=1e18;
#define pb push_back                  
#define mp make_pair          
#define nline "\n"                            
#define f first                                            
#define s second                                             
#define pll pair<ll,ll> 
#define all(x) x.begin(),x.end()      
#define vl vector<ll>           
#define vvl vector<vector<ll>>      
#define vvvl vector<vector<vector<ll>>>          
#ifndef ONLINE_JUDGE    
#define debug(x) cerr<<#x<<" "; _print(x); cerr<<nline;
#define debug(x);  
void _print(ll x){cerr<<x;}  
void _print(char x){cerr<<x;}   
void _print(string x){cerr<<x;}    
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());   
template<class T,class V> void _print(pair<T,V> p) {cerr<<"{"; _print(p.first);cerr<<","; _print(p.second);cerr<<"}";}
template<class T>void _print(vector<T> v) {cerr<<" [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T>void _print(set<T> v) {cerr<<" [ "; for (T i:v){_print(i); cerr<<" ";}cerr<<"]";}
template<class T>void _print(multiset<T> v) {cerr<< " [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T,class V>void _print(map<T, V> v) {cerr<<" [ "; for(auto i:v) {_print(i);cerr<<" ";} cerr<<"]";} 
typedef tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
typedef tree<ll, null_type, less_equal<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_multiset;
typedef tree<pair<ll,ll>, null_type, less<pair<ll,ll>>, rb_tree_tag, tree_order_statistics_node_update> ordered_pset;
const ll MOD=998244353;
const ll MAX=500500;
ll binpow(ll a,ll b,ll MOD){
    ll ans=1;
    return ans;
ll inverse(ll a,ll MOD){
    return binpow(a,MOD-2,MOD);
void solve(){        
    ll n,k; cin>>n>>k;
    ll ans=binpow(2,n-1,MOD);
int main()                                                                               
    #ifndef ONLINE_JUDGE                 
    freopen("input.txt", "r", stdin);                                           
    freopen("output.txt", "w", stdout);      
    freopen("error.txt", "w", stderr);                        
    ll test_cases=1;                 
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#define debug(...)

#define IGNORE_CR

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
#ifdef IGNORE_CR
            if (c == '\r') {
            buffer.push_back((char) c);

    string readOne() {
        assert(pos < (int) buffer.size());
        string res;
        while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
            res += buffer[pos];
        return res;

    string readString(int min_len, int max_len, const string& pattern = "") {
        assert(min_len <= max_len);
        string res = readOne();
        assert(min_len <= (int) res.size());
        assert((int) res.size() <= max_len);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        return res;

    int readInt(int min_val, int max_val) {
        assert(min_val <= max_val);
        int res = stoi(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;

    long long readLong(long long min_val, long long max_val) {
        assert(min_val <= max_val);
        long long res = stoll(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;

    vector<int> readInts(int size, int min_val, int max_val) {
        assert(min_val <= max_val);
        vector<int> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readInt(min_val, max_val);
            if (i != size - 1) {
        return res;

    vector<long long> readLongs(int size, long long min_val, long long max_val) {
        assert(min_val <= max_val);
        vector<long long> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readLong(min_val, max_val);
            if (i != size - 1) {
        return res;

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');

    void readEof() {
        assert((int) buffer.size() == pos);

template <long long mod>
struct modular {
    long long value;
    modular(long long x = 0) {
        value = x % mod;
        if (value < 0) value += mod;
    modular& operator+=(const modular& other) {
        if ((value += other.value) >= mod) value -= mod;
        return *this;
    modular& operator-=(const modular& other) {
        if ((value -= other.value) < 0) value += mod;
        return *this;
    modular& operator*=(const modular& other) {
        value = value * other.value % mod;
        return *this;
    modular& operator/=(const modular& other) {
        long long a = 0, b = 1, c = other.value, m = mod;
        while (c != 0) {
            long long t = m / c;
            m -= t * c;
            swap(c, m);
            a -= t * b;
            swap(a, b);
        a %= mod;
        if (a < 0) a += mod;
        value = value * a % mod;
        return *this;
    friend modular operator+(const modular& lhs, const modular& rhs) { return modular(lhs) += rhs; }
    friend modular operator-(const modular& lhs, const modular& rhs) { return modular(lhs) -= rhs; }
    friend modular operator*(const modular& lhs, const modular& rhs) { return modular(lhs) *= rhs; }
    friend modular operator/(const modular& lhs, const modular& rhs) { return modular(lhs) /= rhs; }
    modular& operator++() { return *this += 1; }
    modular& operator--() { return *this -= 1; }
    modular operator++(int) {
        modular res(*this);
        *this += 1;
        return res;
    modular operator--(int) {
        modular res(*this);
        *this -= 1;
        return res;
    modular operator-() const { return modular(-value); }
    bool operator==(const modular& rhs) const { return value == rhs.value; }
    bool operator!=(const modular& rhs) const { return value != rhs.value; }
    bool operator<(const modular& rhs) const { return value < rhs.value; }
template <long long mod>
string to_string(const modular<mod>& x) {
    return to_string(x.value);
template <long long mod>
ostream& operator<<(ostream& stream, const modular<mod>& x) {
    return stream << x.value;
template <long long mod>
istream& operator>>(istream& stream, modular<mod>& x) {
    stream >> x.value;
    x.value %= mod;
    if (x.value < 0) x.value += mod;
    return stream;

constexpr long long mod = 998244353;
using mint = modular<mod>;

mint power(mint a, long long n) {
    mint res = 1;
    while (n > 0) {
        if (n & 1) {
            res *= a;
        a *= a;
        n >>= 1;
    return res;

vector<mint> fact(1, 1);
vector<mint> finv(1, 1);

mint C(int n, int k) {
    if (n < k || k < 0) {
        return mint(0);
    while ((int) fact.size() < n + 1) {
        fact.emplace_back(fact.back() * (int) fact.size());
        finv.emplace_back(mint(1) / fact.back());
    return fact[n] * finv[k] * finv[n - k];

int main() {
    input_checker in;
    int tt = in.readInt(1, 1e5);
    while (tt--) {
        int n = in.readInt(2, 1e6);
        int k = in.readInt(0, 1e6);
        cout << power(2, k * 1LL * (n - 1)) << '\n';
    cerr << in.pos << " " << in.buffer.size() << endl;
    return 0;
Editorialist's code (Python)
mod = 998244353
for _ in range(int(input())):
    n, k = map(int, input().split())
    print(pow(2, (n-1)*k, mod))