PROBLEM LINK:
Author: hellb0y_suru
Editorialist: hellb0y_suru
PREREQUISITES:
- Euler Totient Function
- Sieve of Eratosthenes
- DFS
- Euler Tour of Tree
- Range Data Structure (Segment, Fenwick)
DIFFICULTY:
EASY-MEDIUM
PROBLEM:
Given a function F(x) , where F(x) is defined as -
F(x) = \displaystyle\sum_{i=1}^{x} i * gcd(i,x)
You are given a tree rooted at Node 1, you need to answer queries, the queries are of
two types -
Type-1 → 1 i - Find the sum of F(x) over all the nodes in the subtree rooted at node i,
where x is the value of the node.
Type-2 → 2 i val - Update the value of the node i to val.
EXPLANATION:
First of all we need to simplify the given function F(x) , where F(x) -
F(x) = \displaystyle\sum_{i=1}^{x} i * gcd(i,x)
Brute force approach to compute F(x) is to iterate over all possible values
of nodes (1 <= Value <= 10^3) and store the results computed for each such value, so that these
can be reused later.
How To Simplify F(x) ?
For efficient computation of F(x) over values of nodes according to the original constraints
we need to do something better.
For simplification of F(x) , go through the below links where I have attached the process of how to compute F(x) in an efficient manner.Proof
Now after simplification F(x) reduces to -
F(x) = (x/2)\times( \displaystyle\sum_{{d|x}} (x/d) \times \displaystyle\Phi(d) + x)
where \displaystyle\Phi(d) is Euler’s phi function.
This function can be easily calculated for every 1 <= x <= 10^6 in O(N logN) time using Sieve of Eratosthenes.
For each query -
Now , we need to calculate sum of F(x) over all the nodes in the subtree rooted at Node i.
In order to calculate the answer for a query , we already have the precomputed values of F(x) , now as a part of brute force -
Type-1 We can perform a dfs every time for a query and take sum of F(x) over all the nodes in the subtree of the nodes given in the query i.e O(V+E) time.
Type-2 Updation of a given Node can be done in O(1) time.
But , in this way , we need to perform DFS for each query i.e O(V+E) for each query which is not fast enough!!
In order to calculate this value efficiently, we can do EULER TOUR of the tree.
And now the problem reduces to finding sum over a given range of intervals for each given node in the query, which can be easily done by building a segment tree or fenwick tree for the intervals in the Euler Tour of the tree as a subtree will always denote some contiguous subarray of the Euler array which stores the values for F(x) for the corresponding node built from Euler Tour , so now for every query -
Type-1 : For a given node we can easily find the range sums over a given interval that corresponds to the node given in the query. Time Complexity: O(logN)
Type-2 : Since we are updating an element over Euler array over which the segment tree is built , updating a element takes O(logN) time too.
SOLUTIONS:
Setter's Solution in c++
// Suryansh Kumar
#include <bits/stdc++.h>
using namespace std;
const int mod = 1e9 + 7;
int a[1000010];
long long int res[1000010] , phi[1000010];
vector<pair<int,int>> ind(1000010);
vector<int> g[1000010];
int idx;
int inv2;
int power(int x, int n){
if(x==0) return 0;
else if(x==1 || n==0) return 1;
int r = power(x,n/2);
r = ((long long int)r*r)%mod;
if(n&1) return (long long int)(r * x)%mod;
else return r;
}
struct Fenwick{
vector<long long int> v;
int sz;
Fenwick(int n){
sz = n;
v.assign(sz+1,0);
}
void update(int id, long long int val){
while(id<=sz) v[id] = ((v[id] + val)%mod + mod)%mod , id+=id & -id;
}
long long int sum(int id){
long long int sum = 0;
while(id>0) sum = (sum + v[id])%mod , id-= id & -id;
return sum%mod;
}
long long int query(int l, int r){
return (sum(r) - sum(l-1) + mod)%mod;
}
};
void dfs(int node, int par){
idx++;
ind[node].first = idx;
for(auto &i: g[node]){
if(i!=par) dfs(i,node);
}
idx++;
ind[node].second = idx;
}
void pre(){
for(int i=1;i<=1000000;i++) phi[i]=i , res[i] = 0;
for(int i=2;i<=1000000;i++){
if(phi[i]==i){
for(int j=1; i*j<=1000000; j++) phi[i*j]-=phi[i*j]/i;
}
}
for(long long int i=1;i<=1000000;i++){
for(long long int j=1;j*i<=1000000;j++){
res[i*j] += j*phi[i];
}
}
}
long long int Fun(int val){
long long int ans = ((res[val] + val)*val)%mod;
return (inv2*ans)%mod;
}
void _sol(){
int n; cin >> n;
for(int i=1;i<=n;i++) cin >> a[i] , g[i].clear();
for(int i=0;i<n-1;i++){
int u,v; cin >> u >> v;
g[u].push_back(v);
g[v].push_back(u);
}
idx=0;
dfs(1,0);
inv2 = power(2,mod-2);
Fenwick f(2*n+2);
for(int i=1;i<=n;i++) f.update(ind[i].first,Fun(a[i])) , f.update(ind[i].second,Fun(a[i]));
int q; cin >> q;
while(q--){
int type; cin >> type;
if(type==1){
int node; cin >> node;
cout << (f.query(ind[node].first,ind[node].second)*inv2)%mod << "\n";
}
else{
long long int node,val; cin >> node >> val;
f.update(ind[node].first , Fun(val) - Fun(a[node]));
f.update(ind[node].second , Fun(val) - Fun(a[node]));
a[node] = val;
}
}
}
int main(){
ios_base::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
pre();
int t=1; // cin >> t;
while(t--) _sol();
}
Setter's Solution in PYPY
# Suryansh Kumar
import sys
import atexit
import io
sys.setrecursionlimit(100000)
buff = io.BytesIO()
sys.stdout = buff
@atexit.register
def write():
sys.__stdout__.write(buff.getvalue())
class Fenwick:
def __init__(self,n):
self.sz = n+10
self.bit=[0]*(n+10)
def update(self,idx,val):
while(idx<=self.sz):
self.bit[idx]+=val
idx += idx & -idx
def summ(self,idx):
ans=0
while(idx>0):
ans+=self.bit[idx]
idx -= idx & -idx
return ans
def query(self,l,r):
return self.summ(r) - self.summ(l-1)
lim = 1000009
phi=[0]*lim
res=[0]*lim
ind = []
a = []
g = []
def pre():
for i in range(lim):
phi[i]=i
res[i]=0
for i in range(2,lim):
if phi[i] == i:
j=1
while(j*i<lim):
phi[i*j]-=phi[i*j]//i
j+=1
for i in range(1,lim):
j=1
while(i*j<lim):
res[i*j]+= phi[i]*j
j+=1
for i in range(lim):
ind.append([])
def dfs(node,par):
global idxx
idxx = idxx + 1
ind[node].append(idxx)
for i in range(len(g[node])):
if g[node][i] != par:
dfs(g[node][i],node)
idxx = idxx + 1
ind[node].append(idxx)
def Fun(val):
ans = (res[val] + val)*val
return ans//2
n = int(input())
for i in range(0,n+10):
g.append([])
a = list(map(int,raw_input().split(" ")))
for i in range(n-1):
u,v = map(int,raw_input().split(" "))
g[u].append(v)
g[v].append(u)
pre()
idxx=0
dfs(1,0)
f = Fenwick(2*n+20)
for i in range(n):
f.update((ind[i+1][0]),Fun(a[i]))
f.update((ind[i+1][1]),Fun(a[i]))
q = int(raw_input())
for i in range(q):
qr = list(map(int,raw_input().split(" ")))
if qr[0]==1:
print((f.query(ind[qr[1]][0],ind[qr[1]][1])//2)%1000000007)
else:
f.update((ind[qr[1]][0]),Fun(qr[2]) - Fun(a[qr[1]-1]))
f.update((ind[qr[1]][1]),Fun(qr[2]) - Fun(a[qr[1]-1]))
a[qr[1]-1]=qr[2]