ENODE_EASY - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: wuhudsm
Testers: Utkarsh Gupta, Jatin Garg
Editorialist: Nishank Suresh

DIFFICULTY:

2707

PREREQUISITES:

Depth first search, Segment/fenwick trees, Binary search

PROBLEM:

You have a tree T, each node of which may or may not be activated.
Each node also has an energy, initially 0.
At time i, each activated node gives i energy to itself and its neighbors.

Answer Q queries of the form (u, T, K), with the answer being the minimum time after which at least T nodes on the path 1 \to u have an energy of \geq K.

EXPLANATION

This editorial will be for the hard version of the problem, with K \leq 10^{18}.

Let’s focus on a single query (u, T, K). How to solve it?

To do that, let’s look at how the energy of a vertex changes.
For vertex x, let m_x be the sum of A_v across all vertices v such that v = x or v is a neighbor of x.
Then, vertex x will receive exactly m_x \cdot i energy at time i.
In particular, at time i, vertex x will have m_x\cdot (1 + 2 + \ldots + i) =m_x\cdot i\cdot (i+1)/2 energy.

So, suppose we know the m_x values for all vertices. To answer a query (u, T, K):

  • First, note that the number of vertices with energy \geq K is a non-decreasing function of time, so we can binary search on it to find the first time this number exceeds T.
    Be careful when choosing the upper limit of the binary search
  • This requires us to be able to quickly calculate the number of nodes on the 1 \to u path that have \geq K energy at a given time t.
  • This means we want the count of nodes x such that m_x \cdot t \cdot (t+1)/2 \geq K.
    t is a constant so by bringing this to the other side we get an equation of the form m_x \geq c for some constant c.
  • Keeping all the values of m_x from the root to u in an appropriate data structure (for example, a segment tree) allows us to find this count in \mathcal{O}(\log N).
  • So, a single query can be solved in \mathcal{O}(\log N\log M) where M = 10^9.

This solves a single query. Now we extend this to multiple queries.

Note that a query at u only depends on the m_x values on the 1 \to u path. So, we can do the following and solve queries offline:

  • Group queries by their u.
  • Run a dfs on the tree.
  • When you enter a vertex u, insert m_u into the data structure you are using.
  • Then, solve all queries involving u.
  • Next, continue on with the dfs to u's children.
  • Finally, when exiting u, remove m_u from the data structure.

You can see that this process ensures that whenever answering a query at u, the data structure contains only those m_x values for x on the 1\to u path, which is exactly what we want.

It is also possible to solve the queries online (for example, with a persistent segment tree).

TIME COMPLEXITY

\mathcal{O}(N + Q\log N\log M) per test case.

CODE:

Setter's code (C++)
#include <map>
#include <set>
#include <cmath>
#include <ctime>
#include <queue>
#include <stack>
#include <cstdio>
#include <cstdlib>
#include <vector>
#include <cstring>
#include <algorithm>
using namespace std;
typedef double db; 
typedef long long ll;
typedef unsigned long long ull;
const int N=1000010;
const int LOGN=28;
const ll  TMD=0;
const ll  INF=2147483647;
int n,q,cur;
int a[N],key[N],tag[N],enough[N],tot[N];
ll  ans[N];
vector<int> val;
vector<int> G[N];

struct query
{
	int t,id;
	ll  k;
	
	query() {}
	
	query(int t,ll k,int id):t(t),k(k),id(id) {}
};
vector<query> Q[N];

void DFS(int x,int pre)
{
	tot[x]=tot[pre]+enough[x];
	for(int i=0;i<G[x].size();i++)
	{
		int y=G[x][i];
		if(y==pre) continue;
		DFS(y,x);
	}
}

