LCASQRT - Editorial

PROBLEM LINK:

Practice
Div1
Div2

Setter: Ivan Safonov
Tester: Alexander Morozov
Editorialist: Ajit Sharma Kasturi

DIFFICULTY:

MEDIUM

PREREQUISITES:

Trees, Depth first search, Quadratic Residue, Tonelli-Shanks Algorithm

PROBLEM:

We are given a rooted tree with N vertices rooted at 1 where p_v is the parent of vertex v.
LCA(u, v) is the lowest common ancestor of u and v . \mathbb{L_{v, s}} is defined as the set of all
vertices u where LCA(u, v) = s .

Let A_p be the set of all sequences of length N whose elements are integers between 0 and P-1 (inclusive) . LCA convolution for two sequences a, b \in P is defined as a sequence c = a*b where c_x = ( \displaystyle\sum_{i=1}^{N} \displaystyle\sum_{j \in \mathbb{L_{i, s}} } a_i \cdot b_j ) \bmod P .

We are given a sequence c . We need to find the number of sequences (modulo 998,244,353)
a \in A_P such that c = a*a . If the number of sequences is greater than 0, we need to output any sequence a \in A_P and c=a*a.

QUICK EXPLANATION:

  • To find the convolution c = a*a, for each subtree, find the sum of values of a in the subtree and square them. To have a solution, it is enough to check if all these c_x values are quadratic residues.

  • We can find a candidate value b_i which is the sum of values of a in the subtree of i for each 1\leq i \leq N by using Tonelli-Shanks Algorithm and the total possible such sequences are 2^{\textbf{the number of nonzero } b_i \textbf{values}} .

EXPLANATION:

The definition for LCA convolution seems very complicated, so let us demystify it. We can expess the convolution of two arrays a and b as c_x = (\displaystyle\sum_{i,j \text{ : LCA(i, j)=x } } a_i \cdot b_j) \bmod P for 1 \leq i,j \leq N .

Now consider the following, let sum_i be the sum of all values of c_j in the subtree of i modulo P . The array sum can be calculated by a simple dfs. Let c = a*a for some array a \in A_P . Now I claim that if such array a exists, then sum_i = (\displaystyle\sum_{j \in subtree(i) } a_j)^2 \bmod P . Try to prove it .

Proof

sum_i = (\displaystyle\sum_{j \in subtree(i) } c_j) \bmod P
\ \ \ \ \ \ \ \ \ \ = ( \displaystyle\sum_{x \in subtree(i)} \displaystyle\sum_{y \in subtree(i) } a_x \cdot a_y ) \bmod P

This is because if we consider any pair of indices x and y in the subtree of i,
LCA(x, y) \in subtree(i) .

The last step is a well known multinomial expansion which is equal to

(\displaystyle\sum_{j \in subtree(i) } a_j)^2 \bmod P .

This completes the proof.

Let us consider an array b_i as the sum of all values of a_j in the subtree of i modulo P. Then we need to have sum_i = b_i^2 \bmod P . Thus the solution exists if sum_i is 0 or is a quadratic residue modulo P . In other words, we are finding a square root modulo P . if we have any b_i \neq 0 which is not a quadratic residue modulo P, then no such sequence exists. Else we can find a b sequence by using Tonelli-Shanks Algorithm for each i from 1 to N. The number of sequences possible will be 2^{\text{number of nonzero }b_i } . This is because if some b_i > 0 satisfies sum_i =b_i^2 \bmod P, then P-b_i also satisfy sum_i = (P-b_i)^2 \bmod P .

Now using this sequence b_i, we can easily find a_i using dfs as follows :
(here j is said to be a child of i if p_j = i) .

a_i = (b_i - \displaystyle\sum_{j \in child(i) } b_j) \bmod P .

TIME COMPLEXITY:

DFS takes O(N) time. Applying Tonelli Shanks algorithm on each b_i takes O(\log(P)) time . Thus the overall time complexity is O(N \cdot \log(P)) .

SOLUTION:

Editorialist's solution
#include <bits/stdc++.h>
#define int long long int
using namespace std;

//*********************** TONELLI SHANKS ALGORITHM ****************
uint64_t modpow(uint64_t a, uint64_t b, uint64_t n)
{
	uint64_t x = 1, y = a;
	while (b > 0)
	{
		if (b % 2 == 1)
		{
			x = (x * y) % n; 
		}
		y = (y * y) % n; 
		b /= 2;
	}
	return x % n;
}

struct Solution
{
	uint64_t root1, root2;
	bool exists;
};

