SMXOR - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: kingmessi
Tester: watoac2001
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Dynamic programming

PROBLEM:

You’re given two arrays L and R, both of length N.
Across all arrays A of length N such that L_i \leq A_i \leq R_i for every i, find the sum of all subarray bitwise XORs.

EXPLANATION:

We’re dealing with bitwise XOR, so it’s helpful to treat bits independently.
The target value is also the sum of bitwise XOR across all subarrays, so we can consider each subarray independently.

So, let’s fix a bit b, and the endpoints of a subarray l and r; and then try to count the number of ways of assigning values such that this subarray has b set in its bitwise XOR.

Let s_i be the number of ways to choose A_i between L_i and R_i, such that it has bit b set.
This can be computed in constant time with a bit of math.

How?

We’ll solve a simplified version: given N and b, find the number of integers \leq N that have b set.
Let this be f(N, b).

Observe that whether bit b is set or not is periodic in 2^{b+1}.
Specifically,

  • 0, 1, 2, \ldots, 2^{b}-1 don’t have it set.
  • 2^b, 2^b+1, 2^b+2, \ldots, 2^{b+1}-1 do have it set.
  • 2^{b+1}, 2^{b+1}+1, \ldots, 2^{b+2}-1 don’t have it set.
    \vdots

So, compute the number of full blocks of length 2^{b+1} below N - each such block contributes a count of 2^b.
We’re then left with the remainder, i.e, the last N\pmod{ 2^{b+1}} elements.
As noted above, the number of them that have b set is easily found: modulo 2^{b+1}, the first half doesn’t have it set and the second half does; so just compute how many elements lie in the second half.


The number of integers in [L, R] that have b set is then just f(R, b) - f(L-1, b).


So, compute s_i for each l \leq i \leq r.
Now, note that this subarray will have b set in its bitwise XOR if and only if an odd number of its elements have b set.
This means we must choose some odd set of indices; then for each of these indices i choose one of s_i values for them, and for every other index choose one of R_i - L_i + 1 - s_i values (i.e, values that don’t have b set).
Let t_i = R_i - L_i + 1 - s_i.

This seems pretty hard to do: there are exponentially many ways to choose a subset of odd size and we definitely can’t try all of them.
Instead, observe that our choice for the last element, A_r, depends only on whether the XOR so far contains b or not.
Specifically, if it does contain b then we must choose from one of t_i values to continue that; and if it doesn’t we choose from one of s_i values to set b.

This lets us solve the problem using dynamic programming.
Let dp[l][r][0] be the number of ways to choose the values of A_l, A_{l+1}, \ldots, A_r such that b is not set; and dp[l][r][1] be the number of ways to choose their values such that it is set.
Then,

dp[l][r][0] = dp[l][r-1][0] \cdot t_r + dp[l][r-1][1] \cdot s_r \\ dp[l][r][1] = dp[l][r-1][0] \cdot s_r + dp[l][r-1][1] \cdot t_r \\

We have \mathcal{O}(N^2) states, each of which can be computed in constant time; for \mathcal{O}(N^2) overall (for a fixed b).
Now, for the subarray [l, r],

  • There are dp[l][r][1] ways to choose its values such that b is set.
  • For everything outside the subarray, any value can be chosen.
    • So, take the product of (R_i - L_i + 1) across all indices outside this subarray.
    • Notice that this is a prefix product and a suffix product, and so can be computed in constant time if those are precomputed.

Repeat this process for each b from 0 to 30, and sum up all the answers.


As a final note, since dp[l][r] depends only on dp[l][r-1], you can actually get away with not having to store a dp table at all.
Fix a value of l, and then simply store two variables denoting dp[l][r][0] and dp[l][r][1], which can updated in constant time as you iterate r from l to N.

TIME COMPLEXITY:

\mathcal{O}(N^2 \log{10^9}) per testcase.

CODE:

Author's code (C++)
#include<bits/stdc++.h>

#include <cassert>
#include <numeric>
#include <type_traits>

