PROBLEM LINK:
Author: hellb0y_suru
Editorialist: hellb0y_suru
PREREQUISITES:
- Euler Totient Function
- Sieve of Eratosthenes
- DFS
- Euler Tour of Tree
- Range Data Structure (Segment, Fenwick)
DIFFICULTY:
EASY-MEDIUM
PROBLEM:
Given a function F(x) , where F(x) is defined as -
F(x) = \displaystyle\sum_{i=1}^{x} i * gcd(i,x)
You are given a tree rooted at Node 1, you need to answer queries, the queries are of
two types -
Type-1 → 1 i - Find the sum of F(x) over all the nodes in the subtree rooted at node i,
where x is the value of the node.
Type-2 → 2 i val - Update the value of the node i to val.
EXPLANATION:
First of all we need to simplify the given function F(x) , where F(x) -
F(x) = \displaystyle\sum_{i=1}^{x} i * gcd(i,x)
Brute force approach to compute F(x) is to iterate over all possible values
of nodes (1 <= Value <= 10^3) and store the results computed for each such value, so that these
can be reused later.
How To Simplify F(x) ?
For efficient computation of F(x) over values of nodes according to the original constraints
we need to do something better.
For simplification of F(x) , go through the below links where I have attached the process of how to compute F(x) in an efficient manner.Proof
Now after simplification F(x) reduces to -
F(x) = (x/2)\times( \displaystyle\sum_{{d|x}} (x/d) \times \displaystyle\Phi(d) + x)
where \displaystyle\Phi(d) is Euler’s phi function.
This function can be easily calculated for every 1 <= x <= 10^6 in O(N logN) time using Sieve of Eratosthenes.
For each query -
Now , we need to calculate sum of F(x) over all the nodes in the subtree rooted at Node i.
In order to calculate the answer for a query , we already have the precomputed values of F(x) , now as a part of brute force -
Type-1 We can perform a dfs every time for a query and take sum of F(x) over all the nodes in the subtree of the nodes given in the query i.e O(V+E) time.
Type-2 Updation of a given Node can be done in O(1) time.
But , in this way , we need to perform DFS for each query i.e O(V+E) for each query which is not fast enough!! ![]()
In order to calculate this value efficiently, we can do EULER TOUR of the tree. ![]()
And now the problem reduces to finding sum over a given range of intervals for each given node in the query, which can be easily done by building a segment tree or fenwick tree for the intervals in the Euler Tour of the tree as a subtree will always denote some contiguous subarray of the Euler array which stores the values for F(x) for the corresponding node built from Euler Tour , so now for every query -
Type-1 : For a given node we can easily find the range sums over a given interval that corresponds to the node given in the query. Time Complexity: O(logN)
Type-2 : Since we are updating an element over Euler array over which the segment tree is built , updating a element takes O(logN) time too. ![]()
SOLUTIONS:
Setter's Solution in c++
// Suryansh Kumar
#include <bits/stdc++.h>
using namespace std;
const int mod = 1e9 + 7;
int a[1000010];
long long int res[1000010] , phi[1000010];
vector<pair<int,int>> ind(1000010);
vector<int> g[1000010];
int idx;
int inv2;
int power(int x, int n){
if(x==0) return 0;
else if(x==1 || n==0) return 1;
int r = power(x,n/2);
r = ((long long int)r*r)%mod;
if(n&1) return (long long int)(r * x)%mod;
else return r;
}
struct Fenwick{
vector<long long int> v;
int sz;
Fenwick(int n){
sz = n;
v.assign(sz+1,0);
}
void update(int id, long long int val){
while(id<=sz) v[id] = ((v[id] + val)%mod + mod)%mod , id+=id & -id;
}
long long int sum(int id){
long long int sum = 0;
while(id>0) sum = (sum + v[id])%mod , id-= id & -id;
return sum%mod;
}
long long int query(int l, int r){
return (sum(r) - sum(l-1) + mod)%mod;
}
};
void dfs(int node, int par){
idx++;
ind[node].first = idx;
for(auto &i: g[node]){
if(i!=par) dfs(i,node);
}
idx++;
ind[node].second = idx;
}
void pre(){
for(int i=1;i<=1000000;i++) phi[i]=i , res[i] = 0;
for(int i=2;i<=1000000;i++){
if(phi[i]==i){
for(int j=1; i*j<=1000000; j++) phi[i*j]-=phi[i*j]/i;
}
}
for(long long int i=1;i<=1000000;i++){
for(long long int j=1;j*i<=1000000;j++){
res[i*j] += j*phi[i];
}
}
}
long long int Fun(int val){
long long int ans = ((res[val] + val)*val)%mod;
return (inv2*ans)%mod;
}
void _sol(){
int n; cin >> n;
for(int i=1;i<=n;i++) cin >> a[i] , g[i].clear();
for(int i=0;i<n-1;i++){
int u,v; cin >> u >> v;
g[u].push_back(v);
g[v].push_back(u);
}
idx=0;
dfs(1,0);
inv2 = power(2,mod-2);
Fenwick f(2*n+2);
for(int i=1;i<=n;i++) f.update(ind[i].first,Fun(a[i])) , f.update(ind[i].second,Fun(a[i]));
int q; cin >> q;
while(q--){
int type; cin >> type;
if(type==1){
int node; cin >> node;
cout << (f.query(ind[node].first,ind[node].second)*inv2)%mod << "\n";
}
else{
long long int node,val; cin >> node >> val;
f.update(ind[node].first , Fun(val) - Fun(a[node]));
f.update(ind[node].second , Fun(val) - Fun(a[node]));
a[node] = val;
}
}
}
int main(){
ios_base::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
pre();
int t=1; // cin >> t;
while(t--) _sol();
}
Setter's Solution in PYPY
# Suryansh Kumar
import sys
import atexit
import io
sys.setrecursionlimit(100000)
buff = io.BytesIO()
sys.stdout = buff
@atexit.register
def write():
sys.__stdout__.write(buff.getvalue())
class Fenwick:
def __init__(self,n):
self.sz = n+10
self.bit=[0]*(n+10)
def update(self,idx,val):
while(idx<=self.sz):
self.bit[idx]+=val
idx += idx & -idx
def summ(self,idx):
ans=0
while(idx>0):
ans+=self.bit[idx]
idx -= idx & -idx
return ans
def query(self,l,r):
return self.summ(r) - self.summ(l-1)
lim = 1000009
phi=[0]*lim
res=[0]*lim
ind = []
a = []
g = []
def pre():
for i in range(lim):
phi[i]=i
res[i]=0
for i in range(2,lim):
if phi[i] == i:
j=1
while(j*i<lim):
phi[i*j]-=phi[i*j]//i
j+=1
for i in range(1,lim):
j=1
while(i*j<lim):
res[i*j]+= phi[i]*j
j+=1
for i in range(lim):
ind.append([])
def dfs(node,par):
global idxx
idxx = idxx + 1
ind[node].append(idxx)
for i in range(len(g[node])):
if g[node][i] != par:
dfs(g[node][i],node)
idxx = idxx + 1
ind[node].append(idxx)
def Fun(val):
ans = (res[val] + val)*val
return ans//2
n = int(input())
for i in range(0,n+10):
g.append([])
a = list(map(int,raw_input().split(" ")))
for i in range(n-1):
u,v = map(int,raw_input().split(" "))
g[u].append(v)
g[v].append(u)
pre()
idxx=0
dfs(1,0)
f = Fenwick(2*n+20)
for i in range(n):
f.update((ind[i+1][0]),Fun(a[i]))
f.update((ind[i+1][1]),Fun(a[i]))
q = int(raw_input())
for i in range(q):
qr = list(map(int,raw_input().split(" ")))
if qr[0]==1:
print((f.query(ind[qr[1]][0],ind[qr[1]][1])//2)%1000000007)
else:
f.update((ind[qr[1]][0]),Fun(qr[2]) - Fun(a[qr[1]-1]))
f.update((ind[qr[1]][1]),Fun(qr[2]) - Fun(a[qr[1]-1]))
a[qr[1]-1]=qr[2]