SUBMEXS - Editorial

PROBLEM LINK:

Practice
Contest

Setter: Mohammed Ehab
Tester: Ramazan Rakhmatullin
Editorialist: Ishmeet Singh Saggu

DIFFICULTY:

Simple

PREREQUISITES:

Greedy and DFS/BFS

PROBLEM:

Chef has a tree with N nodes (numbered 1 through N). The tree is rooted at node 1. Chef wants to assign a non-negative integer to each node in such a way that each integer between 0 and N−1 (inclusive) is assigned to exactly one node.

For each node u, consider the integers assigned to the nodes in the subtree of u (including u); let a_u denote the MEX of these integers. Chef wants a_1+a_2+…+a_N to be as large as possible. Find the maximum possible value of this sum.

EXPLANATION:

If you assign a node x value 0, then all the nodes which are not node x or its ancestor will have MEX value a_i = 0, and our answer will be reduced to sum of MEX values of nodes lying on the simple path from the root to node x. So it gives us a hint that it is optimal to assign value 0 to a node which is a leaf, Also the maximum MEX value i.e. a_i of a node that it can have, is equal to the size of its subtree and for a path from the root to node x it is always possible to assign the maximum MEX value to each node lying on the path. Now the problem reduces to finding a path from the root to a leaf such that the sum of the size of the subtree of nodes is maximum and that will be our answer. This can be done easily by computing the subtree of each node and then doing a dfs which, for each node maintains a value which is equal to the sum of the subtree of all its ancestor including itself, and the maximum of those values will be our answer. (Note our answer can be large).

TIME COMPLEXITY:

  • Time complexity per test case is O(N).

SOLUTIONS:

Setter's Solution
#include <bits/stdc++.h>
using namespace std;
vector<int> v[100005];
pair<long long,int> dfs(int node)
{
	long long mx=0;
	int sz=1;
	for (int u:v[node])
	{
		auto tmp=dfs(u);
		mx=max(mx,tmp.first);
		sz+=tmp.second;
	}
	return {mx+sz,sz};
}
int main()
{
	int t;
	scanf("%d",&t);
	while (t--)
	{
		int n;
		scanf("%d",&n);
		for (int i=1;i<=n;i++)
		v[i].clear();
		for (int i=2;i<=n;i++)
		{
			int p;
			scanf("%d",&p);
			v[p].push_back(i);
		}
		printf("%lld\n",dfs(1).first);
	}
} 
Tester's Solution
#include <bits/stdc++.h>
 
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wunused-const-variable"
#define popcnt(x) __builtin_popcount(x)
 
#define fr first
 
#define sc second
 
#define m_p make_pair
 
#define low_bo(a, x) lower_bound(a.begin(), a.end(), x) - a.begin()
 
#define up_bo(a, x) upper_bound(a.begin(), a.end(), x) - a.begin()
 
#define unique(a) a.resize(unique(a.begin(), a.end()) - a.begin())
 
#define popcnt(x) __builtin_popcount(x)
 
//#include <ext/pb_ds/assoc_container.hpp>
 
//using namespace __gnu_pbds;
 
//gp_hash_table<int, int> table;
 
//#pragma GCC optimize("O3")
//#pragma GCC optimize("Ofast,no-stack-protector")
//#pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx")
//#pragma GCC target("avx,tune=native")
//float __attribute__((aligned(32)))
 
/*char memory[(int)1e8];
char memorypos;
 
inline void * operator new(size_t n){
    char * ret = memory + memorypos;
    memorypos += n;
    return (void *)ret;
}
 
inline void operator delete(void *){}
*/
 
using namespace std;
 
typedef long long ll;
 
typedef unsigned long long ull;
 
typedef long double ld;
 
typedef unsigned int uint;
 
template<typename T>
class Modular {
public:
    using Type = typename decay<decltype(T::value)>::type;
 
    constexpr Modular() : value() {}
 
    template<typename U>
    Modular(const U &x) {
        value = normalize(x);
    }
 
    static Type inverse(Type a, Type mod) {
        Type b = mod, x = 0, y = 1;
        while (a != 0) {
            Type t = b / a;
            b -= a * t;
            x -= t * y;
            swap(a, b);
            swap(x, y);
        }
        if (x < 0)
            x += mod;
        return x;
    }
 
    template<typename U>
    static Type normalize(const U &x) {
        Type v;
        if (-mod() <= x && x < mod()) v = static_cast<Type>(x);
        else v = static_cast<Type>(x % mod());
        if (v < 0) v += mod();
        return v;
    }
 
    const Type &operator()() const { return value; }
 
    template<typename U>
    explicit operator U() const { return static_cast<U>(value); }
 
    constexpr static Type mod() { return T::value; }
 
