HEALTHYTREE - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: shubham_grg
Tester: iceknight1093
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

DFS, LCA, point-add range-sum data structures (like segment trees/fenwick trees)

PROBLEM:

Given a tree rooted at vertex 1, answer Q queries on it.

  1. Given a vertex u, mark it if it is unmarked and vice versa.
  2. Given a vertex u, find the result of the following process:
    • Consider a person standing at each marked vertex.
    • In one second, they all simultaneously move one step up to the parent of their current vertex (or don’t move, if they’re already at 1)
    • If multiple people enter a vertex at the same time, nothing happens.
    • If a single person enters a vertex at a given time, add 1 to the score of this vertex.
    • Find the score of vertex u at the end of this process.

EXPLANATION:

Let’s try to solve subtask 1 first, which will give us intuition for the rest of the problem.

Suppose we have a certain set of marked vertices. Which of them will contribute to the score of 1?
Let’s look at a person standing at node u.
Let d_u denote the depth of vertex u, i.e, its distance from 1.
Then, the person starting at u will reach u at exactly time d_u.

So, the score of 1 will increase at time d_u if and only if there’s no other vertex at distance d_u from 1.
Putting it another way, \text{score}_1 will simply equal the number of depths such that there’s exactly one marked vertex at that depth.

This is enough to solve subtask 1: keep a count of the number of marked vertices at each depth, and how many of the depths have a count of 1.
When a vertex is marked/unmarked, only the contribution of this depth to the answer can change, so the answer can be updated in constant time.

Note that there’s one catch here: vertex 1 itself (or rather, depth 0) shouldn’t be counted in the answer, even if it’s marked - this is because a person starting at vertex 1 won’t contribute to its score, and will instead leave the tree immediately.


The above idea easily generalizes to any vertex u.
Given a state of the tree, \text{score}_u will be the number of depths such that there’s exactly one marked vertex at this depth, within the subtree of u.
Of course, this subtree should exclude u itself; since a person starting at u won’t contribute to it since they’ll immediately move up.

This gives us a way to solve subtask 2.
Since all queries happen only after all updates, the state of the tree is fixed.

This allows us to answer queries offline, i.e, in any order we want - as long as we save all the answers, we can print them in the correct order in the end.

So, suppose we knew \text{freq}[u][d] - the number of marked vertices within the subtree of u that are at depth d - for every u and d.
Then, answering queries would be fairly easy: for u, just find the number of d such that \text{freq}[u][d] = 1.

Ok, but how do I compute this fast?

Use small-to-large merging!

That is, let \text{freq}[u] be a map, which is initially empty for all u.
Now, perform a DFS on the tree.
When you’re at a vertex u, for each of its children v,

  • First, recursively solve for v.
  • Then, merge the frequency table of v into that of u.
    • Quite simply, for each d such that \text{freq}[v][d] \gt 0, do the operation \text{freq}[u][d] += \text{freq}[v][d].

This is still quadratic, but this is where small-to-large merging comes into play: instead of merging v into u, merge whichever one is smaller into the larger one.
This makes the total number of merges we perform \mathcal{O}(N\log N), which is good enough for our purposes.
This blog is a tutorial for the technique, if you haven’t seen it before.

Note that you can also store the number of vertices that contain only one vertex; and updating this is easy when merging, so once the entire process is done for u you already have the answer for it!

This way, subtask 2 is solved in \mathcal{O}(N\log N) or \mathcal{O}(N\log^2 N) time.


The final subtask, unfortunately, doesn’t directly generalize from the second.

Instead, we take a different perspective.
As we observed, for each vertex u, \text{score}_u is determined by some of the vertices in the subtree of u.
Let’s turn this around. Say we fix a vertex v - to which vertices does v contribute to the score of?

Clearly, v can contribute to \text{score}_u only if u is an ancestor of v.
So, let u be an ancestor of v.
Then, if v doesn’t contribute to \text{score}_u, that would mean that some other marked vertex in the subtree of u has the same depth as v.
However, this would also mean that v will not contribute to the score of anything above u either!

