SCALSUM - Editorial


Div-2 Contest
Div-1 Contest

Author: Anadi Agrawal
Setter: Krzysztof Boryczka
Tester: Istvan Nagy
Editorialist: Krzysztof Boryczka




trees, sqrt-decomposition, pre-computation


Given a rooted tree answer Q queries in form: given two vertices u and v with the same depth, calculate scalar product of their vectors. Where we define vector of v as vector of weights of vertices from v to the root.


Do sqrt-decomposition by depth - divide tree in blocks each one consisting of \sqrt N layers. In each block choose the layer with the least number of vertices. Pre-process scalar products between each pair in this layer. Observe that there’d be only N pre-processed pairs in the tree. Answer queries naively - moving up until find the pre-processed pair. Complexity O((N+Q) \sqrt N).


Let’s define depth of vertex v as the length of the shortest path from v to 1. Also let’s define layer as the set of vertices with the same depth.
Let’s divide vertices into blocks by their depths. In the first block put vertices with depths in range [0, \sqrt N), in second [\sqrt N, 2 \sqrt N), etc.
In each block choose the layer with the least number of vertices. Select every possible pair from this layer to pre-process. Let’s observe that if the block has K vertices then in selected layer there could be at most \frac{K}{\sqrt N} vertices. So we’ve chosen at most \frac{K^2}{N} pairs of vertices for block of size K. Obviously, K \leqslant N, so it implies \frac{K^2}{N} \leqslant K - we’ve chosen at most K pairs. Summing it up for every block we get that we’ve chosen at most N pairs.

Pre-process answers for chosen pairs from top to bottom. Also we can do the same for answering queries. Take two vertices form the query and move them up naively until we found already pre-processed pair. We can see that we’ll do at most 2\sqrt N steps.

Complexity O((N+Q) \sqrt N).


Author's Solution
#include <bits/stdc++.h>

using namespace std;

typedef unsigned int uint;

const int N = 1e6 + 7;
const int P = 40;

int n, q;
uint w[N];
vector <int> G[N];

int id[N];
int off[N];
int to_add[N];

bool mem[10 * N];
uint ans[10 * N];

uint dot[N];
int lvl[N], par[N];
vector <int> ver_list[N];

