CHGORAM2 - Editorial

Can you prove the time complexity of your solution?

If you prefer a video explanation: CHGORAM Video Solution - February Long Challenge 2020

PROBLEM LINK:

Practice
Div-1 Contest

Author: Simon St James
Tester: Radoslav Dimitrov
Editorialist: William Lin

DIFFICULTY:

Medium-Hard

PREREQUISITES:

Trees, Observations, Dynamic Programming

PROBLEM:

We are given a tree and some of the nodes are marked. Find the number of triples of distinct nodes such that all three nodes are marked and all three pairs of nodes within the triple have the same distance.

QUICK EXPLANATION:

For any such triple, there exists a center c such that the distance from c to all three nodes of the triple are equal.
For each node u, count the number triples such that the LCA of the triple is u. There are two cases: when u=c and when u \ne c.
We can calculate dp1_{u,i}, the number of nodes in u's subtree with distance i, and dp2_{u, i}, the number of pairs of nodes (a, b) with the same depth in u's subtree which satisfy (depth(a)-depth(lca(a, b)))-(depth(lca(a, b))-depth(u))=i.
If we process the DP for node u in O(k), where k is the depth of the second deepest subtree of u, then the time complexity is guaranteed to be O(N).

EXPLANATION:

Observation 1. There exists a node c such that c is equidistant to all nodes in the triple.

Proof

On the path between two of the nodes, we know that there is a node c which branches out to the third node:

We can set variables for the distances from c to each of the three nodes and write a system of linear equations:

Solving the system of linear equations shows that c is equidistant to all three nodes.

Let’s perform a DFS on the tree, and for each node u that we visit, let’s find the number of triples which have their lowest common ancestor equal to u. For example, this is one such triple:

Notice that we can split the component containing the triple into a leg and a fork, as shown below:

Note that in order for a leg to be able to pair with a fork, a=d-b must be satisfied. It seems like if we can somehow find the number of legs and forks coming out of u, we can find the answer for u!

This suggests that we use the following DP: dp1_{u, i} will be the number of legs with a=i which start at node u and dp2_{u, i} will be the number of forks with d-b=i which start at node u.

We need to consider how we can perform the transitions to find dp1_u and dp2_u. The following procedure will be useful for this problem:

  • We start with dp1_u and dp2_u representing the single node u.
  • Let v_1, v_2, \dots, v_k be the children of u. Let’s merge dp1_{v_1} and dp2_{v_1} to dp1_u and dp2_u somehow, so that 1. the answer will be updated with new triples which form from the union of u and subtree v_1 and 2. dp1_u and dp2_u will represent the union of the node u and the subtree v_1.
  • For each i, let’s merge dp1_{v_i} and dp2_{v_i} to dp1_u and dp2_u somehow, so that 1. the answer will be updated with new triples which form from the union of the current processed subtree of u and the subtree v_i and 2. dp1_u and dp2_u will represent the union of the node u and the first i subtrees.
  • At the end, dp1_u and dp2_u will represent the entire subtree u.

So now, we need to focus on how we can merge dp1_v and dp2_v with dp1_u and dp2_u.

A. First, because v is deeper than u by 1, we need to change dp1_v and dp2_v. More specifically, all legs from v have a increased by 1, so dp1_{v, j+1}=dp1_{v, j} (and dp1_{v, 0} will be 0), and all forks from v have b increased by 1 and d-b decreased by 1, so dp2_{v, j-1}=dp2_{v, j} (we should just delete dp2_{v, 0} since negative indexes are useless).

B. Next, we need to calculate the triples which form from merging v to u for our total answer. Each triple consists of a leg from v and a fork from u and vice versa. So, for each j, we will add dp1_{v, j}\cdot dp2_{u, j}+dp2_{v, j}\cdot dp1_{u, j} to the total answer.

C. Next, we might have some forks which form from merging v to u. Each fork consists of a leg from v and a leg from u. So, for each j, we will add dp1_{v, j}\cdot dp1_{u, j} to dp2_{u, j}.

D. Lastly, when v becomes part of subtree u, all legs and forks from v should be added to u. So, for each j, we will add dp1_{v, j} to dp1_{u, j} and dp2_{v, j} to dp2_{u, j}.

This DP solution works in O(n^2) time. Let’s try to optimize it.

Let d be the smallest depth of both v_i (the depth including the parent edge) and current processed subtree of u (the union of u with its first i-1 children subtrees).

For step A, let’s view dp1_v and dp2_v both as lists. The transition on dp1_v is equivalent to adding a 0 to the front of the list and the transition on dp2_v is equivalent to removing the first element of the list. These are both linear time operations - unless we store the list in reverse, as operations on the end of the list are constant time. So that’s what we’ll do.

For steps B and C, when we iterate j, we only have to iterate from 0 to d. Why don’t we need to iterate over d? d+1 is greater than the depth of one of the subtrees, so one of dp_{v, d+1} and dp_{u, d+1} is guaranteed to be 0. In steps B and C, all terms are a product dp_{v, d+1} and dp_{u, d+1}, so all j>d will cause these terms to evaluate to 0. This makes steps B and C run in O(d) time.

For step D, if v has the smallest depth, then we don’t need to do anything. Otherwise, we will add dp_u to dp_v instead, and let dp_v replace dp_u. In both cases, we only iterate j up to d, which makes step D run in O(d) time.

In steps B, C, and D, we applied what’s known as the merge-small-to-large on depths - we find the smallest depth d of both subtrees before merging them. Then, we merge the small subtree to the large subtree to take O(d) time for the merge.

Observation 2. The merge-small-to-large trick on depths optimizes the solution to O(n) time.

Proof

Whenever we merge v to u, find the subtree with minimum depth (remember, we consider the parent edge for v, and for u, we only consider the union of u with the first i children subtrees). If both subtrees have the same depth, choose the deepest path with the greatest id for the leaf node. Then, in that subtree, color the deepest path, and if there are multiple, color the one with smallest id for the leaf node. Note that we color O(d) edges, so the total number of edges that we color is the time complexity of our solution.

Below are some examples:

In the example above, subtree u has the minimum depth, so we color the deepest path of u (which is nothing).

