SPLITXORNOTX - Editorial

PROBLEM LINK:

Contest Division 1
Contest Division 2
Contest Division 3
Contest Division 4

Setter: Anton Trygub
Tester: Harris Leung
Editorialist: Trung Dang

DIFFICULTY:

3435

PREREQUISITES:

XOR, Probabilities

PROBLEM:

For an array A and a number X, let’s define f(A, X) as follows:

  • If it is not possible to split A into several consecutive subarrays in such a way, that the XOR of all elements in each subarray is not equal to X, f(A, X) = 0.
  • Otherwise, f(A, X) is equal to the largest possible number of subarrays in such a split.

You are given integers N, K, and X, where 0 \leq X \lt 2^K. Consider array A of length N, where each element is an integer generated uniformly from 0 to 2^K - 1. Find the expected value of f(A, X).

EXPLANATION:

If X = 0, then the answer is just the expected number of non-zero elements (i.e. \frac{N}{2^K}). We now assume that X > 0.

Let’s first see how an array can be split optimally. Let’s say values that are not 0 and X are good. Excluding the case where the whole array consists of only 0 and X, the array consists of some beginning 0 and X's, then good values with 0 and X's between these good values, and then some 0 and X's at the end. We note the following:

  • Each good element adds 1 to the answer.
  • Consider two consecutive good elements, with L values that are either 0 and X in between them. Say we make a prefix XOR of these L values into P = [P_0 = 0, P_1, P_2, \dots, P_L]. Then, the contribution of these elements to the answer is \max(occ_0, occ_X) - 1, where occ_t is the number of times t appears in P.
  • Consider the beginning B elements of either 0 or X, we again make a prefix XOR of these values, then the contribution is occ_0.
  • Similarly, the contribution of the end elements are also occ_0 of suffix XOR.

We now simply calculate the expectation for each of these values and sum them up (the answer would be correct due to linearity of expectation). We note the fact that the prefix (and suffix) XOR of random values behave like random values themselves, so for the second, third, and fourth part, we simply calculate (max) occurrences of 0 (and X) in an array of length L.

  • The expected number of good elements is N \cdot (1 - \frac{2}{2^K}).
  • For this part, we divide the subproblem even more: for each length L, we calculate the expectation of \max(occ_0, occ_X) - 1, which is then multiplied by the probability of such a segment appears in the array. The probability part is simple: we need to have L - 1 consecutive values of either 0 or X, bounded by two good values. Therefore, the probability is (N - L) \cdot (1 - \frac{2}{2^K})^2 \cdot (\frac{2}{2^K})^{L - 1}. The expected max part is more involved:
    • We will do a dynamic programming to calculate the expected \max(occ_0, occ_X) - 1. Let f_L be this expectation, then we only need to calculate the probability that the i-th element is chosen (which will be then added to f_{L - 1} to give us f_L). It is only chosen when there are at least \frac{L - 1}{2} values equal to the i-th element among the first L - 1 values. This probability is \frac{1}{2} when L - 1 is odd, or \frac{1}{2} + \frac{\binom{L - 1}{(L - 1)/2}}{2^L} when L - 1 is even.
  • Third and fourth point is trivial compare to second point: the expectation of occurences of 0's in an array of length L is \frac{L}{2}.

Finally, there is the case where the array consists of only 0's and X's, but handling this case should be similar to handling the third and fourth case (do keep in mind that the last element of the prefix XOR in this case have to be 0).

TIME COMPLEXITY:

Time complexity is O(N).

SOLUTION:

Setter's Solution
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#include <bits/stdc++.h>
#pragma GCC target ("avx2")
#pragma GCC optimization ("O3")
#pragma GCC optimization ("unroll-loops")

using namespace __gnu_pbds;
using namespace std;

using ll = long long;
using ld = long double;

typedef tree<
        pair<int, int>,
        null_type,
        less<pair<int, int>>,
        rb_tree_tag,
        tree_order_statistics_node_update>
        ordered_set;

#define mp make_pair

int MOD =  998244353;

int mul(int a, int b) {
    return (1LL * a * b) % MOD;
}

