COLORDIS - EDITORIAL

PROBLEM LINK:

Practice

Contest: Division 1

Contest: Division 2

Setter: Pritom Kundu

Tester: Teja Vardhan Reddy

Editorialist: Teja Vardhan Reddy

DIFFICULTY:

Medium Hard

PREREQUISITES:

Square root decomposition, bfs, fast lca queries

PROBLEM:

Given a colored and labelled tree.Initially a single node.Process the following queries

  • Add a vertex with color c to vertex u which is already existing in tree

  • Find the smallest distance between vertex u and a vertex colored c.

EXPLANATION

Let us name vertices as processed and unprocessed. We will also maintain all the powers of 2 parents of the vertices which will help us in lca queries to find the distance between two vertices. Let us say we have a parameter B (we will see later how to decide the value of B). For each vertex, we maintain a vector of pairs (color,distance) to store minimum distance from processed vertices (lets call this memory of that vertex).

Add query

  • let the color of new vertex be c.
  • We name the new vertex as unprocessed.
  • Build the powers of 2 parents of the new vertex.
  • we propagate the distances of processed vertices from parent to the new vertex by increasing all the distances by 1. This is correct because all the paths from processed vertices to new vertex must pass through parent (since new vertex is unprocessed).
  • If there are B unprocessed vertices of color c, then do bfs with all these vertices as sources (or we can also do bfs with all the vertices colored c as sources irrespective of whether they are processed or unprocessed) and find distances to all the vertices. For all the vertices , add the pair (c,distance obatained) to their respective memory. mark all the B unprocessed vertices of color c as processsed.

Answer query:
We need to find the distance from all the vertices colored c to vertex u. We will divide the solution into 2 cases

Case 1: All the processed vertices which are colored c.

Shortest distances from all these vertices is present in the memory, iterate through memory of u and take minimum of all the distances which are paired with color c.

Case 2: All the unprocessed vertices which are colored c.

Iterate through all the unprocessed vertices of color c and find minimum distance among all the vertices (for finding distance between two vertices u and v, we find lca of the vertices u and v. distance between u and v = depth[u] + depth[v] - 2* depth[lca(u,v)]).

We take the minimum of both cases as the answer. And for the -1 case, we can maintain an array which has true for all the colors present in the tree and false for rest. So, we can use that array to check if there is a vertex with color c in the tree or not.

TIME COMPLEXITY:

Add query

  • Building powers of 2 parents of the new vertex takes O(log(n)) time.This takes place for all the vertices once. Hence total contribution of this step to time complexity is O(nlog(n))
  • propagating memory from parent to child take O( size(memory of a vertex) ) time. Total contribution of this step to time complexity is O(n * size(memory of a vertex) ).
  • If there are B unprocessed vertices of color c, then this step takes O(n) complexity. Since each vertex is initially unprocessed, and as we do this step exactly B unprocessed vertices change their state to processed. Total number of times this step can take place is O(n/B). Hence, its contribution to total time complexity is O(n^2/B). We can also see size(memory of a vertex) increases by 1 each time this step takes place. Hence O(size(memory of a vertex)) = O(n/B).

Answer query:
Case 1: this step takes O( size(memory of a vertex) ) = O(n/B). Across all the queries this step takes O(q*n/B).

Case 2: The number of vertices of color c which are unprocessed are atmost B (because we process immediately when B vertices are reached). Hence complexity of this step is O(B* log(n)). Across all the queries this contributes O(q*B*log(n)).

Now, maximum value of n is q. So lets assume n=q so as to estimate worst case complexity.

Total time complexity: O(nlog(n)) + O(n^2/B) + O(n^2/B) + O(q*n/B) + O(q*B*log(n)).

= O(nlog(n) + n^2/B + n*B*log(n)) (Assuming q=n).

Now we need to choose a value of parameter B such that the above expression gets minimised. so let us differentiate the above expression by B and equate to zero.
we get B = \sqrt{n/log(n)}.

Hence total complexity is O(nlog(n) + n\sqrt{nlog(n)}) = O(n\sqrt{nlog(n)}) for B = O(\sqrt{n/log(n)}).

