REDZ - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: pols_agyi_pols
Tester & Editorialist: iceknight1093

DIFFICULTY:

9999

PREREQUISITES:

XOR basis, dfs, rerooting

PROBLEM:

You’re given a tree on N vertices, vertex i has value A_i.
Answer Q queries of the following type:

  • Given x, Y, L, R, does there exist a subset of vertices whose xor equals Y; and for every vertex u of the subset, dist(x, u) \lt L or dist(x, u) \gt R.

EXPLANATION:

Given that we need to check whether a subset with given XOR exists; the only reasonable thing we can do is to attempt to build a XOR basis of all the valid elements.
(Here’s a tutorial on the concept if it’s new to you)

However, we do need to deal with the distance constraints.
For this, we maintain two separate bases for each vertex u: \text{near}[u] and \text{far}[u].
As their names suggest, \text{near}[u] represents a basis where we pick elements as close to u as possible; while \text{far}[u] is the opposite.

If we were able to compute both of these, then answering query (Y, L, R) for vertex u is doable as follows:

  • Consider all elements of \text{near}[u] whose distance to u is \lt L, and elements of \text{far}[u] whose distance to u is \gt R.
  • Construct a new basis consisting of only these elements; then check if Y lies in their span.
  • The complexity of this is \mathcal{O}((\log\max A)^2) per query, which is good enough.

Now, we focus on computing \text{near}[u] quickly for all u: \text{far}[u] can be done similarly.

First off, a natural algorithm to do this is as follows:

  • Let \text{near}[u] be empty initially.
  • Perform a bfs starting from node u, and each time you reach a new vertex v, try to insert A_v into \text{near}[u].

This way, we go over vertices in increasing order of distances, which gives us what we want: for any distance d, the basis so far represents the span of everything at distance \leq d from u.

This is too slow since we need to do it for every vertex, so some optimization is required.
One way to do that is to use rerooting.

First, let’s consider only subtrees.
That is root the tree arbitrarily and then, for a vertex u, we build \text{near}[u] from vertices in the subtree of u.
That can be done using a DFS, via the following process:

  • First, recursively compute \text{near}[v] for all children v of u.
  • Then, insert A_u into \text{near}[u] and merge all the \text{near}[v] into \text{near}[u].
    However, we need to take care of distances here; so do the following:
    • As an initial step, ensure that the basis elements are sorted in descending order (and in particular, are in row echelon form (as the gaussian elimination process would give)).
      Also maintain, for each element, its distance from u.
    • When adding a new element x to the basis at distance d:
      • Go over bits from larger to smaller.
      • If x doesn’t have the current bit set, do nothing
      • Otherwise, if d is less than the distance of the currently set value for this bit, replace it with x; otherwise perform the gaussian elim step and continue on.
      • If x does replace the element at the current bit, continue on with the replaced element instead (since it might still be useful later on)

The last part can be done relatively simply in \mathcal{O}((\log\max A)^2) by just restarting the insertion process for the replaced element, though it’s also possible to implement it in \mathcal{O}(\log\max A) with a little care.
Depending on which version you do, merging as a whole takes either \mathcal{O}((\log\max A)^3) or \mathcal{O}((\log\max A)^2) time.
The latter is certainly fast enough; the former is likely still alright with reasonable implementation.

Now, we’ve only computed \text{near}[u] considering elements in the subtree of u; but we need to consider vertices outside the subtree as well.

This is exactly what rerooting accomplishes.
When moving from a vertex to its child, we need to add in the bases from all the other children of the vertex.
To do that quickly, precompute bases for each prefix/suffix of children, so that only a constant number of merges are needed instead of needing to go across all of them again.

The overall complexity of this is \mathcal{O}(N(\log\max A)^2) or \mathcal{O}(N(\log\max A)^3) depending on implementation; after which each query is answered in \mathcal{O}((\log\max A)^2).

TIME COMPLEXITY

\mathcal{O}((N+Q) (\log\max A)^2) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define ll long long

