# TOKTREE - Editorial

Author: erdosnumber
Testers: kingmessi
Editorialist: iceknight1093

TBD

# PREREQUISITES:

DP on trees, combinatorics

# PROBLEM:

You are given a directed tree rooted at 1, with edges directed from parent to child.
Vertex i contains a token labelled i.
In one move, you do the following:

1. Choose an edge u\to v.
2. Move all tokens on u to v.
3. Delete this edge from the tree.

Find the number of different final configurations of tokens.

# EXPLANATION:

Consider a fixed (non-leaf) vertex u.
It can be observed that at most two edges going out of u can result in a transfer of tokens:

1. The first time an edge of the form u\to v is chosen, the token labelled u will definitely move across it.
After that, u has no more tokens - so the only possible way it can receive more tokens is from its parent.
2. If u receives tokens from its parent p_u and then some edge going out of u is chosen, this edge will result in some tokens moving.
In particular, the token p_u will definitely move along it.

Let’s color the edge u\to v green if token u moves along it, and red if token p_u moves along it.
(Note that it’s possible an edge is colored both red and green, if it receives the token from its parent before any movement out of it.)

Let’s try to characterize configurations in terms of coloring edges red and/or green instead.
Consider any coloring of the edges of the tree. This coloring can correspond to a valid configuration if and only if:

1. Every non-leaf vertex has exactly one outgoing edge colored green.
2. Every non-leaf vertex has at most one outgoing edge colored red.
3. If u has an outgoing edge colored red, then the edge p_u \to u must be colored (either green, red, or both).

Here’s the nice part: every final configuration of tokens corresponds uniquely to some coloring of edges that satisfies the above properties!

Proof

First, suppose we have a coloring that satisfies the above properties.
Then, it corresponds to a unique final configuration as follows:

• For each non-leaf vertex u, take the green edge out of u, then take red edges as long as they exist.
The first vertex that no longer has a red edge going out is where u will end up.

Conversely, take any final configuration of tokens.
For each edge u\to v,

• If token u lies inside the subtree of v, color the u\to v edge green.
• If token p_u lies inside the subtree of v, color the u\to v edge red.

It’s not hard to see that this token configuration ↔ coloring mapping is indeed a bijection, hence proving our claim.

So, all we need to do is count the number of valid colorings of edges.
That can be done with the help of dynamic programming.
Define:

• \text{dp}_1[u] to be the number of ways of coloring edges in the subtree of u, such that u has an outgoing green edge but no red edge.
• \text{dp}_2[u] to be the number of ways of coloring edges in the subtree of u, such that u has both an outgoing green edge and a red edge.

For transitions:

• To compute \text{dp}_1[u], fix the green edge u \to v. Then,
• The subtree of v can then be colored in any valid way, giving \text{dp}_1[v] + \text{dp}_2[v] options.
• Any other child c of u isn’t allowed to have an outgoing red edge. So, it has only \text{dp}_1[c] choices.
• The total number of ways for this fixed edge is the product of all of these values, which can be computed in \mathcal{O}(1) time with a bit of precalculation.
• \text{dp}_2[u] is similar, though a bit more involved.
• It’s not feasible to fix both red and green edges, that would result in quadratic complexity.
• However, with a bit of algebra (left as an exercise to the reader), you can find the required value in linear time too.

All the \text{dp}_1 and \text{dp}_2 values can be found in linear time, after which the answer is just \text{dp}_1[1] (since the root cannot have an outgoing red edge.)

# TIME COMPLEXITY:

\mathcal{O}(N) per testcase.

# CODE:

Author's code (C++)
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const ll mod = 1e9 + 7;