#ifdef _MSC_VER
#include <intrin.h>
#endif


#include <utility>

#ifdef _MSC_VER
#include <intrin.h>
#endif

namespace atcoder {

namespace internal {

constexpr long long safe_mod(long long x, long long m) {
    x %= m;
    if (x < 0) x += m;
    return x;
}

struct barrett {
    unsigned int _m;
    unsigned long long im;

    explicit barrett(unsigned int m) : _m(m), im((unsigned long long)(-1) / m + 1) {}

    unsigned int umod() const { return _m; }

    unsigned int mul(unsigned int a, unsigned int b) const {

        unsigned long long z = a;
        z *= b;
#ifdef _MSC_VER
        unsigned long long x;
        _umul128(z, im, &x);
#else
        unsigned long long x =
            (unsigned long long)(((unsigned __int128)(z)*im) >> 64);
#endif
        unsigned long long y = x * _m;
        return (unsigned int)(z - y + (z < y ? _m : 0));
    }
};

constexpr long long pow_mod_constexpr(long long x, long long n, int m) {
    if (m == 1) return 0;
    unsigned int _m = (unsigned int)(m);
    unsigned long long r = 1;
    unsigned long long y = safe_mod(x, m);
    while (n) {
        if (n & 1) r = (r * y) % _m;
        y = (y * y) % _m;
        n >>= 1;
    }
    return r;
}

constexpr bool is_prime_constexpr(int n) {
    if (n <= 1) return false;
    if (n == 2 || n == 7 || n == 61) return true;
    if (n % 2 == 0) return false;
    long long d = n - 1;
    while (d % 2 == 0) d /= 2;
    constexpr long long bases[3] = {2, 7, 61};
    for (long long a : bases) {
        long long t = d;
        long long y = pow_mod_constexpr(a, t, n);
        while (t != n - 1 && y != 1 && y != n - 1) {
            y = y * y % n;
            t <<= 1;
        }
        if (y != n - 1 && t % 2 == 0) {
            return false;
        }
    }
    return true;
}
template <int n> constexpr bool is_prime = is_prime_constexpr(n);

constexpr std::pair<long long, long long> inv_gcd(long long a, long long b) {
    a = safe_mod(a, b);
    if (a == 0) return {b, 0};

    long long s = b, t = a;
    long long m0 = 0, m1 = 1;

    while (t) {
        long long u = s / t;
        s -= t * u;
        m0 -= m1 * u;  // |m1 * u| <= |m1| * s <= b


        auto tmp = s;
        s = t;
        t = tmp;
        tmp = m0;
        m0 = m1;
        m1 = tmp;
    }
    if (m0 < 0) m0 += b / s;
    return {s, m0};
}

constexpr int primitive_root_constexpr(int m) {
    if (m == 2) return 1;
    if (m == 167772161) return 3;
    if (m == 469762049) return 3;
    if (m == 754974721) return 11;
    if (m == 998244353) return 3;
    int divs[20] = {};
    divs[0] = 2;
    int cnt = 1;
    int x = (m - 1) / 2;
    while (x % 2 == 0) x /= 2;
    for (int i = 3; (long long)(i)*i <= x; i += 2) {
        if (x % i == 0) {
            divs[cnt++] = i;
            while (x % i == 0) {
                x /= i;
            }
        }
    }
    if (x > 1) {
        divs[cnt++] = x;
    }
    for (int g = 2;; g++) {
        bool ok = true;
        for (int i = 0; i < cnt; i++) {
            if (pow_mod_constexpr(g, (m - 1) / divs[i], m) == 1) {
                ok = false;
                break;
            }
        }
        if (ok) return g;
    }
}
template <int m> constexpr int primitive_root = primitive_root_constexpr(m);

unsigned long long floor_sum_unsigned(unsigned long long n,
                                      unsigned long long m,
                                      unsigned long long a,
                                      unsigned long long b) {
    unsigned long long ans = 0;
    while (true) {
        if (a >= m) {
            ans += n * (n - 1) / 2 * (a / m);
            a %= m;
        }
        if (b >= m) {
            ans += n * (b / m);
            b %= m;
        }

        unsigned long long y_max = a * n + b;
        if (y_max < m) break;
        n = (unsigned long long)(y_max / m);
        b = (unsigned long long)(y_max % m);
        std::swap(m, a);
    }
    return ans;
}

}  // namespace internal

}  // namespace atcoder