In the example above, we merge another child subtree v to u. v has the minimum depth, so we color its deepest path.

In the example above, we finish processing u, and now it becomes a child subtree v to be merged with its parent u. We can see that v has the minimum depth, so we color its deepest path.

We can work with more cases, but it seems like no edge can ever be colored twice!

What can we say about a path when we color it? It’s part of the subtree with minimum depth, so when we merge the two subtrees, that colored path can’t be the deepest path anymore. Since we only color deepest paths, that means we never color a colored path again!

There are O(n) edges in a tree, so if no edge is colored twice, we color at most O(n) edges. Since this is equivalent to the time complexity of our solution, our solution works in O(n) time.

RELATED PROBLEMS:

Merge-small-to-large:
https://codeforces.com/problemset/problem/1193/B
https://agc007.contest.atcoder.jp/tasks/agc007_e

SOLUTIONS:

Setter's Solution
// Simon St James (ssjgz) - Setter's solution for "Equilateral Treeangles" - 2019-07-17.
// Problem was later renamed to "Chef and Gordon Ramsay 2" when I saw CHGORAM in AUG19
// and noticed how my Problem could be turned into a sequel to it :)
#include <iostream>
#include <vector>
#include <map>
#include <algorithm>

#include <cassert>

using namespace std;

const int numTripletPermutations = 1 * 2 * 3; // i.e. == factorial(3).

struct DescendantWithHeightInfo
{
    // The number of suitable pairs of descendants of Node with the given height
    // that have Node as their Lowest Common Ancestor.
    int64_t numPairsWithHeightWithNodeAsLCA = 0;
    // The name "number" is not strictly correct: if numPairsWithHeightWithNodeAsLCA > 0, then this
    // will indeed be the number of suitable descendants of this Node which have this height,
    // but if numPairsWithHeightWithNodeAsLCA == 0, "number" will also be 0.
    // 
    // The "number" count is used to avoid the overcount when performing completeTripletsOfTypeA
    // (Centroid Decomposition has no notion of our parent-child relationship, so
    // will count "descendants" of a Node alongside "non-descendants" of a Node.)
    int number = 0;
};

struct Node
{
    vector<Node*> neighbours;
    bool isSuitable = false;
    int height = 0;

    // The total number of entries into descendantWithHeightInfo, summed across *all* nodes,
    // will be O(n log n).
    map<int, DescendantWithHeightInfo> descendantWithHeightInfo;
};

struct HeightInfo
{
    int numWithHeight = 0;
    // Make a note of which Node this info has been incorporated into (i.e. which Node's
    // descendantWithHeightInfo it has been used to update) so we don't accidentally
    // incorporate it into the same Node twice!
    Node* lastIncorporatedIntoNode = nullptr;
};

class DistTracker
{
    public:
        // O(maxDist).
        DistTracker(int maxDist)
            : m_numWithDist(2 * maxDist + 1), 
            m_maxDist(maxDist)
        {
        }
        // O(1).
        void insertDist(const int newDist)
        {
            numWithDistValue(newDist)++;
            m_largestDist = max(m_largestDist, newDist);
        };
        // O(1).
        int numWithDist(int dist)
        {
            return numWithDistValue(dist);
        }
        // O(1).
        void adjustAllDists(int distDiff)
        {
            m_cumulativeDistAdjustment += distDiff;
            assert(m_cumulativeDistAdjustment >= 0);
            if (m_largestDist != -1)
                m_largestDist += distDiff;
        }
        // O(1).
        int largestDist() const
        {
            return m_largestDist;
        }
    private:
        int m_cumulativeDistAdjustment = 0;
        vector<int> m_numWithDist;
        int m_maxDist = -1;

        int m_largestDist = -1;

        int& numWithDistValue(int dist)
        {
            return m_numWithDist[dist - m_cumulativeDistAdjustment + m_maxDist];
        }
};

enum DistTrackerAdjustment {DoNotAdjust, AdjustWithDepth};

template <typename NodeProcessor>
void doDfs(Node* node, Node* parentNode, int depth, DistTracker& distTracker, DistTrackerAdjustment distTrackerAdjustment, NodeProcessor& processNode)
{
    if (distTrackerAdjustment == AdjustWithDepth)
        distTracker.adjustAllDists(1);

    processNode(node, depth, distTracker);

    for (auto child : node->neighbours)
    {
        if (child == parentNode)
            continue;
        doDfs(child, node, depth + 1, distTracker, distTrackerAdjustment, processNode);
    }

    if (distTrackerAdjustment == AdjustWithDepth)
        distTracker.adjustAllDists(-1);
}

int countDescendants(Node* node, Node* parentNode)
{
    auto numDescendants = 1; // Current node.

    for (const auto& child : node->neighbours)
    {
        if (child == parentNode)
            continue;

        numDescendants += countDescendants(child, node);
    }

    return numDescendants;
}

int findCentroidAux(Node* currentNode, Node* parentNode, const int totalNodes, Node** centroid)
{
    auto numDescendants = 1;

    auto childHasTooManyDescendants = false;

    for (const auto& child : currentNode->neighbours)
    {
        if (child == parentNode)
            continue;

        const auto numChildDescendants = findCentroidAux(child, currentNode, totalNodes, centroid);
        if (numChildDescendants > totalNodes / 2)
        {
            // Not the centroid, but can't break here - must continue processing children.
            childHasTooManyDescendants = true;
        }

        numDescendants += numChildDescendants;
    }

    if (!childHasTooManyDescendants)
    {
        // No child has more than totalNodes/2 descendants, but what about the remainder of the graph?
        const auto nonChildDescendants = totalNodes - numDescendants;
        if (nonChildDescendants <= totalNodes / 2)
        {
            assert(centroid);
            *centroid = currentNode;
        }
    }

    return numDescendants;
}

Node* findCentroid(Node* startNode)
{
    const auto totalNumNodes = countDescendants(startNode, nullptr);
    Node* centroid = nullptr;
    findCentroidAux(startNode, nullptr, totalNumNodes, &centroid);
    assert(centroid);
    return centroid;
}