void init()
{
	scanf("%d",&n);	
	for(int i=1;i<=n;i++) scanf("%d",&a[i]);
	for(int i=1;i<n;i++)
	{
		int u,v;
		scanf("%d%d",&u,&v);
		G[u].push_back(v);
		G[v].push_back(u);
	}
	scanf("%d",&q);
	for(int i=1;i<=q;i++)
	{
		int x,t;
		ll  k;
		scanf("%d%d%lld",&x,&t,&k);
		Q[x].push_back(query(t,k,i));
	}
} 

void solve()
{
	for(int i=1;i<=q;i++) ans[i]=-1;
	for(int i=1;i<=n;i++)
	{
		key[i]+=a[i];
		for(int j=0;j<G[i].size();j++) key[i]+=a[G[i][j]];
		if(!tag[key[i]]&&key[i]) tag[key[i]]=1,val.push_back(key[i]);
	}
	sort(val.begin(),val.end(),greater<int>() );
	for(int i=0;i<val.size();i++)
	{
		for(int j=1;j<=n;j++) if(key[j]>=val[i]) enough[j]=1;
		DFS(1,0);
		for(int j=1;j<=n;j++)
		{
			for(int k=0;k<Q[j].size();k++)
			{
				query qu=Q[j][k];
				if(ans[qu.id]==-1&&tot[j]>=qu.t)
				{
					ll L=0,R=(ll)sqrt(qu.k/val[i]*2+1)+2,M;
					while(L+1!=R)
					{
						M=(L+R)>>1;
						if(M*(1+M)/2*val[i]>=qu.k) R=M;
						else L=M;
					}
					ans[qu.id]=R;
				}	
			}
		}
	}
	for(int i=1;i<=q;i++) printf("%lld\n",ans[i]);
} 

int main()
{
	init();
	solve();
	
	return 0;
}
Tester (rivalq)'s code (C++)
// Jai Shree Ram  
  
#include<bits/stdc++.h>
using namespace std;

#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace __gnu_pbds;
template <class T> using Tree = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>; 
#define ook order_of_key
#define fbo find_by_order

#define rep(i,a,n)     for(int i=a;i<n;i++)
#define ll             long long
#define int            long long
#define pb             push_back
#define all(v)         v.begin(),v.end()
#define endl           "\n"
#define x              first
#define y              second
#define gcd(a,b)       __gcd(a,b)
#define mem1(a)        memset(a,-1,sizeof(a))
#define mem0(a)        memset(a,0,sizeof(a))
#define sz(a)          (int)a.size()
#define pii            pair<int,int>
#define hell           1000000007
#define elasped_time   1.0 * clock() / CLOCKS_PER_SEC



template<typename T1,typename T2>istream& operator>>(istream& in,pair<T1,T2> &a){in>>a.x>>a.y;return in;}
template<typename T1,typename T2>ostream& operator<<(ostream& out,pair<T1,T2> a){out<<a.x<<" "<<a.y;return out;}
template<typename T,typename T1>T maxs(T &a,T1 b){if(b>a)a=b;return a;}
template<typename T,typename T1>T mins(T &a,T1 b){if(b<a)a=b;return a;}

// -------------------- Input Checker Start --------------------
 
long long readInt(long long l, long long r, char endd)
{
    long long x = 0;
    int cnt = 0, fi = -1;
    bool is_neg = false;
    while(true)
    {
        char g = getchar();
        if(g == '-')
        {
            assert(fi == -1);
            is_neg = true;
            continue;
        }
        if('0' <= g && g <= '9')
        {
            x *= 10;
            x += g - '0';
            if(cnt == 0)
                fi = g - '0';
            cnt++;
            assert(fi != 0 || cnt == 1);
            assert(fi != 0 || is_neg == false);
            assert(!(cnt > 19 || (cnt == 19 && fi > 1)));
        }
        else if(g == endd)
        {
            if(is_neg)
                x = -x;
            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(false);
            }
            return x;
        }
        else
        {
            assert(false);
        }
    }
}
 
