TREEPERM - Editorial

PROBLEM LINK:

Practice
Div-3 Contest
Div-2 Contest
Div-1 Contest

Author & Editorialist: Ahmed Zaher
Testers: Shubham Jain, Aryan Choudhary

DIFFICULTY:

Easy-medium

PREREQUISITES:

Greedy, DP, Trees

Note: I’ll refer to a “good vertical partition” as a “solution”.

QUICK EXPLANATION:

To find whether a solution exists, we will attempt to construct a solution S. Iteratively pick a leaf and go up until we hit the first node where we have a path with the multiset of a values on it is the same as the multiset of b values on it. If we don’t find such a path, then there is no solution. Such path can be safely removed (why?), so we add it to S, remove it from the tree, and check whether there is a solution for what remains of the tree. To count the number of solutions, we note that we can construct any solution T by doing the following operation any number of times: start with T=S, take any 2 paths P_1, P_2\in T such that P_1 is directly above/below P_2 in the original tree, remove them from T and add P'=P_1 \cup P_2 to T. Our answer is the number of different T's we can get, which can be computed using dynamic programming on trees.

EXPLANATION:

Finding whether a solution exists

This can is done by analyzing the structure of a solution and trying to come up with properties that hold in any solution or would hold after doing some tweaks to the solution; assuming those tweaks maintain the solution’s validity. That’s exactly what we’re going to do. We will incrementally construct a “special solution”, which we’ll denote by S, such that any other solution can reduce to it, and if we fail to construct it, we will conclude that there aren’t any solutions.

Let’s first look at when it is possible to permute a sequence s = s_1, s_2, .., s_k to obtain another sequence t = t_1, t_2, .., t_k. This is equivalent to saying that for any value x that appears exactly y times in s, that it must also appear exactly y times in t. In other words, the multiset of elements in s is equal to the multiset of elements in t, which we will denote with multiset_{1 \leq i \leq k}(s_i) = multiset_{1 \leq i \leq k}(t_i).

Now let’s pick any leaf u in the tree. We know that u must be included in a path whose end points are u,v, where v is an ancestor of u or u=v. Let’s denote this path with P_{u, v}=p_1, p_2,..,p_m where p_1=u, p_m=v and p_{i+1} is the parent of p_i for all 1 \leq i <m. Moreover, we know that multiset_{1 \leq i \leq m}(a_{p_i}) = multiset_{1 \leq i \leq m}(b_{p_i}) must hold. Let’s start from u and go up in the tree trying to find a v that satisfies these conditions. If there doesn’t exist such v, then we clearly have no solution and we output 0. Otherwise, let node x be the nearest v to u that satisfies these conditions. We can show that any solution either has P_{u,x} or can be reduced to one that has P_{u, x}.

Proof

First, note that for any two sequences s,t of length k, if we have multiset_{1 \leq i \leq k}(s_i) = multiset_{1 \leq i \leq k}(t_i) and we have k' such that k'< k where multiset_{1 \leq i \leq k'}(s_i) = multiset_{1 \leq i \leq k'}(t_i), then we have multiset_{k' <i\leq k}(s_i)=multiset_{k' <i\leq k}(t_i).