void doCentroidDecomposition(Node* startNode, int64_t& numTriplets)
{
    Node* centroid = findCentroid(startNode);

    auto addSomeTypeATripletsForNode = [&numTriplets](Node* node, DistTracker& distTracker)
    {
        // This will be called O(log2 n) times for each node before that node's
        // Type A Triplets are fully completed.
        for (const auto& heightPair : node->descendantWithHeightInfo)
        {
            const auto descendantHeight = heightPair.first;
            const auto requiredNonDescendantDist = (descendantHeight - node->height);
            if (requiredNonDescendantDist > distTracker.largestDist())
                break; // Optimisation - no point continuing with larger descendantHeights.

            const auto numPairsWithHeightWithNodeAsLCA = heightPair.second.numPairsWithHeightWithNodeAsLCA;
            const auto numNewTriplets = numPairsWithHeightWithNodeAsLCA * distTracker.numWithDist(requiredNonDescendantDist) * numTripletPermutations;
            assert(numNewTriplets >= 0);
            numTriplets += numNewTriplets;
        }
    };
    auto propagateDists = [&addSomeTypeATripletsForNode](Node* node, int depth, DistTracker& distTracker)
                        {
                            addSomeTypeATripletsForNode(node, distTracker);
                        };
    auto collectDists = [](Node* node, int depth, DistTracker& distTracker)
                        {
                            if (node->isSuitable)
                                distTracker.insertDist(depth);
                        };

    const auto numNodesInComponent = countDescendants(startNode, nullptr);

    {
        DistTracker distTracker(numNodesInComponent);
        for (auto& child : centroid->neighbours)
        {
            doDfs(child, centroid, 1, distTracker, AdjustWithDepth, propagateDists );
            doDfs(child, centroid, 1, distTracker, DoNotAdjust, collectDists );
        }
    }
    {
        DistTracker distTracker(numNodesInComponent);
        // Do it again, this time backwards ...
        reverse(centroid->neighbours.begin(), centroid->neighbours.end());
        // ... and also include the centre, this time.
        if (centroid->isSuitable)
            distTracker.insertDist(0);
        for (auto& child : centroid->neighbours)
        {
            doDfs(child, centroid, 1, distTracker, AdjustWithDepth, propagateDists );
            doDfs(child, centroid, 1, distTracker, DoNotAdjust, collectDists );
        }
        addSomeTypeATripletsForNode(centroid, distTracker);
    }

    for (auto& neighbour : centroid->neighbours)
    {
        assert(std::find(neighbour->neighbours.begin(), neighbour->neighbours.end(), centroid) != neighbour->neighbours.end());
        // Erase the edge from the centroid's neighbour to the centroid, essentially "chopping off" each child into its own
        // component ...
        neighbour->neighbours.erase(std::find(neighbour->neighbours.begin(), neighbour->neighbours.end(), centroid));
        // ... and recurse.
        doCentroidDecomposition(neighbour, numTriplets);
    }
}

void completeTripletsOfTypeACentroidDecomposition(vector<Node>& nodes, Node* rootNode, int64_t& numTriplets)
{
    doCentroidDecomposition(rootNode, numTriplets);
    // Fix the overcount caused by Centroid Decomposition (over-)counting descendants of a node as non-descendants
    // of a node!
    for (auto& node : nodes)
    {
        for (const auto descendantHeightPair : node.descendantWithHeightInfo)
        {
            const auto height = descendantHeightPair.first;
            const auto numPairsWithHeightWithNodeAsLCA = descendantHeightPair.second.numPairsWithHeightWithNodeAsLCA;

            // Centroid decomposition would have (wrongly) added numPairsWithHeightWithNodeAsLCA[height] * numTripletPermutations 
            // for each suitable descendant of node with height "height" - correct for this.
            numTriplets -= numPairsWithHeightWithNodeAsLCA * node.descendantWithHeightInfo[height].number * numTripletPermutations;
        }
    }
}

