SMQRY - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: kingmessi
Tester & Editorialist: iceknight1093

DIFFICULTY:

3229

PREREQUISITES:

Binary lifting, small-to-large merging

PROBLEM:

There are N vertices, initially all disconnected.
Process the following online:

  • Given vertices u and v, add an edge between u and v.
    The resulting graph is guaranteed to be a forest.
  • Find the sum of diameters of all components of this forest.

In the end, output the XOR of answers of all updates.

EXPLANATION:

Our first step is to figure out what information we need to keep.

The quantity we need to maintain is the sum of diameters of all components.
Each time we merge two components, this quantity doesn’t change too much: the diameters of the two components are to be subtracted from it, and the diameter of the new component is to be added to it.

So, we need to be able to do the following:

  • Given a component, know its diameter,
  • Given two components and an edge joining them, quickly find the diameter of the new component.

The first part is not too hard if we can do the second.

To do the second part quickly, we can use the following (well-known) observation:
Let x and y be the endpoints of a diameter of a tree. Then, for any vertex u in the tree, one farthest vertex from it is either x or y.

In particular, this means the following:
Consider two trees T_1 and T_2, with diameters (x_1, y_1) and (x_2, y_2) respectively.
Suppose T_1 is joined to T_2 via an edge to form a new tree T_3.
Then, there will always exist a diameter of T_3 that has two among \{x_1, x_2, y_1, y_2\} as its endpoints.

Proof

There are three cases here.

  • If the diameter lies entirely in T_1, we can choose (x_1, y_1).
  • If the diameter lies entirely in T_2, we can choose (x_2, y_2).
  • Otherwise, the diameter contains the newly added edge, say u\leftrightarrow v (where u\in T_1)
    In this case, the best we can do is choose the farthest vertex from u (which is either x_1 or y_1), and the farthest vertex from v (which is either x_2 or y_2); so we’re done!

In particular, this means that as long as we can compute distances between vertices quickly, merging can be done fast.

Unfortunately, most of the standard methods of quickly computing distances between vertices don’t work very well online because they rely on the structure of the tree being known beforehand (link-cut trees work, though aren’t the intended solution).

Let’s look at one standard method, and what goes wrong with it: binary lifting.
Normally, we’d perform a DFS on the tree, then build the binary lifting table for each vertex.
Of course, this fails because we can’t perform a DFS on the tree when we don’t know its structure.

A natural modification of this to work online would be to attempt to rebuild the table whenever we add an edge: when adding an edge between u and v, treat v as a child of u, then rebuild the lifting table for all the vertices in the component of v.

In fact, this is enough!
In particular, we can perform this recomputation in a small-to-large fashion: recompute the lifting table for whichever tree is smaller.
This leads to a total complexity of \mathcal{O}(N\log^2 N), since each vertex is part of a lifting computation at most \mathcal{O}(\log N) times.
This is enough to get AC.

Putting everything together, our solution is as follows:

  • Maintain the components, and also the endpoints of a diameter for each component.
  • When merging components, recompute the lifting table for the smaller one.
  • Then, use binary lifting to compute distances between all 6 relevant pairs of vertices to find the diameter of the new component.

TIME COMPLEXITY

\mathcal{O}(N \log^2 N) per testcase.

CODE:

Author's code (C++)
#include<bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp> // Common file
#include <ext/pb_ds/tree_policy.hpp>
#define ll long long
// #define int long long
#define rep(i,a,b) for(int i=a;i<b;i++)
#define rrep(i,a,b) for(int i=a;i>=b;i--)
#define repin rep(i,0,n)
#define di(a) int a;cin>>a;
#define precise(i) cout<<fixed<<setprecision(i)
#define vi vector<int>
#define si set<int>
#define mii map<int,int>
#define take(a,n) for(int j=0;j<n;j++) cin>>a[j];
#define give(a,n) for(int j=0;j<n;j++) cout<<a[j]<<' ';
#define vpii vector<pair<int,int>>
#define sis string s;
#define sin string s;cin>>s;
#define db double
#define be(x) x.begin(),x.end()
#define pii pair<int,int>
#define pb push_back
#define pob pop_back
#define ff first
#define ss second
#define lb lower_bound
#define ub upper_bound
#define bpc(x) __builtin_popcountll(x) 
#define btz(x) __builtin_ctz(x)
using namespace std;

using namespace __gnu_pbds;

typedef tree<int, null_type, less<int>, rb_tree_tag,tree_order_statistics_node_update> ordered_set;
typedef tree<pair<int, int>, null_type,less<pair<int, int> >, rb_tree_tag,tree_order_statistics_node_update> ordered_multiset;

const long long INF=1e18;
const long long M=1e9+7;
const long long MM=998244353;
  
int power( int N, int M){
    int power = N, sum = 1;
    if(N == 0) sum = 0;
    while(M > 0){if((M & 1) == 1){sum *= power;}
    power = power * power;M = M >> 1;}
    return sum;
}



const int N = 100005;

int depth[N];
int visited[N];
int up[N][20];
vi adj[N];
bool rndomvar = 0;

