TXOR - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Authors: kunjrp_1402
Tester: tabr
Editorialist: iceknight1093

DIFFICULTY:

2916

PREREQUISITES:

Dynamic programming on trees, rerooting

PROBLEM:

You’re given a tree on N vertices whose edges are weighted.
The distance between vertices u and v is defined as follows:

  • Let the edges on the (unique) u \to v path be w_1, w_2, \ldots, w_k in order.
  • Then, d(u, v) = f(w_1, 0) \oplus f(w_2, 1) \oplus f(w_3, 2) \oplus\ldots\oplus f(w_k, k-1)
    Here, f(x, b) refers to circular-shifting the integer x to the left, b times.
    All integers are considered to have 32 bits.

For each vertex v, compute \sum_{i=1}^N f(v, i).

EXPLANATION:

Very often, problems like this one, which require you to calculate something for every vertex of a tree, can be solved by the technique known as rerooting.
The name comes from the fact that the problem is first solved for a single root, and then we keep enough information to be able to quickly recompute the answer while shifting the root from a vertex to its neighbor.

As such, rerooting is generally done in two parts:

  • First, root the tree at an arbitrary vertex, say 1. Then, for each vertex u, compute the answer for the subtree of u, usually using a DFS.
  • Once this is done, a second DFS allows us to reroot, and solve for all vertices that are not in the subtree of u. This takes care of all cases.

So, let’s look at each of those parts individually for our problem.

Let’s root the tree at vertex 1.
First for each u, we need to compute \sum_{v} d(u, v) across all v that lie in the subtree of u.
That can be done with DFS, and storing relevant information corresponding to each bit.

Details

Let’s perform our DFS, and suppose we’re at vertex u. We want to find \text{val}[u], the answer for u within its subtree.
Let c be a child of u, and w be the weight of the edge joining them. First, recursively solve for c to compute \text{val}[c].

Now, let’s consider all paths from u going into the subtree of c.
Each such path corresponds to a path from c going into its subtree.
Specifically, we can take any path starting at c and going down, circular rotate every edge weight on this path once (since they’ll all be one step further from u than c), and then xor with w (the weight of the u\to c edge).

However, it’s not immediately obvious how to quickly do this for even a single path, let alone all paths.

Instead, let’s look at what happens bit-by-bit.
Let \text{ct}[u][b] be the number of paths from u into its subtree that have bit b set.
If we are able to calculate \text{ct}[u][b] for each 0 \leq b \lt 32, then we’ll simply have \text{val}[u] = \sum_{b=0}^{31} \text{ct}[u][b]\cdot 2^b.

Let’s recursively compute \text{ct}[c][b] for the child c of u.
Then, notice that bit b in a c \to v path corresponds to bit (b+1) (modulo 32) in the u\to v path, since the distance will increase by one.
In particular,

  • If w has bit b unset, then \text{ct}[u][b] increases by \text{ct}[c][b-1].
  • if w has bit b set, then \text{ct}[u][b] increases by \text{sz}[c] - \text{ct}[c][b-1], where \text{sz}[c] denotes the subtree size of c.
    This is because we’re XOR-ing with w, so any path with the bit set till c will have it unset when moving to u, and vice versa.

In this way, we can compute the contribution of child c to vertex u in \mathcal{O}(32).

This results in an algorithm in \mathcal{O}(32 \cdot N) in total.

Now that we’ve computed the subtree answer for each u, we need to reroot.

The rerooting process follows basically the same procedure.
When moving from u to its child c, we have the following:

  • The rerooting procedure has already computed the path counts for each bit for paths starting at u that are outside it. We want to extend this to path counts for c.
  • First, we need to account for all the children of u that are not c, since when moving to c, they’ll be outside its subtree.
    This is not hard to do: after all, our initial DFS already explicitly computed the path counts from u to each of these other children, so all we need to do is take those values and remove the ones corresponding to child c, i.e, reverse the process for child c.
  • Now, we know the path counts for everything from u. We want to shift this to path counts to everything outside c via edge weight w, which is exactly the same algorithm as we used when moving from child to parent, and hence is done in \mathcal{O}(32).

We DFS twice, each with a complexity of \mathcal{O}(32 N), which is fast enough for the given constraints.

TIME COMPLEXITY

\mathcal{O}(32\cdot N) per test case.

CODE:

Setter's code (C++)
#include <bits/stdc++.h>
#include <iostream>
#define ull long long int
using namespace std;

bool sortVec(const vector<ull> &l, const vector<ull> &r){
    return l[0] < r[0];
}
#define maxlen 100005
#define nbits 32


vector<vector<pair<ull,ull>>> adj(maxlen+1);
vector<vector<ull>> bits(maxlen+1, vector<ull> (nbits, 0));
vector<ull> subtree(maxlen+1, 0);


