 # BOFP - Editorial

Practice

Contest: Division 1

Contest: Division 2

Setter: Farhod

Tester: Teja Vardhan Reddy

Editorialist: Taranpreet Singh

Medium-Hard

# PREREQUISITES:

Observations, Fenwick Tree/Segment Tree, Rerooting.

# PROBLEM:

Given a tree with N nodes, denoted by N-1 edges in form (a_i, b_i) and a sequence W of length N-1. Let’s denote for some sequence A of length N-1, F(A) = \displaystyle\sum_{i = 1}^{N-1} A_i * W_i.

Now, consider a reordering (P, R) as any permutation P of the integers from 1 to N-1 and any subset R of these integers. Now, let’s define two sequences C and D of length N-1 as

• if P_i \notin R, then C_i = A_{P_i} and D_i = B_{P_i}
• if P_i \in R, then C_i = B_{P_i} and D_i = A_{P_i}.

Among all these orderings, a valid ordering is the one where D is in strictly increasing order. Find the maximum value of F(C) among all valid orderings and the number of valid orderings which gives this maximum value.

# QUICK EXPLANATION

• Observe that only N different valid D sequences can exist, and each D contains exactly N-1 distinct values among natural numbers from 1 to N.
• If we make N rooted trees using same edges and all choices of roots r, and for each tree, consider each non-root node x with parent p, we can consider pairs (p, x) and sort these pairs by x to form sequence C and D. We can see, that D is strictly increasing since all x are different and are in increasing order, so we can consider every such C and calculate F(C).
• For considering rooted tree with root r, F(C) can be written as \displaystyle\sum_{x = 1}^{r-1}W_x*V_x + \displaystyle\sum_{x=r+1}^{N}W_{x-1}*V_x where V_x = p for each edge (p, x) and V_r = 0. This happens because for nodes after r, each V_x has to be shifted one place to the left.
• For maintaining above sum, we can use two BIT or Segment Trees, First segment tree storing V_i * W_i for each leaf i and second segment tree storing V_i*W_{i-1} for each leaf, allowing us to query above sum as sum of range [1, r-1] in first segment tree and [r+1, N] in second segment tree (handling boundary cases).

# EXPLANATION

Too many possibilities are there. Total 2^{N-1}*(N-1)! reordering exists. So, no brute force got a chance to pass.

Let us focus on sequence D first. There are N-1 distinct elements in D in strictly increasing order, and only contains value from 1 to N. It is easy to see that only N distinct sequences D can exist.

Consider N = 4. Only following four D sequences can exist.

1 2 3
1 2 4
1 3 4
2 3 4


Each sequence has exactly one of the element from 1 to N missing.

For now, let us consider tree formed with these edges and let’s root it at r. Now, for each non-root node x, we can see, there’s an edge (p, x) where x is distinct for each edge. Writing these N-1 edges in increasing order of x, this actually gave us a valid ordering, since D is in strictly increasing order. We can see, this actually gives us a unique combination of P and R which result in this sequence of edges since subset R is decided by the orientation of edges and then permutation P is decided on basis of increasing order of x. This gives us a unique ordering of C for each chosen root.

The problem has translated into.
Given a tree N with N-1 edges and sequence W of length N-1, for every r \in [1, N], root this tree at r and consider each node x \in [1, N] and x \neq r in increasing order of x and write parent of this node in sequence C. Find the maximum value of F(C) among all possible root nodes and the number of root nodes which gives this maximum value.

Now, Let’s use a notation V_x which denote the parent of node x if root node is r and V_r = 0. We can see, if root is r, that each node x \in [1, r-1] contributes V_x*W_x to sum, and each node x \in [r+1, N] contributes V_x*W_{x-1} to sum, **since all V_x with x > r needs to be shifted one place to the left.

Why does this happen? See the following example where N = 4.

• if root is 1, required sum is V_2*W_1 + V_3*W_2 + V_4+W_3 = \displaystyle\sum_{x = 2}^{4} V_x*W_{x-1}
• if root is 2, required sum is V_1*W_1 + V_3*W_2 + V_4+W_3 = \displaystyle\sum_{x = 1}^{1}V_x*W_x + \displaystyle\sum_{x = 3}^{4} V_x*W_{x-1}
-if root is 3, required sum is V_1*W_1 + V_2*W_2 + V_4*W_3 = \displaystyle\sum_{x = 1}^{2}V_x*W_x + \displaystyle\sum_{x = 4}^{4} V_x*W_{x-1}
• if root is 4, required sum is V_1*W_1 + V_2*W_2 + V_3+W_3 = \displaystyle\sum_{x = 1}^{3} V_x*W_x

