Editorial - CNTIT

PROBLEM LINK:

Practice

Contest: Division 1

Contest: Division 2

Setter: Kasra Mazaheri

Tester: Radoslav Dimitrov

Editorialist: Raja Vardhan Reddy

DIFFICULTY:

Medium-Hard

PREREQUISITES:

BIT, centroid of a tree, basic math,dfs.

PROBLEM:

Given a tree with N vertices (numbered 1 through N), where each edge is coloured in one of K colours (numbered 1 through K). He wants to pick an unordered pair of distinct vertices in the tree and create a multiset S=(S_1,S_2,…,S_K), where S_i is the number of edges on the path between the chosen vertices which have colour i. Now, Roger is wondering: in how many ways can he pick a pair of vertices such that the resulting multiset S is polygonable? Find the number of such pairs. A set S is polygonable if they can form the sides of a polygon(can be degenerate, side length can be zero).

EXPLANATION

A set of values are polygonable if they can form the sides of a polygon (can be degenerate i.e area can be zero and side lengths can also be zero).
Let the set of values be a_1,a_2,...a_k. Let the maximum among them be a_{max}.

Now for them to be polygonable, 2*a_{max}<= \sum_{i=1}^{k} a_i because if maximum side length is more than sum of rest of the sides, there is no way one can join the ends of maximum side using these lengths. So now we want to count the number of paths where the above condition holds.

Let us try to count the number of bad paths and subtract them from the total number of paths (=n*(n-1)/2). For bad path, the value of maximum count will be strictly more than sum of rest of the counts.

Lemma: Each bad path will have a unique color with maximum count.
Proof: Lets say there are two or more vertices with maximum count. Now since for a bad path, maximum value will be more than sum of rest. but here the maximum value is always less than or equal to sum of rest of the values (because our assumption said there is another color with same maximum value which will be part of the rest of the colors other than chosen maximum).

So let us iterate over the colors and find the number of bad paths with the current color’s count being the maximum count on that bad path.

We do not count any bad path multiple times because of above lemma. And also each bad path will be counted under exactly one color. Hence, the sum of bad paths with each color’s count being the maximum count on that bad path will give us the total number of bad paths.

Now, we will discuss how to find number of bad paths with maximum count belonging to a particular color c.

Let us give weight of 1 for all edges with color c and -1 for edges with color other than c.

Now all the paths with edge weight sum positive are bad paths with maximum count belonging to color c.

So now we want to count the number of positive edge sum paths in the tree.

Let us try to solve this.

Let us fix a node x.

Any path in the tree should be passing through x or be completely in a subtree of x's child.

Finding the number of postive edge sum paths passing through x.

Let there be a path (a,b). Now either a or b must be x or a and b must be in different subtrees of children of x.

Case 1: when either a or b is x. Lets say a is x. Now to count number of b which give bad paths, we can do a dfs with x as root to find number of edge sum of path (x,b) which are positive. This will take O(n) time.

Case 2: From case 1, we have edge sum of all paths (x,b). Now every path (a,b) can be seen as (a,x),(x,b). So let us maintain a BIT over edge sum of all (x,b) paths (i.e add 1 at edge sum value of a path).
Now let us iterate over each child of x. And try to count number of bad paths (a,b) with a in this child’s subtree. Let us remove all edge weight sum of paths in this subtree from the BIT.And then iterate over all the nodes in this subtree. Let current node’s edge sum value of path from x to this node is w. Now all the paths (x,b) which have edge weight sum (\gt -1*w) will form bad paths (a,b). Hence, we can use BIT to count such paths.

Now each node’s edge weight sum is inserted and deleted once in the BIT. And for each node we query the BIT once. Hence, we spend O(log(n)) time for each node.

Finding number of bad paths in each child’s subtree.

This is nothing but the same problem on a smaller tree. Hence, we can solve this recursively and when we have only one node in the tree, then there are zero paths.

To compute time complexity, let us try to find the time spent on operations of a particular node. And then take this sum across all the nodes.

We spend O(log(n)) time on the node in each of tree it is present in the recursive calls.

Now, how to find an upper bound on the number of trees a particular node will be present.

Before that, let us always choose node x to be the centroid of the tree because this guarantees that each of child’s subtree size will be less than half of current tree size.

Hence, now each time we process a node, the next time we process that node, it will be in a tree with size less than half of current tree size which means that we will process a particular node atmost log(n) times.

Hence, we will spend O(log(n)^2) time on a particular node when we choose x to be the centroid of the tree.

Hence, time complexity to find number of bad paths with maximum count belonging to a particular color c = O(n*log(n)^2).

Now,computing across all colors will take O(n*k*log(n)^2).

The above complexity is enough to get AC on the problem.

Bonus: try to solve the problem in O(n*k*log(n)) complexity.

