PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: still_me
Testers: the_hyp0cr1t3, rivalq
Editorialist: iceknight1093
DIFFICULTY:
TBD
PREREQUISITES:
Tries, basic probability, DFS
PROBLEM:
There is a tree on N vertices, the i-th vertex has value A_i. Alice starts at vertex 1 and randomly moves to a child of the current vertex till she can no longer do so.
Her score is the bitwise xor of all but one of the values she visited along the way, and she can choose which one to drop in order to maximize her score.
Whatās her expected final score?
EXPLANATION:
Applying the definition of expected value, we see that the answer is \sum P_u S_u, where:
- P_u is the probability that Alice reaches u
- S_u is the score obtained by Alice if she reaches u
and the above summation is taken across all terminal vertices (i.e, leaves) in the tree.
Letās compute them individually.
Computing probabilities
Let P_u be the probability that Alice visits vertex u during her journey.
Since she starts at 1, we have P_1 = 1.
Now, consider some u \gt 1 that has c children. Let v be one such child.
If we already know P_u, computing P_v is easy: itās just \frac{P_u}{c}. This is because Alice must reach u then choose to go into v, so we multiply the probabilities together.
This gives us a way to compute P_u for every vertex u using DFS, starting from the root to the leaves; and the whole process takes \mathcal{O}(N \log{10^9}) time (the log factor because we perform modulo divisions).
Computing score
Let u be a leaf vertex. Letās compute Aliceās score when she reaches it.
Suppose the values on the path to u are x_1, \ldots, x_k. We want the maximum value of
across all i.
Notice that this is just
and the first value is a constant, being the xor-sum of all the values.
So, we have a list [x_1, \ldots, x_k] and a constant C, and weād like to find the maximum possible value of C\oplus x_i for some i.
This is a well-known problem, and can be solved using a trie: a tutorial can be found here.
However, the values in the trie keep changing, and we canāt afford to insert the entire path each time we process a leaf.
Instead, we maintain a single trie and reuse it across our DFS, as follows:
- Let T denote our trie. Initially, itās empty.
- When we enter node u, insert A_u into T.
- Then, if u is a leaf perform the relevant query; and if not continue the DFS into the children of u
- Finally, when exiting u, remove (one copy of) A_u from T.
This ensures we perform only N insertions and deletions each across the whole process, making the time complexity \mathcal{O}(N\log{10^9}).
Once both the probabilities and the scores have been computed, calculating the final answer is trivial using the summation above.
TIME COMPLEXITY:
\mathcal{O}(N\log{10^9}) per testcase.
CODE:
Setter's code (C++)
// Code by Sahil Tiwari (still_me)
#include<bits/stdc++.h>
#define still_me main
#define endl "\n"
#define int long long int
#define all(a) (a).begin() , (a).end()
#define print(a) for(auto TEMPORARY: a) cout<<TEMPORARY<<" ";cout<<endl;
#define tt int TESTCASE;cin>>TESTCASE;while(TESTCASE--)
#define arrin(a,n) for(int INPUT=0;INPUT<n;INPUT++)cin>>a[INPUT]
using namespace std;
const int mod = 1e9+7;
const int inf = 1e18;
long long power(long long a , long long b , long long mod){
if(b==0)
return 1;
long long res = power(a , b/2 , mod);
res = res*res%mod;
if(b%2)
res = res*a % mod;
return res;
}
int inverse(int a){
return power(a , mod-2 , mod);
}
map<int,int> p;
int cnt = 0;
void dfs(vector<vector<int>> &adj , int j , int prev , int prob) {
if(adj[j].size() == 1 && j != 0) {
p[j] = prob;
return;
}
for(int &i: adj[j]) {
if(i == prev)
continue;
dfs(adj , i , j , prob * inverse(adj[j].size() - (j == 0 ? 0 : 1)) % mod);
}
}
int ans = 0;
struct Trie{
vector<array<int, 2>> node;
vector<int> last;
vector<pair<int, int>> bck;
Trie() {
node.push_back({-1, -1});
last.push_back(-1);
bck.push_back({-1, -1});
}
void insert(int val, int n) {
int cur = 0;
for(int i = 29 ; i >= 0 ; i--) {
int p = (val >> i) & 1;
if(node[cur][p] == -1) {
node[cur][p] = node.size();
node.push_back({-1, -1});
last.push_back(n);
bck.push_back({cur, p});
}
cur = node[cur][p];
}
}
void Delete(int n) {
while(last.back() == n) {
node[bck.back().first][bck.back().second] = -1;
bck.pop_back();
last.pop_back(), node.pop_back();
}
}
int query(int v) {
int cur = 0, ans = 0;
for(int i = 29 ; i >= 0 ; i--) {
int p = (v >> i) & 1;
if(node[cur][1 ^ p] > 0)
ans ^= 1 << i, cur = node[cur][1 ^ p];
else cur = node[cur][p];
}
return ans;
}
};
void tdfs(vector<vector<int>> &adj , vector<int> &a, int j , int prev, int curr, Trie &T) {
T.insert(a[j] , j);
curr ^= a[j];
// cout<<curr<<endl;
if(adj[j].size() == 1 && j != 0) {
ans += p[j] * T.query(curr);
ans %= mod;
}
for(int &i: adj[j]) {
if(i == prev)
continue;
tdfs(adj , a , i , j , curr, T);
}
T.Delete(j);
}
void chal_bsdk() {
p.clear();
ans = 0;
int n;
cin>>n;
vector<int> a(n);
arrin(a , n);
vector<vector<int>> adj(n);
for(int i=0;i<n-1;i++) {
int u , v;
cin>>u>>v;
u--;v--;
adj[u].push_back(v);
adj[v].push_back(u);
}
dfs(adj , 0 , 0 , 1);
Trie T;
tdfs(adj , a , 0 , 0 , 0 , T);
cout<<ans<<endl;
}
signed still_me()
{
ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL);
// freopen("15.in" , "r" , stdin);
// freopen("15.out" , "w" , stdout);
tt{
chal_bsdk();
}
return 0;
}
Tester's code (C++)
// Jai Shree Ram
#include<bits/stdc++.h>
using namespace std;
#define rep(i,a,n) for(int i=a;i<n;i++)
#define ll long long
#define int long long
#define pb push_back
#define all(v) v.begin(),v.end()
#define endl "\n"
#define x first
#define y second
#define gcd(a,b) __gcd(a,b)
#define mem1(a) memset(a,-1,sizeof(a))
#define mem0(a) memset(a,0,sizeof(a))
#define sz(a) (int)a.size()
#define pii pair<int,int>
#define hell 1000000007
#define elasped_time 1.0 * clock() / CLOCKS_PER_SEC
template<typename T1,typename T2>istream& operator>>(istream& in,pair<T1,T2> &a){in>>a.x>>a.y;return in;}
template<typename T1,typename T2>ostream& operator<<(ostream& out,pair<T1,T2> a){out<<a.x<<" "<<a.y;return out;}
template<typename T,typename T1>T maxs(T &a,T1 b){if(b>a)a=b;return a;}
template<typename T,typename T1>T mins(T &a,T1 b){if(b<a)a=b;return a;}
// -------------------- Input Checker Start --------------------
long long readInt(long long l, long long r, char endd)
{
long long x = 0;
int cnt = 0, fi = -1;
bool is_neg = false;
while(true)
{
char g = getchar();
if(g == '-')
{
assert(fi == -1);
is_neg = true;
continue;
}
if('0' <= g && g <= '9')
{
x *= 10;
x += g - '0';
if(cnt == 0)
fi = g - '0';
cnt++;
assert(fi != 0 || cnt == 1);
assert(fi != 0 || is_neg == false);
assert(!(cnt > 19 || (cnt == 19 && fi > 1)));
}
else if(g == endd)
{
if(is_neg)
x = -x;
if(!(l <= x && x <= r))
{
cerr << l << ' ' << r << ' ' << x << '\n';
assert(false);
}
return x;
}
else
{
assert(false);
}
}
}
string readString(int l, int r, char endd)
{
string ret = "";
int cnt = 0;
while(true)
{
char g = getchar();
assert(g != -1);
if(g == endd)
break;
cnt++;
ret += g;
}
assert(l <= cnt && cnt <= r);
return ret;
}
long long readIntSp(long long l, long long r) { return readInt(l, r, ' '); }
long long readIntLn(long long l, long long r) { return readInt(l, r, '\n'); }
string readStringLn(int l, int r) { return readString(l, r, '\n'); }
string readStringSp(int l, int r) { return readString(l, r, ' '); }
void readEOF() { assert(getchar() == EOF); }
vector<int> readVectorInt(int n, long long l, long long r)
{
vector<int> a(n);
for(int i = 0; i < n - 1; i++)
a[i] = readIntSp(l, r);
a[n - 1] = readIntLn(l, r);
return a;
}
// -------------------- Input Checker End --------------------
struct node{
node* sons[2];
int cnt=0;
};
node* create(){
node* temp=new node();
temp->sons[0]=NULL;
temp->sons[1]=NULL;
temp->cnt=0;
return temp;
}
template<typename node>
struct trie{
node* root=new node();
void insert(int p){
node* temp=root;
for(int j=30;j>=0;j--){
temp->cnt++;
int k=(((1LL<<j)&p)!=0);
if(temp->sons[k]==NULL){
temp->sons[k]=create();
temp=temp->sons[k];
}
else{
temp=temp->sons[k];
}
}
temp->cnt++;
}
void erase(int p){
node* temp=root;
for(int j=30;j>=0;j--){
temp->cnt--;
if(p&(1<<j))temp=temp->sons[1];
else temp=temp->sons[0];
}
temp->cnt--;
}
int query(int x){
node* temp = root;
int ans = 0;
for(int j = 30; j >= 0; j--){
if((1 << j) & x){
if(temp->sons[0] and temp->sons[0]->cnt) {
ans += 1 << j;
temp = temp -> sons[0];
}else{
temp = temp -> sons[1];
}
}else{
if(temp->sons[1] and temp->sons[1]->cnt) {
ans += 1 << j;
temp = temp -> sons[1];
}else{
temp = temp -> sons[0];
}
}
}
return ans;
// function of query
}
};
const int maxn = 1e5 + 5;
int p[maxn];
int sz[maxn];
void clear(int n=maxn){
rep(i,0,n + 1)p[i]=i,sz[i]=1;
}
int root(int x){
while(x!=p[x]){
p[x]=p[p[x]];
x=p[x];
}
return x;
}
void merge(int x,int y){
int p1=root(x);
int p2=root(y);
if(p1==p2)return;
if(sz[p1]>=sz[p2]){
p[p2]=p1;
sz[p1]+=sz[p2];
}
else{
p[p1]=p2;
sz[p2]+=sz[p1];
}
}
const int MOD = hell;
struct mod_int {
int val;
mod_int(long long v = 0) {
if (v < 0)
v = v % MOD + MOD;
if (v >= MOD)
v %= MOD;
val = v;
}
static int mod_inv(int a, int m = MOD) {
int g = m, r = a, x = 0, y = 1;
while (r != 0) {
int q = g / r;
g %= r; swap(g, r);
x -= q * y; swap(x, y);
}
return x < 0 ? x + m : x;
}
explicit operator int() const {
return val;
}
mod_int& operator+=(const mod_int &other) {
val += other.val;
if (val >= MOD) val -= MOD;
return *this;
}
mod_int& operator-=(const mod_int &other) {
val -= other.val;
if (val < 0) val += MOD;
return *this;
}
static unsigned fast_mod(uint64_t x, unsigned m = MOD) {
#if !defined(_WIN32) || defined(_WIN64)
return x % m;
#endif
unsigned x_high = x >> 32, x_low = (unsigned) x;
unsigned quot, rem;
asm("divl %4\n"
: "=a" (quot), "=d" (rem)
: "d" (x_high), "a" (x_low), "r" (m));
return rem;
}
mod_int& operator*=(const mod_int &other) {
val = fast_mod((uint64_t) val * other.val);
return *this;
}
mod_int& operator/=(const mod_int &other) {
return *this *= other.inv();
}
friend mod_int operator+(const mod_int &a, const mod_int &b) { return mod_int(a) += b; }
friend mod_int operator-(const mod_int &a, const mod_int &b) { return mod_int(a) -= b; }
friend mod_int operator*(const mod_int &a, const mod_int &b) { return mod_int(a) *= b; }
friend mod_int operator/(const mod_int &a, const mod_int &b) { return mod_int(a) /= b; }
mod_int& operator++() {
val = val == MOD - 1 ? 0 : val + 1;
return *this;
}
mod_int& operator--() {
val = val == 0 ? MOD - 1 : val - 1;
return *this;
}
mod_int operator++(int32_t) { mod_int before = *this; ++*this; return before; }
mod_int operator--(int32_t) { mod_int before = *this; --*this; return before; }
mod_int operator-() const {
return val == 0 ? 0 : MOD - val;
}
bool operator==(const mod_int &other) const { return val == other.val; }
bool operator!=(const mod_int &other) const { return val != other.val; }
mod_int inv() const {
return mod_inv(val);
}
mod_int pow(long long p) const {
assert(p >= 0);
mod_int a = *this, result = 1;
while (p > 0) {
if (p & 1)
result *= a;
a *= a;
p >>= 1;
}
return result;
}
friend ostream& operator<<(ostream &stream, const mod_int &m) {
return stream << m.val;
}
friend istream& operator >> (istream &stream, mod_int &m) {
return stream>>m.val;
}
};
int solve(){
int n = readIntLn(1,1e5);
static int sum_n = 0;
sum_n += n;
assert(sum_n <= 1e5);
vector<vector<int>> g(n + 1);
vector<int> a = readVectorInt(n,1,1e9);
clear(n + 1);
for(int i = 2; i <= n; i++){
int u = readIntSp(1,n);
int v = readIntLn(1,n);
assert(root(u) != root(v));
merge(u,v);
g[u].push_back(v);
g[v].push_back(u);
}
mod_int ans = 0;
trie<node> tr;
function<void(int,int,int, mod_int)> dfs = [&](int u,int v,int xor_val, mod_int p){
tr.insert(a[u - 1]);
xor_val ^= a[u - 1];
mod_int childs = g[u].size();
if(u != 1) childs--;
if(childs == 0){
ans += p * tr.query(xor_val);
}else{
p *= childs.inv();
for(auto i: g[u]){
if(i != v){
dfs(i,u,xor_val,p);
}
}
}
tr.erase(a[u - 1]);
xor_val ^= a[u - 1];
};
dfs(1,1,0,1);
cout << ans << endl;
return 0;
}
signed main(){
ios_base::sync_with_stdio(0);cin.tie(0);cout.tie(0);
//freopen("input.txt", "r", stdin);
//freopen("output.txt", "w", stdout);
#ifdef SIEVE
sieve();
#endif
#ifdef NCR
init();
#endif
int t = readIntLn(1,10000);
while(t--){
solve();
}
return 0;
}
Editorialist's code (C++)
#include "bits/stdc++.h"
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());
struct Trie {
vector<int> v;
vector<array<int, 2>> ch;
int id = 0;
Trie() : v(1, 0), ch(1, {-1, -1}) {}
void create() {
v.push_back(0);
ch.push_back({-1, -1});
++id;
}
void add(int x, int dif) {
int node = 0;
for (int bit = 30; bit >= 0; --bit) {
int b = (x >> bit) & 1;
v[node] += dif;
if (ch[node][b] == -1) {
create();
ch[node][b] = id;
}
node = ch[node][b];
}
v[node] += dif;
}
int query (int x) { // Maximum value of a^x for a in the trie
int node = 0, ret = 0;
for (int bit = 30; bit >= 0; --bit) {
int b = (x >> bit) & 1;
if (ch[node][b^1] == -1 or v[ch[node][b^1]] == 0) node = ch[node][b];
else {
ret += 1 << bit;
node = ch[node][b^1];
}
}
return ret;
}
};
int main()
{
ios::sync_with_stdio(false); cin.tie(0);
int t; cin >> t;
while (t--) {
int n; cin >> n;
vector<int> a(n);
for (int i = 0; i < n; ++i) cin >> a[i];
vector<vector<int>> adj(n);
for (int i = 0; i < n-1; ++i) {
int u, v; cin >> u >> v;
adj[--u].push_back(--v);
adj[v].push_back(u);
}
Zp ans = 0; // Zp is a modint class, I removed the template to allow for easier reading
Trie T;
auto dfs = [&] (const auto &self, int u, int p, int pref, Zp prob) -> void {
T.add(a[u], 1);
int children = adj[u].size() - (u > 0);
if (children) prob /= children;
pref ^= a[u];
for (int v : adj[u]) {
if (v == p) continue;
self(self, v, u, pref, prob);
}
if (children == 0) ans += prob * T.query(pref);
T.add(a[u], -1);
};
dfs(dfs, 0, 0, 0, 1);
cout << ans << '\n';
}
}