ENODE_HARD-Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: wuhudsm
Testers: Utkarsh Gupta, Jatin Garg
Editorialist: Nishank Suresh

DIFFICULTY:

2604

PREREQUISITES:

Depth first search, Segment/fenwick trees, Binary search

PROBLEM:

You have a tree T, each node of which may or may not be activated.
Each node also has an energy, initially 0.
At time i, each activated node gives i energy to itself and its neighbors.

Answer Q queries of the form (u, T, K), with the answer being the minimum time after which at least T nodes on the path 1 \to u have an energy of \geq K.

EXPLANATION

This editorial will be for the hard version of the problem, with K \leq 10^{18}.

Let’s focus on a single query (u, T, K). How to solve it?

To do that, let’s look at how the energy of a vertex changes.
For vertex x, let m_x be the sum of A_v across all vertices v such that v = x or v is a neighbor of x.
Then, vertex x will receive exactly m_x \cdot i energy at time i.
In particular, at time i, vertex x will have m_x\cdot (1 + 2 + \ldots + i) =m_x\cdot i\cdot (i+1)/2 energy.

So, suppose we know the m_x values for all vertices. To answer a query (u, T, K):

  • First, note that the number of vertices with energy \geq K is a non-decreasing function of time, so we can binary search on it to find the first time this number exceeds T.
    Be careful when choosing the upper limit of the binary search
  • This requires us to be able to quickly calculate the number of nodes on the 1 \to u path that have \geq K energy at a given time t.
  • This means we want the count of nodes x such that m_x \cdot t \cdot (t+1)/2 \geq K.
    t is a constant so by bringing this to the other side we get an equation of the form m_x \geq c for some constant c.
  • Keeping all the values of m_x from the root to u in an appropriate data structure (for example, a segment tree) allows us to find this count in \mathcal{O}(\log N).
  • So, a single query can be solved in \mathcal{O}(\log N\log M) where M = 10^9.

This solves a single query. Now we extend this to multiple queries.

Note that a query at u only depends on the m_x values on the 1 \to u path. So, we can do the following and solve queries offline:

  • Group queries by their u.
  • Run a dfs on the tree.
  • When you enter a vertex u, insert m_u into the data structure you are using.
  • Then, solve all queries involving u.
  • Next, continue on with the dfs to u's children.
  • Finally, when exiting u, remove m_u from the data structure.

You can see that this process ensures that whenever answering a query at u, the data structure contains only those m_x values for x on the 1\to u path, which is exactly what we want.

It is also possible to solve the queries online (for example, with a persistent segment tree).

TIME COMPLEXITY

\mathcal{O}(N + Q\log N\log M) per test case.

CODE:

Setter's code (C++)
#include <map>
#include <set>
#include <cmath>
#include <ctime>
#include <queue>
#include <stack>
#include <cstdio>
#include <cstdlib>
#include <vector>
#include <cstring>
#include <algorithm>
using namespace std;
typedef double db; 
typedef long long ll;
typedef unsigned long long ull;
const int N=1000010;
const int LOGN=28;
const ll  TMD=0;
const ll  INF=2147483647;
int n,q,cur;
int a[N],key[N],tag[N],enough[N],tot[N];
ll  ans[N];
vector<int> val;
vector<int> G[N];

struct query
{
	int t,id;
	ll  k;
	
	query() {}
	
	query(int t,ll k,int id):t(t),k(k),id(id) {}
};
vector<query> Q[N];

void DFS(int x,int pre)
{
	tot[x]=tot[pre]+enough[x];
	for(int i=0;i<G[x].size();i++)
	{
		int y=G[x][i];
		if(y==pre) continue;
		DFS(y,x);
	}
}

void init()
{
	scanf("%d",&n);	
	for(int i=1;i<=n;i++) scanf("%d",&a[i]);
	for(int i=1;i<n;i++)
	{
		int u,v;
		scanf("%d%d",&u,&v);
		G[u].push_back(v);
		G[v].push_back(u);
	}
	scanf("%d",&q);
	for(int i=1;i<=q;i++)
	{
		int x,t;
		ll  k;
		scanf("%d%d%lld",&x,&t,&k);
		Q[x].push_back(query(t,k,i));
	}
} 

