GGTREE - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: still_me
Testers: the_hyp0cr1t3, rivalq
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Tries, basic probability, DFS

PROBLEM:

There is a tree on N vertices, the i-th vertex has value A_i. Alice starts at vertex 1 and randomly moves to a child of the current vertex till she can no longer do so.

Her score is the bitwise xor of all but one of the values she visited along the way, and she can choose which one to drop in order to maximize her score.
Whatā€™s her expected final score?

EXPLANATION:

Applying the definition of expected value, we see that the answer is \sum P_u S_u, where:

  • P_u is the probability that Alice reaches u
  • S_u is the score obtained by Alice if she reaches u

and the above summation is taken across all terminal vertices (i.e, leaves) in the tree.
Letā€™s compute them individually.

Computing probabilities

Let P_u be the probability that Alice visits vertex u during her journey.
Since she starts at 1, we have P_1 = 1.

Now, consider some u \gt 1 that has c children. Let v be one such child.
If we already know P_u, computing P_v is easy: itā€™s just \frac{P_u}{c}. This is because Alice must reach u then choose to go into v, so we multiply the probabilities together.

This gives us a way to compute P_u for every vertex u using DFS, starting from the root to the leaves; and the whole process takes \mathcal{O}(N \log{10^9}) time (the log factor because we perform modulo divisions).

Computing score

Let u be a leaf vertex. Letā€™s compute Aliceā€™s score when she reaches it.

Suppose the values on the path to u are x_1, \ldots, x_k. We want the maximum value of

x_1 \oplus \ldots \oplus x_{i-1} \oplus x_{i+1} \oplus \ldots \oplus x_k

across all i.
Notice that this is just

(x_1 \oplus \ldots \oplus x_{i-1} \oplus x_i \oplus x_{i+1} \oplus \ldots \oplus x_k) \oplus x_i

and the first value is a constant, being the xor-sum of all the values.

So, we have a list [x_1, \ldots, x_k] and a constant C, and weā€™d like to find the maximum possible value of C\oplus x_i for some i.

This is a well-known problem, and can be solved using a trie: a tutorial can be found here.

However, the values in the trie keep changing, and we canā€™t afford to insert the entire path each time we process a leaf.
Instead, we maintain a single trie and reuse it across our DFS, as follows:

  • Let T denote our trie. Initially, itā€™s empty.
  • When we enter node u, insert A_u into T.
  • Then, if u is a leaf perform the relevant query; and if not continue the DFS into the children of u
  • Finally, when exiting u, remove (one copy of) A_u from T.

This ensures we perform only N insertions and deletions each across the whole process, making the time complexity \mathcal{O}(N\log{10^9}).

Once both the probabilities and the scores have been computed, calculating the final answer is trivial using the summation above.

TIME COMPLEXITY:

\mathcal{O}(N\log{10^9}) per testcase.

CODE:

Setter's code (C++)
//	Code by Sahil Tiwari (still_me)

#include<bits/stdc++.h>
#define still_me main
#define endl "\n"
#define int long long int
#define all(a) (a).begin() , (a).end()
#define print(a) for(auto TEMPORARY: a) cout<<TEMPORARY<<" ";cout<<endl;
#define tt int TESTCASE;cin>>TESTCASE;while(TESTCASE--)
#define arrin(a,n) for(int INPUT=0;INPUT<n;INPUT++)cin>>a[INPUT]

using namespace std;
const int mod = 1e9+7;
const int inf = 1e18;

long long power(long long a , long long b , long long mod){
    if(b==0)
        return 1;
    long long res = power(a , b/2 , mod);
    res = res*res%mod;
    if(b%2)
        res = res*a % mod;
    return res;
}