SOLUTIONS:

Setter's Solution
#include<bits/stdc++.h>
using namespace std;
 
const int N = 2e5+7, K = 20;
const int RT = 300;
 
int n;
vector<int> adj[N];
vector<int> group[N];
vector<int> rem[N];
int col[N], par[N], level[N];
int anc[N][K];
vector<int> dis[N];
vector<int> big;
 
void recalc(int c) {
    if (dis[c].empty())     big.push_back(c);
    dis[c].assign(n+1, N);
 
    queue<int> q;
    for (int x: group[c]) {
        q.push(x);
        dis[c][x] = 0;
    }
 
    while (q.size()) {
        int u = q.front();
        q.pop();
 
        for (int v: adj[u])
            if (dis[c][v] == N) {
                dis[c][v] = dis[c][u] + 1;
                q.push(v);
            }
    }
    rem[c].clear();
}
 
void add(int p, int c) {
    n++;
    col[n] = c;
    par[n] = p;
 
    adj[p].push_back(n);
    adj[n].push_back(p);
 
    level[n] = 1 + level[p];
    anc[n][0] = p;
    for (int k=1; k<K; k++)
        anc[n][k] = anc[anc[n][k-1]][k-1];
 
    for (int x: big)
        dis[x].push_back(1+dis[x][p]);
 
    group[c].push_back(n);
    rem[c].push_back(n);
    if (rem[c].size() >= RT)
        recalc(c);
}
 
int lca(int u, int v) {
    if(level[u] > level[v]) swap(u, v);
    for (int k=K-1; k>=0; k--) {
        if (level[u]+(1<<k) <= level[v])
            v = anc[v][k];
    }
    assert(level[u] == level[v]);
    if (u==v)   return u;
 
    for (int k=K-1; k>=0; k--) {
        if (anc[u][k] != anc[v][k]) {
            u = anc[u][k];
            v = anc[v][k];
        }
    }
    assert(par[u] == par[v]);
    return par[u];
}
 
int distance(int u, int v) {
    return level[u] + level[v] - 2*level[lca(u, v)];
}
 
int query(int u, int c) {
    int ans = (dis[c].size() ? dis[c][u] : N);
    for (int x: rem[c])
        ans = min(ans, distance(u, x));
    if (ans == N)   ans = -1;
    return ans;
}
 
void clear() {
    for (int i=0; i<N; i++) {
        adj[i].clear();
        group[i].clear();
        dis[i].clear();
        rem[i].clear();
    }
    big.clear();
    n=1;
}
 
int main() {
    ios::sync_with_stdio(0);
    cin.tie(0);
 
    int t;
    cin>>t;
    while (t--) {
        clear();
        int q, c;
        cin>>q>>c;
        col[1] = c;
        group[c].push_back(1);
        rem[c].push_back(1);
        
        int lastans = -1;
        
        while (q--) {
            char ch;
            int u, c;
            cin>>ch>>u>>c;
            
            u^=lastans+1;
            c^=lastans+1;
 
            if (ch == '?')  cout<<(lastans = query(u, c))<<"\n";
            else            add(u, c);
        }
    }
}
Tester's Solution
//teja349
#include <bits/stdc++.h>
#include <vector>
#include <set>
#include <map>
#include <string>
#include <cstdio>
#include <cstdlib>
#include <climits>
#include <utility>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <iomanip>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp> 
//setbase - cout << setbase (16); cout << 100 << endl; Prints 64
//setfill -   cout << setfill ('x') << setw (5); cout << 77 << endl; prints xxx77
//setprecision - cout << setprecision (14) << f << endl; Prints x.xxxx
//cout.precision(x)  cout<<fixed<<val;  // prints x digits after decimal in val
 
using namespace std;
using namespace __gnu_pbds;
 
