TREE_GAME - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Authors: krypto_ray, gunpoint_88 shubham_grg
Testers: iceknight1093, tabr
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

DFS

PROBLEM:

There’s a tree on N vertices. Alice and Bob start at vertices A and B of this tree and play a game.
The i-th vertex has a population of P_i.

Each turn proceeds as follows:

  • If either Alice or Bob cannot move, the game ends.
  • Otherwise, Alice moves to an unvisited (by her) vertex, and then Bob moves to an unvisited (by him) vertex.
  • Alice receives a point if the total population of the vertices visited by her so far exceeds the total population of the vertices visited by Bob.

Alice moves to maximize her score, while Bob moves to minimize it.
Find Alice’s final score.

EXPLANATION:

Let’s define a state of the game as a pair (x, y), denoting that Alice is at vertex x and Bob is at y.
The initial state of the game is (A, B).

Let SA_u denote the sum of populations on the A\to u path, and SB_u similarly denote the sum of populations on the B\to u path.
These can be precomputed with DFS.

Notice that a state of the game (x, y) uniquely defines both Alice’s and Bob’s paths so far, and hence their scores so far.
So, for each state, it suffices for us to find the best move Alice can make.

Let f(x, y) denote Alice’s best score if the game starts at state (x, y).
Our objective is to compute f(A, B)

From a state (x, y), the next state can be any (u, v) such that:

  • u is a neighbor of x and v is a neighbor of y
  • u doesn’t lie on the A-x path and v doesn’t lie on the B-y path.

In particular, we can see that:

  • If Alice fixes her choice of u, then Bob will choose v such that f(u, v) is minimized.
  • So, across all possible choices of u, to maximize her own score, Alice will choose the u such that \min_v f(u, v) is maximized.

Rewriting this in terms of f(x, y), we have

f(x, y) = \max_u(\min_v f(u, v)) + (SA_x \gt SB_y)

where the choice is across all valid neighbors u and v of x and y.

The ‘brute force’ method of computing this is, of course, to just iterate across all neighbors u and v of x and y and recursively compute their f(u, v) values.
Let’s say we also cache the values of f in a 2D array so that states aren’t recomputed.

Let’s analyze the time complexity of this.

  • There are \mathcal{O}(N^2) possible states. Not all of them are necessarily reachable, but the number of reachable ones can definitely be \Theta(N^2).
  • For each state, we do \mathcal{O}(N^2) work by iterating across all pairs of neighbors, giving us a total complexity of \mathcal{O}(N^4).

However, we can do a better analysis!
Notice that the transitions essentially consider a pair of edges.
However, each edge can be compared with another one at most four times, one for each pair of endpoints of the edge.
This is a tree, so there are only (N-1) edges.

This means the total number of transitions we make, across all states, is bounded by 4(N-1)^2.

In other words, our ‘brute force’ algorithm is really \mathcal{O}(N^2), and is already fast enough!

You might notice that even caching the values of f(x, y) is unnecessary, since each state is going to be visited at most once anyway.

TIME COMPLEXITY

\mathcal{O}(N^2) per test case.

CODE:

Setter's code (C++)
#include<bits/stdc++.h>
using namespace std;
using ll=long long;
const ll inf=1e16;

#ifdef ANI
#include "D:/DUSTBIN/local_inc.h"
#else
#define dbg(...) 0
#endif

void solve(int &tot) {
    ll n,x,y;
    cin>>n>>x>>y;
    assert(n<=5000 && x<=n && y<=n);
    x--;y--;tot+=n;
    vector<ll> a(n);
    ll nax=1e9;
    for(ll i=0;i<n;i++) {
    	cin>>a[i];
    	assert(a[i]<=nax && a[i]>=1);
    }
    vector<vector<ll>> e(n);
    for(ll i=0;i<n-1;i++) {
    	ll u,v;
    	cin>>u>>v;
        e[u-1].push_back(v-1);
        e[v-1].push_back(u-1);
        assert(u<=n && v<=n && u>=1 && v>=1 && u!=v);
    }

    auto dfs=[&](ll u,ll v,ll su,ll sv,ll pu,ll pv,ll score,auto &&dfs)->ll{ // comsute game states
        su+=a[u],sv+=a[v]; score+=su>sv;
        if((e[u].size()==1&&u!=x)||(e[v].size()==1&&v!=y)) 
            return score;
        ll res=0;
        for(ll p:e[u]) {
            if(p==pu) continue;
            ll cur=inf;
            for(ll q:e[v]) {
                if(q!=pv) 
                    cur=min(cur,dfs(p,q,su,sv,u,v,score,dfs));
            }
            res=max(res,cur);
        }
        return res;
    };
    cout<<dfs(x,y,0,0,-1,-1,0,dfs)<<"\n";
}

