PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: etherinmatic
Tester: airths
Editorialist: iceknight1093
DIFFICULTY:
TBD
PREREQUISITES:
DFS/BFS
PROBLEM:
You’re given a tree on N vertices. Vertex i has value A_i.
Repeatedly perform the following process:
- Choose a vertex u whose degree is odd, and add A_u to your score.
Then, delete vertex u.
Find a sequence of operations that maximizes your score.
EXPLANATION:
All the A_i values are positive, so it’s in our best interest to delete as many vertices as possible.
It’s not possible to delete all N vertices: when only a single vertex remains, it’ll have degree 0 (which isn’t odd) and so can’t be deleted.
The next best option is to try and delete N-1 vertices, leaving only a single vertex remaining.
It turns out that this is always possible - in fact, we can choose any vertex u to be the last remaining vertex!
Proof
Fix the vertex u that must remain in the end.
It’s well-known that any tree with \geq 2 vertices has at least two leaves. (for a simple proof, look at the sum of degrees and what happens if N-1 of them are \geq 2.)
In particular, as long as at least two vertices remain, there will definitely exist a leaf that isn’t u.
Let this be v.
Note that v, being a leaf, has degree 1 (which is odd).
Delete v and repeat the process - since we’re deleting a leaf, the graph continues to remain a tree.
This process will continue till the tree has only a single vertex remaining, which by construction is u.
To maximize the score, it’s clearly best to leave the vertex with smallest A_i value and delete everything else.
A fairly simple way to implement this is as follows:
- Find m, the vertex such that A_m = \min(A).
- Root the tree at m, and perform a DFS or BFS traversal of the tree.
- Then, delete vertices in reverse order of the traversal.
TIME COMPLEXITY:
\mathcal{O}(N) per testcase.
CODE:
Author's code (C++)
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define sz(x) static_cast<int>((x).size())
#define all(x) begin(x), end(x)
const int mod = 1e9 + 7;
void solve(){
int n; cin >> n;
vector<int> a(n);
for (auto &x : a) cin >> x;
vector<vector<int>> g(n);
for (int e = 0; e < n - 1; ++e){
int u, v; cin >> u >> v;
--u, --v;
g[u].push_back(v);
g[v].push_back(u);
}
int root = -1, mn = *min_element(all(a));
for (int i = 0; i < n; ++i) {
if (a[i] == mn) root = i;
}
cout << n - 1 << "\n";
auto dfs = [&](const auto self, int u, int p) -> void{
for (auto &v : g[u]){
if (v == p) continue;
self(self, v, u);
}
if (u != root) cout << u + 1 << ' ';
};
dfs(dfs, root, -1);
cout << "\n";
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0);
int t = 1;
cin >> t;
while(t--) solve();
}
Tester's code (C++)
/*
*
* ^v^
*
*/
#include <iostream>
#include <numeric>
#include <set>
#include <cctype>
#include <iomanip>
#include <chrono>
#include <queue>
#include <string>
#include <vector>
#include <functional>
#include <tuple>
#include <map>
#include <bitset>
#include <algorithm>
#include <array>
#include <random>
#include <cassert>
using namespace std;
using ll = long long int;
using ld = long double;
#define iamtefu ios_base::sync_with_stdio(false); cin.tie(0);
mt19937 rng(chrono::high_resolution_clock::now().time_since_epoch().count());
struct input_checker {
string buffer;
int pos;
const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
const string number = "0123456789";
const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
const string lower = "abcdefghijklmnopqrstuvwxyz";
input_checker() {
pos = 0;
while (true) {
int c = cin.get();
if (c == -1) {
break;
}
buffer.push_back((char) c);
}
}
int nextDelimiter() {
int now = pos;
while (now < (int) buffer.size() && !isspace(buffer[now])) {
now++;
}
return now;
}
string readOne() {
assert(pos < (int) buffer.size());
int nxt = nextDelimiter();
string res;
while (pos < nxt) {
res += buffer[pos];
pos++;
}
return res;
}
string readString(int minl, int maxl, const string &pattern = "") {
assert(minl <= maxl);
string res = readOne();
assert(minl <= (int) res.size());
assert((int) res.size() <= maxl);
for (int i = 0; i < (int) res.size(); i++) {
assert(pattern.empty() || pattern.find(res[i]) != string::npos);
}
return res;
}
int readInt(int minv, int maxv) {
assert(minv <= maxv);
int res = stoi(readOne());
assert(minv <= res);
assert(res <= maxv);
return res;
}
long long readLong(long long minv, long long maxv) {
assert(minv <= maxv);
long long res = stoll(readOne());
assert(minv <= res);
assert(res <= maxv);
return res;
}
auto readInts(int n, int minv, int maxv) {
assert(n >= 0);
vector<int> v(n);
for (int i = 0; i < n; ++i) {
v[i] = readInt(minv, maxv);
if (i+1 < n) readSpace();
}
return v;
}
auto readLongs(int n, long long minv, long long maxv) {
assert(n >= 0);
vector<long long> v(n);
for (int i = 0; i < n; ++i) {
v[i] = readLong(minv, maxv);
if (i+1 < n) readSpace();
}
return v;
}
void readSpace() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == ' ');
pos++;
}
void readEoln() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == '\n');
pos++;
}
void readEof() {
assert((int) buffer.size() == pos);
}
};
void scn(){
// not necessarily distinct
// right down ytdm
input_checker inp = input_checker();
int t;
t = inp.readInt(1, 10'000);
inp.readEoln();
int totn = 0;
while (t--){
ll n;
n = inp.readInt(2, 200'000);
totn+=n;
inp.readEoln();
vector <ll> a(n);
int mn = 1e9, wh = 0;
for (int i=0; i<n; i++){
a[i] = inp.readInt(1, 1'000'000'000);
if (i+1<n){
inp.readSpace();
}
if (mn>a[i]){
mn = a[i];
wh = i+1;
}
}
vector <ll> ads(n+1, 0), szds(n+1, 1);
iota(ads.begin(), ads.end(), 0);
auto pr=[&](ll i){
while (ads[i]!=i){
ads[i]=ads[ads[i]];
i = ads[i];
}
return i;
};
auto un=[&](ll u, ll v){
u = pr(u), v = pr(v);
if (u!=v){
if (szds[u]>szds[v]){
szds[u]+=szds[v];
ads[v] = u;
} else {
szds[v]+=szds[u];
ads[u] = v;
}
}
};
inp.readEoln();
vector <vector <int>> ed(n+1);
vector <int> deg(n+1);
for (int i=0; i+1<n; i++){
int u, v;
u = inp.readInt(1, n);
inp.readSpace();
v = inp.readInt(1, n);
inp.readEoln();
un(u, v);
ed[u].push_back(v);
ed[v].push_back(u);
deg[u]++;
deg[v]++;
}
assert(szds[pr(1)]==n);
set <pair<int,int>> st;
for (int i=1; i<=n; i++){
st.insert({deg[i], i});
}
vector <int> ans;
while (!st.empty()){
auto [d, ind] = *st.begin();
st.erase(st.begin());
if (deg[ind]!=d || ind==wh){
continue;
}
ans.push_back(ind);
for (auto &x:ed[ind]){
deg[x]--;
if (deg[x]==1){
st.insert({d, x});
}
}
}
cout<<ans.size()<<'\n';
for (int i=0; i<ans.size(); i++){
cout<<ans[i]<<" \n"[i+1==ans.size()];
}
}
inp.readEof();
assert(totn>=2 && totn<=200'000);
}
int main(){
iamtefu;
#if defined(airths)
auto t1=chrono::high_resolution_clock::now();
freopen("input.txt", "r", stdin);
freopen("output.txt", "w", stdout);
#endif
// int _; for(cin>>_; _--;)
{
scn();
}
#if defined(airths)
auto t2=chrono::high_resolution_clock::now();
ld ti=chrono::duration_cast<chrono::nanoseconds>(t2-t1).count();
ti*=1e-6;
cerr<<"Time: "<<setprecision(12)<<ti;
cerr<<"ms\n";
#endif
return 0;
}
Editorialist's code (Python)
for _ in range(int(input())):
n = int(input())
a = [10**9 + 7] + list(map(int, input().split()))
adj = [ [] for _ in range(n+1) ]
for i in range(n-1):
u, v = map(int, input().split())
adj[u].append(v)
adj[v].append(u)
root = a.index(min(a))
mark = [0]*(n+1)
mark[root] = 1
que = [root]
for u in que:
for v in adj[u]:
if mark[v] == 0:
mark[v] = 1
que.append(v)
print(n-1)
print(*reversed(que[1:]))