REC19E - Editorial

PROBLEM LINK:
Battle of Sages

Setter: Prasann_kumar_Gupta
Tester: Vishal_Mahavar

DIFFICULTY:
Medium

PREREQUISITES:
Tree, Binary Lifting, Greedy

EXPLANATION
First, we need to observe that to maximize their personal count of number of cities resurrected, they will try to minimize the count of one another. To do so, Radiant sage and Kingdom sage will first resurrect the cities which lie on their simple path. This is because, the cities in the simple path of the starting cities a,b are the cities where both radiant sage and kingdom sage can resurrect. Therefore, if len is the length of the simple path between radiant sage and kingdom sage excluding a and b, then (len+1)/2 cities will be resurrected by radiant sage as he starts resurrecting 1 day before kingdom sage and rest of the len/2 cities will be resurrected by kingdom sage. So, the last city resurrected by Radiant sage in the path of a and b will be at a distance of (len+1)/2 from a and len/2 from b. The problem can be solved by dividing into 2 cases -

  1. f the last city resurrected by Radiant sage is at a distance smaller then that of LCA(least common ancestor) from a - Here, we can find the last city resurrected by Radiant sage in the simple path. The subtree size of that particular city will indeed be number of cities Radiant sage can resurrect and rest of the cities will definitely be resurrected by kingdom sage.

  2. If the last city resurrected by Radiant sage is at a distance greater than or equal to that of LCA from a - Here, we can find the last city resurrected by Kingdom sage in the simple path. The subtree size of that city will indeed be the number of cities resurrected by kingdom sage and rest of the city will be resurrected by Radiant sage.

To find the last last resurrected by Radiant sage in their simple path, we just want the (len+1)/2 city in their simple path. This can be easily done in linear time, but will not fit into the given constraint. So to do this task we can use Binary lifting which indeed compute the city in O(logn) time. Similarly, LCA can also be found using Binary lifting technique in O(logn) time.

TIME COMPLEXITY
O(nlogn + qlogn)

Setter's Solution
#pragma GCC target ("avx2")
//#pragma GCC optimize "trapv"
#pragma GCC optimization ("O3")
#pragma GCC optimization ("unroll-loops")
#include <bits/stdc++.h>
#include<ext/pb_ds/assoc_container.hpp>
#include<ext/pb_ds/tree_policy.hpp>
#define ordered_set tree<ll, null_type,less<ll>, rb_tree_tag,tree_order_statistics_node_update>
#define input(a,n) for(ll i1=0;i1<n;i1++)cin>>a[i1]
#define ll int
#define pi 2 * acos(0.0)
#define usll uset<ll>
#define sll set<ll>
#define mll map<ll,ll>
#define pll pair<ll,ll>
#define vll vector<ll>
#define vpll vector<pll>
#define umll umap<ll,ll>
#define S second
#define sz size()
#define all(v) v.begin(),v.end()
#define Y cout<< "YES"<< "\n"
#define N cout<< "NO"<< "\n"
#define F first
#define pb push_back
#define pf push_front
#define ld long double
#define random(l,r,T)    uniform_int_distribution<T>(l,r)(rng)
using namespace __gnu_pbds;
using namespace std;
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
struct custom_hash {
    static uint64_t splitmix64(uint64_t x) {
        // http://xorshift.di.unimi.it/splitmix64.c
        x += 0x9e3779b97f4a7c15;
        x = (x ^ (x >> 30)) * 0xbf58476d1ce4e5b9;
        x = (x ^ (x >> 27)) * 0x94d049bb133111eb;
        return x ^ (x >> 31);
    }