int main() {
	ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL);
	int t;
	cin>>t;
	assert(t<=1000);
	int tot=0;
	while(t--) {
		solve(tot);
	}	
	assert(tot<=5000);
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        string res;
        while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int min_len, int max_len, const string& pattern = "") {
        assert(min_len <= max_len);
        string res = readOne();
        assert(min_len <= (int) res.size());
        assert((int) res.size() <= max_len);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int min_val, int max_val) {
        assert(min_val <= max_val);
        int res = stoi(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    long long readLong(long long min_val, long long max_val) {
        assert(min_val <= max_val);
        long long res = stoll(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    vector<int> readInts(int size, int min_val, int max_val) {
        assert(min_val <= max_val);
        vector<int> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readInt(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    vector<long long> readLongs(int size, long long min_val, long long max_val) {
        assert(min_val <= max_val);
        vector<long long> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readLong(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

struct dsu {
    vector<int> p;
    vector<int> sz;
    int n;

    dsu(int _n) : n(_n) {
        p = vector<int>(n);
        iota(p.begin(), p.end(), 0);
        sz = vector<int>(n, 1);
    }

    inline int get(int x) {
        if (p[x] == x) {
            return x;
        } else {
            return p[x] = get(p[x]);
        }
    }

    inline bool unite(int x, int y) {
        x = get(x);
        y = get(y);
        if (x == y) {
            return false;
        }
        p[x] = y;
        sz[y] += sz[x];
        return true;
    }

    inline bool same(int x, int y) {
        return (get(x) == get(y));
    }

    inline int size(int x) {
        return sz[get(x)];
    }

    inline bool root(int x) {
        return (x == get(x));
    }
};

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    input_checker in;
    int tt = in.readInt(1, 1000);
    in.readEoln();
    int sn = 0;
    while (tt--) {
        int n = in.readInt(1, 5000);
        in.readSpace();
        int a = in.readInt(1, n);
        in.readSpace();
        int b = in.readInt(1, n);
        in.readEoln();
        a--;
        b--;
        vector<long long> p = in.readLongs(n, 1, 1e9);
        in.readEoln();
        vector<vector<int>> g(n);
        dsu uf(n);
        for (int i = 0; i < n - 1; i++) {
            int x = in.readInt(1, n);
            in.readSpace();
            int y = in.readInt(1, n);
            in.readEoln();
            x--;
            y--;
            uf.unite(x, y);
            g[x].emplace_back(y);
            g[y].emplace_back(x);
        }
        assert(uf.size(0) == n);
        function<int(int, int, int, int, long long, long long)> Dfs = [&](int va, int vb, int pa, int pb, long long ca, long long cb) {
            int res = 0;
            for (int toa : g[va]) {
                if (toa == pa) {
                    continue;
                }
                long long da = ca + p[toa];
                int t = 1e9;
                int s = 0;
                for (int tob : g[vb]) {
                    if (tob == pb) {
                        continue;
                    }
                    s = 1;
                    long long db = cb + p[tob];
                    t = min(t, (da > db) + Dfs(toa, tob, va, vb, da, db));
                }
                res = max(res, t * s);
            }
            return res;
        };
        cout << (p[a] > p[b]) + Dfs(a, b, -1, -1, p[a], p[b]) << '\n';
    }
    assert(sn <= 5000);
    in.readEof();
    return 0;
}
Editorialist's code (Python)
def bfs(adj, par, pref, val, src):
	par[src] = -1
	pref[src] = val[src]
	vertices = [src]
	for u in vertices:
		for v in adj[u]:
			if par[u] == v: continue
			par[v] = u
			pref[v] = pref[u] + val[v]
			vertices.append(v)

for _ in range(int(input())):
	n, a, b = map(int, input().split())
	val = list(map(int, input().split()))
	adj = [[] for _ in range(n)]
	
	for i in range(n-1):
		x, y = map(int, input().split())
		adj[x-1].append(y-1)
		adj[y-1].append(x-1)
	
	parA, parB = [0]*n, [0]*n
	prefA, prefB = [0]*n, [0]*n
	bfs(adj, parA, prefA, val, a-1)
	bfs(adj, parB, prefB, val, b-1)

	def go(x, y):
		add = 0
		if prefA[x] > prefB[y]: add = 1

		mx = 0
		for u in adj[x]:
			if parA[x] == u: continue
			mn = 10 ** 6
			for v in adj[y]:
				if parB[y] == v: continue
				mn = min(mn, go(u, v))
			if mn == 10 ** 6: mn = 0
			mx = max(mx, mn)

		return add + mx
	
	print(go(a-1, b-1))
2 Likes

Excellent problem set. :slight_smile: .

1 Like

@iceknight1093 ,

and SB[y] similarly denote the sum of populations on the 
B→y path.

For A → x path and for B->y path … this will make more sense … since we are going to state(x,y) from state (A,B) .

Explanation section , 2n’d para.

The x in SA_x is just a variable, its name doesn’t really matter.
I’ll change it to u if that makes it more clear.

@iceknight1093 no doubt you’re great, but this line was not clear to me. Any help ?

Do you get the fact that each transition corresponds to a pair of edges?

Once you have that, this fact becomes obvious.
If you have two edges u_1 \leftrightarrow v_1 and u_2 \leftrightarrow v_2, the only time they’ll be considered against each other is when your state is one of (u_1, u_2), (u_1, v_2), (v_1, u_2), (v_1, v_2) right?
After all, you need to be at an endpoint of both edges for them to be considered against each other, and there are only four possible such pairs.

1 Like

great
last question how ensure that we don’t need to memorize the states ?

You’re basically performing two simultaneous DFS-es on the tree, starting from nodes A and B respectively, right?
Since it’s a tree, how many times will you visit a vertex from its parent during this DFS?

1 Like

exactly 1
thanks.