Hint1
try to get rid of the BIT.
Hint2
try to use counting sort instead of BIT to reduce time spent on a node in a tree to O(1).

TIME COMPLEXITY

Time complexity = O(n*k*log(n)^2).

SOLUTIONS:

Setter's Solution
// In The Name Of The Queen
#include<bits/stdc++.h>
#define x first
#define y second
using namespace std;
const int N = 200005;
int n, k, M[N], P[N], E[N], H[N], SZ[N], CS[N];
int ts, to[N], nxt[N], head[N];
long long tot;
vector < pair < int , int > > vec, tmp;
vector < pair < int , int > > Adj[N];
inline void Add(int v, int u)
{
    to[ts] = u;
    nxt[ts] = head[v];
    head[v] = ts ++;
}
inline void CountingSort()
{
    int Mn = 0, Mx = 0;
    for (auto u : vec)
        Mn = min(Mn, H[u.x]),
        Mx = max(Mx, H[u.x]);
    ts = 0;
    for (int i = 0; i <= Mx - Mn; i ++)
        head[i] = -1;
    for (int i = 0; i < (int)vec.size(); i ++)
        Add(H[vec[i].x] - Mn, i);
    int ptr = 0;
    tmp.resize(vec.size());
    for (int i = 0; i <= Mx - Mn; i ++)
        for (; ~ head[i]; head[i] = nxt[head[i]])
            tmp[ptr ++] = vec[to[head[i]]];
}
void DFS(int v, int p, int root)
{
    P[v] = p;
    vec.push_back({v, root});
    for (auto u : Adj[v])
        if (!M[u.x] && u.x != p)
            E[u.x] = u.y, DFS(u.x, v, root);
}
inline void Solve(int v)
{
    for (auto u : Adj[v])
        if (!M[u.x])
            E[u.x] = u.y, DFS(u.x, v, u.x);
    H[v] = 0;
    for (int color = 1; color <= k; color ++)
    {
        for (auto u : vec)
            H[u.x] = H[P[u.x]] + (E[u.x] == color ? 1 : -1);
        CountingSort();
        int l = (int)tmp.size() - 1;
        for (int i = 0; i < (int)tmp.size(); i ++)
        {
            pair < int , int > u = tmp[i];
            while (l > i && H[tmp[l].x] >= 1 - H[u.x])
                CS[tmp[l].y] ++, l --;
            while (l < i)
                l ++, CS[tmp[l].y] --;
            tot += (int)tmp.size() - l - 1 - CS[u.y];
            if (H[u.x] >= 1) tot ++;
        }
    }
    vec.clear();
}
void DFSSZ(int v, int p)
{
    SZ[v] = 1;
    for (auto u : Adj[v])
        if (!M[u.x] && u.x != p)
            DFSSZ(u.x, v), SZ[v] += SZ[u.x];
}
inline int Centroid(int v)
{
    DFSSZ(v, 0);
    int p = 0, szn = SZ[v], w = 1;
    while (w)
    {
        w = 0;
        for (auto u : Adj[v])
            if (!M[u.x] && u.x != p && SZ[u.x] * 2 >= szn)
                {p = v; v = u.x; w = 1; break;}
    }
    return (v);
}
inline void Decompose()
{
    queue < int > qu;
    qu.push(1);
    while (qu.size())
    {
        int v = qu.front();
        qu.pop();
        v = Centroid(v);
        M[v] = 1;
        Solve(v);
        for (auto u : Adj[v])
            if (!M[u.x])
                qu.push(u.x);
    }
}
int main()
{
    int q;
    scanf("%d", &q);
    for (; q; q --)
    {
        scanf("%d%d", &n, &k);
        for (int i = 1; i < n; i ++)
        {
            int v, u, c;
            scanf("%d%d%d", &v, &u, &c);
            Adj[v].push_back({u, c});
            Adj[u].push_back({v, c});
        }
        Decompose();
        printf("%lld\n", 1LL * n * (n - 1) / 2 - tot);
 
        for (int i = 1; i <= n; i ++)
            M[i] = 0, Adj[i].clear();
        tot = 0;
    }
    return 0;
} 
Tester's Solution
#include <bits/stdc++.h>
#define endl '\n'
 
#define SZ(x) ((int)x.size())
#define ALL(V) V.begin(), V.end()
#define L_B lower_bound
#define U_B upper_bound
#define pb push_back
 
using namespace std;
template<class T, class T1> int chkmin(T &x, const T1 &y) { return x > y ? x = y, 1 : 0; }
template<class T, class T1> int chkmax(T &x, const T1 &y) { return x < y ? x = y, 1 : 0; }
const int MAXN = (1 << 19);
 
int n, k;
vector<pair<int, int> > adj[MAXN];
 
int read_int();
 