map<int, HeightInfo> buildDescendantHeightInfo(Node* currentNode, Node* parentNode, int height, int64_t& numTriplets)
{
    currentNode->height = height;

    map<int, HeightInfo> persistentInfoForDescendantHeight;

    for (auto child : currentNode->neighbours)
    {
        if (child == parentNode)
            continue;
        // Quick C++ performance note: in C++11 onwards, capturing a returned std::map
        // in a local variable is O(1), due to Move Semantics.  Prior to this, though,
        // it could have been O(size of std::map) (if the Return Value Optimisation ended up
        // not being used), which would (silently!) lead to asymptotically worse performance!
        //
        // Luckily, this code uses C++11 features so we can't accidentally fall into this trap.
        auto transientInfoForDescendantHeight = buildDescendantHeightInfo(child, currentNode, height + 1, numTriplets);
        if (transientInfoForDescendantHeight.size() > persistentInfoForDescendantHeight.size())
        {
            // We'll be copying transientInfoForDescendantHeight into persistentInfoForDescendantHeight.
            // Ensure that the former is smaller than the latter so that we can make use of the Small-to-Large
            // trick. NB: std::swap'ing is O(1).
            swap(persistentInfoForDescendantHeight, transientInfoForDescendantHeight);
        }

        for (auto transientDescendantHeightPair : transientInfoForDescendantHeight)
        {
            // This block of code (i.e. the body of the containing for... loop) 
            // is executed O(n log2 n) times over the whole run.
            // It is guaranteed to be executed with descendantHeight if the current
            // child has a descendant with height descendantHeight that isSuitable and a previous child of this
            // node also has a descendant with height descendantHeight that isSuitable, but may also
            // be executed under different circumstances.
            //
            // Since this block of code adds at most one entry into currentNode's descendantWithHeightInfo member,
            // the sum of node->descendantWithHeightInfo.size() over all nodes is O(n log2 n).
            const auto descendantHeight = transientDescendantHeightPair.first;

            const auto& transientHeightInfo = transientDescendantHeightPair.second;
            auto& heightInfoForNode = persistentInfoForDescendantHeight[descendantHeight];

            assert (descendantHeight > currentNode->height);

            auto numUnprocessedDescendantsWithHeight = -1;
            auto numKnownDescendantsWithHeight = -1;
            assert(transientHeightInfo.lastIncorporatedIntoNode != nullptr);
            if (transientHeightInfo.lastIncorporatedIntoNode == currentNode)
            {
                assert(heightInfoForNode.lastIncorporatedIntoNode != currentNode);
                numUnprocessedDescendantsWithHeight = heightInfoForNode.numWithHeight;
                numKnownDescendantsWithHeight = transientHeightInfo.numWithHeight;
            }
            else
            {
                assert(transientHeightInfo.lastIncorporatedIntoNode != currentNode);
                numUnprocessedDescendantsWithHeight = transientHeightInfo.numWithHeight;
                numKnownDescendantsWithHeight = heightInfoForNode.numWithHeight;
            }

            const auto earlierChildHasThisHeight = (numKnownDescendantsWithHeight > 0);
            if (earlierChildHasThisHeight)
            {
                // Incorporate any un-incorporated HeightInfo into this Node's descendantWithHeightInfo.
                auto& descendantHeightInfo = currentNode->descendantWithHeightInfo[descendantHeight];
                auto& numPairsWithHeightWithNodeAsLCA = descendantHeightInfo.numPairsWithHeightWithNodeAsLCA;
                auto& numberDescendantsWithThisHeight = descendantHeightInfo.number;

                if (numUnprocessedDescendantsWithHeight * numPairsWithHeightWithNodeAsLCA > 0)
                {
                    // Found a triple where all three nodes have currentNode as their LCA: a "Type B" triple.
                    const auto numNewTriplets = numPairsWithHeightWithNodeAsLCA * numUnprocessedDescendantsWithHeight * numTripletPermutations;
                    assert(numNewTriplets >= 0);
                    numTriplets += numNewTriplets;
                }

                // These numPairsWithHeightWithNodeAsLCA would, when combined with a non-ancestor of currentNode that isSuitable and is
                // (descendantHeight - currentNode->height) distance away from currentNode, form a "Type A" triple.
                // We store numPairsWithHeightWithNodeAsLCA for this descendantHeight inside currentNode: the required non-ancestors of
                // currentNode will be found by completeTripletsOfTypeA() later on.
                numPairsWithHeightWithNodeAsLCA += numUnprocessedDescendantsWithHeight * numKnownDescendantsWithHeight;

                if (numberDescendantsWithThisHeight == 0)
                {
                    // This hasn't been updated yet, so has missed the numKnownDescendantsWithHeight; incorporate it now.
                    numberDescendantsWithThisHeight += numKnownDescendantsWithHeight;
                }
                numberDescendantsWithThisHeight += numUnprocessedDescendantsWithHeight;

            }

            // "Copy" the transient info into persistent info, and make a note that this HeightInfo has been incorporated
            // into currentNode.
            heightInfoForNode.numWithHeight += transientHeightInfo.numWithHeight;
            heightInfoForNode.lastIncorporatedIntoNode = currentNode;
        }
    }

    if (currentNode->isSuitable)
    {
        persistentInfoForDescendantHeight[currentNode->height].numWithHeight++;
        persistentInfoForDescendantHeight[currentNode->height].lastIncorporatedIntoNode = currentNode;
    }

    return persistentInfoForDescendantHeight;
}

int64_t findNumTriplets(vector<Node>& nodes)
{
    int64_t numTriplets = 0;

    auto rootNode = &(nodes.front());

    // Fills in numPairsWithHeightWithNodeAsLCA for each node, and 
    // additionally counts all "Type B" triples and adds them to results.
    buildDescendantHeightInfo(rootNode, nullptr, 0, numTriplets);

    // Finishes off the computation of the number of "Type A" triples
    // that we began in buildDescendantHeightInfo.
    completeTripletsOfTypeACentroidDecomposition(nodes, rootNode, numTriplets);

    return numTriplets;
}

template <typename T>
T read()
{
    T toRead;
    cin >> toRead;
    assert(cin);
    return toRead;
}