    size_t operator()(uint64_t x) const {
        static const uint64_t FIXED_RANDOM = chrono::steady_clock::now().time_since_epoch().count();
        return splitmix64(x + FIXED_RANDOM);
    }
};
template<class T, class H>using umap=unordered_map<T,H,custom_hash>;
template<class T>using uset=unordered_set<T,custom_hash>;
const ll mod = 1000000007;
ll par[100005][21];//par[x][y] denotes the parent of a node x which is at a height 2^y from it.
//par[i][j]=par[par[i][j-1]][j-1];// basically if a node is at a distance 2^j from i then the node will be
//at a distance 2^(j-1) from a node which is 2^(j-1) from i.
ll d[100005];//d[x] denotes the height of  the node x from root (assuming root's height as 0)
ll subsz[100005],Log[100005];
vll adj[100005];
void dfs(ll v,ll p,ll h)
{
   d[v]=h;
   par[v][0]=p;
   ll mx=Log[h];
   for(ll i=1;i<=mx;i++){
       par[v][i]=par[par[v][i-1]][i-1];
   }
   for(auto i:adj[v]){
       if(i==p)
            continue;
       dfs(i,v,h+1);
       subsz[v] += subsz[i];
   }
}
ll lca(ll x,ll y)
{
   while(d[x]>d[y]){
       x=par[x][Log[(d[x]-d[y])]];
   }
   while(d[y]>d[x]){
       y=par[y][Log[(d[y]-d[x])]];
   }
   if(x==y)
        return x;
   ll mx=Log[d[x]];
   for(ll i=mx;i>=0;i--){
       if(par[x][i]!=par[y][i]){
           x=par[x][i];
           y=par[y][i];
       }
   }
   return par[x][0];
}
ll find_node(ll node,ll len){
    ll cnt=0;
    while(len){
        if(len%2)
            node = par[node][cnt];
        len/=2;
        cnt++;
    }
    return node;
}
int main()
{
    // freopen("05.in","r",stdin);
    // freopen("05.out","w",stdout);
    clock_t clk = clock();
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    cout.tie(NULL);
    ll test=1;
    //cin>>test;
    Log[0]=0;
    Log[1]=0;
    for(ll i=2;i<=100001;i++)
        Log[i] = 1 + Log[i/2];
    for(ll tc=1;tc<=test;tc++)
    {
        ll n,q;
        cin>>n>>q;
        assert(2<=n && n<=1e5);
        assert(1<=q && q<=1e5);
        for(ll i=1;i<=n-1;i++){
            ll x,y;
            cin>>x>>y;
            assert(1<=x && x<=n);
            assert(1<=y && y<=n);
            adj[x].pb(y);
            adj[y].pb(x);
            subsz[i]=1;
        }
        subsz[n]=1;
        dfs(1,-1,0);
        while(q--){
            ll a,b;
            cin>>a>>b;
            assert(1<=a && a<=n);
            assert(1<=b && b<=n);
            assert(a!=b);
            ll LCA = lca(a,b);
            ll len = (d[a]-d[LCA]-1) + (d[b]-d[LCA]-1) + 1;
            if((len+1)/2 >= (d[a]-d[LCA])){
                ll node = find_node(b,len/2);
                ll count_b = subsz[node];
                cout<<n-count_b<< " "<<count_b<< "\n";
            }
            else{
                ll node = find_node(a,(len+1)/2);
                ll count_a = subsz[node];
                cout<<count_a<< " "<<n-count_a<< "\n";
            }
        }
        //cout<< "Case #"<<tc<< ": "<<ans<< "\n";
    }
    cerr << '\n'<<"Time (in s): " << double(clock() - clk) * 1.0 / CLOCKS_PER_SEC << '\n';
    return 0;
}

Tester's Solution
#include<bits/stdc++.h>
//#pragma GCC optimize "trapv"
//#include<ext/pb_ds/assoc_container.hpp>
//#include<ext/pb_ds/tree_policy.hpp>
#define fast_az_fuk      ios_base::sync_with_stdio(false); cin.tie(NULL); cout.tie(NULL);
#define ll               long long
#define lll              __int128
#define ull              unsigned ll
#define ld               long double 
#define pb               push_back 
#define pf               push_front
#define dll              deque<ll> 
#define vll              vector<ll>
#define vvll             vector<vll> 
#define pll              pair<ll,ll> 
#define vpll             vector<pll>
#define dpll             deque<pll>
#define mapll            map<ll,ll>
#define umapll           umap<ll,ll>
#define endl             "\n" 
#define all(v)           v.begin(),v.end() 
#define ms(a,x)          memset(a,x,sizeof(a))
#define random(l,r,T)    uniform_int_distribution<T>(l,r)(rng)