void read() {
	n = read_int();
	k = read_int();
	for(int i = 1; i <= n; i++) adj[i].clear();
	for(int i = 1; i <= n - 1; i++) {
		int u, v, c;
		u = read_int();
		v = read_int();
		c = read_int() - 1;
 
		adj[u].pb({v, c});
		adj[v].pb({u, c});
	}
}
 
int tr_sz[MAXN], cnt_vers;
bool used[MAXN];
 
void pre_dfs(int u, int pr) {
	cnt_vers++;
	tr_sz[u] = 1;
	for(auto v: adj[u])
		if(!used[v.first] && v.first != pr) {
			pre_dfs(v.first, u);
			tr_sz[u] += tr_sz[v.first];
		}
}
 
int centroid(int u, int pr) {
	for(auto v: adj[u])
		if(!used[v.first] && v.first != pr && tr_sz[v.first] > cnt_vers / 2)
			return centroid(v.first, u);
 
	return u;
}
 
int DOM_COL;
inline int get_weight(int col) {
	return col == DOM_COL ? 1 : -1;
}
 
int col_ed[MAXN], dp[MAXN];
int repr[MAXN], par[MAXN];
vector<int> vers;
 
void dfs_add(int u, int pr, int repr) {
	vers.pb(u);
	par[u] = pr;
	::repr[u] = repr;
 
	for(auto v: adj[u]) {
		if(!used[v.first] && v.first != pr) {
			col_ed[v.first] = v.second;
			dfs_add(v.first, u, repr);
		}
	}
}
 
int cnt_subtree[MAXN];
vector<pair<int, int> > buck[MAXN];
 
void counting_sort(vector<pair<int, int> > &vec) {
	int mn = (int)1e9;
	int mx = -(int)1e9;
	for(int i = 0; i < SZ(vec); i++) {
		chkmin(mn, vec[i].first);
		chkmax(mx, vec[i].first);
	}
 
	for(int id = mn; id <= mx; id++) {
		buck[id - mn].clear();
	}
 
	for(int i = 0; i< SZ(vec); i++) {
		buck[vec[i].first - mn].pb(vec[i]);
	}
 
	int pos = 0;
	for(int id = mn; id <= mx; id++) {
		for(auto val: buck[id - mn]) {
			vec[pos++] = val;
		}
	}
}
 
int64_t decompose(int u) {
	cnt_vers = 0;
	pre_dfs(u, u);
	int cen = centroid(u, u);
 
	int64_t ans = 0;
 
	used[cen] = true;
	for(auto v: adj[cen])
		if(!used[v.first]) 
			ans += decompose(v.first);
	used[cen] = false;
 
	vers = {cen};
	repr[cen] = cen;
	dp[cen] = 0;
	for(auto v: adj[cen])
		if(!used[v.first]) {
			col_ed[v.first] = v.second;
			dfs_add(v.first, cen, v.first);
		}
 
	for(int c = 0; c < k; c++)
	{
		DOM_COL = c;
		vector<pair<int, int> > tmp;
		for(int i: vers) {
			if(i != cen) {
				dp[i] = dp[par[i]] + get_weight(col_ed[i]);
			}
 
			tmp.pb({dp[i], repr[i]});
		}
 
		counting_sort(tmp);
 
		int cnt_all = SZ(tmp);
		for(auto it: tmp) cnt_subtree[it.second] = 0;
		for(auto it: tmp) cnt_subtree[it.second]++;
 
		int pos = 0;
		for(int i = SZ(tmp) - 1; i > pos; i--) {
			while(pos < i && tmp[i].first + tmp[pos].first <= 0) {
				cnt_subtree[tmp[pos].second]--;
				cnt_all--;
				pos++;
			}
 
			ans += cnt_all - cnt_subtree[tmp[i].second];
			cnt_subtree[tmp[i].second]--;
			cnt_all--;
		}
	}
 
	return ans;
}
 
void solve() {
	int64_t answer = n * 1ll * (n - 1) / 2;
	if(k >= 2) answer -= decompose(1);
	else answer = 0;
	cout << answer << endl;
}
 
int main() {
	ios_base::sync_with_stdio(false);
	cin.tie(NULL);
 
	int T;
	T = read_int();
	while(T--) {
		read();
		solve();
	}
	return 0;
}
 
const int maxl = 100000;
char buff[maxl];
int ret_int, pos_buff = 0;
 
void next_char() { if(++pos_buff == maxl) fread(buff, 1, maxl, stdin), pos_buff = 0; }
 
int read_int()
{
	ret_int = 0;
	for(; buff[pos_buff] < '0' || buff[pos_buff] > '9'; next_char());
	for(; buff[pos_buff] >= '0' && buff[pos_buff] <= '9'; next_char())
		ret_int = ret_int * 10 + buff[pos_buff] - '0';
	return ret_int;
}

Feel free to Share your approach, If it differs. Suggestions are always welcomed. :slight_smile:

1 Like