string readString(int l, int r, char endd)
{
    string ret = "";
    int cnt = 0;
    while(true)
    {
        char g = getchar();
        assert(g != -1);
        if(g == endd)
            break;
        cnt++;
        ret += g;
    }
    assert(l <= cnt && cnt <= r);
    return ret;
}
 
long long readIntSp(long long l, long long r) { return readInt(l, r, ' '); }
long long readIntLn(long long l, long long r) { return readInt(l, r, '\n'); }
string readStringLn(int l, int r) { return readString(l, r, '\n'); }
string readStringSp(int l, int r) { return readString(l, r, ' '); }
void readEOF() { assert(getchar() == EOF); }
 
vector<int> readVectorInt(int n, long long l, long long r)
{
    vector<int> a(n);
    for(int i = 0; i < n - 1; i++)
        a[i] = readIntSp(l, r);
    a[n - 1] = readIntLn(l, r);
    return a;
}
 
// -------------------- Input Checker End --------------------

const int maxn=2000005;
int p[maxn];
int sz[maxn];
void clear(int n=maxn){
    rep(i,0,maxn)p[i]=i,sz[i]=1;
}
int root(int x){
   while(x!=p[x]){
       p[x]=p[p[x]];
       x=p[x];
   }
   return x;  
}
void merge(int x,int y){
    int p1=root(x);
    int p2=root(y);
    if(p1==p2)return;
    if(sz[p1]>=sz[p2]){
        p[p2]=p1;
        sz[p1]+=sz[p2];
    }
    else{
        p[p1]=p2;
        sz[p2]+=sz[p1];
    }
}

int solve(){
		
                int n = readIntLn(1,4e4);
                vector<int> a = readVectorInt(n,0,1);
                vector<int> val(n);
                vector<vector<int>> g(n);
                clear();
                for(int i = 2; i <= n; i++){
                        int u = readIntSp(1,n);
                        int v = readIntLn(1,n);
                        u--;v--;
                        g[u].push_back(v);
                        g[v].push_back(u);
                        assert(root(u) != root(v));
                        merge(u,v);
                } 

                for(int i = 0; i < n; i++){
                        val[i] += a[i];
                        for(auto j:g[i])val[j] += a[i];
                }

                int q = readIntLn(1,4e4);
                vector<vector<array<int,3>>> queries(n);
                vector<int> ans(q,-1);
                for(int i = 0; i < q; i++){
                        int x = readIntSp(1,n);
                        x--;
                        int t = readIntSp(1,n);
                        int k = readIntLn(1,1e18);
                        queries[x].push_back({t,k,i});
                }
                Tree<pii> st;
                function<void(int,int)> dfs = [&](int u,int p){
                        if(val[u] > 0) st.insert({val[u],u});
                        for(auto [t,k,j]:queries[u]){
                                if(sz(st) < t){
                                        ans[j] = -1;
                                }else{
                                        auto itr = st.fbo(sz(st) - t);
                                        int d = itr->x;
                                        //d*(i*(i + 1)) >= 2*k
                                        int L = 1, R = 2e9;
                                        k *= 2;
                                        while(L <= R){
                                                int M = (L + R)/2;
                                                int reqd = k/d + (k % d != 0);
                                                if(M*(M + 1) >= k/d){
                                                        ans[j] = M;
                                                        R = M - 1;
                                                }else{
                                                        L = M + 1;
                                                }
                                        }
                                }
                        }
                        for(auto j:g[u]){
                                if(j != p)dfs(j,u);
                        }
                        st.erase({val[u],u});
                };
                dfs(0,0);
                for(int i = 0; i < q; i++)cout << ans[i] << endl;

 return 0;
}
signed main(){
    ios_base::sync_with_stdio(0);cin.tie(0);cout.tie(0);
    //freopen("input.txt", "r", stdin);
    //freopen("output.txt", "w", stdout)`;
    #ifdef SIEVE
    sieve();
    #endif
    #ifdef NCR
    init();
    #endif
    int t = 1;
    while(t--){
        solve();
    }
    return 0;
}
Tester (utkarsh_25dec)'s code (C++)
//Utkarsh.25dec
#include <bits/stdc++.h>
#define ll long long int
#define pb push_back
#define mp make_pair
#define mod 1000000007
#define vl vector <ll>
#define all(c) (c).begin(),(c).end()
using namespace std;
#include <ext/pb_ds/assoc_container.hpp> 
#include <ext/pb_ds/tree_policy.hpp> 
using namespace __gnu_pbds; 
#define ordered_set tree<pair<int,int>, null_type,less<pair<int,int>>, rb_tree_tag,tree_order_statistics_node_update>
// ordered_set s ; s.order_of_key(val)  no. of elements strictly less than val
// s.find_by_order(i)  itertor to ith element (0 indexed)
ll power(ll a,ll b) {ll res=1;a%=mod; assert(b>=0); for(;b;b>>=1){if(b&1)res=res*a%mod;a=a*a%mod;}return res;}
ll modInverse(ll a){return power(a,mod-2);}
long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);

            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }

            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }

            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}
