SUBXORCNT - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: raysh07
Tester: sushil2006
Editorialist: iceknight1093

DIFFICULTY:

Medium

PREREQUISITES:

Combinatorics

PROBLEM:

Consider a tree with N vertices, and vertex i having value A_i.
Vertex u is good if:

  • u is a leaf; or
  • Let the vertices in the subtree of u, excluding u, be v_1, v_2, \ldots, v_k.
    Then, A_u = A_{v_1} \oplus A_{v_2} \oplus\cdots\oplus A_{v_k}.

The tree is beautiful if all its nodes are good.
f(A) denotes the minimum number of elements of A that need to be changed, for the tree to be beautiful.


Given a tree with N vertices, compute the sum of f(A) across all arrays A with elements in [0, M].

EXPLANATION:

As always, in such problems the first step should be to analyze when exactly a tree is beautiful (and from there, figure out f(A)).

To start, we rewrite the condition on a vertex being good a bit.
Leaves are always good; but for non-leaf vertices, A_u = A_{v_1} \oplus A_{v_2} \oplus\cdots\oplus A_{v_k} is equivalent to
A_u \oplus A_{v_1} \oplus A_{v_2} \oplus\cdots\oplus A_{v_k} = 0, in other words the subtree XOR should be 0.

Now, suppose every vertex is good.
Let u be a non-leaf vertex, let’s look at the children of u.
Some of them might be leaves, and the others might be non-leaves.
Let the children that are leaves be l_1, l_2, \ldots, l_x and the non-leaves be c_1, c_2, \ldots, c_y.

Each of these children must be good themselves.
In particular, that means the non-leaf children must each have a subtree XOR of 0.
This means their contribution to the subtree XOR of u is also just 0, meaning they can be ignored entirely.

The leaves are good no matter what; so the only real condition we have is
A_u \oplus A_{l_1} \oplus A_{l_2} \oplus \cdots \oplus A_{l_x} = 0.

Thus, we obtain the following result:

A tree is beautiful if and only if for every non-leaf vertex u, the bitwise XOR of the value of u and the values of its leaf children equals 0.


With this in mind, suppose we connect each leaf to its parent (and delete all other edges from the tree).
We’ll then be left with several connected components, of various sizes. Each component will look like a star: a node and several (possibly zero) children.

The values within each component should have a bitwise XOR of 0. However, the components are pretty independent of each other otherwise, so the component-wise condition is not only necessary, but also sufficient.

Now, let’s look at f(A).
The beauty of a tree is determined independently by each of its components.
If some component has a bitwise XOR of 0, nothing needs to be done to it.
If the bitwise XOR is non-zero, one operation is necessary to fix it - and it’s easy to see that one operation suffices, since we can just set any one vertex to the XOR of the rest.

So, f(A) simply equals the number of components with non-zero XOR.


We want to compute the sum of f(A) across all arrays A.
We’ll do that by contribution counting.

That is, let’s fix a certain component - say of size K - and count the number of arrangements in which it adds 1 to the answer.

First off, values outside the component don’t matter at all, and can be anything.
So, there’s always a multiplier of (M+1)^{N-K}.
That leaves us with just the component itself: our goal is really to find the number of ways to choose K integers, each in [0, M], such that their XOR is non-zero.

To do this, we’ll instead compute the number of choices whose XOR is 0, and subtract that from (M+1)^K instead.


Our aim is now to count the number of ways of choosing K elements in [0, M] whose XOR is 0.
If we fix K-1 of the elements, the last one is uniquely determined - so at first glance there are (M+1)^{K-1} options.
However, there’s no guarantee that the last element will be \leq M if decided this way, so it doesn’t quite work - we need to modify the idea a bit.

Observe that “bitwise XOR equal to 0” is equivalent to saying “every bit appears an even number of times in total”.

Let 2^b \leq M be the largest power of two that doesn’t exceed M (note: M = 0 is an edge case here, and obviously every value must be 0 so the count is just 1).
Bit b must then occur an even number of times in total.
Suppose it occurs 2x times. Then,

  • There are \binom{K}{2x} choices for which elements contain this bit.
  • Each of these elements must lie in the range [2^b, M], meaning there are (M-2^b+1) choices.
    That makes (M-2^b+1)^{2x} choices overall.
  • There are K-2x elements remaining; and they must all be in [0, 2^b-1].
    However, fixing all but one of them will fix the last one as well - and this time the last element is guaranteed to be \lt 2^b meaning it definitely can’t exceed M - exactly what we want!
    So, there are (2^b)^{K-2x-1} choices for the elements that don’t have bit b set.