void dfs(int v) {
    if(rndomvar==false){
        rndomvar = true;
        rep(i,0,N){
            rep(j,0,20){
                up[i][j]=-1;
            }
        }
    }
    visited[v]=true;
    rep(i,1,20){
        if(up[v][i-1]!=-1)up[v][i] = up[up[v][i-1]][i-1];
        if(up[v][i-1]==-1)up[v][i] = -1;
    }

    for(int x : adj[v]) {
        if(!visited[x]) {
            depth[x] = depth[up[x][0] = v]+1;
            dfs(x);
        }
    }
}

int jump(int x, int d) {
    rep(i,0,20){
        if((d >> i) & 1)
            {if(x==-1)break;
            x = up[x][i];}
    }
    return x;
}

int LCA(int a, int b) {
    if(depth[a] < depth[b]) swap(a, b);

    a = jump(a, depth[a] - depth[b]);
    if(a == b) return a;

    rrep(i,19,0){
        int aT = up[a][i], bT = up[b][i];
        if(aT != bT) a = aT, b = bT;
    }

    return up[a][0];
}

int dist(int a,int b) {
    return depth[a]+depth[b]-2*depth[LCA(a,b)];
}

void rdfs(int cur,int par){
    visited[cur]=0;depth[cur]=0;
    for(auto x : adj[cur]){
        if(x!=par)rdfs(x,cur);
    }
}

int parent[N];
int siz[N];
int d1[N],d2[N],d[N];
int ans;

void make_set(int v) {
    parent[v] = v;
    siz[v] = 1;
    d1[v]=v;d2[v]=v;d[v]=0;
}

int find_set(int v) {
    if (v == parent[v])
        return v;
    return parent[v] = find_set(parent[v]);
}

void union_sets(int a, int b) {
    int na = a,nb = b;
    a = find_set(a);
    b = find_set(b);
    if (a != b) {
        if (siz[a] < siz[b]){
            swap(a,b);swap(na,nb);
        }
        rdfs(nb,0);
        up[nb][0]=na;
        depth[nb]=depth[na]+1;
        dfs(nb);
        adj[na].pb(nb);
        adj[nb].pb(na);
        int prev = d[a]+d[b];
        int cur = max(d[a],d[b]);
        int e1=d1[a],e2=d2[a];        
        if(d[b]>d[a]){e1=d1[b],e2=d2[b];}
        // cout<<a<<" "<<b<<" "<<na<<" "<<nb<<" "<<"hi"<<endl;
        // cout<<d1[a]<<" "<<d2[a]<<" "<<d1[b]<<" "<<d2[b]<<endl;
        // cout<<dist(d1[a],na)<<" "<<dist(d2[a],na)<<" "<<dist(d1[b],nb)<<" "<<dist(d2[b],nb)<<endl;
        if(dist(d1[a],na)+1+dist(d1[b],nb)>cur){
            cur=dist(d1[a],na)+1+dist(d1[b],nb);
            e1=d1[a],e2=d1[b];
        }    
        if(dist(d1[a],na)+1+dist(d2[b],nb)>cur){
            cur=dist(d1[a],na)+1+dist(d2[b],nb);
            e1=d1[a],e2=d2[b];
        } 
        if(dist(d2[a],na)+1+dist(d1[b],nb)>cur){
            cur=dist(d2[a],na)+1+dist(d1[b],nb);
            e1=d2[a],e2=d1[b];
        } 
        if(dist(d2[a],na)+1+dist(d2[b],nb)>cur){
            cur=dist(d2[a],na)+1+dist(d2[b],nb);
            e1=d2[a],e2=d2[b];
        } 
        d1[a]=e1;d2[a]=e2;
        d[a]=cur;
        // cout<<cur<<" "<<prev<<"\n";
        ans+=cur-prev;
        parent[b] = a;
        siz[a] += siz[b];
    }
}
int minn = INT_MAX;
int maxn = INT_MIN;
int sumn = 0;


void solve()
{
    di(n) di(q)
    assert(1<=q<=100000);
    assert(1<=n<=100000);
    // cout<<n<<" "<<q<<endl;
    if(q>=n || q<=0){cout<<"-1"<<endl;}
    sumn+=n;
    // cout<<n<<" "<<q<<endl
    // cout<<n<<" "<<q<<endl;
    rep(i,1,n+1)make_set(i);
    // rep(i,0,q){
    //  di(u) di(v)
    //  union_sets(u,v);
    //  cout<<ans<<endl;
    // }
    int cnt = 0;
    int ax=0;
    while(q--){
        // di(op)
        // cout<<op<<endl;
        // if(op==1){
            di(u) di(v)
            assert(0<=u<=2*n);
            assert(0<=v<=2*n);
            u=u^ax;v=v^ax;
            // cout<<u<<" "<<v<<endl;
            if(find_set(u)==find_set(v)){cout<<"-1"<<endl;}
            union_sets(u,v);
            ax^=ans;
            // cout<<ans<<endl;
            // cout<<u<<" "<<v<<endl;
        // }
        // else{
        //     cout<<ans<<endl;
        // }
    }
    cout<<ax<<endl;
    // int x;cin>>x;
    // if(x==-1)return;
    ans=0;
    rep(i,1,n+1){
        adj[i].clear();
        visited[i]=0;
        depth[i]=0;
        rep(j,0,20){
            up[i][j]=-1;
        }
    }
}