In other words, v will contribute to the score of par(v), par(par(v)), \ldots till a certain point when it stops - more specifically, v will contribute to the score of some path that starts at par(v) and goes upwards.

The next question is, how do we find this path?
Specifically, we want to find the lowest ancestor of v, such that the subtree of this ancestor contains some other vertex whose depth equals \text{dep}_v.

This lowest ancestor can be found by looking at the DFS in-times of all the marked vertices with depth \text{dep}_v.

How?

Consider the set of all marked vertices at depth \text{dep}_v, sorted by their DFS in-times.

Let v_1 and v_2 be the neighbors of v in this order.
Then, when moving upwards, the first same-depth vertex that v will encounter will definitely be either v_1 or v_2.

Proof

To prove this, we’ll use the Euler tour of a tree.
That is, compute DFS in- and out-times for every vertex.

Then, it’s well known that for two vertices u and v,

  • If u is an ancestor of v, the range [in_u, out_u] will completely contain [in_v, out_v].
  • If v is an ancestor of u, the opposite happens.
  • If neither u nor v is an ancestor of the other, the ranges [in_u, out_u] and [in_v, out_v] will be completely disjoint.

Now, look at all the vertices at depth \text{dep}_v, and the ranges corresponding to them.
None of these vertices can be ancestors of any other, so all these ranges will be disjoint.

Then, when some vertex u contains two of these ranges (including [in_v, out_v]), clearly it’ll have to include either the range closest to the left or to the right of [in_v, out_v].
But the vertices corresponding to those ranges are exactly v_1 and v_2 as described above anyway!

Further, clearly v will first meet v_1 at \text{lca}(v, v_1), and v_2 at \text{lca}(v, v_2).
So, the required vertex is simply whichever of these two LCAs has higher depth!

Let’s recap.
For a fixed vertex v, we know that:

  • v will add 1 to the score of some vertices in a path starting at par(v) and going upwards.
  • To find the top of this path, we only need to look at the neighbors of v in the list of marked vertices of depth \text{dep}_v, when sorted by DFS in-time.

Now, observe that vertex v is marked or unmarked, the only vertices whose paths can possibly change are v, and its same-depth marked neighbors - which is at most three vertices!

With this idea in mind, let’s move on to the full solution.
Let \text{score}_u denote the score of u. Initially, this is 0 for all vertices.
Let S_d denote the sorted (by DFS in-time) list of marked vertices at depth d.
Then,

  • When v gets marked, look at the neighbors of v in S_{\text{dep}_v}.
    Remove their existing paths; then insert v into S_{\text{dep}_v} and then recompute the paths for these (at most) three vertices.
  • When v gets unmarked, do the opposite: remove the existing paths for v and its neighbors, remove v from its corresponding set, and then recompute the paths for its neighbors.

S_d needs to be a sorted structure that supports quick insertion, deletion, and finding the next/previous element of a given element.
The obvious choice is std::set here.

Finally, note that we maintain a set of paths, and when asked, need to know how many of them cover a given vertex u.
That’s a very standard problem, with a variety of solutions.
For instance,

  • Path additions and point queries can be converted into point additions and subtree queries, and subtree queries can be done easily using a segment tree after building the euler tour of the tree.
How?

Suppose you want to add x to the path from u to v, where u is an ancestor of v.
Then, add x to the value of v and subtract x from the value of par(u).

To get the value of a vertex y, simply query for the subtree sum of y.
This works because:

  • For any path that doesn’t pass through y, either both endpoints lie in its subtree, or both lie outside.
    Either way, the overall contribution of such a path to the subtree sum of y will be 0.
  • For any path that does pass through y, only the lower vertex will lie in its subtree; so its value will be added to the subtree sum as desired.
  • Alternately, you can use heavy-light decomposition or centroid decomposition for \mathcal{O}(\log^2 N) per query.
    Depending on your implementation, you may have constant-factor issues from the extra log factor.

TIME COMPLEXITY:

\mathcal{O}((N+Q)\log N) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
// #include<ext/pb_ds/assoc_container.hpp>
// #include<ext/pb_ds/tree_policy.hpp>
// using namespace __gnu_pbds;
// template<class T> using oset =tree<T, null_type, less<T>, rb_tree_tag,tree_order_statistics_node_update> ;//oset name name.order_of_key(#ele<k) *name.find_by_order(index) less_equal greater greater_equal

#define vi vector<int>
#define pii pair<int, int>
#define mii map<int, int>
#define int long long
#define ld long double
#define pb push_back
#define all(v) v.begin(), v.end()
#define yes cout << "YES" \
                 << "\n";
#define no cout << "NO" \
                << "\n";
#define nl "\n"
#define FastIO                    \
    ios_base::sync_with_stdio(0); \
    cin.tie(0);                   \
    cout.tie(0)
#define mod 1000000007
const int oo = 1e18;
int jj = 0;
vi euler;
vi tin;
vi tout;
struct LCA
{
    vector<int> height, euler, first, segtree;
    vector<bool> visited;
    int n;

    LCA(vector<vector<int>> &adj, int root = 0)
    {
        n = adj.size();
        height.resize(n);
        first.resize(n);
        euler.reserve(n * 2);
        visited.assign(n, false);
        dfs(adj, root);
        int m = euler.size();
        segtree.resize(m * 4);
        build(1, 0, m - 1);
    }

    void dfs(vector<vector<int>> &adj, int node, int h = 0)
    {
        visited[node] = true;
        height[node] = h;
        first[node] = euler.size();
        euler.push_back(node);
        for (auto to : adj[node])
        {
            if (!visited[to])
            {
                dfs(adj, to, h + 1);
                euler.push_back(node);
            }
        }
    }

    void build(int node, int b, int e)
    {
        if (b == e)
        {
            segtree[node] = euler[b];
        }
        else
        {
            int mid = (b + e) / 2;
            build(node << 1, b, mid);
            build(node << 1 | 1, mid + 1, e);
            int l = segtree[node << 1], r = segtree[node << 1 | 1];
            segtree[node] = (height[l] < height[r]) ? l : r;
        }
    }

    int query(int node, int b, int e, int L, int R)
    {
        if (b > R || e < L)
            return -1;
        if (b >= L && e <= R)
            return segtree[node];
        int mid = (b + e) >> 1;

        int left = query(node << 1, b, mid, L, R);
        int right = query(node << 1 | 1, mid + 1, e, L, R);
        if (left == -1)
            return right;
        if (right == -1)
            return left;
        return height[left] < height[right] ? left : right;
    }

    int lca(int u, int v)
    {
        int left = first[u], right = first[v];
        if (left > right)
            swap(left, right);
        return query(1, 0, euler.size() - 1, left, right);
    }
};
int cnt=0;
void dfs(vector<vi> &adj, vi &vis, vector<int> &depth, vi &par, int node, int dep)
{
    // cerr<<node<<" ";
    // cnt++;cerr<<cnt<<" ";
    vis[node] = 1;
    depth[node] = dep;
    euler.pb(node);
    tin[node] = jj;
    jj++;
    for (auto it : adj[node])
    {
        if (vis[it] == 0)
        {
            par[it] = node;
            dfs(adj, vis, depth, par, it, dep + 1);
        }
    }
    euler.pb(node);
    tout[node] = jj;
    jj++;
}
//-----------------------------------SEGMENT TREE----------------------------------------//
class SegmentTree
{
private:
    vector<int> tree;
    int n;

    // Function to build the segment tree
    void build(const vector<int> &arr, int v, int tl, int tr)
    {
        if (tl == tr)
        {
            tree[v] = arr[tl];
        }
        else
        {
            int tm = (tl + tr) / 2;
            build(arr, v * 2, tl, tm);
            build(arr, v * 2 + 1, tm + 1, tr);
            tree[v] = tree[v * 2] + tree[v * 2 + 1]; // Modify this line based on the query
        }
    }

public:
    SegmentTree(const vector<int> &arr)
    {
        n = arr.size();
        tree.resize(4 * n); // Adjust the size based on the maximum size of your input array
        build(arr, 1, 0, n - 1);
    }