int main(int argc, char* argv[])
{
    // This solution is the first one that occurred to me (well, ish - the first one
    // didn't use Centroid Decomposition but a technique with similar capabilities
    // based around Heavy-Light Decomposition which seems a little faster in practice,
    // and was a little simpler for this Problem).  Based on the solutions of the Tester
    // (and basically everyone who got 100pts on this Problem during the Contest :)),
    // it's become clear that this is an unsually clunky and fiddly way of doing things,
    // so I'm only going to give a brief overview of it, here - the official Editorial 
    // solution will doubtless be better to learn from :) This explanation is adapted from
    // the "Brief (ha!) Description" that is required when submitting a new Problem Idea.
    //
    // Anyway, if you're still reading (which you shouldn't be :p), here goes:
    //
    // Imagine we pick some arbitrary node (I pick node number 1) as the root R of the tree
    // and do a DFS from there.  Then it can be shown that any triple (p, q, r) will be
    // of one of the following two types:
    //
    //
    // Type A - p, q, r have the same lca, from which they are equidistant e.g.
    //
    //                        R
    //                       /|\
    //                        ...
    //                           X = lca(p, q, r)
    //                         / | \
    //                        .......
    //                       /   |   \
    //                      p    q    r   <-- dist(p, x) = dist(q, x) = dist(r, x); 
    //                                        p, q and r all have same distance from X,
    //                                        and so the same distance from R (aka "height").
    //
    // Type B - the two "lowest" nodes - say q and r wlog, are equidistant from their lca = X.  
    // p is *not* a descendent of X, but dist(p, X) = dist(q, X) (= dist(r, X)) e.g.
    //
    //                        R
    //                       /|\
    //                        ...
    //                       /\
    //                      ... ...
    //                     /    \
    //                    p     ...
    //                           |
    //                           X = lca(q, r)
    //                         /   \
    //                        .......
    //                       /       \
    //                      q         r   <-- dist(p, X) = dist(r, X); p, q have the same distance from R.
    //
    // Type A is the easiest to compute; Type B is harder and is computed in two phases: the 
    // first phase is shared with the computation of the number of Type A triplets; the 
    // second is separate and uses Centroid Decomposition.
    // 
    // There's a reasonably well known algorithm for calculating, for each node v, the set of 
    // all descendents of v in O(N log N):
    // 
    //    findDescendents(root)
    //        set_of_descendents = empty-set
    //        for each child c of root:
    //            set_of_descendents_of_child = findDescendents(c)
    //            if |set_of_descendents_of_child| > |set_of_descendents|:
    //                swap set_of_descendents with set_of_descendents_of_child (O(1)).
    //    
    //            for each node in set_of_descendents_of_child:
    //                add node to set_of_descendents
    //    
    //        return set_of_descendents
    // 
    // The algorithm looks like it's O(N^2) worst case, but the fact that we always "copy" 
    // the smaller set into the larger actually makes it asymptotically better (O(N log N)), 
    // for the same reason as disjoint-union-by-size is O(N log N): 
    // https://en.wikipedia.org/wiki/Disjoint-set_data_structure#by_size
    // I've also seen this technique referred to as the "Small-To-Large Trick".
    // 
    // For a node v, let height(v) = dist(R, v).  We adapt the algorithm to compute not the 
    // set of descendents of each node v, but a count (map) of *heights of descendents of 
    // suitable nodes of v* i.e. a map where the keys are heights and the values are the 
    // number of descendents of v which are suitable and have that height; it can be shown 
    // that this also is achievable in worst case O(N log N).
    // 
    // With some more book-keeping, it can be shown that we can extract, in O(N log N) time 
    // and space, for all nodes v, for all heights h [NB: obviously such heights are stored 
    // sparsely, else the space requirements would be O(N^2)]:
    // 
    //     the number of pairs of descendents (u', v') of v such that:
    //         * u' and v' are both suitable;
    //         * lca(u', v') = v; and
    //         * height(u') = height(v') = h (and so: dist(u', v) = dist(v', v))
    // 
    // (this is stored as descendantWithHeightInfo.numPairsWithHeightWithNodeAsLCA) and also
    // 
    //     the number of triples (u', v', w') of v such that
    //         * u', v' and w' are all suitable;
    //         * lca(u', v') = lca(u', w') = lca(v', w') = v; and
    //         * height(u') = height(v') = height(w') = h (and so: 
    //           dist(u', v) = dist(v', v) = dist(w', v))
    // 
    // This completes Phase one of two.
    // 
    // We see that the latter count of triples is the number of triples of Type A (more 
    // precisely - it is the number of such triples divided by the number of permutations 
    // of a triple, 3! - see later).
    // 
    // The former is a step towards the computation of the number of Type B triples - 
    // referring to the definition of Type B, we've found, for each X, the number of q's 
    // and r's with a given height h (which equates to a distance d from X: height(q) - height(x))
    // and lca X, and now we just need to find for each X and h the number of p's such that 
    // dist(p, X) = height(q) - height(X).  We do this latter step via Centroid Decomposition.
    // 
    // More precisely, Phase one gives us, for each node v, a map from heights to number 
    // of pairs of descendents called v.numPairsWithHeightWithNodeAsLCA - all in O(N log N) 
    // time and space.
    // 
    // We then perform Centroid Decomposition with an efficient DistTracker class that 
    // implements the following API:
    // 
    //     * addDistance(newDistance) - adds the distance to the list of tracked distances
    //                                  in O(1).
    //     * addToAllDistances(distanceIncrease) - adds distanceIncrease to each of the 
    //                                             tracked distances in O(1).
    //     * getNumWithDistance(distance) - return the number of tracked distances whos 
    //                                      current value is precisely distance in O(1).
    // 
    // A simple combination of CD and DistTracker allows us to finally "complete" the "q's" 
    // and "r's" found in Phase one with the "p's" necessary to form a complete, Type B-triplet 
    // in O(N x (log N) x (log N)).
    // 
    // It can be further shown that if a triplet (p, q, r) of either type A or B is 
    // counted by the algorithm, then none of its (3! - 1) other permutations - (p, r, q); 
    // (r, p, q) etc - will be counted, so simply multiplying the count of Type A and 
    // Type B triplets by 3! gives us the final result.
    // 
    // Well - not quite :) The Centroid Decomposition step doesn't know about parent and 
    // children in our original DFS from R, so it will overcount triples.  For example, 
    // consider the simple example:
    //
    //                     R
    //                     |
    //                     X
    //                    / \
    //                   Y   Z
    //
    // Imagine that Y and Z are the only suitable nodes, so there are no valid triples in 
    // this example.  
    //
    // Note that X's descendantWithHeightInfo[2].numPairsWithHeightWithNodeAsLCA will be 1.  
    // The Centroid Decomposition step will then, unfortunately, treat Y as a "completer" 
    // of this pair (Y, Z) (and will treat Z the same way), resulting in it reporting *2*
    // triples, instead of the correct answer of 0.  Luckily, it's easy (though clunky) to 
    // remove this overcount: see DescendantWithHeightInfo::number for more information.  
    // The proper Editorial solution, and the original HLD-based solution, don't have this 
    // clunky "overcount correction" step, which is another reason why they are superior :)
    //
    // The whole algorithm runs in O(N x (log N) x (log N)) time with O(N log N) space.
    ios::sync_with_stdio(false);

    const auto numTestcases = read<int>();

    for (auto t = 0; t < numTestcases; t++)
    {
        const auto numNodes = read<int>();
        assert(1 <= numNodes && numNodes <= 200'000);

        vector<Node> nodes(numNodes);

        for (auto i = 0; i < numNodes - 1; i++)
        {
            const auto u = read<int>();
            const auto v = read<int>();

            assert(1 <= u && u <= numNodes);
            assert(1 <= v && v <= numNodes);

            nodes[u - 1].neighbours.push_back(&(nodes[v - 1]));
            nodes[v - 1].neighbours.push_back(&(nodes[u - 1]));
        }
        for (auto i = 0; i < numNodes; i++)
        {
            const auto isSuitable = read<int>();

            assert(isSuitable == 0 || isSuitable == 1);

            nodes[i].isSuitable = (isSuitable == 1);
        }

        const auto numTriplets = findNumTriplets(nodes);
        cout << numTriplets << endl;
    }

    assert(cin);
}
Tester's Solution
#include <bits/stdc++.h>
#define endl '\n'
 
#define SZ(x) ((int)x.size())
#define ALL(V) V.begin(), V.end()
#define L_B lower_bound
#define U_B upper_bound
#define pb push_back
 
using namespace std;
template<class T, class T1> int chkmin(T &x, const T1 &y) { return x > y ? x = y, 1 : 0; }
template<class T, class T1> int chkmax(T &x, const T1 &y) { return x < y ? x = y, 1 : 0; }
const int MAXN = (1 << 18);
 
struct mydeque {
	vector<int> vec;
	int& operator[](const unsigned &pos) { return vec[vec.size() - pos - 1]; } 
	unsigned size() { return vec.size(); }
	void push_front(int x) { vec.pb(x); }
};
 
void merge(mydeque &a, mydeque &b) {
	if(SZ(a) < SZ(b)) swap(a, b);
	for(int i = 0; i < SZ(b); i++) {
		a[i] += b[i];
	}
}
 
int n, s[MAXN];
vector<int> adj[MAXN];
 
int read_int();
 
void read() {
	n = read_int();
	for(int i = 1; i <= n; i++) {
		adj[i].clear();
	}
 
	for(int i = 0; i < n - 1; i++) {
		int u, v;
		u = read_int();
		v = read_int();
		adj[u].pb(v);
		adj[v].pb(u);
	}
 
	for(int i = 1; i <= n; i++) {
		s[i] = read_int();
	}
}
 
int tr_sz[MAXN], cnt_vers;
bool used[MAXN];
 
void pre_dfs(int u, int pr)
{
	cnt_vers++;
	tr_sz[u] = 1;
	for(int v: adj[u])
		if(!used[v] && v != pr)
		{
			pre_dfs(v, u);
			tr_sz[u] += tr_sz[v];
		}
}
 
int centroid(int u, int pr)
{
	for(int v: adj[u])
		if(!used[v] && v != pr && tr_sz[v] > cnt_vers / 2)
			return centroid(v, u);
 
	return u;
}
 
int64_t answer = 0;
 
int dep[MAXN];
int curr_cnt[MAXN], gen_cnt[MAXN];
int64_t cnt[MAXN][2];
 
void fill_curr_cnt(int u, int pr, int d = 1) {
	curr_cnt[d] += s[u];
	for(int v: adj[u]) {
		if(v != pr && !used[v]) {
			fill_curr_cnt(v, u, d + 1);
		}
	}
}
 
void dfs1(int u, int pr) {
	dep[u] = 0;
	for(int v: adj[u]) {
		if(v != pr && !used[v]) {
			dfs1(v, u);
			chkmax(dep[u], dep[v] + 1);
		}
	}
}
 
mydeque dp[MAXN];
 
void add(pair<int, int> &mx, int x) {
	if(chkmax(mx.second, x)) {
		if(mx.first < mx.second) {
			swap(mx.first, mx.second); 
		}
	}
}
 
int64_t two[MAXN];
int tmp[MAXN];
 
void dfs2(int u, int pr, int curr_d = 1) {
	dp[u].vec.clear();
	dp[u].push_front(s[u]);
 
	pair<int, int> mx = {0, 0};
	for(int v: adj[u]) {
		if(v != pr && !used[v]) {
			dfs2(v, u, curr_d + 1);
			dp[v].push_front(0);
			add(mx, dep[v] + 1);
		}
	}
 
	int d = mx.second;
	for(int i = 0; i <= d; i++) {
		tmp[i] = 0;
		two[i] = 0;
	}
 
	for(int v: adj[u]) {
		if(v == pr || used[v]) continue;
 
		int cd = min(d, dep[v] + 1);
		for(int i = 1; i <= cd; i++) {
			two[i] += tmp[i] * 1ll * dp[v][i];
			tmp[i] += dp[v][i];
		}
	}
 
	for(int i = 1; i <= d; i++) {
		int od = i - curr_d;
		if(od < 0) continue;
		answer += (gen_cnt[od] - curr_cnt[od]) * 1ll * two[i];
	}
	
	for(int v: adj[u]) {
		if(v != pr && !used[v]) {
			merge(dp[u], dp[v]);
		}
	}
}
 
void decompose(int u)
{
	cnt_vers = 0;
	pre_dfs(u, u);
	int cen = centroid(u, u);
 
	used[cen] = true;
	for(int v: adj[cen])
		if(!used[v]) 
			decompose(v);
	used[cen] = false;
 
	dfs1(cen, -1);
 
	for(int i = 0; i <= dep[cen] + 1; i++) {
		cnt[i][0] = cnt[i][1] = 0;
		curr_cnt[i] = 0;
	}
 
	fill_curr_cnt(cen, -1, 0);
	for(int i = 0; i <= dep[cen] + 1; i++) {
		gen_cnt[i] = curr_cnt[i];
		curr_cnt[i] = 0;
	}
 
	cnt[0][0] = 1;
 
	for(int v: adj[cen]) {
		if(!used[v]) {
			for(int i = 0; i <= dep[v] + 1; i++) {
				curr_cnt[i] = 0;
				tmp[i] = 0;
			}
 
			fill_curr_cnt(v, cen);
			for(int i = 1; i <= dep[v] + 1; i++) {
				if(curr_cnt[i]) {
					answer += curr_cnt[i] * 1ll * cnt[i][1];
					cnt[i][1] += curr_cnt[i] * 1ll * cnt[i][0];
					cnt[i][0] += curr_cnt[i];
				}	
			}
 
			dfs2(v, cen);
 
			for(int i = 1; i <= dep[v] + 1; i++) {
				curr_cnt[i] = 0;
			}
		}
	}
}
 
void solve() {	
	answer = 0;
	decompose(1);
	cout << answer * 6ll << endl;
}
 
int main() {
	ios_base::sync_with_stdio(false);
	cin.tie(NULL);
 
	int T;
	T = read_int();
	while(T--) {
		read();
		solve();
	}
 
	return 0;
}
 
const int maxl = 100000;
char buff[maxl];
int ret_int, pos_buff = 0;
 
void next_char() { if(++pos_buff == maxl) fread(buff, 1, maxl, stdin), pos_buff = 0; }
 
int read_int()
{
	ret_int = 0;
	for(; buff[pos_buff] < '0' || buff[pos_buff] > '9'; next_char());
	for(; buff[pos_buff] >= '0' && buff[pos_buff] <= '9'; next_char())
		ret_int = ret_int * 10 + buff[pos_buff] - '0';
	return ret_int;
}
Editorialist's Solution
#include <bits/stdc++.h>
using namespace std;

#define ll long long

const int mxN=2e5;
int n, s[mxN];
vector<int> adj[mxN];
vector<ll> dp1[mxN], dp2[mxN];
ll ans;

void dfs(int u=0, int p=-1) {
	dp1[u]={s[u]};
	dp2[u]={0};
	for(int v : adj[u]) {
		if(v==p)
			continue;
		dfs(v, u);
		//now merge the dp for v to u and update ans
		//increase the depths by 1
		dp1[v].push_back(0);
		dp2[v].pop_back();
		//if v is deeper, switch the dp to make it faster
		if(dp1[v].size()>dp1[u].size()) {
			swap(dp1[u], dp1[v]);
			swap(dp2[u], dp2[v]);
		}
		//make sure dp2[u] is big enough
		if(dp2[u].size()<dp1[v].size()) {
			//pad dp2[u] with 0s
			vector<int> p(dp1[v].size()-dp2[u].size(), 0);
			dp2[u].insert(dp2[u].begin(), p.begin(), p.end());
		}
		//update ans
		//leg from v, fork from u
		for(int i=1; i<=dp1[v].size(); ++i)
			ans+=dp1[v][dp1[v].size()-i]*dp2[u][dp2[u].size()-i];
		//fork from v, leg from u
		for(int i=1; i<=dp2[v].size(); ++i)
			ans+=dp2[v][dp2[v].size()-i]*dp1[u][dp1[u].size()-i];
		//combine the dp
		//new forks by combining 2 legs
		for(int i=1; i<=dp1[v].size(); ++i)
			dp2[u][dp2[u].size()-i]+=dp1[u][dp1[u].size()-i]*dp1[v][dp1[v].size()-i];
		//just add the rest together
		for(int i=1; i<=dp1[v].size(); ++i)
			dp1[u][dp1[u].size()-i]+=dp1[v][dp1[v].size()-i];
		for(int i=1; i<=dp2[v].size(); ++i)
			dp2[u][dp2[u].size()-i]+=dp2[v][dp2[v].size()-i];
	}
	adj[u].clear();
}

void solve() {
	//input
	cin >> n;
	for(int i=1, u, v; i<n; ++i) {
		cin >> u >> v, --u, --v;
		adj[u].push_back(v);
		adj[v].push_back(u);
	}
	for(int i=0; i<n; ++i)
		cin >> s[i];

	//dfs to calculate the answer
	ans=0;
	dfs();
	cout << 6*ans << "\n";
}

int main() {
	ios::sync_with_stdio(0);
	cin.tie(0);

	int t;
	cin >> t;
	while(t--)
		solve();
}
8 Likes

Updated Setter’s solution here - I haven’t been that detailed with the overview for the reasons described in it (i.e. - my solution was very overcomplicated :))