void dfs(int u, int p){
	par[u] = p;

	dot[u] += w[u] * w[u];
	for(auto v: G[u])
		if(v != p){
			dot[v] = dot[u];
			lvl[v] = lvl[u] + 1;
			dfs(v, u);

void read(){
	scanf("%d %d", &n, &q);
	for(int i = 1; i <= n; ++i)
		scanf("%u", &w[i]);

	for(int i = 1; i < n; ++i){
		int u, v;
		scanf("%d %d", &u, &v);


void init(){
	dfs(1, 0);

	int off_count = 0;
	for(int i = 0; i < n; i += P){
		int best = i;
		for(int j = 0; j < P; ++j)
			if(ver_list[i + j].size() < ver_list[best].size())
				best = i + j;

		int t = 0;
		for(auto &v: ver_list[best])
			id[v] = t++;

		off[best] = off_count;
		off_count += t * (t - 1) / 2;

uint answer(int u, int v){
	uint ret = 0;
	int it = 0;

	while(u != v){
		int cur_lvl = lvl[u];
		if(off[cur_lvl] > 0){
			int pu = id[u], pv = id[v];
			if(pu < pv)
				swap(pu, pv);

			int size = ver_list[cur_lvl].size();
			int place = off[cur_lvl] + pu * (pu - 1) / 2 + pv;

				ret += ans[place];
				for(int i = 0; i < it; ++i)
					ans[to_add[i]] += ret;
				return ret;

			mem[place] = true;
			ans[place] = -ret;

			to_add[it++] = place;
			ret += w[u] * w[v];
			ret += w[u] * w[v];

		u = par[u], v = par[v];

	ret += dot[u];
	for(int i = 0; i < it; ++i)
		ans[to_add[i]] += ret;
	return ret;

void solve(){
		int u, v;
		scanf("%d %d", &u, &v);
		printf("%u\n", answer(u, v));

int main(){
	return 0;
Setter's Solution
#include <bits/stdc++.h>
using namespace std;

typedef pair<int, int> ii;
typedef vector<int> vi;
const int INF=0x3f3f3f3f;

#define FOR(i, b, e) for(int i = (b); i < (e); i++)
#define TRAV(x, a) for(auto &x: (a))
#define SZ(x) ((int)(x).size())
#define PB push_back
#define X first
#define Y second

const int N = 3e5+5;
const int K = 80;

vi G[N];
int p[N], dpth[N], num[N];
unsigned int dot[N], val[N];
vector<vi> ondpth;
bool chosen[N];
vector<unsigned int> memo[N];

void dfs(int v, int par, int dpt){
	p[v] = par;
	dpth[v] = dpt;
	dot[v] = dot[par] + val[v]*val[v];
	if(SZ(ondpth) == dpt) ondpth.PB({});
	TRAV(x, G[v]){
		if(x == par) continue;
		dfs(x, v, dpt+1);

unsigned int query(int a, int b){
	unsigned int ret = 0;
	while(a != b && !chosen[a]){
		ret += val[a]*val[b];
		a = p[a];
		b = p[b];
	if(a == b) ret += dot[a];
	else ret += memo[dpth[a]][num[a]*SZ(ondpth[dpth[a]])+num[b]];
	return ret;

void solve(){
	int n, q;
	cin >> n >> q;
	FOR(i, 1, n+1) cin >> val[i];
	FOR(i, 0, n-1){
		int a, b;
		cin >> a >> b;
		G[a].PB(b), G[b].PB(a);
	dfs(1, 1, 0);
	for(int i = 0; i < SZ(ondpth); i += K){
		ii akt = {INF, INF};
		FOR(j, i, min(i+K, SZ(ondpth))) akt = min(akt, {SZ(ondpth[j]), j});
		int lev = 0;
		TRAV(x, ondpth[akt.Y]) num[x] = lev++;
		TRAV(x, ondpth[akt.Y]) TRAV(y, ondpth[akt.Y]){
			if(num[x] <= num[y]) memo[akt.Y].PB(query(x, y));
			else memo[akt.Y].PB(memo[akt.Y][num[y]*lev+num[x]]);
		TRAV(x, ondpth[akt.Y]) chosen[x] = 1;
	FOR(i, 0, q){
		int a, b;
		cin >> a >> b;
		cout << query(a, b) << '\n';

int main(){
	return 0;
Tester's Solution

indent whole code by 4 spaces



Can we use MO’s algorithm?


Alternate approach with same time complexity using Mo’s algorithm.
The idea is similar to : FCTRE - Editorial
Implementation :

  • We can flatten the tree with euler tour and then use Mo’s algorithm to answer the queries.
  • Each query L,R will be converted to end(L),start(R) where end(L) denotes the last occurrence of node L and start(R) denotes the first occurence of R.
  • We maintain the scalar product when we traverse queries using Mo’s algorithm and we will get scalar product for nodes in path from node L to node R excluding the LCA(L,R).
  • Remember, if a node appears twice in the given range of euler tour, we will remove it. For understanding better refer to FCTRE - Editorial
  • We precompute the sum of squares (SOS) of value from root to each node and add SOS[LCA(L,R)] to each query’s answer.

Yes we can. Please check my comment.

I was also thinking like this but could not implement it. Can you make a video editorial for your solution.

1 Like

Hi I_returns check my solution -


Ah that’s nice. I tried implementing this in contest but I wasn’t able to figure out how to deal with nodes appearing twice in the euler tour path… I couldn’t simply multiply them and divide since a_i <= 1e9 and MOD was non prime.

1 Like


If we just do the naive approach but cache all the products we come across then the time complexity should be N*root(N), right?

For example:

if from root to n1 and n2 we have: (xi,yi,ui,vi)
xi, yi-> ith elements from root to n1 and n2 respectively
ui,vi -> weight of xi and yi respectively
(1,1, 10, 10) (2,2, 11, 11) (3,3, 12, 12) (4,5, 13, 14) (10,11, 15, 16)
now when we calculate scalar product we know that for (ni,nj) = wiwj + (ni-1,nj-1)
this was we store all values on this path in map
i think at max this would result in N
root(N) complexity
but my code gets TLE
would be great if someone could help me understand why this logic wouldn’t pass


You should try unordered_map with a good hash function.


I tried with unordered map as well
still TLE

I stored those numbers which are in dot product in a vector and then subtracted old value and added new value after removing the element. Since we cannot do mod inverse.

Can someone please explain what was the difference between subtask 2 and 3 ? why having 10^5 nodes and queries is easier than 3*10^5 ? and how were people able to get AC for subtask 2 but not 3 if there isn’t a huge difference.

Where m i going wrong help me out

this is not the problem being discussed in this thread

If you have already sort the queries by depth then why are storing the result till the root. It’s useless just store for current (u, v) pair.

I tried sorting by height of nodes and then implementing with an unordered map. Only 1 case TLE.
What could I have optimized more?
@dtu_amritkumar my solution is a bit similar to yours, except im using recursion. my hash function is from gfg. What is up with case 17.

Try this hash function maybe it will work.

struct pair_hash
template <class T1, class T2>
size_t operator()(const pair<T1, T2> &p) const
size_t h = (size_t(p.first) << 32) + size_t(p.second);
h *= 1231231557ull;
h ^= (h >> 32);
return h;

Source: Quora

1 Like

Can anyone provide me a good source/video to learn square root decomposition on trees?

It is possible that if I don’t store all the pairs till root, it might take n^2
1 -> 2
1-> 3
3-> 5
4 -> 6
and so on
edge between n->2+n for all n>1

in this case if queries are
n-2, n

and so on
i will get TLE then because everytime i have to traverse the path and hence n^2

1 Like

Try CodeNCode