TREEREMOVAL - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: etherinmatic
Tester: airths
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

DFS/BFS

PROBLEM:

You’re given a tree on N vertices. Vertex i has value A_i.
Repeatedly perform the following process:

  • Choose a vertex u whose degree is odd, and add A_u to your score.
    Then, delete vertex u.

Find a sequence of operations that maximizes your score.

EXPLANATION:

All the A_i values are positive, so it’s in our best interest to delete as many vertices as possible.

It’s not possible to delete all N vertices: when only a single vertex remains, it’ll have degree 0 (which isn’t odd) and so can’t be deleted.
The next best option is to try and delete N-1 vertices, leaving only a single vertex remaining.
It turns out that this is always possible - in fact, we can choose any vertex u to be the last remaining vertex!

Proof

Fix the vertex u that must remain in the end.

It’s well-known that any tree with \geq 2 vertices has at least two leaves. (for a simple proof, look at the sum of degrees and what happens if N-1 of them are \geq 2.)
In particular, as long as at least two vertices remain, there will definitely exist a leaf that isn’t u.
Let this be v.

Note that v, being a leaf, has degree 1 (which is odd).
Delete v and repeat the process - since we’re deleting a leaf, the graph continues to remain a tree.
This process will continue till the tree has only a single vertex remaining, which by construction is u.

To maximize the score, it’s clearly best to leave the vertex with smallest A_i value and delete everything else.

A fairly simple way to implement this is as follows:

  • Find m, the vertex such that A_m = \min(A).
  • Root the tree at m, and perform a DFS or BFS traversal of the tree.
  • Then, delete vertices in reverse order of the traversal.

TIME COMPLEXITY:

\mathcal{O}(N) per testcase.

CODE:

Author's code (C++)
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define sz(x) static_cast<int>((x).size())
#define all(x) begin(x), end(x)
const int mod = 1e9 + 7;
 
 
void solve(){
 
    int n; cin >> n;
 
    vector<int> a(n);
    for (auto &x : a) cin >> x;
 
    vector<vector<int>> g(n);
    for (int e = 0; e < n - 1; ++e){
        int u, v; cin >> u >> v;
        --u, --v;
        g[u].push_back(v);
        g[v].push_back(u);
    }
 
    int root = -1, mn = *min_element(all(a));
    for (int i = 0; i < n; ++i) {
        if (a[i] == mn) root = i;
    }
    
    cout << n - 1 << "\n";
 
    auto dfs = [&](const auto self, int u, int p) -> void{
        for (auto &v : g[u]){
            if (v == p) continue;
            self(self, v, u);
        }
        if (u != root) cout << u + 1 << ' ';
    };
 
    dfs(dfs, root, -1);
    cout << "\n";
 
}
 
 
signed main(){
 
    ios::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    cin >> t;
    while(t--) solve();
}
Tester's code (C++)
/*
 * 
 * 	^v^
 * 
 */
#include <iostream>
#include <numeric>
#include <set>
#include <cctype>
#include <iomanip>
#include <chrono>
#include <queue>
#include <string>
#include <vector>
#include <functional>
#include <tuple>
#include <map>
#include <bitset>
#include <algorithm>
#include <array>
#include <random>
#include <cassert>

using namespace std;

using ll = long long int;
using ld = long double;

#define iamtefu ios_base::sync_with_stdio(false); cin.tie(0);

mt19937 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

struct input_checker {
	string buffer;
	int pos;

	const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
	const string number = "0123456789";
	const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
	const string lower = "abcdefghijklmnopqrstuvwxyz";

	input_checker() {
		pos = 0;
		while (true) {
			int c = cin.get();
			if (c == -1) {
				break;
			}
			buffer.push_back((char) c);
		}
	}

	int nextDelimiter() {
		int now = pos;
		while (now < (int) buffer.size() && !isspace(buffer[now])) {
			now++;
		}
		return now;
	}

	string readOne() {
		assert(pos < (int) buffer.size());
		int nxt = nextDelimiter();
		string res;
		while (pos < nxt) {
			res += buffer[pos];
			pos++;
		}
		return res;
	}

	string readString(int minl, int maxl, const string &pattern = "") {
		assert(minl <= maxl);
		string res = readOne();
		assert(minl <= (int) res.size());
		assert((int) res.size() <= maxl);
		for (int i = 0; i < (int) res.size(); i++) {
			assert(pattern.empty() || pattern.find(res[i]) != string::npos);
		}
		return res;
	}

	int readInt(int minv, int maxv) {
		assert(minv <= maxv);
		int res = stoi(readOne());
		assert(minv <= res);
		assert(res <= maxv);
		return res;
	}

	long long readLong(long long minv, long long maxv) {
		assert(minv <= maxv);
		long long res = stoll(readOne());
		assert(minv <= res);
		assert(res <= maxv);
		return res;
	}

