C8KBFTREE - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: c8kbf
Testers: wuhudsm, satyam_343
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Pigeonhole principle

PROBLEM:

Given a weighted tree on N vertices, find two distinct paths whose bitwise XOR is the same; or claim that no such pairs exist.

EXPLANATION:

The A_i values are between 0 and 2^{20}, which means that the bitwise XOR of any path is at most 2^{21}-1, i.e, there are at most 2^{21} distinct possible values of the bitwise XOR of a path.

Now, recall the pigeonhole principle: if the number of pigeons (paths) exceeds the number of holes (XOR values), then two pigeons must share a hole (two paths must share a XOR value).

In particular, this means that if we take 2^{21} + 1 different paths and compute their bitwise XORs, we will definitely find two paths with the same XOR.

This gives us a very simple solution: brute force!
Simply look at paths in the tree one at a time, each time computing its bitwise XOR.
You will either find two paths with equal bitwise XOR, or run out of paths.

The latter can only happen when the number of paths in the tree doesn’t exceed 2^{21}, so each testcase is solved in at most 2^{21} operations, which is good enough.

Computing the bitwise XOR of a given path is not hard, and can be done in a couple of ways.

  • The simplest way is to do it on-the-fly. Fix one end of the path, then perform a DFS/BFS on the tree to compute bitwise XORs to all other vertices; immediately stopping once you reach a duplicate.
  • Alternately, there’s a more general technique. Fix a root of the tree, say vertex 1, and compute the bitwise XOR of the 1 \to u path for every u; say this value is X_u.
    Then the bitwise XOR of the path between u and v is simply $X_u \oplus X_v, since the XORs of edges not on the path cancel out.

TIME COMPLEXITY

\mathcal{O}(\min(N^2, \text{MAX})) per test case, where \text{MAX} = 2^{21} for this problem.

CODE:

Setter's code (C++)

/*
 template by c8kbf
 */

// macOS doesn't have <bits/++.h> (shame)
#include <cstdlib>

#include <iostream>
#include <cstdio>
#include <iomanip>
#include <fstream>

#include <cmath>
#include <cstring>
#include <ctime>

#include <deque>
#include <string>
#include <stack>
#include <vector>
#include <map>
#include <queue>
#include <list>
#include <set>
#include <unordered_map>
#include <unordered_set>
#include <bitset>

#include <algorithm>
#include <numeric>
#include <random>
#include <functional>

//dont worry bout me, i'm not high
#define ef else if
#define leave exit(0);

#define v(x) vector<x >
#define v2(x) vector<vector<x > >
#define v3(x) vector<vector<vector<x > > >

#define q(x) queue<x >
#define dq(x) deque<x >
#define s(x) set<x >
#define st(x) stack<x >
#define ms(x) multiset<x >
#define m(x, y) map<x , y >
#define b(x) bitset<x >
#define l(x) list<x >

#define ss(x) v(_)(x+8, 0)
#define ssz(type, x) v(type)(x+8, 0)
#define s2(x, y) v2(_)(x+8, v(_)(y+8, 0))
#define s2z(type, x, y) v2(type)(x+8, v(type)(y+8, 0))
#define s3(x, y, z) v3(_)(x+8, v2(_)(y+8, v(_)(z+8, 0)))
#define s3z(type, x, y, z) v3(type)(x+8, v2(type)(y+8, v(type)(z+8, 0)))
#define rd(a, sz) for(_ i = 1; i <= sz; ++i) a[i] = read();
#define wr(a, sz) for(_ i = 1; i <= sz; ++i) writesc(a[i]); clr();

#define i(x) x::iterator

#define pr(x, y) pair< x, y >
#define mp(x, y) make_pair(x, y)

using namespace std;

//weirdest typedefs ever??
typedef long long _;
typedef int _0;
typedef double _D;
typedef unsigned long long u_;
typedef string str;
typedef vector<_> v_;
typedef pair<_, _> _p;
typedef const long long constant;