signed main(){
    ios_base::sync_with_stdio(0);
    // freopen("temp.txt","r",stdin);
    // freopen("AA.in","w",stdout);
    cin.tie(0);
    cout.tie(0);
    #ifdef NCR
        init();
    #endif
    #ifdef SIEVE
        sieve();
    #endif
        di(t)
        if(t>100 || t<0){cout<<"-1"<<endl;}
        // cout<<t<<endl;
        while(t--)
        solve();
    if(minn<=1){cout<<"-1"<<endl;}
    if(maxn>100000){cout<<"-1"<<endl;}
    if(sumn>100000){cout<<"-1"<<endl;}
    // cout<<1<<endl;
    assert(sumn<=100000);
    return 0;
}
Editorialist's code (C++)
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

struct DSU {
private:
    std::vector<int> parent_or_size;
public:
    std::vector<array<int, 2>> endpoints;
    DSU(int n = 1): parent_or_size(n, -1), endpoints(n) {
        for (int i = 0; i < n; ++i) endpoints[i] = {i, i};
    }
    int get_root(int u) {
        if (parent_or_size[u] < 0) return u;
        return parent_or_size[u] = get_root(parent_or_size[u]);
    }
    int size(int u) { return -parent_or_size[get_root(u)]; }
    bool same_set(int u, int v) {return get_root(u) == get_root(v); }
    bool merge(int u, int v) {
        u = get_root(u), v = get_root(v);
        if (u == v) return false;
        if (parent_or_size[u] > parent_or_size[v]) std::swap(u, v);
        parent_or_size[u] += parent_or_size[v];
        parent_or_size[v] = u;
        return true;
    }
    std::vector<std::vector<int>> group_up() {
        int n = parent_or_size.size();
        std::vector<std::vector<int>> groups(n);
        for (int i = 0; i < n; ++i) {
            groups[get_root(i)].push_back(i);
        }
        groups.erase(std::remove_if(groups.begin(), groups.end(), [&](auto &s) { return s.empty(); }), groups.end());
        return groups;
    }
};

int main()
{
    int t; cin >> t;
    while (t--) {
        int n, q; cin >> n >> q;
        
        const int LOG = 18;
        vector<array<int, LOG>> anc(n);
        for (int i = 0; i < n; ++i) for (int j = 0; j < 18; ++j)
            anc[i][j] = i;
        vector adj(n, vector<int>());
        vector<int> dep(n);

        auto dist = [&] (int x, int y) {
            int ret = dep[x] + dep[y];

            int diff = abs(dep[x] - dep[y]);
            for (int i = 0; i < LOG; ++i) {
                if (diff & (1 << i)) {
                    if (dep[x] > dep[y]) x = anc[x][i];
                    else y = anc[y][i];
                }
            }
            assert(dep[x] == dep[y]);
            if (x == y) return ret - 2*dep[x];

            for (int i = LOG-1; i >= 0; --i) {
                if (anc[x][i] == anc[y][i]) continue;
                x = anc[x][i];
                y = anc[y][i];
            }
            return ret - 2*dep[x] + 2;
        };

        auto go = [&] (const auto &go, int u, int par, int d) -> void {
            anc[u][0] = par;
            dep[u] = d;
            for (int i = 1; i < LOG; ++i) anc[u][i] = anc[anc[u][i-1]][i-1];

            for (int v : adj[u]) {
                if (v == par) continue;
                go(go, v, u, d+1);
            }
        };

        DSU dsu(n);
        int ans = 0, diams = 0;
        for (int i = 0; i < q; ++i) {
            int x, y; cin >> x >> y;
            int u = x^ans, v = y^ans;
            --u, --v;
        
            auto [x1, y1] = dsu.endpoints[dsu.get_root(u)];
            auto [x2, y2] = dsu.endpoints[dsu.get_root(v)];

            if (dsu.size(u) < dsu.size(v)) swap(u, v);
            go(go, v, u, dep[u] + 1);

            diams -= dist(x1, y1) + dist(x2, y2);

            array<int, 4> vert = {x1, x2, y1, y2};
            int diam = -1, ex = -1, ey = -1;
            for (int j = 0; j < 4; ++j) for (int k = j+1; k < 4; ++k) {
                int cur = dist(vert[j], vert[k]);
                if (cur > diam) {
                    diam = cur;
                    ex = vert[j];
                    ey = vert[k];
                }
            }
            diams += diam;
            ans ^= diams;
            
            dsu.merge(u, v);
            adj[u].push_back(v);
            adj[v].push_back(u);
            dsu.endpoints[dsu.get_root(u)] = {ex, ey};
        }
        cout << ans << '\n';
    }
}
1 Like