QUEFNX-Editorial

PROBLEM LINK:

Practice
Hospital in Berland

Author: hellb0y_suru
Editorialist: hellb0y_suru

PREREQUISITES:

  • Euler Totient Function
  • Sieve of Eratosthenes
  • DFS
  • Euler Tour of Tree
  • Range Data Structure (Segment, Fenwick)

DIFFICULTY:

EASY-MEDIUM

PROBLEM:

Given a function F(x) , where F(x) is defined as -

F(x) = \displaystyle\sum_{i=1}^{x} i * gcd(i,x)

You are given a tree rooted at Node 1, you need to answer queries, the queries are of
two types -

Type-1 -> 1 i - Find the sum of F(x) over all the nodes in the subtree rooted at node i,
where x is the value of the node.

Type-2 -> 2 i val - Update the value of the node i to val.

EXPLANATION:

First of all we need to simplify the given function F(x) , where F(x) -
F(x) = \displaystyle\sum_{i=1}^{x} i * gcd(i,x)

Brute force approach to compute F(x) is to iterate over all possible values
of nodes (1 <= Value <= 10^3) and store the results computed for each such value, so that these
can be reused later.

How To Simplify F(x) ?

For efficient computation of F(x) over values of nodes according to the original constraints
we need to do something better.

For simplification of F(x) , go through the below links where I have attached the process of how to compute F(x) in an efficient manner.Proof

Now after simplification F(x) reduces to -

F(x) = (x/2)\times( \displaystyle\sum_{{d|x}} (x/d) \times \displaystyle\Phi(d) + x)
where \displaystyle\Phi(d) is Euler’s phi function.

This function can be easily calculated for every 1 <= x <= 10^6 in O(N logN) time using Sieve of Eratosthenes.

For each query -

Now , we need to calculate sum of F(x) over all the nodes in the subtree rooted at Node i.
In order to calculate the answer for a query , we already have the precomputed values of F(x) , now as a part of brute force -

Type-1 We can perform a dfs every time for a query and take sum of F(x) over all the nodes in the subtree of the nodes given in the query i.e O(V+E) time.
Type-2 Updation of a given Node can be done in O(1) time.

But , in this way , we need to perform DFS for each query i.e O(V+E) for each query which is not fast enough!! :confused:

In order to calculate this value efficiently, we can do EULER TOUR of the tree. :smiley:

And now the problem reduces to finding sum over a given range of intervals for each given node in the query, which can be easily done by building a segment tree or fenwick tree for the intervals in the Euler Tour of the tree as a subtree will always denote some contiguous subarray of the Euler array which stores the values for F(x) for the corresponding node built from Euler Tour , so now for every query -

Type-1 : For a given node we can easily find the range sums over a given interval that corresponds to the node given in the query. Time Complexity: O(logN)

Type-2 : Since we are updating an element over Euler array over which the segment tree is built , updating a element takes O(logN) time too. :slight_smile:

SOLUTIONS:

Setter's Solution in c++
// Suryansh Kumar
#include <bits/stdc++.h>
using namespace std;

const int mod = 1e9 + 7;
int a[1000010];
long long int res[1000010] , phi[1000010];
vector<pair<int,int>> ind(1000010);
vector<int> g[1000010];
int idx;
int inv2;

int power(int x, int n){
   if(x==0) return 0;
   else if(x==1 || n==0) return 1;
   int r = power(x,n/2);
   r = ((long long int)r*r)%mod;
   if(n&1) return (long long int)(r * x)%mod;
   else return r;
}
struct Fenwick{
   vector<long long int> v;
   int sz;
   Fenwick(int n){
       sz = n;
       v.assign(sz+1,0);
   }
   void update(int id, long long int val){
       while(id<=sz) v[id] = ((v[id] + val)%mod + mod)%mod , id+=id & -id;
   }
   long long int sum(int id){
       long long int sum = 0;
       while(id>0) sum = (sum + v[id])%mod , id-= id & -id;
       return sum%mod;
   }
   long long int query(int l, int r){
       return (sum(r) - sum(l-1) + mod)%mod;
   }
};

void dfs(int node, int par){
   idx++;
   ind[node].first = idx;
   for(auto &i: g[node]){
       if(i!=par) dfs(i,node);
   }
   idx++;
   ind[node].second = idx;
}