vector <ll> v[100005];
vector <ll> child[100005];
ll a[100005];
ll bit_close[100005][20];
ll dist_close[100005][20];
ll bit_far[100005][20];
ll dist_far[100005][20];
ll dp[20];
ll dist_dp[20];
ll current_dp[20];
ll dist_current_dp[20];

void update(ll x,ll pos,ll f){
    ll dist=f;
    for(int i=19;i>=0;i--){
        if(x&(1<<i)){
            if(!bit_close[pos][i]){
                bit_close[pos][i]=x;
                dist_close[pos][i]=dist;
                return;
            }
            if(dist_close[pos][i]>dist){
                swap(dist,dist_close[pos][i]);
                swap(x,bit_close[pos][i]);
            }
            x^=bit_close[pos][i];
        }
    }
    return;
}

void update2(ll x,ll pos,ll f){
    ll dist=f;
    for(int i=19;i>=0;i--){
        if(x&(1<<i)){
            if(!bit_far[pos][i]){
                bit_far[pos][i]=x;
                dist_far[pos][i]=dist;
                return;
            }
            if(dist_far[pos][i]<dist){
                swap(dist,dist_far[pos][i]);
                swap(x,bit_far[pos][i]);
            }
            x^=bit_far[pos][i];
        }
    }
    return;
}

void update3(ll x,ll f){
    ll dist=f;
    for(int i=19;i>=0;i--){
        if(x&(1<<i)){
            if(!dp[i]){
                dp[i]=x;
                dist_dp[i]=dist;
                return;
            }
            if(dist_dp[i]<dist){
                swap(dist,dist_dp[i]);
                swap(x,dp[i]);
            }
            x^=dp[i];
        }
    }
    return;
}

void update4(ll x,ll f){
    ll dist=f;
    for(int i=19;i>=0;i--){
        if(x&(1<<i)){
            if(!current_dp[i]){
                current_dp[i]=x;
                dist_current_dp[i]=dist;
                return;
            }
            if(dist_current_dp[i]<dist){
                swap(dist,dist_current_dp[i]);
                swap(x,current_dp[i]);
            }
            x^=current_dp[i];
        }
    }
    return;
}

void dfs(ll pos,ll par){
    update(a[pos],pos,0);
    update2(a[pos],pos,0);
    for(auto it:v[pos]){
        if(it==par){
            continue;
        }
        child[pos].push_back(it);
        dfs(it,pos);
        for(int i=19;i>=0;i--){
            if(bit_close[it][i]){
                update(bit_close[it][i],pos,dist_close[it][i]+1);
            }
        }
        for(int i=19;i>=0;i--){
            if(bit_far[it][i]){
                update2(bit_far[it][i],pos,dist_far[it][i]+1);
            }
        }
    }
    return;
}