const int N=500023;
bool vis[N];
vector <int> adj[N];
int energetic[N];
vector <pair<ll,ll>> nodequ[N];
int deg[N];
ordered_set s;
map <tuple<ll,ll,ll>,ll> ans;
void dfs(int curr)
{
    vis[curr]=1;
    s.insert(mp(deg[curr],curr));
    for(auto p:nodequ[curr])
    {
        ll nodesreq=p.first;
        ll valreq=p.second;
        if(s.size()<nodesreq)
        {
            ans[make_tuple(curr,nodesreq,valreq)]=-1;
            continue;
        }
        int tmp=s.size();
        auto it=s.find_by_order(tmp-nodesreq);
        auto val=(*it);
        ll d=val.first;
        if(d==0)
        {
            ans[make_tuple(curr,nodesreq,valreq)]=-1;
            continue;
        }
        else
        {
            ll low=1,high=sqrt(4e18/d);
            while(low<=high)
            {
                ll mid=(low+high)/2;
                if((d*mid*(mid+1))/2 >= valreq)
                    high=mid-1;
                else
                    low=mid+1;
            }
            ans[make_tuple(curr,nodesreq,valreq)]=low;
        }
    }
    for(auto it:adj[curr])
    {
        if(vis[it])
            continue;
        dfs(it);
    }
    s.erase(mp(deg[curr],curr));
}
void solve()
{
    int n=readInt(1,40000,'\n');
    for(int i=1;i<=n;i++)
    {
        vis[i]=0;
        adj[i].clear();
        nodequ[i].clear();
        deg[i]=0;
    }
    s.clear();
    ans.clear();
    for(int i=1;i<=n;i++)
    {
        if(i!=n)
            energetic[i]=readInt(0,1,' ');
        else
            energetic[i]=readInt(0,1,'\n');
    }
    for(int i=1;i<n;i++)
    {
        int u,v;
        u=readInt(1,n,' ');
        v=readInt(1,n,'\n');
        assert(u!=v);
        adj[u].pb(v);
        adj[v].pb(u);
    }
    for(int i=1;i<=n;i++)
    {
        deg[i]+=energetic[i];
        for(auto j:adj[i])
            deg[i]+=energetic[j];
    }
    int q=readInt(1,40000,'\n');
    vector <tuple<ll,ll,ll>> queries;
    while(q--)
    {
        ll x,t,k;
        x=readInt(1,n,' ');
        t=readInt(1,n,' ');
        k=readInt(1,1e18,'\n');
        nodequ[x].pb(mp(t,k));
        queries.pb(make_tuple(x,t,k));
    }
    dfs(1);
    for(int i=1;i<=n;i++)
        assert(vis[i]==1);
    for(auto q:queries)
        cout<<ans[q]<<'\n';
}
int main()
{
    #ifndef ONLINE_JUDGE
    freopen("input.txt", "r", stdin);
    freopen("output.txt", "w", stdout);
    #endif
    ios_base::sync_with_stdio(false);
    cin.tie(NULL),cout.tie(NULL);
    int T=1;
    //cin>>T;
    while(T--)
        solve();
    cerr << "Time : " << 1000 * ((double)clock()) / (double)CLOCKS_PER_SEC << "ms\n";
}
Editorialist's code (C++)
#include "bits/stdc++.h"
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

