PROBLEM LINK:
Practice
Div-2 Contest
Div-1 Contest
Author: Anadi Agrawal
Setter: Krzysztof Boryczka
Tester: Istvan Nagy
Editorialist: Krzysztof Boryczka
DIFFICULTY:
MEDIUM-HARD
PREREQUISITES:
trees, sqrt-decomposition, pre-computation
PROBLEM:
Given a rooted tree answer Q queries in form: given two vertices u and v with the same depth, calculate scalar product of their vectors. Where we define vector of v as vector of weights of vertices from v to the root.
QUICK EXPLANATION:
Do sqrt-decomposition by depth - divide tree in blocks each one consisting of \sqrt N layers. In each block choose the layer with the least number of vertices. Pre-process scalar products between each pair in this layer. Observe that there’d be only N pre-processed pairs in the tree. Answer queries naively - moving up until find the pre-processed pair. Complexity O((N+Q) \sqrt N).
EXPLANATION:
Let’s define depth of vertex v as the length of the shortest path from v to 1. Also let’s define layer as the set of vertices with the same depth.
Let’s divide vertices into blocks by their depths. In the first block put vertices with depths in range [0, \sqrt N), in second [\sqrt N, 2 \sqrt N), etc.
In each block choose the layer with the least number of vertices. Select every possible pair from this layer to pre-process. Let’s observe that if the block has K vertices then in selected layer there could be at most \frac{K}{\sqrt N} vertices. So we’ve chosen at most \frac{K^2}{N} pairs of vertices for block of size K. Obviously, K \leqslant N, so it implies \frac{K^2}{N} \leqslant K - we’ve chosen at most K pairs. Summing it up for every block we get that we’ve chosen at most N pairs.
Pre-process answers for chosen pairs from top to bottom. Also we can do the same for answering queries. Take two vertices form the query and move them up naively until we found already pre-processed pair. We can see that we’ll do at most 2\sqrt N steps.
Complexity O((N+Q) \sqrt N).
SOLUTIONS:
Author's Solution
#include <bits/stdc++.h>
using namespace std;
typedef unsigned int uint;
const int N = 1e6 + 7;
const int P = 40;
int n, q;
uint w[N];
vector <int> G[N];
int id[N];
int off[N];
int to_add[N];
bool mem[10 * N];
uint ans[10 * N];
uint dot[N];
int lvl[N], par[N];
vector <int> ver_list[N];
void dfs(int u, int p){
par[u] = p;
ver_list[lvl[u]].push_back(u);
dot[u] += w[u] * w[u];
for(auto v: G[u])
if(v != p){
dot[v] = dot[u];
lvl[v] = lvl[u] + 1;
dfs(v, u);
}
}
void read(){
scanf("%d %d", &n, &q);
for(int i = 1; i <= n; ++i)
scanf("%u", &w[i]);
for(int i = 1; i < n; ++i){
int u, v;
scanf("%d %d", &u, &v);
G[u].push_back(v);
G[v].push_back(u);
}
}
void init(){
read();
dfs(1, 0);
int off_count = 0;
for(int i = 0; i < n; i += P){
int best = i;
for(int j = 0; j < P; ++j)
if(ver_list[i + j].size() < ver_list[best].size())
best = i + j;
int t = 0;
for(auto &v: ver_list[best])
id[v] = t++;
off[best] = off_count;
off_count += t * (t - 1) / 2;
}
}
uint answer(int u, int v){
uint ret = 0;
int it = 0;
while(u != v){
int cur_lvl = lvl[u];
if(off[cur_lvl] > 0){
int pu = id[u], pv = id[v];
if(pu < pv)
swap(pu, pv);
int size = ver_list[cur_lvl].size();
int place = off[cur_lvl] + pu * (pu - 1) / 2 + pv;
if(mem[place]){
ret += ans[place];
for(int i = 0; i < it; ++i)
ans[to_add[i]] += ret;
return ret;
}
mem[place] = true;
ans[place] = -ret;
to_add[it++] = place;
ret += w[u] * w[v];
}
else
ret += w[u] * w[v];
u = par[u], v = par[v];
}
ret += dot[u];
for(int i = 0; i < it; ++i)
ans[to_add[i]] += ret;
return ret;
}
void solve(){
while(q--){
int u, v;
scanf("%d %d", &u, &v);
printf("%u\n", answer(u, v));
}
}
int main(){
init();
solve();
return 0;
}
Setter's Solution
#include <bits/stdc++.h>
using namespace std;
typedef pair<int, int> ii;
typedef vector<int> vi;
const int INF=0x3f3f3f3f;
#define FOR(i, b, e) for(int i = (b); i < (e); i++)
#define TRAV(x, a) for(auto &x: (a))
#define SZ(x) ((int)(x).size())
#define PB push_back
#define X first
#define Y second
const int N = 3e5+5;
const int K = 80;
vi G[N];
int p[N], dpth[N], num[N];
unsigned int dot[N], val[N];
vector<vi> ondpth;
bool chosen[N];
vector<unsigned int> memo[N];
void dfs(int v, int par, int dpt){
p[v] = par;
dpth[v] = dpt;
dot[v] = dot[par] + val[v]*val[v];
if(SZ(ondpth) == dpt) ondpth.PB({});
ondpth[dpt].PB(v);
TRAV(x, G[v]){
if(x == par) continue;
dfs(x, v, dpt+1);
}
}
unsigned int query(int a, int b){
unsigned int ret = 0;
while(a != b && !chosen[a]){
ret += val[a]*val[b];
a = p[a];
b = p[b];
}
if(a == b) ret += dot[a];
else ret += memo[dpth[a]][num[a]*SZ(ondpth[dpth[a]])+num[b]];
return ret;
}
void solve(){
int n, q;
cin >> n >> q;
FOR(i, 1, n+1) cin >> val[i];
FOR(i, 0, n-1){
int a, b;
cin >> a >> b;
G[a].PB(b), G[b].PB(a);
}
dfs(1, 1, 0);
ondpth.PB({});
for(int i = 0; i < SZ(ondpth); i += K){
ii akt = {INF, INF};
FOR(j, i, min(i+K, SZ(ondpth))) akt = min(akt, {SZ(ondpth[j]), j});
int lev = 0;
TRAV(x, ondpth[akt.Y]) num[x] = lev++;
TRAV(x, ondpth[akt.Y]) TRAV(y, ondpth[akt.Y]){
if(num[x] <= num[y]) memo[akt.Y].PB(query(x, y));
else memo[akt.Y].PB(memo[akt.Y][num[y]*lev+num[x]]);
}
TRAV(x, ondpth[akt.Y]) chosen[x] = 1;
}
FOR(i, 0, q){
int a, b;
cin >> a >> b;
cout << query(a, b) << '\n';
}
}
int main(){
ios::sync_with_stdio(0);
cin.tie(0);
solve();
return 0;
}
Tester's Solution
indent whole code by 4 spaces