#define f(i,a,b) for(i=a;i<b;i++)
#define rep(i,n) f(i,0,n)
#define fd(i,a,b) for(i=a;i>=b;i--)
#define pb push_back
#define mp make_pair
#define vi vector< int >
#define vl vector< ll >
#define ss second
#define ff first
#define ll long long
#define pii pair< int,int >
#define pll pair< ll,ll >
#define sz(a) a.size()
#define inf (1000*1000*1000+5)
#define all(a) a.begin(),a.end()
#define tri pair<int,pii>
#define vii vector<pii>
#define vll vector<pll>
#define viii vector<tri>
#define mod (1000*1000*1000+7)
#define pqueue priority_queue< int >
#define pdqueue priority_queue< int,vi ,greater< int > >
#define flush fflush(stdout) 
#define primeDEN 727999983
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
 
// find_by_order()  // order_of_key
typedef tree<
int,
null_type,
less<int>,
rb_tree_tag,
tree_order_statistics_node_update>
ordered_set;
 
int BLOCKSIZE = 350;
int n=1;
int par[412345][20],c[412345],dep[412345];
vector<vii> remem(412345);
vector<vi> vec(412345),adj(412345);
 
int pushdown(int node){
    remem[node].clear();
    int pp=par[node][0];
    pii papa;
    int i;
    rep(i,remem[pp].size()){
        papa=remem[pp][i];
        papa.ss++;
        remem[node].pb(papa);
    }
    return 0;
}
int dis[412345],que[412345];
 
int doupdate(int col){
    int i;
    rep(i,n){
        dis[i]=inf;
    }
    int st=0,en=0,cur;
    rep(i,vec[col].size()){
        que[en++]=vec[col][i];
        dis[vec[col][i]]=0;
    }
 
    while(st!=en){
        cur=que[st];
        st++;
        rep(i,adj[cur].size()){
            if(dis[adj[cur][i]]==inf){
                dis[adj[cur][i]]=dis[cur]+1;
                que[en++]=adj[cur][i];
            }
        }
    }
    rep(i,n){
        remem[i].pb(mp(col,dis[i]));
    }
    vec[col].clear();
    return 0;
}
int boo[412345];
int getlca(int u,int v){
    if(dep[u]>dep[v]){
        swap(u,v);
    }
    int i;
    fd(i,19,0){
        if(dep[v]-(1<<i)>=dep[u])
            v=par[v][i];
    }
    if(u==v)
        return u;
    fd(i,19,0){
        if(par[u][i]!=par[v][i]){
            u=par[u][i];
            v=par[v][i];
        }
    }
    return par[u][0];
}
int getdist(int u,int v){
    return dep[u]+dep[v]-2*dep[getlca(u,v)];
}
int main(){
    std::ios::sync_with_stdio(false); cin.tie(NULL);
    int t;
    cin>>t;
    int hh;
    rep(hh,412345){
        boo[hh]=inf;
    }
    while(t--){
        int q,val;
        cin>>q>>val;
        int col,u;
        int i,j;
        c[0]=val;
        vec[c[0]].clear();
        adj[0].clear();
        boo[c[0]]=t;
        dep[0]=0;
        vec[c[0]].pb(0);
        n=1;
        int lastans=-1;
        rep(i,20){
            par[0][i]=-1;
        }
        string ch;
        rep(i,q){
            cin>>ch>>u>>col;
            u^=lastans+1;
            col^=lastans+1;
            u--;
            if(ch=="+"){
                if(boo[col]!=t){
                    vec[col].clear();
                    boo[col]=t;
                }
                dep[n]=dep[u]+1;
                adj[n].clear();
                c[n]=col;
                adj[u].pb(n);
                adj[n].pb(u);
              optmised solution  //cout<<n<<" "<<u<<endl;
                par[n][0]=u;
                f(j,1,20){
                    if(par[n][j-1]!=-1)
                        par[n][j]=par[par[n][j-1]][j-1];
                    else
                        par[n][j]=-1;
                }
                vec[col].pb(n);
                pushdown(n);
                n++;
                if(vec[col].size()==BLOCKSIZE){
                    doupdate(col);
                }
            }
            else{
                int mini=inf;
                if(boo[col]!=t){
                    lastans=-1;
                    cout<<lastans<<endl;
                    continue;
                }
                rep(j,remem[u].size()){
                    if(remem[u][j].ff==col){
                        mini=min(mini,remem[u][j].ss);
                    }
                }
                rep(j,vec[col].size()){
                    mini=min(getdist(u,vec[col][j]),mini);
                }
                if(mini==inf)
                    mini=-1;
                cout<<mini<<endl;
                lastans=mini;
            }
        }
        rep(i,n+10){
            remem[i].clear();
        }       
    }
    cerr << "\nTime elapsed: " << 1000 * clock() / CLOCKS_PER_SEC << "ms\n";
    return 0;
    return 0;   
}
Optmised Solution
#include<bits/stdc++.h>
using namespace std;
const int MAXN = 200007;
const int LOGN = 18;
 
