FIREWORKS - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: wuhudsm
Tester: apoorv_me
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

DFS

PROBLEM:

You have a tree on N vertices.
An x-rooted fireworks is a connected subset of the tree containing x, such that:

  • Every vertex other than x in the subset has at most two neighbors in the subset.
  • Every vertex other than x in the subset has a degree different from x (in the original tree).

For each x, find the maximum size of an x-rooted fireworks.

EXPLANATION:

Let’s first try to solve the task for a single vertex x.
An x-rooted fireworks cannot include any vertex whose degree is the same as x, so as a first step, we can just delete all such vertices from the tree.
This leaves us with a smaller tree containing x.

In this smaller tree, we need to now choose a connected set containing x, such that every vertex other than x has \leq 2 of its neighbors chosen.
It’s not hard to see that such a graph looks like a star centered at x - that is, a set of edge-disjoint paths spreading out from x.

Computing the maximum size of such a set is not too hard in linear time.
For instance, one method is as follows:

  • First, root the (smaller) tree at x.
  • Now, let \text{mx}_u denote the maximum length of a path (where length = number of vertices) that starts at u and goes into its subtree.
    It’s easy to see that \text{mx}_u = 1 + \max(\text{mx}_v) across all v that are children of u; since a path starting at u can only extend into one child.
  • Finally, the maximum size of an x-rooted fireworks equals the sum of \text{mx}_u across all children of x, plus 1 (to include x itself).

This is easily implemented in linear time with a simple DFS.
Of course, this is also much too slow to run for each vertex separately, being \mathcal{O}(N^2).


To optimize this idea, we’ll need to parallelize some computations.
Specifically, notice that when solving for two vertices u and v with the same degree, we remove almost the same set of vertices from the tree - the only difference is that when processing u we remove v, and vice versa.
A different way to look at this, is that we remove every vertex of a given degree, and only add back in the one we’re dealing with.

So, that’s exactly what we’ll do!
Fix a degree d, and delete all vertices with degree d from the tree.
This breaks the tree into a bunch of smaller trees.
Now, when processing a vertex u with degree d, we want to do the following:

  • Let v be a neighbor of u (whose degree isn’t d).
    v will lie in one of the smaller trees.
    Note that different neighbors of u will lie inside different small trees, so they’re essentially all independent.
  • We now want to find the longest path starting at v and lying fully within this smaller tree.

The latter is a rather standard task (CSES Tree Distances 1), and has a variety of solutions.
One solution is to find the diameter of the tree, and then note that the longest distance from any vertex must be one of the diameter endpoints, so you can precompute distances from the diameter endpoints.
Alternately, tree DP along with rerooting (tutorial) will get you the answer too.

Note that for a fixed degree d, we can solve for all vertices with that degree in \mathcal{O}(N) time total - precomputation of maximum distances for each smaller tree takes \mathcal{O}(N) time overall (since the sum of their sizes is \lt N), and then finding the answer for each vertex with degree d takes \mathcal{O}(N) time too, since we only look at their neighbors (so each edge is looked at twice at most).

While this seems like it’s \mathcal{O}(N^2), it in fact isn’t!
Note that the sum of the degrees of all vertices of the tree equals 2\cdot (N-1), which means that there can be at most \sqrt {4N} distinct degrees among the vertices (since 1+2+3+\ldots +k = \frac{k\cdot (k+1)}{2})

So, by just skipping degrees that don’t have any vertices, the complexity is \mathcal{O}(N\sqrt N), which is fast enough.

TIME COMPLEXITY:

\mathcal{O}(N\sqrt N) per testcase.

CODE:

Author's code (C++)
#include <map>
#include <set>
#include <cmath>
#include <ctime>
#include <queue>
#include <stack>
#include <cstdio>
#include <cstdlib>
#include <vector>
#include <cstring>
#include <algorithm>
#include <iostream>
using namespace std;
typedef double db; 
typedef long long ll;
typedef unsigned long long ull;
const int N=1000010;
const int LOGN=28;
const ll  TMD=0;
const ll  INF=2147483647;
int n;
int mx1[N],mx2[N],dep[N],deg[N],ans[N];
vector<int> G[N];

void DFS(int x,int pr,int d,int key)
{
	dep[x]=d;
	mx1[x]=mx2[x]=1;
	for(int i=0;i<G[x].size();i++)
	{
		int y=G[x][i];
		if(y==pr) continue;
		DFS(y,x,d+1,key);
		if(mx1[y]+1>mx1[x])      mx2[x]=mx1[x],mx1[x]=mx1[y]+1;
		else if(mx1[y]+1>mx2[x]) mx2[x]=mx1[y]+1;
	}
	if(G[x].size()==key) mx1[x]=mx2[x]=0;
}

void DFS2(int x,int pr,int mxv,int key)
{
	int sum=0,nxtv;
	for(int i=0;i<G[x].size();i++)
	{
		int y=G[x][i];
		if(y==pr) continue;
		sum+=mx1[y];
		if(G[x].size()==key) nxtv=-dep[x];
	 	else
	 	{
	 		if(mx1[y]+1==mx1[x]) nxtv=max(mxv,mx2[x]-dep[x]);
	 		else                 nxtv=max(mxv,mx1[x]-dep[x]);
	 	}
		DFS2(y,x,nxtv,key);
	}
	if(G[x].size()==key) ans[x]=sum+dep[x]+mxv;
} 