    Modular &operator+=(const Modular &other) {
        if ((value += other.value) >= mod()) value -= mod();
        return *this;
    }
 
    Modular &operator-=(const Modular &other) {
        if ((value -= other.value) < 0) value += mod();
        return *this;
    }
 
    template<typename U>
    Modular &operator+=(const U &other) { return *this += Modular(other); }
 
    template<typename U>
    Modular &operator-=(const U &other) { return *this -= Modular(other); }
 
    Modular &operator++() { return *this += 1; }
 
    Modular &operator--() { return *this -= 1; }
 
    Modular operator++(int) {
        Modular result(*this);
        *this += 1;
        return result;
    }
 
    Modular operator--(int) {
        Modular result(*this);
        *this -= 1;
        return result;
    }
 
    Modular operator-() const { return Modular(-value); }
 
    template<typename U = T>
    typename enable_if<is_same<typename Modular<U>::Type, int>::value, Modular>::type &operator*=(const Modular &rhs) {
#ifdef _WIN32
        uint64_t x = static_cast<int64_t>(value) * static_cast<int64_t>(rhs.value);
        uint32_t xh = static_cast<uint32_t>(x >> 32), xl = static_cast<uint32_t>(x), d, m;
        asm(
        "divl %4; \n\t"
        : "=a" (d), "=d" (m)
        : "d" (xh), "a" (xl), "r" (mod())
        );
        value = m;
#else
        value = normalize(static_cast<int64_t>(value) * static_cast<int64_t>(rhs.value));
#endif
        return *this;
    }
 
    template<typename U = T>
    typename enable_if<is_same<typename Modular<U>::Type, int64_t>::value, Modular>::type &
    operator*=(const Modular &rhs) {
        int64_t q = static_cast<int64_t>(static_cast<long double>(value) * rhs.value / mod());
        value = normalize(value * rhs.value - q * mod());
        return *this;
    }
 
    template<typename U = T>
    typename enable_if<!is_integral<typename Modular<U>::Type>::value, Modular>::type &operator*=(const Modular &rhs) {
        value = normalize(value * rhs.value);
        return *this;
    }
 
    Modular &operator/=(const Modular &other) { return *this *= Modular(inverse(other.value, mod())); }
 
    template<typename U>
    friend const Modular<U> &abs(const Modular<U> &v) { return v; }
 
    template<typename U>
    friend bool operator==(const Modular<U> &lhs, const Modular<U> &rhs);
 
    template<typename U>
    friend bool operator<(const Modular<U> &lhs, const Modular<U> &rhs);
 
    template<typename U>
    friend std::istream &operator>>(std::istream &stream, Modular<U> &number);
 
private:
    Type value;
};
 
template<typename T>
bool operator==(const Modular<T> &lhs, const Modular<T> &rhs) { return lhs.value == rhs.value; }
 
template<typename T, typename U>
bool operator==(const Modular<T> &lhs, U rhs) { return lhs == Modular<T>(rhs); }
 
template<typename T, typename U>
bool operator==(U lhs, const Modular<T> &rhs) { return Modular<T>(lhs) == rhs; }
 
template<typename T>
bool operator!=(const Modular<T> &lhs, const Modular<T> &rhs) { return !(lhs == rhs); }
 
template<typename T, typename U>
bool operator!=(const Modular<T> &lhs, U rhs) { return !(lhs == rhs); }
 
template<typename T, typename U>
bool operator!=(U lhs, const Modular<T> &rhs) { return !(lhs == rhs); }
 
template<typename T>
bool operator<(const Modular<T> &lhs, const Modular<T> &rhs) { return lhs.value < rhs.value; }
 
template<typename T>
Modular<T> operator+(const Modular<T> &lhs, const Modular<T> &rhs) { return Modular<T>(lhs) += rhs; }
 
template<typename T, typename U>
Modular<T> operator+(const Modular<T> &lhs, U rhs) { return Modular<T>(lhs) += rhs; }
 
template<typename T, typename U>
Modular<T> operator+(U lhs, const Modular<T> &rhs) { return Modular<T>(lhs) += rhs; }
 
template<typename T>
Modular<T> operator-(const Modular<T> &lhs, const Modular<T> &rhs) { return Modular<T>(lhs) -= rhs; }
 
template<typename T, typename U>
Modular<T> operator-(const Modular<T> &lhs, U rhs) { return Modular<T>(lhs) -= rhs; }
 
template<typename T, typename U>
Modular<T> operator-(U lhs, const Modular<T> &rhs) { return Modular<T>(lhs) -= rhs; }
 
template<typename T>
Modular<T> operator*(const Modular<T> &lhs, const Modular<T> &rhs) { return Modular<T>(lhs) *= rhs; }
 