struct Solution makeSolution(uint64_t root1, uint64_t root2, bool exists)
{
	struct Solution sol;
	sol.root1 = root1;
	sol.root2 = root2;
	sol.exists = exists;
	return sol;
}

struct Solution ts(uint64_t n, uint64_t p)
{
	uint64_t q = p - 1;
	uint64_t ss = 0;
	uint64_t z = 2;
	uint64_t c, r, t, m;

	if (modpow(n, (p - 1) / 2, p) != 1)
	{
		return makeSolution(0, 0, false);
	}

	while ((q & 1) == 0)
	{
		ss += 1;
		q >>= 1;
	}

	if (ss == 1)
	{
		uint64_t r1 = modpow(n, (p + 1) / 4, p);
		return makeSolution(r1, p - r1, true);
	}

	while (modpow(z, (p - 1) / 2, p) != p - 1)
	{
		z++;
	}

	c = modpow(z, q, p);
	r = modpow(n, (q + 1) / 2, p);
	t = modpow(n, q, p);
	m = ss;

	while (true)
	{
		uint64_t i = 0, zz = t;
		uint64_t b = c, e;
		if (t == 1)
		{
			return makeSolution(r, p - r, true);
		}
		while (zz != 1 && i < (m - 1))
		{
			zz = zz * zz % p;
			i++;
		}
		e = m - i - 1;
		while (e > 0)
		{
			b = b * b % p;
			e--;
		}
		r = r * b % p;
		c = b * b % p;
		t = t * c % p;
		m = i;
	}
}

int test(uint64_t n, uint64_t p)
{
	struct Solution sol = ts(n, p);
	if (sol.exists)
		return sol.root1;
	return -1;
}

//*********************** END ***********************************

vector<int> v[500005];
int sum[500005];
int c[500005];
int ans[500005];
int arr[500005];

void dfs(int i, int prev = -1)
{
	sum[i] = c[i];
	for (int nxt : v[i])
	{
		if (nxt != prev)
		{
			dfs(nxt, i);
			sum[i] += sum[nxt];
		}
	}
}

void dfs1(int i, int p, int prev = -1)
{
	ans[i] = arr[i];
	for (int nxt : v[i])
	{
		if (nxt != prev)
		{
			dfs1(nxt, p, i);
			ans[i] -= arr[nxt];
			ans[i] += p;
			while (ans[i] >= p)
				ans[i] -= p;
		}
	}
}

int32_t main()
{
	int t;
	cin >> t;
	while (t--)
	{
		int n, p;
		cin >> n >> p;
		for (int i = 2; i <= n; i++)
		{
			int x;
			cin >> x;
			v[x].push_back(i);
			v[i].push_back(x);
		}

		for (int i = 1; i <= n; i++)
			cin >> c[i];
		dfs(1);

		int ways = 1;
		bool possible = true;
		int MOD = 998244353;

		for (int i = 1; i <= n; i++)
		{
			int temp = sum[i] % p;
			arr[i] = test(sum[i] % p, p);
			if (sum[i] % p == 0)
				arr[i] = 0;
			else if (arr[i] == -1)
			{
				possible = false;
				break;
			}
			else
			{
				ways *= 2;
				while (ways >= MOD)
					ways -= MOD;
			}
		}

		if (possible)
		{
			cout << ways << endl;
			dfs1(1, p);
			for (int i = 1; i <= n; i++)
				cout << ans[i] << " ";
			cout << endl;
		}
		else
			cout << 0 << "\n"
				 << -1 << endl;

		for (int i = 1; i <= n; i++)
			v[i].clear();
		for (int i = 1; i <= n; i++)
			sum[i] = 0;
	}
}
Setter's solution
#include <bits/stdc++.h>

using namespace std;

const int MOD = 998244353;

mt19937 rnd(239);

int n, p;
vector<int> pr;
vector<int> c;

void do_conv(vector<int> &a)
{
	for (int i = n - 1; i > 0; i--)
	{
		a[pr[i]] += a[i];
		if (a[pr[i]] >= p)
		{
			a[pr[i]] -= p;
		}
	}
}

void undo_conv(vector<int> &a)
{
	for (int i = 1; i < n; i++)
	{
		a[pr[i]] -= a[i];
		if (a[pr[i]] < 0)
		{
			a[pr[i]] += p;
		}
	}
}

int power(int a, int k)
{
	if (k == 0)
	{
		return 1;
	}
	int t = power(a, k >> 1);
	t = 1LL * t * t % p;
	if (k & 1)
	{
		t = 1LL * t * a % p;
	}
	return t;
}