Problem History here - warning: long and rambling :slight_smile:

Problem History

@shivank98 came up with the idea of writing a little bit about how the Problem came about a while back, and I thought it was a good one, so I think I’ll take a stab at it :slight_smile:

Oddly, CHGORAM2 arose out of my disatisfaction with a Hackerrank problem called “Find the Nearest Clone”. FTNC would have been a neat Easy-Medium question except for the fact that is utterly let down by its legendarily poor testcases - or rather, more fundamentally, by the fact that the Setter/ Editorialist did not realise that his purportedly O(N) solution was in fact O(N^2).

I brainstormed ways of coming up with a tweak to the Problem that would enable me to re-do it (this time with decent testcases!) - perhaps count the number of ways of forming this closest pair? Or finding not the nearest pair of clones, but the nearest triples - triples of nodes who all have the same (minimum) distance from each other? Or perhaps counting all such triples? Or - maybe! - counting all triples of the same type??

This latter seemed more plausible (for a tree, at least - not for an arbitrary graph as in the original Find the Nearest Clone!), and eventually I stripped it down to just two types: at the time, as you can see from the commit, I had a lame backstory in mind where the nodes were houses, and some houses had people in them, and you have to count the number of equidistant triples of people - something about the people in this town preferring to gather in triples and, out of a sense of fairness, wanted the journey time between any pairs of friend’s houses to be the same, or somesuch nonsense :slight_smile: It was to be called “Equilateral Treeangles” (groan).

