PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Authors: krypto_ray, gunpoint_88 shubham_grg
Testers: iceknight1093, tabr
Editorialist: iceknight1093
DIFFICULTY:
TBD
PREREQUISITES:
DFS
PROBLEM:
There’s a tree on N vertices. Alice and Bob start at vertices A and B of this tree and play a game.
The i-th vertex has a population of P_i.
Each turn proceeds as follows:
- If either Alice or Bob cannot move, the game ends.
- Otherwise, Alice moves to an unvisited (by her) vertex, and then Bob moves to an unvisited (by him) vertex.
- Alice receives a point if the total population of the vertices visited by her so far exceeds the total population of the vertices visited by Bob.
Alice moves to maximize her score, while Bob moves to minimize it.
Find Alice’s final score.
EXPLANATION:
Let’s define a state of the game as a pair (x, y), denoting that Alice is at vertex x and Bob is at y.
The initial state of the game is (A, B).
Let SA_u denote the sum of populations on the A\to u path, and SB_u similarly denote the sum of populations on the B\to u path.
These can be precomputed with DFS.
Notice that a state of the game (x, y) uniquely defines both Alice’s and Bob’s paths so far, and hence their scores so far.
So, for each state, it suffices for us to find the best move Alice can make.
Let f(x, y) denote Alice’s best score if the game starts at state (x, y).
Our objective is to compute f(A, B)
From a state (x, y), the next state can be any (u, v) such that:
- u is a neighbor of x and v is a neighbor of y
- u doesn’t lie on the A-x path and v doesn’t lie on the B-y path.
In particular, we can see that:
- If Alice fixes her choice of u, then Bob will choose v such that f(u, v) is minimized.
- So, across all possible choices of u, to maximize her own score, Alice will choose the u such that \min_v f(u, v) is maximized.
Rewriting this in terms of f(x, y), we have
where the choice is across all valid neighbors u and v of x and y.
The ‘brute force’ method of computing this is, of course, to just iterate across all neighbors u and v of x and y and recursively compute their f(u, v) values.
Let’s say we also cache the values of f in a 2D array so that states aren’t recomputed.
Let’s analyze the time complexity of this.
- There are \mathcal{O}(N^2) possible states. Not all of them are necessarily reachable, but the number of reachable ones can definitely be \Theta(N^2).
- For each state, we do \mathcal{O}(N^2) work by iterating across all pairs of neighbors, giving us a total complexity of \mathcal{O}(N^4).
However, we can do a better analysis!
Notice that the transitions essentially consider a pair of edges.
However, each edge can be compared with another one at most four times, one for each pair of endpoints of the edge.
This is a tree, so there are only (N-1) edges.
This means the total number of transitions we make, across all states, is bounded by 4(N-1)^2.
In other words, our ‘brute force’ algorithm is really \mathcal{O}(N^2), and is already fast enough!
You might notice that even caching the values of f(x, y) is unnecessary, since each state is going to be visited at most once anyway.
TIME COMPLEXITY
\mathcal{O}(N^2) per test case.
CODE:
Setter's code (C++)
#include<bits/stdc++.h>
using namespace std;
using ll=long long;
const ll inf=1e16;
#ifdef ANI
#include "D:/DUSTBIN/local_inc.h"
#else
#define dbg(...) 0
#endif
void solve(int &tot) {
ll n,x,y;
cin>>n>>x>>y;
assert(n<=5000 && x<=n && y<=n);
x--;y--;tot+=n;
vector<ll> a(n);
ll nax=1e9;
for(ll i=0;i<n;i++) {
cin>>a[i];
assert(a[i]<=nax && a[i]>=1);
}
vector<vector<ll>> e(n);
for(ll i=0;i<n-1;i++) {
ll u,v;
cin>>u>>v;
e[u-1].push_back(v-1);
e[v-1].push_back(u-1);
assert(u<=n && v<=n && u>=1 && v>=1 && u!=v);
}
auto dfs=[&](ll u,ll v,ll su,ll sv,ll pu,ll pv,ll score,auto &&dfs)->ll{ // comsute game states
su+=a[u],sv+=a[v]; score+=su>sv;
if((e[u].size()==1&&u!=x)||(e[v].size()==1&&v!=y))
return score;
ll res=0;
for(ll p:e[u]) {
if(p==pu) continue;
ll cur=inf;
for(ll q:e[v]) {
if(q!=pv)
cur=min(cur,dfs(p,q,su,sv,u,v,score,dfs));
}
res=max(res,cur);
}
return res;
};
cout<<dfs(x,y,0,0,-1,-1,0,dfs)<<"\n";
}
int main() {
ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL);
int t;
cin>>t;
assert(t<=1000);
int tot=0;
while(t--) {
solve(tot);
}
assert(tot<=5000);
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif
struct input_checker {
string buffer;
int pos;
const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
const string number = "0123456789";
const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
const string lower = "abcdefghijklmnopqrstuvwxyz";
input_checker() {
pos = 0;
while (true) {
int c = cin.get();
if (c == -1) {
break;
}
buffer.push_back((char) c);
}
}
string readOne() {
assert(pos < (int) buffer.size());
string res;
while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
res += buffer[pos];
pos++;
}
return res;
}
string readString(int min_len, int max_len, const string& pattern = "") {
assert(min_len <= max_len);
string res = readOne();
assert(min_len <= (int) res.size());
assert((int) res.size() <= max_len);
for (int i = 0; i < (int) res.size(); i++) {
assert(pattern.empty() || pattern.find(res[i]) != string::npos);
}
return res;
}
int readInt(int min_val, int max_val) {
assert(min_val <= max_val);
int res = stoi(readOne());
assert(min_val <= res);
assert(res <= max_val);
return res;
}
long long readLong(long long min_val, long long max_val) {
assert(min_val <= max_val);
long long res = stoll(readOne());
assert(min_val <= res);
assert(res <= max_val);
return res;
}
vector<int> readInts(int size, int min_val, int max_val) {
assert(min_val <= max_val);
vector<int> res(size);
for (int i = 0; i < size; i++) {
res[i] = readInt(min_val, max_val);
if (i != size - 1) {
readSpace();
}
}
return res;
}
vector<long long> readLongs(int size, long long min_val, long long max_val) {
assert(min_val <= max_val);
vector<long long> res(size);
for (int i = 0; i < size; i++) {
res[i] = readLong(min_val, max_val);
if (i != size - 1) {
readSpace();
}
}
return res;
}
void readSpace() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == ' ');
pos++;
}
void readEoln() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == '\n');
pos++;
}
void readEof() {
assert((int) buffer.size() == pos);
}
};
struct dsu {
vector<int> p;
vector<int> sz;
int n;
dsu(int _n) : n(_n) {
p = vector<int>(n);
iota(p.begin(), p.end(), 0);
sz = vector<int>(n, 1);
}
inline int get(int x) {
if (p[x] == x) {
return x;
} else {
return p[x] = get(p[x]);
}
}
inline bool unite(int x, int y) {
x = get(x);
y = get(y);
if (x == y) {
return false;
}
p[x] = y;
sz[y] += sz[x];
return true;
}
inline bool same(int x, int y) {
return (get(x) == get(y));
}
inline int size(int x) {
return sz[get(x)];
}
inline bool root(int x) {
return (x == get(x));
}
};
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
input_checker in;
int tt = in.readInt(1, 1000);
in.readEoln();
int sn = 0;
while (tt--) {
int n = in.readInt(1, 5000);
in.readSpace();
int a = in.readInt(1, n);
in.readSpace();
int b = in.readInt(1, n);
in.readEoln();
a--;
b--;
vector<long long> p = in.readLongs(n, 1, 1e9);
in.readEoln();
vector<vector<int>> g(n);
dsu uf(n);
for (int i = 0; i < n - 1; i++) {
int x = in.readInt(1, n);
in.readSpace();
int y = in.readInt(1, n);
in.readEoln();
x--;
y--;
uf.unite(x, y);
g[x].emplace_back(y);
g[y].emplace_back(x);
}
assert(uf.size(0) == n);
function<int(int, int, int, int, long long, long long)> Dfs = [&](int va, int vb, int pa, int pb, long long ca, long long cb) {
int res = 0;
for (int toa : g[va]) {
if (toa == pa) {
continue;
}
long long da = ca + p[toa];
int t = 1e9;
int s = 0;
for (int tob : g[vb]) {
if (tob == pb) {
continue;
}
s = 1;
long long db = cb + p[tob];
t = min(t, (da > db) + Dfs(toa, tob, va, vb, da, db));
}
res = max(res, t * s);
}
return res;
};
cout << (p[a] > p[b]) + Dfs(a, b, -1, -1, p[a], p[b]) << '\n';
}
assert(sn <= 5000);
in.readEof();
return 0;
}
Editorialist's code (Python)
def bfs(adj, par, pref, val, src):
par[src] = -1
pref[src] = val[src]
vertices = [src]
for u in vertices:
for v in adj[u]:
if par[u] == v: continue
par[v] = u
pref[v] = pref[u] + val[v]
vertices.append(v)
for _ in range(int(input())):
n, a, b = map(int, input().split())
val = list(map(int, input().split()))
adj = [[] for _ in range(n)]
for i in range(n-1):
x, y = map(int, input().split())
adj[x-1].append(y-1)
adj[y-1].append(x-1)
parA, parB = [0]*n, [0]*n
prefA, prefB = [0]*n, [0]*n
bfs(adj, parA, prefA, val, a-1)
bfs(adj, parB, prefB, val, b-1)
def go(x, y):
add = 0
if prefA[x] > prefB[y]: add = 1
mx = 0
for u in adj[x]:
if parA[x] == u: continue
mn = 10 ** 6
for v in adj[y]:
if parB[y] == v: continue
mn = min(mn, go(u, v))
if mn == 10 ** 6: mn = 0
mx = max(mx, mn)
return add + mx
print(go(a-1, b-1))