Adding this up across all choices of x gives us the total count, i.e.

\sum_{2x \leq K} \binom{K}{2x} \cdot (M-2^b + 1)^{2x}\cdot (2^b)^{K-2x-1}

This is easily computed in \mathcal{O}(K\log K) or even \mathcal{O}(K) time, so we’re done!


Well, not quite: there’s one small detail missing.
The reason the above solution “worked” is that when the last element got fixed automatically, it was definitely not larger than M.
However, what if there’s no “last element” to fix?
That is, what if 2x = K, so that every element has bit b set?

Our counting actually fails for that case, because we imposed no restrictions on the chosen elements.
Luckily, this is not hard to fix at all!

Observe that when 2x = K, so that all the elements have b set, we can just pretend that none of the elements have b set — the only difference is that we’re only allowed to choose elements in [0, M-2^b] instead.

In other words, we’ve simply reduced to a smaller version of the exact same problem!
So, all we need to do is solve this smaller problem, and add it to the answer as well (in lieu of the expression for 2x = K).
This may involve once again removing a bit and solving an even smaller version, and so on: however, since each reduction removes a bit, we’ll solve only \mathcal{O}(\log M) versions in total; each in \mathcal{O}(K) time which is fast enough.


We’re able to compute the answer for a component of size K in \mathcal{O}(K\log M) time.
Since the sum of sizes of all components is N, we have a solution in \mathcal{O}(N\log M) overall which is more than fast enough.

TIME COMPLEXITY:

\mathcal{O}(N\log M) per testcase.

CODE:

Editorialist's code (PyPy3)
mod = 998244353
fac = [1]
for i in range(1, 200005): fac.append(fac[-1] * i % mod)
ifac = fac[:]
ifac[-1] = pow(ifac[-1], mod-2, mod)
for i in reversed(range(200004)): ifac[i] = ifac[i+1] * (i+1) % mod
def C(n, r):
    if n < r or r < 0: return 0
    return fac[n] * ifac[r] % mod * ifac[n-r] % mod