void solve()
{
	for(int i=1;i<=q;i++) ans[i]=-1;
	for(int i=1;i<=n;i++)
	{
		key[i]+=a[i];
		for(int j=0;j<G[i].size();j++) key[i]+=a[G[i][j]];
		if(!tag[key[i]]&&key[i]) tag[key[i]]=1,val.push_back(key[i]);
	}
	sort(val.begin(),val.end(),greater<int>() );
	for(int i=0;i<val.size();i++)
	{
		for(int j=1;j<=n;j++) if(key[j]>=val[i]) enough[j]=1;
		DFS(1,0);
		for(int j=1;j<=n;j++)
		{
			for(int k=0;k<Q[j].size();k++)
			{
				query qu=Q[j][k];
				if(ans[qu.id]==-1&&tot[j]>=qu.t)
				{
					ll L=0,R=(ll)sqrt(qu.k/val[i]*2+1)+2,M;
					while(L+1!=R)
					{
						M=(L+R)>>1;
						if(M*(1+M)/2*val[i]>=qu.k) R=M;
						else L=M;
					}
					ans[qu.id]=R;
				}	
			}
		}
	}
	for(int i=1;i<=q;i++) printf("%lld\n",ans[i]);
} 

int main()
{
	init();
	solve();
	
	return 0;
}
Tester (rivalq)'s code (C++)
// Jai Shree Ram  
  
#include<bits/stdc++.h>
using namespace std;

#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace __gnu_pbds;
template <class T> using Tree = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>; 
#define ook order_of_key
#define fbo find_by_order

#define rep(i,a,n)     for(int i=a;i<n;i++)
#define ll             long long
#define int            long long
#define pb             push_back
#define all(v)         v.begin(),v.end()
#define endl           "\n"
#define x              first
#define y              second
#define gcd(a,b)       __gcd(a,b)
#define mem1(a)        memset(a,-1,sizeof(a))
#define mem0(a)        memset(a,0,sizeof(a))
#define sz(a)          (int)a.size()
#define pii            pair<int,int>
#define hell           1000000007
#define elasped_time   1.0 * clock() / CLOCKS_PER_SEC



template<typename T1,typename T2>istream& operator>>(istream& in,pair<T1,T2> &a){in>>a.x>>a.y;return in;}
template<typename T1,typename T2>ostream& operator<<(ostream& out,pair<T1,T2> a){out<<a.x<<" "<<a.y;return out;}
template<typename T,typename T1>T maxs(T &a,T1 b){if(b>a)a=b;return a;}
template<typename T,typename T1>T mins(T &a,T1 b){if(b<a)a=b;return a;}

// -------------------- Input Checker Start --------------------
 
long long readInt(long long l, long long r, char endd)
{
    long long x = 0;
    int cnt = 0, fi = -1;
    bool is_neg = false;
    while(true)
    {
        char g = getchar();
        if(g == '-')
        {
            assert(fi == -1);
            is_neg = true;
            continue;
        }
        if('0' <= g && g <= '9')
        {
            x *= 10;
            x += g - '0';
            if(cnt == 0)
                fi = g - '0';
            cnt++;
            assert(fi != 0 || cnt == 1);
            assert(fi != 0 || is_neg == false);
            assert(!(cnt > 19 || (cnt == 19 && fi > 1)));
        }
        else if(g == endd)
        {
            if(is_neg)
                x = -x;
            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(false);
            }
            return x;
        }
        else
        {
            assert(false);
        }
    }
}
 
string readString(int l, int r, char endd)
{
    string ret = "";
    int cnt = 0;
    while(true)
    {
        char g = getchar();
        assert(g != -1);
        if(g == endd)
            break;
        cnt++;
        ret += g;
    }
    assert(l <= cnt && cnt <= r);
    return ret;
}
 
long long readIntSp(long long l, long long r) { return readInt(l, r, ' '); }
long long readIntLn(long long l, long long r) { return readInt(l, r, '\n'); }
string readStringLn(int l, int r) { return readString(l, r, '\n'); }
string readStringSp(int l, int r) { return readString(l, r, ' '); }
void readEOF() { assert(getchar() == EOF); }
 