void solve() {
int n;
cin >> n;
vector<int> p(n + 1);
vector<int> ch(n + 1);
for(int i = 2; i <= n; i ++) cin >> p[i];
for(int i = 2; i <= n; i ++) {
ch[p[i]] ++;
}

vector<ll> dp1(n + 1), dp2(n + 1);
vector<vector<ll>> knap(n + 1, vector<ll>(3));
ll ans = 1;

auto func = [&](const auto self, int cur) -> void {
dp1[cur] = 1;
dp2[cur] = 0;
knap[cur][0] = 1;

vector<int> child;
if(x == p[cur]) continue;
child.push_back(cur);
self(self, x);
ll temp = ((ch[x] > 0) ? (2 * dp1[x] + dp2[x]) : (dp1[x] + dp2[x])) % mod;
knap[cur][2] = (knap[cur][1] * temp) % mod
+ (knap[cur][2] * dp1[x]) % mod;
knap[cur][2] %= mod;
knap[cur][1] = (knap[cur][0] * temp) % mod
+ (knap[cur][1] * dp1[x]) % mod;
knap[cur][1] %= mod;
knap[cur][0] = (knap[cur][0] * dp1[x]) % mod;
}

if(ch[cur] > 0) {
dp1[cur] = knap[cur][1];
dp2[cur] = (2 * knap[cur][2]) % mod;
}
};

func(func, 1);
cout << dp1[1] << '\n';
}
signed main(int argc, char* argv[]) {

int t;  cin >> t;
while(t --) solve();
}
Tester's code (C++)
#include<bits/stdc++.h>

#include <cassert>
#include <numeric>
#include <type_traits>

#ifdef _MSC_VER
#include <intrin.h>
#endif

#include <utility>

#ifdef _MSC_VER
#include <intrin.h>
#endif

namespace atcoder {

namespace internal {

constexpr long long safe_mod(long long x, long long m) {
x %= m;
if (x < 0) x += m;
return x;
}

struct barrett {
unsigned int _m;
unsigned long long im;

explicit barrett(unsigned int m) : _m(m), im((unsigned long long)(-1) / m + 1) {}

unsigned int umod() const { return _m; }

unsigned int mul(unsigned int a, unsigned int b) const {

unsigned long long z = a;
z *= b;
#ifdef _MSC_VER
unsigned long long x;
_umul128(z, im, &x);
#else
unsigned long long x =
(unsigned long long)(((unsigned __int128)(z)*im) >> 64);
#endif
unsigned long long y = x * _m;
return (unsigned int)(z - y + (z < y ? _m : 0));
}
};

constexpr long long pow_mod_constexpr(long long x, long long n, int m) {
if (m == 1) return 0;
unsigned int _m = (unsigned int)(m);
unsigned long long r = 1;
unsigned long long y = safe_mod(x, m);
while (n) {
if (n & 1) r = (r * y) % _m;
y = (y * y) % _m;
n >>= 1;
}
return r;
}

constexpr bool is_prime_constexpr(int n) {
if (n <= 1) return false;
if (n == 2 || n == 7 || n == 61) return true;
if (n % 2 == 0) return false;
long long d = n - 1;
while (d % 2 == 0) d /= 2;
constexpr long long bases[3] = {2, 7, 61};
for (long long a : bases) {
long long t = d;
long long y = pow_mod_constexpr(a, t, n);
while (t != n - 1 && y != 1 && y != n - 1) {
y = y * y % n;
t <<= 1;
}
if (y != n - 1 && t % 2 == 0) {
return false;
}
}
return true;
}
template <int n> constexpr bool is_prime = is_prime_constexpr(n);

constexpr std::pair<long long, long long> inv_gcd(long long a, long long b) {
a = safe_mod(a, b);
if (a == 0) return {b, 0};

long long s = b, t = a;
long long m0 = 0, m1 = 1;

while (t) {
long long u = s / t;
s -= t * u;
m0 -= m1 * u;  // |m1 * u| <= |m1| * s <= b

auto tmp = s;
s = t;
t = tmp;
tmp = m0;
m0 = m1;
m1 = tmp;
}
if (m0 < 0) m0 += b / s;
return {s, m0};
}

constexpr int primitive_root_constexpr(int m) {
if (m == 2) return 1;
if (m == 167772161) return 3;
if (m == 469762049) return 3;
if (m == 754974721) return 11;
if (m == 998244353) return 3;
int divs[20] = {};
divs[0] = 2;
int cnt = 1;
int x = (m - 1) / 2;
while (x % 2 == 0) x /= 2;
for (int i = 3; (long long)(i)*i <= x; i += 2) {
if (x % i == 0) {
divs[cnt++] = i;
while (x % i == 0) {
x /= i;
}
}
}
if (x > 1) {
divs[cnt++] = x;
}
for (int g = 2;; g++) {
bool ok = true;
for (int i = 0; i < cnt; i++) {
if (pow_mod_constexpr(g, (m - 1) / divs[i], m) == 1) {
ok = false;
break;
}
}
if (ok) return g;
}
}
template <int m> constexpr int primitive_root = primitive_root_constexpr(m);

unsigned long long floor_sum_unsigned(unsigned long long n,
unsigned long long m,
unsigned long long a,
unsigned long long b) {
unsigned long long ans = 0;
while (true) {
if (a >= m) {
ans += n * (n - 1) / 2 * (a / m);
a %= m;
}
if (b >= m) {
ans += n * (b / m);
b %= m;
}

unsigned long long y_max = a * n + b;
if (y_max < m) break;
n = (unsigned long long)(y_max / m);
b = (unsigned long long)(y_max % m);
std::swap(m, a);
}
return ans;
}

}  // namespace internal

}  // namespace atcoder

