ATWNT - Editorial

Author: Tursynbay Dinmukhamed
Testers: Felipe Mota
Editorialist: Vichitr Gandas

Medium

PREREQUISITES:

Graphs, DFS, Sieve of eratosthenes

PROBLEM:

Given a rooted tree and Q queries. In each query, W tasks are given to node V. If V is a leaf, it executes all tasks. If it has K childs and it receives A tasks, then if A \%K=0, it divides given tasks equally in its childs, otherwise all tasks are ignored. For each query, find number of tasks which are ignored.

QUICK EXPLANATION

If a node has only one child, it gives all tasks to the child. If the node is leaf, it performs all tasks. If it has more than one child, it either equally divides if possible otherwise drops all given tasks. First merge every node v with his child if v has 1 child. after this step, each leaf will receive tasks from not more than log(W) parents. Calculate how much contribution a node gives to its subtree leaves and store in the map. This can be easily done with DFS which returns the map of \{X, number of leaves in its subtree who takes the work 1/X from the node \}. And answers for the queries can be calculated for the node v when we visit it in the dfs.

EXPLANATION

Lemma 1: Tasks are actually performed by leaves only

When a node has one child, it gives all tasks to the child. If it more than one child, it equally divides the tasks in the childs. Or if its not possible to equally divide, all given tasks are ignored. So tasks either gets ignored or given to the childs. This stops when we reach leaves, and the leaves perform the given tasks.

Once we merge all nodes v with its child if it has only one child. The remaining tree nodes will always either divide the tasks in childs or ignore them. Now all the nodes other than leaves, have at least 2 childs, hence in each step, assigned tasks are divided by at least 2. Hence W becomes 1 in not more than logW steps.

Now lets come back to the problem. In each query, node v receives w tasks. For calculating how many tasks are actually performed out of w, we need to calculate how much contribution all leaves receive under the subtree of node v. See the below figure -

So for each node, we will maintain the count of leaves with a particular contribution. For example take the tree given in above image. Node 2 map will contain: \{2: 2\} which represent that there are 2 leaves under the subtree of node 2 each with contribution 1/2. Similarly map of node 1 will contain: \{3: 2, 6: 2\} which represent that there are 2 leaves under the subtree of node 1 each with contribution 1/3 which are leaves 3 and 4. And also there are 2 leaves under the subtree of node 1 each with contribution 1/6 which are leaves 5 and 6.

Now how do we construct this map? For leaves, the map will just contain \{1:1\}. Because only 1 leaf with 1 contribution. Now if we have the map for childs, how do we update for node v?
Lets say an entry is \{a:b\} in child u map, that means b leaves each with contribution 1/a for the child u. If the current node has k childs. we will add \{a*k: b\} to the node v map. Because the tasks going in one child will be 1/k only. Here we add \{a*k: b\} to current node map only if a*k <= 10^6 because W <= 10^6 as given in constraints.
This way we can iterate over all child maps and update the current node map.

Pseudocode
k = graph[v].size
curMap = {}
for u in graph[v]:
childMap = dfs(u)
for [a, b] in childMap:
curMap[a*k] += b


In the pseudocode, dfs(u) returns the map formed for the child u.

Now to answer the queries received on node v, lets say number of tasks received are w. And let the entry in the map be \{a:b\}. Now check if w \% a = 0, if so, these b leaves each will be performing w/a tasks. Hence total of w/a * b tasks will be performed. This way we can calculate the number of tasks which are performed out of w, let it be x. Then answer for the query would be w - x.

Pseudocode
for [w, idx] in queries[v]:
cnt = 0
for [a, b] in curMap:
if w%a == 0:
cnt += (w/a) * b
ans[idx] = w - cnt


Complexity Analysis
Overall time complexity of this solution is \mathcal{O}((N+Q) \sqrt{N} \log{N}).
Space complexity is \mathcal{O}(N \sqrt{N} + Q) to store the tree, queries, answers for the queries and the maps.

SOLUTIONS

Setter's Solution

//#pragma GCC optimize("Ofast,no-stack-protector,unroll-loops")
//#pragma GCC target("avx,avx2")
//#pragma GCC target("avx2")
//#pragma GCC optimize("O3")