//#define ordered_set tree<ll, null_type,less<ll>, rb_tree_tag,tree_order_statistics_node_update>

using namespace std;


mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
//using namespace __gnu_pbds;

template<typename T> istream& operator >>(istream &in,vector<T> &v){ for(auto &x:v) in>>x; return in;}
template<typename T> ostream& operator <<(ostream &out,const vector<T> &v){ for(auto &x:v) out<<x<<' '; return out;}
template<typename T1,typename T2> istream& operator >>(istream &in,pair<T1,T2> &p){ in>>p.first>>p.second; return in;}
template<typename T1,typename T2> ostream& operator <<(ostream &out,const pair<T1,T2> &p){ out<<p.first<<' '<<p.second; return out;}

vvll adj,up; vll level,in,out; ll timer; vll siz;
void binary_lifitng_prepare(int n,int m=20){
    adj.clear(); adj.resize(n+1); up.assign(n+1,vll(m,0)); in.resize(n+1); out.resize(n+1);
    level.resize(n+1); timer=0; siz.resize(n+1);
}
void dfsTimer(ll node,ll par,ll lev){
	siz[node]=1;
    in[node] = ++timer; up[node][0] = par; level[node] = lev;
    for(ll i=1;i<up[0].size();i++){
        up[node][i] = up[up[node][i-1]][i-1];
    }
    for(ll x:adj[node]){ if(x == par) continue; dfsTimer(x,node,lev+1); siz[node] += siz[x]; }
    out[node] = ++timer;
}

inline bool isParent(ll u,ll v){
    if(in[u]<=in[v] && out[u]>=out[v]) return true;
    return false;
}

ll LCA(ll u,ll v){
    if(isParent(u,v)) return u; if(isParent(v,u)) return v;
    ll L = up[0].size();
    for(ll i = L-1; i>=0; i--){ if(!isParent(up[u][i],v)) u = up[u][i]; }
    return up[u][0];
}

ll dist(ll u,ll v){
    ll L = LCA(u,v); return level[u]+level[v]-2*level[L];
}
ll find(ll u,ll dis){
	for(int i=up[0].size()-1;i>=0;i--){
		if((1<<i) <= dis) {
			u = up[u][i]; dis -= (1<<i);
		}
	}
	return u;
}
vll vis;
bool proper_tree(ll node,ll par){
	vis[node]=1;
	for(ll x:adj[node]){
		if(x == par) continue;
		if(vis[x]) return false;
		if(!proper_tree(x,node)) return false;
	}
	return true;
}

const bool tests = 0;
void solve_case(){
	int n,q; cin>>n>>q;
	assert(n>1 && n<=1e5 && q>=1 && q <= 1e5);
	binary_lifitng_prepare(n);
	for(int i=1;i<n;i++){
		ll u,v; cin>>u>>v; adj[u].pb(v); adj[v].pb(u);
		assert(u>=1 && u<=n && v>=1 && v<=n);
	} dfsTimer(1,1,0); vis.resize(n+1,0); assert(proper_tree(1,1));
	while(q--){
		ll a,b; cin>>a>>b; assert(a != b);
		ll L = LCA(a,b); ll len = dist(a,b)+1;
		if((len+1)/2 >= dist(L,a)+1){
			ll node = find(b,len - (len+1)/2-1);
			cout<<n-siz[node]<<' '<<siz[node]<<endl;
		}
		else{
			ll node = find(a,(len+1)/2-1);
			cout<<siz[node]<<' '<<n-siz[node]<<endl;
		} 
	}
}	

int32_t main()
{
    #ifdef LOCAL
        freopen("error.txt", "w", stderr);
        clock_t clk = clock();
    #endif
    fast_az_fuk
    ll testcase=1; if(tests) cin>>testcase;
    cout<<fixed<<setprecision(10);
    for(ll test=1;test<=testcase;test++)
    {//cout<<"Case #"<<test<<": ";
        solve_case();
    }
    #ifdef LOCAL
        cerr << '\n'<<"Time (in s): " << double(clock() - clk) * 1.0 / CLOCKS_PER_SEC << '\n';
    #endif
    return 0;
}