So, in order to compute for some L \leq R \displaystyle\sum_{x = L}^{R}V_x*W_x with updates on V and \displaystyle\sum_{x = L}^{R}V_x*W_{x-1}, we need some point update range query data structure. Both Fenwick tree and Segment tree can work here.

Whenever updating V_x to some value y, we can directly change leaf of first segment tree to W_x*y and to W_{x-1}*y in second segment tree.

Now, Important thing is to try all roots. There is a simple way. Suppose r is current root, we can easily change the root to one of the direct child of r, say s, by reversing the edge (r, s) to (s, r). Due to reversing, V_s = 0 and V_r = s happens by our definition, which needs to be updated in both segment trees.

So we can actually try all roots in linear time in a dfs style manner as in following pseudo-code.

def tryRoot(x, par):
//current root is u, make query here, [1, u-1] to first segment tree and [u+1, n] to second segment tree
for v in children[u]:
// Reverse edge (u, v) to (v, u)
tryRoot(v, u) //recursively tries all roots in subtree of v
// Reverse edge (v, u) to (u, v)


So now, we can easily get F(C) for all reorderings and thus, can find maximum F(C) as well as the number of reordering giving this maximum value of F(C).

For implementation, it is easier to make two copies of W, and for the second one, shift it one step to the right. It allows us to use two copies of same Segment tree, refer my code for details.

# TIME COMPLEXITY

The time complexity is O(N*logN) per test case.

# SOLUTIONS:

Setter's Solution
#include <bits/stdc++.h>

#define fi first
#define se second

const int N = 200200;

using namespace std;

struct fenwick
{
long long t[N];

void upd(int x, int y)
{
while(x < N){
t[x] += y;
x += x & -x;
}
}
long long get(int x)
{
long long res = 0;
while(x > 0){
res += t[x];
x -= x & -x;
}
return res;
}
};

int n;
int a[N];
int w[N];
vector < int > v[N];
fenwick T1, T2;

void upd(int x, int p)
{
T1.upd(x, - a[x] * w[x]);
T2.upd(x, - a[x] * w[x - 1]);

a[x] = p;

T1.upd(x, + a[x] * w[x]);
T2.upd(x, + a[x] * w[x - 1]);
}

void dfs(int x, int p)
{
upd(x, p);
for(int y: v[x]){
if(y == p){
continue;
}
dfs(y, x);
}
}

int cnt;
long long res;

void trace(int x, int p)
{
long long cur = T1.get(x - 1) + (T2.get(n) - T2.get(x));
if(cur > res){
res = cur;
cnt = 1;
} else if(cur == res){
cnt += 1;
}

for(int y: v[x]){
if(y == p){
continue;
}
upd(y, 0);
upd(x, y);
trace(y, x);
upd(x, 0);
upd(y, x);
}
}

void solve()
{
cin >> n;

for(int i = 1; i < N; i++){
T1.t[i] = T2.t[i] = 0;
v[i].clear();
a[i] = 0;
}

for(int i = 1; i < n; i++){
cin >> w[i];
}
for(int i = 1; i < n; i++){
int x, y;
cin >> x >> y;
v[x].push_back(y);
v[y].push_back(x);
}

cnt = 0;
res = -3e18;

dfs(1, 0);
trace(1, 0);

cout << res << " " << cnt << "\n";
}

int main()
{
ios_base::sync_with_stdio(0);

int T;
cin >> T;
while(T--){
solve();
}
}

Tester's Solution
//teja349
#include <bits/stdc++.h>
#include <vector>
#include <set>
#include <map>
#include <string>
#include <cstdio>
#include <cstdlib>
#include <climits>
#include <utility>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <iomanip>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
//setbase - cout << setbase (16); cout << 100 << endl; Prints 64
//setfill -   cout << setfill ('x') << setw (5); cout << 77 << endl; prints xxx77
//setprecision - cout << setprecision (14) << f << endl; Prints x.xxxx
//cout.precision(x)  cout<<fixed<<val;  // prints x digits after decimal in val