vector<int> readVectorInt(int n, long long l, long long r)
{
    vector<int> a(n);
    for(int i = 0; i < n - 1; i++)
        a[i] = readIntSp(l, r);
    a[n - 1] = readIntLn(l, r);
    return a;
}
 
// -------------------- Input Checker End --------------------

const int maxn=2000005;
int p[maxn];
int sz[maxn];
void clear(int n=maxn){
    rep(i,0,maxn)p[i]=i,sz[i]=1;
}
int root(int x){
   while(x!=p[x]){
       p[x]=p[p[x]];
       x=p[x];
   }
   return x;  
}
void merge(int x,int y){
    int p1=root(x);
    int p2=root(y);
    if(p1==p2)return;
    if(sz[p1]>=sz[p2]){
        p[p2]=p1;
        sz[p1]+=sz[p2];
    }
    else{
        p[p1]=p2;
        sz[p2]+=sz[p1];
    }
}

int solve(){
		
                int n = readIntLn(1,4e4);
                vector<int> a = readVectorInt(n,0,1);
                vector<int> val(n);
                vector<vector<int>> g(n);
                clear();
                for(int i = 2; i <= n; i++){
                        int u = readIntSp(1,n);
                        int v = readIntLn(1,n);
                        u--;v--;
                        g[u].push_back(v);
                        g[v].push_back(u);
                        assert(root(u) != root(v));
                        merge(u,v);
                } 

                for(int i = 0; i < n; i++){
                        val[i] += a[i];
                        for(auto j:g[i])val[j] += a[i];
                }

                int q = readIntLn(1,4e4);
                vector<vector<array<int,3>>> queries(n);
                vector<int> ans(q,-1);
                for(int i = 0; i < q; i++){
                        int x = readIntSp(1,n);
                        x--;
                        int t = readIntSp(1,n);
                        int k = readIntLn(1,1e18);
                        queries[x].push_back({t,k,i});
                }
                Tree<pii> st;
                function<void(int,int)> dfs = [&](int u,int p){
                        if(val[u] > 0) st.insert({val[u],u});
                        for(auto [t,k,j]:queries[u]){
                                if(sz(st) < t){
                                        ans[j] = -1;
                                }else{
                                        auto itr = st.fbo(sz(st) - t);
                                        int d = itr->x;
                                        //d*(i*(i + 1)) >= 2*k
                                        int L = 1, R = 2e9;
                                        k *= 2;
                                        while(L <= R){
                                                int M = (L + R)/2;
                                                int reqd = k/d + (k % d != 0);
                                                if(M*(M + 1) >= k/d){
                                                        ans[j] = M;
                                                        R = M - 1;
                                                }else{
                                                        L = M + 1;
                                                }
                                        }
                                }
                        }
                        for(auto j:g[u]){
                                if(j != p)dfs(j,u);
                        }
                        st.erase({val[u],u});
                };
                dfs(0,0);
                for(int i = 0; i < q; i++)cout << ans[i] << endl;

 return 0;
}
signed main(){
    ios_base::sync_with_stdio(0);cin.tie(0);cout.tie(0);
    //freopen("input.txt", "r", stdin);
    //freopen("output.txt", "w", stdout)`;
    #ifdef SIEVE
    sieve();
    #endif
    #ifdef NCR
    init();
    #endif
    int t = 1;
    while(t--){
        solve();
    }
    return 0;
}
Tester (utkarsh_25dec)'s code (C++)
//Utkarsh.25dec
#include <bits/stdc++.h>
#define ll long long int
#define pb push_back
#define mp make_pair
#define mod 1000000007
#define vl vector <ll>
#define all(c) (c).begin(),(c).end()
using namespace std;
#include <ext/pb_ds/assoc_container.hpp> 
#include <ext/pb_ds/tree_policy.hpp> 
using namespace __gnu_pbds; 
#define ordered_set tree<pair<int,int>, null_type,less<pair<int,int>>, rb_tree_tag,tree_order_statistics_node_update>
// ordered_set s ; s.order_of_key(val)  no. of elements strictly less than val
// s.find_by_order(i)  itertor to ith element (0 indexed)
ll power(ll a,ll b) {ll res=1;a%=mod; assert(b>=0); for(;b;b>>=1){if(b&1)res=res*a%mod;a=a*a%mod;}return res;}
ll modInverse(ll a){return power(a,mod-2);}
long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);

            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }

            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }

            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}