int add(int a, int b) {
    int s = (a+b);
    if (s>=MOD) s-=MOD;
    return s;
}

int sub(int a, int b) {
    int s = (a+MOD-b);
    if (s>=MOD) s-=MOD;
    return s;
}

int po(int a, ll deg)
{
    if (deg==0) return 1;
    if (deg%2==1) return mul(a, po(a, deg-1));
    int t = po(a, deg/2);
    return mul(t, t);
}

int inv(int n)
{
    return po(n, MOD-2);
}


mt19937 rnd(time(0));


const int LIM = 1000005;

vector<int> facs(LIM), invfacs(LIM), invs(LIM);

void init()
{
    facs[0] = 1;
    for (int i = 1; i<LIM; i++) facs[i] = mul(facs[i-1], i);
    invfacs[LIM-1] = inv(facs[LIM-1]);
    for (int i = LIM-2; i>=0; i--) invfacs[i] = mul(invfacs[i+1], i+1);

    for (int i = 1; i<LIM; i++) invs[i] = mul(invfacs[i], facs[i-1]);

}

int C(int n, int k)
{
    if (n<k) return 0;
    if (n<0 || k<0) return 0;
    return mul(facs[n], mul(invfacs[k], invfacs[n-k]));
}


struct DSU
{
    vector<int> sz;
    vector<int> parent;
    void make_set(int v) {
        parent[v] = v;
        sz[v] = 1;
    }

    int find_set(int v) {
        if (v == parent[v])
            return v;
        return find_set(parent[v]);
    }

    void union_sets(int a, int b) {
        a = find_set(a);
        b = find_set(b);

        if (a != b) {
            if (sz[a] < sz[b])
                swap(a, b);
            parent[b] = a;
            sz[a] += sz[b];
        }
    }

    DSU (int n)
    {
        parent.resize(n);
        sz.resize(n);
        for (int i = 0; i<n; i++) make_set(i);
    }
};

void print(vector<int> a)
{
    for (auto it: a) cout<<it<<' ';
    cout<<endl;
}

/*const int mod = 998244353;

template<int mod>
struct NTT {
    static constexpr int max_lev = __builtin_ctz(mod - 1);

    int prod[2][max_lev - 1];

    NTT() {
        int root = find_root();//(mod == 998244353) ? 31 : find_root();
        int rroot = power(root, mod - 2);
        vector<vector<int>> roots(2, vector<int>(max_lev - 1));
        roots[0][max_lev - 2] = root;
        roots[1][max_lev - 2] = rroot;
        for (int tp = 0; tp < 2; ++tp) {
            for (int i = max_lev - 3; i >= 0; --i) {
                roots[tp][i] = mul(roots[tp][i + 1], roots[tp][i + 1]);
            }
        }
        for (int tp = 0; tp < 2; ++tp) {
            int cur = 1;
            for (int i = 0; i < max_lev - 1; ++i) {
                prod[tp][i] = mul(cur, roots[tp][i]);
                cur = mul(cur, roots[tp ^ 1][i]);
            }
        }
    }

    template<bool inv>
    void fft(int *a, int lg) const {
        const int n = 1 << lg;
        int pos = max_lev - 1;
        for (int it = 0; it < lg; ++it) {
            const int h = inv ? lg - 1 - it : it;
            const int shift = (1 << (lg - h - 1));
            int coef = 1;
            for (int start = 0; start < (1 << h); ++start) {
                for (int i = start << (lg - h); i < (start << (lg - h)) + shift; ++i) {
                    if (!inv) {
                        const int y = mul(a[i + shift], coef);
                        a[i + shift] = a[i];
                        inc(a[i], y);
                        dec(a[i + shift], y);
                    } else {
                        const int y = mul(a[i] + mod - a[i + shift], coef);
                        inc(a[i], a[i + shift]);
                        a[i + shift] = y;
                    }
                }
                coef = mul(coef, prod[inv][__builtin_ctz(~start)]);
            }
        }
    }

    vector<int> product(vector<int> a, vector<int> b) const {
        if (a.empty() || b.empty()) {
            return {};
        }
        const int sz = a.size() + b.size() - 1;
        const int lg = 32 - __builtin_clz(sz - 1), n = 1 << lg;
        a.resize(n);
        b.resize(n);
        fft<false>(a.data(), lg);
        fft<false>(b.data(), lg);
        for (int i = 0; i < n; ++i) {
            a[i] = mul(a[i], b[i]);
        }
        fft<true>(a.data(), lg);
        a.resize(sz);
        const int rn = power(n, mod - 2);
        for (int &x : a) {
            x = mul(x, rn);
        }
        return a;
    }

private:
    static inline void inc(int &x, int y) {
        x += y;
        if (x >= mod) {
            x -= mod;
        }
    }

    static inline void dec(int &x, int y) {
        x -= y;
        if (x < 0) {
            x += mod;
        }
    }

    static inline int mul(int x, int y) {
        return (1LL * x * y) % mod;
    }

    static int power(int x, int y) {
        if (y == 0) {
            return 1;
        }
        if (y % 2 == 0) {
            return power(mul(x, x), y / 2);
        }
        return mul(x, power(x, y - 1));
    }

    static int find_root() {
        for (int root = 2; ; ++root) {
            if (power(root, (1 << max_lev)) == 1 && power(root, (1 << (max_lev - 1))) != 1) {
                return root;
            }
        }
    }
};

NTT<mod> ntt;
*/