//# include <x86intrin.h>
# include <bits/stdc++.h>

# include <ext/pb_ds/assoc_container.hpp>
# include <ext/pb_ds/tree_policy.hpp>

using namespace __gnu_pbds;
using namespace std;

template<typename T> using ordered_set = tree <T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;

#define _USE_MATH_DEFINES_
#define ll long long
#define ld long double
#define Accepted 0
#define pb push_back
#define mp make_pair
#define sz(x) (int)(x.size())
#define every(x) x.begin(),x.end()
#define F first
#define S second
#define lb lower_bound
#define ub upper_bound
#define For(i,x,y)  for (ll i = x; i <= y; i ++)
#define FOr(i,x,y)  for (ll i = x; i >= y; i --)
#define SpeedForce ios_base::sync_with_stdio(0), cin.tie(0), cout.tie(0)

void setIn(string s) { freopen(s.c_str(),"r",stdin); }
void setOut(string s) { freopen(s.c_str(),"w",stdout); }
void setIO(string s = "") {
// cin.exceptions(cin.failbit);
// throws exception when do smth illegal
// ex. try to read letter into int
if (sz(s)) { setIn(s+".in"), setOut(s+".out"); } // for USACO
}

const double eps = 0.000001;
const ld pi = acos(-1);
const int maxn = 1e7 + 9;
const int mod = 1e9 + 7;
const ll MOD = 1e18 + 9;
const ll INF = 1e18 + 123;
const int inf = 2e9 + 11;
const int mxn = 1e6 + 9;
const int N = 1e5+5;
const int M = 22;
const int pri = 997;
const int Magic = 2101;

const int dx[] = {-1, 0, 1, 0};
const int dy[] = {0, -1, 0, 1};

int rnd (int l, int r) {
return uniform_int_distribution<int> (l, r)(gen);
}
int n;
int p[N];
int deg[N];
bool isleaf[N];
int id[N], ans[N];
int nxt[N];
vector < int > g[N];
vector < pair<int,int> > query[mxn];
int pA[mxn], pB[mxn];

int fix (int v) {
if(id[v] == v)
return v;
return id[v] = fix(id[v]);
}

void calcUp (int v) {
ll x = 1; //product
while(v && x <= 1e6) {
g[v].pb(x);
v = p[v];
x *= deg[v];
}
}

void solve() {
cin >> n;
fill(isleaf+1, isleaf+n+1, 1);
for (int i = 1; i < n; ++i) {
int a, b = i+1;
cin >> a;
p[b] = a;
++deg[a];
isleaf[a] = false;
}

iota(id+1, id+n+1, 1);
for (int i = 2; i <= n; ++i) {
int par = p[i];
if(deg[par] == 1) {
id[par] = i;
p[i] = p[par];
}
}

for (int v = 1; v <= n; ++v) {
id[v] = fix(v);
}

for (int v = 1; v <= n; ++v) if(isleaf[v]) {
calcUp(v);
}

for (int v = 1; v <= n; ++v)
if(id[v] == v)
sort(every(g[v]));

int m;
cin >> m;
for (int i = 1; i <= m; ++i) {
int v, w;
cin >> v >> w;
if(deg[v] == 0) {
ans[i] = 0;
continue;
}
v = id[v];

ans[i] = w;
query[w].pb({v, i});
}

for (int x = 1; x < mxn; ++x)
for (int y = x; y < mxn; y += x) {
for (auto it : query[y]) {
int v = it.first;

while(pA[v] < g[v].size() && g[v][pA[v]] < x) ++pA[v];

if(pA[v] > pB[v])
pB[v] = pA[v];

while(pB[v] < g[v].size() && g[v][pB[v]] <= x) ++pB[v];

ans[it.second] -= (y/x) * (pB[v] - pA[v]);
}
}

for (int i = 1; i <= m; ++i)
cout << ans[i] << '\n';
}

int main () {
SpeedForce;

int T = 1;
//cin >> T;
while(T--) solve();

return Accepted;
}