    // Function to update a value at index idx to val
    void update(int idx, int val)
    {
        update(1, 0, n - 1, idx, val);
    }

    // Function to query the range [l, r]
    int query(int l, int r)
    {
        return query(1, 0, n - 1, l, r);
    }

private:
    // Function to update a value at index idx to val in the segment tree
    void update(int v, int tl, int tr, int idx, int val)
    {
        if (tl == tr)
        {
            tree[v] = val;
        }
        else
        {
            int tm = (tl + tr) / 2;
            if (idx <= tm)
            {
                update(v * 2, tl, tm, idx, val);
            }
            else
            {
                update(v * 2 + 1, tm + 1, tr, idx, val);
            }
            tree[v] = tree[v * 2] + tree[v * 2 + 1]; // Modify this line based on the query
        }
    }

    // Function to query the range [l, r] in the segment tree
    int query(int v, int tl, int tr, int l, int r)
    {
        if (l > r)
        {
            return 0; // Modify this line based on the query
        }
        if (l == tl && r == tr)
        {
            return tree[v];
        }
        int tm = (tl + tr) / 2;
        return query(v * 2, tl, tm, l, min(r, tm)) + query(v * 2 + 1, tm + 1, tr, max(l, tm + 1), r);
        // Modify the line above based on the query
    }
};