void dfs2(ll pos,ll par){
    for(int i=19;i>=0;i--){
        if(dp[i]){
            dist_dp[i]++;
        }
    }
    for(int i=19;i>=0;i--){
        if(bit_close[par][i]){
            update(bit_close[par][i],pos,dist_close[par][i]+1);
        }
        if(dp[i]){
            update2(dp[i],pos,dist_dp[i]);
        }
    }
    ll x=a[pos],dist;
    update3(x,0);
    if(child[pos].size()==0){
        return;
    }
    ll past_dp[20];
    ll dist_past_dp[20];
    for(int i=0;i<=19;i++){
        past_dp[i]=dp[i];
        dist_past_dp[i]=dist_dp[i];
    }
    ll last[child[pos].size()+1][20];
    ll dist_last[child[pos].size()+1][20];
    for(int i=0;i<=19;i++){
        last[child[pos].size()][i]=0;
        dist_last[child[pos].size()][i]=0;
        current_dp[i]=0;
        dist_current_dp[i]=0;
    }
    for(int i=child[pos].size()-1;i>=0;i--){
        x=child[pos][i];
        for(int j=19;j>=0;j--){
            update4(bit_far[x][j],dist_far[x][j]+1);
        }
        for(int j=19;j>=0;j--){
            last[i][j]=current_dp[j];
            dist_last[i][j]=dist_current_dp[j];
        }
    }
    ll pre[child[pos].size()+1][20];
    ll dist_pre[child[pos].size()+1][20];
    for(int i=0;i<=19;i++){
        pre[0][i]=0;
        dist_pre[0][i]=0;
        current_dp[i]=0;
        dist_current_dp[i]=0;
    }
    for(int i=0;i<child[pos].size();i++){
        x=child[pos][i];
        for(int j=19;j>=0;j--){
            update4(bit_far[x][j],dist_far[x][j]+1);
        }
        for(int j=19;j>=0;j--){
            pre[i+1][j]=current_dp[j];
            dist_pre[i+1][j]=dist_current_dp[j];
        }
    }
    for(int i=0;i<child[pos].size();i++){
        for(int j=0;j<=19;j++){
            dp[j]=past_dp[j];
            dist_dp[j]=dist_past_dp[j];
        }
        for(int j=19;j>=0;j--){
            x=pre[i][j];
            dist=dist_pre[i][j];
            update3(x,dist);
        }
        for(int j=19;j>=0;j--){
            x=last[i+1][j];
            dist=dist_last[i+1][j];
            update3(x,dist);
        }
        x=child[pos][i];
        dfs2(x,pos);
    }
    return;
}

void answer(ll pos,ll x,ll l,ll r){
    for(int i=19;i>=0;i--){
        if(dist_close[pos][i]<l && bit_close[pos][i]){
            dp[i]=bit_close[pos][i];
        }else{
            dp[i]=0;
        }
    }
    ll y;
    for(int i=19;i>=0;i--){
        if(dist_far[pos][i]>r && bit_far[pos][i]){
            y=bit_far[pos][i];
            for(int j=19;j>=0;j--){
                if(y&(1<<j)){
                    if(!dp[j]){
                        dp[j]=y;
                        break;
                    }
                    y^=dp[j];
                }
            }
        }
    }
    for(int i=19;i>=0;i--){
        if(x&(1<<i)){
            if(!dp[i]){
                break;
            }
            x^=dp[i];
        }
    }
    if(x==0){
        cout<<"YES\n";
    }else{
        cout<<"NO\n";
    }
    return;
}

int main(){
	ios_base::sync_with_stdio(false);
	cin.tie(NULL);
    ll kitne_cases_hain;
    kitne_cases_hain=1;
    //freopen("input.txt","r",stdin);freopen("output.txt","w",stdout);
    cin>>kitne_cases_hain;
    while(kitne_cases_hain--){          
        ll n,q;
        cin>>n>>q;
        for(int i=0;i<=19;i++){
            dp[i]=0;
        }
        for(int i=1;i<=n;i++){
            v[i].clear();
            child[i].clear();
            for(int j=19;j>=0;j--){
                bit_close[i][j]=0;
                bit_far[i][j]=0;
            }
        }
        for(int i=1;i<=n;i++){
            cin>>a[i];
        }
        ll x,y,l,r;
        for(int i=1;i<n;i++){
            cin>>x>>y;
            v[x].push_back(y);
            v[y].push_back(x);
        }
        dfs(1,0);
        dfs2(1,0);
        while(q--){
            cin>>x>>y>>l>>r;
            answer(x,y,l,r);
        }
    }
	return 0;
}
Editorialist's code (C++)
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