void DFS(ull root, ull parent){
	ull sum = 0;
	for (ull i = 0; i < adj[root].size(); i++){
		if (adj[root][i].first != parent){
			DFS(adj[root][i].first, root);
			ull w = adj[root][i].second;
			sum += subtree[adj[root][i].first];
			//cout<<root<<" "<<adj[root][i].first<<" "<<w<<"\n";
			for (ull j = 0; j < nbits; j++){
				ull b = ((w&(1LL<<j))!=0);
				//cout<<w<<" "<<subtree[adj[root][i].first]<<" "<<bits[adj[root][i].first][(j-1+nbits)%nbits]<<" "<<b<<"\n";
				if (b==0)bits[root][j]+=bits[adj[root][i].first][(j-1+nbits)%nbits];
				else bits[root][j]+=(subtree[adj[root][i].first] - bits[adj[root][i].first][(j-1+nbits)%nbits]);
			}
			/*cout<<root<<"\n";
			for (ull j=0; j < nbits; j++){
				cout<<bits[root][j]<<" ";
			}
			cout<<"\n";*/
		}
	}
	subtree[root] = sum+1;
	
}

void DFS1(ull root, ull parent, ull w, ull n){
	ull temp[nbits] = {0};
	for (ull i = 0; i < nbits; i++){
		ull b = ((w&(1LL<<i))!=0);
		//cout<<bits[parent][i]<<" "<<subtree[parent]<<" "<<subtree[root]<<" "<<bits[root][(i-1+nbits)%nbits]<<" "<<b<<"\n";
		if (b==0) temp[i] = bits[parent][i] - bits[root][(i-1+nbits)%nbits];
		else temp[i] = bits[parent][i] - (subtree[root]- bits[root][(i-1+nbits)%nbits]);
	}
	/*cout<<root<<"->"<<parent<<"\n";
	for (ull i = 0; i < nbits; i++){
		cout<<temp[i]<<" ";
	}
	cout<<"\n";*/
	ull res = n - subtree[root];
	for (ull j = 0; j < nbits; j++){
		ull b = ((w&(1LL<<j))!=0);
		if (b==0)bits[root][j]+=temp[(j-1+nbits)%nbits];
		else bits[root][j]+=(res - temp[(j-1+nbits)%nbits]);
	}
	for (ull i = 0; i < adj[root].size(); i++){
		if (adj[root][i].first!=parent){
			DFS1(adj[root][i].first, root, adj[root][i].second, n);
		}
	}
}


void solve (){
	ull n;
	cin>>n;
    //cout<<n<<"\n";
	for (ull i = 0; i <= n; i++){
		adj[i].clear();
		for (ull j = 0; j < nbits; j++)bits[i][j] = 0;
		subtree[i] = 0;
	}
	ull u, v, w;
	for (ull i = 0; i < n-1; i++){
		cin>>u>>v>>w;
        //if ((u==i+1)&&(v==i+1));
        //else cout<<"ok"<<" "<<i<<"\n";
    	adj[u].push_back({v, w});
		adj[v].push_back({u, w});
	}
    //cout<<"here"<<"\n";
	DFS(1, 0);
	ull root = 1, parent = 0;
	for (ull i = 0; i < adj[root].size(); i++){
		if (adj[root][i].first!=parent){
			DFS1(adj[root][i].first, root, adj[root][i].second, n);
		}
	}
	for (ull i = 1; i <= n; i++){
		ull ans = 0;
		for (ull j =0 ; j < nbits; j++){
			//cout<<bits[i][j]<<" ";
			ans += ((1LL<<j)*bits[i][j]);
		}
		cout<<ans<<" ";
		//cout<<"\n";
	}
	cout<<"\n";
}