#include <cassert>
#include <numeric>
#include <type_traits>

namespace atcoder {

namespace internal {

#ifndef _MSC_VER
template <class T>
using is_signed_int128 =
    typename std::conditional<std::is_same<T, __int128_t>::value ||
                                  std::is_same<T, __int128>::value,
                              std::true_type,
                              std::false_type>::type;

template <class T>
using is_unsigned_int128 =
    typename std::conditional<std::is_same<T, __uint128_t>::value ||
                                  std::is_same<T, unsigned __int128>::value,
                              std::true_type,
                              std::false_type>::type;

template <class T>
using make_unsigned_int128 =
    typename std::conditional<std::is_same<T, __int128_t>::value,
                              __uint128_t,
                              unsigned __int128>;

template <class T>
using is_integral = typename std::conditional<std::is_integral<T>::value ||
                                                  is_signed_int128<T>::value ||
                                                  is_unsigned_int128<T>::value,
                                              std::true_type,
                                              std::false_type>::type;

template <class T>
using is_signed_int = typename std::conditional<(is_integral<T>::value &&
                                                 std::is_signed<T>::value) ||
                                                    is_signed_int128<T>::value,
                                                std::true_type,
                                                std::false_type>::type;

template <class T>
using is_unsigned_int =
    typename std::conditional<(is_integral<T>::value &&
                               std::is_unsigned<T>::value) ||
                                  is_unsigned_int128<T>::value,
                              std::true_type,
                              std::false_type>::type;

template <class T>
using to_unsigned = typename std::conditional<
    is_signed_int128<T>::value,
    make_unsigned_int128<T>,
    typename std::conditional<std::is_signed<T>::value,
                              std::make_unsigned<T>,
                              std::common_type<T>>::type>::type;

#else

template <class T> using is_integral = typename std::is_integral<T>;

template <class T>
using is_signed_int =
    typename std::conditional<is_integral<T>::value && std::is_signed<T>::value,
                              std::true_type,
                              std::false_type>::type;

template <class T>
using is_unsigned_int =
    typename std::conditional<is_integral<T>::value &&
                                  std::is_unsigned<T>::value,
                              std::true_type,
                              std::false_type>::type;

template <class T>
using to_unsigned = typename std::conditional<is_signed_int<T>::value,
                                              std::make_unsigned<T>,
                                              std::common_type<T>>::type;

#endif

template <class T>
using is_signed_int_t = std::enable_if_t<is_signed_int<T>::value>;

template <class T>
using is_unsigned_int_t = std::enable_if_t<is_unsigned_int<T>::value>;

template <class T> using to_unsigned_t = typename to_unsigned<T>::type;

}  // namespace internal

}  // namespace atcoder


namespace atcoder {

namespace internal {

struct modint_base {};
struct static_modint_base : modint_base {};

template <class T> using is_modint = std::is_base_of<modint_base, T>;
template <class T> using is_modint_t = std::enable_if_t<is_modint<T>::value>;

}  // namespace internal

template <int m, std::enable_if_t<(1 <= m)>* = nullptr>
struct static_modint : internal::static_modint_base {
    using mint = static_modint;

  public:
    static constexpr int mod() { return m; }
    static mint raw(int v) {
        mint x;
        x._v = v;
        return x;
    }

    static_modint() : _v(0) {}
    template <class T, internal::is_signed_int_t<T>* = nullptr>
    static_modint(T v) {
        long long x = (long long)(v % (long long)(umod()));
        if (x < 0) x += umod();
        _v = (unsigned int)(x);
    }
    template <class T, internal::is_unsigned_int_t<T>* = nullptr>
    static_modint(T v) {
        _v = (unsigned int)(v % umod());
    }