#include <cassert>
#include <numeric>
#include <type_traits>

namespace atcoder {

namespace internal {

#ifndef _MSC_VER
template <class T>
using is_signed_int128 =
typename std::conditional<std::is_same<T, __int128_t>::value ||
std::is_same<T, __int128>::value,
std::true_type,
std::false_type>::type;

template <class T>
using is_unsigned_int128 =
typename std::conditional<std::is_same<T, __uint128_t>::value ||
std::is_same<T, unsigned __int128>::value,
std::true_type,
std::false_type>::type;

template <class T>
using make_unsigned_int128 =
typename std::conditional<std::is_same<T, __int128_t>::value,
__uint128_t,
unsigned __int128>;

template <class T>
using is_integral = typename std::conditional<std::is_integral<T>::value ||
is_signed_int128<T>::value ||
is_unsigned_int128<T>::value,
std::true_type,
std::false_type>::type;

template <class T>
using is_signed_int = typename std::conditional<(is_integral<T>::value &&
std::is_signed<T>::value) ||
is_signed_int128<T>::value,
std::true_type,
std::false_type>::type;

template <class T>
using is_unsigned_int =
typename std::conditional<(is_integral<T>::value &&
std::is_unsigned<T>::value) ||
is_unsigned_int128<T>::value,
std::true_type,
std::false_type>::type;

template <class T>
using to_unsigned = typename std::conditional<
is_signed_int128<T>::value,
make_unsigned_int128<T>,
typename std::conditional<std::is_signed<T>::value,
std::make_unsigned<T>,
std::common_type<T>>::type>::type;

#else

template <class T> using is_integral = typename std::is_integral<T>;

template <class T>
using is_signed_int =
typename std::conditional<is_integral<T>::value && std::is_signed<T>::value,
std::true_type,
std::false_type>::type;

template <class T>
using is_unsigned_int =
typename std::conditional<is_integral<T>::value &&
std::is_unsigned<T>::value,
std::true_type,
std::false_type>::type;

template <class T>
using to_unsigned = typename std::conditional<is_signed_int<T>::value,
std::make_unsigned<T>,
std::common_type<T>>::type;

#endif

template <class T>
using is_signed_int_t = std::enable_if_t<is_signed_int<T>::value>;

template <class T>
using is_unsigned_int_t = std::enable_if_t<is_unsigned_int<T>::value>;

template <class T> using to_unsigned_t = typename to_unsigned<T>::type;

}  // namespace internal

}  // namespace atcoder