template<typename T, typename U>
Modular<T> operator*(const Modular<T> &lhs, U rhs) { return Modular<T>(lhs) *= rhs; }
 
template<typename T, typename U>
Modular<T> operator*(U lhs, const Modular<T> &rhs) { return Modular<T>(lhs) *= rhs; }
 
template<typename T>
Modular<T> operator/(const Modular<T> &lhs, const Modular<T> &rhs) { return Modular<T>(lhs) /= rhs; }
 
template<typename T, typename U>
Modular<T> operator/(const Modular<T> &lhs, U rhs) { return Modular<T>(lhs) /= rhs; }
 
template<typename T, typename U>
Modular<T> operator/(U lhs, const Modular<T> &rhs) { return Modular<T>(lhs) /= rhs; }
 
template<typename T, typename U>
Modular<T> power(const Modular<T> &a, const U &b) {
    assert(b >= 0);
    Modular<T> x = a, res = 1;
    U p = b;
    while (p > 0) {
        if (p & 1) res *= x;
        x *= x;
        p >>= 1;
    }
    return res;
}
 
template<typename T>
bool IsZero(const Modular<T> &number) {
    return number() == 0;
}
 
template<typename T>
string to_string(const Modular<T> &number) {
    return to_string(number());
}
 
template<typename T>
std::ostream &operator<<(std::ostream &stream, const Modular<T> &number) {
    return stream << number();
}
 
template<typename T>
std::istream &operator>>(std::istream &stream, Modular<T> &number) {
    typename common_type<typename Modular<T>::Type, int64_t>::type x;
    stream >> x;
    number.value = Modular<T>::normalize(x);
    return stream;
}
 
const int md = 1e9 + 7;
 
using Mint = Modular<std::integral_constant<decay<decltype(md)>::type, md>>;
 
ll sqr(ll x) {
    return x * x;
}
 
int mysqrt(ll x) {
    int l = 0, r = 1e9 + 1;
    while (r - l > 1) {
        int m = (l + r) / 2;
        if (m * (ll) m <= x)
            l = m;
        else
            r = m;
    }
    return l;
}
 
#ifdef ONPC
mt19937 rnd(513);
mt19937_64 rndll(231);
#else
mt19937 rnd(chrono::high_resolution_clock::now().time_since_epoch().count());
    mt19937_64 rndll(chrono::high_resolution_clock::now().time_since_epoch().count());
#endif
 
template<typename T>
T gcd(T a, T b) {
    return a ? gcd(b % a, a) : b;
}
 
int gcdex(int a, int b, int &x, int &y) {
    if (a == 0) {
        x = 0;
        y = 1;
        return b;
    }
    int x1, y1;
    int ret = gcdex(b % a, a, x1, y1);
    x = y1 - (b / a) * x1;
    y = x1;
    return ret;
}
 
void setmin(int &x, int y) {
    x = min(x, y);
}
 
void setmax(int &x, int y) {
    x = max(x, y);
}
 
void setmin(ll &x, ll y) {
    x = min(x, y);
}
 
void setmax(ll &x, ll y) {
    x = max(x, y);
}
 
const ll llinf = 4e18 + 100;
 
const ld eps = 1e-9, PI = atan2(0, -1);
 
const int maxn = 2e5 + 100, maxw = 2e6 + 1111, inf = 1e9 + 100, sq = 450, LG = 18, mod = 1e9 + 933, mod1 = 1e9 + 993;
 
int n;
 
vector<int> e[maxn];
 
ll q[maxn];
 
int sz[maxn];
 
void dfs(int v) {
    sz[v] = 1;
    q[v] = 0;
    for (int i : e[v]) {
        dfs(i);
        setmax(q[v], q[i]);
        sz[v] += sz[i];
    }
    q[v] += sz[v];
}
 
int main() {
#ifdef ONPC
    freopen("../a.in", "r", stdin);
    freopen("../a.out", "w", stdout);
#else
    //freopen("a.in", "r", stdin);
    //freopen("a.out", "w", stdout);
#endif // ONPC
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    int t;
    cin >> t;
    while (t--) {
        cin >> n;
        for (int i = 0; i < n; i++)
            e[i].clear();
        for (int i = 1; i < n; i++) {
            int w;
            cin >> w;
            e[w - 1].push_back(i);
        }
        dfs(0);
        cout << q[0] << '\n';
    }
}
Editorialist's Solution
#include <bits/stdc++.h>
using namespace std;
 
int N;
long long ans;
vector<long long> subtree;
vector<vector<int>> graph;

long long computeSubtree(int node, int parent) {
	subtree[node] = 1;
	for(auto i : graph[node]) {
		if(i == parent) continue;
		subtree[node] += computeSubtree(i, node);
	}
	return subtree[node];
}