    unsigned int val() const { return _v; }

    mint& operator++() {
        _v++;
        if (_v == umod()) _v = 0;
        return *this;
    }
    mint& operator--() {
        if (_v == 0) _v = umod();
        _v--;
        return *this;
    }
    mint operator++(int) {
        mint result = *this;
        ++*this;
        return result;
    }
    mint operator--(int) {
        mint result = *this;
        --*this;
        return result;
    }

    mint& operator+=(const mint& rhs) {
        _v += rhs._v;
        if (_v >= umod()) _v -= umod();
        return *this;
    }
    mint& operator-=(const mint& rhs) {
        _v -= rhs._v;
        if (_v >= umod()) _v += umod();
        return *this;
    }
    mint& operator*=(const mint& rhs) {
        unsigned long long z = _v;
        z *= rhs._v;
        _v = (unsigned int)(z % umod());
        return *this;
    }
    mint& operator/=(const mint& rhs) { return *this = *this * rhs.inv(); }

    mint operator+() const { return *this; }
    mint operator-() const { return mint() - *this; }

    mint pow(long long n) const {
        assert(0 <= n);
        mint x = *this, r = 1;
        while (n) {
            if (n & 1) r *= x;
            x *= x;
            n >>= 1;
        }
        return r;
    }
    mint inv() const {
        if (prime) {
            assert(_v);
            return pow(umod() - 2);
        } else {
            auto eg = internal::inv_gcd(_v, m);
            assert(eg.first == 1);
            return eg.second;
        }
    }

    friend mint operator+(const mint& lhs, const mint& rhs) {
        return mint(lhs) += rhs;
    }
    friend mint operator-(const mint& lhs, const mint& rhs) {
        return mint(lhs) -= rhs;
    }
    friend mint operator*(const mint& lhs, const mint& rhs) {
        return mint(lhs) *= rhs;
    }
    friend mint operator/(const mint& lhs, const mint& rhs) {
        return mint(lhs) /= rhs;
    }
    friend bool operator==(const mint& lhs, const mint& rhs) {
        return lhs._v == rhs._v;
    }
    friend bool operator!=(const mint& lhs, const mint& rhs) {
        return lhs._v != rhs._v;
    }