int inv2 = inv(2);

vector<int> deg(LIM), invdeg(LIM);

int at_least(int n, int k)
{
    int ans, cur;
    if (n%2 == 1)
    {
        cur = (n+1)/2; ans = inv2;
    }
    else
    {
        cur = n/2; ans = mul(add(deg[n], C(n, cur)), invdeg[n+1]);
    }

    while (cur>k)
    {
        cur--; ans = add(ans, mul(C(n, cur), invdeg[n]));
    }

    while (cur<k)
    {
        ans = sub(ans, mul(C(n, cur), invdeg[n])); cur++;
    }

    return ans;
}

int exp_max(int n)
{
    int prb = add(at_least(n-1, (n-1)/2), at_least(n-1, n/2));

    prb = mul(prb, inv2);

    //if element is 0: need at least
    return mul(n, prb);
}

int main()
{
    ios_base::sync_with_stdio(0);
    cin.tie(nullptr);

    init();

    deg[0] = 1;
    invdeg[0] = 1;

    for (int i = 1; i<LIM; i++)
    {
        deg[i] = mul(deg[i-1], 2);
        invdeg[i] = mul(invdeg[i-1], inv2);
    }

    ll n, k, x;
    //n = 2; k = 2; x = 1;
    cin>>n>>k>>x;

    ll M = (1ll<<k);

    ll invM = po(inv2, k);
    init();

    if (x==0)
    {
        cout<<mul(n, sub(1, invM))<<endl; return 0;
    }

    int not0notX = sub(1, mul(invM, 2));

    int ans = 0;

    ans = add(ans, mul(n, not0notX));

    //cout<<mul(ans, po(M, n))<<endl;

    vector<int> prob_len(n+1);

    vector<int> validdeg(LIM);
    validdeg[0] = 1;
    for (int i = 1; i<LIM; i++) validdeg[i] = mul(validdeg[i-1], add(invM, invM));

    for (int len = 3; len<=n; len++)
    {
        prob_len[len] = validdeg[len-2];
        prob_len[len] = mul(prob_len[len], mul(not0notX, not0notX));
        prob_len[len] = mul(prob_len[len], sub(exp_max(len-1), 1));

        ans = add(ans, mul(prob_len[len], n + 1 - len));
    }

    //cout<<mul(ans, po(M, n))<<endl;

    for (int pref = 2; pref<=n; pref++)
    {
        int prob = mul(validdeg[pref-1], not0notX);
        prob = mul(prob, mul(pref-1, inv2));
        ans = add(ans, mul(prob, 2));
    }

    //cout<<mul(ans, po(M, n))<<endl;

    int prob = mul(validdeg[n-1], invM);

    ans = add(ans, prob);
    ans = add(ans, mul(prob, mul(n-1, inv2)));

    cout<<ans<<endl;

    //cout<<mul(ans, po(M, n))<<endl;
}