const int N=500023;
bool vis[N];
vector <int> adj[N];
int energetic[N];
vector <pair<ll,ll>> nodequ[N];
int deg[N];
ordered_set s;
map <tuple<ll,ll,ll>,ll> ans;
void dfs(int curr)
{
    vis[curr]=1;
    s.insert(mp(deg[curr],curr));
    for(auto p:nodequ[curr])
    {
        ll nodesreq=p.first;
        ll valreq=p.second;
        if(s.size()<nodesreq)
        {
            ans[make_tuple(curr,nodesreq,valreq)]=-1;
            continue;
        }
        int tmp=s.size();
        auto it=s.find_by_order(tmp-nodesreq);
        auto val=(*it);
        ll d=val.first;
        if(d==0)
        {
            ans[make_tuple(curr,nodesreq,valreq)]=-1;
            continue;
        }
        else
        {
            ll low=1,high=sqrt(4e18/d);
            while(low<=high)
            {
                ll mid=(low+high)/2;
                if((d*mid*(mid+1))/2 >= valreq)
                    high=mid-1;
                else
                    low=mid+1;
            }
            ans[make_tuple(curr,nodesreq,valreq)]=low;
        }
    }
    for(auto it:adj[curr])
    {
        if(vis[it])
            continue;
        dfs(it);
    }
    s.erase(mp(deg[curr],curr));
}
void solve()
{
    int n=readInt(1,40000,'\n');
    for(int i=1;i<=n;i++)
    {
        vis[i]=0;
        adj[i].clear();
        nodequ[i].clear();
        deg[i]=0;
    }
    s.clear();
    ans.clear();
    for(int i=1;i<=n;i++)
    {
        if(i!=n)
            energetic[i]=readInt(0,1,' ');
        else
            energetic[i]=readInt(0,1,'\n');
    }
    for(int i=1;i<n;i++)
    {
        int u,v;
        u=readInt(1,n,' ');
        v=readInt(1,n,'\n');
        assert(u!=v);
        adj[u].pb(v);
        adj[v].pb(u);
    }
    for(int i=1;i<=n;i++)
    {
        deg[i]+=energetic[i];
        for(auto j:adj[i])
            deg[i]+=energetic[j];
    }
    int q=readInt(1,40000,'\n');
    vector <tuple<ll,ll,ll>> queries;
    while(q--)
    {
        ll x,t,k;
        x=readInt(1,n,' ');
        t=readInt(1,n,' ');
        k=readInt(1,1e18,'\n');
        nodequ[x].pb(mp(t,k));
        queries.pb(make_tuple(x,t,k));
    }
    dfs(1);
    for(int i=1;i<=n;i++)
        assert(vis[i]==1);
    for(auto q:queries)
        cout<<ans[q]<<'\n';
}
int main()
{
    #ifndef ONLINE_JUDGE
    freopen("input.txt", "r", stdin);
    freopen("output.txt", "w", stdout);
    #endif
    ios_base::sync_with_stdio(false);
    cin.tie(NULL),cout.tie(NULL);
    int T=1;
    //cin>>T;
    while(T--)
        solve();
    cerr << "Time : " << 1000 * ((double)clock()) / (double)CLOCKS_PER_SEC << "ms\n";
}
Editorialist's code (C++)
#include "bits/stdc++.h"
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

struct query {
	ll t, k, id;
};

/**
 * Point-update Segment Tree
 * Source: kactl
 * Description: Iterative point-update segment tree, ranges are half-open i.e [L, R).
 *              f is any associative function.
 * Time: O(logn) update/query
 */