  private:
    unsigned int _v;
    static constexpr unsigned int umod() { return m; }
    static constexpr bool prime = internal::is_prime<m>;
};

template <int id> struct dynamic_modint : internal::modint_base {
    using mint = dynamic_modint;

  public:
    static int mod() { return (int)(bt.umod()); }
    static void set_mod(int m) {
        assert(1 <= m);
        bt = internal::barrett(m);
    }
    static mint raw(int v) {
        mint x;
        x._v = v;
        return x;
    }

    dynamic_modint() : _v(0) {}
    template <class T, internal::is_signed_int_t<T>* = nullptr>
    dynamic_modint(T v) {
        long long x = (long long)(v % (long long)(mod()));
        if (x < 0) x += mod();
        _v = (unsigned int)(x);
    }
    template <class T, internal::is_unsigned_int_t<T>* = nullptr>
    dynamic_modint(T v) {
        _v = (unsigned int)(v % mod());
    }

    unsigned int val() const { return _v; }

    mint& operator++() {
        _v++;
        if (_v == umod()) _v = 0;
        return *this;
    }
    mint& operator--() {
        if (_v == 0) _v = umod();
        _v--;
        return *this;
    }
    mint operator++(int) {
        mint result = *this;
        ++*this;
        return result;
    }
    mint operator--(int) {
        mint result = *this;
        --*this;
        return result;
    }

    mint& operator+=(const mint& rhs) {
        _v += rhs._v;
        if (_v >= umod()) _v -= umod();
        return *this;
    }
    mint& operator-=(const mint& rhs) {
        _v += mod() - rhs._v;
        if (_v >= umod()) _v -= umod();
        return *this;
    }
    mint& operator*=(const mint& rhs) {
        _v = bt.mul(_v, rhs._v);
        return *this;
    }
    mint& operator/=(const mint& rhs) { return *this = *this * rhs.inv(); }

    mint operator+() const { return *this; }
    mint operator-() const { return mint() - *this; }

    mint pow(long long n) const {
        assert(0 <= n);
        mint x = *this, r = 1;
        while (n) {
            if (n & 1) r *= x;
            x *= x;
            n >>= 1;
        }
        return r;
    }
    mint inv() const {
        auto eg = internal::inv_gcd(_v, mod());
        assert(eg.first == 1);
        return eg.second;
    }

    friend mint operator+(const mint& lhs, const mint& rhs) {
        return mint(lhs) += rhs;
    }
    friend mint operator-(const mint& lhs, const mint& rhs) {
        return mint(lhs) -= rhs;
    }
    friend mint operator*(const mint& lhs, const mint& rhs) {
        return mint(lhs) *= rhs;
    }
    friend mint operator/(const mint& lhs, const mint& rhs) {
        return mint(lhs) /= rhs;
    }
    friend bool operator==(const mint& lhs, const mint& rhs) {
        return lhs._v == rhs._v;
    }
    friend bool operator!=(const mint& lhs, const mint& rhs) {
        return lhs._v != rhs._v;
    }

  private:
    unsigned int _v;
    static internal::barrett bt;
    static unsigned int umod() { return bt.umod(); }
};
template <int id> internal::barrett dynamic_modint<id>::bt(998244353);

using modint998244353 = static_modint<998244353>;
using modint1000000007 = static_modint<1000000007>;
using modint = dynamic_modint<-1>;

namespace internal {

template <class T>
using is_static_modint = std::is_base_of<internal::static_modint_base, T>;

template <class T>
using is_static_modint_t = std::enable_if_t<is_static_modint<T>::value>;

template <class> struct is_dynamic_modint : public std::false_type {};
template <int id>
struct is_dynamic_modint<dynamic_modint<id>> : public std::true_type {};

template <class T>
using is_dynamic_modint_t = std::enable_if_t<is_dynamic_modint<T>::value>;

}  // namespace internal

}  // namespace atcoder

#include <ext/pb_ds/assoc_container.hpp> // Common file
#include <ext/pb_ds/tree_policy.hpp>
#define ll long long
#define int long long
#define rep(i,a,b) for(int i=a;i<b;i++)
#define rrep(i,a,b) for(int i=a;i>=b;i--)
#define repin rep(i,0,n)
#define di(a) int a;cin>>a;
#define precise(i) cout<<fixed<<setprecision(i)
#define vi vector<int>
#define si set<int>
#define mii map<int,int>
#define take(a,n) for(int j=0;j<n;j++) cin>>a[j];
#define give(a,n) for(int j=0;j<n;j++) cout<<a[j]<<' ';
#define vpii vector<pair<int,int>>
#define sis string s;
#define sin string s;cin>>s;
#define db double
#define be(x) x.begin(),x.end()
#define pii pair<int,int>
#define pb push_back
#define pob pop_back
#define ff first
#define ss second
#define lb lower_bound
#define ub upper_bound
#define bpc(x) __builtin_popcountll(x) 
#define btz(x) __builtin_ctz(x)
using namespace std;
using namespace atcoder;
using namespace __gnu_pbds;

typedef tree<int, null_type, less<int>, rb_tree_tag,tree_order_statistics_node_update> ordered_set;
typedef tree<pair<int, int>, null_type,less<pair<int, int> >, rb_tree_tag,tree_order_statistics_node_update> ordered_multiset;

const long long INF=1e18;
const long long M=1e9+7;
const long long MM=998244353;
using mint = static_modint<MM>;
  
int power( int N, int M){
    int power = N, sum = 1;
    if(N == 0) sum = 0;
    while(M > 0){if((M & 1) == 1){sum *= power;}
    power = power * power;M = M >> 1;}
    return sum;
}