namespace atcoder {

namespace internal {

struct modint_base {};
struct static_modint_base : modint_base {};

template <class T> using is_modint = std::is_base_of<modint_base, T>;
template <class T> using is_modint_t = std::enable_if_t<is_modint<T>::value>;

}  // namespace internal

template <int m, std::enable_if_t<(1 <= m)>* = nullptr>
struct static_modint : internal::static_modint_base {
using mint = static_modint;

public:
static constexpr int mod() { return m; }
static mint raw(int v) {
mint x;
x._v = v;
return x;
}

static_modint() : _v(0) {}
template <class T, internal::is_signed_int_t<T>* = nullptr>
static_modint(T v) {
long long x = (long long)(v % (long long)(umod()));
if (x < 0) x += umod();
_v = (unsigned int)(x);
}
template <class T, internal::is_unsigned_int_t<T>* = nullptr>
static_modint(T v) {
_v = (unsigned int)(v % umod());
}

unsigned int val() const { return _v; }

mint& operator++() {
_v++;
if (_v == umod()) _v = 0;
return *this;
}
mint& operator--() {
if (_v == 0) _v = umod();
_v--;
return *this;
}
mint operator++(int) {
mint result = *this;
++*this;
return result;
}
mint operator--(int) {
mint result = *this;
--*this;
return result;
}

mint& operator+=(const mint& rhs) {
_v += rhs._v;
if (_v >= umod()) _v -= umod();
return *this;
}
mint& operator-=(const mint& rhs) {
_v -= rhs._v;
if (_v >= umod()) _v += umod();
return *this;
}
mint& operator*=(const mint& rhs) {
unsigned long long z = _v;
z *= rhs._v;
_v = (unsigned int)(z % umod());
return *this;
}
mint& operator/=(const mint& rhs) { return *this = *this * rhs.inv(); }

mint operator+() const { return *this; }
mint operator-() const { return mint() - *this; }

mint pow(long long n) const {
assert(0 <= n);
mint x = *this, r = 1;
while (n) {
if (n & 1) r *= x;
x *= x;
n >>= 1;
}
return r;
}
mint inv() const {
if (prime) {
assert(_v);
return pow(umod() - 2);
} else {
auto eg = internal::inv_gcd(_v, m);
assert(eg.first == 1);
return eg.second;
}
}

friend mint operator+(const mint& lhs, const mint& rhs) {
return mint(lhs) += rhs;
}
friend mint operator-(const mint& lhs, const mint& rhs) {
return mint(lhs) -= rhs;
}
friend mint operator*(const mint& lhs, const mint& rhs) {
return mint(lhs) *= rhs;
}
friend mint operator/(const mint& lhs, const mint& rhs) {
return mint(lhs) /= rhs;
}
friend bool operator==(const mint& lhs, const mint& rhs) {
return lhs._v == rhs._v;
}
friend bool operator!=(const mint& lhs, const mint& rhs) {
return lhs._v != rhs._v;
}

private:
unsigned int _v;
static constexpr unsigned int umod() { return m; }
static constexpr bool prime = internal::is_prime<m>;
};

template <int id> struct dynamic_modint : internal::modint_base {
using mint = dynamic_modint;

public:
static int mod() { return (int)(bt.umod()); }
static void set_mod(int m) {
assert(1 <= m);
bt = internal::barrett(m);
}
static mint raw(int v) {
mint x;
x._v = v;
return x;
}

dynamic_modint() : _v(0) {}
template <class T, internal::is_signed_int_t<T>* = nullptr>
dynamic_modint(T v) {
long long x = (long long)(v % (long long)(mod()));
if (x < 0) x += mod();
_v = (unsigned int)(x);
}
template <class T, internal::is_unsigned_int_t<T>* = nullptr>
dynamic_modint(T v) {
_v = (unsigned int)(v % mod());
}

unsigned int val() const { return _v; }

mint& operator++() {
_v++;
if (_v == umod()) _v = 0;
return *this;
}
mint& operator--() {
if (_v == 0) _v = umod();
_v--;
return *this;
}
mint operator++(int) {
mint result = *this;
++*this;
return result;
}
mint operator--(int) {
mint result = *this;
--*this;
return result;
}

mint& operator+=(const mint& rhs) {
_v += rhs._v;
if (_v >= umod()) _v -= umod();
return *this;
}
mint& operator-=(const mint& rhs) {
_v += mod() - rhs._v;
if (_v >= umod()) _v -= umod();
return *this;
}
mint& operator*=(const mint& rhs) {
_v = bt.mul(_v, rhs._v);
return *this;
}
mint& operator/=(const mint& rhs) { return *this = *this * rhs.inv(); }

mint operator+() const { return *this; }
mint operator-() const { return mint() - *this; }

mint pow(long long n) const {
assert(0 <= n);
mint x = *this, r = 1;
while (n) {
if (n & 1) r *= x;
x *= x;
n >>= 1;
}
return r;
}
mint inv() const {
auto eg = internal::inv_gcd(_v, mod());
assert(eg.first == 1);
return eg.second;
}

friend mint operator+(const mint& lhs, const mint& rhs) {
return mint(lhs) += rhs;
}
friend mint operator-(const mint& lhs, const mint& rhs) {
return mint(lhs) -= rhs;
}
friend mint operator*(const mint& lhs, const mint& rhs) {
return mint(lhs) *= rhs;
}
friend mint operator/(const mint& lhs, const mint& rhs) {
return mint(lhs) /= rhs;
}
friend bool operator==(const mint& lhs, const mint& rhs) {
return lhs._v == rhs._v;
}
friend bool operator!=(const mint& lhs, const mint& rhs) {
return lhs._v != rhs._v;
}

private:
unsigned int _v;
static internal::barrett bt;
static unsigned int umod() { return bt.umod(); }
};
template <int id> internal::barrett dynamic_modint<id>::bt(998244353);

using modint998244353 = static_modint<998244353>;
using modint1000000007 = static_modint<1000000007>;
using modint = dynamic_modint<-1>;

namespace internal {

template <class T>
using is_static_modint = std::is_base_of<internal::static_modint_base, T>;

template <class T>
using is_static_modint_t = std::enable_if_t<is_static_modint<T>::value>;

template <class> struct is_dynamic_modint : public std::false_type {};
template <int id>
struct is_dynamic_modint<dynamic_modint<id>> : public std::true_type {};

template <class T>
using is_dynamic_modint_t = std::enable_if_t<is_dynamic_modint<T>::value>;

}  // namespace internal

}  // namespace atcoder