using namespace std;
using namespace __gnu_pbds;

#define f(i,a,b) for(i=a;i<b;i++)
#define rep(i,n) f(i,0,n)
#define fd(i,a,b) for(i=a;i>=b;i--)
#define pb push_back
#define mp make_pair
#define vi vector< int >
#define vl vector< ll >
#define ss second
#define ff first
#define ll long long
#define pii pair< int,int >
#define pll pair< ll,ll >
#define sz(a) a.size()
#define inf (1000*1000*1000+5)
#define all(a) a.begin(),a.end()
#define tri pair<int,pii>
#define vii vector<pii>
#define vll vector<pll>
#define viii vector<tri>
#define mod (1000*1000*1000+7)
#define pqueue priority_queue< int >
#define pdqueue priority_queue< int,vi ,greater< int > >
#define flush fflush(stdout)
#define primeDEN 727999983

// find_by_order()  // order_of_key
typedef tree<
int,
null_type,
less<int>,
rb_tree_tag,
tree_order_statistics_node_update>
ordered_set;

#define int ll

int intim,outim;
int paren;

int timer=3;
vector<vi> child(212345);

int dfs(int cur,int par){
int i;
intim[cur]=timer++;
paren[cur]=par;
if(par!=-1){
child[par].pb(cur);
}
}
}
outim[cur]=timer++;
return 0;
}

int bit;

int update(int typ,int pos,int val){
typ--;
while(pos<timer+5){
bit[typ][pos]+=val;
pos+=pos&(-pos);
}
return 0;
}

int query(int typ,int pos){
typ--;
int ans=0;
while(pos>0){
ans+=bit[typ][pos];
pos-=pos&(-pos);
}
return ans;
}
int w,a,b;

main(){
std::ios::sync_with_stdio(false); cin.tie(NULL);
int t;
cin>>t;
while(t--){
timer=3;
int n;
cin>>n;
int i,j;
f(i,1,n){
cin>>w[i];
}
rep(i,n+10){
child[i].clear();
}
rep(i,n-1){
cin>>a[i]>>b[i];
}
dfs(0,-1);
vii vec;
f(i,1,n+1){
vec.pb(mp(i,paren[i]));
}
sort(all(vec));
int ans=0;
f(i,2,n+1){
//cout<<
ans+=paren[i]*w[i-1];
}
f(i,2,n+1){
update(1,intim[i],paren[i]*w[i-1]);
update(1,outim[i],-1*paren[i]*w[i-1]);
}
f(i,1,n+1){
if(paren[i]==1)
continue;
update(2,intim[i],i*w[paren[i]-1]);
update(2,outim[i],-1*i*w[paren[i]-1]);
}
int maxi=ans;
int cnt=1,val;
//cout<<maxi<<endl;
f(i,2,n+1){
ans-=paren[i]*w[i-1];
ans+=paren[i-1]*w[i-1];
update(1,intim[i],-1*paren[i]*w[i-1]);
update(1,outim[i],paren[i]*w[i-1]);

update(1,intim[i-1],paren[i-1]*w[i-1]);
update(1,outim[i-1],-1*paren[i-1]*w[i-1]);

rep(j,child[i].size()){
update(2,intim[child[i][j]],-1*child[i][j]*w[i-1]);
update(2,outim[child[i][j]],child[i][j]*w[i-1]);
}

rep(j,child[i-1].size()){
update(2,intim[child[i-1][j]],child[i-1][j]*w[i-1]);
update(2,outim[child[i-1][j]],-1*child[i-1][j]*w[i-1]);
}
val=ans-query(1,intim[paren[i]])+query(2,intim[i]);
if(val>maxi){
maxi=val;
cnt=1;
}
else if(val==maxi){
cnt++;
}
}
cout<<maxi<<" "<<cnt<<endl;
rep(i,timer+10){
bit[i]=0;
bit[i]=0;
}
}
return 0;
}