int sb(int l,int r,int k){
    l--;
    int p = (1LL<<(k+1));
    if(l/p == r/p){
        return max((r%p),p/2-1) - max((l%p),p/2-1);
    }
    int ans = 0;
    if(r%p){
        ans += max((r%p)-p/2+1,0LL);
        r = (r/p)*p;
    }
    if(l%p){
        ans += p-1-max(p/2-1,l%p);
        l = ((l/p)+1)*p;
    }
    int num = (r-l)/p;
    ans += num*(p/2);
    return ans;
}

int smn;
 
void solve()
{
    int n;
    cin >> n;
    assert(n >= 1);
    assert(n <= 1000);
    smn += n;
    vi l(n),r(n);
    take(l,n);
    take(r,n);
    for(auto x : l)assert(x >= 0),assert(x <= 1e9);
    for(auto x : r)assert(x >= 0),assert(x <= 1e9);
    repin assert(l[i] <= r[i]);
    // cout << sb(0,5,20) << "\n";
    mint ans = 0;
    rep(k,0,30){
        vector<int> f(n);
        rep(j,0,n){
            f[j] = sb(l[j],r[j],k);
        }
        vector<vector<mint>> prod(n,vector<mint>(n,1));
        vector<vector<mint>> dp(n,vector<mint>(n,0));
        repin{
            mint pr = 1;
            rep(j,i,n){
                pr *= (r[j]-l[j]+1);
                prod[i][j] = pr;
            }
        }
        repin{
            dp[i][i] = f[i];
        }
        repin{
            rep(j,i+1,n){
                dp[i][j] = dp[i][j-1]*(r[j]-l[j]+1-f[j]) + (prod[i][j-1] - dp[i][j-1])*f[j];
            }
        }
        repin{
            rep(j,i,n){
                ans += dp[i][j] * (1<<k) * (i?prod[0][i-1]:1) * (j+1<n?prod[j+1][n-1]:1);
            }
        }
    }
    cout << ans.val() << "\n";

}

signed main(){
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    #ifdef NCR
        init();
    #endif
    #ifdef SIEVE
        sieve();
    #endif
        int t; cin >> t; 
        assert(t >= 1);
        assert(t <= 100);
        while(t--)
        solve();
        assert(smn <= 1000);
        // cout << "hi\n";
    return 0;
}
Editorialist's code (Python)
mod = 998244353

def calc(N, b):
    # number of integers <= N that have bit b set
    pw = 2**b
    res = N // (2*pw)
    res = pw*res
    res += max(0, N%(2*pw) - pw + 1)
    return res

for _ in range(int(input())):
    n = int(input())
    l = list(map(int, input().split()))
    r = list(map(int, input().split()))
    
    ans = 0
    pref, suf = [1]*n, [1]*n
    for i in range(n):
        pref[i] = suf[i] = r[i] - l[i] + 1
        if i > 0: pref[i] = pref[i] * pref[i-1] % mod
    for i in reversed(range(n-1)):
        suf[i] = suf[i+1] * suf[i] % mod
    
    for b in range(30):
        for L in range(n):
            on, off = 0, 1
            for R in range(L, n):
                onways = calc(r[R], b) - calc(l[R]-1, b)
                offways = r[R] - l[R] + 1 - onways
                on, off = (onways * off + offways * on) % mod, (onways * on + offways * off) % mod

                outside = 1
                if L > 0: outside = pref[L-1]
                if R+1 < n: outside = outside * suf[R+1] % mod
                ans += on * outside % mod * (2**b) % mod
    print(ans % mod)
1 Like

O(n*log(10^9)*log(10^9)) solution
https://www.codechef.com/viewsolution/1055604900

it might be further improved if this function is improved →
“countSetBits”
we can get rid of a log factor if we precalculate countSetBit values.

okay i further improved it O(n*log(10^9))
https://www.codechef.com/viewsolution/1055614354