for _ in range(int(input())):
    n, m = map(int, input().split())
    adj = [ [] for _ in range(n+1) ]
    for i in range(n-1):
        u, v = map(int, input().split())
        adj[u].append(v)
        adj[v].append(u)
    
    leaf, mark = [0]*(n+1), [0]*(n+1)
    qu = [1]
    mark[1] = 1
    for u in qu:
        leaf[u] = 1
        for v in adj[u]:
            if not mark[v]:
                mark[v] = 1
                qu.append(v)
                leaf[u] = 0
    
    ms = [m]
    for b in reversed(range(30)):
        if ms[-1] & 2**b:
            ms.append(ms[-1] - 2**b)
    
    ans = 0
    for u in range(1, n+1):
        if leaf[u]: continue

        ct = 1
        for v in adj[u]: ct += leaf[v]

        cur = 0
        for i in range(len(ms)):
            M = ms[i]
            if M == 0:
                cur += 1
                break
            
            for x in range((ct+1)//2):
                cur += C(ct, 2*x) * pow(ms[i+1] + 1, 2*x, mod) % mod * pow(M - ms[i+1], ct - 2*x - 1, mod) % mod
            if ct%2 == 1: break
        cur = pow(m+1, ct, mod) - cur
        cur = cur * pow(m+1, n-ct, mod) % mod
        ans += cur
    print(ans % mod)
Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

const int mod = 998244353;

struct mint{
    int x;

    mint (){ x = 0;}
    mint (int32_t xx){ x = xx % mod; if (x < 0) x += mod;}
    mint (long long xx){ x = xx % mod; if (x < 0) x += mod;}

    int val(){
        return x;
    }
    mint &operator++(){
        x++;
        if (x == mod) x = 0;
        return *this;
    }
    mint &operator--(){
        if (x == 0) x = mod;
        x--;
        return *this;
    }
    mint operator++(int32_t){
        mint result = *this;
        ++*this;
        return result;
    }
    
    mint operator--(int32_t){
        mint result = *this;
        --*this;
        return result;
    }
    mint& operator+=(const mint &b){
        x += b.x;
        if (x >= mod) x -= mod;
        return *this;
    }
    mint& operator-=(const mint &b){
        x -= b.x;
        if (x < 0) x += mod;
        return *this;
    }
    mint& operator*=(const mint &b){
        long long z = x;
        z *= b.x;
        z %= mod;
        x = (int)z;
        return *this;
    }
    mint operator+() const {
        return *this;
    }
    mint operator-() const {
        return mint() - *this;
    }
    mint operator/=(const mint &b){
        return *this = *this * b.inv();
    }
    mint power(long long n) const {
        mint ok = *this, r = 1;
        while (n){
            if (n & 1){
                r *= ok;
            }
            ok *= ok;
            n >>= 1;
        }
        return r;
    }
    mint inv() const {
        return power(mod - 2);
    }
    friend mint operator+(const mint& a, const mint& b){ return mint(a) += b;}
    friend mint operator-(const mint& a, const mint& b){ return mint(a) -= b;}
    friend mint operator*(const mint& a, const mint& b){ return mint(a) *= b;}
    friend mint operator/(const mint& a, const mint& b){ return mint(a) /= b;}
    friend bool operator==(const mint& a, const mint& b){ return a.x == b.x;}
    friend bool operator!=(const mint& a, const mint& b){ return a.x != b.x;}
    mint power(mint a, long long n){
        return a.power(n);
    }
    friend ostream &operator<<(ostream &os, const mint &m) {
        os << m.x;
        return os;
    }
    explicit operator bool() const {
        return x != 0;
    }
};

struct factorials{
    int n;
    vector <mint> ff, iff;
    
    factorials(int nn){
        n = nn;
        ff.resize(n + 1);
        iff.resize(n + 1);
        
        ff[0] = 1;
        for (int i = 1; i <= n; i++){
            ff[i] = ff[i - 1] * i;
        }
        
        iff[n] = ff[n].inv();
        for (int i = n - 1; i >= 0; i--){
            iff[i] = iff[i + 1] * (i + 1);
        }
    }
    
    mint C(int n, int r){
        if (n == r) return mint(1);
        if (n < 0 || r < 0 || r > n) return mint(0);
        return ff[n] * iff[r] * iff[n - r];
    }
    
    mint P(int n, int r){
        if (n < 0 || r < 0 || r > n) return mint(0);
        return ff[n] * iff[n - r];
    }
    
    mint solutions(int n, int r){
        // Solutions to x1 + x2 + ... + xn = r, xi >= 0 
        return C(n + r - 1, n - 1);
    }
    
    mint catalan(int n){
        return ff[2 * n] * iff[n] * iff[n + 1];
    }
};

const int PRECOMP = 3e6 + 69;
factorials F(PRECOMP);

// REMEMBER To check MOD and PRECOMP

int ceil_pow2(int n) {
    int x = 0;
    while ((1U << x) < (unsigned int)(n)) x++;
    return x;
}
 
int bsf(unsigned int n){
    return __builtin_ctz(n);
}
 
void butterfly(std::vector<mint>& a) {
    static constexpr int g = 3; // primitive root 
    int n = (int)(a.size());
    int h = ceil_pow2(n);
 
    static bool first = true;
    static mint sum_e[30];  // sum_e[i] = ies[0] * ... * ies[i - 1] * es[i]
    if (first) {
        first = false;
        mint es[30], ies[30];  // es[i]^(2^(2+i)) == 1
        int cnt2 = bsf(mod - 1);
        mint e = mint(g).power((mod - 1) >> cnt2), ie = e.inv();
        for (int i = cnt2; i >= 2; i--) {
            // e^(2^i) == 1
            es[i - 2] = e;
            ies[i - 2] = ie;
            e *= e;
            ie *= ie;
        }
        mint now = 1;
        for (int i = 0; i <= cnt2 - 2; i++) {
            sum_e[i] = es[i] * now;
            now *= ies[i];
        }
    }
    for (int ph = 1; ph <= h; ph++) {
        int w = 1 << (ph - 1), p = 1 << (h - ph);
        mint now = 1;
        for (int s = 0; s < w; s++) {
            int offset = s << (h - ph + 1);
            for (int i = 0; i < p; i++) {
                auto l = a[i + offset];
                auto r = a[i + offset + p] * now;
                a[i + offset] = l + r;
                a[i + offset + p] = l - r;
            }
            now *= sum_e[bsf(~(unsigned int)(s))];
        }
    }
}
 
void butterfly_inv(std::vector<mint>& a) {
    static constexpr int g = 3; // primitive root 
    int n = (int)(a.size());
    int h = ceil_pow2(n);
 
    static bool first = true;
    static mint sum_ie[30];  // sum_ie[i] = es[0] * ... * es[i - 1] * ies[i]
    if (first) {
        first = false;
        mint es[30], ies[30];  // es[i]^(2^(2+i)) == 1
        int cnt2 = bsf(mod - 1);
        mint e = mint(g).power((mod - 1) >> cnt2), ie = e.inv();
        for (int i = cnt2; i >= 2; i--) {
            // e^(2^i) == 1
            es[i - 2] = e;
            ies[i - 2] = ie;
            e *= e;
            ie *= ie;
        }
        mint now = 1;
        for (int i = 0; i <= cnt2 - 2; i++) {
            sum_ie[i] = ies[i] * now;
            now *= es[i];
        }
    }
 
    for (int ph = h; ph >= 1; ph--) {
        int w = 1 << (ph - 1), p = 1 << (h - ph);
        mint inow = 1;
        for (int s = 0; s < w; s++) {
            int offset = s << (h - ph + 1);
            for (int i = 0; i < p; i++) {
                auto l = a[i + offset];
                auto r = a[i + offset + p];
                a[i + offset] = l + r;
                a[i + offset + p] = (mod + l.val() - r.val()) * inow.val();
            }
            inow *= sum_ie[bsf(~(unsigned int)(s))];
        }
    }
}
 
std::vector<mint> convolution_p2(std::vector<mint> a, std::vector<mint> b) {
    int n = (int)(a.size()), m = (int)(b.size());
    if (!n || !m) return {};
    if (std::min(n, m) <= 60) {
        if (n < m) {
            std::swap(n, m);
            std::swap(a, b);
        }
        std::vector<mint> ans(n + m - 1);
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < m; j++) {
                ans[i + j] += a[i] * b[j];
            }
        }
        return ans;
    }
    int z = 1 << ceil_pow2(n + m - 1);
    a.resize(z);
    butterfly(a);
    b.resize(z);
    butterfly(b);
    for (int i = 0; i < z; i++) {
        a[i] *= b[i];
    }
    butterfly_inv(a);
    a.resize(n + m - 1);
    mint iz = mint(z).inv();
    for (int i = 0; i < n + m - 1; i++) a[i] *= iz;
    return a;
}
 