int inverse(int a){
    return power(a , mod-2 , mod);
}
map<int,int> p;
int cnt = 0;
void dfs(vector<vector<int>> &adj , int j , int prev , int prob) {
    if(adj[j].size() == 1 && j != 0) {
        p[j] = prob;
        return;
    }
    for(int &i: adj[j]) {
        if(i == prev)
            continue;
        dfs(adj , i , j , prob * inverse(adj[j].size() - (j == 0 ? 0 : 1)) % mod);
    }
}
int ans = 0;
struct Trie{
	vector<array<int, 2>> node;
	vector<int> last;
	vector<pair<int, int>> bck;
	Trie() {
		node.push_back({-1, -1});
		last.push_back(-1);
		bck.push_back({-1, -1});
	}
	void insert(int val, int n) {
		int cur = 0;
		for(int i = 29 ; i >= 0 ; i--) {
			int p = (val >> i) & 1;
			if(node[cur][p] == -1) {
				node[cur][p] = node.size();
				node.push_back({-1, -1});
				last.push_back(n);
				bck.push_back({cur, p});
			}
			cur = node[cur][p];
		}
	}
	void Delete(int n) {
		while(last.back() == n) {
			node[bck.back().first][bck.back().second] = -1;
			bck.pop_back();
			last.pop_back(), node.pop_back();
		}
	}
	int query(int v) {
		int cur = 0, ans = 0;
		for(int i = 29 ; i >= 0 ; i--) {
			int p = (v >> i) & 1;
			if(node[cur][1 ^ p] > 0) 
				ans ^= 1 << i, cur = node[cur][1 ^ p];
			else cur = node[cur][p];
		}
		return ans;
	}
};

void tdfs(vector<vector<int>> &adj , vector<int> &a, int j , int prev, int curr, Trie &T) {
    T.insert(a[j] , j);
    curr ^= a[j];
    // cout<<curr<<endl;
    if(adj[j].size() == 1 && j != 0) {
        ans += p[j] * T.query(curr);
        ans %= mod;
    }
    for(int &i: adj[j]) {
        if(i == prev)
            continue;
        tdfs(adj , a , i , j , curr, T);
    }
    T.Delete(j);
}

void chal_bsdk() {
    p.clear();
    ans = 0;
    int n;
    cin>>n;
    vector<int> a(n);
    arrin(a , n);
    vector<vector<int>> adj(n);
    for(int i=0;i<n-1;i++) {
        int u , v;
        cin>>u>>v;
        u--;v--;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }

    dfs(adj , 0 , 0 , 1);
    Trie T;
    tdfs(adj , a , 0 , 0 , 0 , T);
    cout<<ans<<endl;

}

signed still_me()
{
    ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL);

    // freopen("15.in" , "r" , stdin);
    // freopen("15.out" , "w" , stdout);
    tt{
        chal_bsdk();
    }
    return 0;
}
Tester's code (C++)
// Jai Shree Ram  
  
#include<bits/stdc++.h>
using namespace std;

#define rep(i,a,n)     for(int i=a;i<n;i++)
#define ll             long long
#define int            long long
#define pb             push_back
#define all(v)         v.begin(),v.end()
#define endl           "\n"
#define x              first
#define y              second
#define gcd(a,b)       __gcd(a,b)
#define mem1(a)        memset(a,-1,sizeof(a))
#define mem0(a)        memset(a,0,sizeof(a))
#define sz(a)          (int)a.size()
#define pii            pair<int,int>
#define hell           1000000007
#define elasped_time   1.0 * clock() / CLOCKS_PER_SEC



template<typename T1,typename T2>istream& operator>>(istream& in,pair<T1,T2> &a){in>>a.x>>a.y;return in;}
template<typename T1,typename T2>ostream& operator<<(ostream& out,pair<T1,T2> a){out<<a.x<<" "<<a.y;return out;}
template<typename T,typename T1>T maxs(T &a,T1 b){if(b>a)a=b;return a;}
template<typename T,typename T1>T mins(T &a,T1 b){if(b<a)a=b;return a;}

// -------------------- Input Checker Start --------------------
 
long long readInt(long long l, long long r, char endd)
{
    long long x = 0;
    int cnt = 0, fi = -1;
    bool is_neg = false;
    while(true)
    {
        char g = getchar();
        if(g == '-')
        {
            assert(fi == -1);
            is_neg = true;
            continue;
        }
        if('0' <= g && g <= '9')
        {
            x *= 10;
            x += g - '0';
            if(cnt == 0)
                fi = g - '0';
            cnt++;
            assert(fi != 0 || cnt == 1);
            assert(fi != 0 || is_neg == false);
            assert(!(cnt > 19 || (cnt == 19 && fi > 1)));
        }
        else if(g == endd)
        {
            if(is_neg)
                x = -x;
            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(false);
            }
            return x;
        }
        else
        {
            assert(false);
        }
    }
}
 