void dfs(int node, int parent, long long sum_of_sub_of_ansector) {
	long long value = sum_of_sub_of_ansector + subtree[node];
	int child = 0;
	for(auto i : graph[node]) {
		if(i == parent) continue;
		child ++;
		dfs(i, node, value);
	}
	if(child == 0) {
		ans = max(ans, value); // Here checking for the leaf, though there is no need for checking for the leaf as the value for the leaf will be greater.
	}
}
 
void Solve() {
	cin >> N;
	ans = 0;
	graph.clear();
	subtree.assign(N+1, 0);
	graph.resize(N+1);
	for(int i = 2; i <= N; i ++) {
		int a = i, b;
		cin >> b;
		graph[a].push_back(b);
		graph[b].push_back(a);
	}
	computeSubtree(1, 0);
	dfs(1, 1, 0);
	cout << ans << "\n";
}
 
int main() {
	ios_base::sync_with_stdio(0);
	cin.tie(0);
	cout.tie(0);
	
	int test_case = 1;
	cin >> test_case;
	for(int i = 1; i <= test_case; i ++) {
		Solve();
	}
	
	return 0;
} 

VIDEO EDITORIAL (Hindi):

VIDEO EDITORIAL (English):

Feel free to share your approach. In case of any doubt or anything is unclear please ask it in the comment section. Any suggestions are welcomed. :smile:

8 Likes

great contest btw

5 Likes

Can anyone help me what i have done wrong? My approach has been
For each child of root
1)do dfs and during dfs calculate the subtree size of each node and also find out the bottommost node in that subtree
2)traverse in bottom up manner from the bottommost node of each subtree till that subtree root and at each step add the subtree size to a variable sum
3)Add n=no.of vertices to sum.
4)find maximum of all sum that is obtained by repeating the steps 1 to 3 for each child of 1
Link to my solution
https://www.codechef.com/viewsolution/39244783
Thanks in advance…

  1. count the subtree size of eah node.
  2. insert the pairs of {subtree size ,node} and sort them for each node descending order.
    3.start adding the mex from root and move to the child node that has grater subtree-size until we reach leaf node.
    what’s wrong with approach?
    Link-
    https://www.codechef.com/viewsolution/39242889

Hi, can we try to find out wrong in each other’s solution? I find out yours…you find out mine

okay,I’ll try!

Hey, i have found out yours…
try this input
n=13
1 2 3 4 1 6 6 6 6 6 6 6 —> this is the input sequence …output should be 23 yours is 22

2 Likes

@bootcoder

I did dp on trees.

void dfs(ll u, ll p)
{
    par[u] = p;
    cnt[u] = 1;
    for (int v : adj[u])
    {
        if (v != p)
        {
            dfs(v, u);
            cnt[u] += cnt[v];
        }
    }
    if (adj[u].size() == 1 && u != 1) dp[u] = 1;
    for (int v : adj[u])
        if (v != p)
            ckmax(dp[u], dp[v]+cnt[u]);
}

dp_i is the answer for the subtree of i including i.
So dp_u = \text{max}(dp_v) \ + \ \text{size}(u)

5 Likes

thanks!

Hello, can anyone tell me what I’m doing wrong?
My solution
What the code is doing is that it calculates how many vertices there are in the subtree of each node and then it finds the vertex with maximum depth and does up the tree from that vertex until we reach vertex 1
At every vertex that we visit, we add (number of vertices in subtree) to the answer
Why I thought this was correct:
1
/\
2 3
/\
4 5
If we assign 0 to 5 then 1 to 4 and then 2 to 2 then 3 to 3 and then 4 to 1 then answer will be (1 + 3 + 5) = 9
Please help me ;-;

1 Like

My solution is
1.Calculate maximum size subtree for every level ,Let it be maxSiz[i] for level i , for e.g. maxSiz[1]=n 2. Then sum all the maxSiz[i] for all the levels,
Result is answer.

But it is giving wrong answer, i can’t think of any test case that fails, please point out the mistake, or just provide a test case.

My solution

When i submitted in cpp : https://www.codechef.com/viewsolution/39249621 it passed but earlier i tried to submit in python the same code :: https://www.codechef.com/viewsolution/39247646 i got runtime error i dont know why
please help

Try this testcase, ur code is giving wrong ans with this testcase.

3 Likes

how much time ur solution took? execution time?
I used dfs twice it took 0.20 secs.
https://www.codechef.com/viewsolution/39250093

1 Like

u should assign 1 to the leaf nodes and then calculate subtree size.
This is my submission link check this out
https://www.codechef.com/viewsolution/39250093

2 Likes

0.08s.

2 Likes

nice I m studying ur solution

1 Like

There’s some unnecessary stuff like par that I forgot to remove, so don’t confuse yourself.

2 Likes

use dfs twice

1 Like