/*
2 2 1
 */
Tester's Solution
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define fi first
#define se second
const ll mod=998244353;
const int N=1e6+5;
ll n,k,x;
ll pw(ll x,ll y){
	if(y==0) return 1;
	if(y%2) return x*pw(x,y-1)%mod;
	ll res=pw(x,y/2);
	return res*res%mod;
}
ll f[N],inf[N];
ll C(ll x,ll y){
	return f[x]*inf[y]%mod*inf[x-y]%mod;
}
int main(){
	ios::sync_with_stdio(false);cin.tie(0);
	cin >> n >> k >> x;
	ll king=pw(2,k);
	if(x==0){
		ll f=(king-1)*pw(king,mod-2)%mod;
		ll ans=n*f%mod;
		cout << ans << '\n';
		return 0;
	}
	f[0]=1;
	for(int i=1; i<=n ;i++) f[i]=f[i-1]*i%mod;
	inf[n]=pw(f[n],mod-2);
	for(int i=n; i>=1 ;i--) inf[i-1]=inf[i]*i%mod;
	
	ll py=(king-2)*pw(king,mod-2)%mod;
	ll pz=2*pw(king,mod-2)%mod;
	ll ans=n*py%mod;
	for(int i=2; i<n ;i++){
		ll ways=(n-i)*pw(py,2)%mod*pw(pz,i-1)%mod;
		ll guys=(C(i-1,(i-1)/2)*pw(2,(mod-2)*i)+(mod+1)/2)%mod;
		guys=(guys*i+mod-1)%mod;
		ans=(ans+ways*guys)%mod;
	}
	ans=(ans+pw(pz,n)*(n+1)%mod*pw(4,mod-2))%mod;
	for(int i=2; i<=n ;i++){
		ll ways=pw(pz,i-1)*py*2%mod;
		ll guys=(i-1)*(mod+1)/2%mod;
		ans=(ans+ways*guys)%mod;
	}
	cout << ans << '\n';
}
Editorialist's Solution
#include <bits/stdc++.h>
#include <atcoder/modint>
using namespace std;
using namespace atcoder;

using mint = modint998244353;

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    int n; cin >> n;
    long long k, x; cin >> k >> x;
    mint singleton = 1 / mint(2).pow(k);
    if (x == 0) {
        cout << ((1 - singleton) * n).val();
        return 0;
    }

    // init
    vector<mint> fct(n + 1);
    fct[0] = 1;
    for (int i = 1; i <= n; i++) {
        fct[i] = fct[i - 1] * i;
    }
    auto C = [&](int n, int k) {
        return n < k || k < 0 ? mint(0) : fct[n] / fct[k] / fct[n - k];
    };
    mint ans = 0;

    // good elements
    ans += (1 - 2 * singleton) * n;
    
    // middle segments
    vector<mint> expected_max(n + 1);
    for (int i = 1; i <= n; i++) {
        // if we choose the i-th element, (i - 1) needs to have at least (i - 1) / 2 similar elements
        mint prob;
        if ((i - 1) % 2 == 1) {
            prob = mint(1) / 2;
        } else {
            mint mid = C(i - 1, (i - 1) / 2) / mint(2).pow(i - 1);
            prob = (1 - mid) / 2 + mid;
        }
        expected_max[i] = expected_max[i - 1] + prob;
        // contribution of ex_max into the answer
        // endpoints are i apart, each w.p. 1 - 2 * singleton
        // middle parts have i - 1, each w.p. 2 * singleton
        ans += (n - i) * (1 - 2 * singleton).pow(2) * (2 * singleton).pow(i - 1) * (expected_max[i] - 1);
    }

    // ending segments
    for (int i = 1; i <= n; i++) {
        ans += (1 - 2 * singleton) * (2 * singleton).pow(i - 1) * (i - 1);
    }

    // full length
    ans += (2 * singleton).pow(n) * ((n - 1) / mint(2) + 1) / 2;

    cout << ans.val();
}
1 Like