PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: nskybytskyi
Tester: raysh07
Editorialist: iceknight1093
DIFFICULTY:
Easy-Medium
PREREQUISITES:
Dynamic programming, small-to-large merging, implicit segment trees
PROBLEM:
You’re given a tree on N vertices. Vertex i has the value A_i written on it.
In the hard version, A_i \leq N.
Count the number of zig-zag sequences in the tree.
A zig-zag sequence is a sequence of vertices (v_1, v_2, \ldots, v_k) such that k \geq 2, the concatenation of the simple paths (v_i, v_{i+1}) is the full path (v_1, v_k), and the values of the vertices alternate in size.
EXPLANATION:
It is recommended that you read the solution to ZIGZAGTREE first, this will continue from there.
We now have a more general array, so zig-zag sequences no longer just alternate between the same two values.
However, the core of the solution remains quite similar.
Let’s define \text{dp}_1[u] to be the number of zig-zag sequences that start at u, go into the subtree of u, and have A_u be smaller than the next element of the sequence.
Let \text{dp}_2[u] be defined similarly, but for sequences where A_u is larger than the next element.
Suppose we’re able to (somehow) compute these values for all u.
Let’s try to extend the idea of the easy version to this one: we’ll fix the LCA to be u and count the number of sequences.
In particular, if u is fixed,
- For sequences containing u, with u being a ‘valley’, we want to find the sum of \text{dp}_1[x_1] \cdot \text{dp}_1[x_2] across all (x_1, x_2) such that:
- A_{x_1} \gt A_u and A_{x_2} \gt A_u.
- x_1, x_2 lie in the subtrees of different children of u.
- A similar formulation is obtained for sequences containing u with it being a ‘mountain’.
- Finally, we have sequences that don’t contain u at all.
Here, we want something like the sum of \text{dp}_1[x_1] \cdot \text{dp}_2[x_2] across all pairs (x_1, x_2) from different child subtrees of u, with A_{x_1} \lt A_{x_2}.
The last part seems especially hard to compute quickly, and optimizing that is what will allow us to solve the problem.
For each vertex u, let’s maintain a segment tree S_u.
The i-th index of this segment tree will hold a pair of values: the sum of \text{dp}_1[v] and the sum of \text{dp}_2[v] across all v in the subtree of u such that A_v = i.
If we’re able to maintain these segment trees, observe that all the computations we want to do become quite tractable with the help of small-to-large merging.
Indeed, when merging one child of u into another, for each vertex v of the smaller child we only really want to know something like “what is the sum of \text{dp}_1[x] across all vertices in the larger child, for whom A_x \lt A_v?”
(Of course, there are 2-3 such queries we want per merge, but they’re all like this: a range sum of either \text{dp}_1 or \text{dp}_2 based on values.)
As it turns out, we can indeed maintain segment trees like this - if we store them implicitly and only create the nodes we really need to.
Since we’re using small-to-large merging, there are \mathcal{O}(N\log N) updates and queries across all the segment trees.
Each of them creates \mathcal{O}(\log N) new nodes and requires \mathcal{O}(\log N) time, so we use \mathcal{O}(N\log^2 N) time and memory in total.
If you dynamically allocate memory using new
you can potentially run into memory issues, but a predeclared large static pool will work just fine - alternately, you can try using a bump allocator.
Finally, you may recall that we skipped the details on computing \text{dp}_1 and \text{dp}_2 at the start - however, it should be quite clear now that doing so is trivial with the segment trees we have, each requiring a single query.
TIME COMPLEXITY:
\mathcal{O}(N\log^2 N) per testcase.
CODE:
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());
int n;
const int M = 5e7 + 69;
const int N = 1e5 + 69;
struct node{
int l, r;
int sum; // sum of ways?
};
struct S{
// dp1, dp2 needs to be stored as implicit segtrees
int dp1v, dp2v;
vector <int> a;
};
node seg[M];
int sub[N];
vector <int> adj[N];
int ptr = 0;
int a[N];
int dp1[N], dp2[N];
long long ans = 0;
const int mod = 1e9 + 7;
void add(int l, int r, int pos, int qp, int v){
seg[pos].sum += v;
if (seg[pos].sum >= mod) seg[pos].sum -= mod;
if (l == r){
return;
}
int mid = (l + r) / 2;
if (qp <= mid){
if (seg[pos].l == -1){
seg[pos].l = ptr++;
}
add(l, mid, seg[pos].l, qp, v);
} else {
if (seg[pos].r == -1){
seg[pos].r = ptr++;
}
add(mid + 1, r, seg[pos].r, qp, v);
}
}
int query(int l, int r, int pos, int ql, int qr){
if (l > qr || r < ql) return 0;
if (l >= ql && r <= qr) return seg[pos].sum;
int ans = 0;
int mid = (l + r) / 2;
if (seg[pos].l != -1 && ql <= mid){
ans += query(l, mid, seg[pos].l, ql, qr);
}
if (seg[pos].r != -1 && qr > mid){
ans += query(mid + 1, r, seg[pos].r, ql, qr);
}
if (ans >= mod) ans -= mod;
return ans;
}
void getSub(int u, int p) {
sub[u] = 1;
pair<int, int> heavy = {-1, -1};
for (int i = 0; i < adj[u].size(); i++) {
int v = adj[u][i];
if (v == p) continue;
getSub(v, u);
sub[u] += sub[v];
heavy = max(heavy, {sub[v], i});
}
// make the vertex with the largest subtree size the first
if (heavy.first != -1) {
swap(adj[u][0], adj[u][heavy.second]);
}
}
S dp(int u, int p) {
// do not initialize yet
// dp1 -> this is greater
// dp2 -> this is smaller than previous
// case 1 : path passes through this node and includes it, and is not vertical
// case 2 : path doesnt pass through this node
S res;
dp1[u] = 1;
dp2[u] = 1;
bool hasInit = false;
for (int v : adj[u]) {
if (v == p) continue;
if (!hasInit){
hasInit = true;
S temp = dp(v, u);
swap(temp, res);
int v1 = query(1, n, res.dp2v, 1, a[u] - 1);
int v2 = query(1, n, res.dp1v, a[u] + 1, n);
ans += v1 + v2;
ans %= mod;
dp1[u] += v1;
dp2[u] += v2;
// cerr << "HANDLED " << v << " " << ans << "\n";
} else {
S temp = dp(v, u);
// merge
for (auto x : temp.a){
// add to answer for case when a[x] is used
// cout << "HERE " << res.dp1v << " " << res.dp2v << endl;
// query(1, n, res.dp1v, 1, a[x] - 1);
// // cout << "DONE" << endl;
// query(1, n, res.dp2v, a[x] + 1, n);
int v1 = query(1, n, res.dp2v, 1, a[x] - 1);
int v2 = query(1, n, res.dp1v, a[x] + 1, n);
// cerr << dp1[x] << " " << v1 << "\n";
// cerr << dp2[x] << " " << v2 << "\n";
ans += 1LL * dp1[x] * v1;
ans += 1LL * dp2[x] * v2;
ans %= mod;
res.a.push_back(x);
}
for (auto x : temp.a){
add(1, n, res.dp1v, a[x], dp1[x]);
add(1, n, res.dp2v, a[x], dp2[x]);
}
int v1 = query(1, n, temp.dp2v, 1, a[u] - 1);
int v2 = query(1, n, temp.dp1v, a[u] + 1, n);
ans += 1LL * dp1[u] * v1;
ans += 1LL * dp2[u] * v2;
dp1[u] += v1;
dp2[u] += v2;
if (dp1[u] >= mod) dp1[u] -= mod;
if (dp2[u] >= mod) dp2[u] -= mod;
ans %= mod;
}
}
if (!hasInit) {
res.dp1v = ptr++;
res.dp2v = ptr++;
}
res.a.push_back(u);
// for (int i = 1; i <= n; i++){
// cerr << query(1, n, res.dp1v, i, i) << " \n"[i == n];
// }
add(1, n, res.dp1v, a[u], dp1[u]);
add(1, n, res.dp2v, a[u], dp2[u]);
// cout << dp1[u] << "\n";
// for (int i = 1; i <= n; i++){
// cerr << query(1, n, res.dp1v, i, i) << " \n"[i == n];
// }
// cout << u << " " << ans << "\n";
// cout << "dp1 " << dp1[u] << " dp2 " << dp2[u] << "\n";
return res;
}
void Solve()
{
cin >> n;
for (int i = 1; i <= n; i++){
adj[i].clear();
dp1[i] = dp2[i] = 0;
sub[i] = 0;
}
ans = 0;
for (int i = 1; i <= n; i++){
cin >> a[i];
// a[i] = 1 + RNG() % n;
}
for (int i = 2; i <= n; i++){
int p; cin >> p;
// int p = i / 2;
adj[p].push_back(i);
}
getSub(1, -1);
dp(1, -1);
// cout << ptr << "\n";
// for (int i = 1; i <= n; i++){
// cout << dp1[i] << " " << dp2[i] << "\n";
// }
ans %= mod;
cout << ans << "\n";
}
int32_t main()
{
auto begin = std::chrono::high_resolution_clock::now();
ios_base::sync_with_stdio(0);
cin.tie(0);
int t = 1;
// freopen("in", "r", stdin);
// freopen("out", "w", stdout);
for (int i = 0; i < M; i++){
seg[i].l = seg[i].r = -1;
}
cin >> t;
for(int i = 1; i <= t; i++)
{
//cout << "Case #" << i << ": ";
Solve();
}
auto end = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n";
return 0;
}
Editorialist's code (C++)
// #include <bits/allocator.h>
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());
/**
* Integers modulo p, where p is a prime
* Source: Aeren (modified from tourist?)
* Modmul for 64-bit mod from kactl:ModMulLL
* Works with p < 7.2e18 with x87 80-bit long double, and p < 2^52 ~ 4.5e12 with 64-bit
*/
template<typename T>
struct Z_p{
using Type = typename decay<decltype(T::value)>::type;
static vector<Type> MOD_INV;
constexpr Z_p(): value(){ }
template<typename U> Z_p(const U &x){ value = normalize(x); }
template<typename U> static Type normalize(const U &x){
Type v;
if(-mod() <= x && x < mod()) v = static_cast<Type>(x);
else v = static_cast<Type>(x % mod());
if(v < 0) v += mod();
return v;
}
const Type& operator()() const{ return value; }
template<typename U> explicit operator U() const{ return static_cast<U>(value); }
constexpr static Type mod(){ return T::value; }
Z_p &operator+=(const Z_p &otr){ if((value += otr.value) >= mod()) value -= mod(); return *this; }
Z_p &operator-=(const Z_p &otr){ if((value -= otr.value) < 0) value += mod(); return *this; }
template<typename U> Z_p &operator+=(const U &otr){ return *this += Z_p(otr); }
template<typename U> Z_p &operator-=(const U &otr){ return *this -= Z_p(otr); }
Z_p &operator++(){ return *this += 1; }
Z_p &operator--(){ return *this -= 1; }
Z_p operator++(int){ Z_p result(*this); *this += 1; return result; }
Z_p operator--(int){ Z_p result(*this); *this -= 1; return result; }
Z_p operator-() const{ return Z_p(-value); }
template<typename U = T>
typename enable_if<is_same<typename Z_p<U>::Type, int>::value, Z_p>::type &operator*=(const Z_p& rhs){
#ifdef _WIN32
uint64_t x = static_cast<int64_t>(value) * static_cast<int64_t>(rhs.value);
uint32_t xh = static_cast<uint32_t>(x >> 32), xl = static_cast<uint32_t>(x), d, m;
asm(
"divl %4; \n\t"
: "=a" (d), "=d" (m)
: "d" (xh), "a" (xl), "r" (mod())
);
value = m;
#else
value = normalize(static_cast<int64_t>(value) * static_cast<int64_t>(rhs.value));
#endif
return *this;
}
template<typename U = T>
typename enable_if<is_same<typename Z_p<U>::Type, int64_t>::value, Z_p>::type &operator*=(const Z_p &rhs){
uint64_t ret = static_cast<uint64_t>(value) * static_cast<uint64_t>(rhs.value) - static_cast<uint64_t>(mod()) * static_cast<uint64_t>(1.L / static_cast<uint64_t>(mod()) * static_cast<uint64_t>(value) * static_cast<uint64_t>(rhs.value));
value = normalize(static_cast<int64_t>(ret + static_cast<uint64_t>(mod()) * (ret < 0) - static_cast<uint64_t>(mod()) * (ret >= static_cast<uint64_t>(mod()))));
return *this;
}
template<typename U = T>
typename enable_if<!is_integral<typename Z_p<U>::Type>::value, Z_p>::type &operator*=(const Z_p &rhs){
value = normalize(value * rhs.value);
return *this;
}
template<typename U>
Z_p &operator^=(U e){
if(e < 0) *this = 1 / *this, e = -e;
Z_p res = 1;
for(; e; *this *= *this, e >>= 1) if(e & 1) res *= *this;
return *this = res;
}
template<typename U>
Z_p operator^(U e) const{
return Z_p(*this) ^= e;
}
Z_p &operator/=(const Z_p &otr){
Type a = otr.value, m = mod(), u = 0, v = 1;
if(a < (int)MOD_INV.size()) return *this *= MOD_INV[a];
while(a){
Type t = m / a;
m -= t * a; swap(a, m);
u -= t * v; swap(u, v);
}
assert(m == 1);
return *this *= u;
}
template<typename U> friend const Z_p<U> &abs(const Z_p<U> &v){ return v; }
Type value;
};
template<typename T> bool operator==(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value == rhs.value; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator==(const Z_p<T>& lhs, U rhs){ return lhs == Z_p<T>(rhs); }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator==(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) == rhs; }
template<typename T> bool operator!=(const Z_p<T> &lhs, const Z_p<T> &rhs){ return !(lhs == rhs); }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator!=(const Z_p<T> &lhs, U rhs){ return !(lhs == rhs); }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator!=(U lhs, const Z_p<T> &rhs){ return !(lhs == rhs); }
template<typename T> bool operator<(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value < rhs.value; }
template<typename T> bool operator>(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value > rhs.value; }
template<typename T> bool operator<=(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value <= rhs.value; }
template<typename T> bool operator>=(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value >= rhs.value; }
template<typename T> Z_p<T> operator+(const Z_p<T> &lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) += rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator+(const Z_p<T> &lhs, U rhs){ return Z_p<T>(lhs) += rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator+(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) += rhs; }
template<typename T> Z_p<T> operator-(const Z_p<T> &lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) -= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator-(const Z_p<T>& lhs, U rhs){ return Z_p<T>(lhs) -= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator-(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) -= rhs; }
template<typename T> Z_p<T> operator*(const Z_p<T> &lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) *= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator*(const Z_p<T>& lhs, U rhs){ return Z_p<T>(lhs) *= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator*(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) *= rhs; }
template<typename T> Z_p<T> operator/(const Z_p<T> &lhs, const Z_p<T> &rhs) { return Z_p<T>(lhs) /= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator/(const Z_p<T>& lhs, U rhs) { return Z_p<T>(lhs) /= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator/(U lhs, const Z_p<T> &rhs) { return Z_p<T>(lhs) /= rhs; }
template<typename T> istream &operator>>(istream &in, Z_p<T> &number){
typename common_type<typename Z_p<T>::Type, int64_t>::type x;
in >> x;
number.value = Z_p<T>::normalize(x);
return in;
}
template<typename T> ostream &operator<<(ostream &out, const Z_p<T> &number){ return out << number(); }
/*
using ModType = int;
struct VarMod{ static ModType value; };
ModType VarMod::value;
ModType &mod = VarMod::value;
using Zp = Z_p<VarMod>;
*/
constexpr int mod = 1e9 + 7; // 1000000007
// constexpr int mod = (119 << 23) + 1; // 998244353
// constexpr int mod = 1e9 + 9; // 1000000009
using Zp = Z_p<integral_constant<decay<decltype(mod)>::type, mod>>;
template<typename T> vector<typename Z_p<T>::Type> Z_p<T>::MOD_INV;
template<typename T = integral_constant<decay<decltype(mod)>::type, mod>>
void precalc_inverse(int SZ){
auto &inv = Z_p<T>::MOD_INV;
if(inv.empty()) inv.assign(2, 1);
for(; inv.size() <= SZ; ) inv.push_back((mod - 1LL * mod / (int)inv.size() * inv[mod % (int)inv.size()]) % mod);
}
template<typename T>
vector<T> precalc_power(T base, int SZ){
vector<T> res(SZ + 1, 1);
for(auto i = 1; i <= SZ; ++ i) res[i] = res[i - 1] * base;
return res;
}
template<typename T>
vector<T> precalc_factorial(int SZ){
vector<T> res(SZ + 1, 1); res[0] = 1;
for(auto i = 1; i <= SZ; ++ i) res[i] = res[i - 1] * i;
return res;
}
int nodes = 0;
static char buf[490 << 21];
void* operator new(size_t s) {
static size_t i = sizeof buf;
assert(s < i);
return (void*)&buf[i -= s];
}
void operator delete(void*) {}
void operator delete(void*, size_t) {}
struct Node {
using T = Zp;
T unit = 0;
T f(T a, T b) { return a+b; }
Node *l = 0, *r = 0;
int lo, hi;
T val = unit;
Node(int _lo,int _hi):lo(_lo),hi(_hi){++nodes;}
T query(int L, int R) {
if (R <= lo || hi <= L) return unit;
if (L <= lo && hi <= R) return val;
push();
return f(l->query(L, R), r->query(L, R));
}
void add(int pos, T x) {
if (pos >= hi or pos < lo) return;
if (lo+1 == hi) val += x;
else {
push();
l->add(pos, x), r->add(pos, x);
val = l->val + r->val;
}
}
void push() {
if (!l) {
int mid = lo + (hi - lo)/2;
l = new Node(lo, mid); r = new Node(mid, hi);
}
}
};
int main()
{
ios::sync_with_stdio(false); cin.tie(0);
int t; cin >> t;
while (t--) {
int n; cin >> n;
vector<int> a(n);
for (int &x : a) cin >> x;
vector adj(n, vector<int>());
for (int i = 1; i < n; ++i) {
int x; cin >> x;
adj[--x].push_back(i);
}
vector<Zp> dp1(n), dp2(n);
// dp1 -> a[u] is smaller than the next value
// dp2 -> a[u] is larger than the next value
vector<int> subsz(n), in(n), order;
int timer = 0;
vector<Node*> seg1(n), seg2(n);
Zp ans = 0;
auto dfs = [&] (const auto &self, int u) -> void {
subsz[u] = 1;
in[u] = timer++;
order.push_back(u);
for (int v : adj[u]) {
self(self, v);
subsz[u] += subsz[v];
}
if (subsz[u] == 1) {
seg1[u] = new Node(0, n+1);
seg1[u] -> add(a[u], 1);
seg2[u] = new Node(0, n+1);
seg2[u] -> add(a[u], 1);
dp1[u] = dp2[u] = 1;
return;
}
sort(begin(adj[u]), end(adj[u]), [&] (int x, int y) {return subsz[x] > subsz[y];});
seg1[u] = seg1[adj[u][0]];
seg2[u] = seg2[adj[u][0]];
for (int v : adj[u]) {
if (v == adj[u][0]) continue;
// for (int i = 0; i <= n; ++i) cout << seg1[u]->query(i,i+1) << ' ';
// cout << '\n';
// for (int i = 0; i <= n; ++i) cout << seg2[u]->query(i,i+1) << ' ';
// cout << '\n';
// cout << '\n';
for (int i = in[v]; i < in[v] + subsz[v]; ++i) {
int x = order[i];
// not including u
ans += dp1[x] * seg2[u]->query(a[x]+1, n+1);
ans += dp2[x] * seg1[u]->query(0, a[x]);
// including u
if (a[x] < a[u]) ans += dp1[x] * seg1[u]->query(0, a[u]);
if (a[x] > a[u]) ans += dp2[x] * seg2[u]->query(a[u]+1, n+1);
}
for (int i = in[v]; i < in[v] + subsz[v]; ++i) {
int x = order[i];
seg1[u]->add(a[x], dp1[x]);
seg2[u]->add(a[x], dp2[x]);
}
}
dp1[u] = 1 + seg2[u]->query(a[u]+1, n+1);
dp2[u] = 1 + seg1[u]->query(0, a[u]);
seg1[u] -> add(a[u], dp1[u]);
seg2[u] -> add(a[u], dp2[u]);
ans += dp1[u] + dp2[u] - 2;
};
dfs(dfs, 0);
cout << ans << '\n';
}
cerr << nodes << '\n';
}