Unusually for me, I came up with the basic approach - the Type A arrangement, countable using a DFS with the Small-to-Large optimisation used with descendent heights - almost straight away. My git logs from around the time explicitly mention this Hackerrank Problem, so I suppose by chance I was thinking about it at this time, and that steered my thoughts in the right direction. Of course, I didn’t get it completely correctly initially - I completely overlooked the existence of Type B arrangements (ha!) and also assumed that the “p” in the Type A case would always be a direct ancestor of X (lol). But the fact that I got this latter one wrong was actually good - it made the problem much deeper, as now it seemed to required Centroid Decomposition, too!

Actually implementing it didn’t take too long, though there was one heart-stopping moment where a test-run of my implementation took about 5 seconds on a testcase I generated (the git log at the time contains “I’ve gravely misjudged something, but don’t yet know what. Panic stations!”), but it turns out that I just had the direction of a “<” reversed and was doing the Large-to-Small Pessimisation instead of the Small-to-Large Optimisation XD

I then put the Problem on hold for a while, and took part in my first Codechef contest (AUG19B). One of the most heartbreaking things that can happen when you’ve invested a lot of time and effort in creating a Problem is to find that someone has beaten you to it, and so I nearly fell off my chair when I came across an AUG19B Problem where you had to count the number of ways of choosing 3 nodes (ulp!) on a tree (gulp!) satisfying certain constraints (yikes!), but after taking some deep breaths and forcing myself to read the Problem carefully, I saw that this “CHGORAM” Problem was completely different to mine and breathed a sigh of relief. Even better, as I realised a little later - I could ditch my naff backstory about triples of fair-minded friends and piggy-back on this one instead!

So: I had a working implementation with the “<” pointing the right way, and a backstory - time to get some proper tests written, which would prove to be by far the hardest part of the Problem creation.

