PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: am_aadvik
Tester: raysh07
Editorialist: iceknight1093
DIFFICULTY:
Easy
PREREQUISITES:
DFS, binary exponentiation
PROBLEM:
You’re given a tree, where each vertex has a value A_i, either positive or negative (but not 0).
The score of a walk on the tree is the product of:
- 2^{A_i} for all vertices i with A_i \gt 0 that are visited at least once; and
- -2^{-A_i} for all vertices i with A_i \lt 0 that are visited at least once.
Find the maximum possible score of a walk on the tree.
EXPLANATION:
Since we’re only multiplying powers of 2 (and maybe -1), the score of any walk will also be some power of 2 in magnitude; though it may be positive or negative.
So, while the actual values may become very large, we are still always able to compare the values by just comparing their powers.
This means our goal is as follows:
- Ensure that the product is positive, and not negative.
- Among positive products, maximize the exponent of 2.
Note that the final answer can never be negative: if there exists a positive A_i then surely the length-1 walk at that vertex has a positive score; and if all values are negative, then N \le 2 ensures that we can take any length 2 walk and reach a positive score.
Let’s first focus on making the product positive.
The product of several non-zero numbers is positive if and only if an even number of them are negative.
So, we’re only interested in walks that touch an even number of negative values.
However, since each new vertex we visit increases the magnitude of the product (ignoring the sign), we should also aim to visit as many vertices as possible.
In particular, observe that if in the entire array A the number of negative elements is even, it’s then optimal to just visit all the vertices of the tree on our walk.
This way, we multiply absolutely everything together, so the final exponent is
|A_1| + |A_2| + |A_3| + \ldots + |A_N|.
Further, the resulting value is positive; so this is clearly optimal.
That leaves us with the case where the number of negative elements is odd.
Here, it’s surely impossible to visit all vertices; so we need to decide which ones to skip.
In particular, to obtain a positive product in the end, we need to skip visiting an odd number of negative-valued vertices.
Now, in an ideal situation we’d skip visiting exactly one negative-valued vertex.
However, this is not always possible: because we have a tree, if we decide to never visit some vertex v, we’ll also be unable to visit any vertex on the other ‘side’ of v; and this ‘other side’ may contain both positive and negative values (so there’s no guarantee we end up with a positive product either.)
Luckily for us, it can be proved that there exists an optimal solution in which we don’t visit exactly one vertex with negative value.
Proof
Suppose we don’t visit v, and this causes us to be unable to reach other negative-valued vertices.
There then exists some negative-valued vertex w such that the path from v to w doesn’t contain any other negative-valued vertex on it.
We can then simply extend our walk to visit both v and w (via the path between them); which strictly improves the answer while keeping it positive.So, an optimal solution must skip exactly one negative-valued vertex.
(This is technically not a formal proof since we didn’t talk about reaching v from the walk, but it’s easy enough to extend it to a formal proof by taking an optimal walk, and then assuming there are \ge 2 unvisited negative-valued vertices, taking a couple of the closest ones to the walk, and doing a bit of casework on their orientation with respect to the walk. The details are left as an exercise.)
With this in mind, let’s look at some vertex v such that A_v \lt 0.
If we don’t visit v, what’s the best we can do?
Well, if we delete v from the tree, several smaller components will form (each of which is itself a tree).
Since we’re not visiting v, we must start in one of these components; and then cannot leave it.
Since we already proved that exactly one negative-valued vertex will be skipped in an optimal solution, and we’re skipping v, we thus must start in one of these components and then take the entire component.
So, the optimal component to start in is just the one with largest absolute sum of A_i values.
(Note that technically, we need to care about whether the component has an even or odd number of negative values to ensure a positive product; but we don’t actually need to check for this since the optimal solution will have an even count for sure, and components with odd counts cannot surpass the optimal solution.)
The question now is how to compute the above quickly, for a removed vertex v.
To do that, observe that if we root the tree at some vertex (say 1), then upon deleting v the components formed are exactly:
- The entire subtrees of each of the children of v, and
- The entire tree minus the subtree of v.
So, if we define dp_u to be the sum of absolute values within the subtree of u (which is easy to compute for all vertices with a simple DFS), the answer when removing a vertex v is the maximum of:
- dp_c for each child of c, and
- dp_1 - dp_v, corresponding to taking everything except the subtree of v.
We can then take the maximum of all answers, leading to an algorithm that’s linear time overall.
Once the optimal exponent is known, say M, we want to compute the actual answer 2^M \pmod {998244343}.
This can be done in \mathcal{O}(\log M) time using binary exponentiation.
TIME COMPLEXITY:
\mathcal{O}(N + \log(N\cdot\max|A|)) per testcase.
CODE:
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18
mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());
void Solve()
{
int n; cin >> n;
vector <int> a(n + 1);
for (int i = 1; i <= n; i++){
cin >> a[i];
}
vector<vector<int>> adj(n + 1);
for (int i = 1; i < n; i++){
int u, v; cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
}
int par = 0;
for (int i = 1; i <= n; i++){
par ^= a[i] < 0;
}
int best = 0;
int sum = 0;
for (int i = 1; i <= n; i++){
sum += abs(a[i]);
}
if (par == 0){
best = sum;
} else {
vector <int> b(n + 1), c(n + 1);
auto dfs = [&](auto self, int u, int par) -> void{
b[u] += abs(a[u]);
c[u] ^= a[u] < 0;
for (int v : adj[u]){
if (v != par){
self(self, v, u);
b[u] += b[v];
c[u] ^= c[v];
}
}
if (c[u] == 0){
best = max(best, b[u]);
}
if (c[u] == 1){
best = max(best, sum - b[u]);
}
};
dfs(dfs, 1, -1);
}
const int mod = 998244353;
int ans = 1, x = 2;
for (int i = 0; i < 60; i++){
if (best >> i & 1){
ans *= x;
ans %= mod;
}
x *= x;
x %= mod;
}
cout << ans << "\n";
}
int32_t main()
{
auto begin = std::chrono::high_resolution_clock::now();
ios_base::sync_with_stdio(0);
cin.tie(0);
int t = 1;
// freopen("in", "r", stdin);
// freopen("out", "w", stdout);
cin >> t;
for(int i = 1; i <= t; i++)
{
//cout << "Case #" << i << ": ";
Solve();
}
auto end = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n";
return 0;
}