void solve()
{
    int n, q;
    cin >> n >> q;
    jj = 1;
    tin.resize(n + 2);
    tout.resize(n + 2);
    euler.clear();
    tin[0] = 0;
    tout[0] = 0;
    vector<vi> adj(n + 1);
    for (int i = 0; i < n - 1; i++)
    {
        int a, b;
        cin >> a >> b;
        adj[a].pb(b);
        adj[b].pb(a);
    }
    adj[0].pb(1);
    adj[1].pb(0);
    vi vis(n + 1, 0);
    vi par(n + 1, 0);
    vector<int> depth(n + 1,0);
    dfs(adj, vis, depth, par, 0, 0);
    // cerr<<"hi";
    LCA tree(adj, 0);
    vi eu(euler.size() + 1, 0);
    SegmentTree eul(eu);
    vector<set<vi>> node_depth(n + 2);
    vi rev(n + 1, 0);int p=0;
    auto myLambda = [&](auto ptr,int dep, int add) {
            
            int p1 = 0, p2 = 0, p3 = 0;
            if (ptr != node_depth[dep].begin())
            {
                ptr--;
                p1 = (*ptr)[1];
                ptr++;
            }
            p2 = (*ptr)[1];
            if (ptr != (--node_depth[dep].end()))
            {
                ptr++;
                p3 = (*ptr)[1];
                ptr--;
            }

            int lca1 = tree.lca(p1, p2);
            int lca2 = tree.lca(p2, p3);
            if (depth[lca1] > depth[lca2])
            {   
                int cur_lca = lca1;
                if (cur_lca == p2)
                {
                }
                else
                {
                    p1 = par[p1];
                    p2 = par[p2];

                // cout<<p1<<p2<<"helo"<<nl;
                    // eu[tin[p1]] += (*ptr)[2]*add;
                    eu[tin[p2]] += (*ptr)[2]*add;
                    eu[tin[cur_lca]] -= 1 * (*ptr)[2]*add;
                    eul.update(tin[p1], eu[tin[p1]]);
                    eul.update(tin[p2], eu[tin[p2]]);
                    eul.update(tin[cur_lca], eu[tin[cur_lca]]);
                }
            }
            else
            {
                int cur_lca = lca2;
                if (cur_lca == p2)
                {
                }
                else
                {
                    p3 = par[p3];
                    p2 = par[p2];

                    // eu[tin[p3]] += (*ptr)[2]*add;
                    eu[tin[p2]] += (*ptr)[2]*add;
                    eu[tin[cur_lca]] -= 1 * (*ptr)[2]*add;
                    eul.update(tin[p2], eu[tin[p2]]);
                    eul.update(tin[p3], eu[tin[p3]]);
                    eul.update(tin[cur_lca], eu[tin[cur_lca]]);
                }
            }

    };

    vi person_presence(n+1,0);
    for (int i = 0; i < q; i++)
    {
        int type;
        cin >> type;
        if (type == 1)
        {
            int cq;
            cin >> cq;
            // cerr<<tin[cq]<<" "<<tout[cq];
            cout << eul.query(tin[cq], tout[cq]) << nl;
// cerr<<"hi";
        }
        else
        {
                int node_p, val_p=1;
                cin >> node_p;

            char c;
          if(person_presence[node_p]) {person_presence[node_p]=0; c='-';}
          else{person_presence[node_p]=1;c='+';}
            if (c == '+')
            {
                // cerr<<q<<nl;
                int dep = depth[node_p];

                auto ptr = node_depth[dep].upper_bound({tin[node_p], node_p, val_p, 0});
                if(ptr!=node_depth[dep].end()){myLambda(ptr,dep,-1);}
                if(ptr!=node_depth[dep].begin()){ptr--;myLambda(ptr,dep,-1);ptr++;}

                node_depth[dep].insert({tin[node_p], node_p, val_p, 0});


                ptr=node_depth[dep].upper_bound({tin[node_p], node_p, val_p, 0});
                if(ptr!=node_depth[dep].end()){myLambda(ptr,dep,1);}
                
                // cout<<i<<" "<<eul.query(tin[1], tout[1])<<" "<<eul.query(tin[2], tout[2])<<" "<<eul.query(tin[3], tout[3])<<" "<<eul.query(tin[4], tout[4])<<" "<<eul.query(tin[5], tout[5])<<nl;
                ptr--;myLambda(ptr,dep,1);
                if(ptr!=node_depth[dep].begin()){ptr--;myLambda(ptr,dep,1);ptr++;}

                


            }
            else
            {
                // cerr<<"hi";
                int dep = depth[node_p];
                auto ptr=node_depth[dep].upper_bound({tin[node_p], node_p, val_p, 0});
                if(ptr!=node_depth[dep].end()){myLambda(ptr,dep,-1);}
                ptr--;myLambda(ptr,dep,-1);
                if(ptr!=node_depth[dep].begin()){ptr--;myLambda(ptr,dep,-1);ptr++;}

                node_depth[dep].erase(ptr);

                ptr=node_depth[dep].upper_bound({tin[node_p], node_p, val_p, 0});
                if(ptr!=node_depth[dep].end()){myLambda(ptr,dep,1);}
                if(ptr!=node_depth[dep].begin()){ptr--;myLambda(ptr,dep,1);ptr++;}


            }
        }
    }
}
signed main()
{
    FastIO;
    // freopen("P9_8.in", "r", stdin);
    // freopen("P9_8.out", "w", stdout);
    int test = 1;
    cin >> test;
    while (test--)
    {
        solve();
    }
    return 0;
}
 
Editorialist's code (C++)
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());

template<class T>
struct RMQ {
    vector<vector<T>> jmp;
    RMQ(const vector<T>& V) : jmp(1, V) {
        for (int pw = 1, k = 1; pw * 2 <= (int)size(V); pw *= 2, ++k) {
            jmp.emplace_back(size(V) - pw * 2 + 1);
            for (int j = 0; j < (int)size(jmp[k]); ++j)
                jmp[k][j] = min(jmp[k - 1][j], jmp[k - 1][j + pw]);
        }
    }
    T query(int a, int b) {
        assert(a < b); // or return inf if a == b
        int dep = 31 - __builtin_clz(b - a);
        return min(jmp[dep][a], jmp[dep][b - (1 << dep)]);
    }
};

struct LCA {
    int T = 0;
    vector<int> time, path, ret, out, depth;
    RMQ<int> rmq;