Tester's Solution
#include <bits/stdc++.h>
using namespace std;
template<typename T = int> vector<T> create(size_t n){ return vector<T>(n); }
template<typename T, typename... Args> auto create(size_t n, Args... args){ return vector<decltype(create<T>(args...))>(n, create<T>(args...)); }
long long readInt(long long l,long long r,char endd){
long long a;
cin >> a;
return a;
long long x=0;
int cnt=0;
int fi=-1;
bool is_neg=false;
while(true){
char g=getchar();
if(g=='-'){
assert(fi==-1);
is_neg=true;
continue;
}
if('0'<=g && g<='9'){
x*=10;
x+=g-'0';
if(cnt==0){
fi=g-'0';
}
cnt++;
assert(fi!=0 || cnt==1);
assert(fi!=0 || is_neg==false);

assert(!(cnt>19 || ( cnt==19 && fi>1) ));
} else if(g==endd){
if(is_neg){
x= -x;
}
assert(l<=x && x<=r);
return x;
} else {
assert(false);
}
}
}
string ret="";
int cnt=0;
while(true){
char g=getchar();
assert(g!=-1);
if(g==endd){
break;
}
cnt++;
ret+=g;
}
assert(l<=cnt && cnt<=r);
return ret;
}
long long readIntSp(long long l,long long r){
}
long long readIntLn(long long l,long long r){
}
}
}
int main(){
ios::sync_with_stdio(false);
cin.tie(0);
const int maxn = 100000, maxq = 100000, maxw = 1000000;
int n = readInt(1, maxn, '\n');
for(int i = 1; i < n; i++){
int p;
if(i != n - 1) p = readInt(1, n, ' ');
else p = readInt(1, n, '\n');
}
vector<int> real(n), parent(n, -1);
function<void(int)> dfs = [&](int u){
} else {
real[u] = u;
dfs(v);
parent[real[v]] = real[u];
}
}
};
dfs(0);
vector<vector<int>> facs(n);
for(int i = 0; i < n; i++){
int cur = i;
long long dx = 1;
while(cur != -1 && dx <= maxw){
facs[cur].push_back(dx);
cur = parent[cur];
}
}
}
vector<vector<pair<int,int>>> qry(n);
int q = readInt(1, maxq, '\n');
vector<int> ans(q), seenw(maxw + 1, 0);
for(int i = 0; i < q; i++){
int v, w;
v = readInt(1, n, ' ');
qry[real[v - 1]].push_back({i, w});
seenw[w] = 1;
}
vector<int> factor(maxw + 1, -1), nxt(maxw + 1, -1);
for(int i = 2; i <= maxw; i++){
if(factor[i] == -1){
for(int j = i; j <= maxw; j += i){
if(factor[j] == -1)
factor[j] = i;
}
}
nxt[i] = i / factor[i];
}
vector<vector<int>> divs(maxw + 1);
for(int i = 1; i <= maxw; i++){
if (seenw[i]) {
vector<pair<int,int>> f;
int c = i;
while(c != 1){
if(!f.empty() && f.back().first == factor[c]) f.back().second += 1;
else f.push_back({factor[c], 1});
c = nxt[c];
}
function<void(int,int)> brute = [&](int at, int dv){
if(at == f.size()) divs[i].push_back(dv);
else {
for(int j = 0; j <= f[at].second; j++){
brute(at + 1, dv);
dv *= f[at].first;
}
}
};
brute(0, 1);
}
}
vector<int> cnt(maxw + 1, 0);
for(int i = 0; i < n; i++){
if(qry[i].size() > 0){
for(int fac : facs[i]) cnt[fac] += 1;
for(auto e : qry[i]){
int idx, w; tie(idx, w) = e;
int done = 0;
for(int div : divs[w]){
assert(w % div == 0);
done += cnt[div] * (w / div);
}
ans[idx] = w - done;
}
for(int fac : facs[i]) cnt[fac] -= 1;
}
}
for(int i = 0; i < q; i++)
cout << ans[i] << '\n';
return 0;
}

Editorialist's Solution
/***************************************************

@author: vichitr
Compiled On: 13 Feb 2021

*****************************************************/
#include<bits/stdc++.h>
using namespace std;

const int N = 1e5 + 7;

vector<int> graph[N];
vector<pair<int, int>> queries[N];
int ans[N];