My shiny new testcases were being passed with ease by my Editorial implementation, so I tested them against an obvious O(N^2) implementation: one where we consider each node in turn as the “centre” of the triple (the “X”) and find all nodes at given distance from X, reachable via different neighbours of X, and count the triples that can be formed by all of these that are suitable. Happily, the testcases TLE’d heavily on this naive approach, but then I suddenly (much later than most of you, I’m sure) noticed the Achilles’ heel for this Problem: if, at any point, we cannot find 3 nodes at distance d from X reachable by different children of X, then we can stop our search there - subsequent d's will reveal no further triples.

I added this small optimisation to the “cheat” O(N^2) solution and … it absolutely wrecked my testcases. Just blasted through all of them, taking barely more than 1 sec for the hardest. Back to the drawing board!

I had to exercise a bit of ingenuity (and up the node limit to 200,000 XD) to break this Cheat Solution, but this was still a bit of a crisis - if such a simple ploy could have obliterated my original testcases, then maybe a slightly less simple one would destroy the new ones? I’m forever complaining about weak testcases - mostly recently in this post - so I couldn’t submit a Problem that could be so easily cracked by sub-optimal solutions! So I tried to be as proactive about anticipating avenues that could be exploited as I could, and also wrote a few “blue-sky” tests (006-008) that violated some core assumptions that a sub-optimal solution might depend on. I couldn’t really come up with any concrete implementations that could use these assumptions to cheat, and so was quite close to deleting tests 006-008 but, as it happens, 008 at least turned out to be oddly effective :slight_smile: 015 was also a last-minute reaction to a possible attack that I thought of, but proved less useful - I don’t think anyone chose that possible means of assault.

Anyway, I had all the ingredients ready now so I submitted the Problem Idea, which was Approved very quickly and fast-tracked to appear in FEB20 Long where it would be one of the Div 1-only Problems - which made me nervous, as I was still unsure about how easily the “Achilles’ Heel” would permit sub-optimal solutions. Watching it get Tested by the Tester was very nerve-wracking, and for a while it looked like the Tester was going to trounce it in very short order, but in fact, it was here that the almost-deleted 008 first proved its mettle by being the Last Man Standing for one of the Tester’s early attempts!

After a while, I couldn’t stand the tension anymore and went to sleep. As expected, the Tester had broken it by the time I woke up, but it had put up a good fight, at least, and made him work for it! His assessment was “medium or hard (although probably closer to hard)”, which was reassuring. He did suggest to toughen up some of the tests, and this is where the mysterious testfiles 013 and 014 were born - they came about through setting up the testcase generator to generate trees with certainly parameters, and then doing a random parameter search to see which combination caused the Setter’s “cheat” implementation to time-out the hardest. To this day, I still don’t have the slightest clue how 013 and 014 actually work, but they proved very effective against the Tester’s solution, and indeed many solutions during the Contest!

My confidence in the Problem had been bolstered by now, so I was quite panicked when the first few hours of the Contest consisted of what seemed like an endless parade of people effortlessly dunking on the it XD I was debating writing this Problem History as a long “sorry my testcases sucked” apology, but eventually the parade dried up a little (and digging through these early submitter’s histories revealed them as either official 7*'s, or people who will surely be 7* after a few more contests - unless they elect to take the @just1star route, of course :)) and in the end far, far fewer people managed to complete it - approx 1/4 as many as the original CHGORAM, in fact :slight_smile: The 008 testfile again proved surprisingly effective, and I’m very glad I didn’t delete it.

From looking at some of the solutions, it quickly became clear that my solution was unnecessarily clumsy, so I’ve only given it a little bit of documentation compared to my normal solutions - please refer to the official Editorial instead :slight_smile:

Anyway, that’s that - all a bit of a roller-coaster, but I’m glad I finally got a Problem published somewhere and it didn’t turn out to be a disgrace :wink:

15 Likes

I tried your “cheap” O(n^2) algorithm, I thought it actually works faster and i just need to optimise the data structures.
Also, if you could, Preferably don’t gloss over the implementation details because I struggled to implement the code for lcas

1 Like

I didn’t understand half of what you wrote🤣
But it was pretty fun to read XD

3 Likes

Which implementation details - the O(N^2) version? If so, here it is :slight_smile:

2 Likes

No i meant that part for the editorial, sorry. I managed to implement the O(n^2) but struggled to implement the lca dp.

1 Like

Oh right - hopefully my commented cpp implementation plus @tmwilliamlin’s pending write-up will help to clear up any confusion :slight_smile:

1 Like

Hotels from POI XXI is very similar to this problem, and although an O(N^2) solution is good enough to solve it, one of the tester’s solutions is O(NlogN). This solution can be found here. I’ve used this code to solve the problem, and on skimming through some other contestant’s solutions, it seems a few of them have as well. Just want to bring this up before I’m accused of plagiarism.

2 Likes

The POI problem has an increased data range version on BZOJ.

Well written. Never could have thought that the story behind setting a problem could be so interesting! :slight_smile:

1 Like

For those who haven’t seen it, I created a video solution as a substitute until I finish writing up the editorial.

4 Likes

Editorial finished and updated!

3 Likes

That’s not mine - that’s the Tester’s again (which, admittedly, I submitted - but only to check its progress against the current testcases ) :slight_smile: Mine is in the first reply to this thread.

2 Likes

Oops, I took whatever the best submissions page on campus showed me lol

2 Likes

Hehe :slight_smile:

Updated with related problems

1 Like

+1 happens wayy too often with me XD

2 Likes

Is the current setter’s solution yours @ssjgz?

1 Like

Yes

1 Like

In editorial i did not understood following :

dp1_v,j+1 = dp1_v,j (and dp1_v,0 = 0) .

In above line , if dp1_v,0 = 0 then by equality dp1_v,j+1 = dp1_v,j , dp1_v,j = 0 for all j which is not possible .

Also i did not understood how you got dp1_v,j+1 = dp1_v,j . Suppose take following tree :

suppose vertex 1 is root .

for vertex ‘v’ marked , number of legs of length 1 from v is 2 whereas number of legs of length 2 from v is 3 .