//fastIO cos why not
inline _ read() {
    _ x = 0, f = 1;
    char ch = getchar();
    for(; !(ch >= '0' && ch <= '9'); ch = getchar()) if(ch == '-') f *= -1;
    for(; (ch >= '0' && ch <= '9'); ch = getchar()) x = (x<<3)+(x<<1)+(ch^48);
    return x*f;
}

inline bool read(_ & x, v(char) tl = {'\n', EOF}) {
    x = 0;
    _ f = 1;
    char ch = getchar();
    for(; !(ch >= '0' && ch <= '9'); ch = getchar()) if(ch == '-') f *= -1;
    for(; (ch >= '0' && ch <= '9'); ch = getchar()) x = (x<<3)+(x<<1)+(ch^48);
    x *= f;
    if(ch == '\r') ch = getchar();
    return !count(tl.begin(), tl.end(), ch);
}

inline void read(char * a, v(char) tl = {' ', '\n', '\r', '\t', '\0', EOF}, v(char) skp = {' ', '\n', '\r', '\t'}) {
    char ch = getchar();
    for(; count(skp.begin(), skp.end(), ch); ) ch = getchar();
    for(; !count(tl.begin(), tl.end(), ch); ch = getchar()) {
        *a = ch;
        ++a;
    }
    *a = '\0';
    return;
}

inline void read(str & a, v(char) tl = {' ', '\n', '\r', '\t', '\0', EOF}, v(char) skp = {' ', '\n', '\r', '\t'}) {
    a.clear();
    char ch = getchar();
    for(; count(skp.begin(), skp.end(), ch); ) ch = getchar();
    for(; !count(tl.begin(), tl.end(), ch); ch = getchar()) a += ch;
    return;
}

inline void read(vector<reference_wrapper<_> > a) {
    for(_ & i : a) i = read();
    return;
}

inline void read(_p & x) {
    x.first = read();
    x.second = read();
    return;
}

inline char getDg() {
    char ch = getchar();
    for(; !(ch >= '0' && ch <= '9'); ) ch = getchar();
    return ch;
}

inline char getLw() {
    char ch = getchar();
    for(; !(ch >= 'a' && ch <= 'z'); ) ch = getchar();
    return ch;
}

inline char getUp() {
    char ch = getchar();
    for(; !(ch >= 'A' && ch <= 'Z'); ) ch = getchar();
    return ch;
}