std::vector<mint> convolution(const std::vector<mint>& a, const std::vector<mint>& b) {
    int n = (int)(a.size()), m = (int)(b.size());
    if (!n || !m) return {};
 
    std::vector<mint> a2(n), b2(m);
    for (int i = 0; i < n; i++) {
        a2[i] = mint(a[i]);
    }
    for (int i = 0; i < m; i++) {
        b2[i] = mint(b[i]);
    }
    auto c2 = convolution_p2(move(a2), move(b2));
    return c2;
}

// Remember to check MOD

void Solve() 
{
    int n, m; cin >> n >> m;
    m++;
    
    vector<vector<int>> adj(n + 1);
    for (int i = 1; i < n; i++){
        int u, v; cin >> u >> v;
        
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
    
    mint ans = 0;
    
    auto solve = [&](int x){
        vector <mint> dp(x + 1, 0);
        dp[x] = 1;
        
        vector <mint> f(x + 1, 0);
        for (int i = 0; i <= x; i++){
            f[i] = F.iff[x - i];
        }
        
        vector <mint> ndp(x + 1, 0);
        
        for (int bit = 30; bit >= 0; bit--){
            for (int i = 0; i <= x; i++){
                ndp[i] = 0;
            }
            
            if (m >> bit & 1){
                for (int i = 0; i <= x; i++){
                    if (i % 2 == 0){
                        ndp[i] += dp[x] * F.C(x, i);
                    }
                }
                
                dp[x] = 0;
                for (int i = 0; i < x; i++){
                    dp[i] *= mint(2).power(x - i - 1);
                    dp[i] *= F.ff[i];
                }
                
                auto c = convolution(dp, f);
                for (int j = 0; j <= x; j++){
                    ndp[j] += c[j + x] * F.iff[j];
                }
                
                for (int i = 0; i <= x; i++){
                    dp[i] = ndp[i];
                }
            } else {
                for (int i = 0; i <= x; i++){
                    mint w;
                    if (i == x){
                        w = 1;
                    } else {
                        w = mint(2).power(x - i - 1);
                    }
                    
                    dp[i] *= w;
                }
            }
        }
        
        mint ans = dp[0];
        mint tot = mint(m).power(x);
        tot -= ans;
        tot *= mint(m).power(n - x);
        
        return tot;
    };

    auto dfs = [&](auto self, int u, int par) -> bool{
        bool leaf = true;
        int cnt = 0;
        
        for (int v : adj[u]){
            if (v != par){
                leaf = false;
                cnt += self(self, v, u);
            }
        }
        
        if (!leaf){
            ans += solve(cnt + 1);
        }
        return leaf;
    };
    
    dfs(dfs, 1, -1);
    
    cout << ans << "\n";
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
    
    cin >> t;
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}
1 Like