int main(){
    //ios_base::sync_with_stdio(false);
    //cin.tie(NULL);      
    int t;
    cin>>t;
    //cout<<t<<"\n";
    while(t--){
        solve();
    }
    return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        string res;
        while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int min_len, int max_len, const string& pattern = "") {
        assert(min_len <= max_len);
        string res = readOne();
        assert(min_len <= (int) res.size());
        assert((int) res.size() <= max_len);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int min_val, int max_val) {
        assert(min_val <= max_val);
        int res = stoi(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    long long readLong(long long min_val, long long max_val) {
        assert(min_val <= max_val);
        long long res = stoll(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    vector<int> readInts(int size, int min_val, int max_val) {
        assert(min_val <= max_val);
        vector<int> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readInt(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    vector<long long> readLongs(int size, long long min_val, long long max_val) {
        assert(min_val <= max_val);
        vector<long long> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readLong(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

struct dsu {
    vector<int> p;
    vector<int> sz;
    int n;

    dsu(int _n) : n(_n) {
        p = vector<int>(n);
        iota(p.begin(), p.end(), 0);
        sz = vector<int>(n, 1);
    }

    inline int get(int x) {
        if (p[x] == x) {
            return x;
        } else {
            return p[x] = get(p[x]);
        }
    }

    inline bool unite(int x, int y) {
        x = get(x);
        y = get(y);
        if (x == y) {
            return false;
        }
        p[x] = y;
        sz[y] += sz[x];
        return true;
    }

    inline bool same(int x, int y) {
        return (get(x) == get(y));
    }

    inline int size(int x) {
        return sz[get(x)];
    }

    inline bool root(int x) {
        return (x == get(x));
    }
};

int main() {
    input_checker in;
    int tt = in.readInt(1, 10000);
    in.readEoln();
    int sn = 0;
    while (tt--) {
        int n = in.readInt(1, 1e5);
        in.readEoln();
        sn += n;
        dsu uf(n);
        vector<vector<pair<int, unsigned int>>> g(n);
        for (int i = 0; i < n - 1; i++) {
            int x = in.readInt(1, n);
            in.readSpace();
            int y = in.readInt(1, n);
            in.readSpace();
            int z = in.readInt(1, 1e9);
            in.readEoln();
            x--;
            y--;
            assert(uf.unite(x, y));
            g[x].emplace_back(y, z);
            g[y].emplace_back(x, z);
        }
        vector<int> sz(n, 1);
        vector<long long> a(n);
        vector<vector<int>> f(n, vector<int>(32));
        function<void(int, int)> Dfs = [&](int v, int p) {
            for (auto [to, w] : g[v]) {
                if (to == p) {
                    continue;
                }
                Dfs(to, v);
                auto t = f[to];
                rotate(t.begin(), t.begin() + 31, t.end());
                for (int i = 0; i < 32; i++) {
                    if (w & (1U << i)) {
                        f[v][i] += sz[to] - t[i];
                    } else {
                        f[v][i] += t[i];
                    }
                }
                sz[v] += sz[to];
            }
        };
        Dfs(0, -1);
        function<void(int, int)> Reroot = [&](int v, int p) {
            for (int i = 0; i < 32; i++) {
                a[v] += ((long long) f[v][i]) << i;
            }
            for (auto [to, w] : g[v]) {
                if (to == p) {
                    continue;
                }
                auto fv = f[v], ft = f[to];
                {
                    auto t = f[to];
                    rotate(t.begin(), t.begin() + 31, t.end());
                    for (int i = 0; i < 32; i++) {
                        if (w & (1U << i)) {
                            f[v][i] -= sz[to] - t[i];
                        } else {
                            f[v][i] -= t[i];
                        }
                    }
                }
                sz[v] -= sz[to];
                sz[to] += sz[v];
                {
                    auto t = f[v];
                    rotate(t.begin(), t.begin() + 31, t.end());
                    for (int i = 0; i < 32; i++) {
                        if (w & (1U << i)) {
                            f[to][i] += sz[v] - t[i];
                        } else {
                            f[to][i] += t[i];
                        }
                    }
                }
                Reroot(to, v);
                sz[to] -= sz[v];
                sz[v] += sz[to];
                f[v] = fv;
                f[to] = ft;
            }
        };
        Reroot(0, -1);
        for (int i = 0; i < n; i++) {
            cout << a[i] << " \n"[i == n - 1];
        }
    }
    assert(sn <= 1e5);
    in.readEof();
    return 0;
}
Editorialist's code (C++)
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
	ios::sync_with_stdio(false); cin.tie(0);

	int t; cin >> t;
	while (t--) {
		int n; cin >> n;
		vector<vector<array<ll, 2>>> adj(n);
		for (int i = 0; i < n-1; ++i) {
			int u, v, w; cin >> u >> v >> w;
			adj[--u].push_back({--v, w});
			adj[v].push_back({u, w});
		}

		vector<int> subsz(n);
		vector<array<ll, 32>> path_ct(n);
		vector<ll> ans(n);

		auto nxt = [&] (int x) {return (x+1) % 32;};
		auto dfs = [&] (const auto &self, int u, int p) -> void {
			auto it = find_if(begin(adj[u]), end(adj[u]), [&] (auto a) {return a[0] == p;});
			if (it != end(adj[u])) adj[u].erase(it);
			
			subsz[u] = 1;
			for (auto [v, w] : adj[u]) {
				self(self, v, u);
				subsz[u] += subsz[v];
				for (int b = 0; b < 32; ++b) {
					int b2 = nxt(b);
					if ((w >> b2) & 1) path_ct[u][b2] += subsz[v] - path_ct[v][b];
					else path_ct[u][b2] += path_ct[v][b];
				}
			}

			for (int b = 0; b < 32; ++b) {
				ll val = 1LL << b;
				ans[u] += val * path_ct[u][b];
			}
		};
		auto reroot = [&] (const auto &self, int u, auto from_up) -> void {

			for (int b = 0; b < 32; ++b) {
				ll val = 1LL << b;
				ans[u] += val * from_up[b];

				from_up[b] += path_ct[u][b];
			}

			for (auto [v, w] : adj[u]) {
				auto tmp = from_up;
				array<int, 32> send{};
				for (int b = 0; b < 32; ++b) {
					int b2 = nxt(b);
					if ((w >> b2) & 1) tmp[b2] -= subsz[v] - path_ct[v][b];
					else tmp[b2] -= path_ct[v][b];
					
					int b3 = nxt(b2);
					if ((w >> b3) & 1) send[b3] = (n - subsz[v]) - tmp[b2];
					else send[b3] = tmp[b2]; 
				}
				self(self, v, send);
			}
		};
		dfs(dfs, 0, 0);
		reroot(reroot, 0, array<int, 32>{});
		for (auto x : ans) cout << x << ' ';
		
		cout << '\n';
	}
}