map<int, int> dfs(int v){
int k = graph[v].size();
map<int, int> curMap;

// if mp child, leaf performs whole tasks
if(k == 0)
curMap[1] = 1;
// if one child, merge it with parent
else if(k == 1)
curMap = dfs(graph[v][0]);
// if more childs, get their maps and build curMap
else{
for(int u: graph[v]){
map<int, int> childMap = dfs(u);
for(auto ab : childMap){
int a = ab.first;
int b = ab.second;
int contribution = a * k;
// if contribution > 10^6, it wont pick up any tasks
if(contribution <= 1e6)
curMap[contribution] += b;
}
}
}

// answer queries for node v
for(auto widx : queries[v]){
int w = widx.first;
int idx = widx.second;
int cnt = 0;
for(auto ab : curMap){
int a = ab.first;
int b = ab.second;
if(w % a == 0)
cnt += w / a * b;
}
ans[idx] = w - cnt;
}

return curMap;
}

signed main(){
#ifndef ONLINE_JUDGE
freopen("input.txt", "r", stdin);
freopen("output.txt", "w", stdout);
#endif

ios_base::sync_with_stdio(0);
cin.tie(0);

int n; cin>>n;
for(int i = 1; i < n; i++){
int p; cin >> p;
graph[p].push_back(i + 1);
}

int q; cin>>q;
for(int i = 1; i <= q; i++){
int v, w; cin >> v >> w;
queries[v].push_back({w, i});
}

// call dfs for node 1
dfs(1);

// output queries offline
for(int i = 1; i <= q; i++)
cout << ans[i] << '\n';

return 0;
}


VIDEO EDITORIAL:

1 Like

Hi, I have understood that each leaf will be contributed to at max 20 (log1e6) nodes. But what if we have 10^5 queries with same node in each query. (As we donâ€™t guarantee that a node will have at max K(upper limit) leaves in that subtree) ? I mean is there any upper limit for the size of map[some node]. Am i missing something?

isnâ€™t the complexity O(Q x sqrt(1e6)) for queries and is it necessary to process the queries offline.
This is my code getting TLE, Anyone help me

You havenâ€™t compressed the tree.

Got it, This is my code,
But i think complexity is Q * sqrt(1e6).

Might be, I am also unable to prove the complexity part.

1 Like

Well I believe the complexity of the editorialâ€™s solution is qroot(n)(log(n)) . The log(n) comes from usage of map and there exists a case where the size of map can be pushed to root(n). The case is simply when the root has approximately root(n) children and each child with children of its own in ap. The first child having 2 children, the second having 3 and so on. In this case, the size of map of root will be root(n). I donâ€™t think the solution would pass in this case. I had thought of two optimizations that would help, firstly to use a frequency array instead of map to deal with the log(n) which was enough giving complexity ((n+q)*(root(w)+log(n))) by iterating on factors of w from the array, hence the root(w). Another optimization that could have been implemented would be to find all the factors of w exactly using a form of seive. As the maximum number of factors are approximately (w^(1/3)) it would have given a significant boost, though it was not required. Alternate approaches, the one that I applied during the contest was a form of using mid way cashing on recursion. I will spare the details.

3 Likes

Well, I was also not able to prove it, Yes its definitely more than \mathcal{O}((N+Q) \log{10^6}) as in the case mentioned. looks like its \mathcal{O}((N+Q) \sqrt{N} \log{N}).

3 Likes

I believe this solution is almost optimal and would take at worst case (n+q)(log(n)(log(n))+max(w)log(w). First logn is for the sizes of leaf sets and the second is for sorting, a binary tree would probably make the worst case, which could have been avoided using merge sort but is not needed as std sort is quite fast. Really commendable and elegant implementation of the q*w^(1/3) to max(w)log(w) by using a form of sieve and two pointers to change that.

Can you please put more details on this. I know, itâ€™s performing seive, but couldnâ€™t get the two pointer pA, and pB

for (int x = 1; x < mxn; ++x)
for (int y = x; y < mxn; y += x) {
for (auto it : query[y]) {
int v = it.first;

while(pA[v] < g[v].size() && g[v][pA[v]] < x) ++pA[v];

if(pA[v] > pB[v])
pB[v] = pA[v];

while(pB[v] < g[v].size() && g[v][pB[v]] <= x) ++pB[v];

ans[it.second] -= (y/x) * (pB[v] - pA[v]);
}
}