	auto readInts(int n, int minv, int maxv) {
		assert(n >= 0);
		vector<int> v(n);
		for (int i = 0; i < n; ++i) {
			v[i] = readInt(minv, maxv);
			if (i+1 < n) readSpace();
		}
		return v;
	}

	auto readLongs(int n, long long minv, long long maxv) {
		assert(n >= 0);
		vector<long long> v(n);
		for (int i = 0; i < n; ++i) {
			v[i] = readLong(minv, maxv);
			if (i+1 < n) readSpace();
		}
		return v;
	}

	void readSpace() {
		assert((int) buffer.size() > pos);
		assert(buffer[pos] == ' ');
		pos++;
	}

	void readEoln() {
		assert((int) buffer.size() > pos);
		assert(buffer[pos] == '\n');
		pos++;
	}

	void readEof() {
		assert((int) buffer.size() == pos);
	}
};

void scn(){
	// not necessarily distinct
	// right down ytdm
	
	input_checker inp = input_checker();
	int t;
	t = inp.readInt(1, 10'000);
	inp.readEoln();
	int totn = 0;
	while (t--){
		ll n;
		n = inp.readInt(2, 200'000);
		totn+=n;
		inp.readEoln();
		vector <ll> a(n);
		int mn = 1e9, wh = 0;
		for (int i=0; i<n; i++){
			a[i] = inp.readInt(1, 1'000'000'000);
			if (i+1<n){
				inp.readSpace();
			}
			if (mn>a[i]){
				mn = a[i];
				wh = i+1;
			}
		}
		vector <ll> ads(n+1, 0), szds(n+1, 1);
		iota(ads.begin(), ads.end(), 0);
		auto pr=[&](ll i){
			while (ads[i]!=i){
				ads[i]=ads[ads[i]];
				i = ads[i];
			}
			return i;
		};
		auto un=[&](ll u, ll v){
			u = pr(u), v = pr(v);
			if (u!=v){
				if (szds[u]>szds[v]){
					szds[u]+=szds[v];
					ads[v] = u;
				} else {
					szds[v]+=szds[u];
					ads[u] = v;
				}
			}
		};
		inp.readEoln();
		vector <vector <int>> ed(n+1);
		vector <int> deg(n+1);
		for (int i=0; i+1<n; i++){
			int u, v;
			u = inp.readInt(1, n);
			inp.readSpace();
			v = inp.readInt(1, n);
			inp.readEoln();
			un(u, v);
			ed[u].push_back(v);
			ed[v].push_back(u);
			deg[u]++;
			deg[v]++;
		}
		assert(szds[pr(1)]==n);
		set <pair<int,int>> st;
		for (int i=1; i<=n; i++){
			st.insert({deg[i], i});
		}
		vector <int> ans;
		while (!st.empty()){
			auto [d, ind] = *st.begin();
			st.erase(st.begin());
			if (deg[ind]!=d || ind==wh){
				continue;
			}
			ans.push_back(ind);
			for (auto &x:ed[ind]){
				deg[x]--;
				if (deg[x]==1){
					st.insert({d, x});
				}
			}
		}
		cout<<ans.size()<<'\n';
		for (int i=0; i<ans.size(); i++){
			cout<<ans[i]<<" \n"[i+1==ans.size()];
		}
	}
	inp.readEof();
	assert(totn>=2 && totn<=200'000);
}

int main(){
	iamtefu;
#if defined(airths)
	auto t1=chrono::high_resolution_clock::now();
	freopen("input.txt", "r", stdin);
	freopen("output.txt", "w", stdout);
#endif
	// int _; for(cin>>_; _--;)
	{
		scn();
	}
#if defined(airths)
	auto t2=chrono::high_resolution_clock::now();
	ld ti=chrono::duration_cast<chrono::nanoseconds>(t2-t1).count();
	ti*=1e-6;
	cerr<<"Time: "<<setprecision(12)<<ti;
	cerr<<"ms\n";
#endif
	return 0;
}
Editorialist's code (Python)
for _ in range(int(input())):
    n = int(input())
    a = [10**9 + 7] + list(map(int, input().split()))
    adj = [ [] for _ in range(n+1) ]
    for i in range(n-1):
        u, v = map(int, input().split())
        adj[u].append(v)
        adj[v].append(u)
    root = a.index(min(a))
    mark = [0]*(n+1)
    mark[root] = 1
    que = [root]
    for u in que:
        for v in adj[u]:
            if mark[v] == 0:
                mark[v] = 1
                que.append(v)
    
    print(n-1)
    print(*reversed(que[1:]))

my approach:
We will take all values except lowest value.
Push deg 1 nodes in a queue,visit their child,subtract degrees,if their final degree is 1 push in queue and repeat(always ignore lowest value node).
But it is giving TLE ,can anyone plz help why?according to my code ,i am visiting each edge twice only.
My code

@iceknight1093 Codechef users would be much more happy if Codechef looks into this issue atleast 1 time :- CHEATING by "type_7_shady"

Channel - https://www.youtube.com/watch?v=zUOzgAUU7bs&ab_channel=Type7Shady

There have been 50+ blogs on this same topic here; why don’t you guys pay any attention?

Every contest on this youtube channel Codes are shown and sent live since last few months totally destroying the Codechef contest worth. Will you guys please take strict action against this? Can you see today’s contest 1000+ people solving the last problem. All limits of cheati9ng have been crossed.

2 Likes

Btw do you see any way to stop this?i always report the channels,still

2 Likes

Only solution is Codechef authorities to formally complain about it to YouTube. They need to be serious about it if they want Codechef to be relevant or else it is just a code copying contest - Who can copy and cheat fast

2 Likes

Can we use a Priority queue to greedily delete the highest value odd vertex and keep on adding the odd vertices since we are only processing odd vertices and we just need to keep a condition size<n-1 cause the lowest value might have been added in pq sometime.
also i was wondering that in the 2nd test case we are given a forest not a tree and we dont know how many edges we will be getting as input because input format says n-1

#include<bits/stdc++.h>
using namespace std;
typedef long long int lli;
// Use DEBUG to print variables
    //DEBUG(n);
    //DEBUG(a);

#define DEBUG(x...) { cout << "(" << #x << ")" << " = ( "; Print(x); }
template <typename T1> void Print(T1 t1) { cout << t1 << " )" << endl; }
template <typename T1, typename... T2>
void Print(T1 t1, T2... t2) { cout << t1 << " , "; Print(t2...); }
vector<vector<int>> adjlist(vector<pair<int,int>>& edges,int n)
{
    vector<vector<int>> list(n);
    for(auto& it : edges)
    {
        list[it.first].push_back(it.second);
        list[it.second].push_back(it.first);
    }

    return list;
}
void solve()
{
    int n;
    cin>>n;
    vector<lli>a(n);
    for(int i=0;i<n;i++)cin>>a[i];

    unordered_map<lli,lli> mp;
    for(int i=0;i<n;i++)
    mp[i]=a[i];
    vector<pair<int,int>> edges;
    for(int i=0;i<n-1;i++)
    {
        int x,y;
        cin>>x>>y;
        x--;
        y--;
        edges.push_back({x,y});
        //edges.push_back({y,x});
    }
    vector<vector<int>> list=adjlist(edges,n);
    vector<lli> degree(n);
    for(int i=0;i<n;i++)
    {
        degree[i]=list[i].size();
       // DEBUG(degree[i]);

    }
    priority_queue<pair<lli,int>> q;

    vector<int> ans;
    vector<int> vis(n,0);
    for(int i=0;i<n;i++)
   { if(degree[i]%2==1)
    {

        q.push({mp[i],i});
        vis[i]=1;
        //DEBUG(i);
        //degree[i]=0;
    }
    }

    lli res=0;

    while(!q.empty()&& ans.size()<n-1)
    {
        auto node=q.top().second;
        lli value=q.top().first;
        vis[node]=1;
        degree[node]=0;
        q.pop();
        ans.push_back(node+1);
       // DEBUG(node+1);
        res+=value;
        //DEBUG(res);
        for(auto& adj:list[node])
        {
            degree[adj]--;
            if(degree[adj]%2==1 && vis[adj]==0)
            {
                q.push({mp[adj],adj});
                vis[adj]=1;
            }

        }
    }

cout<<ans.size()<<endl;
for(int i=0;i<ans.size();i++)
cout<<ans[i]<<" ";
cout<<endl;
}

int main()
{
    int t;
    cin>>t;
    while(t--)
    {
        solve();
    }
    return 0;
}

1 Like

That does not work as it doesn’t guarantee all N-1 nodes will be printed. Solution is always printing N-1 nodes.

can someone please tell,if my approach is right or wrong ?

In line 15 of your code:

      for(ll i=1;i<=n;i++){
      ll temp;cin>>temp;vc.push_back({temp,i});
      sort(vc.begin(),vc.end()); // Here
      }

You need to sort the vector after the loop completes:

      for(ll i=1;i<=n;i++){
      ll temp;cin>>temp;vc.push_back({temp,i});
      }
      sort(vc.begin(),vc.end()); // Here

Tip: ensure proper indentation throughout your code. This will help you identify blocks of code more quickly so you don’t spend time debugging where the issue is.

Accepted Submission: 1067659837

1 Like

Thank you so much!! Now i feel a little sad,i tried tons of ways to do this,but always after the sorting portion