int main()
{
	int T; 
	scanf("%d",&T);
	
	while(T--)
	{
	scanf("%d",&n);
	for(int i=1;i<=n;i++) G[i].clear();
	for(int i=1;i<n;i++)
	{
		int u,v;
		scanf("%d%d",&u,&v);
		G[u].push_back(v);
		G[v].push_back(u);
	}
	for(int i=1;i<=n;i++) deg[i]=G[i].size();
	sort(deg+1,deg+n+1);
	//
	//int dk=0,st=clock();
	//
	for(int i=1;i<=n;i++) 
	{
		if(deg[i]!=deg[i-1])
		{
			//dk++;
			DFS(1,0,0,deg[i]);
			DFS2(1,0,1,deg[i]);
		}
	}
	//
	//printf("dk=%d\n time=%dms\n",dk,clock()-st);
	//
	for(int i=1;i<=n;i++) printf("%d%c",ans[i],i==n?'\n':' ');
	}
	
	//fclose(stdin);
	return 0;
}

Tester's code (C++)
#include<bits/stdc++.h>
using namespace std;

struct input_checker {
    string buffer;
    int pos;
 
    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";
 
    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }
 
    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
            now++;
        }
        return now;
    }
 
    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        return res;
    }
 
    string readString(int minl, int maxl, const string &pattern = "") {
        assert(minl <= maxl);
        string res = readOne();
        assert(minl <= (int) res.size());
        assert((int) res.size() <= maxl);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }
 
    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res = stoi(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }
 
    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res = stoll(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }
 
    auto readInts(int n, int minv, int maxv) {
        assert(n >= 0);
        vector<int> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readInt(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }
 
    auto readLongs(int n, long long minv, long long maxv) {
        assert(n >= 0);
        vector<long long> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readLong(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }
 
    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }
 
    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }
 
    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

#ifdef LOCAL
#define dbg(x...) cerr << "[" << #x << "] = ["; _print(x)
#else
#define dbg(...)
#endif
 
void __print(int32_t x) {cerr << x;}
void __print(int64_t x) {cerr << x;}
void __print(unsigned x) {cerr << x;}
void __print(unsigned long x) {cerr << x;}
void __print(unsigned long long x) {cerr << x;}
void __print(float x) {cerr << x;}
void __print(double x) {cerr << x;}
void __print(long double x) {cerr << x;}
void __print(char x) {cerr << '\'' << x << '\'';}
void __print(const char *x) {cerr << '\"' << x << '\"';}
void __print(string x) {cerr << '\"' << x << '\"';}
void __print(bool x) {cerr << (x ? "true" : "false");}
template<typename T>void __print(complex<T> x) {cerr << '{'; __print(x.real()); cerr << ','; __print(x.imag()); cerr << '}';}
 
template<typename T>
void __print(const T &x);
template<typename T, typename V>
void __print(const pair<T, V> &x) {cerr << '{'; __print(x.first); cerr << ','; __print(x.second); cerr << '}';}
template<typename T>
void __print(const T &x) {int f = 0; cerr << '{'; for (auto it = x.begin() ; it != x.end() ; it++) cerr << (f++ ? "," : ""), __print(*it); cerr << "}";}
void _print() {cerr << "]\n";}
template <typename T, typename... V>
void _print(T t, V... v) {__print(t); if (sizeof...(v)) cerr << ", "; _print(v...);}

// input_checker inp;

int32_t main() {
	ios_base::sync_with_stdio(0);
	cin.tie(0);

	int sumN = 0;
	auto __solve_testcase = [&](int test) {
// 		int N = inp.readInt(1, (int)1e5);	inp.readEoln();		sumN += N;

		int N;	cin >> N;

		vector<vector<int>> adj(N), D(N);
		for(int i = 1 ; i < N ; ++i) {
			// int u = inp.readInt(1, N);	inp.readSpace();
			// int v = inp.readInt(1, N);	inp.readEoln();
			int u, v;	cin >> u >> v;
			adj[u - 1].push_back(v - 1);
			adj[v - 1].push_back(u - 1);
		}
		for(int i = 0 ; i < N ; ++i)
			D[adj[i].size()].push_back(i);
		vector<int> blocked(N), marked(N), dis(N, -1), da(N, -1), db(N, -1);

		auto bfs = [&](int node, vector<int> &dis) {
			vector<int> que(1, node);	dis[node] = 0;
			for(int i = 0 ; i < (int)que.size() ; ++i) {
				int nd = que[i];
				assert(!blocked[nd]);
				for(auto &u: adj[nd]) if(dis[u] == -1) {
					if(blocked[u])	continue;
					dis[u] = dis[nd] + 1;
					que.push_back(u);
				}
			}
			return que.back();
		};

		auto solve = [&](int node) {
			bfs(bfs(bfs(node, dis), da), db);
		};

		vector<int> res(N, 1);

		for(int i = 1 ; i < N ; ++i) if((int)D[i].size()) {
			for(auto &x: D[i])	blocked[x] = 1;

			for(int i = 0 ; i < N ; ++i) if(!blocked[i] && dis[i] == -1) {
				solve(i);
			}

			for(auto &x: D[i]) {
				for(auto &u: adj[x]) if(!blocked[u]) {
					res[x] += max(da[u], db[u]) + 1;	
				}
				blocked[x] = 0;
			}
			for(int i = 0 ; i < N ; ++i) {
				dis[i] = -1, da[i] = -1, db[i] = -1;
			}
		}
		for(int i = 0 ; i < N ; ++i)
			cout << res[i] << " \n"[i == N - 1];
	};
	
	int NumTest = 1;	cin >> NumTest;
	// NumTest = inp.readInt(1, (int)1e4); inp.readEoln();
	for(int testno = 1; testno <= NumTest ; ++testno) {
		__solve_testcase(testno);
	}
	assert(sumN <= (int)5e5);

	// inp.readEof();
	
	return 0;
}

Editorialist's code (C++)

dealing with each degree as a seperate group is what I missed. thanks for the problem.