    LCA(vector<vector<int>>& C) : time(size(C)), out(size(C)), depth(size(C)), rmq((dfs(C,0,-1), ret)) {}
    void dfs(vector<vector<int>>& C, int v, int par) {
        time[v] = T++;
        for (int y : C[v]) if (y != par) {
            depth[y] = 1 + depth[v];
            path.push_back(v), ret.push_back(time[v]);
            dfs(C, y, v);
        }
        out[v] = T;
    }

    int lca(int a, int b) {
        if (a == b) return a;
        tie(a, b) = minmax(time[a], time[b]);
        return path[rmq.query(a, b)];
    }
    int dist(int a, int b) {
        return depth[a] + depth[b] - 2*depth[lca(a,b)];
    }
};

template<class T, T unit = T()>
struct SegTree {
    T f(T a, T b) { return a+b; }
    vector<T> s; int n;
    SegTree(int _n = 0, T def = unit) : s(2*_n, def), n(_n) {}
    void update(int pos, T val) {
        for (s[pos += n] += val; pos /= 2;)
            s[pos] = f(s[pos * 2], s[pos * 2 + 1]);
    }
    T query(int b, int e) {
        T ra = unit, rb = unit;
        for (b += n, e += n; b < e; b /= 2, e /= 2) {
            if (b % 2) ra = f(ra, s[b++]);
            if (e % 2) rb = f(s[--e], rb);
        }
        return f(ra, rb);
    }
};

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    int t; cin >> t;
    while (t--) {
        int n, q; cin >> n >> q;
        vector adj(n+1, vector<int>());
        for (int i = 0; i < n-1; ++i) {
            int u, v; cin >> u >> v;
            adj[u].push_back(v);
            adj[v].push_back(u);
        }
        adj[0].push_back(1);
        adj[1].push_back(0);
        LCA L(adj);
        vector<set<array<int, 2>>> active(n+1);
        SegTree<int> seg(n+1);
        vector<int> mark(n+1);
        auto &in = L.time, &out = L.out, &dep = L.depth;
        vector<int> par(n+1);
        for (int i = 0; i <= n; ++i) for (int u : adj[i])
            if (dep[u] == dep[i]-1) par[i] = u;

        auto upd = [&] (int u, int type) {
            int d = L.depth[u];
            int till = 0;
            active[d].insert({in[u], u});
            auto it = active[d].find({in[u], u});
            if (it != begin(active[d])) {
                auto [pos, v] = *prev(it);
                int l = L.lca(u, v);
                if (dep[l] > dep[till]) till = l;
            }
            if (next(it) != end(active[d])) {
                auto [pos, v] = *next(it);
                int l = L.lca(u, v);
                if (dep[l] > dep[till]) till = l;
            }
            seg.update(in[par[u]], type);
            seg.update(in[till], -type);
        };

        while (q--) {
            int type, u; cin >> type >> u;
            if (type == 1) {
                cout << seg.query(in[u], out[u]) << '\n';
            }
            else {
                int d = L.depth[u];
                auto it1 = active[d].upper_bound({L.time[u], u});
                auto it2 = active[d].lower_bound({L.time[u], u});
                int v1 = -1, v2 = -1;
                if (it1 != end(active[d])) v1 = it1->at(1);
                if (it2 != begin(active[d])) v2 = prev(it2)->at(1);

                if (v1 != -1) upd(v1, -1);
                if (v2 != -1) upd(v2, -1);

                if (mark[u]) {
                    upd(u, -1);
                    active[d].erase({in[u], u});
                }
                else {
                    active[d].insert({in[u], u});
                    upd(u, 1);
                }
                mark[u] ^= 1;

                if (v1 != -1) upd(v1, 1);
                if (v2 != -1) upd(v2, 1);
            }
        }
    }
}
1 Like

I might be wrong, but I think we should add x to the value of par(v) and we should subtract it from the value of u?

This is quoted from the last spoiler above time complexity.

Yep, you’re right - it seems I accidentally switched up which of u and v was the ancestor.
Fixed now, thanks!