inline char getLtr() {
    char ch = getchar();
    for(; !((ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z')); ) ch = getchar();
    return ch;
}

inline char gc() {
    char ch = getchar();
    for(; ch == '\n' || ch == '\r' || ch == ' '; ) ch = getchar();
    return ch;
}

inline void write(_ x) {
    if(x < 0) {
        putchar('-');
        write(-x);
        return;
    }
    if(x > 9) write(x/10);
    putchar((x%10)^48);
    return;
}

inline void write(_0 x) {
    write((_)x);
    return;
}

inline void write(char const * a) {
    for(_ i = 0; a[i]; ++i) putchar(a[i]);
    return;
}

inline void write(const str a) {
    write(a.c_str());
    return;
}

inline void write(char ch) {
    putchar(ch);
    return;
}

inline void write(_p a, char const * b = " ") {
    write(a.first);
    write(b);
    write(a.second);
    return;
}

inline void write(v_ a, char const * b = " ") {
    bool fs = false;
    for(_ i : a) {
        if(!fs) fs = true;
        else write(b);
        write(i);
    }
    return;
}

inline void clr() {
    putchar(10);
    return;
}

inline void flsh(bool nl = true) {
    if(nl) clr();
    fflush(stdout);
    return;
}

inline void spc() {
    putchar(32);
    return;
}

template <class tp>
inline void writeln(tp x) {
    write(x);
    clr();
}

inline void writeln(_p a, char const * b = " ") {
    write(a, b);
    clr();
    return;
}

inline void writeln(v_ a, char const * b = " ") {
    write(a, b);
    clr();
    return;
}

template <class tp>
inline void writesc(tp x) {
    write(x);
    spc();
}

inline void writesc(_p a, char const * b = " ") {
    write(a);
    spc();
    return;
}

template <class tp>
inline void writeflsh(tp x, bool nl = true) {
    write(x);
    flsh(nl);
}

inline void writeflsh(_p a, char const * b = " ", bool nl = true) {
    write(a, b);
    flsh(nl);
    return;
}

inline void yes(_ a = 1) {
    write(a & 1 ? 'Y' : 'y');
    write(a & 2 ? 'E' : 'e');
    write(a & 4 ? 'S' : 's');
    clr();
    return;
}

inline void no(_ a = 1) {
    write(a & 1 ? 'N' : 'n');
    write(a & 2 ? 'O' : 'o');
    clr();
    return;
}

//loop systems
inline v_ rg(_ r, _ l = 1, _ d = 1) {
    v_ rv;
    for(_ i = l; i <= r; i += d) rv.push_back(i);
    return rv;
}

inline v_ dg(_ r, _ l = 1, _ d = -1) {
    v_ rv;
    for(_ i = r; i >= l; i += d) rv.push_back(i);
    return rv;
}

inline void AC();
int main(int argc, char * argv[]) {

    // freopen("/Users/ryanzhang/Dropbox/Problemsetting/Problems In Progress/Codechef - C8KBFTREE/data/3.in", "r", stdin);
    
//    #define file_IO
#ifdef file_IO
    str fileN = "";
    freopen((fileN+".in").c_str(), "r", stdin);
    freopen((fileN+".out").c_str(), "w", stdout);
#endif

    #define multiple_testcases
#ifdef multiple_testcases
    _ tc = read();
    for(; tc--; ) AC(); // good luck!
#else
    AC(); // good luck!
#endif

    return 0;
}

// ----- End of Template -----




constant maxn = 1E6+8;
constant maxm = 2E6+8;

_ n, x, y, z;
vector<_p> g[maxn];
_p a[maxm];
bool ok;

void dfs(_ x, _ fa, _ tp, _ vl);
inline void AC() {
    
    n = read();
    for(_ i = 1; i <= n; ++i) g[i].clear();
    for(_ i = 0; i <= maxm-1; ++i) a[i] = mp(-1, -1);
    for(_ i = 1; i <= n-1; ++i) {
        read({x, y, z});
        g[x].push_back(mp(y, z));
        g[y].push_back(mp(x, z));
    }
    ok = false;
    for(_ i = 1; i <= n; ++i) dfs(i, -1, i, 0);
    if(!ok) writeln(-1);
 
    return;
}

void dfs(_ x, _ fa, _ tp, _ vl) {
    if(ok) return;
    if(x > tp) {
        if(!~a[vl].first) a[vl] = mp(x, tp);
        else {
            ok = true;
            if(a[vl].first > a[vl].second) swap(a[vl].first, a[vl].second);
            if(x > tp) swap(x, tp);
            writeln({a[vl].first, a[vl].second, x, tp});
            return;
        }
    }
    for(_p i : g[x]) if(i.first != fa) {
        dfs(i.first, x, tp, vl^i.second);
        if(ok) return;
    }
    return;
}
Tester's code (C++)
#pragma GCC optimisation("O3")
#pragma GCC target("avx,avx2,fma")
#pragma GCC optimize("Ofast,unroll-loops")
#include <bits/stdc++.h>   
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;
using namespace std;
#define ll long long
const ll INF_ADD=1e18;
#define pb push_back               
#define mp make_pair        
#define nline "\n"                           
#define f first                                          
#define s second                                               
#define pll pair<ll,ll> 
#define all(x) x.begin(),x.end()     
#define vl vector<ll>       
#define vvl vector<vector<ll>>    
#define vvvl vector<vector<vector<ll>>>          
#ifndef ONLINE_JUDGE    
#define debug(x) cerr<<#x<<" "; _print(x); cerr<<nline;
#else
#define debug(x);  
#endif     
void _print(ll x){cerr<<x;}   
void _print(string x){cerr<<x;}     
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count()); 
template<class T,class V> void _print(pair<T,V> p) {cerr<<"{"; _print(p.first);cerr<<","; _print(p.second);cerr<<"}";}
template<class T>void _print(vector<T> v) {cerr<<" [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T>void _print(set<T> v) {cerr<<" [ "; for (T i:v){_print(i); cerr<<" ";}cerr<<"]";}
template<class T>void _print(multiset<T> v) {cerr<< " [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T,class V>void _print(map<T, V> v) {cerr<<" [ "; for(auto i:v) {_print(i);cerr<<" ";} cerr<<"]";} 
typedef tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
typedef tree<ll, null_type, less_equal<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_multiset;
typedef tree<pair<ll,ll>, null_type, less<pair<ll,ll>>, rb_tree_tag, tree_order_statistics_node_update> ordered_pset;
//--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
const ll MOD=1e9+7;    
const ll MAX=500500;  
void solve(){               
    ll n; cin>>n;
    vector<pair<ll,ll>> adj[n+5];
    for(ll i=1;i<n;i++){
        ll u,v,w; cin>>u>>v>>w;
        adj[u].push_back({v,w});
        adj[v].push_back({u,w});
    }  
    vector<ll> pref(n+5,-1);
    pref[1]=0;
    queue<ll> track; track.push(1);
    while(!track.empty()){
        auto it=track.front();
        track.pop();
        for(auto chld:adj[it]){
            if(pref[chld.f]==-1){  
                pref[chld.f]=pref[it]^chld.s;
                track.push(chld.f);
            }
        }
    } 
    map<ll,pair<ll,ll>> use;
    for(ll i=1;i<=n;i++){
        for(ll j=i+1;j<=n;j++){
            ll now=pref[i]^pref[j];
            if(use.find(now)==use.end()){
                use[now]={i,j};
            }
            else{
                auto it=use[now];
                cout<<it.f<<" "<<it.s<<" "<<i<<" "<<j<<nline;
                return;
            }
        }
    }
    cout<<"-1\n";
    return;          
}                                               
int main()                                                                            
{                              
    ios_base::sync_with_stdio(false);                            
    cin.tie(NULL);                       
    #ifndef ONLINE_JUDGE               
    freopen("input.txt", "r", stdin);                                           
    freopen("output.txt", "w", stdout);    
    freopen("error.txt", "w", stderr);                        
    #endif        
    ll test_cases=1;                 
    cin>>test_cases; 
    while(test_cases--){
        solve();
    }
    cout<<fixed<<setprecision(9);
    cerr<<"Time:"<<1000*((double)clock())/(double)CLOCKS_PER_SEC<<"ms\n"; 
} 
2 Likes

How the brute force is accepted in this question O(n^2)
Why it did not give TLE?

How the brute force solution is working. ?In the worst case when there are no matching paths complexity will be o(N^2).

@jjjjjj @abhinav_152
Please do read through the editorial first, it mentions why the ‘brute force’ is in fact fast enough.

It mentions that if you take more than 2^{21} paths then you will definitely find an answer because of the pigeonhole principle.
If N^2 \gt 2^{21} this means you’ll always find an answer way before actually reaching N^2 paths as long as you break out early.

2^{21} is approximately 2\cdot 10^6, for reference.

5 Likes

feeling bad, didnot see a<b and c<d part and solution failed on testcase 0 and 4, took so many wrong submission to ultimately found out condition on pairs :<

Is there any solution better than O(n^2)?

In worst case that would exceed 2^21 and hence there will be two xor.