Now consider any solution that for a leaf u, it picked a path P_{u, v}=p_1, p_2,..,p_m. If v=x, then we are done. Otherwise, there is some m' such that x=p_{m'}, 1 \leq m' < m. By our definition of x, we have multiset_{1 \leq i \leq m'}(a_{p_i}) = multiset_{1 \leq i \leq m'}(b_{p_i}), which implies multiset_{m' <i\leq m}(a_{p_i}) = multiset_{m' <i\leq m}(b_{p_i}). Therefore, we can replace P_{u, v} with 2 paths P_{u, p_{m'}}=P_{u,x}, P_{p_{m'+1}, v} and get a new valid solution that indeed has P_{u,x}.

Therefore, we can safely add P_{u, x} to S and remove it from the tree so that we’re no longer concerned with it. Then, we check the validity of the rest of the tree using the same method. Note that we might get a forest of trees after such removal. We can illustrate our solution for this part of the problem with the following pseudocode:

let deg[i] be the number of children on node i for i in 1..N
leaves = set of all nodes i with deg[i] = 0
while leaves is not empty:
    pop a node u from leaves
    path = {u} // path initially contains u
    v = u 
    while multiset(a[i]) != multiset(b[i]) for nodes in path
        if v has no parent: return 0 // there's no solution
        v = parent[v]
        add v to path // extend path
    // current path from u to v is now safe to remove
    add nodes in path to S and remove them from tree
    if v has a parent:
        decrement deg[parent[v]]
        if deg[parent[v]] = 0: add parent[v] to leaves
return 1

Counting the number of solutions

Note that all solutions reduce to S by repeatedly taking a path and dividing it into two paths. We can go in the opposite direction, we can construct any solution T from S as follows: start with T=S. Then do the following operation any number of times: take any 2 paths P_{u,v}, P_{u',v'}\in T where one of u, v is the parent of one of u',v', and join them into one path P'. That is, we update T as follows: T:=T-\{P_{u,v} \cup P_{u',v'}\} \cup P'. The updated T is still valid, and hence the number of solutions is exactly the number of different T's we can get.

We can construct a graph F where each node in that forest represents one path in S. For any two paths in S that can be joint, we add an edge between their corresponding nodes in F. It’s clear that F will be a forest of trees. For each tree in the forest, we pick its root to be the node whose corresponding path in S has no other path directly above it.

Now we have a simpler problem, which is finding the number of vertical partitions of F. Suppose we have a tree rooted at 1, let f(u) be the number vertical partitions of the subtree of u. How to compute f(u)? Well, we can either let u be in a single-node-path and we’d get \prod_{v\in children(u)} f(v) vertical partitions, or we can join it with one of its children v' and we’d get f(v')\cdot (\prod_{v\in children(u), v \neq v'} f(v))=\prod_{v\in children(u)} f(v) vertical partitions. Summing up over all children, f(u) would reduce to:
f(u)=\prod_{v\in children(u)} f(v)+\sum_{v'\in children(u)}\prod_{v\in children(u)} f(v)=(|children(u)| + 1)\cdot \prod_{v\in children(u)} f(v). If the trees of the forest F has roots r_1,r_2,..,r_k, then our answer would be \prod_{i=1}^k f(r_i).

IMPLEMENTATION DETAILS:

For finding a solution, we’ll largely follow the pseudocode above. But we need a way to efficiently tell whether two multisets are equal. Since values of a[], b[] are small, we will maintain one frequency array freq[], where freq[x] = (number of times x appeared in a[]) - (number of times x appeared in b[]), where we consider values of a[u], b[u] over u’s that appeared in the path we are currently constructing. Clearly, the multisets are equal if and only if the array freq[] is filled with zeros. If it is filled at any position with a non-zero, then they are not equal. We will use a variable nonZero which tells us how many elements in freq[] are not equal to zero. It simply starts with 0. We increment it if, for some x, freq[x] was updated from 0 to a non-zero value, and decrement it if, for some x, freq[x] was updated from a non-zero value to 0. Do not forget to clean freq[] after each test case.

For the counting part, we first need to modify the previous code to uniquely identify each path we add to S. Then, we construct T as follows: for every edge in the tree that connects two different paths in S, we add an edge between the nodes corresponding to those paths in F. We then do a simple DFS to compute the final answer.

Time complexity: O(N + MAX A_i).

PROBLEM VARIANTS:

Find a solution that minimizes the length of the longest vertical path

This would simply be the solution S we constructed.

Find a solution that maximizes the length of the longest vertical path

In F, we assign each node a value equal to the length of the vertical path it corresponds to in the original tree. In each tree in F, we maximize the longest path in it. We do so by picking a leaf with the maximum sum-of-weights of the path from the root to this leaf. We join those nodes into one path.

SOLUTIONS:

Setter's Solution
#include <bits/stdc++.h>

using namespace std;

const int OO = 1e9;
const double EPS = 1e-9;

const int MX = 1e6 + 5;
const int MOD = 1000000007;
int N, S, freq[MX], a[MX], b[MX], color[MX], par[MX], deg[MX];
vector<vector<int>> adj, F;

void dfs(int u, int p) {
	par[u] = p;
	deg[u] = 0;

	for (auto &v : adj[u]) {
		if (v != p) {
			dfs(v, u);
			++deg[u];
		}
	}
}
void upd(int x, int delta, int& nonZero) {
	if (!freq[x])
		++nonZero;
	freq[x] += delta;
	if (!freq[x])
		--nonZero;
}

int f(int u) {
	int ret = 1 + F[u].size();

	for (auto &v : F[u])
		ret = (ret * 1LL * f(v)) % MOD;

	return ret;
}

int main() {
	ios::sync_with_stdio(false);
	cout.precision(10);
	cin.tie(0);

	int T;

	cin >> T;
	int inv = 0;

	for (int tc = 1; tc <= T; ++tc) {

		// taking input

		cin >> N >> S;


		adj.assign(N + 1, vector<int>());

		for (int i = 1; i <= N - 1; ++i) {
			int u, v;
			cin >> u >> v;
			adj[u].push_back(v);
			adj[v].push_back(u);
		}




		for (int i = 1; i <= N; ++i) {
			cin >> a[i];
		}

		for (int i = 1; i <= N; ++i) {
			cin >> b[i];
		}

		// building deg[] and par[]

		dfs(1, 0);

		// finding if a solution exists

		queue<int> leaves;

		for (int i = 1; i <= N; ++i) {
			if (!deg[i])
				leaves.push(i);
		}

		for (int i = 1; i <= N; ++i) {
			color[i] = -1;
		}

		int nonZero = 0;



		bool invalid = false;

		while (!leaves.empty()) {
			int cur = leaves.front();
			leaves.pop();

			int id = cur;
			bool good = false;

			while (true) {
				color[cur] = id;

				// nonZero is passed by reference, could be changed by upd
				upd(a[cur], 1, nonZero);
				upd(b[cur], -1, nonZero);

				--deg[par[cur]];

				if (!nonZero) {
					good = true;
					if (color[par[cur]] == -1 && !deg[par[cur]])
						leaves.push(par[cur]);
					break;
				}

				if (color[par[cur]] == -1)
					cur = par[cur];
				else
					break;
			}

			if (!good) {
				invalid = true;
				break;
			}
		}

		inv += invalid;

		if (invalid)
			cout << "0\n";
		else {

			if (S == 1)
				cout << "1\n";
			else {

				// building F

				F.assign(N + 1, vector<int>());
				vector<int> roots;
				for (int i = 1; i <= N; ++i) {
					if (color[i] != color[par[i]]) {
						if (i > 1 && color[par[i]] == par[i])
							F[color[par[i]]].push_back(color[i]);
						else
							roots.push_back(color[i]);
					}
				}

				int ans = 1;

				for (auto &root : roots)
					ans = (ans * 1LL * f(root)) % MOD;

				cout << ans << '\n';
			}
		}

		for (int i = 1; i <= N; ++i) {
			freq[a[i]] = freq[b[i]] = 0;
		}
	}

	return 0;
}
Tester's Solution
//By TheOneYouWant
#pragma GCC optimize ("-O2")
#include <bits/stdc++.h>
using namespace std;
#define fastio ios_base::sync_with_stdio(0);cin.tie(0)
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define all(x) x.begin(),x.end()
#define memreset(a) memset(a,0,sizeof(a))
#define testcase(t) int t;cin>>t;while(t--)
#define forstl(i,v) for(auto &i: v)
#define forn(i,e) for(int i=0;i<e;++i)
#define forsn(i,s,e) for(int i=s;i<e;++i)
#define rforn(i,s) for(int i=s;i>=0;--i)
#define rforsn(i,s,e) for(int i=s;i>=e;--i)
#define bitcount(a) __builtin_popcount(a) // set bits (add ll)
#define ln '\n'
#define getcurrtime() cerr<<"Time = "<<((double)clock()/CLOCKS_PER_SEC)<<endl
#define dbgarr(v,s,e) cerr<<#v<<" = "; forsn(i,s,e) cerr<<v[i]<<", "; cerr<<endl
#define inputfile freopen("input.txt", "r", stdin)
#define outputfile freopen("output.txt", "w", stdout)
#define dbg(args...) { string _s = #args; replace(_s.begin(), _s.end(), ',', ' '); \
stringstream _ss(_s); istream_iterator<string> _it(_ss); err(_it, args); }
void err(istream_iterator<string> it) { cerr<<endl; }
template<typename T, typename... Args>
void err(istream_iterator<string> it, T a, Args... args) {
	cerr << *it << " = " << a << "\t"; err(++it, args...);
}
template<typename T1,typename T2>
ostream& operator <<(ostream& c,pair<T1,T2> &v){
	c<<"("<<v.fi<<","<<v.se<<")"; return c;
}
template <template <class...> class TT, class ...T>
ostream& operator<<(ostream& out,TT<T...>& c){
    out<<"{ ";
    forstl(x,c) out<<x<<" ";
    out<<"}"; return out;
}
typedef long long ll;
typedef unsigned long long ull;
typedef long double ld;
typedef pair<ll,ll> p64;
typedef pair<int,int> p32;
typedef pair<int,p32> p96;
typedef vector<ll> v64;
typedef vector<int> v32; 
typedef vector<v32> vv32;
typedef vector<v64> vv64;
typedef vector<p32> vp32;
typedef vector<p64> vp64;
typedef vector<vp32> vvp32;
typedef map<int,int> m32;
const int LIM=2e5+5,MOD=1e9+7;
const ld EPS = 1e-9;

int read(){
    int xx=0,ff=1;char ch=getchar();
    while(ch>'9'||ch<'0'){if(ch=='-')ff=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){xx=(xx<<3)+(xx<<1)+ch-'0';ch=getchar();}
    return xx*ff;
}

mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());

long long readInt(long long l, long long r, char endd) {
	long long x=0;
	int cnt=0;
	int fi=-1;
	bool is_neg=false;
	while(true) {
		char g=getchar();
		if(g=='-') {
			assert(fi==-1);
			is_neg=true;
			continue;
		}
		if('0'<=g&&g<='9') {
			x*=10;
			x+=g-'0';
			if(cnt==0) {
				fi=g-'0';
			}
			cnt++;
			assert(fi!=0 || cnt==1);
			assert(fi!=0 || is_neg==false);
 
			assert(!(cnt>19 || ( cnt==19 && fi>1) ));
		} else if(g==endd) {
			if(is_neg) {
				x=-x;
			}
			assert(l<=x&&x<=r);
			return x;
		} else {
			assert(false);
		}
	}
}
string readString(int l, int r, char endd) {
	string ret="";
	int cnt=0;
	while(true) {
		char g=getchar();
		assert(g!=-1);
		if(g==endd) {
			break;
		}
		cnt++;
		ret+=g;
	}
	assert(l<=cnt&&cnt<=r);
	return ret;
}
long long readIntSp(long long l, long long r) {
	return readInt(l,r,' ');
}
long long readIntLn(long long l, long long r) {
	return readInt(l,r,'\n');
}
string readStringLn(int l, int r) {
	return readString(l,r,'\n');
}
string readStringSp(int l, int r) {
	return readString(l,r,' ');
}

v32 adj[LIM];



int link[LIM] = {0};
int sz[LIM] = {0};
int deg[LIM] = {0};
int par[LIM] = {0};
int taken[LIM] = {0};
int h[LIM] = {0};
int a[LIM] = {0};
int b[LIM] = {0};
map<int,int> val;
int to_correct = 0;
bool pos;
int n, s;

priority_queue<pair<int,int>> consider;
v32 rem;

int find(int x){
	if(x == link[x]) return x;
	return link[x] = find(link[x]);
}

void unite(int a, int b){
	a = find(a);
	b = find(b);
	if(sz[a]<sz[b]) swap(a,b);
	sz[a]+=sz[b];
	link[b] = a;
}

void dfs(int node){
	forstl(r, adj[node]){
		if(r == par[node]) continue;
		h[r] = h[node]+1;
		par[r] = node;
		dfs(r);
	}
}

void check_add(int node){
	if(deg[node] == 1){
		consider.push(mp(h[node], node));
	}
	return;
}

void solve(int node){
	if(taken[node]){
		pos = 0;
		return;
	}
	taken[node] = 1;
	if(val[a[node]]<0){
		to_correct--;
	}
	else{
		to_correct++;
	}
	val[a[node]]++;
	if(val[b[node]]>0){
		to_correct--;
	}
	else{
		to_correct++;
	}
	val[b[node]]--;
	deg[node] = 0;
	forstl(r, adj[node]){
		deg[r]--;
		check_add(r);
	}
	if(to_correct == 0){
		return;
	}
	int p = par[node];
	if(p == node){
		pos = 0;
		return;
	}
	solve(p);
}



int main(){
	fastio;

	int tests;
	tests = readIntLn(1, 1'000'000);
	ll sum_n = 0;

	while(tests--){
		n = readIntSp(1, 100'000);
		s = readIntLn(1, 2);

		sum_n += n;

		forn(i,n){
			adj[i].clear();
			link[i] = i;
			sz[i] = 1;
			h[i] = 0;
			par[i] = 0;
			deg[i] = 0;
			taken[i] = 0;
		}

		forn(i,n-1){
			int u, v;
			u = readIntSp(1, n);
			v = readIntLn(1, n);
			u--;
			v--;
			if(find(u) == find(v)){
				cout<<"not a tree! \n";
				assert(0);
			}
			unite(u, v);
			adj[u].pb(v);
			adj[v].pb(u);
			deg[u]++;
			deg[v]++;
		}

		forn(i,n){
			if(i < n-1) a[i] = readIntSp(1, 1'000'000);
			else a[i] = readIntLn(1, 1'000'000);
		}
		forn(i,n){
			if(i < n-1) b[i] = readIntSp(1, 1'000'000);
			else b[i] = readIntLn(1, 1'000'000);
		}

		dfs(0);
		pos = 1;
		v32 special;
		forn(i,n){
			if(deg[i] <= 1){
				consider.push(mp(h[i], i));
			}
		}
		while(!consider.empty()){
			auto t = consider.top();
			consider.pop();
			if(taken[t.se]) continue;
			special.pb(t.se);
			// solve for t
			// cout<<t<<ln;
			to_correct = 0;
			val.clear();
			solve(t.se);
		}

		if(pos){
			ll ans = 1;
			forstl(r,special){
				ll choice = 1;
				forstl(k,adj[r]){
					if(k == par[r]) continue;
					choice++;
				}
				ans = (ans * choice) % MOD;
			}
			if(s == 2) cout<<ans<<ln;
			else cout<<1<<ln;
		}
		else{
			cout<<0<<ln;
		}
	}

	assert(sum_n <= 1'000'000);
	assert(getchar()==EOF);

	return 0;
}
6 Likes

The key to solution is just the fact that
1: When we move in a tree from top to down(from root to leaves), we have multiple paths(multiple children) to choose from BUT

2: When we move from bottom to up(leaves to root), we have only one path(a node has single parent) to follow

5 Likes

Ok I am finally saying it. Why do Setter’s and Tester’s solutions contain all that bloatware code used by seasoned coders? The solutions are meant for people who struggle at it. It should be sweet and simple code with readability in mind. I could hardly find them useful.

9 Likes

I’m sorry. As a setter, the code shown is a modified version in which I tried to make everything readable. I don’t think that it contains anything complex enough for it to be understandable only by seasoned coders. I also assumed in my code that it will be read by people having some familiarity with the topics mentioned in the prerequisites (e.g. the adjacency list representation of the tree using vectors - here). The other thing that might be confusing is the for-each loop syntax - here and the auto keyword - here. If there is a particular part of the code that is still confusing, feel free to point it out.

3 Likes

Thank you so much :slight_smile: The code looks very clean now.
Also, thanks for your efforts for setting this nice problem. Loved solving it.

1 Like

Hey @a7med1080 Can you please tell the variants of this Tree Permutation question that @alei suggested you during the question reviewing phase.

2 Likes

Done

1 Like