Editorialist's Solution
import java.util.*;
import java.io.*;
import java.text.*;
class BOFP{
//SOLUTION BEGIN
void pre() throws Exception{}
void solve(int TC) throws Exception{
int n = ni();
long[] w1 = new long[1+n], w2 = new long[1+n];
for(int i = 1; i< n; i++)w1[i] = w2[i+1] = nl();
SegTree t1 = new SegTree(n+1, w1), t2 = new SegTree(n+1, w2);
int[][] e = new int[n-1][];
for(int i = 0; i< n-1; i++)e[i] = new int[]{ni(), ni()};
int[][] g = makeU(n+1, e);
dfs(g, t1, t2, 1, 0);
ans = (long)-1e18;count = 0;
calc(g, t1, t2, n, 1, 0);
pn(ans+" "+count);
}
long ans,count;
void calc(int[][] g, SegTree t1, SegTree t2, int n, int u, int p){
long cur = 0;
if(u > 1)cur += t1.query(1, u-1);
if(u < n)cur += t2.query(u+1, n);
if(cur > ans){
ans = cur;count = 0;
}
if(cur == ans)count++;

for(int v:g[u]){
if(v == p)continue;
t1.assign(v, 0);
t2.assign(v, 0);
t1.assign(u, v);
t2.assign(u, v);
calc(g, t1, t2, n, v, u);
t1.assign(u, 0);
t2.assign(u, 0);
t1.assign(v, u);
t2.assign(v, u);
}
}
void dfs(int[][] g, SegTree t1, SegTree t2, int u, int p){
for(int v:g[u]){
if(v == p)continue;
t1.assign(v, u);
t2.assign(v, u);
dfs(g, t1, t2, v, u);
}
}
class SegTree{
int m = 1;
long[] w;
long[] sum;
public SegTree(int n, long[] W){
while(m<= n)m<<=1;
w = W;
sum = new long[m<<1];
}
void reset(int p){assign(p, 0);}
void assign(int pos, long x){
int p = pos+m;
sum[p] = x*w[pos];
for(p>>=1; p>0; p>>=1)sum[p] = sum[p<<1]+sum[p<<1|1];
}
long query(int l, int r){return query(l, r, 0, m-1, 1);}
long query(int l, int r, int ll, int rr, int i){
if(l == ll && r == rr)return sum[i];
int mid = (ll+rr)/2;
if(r <= mid)return query(l, r, ll, mid, i<<1);
else if(l > mid)return query(l, r, mid+1, rr, i<<1|1);
else return query(l, mid, ll, mid, i<<1)+query(mid+1, r, mid+1, rr, i<<1|1);
}
}
int[][] makeU(int n, int[][] edge){
int[][] g = new int[n][];int[] cnt = new int[n];
for(int i = 0; i< edge.length; i++){cnt[edge[i]]++;cnt[edge[i]]++;}
for(int i = 0; i< n; i++)g[i] = new int[cnt[i]];
for(int i = 0; i< edge.length; i++){
g[edge[i]][--cnt[edge[i]]] = edge[i];
g[edge[i]][--cnt[edge[i]]] = edge[i];
}
return g;
}
//SOLUTION END
void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
DecimalFormat df = new DecimalFormat("0.00000000000");
static boolean multipleTC = true;
void run() throws Exception{
out = new PrintWriter(System.out);
//Solution Credits: Taranpreet Singh
int T = (multipleTC)?ni():1;
pre();for(int t = 1; t<= T; t++)solve(t);
out.flush();
out.close();
}
public static void main(String[] args) throws Exception{
new BOFP().run();
}
int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
void p(Object o){out.print(o);}
void pn(Object o){out.println(o);}
void pni(Object o){out.println(o);out.flush();}
String n()throws Exception{return in.next();}
String nln()throws Exception{return in.nextLine();}
int ni()throws Exception{return Integer.parseInt(in.next());}
long nl()throws Exception{return Long.parseLong(in.next());}
double nd()throws Exception{return Double.parseDouble(in.next());}

StringTokenizer st;
}

}

String next() throws Exception{
while (st == null || !st.hasMoreElements()){
try{
}catch (IOException  e){
throw new Exception(e.toString());
}
}
return st.nextToken();
}

String nextLine() throws Exception{
String str = "";
try{

Feel free to share your approach, if you want to. (even if its same ) . Suggestions are welcomed as always had been. 