template<class T, T unit = T()>
struct SegTree {
	T f(T a, T b) { return a+b; }
	vector<T> s; int n;
	SegTree(int _n = 0, T def = unit) : s(2*_n, def), n(_n) {}
	void update(int pos, T val) {
		for (s[pos += n] += val; pos /= 2;)
			s[pos] = f(s[pos * 2], s[pos * 2 + 1]);
	}
	T query(int b, int e) {
		T ra = unit, rb = unit;
		for (b += n, e += n; b < e; b /= 2, e /= 2) {
			if (b % 2) ra = f(ra, s[b++]);
			if (e % 2) rb = f(s[--e], rb);
		}
		return f(ra, rb);
	}
};

int main()
{
	ios::sync_with_stdio(false); cin.tie(0);

	int n; cin >> n;
	vector<int> a(n), b(n);
	for (int &x : a) cin >> x;
	vector<vector<int>> g(n);
	for (int i = 0; i < n-1; ++i) {
		int u, v; cin >> u >> v;
		g[--u].push_back(--v);
		g[v].push_back(u);
	}
	for (int i = 0; i < n; ++i) {
		b[i] = a[i];
		for (int u : g[i]) b[i] += a[u];
	}
	vector<vector<query>> queries(n);
	int q; cin >> q;
	for (int i = 0; i < q; ++i) {
		query cur;
		int u; cin >> u; --u;
		cin >> cur.t >> cur.k;
		cur.id = i;
		queries[u].push_back(cur);
	}
	vector<ll> ans(q, -1);
	SegTree<int> T(n);
	const ll inf = 2.5e9;
	auto dfs = [&] (const auto &self, int u, int p) -> void {
		T.update(b[u], 1);
		for (auto qry : queries[u]) {
			ll lo = 1, hi = inf;
			while (lo < hi) {
				ll mid = (lo + hi)/2;
				ll den = (mid * (mid+1) / 2);
				ll want = (qry.k + den - 1) / den;

				if (want >= n or T.query(want, n) < qry.t) lo = mid+1;
				else hi = mid;
			}
			if (lo == inf) lo = -1;
			ans[qry.id] = lo;
		}

		for (int v : g[u]) {
			if (v == p) continue;
			self(self, v, u);
		}
		T.update(b[u], -1);
	};
	dfs(dfs, 0, 0);
	for (auto x : ans) cout << x << '\n';
}

I think the overall time complexity is O(NlogN+QlogN)? Maintain a Fenwick/Segment tree costs O(NlogN) then binary search for T maximum speed costs another O(QlogN).
Besides, I didn’t see too much difference between the easy ver and hard ver. The only difference is K, but it does not impact the overall time complexity too much. One can solve m_x\times t \times (t+1)/2 \ge K in O(1) time. Even do binary search like my solution, it costs only an extra O(QlogK) which is not too hard for a contestant who can solve the easy version.

According to my original idea,this problem should’t be solved easily using data structure+binary search.An important property has not been used.

In retrospect,the information from node 1 to x is too easy to maintain.

Let’s focus on a harder version:for each query,you need to answer the information of several subtrees.

That is,find the minimum number of seconds after which, there are at least T_i nodes with energy not less than K_i on the union of subtree X_1,X_2,...,X_m.

(Although it still can be maintained using persistence LCT,at least it inspires us to look for properties rather than directly using data structures…)

Solution(no data structure)

For vertex x, let m_x be the sum of A_v across all vertices v such that v = x or v is a neighbor of x.Then, vertex x will receive exactly m_x \cdot i energy at time i(the same as editorial above).

Conclusion:There are only O(\sqrt{N}) distinct m_x.

Proof:

we suppose there are k distinct m_x,then

1+2+...+k \leq ÎŁm_i \leq ÎŁ(degree(i)+1)=n+ÎŁdegree(i)=3n

q.e.d

Enumerate each distinct m_i,let value_j=1 if m_j>=m_i and value_j=0 otherwise.Then we can simply calculate the answer on these trees.

Time complexity:O(N\sqrt{N}+Q\sqrt{N}).

1 Like

It would also be a great problem if each query asks the answer on the path from u to v.

This version(query path from u to v) is also DS-bashable.
We can split a query into root->x, root->y, root->lca and use persistent segtree to solve it.

I just solved it with a DFS + PBDS Set. Submission.
AFAIK Time complexity is - (VLogV + E + QLogV)
Sadly, couldn’t solve it during the contest.