struct query {
	ll t, k, id;
};

/**
 * Point-update Segment Tree
 * Source: kactl
 * Description: Iterative point-update segment tree, ranges are half-open i.e [L, R).
 *              f is any associative function.
 * Time: O(logn) update/query
 */

template<class T, T unit = T()>
struct SegTree {
	T f(T a, T b) { return a+b; }
	vector<T> s; int n;
	SegTree(int _n = 0, T def = unit) : s(2*_n, def), n(_n) {}
	void update(int pos, T val) {
		for (s[pos += n] += val; pos /= 2;)
			s[pos] = f(s[pos * 2], s[pos * 2 + 1]);
	}
	T query(int b, int e) {
		T ra = unit, rb = unit;
		for (b += n, e += n; b < e; b /= 2, e /= 2) {
			if (b % 2) ra = f(ra, s[b++]);
			if (e % 2) rb = f(s[--e], rb);
		}
		return f(ra, rb);
	}
};

int main()
{
	ios::sync_with_stdio(false); cin.tie(0);

	int n; cin >> n;
	vector<int> a(n), b(n);
	for (int &x : a) cin >> x;
	vector<vector<int>> g(n);
	for (int i = 0; i < n-1; ++i) {
		int u, v; cin >> u >> v;
		g[--u].push_back(--v);
		g[v].push_back(u);
	}
	for (int i = 0; i < n; ++i) {
		b[i] = a[i];
		for (int u : g[i]) b[i] += a[u];
	}
	vector<vector<query>> queries(n);
	int q; cin >> q;
	for (int i = 0; i < q; ++i) {
		query cur;
		int u; cin >> u; --u;
		cin >> cur.t >> cur.k;
		cur.id = i;
		queries[u].push_back(cur);
	}
	vector<ll> ans(q, -1);
	SegTree<int> T(n);
	const ll inf = 2.5e9;
	auto dfs = [&] (const auto &self, int u, int p) -> void {
		T.update(b[u], 1);
		for (auto qry : queries[u]) {
			ll lo = 1, hi = inf;
			while (lo < hi) {
				ll mid = (lo + hi)/2;
				ll den = (mid * (mid+1) / 2);
				ll want = (qry.k + den - 1) / den;

				if (want >= n or T.query(want, n) < qry.t) lo = mid+1;
				else hi = mid;
			}
			if (lo == inf) lo = -1;
			ans[qry.id] = lo;
		}

		for (int v : g[u]) {
			if (v == p) continue;
			self(self, v, u);
		}
		T.update(b[u], -1);
	};
	dfs(dfs, 0, 0);
	for (auto x : ans) cout << x << '\n';
}

What is different in solution of easy and hard versions of this problem?

1 Like

For easy version:

We notice after O(\sqrt{maxK}) seconds,the energy of each node is either 0 or not less than maxK.

For 1,2,...,144 seconds,calculate the energy of each tree and using the 144 trees to answer each query.

Can we solve the easy version of the problem by only having the knowledge of DFS and Binary Search without knowing the segment tree?

Please help me figure out my assumption is correct or not!!
For a query (X, T, K), In the path from 1 → X, we have to figure out the T’th maximum node, and we will further calculate the minimum time using binary search.
Now the main problem, is to calculate the T’th maximum node in the path 1 → X,
For online queries can this be solved in O(n√n), codeforces blog. Please explain a bit about the sqrt root approach.

YES.

In fact,even the hard & ex-hard version can be solved without any data structure.

I’ve metioned them in ENODE_HARD-Editorial - #5.