string readString(int l, int r, char endd)
{
    string ret = "";
    int cnt = 0;
    while(true)
    {
        char g = getchar();
        assert(g != -1);
        if(g == endd)
            break;
        cnt++;
        ret += g;
    }
    assert(l <= cnt && cnt <= r);
    return ret;
}
 
long long readIntSp(long long l, long long r) { return readInt(l, r, ' '); }
long long readIntLn(long long l, long long r) { return readInt(l, r, '\n'); }
string readStringLn(int l, int r) { return readString(l, r, '\n'); }
string readStringSp(int l, int r) { return readString(l, r, ' '); }
void readEOF() { assert(getchar() == EOF); }
 
vector<int> readVectorInt(int n, long long l, long long r)
{
    vector<int> a(n);
    for(int i = 0; i < n - 1; i++)
        a[i] = readIntSp(l, r);
    a[n - 1] = readIntLn(l, r);
    return a;
}
 
// -------------------- Input Checker End --------------------

struct node{
   node* sons[2];
   int cnt=0;
};
node* create(){
       node* temp=new node();
       temp->sons[0]=NULL;
       temp->sons[1]=NULL;
       temp->cnt=0;
       return temp;
} 
template<typename node>
struct trie{
   node* root=new node();
   void insert(int p){
       node* temp=root;
       for(int j=30;j>=0;j--){
           temp->cnt++;
           int k=(((1LL<<j)&p)!=0);
           if(temp->sons[k]==NULL){
               temp->sons[k]=create();
               temp=temp->sons[k];
           }
           else{
               temp=temp->sons[k];
           } 
       }
       temp->cnt++;
   }
   void erase(int p){
       node* temp=root;
       for(int j=30;j>=0;j--){
           temp->cnt--;
           if(p&(1<<j))temp=temp->sons[1];
           else temp=temp->sons[0];
       }
       temp->cnt--;
   }
   int query(int x){
       node* temp = root;
       int ans = 0;
       for(int j = 30; j >= 0; j--){
	       	if((1 << j) & x){
	       		if(temp->sons[0] and temp->sons[0]->cnt) {
	       			ans += 1 << j;
	       			temp = temp -> sons[0];
	       		}else{
	       			temp = temp -> sons[1];
	       		}
	       	}else{
	       		if(temp->sons[1] and temp->sons[1]->cnt) {
	       			ans += 1 << j;
	       			temp = temp -> sons[1];
	       		}else{
	       			temp = temp -> sons[0];
	       		}
	       	}
       }	
       return ans;
      // function of query
   }
};
const int maxn  = 1e5 + 5;
int p[maxn];
int sz[maxn];
void clear(int n=maxn){
    rep(i,0,n + 1)p[i]=i,sz[i]=1;
}
int root(int x){
   while(x!=p[x]){
       p[x]=p[p[x]];
       x=p[x];
   }
   return x;  
}
void merge(int x,int y){
    int p1=root(x);
    int p2=root(y);
    if(p1==p2)return;
    if(sz[p1]>=sz[p2]){
        p[p2]=p1;
        sz[p1]+=sz[p2];
    }
    else{
        p[p1]=p2;
        sz[p2]+=sz[p1];
    }
}

const int MOD = hell;
 
struct mod_int {
    int val;
 
    mod_int(long long v = 0) {
        if (v < 0)
            v = v % MOD + MOD;
 
        if (v >= MOD)
            v %= MOD;
 
        val = v;
    }
 
    static int mod_inv(int a, int m = MOD) {
        int g = m, r = a, x = 0, y = 1;
 
        while (r != 0) {
            int q = g / r;
            g %= r; swap(g, r);
            x -= q * y; swap(x, y);
        }
 
        return x < 0 ? x + m : x;
    }
 
    explicit operator int() const {
        return val;
    }
 