int lvl[MAXN], pr[MAXN][LOGN];
 
int lca(int u, int v)
{
    if (lvl[u] < lvl[v]) swap(u, v);
    for (int k = LOGN-1; k >= 0; k--) {
        if (lvl[u]-(1<<k) >= lvl[v]) {
            u = pr[u][k];
        }
    }
    if (u==v) return u;
    assert(lvl[u]==lvl[v]);
    for (int k = LOGN-1; k >= 0; k--) {
        if (pr[u][k] != pr[v][k]) {
            u = pr[u][k];
            v = pr[v][k];
        }
    }
    u = pr[u][0];
    v = pr[v][0];
    assert(u==v);
    return u;
}
 
int distance(int u, int v)
{
    return lvl[u]+lvl[v]-2*lvl[lca(u, v)];
}
 
vector<int>edg[MAXN];
vector<int>occur[MAXN];
int color[MAXN];
 
const int TOLERANCE = 350;
vector<int>dst[MAXN];
set<int>calculated;
int n;
 
void bfs(int c)
{
    if (dst[c].empty()) {
        dst[c].assign(n+1, MAXN);
        calculated.insert(c);
    }
 
    queue<int>q;
    for (int u : occur[c]) {
        q.push(u);
        dst[c][u] = 0;
    }
    occur[c].clear();
 
    while (!q.empty()) {
        int u = q.front(); q.pop();
        for (int v : edg[u]) {
            if (dst[c][v] > dst[c][u]+1) {
                dst[c][v] = dst[c][u]+1;
                q.push(v);
            }
        }
    }
}
 
void addNode(int u, int p, int c)
{
    edg[u].clear();
    color[u] = c;
    occur[c].push_back(u);
    lvl[u] = lvl[p]+1;
 
    if (p > 0) {
        edg[u].push_back(p);
        edg[p].push_back(u);
    }
 
    for (int x : calculated) {
        dst[x].push_back(dst[x][p]+1);
    }
 
    pr[u][0] = p;
    for (int k = 1; k < LOGN; k++) {
        pr[u][k] = pr[pr[u][k-1]][k-1];
    }
}
 
int query(int u, int c)
{
    if (occur[c].empty() && dst[c].empty()) return -1;
    if (occur[c].size() > TOLERANCE) bfs(c);
    int ans = MAXN;
    for (int v : occur[c]) ans = min(ans, distance(u, v));
    if (!dst[c].empty()) ans = min(ans, dst[c][u]);
 
    return ans;
}
 
int main()
{
    ios::sync_with_stdio(false);
    cin.tie(0);
 
    int t;
    cin >> t;
 
    for (int ti = 1; ti <= t; ti++) {
        int q, c;
        cin >> q >> c;
 
        set<int>cl;
        addNode(1, 0, c);
        n = 1;
 
        cl.insert(c);
        int last = -1;
        while (q--) {
            char cmd;
            int u, c;
            cin >> cmd >> u >> c;
 
            u ^= last+1;
            c ^= last+1;
            if (cmd=='+') {
                ++n;
                addNode(n, u, c);
                cl.insert(c);
            } else {
                last = query(u, c);
                cout << last << "\n";
            }
        }
 
        for (int c : cl) {
            occur[c].clear();
            dst[c].clear();
        }
        calculated.clear();
    }
 
 
    return 0;
}

Feel free to Share your approach, If it differs. Suggestions are always welcomed. :slight_smile:

2 Likes