#include <ext/pb_ds/assoc_container.hpp> // Common file
#include <ext/pb_ds/tree_policy.hpp>
#define ll long long
#define int long long
#define rep(i,a,b) for(int i=a;i<b;i++)
#define rrep(i,a,b) for(int i=a;i>=b;i--)
#define repin rep(i,0,n)
#define precise(i) cout<<fixed<<setprecision(i)
#define vi vector<int>
#define si set<int>
#define mii map<int,int>
#define take(a,n) for(int j=0;j<n;j++) cin>>a[j];
#define give(a,n) for(int j=0;j<n;j++) cout<<a[j]<<' ';
#define vpii vector<pair<int,int>>
#define db double
#define be(x) x.begin(),x.end()
#define pii pair<int,int>
#define pb push_back
#define pob pop_back
#define ff first
#define ss second
#define lb lower_bound
#define ub upper_bound
#define bpc(x) __builtin_popcountll(x)
#define btz(x) __builtin_ctz(x)
using namespace std;
using namespace atcoder;
using namespace __gnu_pbds;

typedef tree<int, null_type, less<int>, rb_tree_tag,tree_order_statistics_node_update> ordered_set;
typedef tree<pair<int, int>, null_type,less<pair<int, int> >, rb_tree_tag,tree_order_statistics_node_update> ordered_multiset;

const long long INF=1e18;
const long long M=1e9+7;
const long long MM=998244353;
using mint = static_modint<M>;

int power( int N, int M){
int power = N, sum = 1;
if(N == 0) sum = 0;
while(M > 0){if((M & 1) == 1){sum *= power;}
power = power * power;M = M >> 1;}
return sum;
}

struct input_checker {
string buffer;
int pos;

const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
const string number = "0123456789";
const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
const string lower = "abcdefghijklmnopqrstuvwxyz";

input_checker() {
pos = 0;
while (true) {
int c = cin.get();
if (c == -1) {
break;
}
buffer.push_back((char) c);
}
}

int nextDelimiter() {
int now = pos;
while (now < (int) buffer.size() && !isspace(buffer[now])) {
now++;
}
return now;
}

assert(pos < (int) buffer.size());
int nxt = nextDelimiter();
string res;
while (pos < nxt) {
res += buffer[pos];
pos++;
}
return res;
}

string readString(int minl, int maxl, const string &pattern = "") {
assert(minl <= maxl);
assert(minl <= (int) res.size());
assert((int) res.size() <= maxl);
for (int i = 0; i < (int) res.size(); i++) {
assert(pattern.empty() || pattern.find(res[i]) != string::npos);
}
return res;
}

