PROBLEM LINK:
Setter: Ivan Safonov
Tester: Alexander Morozov
Editorialist: Ajit Sharma Kasturi
DIFFICULTY:
MEDIUM
PREREQUISITES:
Trees, Depth first search, Quadratic Residue, Tonelli-Shanks Algorithm
PROBLEM:
We are given a rooted tree with N vertices rooted at 1 where p_v is the parent of vertex v.
LCA(u, v) is the lowest common ancestor of u and v . \mathbb{L_{v, s}} is defined as the set of all
vertices u where LCA(u, v) = s .
Let A_p be the set of all sequences of length N whose elements are integers between 0 and P-1 (inclusive) . LCA convolution for two sequences a, b \in P is defined as a sequence c = a*b where c_x = ( \displaystyle\sum_{i=1}^{N} \displaystyle\sum_{j \in \mathbb{L_{i, s}} } a_i \cdot b_j ) \bmod P .
We are given a sequence c . We need to find the number of sequences (modulo 998,244,353)
a \in A_P such that c = a*a . If the number of sequences is greater than 0, we need to output any sequence a \in A_P and c=a*a.
QUICK EXPLANATION:
-
To find the convolution c = a*a, for each subtree, find the sum of values of a in the subtree and square them. To have a solution, it is enough to check if all these c_x values are quadratic residues.
-
We can find a candidate value b_i which is the sum of values of a in the subtree of i for each 1\leq i \leq N by using Tonelli-Shanks Algorithm and the total possible such sequences are 2^{\textbf{the number of nonzero } b_i \textbf{values}} .
EXPLANATION:
The definition for LCA convolution seems very complicated, so let us demystify it. We can expess the convolution of two arrays a and b as c_x = (\displaystyle\sum_{i,j \text{ : LCA(i, j)=x } } a_i \cdot b_j) \bmod P for 1 \leq i,j \leq N .
Now consider the following, let sum_i be the sum of all values of c_j in the subtree of i modulo P . The array sum can be calculated by a simple dfs. Let c = a*a for some array a \in A_P . Now I claim that if such array a exists, then sum_i = (\displaystyle\sum_{j \in subtree(i) } a_j)^2 \bmod P . Try to prove it .
Proof
sum_i = (\displaystyle\sum_{j \in subtree(i) } c_j) \bmod P
\ \ \ \ \ \ \ \ \ \ = ( \displaystyle\sum_{x \in subtree(i)} \displaystyle\sum_{y \in subtree(i) } a_x \cdot a_y ) \bmod P
This is because if we consider any pair of indices x and y in the subtree of i,
LCA(x, y) \in subtree(i) .
The last step is a well known multinomial expansion which is equal to
(\displaystyle\sum_{j \in subtree(i) } a_j)^2 \bmod P .
This completes the proof.
Let us consider an array b_i as the sum of all values of a_j in the subtree of i modulo P. Then we need to have sum_i = b_i^2 \bmod P . Thus the solution exists if sum_i is 0 or is a quadratic residue modulo P . In other words, we are finding a square root modulo P . if we have any b_i \neq 0 which is not a quadratic residue modulo P, then no such sequence exists. Else we can find a b sequence by using Tonelli-Shanks Algorithm for each i from 1 to N. The number of sequences possible will be 2^{\text{number of nonzero }b_i } . This is because if some b_i > 0 satisfies sum_i =b_i^2 \bmod P, then P-b_i also satisfy sum_i = (P-b_i)^2 \bmod P .
Now using this sequence b_i, we can easily find a_i using dfs as follows :
(here j is said to be a child of i if p_j = i) .
a_i = (b_i - \displaystyle\sum_{j \in child(i) } b_j) \bmod P .
TIME COMPLEXITY:
DFS takes O(N) time. Applying Tonelli Shanks algorithm on each b_i takes O(\log(P)) time . Thus the overall time complexity is O(N \cdot \log(P)) .
SOLUTION:
Editorialist's solution
#include <bits/stdc++.h>
#define int long long int
using namespace std;
//*********************** TONELLI SHANKS ALGORITHM ****************
uint64_t modpow(uint64_t a, uint64_t b, uint64_t n)
{
uint64_t x = 1, y = a;
while (b > 0)
{
if (b % 2 == 1)
{
x = (x * y) % n;
}
y = (y * y) % n;
b /= 2;
}
return x % n;
}
struct Solution
{
uint64_t root1, root2;
bool exists;
};
struct Solution makeSolution(uint64_t root1, uint64_t root2, bool exists)
{
struct Solution sol;
sol.root1 = root1;
sol.root2 = root2;
sol.exists = exists;
return sol;
}
struct Solution ts(uint64_t n, uint64_t p)
{
uint64_t q = p - 1;
uint64_t ss = 0;
uint64_t z = 2;
uint64_t c, r, t, m;
if (modpow(n, (p - 1) / 2, p) != 1)
{
return makeSolution(0, 0, false);
}
while ((q & 1) == 0)
{
ss += 1;
q >>= 1;
}
if (ss == 1)
{
uint64_t r1 = modpow(n, (p + 1) / 4, p);
return makeSolution(r1, p - r1, true);
}
while (modpow(z, (p - 1) / 2, p) != p - 1)
{
z++;
}
c = modpow(z, q, p);
r = modpow(n, (q + 1) / 2, p);
t = modpow(n, q, p);
m = ss;
while (true)
{
uint64_t i = 0, zz = t;
uint64_t b = c, e;
if (t == 1)
{
return makeSolution(r, p - r, true);
}
while (zz != 1 && i < (m - 1))
{
zz = zz * zz % p;
i++;
}
e = m - i - 1;
while (e > 0)
{
b = b * b % p;
e--;
}
r = r * b % p;
c = b * b % p;
t = t * c % p;
m = i;
}
}
int test(uint64_t n, uint64_t p)
{
struct Solution sol = ts(n, p);
if (sol.exists)
return sol.root1;
return -1;
}
//*********************** END ***********************************
vector<int> v[500005];
int sum[500005];
int c[500005];
int ans[500005];
int arr[500005];
void dfs(int i, int prev = -1)
{
sum[i] = c[i];
for (int nxt : v[i])
{
if (nxt != prev)
{
dfs(nxt, i);
sum[i] += sum[nxt];
}
}
}
void dfs1(int i, int p, int prev = -1)
{
ans[i] = arr[i];
for (int nxt : v[i])
{
if (nxt != prev)
{
dfs1(nxt, p, i);
ans[i] -= arr[nxt];
ans[i] += p;
while (ans[i] >= p)
ans[i] -= p;
}
}
}
int32_t main()
{
int t;
cin >> t;
while (t--)
{
int n, p;
cin >> n >> p;
for (int i = 2; i <= n; i++)
{
int x;
cin >> x;
v[x].push_back(i);
v[i].push_back(x);
}
for (int i = 1; i <= n; i++)
cin >> c[i];
dfs(1);
int ways = 1;
bool possible = true;
int MOD = 998244353;
for (int i = 1; i <= n; i++)
{
int temp = sum[i] % p;
arr[i] = test(sum[i] % p, p);
if (sum[i] % p == 0)
arr[i] = 0;
else if (arr[i] == -1)
{
possible = false;
break;
}
else
{
ways *= 2;
while (ways >= MOD)
ways -= MOD;
}
}
if (possible)
{
cout << ways << endl;
dfs1(1, p);
for (int i = 1; i <= n; i++)
cout << ans[i] << " ";
cout << endl;
}
else
cout << 0 << "\n"
<< -1 << endl;
for (int i = 1; i <= n; i++)
v[i].clear();
for (int i = 1; i <= n; i++)
sum[i] = 0;
}
}
Setter's solution
#include <bits/stdc++.h>
using namespace std;
const int MOD = 998244353;
mt19937 rnd(239);
int n, p;
vector<int> pr;
vector<int> c;
void do_conv(vector<int> &a)
{
for (int i = n - 1; i > 0; i--)
{
a[pr[i]] += a[i];
if (a[pr[i]] >= p)
{
a[pr[i]] -= p;
}
}
}
void undo_conv(vector<int> &a)
{
for (int i = 1; i < n; i++)
{
a[pr[i]] -= a[i];
if (a[pr[i]] < 0)
{
a[pr[i]] += p;
}
}
}
int power(int a, int k)
{
if (k == 0)
{
return 1;
}
int t = power(a, k >> 1);
t = 1LL * t * t % p;
if (k & 1)
{
t = 1LL * t * a % p;
}
return t;
}
pair<int, int> mult_pair(const pair<int, int> &a, const pair<int, int> &b, int t)
{
return make_pair((1LL * a.first * b.second + 1LL * a.second * b.first) % p,
(1LL * a.second * b.second + 1LL * t * (1LL * a.first * b.first % p)) % p);
}
pair<int, int> power_pair(pair<int, int> a, int t, int k)
{
if (k == 1)
{
return a;
}
pair<int, int> u = power_pair(a, t, k / 2);
u = mult_pair(u, u, t);
if (k & 1)
{
u = mult_pair(u, a, t);
}
return u;
}
int mysqrt(int t)
{
if (t == 0)
{
return 0;
}
while (true)
{
int d = rnd() % p;
pair<int, int> u = power_pair(make_pair(1, d), t, (p - 1) / 2);
int res = 1 - u.second;
if (res < 0)
{
res += p;
}
res = 1LL * res * power(u.first, p - 2) % p;
if ((1LL * res * res) % p == t)
{
return res;
}
}
}
void solve()
{
cin >> n >> p;
pr.resize(n);
c.resize(n);
for (int i = 1; i < n; i++)
{
cin >> pr[i];
pr[i]--;
}
for (int i = 0; i < n; i++)
{
cin >> c[i];
}
do_conv(c);
int ans = 1;
for (int i = 0; i < n; i++)
{
if (c[i] == 0)
{
continue;
}
if (power(c[i], (p - 1) / 2) != 1)
{
cout << "0\n-1\n";
return;
}
ans += ans;
if (ans >= MOD)
{
ans -= MOD;
}
}
for (int i = 0; i < n; i++)
{
c[i] = mysqrt(c[i]);
}
undo_conv(c);
cout << ans << "\n";
for (int i = 0; i < n; i++)
{
cout << c[i] << " ";
}
cout << "\n";
}
int main()
{
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
int t;
cin >> t;
while (t--)
solve();
return 0;
}
Please comment below if you have any questions, alternate solutions, or suggestions.