void pre(){  
   for(int i=1;i<=1000000;i++) phi[i]=i , res[i] = 0;  
   for(int i=2;i<=1000000;i++){
       if(phi[i]==i){
           for(int j=1; i*j<=1000000; j++) phi[i*j]-=phi[i*j]/i;
       }
   }
   for(long long int i=1;i<=1000000;i++){
       for(long long int j=1;j*i<=1000000;j++){
           res[i*j] += j*phi[i];
       }
   }
}

long long int Fun(int val){
   long long int ans = ((res[val] + val)*val)%mod;
   return (inv2*ans)%mod;
}


void _sol(){
   int n; cin >> n;
   for(int i=1;i<=n;i++) cin >> a[i] , g[i].clear();
   for(int i=0;i<n-1;i++){
       int u,v; cin >> u >> v;
       g[u].push_back(v);
       g[v].push_back(u);
   }
   idx=0;
   dfs(1,0);
   inv2 = power(2,mod-2);
   Fenwick f(2*n+2);
   for(int i=1;i<=n;i++) f.update(ind[i].first,Fun(a[i])) , f.update(ind[i].second,Fun(a[i]));
   int q; cin >> q;
   while(q--){
       int type; cin >> type;
       if(type==1){
           int node; cin >> node;
           cout << (f.query(ind[node].first,ind[node].second)*inv2)%mod << "\n";
       }
       else{
           long long int node,val; cin >> node  >> val;
           f.update(ind[node].first , Fun(val) - Fun(a[node]));
           f.update(ind[node].second , Fun(val) - Fun(a[node]));
           a[node] = val;
       }
   }
}

int main(){
   ios_base::sync_with_stdio(0);
   cin.tie(0);
   cout.tie(0);
   pre();
   int t=1; // cin >> t;
   while(t--) _sol();
}

Setter's Solution in PYPY
# Suryansh Kumar
import sys
import atexit
import io

sys.setrecursionlimit(100000)
 
buff = io.BytesIO()
sys.stdout = buff
 
@atexit.register
def write():
    sys.__stdout__.write(buff.getvalue())
 
 
class Fenwick:
    def __init__(self,n):
        self.sz = n+10
        self.bit=[0]*(n+10)
    def update(self,idx,val):
        while(idx<=self.sz):
            self.bit[idx]+=val
            idx += idx & -idx
    def summ(self,idx):
        ans=0
        while(idx>0):
            ans+=self.bit[idx]
            idx -= idx & -idx
        return ans
    def query(self,l,r):
        return self.summ(r) - self.summ(l-1)
 
 
lim = 1000009
phi=[0]*lim
res=[0]*lim
ind = []
a = []
g = []
 
def pre():
    for i in range(lim):
        phi[i]=i
        res[i]=0
    for i in range(2,lim):
        if phi[i] == i:
            j=1
            while(j*i<lim):
                phi[i*j]-=phi[i*j]//i
                j+=1
    for i in range(1,lim):
        j=1
        while(i*j<lim):
            res[i*j]+= phi[i]*j
            j+=1
    for i in range(lim):
        ind.append([])
 
def dfs(node,par):
    global idxx
    idxx = idxx + 1
    ind[node].append(idxx)
    for i in range(len(g[node])):
        if g[node][i] != par:
            dfs(g[node][i],node)
    idxx = idxx + 1
    ind[node].append(idxx)
 
def Fun(val):
    ans = (res[val] + val)*val
    return ans//2
    
 
 
n = int(input())
 
for i in range(0,n+10):
    g.append([])
 
a = list(map(int,raw_input().split(" ")))
 
 
for i in range(n-1):
    u,v = map(int,raw_input().split(" "))
    g[u].append(v)
    g[v].append(u)
 
 
pre()
idxx=0
dfs(1,0)
 
f = Fenwick(2*n+20)
 
for i in range(n):
    f.update((ind[i+1][0]),Fun(a[i]))
    f.update((ind[i+1][1]),Fun(a[i]))
q = int(raw_input())
 
for i in range(q):
    qr = list(map(int,raw_input().split(" ")))
    if qr[0]==1:
        print((f.query(ind[qr[1]][0],ind[qr[1]][1])//2)%1000000007)
    else:
        f.update((ind[qr[1]][0]),Fun(qr[2]) - Fun(a[qr[1]-1]))
        f.update((ind[qr[1]][1]),Fun(qr[2]) - Fun(a[qr[1]-1]))
        a[qr[1]-1]=qr[2]


3 Likes

Hey @hellb0y_suru can you please check the link to the proof isn’t working…