int readInt(int minv, int maxv) {
assert(minv <= maxv);
assert(minv <= res);
assert(res <= maxv);
return res;
}

long long readLong(long long minv, long long maxv) {
assert(minv <= maxv);
assert(minv <= res);
assert(res <= maxv);
return res;
}

auto readInts(int n, int minv, int maxv) {
assert(n >= 0);
vector<int> v(n);
for (int i = 0; i < n; ++i) {
}
return v;
}

auto readLongs(int n, long long minv, long long maxv) {
assert(n >= 0);
vector<long long> v(n);
for (int i = 0; i < n; ++i) {
}
return v;
}

assert((int) buffer.size() > pos);
assert(buffer[pos] == ' ');
pos++;
}

assert((int) buffer.size() > pos);
assert(buffer[pos] == '\n');
pos++;
}

assert((int) buffer.size() == pos);
}
}inp;

const int MX = 5e5 + 5;
mint dp[MX];
mint dp1[MX];
vector<bool> vis(MX);

void dfs(int cur){
vis[cur] = true;
vector<mint> b;
dfs(x);
b.pb(dp[x]);
}
dp[cur] = 1;
dp1[cur] = 1;
}
else{
vector<mint> pf = b,sf = b;
rep(i,1,1ll*pf.size())pf[i] *= pf[i-1];
rrep(i,1ll*pf.size()-2,0)sf[i] *= sf[i+1];
rep(i,0,sz){
mint res = (i?pf[i-1]:1)*(i+1<sz?sf[i+1]:1);
dp[cur] += dp1[x]*res;
}

if(cur == 0)return;
mint mul[sz][3];
mul[0][2] = 0;
rep(i,1,sz){
}
dp1[cur] = dp[cur]*2 + mul[sz-1][2]*2;
}
}

int smn = 0;

void solve()
{
int n;
// cin >> n;
// assert(n >= 1);
// assert(n <= 1e5);
smn += n;
rep(i,1,n+1){
vis[i] = 0;
dp[i] = 0;
dp1[i] = 0;
}
rep(i,2,n+1){
// int p;
// cin >> p;
// assert(p >= 1);
// assert(p < i);
}
dfs(1);
rep(i,1,n+1){
assert(vis[i]);
}
// rep(i,1,n+1)cout << dp[i].val() << ' ';cout << "\n";
// rep(i,1,n+1)cout << dp1[i].val() << ' ';cout << "\n";
cout << dp[1].val() << "\n";
}

signed main(){
ios_base::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
#ifdef NCR
init();
#endif
#ifdef SIEVE
sieve();
#endif
int t;
// cin >> t;
// assert(t >= 1);
// assert(t <= 1e5);
while(t--)
solve();
assert(smn <= 500'000);
return 0;
}
Editorialist's code (Python)
mod = 10**9 + 7
for _ in range(int(input())):
n = int(input())
par = [0] + list(map(int, input().split()))
for i in range(1, n): par[i] -= 1

dp1, dp2, prod, sm, sqsm, ch = [0]*n, [0]*n, [1]*n, [0]*n, [0]*n, [0]*n
for i in reversed(range(n)):
if ch[i] == 0:
dp1[i] = 1
dp2[i] = 0
else:
dp1[i] = prod[i] * (ch[i] + sm[i]) % mod
dp2[i] = dp1[i] + prod[i] * (ch[i] * (ch[i] - 1) + 2 * (ch[i] - 1) * sm[i] + (sm[i] * sm[i] - sqsm[i]))
dp2[i] %= mod

prod[par[i]] *= dp1[i]
prod[par[i]] %= mod
x = dp2[i] * pow(dp1[i], mod-2, mod)
sm[par[i]] += x
sqsm[par[i]] += x*x
ch[par[i]] += 1
print(dp1[0])

Thanks for the detailed editorial. One minor issue in the edtorialist’s code. It uses Modular Multiplicative Inverse. Is it possible to construct a case that dp1 is 0 modulo 1e9+7?

Unfortunately, I don’t yet know of a testcase under these constraints that will make some \text{dp}_1 value 0, though I imagine it should be possible to create one.