    mod_int& operator+=(const mod_int &other) {
        val += other.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
 
    mod_int& operator-=(const mod_int &other) {
        val -= other.val;
        if (val < 0) val += MOD;
        return *this;
    }
 
    static unsigned fast_mod(uint64_t x, unsigned m = MOD) {
           #if !defined(_WIN32) || defined(_WIN64)
                return x % m;
           #endif
           unsigned x_high = x >> 32, x_low = (unsigned) x;
           unsigned quot, rem;
           asm("divl %4\n"
            : "=a" (quot), "=d" (rem)
            : "d" (x_high), "a" (x_low), "r" (m));
           return rem;
    }
 
    mod_int& operator*=(const mod_int &other) {
        val = fast_mod((uint64_t) val * other.val);
        return *this;
    }
 
    mod_int& operator/=(const mod_int &other) {
        return *this *= other.inv();
    }
 
    friend mod_int operator+(const mod_int &a, const mod_int &b) { return mod_int(a) += b; }
    friend mod_int operator-(const mod_int &a, const mod_int &b) { return mod_int(a) -= b; }
    friend mod_int operator*(const mod_int &a, const mod_int &b) { return mod_int(a) *= b; }
    friend mod_int operator/(const mod_int &a, const mod_int &b) { return mod_int(a) /= b; }
 
    mod_int& operator++() {
        val = val == MOD - 1 ? 0 : val + 1;
        return *this;
    }
 
    mod_int& operator--() {
        val = val == 0 ? MOD - 1 : val - 1;
        return *this;
    }
 
    mod_int operator++(int32_t) { mod_int before = *this; ++*this; return before; }
    mod_int operator--(int32_t) { mod_int before = *this; --*this; return before; }
 
    mod_int operator-() const {
        return val == 0 ? 0 : MOD - val;
    }
 
    bool operator==(const mod_int &other) const { return val == other.val; }
    bool operator!=(const mod_int &other) const { return val != other.val; }
 
    mod_int inv() const {
        return mod_inv(val);
    }
 
    mod_int pow(long long p) const {
        assert(p >= 0);
        mod_int a = *this, result = 1;
 
        while (p > 0) {
            if (p & 1)
                result *= a;
 
            a *= a;
            p >>= 1;
        }
 
        return result;
    }
 
    friend ostream& operator<<(ostream &stream, const mod_int &m) {
        return stream << m.val;
    }
    friend istream& operator >> (istream &stream, mod_int &m) {
        return stream>>m.val;   
    }
};

int solve(){
 		int n = readIntLn(1,1e5);
 		static int sum_n = 0;
 		sum_n += n;

 		assert(sum_n <= 1e5);

 		vector<vector<int>> g(n + 1);
 		vector<int> a = readVectorInt(n,1,1e9);

 		clear(n + 1);

 		for(int i = 2; i <= n; i++){
 			int u = readIntSp(1,n);
 			int v = readIntLn(1,n);

 			assert(root(u) != root(v));
 			merge(u,v);

 			g[u].push_back(v);
 			g[v].push_back(u);
 		}
 		mod_int ans = 0;
 		trie<node> tr;
 		function<void(int,int,int, mod_int)> dfs = [&](int u,int v,int xor_val, mod_int p){
 			
 			tr.insert(a[u - 1]);
 			xor_val ^= a[u - 1];
 			mod_int childs = g[u].size();
 			if(u != 1) childs--;

 			if(childs == 0){
 				ans += p * tr.query(xor_val);
 			}else{
 				p *= childs.inv();

	 			for(auto i: g[u]){
	 				if(i != v){
	 					dfs(i,u,xor_val,p);
	 				}
	 			}
 			}
 			tr.erase(a[u - 1]);
 			xor_val ^= a[u - 1];
 		};
 		dfs(1,1,0,1);
 		cout << ans << endl;




 return 0;
}
signed main(){
    ios_base::sync_with_stdio(0);cin.tie(0);cout.tie(0);
    //freopen("input.txt", "r", stdin);
    //freopen("output.txt", "w", stdout);
    #ifdef SIEVE
    sieve();
    #endif
    #ifdef NCR
    init();
    #endif
    int t = readIntLn(1,10000);
    while(t--){
        solve();
    }
    return 0;
}
Editorialist's code (C++)
#include "bits/stdc++.h"
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

struct Trie {
	vector<int> v;
	vector<array<int, 2>> ch;
	int id = 0;
	Trie() : v(1, 0), ch(1, {-1, -1}) {}
	void create() {
		v.push_back(0);
		ch.push_back({-1, -1});
		++id;
	}
	void add(int x, int dif) {
		int node = 0;
		for (int bit = 30; bit >= 0; --bit) {
			int b = (x >> bit) & 1;
			v[node] += dif;
			if (ch[node][b] == -1) {
				create();
				ch[node][b] = id;
			}
			node = ch[node][b];
		}
		v[node] += dif;
	}
	int query (int x) { // Maximum value of a^x for a in the trie
		int node = 0, ret = 0;
		for (int bit = 30; bit >= 0; --bit) {
			int b = (x >> bit) & 1;
			if (ch[node][b^1] == -1 or v[ch[node][b^1]] == 0) node = ch[node][b];
			else {
				ret += 1 << bit;
				node = ch[node][b^1];
			}
		}
		return ret;
	}
};

int main()
{
	ios::sync_with_stdio(false); cin.tie(0);

	int t; cin >> t;
	while (t--) {
		int n; cin >> n;
		vector<int> a(n);
		for (int i = 0; i < n; ++i) cin >> a[i];
		vector<vector<int>> adj(n);
		for (int i = 0; i < n-1; ++i) {
			int u, v; cin >> u >> v;
			adj[--u].push_back(--v);
			adj[v].push_back(u);
		}

		Zp ans = 0; // Zp is a modint class, I removed the template to allow for easier reading
		Trie T;
		auto dfs = [&] (const auto &self, int u, int p, int pref, Zp prob) -> void {
			T.add(a[u], 1);
			int children = adj[u].size() - (u > 0);
			if (children) prob /= children;
			pref ^= a[u];
			for (int v : adj[u]) {
				if (v == p) continue;
				self(self, v, u, pref, prob);
			}
			if (children == 0) ans += prob * T.query(pref);
			T.add(a[u], -1);
		};
		dfs(dfs, 0, 0, 0, 1);
		cout << ans << '\n';
	}
}
1 Like

Was using a map with masking instead of Trie supposed to give TLE?
Also were there some weird time or space constraints involved?
My code with map is getting RTE on all test cases, but there doesnā€™t seem to be any problem with the indexing.
My Submission.

The issue with your code is that the acolyte function isnā€™t returning anything (but is defined to have int return type).
A function not returning anything when it should be is undefined behavior, so anything can happen: for example, it might work correctly on your system but not on codechefā€™s servers.

I recommend using compiler flags when testing your code locally: thereā€™s a short guide on them in this blog.
For example, simply compiling your code locally gave me the error

Other.cpp: In function ā€˜int acolyte(int, int)ā€™:
Other.cpp:57:1: warning: no return statement in function returning non-void [-Wreturn-type]
   57 | }

so I didnā€™t even need to read the code to figure out where the issue was.

As an aside, fixing that issue makes your code get either WA or TLE on every testcase so thereā€™s still some issues there.

If your algorithm is \mathcal{O}(N\log{10^9}) with a reasonable constant factor, itā€™ll pass.
\mathcal{O}(N\log N\log{10^9}) most likely wonā€™t, but youā€™re welcome to try; we had to increase the limits because \mathcal{O}(N^2) ran surprisingly fast.

The time limit for every problem can be found at the bottom in the ā€œMore infoā€ section.
The memory limit for every problem is 1.5GB

I recommend using compiler flags when testing your code locally: thereā€™s a short guide on them in this blog.

Thanks.

As an aside, fixing that issue makes your code get either WA or TLE on every testcase so thereā€™s still some issues there.

Will try to re-implement it with a Trie.
Map with masking isnā€™t the best way of doing this anyway.

It was a very nice multi-concept problem. Couldnā€™t solve it in contest. Took me over 2h+ to think and implement. Loved it.