pair<int, int> mult_pair(const pair<int, int> &a, const pair<int, int> &b, int t)
{
	return make_pair((1LL * a.first * b.second + 1LL * a.second * b.first) % p,
					 (1LL * a.second * b.second + 1LL * t * (1LL * a.first * b.first % p)) % p);
}

pair<int, int> power_pair(pair<int, int> a, int t, int k)
{
	if (k == 1)
	{
		return a;
	}
	pair<int, int> u = power_pair(a, t, k / 2);
	u = mult_pair(u, u, t);
	if (k & 1)
	{
		u = mult_pair(u, a, t);
	}
	return u;
}

int mysqrt(int t)
{
	if (t == 0)
	{
		return 0;
	}
	while (true)
	{
		int d = rnd() % p;
		pair<int, int> u = power_pair(make_pair(1, d), t, (p - 1) / 2);
		int res = 1 - u.second;
		if (res < 0)
		{
			res += p;
		}
		res = 1LL * res * power(u.first, p - 2) % p;
		if ((1LL * res * res) % p == t)
		{
			return res;
		}
	}
}

void solve()
{
	cin >> n >> p;
	pr.resize(n);
	c.resize(n);
	for (int i = 1; i < n; i++)
	{
		cin >> pr[i];
		pr[i]--;
	}
	for (int i = 0; i < n; i++)
	{
		cin >> c[i];
	}
	do_conv(c);
	int ans = 1;
	for (int i = 0; i < n; i++)
	{
		if (c[i] == 0)
		{
			continue;
		}
		if (power(c[i], (p - 1) / 2) != 1)
		{
			cout << "0\n-1\n";
			return;
		}
		ans += ans;
		if (ans >= MOD)
		{
			ans -= MOD;
		}
	}
	for (int i = 0; i < n; i++)
	{
		c[i] = mysqrt(c[i]);
	}
	undo_conv(c);
	cout << ans << "\n";
	for (int i = 0; i < n; i++)
	{
		cout << c[i] << " ";
	}
	cout << "\n";
}

int main()
{
	ios::sync_with_stdio(0);
	cin.tie(0);
	cout.tie(0);
	int t;
	cin >> t;
	while (t--)
		solve();
	return 0;
}

Please comment below if you have any questions, alternate solutions, or suggestions.

VIDEO EDITORIAL:

4 Likes

Can anyone please share it’s video editorial? I can’t find it.

2 Likes

I am glad I was able to solve it without knowing what Tonelli-shanks or quadratic residue is!!
I got the first test case, but TLEed on the others. This is the only medium problem I partially managed to solve

It will be uploaded in a while. Sorry for the Delay.

1 Like

How did you solved the quadratic modulo equation without Tonelli-shanks , Did you just checked all the possible x to find the solution for the equation ?? or did you use any other approach.

So basically “c = (f(a) % P)”, so f(a) <= c + kP; but since 0 <= f(a) <= P-1;
k <= (f(P-1) - c)/P
and then I searched for all values of k from 0 <= i <= k; those are valid values of ‘a’;

I started from the leaf node and then progressed up to the root. I was able to solve the first testcase by this method. I was able to reduce the computation, by only keeping track of sum of subtree vertices and then a constant term, and I used to solve this the quadratic equation, associated with every tree. Here is my solution

https://www.codechef.com/viewsolution/40442735

OK , I also did the same thing , but finding k this way is time taking , so i learned how to solve quadratic modulo equations in constant time (approx) [ Tonelli - shanks algorithm ]. you are on the right path , if you use Tonelli - shanks you will get AC for all test cases.

1 Like

Yes. Happy to learn a new algorithm today. Will solve using that!

1 Like

I am glad that I got AC… Took some time to understand the question… That finally gave lead to quadratic equations starting from the leaf nodes to the root and each ai depending on the values of ai of the children… Solution was timing out. Then I explored a bit and stumbled upon the concept of quadratic residues and then landed on Cipollas algorithm to compute it. I implemented the Cipolla’s algo to get AC.
Link to my solution in Java
https://www.codechef.com/viewsolution/40395744

1 Like

I am having difficulty in understanding tonelli-shanks algorithm. Does anyone have some resources for learning this algorithm.

Except the Wiki page. :slightly_smiling_face:

Do you want steps and proof, then check this out

1 Like

Has anyone created editorial for problem DGMATRIX?

2 Likes

Thanks a lot!

1 Like