struct Basis {
	static const int SZ = 20;
	array<array<int, 2>, SZ> basis{};
	Basis() {}
	void Add(int x, int dep, bool comp = 0) {
        if (x == 0) return;
        int cur = x;
        for (int i = 0; i < SZ; ++i) {
            auto [elem, d] = basis[i];
            if (elem == 0) {
                basis[i] = {cur, dep};
                break;
            }
            int lead1 = __builtin_clz(cur), lead2 = __builtin_clz(elem);
            if (lead1 < lead2 or (lead1 == lead2 and ((comp == 0 and dep < d) or (comp == 1 and dep > d)))) {
                basis[i] = {cur, dep};
                Add(elem, d, comp);
                break;
            }
            cur = min(cur, cur ^ elem);
            if (cur == 0) break;
        }
	}
	void Merge(Basis &other, bool type) {
		for (const auto &y : other.basis) {
            if (y[0] == 0) break;
			Add(y[0], y[1]+1, type);
		}
	}
    bool inSpan(int x) {
        for (auto &y : basis) {
            if (y[0] == 0) break;
            x = min(x, x ^ y[0]);
        }
        return x == 0;
    }
};

int main()
{
    int t; cin >> t;
    while (t--) {
        int n, q; cin >> n >> q;

        vector<int> a(n);
        for (int &x : a) cin >> x;
        vector adj(n, vector<int>());
        for (int i= 0; i < n-1; ++i) {
            int u, v; cin >> u >> v;

            adj[--u].push_back(--v);
            adj[v].push_back(u);
        }

        vector queries(n, vector<array<int, 4>>());
        for (int i = 0; i < q; ++i) {
            int u, y, l, r; cin >> u >> y >> l >> r;
            if (y < (1 << 20)) queries[u-1].push_back({i, y, l, r});
        }

        vector<Basis> closest(n), farthest(n);
        auto init = [&] (const auto &init, int u, int par) -> void {
            closest[u].Add(a[u], 0, 0);
            farthest[u].Add(a[u], 0, 1);

            for (int v : adj[u]) if (v != par) {
                init(init, v, u);
                closest[u].Merge(closest[v], 0);
                farthest[u].Merge(farthest[v], 1);
            }
        };
        init(init, 0, 0);

        vector<int> ans(q);
        auto answer = [&] (const auto &answer, int u, int par, array<Basis, 2> above) -> void {
            if (u) {
                remove(begin(adj[u]), end(adj[u]), par);
                adj[u].pop_back();
            }
            auto close = closest[u]; close.Merge(above[0], 0);
            auto far = farthest[u]; far.Merge(above[1], 1);

            for (auto [id, target, l, r] : queries[u]) {
                Basis cur;
                for (auto &y : close.basis) if (y[1] < l) cur.Add(y[0], y[1]);
                for (auto &y : far.basis) if (y[1] > r) cur.Add(y[0], y[1]);
                ans[id] = cur.inSpan(target);
            }

            int ch = adj[u].size();
            for (int c : {0, 1}) {
                for (auto &y : above[c].basis) ++y[1];
                above[c].Add(a[u], 0, c);
            }
            vector<array<Basis, 2>> send(ch, above);
            Basis prefclose, preffar;
            for (int i = 1; i < ch; ++i) {
                prefclose.Merge(closest[adj[u][i-1]], 0);
                for (auto &y : prefclose.basis) send[i][0].Add(y[0], y[1]);
                preffar.Merge(farthest[adj[u][i-1]], 1);
                for (auto &y : preffar.basis) send[i][1].Add(y[0], y[1], 1);
            }
            prefclose = preffar = Basis();
            for (int i = ch-2; i >= 0; --i) {
                prefclose.Merge(closest[adj[u][i+1]], 0);
                for (auto &y : prefclose.basis) send[i][0].Add(y[0], y[1]);
                preffar.Merge(farthest[adj[u][i+1]], 1);
                for (auto &y : preffar.basis) send[i][1].Add(y[0], y[1], 1);
            }

            for (int i = 0; i < ch; ++i) answer(answer, adj[u][i], u, send[i]);
        };
        array<Basis, 2> above;
        answer(answer, 0, 0, above);

        for (auto x : ans) {
            cout << (x ? "YES" : "NO") << '\n';
        }
    }
}

This problem is atcoder H. Xor Query union codeforces E. The tree has fallen!.